"""File containing the copulAX implementation of the multivariate normal
distribution."""
import jax.numpy as jnp
from jax import random
from jax import Array
from jax.typing import ArrayLike
from copulax._src._distributions import Multivariate
from copulax._src.multivariate._utils import _multivariate_input
from copulax._src._utils import _resolve_key
from copulax._src.multivariate._shape import cov
[docs]
class MvtNormal(Multivariate):
r"""The multivariate normal / Gaussian distribution is a
generalization of the univariate normal distribution to d > 1
dimensions.
https://en.wikipedia.org/wiki/Multivariate_normal_distribution
.. math::
f(x|\mu, \Sigma) = \frac{1}{(2\pi)^{n/2}|\Sigma|^{1/2}} \exp\left(-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)\right)
where :math:`\mu` is the mean vector and :math:`\Sigma` the
variance-covariance matrix of the data distribution.
"""
mu: Array = None
sigma: Array = None
def __init__(self, name="Mvt-Normal", *, mu=None, sigma=None):
"""Initialize with optional stored parameters ``mu`` and ``sigma``."""
super().__init__(name)
self.mu = jnp.asarray(mu, dtype=float) if mu is not None else None
self.sigma = jnp.asarray(sigma, dtype=float) if sigma is not None else None
@property
def _stored_params(self):
"""Return stored parameters dict if all are set, else None."""
if self.mu is None or self.sigma is None:
return None
return {"mu": self.mu, "sigma": self.sigma}
def _classify_params(self, params: dict) -> dict:
"""Classify parameters into vector and shape groups."""
return super()._classify_params(
params=params,
vector_names=("mu",),
shape_names=("sigma",),
symmetric_shape_names=("sigma",),
)
def _params_dict(self, mu: ArrayLike, sigma: ArrayLike) -> dict:
"""Construct a normalized parameters dict from ``mu`` and ``sigma``."""
d: dict = {"mu": mu, "sigma": sigma}
return self._args_transform(d)
def _params_to_tuple(self, params: dict) -> tuple:
"""Extract `(mu, sigma)` tuple from a parameters dict."""
params = self._args_transform(params)
return params["mu"], params["sigma"]
[docs]
def example_params(self, dim: int = 3, *args, **kwargs) -> dict:
r"""Example parameters for the multivariate normal distribution.
This is a two parameter family, defined by the mean / location
vector `mu` and the variance-covariance matrix `sigma`.
Args:
dim: int, number of dimensions of the multivariate normal
distribution. Default is 3.
"""
return self._params_dict(mu=jnp.zeros((dim, 1)), sigma=jnp.eye(dim, dim))
[docs]
def support(self, params: dict = None) -> Array:
"""Return the support of the distribution: `(-inf, inf)` per dimension."""
return super().support(params=params)
[docs]
def logpdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Log-probability density function of the multivariate normal.
Args:
x: Input data of shape (n, d).
params: Distribution parameters with keys 'mu' and 'sigma'.
Returns:
Array of log-density values with shape (n, 1).
"""
params = self._resolve_params(params)
x, yshape, n, d = _multivariate_input(x)
mu, sigma = self._params_to_tuple(params)
const: jnp.ndarray = -0.5 * (
d * jnp.log(2 * jnp.pi) + jnp.linalg.slogdet(sigma)[1]
)
sigma_inv: Array = jnp.linalg.inv(sigma)
Q: jnp.ndarray = self._calc_Q(x=x, mu=mu, sigma_inv=sigma_inv)
logpdf: jnp.ndarray = -0.5 * Q + const
return logpdf.reshape(yshape)
# sampling
[docs]
def rvs(self, size: int, params: dict = None, key=None) -> Array:
"""Generate random samples from the multivariate normal.
Args:
size: Number of samples to draw.
params: Distribution parameters.
key: JAX random key.
Returns:
Array of shape (size, d).
"""
params = self._resolve_params(params)
key = _resolve_key(key)
mu, sigma = self._params_to_tuple(params)
return random.multivariate_normal(
key=key, mean=mu.flatten(), cov=sigma, shape=(size,)
)
# stats
[docs]
def stats(self, params: dict = None) -> dict:
"""Compute distribution statistics (mean, median, mode, cov, skewness)."""
params = self._resolve_params(params)
mu, sigma = self._params_to_tuple(params)
return {
"mean": mu,
"median": mu,
"mode": mu,
"cov": sigma,
"skewness": jnp.zeros_like(mu),
}
# fitting
_supported_methods = frozenset({"mle"})
[docs]
def fit(
self, x: ArrayLike, sigma_method: str = "pearson", *args, name: str = None, **kwargs
) -> dict:
r"""Fit the multivariate normal to data via **closed-form** MLE:
:math:`\hat\mu = \operatorname{mean}(x)` (row-wise), and
:math:`\hat\Sigma` via :func:`copulax.multivariate.cov` using
the estimator chosen by ``sigma_method``.
Note:
If you intend to jit wrap this function, ensure that
``sigma_method`` is a static argument.
Args:
x: Input data of shape ``(n, d)``.
sigma_method: Covariance estimator name forwarded to
:func:`copulax.multivariate.cov` (default
``'pearson'``).
name: Optional custom name for the fitted instance.
Returns:
MvtNormal: A fitted ``MvtNormal`` instance.
"""
x, _, _, d = _multivariate_input(x)
mu: jnp.ndarray = jnp.mean(x, axis=0)
sigma: jnp.ndarray = cov(x=x, method=sigma_method)
params = self._params_dict(mu=mu, sigma=sigma)
return self._fitted_instance(params, name=name)
mvt_normal = MvtNormal("Mvt-Normal")