jaxpint.types#

Core data types for JaxPINT.

Defines the three foundational types: - DualFloat: Extended-precision value as integer + fractional parts - TOAData: Pre-extracted TOA data as JAX arrays - ParameterVector: Timing model parameters as a flat JAX array with metadata

All types are equinox Modules (automatic JAX pytrees) and are compatible with jax.jit, jax.grad, jax.vmap, etc.

class jaxpint.types.TOAData(mjd_int, mjd_frac, tdb_int, tdb_frac, error, freq, delta_pulse_number, ssb_obs_pos, ssb_obs_vel, obs_sun_pos, obs_indices, flag_masks, planet_positions, dm_values, dm_errors, tropo_alt, tropo_alt_valid, obs_geodetic_lat, obs_height_km, n_toas, obs_names, tzr_tdb_int=None, tzr_tdb_frac=None, tzr_freq=None, tzr_ssb_obs_pos=None, tzr_obs_sun_pos=None)[source]#

Bases: Module

Pre-extracted TOA data as JAX arrays.

Created by the bridge layer from PINT TOAs objects. All astropy units are stripped; see unit conventions below.

Unit conventions (enforced by bridge, not by this class):

mjd_int, mjd_frac: days (integer MJD + fractional day in [0, 1)) tdb_int, tdb_frac: days (TDB timescale, same split) error: seconds freq: MHz (barycentric, Doppler-corrected) ssb_obs_pos: km, shape (n_toas, 3) ssb_obs_vel: km/s, shape (n_toas, 3) obs_sun_pos: km, shape (n_toas, 3) delta_pulse_number: dimensionless (cycles) dm_values, dm_errors: pc/cm^3

Parameters:
  • mjd_int (Float[Array, 'n_toas'])

  • mjd_frac (Float[Array, 'n_toas'])

  • tdb_int (Float[Array, 'n_toas'])

  • tdb_frac (Float[Array, 'n_toas'])

  • error (Float[Array, 'n_toas'])

  • freq (Float[Array, 'n_toas'])

  • delta_pulse_number (Float[Array, 'n_toas'])

  • ssb_obs_pos (Float[Array, 'n_toas 3'])

  • ssb_obs_vel (Float[Array, 'n_toas 3'])

  • obs_sun_pos (Float[Array, 'n_toas 3'])

  • obs_indices (Int[Array, 'n_toas'])

  • flag_masks (dict[str, Bool[Array, 'n_toas']])

  • planet_positions (dict[str, Float[Array, 'n_toas 3']] | None)

  • dm_values (Float[Array, 'n_toas'] | None)

  • dm_errors (Float[Array, 'n_toas'] | None)

  • tropo_alt (Float[Array, 'n_toas'] | None)

  • tropo_alt_valid (Bool[Array, 'n_toas'] | None)

  • obs_geodetic_lat (Float[Array, 'n_toas'] | None)

  • obs_height_km (Float[Array, 'n_toas'] | None)

  • n_toas (int)

  • obs_names (tuple[str, ...])

  • tzr_tdb_int (float | None)

  • tzr_tdb_frac (float | None)

  • tzr_freq (float | None)

  • tzr_ssb_obs_pos (Float[Array, '3'] | None)

  • tzr_obs_sun_pos (Float[Array, '3'] | None)

mjd_int: Float[Array, 'n_toas']#
mjd_frac: Float[Array, 'n_toas']#
tdb_int: Float[Array, 'n_toas']#
tdb_frac: Float[Array, 'n_toas']#
error: Float[Array, 'n_toas']#
freq: Float[Array, 'n_toas']#
delta_pulse_number: Float[Array, 'n_toas']#
ssb_obs_pos: Float[Array, 'n_toas 3']#
ssb_obs_vel: Float[Array, 'n_toas 3']#
obs_sun_pos: Float[Array, 'n_toas 3']#
obs_indices: Int[Array, 'n_toas']#
flag_masks: dict[str, Bool[Array, 'n_toas']]#
planet_positions: dict[str, Float[Array, 'n_toas 3']] | None#
dm_values: Float[Array, 'n_toas'] | None#
dm_errors: Float[Array, 'n_toas'] | None#
tropo_alt: Float[Array, 'n_toas'] | None#
tropo_alt_valid: Bool[Array, 'n_toas'] | None#
obs_geodetic_lat: Float[Array, 'n_toas'] | None#
obs_height_km: Float[Array, 'n_toas'] | None#
n_toas: int#
obs_names: tuple[str, ...]#
property tdb: DualFloat#

TDB timestamp as a DualFloat (int day + fractional day).

property mjd: DualFloat#

MJD timestamp as a DualFloat (int day + fractional day).

