jaxpint.pta#

PTA likelihood module for JaxPINT.

Composes jaxpint.likelihood.single_pulsar_logL() across multiple pulsars with shared signal injections (CW sources, GWB, etc.).

class jaxpint.pta.GlobalParams(values, names, _name_to_index)[source]#

Bases: Module

Shared parameter container for PTA-level signals.

An Equinox module where values is the only dynamic (traced) leaf. Static metadata (names, index mapping) is fixed at construction time.

Build incrementally via add_params() or the builder pattern:

gp = GlobalParams.empty()
for inj in signal_injectors:
    gp = inj.register_params(gp)
Parameters:
  • values (Float[Array, 'n_global'])

  • names (tuple[str, ...])

  • _name_to_index (dict[str, int])

values: Float[Array, 'n_global']#
names: tuple[str, ...]#
static empty()[source]#

Create an empty GlobalParams with no parameters.

Return type:

GlobalParams

add_params(names, values)[source]#

Append new parameters, returning a new GlobalParams.

Parameters:
  • names (list[str]) – Parameter names to add.

  • values (list[float]) – Initial values for each parameter (same length as names).

Returns:

New instance with the appended parameters.

Return type:

GlobalParams

Raises:

ValueError – If names and values have different lengths, or if any name is already present (prevents silent overwrites from duplicate injectors or prefix collisions).

param_value(name)[source]#

Look up a single parameter value by name.

Parameters:

name (str)

Return type:

Float[Array, ‘’]

with_value(name, val)[source]#

Return a copy with one parameter replaced.

Parameters:
Return type:

GlobalParams

property n_params: int#

Total number of parameters.

class jaxpint.pta.PTAConfig(toa_data_list, timing_models, noise_models, signal_injectors)[source]#

Bases: Module

Static (non-differentiated) configuration for PTA likelihood evaluation.

All fields are marked static=True so that JAX treats the entire config as a compile-time constant.

Raises ValueError at construction if the per-pulsar tuples have mismatched lengths.

Parameters:
toa_data_list: tuple[TOAData, ...]#
timing_models: tuple[TimingModel, ...]#
noise_models: tuple[NoiseModel, ...]#
signal_injectors: tuple[SignalInjector, ...]#
property n_pulsars: int#

Number of pulsars in this PTA configuration.

Returns:

Length of toa_data_list.

Return type:

int

class jaxpint.pta.SignalInjector[source]#

Bases: ABC

Abstract base class for PTA signal components.

Each injector:

  1. Registers its own parameters into GlobalParams via register_params() (required — abstract).

  2. Produces delay arrays and/or covariance (U, Phi) tuples per pulsar via delay() / covariance() (optional — default implementations return None).

Subclasses must implement register_params(). Override delay() for deterministic signals (e.g. CW) and/or covariance() for stochastic signals (e.g. GWB).

pta_logL() is agnostic to the signal type.

abstractmethod register_params(global_params)[source]#

Append this signal’s parameters to global_params.

Parameters:

global_params (GlobalParams) – Mutable accumulator of shared PTA parameters.

Returns:

Updated copy with this signal’s parameters appended.

Return type:

GlobalParams

delay(p, toa_data, pulsar_params, global_params)[source]#

Return deterministic delay for pulsar p, or None.

Override for deterministic signals. The default returns None (no delay contribution).

Parameters:
  • p (int) – Pulsar index within the PTA.

  • toa_data (TOAData) – Pulse time-of-arrival data for pulsar p.

  • pulsar_params (ParameterVector) – Timing and noise parameters for pulsar p.

  • global_params (GlobalParams) – Shared PTA parameters (CW source properties, GWB spectrum, etc.).

Returns:

Deterministic timing delay in seconds, or None if this injector does not contribute a delay.

Return type:

(n_toas,) array or None

covariance(p, toa_data, pulsar_params, global_params)[source]#

Return (U, Phi) covariance contribution for pulsar p, or None.

Override for stochastic signals. The default returns None (no covariance contribution).

Parameters:
  • p (int) – Pulsar index within the PTA.

  • toa_data (TOAData) – Pulse time-of-arrival data for pulsar p.

  • pulsar_params (ParameterVector) – Timing and noise parameters for pulsar p.

  • global_params (GlobalParams) – Shared PTA parameters (CW source properties, GWB spectrum, etc.).

Returns:

Design matrix U and diagonal PSD vector Phi, or None if this injector does not contribute covariance.

Return type:

tuple of ((n_toas, n_basis) array, (n_basis,) array) or None

