Source code for jaxpint.phase.piecewise_spindown
"""Piecewise spindown phase component.
Ports PINT's ``PiecewiseSpindown`` class as a pure Equinox module. The
phase is modelled as a Taylor expansion within user-defined time bins:
phase(t) = Σ_n [ PWPH_n + PWF0_n*dt + PWF1_n*dt^2/2! + PWF2_n*dt^3/3! ]
for t in [PWSTART_n, PWSTOP_n) where dt = t - PWEP_n
All derivatives are handled by ``jax.jacobian`` through ``__call__``.
"""
from __future__ import annotations
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Float
from jaxpint.components import PhaseComponent
from jaxpint.constants import SECS_PER_DAY
from jaxpint.dual_float import DualFloat
from jaxpint.types import TOAData, ParameterVector
from jaxpint.utils import taylor_horner
[docs]
class PiecewiseSpindown(PhaseComponent):
"""Piecewise Taylor-expansion spindown model.
Parameters
----------
n_pieces : int
Number of piecewise segments.
pwstart_names : tuple[str, ...]
Names of segment start epoch parameters (MJD).
pwstop_names : tuple[str, ...]
Names of segment stop epoch parameters (MJD).
pwep_names : tuple[str, ...]
Names of segment reference epoch parameters (MJD).
pwph_names : tuple[str, ...]
Names of segment phase offset parameters (dimensionless cycles).
pwf0_names : tuple[str, ...]
Names of segment frequency parameters (Hz).
pwf1_names : tuple[str, ...]
Names of segment frequency derivative parameters (Hz/s).
pwf2_names : tuple[str, ...]
Names of segment second derivative parameters (Hz/s^2).
Raises
------
ValueError
If ``n_pieces`` is less than 1.
ValueError
If the length of any segment parameter name tuple does not match
``n_pieces``.
"""
n_pieces: int = eqx.field(static=True)
pwstart_names: tuple[str, ...] = eqx.field(static=True)
pwstop_names: tuple[str, ...] = eqx.field(static=True)
pwep_names: tuple[str, ...] = eqx.field(static=True)
pwph_names: tuple[str, ...] = eqx.field(static=True)
pwf0_names: tuple[str, ...] = eqx.field(static=True)
pwf1_names: tuple[str, ...] = eqx.field(static=True)
pwf2_names: tuple[str, ...] = eqx.field(static=True)
def __check_init__(self):
if self.n_pieces < 1:
raise ValueError("PiecewiseSpindown requires at least one piece")
for attr in (
"pwstart_names", "pwstop_names", "pwep_names",
"pwph_names", "pwf0_names", "pwf1_names", "pwf2_names",
):
if len(getattr(self, attr)) != self.n_pieces:
raise ValueError(
f"Length of {attr} ({len(getattr(self, attr))}) "
f"does not match n_pieces ({self.n_pieces})"
)
def __call__(
self,
toa_data: TOAData,
params: ParameterVector,
delay: Float[Array, " n_toas"],
) -> DualFloat:
"""Compute piecewise spindown phase contribution.
Parameters
----------
toa_data : TOAData
Pre-extracted TOA data.
params : ParameterVector
Timing-model parameters.
delay : array, shape (n_toas,)
Accumulated signal delay from prior components in seconds.
Returns
-------
DualFloat
Phase contribution in cycles (int + frac split).
"""
toa_tdb = toa_data.tdb.total
phase = jnp.zeros(toa_data.n_toas)
for i in range(self.n_pieces):
# Segment boundaries
start = params.epoch_dual(self.pwstart_names[i]).total
stop = params.epoch_dual(self.pwstop_names[i]).total
affected = (toa_tdb >= start) & (toa_tdb < stop)
# Time since segment epoch (DualFloat precision)
ep = params.epoch_dual(self.pwep_names[i])
dt = (toa_data.tdb - ep).total * SECS_PER_DAY - delay
# Taylor coefficients: [PWPH, PWF0, PWF1, PWF2]
pwph = params.param_value(self.pwph_names[i])
pwf0 = params.param_value(self.pwf0_names[i])
pwf1 = params.param_value(self.pwf1_names[i])
pwf2 = params.param_value(self.pwf2_names[i])
coeffs = jnp.array([pwph, pwf0, pwf1, pwf2])
piece_phase = taylor_horner(dt, coeffs)
phase = phase + jnp.where(affected, piece_phase, 0.0)
return DualFloat.cycles(jnp.zeros(toa_data.n_toas), phase)