"""File containing the copulAX implementation of the generalized hyperbolic distribution."""
import jax.numpy as jnp
from jax import lax, custom_vjp, random, jit, value_and_grad
from jax import Array
from jax.typing import ArrayLike
from copy import deepcopy
from copulax._src._distributions import Univariate
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.typing import Scalar
from copulax._src.univariate._cdf import _cdf, cdf_bwd, _cdf_fwd
from copulax.special import log_kv
from copulax._src.univariate._rvs import mean_variance_sampling
from copulax._src.univariate._normal_mixture import (
forward_reparam_1d,
invert_gamma_to_z_1d,
mean_variance_stats,
)
from copulax._src.univariate.gig import gig
from copulax._src.univariate.nig import NIG
from copulax._src.optimize import projected_gradient
def _nig_mom_gh_init(x: jnp.ndarray) -> tuple:
r"""GH initial point from NIG method-of-moments.
NIG ≡ GH at :math:`\lambda = -1/2` under the mapping
:math:`(\chi, \psi, \mu_{GH}, \gamma_{GH}, \sigma) =
(\hat{\delta}^2,\, \hat{\alpha}^2 - \hat{\beta}^2,\,
\hat{\mu},\, \hat{\beta},\, 1)`. ``NIG._fit_mom`` (Karlis 2002)
falls back to the symmetric-NIG branch when the empirical
skew/kurtosis pair lies outside the NIG-feasible cone
(:math:`3\kappa - 5\gamma_1^2 \le 0`), keeping the init feasible
for any data.
"""
p = NIG._fit_mom(x)
mu_hat, alpha_hat, beta_hat, delta_hat = (
p["mu"], p["alpha"], p["beta"], p["delta"],
)
lamb = jnp.asarray(-0.5, dtype=mu_hat.dtype)
chi = delta_hat ** 2
psi = alpha_hat ** 2 - beta_hat ** 2
sigma = jnp.asarray(1.0, dtype=mu_hat.dtype)
return lamb, chi, psi, mu_hat, sigma, beta_hat
[docs]
class GH(Univariate):
r"""The generalized hyperbolic distribution. This is a flexible,
continuous 6-parameter family of distributions that can model a variety
of data behaviors, including heavy tails and skewness. It contains
a number of popular distributions as special cases, including the
normal, student-t, hyperbolic, laplace, and skewed-T distributions.
We adopt the parameterization used by McNeil et al. (2005):
.. math::
f(x|\mu, \sigma, \chi, \psi, \gamma, \lambda) \propto
e^{\gamma (x-\mu) / \sigma^2}\,
\frac{K_{\lambda - 0.5}(\sqrt{A})}
{(\sqrt{A})^{0.5 - \lambda}},
\qquad
A = \left(\chi + \left(\tfrac{x-\mu}{\sigma}\right)^2\right)
\left(\psi + \left(\tfrac{\gamma}{\sigma}\right)^2\right)
where :math:`K_{\lambda}` is the modified Bessel function of the second
kind, :math:`\mu` is the location parameter, :math:`\sigma` is the scale,
:math:`\gamma` is the skewness and :math:`\lambda`, :math:`\chi` and
:math:`\psi` relate to the shape of the distribution.
"""
lamb: Array = None
chi: Array = None
psi: Array = None
mu: Array = None
sigma: Array = None
gamma: Array = None
def __init__(
self,
name="GH",
*,
lamb=None,
chi=None,
psi=None,
mu=None,
sigma=None,
gamma=None,
):
"""Initialize the Generalized Hyperbolic distribution.
Args:
name: Display name for the distribution.
lamb: Shape parameter (real-valued).
chi: Concentration parameter (strictly positive).
psi: Rate parameter (strictly positive).
mu: Location parameter.
sigma: Scale / dispersion parameter.
gamma: Skewness parameter.
"""
super().__init__(name)
self.lamb = (
jnp.asarray(lamb, dtype=float).reshape(()) if lamb is not None else None
)
self.chi = (
jnp.asarray(chi, dtype=float).reshape(()) if chi is not None else None
)
self.psi = (
jnp.asarray(psi, dtype=float).reshape(()) if psi is not None else None
)
self.mu = jnp.asarray(mu, dtype=float).reshape(()) if mu is not None else None
self.sigma = (
jnp.asarray(sigma, dtype=float).reshape(()) if sigma is not None else None
)
self.gamma = (
jnp.asarray(gamma, dtype=float).reshape(()) if gamma is not None else None
)
@property
def _stored_params(self):
"""Return stored parameters if all are set, else None."""
if any(
v is None
for v in [self.lamb, self.chi, self.psi, self.mu, self.sigma, self.gamma]
):
return None
return {
"lamb": self.lamb,
"chi": self.chi,
"psi": self.psi,
"mu": self.mu,
"sigma": self.sigma,
"gamma": self.gamma,
}
@classmethod
def _params_dict(
cls,
lamb: Scalar,
chi: Scalar,
psi: Scalar,
mu: Scalar,
sigma: Scalar,
gamma: Scalar,
) -> dict:
r"""Convert parameters to a dictionary."""
d: dict = {
"lamb": lamb,
"chi": chi,
"psi": psi,
"mu": mu,
"sigma": sigma,
"gamma": gamma,
}
return cls._args_transform(d)
@staticmethod
def _params_to_tuple(params: dict) -> tuple:
"""Extract (lamb, chi, psi, mu, sigma, gamma) from the parameter dictionary."""
params = GH._args_transform(params)
return (
params["lamb"],
params["chi"],
params["psi"],
params["mu"],
params["sigma"],
params["gamma"],
)
@staticmethod
def _params_to_array(params: dict) -> Array:
"""Convert the parameter dictionary to a flat array."""
return jnp.asarray(GH._params_to_tuple(params)).flatten()
@classmethod
def _support(cls, *args, **kwargs) -> Array:
"""Return the support ``[-inf, inf]``."""
return jnp.array([-jnp.inf, jnp.inf])
[docs]
def example_params(self, *args, **kwargs) -> dict:
return self._params_dict(
lamb=0.0, chi=1.0, psi=1.0, mu=0.0, sigma=1.0, gamma=0.0
)
@staticmethod
def _stable_logpdf(stability: Scalar, x: ArrayLike, params: dict) -> Array:
"""Compute the numerically stabilized log-PDF of the GH distribution."""
lamb, chi, psi, mu, sigma, gamma = GH._params_to_tuple(params)
x, xshape = _univariate_input(x)
r: float = lax.sqrt(lax.mul(chi, psi))
s: float = 0.5 - lamb
h: float = lax.add(psi, lax.pow(lax.div(gamma, sigma), 2))
g = lax.div(lax.sub(x, mu), lax.pow(sigma, 2))
m = lax.sqrt(lax.mul(lax.add(chi, lax.mul(g, lax.sub(x, mu))), h))
T = lax.add(log_kv(-s, m), lax.mul(g, gamma))
B = lax.mul(lax.log(m + stability), s)
cT = lax.add(
lax.mul(lamb, lax.log((psi / (r + stability)) + stability)),
lax.mul(lax.log(h), s),
)
cB = lax.add(
lax.add(lax.log(sigma), lax.log(lax.sqrt(2 * jnp.pi))),
log_kv(lamb, r),
)
c = lax.sub(cT, cB)
logpdf: jnp.ndarray = lax.add(lax.sub(T, B), c)
return logpdf.reshape(xshape)
[docs]
def logpdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Compute the log probability density function."""
params = self._resolve_params(params)
logpdf = GH._stable_logpdf(stability=0.0, x=x, params=params)
return self._enforce_support_on_logpdf(x=x, logpdf=logpdf, params=params)
# sampling
[docs]
def rvs(
self, size: tuple | Scalar, params: dict = None, key: Array = None
) -> Array:
"""Generate random variates via GIG-normal mean-variance mixture."""
params = self._resolve_params(params)
key = _resolve_key(key)
lamb, chi, psi, mu, sigma, gamma = self._params_to_tuple(params)
key1, key2 = random.split(key)
W = gig.rvs(
key=key1, size=size, params={"lamb": lamb, "chi": chi, "psi": psi}
)
return mean_variance_sampling(
key=key2, W=W, shape=size, mu=mu, sigma=sigma, gamma=gamma
)
# stats
def _get_w_stats(self, lamb: Scalar, chi: Scalar, psi: Scalar) -> dict:
return gig.stats(params={"lamb": lamb, "chi": chi, "psi": psi})
[docs]
def stats(self, params: dict = None) -> dict:
"""Compute distribution statistics derived from the GIG-normal mixture representation."""
params = self._resolve_params(params)
lamb, chi, psi, mu, sigma, gamma = self._params_to_tuple(params)
gig_stats: dict = self._get_w_stats(lamb=lamb, chi=chi, psi=psi)
return self._scalar_transform(
mean_variance_stats(w_stats=gig_stats, mu=mu, sigma=sigma, gamma=gamma)
)
# fitting
@staticmethod
def _gig_expected_w(lamb: Scalar, chi: Scalar, psi: Scalar) -> Array:
"""E[W] for W ~ GIG(lamb, chi, psi) using log_kv ratios."""
r = lax.sqrt(jnp.maximum(lax.mul(chi, psi), 1e-8))
log_ew = 0.5 * lax.log(lax.div(chi, psi)) + log_kv(lamb + 1, r) - log_kv(lamb, r)
return jnp.exp(log_ew)
@staticmethod
def _gig_expected_inv_w(lamb: Scalar, chi: Scalar, psi: Scalar) -> Array:
"""E[1/W] for W ~ GIG(lamb, chi, psi) using log_kv ratios."""
r = lax.sqrt(jnp.maximum(lax.mul(chi, psi), 1e-8))
log_einv = 0.5 * lax.log(lax.div(psi, chi)) + log_kv(lamb - 1, r) - log_kv(lamb, r)
return jnp.exp(log_einv)
@staticmethod
@jit
def _nll_value_and_grad(all_params: Array, x: Array) -> tuple:
def _nll(params_arr, x):
params = GH._params_from_array(params_arr)
return -jnp.mean(GH._stable_logpdf(1e-30, x, params))
return value_and_grad(_nll)(all_params, x)
def _fit_mle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict:
"""Fit all six parameters via projected gradient MLE with box
constraints. Initial point from NIG MoM (see
:func:`_nig_mom_gh_init`)."""
eps: float = 1e-8
constraints: tuple = (
jnp.array([[-jnp.inf, eps, eps, -jnp.inf, eps, -jnp.inf]]).T,
jnp.array([[jnp.inf, jnp.inf, jnp.inf, jnp.inf, jnp.inf, jnp.inf]]).T,
)
projection_options: dict = {"lower": constraints[0], "upper": constraints[1]}
lamb0, chi0, psi0, mu0, sigma0, gamma0 = _nig_mom_gh_init(x)
params0: jnp.ndarray = jnp.array([
lamb0,
jnp.maximum(chi0, eps),
jnp.maximum(psi0, eps),
mu0,
jnp.maximum(sigma0, eps),
gamma0,
])
res: dict = projected_gradient(
f=self._mle_objective,
x0=params0,
projection_method="projection_box",
projection_options=projection_options,
x=x,
lr=lr,
maxiter=maxiter,
)
lamb, chi, psi, mu, sigma, gamma = res["x"]
return GH._params_dict(
lamb=lamb, chi=chi, psi=psi, mu=mu, sigma=sigma, gamma=gamma
)
@staticmethod
def _em_body(carry: tuple, _: None, x: Array, lr: float, shape_steps: int) -> tuple:
"""Single ECME iteration as a pure function for use with lax.scan.
Args:
carry: Tuple of (lamb, chi, psi, mu, sigma, gamma).
_: Unused scan input.
x: Data array (static).
lr: Shape learning rate (static).
shape_steps: Number of inner gradient steps (static).
Returns:
Updated carry and None (no stacked output).
"""
eps: float = 1e-8
lamb, chi, psi, mu, sigma, gamma = carry
# --- E-step ---
Q = lax.pow(lax.div(lax.sub(x, mu), sigma), 2)
psi_bar = psi + lax.pow(lax.div(gamma, sigma), 2)
lam_post = lamb - 0.5
chi_post = chi + Q
delta = jnp.clip(GH._gig_expected_w(lam_post, chi_post, psi_bar), eps, 1e10)
eta = jnp.clip(GH._gig_expected_inv_w(lam_post, chi_post, psi_bar), eps, 1e10)
# --- CM-step 1: closed-form update for mu, gamma, sigma^2 ---
delta_bar = jnp.mean(delta)
eta_bar = jnp.mean(eta)
x_bar = jnp.mean(x)
x_eta_bar = jnp.mean(x * eta)
denom = eta_bar - 1.0 / delta_bar
denom = jnp.where(jnp.abs(denom) < eps, eps, denom)
mu = (x_eta_bar - x_bar / delta_bar) / denom
gamma = (x_bar - mu) / delta_bar
sigma_sq = jnp.mean(
(x - mu) ** 2 * eta - 2 * (x - mu) * gamma + delta * gamma ** 2
)
sigma = jnp.sqrt(jnp.maximum(sigma_sq, eps))
# --- CM-step 2: gradient descent for lamb, chi, psi ---
def _shape_step(shape_carry, _):
l, c, p = shape_carry
all_p = jnp.array([l, c, p, mu, sigma, gamma])
_, g = GH._nll_value_and_grad(all_p, x)
g_shape = jnp.nan_to_num(g[:3], nan=0.0)
l = l - lr * g_shape[0]
c = jnp.maximum(c - lr * g_shape[1], eps)
p = jnp.maximum(p - lr * g_shape[2], eps)
return (l, c, p), None
(lamb, chi, psi), _ = lax.scan(
_shape_step, (lamb, chi, psi), None, length=shape_steps
)
return (lamb, chi, psi, mu, sigma, gamma), None
def _fit_em(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict:
"""Fit via ECME algorithm (McNeil et al. 2005, Section 3.4.2).
The EM algorithm treats the GIG mixing variable W as latent data.
It avoids the mu/gamma/sigma identifiability ridge by updating these
parameters in closed form from the expected sufficient statistics,
while the shape parameters (lamb, chi, psi) are updated via
gradient descent on the observed log-likelihood.
The entire loop is compiled via ``lax.scan`` for performance.
Args:
x: Input data array.
lr: Learning rate for shape parameter gradient steps.
maxiter: Number of EM iterations.
Returns:
Fitted parameter dictionary.
"""
# Initialize from sample moments
sample_mean: Scalar = x.mean()
sample_std: Scalar = x.std()
z: jnp.ndarray = (x - sample_mean) / sample_std
sample_skew: Scalar = jnp.mean(z ** 3)
init_carry: tuple = (
jnp.array(0.0), # lamb
jnp.array(1.0), # chi
jnp.array(1.0), # psi
sample_mean, # mu
sample_std, # sigma
sample_skew * sample_std * 0.25, # gamma
)
shape_steps: int = 10
em_step = lambda carry, _: self._em_body(carry, _, x, lr, shape_steps)
final_carry, _ = lax.scan(em_step, init_carry, None, length=maxiter)
lamb, chi, psi, mu, sigma, gamma = final_carry
return GH._params_dict(
lamb=lamb, chi=chi, psi=psi, mu=mu, sigma=sigma, gamma=gamma
)
def _ldmle_objective(
self,
params: jnp.ndarray,
x: jnp.ndarray,
sample_mean: Scalar,
sample_variance: Scalar,
) -> Scalar:
"""LDMLE objective over (lamb, chi, psi, z). gamma follows from z via
the feasibility reparam; mu and sigma follow from moment-matching.
"""
lamb, chi, psi, z = params
sigma_hat = jnp.sqrt(sample_variance)
gig_stats: dict = self._get_w_stats(lamb=lamb, chi=chi, psi=psi)
gamma, sigma = forward_reparam_1d(
z, sigma_hat, gig_stats["mean"], gig_stats["variance"],
)
mu = sample_mean - gig_stats["mean"] * gamma
return self._mle_objective(
params_arr=jnp.array([lamb, chi, psi, mu, sigma, gamma]), x=x
)
def _fit_ldmle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict:
"""Fit via LDMLE. Optimises ``(lamb, chi, psi, z)``; gamma is
reparametrised so feasibility of the moment-matching
reconstruction is structural. Initial ``(lamb, chi, psi)`` from
NIG MoM (see :func:`_nig_mom_gh_init`); ``z₀`` inverts the
reparam at the NIG-MoM γ̂."""
eps = 1e-8
constraints: tuple = (
jnp.array([[-jnp.inf, eps, eps, -jnp.inf]]).T,
jnp.array([[jnp.inf, jnp.inf, jnp.inf, jnp.inf]]).T,
)
sample_mean, sample_variance = x.mean(), x.var()
lamb0, chi0, psi0, _mu0, _sigma0, gamma0 = _nig_mom_gh_init(x)
chi0 = jnp.maximum(chi0, eps)
psi0 = jnp.maximum(psi0, eps)
# Invert gamma → z under the LDMLE reparam, evaluated at the
# NIG-MoM (lamb0, chi0, psi0). ``invert_gamma_to_z_1d`` needs
# ``Var[W]`` evaluated at those starting GIG params; we get it
# via ``_get_w_stats`` (the same helper LDMLE uses on the
# forward path).
gig_stats0 = self._get_w_stats(lamb=lamb0, chi=chi0, psi=psi0)
sigma_hat = jnp.sqrt(jnp.maximum(sample_variance, eps))
z0 = invert_gamma_to_z_1d(gamma0, sigma_hat, gig_stats0["variance"])
params0: jnp.ndarray = jnp.array([lamb0, chi0, psi0, z0])
projection_options: dict = {"lower": constraints[0], "upper": constraints[1]}
res = projected_gradient(
f=self._ldmle_objective,
x0=params0,
x=x,
lr=lr,
maxiter=maxiter,
projection_method="projection_box",
projection_options=projection_options,
sample_mean=sample_mean,
sample_variance=sample_variance,
)
lamb, chi, psi, z = res["x"]
sigma_hat = jnp.sqrt(sample_variance)
gig_stats: dict = self._get_w_stats(lamb=lamb, chi=chi, psi=psi)
gamma, sigma = forward_reparam_1d(
z, sigma_hat, gig_stats["mean"], gig_stats["variance"],
)
mu = sample_mean - gig_stats["mean"] * gamma
return self._params_dict(
lamb=lamb, chi=chi, psi=psi, mu=mu, sigma=sigma, gamma=gamma
)
_supported_methods = frozenset({"em", "mle", "ldmle"})
[docs]
def fit(
self,
x: ArrayLike,
method: str = "em",
lr: float = 0.1,
maxiter: int = 100,
name: str = None,
):
r"""Fit the distribution to the input data via numerical MLE.
Note:
If you intend to jit wrap this function, ensure that
``method`` is a static argument.
Args:
x (ArrayLike): The input data to fit the distribution to.
method (str): The fitting method to use. One of:
``'em'`` — the ECME algorithm (McNeil et al. 2005);
``'mle'`` — direct projected-gradient maximum likelihood;
``'ldmle'`` — low-dimensional maximum likelihood
estimation. Defaults to ``'em'``.
lr (float): Learning rate for the optimiser.
maxiter (int): Maximum number of iterations for the optimiser.
name (str): Optional custom name for the fitted instance.
Returns:
GH: A fitted ``GH`` instance.
Raises:
ValueError: If ``method`` is not one of the accepted
strings listed above.
"""
self._check_method(method)
x = _univariate_input(x)[0]
if method == "mle":
return self._fitted_instance(
self._fit_mle(x, lr=lr, maxiter=maxiter), name=name
)
elif method == "em":
return self._fitted_instance(
self._fit_em(x, lr=lr, maxiter=maxiter), name=name
)
elif method == "ldmle":
return self._fitted_instance(
self._fit_ldmle(x, lr=lr, maxiter=maxiter), name=name
)
else:
raise ValueError(
f"Unknown GH fit method {method!r}. "
f"Expected one of: {sorted(self._supported_methods)}."
)
# cdf
@staticmethod
def _params_from_array(params_arr: jnp.ndarray, *args, **kwargs) -> dict:
"""Reconstruct a parameter dictionary from a flat array."""
lamb, chi, psi, mu, sigma, gamma = params_arr
return GH._params_dict(
lamb=lamb, chi=chi, psi=psi, mu=mu, sigma=sigma, gamma=gamma
)
@staticmethod
def _pdf_for_cdf(x: ArrayLike, *params_tuple) -> Array:
"""Evaluate the PDF for numerical CDF integration."""
params_array: jnp.ndarray = jnp.asarray(params_tuple).flatten()
params: dict = GH._params_from_array(params_array)
return lax.exp(GH._stable_logpdf(stability=0.0, x=x, params=params))
def _cdf_anchor_scales(self, params: dict) -> Array:
"""Use the intrinsic sigma shape parameter, not sqrt(variance).
The default base-class scale computes sqrt(variance) where the
GH variance formula involves a ratio of modified Bessel
functions (via the GIG mixing variable). That ratio is
numerically delicate for extreme shape parameters; the sigma
shape parameter is always well-defined and gives a cleaner
bulk scale for the t-space breakpoint grid.
"""
_, _, _, _, sigma, _ = GH._params_to_tuple(params)
return jnp.asarray(sigma).reshape((1,))
[docs]
def cdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Compute the CDF via numerical integration with a custom VJP."""
params = self._resolve_params(params)
cdf = _vjp_cdf(x=x, params=params)
return self._enforce_support_on_cdf(x=x, cdf=cdf, params=params)
gh = GH("GH")
def _vjp_cdf(x: ArrayLike, params: dict) -> Array:
params = GH._args_transform(params)
return _cdf(dist=gh, x=x, params=params)
_vjp_cdf_copy = deepcopy(_vjp_cdf)
_vjp_cdf = custom_vjp(_vjp_cdf)
[docs]
def cdf_fwd(x: ArrayLike, params: dict) -> tuple[Array, tuple]:
params = GH._args_transform(params)
return _cdf_fwd(dist=gh, cdf_func=_vjp_cdf_copy, x=x, params=params)
_vjp_cdf.defvjp(cdf_fwd, cdf_bwd)