Source code for copulax._src.univariate._gof

"""JAX-jittable goodness-of-fit tests for univariate distributions.

Implements the Kolmogorov-Smirnov and Cramér-von Mises one-sample tests.
Both test statistics and asymptotic p-values are fully jit-compilable.
"""

import jax.numpy as jnp
from jax import jit, lax
from jax.scipy import special

from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src.special import kv

# Kolmogorov series small-lambda threshold (mirrors scipy cephes/kolmogorov.c).
# Below lam_min, exp(-2*k^2*lam^2) ≈ 1 for all k and the alternating series
# diverges.  Derived from the dtype's log-underflow limit:
#   lam_min = pi / sqrt(8 * |log(tiny)|)
# float64 → ~0.042, float32 → ~0.119.
def _ks_lam_min(d: jnp.ndarray) -> Scalar:
    """Series-cutoff threshold for the dtype of ``d``.

    Uses ``jnp.finfo(d.dtype).tiny`` so the threshold tracks whatever
    precision the input is actually using — robust to ``jax_enable_x64``
    being on or off, and to lower-precision dtypes (float16, bfloat16).
    The dtype is part of the JAX abstract value, so this is a trace-time
    constant under ``@jit`` and contributes nothing at runtime.
    """
    return jnp.pi / jnp.sqrt(8.0 * (-jnp.log(jnp.finfo(d.dtype).tiny)))


###############################################################################
# Kolmogorov-Smirnov test
###############################################################################
def _ks_pvalue(d: Scalar, n: Scalar) -> Scalar:
    r"""Two-sided Kolmogorov-Smirnov p-value (asymptotic).

    Uses the Kolmogorov survival function:

    .. math::

        P(D_n \ge d) \approx 2 \sum_{k=1}^{K} (-1)^{k+1} e^{-2 k^2 \lambda^2}

    where :math:`\lambda = (\sqrt{n} + 0.12 + 0.11 / \sqrt{n}) \cdot d`.

    For small :math:`\lambda`, the series terms ``exp(-2 k^2 \lambda^2)``
    all approach 1 and the alternating sum fails to converge.  Following
    scipy's ``cephes/kolmogorov.c``, we return ``p = 1`` when
    :math:`\lambda \le \pi / \sqrt{8 \, |\!\log(\mathrm{tiny})|}`, the
    point below which every term underflows for the input dtype.  The
    threshold is computed once at module load from ``jnp.finfo`` (see
    ``_KS_LAM_MIN``).

    References
    ----------
    Marsaglia, G., Tsang, W. W., & Wang, J. (2003).
    "Evaluating Kolmogorov's Distribution". JOSS 8(18).
    """
    sqrt_n = jnp.sqrt(n)
    lam = (sqrt_n + 0.12 + 0.11 / sqrt_n) * d

    lam_min = _ks_lam_min(d)

    ks = jnp.arange(1, 101)
    signs = jnp.where(ks % 2 == 1, 1.0, -1.0)
    terms = signs * jnp.exp(-2.0 * ks**2 * lam**2)
    p_series = 2.0 * jnp.sum(terms)

    # For lam <= lam_min the series is unreliable; true p-value is 1.0.
    # jnp.where evaluates both branches but selects based on condition,
    # keeping gradients clean (no NaN from the unused branch).
    p = jnp.where(lam <= lam_min, 1.0, p_series)
    return jnp.clip(p, 0.0, 1.0)


