jaxpint.fitters#
WLS, GLS, and wideband fitters for JaxPINT.
- class jaxpint.fitters.BaseFitter(model, toa_data, params, noise_model=None)[source]#
Bases:
ABCAbstract base for all JaxPINT fitters.
Subclasses must implement
fit_toas()(the main entry point) and typically also implement_iterationand_build_result.- Parameters:
model (
TimingModel) – JaxPINT timing model.toa_data (
TOAData) – Pre-extracted TOA data.params (
ParameterVector) – Initial parameter values (free/frozen flags determine what is fit).noise_model (
NoiseModel, optional) – Noise model.
- abstractmethod fit_toas(maxiter=1, **kwargs)[source]#
Run the fit and return a result container.
Subclasses narrow the return type to their specific result class (e.g.
WLSFitResult,GLSFitResult).- Parameters:
maxiter (
int, optional) – Maximum number of Gauss-Newton iterations. Default is 1.**kwargs – Subclass-specific options (e.g.
threshold,full_cov).
- Returns:
A dataclass containing updated parameters, covariance, uncertainties, chi-squared, and degrees of freedom.
- Return type:
- class jaxpint.fitters.BaseFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2)[source]#
Bases:
objectCommon fields shared by all fit results.
- Parameters:
params (ParameterVector)
covariance_matrix (Float[Array, 'n_free n_free'])
correlation_matrix (Float[Array, 'n_free n_free'])
parameter_uncertainties (Float[Array, 'n_free'])
chi2 (float)
dof (int)
reduced_chi2 (float)
- params: ParameterVector#
- covariance_matrix: Float[Array, 'n_free n_free']#
- correlation_matrix: Float[Array, 'n_free n_free']#
- parameter_uncertainties: Float[Array, 'n_free']#
- class jaxpint.fitters.WLSFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2, residuals)[source]#
Bases:
BaseFitResultResult of a WLS fit.
- Parameters:
params (ParameterVector)
covariance_matrix (Float[Array, 'n_free n_free'])
correlation_matrix (Float[Array, 'n_free n_free'])
parameter_uncertainties (Float[Array, 'n_free'])
chi2 (float)
dof (int)
reduced_chi2 (float)
residuals (Float[Array, 'n_toas'])
- residuals: Float[Array, 'n_toas']#
- class jaxpint.fitters.WLSFitter(model, toa_data, params, noise_model=None)[source]#
Bases:
BaseFitterWeighted Least Squares fitter (Gauss-Newton with SVD).
The fitter is an immutable configuration container. Calling
fit_toas()returns aWLSFitResultwithout mutating the fitter itself. Only the diagonal (white-noise) part of the noise model is used.- Parameters:
model (TimingModel)
toa_data (TOAData)
params (ParameterVector)
noise_model (Optional[NoiseModel])
- class jaxpint.fitters.GLSFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2, residuals, noise_realizations)[source]#
Bases:
BaseFitResultResult of a GLS fit.
- Parameters:
params (ParameterVector)
covariance_matrix (Float[Array, 'n_free n_free'])
correlation_matrix (Float[Array, 'n_free n_free'])
parameter_uncertainties (Float[Array, 'n_free'])
chi2 (float)
dof (int)
reduced_chi2 (float)
residuals (Float[Array, 'n_toas'])
noise_realizations (Float[Array, 'n_epochs'] | None)
- residuals: Float[Array, 'n_toas']#
- class jaxpint.fitters.GLSFitter(model, toa_data, params, noise_model=None)[source]#
Bases:
BaseFitterGeneralised Least Squares fitter.
The fitter is an immutable configuration container. Calling
fit_toas()returns aGLSFitResultwithout mutating the fitter itself.Supports arbitrary correlated noise sources (ECORR, red noise, etc.) through the
NoiseModelinterface. When no correlated components are present the GLS fitter reduces to WLS.- Parameters:
model (TimingModel)
toa_data (TOAData)
params (ParameterVector)
noise_model (Optional[NoiseModel])
- class jaxpint.fitters.WidebandGLSFitResult(params, covariance_matrix, correlation_matrix, parameter_uncertainties, chi2, dof, reduced_chi2, time_residuals, dm_residuals, noise_realizations)[source]#
Bases:
BaseFitResultResult of a wideband GLS fit.
- Parameters:
params (ParameterVector)
covariance_matrix (Float[Array, 'n_free n_free'])
correlation_matrix (Float[Array, 'n_free n_free'])
parameter_uncertainties (Float[Array, 'n_free'])
chi2 (float)
dof (int)
reduced_chi2 (float)
time_residuals (Float[Array, 'n_toas'])
dm_residuals (Float[Array, 'n_toas'])
noise_realizations (Float[Array, 'n_epochs'] | None)
- time_residuals: Float[Array, 'n_toas']#
- dm_residuals: Float[Array, 'n_toas']#
- class jaxpint.fitters.WidebandGLSFitter(model, toa_data, params, noise_model=None)[source]#
Bases:
BaseFitterWideband Generalised Least Squares fitter.
Jointly fits TOA and DM residuals using a combined
(2N,)residual vector and design matrix. Reuses the same GLS solve routines asGLSFitter.- Parameters:
model (TimingModel)
toa_data (TOAData)
params (ParameterVector)
noise_model (Optional[NoiseModel])
- jaxpint.fitters.compute_chi2(residuals, sigma)[source]#
Compute the weighted chi-squared statistic.
Calculates
sum((residuals / sigma) ** 2).
- jaxpint.fitters.compute_design_matrix(model, toa_data, params)[source]#
Build the design matrix via autodiff.
Computes
jax.jacobianof time residuals w.r.t. all parameter values, then extracts only the free-parameter columns.Following PINT’s convention, the design matrix is negated so that
M[i, j] = -d(time_resid_i) / d(param_j). This ensures that the WLS updatep_new = p_old + dparsreduces residuals.- Parameters:
model (
TimingModel) – JaxPINT timing model.toa_data (
TOAData) – Pre-extracted TOA data.params (
ParameterVector) – Current parameter values. Thefrozen_maskattribute determines which columns (free parameters) appear in the output.
- Returns:
M – Negated Jacobian of time residuals with respect to free parameters.
- Return type:
jax.Array,shape (n_toas,n_free)
- jaxpint.fitters.compute_dm_residuals(model, toa_data, params)[source]#
Compute DM residuals: measured DM - model DM (pc/cm³).
- Parameters:
model (TimingModel)
toa_data (TOAData)
params (ParameterVector)
- Return type:
Float[Array, ‘n_toas’]
- jaxpint.fitters.compute_gls_chi2(residuals, Ndiag, U, Phidiag)[source]#
Compute the GLS chi-squared statistic:
r^T C^{-1} r.The covariance
C = diag(N) + U diag(Phi) U^Tis inverted via the Woodbury identity without forming the full matrix.- Parameters:
residuals (
jax.Array,shape (n_toas,)) – Time residuals in seconds.Ndiag (
jax.Array,shape (n_toas,)) – Diagonal of the white-noise covariance matrix (variances).U (
jax.Array,shape (n_toas,n_epochs)) – Basis matrix for correlated noise components.Phidiag (
jax.Array,shape (n_epochs,)) – Diagonal prior covariance of the correlated noise amplitudes.
- Returns:
chi2 – Scalar GLS chi-squared value.
- Return type:
jax.Array,shape ()
- jaxpint.fitters.compute_phase_residuals(model, toa_data, params)[source]#
Compute phase residuals using nearest-pulse tracking.
Returns the fractional part of the model phase (in cycles), after adjusting for
delta_pulse_numberoffsets stored in the TOA data.- Parameters:
model (
TimingModel) – JaxPINT timing model used to compute pulse phase.toa_data (
TOAData) – Pre-extracted TOA data containing observation times anddelta_pulse_numbercorrections.params (
ParameterVector) – Current parameter values for the timing model.
- Returns:
residuals – Phase residuals in cycles (fractional part of adjusted phase).
- Return type:
jax.Array,shape (n_toas,)
- jaxpint.fitters.compute_time_residuals(model, toa_data, params)[source]#
Compute time residuals in seconds.
Converts phase residuals (cycles) to time by dividing by the spin frequency F0 (Hz).
- Parameters:
model (
TimingModel) – JaxPINT timing model.toa_data (
TOAData) – Pre-extracted TOA data.params (
ParameterVector) – Current parameter values (must includeF0).
- Returns:
residuals – Time residuals in seconds.
- Return type:
jax.Array,shape (n_toas,)
- jaxpint.fitters.compute_wideband_design_matrix(model, toa_data, params)[source]#
Build the wideband design matrix via autodiff, shape
(2N, n_free).Uses
jax.jacobianof the combined[time_resid; dm_resid]vector w.r.t. all parameters, then extracts free columns. Negated per PINT convention.- Parameters:
model (TimingModel)
toa_data (TOAData)
params (ParameterVector)
- Return type:
Float[Array, ‘n2_toas n_free’]
- jaxpint.fitters.compute_wideband_residuals(model, toa_data, params)[source]#
Compute stacked
[time_residuals; dm_residuals], shape(2N,).Time residuals are in seconds, DM residuals in pc/cm³.
- Parameters:
model (TimingModel)
toa_data (TOAData)
params (ParameterVector)
- Return type:
Float[Array, ‘n2_toas’]
- jaxpint.fitters.gls_step_augmented(residuals, Ndiag, U, Phidiag, M, threshold)[source]#
One GLS solve via the augmented design-matrix approach.
Augments the design matrix as
M_aug = [M | U]and solves with diagonal weightingN^{-1}plus a prior on noise amplitudes.- Parameters:
residuals (
jax.Array,shape (n_toas,)) – Time residuals in seconds (GLS-weighted mean already subtracted).Ndiag (
jax.Array,shape (n_toas,)) – Diagonal of the white-noise covariance matrix (variances).U (
jax.Array,shape (n_toas,n_epochs)) – Basis matrix for correlated noise components.Phidiag (
jax.Array,shape (n_epochs,)) – Diagonal prior covariance of the correlated noise amplitudes.M (
jax.Array,shape (n_toas,n_free)) – Design matrix (free-parameter columns only).threshold (
float) – Singular values belowthreshold * S_maxare discarded.
- Returns:
dpars (
jax.Array,shape (n_free,)) – Timing parameter updates.covariance (
jax.Array,shape (n_free,n_free)) – Timing parameter covariance.norms (
jax.Array,shape (n_free + n_epochs,)) – Column norms of the augmented system (diagnostic).noise_realizations (
jax.Array,shape (n_epochs,)) – MAP noise amplitude estimates.
- Return type:
tuple[Float[Array, ‘n_free’], Float[Array, ‘n_free n_free’], Float[Array, ‘n_aug’], Float[Array, ‘n_epochs’]]
- jaxpint.fitters.gls_step_fullcov(residuals, Ndiag, U, Phidiag, M, threshold)[source]#
One GLS solve via full (Woodbury) covariance inversion + SVD.
Computes
M^T C^{-1} MandM^T C^{-1} rusingwoodbury_solve(), then SVD-solves the(n_free, n_free)normal equations.- Parameters:
residuals (
jax.Array,shape (n_toas,)) – Time residuals in seconds (GLS-weighted mean already subtracted).Ndiag (
jax.Array,shape (n_toas,)) – Diagonal of the white-noise covariance matrix (variances).U (
jax.Array,shape (n_toas,n_epochs)) – Basis matrix for correlated noise components.Phidiag (
jax.Array,shape (n_epochs,)) – Diagonal prior covariance of the correlated noise amplitudes.M (
jax.Array,shape (n_toas,n_free)) – Design matrix (free-parameter columns only).threshold (
float) – Singular values belowthreshold * S_maxare discarded.
- Returns:
- Return type:
tuple[Float[Array, ‘n_free’], Float[Array, ‘n_free n_free’], Float[Array, ‘n_free’]]
- jaxpint.fitters.wls_step(residuals, sigma, M, threshold)[source]#
One WLS solve via SVD.
- Parameters:
residuals (
(n_toas,)) – Time residuals in seconds (mean already subtracted).sigma (
(n_toas,)) – TOA uncertainties in seconds.M (
(n_toas,n_free)) – Design matrix.threshold (
float) – Singular values belowthreshold * S_maxare discarded.
- Returns:
dpars (
(n_free,)) – Parameter updates.covariance (
(n_free,n_free)) – Parameter covariance matrix.norms (
(n_free,)) – Column norms used for normalisation (diagnostic).
- Return type:
tuple[Float[Array, ‘n_free’], Float[Array, ‘n_free n_free’], Float[Array, ‘n_free’]]