Source code for copulax._src.multivariate.mvt_gh

"""File containing the copulAX implementation of the multivariate
generalized hyperbolic (GH) distribution."""

import jax.numpy as jnp
import jax.nn as jnn
from jax import lax, random, jit, value_and_grad
from jax import Array
from jax.typing import ArrayLike
from jax.scipy import special

from copulax._src._distributions import NormalMixture
from copulax._src.special import log_kv
from copulax._src.typing import Scalar
from copulax._src.multivariate._utils import _multivariate_input
from copulax._src._utils import _resolve_key
from copulax._src.multivariate._shape import cov, _corr
from copulax._src.multivariate._normal_mixture import (
    prepare_sample_cov,
    forward_reparam,
    invert_gamma_to_z,
)
from copulax._src.univariate.gig import gig
from copulax._src.univariate.gh import GH
from copulax.special import kv

_POS_EPS = 1e-8
_POS_INIT = 1.0


[docs] class MvtGH(NormalMixture): r"""The multivariate generalized hyperbolic (GH) distribution is a generalization of the univariate GH distribution to d > 1 dimensions. This is a flexible, continuous 6-parameter family of distributions that can model a variety of data behaviors, including heavy tails and skewness. It contains a number of popular distributions as special cases, including the multivariate normal, multivariate student-t and multivariate skewed-T distributions. We adopt the parameterization used by McNeil et al. (2005) """ lamb: Array = None chi: Array = None psi: Array = None mu: Array = None gamma: Array = None sigma: Array = None def __init__( self, name="Mvt-GH", *, lamb=None, chi=None, psi=None, mu=None, gamma=None, sigma=None, ): """Initialize with optional stored parameters.""" super().__init__(name) self.lamb = ( jnp.asarray(lamb, dtype=float).reshape(()) if lamb is not None else None ) self.chi = ( jnp.asarray(chi, dtype=float).reshape(()) if chi is not None else None ) self.psi = ( jnp.asarray(psi, dtype=float).reshape(()) if psi is not None else None ) self.mu = jnp.asarray(mu, dtype=float) if mu is not None else None self.gamma = jnp.asarray(gamma, dtype=float) if gamma is not None else None self.sigma = jnp.asarray(sigma, dtype=float) if sigma is not None else None @property def _stored_params(self): """Return stored parameters dict if all are set, else None.""" if any( v is None for v in [self.lamb, self.chi, self.psi, self.mu, self.gamma, self.sigma] ): return None return { "lamb": self.lamb, "chi": self.chi, "psi": self.psi, "mu": self.mu, "gamma": self.gamma, "sigma": self.sigma, } def _classify_params(self, params: dict) -> tuple: """Classify parameters into scalar, vector, and shape groups.""" # return (lamb, chi, psi,), (mu, gamma), (sigma,) return super()._classify_params( params=params, scalar_names=("lamb", "chi", "psi"), vector_names=("mu", "gamma"), shape_names=("sigma",), symmetric_shape_names=("sigma",), ) def _params_dict( self, lamb: Scalar, chi: Scalar, psi: Scalar, mu: ArrayLike, gamma: ArrayLike, sigma: ArrayLike, ) -> dict: """Construct a normalized parameters dict from all six GH parameters.""" d: dict = { "lamb": lamb, "chi": chi, "psi": psi, "mu": mu, "gamma": gamma, "sigma": sigma, } return self._args_transform(d) def _params_to_tuple(self, params: dict) -> tuple: """Extract ``(lamb, chi, psi, mu, gamma, sigma)`` from a params dict.""" params = self._args_transform(params) return ( params["lamb"], params["chi"], params["psi"], params["mu"], params["gamma"], params["sigma"], )
[docs] def example_params(self, dim: int = 3, *args, **kwargs) -> dict: r"""Example parameters for the multivariate GH distribution. This is a six parameter family, defined by the scalar parameters `lamb`, `chi`, `psi`, the location vector `mu`, the skewness vector `gamma` and the shape matrix `sigma`. Args: dim: int, number of dimensions of the multivariate GH distribution. Default is 3. """ return self._params_dict( lamb=0.0, chi=1.0, psi=1.0, mu=jnp.zeros((dim, 1)), gamma=jnp.zeros((dim, 1)), sigma=jnp.eye(dim, dim), )
[docs] def support(self, params: dict = None) -> Array: """Return the support: ``(-inf, inf)`` per dimension.""" return super().support(params=params)
@staticmethod def _logpdf_core( stability: Scalar, x: Array, lamb: Scalar, chi: Scalar, psi: Scalar, mu: Array, gamma: Array, sigma: Array, ) -> Array: """Core log-PDF computation for the multivariate GH distribution. This is a static, pure function suitable for use inside ``value_and_grad``. Both the public ``_stable_logpdf`` and the ECME shape-parameter gradient call this. Args: stability: Small constant for numerical stability. x: Input data of shape (n, d). lamb: Shape parameter lamb. chi: Shape parameter chi. psi: Shape parameter psi. mu: Location vector of shape (d, 1). gamma: Skewness vector of shape (d, 1). sigma: Covariance matrix of shape (d, d). Returns: Array of log-density values with shape (n,). """ d: int = x.shape[1] sigma_inv: Array = jnp.linalg.inv(sigma) diff: Array = x - mu.flatten() Q: Array = chi + jnp.sum(diff @ sigma_inv * diff, axis=1) R: Array = psi + (gamma.T @ sigma_inv @ gamma).squeeze() QR: Array = Q * R H: Array = ((x - mu.T) @ sigma_inv @ gamma).flatten() log_det_sigma: Scalar = jnp.linalg.slogdet(sigma)[1] s: Scalar = lamb - d / 2.0 log_c: Scalar = ( 0.5 * lamb * lax.log((psi / (chi + stability)) + stability) - s * lax.log(R + stability) - 0.5 * d * lax.log(2 * jnp.pi) - 0.5 * log_det_sigma - log_kv(lamb, lax.sqrt(chi * psi)) ) logpdf: Array = ( log_c + log_kv(s, lax.sqrt(QR)) + H + 0.5 * s * lax.log(QR + stability) ) return logpdf def _stable_logpdf(self, stability: Scalar, x: ArrayLike, params: dict) -> Array: """Numerically stable log-PDF of the multivariate GH distribution. Args: stability: Small constant for numerical stability. x: Input data of shape (n, d). params: Distribution parameters. Returns: Array of log-density values with shape (n, 1). """ x, yshape, n, d = _multivariate_input(x) lamb, chi, psi, mu, gamma, sigma = self._params_to_tuple(params) logpdf = MvtGH._logpdf_core( stability, x, lamb, chi, psi, mu, gamma, sigma ) return logpdf.reshape(yshape) # sampling
[docs] def rvs(self, size: int, params: dict = None, key: ArrayLike = None) -> Array: """Generate random samples via the GIG normal-variance mixture. Args: size: Number of samples to draw. params: Distribution parameters. key: JAX random key. Returns: Array of shape (size, d). """ params = self._resolve_params(params) key = _resolve_key(key) lamb, chi, psi, mu, gamma, sigma = self._params_to_tuple(params) key, subkey = random.split(key) W: Array = gig.rvs( size=(size,), key=key, params={"lamb": lamb, "chi": chi, "psi": psi} ) return super()._rvs(key=subkey, n=size, W=W, mu=mu, gamma=gamma, sigma=sigma)
# stats
[docs] def stats(self, params: dict = None) -> dict: """Compute distribution statistics using GIG mixing moments.""" params = self._resolve_params(params) lamb, chi, psi, mu, gamma, sigma = self._params_to_tuple(params) gig_stats = gig.stats(params={"lamb": lamb, "chi": chi, "psi": psi}) return self._stats(w_stats=gig_stats, mu=mu, gamma=gamma, sigma=sigma)
# fitting — ECME algorithm (McNeil et al. 2005, Algorithm 3.14) @staticmethod @jit def _nll_shape_value_and_grad( shape_params: Array, mu: Array, gamma: Array, sigma: Array, x: Array ) -> tuple: """Compute NLL and gradient w.r.t. shape parameters [lamb, chi, psi]. This implements the ECME variant of CM-step 2 from McNeil et al. (2005, p. 83): "instead of maximizing Q2 we may maximize the original likelihood (3.33) with respect to lamb, chi and psi with the other parameters held fixed." Args: shape_params: Array of shape (3,) containing [lamb, chi, psi]. mu: Location vector of shape (d, 1). gamma: Skewness vector of shape (d, 1). sigma: Covariance matrix of shape (d, d). x: Data array of shape (n, d). Returns: Tuple of (nll_value, gradient) where gradient has shape (3,). """ def _nll(sp, mu, gamma, sigma, x): lamb, chi, psi = sp logpdf = MvtGH._logpdf_core(1e-30, x, lamb, chi, psi, mu, gamma, sigma) return -jnp.mean(logpdf) return value_and_grad(_nll)(shape_params, mu, gamma, sigma, x) @staticmethod def _em_body( carry: tuple, _: None, x: Array, log_det_S: Scalar, lr: float, shape_steps: int, ) -> tuple: """Single ECME iteration following McNeil et al. Algorithm 3.14. Notation follows the book (eq. 3.37): - delta_i = E[W_i^{-1} | X_i; theta^{[k]}] - eta_i = E[W_i | X_i; theta^{[k]}] Steps: (2) E-step — compute weights delta_i, eta_i from posterior W_i | X_i ~ GIG(lamb - d/2, chi + Q_i, psi + R) (3) Update gamma (symmetric model: gamma = 0) (4) Update mu, Psi, then Sigma with determinant constraint ``|Sigma| = |S|`` (5)-(6) CM-step 2 — ECME: maximize observed log-likelihood w.r.t. (lamb, chi, psi) via gradient descent Args: carry: Tuple of (lamb, chi, psi, mu, gamma, sigma). _: Unused scan input. x: Data array of shape (n, d) (static). log_det_S: ``log|S|`` where ``S`` is the sample covariance (static). lr: Shape learning rate (static). shape_steps: Number of inner gradient steps (static). Returns: Updated carry and None (no stacked output). """ eps: float = 1e-8 lamb, chi, psi, mu, gamma, sigma = carry n, d = x.shape[0], x.shape[1] # --- Step (2): E-step — posterior GIG expectations (eq. 3.36) --- # W_i | X_i ~ GIG(lamb - d/2, chi + Q_i, psi + gamma' Sigma^{-1} gamma) sigma_inv: Array = jnp.linalg.inv(sigma) diff: Array = x - mu.flatten() # (n, d) Q: Array = jnp.sum(diff @ sigma_inv * diff, axis=1) # (n,) R: Scalar = (gamma.T @ sigma_inv @ gamma).squeeze() # scalar lam_post: Scalar = lamb - d / 2.0 chi_post: Array = chi + Q # (n,) psi_post: Scalar = psi + R # scalar # delta_i = E[1/W_i | X_i] (eq. 3.37) delta: Array = jnp.clip( GH._gig_expected_inv_w(lam_post, chi_post, psi_post), eps, 1e10 ) # eta_i = E[W_i | X_i] (eq. 3.37) eta: Array = jnp.clip( GH._gig_expected_w(lam_post, chi_post, psi_post), eps, 1e10 ) delta_bar: Scalar = jnp.mean(delta) eta_bar: Scalar = jnp.mean(eta) x_bar: Array = jnp.mean(x, axis=0).reshape((d, 1)) # --- Step (3): gamma update (Algorithm 3.14, step 3) --- # gamma = [n^{-1} sum delta_i (X_bar - X_i)] / (delta_bar * eta_bar - 1) x_delta_bar: Array = jnp.mean( x * delta[:, None], axis=0 ).reshape((d, 1)) denom: Scalar = delta_bar * eta_bar - 1.0 denom = jnp.where(jnp.abs(denom) < eps, eps, denom) gamma = (delta_bar * x_bar - x_delta_bar) / denom # --- Step (4): mu, Psi, Sigma update (Algorithm 3.14, step 4) --- # mu = (n^{-1} sum delta_i X_i - gamma) / delta_bar mu = (x_delta_bar - gamma) / delta_bar # Psi = (1/n) sum delta_i (X_i - mu)(X_i - mu)' - eta_bar * gamma gamma' diff = x - mu.flatten() # (n, d) — recompute with updated mu psi_mat: Array = ( jnp.mean( delta[:, None, None] * (diff[:, :, None] * diff[:, None, :]), axis=0, ) - eta_bar * (gamma @ gamma.T) ) # PSD repair first, then determinant constraint (order matters: # _rm_incomplete changes eigenvalues which changes the determinant, # so we must apply it before the rescaling, not after). psi_mat = _corr._rm_incomplete(psi_mat, 1e-5) # Determinant constraint: |Sigma| = |S| (identifiability, McNeil p. 82) # Sigma = |S|^{1/d} * Psi / |Psi|^{1/d} log_det_psi: Scalar = jnp.linalg.slogdet(psi_mat)[1] scale: Scalar = jnp.exp((log_det_S - log_det_psi) / d) sigma = scale * psi_mat # --- Steps (5)-(6): CM-step 2 — ECME variant (McNeil p. 83) --- # Maximize original log-likelihood w.r.t. (lamb, chi, psi) # with (mu, gamma, Sigma) held fixed. def _shape_step(shape_carry, _): l, c, p = shape_carry _, g = MvtGH._nll_shape_value_and_grad( jnp.array([l, c, p]), mu, gamma, sigma, x ) g = jnp.nan_to_num(g, nan=0.0) l = l - lr * g[0] c = jnp.maximum(c - lr * g[1], eps) p = jnp.maximum(p - lr * g[2], eps) return (l, c, p), None (lamb, chi, psi), _ = lax.scan( _shape_step, (lamb, chi, psi), None, length=shape_steps ) return (lamb, chi, psi, mu, gamma, sigma), None def _fit_em( self, x: jnp.ndarray, lr: float = 0.1, maxiter: int = 100 ) -> dict: """Fit via ECME algorithm (McNeil et al. 2005, Algorithm 3.14). The EM algorithm treats the GIG mixing variable W as latent data. Steps (3)-(4) update (gamma, mu, Sigma) in closed form from the expected sufficient statistics, with Sigma constrained so that ``|Sigma| = |S|`` for identifiability. Steps (5)-(6) use the ECME variant: maximize the observed log-likelihood w.r.t. (lamb, chi, psi) via gradient descent. The entire loop is compiled via ``lax.scan`` for performance. Args: x: Input data array of shape (n, d). lr: Learning rate for shape parameter gradient steps. maxiter: Number of EM iterations. Returns: Fitted parameter dictionary. """ x, _, n, d = _multivariate_input(x) sample_mean: Array = jnp.mean(x, axis=0).reshape((d, 1)) sample_cov: Array = cov(x=x, method="pearson") log_det_S: Scalar = jnp.linalg.slogdet(sample_cov)[1] # Step (1): starting values (Algorithm 3.14, step 1) init_carry: tuple = ( jnp.array(0.0), # lamb jnp.array(1.0), # chi jnp.array(1.0), # psi sample_mean, # mu = X_bar jnp.zeros((d, 1)), # gamma = 0 sample_cov, # sigma = S ) shape_steps: int = 10 em_step = lambda carry, _: self._em_body( carry, _, x, log_det_S, lr, shape_steps ) final_carry, _ = lax.scan(em_step, init_carry, None, length=maxiter) lamb, chi, psi, mu, gamma, sigma = final_carry return self._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) _supported_methods = frozenset({"em", "ldmle"})
[docs] def fit( self, x: ArrayLike, method: str = "em", cov_method: str = "pearson", lr: float = 0.1, maxiter: int = 100, name: str = None, ): r"""Fit the multivariate GH distribution to data. Note: If you intend to jit wrap this function, ensure that ``method`` and ``cov_method`` are static arguments. Args: x: Input data of shape ``(n, d)``. method: Fitting method. One of: ``'em'`` — ECME algorithm (McNeil et al. 2005, Section 3.4.2); updates ``(mu, gamma, Sigma)`` in closed form via E-step sufficient statistics and ``(lamb, chi, psi)`` via gradient descent; generally more robust and faster-converging than LDMLE (default); ``'ldmle'`` — low-dimensional MLE via projected ADAM gradient descent, optimising ``(lamb, chi, psi, gamma)`` while deriving ``(mu, Sigma)`` analytically from sample moments. cov_method: Covariance estimator used for initialisation (both methods) and throughout the LDMLE path. Forwarded to :func:`copulax.multivariate.cov`. lr: Learning rate. Default ``0.1`` is tuned for EM; LDMLE may require a lower rate. maxiter: Maximum number of iterations. name: Optional custom name for the fitted instance. Returns: MvtGH: A fitted ``MvtGH`` instance. Raises: ValueError: If ``method`` is not one of the accepted strings listed above. """ self._check_method(method) if method == "em": params = self._fit_em(x=x, lr=lr, maxiter=maxiter) return self._fitted_instance(params, name=name) x_arr, _, _, d = _multivariate_input(x) sample_mean, L = prepare_sample_cov(x_arr, cov_method) params = self._general_fit( x=x_arr, d=d, loc=sample_mean, shape=L, lr=lr, maxiter=maxiter, ) return self._fitted_instance(params, name=name)
def _ldmle_inputs(self, d, x=None): """Generate initial parameter array and bounds for LD-MLE optimisation. Initial ``(lamb, chi, psi)`` match the ECME starting point of :py:meth:`_fit_em` (McNeil et al. 2005, Algorithm 3.14 step 1): ``(0, 1, 1)``. ``gamma`` is taken from the marginal sample skewness direction when data ``x`` is supplied, otherwise zero. The slot for ``gamma`` stores the unconstrained ``z`` vector driving the feasibility reparametrisation; the init inverts ``gamma0`` through the same map. """ lc = jnp.full((d + 3, 1), -jnp.inf) uc = jnp.full((d + 3, 1), jnp.inf) # Match ECME init (McNeil et al. 2005, Algorithm 3.14, step 1): # (lamb, chi, psi) = (0, 1, 1). The raw values invert # ``softplus + _POS_EPS`` so ``_reconstruct_ldmle_params`` # recovers chi = psi = _POS_INIT. lamb0 = jnp.asarray(0.0) pos0_raw_value = jnp.log(jnp.expm1(_POS_INIT - _POS_EPS)) pos0_raw = jnp.array([pos0_raw_value, pos0_raw_value]) if x is not None: x_std = jnp.std(x, axis=0) z_data = (x - jnp.mean(x, axis=0)) / jnp.where(x_std > 1e-8, x_std, 1.0) skew = jnp.mean(z_data ** 3, axis=0) gamma0 = skew * x_std * 0.25 sample_cov0 = _corr._rm_incomplete(cov(x=x, method="pearson"), 1e-5) else: gamma0 = jnp.zeros((d,)) sample_cov0 = jnp.eye(d) L0 = jnp.linalg.cholesky(sample_cov0) chi0 = jnn.softplus(pos0_raw[0]) + _POS_EPS psi0 = jnn.softplus(pos0_raw[1]) + _POS_EPS w_var0 = gig.stats( params={"lamb": lamb0, "chi": chi0, "psi": psi0} )["variance"] z0 = invert_gamma_to_z(gamma0, L0, w_var0) params0 = jnp.array([lamb0, *pos0_raw, *z0]).flatten() return {"lower": lc, "upper": uc}, params0 def _reconstruct_ldmle_params(self, params_arr, loc, shape): """Reconstruct lamb, chi, psi, mu, gamma, sigma from LD-MLE output. ``shape`` is ``L = chol(sample_cov_pd)``, precomputed in ``fit``. gamma is obtained via the feasibility reparametrisation so the reconstructed sigma is strictly PD by construction; no silent repair and no per-step matrix decomposition. """ L: Array = shape d: int = L.shape[0] scalars = lax.dynamic_slice_in_dim(params_arr, 0, 3) lamb, chi_, psi_ = scalars chi = jnn.softplus(chi_) + _POS_EPS psi = jnn.softplus(psi_) + _POS_EPS z: Array = lax.dynamic_slice_in_dim(params_arr, 3, d) gig_stats: dict = gig.stats(params={"lamb": lamb, "chi": chi, "psi": psi}) gamma, sigma = forward_reparam( z, L, gig_stats["mean"], gig_stats["variance"] ) mu: Array = loc - gig_stats["mean"] * gamma return lamb, chi, psi, mu, gamma, sigma
mvt_gh = MvtGH("Mvt-GH")