"""File containing the copulAX implementation of the Gamma distribution."""
import jax.numpy as jnp
from jax import lax, random, scipy
from jax import Array
from jax.typing import ArrayLike
from copulax._src._distributions import Univariate
from copulax._src.special import igammainv
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.optimize import projected_gradient
[docs]
class Gamma(Univariate):
r"""The gamma distribution is a two-parameter continuous family on
:math:`(0, \infty)` that includes the exponential, Erlang, and
chi-squared distributions as special cases. The rate parameterisation
of McNeil et al (2005) is used.
The PDF is
.. math::
f(x | \alpha, \beta) =
\frac{\beta^{\alpha}}{\Gamma(\alpha)}\,
x^{\alpha - 1} e^{-\beta x},
\qquad x > 0
where :math:`\alpha > 0` is the shape parameter and
:math:`\beta > 0` is the rate parameter (so the mean is
:math:`\alpha / \beta`).
https://en.wikipedia.org/wiki/Gamma_distribution
"""
alpha: Array = None
beta: Array = None
def __init__(self, name="Gamma", *, alpha=None, beta=None):
"""Initialize the Gamma distribution.
Args:
name: Display name for the distribution.
alpha: Shape parameter.
beta: Rate parameter.
"""
super().__init__(name)
self.alpha = (
jnp.asarray(alpha, dtype=float).reshape(()) if alpha is not None else None
)
self.beta = (
jnp.asarray(beta, dtype=float).reshape(()) if beta is not None else None
)
@property
def _stored_params(self):
"""Return stored parameters if all are set, else None."""
if self.alpha is None or self.beta is None:
return None
return {"alpha": self.alpha, "beta": self.beta}
@classmethod
def _params_dict(cls, alpha: Scalar, beta: Scalar) -> dict:
"""Create a parameter dictionary from alpha (shape) and beta (rate)."""
d: dict = {"alpha": alpha, "beta": beta}
return cls._args_transform(d)
def _params_to_tuple(self, params: dict):
"""Extract (alpha, beta) from the parameter dictionary."""
params = self._args_transform(params)
return params["alpha"], params["beta"]
[docs]
def example_params(self, *args, **kwargs):
return self._params_dict(alpha=1.0, beta=1.0)
@classmethod
def _support(cls, *args, **kwargs) -> Array:
"""Return the support ``[0, inf)``."""
return jnp.array([0.0, jnp.inf])
def _stable_logpdf(self, stability: Scalar, x: ArrayLike, params: dict) -> Array:
"""Compute the numerically stabilized log-PDF of the Gamma distribution."""
x, xshape = _univariate_input(x)
alpha, beta = self._params_to_tuple(params)
logpdf: jnp.ndarray = (
alpha * jnp.log(beta + stability)
- lax.lgamma(alpha)
+ (alpha - 1) * jnp.log(x)
- beta * x
)
return logpdf.reshape(xshape)
[docs]
def cdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Compute the CDF via the regularized incomplete gamma function."""
params = self._resolve_params(params)
x, xshape = _univariate_input(x)
alpha, beta = self._params_to_tuple(params)
cdf: jnp.ndarray = scipy.special.gammainc(a=alpha, x=beta * x)
return self._enforce_support_on_cdf(
x=x, cdf=cdf.reshape(xshape), params=params
)
# ppf
def _ppf(self, q: ArrayLike, params: dict, *args, **kwargs) -> Array:
"""Compute the percent-point function (inverse CDF) via ``igammainv``."""
alpha, beta = self._params_to_tuple(params)
return igammainv(a=alpha, p=q) / beta
# sampling
[docs]
def rvs(
self, size: tuple | Scalar, params: dict = None, key: Array = None
) -> Array:
"""Generate random variates from the Gamma distribution."""
params = self._resolve_params(params)
key = _resolve_key(key)
alpha, beta = self._params_to_tuple(params)
unscales_rvs: jnp.ndarray = random.gamma(key, shape=size, a=alpha)
return unscales_rvs / beta
# stats
[docs]
def stats(self, params: dict = None) -> dict:
"""Compute distribution statistics (mean, mode, variance, std, skewness, kurtosis)."""
params = self._resolve_params(params)
alpha, beta = self._params_to_tuple(params)
mean: float = alpha / beta
mode: float = jnp.where(alpha >= 1.0, (alpha - 1) / beta, 0.0)
variance: float = alpha / (beta**2)
std: float = jnp.sqrt(variance)
skewness: float = 2.0 / jnp.sqrt(alpha)
kurtosis: float = 6.0 / alpha
return self._scalar_transform(
{
"mean": mean,
"mode": mode,
"variance": variance,
"std": std,
"skewness": skewness,
"kurtosis": kurtosis,
}
)
# fitting
@staticmethod
def _sample_moments(x: jnp.ndarray) -> tuple:
"""Method-of-moments (alpha, beta) under the rate parameterisation: ``beta = mean(x) / var(x)``, ``alpha = mean(x) * beta``."""
eps: float = 1e-8
m: jnp.ndarray = jnp.maximum(x.mean(), eps)
v: jnp.ndarray = jnp.maximum(x.var(), eps)
beta0: jnp.ndarray = m / v
alpha0: jnp.ndarray = m * beta0
return alpha0, beta0
def _fit_mle(self, x: ArrayLike, lr: float, maxiter: int) -> dict:
"""Fit alpha and beta via projected gradient MLE."""
alpha0, beta0 = self._sample_moments(x)
params0: jnp.ndarray = jnp.array([alpha0, beta0])
res = projected_gradient(
f=self._mle_objective,
x0=params0,
projection_method="projection_non_negative",
x=x,
lr=lr,
maxiter=maxiter,
)
alpha, beta = res["x"]
return self._params_dict(alpha=alpha, beta=beta) # , res['fun']
_supported_methods = frozenset({"mle"})
[docs]
def fit(
self, x: ArrayLike, lr: float = 0.1, maxiter: int = 100, name: str = None
):
r"""Fit the Gamma distribution to data via **numerical** MLE
(projected gradient on the negative log-likelihood).
Args:
x (ArrayLike): The input data to fit the distribution to.
lr (float): Learning rate for the fitting process.
maxiter (int): Maximum number of iterations for the fitting process.
name (str): Optional custom name for the fitted instance.
Returns:
Gamma: A fitted ``Gamma`` instance.
"""
x: jnp.ndarray = _univariate_input(x)[0]
return self._fitted_instance(
self._fit_mle(x=x, lr=lr, maxiter=maxiter), name=name
)
gamma = Gamma("Gamma")