Source code for copulax._src.univariate.gen_normal

"""File containing the copulAX implementation of the Generalized normal distribution."""

import jax.numpy as jnp
from jax import random, scipy
from jax.scipy import special
from jax import Array
from jax.typing import ArrayLike

from copulax._src._distributions import Univariate
from copulax._src.special import igammainv, digamma
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.optimize import brent
from copulax._src.univariate.gamma import gamma


[docs] class GenNormal(Univariate): r"""The symmetric generalized normal distribution is a three-parameter continuous family that generalises the normal by allowing heavier or lighter tails. The normal (``beta = 2``) and Laplace (``beta = 1``) distributions arise as special cases; as ``beta -> inf`` the density tends to a uniform on ``[mu - alpha, mu + alpha]``. The PDF is .. math:: f(x | \mu, \alpha, \beta) = \frac{\beta}{2 \alpha \, \Gamma(1/\beta)} \exp\!\left(-\left(\frac{|x - \mu|}{\alpha}\right)^\beta\right) where :math:`\mu \in \mathbb{R}` is the location, :math:`\alpha > 0` is the scale, and :math:`\beta > 0` is the shape parameter controlling tail weight. https://en.wikipedia.org/wiki/Generalized_normal_distribution """ mu: Array = None alpha: Array = None beta: Array = None def __init__(self, name="GenNormal", *, mu=None, alpha=None, beta=None): """Initialize the Generalized Normal distribution. Args: name: Display name for the distribution. mu: Location parameter. alpha: Scale parameter. beta: Shape parameter (beta=2 gives the normal distribution). """ super().__init__(name) self.mu = jnp.asarray(mu, dtype=float).reshape(()) if mu is not None else None self.alpha = ( jnp.asarray(alpha, dtype=float).reshape(()) if alpha is not None else None ) self.beta = ( jnp.asarray(beta, dtype=float).reshape(()) if beta is not None else None ) @property def _stored_params(self): """Return stored parameters if all are set, else None.""" if self.mu is None or self.alpha is None or self.beta is None: return None return {"mu": self.mu, "alpha": self.alpha, "beta": self.beta} @classmethod def _params_dict(cls, mu: Scalar, alpha: Scalar, beta: Scalar) -> dict: """Create a parameter dictionary from mu, alpha, and beta values.""" d: dict = {"mu": mu, "alpha": alpha, "beta": beta} return cls._args_transform(d) def _params_to_tuple(self, params: dict) -> tuple: """Extract (mu, alpha, beta) from the parameter dictionary.""" params = self._args_transform(params) return params["mu"], params["alpha"], params["beta"]
[docs] def example_params(self, *args, **kwargs) -> dict: return self._params_dict(mu=0.0, alpha=1.0, beta=2.0)
@classmethod def _support(cls, *args, **kwargs) -> Array: """Return the support ``[-inf, inf]``.""" return jnp.array([-jnp.inf, jnp.inf]) def _stable_logpdf(self, stability: Scalar, x: ArrayLike, params: dict) -> Array: """Compute the numerically stabilized log-PDF of the Generalized Normal.""" x, xshape = _univariate_input(x) mu, alpha, beta = self._params_to_tuple(params) log_c: Scalar = ( jnp.log(beta + stability) - jnp.log(2.0 * alpha) - special.gammaln(1.0 / (beta + stability)) ) logpdf: Array = log_c - (jnp.abs(x - mu) / (alpha)) ** beta return logpdf.reshape(xshape)
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: params = self._resolve_params(params) x, xshape = _univariate_input(x) mu, alpha, beta = self._params_to_tuple(params) z: Array = (x - mu) / alpha incomplete_gamma_component = scipy.special.gammainc( a=1.0 / beta, x=(jnp.abs(z) ** beta) ) cdf: Array = 0.5 * (1.0 + jnp.sign(z) * incomplete_gamma_component) return self._enforce_support_on_cdf( x=x, cdf=cdf.reshape(xshape), params=params )
def _ppf(self, q: ArrayLike, params: dict = None, *args, **kwargs) -> Array: """Compute the PPF via the inverse regularized incomplete gamma function.""" params = self._resolve_params(params) q, qshape = _univariate_input(q) mu, alpha, beta = self._params_to_tuple(params) z = 2.0 * q - 1.0 x = mu + jnp.sign(z) * alpha * jnp.power( igammainv(a=1.0 / beta, p=jnp.abs(z)), 1.0 / beta ) return x.reshape(qshape) # sampling
[docs] def rvs( self, size: tuple | Scalar, params: dict = None, key: Array = None ) -> Array: params = self._resolve_params(params) key = _resolve_key(key) mu, alpha, beta = self._params_to_tuple(params) key1, key2 = random.split(key) G = gamma.rvs(size=size, key=key1, params={"alpha": 1.0 / beta, "beta": 1.0}) sign = 2.0 * random.bernoulli(key2, 0.5, shape=size).astype(float) - 1.0 return mu + alpha * sign * jnp.power(G, 1.0 / beta)
# stats
[docs] def stats(self, params: dict = None) -> dict: params = self._resolve_params(params) mu, alpha, beta = self._params_to_tuple(params) variance = alpha**2 * special.gamma(3.0 / beta) / special.gamma(1.0 / beta) kurtosis = ( special.gamma(5.0 / beta) * special.gamma(1.0 / beta) / (special.gamma(3.0 / beta) ** 2) - 3.0 ) return { "mean": mu, "median": mu, "mode": mu, "variance": variance, "skewness": jnp.float32(0.0), "kurtosis": kurtosis, }
# fitting @staticmethod def _sample_moments(x: jnp.ndarray) -> Scalar: r"""Sample-median initial estimate for mu (robust under symmetry; preferred over sample mean for heavy-tailed / small-beta regimes).""" return jnp.median(x) @staticmethod def _mle_score(beta: Scalar, x: jnp.ndarray, mu: Scalar) -> Scalar: r"""Score function g(beta) whose root is the MLE of beta. From Wikipedia (Generalized normal distribution, Version 1): .. math:: g(\beta) = 1 + \frac{\psi(1/\beta)}{\beta} - \frac{\sum |x_i - \mu|^\beta \log|x_i - \mu|} {\sum |x_i - \mu|^\beta} + \frac{\log\!\bigl(\frac{\beta}{N}\sum |x_i-\mu|^\beta\bigr)} {\beta} where psi is the digamma function. Args: beta: Shape parameter (scalar, > 0). x: Data array. mu: Location parameter (fixed). Returns: Scalar value of g(beta). """ n = x.shape[0] abs_dev = jnp.abs(x - mu) + 1e-30 # avoid log(0) log_abs_dev = jnp.log(abs_dev) abs_dev_beta = abs_dev ** beta sum_abs_dev_beta = jnp.sum(abs_dev_beta) sum_weighted_log = jnp.sum(abs_dev_beta * log_abs_dev) inv_beta = 1.0 / beta psi_val = digamma(jnp.atleast_1d(inv_beta))[0] term1 = 1.0 + psi_val * inv_beta term2 = sum_weighted_log / sum_abs_dev_beta term3 = jnp.log(beta / n * sum_abs_dev_beta) * inv_beta return term1 - term2 + term3 @staticmethod def _mu_score(mu: Scalar, x: jnp.ndarray, beta: Scalar) -> Scalar: r"""Derivative of :math:`\sum |x_i - \mu|^\beta` w.r.t. :math:`\mu`. .. math:: \frac{d}{d\mu}\sum|x_i-\mu|^\beta = -\beta \sum |x_i-\mu|^{\beta-1}\,\mathrm{sign}(x_i-\mu) The root of this function is the MLE of mu given beta. Args: mu: Location parameter (scalar). x: Data array. beta: Shape parameter (fixed). Returns: Scalar derivative value. """ diff = x - mu abs_diff = jnp.abs(diff) + 1e-30 return -beta * jnp.sum(abs_diff ** (beta - 1.0) * jnp.sign(diff)) def _fit_mle(self, x: jnp.ndarray) -> dict: r"""Fit via Wikipedia's MLE algorithm using Brent's method. Algorithm (single pass): 1. mu_0 = mean(x) 2. Solve g(beta) = 0 for beta via Brent (with mu fixed at mu_0) 3. Solve d/dmu sum|x_i - mu|^beta = 0 for mu via Brent (with beta fixed) 4. alpha = (beta/N * sum|x_i - mu|^beta)^(1/beta) Reference: https://en.wikipedia.org/wiki/Generalized_normal_distribution """ n = x.shape[0] mu = self._sample_moments(x) # Step 1: Solve g(beta) = 0 for beta with mu fixed beta = brent( g=self._mle_score, bounds=jnp.array([0.1, 10.0]), maxiter=30, x=x, mu=mu, ) beta = jnp.clip(beta, 0.1, 10.0) # Step 2: Solve d/dmu sum|x_i - mu|^beta = 0 for mu with beta fixed mu = brent( g=self._mu_score, bounds=jnp.array([jnp.min(x), jnp.max(x)]), maxiter=30, x=x, beta=beta, ) # Step 3: Derive alpha analytically alpha = jnp.power( beta / n * jnp.sum(jnp.abs(x - mu) ** beta), 1.0 / beta ) return self._params_dict(mu=mu, alpha=alpha, beta=beta) def _fit_mom(self, x: jnp.ndarray) -> dict: """Fit via method of moments (no MLE refinement). Uses the sample median as mu, solves the MLE score equation for beta via Brent, then derives alpha analytically. Args: x: Data array. Returns: Parameter dictionary with MoM estimates. """ n = x.shape[0] mu = self._sample_moments(x) beta = brent( g=self._mle_score, bounds=jnp.array([0.1, 10.0]), maxiter=30, x=x, mu=mu, ) beta = jnp.clip(beta, 0.1, 10.0) alpha = jnp.power( beta / n * jnp.sum(jnp.abs(x - mu) ** beta), 1.0 / beta ) return self._params_dict(mu=mu, alpha=alpha, beta=beta) _supported_methods = frozenset({"mle", "mom"})
[docs] def fit(self, x: ArrayLike, method: str = "mle", name: str = None): r"""Fit the distribution to data. Note: If you intend to jit wrap this function, ensure that ``method`` is a static argument. Args: x: Input data to fit. method: Fitting method. One of: ``'mle'`` — MLE algorithm using Brent's method (derivative-free numerical root-finding; default); ``'mom'`` — **closed-form** method of moments (faster, no μ refinement step). name: Optional custom name for the fitted instance. Returns: GenNormal: A fitted ``GenNormal`` instance. Raises: ValueError: If ``method`` is not one of the accepted strings listed above. """ self._check_method(method) x: jnp.ndarray = _univariate_input(x)[0] if method == "mle": return self._fitted_instance(self._fit_mle(x), name=name) elif method == "mom": return self._fitted_instance(self._fit_mom(x), name=name) else: raise ValueError( f"Unknown Gen-Normal fit method {method!r}. " f"Expected one of: {sorted(self._supported_methods)}." )
gen_normal = GenNormal("Gen-Normal")