Source code for copulax._src.univariate.nig

"""File containing the copulAX implementation of the Normal-Inverse Gaussian distribution."""

import jax.numpy as jnp
from jax import lax, custom_vjp, random
from jax import Array
from jax.typing import ArrayLike
from copy import deepcopy

from copulax._src._distributions import Univariate
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.typing import Scalar
from copulax._src.univariate._cdf import _cdf, cdf_bwd, _cdf_fwd
from copulax.special import log_kv
from copulax._src.univariate.wald import wald
from copulax._src.optimize import projected_gradient
from copulax._src.stats import skew, kurtosis


[docs] class NIG(Univariate): r"""The Normal-Inverse Gaussian distribution. This is a flexible, continuous 4-parameter distribution that can capture skewness and heavy tails. It is a special case of the Generalized Hyperbolic distribution, obtained by fixing :math:`\lambda = -\tfrac{1}{2}`. We adopt the parameterization used on Wikipedia (and by Karlis 2002): .. math:: f(x|\mu, \alpha, \beta, \delta) = \frac{\alpha \delta K_{1}\left(\alpha \sqrt{\delta^2 + (x-\mu)^2}\right)} {\pi \sqrt{\delta^2 + (x-\mu)^2}} e^{\delta \sqrt{\alpha^2 - \beta^2} + \beta (x-\mu)} where :math:`K_{1}` is the modified Bessel function of the second kind of order 1, :math:`\mu \in \mathbb{R}` is the location parameter, :math:`\delta > 0` is the scale parameter, :math:`\alpha > 0` controls the tail heaviness, and :math:`\beta \in (-\alpha, \alpha)` controls the asymmetry / skewness. """ mu: Array = None alpha: Array = None beta: Array = None delta: Array = None def __init__( self, name="NIG", *, mu=None, alpha=None, beta=None, delta=None, ): """Initialize the NIG distribution. Args: name: Name of the distribution. mu: Location parameter (real-valued). alpha: Tail heaviness parameter (positive). beta: Asymmetry parameter (between -alpha and alpha). delta: Scale parameter (positive). """ super().__init__(name=name) self.mu = ( jnp.asarray(mu, dtype=float).reshape(()) if mu is not None else None ) self.alpha = ( jnp.asarray(alpha, dtype=float).reshape(()) if alpha is not None else None ) self.beta = ( jnp.asarray(beta, dtype=float).reshape(()) if beta is not None else None ) self.delta = ( jnp.asarray(delta, dtype=float).reshape(()) if delta is not None else None ) @property def _stored_params(self): """Return stored parameters as a dict if all are set, else None.""" if any(v is None for v in [self.mu, self.alpha, self.beta, self.delta]): return None return { "mu": self.mu, "alpha": self.alpha, "beta": self.beta, "delta": self.delta, } @classmethod def _params_dict(cls, mu: Scalar, alpha: Scalar, beta: Scalar, delta: Scalar) -> dict: d: dict = {"mu": mu, "alpha": alpha, "beta": beta, "delta": delta} return cls._args_transform(d) @staticmethod def _params_to_tuple(params: dict) -> tuple: params = NIG._args_transform(params) return params["mu"], params["alpha"], params["beta"], params["delta"] @staticmethod def _params_to_array(params: dict) -> Array: return jnp.asarray(NIG._params_to_tuple(params), dtype=float).flatten() @classmethod def _support(cls, *args, **kwargs) -> Array: return jnp.array([-jnp.inf, jnp.inf])
[docs] def example_params(self, *args, **kwargs) -> dict: return self._params_dict(mu=0.0, alpha=2.5, beta=1.5, delta=1.0)
@staticmethod def _stable_logpdf(stability: Scalar, x: ArrayLike, params: dict) -> Array: """Compute the numerically stabilised log-pdf of the NIG distribution.""" mu, alpha, beta, delta = NIG._params_to_tuple(params) x, xshape = _univariate_input(x) gamma = jnp.sqrt(jnp.maximum(alpha ** 2 - beta ** 2, stability)) diff = x - mu r = jnp.sqrt(delta ** 2 + diff ** 2) log_exponent = delta * gamma + beta * diff T = log_kv(1, alpha * r) + jnp.log(alpha + stability) + jnp.log(delta + stability) B = jnp.log(jnp.pi) + jnp.log(r + stability) logpdf = log_exponent + T - B return logpdf.reshape(xshape) # sampling
[docs] def rvs(self, size: tuple | Scalar, params: dict = None, key: Array = None) -> Array: r"""Generate random variates via an IG-normal variance-mean mixture.""" params = self._resolve_params(params) key = _resolve_key(key) mu, alpha, beta, delta = NIG._params_to_tuple(params) gamma = jnp.sqrt(alpha ** 2 - beta ** 2) key1, key2 = random.split(key) W = wald.rvs(size=size, params={"mu": delta / gamma, "lamb": delta ** 2}, key=key1) Z = random.normal(key2, shape=W.shape) return mu + beta * W + jnp.sqrt(W) * Z
# stats
[docs] def stats(self, params: dict = None) -> dict: params = self._resolve_params(params) mu, alpha, beta, delta = NIG._params_to_tuple(params) gamma = jnp.sqrt(alpha ** 2 - beta ** 2) mean = mu + delta * beta / gamma variance = delta * alpha ** 2 / gamma ** 3 skewness = 3.0 * beta / (alpha * jnp.sqrt(delta * gamma)) kurt = 3.0 * (1.0 + 4.0 * beta ** 2 / alpha ** 2) / (delta * gamma) return self._scalar_transform({ "mean": mean, "variance": variance, "skewness": skewness, "kurtosis": kurt, })
# -------------------------------------------------------------------- # # Fitting # -------------------------------------------------------------------- # @staticmethod def _fit_mom(x: jnp.ndarray) -> dict: """Fit the NIG distribution to data using method of moments (Karlis 2002, §3.1). For the moment estimator to exist we need ``3·kurt − 5·skew² > 0``. When this condition fails we fall back to a symmetric-NIG moment match ``(β = 0, α = 1/s, δ = s)`` where ``s`` is the sample standard deviation; this keeps the estimator in the admissible region for any input and gives a safe EM/MLE starting point. """ eps = 1e-8 sample_mean = jnp.mean(x) sample_var = jnp.var(x, ddof=1) sample_std = jnp.sqrt(jnp.maximum(sample_var, eps)) sample_skew = skew(x) sample_kurt = kurtosis(x, fisher=True) # excess kurtosis γ₂ cond_value = 3.0 * sample_kurt - 5.0 * sample_skew ** 2 def _regular_branch(_): gamma = 3.0 / (sample_std * jnp.sqrt(jnp.maximum(cond_value, eps))) beta = sample_skew * sample_std * gamma ** 2 / 3.0 delta = sample_var * gamma ** 3 / jnp.maximum(beta ** 2 + gamma ** 2, eps) mu = sample_mean - beta * delta / jnp.maximum(gamma, eps) alpha = jnp.sqrt(beta ** 2 + gamma ** 2) return mu, alpha, beta, delta def _fallback_branch(_): mu = sample_mean delta = sample_std alpha = 1.0 / jnp.maximum(sample_std, eps) beta = jnp.asarray(0.0, dtype=sample_std.dtype) return mu, alpha, beta, delta mu, alpha, beta, delta = lax.cond( cond_value > 0.0, _regular_branch, _fallback_branch, operand=None ) return NIG._params_dict(mu=mu, alpha=alpha, beta=beta, delta=delta) @staticmethod def _em_body(carry: tuple, _: None, x: Array) -> tuple: """Single Karlis EM iteration as a pure function, suitable for ``lax.scan``. Every update is closed form — no inner gradient step, no ECME. """ eps = 1e-12 mu, alpha, beta, delta = carry # --- E-step: posterior expectations of the IG mixing variable. # Karlis (2002) eqs (4)-(5): the posterior of Z|x is GIG(-1, δ√φ(x), α), # whose first moments reduce to ratios of Bessel K functions. diff = x - mu t = jnp.sqrt(delta ** 2 + diff ** 2) # = δ·√φ(x) u = alpha * t # argument shared by every Bessel K in the E-step # Log-space ratios protect against underflow in K_v(u) at large u # and overflow at small u (c.f. the skewed-T Bessel underflow fix). log_r_s = log_kv(0, u) - log_kv(1, u) log_r_w = log_kv(2, u) - log_kv(1, u) s_i = (t / alpha) * jnp.exp(log_r_s) # E[Z|x_i] w_i = (alpha / t) * jnp.exp(log_r_w) # E[Z^{-1}|x_i] # --- M-step: closed-form updates (Karlis 2002 p. 47-48). x_bar = jnp.mean(x) s_bar = jnp.mean(s_i) w_bar = jnp.mean(w_i) xw_bar = jnp.mean(x * w_i) # δ update: Λ̂ = 1 / mean(w_i − 1/s̄). inv_term = w_bar - 1.0 / jnp.maximum(s_bar, eps) lam = 1.0 / jnp.maximum(inv_term, eps) delta_new = jnp.sqrt(jnp.maximum(lam, eps)) gamma_new = delta_new / jnp.maximum(s_bar, eps) # β update: ML regression coefficient for E[x|z] = μ + β·z with Var(x|z)=z. denom_b = 1.0 - s_bar * w_bar denom_b = jnp.where(jnp.abs(denom_b) < eps, eps, denom_b) beta_new = (xw_bar - x_bar * w_bar) / denom_b mu_new = x_bar - beta_new * s_bar alpha_new = jnp.sqrt(gamma_new ** 2 + beta_new ** 2) return (mu_new, alpha_new, beta_new, delta_new), None def _fit_em(self, x: jnp.ndarray, maxiter: int) -> dict: """Fit the NIG distribution via the Karlis (2002) EM algorithm. The IG mixing variable ``Z`` is treated as latent. The GIG conjugacy of the IG prior gives a closed-form posterior, so both the E-step and M-step are analytic. Compiles via ``lax.scan``. """ init_params = self._fit_mom(x) init_carry = NIG._params_to_tuple(init_params) em_step = lambda carry, _: NIG._em_body(carry, _, x) final_carry, _ = lax.scan(em_step, init_carry, None, length=maxiter) mu, alpha, beta, delta = final_carry return NIG._params_dict(mu=mu, alpha=alpha, beta=beta, delta=delta) def _mle_objective_3p( self, params_arr: jnp.ndarray, x: jnp.ndarray, sample_mean: Scalar, ) -> Scalar: """3-parameter NIG objective exploiting the exact β-score identity. Karlis (2002) Lemma: ``∂L/∂β = 0`` gives ``x̄ = μ + δβ/γ`` exactly, so the observed-data MLE over ``(μ, α, β, δ)`` equals the MLE over ``(γ, β, δ)`` with ``μ = x̄ − δβ/γ`` and ``α = √(γ² + β²)``. We optimise in ``(γ, β, log δ)`` so ``δ`` stays strictly positive without a boundary constraint. """ gamma, beta, log_delta = params_arr delta = jnp.exp(log_delta) alpha = jnp.sqrt(gamma ** 2 + beta ** 2) mu = sample_mean - delta * beta / gamma full_params = jnp.array([mu, alpha, beta, delta]) return self._mle_objective(params_arr=full_params, x=x) def _fit_mle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict: """Fit via projected-gradient MLE over ``(γ, β, log δ)``. This is a *genuine* MLE — the β-score identity ``μ = x̄ − δβ/γ`` is an exact first-order condition, so eliminating ``μ`` loses no optimality relative to a 4-D search. """ eps = 1e-6 constraints = ( jnp.array([[eps, -jnp.inf, -jnp.inf]]).T, jnp.array([[jnp.inf, jnp.inf, jnp.inf]]).T, ) projection_options = {"lower": constraints[0], "upper": constraints[1]} # Initialise from MoM so γ, β, log δ start in the admissible region. mom = self._fit_mom(x) mu0, alpha0, beta0, delta0 = NIG._params_to_tuple(mom) gamma0 = jnp.sqrt(jnp.maximum(alpha0 ** 2 - beta0 ** 2, eps)) params0 = jnp.array([gamma0, beta0, jnp.log(jnp.maximum(delta0, eps))]) sample_mean = x.mean() res = projected_gradient( f=self._mle_objective_3p, x0=params0, projection_method="projection_box", projection_options=projection_options, x=x, sample_mean=sample_mean, lr=lr, maxiter=maxiter, ) gamma, beta, log_delta = res["x"] delta = jnp.exp(log_delta) alpha = jnp.sqrt(gamma ** 2 + beta ** 2) mu = sample_mean - delta * beta / gamma return NIG._params_dict(mu=mu, alpha=alpha, beta=beta, delta=delta) _supported_methods = frozenset({"em", "mle", "mom"})
[docs] def fit( self, x: ArrayLike, method: str = "em", lr: float = 0.1, maxiter: int = 100, name: str = None, ): r"""Fit the NIG distribution to the input data. Note: If you intend to ``jit``-wrap this function, ensure that ``method`` is a static argument. Args: x (ArrayLike): The input data to fit the distribution to. method (str): Fitting method. One of: ``'em'`` — iterated Karlis (2002) EM step (numerical; **default**); ``'mle'`` — 3-parameter projected-gradient MLE via the exact β-score identity (numerical); ``'mom'`` — **closed-form** method of moments. lr (float): Learning rate for the projected-gradient MLE. Ignored for ``'em'`` and ``'mom'``. maxiter (int): Maximum number of iterations for iterative methods. Ignored for ``'mom'``. name (str): Optional custom name for the fitted instance. Returns: NIG: A fitted ``NIG`` instance. Raises: ValueError: If ``method`` is not one of the accepted strings listed above. """ self._check_method(method) x = _univariate_input(x)[0] if method == "mle": return self._fitted_instance(self._fit_mle(x, lr=lr, maxiter=maxiter), name=name) elif method == "em": return self._fitted_instance(self._fit_em(x, maxiter=maxiter), name=name) elif method == "mom": return self._fitted_instance(self._fit_mom(x), name=name) else: raise ValueError( f"Unknown NIG fit method {method!r}. " f"Expected one of: {sorted(self._supported_methods)}." )
# -------------------------------------------------------------------- # # CDF (numerical integration with custom VJP) # -------------------------------------------------------------------- # @staticmethod def _params_from_array(params_arr: jnp.ndarray, *args, **kwargs) -> dict: mu, alpha, beta, delta = params_arr return NIG._params_dict(mu=mu, alpha=alpha, beta=beta, delta=delta) @staticmethod def _pdf_for_cdf(x: ArrayLike, *params_tuple) -> Array: """PDF evaluator for the CDF integrator; overrides the base to call the static ``_stable_logpdf`` directly (the base assumes ``pdf`` is a classmethod).""" params_array: jnp.ndarray = jnp.asarray(params_tuple).flatten() params: dict = NIG._params_from_array(params_array) return lax.exp(NIG._stable_logpdf(stability=0.0, x=x, params=params)) def _cdf_anchor_scales(self, params: dict) -> Array: """Use the intrinsic scale parameter delta. The default sqrt(variance) formula for NIG is ``delta * alpha^2 / (alpha^2 - beta^2)^(3/2)``, which blows up as ``|beta|`` approaches ``alpha`` (near-boundary case). The scale parameter ``delta`` is always finite and positive and gives a numerically robust bulk scale for the t-space breakpoint grid. """ _, _, _, delta = NIG._params_to_tuple(params) return jnp.asarray(delta).reshape((1,))
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the CDF via numerical integration with a custom VJP.""" params = self._resolve_params(params) cdf_vals = _vjp_cdf(x=x, params=params) return self._enforce_support_on_cdf(x=x, cdf=cdf_vals, params=params)
nig = NIG("NIG") def _vjp_cdf(x: ArrayLike, params: dict) -> Array: params = NIG._args_transform(params) return _cdf(dist=nig, x=x, params=params) _vjp_cdf_copy = deepcopy(_vjp_cdf) _vjp_cdf = custom_vjp(_vjp_cdf)
[docs] def cdf_fwd(x: ArrayLike, params: dict) -> tuple[Array, tuple]: params = NIG._args_transform(params) return _cdf_fwd(dist=nig, cdf_func=_vjp_cdf_copy, x=x, params=params)
_vjp_cdf.defvjp(cdf_fwd, cdf_bwd)