Source code for asterion.inference

"""The inference module contains a handy :class:`Inference` class which wraps
:mod:`numpyro` inference methods to facilitate the Bayesian workflow.
"""
from __future__ import annotations

import arviz as az
import astropy.units as u
import numpy as np
import numpyro, time, xarray, warnings

from collections import OrderedDict
from jax import random
from numpyro import handlers as hdl
from numpyro.infer import MCMC, NUTS, init_to_median, Predictive, SVI
from numpyro.distributions import constraints
from numpyro.contrib.nested_sampling import NestedSampler
from numpyro.infer.reparam import Reparam, CircularReparam
from typing import Optional, Union, List

from .models import Model

__all__ = [
    "Inference",
]


[docs]class Inference: """Perform inference on a given model. Args: model (Model): Model which predicts the asteroseismic mode frequencies. nu (:term:`array_like`): Asteroseismic mode frequencies. nu_err (:term:`array_like`, optional): Observational uncertainty on the asteroseismic mode frequencies. seed (int): Seed for pseudo-random number generation. Example: .. code-block:: python from asterion import Model, Inference n = [9, 10, 11, 12, 13] nu = [111.1, 122.2, 133.3, 144.4, 155.5] nu_err = 0.01 model = Model(...) # Construct the model here infer = Inference(model, n=n, nu=nu, nu_err=nu_err, seed=42) # Sample from prior predictive infer.prior_predictive(num_samples=2000) # Sample from posterior infer.sample() # Sample from posterior predictive infer.posterior_predictive() # Get data from inference data = infer.get_data() # Save data data.to_netcdf('inference_data.nc') Attributes: model (Model): Model with which to perform inference. nu (numpy.ndarray): Observed mode frequencies. nu_err (numpy.ndarray, optional): Uncertainty on observed mode frequencies. samples (dict, optional): Posterior samples. weighted_samples (dict, optional): Posterior weighted samples. sample_stats (dict, optional): Posterior sample statistics. prior_predictive_samples (dict, optional): Prior predictive samples. predictive_samples (dict, optional): Posterior predictive samples. sample_method (str, optional): Posterior sampling method. """ def __init__(self, model: Model, *, n, nu, nu_err=None, seed: int = 0): self._rng_key = random.PRNGKey(seed) self.model: Model = model self.n = np.asarray(n) self.n_pred = np.linspace(self.n[0], self.n[-1], 250) self.nu = np.asarray(nu) self.nu_err = None if nu_err is None else np.asarray(nu_err) self.samples: Optional[dict] = None self.weighted_samples: Optional[dict] = None self.sample_stats: Optional[dict] = None self.prior_predictive_samples: Optional[dict] = None self.predictive_samples: Optional[dict] = None # self.mcmc: Optional[MCMC] = None self._map_loss: Optional[dict] = None self._map_guide: Optional[dict] = None self._map_params: Optional[dict] = None # self.ns: Optional[NestedSampler] = None self.sample_metadata = {} def _mcmc_support_warnings(self): # Catch circular-like parameters trace = self.get_trace() for _, value in trace.items(): if value["type"] == "sample": if ( value["fn"].support == constraints.circular and value["fn"].support is not constraints.circular ): # Catches parameters with circular-like support warnings.warn( f"Parameter '{value['name']}' has " + "circular-like support but the distribution is " + "not circular. Consider changing its distribution " + "to numpyro.distributions.VonMises for better " + "performance during MCMC." ) def _get_dims(self): coords = {} dims = {} for _, v in self.get_trace(pred=True).items(): name = v["name"] if v["type"] == "dimension": coords[name] = v["value"] elif "dims" in v.keys(): dims[name] = v["dims"] return dims, coords
[docs] def get_trace(self, pred=False) -> OrderedDict: """[summary] Args: pred (bool): Whether to trace the predictive model or not. Default is False. Returns: OrderedDict: Model trace. """ rng_key, self._rng_key = random.split(self._rng_key) n_pred = self.n_pred if pred else None model = hdl.trace(hdl.seed(self.model, rng_key)) trace = model.get_trace( self.n, nu=self.nu, nu_err=self.nu_err, n_pred=n_pred ) return trace
[docs] def get_circ_var_names(self) -> List[str]: """[summary] Returns: list: Circular variable names in the model. """ var_names = [] trace = self.get_trace() for _, value in trace.items(): if value["type"] == "sample": if value["fn"].support is constraints.circular: var_names.append(value["name"]) return var_names
def _auto_reparam(self) -> numpyro.handlers.reparam: """Automatically reparameterise circular parameters.""" return hdl.reparam( config={k: CircularReparam() for k in self.get_circ_var_names()} ) def _init_handlers( self, handlers: list, reparam: Union[str, hdl.reparam] = "auto" ) -> list: """Initialise handlers with/without reparameterisation.""" # handlers = handlers.copy() if handlers is None: handlers = [] if reparam == "auto": handlers.append(self._auto_reparam()) elif reparam == "none" or reparam is None: pass else: handlers.append(reparam) return handlers def _expand_batch_dims(self, samples): """Expand the batch dimensions of samples (add leading dimension).""" new_samples = {} for k, v in samples.items(): # Hack because InferenceData assumes a chains dimension. new_samples[k] = v[None, ...] # Add a leading dimension to samples return new_samples
[docs] def init_mcmc( self, model, num_warmup: int = 1000, num_samples: int = 1000, num_chains: int = 1, sampler="NUTS", sampler_kwargs={}, **kwargs, ) -> MCMC: """Initialises the MCMC sampler. Args: model (callable): [desc] num_warmup (int): [description]. Defaults to 1000. num_samples (int): [description]. Defaults to 1000. num_chains (int): [description]. Defaults to 1. sampler (str, or numpyro.infer.mcmc.MCMCKernel): Choose one of ['NUTS'], or pass a numpyro mcmc kernel. sampler_kwargs (dict): Keyword arguments to pass to the chosen sampler. **kwargs: Keyword arguments to pass to mcmc instance. """ self._mcmc_support_warnings() if isinstance(sampler, str): sampler = sampler.lower() if sampler != "nuts": raise ValueError(f"Sampler '{sampler}' not supported.") target_accept_prob = sampler_kwargs.pop("target_accept_prob", 0.98) init_strategy = sampler_kwargs.pop( "init_strategy", lambda site=None: init_to_median(site=site, num_samples=100), ) step_size = sampler_kwargs.pop("step_size", 0.1) sampler = NUTS( model, target_accept_prob=target_accept_prob, init_strategy=init_strategy, step_size=step_size, **sampler_kwargs, ) # if num_chains > 1: # self.batch_ndims = 2 # I.e. two dims for chains then samples return MCMC( sampler, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, **kwargs, )
# def _update_args_kwargs(self, model_args: tuple, model_kwargs: dict): # self._model_args = model_args # self._model_kwargs.update(model_kwargs)
[docs] def run_mcmc( self, model, num_warmup: int = 1000, num_samples: int = 1000, num_chains: int = 1, sampler="NUTS", sampler_kwargs={}, extra_fields: tuple = (), init_params: dict = None, **kwargs, ) -> tuple: """Runs MCMC for a given set of model arguments. Args: model (callable): [desc] num_warmup (int): [description]. Defaults to 1000. num_samples (int): [description]. Defaults to 1000. num_chains (int): [description]. Defaults to 1. sampler (str): Choose one of ['NUTS'] sampler_kwargs (dict): Keyword arguments to pass to the chosen sampler. extra_fields (tuple): Extra fields to report in sample_stats. Defaults to (). init_params (dict): Initial parameter values prior to sampling. Defaults to None. **kwargs: Keyword arguments to pass to mcmc instance. Returns: tuple: [description] """ mcmc = self.init_mcmc( model, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, sampler=sampler, sampler_kwargs=sampler_kwargs, **kwargs, ) rng_key, self._rng_key = random.split(self._rng_key) mcmc.run( rng_key, self.n, nu=self.nu, nu_err=self.nu_err, extra_fields=extra_fields, init_params=init_params, ) self.samples = mcmc.get_samples(group_by_chain=True) self.sample_stats = mcmc.get_extra_fields(group_by_chain=True) self.sample_metadata = { "method": "MCMC", "sampler": sampler if isinstance(sampler, str) else sampler.__name__, }
[docs] def init_nested( self, model: Model, num_live_points: int = 50, max_samples: int = 50000, sampler: str = "multi_ellipsoid", **kwargs, ) -> NestedSampler: """[summary] Args: model (Model): [description] num_live_points (int): [description]. Defaults to 50. max_samples (int): [description]. Defaults to 50000. sampler (str): [description]. Defaults to "multi_ellipsoid". **kwargs: Keyword arguments to pass to nested sampler instance. Returns: numpyro.contrib.nested_sampling.NestedSampler: [description] """ depth = kwargs.pop("depth", 5) return NestedSampler( model, num_live_points=num_live_points, max_samples=max_samples, sampler_name=sampler, depth=depth, **kwargs, )
[docs] def run_nested( self, model: Model, num_live_points: int = 50, num_samples: int = 1000, max_samples: int = 50000, sampler: str = "multi_ellipsoid", **kwargs, ): """[summary] Args: model (Model): [description] num_live_points (int): [description]. Defaults to 100. num_samples (int): [description]. Defaults to 1000. max_samples (int): [description]. Defaults to 100000. sampler (str): [description]. Defaults to "multi_ellipsoid". **kwargs: Keyword arguments to pass to nested sampler instance. """ nested_sampler = self.init_nested( model, num_live_points=num_live_points, max_samples=max_samples, sampler=sampler, **kwargs, ) key1, key2, self._rng_key = random.split(self._rng_key, 3) print( f"Running nested sampling using the '{sampler}' sampler " + f"with {num_live_points} live points " + f"and {max_samples} maximum samples..." ) start_time = time.time() nested_sampler.run(key1, self.n, nu=self.nu, nu_err=self.nu_err) end_time = time.time() print(f"Completed in {end_time - start_time:.1f} seconds.") samples = nested_sampler.get_samples(key2, num_samples) weighted_samples = nested_sampler.get_weighted_samples()[0] num_weighted_samples = nested_sampler._results.num_samples logX = nested_sampler._results.log_X[:num_weighted_samples] logL = nested_sampler._results.log_L_samples[:num_weighted_samples] logp = nested_sampler._results.log_p[:num_weighted_samples] eff = nested_sampler._results.sampler_efficiency[:num_weighted_samples] self.samples = self._expand_batch_dims(samples) self.weighted_samples = self._expand_batch_dims(weighted_samples) self.sample_metadata = { "method": "nested", "sampler": sampler, "num_likelihood_evals": int( nested_sampler._results.num_likelihood_evaluations ), "num_weighted_samples": int(num_weighted_samples), "logZ": float(nested_sampler._results.logZ), # evidence "logZ_err": float(nested_sampler._results.logZerr), "ESS": float(nested_sampler._results.ESS), } self.sample_stats = { "logX": logX, "logL": logL, "lp": logp, # log joint posterior (i.e. sample weights) "sampler_efficiency": eff, }
def _add_handlers_to_model( self, handlers: Optional[list] = None, reparam: Union[str, Reparam] = "auto", ) -> Model: """[summary] Args: handlers (list, optional): [description]. Defaults to None. reparam (str, or numpyro.infer.reparam.Reparam): [description]. Defaults to 'auto'. Returns: Model: [description] """ handlers = self._init_handlers(handlers, reparam=reparam) model = self.model for h in handlers: model = h(model) return model
[docs] def sample( self, num_samples: int = 1000, method: str = "nested", handlers: Optional[list] = None, reparam="auto", **kwargs, ): """[summary] Args: num_samples (int): Number of samples after warmup. method (str): Sampling method, choose from ['mcmc', 'nested']. handlers (list, optional): Handlers to apply to the model during inference. reparam (str, or numpyro.infer.reparam.Reparam): Default is 'auto' will automatically reparameterise the model to improve sampling during MCMC. **kwargs: Keyword arguments to pass to the sampling method. """ # Add handlers to model model = self._add_handlers_to_model(handlers=handlers, reparam=reparam) if method == "mcmc": self.run_mcmc( model, num_samples=num_samples, **kwargs, ) elif method == "nested": self.run_nested(model, num_samples=num_samples, **kwargs) else: raise ValueError( f"Method '{method}' not one of ['mcmc', 'nested']." )
# TODO: diagnostics and warnings
[docs] def find_map( self, num_steps: int = 10000, handlers: Optional[list] = None, reparam: Union[str, hdl.reparam] = "auto", svi_kwargs: dict = {}, ): """EXPERIMENTAL: find MAP. Args: num_steps (int): [description]. Defaults to 10000. handlers (list, optional): [description]. Defaults to None. reparam (str, or numpyro.handlers.reparam): [description]. Defaults to 'auto'. svi_kwargs (dict): [description]. Defaults to {}. """ model = self._add_handlers_to_model(handlers=handlers, reparam=reparam) guide = numpyro.infer.autoguide.AutoDelta(model) optim = svi_kwargs.pop("optim", numpyro.optim.Minimize()) loss = svi_kwargs.pop("loss", numpyro.infer.Trace_ELBO()) map_svi = SVI(model, guide, optim, loss=loss, **svi_kwargs) rng_key, self._rng_key = random.split(self._rng_key) map_result = map_svi.run( rng_key, num_steps, self.n, nu=self.nu, nu_err=self.nu_err ) self._map_loss = map_result.losses self._map_guide = map_svi.guide self._map_params = map_result.params
[docs] def predictive( self, n, nu=None, nu_err=None, n_pred=None, **kwargs ) -> dict: """[summary] Args: model_args (tuple): Positional arguments to pass to the model callable. model_kwargs (dict): Keyword arguments to pass to the model callable. **kwargs: Kwargs to pass to Predictive. Returns: dict: [description] """ posterior_samples = kwargs.pop("posterior_samples", None) num_samples = kwargs.pop("num_samples", None) batch_ndims = kwargs.pop("batch_ndims", 2) return_sites = kwargs.pop("return_sites", None) posterior = {} if posterior_samples is None else posterior_samples if return_sites is None: trace = self.get_trace(pred=True) return_sites = [] for k, site in trace.items(): # Only return non-observed sample sites not in samples and # all deterministic sites. if site["type"] == "sample": if not site["is_observed"] and k not in posterior: return_sites.append(k) elif site["type"] == "deterministic": return_sites.append(k) predictive = Predictive( self.model, posterior_samples=posterior_samples, num_samples=num_samples, return_sites=return_sites, batch_ndims=batch_ndims, **kwargs, ) if predictive.batch_ndims == 0: # Fix bug in Predictive for computing batch shape predictive._batch_shape = () rng_key, self._rng_key = random.split(self._rng_key) samples = predictive(rng_key, n, nu=nu, nu_err=nu_err, n_pred=n_pred) # self._update_args_kwargs(model_args, model_kwargs) return samples
[docs] def prior_predictive(self, num_samples: int = 1000, **kwargs): """[summary] Args: num_samples (int): Number of samples to take from the prior. **kwargs: Keyword arguments to pass to Predictive instance. """ self.prior_predictive_samples = self.predictive( self.n, n_pred=self.n_pred, num_samples=num_samples, **kwargs )
[docs] def posterior_predictive(self, **kwargs): """[summary] Args: **kwargs: Keyword arguments to pass to Predictive instance. """ self.predictive_samples = self.predictive( self.n, nu=self.nu, nu_err=self.nu_err, n_pred=self.n_pred, posterior_samples=self.samples, **kwargs, )
[docs] def map_predictive(self, **kwargs) -> xarray.Dataset: """EXPERIMENTAL: Get predictive from MAP. Args: model_args (tuple): [description]. Defaults to (). model_kwargs (dict): [description]. Defaults to {}. Returns: xarray.Dataset: [description] """ batch_ndims = 0 guide = self._map_guide params = self._map_params map_pred = self.predictive( self.n, n_pred=self.n_pred, guide=guide, params=params, num_samples=1, batch_ndims=batch_ndims, **kwargs, ) dims, coords = self._get_dims() ds = {} for k, v in map_pred.items(): ds[k] = xarray.DataArray(v, dims=dims.get(k)) return xarray.Dataset(ds, coords=coords)
[docs] def get_data(self) -> az.InferenceData: """Get inference data. Returns: arviz.InferenceData: Inference data. """ dims, coords = self._get_dims() observed_data = {"nu": self.nu} constant_data = { "nu_err": np.zeros_like(self.nu) if self.nu_err is None else self.nu_err } data = az.from_dict( posterior=self.samples, prior_predictive=self.prior_predictive_samples, posterior_predictive=self.predictive_samples, sample_stats=self.sample_stats, observed_data=observed_data, constant_data=constant_data, dims=dims, coords=coords, ) # Add sample metadata info if self.sample_stats is not None: data.sample_stats.attrs.update(self.sample_metadata) if self.sample_metadata.get("method", None) == "nested": # The weights are just the logP in sampler stats with warnings.catch_warnings(): # Catch user warnings warnings.filterwarnings("ignore", category=UserWarning) data.add_groups( {"weighted_posterior": self.weighted_samples}, coords=coords, dims=dims, ) circ_var_names = self.get_circ_var_names() # Add unit, symbol and circular attributes to groups for group in data.groups(): for key in data[group].keys(): sub_key = key if key.endswith("_pred"): sub_key = key[:-5] else: sub_key = key unit = self.model.units.get(sub_key, u.Unit()) sym = self.model.symbols.get(sub_key, "") circ = 1 if sub_key in circ_var_names else 0 data[group][key].attrs["unit"] = unit.to_string() data[group][key].attrs["symbol"] = sym data[group][key].attrs["is_circular"] = circ return data