Source code for copulax._src.univariate.gig

"""File containing the copulAX implementation of the Generalized Inverse
Gaussian distribution."""

import jax.numpy as jnp
from jax import random, lax, custom_vjp, jit
from jax import Array
from jax.typing import ArrayLike
from copy import deepcopy

from copulax._src._distributions import Univariate
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.univariate._cdf import _cdf, cdf_bwd, _cdf_fwd
from copulax._src.optimize import projected_gradient
from copulax.special import kv, log_kv


[docs] class GIG(Univariate): r"""The Generalized Inverse Gaussian distribution is a three-parameter continuous family on :math:`(0, \infty)` that arises as the mixing distribution of the generalised hyperbolic family. The inverse-gamma, gamma, and inverse-Gaussian distributions are special / limiting cases. The McNeil et al (2005) parameterisation is used. The PDF is .. math:: f(x | \lambda, \chi, \psi) = \frac{(\psi / \chi)^{\lambda / 2}} {2\, K_{\lambda}\!\bigl(\sqrt{\chi \psi}\bigr)}\, x^{\lambda - 1} \exp\!\left(-\tfrac{1}{2}(\chi / x + \psi x)\right), \qquad x > 0 where :math:`K_{\lambda}` is the modified Bessel function of the second kind, :math:`\lambda \in \mathbb{R}` is the shape / order parameter, :math:`\chi > 0` is the concentration (or penalty on small :math:`x`), and :math:`\psi > 0` is the rate (penalty on large :math:`x`). https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution """ lamb: Array = None chi: Array = None psi: Array = None def __init__(self, name="GIG", *, lamb=None, chi=None, psi=None): """Initialize the Generalized Inverse Gaussian distribution. Args: name: Display name for the distribution. lamb: Shape parameter (real-valued). chi: Concentration parameter (strictly positive). psi: Rate parameter (strictly positive). """ super().__init__(name) self.lamb = ( jnp.asarray(lamb, dtype=float).reshape(()) if lamb is not None else None ) self.chi = ( jnp.asarray(chi, dtype=float).reshape(()) if chi is not None else None ) self.psi = ( jnp.asarray(psi, dtype=float).reshape(()) if psi is not None else None ) @property def _stored_params(self): """Return stored parameters if all are set, else None.""" if self.lamb is None or self.chi is None or self.psi is None: return None return {"lamb": self.lamb, "chi": self.chi, "psi": self.psi} @classmethod def _params_dict(cls, lamb: Scalar, chi: Scalar, psi: Scalar) -> dict: """Create a parameter dictionary from lamb, chi, and psi values.""" d: dict = {"lamb": lamb, "chi": chi, "psi": psi} return cls._args_transform(d) @staticmethod def _params_to_tuple(params: dict) -> tuple: """Extract (lamb, chi, psi) from the parameter dictionary.""" params = GIG._args_transform(params) return params["lamb"], params["chi"], params["psi"] @staticmethod def _params_to_array(params: dict) -> Array: """Convert the parameter dictionary to a flat array.""" return jnp.asarray(GIG._params_to_tuple(params)).flatten() @classmethod def _support(cls, *args, **kwargs) -> Array: """Return the support ``(0, inf)``.""" return jnp.array([0.0, jnp.inf])
[docs] def example_params(self, *args, **kwargs): return self._params_dict(lamb=1.0, chi=1.0, psi=1.0)
@staticmethod def _stable_logpdf(stability: Scalar, x: ArrayLike, params: dict) -> Array: """Compute the numerically stabilized log-PDF of the GIG distribution.""" lamb, chi, psi = GIG._params_to_tuple(params) x, xshape = _univariate_input(x) var = lax.add( lax.mul(lamb - 1, lax.log(x)), -0.5 * (lax.mul(chi, lax.pow(x, -1)) + lax.mul(psi, x)), ) cT = lax.mul(0.5 * lamb, lax.log((psi / (chi + stability)) + stability)) cB = log_kv(lamb, lax.pow(lax.mul(chi, psi), 0.5)) + jnp.log(2.0) # kv_val = kv(lamb, lax.pow(lax.mul(chi, psi), 0.5)) # cB = lax.log(stability + 2 * kv_val) c = lax.sub(cT, cB) logpdf_raw = lax.add(var, c) logpdf: jnp.ndarray = jnp.where(jnp.isnan(logpdf_raw), -jnp.inf, logpdf_raw) return logpdf.reshape(xshape)
[docs] def logpdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the log probability density function.""" params = self._resolve_params(params) logpdf = GIG._stable_logpdf(stability=0.0, x=x, params=params) return self._enforce_support_on_logpdf(x=x, logpdf=logpdf, params=params)
# sampling # Uses the method outlined by Luc Devroye in "Random variate generation for # the generalized inverse Gaussian distribution" (2014). def _devroye(self, x, alpha, lamb): """Evaluate the Devroye (2014) acceptance log-density.""" return -alpha * (jnp.cosh(x) - 1) - lamb * (jnp.exp(x) - x - 1) def _devroye_grad(self, x, alpha, lamb): """Gradient of the Devroye acceptance log-density.""" return -alpha * jnp.sinh(x) - lamb * (jnp.exp(x) - 1) def _new_single_rv(self, carry, _): """One iteration of the Devroye rejection sampler.""" key, _, stop, count, constants = carry lamb, alpha, t, s, t_, s_, eta, zeta, theta, xi, p, r, q = constants key, subkey = random.split(key) u, v, w = random.uniform(subkey, shape=(3,)) x = jnp.where( u < (q + r) / (q + p + r), t_ + r * lax.log(1 / v), -s_ - p * lax.log(1 / v) ) x = jnp.where(u < q / (q + p + r), -s_ + q * v, x) # checking stopping condition chi = ( jnp.where(jnp.logical_and(-s_ < x, x < t_), 1.0, 0.0) + jnp.where(t_ < x, jnp.exp(-eta - zeta * (x - t)), 0.0) + jnp.where(x < -s_, jnp.exp(-theta + xi * (x + s)), 0.0) ) stop = w * chi <= jnp.exp(self._devroye(x, alpha, lamb)) return (key, x, stop, count + 1, constants), None @jit def _generate_single_rv(self, key: Array, constants: tuple) -> tuple[Array, Array]: """Generate a single GIG random variate using the Devroye (2014) algorithm.""" maxiter = 10 init = (key, jnp.array(jnp.nan), False, 0, constants) res = lax.scan( ( lambda carry, _: lax.cond( carry[2], (lambda carry, _: (carry, _)), self._new_single_rv, carry, None, ) ), init, None, maxiter, )[0] return res[0], res[1]
[docs] def rvs( self, size: tuple | Scalar, params: dict = None, key: Array = None ) -> Array: """Generate random variates using the Devroye (2014) rejection algorithm. Args: size: Shape of the output array. params: Distribution parameters. Uses stored parameters if None. key: JAX PRNG key. A default key is used if None. Returns: Array of GIG random samples. """ params = self._resolve_params(params) key = _resolve_key(key) # getting parameters lamb, chi, psi = self._params_to_tuple(params) sign_lamb: int = jnp.where(jnp.sign(lamb) >= 0, 1, -1) lamb: float = jnp.abs(lamb) omega: float = lax.sqrt(chi * psi) alpha: float = lax.sqrt(jnp.pow(omega, 2) + jnp.pow(lamb, 2)) - lamb # getting positive constant t _devroye_1: float = self._devroye(x=1, alpha=alpha, lamb=lamb) t: float = jnp.where(-_devroye_1 > 2, lax.sqrt(2 / (alpha + lamb)), 1) t = jnp.where(-_devroye_1 < 0.5, lax.log(4 / (alpha + 2 * lamb)), t) # getting positive constant s _devroye_minus_1: float = self._devroye(x=-1, alpha=alpha, lamb=lamb) s: float = jnp.where( -_devroye_minus_1 > 2, lax.sqrt(4 / (alpha * jnp.cosh(1) + lamb)), 1 ) s = jnp.where( -_devroye_minus_1 < 0.5, jnp.min( jnp.array( [ 1 / lamb, lax.log( 1 + (1 / alpha) + lax.sqrt(jnp.pow(alpha, -2) + (2 / alpha)) ), ] ) ), s, ) # Computing constants eta, zeta, theta, xi = ( -self._devroye(x=t, alpha=alpha, lamb=lamb), -self._devroye_grad(x=t, alpha=alpha, lamb=lamb), -self._devroye(x=-s, alpha=alpha, lamb=lamb), self._devroye_grad(x=-s, alpha=alpha, lamb=lamb), ) p, r = 1 / xi, 1 / zeta t_: float = t - r * eta s_: float = s - p * theta q: float = t_ + s_ # Generating random variables constants: tuple = (lamb, alpha, t, s, t_, s_, eta, zeta, theta, xi, p, r, q) if isinstance(size, (int, float)): num_samples: int = int(size) else: num_samples: int = 1 for number in size: num_samples *= number X: jnp.ndarray = lax.scan( (lambda key, _: self._generate_single_rv(key, constants)), key, None, num_samples, )[1] frac: float = lax.div(lamb, omega) c: float = frac + lax.sqrt(1 + lax.pow(frac, 2)) scale = lax.sqrt(lax.div(chi, psi)) return (scale * jnp.pow((c * jnp.exp(X)), sign_lamb)).reshape(size)
# stats @staticmethod def _mode(params: dict) -> Array: """Closed-form mode ``((lamb - 1) + sqrt((lamb - 1)^2 + chi * psi)) / psi`` (valid for ``chi, psi > 0``).""" lamb, chi, psi = GIG._params_to_tuple(params) return lax.div( (lamb - 1) + lax.sqrt(lax.pow(lamb - 1, 2) + lax.mul(chi, psi)), psi )
[docs] def stats(self, params: dict = None) -> dict: """Compute distribution statistics (mean, variance, std, mode). Uses analytical formulas based on modified Bessel functions. Falls back to sample estimates when numerical instability causes NaN. """ params = self._resolve_params(params) lamb, chi, psi = self._params_to_tuple(params) # calculating mean r: float = lax.sqrt(lax.mul(chi, psi)) # frac: float = lax.div(chi, psi) # kv_lamb: float = kv(lamb, r) # kv_lamb_plus_1: float = kv(lamb + 1, r) # mean: float = lax.mul( # lax.pow(frac, 0.5), lax.div(kv_lamb_plus_1, kv_lamb) # ) log_frac: float = lax.log(chi) - lax.log(psi) log_kv_lamb: float = log_kv(lamb, r) log_kv_lamb_plus_1: float = log_kv(lamb + 1, r) log_mean: float = 0.5 * log_frac + log_kv_lamb_plus_1 - log_kv_lamb mean = jnp.exp(log_mean) # calculating variance # kv_lamb_plus_2: float = kv(lamb + 2, r) # second_moment: float = lax.mul(frac, lax.div(kv_lamb_plus_2, kv_lamb)) # variance: float = lax.sub(second_moment, lax.pow(mean, 2)) log_kv_lamb_plus_2: float = log_kv(lamb + 2, r) log_second_moment: float = log_frac + log_kv_lamb_plus_2 - log_kv_lamb second_moment: float = jnp.exp(log_second_moment) variance: float = lax.sub(second_moment, lax.pow(mean, 2)) std: float = jnp.sqrt(variance) return self._scalar_transform( {"mean": mean, "variance": variance, "std": std, "mode": GIG._mode(params)} )
# fitting @staticmethod def _sample_moments(x: jnp.ndarray) -> tuple: """Compute method-of-moments initial estimates for (lamb, chi, psi). Uses the large-r asymptotic approximation where K_{λ+1}(r)/K_λ(r) ≈ 1: E[X] ≈ sqrt(chi/psi) Var(X) ≈ sqrt(chi/psi) / sqrt(chi*psi) = E[X] / r Solving for chi and psi: r ≈ mean² / var (from Var ≈ mean / r) chi ≈ mean * r (from chi = sqrt(chi/psi * chi*psi) = mean * r) psi ≈ r / mean (from psi = sqrt(chi*psi / (chi/psi)) = r / mean) """ m = jnp.mean(x) v = jnp.var(x) # r = sqrt(chi*psi) ≈ mean^2 / var r0 = jnp.clip(m ** 2 / (v + 1e-10), 0.5, 50.0) chi0 = jnp.clip(m * r0, 1e-4, 100.0) psi0 = jnp.clip(r0 / (m + 1e-10), 1e-4, 100.0) lamb0 = 1.0 return lamb0, chi0, psi0 def _fit_mle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict: """Fit via projected gradient MLE with box constraints on chi and psi.""" eps = 1e-8 constraints: tuple = ( jnp.array([[-jnp.inf, eps, eps]]).T, jnp.array([[jnp.inf, jnp.inf, jnp.inf]]).T, ) projection_options: dict = {"lower": constraints[0], "upper": constraints[1]} lamb0, chi0, psi0 = self._sample_moments(x) params0: jnp.ndarray = jnp.array([lamb0, chi0, psi0]) res = projected_gradient( f=self._mle_objective, x0=params0, projection_method="projection_box", projection_options=projection_options, x=x, lr=lr, maxiter=maxiter, ) lamb, chi, psi = res["x"] return self._params_dict(lamb=lamb, chi=chi, psi=psi) # , res['fun'] _supported_methods = frozenset({"mle"})
[docs] def fit( self, x: ArrayLike, lr: float = 0.1, maxiter: int = 100, name: str = None ): r"""Fit the Generalized Inverse Gaussian distribution to data via **numerical** MLE (projected gradient on the negative log-likelihood). Args: x (ArrayLike): The input data to fit the distribution to. lr (float): Learning rate for optimization. maxiter (int): Maximum number of iterations for optimization. name (str): Optional custom name for the fitted instance. Returns: GIG: A fitted ``GIG`` instance. """ x: jnp.ndarray = _univariate_input(x)[0] return self._fitted_instance( self._fit_mle(x=x, lr=lr, maxiter=maxiter), name=name )
# cdf @staticmethod def _params_from_array(params_arr, *args, **kwargs) -> dict: """Reconstruct a parameter dictionary from a flat array.""" lamb, chi, psi = params_arr return GIG._params_dict(lamb=lamb, chi=chi, psi=psi) @staticmethod def _pdf_for_cdf(x: ArrayLike, *params_tuple) -> Array: """Evaluate the PDF for numerical CDF integration.""" params_array: jnp.ndarray = jnp.asarray(params_tuple).flatten() params: dict = GIG._params_from_array(params_array) return lax.exp(GIG._stable_logpdf(stability=0.0, x=x, params=params)) def _cdf_anchors(self, params: dict) -> Array: """Use the closed-form mode (delegates to ``_mode``) as the bulk anchor. For GIG, the mode is a tighter bulk anchor than the mean — GIG is strongly skewed for small ``lamb`` (e.g. ``lamb < 0`` puts the mean out in the right tail while the mode sits near the lower support bound). """ return jnp.asarray(GIG._mode(params)).reshape((1,))
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the CDF via numerical integration with a custom VJP.""" params = self._resolve_params(params) cdf = _vjp_cdf(x=x, params=params) return self._enforce_support_on_cdf(x=x, cdf=cdf, params=params)
gig = GIG("GIG") def _vjp_cdf(x: ArrayLike, params: dict) -> Array: params: dict = GIG._args_transform(params) return _cdf(dist=gig, x=x, params=params) _vjp_cdf_copy = deepcopy(_vjp_cdf) _vjp_cdf = custom_vjp(_vjp_cdf)
[docs] def cdf_fwd(x: ArrayLike, params: dict) -> tuple[Array, tuple]: params = GIG._args_transform(params) return _cdf_fwd(dist=gig, cdf_func=_vjp_cdf_copy, x=x, params=params)
_vjp_cdf.defvjp(cdf_fwd, cdf_bwd)