Source code for asterion.messengers
"""Messengers module.
"""
import numpy as np
from numpyro.primitives import Messenger, apply_stack
from typing import Callable, Optional, Dict
from numpy.typing import ArrayLike
[docs]class dimension(Messenger):
"""Context manager for a model dimension.
Args:
name (str): Name of the dimension.
size (int): Size of the dimension.
coords (:term:`array_like`, optional): Coordinates for points in the
dimension. Defaults to :code:`np.arange(size)`.
dim (int, optional): Where to place the dimension. Defaults to
:code:`-1` which corresponds to the rightmost dimension. Must be
negative.
"""
def __init__(
self,
name: str,
size: int,
coords: Optional[ArrayLike] = None,
dim: Optional[ArrayLike] = None,
):
self.name: str = name
self.size: int = size
self.dim: int = -1 if dim is None else dim
"""int: Location in which to insert the dimension."""
assert self.dim < 0
if coords is None:
coords = np.arange(self.size)
self.coords: np.ndarray = np.array(coords)
"""numpy.ndarray: Coordinates for the dimension."""
msg = self._get_message()
apply_stack(msg)
super().__init__()
def _get_message(self) -> dict:
msg = {
"name": self.name,
"type": "dimension",
"dim": self.dim,
"value": self.coords,
}
return msg
def __enter__(self) -> dict:
super().__enter__()
return self._get_message()
[docs] def process_message(self, msg: dict):
"""Process the message.
Args:
msg (dict): Message.
Raises:
ValueError: If the corresponding dimension of the site is of
incorrect size.
"""
if msg["type"] not in ("param", "sample", "deterministic"):
# We don't add dimensions to dimensions
return
if msg["value"] is None:
shape = ()
if "fn" in msg.keys():
sample_shape = msg["kwargs"].get("sample_shape", ())
shape = msg["fn"].shape(sample_shape)
else:
shape = msg["value"].shape
if "dims" not in msg.keys():
dims = [f"{msg['name']}_dim_{i}" for i in range(len(shape))]
msg["dims"] = dims
if "dim_stack" not in msg.keys():
msg["dim_stack"] = []
dim = self.dim
while dim in msg["dim_stack"]:
dim -= 1
msg["dim_stack"].append(dim)
msg["dims"][dim] = self.name
if shape[dim] != self.size:
raise ValueError(
f"Dimension {dim} of site '{msg['name']}' should have "
+ f"length {self.size}"
)