tzr_tdb_int: float | None = None#
tzr_tdb_frac: float | None = None#
tzr_freq: float | None = None#
tzr_ssb_obs_pos: Float[Array, '3'] | None = None#
tzr_obs_sun_pos: Float[Array, '3'] | None = None#
class jaxpint.types.ParameterVector(values, frozen_mask, names, units, epoch_int_values, _name_to_index=<factory>, _free_indices=())[source]#

Bases: Module

Timing model parameters as a flat JAX array with metadata.

Stores ALL parameters (free and frozen) in a single array. Epoch-type parameters (PEPOCH, T0, TASC, POSEPOCH, GLEP_*) are split: the integer MJD day is in epoch_int_values (static, not differentiated) and only the fractional day is in values. The bridge layer handles splitting on input and recombining on output.

Pytree: only values is a dynamic leaf (participates in jax.grad). All other fields are static metadata frozen into JIT traces.

Unit conventions#

All values are stored as raw float64 in a fixed internal unit system. Components assume these units unconditionally – no runtime conversion.

Parameter(s)

Unit

Angles (RAJ, DECJ, OM, KIN, KOM, ELONG, ELAT)

radians

Angular rates (OMDOT, XOMDOT)

rad/s

Proper motion (PMRA, PMDEC, PMELONG, PMELAT)

mas/yr

Parallax (PX)

mas

Epochs (PEPOCH, T0, TASC, POSEPOCH, …)

frac day

Spin frequency (F0, F1, F2, …)

Hz/s^N

Dispersion (DM, DM1, DMX_*, CM, CMX_*)

pc/cm^3

Orbital period (PB)

day

Projected semi-major axis (A1)

ls

Companion mass (M2, MTOT)

Msun

TOA error scaling (EQUAD, ECORR)

seconds

Frequencies (WXFREQ_*, DMWXFREQ_*, …)

1/day

Delay amplitudes (WXSIN_*, WXCOS_*, FD*, …)

seconds

Dimensionless (EFAC, SINI, ECC, STIGMA, …)

Everything else

.par native

Epoch integer MJD days are stored separately in epoch_int_values to preserve precision; only the fractional day enters values.

raises ValueError:

If lengths of frozen_mask, units, or values don’t match names, or if epoch_int_values contains keys not in names.

values: Float[Array, 'n_params']#
frozen_mask: tuple[bool, ...]#
names: tuple[str, ...]#
units: tuple[str, ...]#
epoch_int_values: dict[str, float]#
param_index(name)[source]#

Index of parameter name in the values array.

Parameters:

name (str) – Parameter name.

Returns:

Zero-based index into values.

Return type:

int

param_value(name)[source]#

Value of a single parameter. JIT-compatible if name is a static string.

Parameters:

name (str) – Parameter name.

Returns:

The parameter’s current value.

Return type:

scalar

param_value_or(name, default=0.0)[source]#

Value of a parameter if name is not None, otherwise default.

Convenient for optional parameters stored as Optional[str] field names on components:

pbdot = params.param_value_or(self.pbdot_name, 0.0)
Parameters:
epoch_value(name)[source]#

For epoch parameters: returns (integer_mjd_day, fractional_day).

The full MJD is int_day + frac_day. Only frac_day is differentiable.

Parameters:

name (str) – Epoch parameter name (e.g. "PEPOCH", "T0").

Returns:

(integer_mjd_day, fractional_day).

Return type:

tuple of (float, scalar)

epoch_dual(name)[source]#

For epoch parameters: returns a DualFloat(integer_mjd, fractional_day).

The full MJD is result.int + result.frac. Only the fractional part is differentiable.

Parameters:

name (str) – Epoch parameter name (e.g. "PEPOCH", "T0").

Returns:

Extended-precision epoch value.

Return type:

DualFloat

free_mask_array()[source]#

Boolean array: True where parameter is free (not frozen).

Return type:

Bool[Array, ‘n_params’]

free_indices_array()[source]#

Integer indices of free parameters as a JAX array (JIT-safe).

Return type:

Int[Array, ‘n_free’]

free_values()[source]#

Extract values of free (unfrozen) parameters.

Return type:

Float[Array, ‘n_free’]

free_names()[source]#

Names of free parameters (Python-level, not JIT-compatible).

Return type:

tuple[str, …]

with_free_values(new_free)[source]#

Return a new ParameterVector with free parameter values replaced.

Parameters:

new_free (array, shape (n_free,)) – Replacement values for the free (unfrozen) parameters.

Returns:

Copy with updated free-parameter values.

Return type:

ParameterVector

with_value(name, val)[source]#

Return a new ParameterVector with one parameter updated.

Parameters:
  • name (str) – Parameter name.

  • val (float) – New value.

Returns:

Copy with the specified parameter updated.

Return type:

ParameterVector

property n_params: int#
property n_free: int#
Parameters: