Source code for jaxpint.noise.red_noise
"""Power-law red noise model for JaxPINT.
Implements achromatic red noise with a power-law power spectral density
using an alternating Fourier basis (sin/cos pairs), matching PINT's
``PLRedNoise`` component.
The noise covariance is decomposed as::
C_rn = F · diag(w) · Fᵀ
where *F* is a Fourier design matrix (pre-computed by the bridge) and
*w* are the power-law PSD weights computed from the amplitude and
spectral index parameters.
"""
from __future__ import annotations
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
from jaxpint.components import NoiseComponent
from jaxpint.constants import FYR
from jaxpint.types import TOAData, ParameterVector
[docs]
class PLRedNoise(NoiseComponent):
"""Power-law red noise via alternating Fourier basis.
The Fourier design matrix *F* is pre-computed by the bridge from
TOA times and stored as a JAX array. The PSD weights depend on
the amplitude (``TNREDAMP``) and spectral index (``TNREDGAM``)
parameters and are computed dynamically so that they are
differentiable.
Parameters
----------
fourier_basis : (n_toas, 2 * n_freqs)
Pre-computed Fourier design matrix with alternating sin/cos
columns: ``[sin(2πf₁t), cos(2πf₁t), sin(2πf₂t), ...]``.
freqs : (n_freqs,)
Frequency array in Hz.
freq_bin_widths : (n_freqs,)
Δf for each frequency bin (used to weight the PSD).
tnredamp_name : str
Parameter name for the log10 amplitude.
tnredgam_name : str
Parameter name for the spectral index.
"""
fourier_basis: Float[Array, "n_toas n_basis"]
freqs: Float[Array, " n_freqs"]
freq_bin_widths: Float[Array, " n_freqs"]
tnredamp_name: str = eqx.field(static=True)
tnredgam_name: str = eqx.field(static=True)
[docs]
def psd_weights(
self,
params: ParameterVector,
) -> Float[Array, " n_basis"]:
"""Compute power-law PSD weights for the Fourier basis.
Returns one weight per basis column (sin and cos of each
frequency get the same weight).
The power spectral density follows the convention::
P(f) = (A² / 12π²) · f_yr^(γ-3) · f^(-γ)
Each weight is ``P(f) · Δf``, repeated twice for the sin/cos
pair at that frequency.
Parameters
----------
params : ParameterVector
Must contain values for ``TNREDAMP`` (log10 amplitude)
and ``TNREDGAM`` (spectral index).
Returns
-------
weights : (2 * n_freqs,)
PSD weights for each basis column.
"""
log10_A = params.param_value(self.tnredamp_name)
gamma = params.param_value(self.tnredgam_name)
A = 10.0 ** log10_A
psd = (
A ** 2
/ (12.0 * jnp.pi ** 2)
* FYR ** (gamma - 3.0)
* self.freqs ** (-gamma)
)
# weight = PSD(f) * Δf, repeated for sin and cos
return jnp.repeat(psd * self.freq_bin_widths, 2)
[docs]
def covariance(
self,
toa_data: TOAData,
params: ParameterVector,
) -> tuple[
Float[Array, " n_toas"],
Float[Array, "n_toas n_basis"],
Float[Array, " n_basis"],
]:
"""Return the Woodbury ``(Ndiag, U, Phidiag)`` triple for red noise.
Red noise is purely low-rank: ``Ndiag = 0``.
Parameters
----------
toa_data : TOAData
Observed TOA data (used for array sizing).
params : ParameterVector
Current parameter values for amplitude and spectral index.
Returns
-------
Ndiag : (n_toas,)
Zero diagonal (red noise has no white component).
U : (n_toas, 2 * n_freqs)
Fourier design matrix.
Phidiag : (2 * n_freqs,)
Power-law PSD weights.
"""
Ndiag = jnp.zeros(toa_data.n_toas)
return Ndiag, self.fourier_basis, self.psd_weights(params)
[docs]
def generate(
self,
toa_data: TOAData,
params: ParameterVector,
key: jax.Array,
) -> Float[Array, " n_toas"]:
"""Draw a random red noise realization.
Draws standard-normal Fourier amplitudes and projects them
through the basis matrix scaled by sqrt(weights).
Parameters
----------
toa_data : TOAData
Observed TOA data (used for basis matrix dimensions).
params : ParameterVector
Current parameter values for amplitude and spectral index.
key : jax.Array
PRNG key for random sampling.
Returns
-------
noise : (n_toas,)
Red noise realization in seconds.
"""
weights = self.psd_weights(params)
n_basis = self.fourier_basis.shape[1]
a = jax.random.normal(key, shape=(n_basis,))
return self.fourier_basis @ (jnp.sqrt(weights) * a)