jaxpint.fitters#

WLS, GLS, and wideband fitters for JaxPINT.

class jaxpint.fitters.BaseFitter(model, toa_data, params, noise_model=None)[source]#

Bases: ABC

Abstract base for all JaxPINT fitters.

Subclasses must implement fit_toas() (the main entry point) and typically also implement _iteration and _build_result.

Parameters:
  • model (TimingModel) – JaxPINT timing model.

  • toa_data (TOAData) – Pre-extracted TOA data.

  • params (ParameterVector) – Initial parameter values (free/frozen flags determine what is fit).

  • noise_model (NoiseModel, optional) – Noise model.

abstractmethod fit_toas(maxiter=1, **kwargs)[source]#

Run the fit and return a result container.

Subclasses narrow the return type to their specific result class (e.g. WLSFitResult, GLSFitResult).

Parameters:
  • maxiter (int, optional) – Maximum number of Gauss-Newton iterations. Default is 1.

  • **kwargs – Subclass-specific options (e.g. threshold, full_cov).

Returns:

A dataclass containing updated parameters, covariance, uncertainties, chi-squared, and degrees of freedom.

Return type:

BaseFitResult

class jaxpint.fitters.BaseFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2)[source]#

Bases: object

Common fields shared by all fit results.

Parameters:
  • params (ParameterVector)

  • covariance_matrix (Float[Array, 'n_free n_free'])

  • correlation_matrix (Float[Array, 'n_free n_free'])

  • parameter_uncertainties (Float[Array, 'n_free'])

  • chi2 (float)

  • dof (int)

  • reduced_chi2 (float)

params: ParameterVector#
covariance_matrix: Float[Array, 'n_free n_free']#
correlation_matrix: Float[Array, 'n_free n_free']#
parameter_uncertainties: Float[Array, 'n_free']#
chi2: float#
dof: int#
reduced_chi2: float#
class jaxpint.fitters.WLSFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2, residuals)[source]#

Bases: BaseFitResult

Result of a WLS fit.

Parameters:
  • params (ParameterVector)

  • covariance_matrix (Float[Array, 'n_free n_free'])

  • correlation_matrix (Float[Array, 'n_free n_free'])

  • parameter_uncertainties (Float[Array, 'n_free'])

  • chi2 (float)

  • dof (int)

  • reduced_chi2 (float)

  • residuals (Float[Array, 'n_toas'])

residuals: Float[Array, 'n_toas']#
class jaxpint.fitters.WLSFitter(model, toa_data, params, noise_model=None)[source]#

Bases: BaseFitter

Weighted Least Squares fitter (Gauss-Newton with SVD).

The fitter is an immutable configuration container. Calling fit_toas() returns a WLSFitResult without mutating the fitter itself. Only the diagonal (white-noise) part of the noise model is used.

Parameters:
fit_toas(maxiter=1, threshold=None)[source]#

Run the WLS fit.

Parameters:
  • maxiter (int) – Number of Gauss-Newton iterations.

  • threshold (float, optional) – SVD threshold (default 1e-14 * max(n_toas, n_free)).

Returns:

Fit result containing updated parameters, covariance, uncertainties, chi-squared, and residuals.

Return type:

WLSFitResult

class jaxpint.fitters.GLSFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2, residuals, noise_realizations)[source]#

Bases: BaseFitResult

Result of a GLS fit.

Parameters:
  • params (ParameterVector)

  • covariance_matrix (Float[Array, 'n_free n_free'])

  • correlation_matrix (Float[Array, 'n_free n_free'])

  • parameter_uncertainties (Float[Array, 'n_free'])

  • chi2 (float)

  • dof (int)

  • reduced_chi2 (float)

  • residuals (Float[Array, 'n_toas'])

  • noise_realizations (Float[Array, 'n_epochs'] | None)

residuals: Float[Array, 'n_toas']#
noise_realizations: Float[Array, 'n_epochs'] | None#
class jaxpint.fitters.GLSFitter(model, toa_data, params, noise_model=None)[source]#

Bases: BaseFitter

Generalised Least Squares fitter.

The fitter is an immutable configuration container. Calling fit_toas() returns a GLSFitResult without mutating the fitter itself.

Supports arbitrary correlated noise sources (ECORR, red noise, etc.) through the NoiseModel interface. When no correlated components are present the GLS fitter reduces to WLS.

Parameters:
fit_toas(maxiter=1, threshold=None, full_cov=False)[source]#

Run the GLS fit.

Parameters:
  • maxiter (int) – Number of Gauss-Newton iterations.

  • threshold (float, optional) – SVD threshold (default 1e-14 * dim).

  • full_cov (bool) – If True, use Woodbury-based full covariance inversion. If False (default), use the augmented design-matrix approach.

Returns:

Fit result containing updated parameters, covariance, uncertainties, chi-squared, residuals, and noise realizations.

