"""The results module contains functions for inspecting, summarising, and
tabulating inference data.
"""
from __future__ import annotations
import arviz as az
import astropy.units as u
import numpy as np
import pandas as pd
import xarray
from astropy.table import Table
from typing import Optional, Dict, Union, List, Tuple
__all__ = [
"get_dims",
"get_var_names",
"get_summary",
"get_table",
]
def _get_dim_vars(
data, group: str = "posterior"
) -> Dict[Tuple[str], List[str]]:
"""Get the dimensions and their variable names."""
dim_vars = {}
for k in data[group].data_vars.keys():
ns = 2 # Don't count chains and draws
if group in ["observed_data", "constant_data"]:
ns = 0 # These groups don't have chains and draws
dims = data[group][k].dims[ns:]
if dims not in dim_vars.keys():
dim_vars[dims] = []
if k not in dim_vars[dims]:
dim_vars[dims].append(k)
return dim_vars
[docs]def get_dims(data, group: str = "posterior") -> List[Tuple[str]]:
"""Get available dimension groups for a given inference data group.
Args:
data (arviz.InferenceData): Inference data object.
group (str): Inference data group.
Returns:
list[tuple]: [description]
"""
# return list(self._dim_vars.keys())
dim_vars = _get_dim_vars(data, group=group)
return list(dim_vars.keys())
[docs]def get_var_names(
data, group: str = "posterior", dims: Union[str, Tuple[str]] = "all"
) -> List[str]:
"""Get var names for a given group and dimensions.
Args:
data (arviz.InferenceData): Inference data object.
group (str): Inference data group.
dims (str, or tuple[str]): Dimensions by which to group variables.
If 'all', returns variable names for all model dimensions. If a
tuple of dimension names, returns variable names in that dimension
group.
Returns:
list[str]: Variable names for a given group and dimensions.
"""
if dims == "all":
var_names = list(data[group].data_vars.keys())
else:
dim_vars = _get_dim_vars(data, group=group)
var_names = list(dim_vars[dims])
return var_names
def _validate_var_names(
data,
group: str = "posterior",
var_names: Optional[List[str]] = None,
dims: Union[str, Tuple[str]] = "all",
) -> List[str]:
"""Validate variable names."""
available_vars = get_var_names(data, group=group, dims=dims)
if var_names is None:
# var_names = list(self.data[group].data_vars.keys())
var_names = available_vars
if len(var_names) == 0:
if dims == "all":
msg = f"No variables exist in group '{group}'."
else:
msg = (
f"No variables exist for dims {dims}"
+ f" in group '{group}'."
)
raise ValueError(msg)
elif dims != "all":
subset = set(var_names)
available = set(available_vars)
if not available.intersection(subset) == subset:
diff = subset.difference(available)
raise ValueError(
f"Variable name(s) {diff} not available"
+ f" for dims {dims} in group '{group}'."
)
return var_names
[docs]def get_summary(
data,
group: str = "posterior",
var_names: Optional[List[str]] = None,
**kwargs,
) -> Union[xarray.Dataset, pd.DataFrame]:
"""Get a summary of the inference data for a chosen group.
Args:
data (arviz.InferenceData): Inference data object.
group (str): [description]. Defaults to 'posterior'.
var_names (list, optional): [description]. Defaults to None (all
variable names)
**kwargs: Keyword arguments to pass to :func:`arviz.summary`.
Returns:
xarray.Dataset, or pandas.DataFrame: Summary of inference data.
See Also:
:func:`arviz.summary`: The function for which this wraps.
"""
fmt = kwargs.pop("fmt", "xarray")
round_to = kwargs.pop("round_to", "none")
stat_funcs = {
"mean": np.mean,
"sd": lambda x: np.std(x, ddof=1),
"16th": lambda x: np.quantile(x, 0.16),
"50th": np.median,
"84th": lambda x: np.quantile(x, 0.84),
}
stat_funcs = kwargs.pop("stat_funcs", stat_funcs)
extend = kwargs.pop("extend", False)
kind = kwargs.pop("kind", "stats") # default just stats, no diagnostics
var_names = _validate_var_names(data, group=group, var_names=var_names)
# self.data[group]
circ_var_names = [
k for k in var_names if data[group][k].attrs.get("is_circular", 0) == 1
]
# circ_var_names = [i for i in self.circ_var_names if i in var_names]
circ_var_names = kwargs.pop("circ_var_names", circ_var_names)
summary = az.summary(
data,
group=group,
var_names=var_names,
fmt=fmt,
round_to=round_to,
stat_funcs=stat_funcs,
extend=extend,
circ_var_names=circ_var_names,
kind=kind,
**kwargs,
)
# Check for duplicated metric names which are a pain to deal with.
unique, counts = np.unique(summary.metric, return_counts=True)
is_dup = counts > 1
if is_dup.any():
dup = list(unique[is_dup])
raise ValueError(f"Metric names {dup} are duplicated.")
return summary
[docs]def get_table(
data,
*,
dims: Tuple[str],
group: str = "posterior",
var_names: Optional[List[str]] = None,
fmt: str = "pandas",
round_to: Union[str, int] = "auto",
**kwargs,
) -> Union[pd.DataFrame, Table]:
"""Get a table of results for parameters in data corresponding to a chosen
model dimension. Two-dimensional tables
Args:
data (arviz.InferenceData): Inference data object.
dims (tuple[str]): The parameter dimensions for the table. E.g. pass
() to return a table of 0-dimensional parameters in data, or pass
('n',) for 1-dimensional parameters along dimension 'n'.
group (str): Group in data to tabulate. Defaults to
'posterior'.
var_names (list[str], optional): Variable names in data to show in
table. By default all variables along the chosen dim are shown.
Defaults to None.
fmt (str): Table format, one of ['pandas', 'astropy'].
Defaults to 'pandas'.
round_to (str, or int): Precision of table data. Defaults to
'auto' which chooses the precision for each variable based on the
error on the mean.
**kwargs: Keyword arguments to pass to :func:`get_summary`.
Returns:
pandas.DataFrame, or astropy.table.Table]: [description]
"""
var_names = _validate_var_names(
data, group=group, var_names=var_names, dims=dims
)
summary = get_summary(data, group=group, var_names=var_names, **kwargs)
table = summary[var_names].to_dataframe()
if round_to == "auto":
# Rounds to the error on the mean. I.e. if mean_err is in range
# (0.01, 0.1] then the metrics are rounded to 2 decimal places
if "sd" not in summary.metric:
raise ValueError(
"Automatic rounding requires the standard " + "deviation 'sd'."
)
mean_err = table.loc["sd"] / np.sqrt(data[group].draw.size)
precision = np.log10(mean_err).astype(int) - 1
if isinstance(table.index, pd.MultiIndex):
precision = (
precision.min()
) # Choose min precision = max decimal precision
table = table.round(-precision)
elif round_to != "none":
table = table.round(round_to)
if fmt == "astropy":
# units = {k: v for k, v in self.units.items() if k in var_names}
units = {
k: u.Unit(data[group][k].attrs.get("unit", "")) for k in var_names
}
table = Table.from_pandas(table.reset_index(), units=units)
return table