jaxpint.pta.pta_logL(global_params, pulsar_params, config)[source]#

Multi-pulsar log-likelihood with signal injections.

For each pulsar, collects delay and covariance contributions from every SignalInjector in config, then delegates to jaxpint.likelihood.single_pulsar_logL().

Parameters:
  • global_params (GlobalParams) – Shared parameters (CW source properties, GWB spectrum, etc.). This is the first differentiable argument.

  • pulsar_params (tuple of ParameterVector) – Per-pulsar timing and noise parameters. This is the second differentiable argument.

  • config (PTAConfig) – Static configuration (TOA data, models, injectors).

Returns:

logL – Sum of per-pulsar log-likelihoods.

Return type:

scalar

class jaxpint.pta.CorrelatedPTAConfig(toa_data_list, timing_models, noise_models, signal_injectors, correlated_injectors)[source]#

Bases: Module

Static configuration for the correlated PTA likelihood.

Extends PTAConfig with a correlated_injectors field for cross-pulsar signals.

All fields are static=True (compile-time constants for JAX).

Parameters:
toa_data_list: tuple[TOAData, ...]#
timing_models: tuple[TimingModel, ...]#
noise_models: tuple[NoiseModel, ...]#
signal_injectors: tuple[SignalInjector, ...]#
correlated_injectors: tuple[CorrelatedSignalInjector, ...]#
property n_pulsars: int#
class jaxpint.pta.CorrelatedSignalInjector[source]#

Bases: ABC

Abstract base class for cross-pulsar correlated signal components.

Unlike SignalInjector, which produces per-pulsar covariance contributions, a CorrelatedSignalInjector provides the ingredients to build a PTA-wide covariance matrix with inter-pulsar correlations.

abstractmethod register_params(global_params)[source]#

Append this signal’s parameters to global_params.

Parameters:

global_params (GlobalParams) – Accumulator of shared PTA parameters.

Returns:

Updated copy with this signal’s parameters appended.

Return type:

GlobalParams

abstractmethod get_fourier_basis(toa_data)[source]#

Return the Fourier design matrix for a single pulsar.

Parameters:

toa_data (TOAData) – Pulse time-of-arrival data for one pulsar.

Returns:

F – Fourier design matrix (sin/cos columns).

Return type:

(n_toas, n_basis) array

abstractmethod get_psd(global_params)[source]#

Return the GWB power spectral density vector.

Parameters:

global_params (GlobalParams) – Shared PTA parameters (amplitude, spectral index, etc.).

Returns:

S – PSD values for each Fourier basis function (sin and cos each get the same value for their frequency).

Return type:

(n_basis,) array

abstractmethod get_orf_matrix()[source]#

Return the overlap reduction function matrix.

The matrix must be invertible (full rank). Rank-deficient ORFs such as the monopole (all ones) are not supported by the two-tier Woodbury scheme.

Returns:

Gamma – Symmetric, positive-definite ORF matrix. Gamma[a, b] is the correlation coefficient between pulsars a and b.

Return type:

(n_psr, n_psr) array

jaxpint.pta.pta_logL_correlated(global_params, pulsar_params, config)[source]#

Multi-pulsar log-likelihood with cross-pulsar correlations.

Implements a two-tier Woodbury scheme:

  1. Per-pulsar (inner tier): compute C_p^{-1} r_p, C_p^{-1} F_p, and per-pulsar log-likelihood contributions.

  2. Cross-pulsar (outer tier): assemble and solve the compressed Sigma_gwb system to account for ORF-mediated correlations.

Parameters:
  • global_params (GlobalParams) – Shared parameters (GWB amplitude/spectral index, CW source, etc.).

  • pulsar_params (tuple of ParameterVector) – Per-pulsar timing and noise parameters.

  • config (CorrelatedPTAConfig) – Static configuration.

Returns:

logL – Log-likelihood value.

Return type:

scalar

jaxpint.pta.fisher_matrix(global_params, pulsar_params, config)[source]#

Compute the Fisher information matrix via jax.hessian.

Internally flattens all parameters into a single array, wraps pta_logL() in a closure that unflattens, and calls jax.hessian on the flat function.

Parameters:
  • global_params (GlobalParams) – Current global parameter values (evaluation point).

  • pulsar_params (tuple of ParameterVector) – Current per-pulsar parameter values (evaluation point).

  • config (PTAConfig) – Static PTA configuration.

Returns:

fisher – Fisher matrix, where n_total = n_global + sum(n_pp_i).

Return type:

(n_total, n_total) array

