"""Probabilistic models for asteroseismic oscillation mode frequencies.
"""
from __future__ import annotations
import numpyro
import numpy as np
import astropy.units as u
import jax.numpy as jnp
from numpy.typing import ArrayLike
from typing import Optional, Union
from jax import random
from numpyro.infer import Predictive
from tinygp import kernels, GaussianProcess
from .typing import DistLike
from .priors import (
AsyFunction,
CZGlitchFunction,
HeGlitchFunction,
TauPrior,
Prior,
)
from .messengers import dimension
from .utils import distribution
__all__ = [
"Model",
"GlitchModel",
]
[docs]class Model(Prior):
"""Model class.
A model is a probabilistic object which may be given to Inference. It does
not need to return anything during inference, but should have at least
one observed sample sites.
"""
def __call__(self, n, nu=None, nu_err=None, n_pred=None):
"""Call the model during inference.
Args:
nu (:term:`array_like`, optional): Observed radial mode
frequencies.
nu_err (:term:`array_like`, optional): Gaussian observational
uncertainties (sigma) for nu.
Raises:
NotImplementedError: This is an abstract class and cannot be
called.
"""
raise NotImplementedError
[docs]class GlitchModel(Model):
r"""Asteroseismic glitch model.
.. math::
\nu_\mathrm{obs} \sim \mathcal{GP}(m(n), k(n, n') +
\sigma^2\mathcal{I})
Where the mean function is,
.. math::
m(n) &= \nu_\mathrm{bkg} + \delta_\mathrm{He} + \delta_\mathrm{CZ},\\
\nu_\mathrm{bkg} &= f_\mathrm{bkg}(n),\\
\delta_\mathrm{He} &= f_\mathrm{He}(\nu_\mathrm{bkg}),\\
\delta_\mathrm{CZ} &= f_\mathrm{CZ}(\nu_\mathrm{bkg}),
and the kernel function is,
.. math::
k(n, n') = \sigma_k^2 \exp\left( - \frac{(n' - n)^2}{l^2} \right).
Args:
n (:term:`array_like`): Radial order of model observations.
nu_max (:term:`dist_like`): Prior on the frequency at maximum power.
delta_nu (:term:`dist_like`): Prior on the large frequency separation.
teff (:term:`dist_like`, optional): Prior on the effective temperature.
This is used for estimating a prior on the glitch acoustic depths.
If None (default), a prior of Normal(5000, 700) is assumed.
epsilon (:term:`dist_like`, optional): Prior on the asymptotic phase
parameter.
num_pred (int): The number of points in radial order for
which to make predictions.
seed (int): The seed used to generate samples from the prior on the
glitch periods (acoustic depths) tau_he and tau_cz.
window_width (float): The number of delta_nu either side of nu_max over
which to average the helium glitch amplitude for the parameter
'he_amplitude'.
Attributes:
n (numpy.ndarray): Radial order of model observations.
n_pred (numpy.ndarray): Radial order of model predictions.
background (Prior): Prior on the background function.
he_glitch (Prior): Prior on the helium glitch function.
cz_glitch (Prior): Prior on the base of convective zone glitch
function.
window_width (str or float): The number of delta_nu either side of
nu_max over which to average the glitch amplitudes. If string,
'full', the window is chosen over the entire range in frequency.
"""
def __init__(
self,
nu_max: DistLike,
delta_nu: DistLike,
teff: Optional[DistLike] = None,
epsilon: Optional[DistLike] = None,
seed: int = 0,
window_width: Union[str, float] = "full",
):
super().__init__(
nu_max,
delta_nu,
teff=teff,
epsilon=epsilon,
seed=seed,
)
self.background: Prior = AsyFunction(delta_nu, epsilon=epsilon)
key = random.PRNGKey(seed)
# tau_prior = TauPrior(nu_max, teff)
# log_tau_he, log_tau_cz = self._init_tau(key, tau_prior)
def logistic(x, x0, k):
return 1 / (1 + np.exp(-k * (x - x0)))
log_numax = jnp.log10(nu_max[0])
log_numax_err = nu_max[1]/nu_max[0]/jnp.log(10.0)
log_teff = jnp.log10(teff[0])
# mu_he = 0.184 - 0.964 * log_numax
# mu_he = 4.981 - 1.334 * log_teff - 0.899 * log_numax # from 50 stars
mu_he = 3.82 - 0.99 * log_teff - 0.93 * log_numax # from models
# sigma_he = 0.085
sigma_he = 0.05
log_tau_he = (mu_he, sigma_he)
# mu_cz = 0.449 - 0.909 * log_numax
# mu_cz = 10.331 - 2.747 * log_teff - 0.774 * log_numax # from 50 stars
mu_cz = 0.31 + (1.69 - 0.68 * log_teff) * log_numax # from models
# sigma_cz = 0.14
sigma_cz = 0.08
log_tau_cz = (mu_cz, sigma_cz)
self.he_glitch: Prior = HeGlitchFunction(nu_max, log_tau=log_tau_he)
self.cz_glitch: Prior = CZGlitchFunction(nu_max, log_tau=log_tau_cz)
self._nu_max = distribution(nu_max)
self._kernel_var = 0.1 * self.background.delta_nu.mean
self._kernel_length = 5.0
self.window_width = window_width
self.units = {
"nu_obs": u.microhertz,
"nu": u.microhertz,
"nu_bkg": u.microhertz,
"dnu_he": u.microhertz,
"dnu_cz": u.microhertz,
# "he_nu_max": u.microhertz,
# "cz_nu_max": u.microhertz,
"he_amplitude": u.microhertz,
"cz_amplitude": u.microhertz,
"nu_max": u.microhertz,
}
self.symbols = {
"nu_obs": r"$\nu_\mathrm{obs}$",
"nu": r"$\nu$",
"nu_bkg": r"$\nu_\mathrm{bkg}$",
"dnu_he": r"$\delta\nu_\mathrm{He}$",
"dnu_cz": r"$\delta\nu_\mathrm{BCZ}$",
# "he_nu_max": r"$A_\mathrm{He}(\nu_\max)$",
# "cz_nu_max": r"$A_\mathrm{BCZ}(\nu_\max)$",
"he_amplitude": r"$\langle A_\mathrm{He} \rangle$",
"cz_amplitude": r"$\langle A_\mathrm{BCZ} \rangle$",
"nu_max": r"$\nu_\mathrm{max}$",
}
for prior in [self.background, self.he_glitch, self.cz_glitch]:
# Inherit units from priors.
self.units.update(prior.units)
self.symbols.update(prior.symbols)
def _init_tau(self, rng_key, tau_prior, num_samples=5000):
predictive = Predictive(tau_prior, num_samples=num_samples)
pred = predictive(rng_key)
log_tau = pred["log_tau"] - 6 # Convert from seconds to mega seconds
loc = log_tau.mean(axis=0)
scale = log_tau.std(axis=0, ddof=1)
return (
distribution((loc[0], scale[0])), # tau_he
distribution((loc[1], scale[1])), # tau_cz
)
def _glitch_amplitudes(self, nu):
# nu_max = numpyro.sample("nu_max", self._nu_max)
# numpyro.deterministic("he_nu_max", self.he_glitch.amplitude(nu_max))
# numpyro.deterministic("cz_nu_max", self.cz_glitch.amplitude(nu_max))
if self.window_width == "full":
low, high = nu.min(), nu.max()
else:
low = self._nu_max.mean - self.window_width * self.background._delta_nu
high = self._nu_max.mean + self.window_width * self.background._delta_nu
he_amp = numpyro.deterministic(
"he_amplitude", self.he_glitch._average_amplitude(low, high)
)
cz_amp = numpyro.deterministic(
"cz_amplitude", self.cz_glitch._average_amplitude(low, high)
)
return he_amp, cz_amp
def _amplitude_prior(self, he_amp, cz_amp):
# Prior that log(he_amp) == log(cz_amp) is a 2-sigma event
# and the He amplitude is ~ 4 times the BCZ amplitude (log10(4) ~ 0.6)
delta = 2.0 * (jnp.log10(he_amp) - jnp.log10(cz_amp) - 0.6) / 0.6
logp = numpyro.distributions.Normal().log_prob(delta)
numpyro.factor("amp", logp)
def __call__(
self,
n: ArrayLike,
nu: Optional[ArrayLike] = None,
nu_err: Optional[ArrayLike] = None,
n_pred: Optional[ArrayLike] = None,
):
"""Sample the model for given observables.
Args:
nu (:term:`array_like`, optional): Observed radial mode
frequencies.
nu_err (:term:`array_like`, optional): Gaussian observational
uncertainties (sigma) for nu.
pred (bool): If True, make predictions nu and nu_pred from n and
num_pred.
"""
# TODO it may be more general for all models to take an obs dict as
# argument and every parameter to do obs.get('name', None)
bkg_func = self.background()
he_glitch_func = self.he_glitch()
cz_glitch_func = self.cz_glitch()
# The mean function for the GP
def mean(n):
nu_bkg = bkg_func(n)[0] # shape of bkg_func is (1, num_orders)
return nu_bkg + he_glitch_func(nu_bkg) + cz_glitch_func(nu_bkg)
var = numpyro.param("kernel_var", self._kernel_var)
length = numpyro.param("kernel_length", self._kernel_length)
# kernel = SquaredExponential(var, length)
kernel = var * kernels.ExpSquared(length)
diag = 1e-6 if nu_err is None else nu_err**2 # No need for jitter
gp = GaussianProcess(kernel, n, mean=mean, diag=diag)
# gp = GP(kernel, mean=mean)
with dimension("n", n.shape[-1], coords=n):
nu = numpyro.sample("nu_obs", gp.numpyro_dist(), obs=nu) # new nu!
# nu = gp.sample("nu_obs", n, noise=nu_err, obs=nu)
# if n_pred is not None:
# gp.predict("nu", n) # prediction without noise
nu_bkg = numpyro.deterministic("nu_bkg", bkg_func(n))
numpyro.deterministic("dnu_he", he_glitch_func(nu_bkg))
numpyro.deterministic("dnu_cz", cz_glitch_func(nu_bkg))
if n_pred is not None:
with dimension("n", n.shape[-1], coords=n):
numpyro.sample("nu", gp.condition(nu, n).gp.numpyro_dist())
with dimension("n_pred", n_pred.shape[-1], coords=n_pred):
# gp.predict("nu_pred", n_pred)
numpyro.sample(
"nu_pred",
gp.condition(nu, n_pred).gp.numpyro_dist()
)
nu_bkg = numpyro.deterministic("nu_bkg_pred", bkg_func(n_pred))
numpyro.deterministic("dnu_he_pred", he_glitch_func(nu_bkg))
numpyro.deterministic("dnu_cz_pred", cz_glitch_func(nu_bkg))
# Other deterministics
self._glitch_amplitudes(nu)
# self._amplitude_prior(*self._glitch_amplitudes(nu))
[docs]class GlitchModelComparison(GlitchModel):
r"""Asteroseismic glitch model comparison. Compare the glitch model with
a glitchless model. The frequencies are modelled using a GP with the same
kernel function but different mean functions.
The glitch model is the same as :class:`GlitchModel`. The glitchless model
is the same except that the mean function is,
.. math::
m_0(n) = f_\mathrm{bkg}(n),
The two models are compared using the Bayes' factor,
.. math::
K = \frac{p(\nu_\mathrm{obs} \mid \mathcal{GP}_1)}
{p(\nu_\mathrm{obs} \mid \mathcal{GP}_0)}
where :math:`\mathcal{GP}_0` is the glitchless model and
:math:`\mathcal{GP}_1` is the glitch model.
Args:
n (:term:`array_like`): Radial order of model observations.
nu_max (:term:`dist_like`): Prior on the frequency at maximum power.
delta_nu (:term:`dist_like`): Prior on the large frequency separation.
teff (:term:`dist_like`, optional): Prior on the effective temperature.
This is used for estimating a prior on the glitch acoustic depths.
If None (default), a prior of Normal(5000, 700) is assumed.
epsilon (:term:`dist_like`, optional): Prior on the asymptotic phase
parameter.
num_pred (int): The number of points in radial order for
which to make predictions.
seed (int): The seed used to generate samples from the prior on the
glitch periods (acoustic depths) tau_he and tau_cz.
window_width (float): The number of delta_nu either side of nu_max over
which to average the helium glitch amplitude for the parameter
'he_amplitude'.
Attributes:
n (numpy.ndarray): Radial order of model observations.
n_pred (numpy.ndarray): Radial order of model predictions.
background (Prior): Prior on the background function.
he_glitch (Prior): Prior on the helium glitch function.
cz_glitch (Prior): Prior on the base of convective zone glitch
function.
window_width (float): The number of delta_nu either side of nu_max over
which to average the helium glitch amplitude for the parameter
'he_amplitude'.
"""
def __init__(
self,
nu_max: DistLike,
delta_nu: DistLike,
teff: Optional[DistLike] = None,
epsilon: Optional[DistLike] = None,
seed: int = 0,
window_width: Union[str, float] = "full",
):
super().__init__(nu_max, delta_nu, teff, epsilon, seed, window_width)
self._prefix = "null"
self._divider = "."
units = {
"log_k": u.LogUnit(u.dimensionless_unscaled),
}
symbols = {"log_k": r"$\log(k)$"}
null_vars = ["nu", "nu_obs", "nu_bkg"]
for var_name in null_vars:
key = self._divider.join([self._prefix, var_name])
units[key] = self.units[var_name]
symbols[key] = self.symbols[var_name]
self.units.update(units)
self.symbols.update(symbols)
def __call__(
self,
n: ArrayLike,
nu: Optional[ArrayLike] = None,
nu_err: Optional[ArrayLike] = None,
n_pred: Optional[ArrayLike] = None,
):
"""Sample the model for given observables.
Args:
nu (:term:`array_like`, optional): Observed radial mode
frequencies.
nu_err (:term:`array_like`, optional): Gaussian observational
uncertainties (sigma) for nu.
pred (bool): If True, make predictions nu and nu_pred from n and
num_pred.
"""
# Same kernel function for both models
var = numpyro.param("kernel_var", self._kernel_var)
length = numpyro.param("kernel_length", self._kernel_length)
kernel = var * kernels.ExpSquared(length)
diag = 1e-6 if nu_err is None else nu_err**2 # No need for jitter
args = ("models", 2)
with dimension(*args):
with numpyro.plate(*args):
# Broadcast background function to both models
bkg_func = self.background()
# MODEL 0
with numpyro.handlers.scope(prefix=self._prefix, divider="."):
# Contain null model parameters in the null scope
def mean0(n):
return jnp.squeeze(bkg_func(n)[0])
# gp0 = GP(kernel, mean=mean0)
# dist0 = gp0.distribution(n, noise=nu_err)
gp0 = GaussianProcess(kernel, n, mean=mean0, diag=diag)
dist0 = gp0.numpyro_dist()
with dimension("n", n.shape[-1], coords=n):
nu0 = numpyro.sample("nu_obs", dist0, obs=nu)
numpyro.deterministic("nu_bkg", bkg_func(n)[0])
if n_pred is not None:
with dimension("n", n.shape[-1], coords=n):
# gp0.predict("nu", n)
numpyro.sample(
"nu",
gp0.condition(nu0, n).gp.numpyro_dist(),
)
with dimension("n_pred", n_pred.shape[-1], coords=n_pred):
# gp0.predict("nu_pred", n_pred)
numpyro.sample(
"nu_pred",
gp0.condition(nu0, n_pred).gp.numpyro_dist()
)
numpyro.deterministic("nu_bkg_pred", bkg_func(n_pred)[0])
# MODEL 1
he_glitch_func = self.he_glitch()
cz_glitch_func = self.cz_glitch()
def mean(n):
nu_bkg = jnp.squeeze(bkg_func(n)[1])
return nu_bkg + he_glitch_func(nu_bkg) + cz_glitch_func(nu_bkg)
# gp = GP(kernel, mean=mean)
# dist = gp.distribution(n, noise=nu_err)
gp = GaussianProcess(kernel, n, mean=mean, diag=diag)
dist = gp.numpyro_dist()
with dimension("n", n.shape[-1], coords=n):
nu = numpyro.sample("nu_obs", dist, obs=nu) # redefines nu!
nu_bkg = numpyro.deterministic("nu_bkg", bkg_func(n)[1])
numpyro.deterministic("dnu_he", he_glitch_func(nu_bkg))
numpyro.deterministic("dnu_cz", cz_glitch_func(nu_bkg))
if n_pred is not None:
with dimension("n", n.shape[-1], coords=n):
# gp.predict("nu", n)
numpyro.sample("nu", gp.condition(nu, n).gp.numpyro_dist())
with dimension("n_pred", n_pred.shape[-1], coords=n_pred):
# gp.predict("nu_pred", n_pred)
numpyro.sample(
"nu_pred",
gp.condition(nu, n_pred).gp.numpyro_dist()
)
nu_bkg = numpyro.deterministic(
"nu_bkg_pred", bkg_func(n_pred)[1]
)
numpyro.deterministic("dnu_he_pred", he_glitch_func(nu_bkg))
numpyro.deterministic("dnu_cz_pred", cz_glitch_func(nu_bkg))
# Other deterministics and priors
self._amplitude_prior(*self._glitch_amplitudes(nu))
# LIKELIHOOD
# Model comparison - if nu is not None, then nu0 == nu
logL0 = dist0.log_prob(nu0)
logL = dist.log_prob(nu)
numpyro.factor("obs", (logL0 + logL).sum())
# Log10 Bayes factor
numpyro.deterministic("log_k", (logL - logL0).sum() / np.log(10.0))