"""File containing the copulAX implementation of the multivariate
student-t distribution."""
import jax.numpy as jnp
import jax.nn as jnn
from jax import lax, random, jit
from jax import Array
from jax.typing import ArrayLike
from jax.scipy import special
from copulax._src._distributions import NormalMixture
from copulax._src.typing import Scalar
from copulax._src.multivariate._utils import _multivariate_input
from copulax._src._utils import _resolve_key
from copulax._src.stats import kurtosis
from copulax._src.multivariate._shape import cov
from copulax._src.univariate.ig import ig
_NU_EPS = 1e-8
[docs]
class MvtStudentT(NormalMixture):
r"""The multivariate student-t distribution is a generalization of
the univariate student-t distribution to d > 1 dimensions.
https://en.wikipedia.org/wiki/Multivariate_t-distribution
:math:`\mu` is the mean vector and :math:`\sigma` the shape matrix,
which for this parameterization is not the variance-covariance
matrix of the data distribution. :math:`\nu` is the degrees of
freedom parameter.
"""
nu: Array = None
mu: Array = None
sigma: Array = None
def __init__(self, name="Mvt-Student-T", *, nu=None, mu=None, sigma=None):
"""Initialize with optional stored parameters ``nu``, ``mu``, and ``sigma``."""
super().__init__(name)
self.nu = jnp.asarray(nu, dtype=float).reshape(()) if nu is not None else None
self.mu = jnp.asarray(mu, dtype=float) if mu is not None else None
self.sigma = jnp.asarray(sigma, dtype=float) if sigma is not None else None
@property
def _stored_params(self):
"""Return stored parameters dict if all are set, else None."""
if self.nu is None or self.mu is None or self.sigma is None:
return None
return {"nu": self.nu, "mu": self.mu, "sigma": self.sigma}
def _classify_params(self, params: dict) -> dict:
"""Classify parameters into scalar, vector, and shape groups."""
return super()._classify_params(
params=params,
scalar_names=("nu",),
vector_names=("mu",),
shape_names=("sigma",),
symmetric_shape_names=("sigma",),
)
def _params_dict(self, nu: Scalar, mu: ArrayLike, sigma: ArrayLike) -> dict:
"""Construct a normalized parameters dict from ``nu``, ``mu``, and ``sigma``."""
d: dict = {"nu": nu, "mu": mu, "sigma": sigma}
return self._args_transform(d)
def _params_to_tuple(self, params: dict) -> tuple:
"""Extract ``(nu, mu, sigma)`` tuple from a parameters dict."""
params = self._args_transform(params)
return params["nu"], params["mu"], params["sigma"]
[docs]
def example_params(self, dim: int = 3, *args, **kwargs) -> dict:
r"""Example parameters for the multivariate student-t distribution.
This is a three parameter family, defined by the degrees of
freedom scalar ``nu``, the mean / location vector ``mu`` and the
shape matrix ``sigma``.
Args:
dim: int, number of dimensions of the multivariate student-t
distribution. Default is 3.
"""
return self._params_dict(
nu=2.5, mu=jnp.zeros((dim, 1)), sigma=jnp.eye(dim, dim)
)
[docs]
def support(self, params: dict = None) -> Array:
"""Return the support: ``(-inf, inf)`` per dimension."""
return super().support(params=params)
def _stable_logpdf(self, stability: Scalar, x: ArrayLike, params: dict) -> Array:
"""Numerically stable log-PDF of the multivariate student-t.
Args:
stability: Small constant added for numerical stability.
x: Input data of shape (n, d).
params: Distribution parameters.
Returns:
Array of log-density values with shape (n, 1).
"""
x, yshape, n, d = _multivariate_input(x)
nu, mu, sigma = self._params_to_tuple(params)
s: Scalar = 0.5 * (nu + d)
sigma_inv: Array = jnp.linalg.inv(sigma)
Q: Array = self._calc_Q(x=x, mu=mu, sigma_inv=sigma_inv)
log_det_sigma: Scalar = jnp.linalg.slogdet(sigma)[1]
logpdf: Array = (
lax.lgamma(s)
- lax.lgamma(0.5 * nu)
- 0.5 * d * jnp.log(jnp.pi * nu + stability)
- 0.5 * log_det_sigma
- s * jnp.log1p(Q / nu)
)
return logpdf.reshape(yshape)
# sampling
[docs]
def rvs(self, size: int, params: dict = None, key: ArrayLike = None) -> Array:
"""Generate random samples via the normal-variance mixture.
Sampling uses an inverse-gamma mixing variable W and the
base class normal-variance mixture sampler.
Args:
size: Number of samples to draw.
params: Distribution parameters.
key: JAX random key.
Returns:
Array of shape (size, d).
"""
params = self._resolve_params(params)
key = _resolve_key(key)
nu, mu, sigma = self._params_to_tuple(params)
key, subkey = random.split(key)
W: Array = ig.rvs(
size=(size,), key=key, params={"alpha": 0.5 * nu, "beta": 0.5 * nu}
)
gamma: Array = jnp.zeros_like(mu)
return super()._rvs(key=subkey, n=size, W=W, mu=mu, gamma=gamma, sigma=sigma)
# stats
[docs]
def stats(self, params: dict = None) -> dict:
"""Compute distribution statistics (mean, median, mode, cov, skewness)."""
params = self._resolve_params(params)
nu, mu, sigma = self._params_to_tuple(params)
mean: Array = jnp.where(nu > 1, mu, jnp.full_like(mu, jnp.nan))
scale: Scalar = jnp.where(nu > 2, nu / (nu - 2), jnp.nan)
cov: Array = scale * sigma
return {
"mean": mean,
"median": mu,
"mode": mu,
"cov": cov,
"skewness": jnp.zeros_like(mu),
}
# fitting
def _ldmle_inputs(self, d, x=None):
"""Generate initial parameter array and bounds for LD-MLE optimization."""
lc = jnp.full((1, 1), -jnp.inf)
uc = jnp.full((1, 1), jnp.inf)
# MoM: average marginal excess kurtosis -> nu = 4 + 6/kappa
kappas = jnp.array([kurtosis(x[:, j], fisher=True) for j in range(d)])
kappa = jnp.mean(kappas)
nu0 = jnp.clip(4.0 + 6.0 / jnp.maximum(kappa, 0.06), 2.5, 100.0)
raw_nu0 = jnp.log(jnp.expm1(nu0))
params0: jnp.ndarray = jnp.array([raw_nu0])
return {"lower": lc, "upper": uc}, params0
def _reconstruct_ldmle_params(self, params_arr, loc, shape):
"""Reconstruct nu, mu, sigma from LD-MLE optimizer output."""
raw_nu: Scalar = params_arr.reshape(())
nu: Scalar = jnn.softplus(raw_nu) + _NU_EPS
scale: Scalar = jnp.where(nu > 2, (nu - 2) / nu, 1.0)
return nu, loc, scale * shape
_supported_methods = frozenset({"mle"})
mvt_student_t = MvtStudentT("Mvt-Student-T")