Return type:

GLSFitResult

class jaxpint.fitters.WidebandGLSFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2, time_residuals, dm_residuals, noise_realizations)[source]#

Bases: BaseFitResult

Result of a wideband GLS fit.

Parameters:
  • params (ParameterVector)

  • covariance_matrix (Float[Array, 'n_free n_free'])

  • correlation_matrix (Float[Array, 'n_free n_free'])

  • parameter_uncertainties (Float[Array, 'n_free'])

  • chi2 (float)

  • dof (int)

  • reduced_chi2 (float)

  • time_residuals (Float[Array, 'n_toas'])

  • dm_residuals (Float[Array, 'n_toas'])

  • noise_realizations (Float[Array, 'n_epochs'] | None)

time_residuals: Float[Array, 'n_toas']#
dm_residuals: Float[Array, 'n_toas']#
noise_realizations: Float[Array, 'n_epochs'] | None#
class jaxpint.fitters.WidebandGLSFitter(model, toa_data, params, noise_model=None)[source]#

Bases: BaseFitter

Wideband Generalised Least Squares fitter.

Jointly fits TOA and DM residuals using a combined (2N,) residual vector and design matrix. Reuses the same GLS solve routines as GLSFitter.

Parameters:
fit_toas(maxiter=1, threshold=None, full_cov=False)[source]#

Run the wideband GLS fit.

Parameters:
  • maxiter (int) – Number of Gauss-Newton iterations.

  • threshold (float, optional) – SVD threshold (default 1e-14 * dim).

  • full_cov (bool) – If True, use Woodbury-based full covariance inversion. If False (default), use the augmented design-matrix approach.

Return type:

WidebandGLSFitResult

jaxpint.fitters.compute_chi2(residuals, sigma)[source]#

Compute the weighted chi-squared statistic.

Calculates sum((residuals / sigma) ** 2).

Parameters:
  • residuals (jax.Array, shape (n_toas,)) – Time residuals in seconds.

  • sigma (jax.Array, shape (n_toas,)) – TOA uncertainties in seconds.

Returns:

chi2 – Scalar weighted chi-squared value.

Return type:

jax.Array, shape ()

jaxpint.fitters.compute_design_matrix(model, toa_data, params)[source]#

Build the design matrix via autodiff.

Computes jax.jacobian of time residuals w.r.t. all parameter values, then extracts only the free-parameter columns.

Following PINT’s convention, the design matrix is negated so that M[i, j] = -d(time_resid_i) / d(param_j). This ensures that the WLS update p_new = p_old + dpars reduces residuals.

Parameters:
  • model (TimingModel) – JaxPINT timing model.

  • toa_data (TOAData) – Pre-extracted TOA data.

  • params (ParameterVector) – Current parameter values. The frozen_mask attribute determines which columns (free parameters) appear in the output.

Returns:

M – Negated Jacobian of time residuals with respect to free parameters.

Return type:

jax.Array, shape (n_toas, n_free)

jaxpint.fitters.compute_dm_residuals(model, toa_data, params)[source]#

Compute DM residuals: measured DM - model DM (pc/cm³).

Parameters:
Return type:

Float[Array, ‘n_toas’]

jaxpint.fitters.compute_gls_chi2(residuals, Ndiag, U, Phidiag)[source]#

Compute the GLS chi-squared statistic: r^T C^{-1} r.

The covariance C = diag(N) + U diag(Phi) U^T is inverted via the Woodbury identity without forming the full matrix.

Parameters:
  • residuals (jax.Array, shape (n_toas,)) – Time residuals in seconds.

  • Ndiag (jax.Array, shape (n_toas,)) – Diagonal of the white-noise covariance matrix (variances).

  • U (jax.Array, shape (n_toas, n_epochs)) – Basis matrix for correlated noise components.

  • Phidiag (jax.Array, shape (n_epochs,)) – Diagonal prior covariance of the correlated noise amplitudes.

Returns:

chi2 – Scalar GLS chi-squared value.

Return type:

jax.Array, shape ()

jaxpint.fitters.compute_phase_residuals(model, toa_data, params)[source]#

Compute phase residuals using nearest-pulse tracking.

Returns the fractional part of the model phase (in cycles), after adjusting for delta_pulse_number offsets stored in the TOA data.

Parameters:
  • model (TimingModel) – JaxPINT timing model used to compute pulse phase.

  • toa_data (TOAData) – Pre-extracted TOA data containing observation times and delta_pulse_number corrections.

  • params (ParameterVector) – Current parameter values for the timing model.

Returns:

residuals – Phase residuals in cycles (fractional part of adjusted phase).

Return type:

jax.Array, shape (n_toas,)

jaxpint.fitters.compute_time_residuals(model, toa_data, params)[source]#

