Source code for copulax._src.univariate.skewed_t

"""File containing the copulAX implementation of the skewed-T distribution."""

import jax.numpy as jnp
import jax.nn as jnn
from jax import lax, custom_vjp, random, jit, value_and_grad
from jax import Array
from jax.typing import ArrayLike
from copy import deepcopy

from copulax._src._distributions import Univariate
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src.special import log_kv_plus_s_log_r
from copulax._src._utils import _resolve_key
from copulax._src.univariate._cdf import _cdf, cdf_bwd, _cdf_fwd
from copulax._src.optimize import projected_gradient
from copulax._src.univariate.ig import ig
from copulax._src.univariate._normal_mixture import (
    forward_reparam_1d,
    invert_gamma_to_z_1d,
    mean_variance_stats,
)
from copulax._src.univariate._rvs import mean_variance_sampling
from copulax._src.univariate.gh import GH

_NU_EPS = 1e-8
_NU_INIT = 4.0
_NU_LDMLE_MIN = 4.0 + 1e-3


[docs] class SkewedT(Univariate): r"""The skewed-t distribution is a four-parameter continuous generalisation of the Student's t that admits asymmetry via a mean-variance mixture of normals with an inverse-gamma mixing variable. It arises as the limiting case of the generalised hyperbolic distribution with :math:`\psi \to 0` and :math:`\lambda = -\nu/2`, and collapses to the symmetric Student's t at :math:`\gamma = 0`. The four-parameter McNeil et al (2005) specification is used. The PDF is .. math:: f(x | \nu, \mu, \sigma, \gamma) = \frac{2^{1 - s}}{\Gamma(\nu/2)\,\sqrt{\nu \pi}\,\sigma}\; K_s(r)\, r^{s}\, \frac{\exp\!\bigl(\gamma (x - \mu) / \sigma^2\bigr)} {\bigl(1 + Q/\nu\bigr)^{s}}, \qquad s = \frac{\nu + 1}{2}, \quad Q = \frac{(x - \mu)^2}{\sigma^2}, \quad r = \sqrt{(\nu + Q)\, \gamma^2 / \sigma^2} where :math:`K_s` is the modified Bessel function of the second kind, :math:`\nu > 0` is the degrees of freedom, :math:`\mu \in \mathbb{R}` is the location, :math:`\sigma > 0` is the scale, and :math:`\gamma \in \mathbb{R}` is the skewness parameter. https://en.wikipedia.org/wiki/Skew-t_distribution """ nu: Array = None mu: Array = None sigma: Array = None gamma: Array = None def __init__(self, name="Skewed-T", *, nu=None, mu=None, sigma=None, gamma=None): """Initialize the Skewed-T distribution. Args: name: Display name for the distribution. nu: Degrees of freedom parameter. mu: Location parameter. sigma: Scale parameter. gamma: Skewness parameter (zero recovers the Student-T). """ super().__init__(name) self.nu = jnp.asarray(nu, dtype=float).reshape(()) if nu is not None else None self.mu = jnp.asarray(mu, dtype=float).reshape(()) if mu is not None else None self.sigma = ( jnp.asarray(sigma, dtype=float).reshape(()) if sigma is not None else None ) self.gamma = ( jnp.asarray(gamma, dtype=float).reshape(()) if gamma is not None else None ) @property def _stored_params(self): """Return stored parameters if all are set, else None.""" if any(v is None for v in [self.nu, self.mu, self.sigma, self.gamma]): return None return {"nu": self.nu, "mu": self.mu, "sigma": self.sigma, "gamma": self.gamma} @classmethod def _params_dict(cls, nu: Scalar, mu: Scalar, sigma: Scalar, gamma: Scalar) -> dict: """Create a parameter dictionary from nu, mu, sigma, and gamma values.""" d: dict = {"nu": nu, "mu": mu, "sigma": sigma, "gamma": gamma} return cls._args_transform(d) @staticmethod def _params_to_tuple(params: dict) -> tuple: """Extract (nu, mu, sigma, gamma) from the parameter dictionary.""" params = SkewedT._args_transform(params) return params["nu"], params["mu"], params["sigma"], params["gamma"] @staticmethod def _params_to_array(params: dict) -> Array: """Convert the parameter dictionary to a flat array.""" return jnp.asarray(SkewedT._params_to_tuple(params)).flatten() @classmethod def _support(cls, *args, **kwargs) -> Array: """Return the support ``[-inf, inf]``.""" return jnp.array([-jnp.inf, jnp.inf])
[docs] def example_params(self, *args, **kwargs) -> dict: return self._params_dict(nu=4.5, mu=0.0, sigma=1.0, gamma=1.0)
@staticmethod def _stable_logpdf(stability: float, x: ArrayLike, params: dict) -> Array: r"""Skewed-t log-PDF (McNeil, Frey & Embrechts 2005, §3.2). .. math:: \log f(x) = c + \bigl[\log K_s(r) + s \log r\bigr] + P \gamma - s \log\!\bigl(1 + Q/\nu\bigr), where ``s = (ν+1)/2``, ``P = (x-μ)/σ²``, ``Q = P·(x-μ)``, ``R = (γ/σ)²`` and ``r = sqrt((ν+Q)·R)``. The bracketed combination ``log K_s(r) + s log r`` has divergent individual terms as ``γ → 0`` (``log K_s(0) = +∞``, ``s log 0 = −∞``) but a finite analytical limit ``log Γ(s) + (s−1) log 2`` — computed as a single cancellation-stable object by :py:func:`log_kv_plus_s_log_r`. At ``γ = 0`` exactly the formula evaluates to the Student-t log-PDF to float64 eps. """ nu, mu, sigma, gamma = SkewedT._params_to_tuple(params) x, xshape = _univariate_input(x) s: float = 0.5 * (nu + 1) c: float = ( jnp.log(2.0) * (1 - s) - lax.lgamma(0.5 * nu) - 0.5 * jnp.log(jnp.pi * nu + stability) - jnp.log(sigma + stability) ) P: jnp.ndarray = (x - mu) * lax.pow(sigma, -2) Q: jnp.ndarray = P * (x - mu) R: jnp.ndarray = lax.pow(gamma / sigma, 2) # jnp.maximum on the sqrt argument keeps ∂r/∂γ finite at γ=0 # (otherwise ∂√z/∂z = 1/(2√z) → ∞ at z=0 multiplies against # the upstream ∂z/∂γ = 0 to give NaN). The 1e-24 floor is # the square of log_kv_plus_s_log_r's internal r-floor, so # the helper's direct-sum path is reached for any γ > 0. r = lax.sqrt(jnp.maximum((nu + Q) * R, 1e-24)) log_kv_plus = log_kv_plus_s_log_r(s, r) logpdf: jnp.ndarray = ( c + log_kv_plus + P * gamma - s * jnp.log(1 + Q / (nu + stability)) ) return logpdf.reshape(xshape) # sampling
[docs] def rvs( self, size: tuple | Scalar, params: dict = None, key: Array = None ) -> Array: """Generate random variates via mean-variance mixture of normals.""" params = self._resolve_params(params) key = _resolve_key(key) nu, mu, sigma, gamma = self._params_to_tuple(params) key1, key2 = random.split(key) W: jnp.ndarray = ig.rvs( size=size, key=key1, params={"alpha": nu * 0.5, "beta": nu * 0.5} ) return mean_variance_sampling( key=key2, W=W, shape=size, mu=mu, sigma=sigma, gamma=gamma )
# stats def _get_w_stats(self, nu: float) -> dict: """Compute mean and variance of the inverse-gamma mixing variable W. Divergent moments propagate as ``+inf``: mean diverges for ``nu <= 2``, variance diverges for ``nu <= 4``. """ ig_params: dict = {"alpha": nu * 0.5, "beta": nu * 0.5} ig_stats: dict = ig.stats(params=ig_params) return {"mean": ig_stats["mean"], "variance": ig_stats["variance"]}
[docs] def stats(self, params: dict = None) -> dict: """Compute distribution statistics derived from the mean-variance mixture representation.""" params = self._resolve_params(params) nu, mu, sigma, gamma = self._params_to_tuple(params) w_stats: dict = self._get_w_stats(nu) return self._scalar_transform( mean_variance_stats(mu=mu, sigma=sigma, gamma=gamma, w_stats=w_stats) )
# fitting @staticmethod def _sample_moments(x: jnp.ndarray) -> tuple: """Sample (mean, std, skew, excess kurtosis) used for method-of-moments initialisation of the 4-parameter fit.""" sample_mean = x.mean() sample_std = x.std() z = (x - sample_mean) / sample_std sample_skew = jnp.mean(z ** 3) sample_kurt = jnp.mean(z ** 4) - 3.0 return sample_mean, sample_std, sample_skew, sample_kurt @staticmethod @jit def _nll_value_and_grad(all_params: Array, x: Array) -> tuple: """Compute negative log-likelihood and its gradient w.r.t. all 4 parameters.""" def _nll(params_arr, x): params = SkewedT._params_from_array(params_arr) return -jnp.mean(SkewedT._stable_logpdf(1e-30, x, params)) return value_and_grad(_nll)(all_params, x) @staticmethod def _em_body(carry: tuple, _: None, x: Array, lr: float, shape_steps: int) -> tuple: """Single ECME iteration for skewed-t, compatible with lax.scan. E-step: posterior W_i|x_i ~ GIG(-(nu+1)/2, nu+Q_i, gamma^2/sigma^2) CM-step 1: closed-form update for mu, gamma, sigma CM-step 2: gradient descent on nu Args: carry: Tuple of (nu, mu, sigma, gamma). _: Unused scan input. x: Data array (static). lr: Shape learning rate (static). shape_steps: Number of inner gradient steps for nu (static). Returns: Updated carry and None. """ eps: float = 1e-8 nu, mu, sigma, gamma = carry # --- E-step --- Q = lax.pow(lax.div(lax.sub(x, mu), sigma), 2) psi_bar = lax.pow(lax.div(gamma, sigma), 2) lam_post = -(nu + 1.0) / 2.0 chi_post = nu + Q delta = jnp.clip(GH._gig_expected_w(lam_post, chi_post, psi_bar), eps, 1e10) eta = jnp.clip(GH._gig_expected_inv_w(lam_post, chi_post, psi_bar), eps, 1e10) # --- CM-step 1: closed-form update for mu, gamma, sigma --- delta_bar = jnp.mean(delta) eta_bar = jnp.mean(eta) x_bar = jnp.mean(x) x_eta_bar = jnp.mean(x * eta) denom = eta_bar - 1.0 / delta_bar denom = jnp.where(jnp.abs(denom) < eps, eps, denom) mu = (x_eta_bar - x_bar / delta_bar) / denom gamma = (x_bar - mu) / delta_bar sigma_sq = jnp.mean( (x - mu) ** 2 * eta - 2 * (x - mu) * gamma + delta * gamma ** 2 ) sigma = jnp.sqrt(jnp.maximum(sigma_sq, eps)) # --- CM-step 2: gradient descent for nu --- def _shape_step(shape_carry, _): n = shape_carry[0] all_p = jnp.array([n, mu, sigma, gamma]) _, g = SkewedT._nll_value_and_grad(all_p, x) g_nu = jnp.nan_to_num(g[0], nan=0.0) n = jnp.maximum(n - lr * g_nu, eps) return (n,), None (nu,), _ = lax.scan(_shape_step, (nu,), None, length=shape_steps) return (nu, mu, sigma, gamma), None def _fit_em(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict: """Fit via ECME algorithm (McNeil et al. 2005, Section 3.4.2). The EM algorithm treats the IG mixing variable W as latent data. It avoids the mu/gamma/sigma identifiability ridge by updating these parameters in closed form, while nu is updated via gradient descent. The entire loop is compiled via ``lax.scan`` for performance. Args: x: Input data array. lr: Learning rate for nu gradient steps. maxiter: Number of EM iterations. Returns: Fitted parameter dictionary. """ sample_mean, sample_std, sample_skew, sample_kurt = self._sample_moments(x) nu0 = jnp.clip(4.0 + 6.0 / jnp.maximum(sample_kurt, 0.1), _NU_INIT, 60.0) init_carry: tuple = ( nu0, sample_mean, sample_std, sample_skew * sample_std * 0.25, ) shape_steps: int = 10 em_step = lambda carry, _: self._em_body(carry, _, x, lr, shape_steps) final_carry, _ = lax.scan(em_step, init_carry, None, length=maxiter) nu, mu, sigma, gamma = final_carry return self._params_dict(nu=nu, mu=mu, sigma=sigma, gamma=gamma) def _fit_mle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict: """Fit all four parameters via projected gradient MLE with box constraints.""" eps: float = 1e-8 sample_mean, sample_std, sample_skew, sample_kurt = self._sample_moments(x) # Data-driven box constraints to prevent divergence sigma_lo = 0.1 * sample_std gamma_bound = 2.0 * sample_std mu_bound = 2.0 * sample_std constraints: tuple = ( jnp.array([[eps, sample_mean - mu_bound, sigma_lo + eps, -gamma_bound]]).T, jnp.array([[jnp.inf, sample_mean + mu_bound, jnp.inf, gamma_bound]]).T, ) projection_options: dict = {"lower": constraints[0], "upper": constraints[1]} # Method-of-moments initial estimates nu0 = jnp.clip(4.0 + 6.0 / jnp.maximum(sample_kurt, 0.1), 4.0 + eps, 60.0) gamma0 = jnp.clip(sample_skew * sample_std * 0.5, -gamma_bound, gamma_bound) ew = nu0 / (nu0 - 2.0) mu0 = jnp.clip(sample_mean - ew * gamma0, sample_mean - mu_bound, sample_mean + mu_bound) sigma0 = jnp.maximum( jnp.sqrt(jnp.maximum(sample_std**2 - ew * gamma0**2, eps) / ew), sigma_lo + eps, ) params0: jnp.ndarray = jnp.array([nu0, mu0, sigma0, gamma0]) res: dict = projected_gradient( f=self._mle_objective, x0=params0, projection_method="projection_box", projection_options=projection_options, x=x, lr=lr, maxiter=maxiter, ) nu, mu, sigma, gamma = res["x"] return self._params_dict(nu=nu, mu=mu, sigma=sigma, gamma=gamma) def _ldmle_objective( self, params: jnp.ndarray, x: jnp.ndarray, sample_mean: Scalar, sample_variance: Scalar, ) -> jnp.ndarray: """LDMLE objective over (raw_nu, z). gamma follows from z via the feasibility reparam; mu and sigma follow from moment-matching. sigma is strictly positive by construction. """ raw_nu, z = params nu = jnn.softplus(raw_nu) + _NU_LDMLE_MIN sigma_hat = jnp.sqrt(sample_variance) ig_stats: dict = self._get_w_stats(nu=nu) gamma, sigma = forward_reparam_1d( z, sigma_hat, ig_stats["mean"], ig_stats["variance"], ) mu = sample_mean - ig_stats["mean"] * gamma return self._mle_objective(params_arr=jnp.array([nu, mu, sigma, gamma]), x=x) def _fit_ldmle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict: """Fit via LDMLE. Optimises (raw_nu, z): gamma is reparametrised so feasibility of the moment-matching reconstruction is structural. """ _, sample_std, sample_skew, sample_kurt = self._sample_moments(x) # z is unconstrained in both directions. constraints: tuple = ( jnp.array([[-jnp.inf, -jnp.inf]]).T, jnp.array([[jnp.inf, jnp.inf]]).T, ) # Method-of-moments initial estimates. nu floored above _NU_LDMLE_MIN so # the moment-matching reconstruction stays in the Var[W] < inf regime. nu_lower = _NU_LDMLE_MIN + 0.5 nu0 = jnp.clip(4.0 + 6.0 / jnp.maximum(sample_kurt, 0.1), nu_lower, 60.0) raw_nu0 = jnp.log(jnp.expm1(nu0 - _NU_LDMLE_MIN)) gamma0 = sample_skew * sample_std * 0.5 w_var0 = self._get_w_stats(nu=nu0)["variance"] z0 = invert_gamma_to_z_1d(gamma0, sample_std, w_var0) params0: jnp.ndarray = jnp.array([raw_nu0, z0]) projection_options: dict = {"lower": constraints[0], "upper": constraints[1]} sample_mean, sample_variance = x.mean(), x.var() res: dict = projected_gradient( f=self._ldmle_objective, x0=params0, projection_method="projection_box", projection_options=projection_options, x=x, sample_mean=sample_mean, sample_variance=sample_variance, lr=lr, maxiter=maxiter, ) raw_nu, z = res["x"] nu = jnn.softplus(raw_nu) + _NU_LDMLE_MIN sigma_hat = jnp.sqrt(sample_variance) ig_stats: dict = self._get_w_stats(nu=nu) gamma, sigma = forward_reparam_1d( z, sigma_hat, ig_stats["mean"], ig_stats["variance"], ) mu = sample_mean - ig_stats["mean"] * gamma return self._params_dict(nu=nu, mu=mu, sigma=sigma, gamma=gamma) _supported_methods = frozenset({"em", "mle", "ldmle"})
[docs] def fit( self, x: ArrayLike, method: str = "em", lr=0.1, maxiter: int = 100, name: str = None, ): r"""Fit the distribution to the input data via numerical MLE. Note: If you intend to jit wrap this function, ensure that ``method`` is a static argument. Args: x (ArrayLike): The input data to fit the distribution to. method (str): The fitting method to use. One of: ``'em'`` — the ECME algorithm (McNeil et al. 2005); ``'mle'`` — projected-gradient maximum likelihood; ``'ldmle'`` — low-dimensional maximum likelihood estimation. Defaults to ``'em'``. lr (float): Learning rate for the optimiser. maxiter (int): Maximum number of iterations for the optimiser. name (str): Optional custom name for the fitted instance. Returns: SkewedT: A fitted ``SkewedT`` instance. Raises: ValueError: If ``method`` is not one of the accepted strings listed above. """ self._check_method(method) x = _univariate_input(x)[0] if method == "mle": return self._fitted_instance( self._fit_mle(x=x, lr=lr, maxiter=maxiter), name=name ) elif method == "em": return self._fitted_instance( self._fit_em(x=x, lr=lr, maxiter=maxiter), name=name ) elif method == "ldmle": return self._fitted_instance( self._fit_ldmle(x=x, lr=lr, maxiter=maxiter), name=name ) else: raise ValueError( f"Unknown Skewed-T fit method {method!r}. " f"Expected one of: {sorted(self._supported_methods)}." )
# cdf @staticmethod def _params_from_array(params_arr: jnp.ndarray, *args, **kwargs) -> dict: """Reconstruct a parameter dictionary from a flat array.""" nu, mu, sigma, gamma = params_arr return SkewedT._params_dict(nu=nu, mu=mu, sigma=sigma, gamma=gamma) @staticmethod def _pdf_for_cdf(x: ArrayLike, *params_tuple) -> Array: """Evaluate the PDF for numerical CDF integration.""" params_array: jnp.ndarray = jnp.asarray(params_tuple).flatten() params: dict = SkewedT._params_from_array(params_array) return jnp.exp(SkewedT._stable_logpdf(stability=1e-30, x=x, params=params)) def _cdf_anchor_scales(self, params: dict) -> Array: """Use the intrinsic scale parameter sigma, not sqrt(variance). The default sqrt(variance) formula for skewed-T requires ``nu > 2`` to be finite. For ``1 < nu <= 2`` the mean exists but variance is infinite; the base-class default would produce ``inf`` or ``nan`` for the scale and break the breakpoint grid. The sigma shape parameter is always positive and well-defined regardless of ``nu``, giving a clean bulk scale. """ _, _, sigma, _ = SkewedT._params_to_tuple(params) return jnp.asarray(sigma).reshape((1,))
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the CDF via numerical integration with a custom VJP.""" params = self._resolve_params(params) cdf = _vjp_cdf(x=x, params=params) return self._enforce_support_on_cdf(x=x, cdf=cdf, params=params)
skewed_t = SkewedT("Skewed-T") def _vjp_cdf(x: ArrayLike, params: dict) -> Array: params = SkewedT._args_transform(params) return _cdf(dist=skewed_t, x=x, params=params) _vjp_cdf_copy = deepcopy(_vjp_cdf) _vjp_cdf = custom_vjp(_vjp_cdf)
[docs] def cdf_fwd(x: ArrayLike, params: dict) -> tuple[Array, tuple]: params = SkewedT._args_transform(params) return _cdf_fwd(dist=skewed_t, cdf_func=_vjp_cdf_copy, x=x, params=params)
_vjp_cdf.defvjp(cdf_fwd, cdf_bwd)