Source code for copulax._src.univariate.wald

"""File containing the copulAX implementation of the Wald/Inverse Gaussian distribution."""

import jax.numpy as jnp
from jax import random
from jax import Array
from jax.typing import ArrayLike
from jax.scipy import special

from copulax._src._distributions import Univariate
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.univariate.normal import normal


[docs] class Wald(Univariate): r"""The Wald distribution, also known as the Inverse Gaussian distribution, is a continuous 2 parameter family. The Wald distribution is defined as: .. math:: f(x|\mu, \lambda) = \sqrt{\frac{\lambda}{2\pi x^3}} \exp\left(-\frac{\lambda(x-\mu)^2}{2\mu^2 x}\right) where :math:`\mu` is the mean and :math:`\lambda` the shape parameter of the distribution. https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution """ mu: Array = None lamb: Array = None def __init__(self, name="Wald", *, mu=None, lamb=None): """Initialize the Wald distribution. Args: name: Display name for the distribution. mu: Location parameter (mean). If provided, stored on the instance. lamb: Shape parameter. If provided, stored on the instance. """ super().__init__(name=name) self.mu = jnp.asarray(mu, dtype=float).reshape(()) if mu is not None else None self.lamb = jnp.asarray(lamb, dtype=float).reshape(()) if lamb is not None else None @property def _stored_params(self): """Return stored parameters if all are set, else None.""" if self.mu is None or self.lamb is None: return None return self._params_dict(self.mu, self.lamb) @classmethod def _params_dict(cls, mu: Scalar, lamb: Scalar) -> dict: """Create a parameter dictionary from mu and lamb values.""" d: dict = {"mu": mu, "lamb": lamb} return cls._args_transform(d) def _params_to_tuple(self, params: dict) -> tuple: """Extract (mu, lamb) from the parameter dictionary.""" params = self._args_transform(params) return params["mu"], params["lamb"]
[docs] def example_params(self, *args, **kwargs) -> dict: """Return example parameters for the Wald / Inverse Gaussian distribution.""" return self._params_dict(mu=1.0, lamb=1.0)
@classmethod def _support(cls, *args, **kwargs) -> Array: """Return the support ``[0, inf]`` of the Wald distribution.""" return jnp.array([0.0, jnp.inf])
[docs] def logpdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the log probability density function of the Wald distribution. Args: x: Input values at which to evaluate the log-PDF. params: Dictionary containing the parameters of the distribution. Uses stored parameters if None. Returns: Log-PDF values with the same shape as ``x``. """ params = self._resolve_params(params) x, xshape = _univariate_input(x) mu, lamb = self._params_to_tuple(params) diff = x - mu log_exponent = -0.5 * lamb * diff ** 2 / (mu**2 * x) log_prefactor = 0.5 * jnp.log(lamb) - 1.5 * jnp.log(x) - 0.5 * jnp.log(2 * jnp.pi) log_pdf = log_prefactor + log_exponent return self._enforce_support_on_logpdf( x=x, logpdf=log_pdf.reshape(xshape), params=params )
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the cumulative distribution function of the Wald / Inverse Gaussian distribution. Args: x: Input values at which to evaluate the CDF. params: Dictionary containing the parameters of the distribution. Uses stored parameters if None. Returns: CDF values with the same shape as ``x``. """ params = self._resolve_params(params) x, xshape = _univariate_input(x) mu, lamb = self._params_to_tuple(params) sqrt_lamb_over_x = jnp.sqrt(lamb / x) x_over_mu = x / mu z1 = sqrt_lamb_over_x * (x_over_mu - 1) z2 = -sqrt_lamb_over_x * (x_over_mu + 1) # Stable form: exp(2λ/μ)·Φ(z2) = exp(2λ/μ + log_ndtr(z2)). # z2 < 0 for x > 0, so log_ndtr(z2) → -∞ as 2λ/μ → ∞, keeping the sum bounded. mirror_term = jnp.exp(2 * lamb / mu + special.log_ndtr(z2)) cdf = special.ndtr(z1) + mirror_term return self._enforce_support_on_cdf(x=x, cdf=cdf.reshape(xshape), params=params)
# sampling
[docs] def rvs(self, size: tuple | Scalar, params: dict = None, key: Array = None) -> Array: """Generate random variates from the Wald distribution via Michael-Schucany-Haas.""" params = self._resolve_params(params) key = _resolve_key(key) mu, lamb = self._params_to_tuple(params) key1, key2 = random.split(key) # Step 1: Sample y = z^2, z ~ N(0, 1) z = normal.rvs(size=size, params={"mu": 0.0, "sigma": 1.0}, key=key1) y = z ** 2 # Step 2: Compute the smaller root of the MSH quadratic x_candidate = mu + (mu**2 * y) / (2 * lamb) \ - (mu / (2 * lamb)) * jnp.sqrt(4 * mu * lamb * y + mu**2 * y**2) # Step 3: Accept x_candidate with prob mu/(mu + x_candidate), else mu^2/x_candidate u = random.uniform(key2, shape=z.shape) return jnp.where(u <= mu / (mu + x_candidate), x_candidate, (mu**2) / x_candidate)
# stats
[docs] def stats(self, params: dict = None) -> dict: """Compute the mean and variance of the Wald distribution given its parameters. Args: params: Dictionary containing the parameters of the distribution. Uses stored parameters if None. """ params = self._resolve_params(params) mu, lamb = self._params_to_tuple(params) mean = mu mode = mu * (jnp.sqrt(1 + (9 * mu**2) / (4 * lamb**2)) - (3 * mu) / (2 * lamb)) variance = (mu**3) / lamb skewness = 3 * jnp.sqrt(mu / lamb) kurtosis = 15 * mu / lamb return self._scalar_transform({"mean": mean, "variance": variance, "mode": mode, "skewness": skewness, "kurtosis": kurtosis})
# fitting _supported_methods = frozenset({"mle"})
[docs] def fit(self, x: ArrayLike, *args, name: str = None, **kwargs) -> dict: r"""Fit the Wald distribution to data via **closed-form** MLE: ``μ̂ = mean(x)``, ``λ̂ = 1 / (mean(1/x) − 1/mean(x))``. The closed-form estimator takes no tuning parameters. Args: x: Input data to fit. name: Optional custom name for the fitted distribution instance. Returns: Wald: A fitted ``Wald`` instance. """ x = _univariate_input(x)[0] mean_x = x.mean() inv_mean_x = (1 / x).mean() mu: jnp.ndarray = mean_x lamb: jnp.ndarray = 1 / (inv_mean_x - (1 / mean_x)) return self._fitted_instance(self._params_dict(mu=mu, lamb=lamb), name=name)
wald = Wald("Wald")