jaxpint.bridge#

Bridge layer: converts PINT objects to JaxPINT JAX-native types.

PINT’s role is purely I/O: .par/.tim parsing, observatory database, clock corrections, ephemeris lookups, and coordinate transforms. JaxPINT owns all numerical computation. This module is the boundary – the only place that touches astropy units. It runs once per fit setup; after conversion everything is convention-based float64 arrays (see Plans/Units.md for the unit contract).

jaxpint.bridge.build_timing_model(pint_model, toas=None)[source]#

Construct a JaxPINT TimingModel from a PINT model.

Converts the PINT model to a ParResult and optional TOAData, then delegates to jaxpint.bridge._model_builder.build_model().

Parameters:
  • pint_model (pint.models.TimingModel) – The PINT timing model to convert.

  • toas (pint.toa.TOAs, optional) – If provided, TOA-dependent components (ECORR, red noise, etc.) will be constructed.

Returns:

The timing model and a NoiseModel that aggregates all noise sources (white noise and correlated components).

Return type:

(TimingModel, NoiseModel)

jaxpint.bridge.extract_tzr_toa(model, toas)[source]#

Extract the TZR TOA data from PINT’s AbsPhase component.

If the model does not already have an AbsPhase component, one is auto-generated from the TOAs (first TOA after PEPOCH), matching PINT’s guarantee in timing_model.phase().

Parameters:
  • model (pint.models.TimingModel) – PINT timing model, which should contain an AbsPhase component (or one will be auto-generated).

  • toas (pint.toa.TOAs) – The TOA set used to generate the TZR TOA if AbsPhase is absent.

Returns:

Dictionary with keys:

  • tdb_int (float) – Integer MJD day of the TZR TOA in TDB.

  • tdb_frac (float) – Fractional MJD day of the TZR TOA in TDB.

  • freq (float) – Observing frequency in MHz.

  • ssb_obs_pos (numpy.ndarray, shape (3,)) – SSB observer position in km.

  • obs_sun_pos (numpy.ndarray, shape (3,)) – Observer-to-Sun position in km (zeros for barycentric observations).

Return type:

dict

jaxpint.bridge.params_to_pint_model(params, model)[source]#

Write JaxPINT parameter values back into a PINT TimingModel.

Modifies model in-place and returns it. The caller should copy the model first (copy.deepcopy(model)) if the original must be preserved.

Parameters:
  • params (ParameterVector) – The (possibly fitted) parameter values.

  • model (pint.models.TimingModel) – The PINT model to update.

Returns:

The same model instance, modified in-place with the updated parameter values (angles converted back to degrees, epochs reconstructed from integer + fractional day, etc.).

Return type:

pint.models.TimingModel

jaxpint.bridge.pint_model_to_params(model)[source]#

Convert a PINT TimingModel to a JaxPINT ParResult.

Iterates all parameters, skipping non-numeric types (str, bool, int, func) for the ParameterVector but collecting them into metadata, bool_params, and int_params dicts. MJD epochs are split into a static integer day and a dynamic fractional day. Angles are converted to radians.

Parameters:

model (pint.models.TimingModel) – The timing model to extract parameters from.

Returns:

A container holding the ParameterVector, the set of detected timing-model components, optional binary model identifier, string metadata, mask info, integer parameters, and boolean parameters extracted from the PINT model.

Return type:

ParResult

jaxpint.bridge.pint_toas_to_jax(toas, model=None)[source]#

Convert PINT TOAs to a JaxPINT TOAData.

All unit conversion and validation happens here. After this call everything is raw float64 JAX arrays following the conventions documented in jaxpint.types.TOAData.

Parameters:
  • toas (pint.toa.TOAs) – Must already have compute_TDBs() and compute_posvels() called (or this function calls them).

  • model (pint.models.TimingModel, optional) – If provided, pre-computes boolean flag masks for all maskParameter instances (JUMP, EFAC, EQUAD, DMX, etc.) and extracts TZR TOA data for absolute phase computation.

Returns:

A frozen container of JAX float64 arrays holding MJD times, TDB times, uncertainties, frequencies, SSB positions/velocities, observatory indices, flag masks, and optional fields (planet positions, wideband DM, troposphere data, TZR TOA).

Return type:

TOAData

class jaxpint.bridge.ParResult(params, component_set=<factory>, binary_model=None, metadata=<factory>, mask_info=<factory>, int_params=<factory>, bool_params=<factory>)[source]#

Bases: object

Complete result of converting a timing model to JaxPINT’s internal format.

Produced by jaxpint.bridge.model_conversion.pint_model_to_params() and consumed by jaxpint.bridge._model_builder.build_model().

Parameters:
params: ParameterVector#
component_set: set[Component]#
binary_model: BinaryModel | None = None#
metadata: dict[str, str]#
mask_info: dict[str, MaskInfo]#
int_params: dict[str, int]#
bool_params: dict[str, bool]#