jaxpint.pta.flatten_params(global_params, pulsar_params)[source]#

Pack all differentiable parameters into a single flat array.

Layout: [global_params.values | pp[0].values | pp[1].values | ...]

Parameters:
  • global_params (GlobalParams) – Shared PTA parameters.

  • pulsar_params (tuple of ParameterVector) – Per-pulsar timing and noise parameters.

Returns:

flat – Concatenated parameter values, where n_total = n_global + sum(n_pp_i).

Return type:

(n_total,) array

jaxpint.pta.unflatten_params(flat, global_template, pulsar_templates)[source]#

Unpack a flat array back into structured parameter objects.

Templates carry the static metadata (names, frozen mask, units, etc.). Only .values is replaced from slices of the flat array; everything else is preserved from the template.

Layout must match flatten_params():

flat[0 : n_global]               -> GlobalParams.values
flat[n_global : n_global + n_pp0] -> pulsar_params[0].values
...
Parameters:
  • flat ((n_total,) array) – Flat parameter vector.

  • global_template (GlobalParams) – Template with correct static metadata for the global params.

  • pulsar_templates (tuple of ParameterVector) – Templates with correct static metadata for each pulsar.

Returns:

Return type:

tuple[GlobalParams, tuple[ParameterVector, …]]

class jaxpint.pta.CWInjector(pulsar_positions, dist_param_name='PX', prefix='cw0_', initial_values=None)[source]#

Bases: SignalInjector

Injects a single continuous gravitational wave source.

Subclasses SignalInjector. Uses a naming prefix (e.g. 'cw0_', 'cw1_') so that multiple CW sources can coexist in the same GlobalParams.

Parameters:
  • pulsar_positions ((n_psr, 3) array) – Unit vectors pointing to each pulsar.

  • dist_param_name (str) – Name of the parallax parameter in each pulsar’s ParameterVector (default 'PX', in mas). The pulsar-term phase is computed from distance L_kpc = 1 / PX_mas internally (Ellis+2012).

  • prefix (str) – Naming prefix for this source in GlobalParams.

  • initial_values (dict, optional) – Override default initial values. Keys must be in CW_PARAM_DEFAULTS.

param_defaults = {'cos_gwtheta': 0.0, 'cos_inc': 0.0, 'gwphi': 0.0, 'log10_fgw': -8.0, 'log10_h': -14.0, 'phase0': 0.0, 'psi': 0.0}#
register_params(global_params)[source]#

Register CW source parameters into global_params.

Parameters:

global_params (GlobalParams) – Mutable accumulator of shared PTA parameters.

Returns:

Updated copy with this CW source’s parameters appended.

Return type:

GlobalParams

delay(p, toa_data, pulsar_params, global_params)[source]#

Compute CW delay for pulsar p.

Parameters:
  • p (int) – Pulsar index within the PTA.

  • toa_data (TOAData) – Pulse time-of-arrival data for pulsar p.

  • pulsar_params (ParameterVector) – Timing and noise parameters for pulsar p.

  • global_params (GlobalParams) – Shared PTA parameters containing this source’s CW values.

Returns:

CW timing residual in seconds.

Return type:

(n_toas,) array

jaxpint.pta.cw_delay(toa_data, pos, pulsar_dist, global_params, prefix='cw0_')[source]#

CW-induced timing delay for one pulsar (Earth + pulsar term).

Implements the timing residual from Sesana & Vecchio (2010) [cw_sv10] Eq. 5, using the phase-averaging decomposition of Ellis (2013) [cw_e13] Eq. 4. The strain-to-residual scaling alpha = h / (2*pi*f) follows from Detweiler (1979) [cw_d79] Eq. 5.

The pulsar-term phase depends on pulsar distance, which is what makes the Fisher matrix informative for distance constraints.

Parameters:
  • toa_data (TOAData) – Pulse time-of-arrival data (uses TDB timestamps).

  • pos ((3,) array) – Unit vector pointing to the pulsar.

  • pulsar_dist (scalar) – Pulsar parallax in mas (types.py convention). Converted internally to physical distance in kpc for the Ellis+2012 pulsar-term phase.

  • global_params (GlobalParams) – Shared PTA parameters (accessed by prefixed name).

  • prefix (str) – Naming prefix for this CW source in global_params.

Returns:

delay – CW timing residual in seconds.

Return type:

(n_toas,) array

References

[cw_sv10]

Sesana & Vecchio (2010), PRD 81, 104008.

[cw_e13]

Ellis (2013), CQG 30, 224004.

[cw_d79]

Detweiler (1979), ApJ 234, 1100.

jaxpint.pta.fplus_fcross(pos, gwtheta, gwphi)[source]#

Compute F+ and Fx antenna pattern response for a single pulsar.

Implements Eqs. 4–10 of Ellis, Siemens & Creighton (2012) [cw_esc12].

Parameters:
  • pos ((3,) array) – Unit vector pointing to the pulsar.

  • gwtheta (scalar) – GW source colatitude (radians, pi/2 - dec).

  • gwphi (scalar) – GW source right ascension (radians).

Returns:

fplus, fcross – Antenna pattern coefficients.

Return type:

scalars

References

[cw_esc12]

Ellis, Siemens & Creighton (2012), ApJ 756, 175.

class jaxpint.pta.CURNInjector(n_components, T_span, prefix='gwb_', initial_values=None)[source]#

Bases: SignalInjector

Uncorrelated common red noise (CURN, Gamma = I) injector.

Subclasses SignalInjector. Registers two global parameters (with prefix): {prefix}log10_A and {prefix}gamma.

Parameters:
  • n_components (int) – Number of Fourier frequency components per pulsar.

  • T_span (float) – Observing time span in seconds.

  • prefix (str) – Naming prefix in GlobalParams.

  • initial_values (dict, optional) – Override default initial values (keys must be in CURN_PARAM_DEFAULTS).

param_defaults = {'gamma': 4.33, 'log10_A': -15.0}#
register_params(global_params)[source]#

Register CURN amplitude and spectral index into global_params.

Parameters:

global_params (GlobalParams) – Mutable accumulator of shared PTA parameters.

Returns:

Updated copy with {prefix}log10_A and {prefix}gamma appended.

Return type:

GlobalParams

covariance(p, toa_data, pulsar_params, global_params)[source]#

Compute (U, Phi) GWB covariance contribution for pulsar p.

Parameters:
  • p (int) – Pulsar index within the PTA (unused; CURN is identical for all pulsars).

  • toa_data (TOAData) – Pulse time-of-arrival data for pulsar p.

  • pulsar_params (ParameterVector) – Timing and noise parameters for pulsar p (unused).

  • global_params (GlobalParams) – Shared PTA parameters containing GWB amplitude and spectral index.

Returns:

Fourier design matrix U and diagonal PSD vector Phi.

Return type:

tuple of ((n_toas, 2*n_components) array, (2*n_components,) array)

class jaxpint.pta.HDCorrelatedGWBInjector(pulsar_positions, n_components, T_span, orf_func=<function hd_orf>, prefix='gwb_', initial_values=None)[source]#

Bases: CorrelatedSignalInjector

Correlated GWB injector with configurable overlap reduction function.

Registers two global parameters: {prefix}log10_A and {prefix}gamma. The ORF matrix is precomputed at construction time from the supplied pulsar positions and ORF function.

Parameters:
  • pulsar_positions ((n_psr, 3) array) – Unit vectors pointing to each pulsar (ICRS).

  • n_components (int) – Number of Fourier frequency components.

  • T_span (float) – Observing time span in seconds.

  • orf_func (callable, optional) – Overlap reduction function (pos1, pos2) -> scalar. Defaults to hd_orf().

  • prefix (str) – Naming prefix for parameters in GlobalParams.

  • initial_values (dict, optional) – Override default initial values for log10_A and gamma.

register_params(global_params)[source]#

Append this signal’s parameters to global_params.

Parameters:

global_params (GlobalParams) – Accumulator of shared PTA parameters.

Returns:

Updated copy with this signal’s parameters appended.

Return type:

GlobalParams

get_fourier_basis(toa_data)[source]#

Return the Fourier design matrix for a single pulsar.

Parameters:

toa_data (TOAData) – Pulse time-of-arrival data for one pulsar.

Returns:

F – Fourier design matrix (sin/cos columns).

Return type:

(n_toas, n_basis) array

get_psd(global_params)[source]#

Return the GWB power spectral density vector.

Parameters:

global_params (GlobalParams) – Shared PTA parameters (amplitude, spectral index, etc.).

Returns:

S – PSD values for each Fourier basis function (sin and cos each get the same value for their frequency).

Return type:

(n_basis,) array

get_orf_matrix()[source]#

Return the overlap reduction function matrix.

The matrix must be invertible (full rank). Rank-deficient ORFs such as the monopole (all ones) are not supported by the two-tier Woodbury scheme.

Returns:

Gamma – Symmetric, positive-definite ORF matrix. Gamma[a, b] is the correlation coefficient between pulsars a and b.

