jaxpint.pta#
PTA likelihood module for JaxPINT.
Composes jaxpint.likelihood.single_pulsar_logL() across multiple
pulsars with shared signal injections (CW sources, GWB, etc.).
- class jaxpint.pta.GlobalParams(values, names, _name_to_index)[source]#
Bases:
ModuleShared parameter container for PTA-level signals.
An Equinox module where
valuesis the only dynamic (traced) leaf. Static metadata (names, index mapping) is fixed at construction time.Build incrementally via
add_params()or the builder pattern:gp = GlobalParams.empty() for inj in signal_injectors: gp = inj.register_params(gp)
- Parameters:
- values: Float[Array, 'n_global']#
- add_params(names, values)[source]#
Append new parameters, returning a new GlobalParams.
- Parameters:
names (
list[str]) – Parameter names to add.values (
list[float]) – Initial values for each parameter (same length as names).
- Returns:
New instance with the appended parameters.
- Return type:
- Raises:
ValueError – If names and values have different lengths, or if any name is already present (prevents silent overwrites from duplicate injectors or prefix collisions).
- param_value(name)[source]#
Look up a single parameter value by name.
- Parameters:
name (str)
- Return type:
Float[Array, ‘’]
- class jaxpint.pta.PTAConfig(toa_data_list, timing_models, noise_models, signal_injectors)[source]#
Bases:
ModuleStatic (non-differentiated) configuration for PTA likelihood evaluation.
All fields are marked
static=Trueso that JAX treats the entire config as a compile-time constant.Raises
ValueErrorat construction if the per-pulsar tuples have mismatched lengths.- Parameters:
timing_models (tuple[TimingModel, ...])
noise_models (tuple[NoiseModel, ...])
signal_injectors (tuple[SignalInjector, ...])
- timing_models: tuple[TimingModel, ...]#
- noise_models: tuple[NoiseModel, ...]#
- signal_injectors: tuple[SignalInjector, ...]#
- class jaxpint.pta.SignalInjector[source]#
Bases:
ABCAbstract base class for PTA signal components.
Each injector:
Registers its own parameters into
GlobalParamsviaregister_params()(required — abstract).Produces delay arrays and/or covariance
(U, Phi)tuples per pulsar viadelay()/covariance()(optional — default implementations returnNone).
Subclasses must implement
register_params(). Overridedelay()for deterministic signals (e.g. CW) and/orcovariance()for stochastic signals (e.g. GWB).pta_logL()is agnostic to the signal type.- abstractmethod register_params(global_params)[source]#
Append this signal’s parameters to global_params.
- Parameters:
global_params (
GlobalParams) – Mutable accumulator of shared PTA parameters.- Returns:
Updated copy with this signal’s parameters appended.
- Return type:
- delay(p, toa_data, pulsar_params, global_params)[source]#
Return deterministic delay for pulsar p, or
None.Override for deterministic signals. The default returns
None(no delay contribution).- Parameters:
p (
int) – Pulsar index within the PTA.toa_data (
TOAData) – Pulse time-of-arrival data for pulsar p.pulsar_params (
ParameterVector) – Timing and noise parameters for pulsar p.global_params (
GlobalParams) – Shared PTA parameters (CW source properties, GWB spectrum, etc.).
- Returns:
Deterministic timing delay in seconds, or
Noneif this injector does not contribute a delay.- Return type:
(n_toas,) arrayorNone
- covariance(p, toa_data, pulsar_params, global_params)[source]#
Return
(U, Phi)covariance contribution for pulsar p, orNone.Override for stochastic signals. The default returns
None(no covariance contribution).- Parameters:
p (
int) – Pulsar index within the PTA.toa_data (
TOAData) – Pulse time-of-arrival data for pulsar p.pulsar_params (
ParameterVector) – Timing and noise parameters for pulsar p.global_params (
GlobalParams) – Shared PTA parameters (CW source properties, GWB spectrum, etc.).
- Returns:
Design matrix
Uand diagonal PSD vectorPhi, orNoneif this injector does not contribute covariance.- Return type:
tupleof((n_toas,n_basis) array,(n_basis,) array)orNone
- jaxpint.pta.pta_logL(global_params, pulsar_params, config)[source]#
Multi-pulsar log-likelihood with signal injections.
For each pulsar, collects delay and covariance contributions from every
SignalInjectorin config, then delegates tojaxpint.likelihood.single_pulsar_logL().- Parameters:
global_params (
GlobalParams) – Shared parameters (CW source properties, GWB spectrum, etc.). This is the first differentiable argument.pulsar_params (
tupleofParameterVector) – Per-pulsar timing and noise parameters. This is the second differentiable argument.config (
PTAConfig) – Static configuration (TOA data, models, injectors).
- Returns:
logL – Sum of per-pulsar log-likelihoods.
- Return type:
scalar
Bases:
ModuleStatic configuration for the correlated PTA likelihood.
Extends
PTAConfigwith acorrelated_injectorsfield for cross-pulsar signals.All fields are
static=True(compile-time constants for JAX).- Parameters:
timing_models (tuple[TimingModel, ...])
noise_models (tuple[NoiseModel, ...])
signal_injectors (tuple[SignalInjector, ...])
correlated_injectors (tuple[CorrelatedSignalInjector, ...])
Bases:
ABCAbstract base class for cross-pulsar correlated signal components.
Unlike
SignalInjector, which produces per-pulsar covariance contributions, aCorrelatedSignalInjectorprovides the ingredients to build a PTA-wide covariance matrix with inter-pulsar correlations.Append this signal’s parameters to global_params.
- Parameters:
global_params (
GlobalParams) – Accumulator of shared PTA parameters.- Returns:
Updated copy with this signal’s parameters appended.
- Return type:
Return the Fourier design matrix for a single pulsar.
- Parameters:
toa_data (
TOAData) – Pulse time-of-arrival data for one pulsar.- Returns:
F – Fourier design matrix (sin/cos columns).
- Return type:
(n_toas,n_basis) array
Return the GWB power spectral density vector.
- Parameters:
global_params (
GlobalParams) – Shared PTA parameters (amplitude, spectral index, etc.).- Returns:
S – PSD values for each Fourier basis function (sin and cos each get the same value for their frequency).
- Return type:
(n_basis,) array
Return the overlap reduction function matrix.
The matrix must be invertible (full rank). Rank-deficient ORFs such as the monopole (all ones) are not supported by the two-tier Woodbury scheme.
- Returns:
Gamma – Symmetric, positive-definite ORF matrix.
Gamma[a, b]is the correlation coefficient between pulsars a and b.- Return type:
(n_psr,n_psr) array
Multi-pulsar log-likelihood with cross-pulsar correlations.
Implements a two-tier Woodbury scheme:
Per-pulsar (inner tier): compute
C_p^{-1} r_p,C_p^{-1} F_p, and per-pulsar log-likelihood contributions.Cross-pulsar (outer tier): assemble and solve the compressed
Sigma_gwbsystem to account for ORF-mediated correlations.
- Parameters:
global_params (
GlobalParams) – Shared parameters (GWB amplitude/spectral index, CW source, etc.).pulsar_params (
tupleofParameterVector) – Per-pulsar timing and noise parameters.config (
CorrelatedPTAConfig) – Static configuration.
- Returns:
logL – Log-likelihood value.
- Return type:
scalar
- jaxpint.pta.fisher_matrix(global_params, pulsar_params, config)[source]#
Compute the Fisher information matrix via
jax.hessian.Internally flattens all parameters into a single array, wraps
pta_logL()in a closure that unflattens, and callsjax.hessianon the flat function.- Parameters:
global_params (
GlobalParams) – Current global parameter values (evaluation point).pulsar_params (
tupleofParameterVector) – Current per-pulsar parameter values (evaluation point).config (
PTAConfig) – Static PTA configuration.
- Returns:
fisher – Fisher matrix, where
n_total = n_global + sum(n_pp_i).- Return type:
(n_total,n_total) array
- jaxpint.pta.flatten_params(global_params, pulsar_params)[source]#
Pack all differentiable parameters into a single flat array.
Layout:
[global_params.values | pp[0].values | pp[1].values | ...]- Parameters:
global_params (
GlobalParams) – Shared PTA parameters.pulsar_params (
tupleofParameterVector) – Per-pulsar timing and noise parameters.
- Returns:
flat – Concatenated parameter values, where
n_total = n_global + sum(n_pp_i).- Return type:
(n_total,) array
- jaxpint.pta.unflatten_params(flat, global_template, pulsar_templates)[source]#
Unpack a flat array back into structured parameter objects.
Templates carry the static metadata (names, frozen mask, units, etc.). Only
.valuesis replaced from slices of the flat array; everything else is preserved from the template.Layout must match
flatten_params():flat[0 : n_global] -> GlobalParams.values flat[n_global : n_global + n_pp0] -> pulsar_params[0].values ...
- Parameters:
flat (
(n_total,) array) – Flat parameter vector.global_template (
GlobalParams) – Template with correct static metadata for the global params.pulsar_templates (
tupleofParameterVector) – Templates with correct static metadata for each pulsar.
- Returns:
global_params (
GlobalParams)pulsar_params (
tupleofParameterVector)
- Return type:
- class jaxpint.pta.CWInjector(pulsar_positions, dist_param_name='PX', prefix='cw0_', initial_values=None)[source]#
Bases:
SignalInjectorInjects a single continuous gravitational wave source.
Subclasses
SignalInjector. Uses a naming prefix (e.g.'cw0_','cw1_') so that multiple CW sources can coexist in the sameGlobalParams.- Parameters:
pulsar_positions (
(n_psr,3) array) – Unit vectors pointing to each pulsar.dist_param_name (
str) – Name of the parallax parameter in each pulsar’sParameterVector(default'PX', in mas). The pulsar-term phase is computed from distanceL_kpc = 1 / PX_masinternally (Ellis+2012).prefix (
str) – Naming prefix for this source inGlobalParams.initial_values (
dict, optional) – Override default initial values. Keys must be inCW_PARAM_DEFAULTS.
- param_defaults = {'cos_gwtheta': 0.0, 'cos_inc': 0.0, 'gwphi': 0.0, 'log10_fgw': -8.0, 'log10_h': -14.0, 'phase0': 0.0, 'psi': 0.0}#
- register_params(global_params)[source]#
Register CW source parameters into global_params.
- Parameters:
global_params (
GlobalParams) – Mutable accumulator of shared PTA parameters.- Returns:
Updated copy with this CW source’s parameters appended.
- Return type:
- delay(p, toa_data, pulsar_params, global_params)[source]#
Compute CW delay for pulsar p.
- Parameters:
p (
int) – Pulsar index within the PTA.toa_data (
TOAData) – Pulse time-of-arrival data for pulsar p.pulsar_params (
ParameterVector) – Timing and noise parameters for pulsar p.global_params (
GlobalParams) – Shared PTA parameters containing this source’s CW values.
- Returns:
CW timing residual in seconds.
- Return type:
(n_toas,) array
- jaxpint.pta.cw_delay(toa_data, pos, pulsar_dist, global_params, prefix='cw0_')[source]#
CW-induced timing delay for one pulsar (Earth + pulsar term).
Implements the timing residual from Sesana & Vecchio (2010) [cw_sv10] Eq. 5, using the phase-averaging decomposition of Ellis (2013) [cw_e13] Eq. 4. The strain-to-residual scaling
alpha = h / (2*pi*f)follows from Detweiler (1979) [cw_d79] Eq. 5.The pulsar-term phase depends on pulsar distance, which is what makes the Fisher matrix informative for distance constraints.
- Parameters:
toa_data (
TOAData) – Pulse time-of-arrival data (uses TDB timestamps).pos (
(3,) array) – Unit vector pointing to the pulsar.pulsar_dist (
scalar) – Pulsar parallax in mas (types.py convention). Converted internally to physical distance in kpc for the Ellis+2012 pulsar-term phase.global_params (
GlobalParams) – Shared PTA parameters (accessed by prefixed name).prefix (
str) – Naming prefix for this CW source in global_params.
- Returns:
delay – CW timing residual in seconds.
- Return type:
(n_toas,) array
References
- jaxpint.pta.fplus_fcross(pos, gwtheta, gwphi)[source]#
Compute F+ and Fx antenna pattern response for a single pulsar.
Implements Eqs. 4–10 of Ellis, Siemens & Creighton (2012) [cw_esc12].
- Parameters:
pos (
(3,) array) – Unit vector pointing to the pulsar.gwtheta (
scalar) – GW source colatitude (radians,pi/2 - dec).gwphi (
scalar) – GW source right ascension (radians).
- Returns:
fplus, fcross – Antenna pattern coefficients.
- Return type:
scalars
References
[cw_esc12]Ellis, Siemens & Creighton (2012), ApJ 756, 175.
- class jaxpint.pta.CURNInjector(n_components, T_span, prefix='gwb_', initial_values=None)[source]#
Bases:
SignalInjectorUncorrelated common red noise (CURN, Gamma = I) injector.
Subclasses
SignalInjector. Registers two global parameters (with prefix):{prefix}log10_Aand{prefix}gamma.- Parameters:
n_components (
int) – Number of Fourier frequency components per pulsar.T_span (
float) – Observing time span in seconds.prefix (
str) – Naming prefix inGlobalParams.initial_values (
dict, optional) – Override default initial values (keys must be inCURN_PARAM_DEFAULTS).
- param_defaults = {'gamma': 4.33, 'log10_A': -15.0}#
- register_params(global_params)[source]#
Register CURN amplitude and spectral index into global_params.
- Parameters:
global_params (
GlobalParams) – Mutable accumulator of shared PTA parameters.- Returns:
Updated copy with
{prefix}log10_Aand{prefix}gammaappended.- Return type:
- covariance(p, toa_data, pulsar_params, global_params)[source]#
Compute
(U, Phi)GWB covariance contribution for pulsar p.- Parameters:
p (
int) – Pulsar index within the PTA (unused; CURN is identical for all pulsars).toa_data (
TOAData) – Pulse time-of-arrival data for pulsar p.pulsar_params (
ParameterVector) – Timing and noise parameters for pulsar p (unused).global_params (
GlobalParams) – Shared PTA parameters containing GWB amplitude and spectral index.
- Returns:
Fourier design matrix
Uand diagonal PSD vectorPhi.- Return type:
tupleof((n_toas,2*n_components) array,(2*n_components,) array)
Bases:
CorrelatedSignalInjectorCorrelated GWB injector with configurable overlap reduction function.
Registers two global parameters:
{prefix}log10_Aand{prefix}gamma. The ORF matrix is precomputed at construction time from the supplied pulsar positions and ORF function.- Parameters:
pulsar_positions (
(n_psr,3) array) – Unit vectors pointing to each pulsar (ICRS).n_components (
int) – Number of Fourier frequency components.T_span (
float) – Observing time span in seconds.orf_func (
callable, optional) – Overlap reduction function(pos1, pos2) -> scalar. Defaults tohd_orf().prefix (
str) – Naming prefix for parameters inGlobalParams.initial_values (
dict, optional) – Override default initial values forlog10_Aandgamma.
Append this signal’s parameters to global_params.
- Parameters:
global_params (
GlobalParams) – Accumulator of shared PTA parameters.- Returns:
Updated copy with this signal’s parameters appended.
- Return type:
Return the Fourier design matrix for a single pulsar.
- Parameters:
toa_data (
TOAData) – Pulse time-of-arrival data for one pulsar.- Returns:
F – Fourier design matrix (sin/cos columns).
- Return type:
(n_toas,n_basis) array
Return the GWB power spectral density vector.
- Parameters:
global_params (
GlobalParams) – Shared PTA parameters (amplitude, spectral index, etc.).- Returns:
S – PSD values for each Fourier basis function (sin and cos each get the same value for their frequency).
- Return type:
(n_basis,) array
Return the overlap reduction function matrix.
The matrix must be invertible (full rank). Rank-deficient ORFs such as the monopole (all ones) are not supported by the two-tier Woodbury scheme.
- Returns:
Gamma – Symmetric, positive-definite ORF matrix.
Gamma[a, b]is the correlation coefficient between pulsars a and b.- Return type:
(n_psr,n_psr) array
- jaxpint.pta.fourier_basis(toas_seconds, n_components, T_span)[source]#
Fourier design matrix (sine/cosine pairs).
Constructs the basis used for Gaussian-process red noise modelling as described in Lentati et al. (2013) [gwb_l13] Section II.A and van Haasteren & Vallisneri (2014) [gwb_vh14].
- Parameters:
- Returns:
F (
(n_toas,2 * n_components) array) – Design matrix with alternating sin/cos columns.freqs (
(n_components,) array) – Frequencies in Hz.
- Return type:
tuple[Float[Array, ‘n_toas n_basis’], Float[Array, ‘n_freq’]]
References
- jaxpint.pta.gwb_covariance(toa_data, n_components, T_span, log10_A, gamma)[source]#
Compute (U, Phi) for CURN injection into
single_pulsar_logL.- Parameters:
- Returns:
U (
(n_toas,2 * n_components) array) – Fourier design matrix.Phi (
(2 * n_components,) array) – PSD values for each basis function.
- Return type:
tuple[Float[Array, ‘n_toas n_basis’], Float[Array, ‘n_basis’]]
- jaxpint.pta.powerlaw_psd(f, log10_A, gamma)[source]#
Power-law power spectral density (NANOGrav convention).
Follows the parameterisation of Arzoumanian et al. (2016) [gwb_a16] Eq. 1, derived from the characteristic-strain relation of Phinney (2001) [gwb_p01]:
S(f) = h_c^2(f) / (12 pi^2 f^3).\[S(f) = \frac{A^2}{12\pi^2} \left(\frac{f}{f_{\rm yr}}\right)^{-\gamma} f_{\rm yr}^{-3}\]- Parameters:
f (
(n_freq,) array) – Frequencies in Hz.log10_A (
scalar) – Log-10 of the dimensionless amplitude.gamma (
scalar) – Spectral index (positive for red noise).
- Returns:
psd – Power spectral density in units of s^3.
- Return type:
(n_freq,) array
References
- jaxpint.pta.hd_orf(pos1, pos2)[source]#
Hellings-Downs overlap reduction function.
Implements Eq. 2 of Hellings & Downs (1983) [orf_hd83]:
\[C(\xi) = \frac{3}{2} x \ln x - \frac{x}{4} + \frac{1}{2}, \quad x = \frac{1 - \cos\xi}{2}\]- Parameters:
pos1 (
(3,) arrays) – Unit vectors pointing to the two pulsars.pos2 (
(3,) arrays) – Unit vectors pointing to the two pulsars.
- Returns:
HD correlation coefficient in [−1/8, 1/2].
- Return type:
References
[orf_hd83]Hellings & Downs (1983), ApJL 265, L39.
- jaxpint.pta.monopole_orf(pos1, pos2)[source]#
Monopole ORF (isotropic, unit correlation for all pairs).
Returns 1.0 for every pulsar pair regardless of angular separation.
- Parameters:
pos1 (
(3,) arrays) – Unit vectors pointing to the two pulsars.pos2 (
(3,) arrays) – Unit vectors pointing to the two pulsars.
- Returns:
Always 1.0.
- Return type:
- jaxpint.pta.dipole_orf(pos1, pos2)[source]#
Dipole ORF (correlation proportional to cos(angle)).
Returns the cosine of the angular separation between the two pulsars, i.e.
dot(pos1, pos2).- Parameters:
pos1 (
(3,) arrays) – Unit vectors pointing to the two pulsars.pos2 (
(3,) arrays) – Unit vectors pointing to the two pulsars.
- Returns:
Cosine of the angular separation, in [-1, 1].
- Return type: