jaxpint.utils#

Utility functions for JaxPINT.

Pure JAX ports of selected functions from pint.utils. All functions are JIT-compatible and operate on raw float64 arrays (no units).

jaxpint.utils.taylor_horner(x, coeffs)[source]#

Evaluate a Taylor series at x via the Horner scheme.

The Taylor series is:

coeffs[0] + coeffs[1]*x/1! + coeffs[2]*x^2/2! + ...

Example:

taylor_horner(2.0, jnp.array([10., 3., 4., 12.]))  # -> 40.0
Parameters:
  • x (array) – Evaluation point(s).

  • coeffs (1-D array, shape (n_coeffs,)) – Taylor coefficients. coeffs[i] multiplies x**i / i!.

Return type:

array, same shape as *x*

jaxpint.utils.taylor_horner_deriv(x, coeffs, deriv_order=1)[source]#

Evaluate the deriv_order-th derivative of a Taylor series.

Uses the Horner scheme with jax.lax.fori_loop for JIT efficiency. (see # https://en.wikipedia.org/wiki/Horner%27s_method)

Example:

taylor_horner_deriv(2.0, jnp.array([10., 3., 4., 12.]), 1)  # -> 35.0
Parameters:
  • x (array) – Evaluation point(s).

  • coeffs (1-D array, shape (n_coeffs,)) – Taylor coefficients.

  • deriv_order (int) – Derivative order (non-negative).

Return type:

array, same shape as *x*

jaxpint.utils.taylor_horner_phase(dt_int_days, dt_frac_days, delay, coeffs)[source]#

Evaluate a Taylor series with phase precision via int/frac Horner.

Uses the day decomposition dt = dt_int_days * 86400 + dt_frac_s to split each Horner multiplication into integer (exact) and fractional (precise) parts, avoiding the precision loss that occurs when a large absolute phase (~10^10 cycles) is computed as a single float64.

Parameters:
  • dt_int_days ((n_toas,)) – Integer MJD day difference from epoch (exact).

  • dt_frac_days ((n_toas,)) – Fractional MJD day difference from epoch.

  • delay ((n_toas,)) – Accumulated signal delay in seconds.

  • coeffs ((n_coeffs,)) – Taylor coefficients: coeffs[k] multiplies dt**k / k!.

Returns:

Phase in cycles, split as integer + fractional part.

Return type:

DualFloat

jaxpint.utils.weighted_mean(arrin, weights_in, inputmean=None, calcerr=False)[source]#

Compute weighted mean and error of arrin.

Parameters:
  • arrin (1-D array) – Data values.

  • weights_in (1-D array) – Weights (typically 1 / sigma**2).

  • inputmean (float, optional) – If given, use this as the mean instead of computing it.

  • calcerr (bool) – If True, compute error from weighted scatter rather than 1 / sqrt(sum(weights)).

Return type:

(wmean, werr)

jaxpint.utils.weighted_mean_sdev(arrin, weights_in, inputmean=None, calcerr=False)[source]#

Compute weighted mean, error, and standard deviation of arrin.

Parameters:
  • arrin (1-D array) – Data values.

  • weights_in (1-D array) – Weights (typically 1 / sigma**2).

  • inputmean (float, optional) – If given, use this as the mean instead of computing it.

  • calcerr (bool) – If True, compute error from weighted scatter rather than 1 / sqrt(sum(weights)).

Return type:

(wmean, werr, wsdev)

jaxpint.utils.normalize_designmatrix(M)[source]#

Column-normalize the design matrix for numerical stability.

The normalized matrix Mn and the original M are related by M = Mn * norms (broadcasting over rows). GLS expressions of the form M @ inv(M.T @ Ninv @ M) @ M.T are invariant under this rescaling.

Columns with zero norm (degenerate parameters) are left as-is.

Parameters:

M (2-D array, shape (n_toas, n_params))

Returns:

degenerate is a boolean mask that is True for columns with zero norm (i.e. parameters that have no effect on the residuals).

Return type:

(M_normalized, norms, degenerate)

jaxpint.utils.sherman_morrison_dot(Ndiag, v, w, x, y)[source]#

Compute \(x^T C^{-1} y\) where \(C = \mathrm{diag}(N) + w\,v\,v^T\).

Uses the Sherman–Morrison identity to avoid forming or inverting C.

Parameters:
  • Ndiag (1-D array) – Diagonal of N (positive).

  • v (1-D array) – Rank-1 update vector.

  • w (scalar) – Weight of the rank-1 update.

  • x (1-D arrays) – Vectors for the inner product.

  • y (1-D arrays) – Vectors for the inner product.

Returns:

The inner product and the log-determinant of C.

Return type:

(result, logdet_C)

jaxpint.utils.woodbury_dot(Ndiag, U, Phidiag, x, y)[source]#

Compute \(x^T C^{-1} y\) where \(C = \mathrm{diag}(N) + U\,\mathrm{diag}(\Phi)\,U^T\).

Uses the Woodbury identity and Cholesky factorisation of the reduced-rank matrix \(\Sigma = \Phi^{-1} + U^T N^{-1} U\).

Parameters:
  • Ndiag (1-D array, shape (n,)) – Diagonal of N (positive).

  • U (2-D array, shape (n, k)) – Low-rank update basis.

  • Phidiag (1-D array, shape (k,)) – Diagonal of \(\Phi\) (positive).

  • x (1-D arrays, shape (n,)) – Vectors for the inner product.

  • y (1-D arrays, shape (n,)) – Vectors for the inner product.

Returns:

The inner product and the log-determinant of C.

Return type:

(result, logdet_C)

jaxpint.utils.woodbury_solve(Ndiag, U, Phidiag, B)[source]#

Compute \(C^{-1} B\) where \(C = \mathrm{diag}(N) + U\,\mathrm{diag}(\Phi)\,U^T\).

Uses the Woodbury identity:

C^{-1} = N^{-1} - N^{-1} U Σ^{-1} U^T N^{-1}

where \(\Sigma = \Phi^{-1} + U^T N^{-1} U\).

Parameters:
  • Ndiag (1-D array, shape (n,)) – Diagonal of N (positive).

  • U (2-D array, shape (n, k)) – Low-rank update basis.

  • Phidiag (1-D array, shape (k,)) – Diagonal of \(\Phi\) (positive).

  • B (2-D array, shape (n, m)) – Right-hand side matrix.

Returns:

Cinv_B – The product \(C^{-1} B\).

Return type:

array, shape (n, m)

jaxpint.utils.ecl_to_icrs_rotation(obliquity_arcsec)[source]#

Rotation matrix from ecliptic to ICRS (row-vector convention).

Usage: L_icrs = L_ecl @ ecl_to_icrs_rotation(obl)

This is the transpose of astropy’s rotation_matrix(obl, 'x'), adapted for row-vector multiplication.

Parameters:

obliquity_arcsec (float)

Return type:

Float[Array, ‘3 3’]

jaxpint.utils.compute_pulsar_direction(toa_data, params, raj_name, decj_name, pmra_name, pmdec_name, posepoch_name)[source]#

Unit vector from SSB to pulsar in ICRS Cartesian coordinates.

Without proper motion the direction is constant; with proper motion a linear correction is applied per TOA.

Parameters:
  • toa_data (TOAData) – Pre-extracted TOA data (needs tdb_int, tdb_frac, n_toas).

  • params (ParameterVector) – Timing-model parameters.

  • raj_name (str) – Parameter names for RA and DEC (radians).

  • decj_name (str) – Parameter names for RA and DEC (radians).

  • pmra_name (str or None) – Parameter names for proper motion (mas/yr). None disables PM.

  • pmdec_name (str or None) – Parameter names for proper motion (mas/yr). None disables PM.

  • posepoch_name (str or None) – Epoch parameter for proper-motion reference.

Return type:

Float[Array, ‘n_toas 3’]

jaxpint.utils.compute_pulsar_direction_ecl(toa_data, params, elong_name, elat_name, pmelong_name, pmelat_name, posepoch_name, obliquity_arcsec)[source]#

Unit vector from SSB to pulsar in ICRS, computed from ecliptic coordinates.

Computes the direction in ecliptic frame (reusing the same lon/lat → xyz math as compute_pulsar_direction), then rotates to ICRS.

Parameters:
  • toa_data (TOAData)

  • params (ParameterVector)

  • elong_name (str) – Parameter names for ecliptic longitude and latitude (radians).

  • elat_name (str) – Parameter names for ecliptic longitude and latitude (radians).

  • pmelong_name (str or None) – Proper motion parameter names (mas/yr). None disables PM.

  • pmelat_name (str or None) – Proper motion parameter names (mas/yr). None disables PM.

  • posepoch_name (str or None) – Epoch parameter for proper-motion reference.

  • obliquity_arcsec (float) – Obliquity of the ecliptic in arcseconds.

Return type:

Float[Array, ‘n_toas 3’]

jaxpint.utils.fourier_sum(dt_days, wx_freqs, wx_sins, wx_coses)[source]#

Evaluate a Fourier sum at each TOA.

Computes:

result[t] = Σ_i (wx_sins[i] * sin(2π * wx_freqs[i] * dt_days[t])
               + wx_coses[i] * cos(2π * wx_freqs[i] * dt_days[t]))
Parameters:
  • dt_days ((n_toas,)) – Time differences from the reference epoch in days.

  • wx_freqs ((n_components,)) – Fourier frequencies in 1/day.

  • wx_sins ((n_components,)) – Sine amplitudes.

  • wx_coses ((n_components,)) – Cosine amplitudes.

Returns:

Fourier sum evaluated at each TOA.

Return type:

(n_toas,)

jaxpint.utils.build_quantization_matrix(tdb_times_s, ecorr_masks, dt=1.0, nmin=2)[source]#

Build the ECORR quantization matrix (NumPy, not JIT-compatible).

Groups TOAs within dt seconds into epochs and creates a binary matrix U mapping TOAs to epochs. Only epochs with at least nmin TOAs are kept.

Parameters:
  • tdb_times_s ((n_toas,) float64) – TOA times in TDB seconds.

  • ecorr_masks (dict[str, ndarray]) – Boolean masks keyed by ECORR parameter name.

  • dt (float, int) – Epoch grouping threshold (seconds) and minimum TOAs per epoch.

  • nmin (float, int) – Epoch grouping threshold (seconds) and minimum TOAs per epoch.

Returns:

  • U ((n_toas, n_total_epochs)) – Binary quantization matrix.

  • epoch_slices (dict[str, (int, int)]) – Column-index range for each ECORR parameter.

Return type:

tuple[ndarray, dict[str, tuple[int, int]]]

jaxpint.utils.build_fourier_basis(tdb_times_s, n_freqs, T)[source]#

Build an alternating sin/cos Fourier design matrix.

Parameters:
  • tdb_times_s ((n_toas,)) – TOA times in TDB seconds.

  • n_freqs (int) – Number of frequency modes.

  • T (float) – Time span in seconds (sets the fundamental frequency 1/T).

Returns:

  • F ((n_toas, 2 * n_freqs)) – Fourier design matrix with columns [sin(2πf₁t), cos(2πf₁t), sin(2πf₂t), ...].

  • freqs ((n_freqs,)) – Frequency array in Hz.

  • freq_bin_widths ((n_freqs,)) – Δf for each frequency bin.

Return type:

tuple[ndarray, ndarray, ndarray]