jaxpint.types#
Core data types for JaxPINT.
Defines the three foundational types: - DualFloat: Extended-precision value as integer + fractional parts - TOAData: Pre-extracted TOA data as JAX arrays - ParameterVector: Timing model parameters as a flat JAX array with metadata
All types are equinox Modules (automatic JAX pytrees) and are compatible with jax.jit, jax.grad, jax.vmap, etc.
- class jaxpint.types.TOAData(mjd_int, mjd_frac, tdb_int, tdb_frac, error, freq, delta_pulse_number, ssb_obs_pos, ssb_obs_vel, obs_sun_pos, obs_indices, flag_masks, planet_positions, dm_values, dm_errors, tropo_alt, tropo_alt_valid, obs_geodetic_lat, obs_height_km, n_toas, obs_names, tzr_tdb_int=None, tzr_tdb_frac=None, tzr_freq=None, tzr_ssb_obs_pos=None, tzr_obs_sun_pos=None)[source]#
Bases:
ModulePre-extracted TOA data as JAX arrays.
Created by the bridge layer from PINT
TOAsobjects. All astropy units are stripped; see unit conventions below.- Unit conventions (enforced by bridge, not by this class):
mjd_int, mjd_frac: days (integer MJD + fractional day in [0, 1)) tdb_int, tdb_frac: days (TDB timescale, same split) error: seconds freq: MHz (barycentric, Doppler-corrected) ssb_obs_pos: km, shape (n_toas, 3) ssb_obs_vel: km/s, shape (n_toas, 3) obs_sun_pos: km, shape (n_toas, 3) delta_pulse_number: dimensionless (cycles) dm_values, dm_errors: pc/cm^3
- Parameters:
mjd_int (Float[Array, 'n_toas'])
mjd_frac (Float[Array, 'n_toas'])
tdb_int (Float[Array, 'n_toas'])
tdb_frac (Float[Array, 'n_toas'])
error (Float[Array, 'n_toas'])
freq (Float[Array, 'n_toas'])
delta_pulse_number (Float[Array, 'n_toas'])
ssb_obs_pos (Float[Array, 'n_toas 3'])
ssb_obs_vel (Float[Array, 'n_toas 3'])
obs_sun_pos (Float[Array, 'n_toas 3'])
obs_indices (Int[Array, 'n_toas'])
planet_positions (dict[str, Float[Array, 'n_toas 3']] | None)
dm_values (Float[Array, 'n_toas'] | None)
dm_errors (Float[Array, 'n_toas'] | None)
tropo_alt (Float[Array, 'n_toas'] | None)
tropo_alt_valid (Bool[Array, 'n_toas'] | None)
obs_geodetic_lat (Float[Array, 'n_toas'] | None)
obs_height_km (Float[Array, 'n_toas'] | None)
n_toas (int)
tzr_tdb_int (float | None)
tzr_tdb_frac (float | None)
tzr_freq (float | None)
tzr_ssb_obs_pos (Float[Array, '3'] | None)
tzr_obs_sun_pos (Float[Array, '3'] | None)
- mjd_int: Float[Array, 'n_toas']#
- mjd_frac: Float[Array, 'n_toas']#
- tdb_int: Float[Array, 'n_toas']#
- tdb_frac: Float[Array, 'n_toas']#
- error: Float[Array, 'n_toas']#
- freq: Float[Array, 'n_toas']#
- delta_pulse_number: Float[Array, 'n_toas']#
- ssb_obs_pos: Float[Array, 'n_toas 3']#
- ssb_obs_vel: Float[Array, 'n_toas 3']#
- obs_sun_pos: Float[Array, 'n_toas 3']#
- obs_indices: Int[Array, 'n_toas']#
- class jaxpint.types.ParameterVector(values, frozen_mask, names, units, epoch_int_values, _name_to_index=<factory>, _free_indices=())[source]#
Bases:
ModuleTiming model parameters as a flat JAX array with metadata.
Stores ALL parameters (free and frozen) in a single array. Epoch-type parameters (PEPOCH, T0, TASC, POSEPOCH, GLEP_*) are split: the integer MJD day is in
epoch_int_values(static, not differentiated) and only the fractional day is invalues. The bridge layer handles splitting on input and recombining on output.Pytree: only
valuesis a dynamic leaf (participates in jax.grad). All other fields are static metadata frozen into JIT traces.Unit conventions#
All values are stored as raw float64 in a fixed internal unit system. Components assume these units unconditionally – no runtime conversion.
Parameter(s)
Unit
Angles (RAJ, DECJ, OM, KIN, KOM, ELONG, ELAT)
radians
Angular rates (OMDOT, XOMDOT)
rad/s
Proper motion (PMRA, PMDEC, PMELONG, PMELAT)
mas/yr
Parallax (PX)
mas
Epochs (PEPOCH, T0, TASC, POSEPOCH, …)
frac day
Spin frequency (F0, F1, F2, …)
Hz/s^N
Dispersion (DM, DM1,
DMX_*, CM,CMX_*)pc/cm^3
Orbital period (PB)
day
Projected semi-major axis (A1)
ls
Companion mass (M2, MTOT)
Msun
TOA error scaling (EQUAD, ECORR)
seconds
Frequencies (
WXFREQ_*,DMWXFREQ_*, …)1/day
Delay amplitudes (
WXSIN_*,WXCOS_*,FD*, …)seconds
Dimensionless (EFAC, SINI, ECC, STIGMA, …)
–
Everything else
.par native
Epoch integer MJD days are stored separately in
epoch_int_valuesto preserve precision; only the fractional day entersvalues.- raises ValueError:
If lengths of
frozen_mask,units, orvaluesdon’t matchnames, or ifepoch_int_valuescontains keys not innames.
- values: Float[Array, 'n_params']#
- param_value(name)[source]#
Value of a single parameter. JIT-compatible if
nameis a static string.- Parameters:
name (
str) – Parameter name.- Returns:
The parameter’s current value.
- Return type:
scalar
- param_value_or(name, default=0.0)[source]#
Value of a parameter if name is not None, otherwise default.
Convenient for optional parameters stored as
Optional[str]field names on components:pbdot = params.param_value_or(self.pbdot_name, 0.0)
- epoch_value(name)[source]#
For epoch parameters: returns (integer_mjd_day, fractional_day).
The full MJD is
int_day + frac_day. Onlyfrac_dayis differentiable.
- epoch_dual(name)[source]#
For epoch parameters: returns a DualFloat(integer_mjd, fractional_day).
The full MJD is
result.int + result.frac. Only the fractional part is differentiable.- Parameters:
name (
str) – Epoch parameter name (e.g."PEPOCH","T0").- Returns:
Extended-precision epoch value.
- Return type:
DualFloat
- free_mask_array()[source]#
Boolean array: True where parameter is free (not frozen).
- Return type:
Bool[Array, ‘n_params’]
- free_indices_array()[source]#
Integer indices of free parameters as a JAX array (JIT-safe).
- Return type:
Int[Array, ‘n_free’]
- free_values()[source]#
Extract values of free (unfrozen) parameters.
- Return type:
Float[Array, ‘n_free’]
- with_free_values(new_free)[source]#
Return a new ParameterVector with free parameter values replaced.
- Parameters:
new_free (
array,shape (n_free,)) – Replacement values for the free (unfrozen) parameters.- Returns:
Copy with updated free-parameter values.
- Return type: