jaxpint.utils#
Utility functions for JaxPINT.
Pure JAX ports of selected functions from pint.utils. All functions are JIT-compatible and operate on raw float64 arrays (no units).
- jaxpint.utils.taylor_horner(x, coeffs)[source]#
Evaluate a Taylor series at x via the Horner scheme.
The Taylor series is:
coeffs[0] + coeffs[1]*x/1! + coeffs[2]*x^2/2! + ...
Example:
taylor_horner(2.0, jnp.array([10., 3., 4., 12.])) # -> 40.0
- Parameters:
x (
array) – Evaluation point(s).coeffs (
1-D array,shape (n_coeffs,)) – Taylor coefficients.coeffs[i]multipliesx**i / i!.
- Return type:
array,same shape as *x*
- jaxpint.utils.taylor_horner_deriv(x, coeffs, deriv_order=1)[source]#
Evaluate the deriv_order-th derivative of a Taylor series.
Uses the Horner scheme with
jax.lax.fori_loopfor JIT efficiency. (see # https://en.wikipedia.org/wiki/Horner%27s_method)Example:
taylor_horner_deriv(2.0, jnp.array([10., 3., 4., 12.]), 1) # -> 35.0
- Parameters:
x (
array) – Evaluation point(s).coeffs (
1-D array,shape (n_coeffs,)) – Taylor coefficients.deriv_order (
int) – Derivative order (non-negative).
- Return type:
array,same shape as *x*
- jaxpint.utils.taylor_horner_phase(dt_int_days, dt_frac_days, delay, coeffs)[source]#
Evaluate a Taylor series with phase precision via int/frac Horner.
Uses the day decomposition
dt = dt_int_days * 86400 + dt_frac_sto split each Horner multiplication into integer (exact) and fractional (precise) parts, avoiding the precision loss that occurs when a large absolute phase (~10^10 cycles) is computed as a single float64.- Parameters:
dt_int_days (
(n_toas,)) – Integer MJD day difference from epoch (exact).dt_frac_days (
(n_toas,)) – Fractional MJD day difference from epoch.delay (
(n_toas,)) – Accumulated signal delay in seconds.coeffs (
(n_coeffs,)) – Taylor coefficients:coeffs[k]multipliesdt**k / k!.
- Returns:
Phase in cycles, split as integer + fractional part.
- Return type:
DualFloat
- jaxpint.utils.weighted_mean(arrin, weights_in, inputmean=None, calcerr=False)[source]#
Compute weighted mean and error of arrin.
- Parameters:
- Return type:
(wmean,werr)
- jaxpint.utils.weighted_mean_sdev(arrin, weights_in, inputmean=None, calcerr=False)[source]#
Compute weighted mean, error, and standard deviation of arrin.
- Parameters:
- Return type:
(wmean,werr,wsdev)
- jaxpint.utils.normalize_designmatrix(M)[source]#
Column-normalize the design matrix for numerical stability.
The normalized matrix
Mnand the originalMare related byM = Mn * norms(broadcasting over rows). GLS expressions of the formM @ inv(M.T @ Ninv @ M) @ M.Tare invariant under this rescaling.Columns with zero norm (degenerate parameters) are left as-is.
- Parameters:
M (
2-D array,shape (n_toas,n_params))- Returns:
degenerateis a boolean mask that is True for columns with zero norm (i.e. parameters that have no effect on the residuals).- Return type:
(M_normalized,norms,degenerate)
- jaxpint.utils.sherman_morrison_dot(Ndiag, v, w, x, y)[source]#
Compute \(x^T C^{-1} y\) where \(C = \mathrm{diag}(N) + w\,v\,v^T\).
Uses the Sherman–Morrison identity to avoid forming or inverting C.
- Parameters:
Ndiag (
1-D array) – Diagonal of N (positive).v (
1-D array) – Rank-1 update vector.w (
scalar) – Weight of the rank-1 update.x (
1-D arrays) – Vectors for the inner product.y (
1-D arrays) – Vectors for the inner product.
- Returns:
The inner product and the log-determinant of C.
- Return type:
(result,logdet_C)
- jaxpint.utils.woodbury_dot(Ndiag, U, Phidiag, x, y)[source]#
Compute \(x^T C^{-1} y\) where \(C = \mathrm{diag}(N) + U\,\mathrm{diag}(\Phi)\,U^T\).
Uses the Woodbury identity and Cholesky factorisation of the reduced-rank matrix \(\Sigma = \Phi^{-1} + U^T N^{-1} U\).
- Parameters:
Ndiag (
1-D array,shape (n,)) – Diagonal of N (positive).U (
2-D array,shape (n,k)) – Low-rank update basis.Phidiag (
1-D array,shape (k,)) – Diagonal of \(\Phi\) (positive).x (
1-D arrays,shape (n,)) – Vectors for the inner product.y (
1-D arrays,shape (n,)) – Vectors for the inner product.
- Returns:
The inner product and the log-determinant of C.
- Return type:
(result,logdet_C)
- jaxpint.utils.woodbury_solve(Ndiag, U, Phidiag, B)[source]#
Compute \(C^{-1} B\) where \(C = \mathrm{diag}(N) + U\,\mathrm{diag}(\Phi)\,U^T\).
Uses the Woodbury identity:
C^{-1} = N^{-1} - N^{-1} U Σ^{-1} U^T N^{-1}
where \(\Sigma = \Phi^{-1} + U^T N^{-1} U\).
- Parameters:
Ndiag (
1-D array,shape (n,)) – Diagonal of N (positive).U (
2-D array,shape (n,k)) – Low-rank update basis.Phidiag (
1-D array,shape (k,)) – Diagonal of \(\Phi\) (positive).B (
2-D array,shape (n,m)) – Right-hand side matrix.
- Returns:
Cinv_B – The product \(C^{-1} B\).
- Return type:
array,shape (n,m)
- jaxpint.utils.ecl_to_icrs_rotation(obliquity_arcsec)[source]#
Rotation matrix from ecliptic to ICRS (row-vector convention).
Usage:
L_icrs = L_ecl @ ecl_to_icrs_rotation(obl)This is the transpose of astropy’s
rotation_matrix(obl, 'x'), adapted for row-vector multiplication.- Parameters:
obliquity_arcsec (float)
- Return type:
Float[Array, ‘3 3’]
- jaxpint.utils.compute_pulsar_direction(toa_data, params, raj_name, decj_name, pmra_name, pmdec_name, posepoch_name)[source]#
Unit vector from SSB to pulsar in ICRS Cartesian coordinates.
Without proper motion the direction is constant; with proper motion a linear correction is applied per TOA.
- Parameters:
toa_data (
TOAData) – Pre-extracted TOA data (needstdb_int,tdb_frac,n_toas).params (
ParameterVector) – Timing-model parameters.raj_name (
str) – Parameter names for RA and DEC (radians).decj_name (
str) – Parameter names for RA and DEC (radians).pmra_name (
strorNone) – Parameter names for proper motion (mas/yr). None disables PM.pmdec_name (
strorNone) – Parameter names for proper motion (mas/yr). None disables PM.posepoch_name (
strorNone) – Epoch parameter for proper-motion reference.
- Return type:
Float[Array, ‘n_toas 3’]
- jaxpint.utils.compute_pulsar_direction_ecl(toa_data, params, elong_name, elat_name, pmelong_name, pmelat_name, posepoch_name, obliquity_arcsec)[source]#
Unit vector from SSB to pulsar in ICRS, computed from ecliptic coordinates.
Computes the direction in ecliptic frame (reusing the same lon/lat → xyz math as
compute_pulsar_direction), then rotates to ICRS.- Parameters:
toa_data (
TOAData)params (
ParameterVector)elong_name (
str) – Parameter names for ecliptic longitude and latitude (radians).elat_name (
str) – Parameter names for ecliptic longitude and latitude (radians).pmelong_name (
strorNone) – Proper motion parameter names (mas/yr). None disables PM.pmelat_name (
strorNone) – Proper motion parameter names (mas/yr). None disables PM.posepoch_name (
strorNone) – Epoch parameter for proper-motion reference.obliquity_arcsec (
float) – Obliquity of the ecliptic in arcseconds.
- Return type:
Float[Array, ‘n_toas 3’]
- jaxpint.utils.fourier_sum(dt_days, wx_freqs, wx_sins, wx_coses)[source]#
Evaluate a Fourier sum at each TOA.
Computes:
result[t] = Σ_i (wx_sins[i] * sin(2π * wx_freqs[i] * dt_days[t]) + wx_coses[i] * cos(2π * wx_freqs[i] * dt_days[t]))
- Parameters:
dt_days (
(n_toas,)) – Time differences from the reference epoch in days.wx_freqs (
(n_components,)) – Fourier frequencies in 1/day.wx_sins (
(n_components,)) – Sine amplitudes.wx_coses (
(n_components,)) – Cosine amplitudes.
- Returns:
Fourier sum evaluated at each TOA.
- Return type:
(n_toas,)
- jaxpint.utils.build_quantization_matrix(tdb_times_s, ecorr_masks, dt=1.0, nmin=2)[source]#
Build the ECORR quantization matrix (NumPy, not JIT-compatible).
Groups TOAs within dt seconds into epochs and creates a binary matrix
Umapping TOAs to epochs. Only epochs with at least nmin TOAs are kept.- Parameters:
tdb_times_s (
(n_toas,) float64) – TOA times in TDB seconds.ecorr_masks (
dict[str,ndarray]) – Boolean masks keyed by ECORR parameter name.dt (
float,int) – Epoch grouping threshold (seconds) and minimum TOAs per epoch.nmin (
float,int) – Epoch grouping threshold (seconds) and minimum TOAs per epoch.
- Returns:
U (
(n_toas,n_total_epochs)) – Binary quantization matrix.epoch_slices (
dict[str,(int,int)]) – Column-index range for each ECORR parameter.
- Return type:
- jaxpint.utils.build_fourier_basis(tdb_times_s, n_freqs, T)[source]#
Build an alternating sin/cos Fourier design matrix.
- Parameters:
- Returns:
F (
(n_toas,2 * n_freqs)) – Fourier design matrix with columns[sin(2πf₁t), cos(2πf₁t), sin(2πf₂t), ...].freqs (
(n_freqs,)) – Frequency array in Hz.freq_bin_widths (
(n_freqs,)) – Δf for each frequency bin.
- Return type:
tuple[ndarray, ndarray, ndarray]