Source code for copulax._src.univariate.asym_gen_normal

"""File containing the copulAX implementation of the Asymmetric Generalized Normal distribution."""

import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

from copulax._src._distributions import Univariate
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src.optimize import projected_gradient, brent
from copulax._src.univariate.normal import normal
from copulax._src.stats import skew, kurtosis as sample_kurtosis


[docs] class AsymGenNormal(Univariate): r"""The asymmetric generalized normal distribution is a three-parameter continuous family that adds skewness to the normal via a log-link transformation of a standard normal variate :math:`X = \zeta + \alpha (1 - e^{-\kappa Z}) / \kappa`. The normal (``kappa = 0``) arises as the limiting symmetric case. The PDF is .. math:: f(x | \zeta, \alpha, \kappa) = \frac{\phi(y)}{\alpha - \kappa (x - \zeta)}, \qquad y = \begin{cases} (x - \zeta) / \alpha, & \kappa = 0 \\ -\,\kappa^{-1} \log\!\left(1 - \kappa (x - \zeta) / \alpha\right), & \kappa \ne 0 \end{cases} where :math:`\phi` is the standard-normal PDF, :math:`\zeta \in \mathbb{R}` is the location, :math:`\alpha > 0` is the scale, and :math:`\kappa \in \mathbb{R}` is the shape parameter controlling skewness (positive :math:`\kappa` skews left, negative right). The support is right-bounded when :math:`\kappa > 0` and left-bounded when :math:`\kappa < 0`. https://en.wikipedia.org/wiki/Generalized_normal_distribution """ zeta: Array = None alpha: Array = None kappa: Array = None def __init__(self, name="AsymGenNormal", *, zeta=None, alpha=None, kappa=None): """Initialize the Asymmetric Generalized Normal distribution. Args: name: Display name for the distribution. zeta: Location parameter. alpha: Scale parameter. kappa: Shape parameter controlling skewness. """ super().__init__(name) self.zeta = ( jnp.asarray(zeta, dtype=float).reshape(()) if zeta is not None else None ) self.alpha = ( jnp.asarray(alpha, dtype=float).reshape(()) if alpha is not None else None ) self.kappa = ( jnp.asarray(kappa, dtype=float).reshape(()) if kappa is not None else None ) @property def _stored_params(self): """Return stored parameters if all are set, else None.""" if self.zeta is None or self.alpha is None or self.kappa is None: return None return {"zeta": self.zeta, "alpha": self.alpha, "kappa": self.kappa} @classmethod def _params_dict(cls, zeta: Scalar, alpha: Scalar, kappa: Scalar) -> dict: """Create a parameter dictionary from zeta, alpha, and kappa values.""" d: dict = {"zeta": zeta, "alpha": alpha, "kappa": kappa} return cls._args_transform(d) @classmethod def _params_to_tuple(cls, params: dict) -> tuple: """Extract (zeta, alpha, kappa) from the parameter dictionary.""" params = cls._args_transform(params) return params["zeta"], params["alpha"], params["kappa"]
[docs] def example_params(self, *args, **kwargs) -> dict: return self._params_dict(zeta=0.0, alpha=1.0, kappa=-0.5)
@classmethod def _support(cls, params: dict) -> Array: """Return the support, which depends on kappa. When ``kappa < 0`` the support is ``[zeta + alpha/kappa, inf)``; when ``kappa > 0`` it is ``(-inf, zeta + alpha/kappa]``; when ``kappa == 0`` it is ``(-inf, inf)``. """ zeta, alpha, kappa = cls._params_to_tuple(params) val = jnp.where(kappa == 0, jnp.inf, zeta + alpha / kappa) support = jnp.where( kappa < 0, jnp.array([val, jnp.inf]), jnp.array([-jnp.inf, val]) ) return support def _stable_logpdf(self, stability: Scalar, x: ArrayLike, params: dict) -> Array: """Compute the numerically stabilized log-PDF of the Asymmetric Generalized Normal.""" x, xshape = _univariate_input(x) zeta, alpha, kappa = self._params_to_tuple(params) z = (x - zeta) / (alpha + stability) one_minus_kz = 1.0 - kappa * z # 0 exactly at the support boundary # Substitute x with zeta (always strictly inside the support) at and # beyond the boundary so log1p / log arguments stay valid. The # masked logpdf is restored to -inf at the end, keeping both # branches finite for clean autograd through jnp.where. safe_x = jnp.where(one_minus_kz > 0, x, zeta) z_safe = (safe_x - zeta) / (alpha + stability) y = jnp.where( kappa == 0, z_safe, (-1.0 / (kappa + stability)) * jnp.log1p(-kappa * z_safe), ) raw = normal.logpdf(y, params={"mu": 0.0, "sigma": 1.0}) - jnp.log( alpha - kappa * (safe_x - zeta) ) log_pdf = jnp.where(one_minus_kz > 0, raw, -jnp.inf) return log_pdf.reshape(xshape)
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the CDF via transformation to the standard normal.""" params = self._resolve_params(params) x, xshape = _univariate_input(x) zeta, alpha, kappa = self._params_to_tuple(params) z = (x - zeta) / alpha y = jnp.where(kappa == 0, z, (-1.0 / kappa) * jnp.log1p(-kappa * z)) cdf = normal.cdf(y, params={"mu": 0.0, "sigma": 1.0}) return self._enforce_support_on_cdf( x=x, cdf=cdf.reshape(xshape), params=params )
# sampling
[docs] def rvs( self, size: tuple | Scalar, params: dict = None, key: Array = None ) -> Array: """Generate random variates via transformation of standard normals.""" params = self._resolve_params(params) zeta, alpha, kappa = self._params_to_tuple(params) Z = normal.rvs(size=size, key=key, params={"mu": 0.0, "sigma": 1.0}) X = jnp.where( kappa == 0, zeta + alpha * Z, zeta + alpha * (1 - jnp.exp(-kappa * Z)) / kappa, ) return X
# stats
[docs] def stats(self, params: dict = None) -> dict: """Compute distribution statistics (mean, median, mode, variance, skewness, kurtosis).""" params = self._resolve_params(params) zeta, alpha, kappa = self._params_to_tuple(params) kappa_sq_exp = jnp.exp(kappa**2) mean = jnp.where( kappa == 0, zeta, zeta - (alpha / kappa) * (jnp.exp(0.5 * kappa**2) - 1.0) ) variance = jnp.where( kappa == 0, alpha**2, (alpha / kappa) ** 2 * kappa_sq_exp * (kappa_sq_exp - 1.0), ) skewness = jnp.where( kappa == 0, 0.0, jnp.sign(kappa) * (3 * kappa_sq_exp - jnp.exp(3 * kappa**2) - 2) / ((kappa_sq_exp - 1) ** 1.5), ) kurtosis = ( jnp.exp(4 * kappa**2) + 2 * jnp.exp(3 * kappa**2) + 3 * jnp.exp(2 * kappa**2) - 6.0 ) return { "mean": mean, "median": zeta, "mode": zeta, "variance": variance, "skewness": skewness, "kurtosis": kurtosis, }
# fitting @staticmethod def _kurtosis_score(kappa_abs: Scalar, sample_kurt: Scalar) -> Scalar: r"""Residual of the excess kurtosis equation for ``|kappa|``. The excess kurtosis of the AsymGenNormal is purely a function of ``kappa^2``: .. math:: \kappa_4(\kappa) = e^{4\kappa^2} + 2e^{3\kappa^2} + 3e^{2\kappa^2} - 6 This is monotonically increasing in ``|kappa|``, so Brent's method can find the unique root on ``[0, 2]``. Args: kappa_abs: Absolute value of the shape parameter (scalar, >= 0). sample_kurt: Sample excess kurtosis to match. Returns: Residual: theoretical kurtosis - sample kurtosis. """ k2 = kappa_abs ** 2 theoretical = jnp.exp(4 * k2) + 2 * jnp.exp(3 * k2) + 3 * jnp.exp(2 * k2) - 6.0 return theoretical - sample_kurt @staticmethod def _sample_moments(x: jnp.ndarray) -> dict: r"""Method-of-moments estimates for (zeta, alpha, kappa): zeta = median(x); ``|kappa|`` from sample excess kurtosis via Brent inversion on ``[0, 2]`` (kurtosis is symmetric in kappa and monotone in ``|kappa|``); sign of kappa from sample skew; alpha from sample variance and ``Var(X) = (alpha/kappa)^2 * exp(kappa^2) * (exp(kappa^2) - 1)``.""" sample_mean = jnp.mean(x) sample_std = jnp.std(x) sample_kurt = sample_kurtosis(x, fisher=True, bias=True) sample_skew = skew(x, bias=True) # Clip kurtosis to valid range [0, kurtosis(2)] # kurtosis(0) = 0, kurtosis(2) ≈ 9.2M sample_kurt = jnp.clip(sample_kurt, 0.01, 9e6) # Solve for |kappa| via Brent kappa_abs = brent( g=AsymGenNormal._kurtosis_score, bounds=jnp.array([0.0, 2.0]), maxiter=30, sample_kurt=sample_kurt, ) kappa_abs = jnp.clip(kappa_abs, 0.01, 2.0) # Sign: negative skew => kappa > 0, positive skew => kappa < 0 kappa = jnp.where(sample_skew < 0, kappa_abs, -kappa_abs) # zeta = median zeta = jnp.median(x) # alpha from variance var_scale = (kappa**2) / (jnp.exp(kappa**2) * (jnp.exp(kappa**2) - 1.0)) alpha = jnp.sqrt(var_scale) * sample_std # Support safety: ensure all data is within the implied support. # For kappa < 0: support is [zeta + alpha/kappa, inf). # Need: min(x) > zeta + alpha/kappa => |kappa| < alpha / (zeta - min(x)) # For kappa > 0: support is (-inf, zeta + alpha/kappa]. # Need: max(x) < zeta + alpha/kappa => kappa < alpha / (max(x) - zeta) # Scale |kappa| down with 0.95 safety margin if it violates. margin = 0.95 kappa_max_neg = margin * alpha / (zeta - jnp.min(x) + 1e-30) kappa_max_pos = margin * alpha / (jnp.max(x) - zeta + 1e-30) kappa = jnp.where( kappa < 0, jnp.maximum(kappa, -kappa_max_neg), # clamp toward 0 jnp.minimum(kappa, kappa_max_pos), # clamp toward 0 ) return AsymGenNormal._params_dict(zeta=zeta, alpha=alpha, kappa=kappa) def _fit_mom(self, x: jnp.ndarray) -> dict: """Fit via method of moments (no MLE refinement). Returns parameter estimates derived purely from sample moments: kurtosis → ``|kappa|`` via Brent, sign from skewness, zeta from median, alpha from the variance formula. Args: x: Data array. Returns: Parameter dictionary with MoM estimates. """ return self._sample_moments(x) def _fit_mle(self, x: jnp.ndarray, lr: float, maxiter: int) -> dict: """Fit via projected gradient MLE, initialized from method of moments. Uses MoM estimates (kurtosis inversion for kappa, median for zeta, variance formula for alpha) as starting point, then refines all three parameters via projected gradient descent on the negative log-likelihood. Args: x: Data array. lr: Learning rate for optimization. maxiter: Maximum number of iterations. Returns: Parameter dictionary with MLE estimates. """ eps: float = 1e-8 constraints: tuple = ( jnp.array([[-jnp.inf, eps, -jnp.inf]]).T, jnp.array([[jnp.inf, jnp.inf, jnp.inf]]).T, ) projection_options: dict = {"lower": constraints[0], "upper": constraints[1]} # MoM initialization mom_params = self._sample_moments(x) zeta0, alpha0, kappa0 = self._params_to_tuple(mom_params) params0: jnp.ndarray = jnp.array([zeta0, alpha0, kappa0]) res: dict = projected_gradient( f=self._mle_objective, x0=params0, projection_method="projection_box", projection_options=projection_options, x=x, lr=lr, maxiter=maxiter, ) zeta, alpha, kappa = res["x"] return self._params_dict(zeta=zeta, alpha=alpha, kappa=kappa) _supported_methods = frozenset({"mle", "mom"})
[docs] def fit( self, x: ArrayLike, method: str = "mle", lr: float = 0.1, maxiter: int = 100, name: str = None, ): r"""Fit the distribution to data. Note: If you intend to jit wrap this function, ensure that ``method`` is a static argument. Args: x: Input data to fit. method: Fitting method. One of: ``'mle'`` — projected-gradient maximum likelihood with MoM initialisation (numerical; default); ``'mom'`` — **closed-form** method of moments (faster, no gradient refinement). lr: Learning rate for optimisation (MLE only; ignored for ``'mom'``). Default ``0.1``. maxiter: Maximum number of iterations (MLE only; ignored for ``'mom'``). name: Optional custom name for the fitted instance. Returns: AsymGenNormal: A fitted ``AsymGenNormal`` instance. Raises: ValueError: If ``method`` is not one of the accepted strings listed above. """ self._check_method(method) x: jnp.ndarray = _univariate_input(x)[0] if method == "mle": return self._fitted_instance(self._fit_mle(x, lr, maxiter), name=name) elif method == "mom": return self._fitted_instance(self._fit_mom(x), name=name) else: raise ValueError( f"Unknown Asym-Gen-Normal fit method {method!r}. " f"Expected one of: {sorted(self._supported_methods)}." )
asym_gen_normal = AsymGenNormal("Asym-Gen-Normal")