Source code for jaxpint.fitters.wideband

"""Wideband Generalised Least Squares fitter."""
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from jaxpint.model import TimingModel
from jaxpint.noise import NoiseModel
from jaxpint.types import TOAData, ParameterVector

from ._base import (
    BaseFitter,
    BaseFitResult,
    compute_time_residuals,
    _subtract_weighted_mean,
    wls_step,
    compute_chi2,
)
from .gls import (
    _subtract_gls_weighted_mean,
    gls_step_fullcov,
    gls_step_augmented,
    compute_gls_chi2,
)


# ---------------------------------------------------------------------------
# Wideband residual / design-matrix functions
# ---------------------------------------------------------------------------


[docs] def compute_dm_residuals( model: TimingModel, toa_data: TOAData, params: ParameterVector, ) -> Float[Array, " n_toas"]: """Compute DM residuals: measured DM - model DM (pc/cm³).""" model_dm = model.compute_dm(toa_data, params) return toa_data.dm_values - model_dm
[docs] def compute_wideband_residuals( model: TimingModel, toa_data: TOAData, params: ParameterVector, ) -> Float[Array, " n2_toas"]: """Compute stacked ``[time_residuals; dm_residuals]``, shape ``(2N,)``. Time residuals are in seconds, DM residuals in pc/cm³. """ time_resid = compute_time_residuals(model, toa_data, params) dm_resid = compute_dm_residuals(model, toa_data, params) return jnp.concatenate([time_resid, dm_resid])
[docs] def compute_wideband_design_matrix( model: TimingModel, toa_data: TOAData, params: ParameterVector, ) -> Float[Array, "n2_toas n_free"]: """Build the wideband design matrix via autodiff, shape ``(2N, n_free)``. Uses ``jax.jacobian`` of the combined ``[time_resid; dm_resid]`` vector w.r.t. all parameters, then extracts free columns. Negated per PINT convention. """ J, M = _compute_wideband_jacobian_and_design(model, toa_data, params) return M
@eqx.filter_jit def _compute_wideband_jacobian_and_design( model: TimingModel, toa_data: TOAData, params: ParameterVector, ) -> tuple[Float[Array, "n2_toas n_params"], Float[Array, "n2_toas n_free"]]: """JIT-compiled wideband Jacobian and design matrix computation.""" free_indices = params.free_indices_array() def combined_resid_fn(all_values: Float[Array, " n_params"]): p = eqx.tree_at(lambda pv: pv.values, params, all_values) return compute_wideband_residuals(model, toa_data, p) J = jax.jacobian(combined_resid_fn)(params.values) # (2N, n_params) M = -J[:, free_indices] return J, M # --------------------------------------------------------------------------- # JIT-compiled iteration core # --------------------------------------------------------------------------- @eqx.filter_jit def _wideband_iteration_core( model: TimingModel, toa_data: TOAData, params: ParameterVector, noise_model: Optional[NoiseModel], threshold: float, full_cov: bool, ) -> tuple[ Float[Array, " n_params"], Float[Array, "n_free n_free"], Float[Array, " n_basis"], ]: """JIT-compiled core of one wideband GLS Gauss-Newton iteration.""" free_indices = params.free_indices_array() n = toa_data.n_toas # Noise if noise_model is not None: sigma_toa = noise_model.scaled_sigma(toa_data, params) Ndiag_toa, U_toa, Phi_toa, Ndiag_dm = ( noise_model.wideband_covariance(toa_data, params) ) else: sigma_toa = toa_data.error Ndiag_toa = sigma_toa ** 2 U_toa = jnp.zeros((n, 0)) Phi_toa = jnp.zeros(0) Ndiag_dm = toa_data.dm_errors ** 2 Ndiag = jnp.concatenate([Ndiag_toa, Ndiag_dm]) U = jnp.concatenate([U_toa, jnp.zeros((n, U_toa.shape[1]))], axis=0) Phidiag = Phi_toa n_basis = U.shape[1] # Residuals time_resid = compute_time_residuals(model, toa_data, params) dm_resid = compute_dm_residuals(model, toa_data, params) if n_basis > 0: time_resid = _subtract_gls_weighted_mean( time_resid, Ndiag_toa, U_toa, Phidiag ) else: time_resid = _subtract_weighted_mean(time_resid, sigma_toa) residuals = jnp.concatenate([time_resid, dm_resid]) # Design matrix def combined_resid_fn(all_values: Float[Array, " n_params"]): p = eqx.tree_at(lambda pv: pv.values, params, all_values) return compute_wideband_residuals(model, toa_data, p) J = jax.jacobian(combined_resid_fn)(params.values) M = -J[:, free_indices] # Solve noise_realizations = jnp.zeros(0) if full_cov: dpars, covariance, _norms = gls_step_fullcov( residuals, Ndiag, U, Phidiag, M, threshold ) elif n_basis > 0: dpars, covariance, _norms, noise_realizations = ( gls_step_augmented( residuals, Ndiag, U, Phidiag, M, threshold ) ) else: sigma_combined = jnp.sqrt(Ndiag) dpars, covariance, _norms = wls_step( residuals, sigma_combined, M, threshold ) new_values = params.values.at[free_indices].set( params.values[free_indices] + dpars ) return new_values, covariance, noise_realizations # --------------------------------------------------------------------------- # Wideband GLS result container # ---------------------------------------------------------------------------
[docs] @dataclass class WidebandGLSFitResult(BaseFitResult): """Result of a wideband GLS fit.""" time_residuals: Float[Array, " n_toas"] dm_residuals: Float[Array, " n_toas"] noise_realizations: Optional[Float[Array, " n_epochs"]]
# --------------------------------------------------------------------------- # Wideband GLS fitter # ---------------------------------------------------------------------------
[docs] class WidebandGLSFitter(BaseFitter): """Wideband Generalised Least Squares fitter. Jointly fits TOA and DM residuals using a combined ``(2N,)`` residual vector and design matrix. Reuses the same GLS solve routines as :class:`GLSFitter`. """ def _get_wideband_noise( self, params: ParameterVector, ) -> tuple[ Float[Array, " n_toas"], # sigma_toa Float[Array, " n2_toas"], # Ndiag (2N) Float[Array, "n2_toas n_basis"], # U (2N, K) Float[Array, " n_basis"], # Phidiag Float[Array, " n_toas"], # sigma_dm ]: """Return wideband noise quantities.""" n = self.toa_data.n_toas if self.noise_model is not None: sigma_toa = self.noise_model.scaled_sigma(self.toa_data, params) Ndiag_toa, U_toa, Phi_toa, Ndiag_dm = ( self.noise_model.wideband_covariance(self.toa_data, params) ) sigma_dm = self.noise_model.scaled_dm_sigma(self.toa_data, params) else: sigma_toa = self.toa_data.error Ndiag_toa = sigma_toa ** 2 U_toa = jnp.zeros((n, 0)) Phi_toa = jnp.zeros(0) Ndiag_dm = self.toa_data.dm_errors ** 2 sigma_dm = self.toa_data.dm_errors # Stack into (2N,) diagonal and (2N, K) basis Ndiag = jnp.concatenate([Ndiag_toa, Ndiag_dm]) U = jnp.concatenate([U_toa, jnp.zeros((n, U_toa.shape[1]))], axis=0) return sigma_toa, Ndiag, U, Phi_toa, sigma_dm def _iteration( self, params: ParameterVector, threshold: float, full_cov: bool, ) -> tuple[ ParameterVector, Float[Array, "n_free n_free"], Optional[Float[Array, " n_basis"]], ]: """Run one Gauss-Newton iteration.""" new_values, covariance, noise_real = _wideband_iteration_core( self.model, self.toa_data, params, self.noise_model, threshold, full_cov, ) new_params = eqx.tree_at(lambda pv: pv.values, params, new_values) noise_realizations = noise_real if noise_real.size > 0 else None return new_params, covariance, noise_realizations def _build_result( self, params: ParameterVector, covariance: Float[Array, "n_free n_free"], noise_realizations: Optional[Float[Array, " n_basis"]], ) -> WidebandGLSFitResult: """Compute final residuals/chi2 and return a result object.""" sigma_toa, Ndiag, U, Phidiag, sigma_dm = self._get_wideband_noise( params ) n_basis = U.shape[1] n = self.toa_data.n_toas time_resid = compute_time_residuals( self.model, self.toa_data, params ) dm_resid = compute_dm_residuals(self.model, self.toa_data, params) if n_basis > 0: Ndiag_toa = Ndiag[:n] U_toa = U[:n, :] time_resid = _subtract_gls_weighted_mean( time_resid, Ndiag_toa, U_toa, Phidiag ) else: time_resid = _subtract_weighted_mean(time_resid, sigma_toa) residuals = jnp.concatenate([time_resid, dm_resid]) if n_basis > 0: chi2_val = float(compute_gls_chi2(residuals, Ndiag, U, Phidiag)) else: sigma_combined = jnp.sqrt(Ndiag) chi2_val = float(compute_chi2(residuals, sigma_combined)) dof = 2 * n - params.n_free errors, correlation = self._covariance_to_correlation(covariance) reduced_chi2 = self._reduced_chi2(chi2_val, dof) return WidebandGLSFitResult( params=params, covariance_matrix=covariance, correlation_matrix=correlation, parameter_uncertainties=errors, chi2=chi2_val, dof=dof, reduced_chi2=reduced_chi2, time_residuals=time_resid, dm_residuals=dm_resid, noise_realizations=noise_realizations, )
[docs] def fit_toas( self, maxiter: int = 1, threshold: Optional[float] = None, full_cov: bool = False, ) -> WidebandGLSFitResult: """Run the wideband GLS fit. Parameters ---------- maxiter : int Number of Gauss-Newton iterations. threshold : float, optional SVD threshold (default ``1e-14 * dim``). full_cov : bool If True, use Woodbury-based full covariance inversion. If False (default), use the augmented design-matrix approach. Returns ------- WidebandGLSFitResult """ n_toas = self.toa_data.n_toas if self.noise_model is not None and self.noise_model.has_correlated: _, _, U_init, _ = self.noise_model.covariance( self.toa_data, self.params ) n_basis = U_init.shape[1] else: n_basis = 0 if threshold is None: if full_cov: dim = self.params.n_free else: dim = self.params.n_free + n_basis threshold = 1e-14 * max(2 * n_toas, dim) safe_maxiter = 1 if maxiter < 1 else maxiter params = self.params covariance = None noise_realizations = None for _ in range(safe_maxiter): params, covariance, noise_realizations = self._iteration( params, threshold, full_cov ) return self._build_result(params, covariance, noise_realizations)