Compute time residuals in seconds.

Converts phase residuals (cycles) to time by dividing by the spin frequency F0 (Hz).

Parameters:
  • model (TimingModel) – JaxPINT timing model.

  • toa_data (TOAData) – Pre-extracted TOA data.

  • params (ParameterVector) – Current parameter values (must include F0).

Returns:

residuals – Time residuals in seconds.

Return type:

jax.Array, shape (n_toas,)

jaxpint.fitters.compute_wideband_design_matrix(model, toa_data, params)[source]#

Build the wideband design matrix via autodiff, shape (2N, n_free).

Uses jax.jacobian of the combined [time_resid; dm_resid] vector w.r.t. all parameters, then extracts free columns. Negated per PINT convention.

Parameters:
Return type:

Float[Array, ‘n2_toas n_free’]

jaxpint.fitters.compute_wideband_residuals(model, toa_data, params)[source]#

Compute stacked [time_residuals; dm_residuals], shape (2N,).

Time residuals are in seconds, DM residuals in pc/cm³.

Parameters:
Return type:

Float[Array, ‘n2_toas’]

jaxpint.fitters.gls_step_augmented(residuals, Ndiag, U, Phidiag, M, threshold)[source]#

One GLS solve via the augmented design-matrix approach.

Augments the design matrix as M_aug = [M | U] and solves with diagonal weighting N^{-1} plus a prior on noise amplitudes.

Parameters:
  • residuals (jax.Array, shape (n_toas,)) – Time residuals in seconds (GLS-weighted mean already subtracted).

  • Ndiag (jax.Array, shape (n_toas,)) – Diagonal of the white-noise covariance matrix (variances).

  • U (jax.Array, shape (n_toas, n_epochs)) – Basis matrix for correlated noise components.

  • Phidiag (jax.Array, shape (n_epochs,)) – Diagonal prior covariance of the correlated noise amplitudes.

  • M (jax.Array, shape (n_toas, n_free)) – Design matrix (free-parameter columns only).

  • threshold (float) – Singular values below threshold * S_max are discarded.

Returns:

  • dpars (jax.Array, shape (n_free,)) – Timing parameter updates.

  • covariance (jax.Array, shape (n_free, n_free)) – Timing parameter covariance.

  • norms (jax.Array, shape (n_free + n_epochs,)) – Column norms of the augmented system (diagnostic).

  • noise_realizations (jax.Array, shape (n_epochs,)) – MAP noise amplitude estimates.

Return type:

tuple[Float[Array, ‘n_free’], Float[Array, ‘n_free n_free’], Float[Array, ‘n_aug’], Float[Array, ‘n_epochs’]]

jaxpint.fitters.gls_step_fullcov(residuals, Ndiag, U, Phidiag, M, threshold)[source]#

One GLS solve via full (Woodbury) covariance inversion + SVD.

Computes M^T C^{-1} M and M^T C^{-1} r using woodbury_solve(), then SVD-solves the (n_free, n_free) normal equations.

Parameters:
  • residuals (jax.Array, shape (n_toas,)) – Time residuals in seconds (GLS-weighted mean already subtracted).

  • Ndiag (jax.Array, shape (n_toas,)) – Diagonal of the white-noise covariance matrix (variances).

  • U (jax.Array, shape (n_toas, n_epochs)) – Basis matrix for correlated noise components.

  • Phidiag (jax.Array, shape (n_epochs,)) – Diagonal prior covariance of the correlated noise amplitudes.

  • M (jax.Array, shape (n_toas, n_free)) – Design matrix (free-parameter columns only).

  • threshold (float) – Singular values below threshold * S_max are discarded.

Returns:

  • dpars (jax.Array, shape (n_free,)) – Parameter updates.

  • covariance (jax.Array, shape (n_free, n_free)) – Parameter covariance matrix.

  • norms (jax.Array, shape (n_free,)) – Column norms used for normalisation (diagnostic).

Return type:

tuple[Float[Array, ‘n_free’], Float[Array, ‘n_free n_free’], Float[Array, ‘n_free’]]

jaxpint.fitters.wls_step(residuals, sigma, M, threshold)[source]#

One WLS solve via SVD.

Parameters:
  • residuals ((n_toas,)) – Time residuals in seconds (mean already subtracted).

  • sigma ((n_toas,)) – TOA uncertainties in seconds.

  • M ((n_toas, n_free)) – Design matrix.

  • threshold (float) – Singular values below threshold * S_max are discarded.

Returns:

  • dpars ((n_free,)) – Parameter updates.

  • covariance ((n_free, n_free)) – Parameter covariance matrix.

  • norms ((n_free,)) – Column norms used for normalisation (diagnostic).

Return type:

tuple[Float[Array, ‘n_free’], Float[Array, ‘n_free n_free’], Float[Array, ‘n_free’]]