Return type:

(n_psr, n_psr) array

jaxpint.pta.fourier_basis(toas_seconds, n_components, T_span)[source]#

Fourier design matrix (sine/cosine pairs).

Constructs the basis used for Gaussian-process red noise modelling as described in Lentati et al. (2013) [gwb_l13] Section II.A and van Haasteren & Vallisneri (2014) [gwb_vh14].

Parameters:
  • toas_seconds ((n_toas,) array) – TOA times in seconds.

  • n_components (int) – Number of frequency components.

  • T_span (float) – Observing time span in seconds.

Returns:

  • F ((n_toas, 2 * n_components) array) – Design matrix with alternating sin/cos columns.

  • freqs ((n_components,) array) – Frequencies in Hz.

Return type:

tuple[Float[Array, ‘n_toas n_basis’], Float[Array, ‘n_freq’]]

References

[gwb_l13]

Lentati et al. (2013), PRD 87, 104021.

[gwb_vh14]

van Haasteren & Vallisneri (2014), PRD 90, 104012.

jaxpint.pta.gwb_covariance(toa_data, n_components, T_span, log10_A, gamma)[source]#

Compute (U, Phi) for CURN injection into single_pulsar_logL.

Parameters:
  • toa_data (TOAData) – Pulse time-of-arrival data (uses TDB times).

  • n_components (int) – Number of Fourier frequency components.

  • T_span (float) – Observing time span in seconds.

  • log10_A (scalar) – Log-10 GWB amplitude.

  • gamma (scalar) – GWB spectral index.

Returns:

  • U ((n_toas, 2 * n_components) array) – Fourier design matrix.

  • Phi ((2 * n_components,) array) – PSD values for each basis function.

Return type:

tuple[Float[Array, ‘n_toas n_basis’], Float[Array, ‘n_basis’]]

jaxpint.pta.powerlaw_psd(f, log10_A, gamma)[source]#

Power-law power spectral density (NANOGrav convention).

Follows the parameterisation of Arzoumanian et al. (2016) [gwb_a16] Eq. 1, derived from the characteristic-strain relation of Phinney (2001) [gwb_p01]: S(f) = h_c^2(f) / (12 pi^2 f^3).

\[S(f) = \frac{A^2}{12\pi^2} \left(\frac{f}{f_{\rm yr}}\right)^{-\gamma} f_{\rm yr}^{-3}\]
Parameters:
  • f ((n_freq,) array) – Frequencies in Hz.

  • log10_A (scalar) – Log-10 of the dimensionless amplitude.

  • gamma (scalar) – Spectral index (positive for red noise).

Returns:

psd – Power spectral density in units of s^3.

Return type:

(n_freq,) array

References

[gwb_a16]

Arzoumanian et al. (2016), ApJ 821, 13.

[gwb_p01]

Phinney (2001), astro-ph/0108028.

jaxpint.pta.hd_orf(pos1, pos2)[source]#

Hellings-Downs overlap reduction function.

Implements Eq. 2 of Hellings & Downs (1983) [orf_hd83]:

\[C(\xi) = \frac{3}{2} x \ln x - \frac{x}{4} + \frac{1}{2}, \quad x = \frac{1 - \cos\xi}{2}\]
Parameters:
  • pos1 ((3,) arrays) – Unit vectors pointing to the two pulsars.

  • pos2 ((3,) arrays) – Unit vectors pointing to the two pulsars.

Returns:

HD correlation coefficient in [−1/8, 1/2].

Return type:

float

References

[orf_hd83]

Hellings & Downs (1983), ApJL 265, L39.

jaxpint.pta.monopole_orf(pos1, pos2)[source]#

Monopole ORF (isotropic, unit correlation for all pairs).

Returns 1.0 for every pulsar pair regardless of angular separation.

Parameters:
  • pos1 ((3,) arrays) – Unit vectors pointing to the two pulsars.

  • pos2 ((3,) arrays) – Unit vectors pointing to the two pulsars.

Returns:

Always 1.0.

Return type:

float

jaxpint.pta.dipole_orf(pos1, pos2)[source]#

Dipole ORF (correlation proportional to cos(angle)).

Returns the cosine of the angular separation between the two pulsars, i.e. dot(pos1, pos2).

Parameters:
  • pos1 ((3,) arrays) – Unit vectors pointing to the two pulsars.

  • pos2 ((3,) arrays) – Unit vectors pointing to the two pulsars.

Returns:

Cosine of the angular separation, in [-1, 1].

Return type:

float