[docs] @jit def ks_test(x: jnp.ndarray, dist, params: dict) -> dict: r"""One-sample Kolmogorov-Smirnov goodness-of-fit test. Tests whether *x* was drawn from the distribution described by *dist* and *params*. Args: x (ArrayLike): Observed sample (1-D). dist: A copulAX ``Univariate`` distribution object. params (dict): Distribution parameters (as returned by ``fit``). Returns: dict: ``{'statistic': D_n, 'p_value': p}`` """ x, _ = _univariate_input(x) x_sorted = jnp.sort(x.flatten()) n = jnp.asarray(x_sorted.size, dtype=float) # theoretical CDF at the sorted observations f_x = dist.cdf(x=x_sorted, params=params).flatten() # empirical CDF: i/n (i = 1, ..., n) i = jnp.arange(1, x_sorted.size + 1, dtype=float) ecdf_upper = i / n # F_n(x_i) ecdf_lower = (i - 1.0) / n # F_n(x_{i-1}) d_plus = jnp.max(ecdf_upper - f_x) d_minus = jnp.max(f_x - ecdf_lower) d_n = jnp.maximum(d_plus, d_minus) p_value = _ks_pvalue(d_n, n) return {"statistic": d_n, "p_value": p_value}
############################################################################### # Cramér-von Mises test ############################################################################### def _cvm_pvalue(w2: Scalar) -> Scalar: r"""Asymptotic Cramér-von Mises p-value. Uses the representation by Csörgő & Faraway (1996) eq. 1.2 based on the eigenvalue expansion. The CDF is a sum of **all-positive** terms (no alternating signs): .. math:: F(w) = \sum_{j=0}^{\infty} \frac{\Gamma(j + \tfrac{1}{2})}{j!\,\pi} \, \sqrt{\frac{4j+1}{\pi\,w}} \; e^{-\frac{(4j+1)^2}{16\,w}} \; K_{1/4}\!\!\left(\frac{(4j+1)^2}{16\,w}\right) The p-value is :math:`1 - F(w)`. For :math:`W^2 \ge 5` the 20-term truncation of the series is sufficient to establish :math:`F \approx 1` (true *p* < 3.1e-12). To prevent a spurious rise in *p* for very large :math:`W^2` (caused by :math:`K_{1/4}(z \to 0)` divergence in the truncated tail), the CDF is clamped to 1.0 above that threshold. References ---------- Csörgő, S. & Faraway, J. (1996). "The Exact and Asymptotic Distributions of Cramér-von Mises Statistics." JRSS-B 58(1), 221-234. """ # Guard against w2 <= 0 (degenerate perfect fit): avoid division by # zero in z = a^2 / (16 * w2). Use a safe denominator for tracing, # then select p = 1.0 at the end via jnp.where. w2_safe = jnp.where(w2 > 0, w2, 1.0) j = jnp.arange(0, 20, dtype=float) a = 4.0 * j + 1.0 z = a**2 / (16.0 * w2_safe) # K_{1/4}(z) via quadrature k_val = kv(0.25, z) # coefficients: Gamma(j+1/2) / (j! * pi) — ALL POSITIVE (no (-1)^j) # Reference: Csörgő & Faraway (1996) eq. 1.2; confirmed by scipy source log_gamma_half = special.gammaln(j + 0.5) log_factorial = special.gammaln(j + 1.0) log_coeff = log_gamma_half - log_factorial - jnp.log(jnp.pi) coeff = jnp.exp(log_coeff) sqrt_term = jnp.sqrt(a / (jnp.pi * w2)) exp_term = jnp.exp(-z) terms = coeff * sqrt_term * exp_term * k_val cdf = jnp.sum(terms) # Enforce CDF monotonicity. For W² > ~30 the 20-term truncation # becomes unreliable because K_{1/4}(z → 0) diverges, dragging the # partial sum spuriously below 1. At W² = 5 the true p < 3.1e-12, # so clamping CDF to 1.0 here loses nothing in practice. cdf = jnp.where(w2 >= 5.0, 1.0, cdf) p = jnp.where(w2 <= 0, 1.0, 1.0 - cdf) return jnp.clip(p, 0.0, 1.0)
[docs] @jit def cvm_test(x: jnp.ndarray, dist, params: dict) -> dict: r"""One-sample Cramér-von Mises goodness-of-fit test. Tests whether *x* was drawn from the distribution described by *dist* and *params*. Args: x (ArrayLike): Observed sample (1-D). dist: A copulAX ``Univariate`` distribution object. params (dict): Distribution parameters (as returned by ``fit``). Returns: dict: ``{'statistic': W2, 'p_value': p}`` """ x, _ = _univariate_input(x) x_sorted = jnp.sort(x.flatten()) n = jnp.asarray(x_sorted.size, dtype=float) # theoretical CDF at sorted observations f_x = dist.cdf(x=x_sorted, params=params).flatten() # CvM statistic i = jnp.arange(1, x_sorted.size + 1, dtype=float) w2 = jnp.sum((f_x - (2.0 * i - 1.0) / (2.0 * n)) ** 2) + 1.0 / (12.0 * n) p_value = _cvm_pvalue(w2) return {"statistic": w2, "p_value": p_value}