Source code for copulax._src.special

"""Modified Bessel function of the second kind and Student's t CDF.

References:
    - Bessel function integral representation:
      https://dlmf.nist.gov/10.32#E10
    - Asymptotic forms:
      https://dlmf.nist.gov/10.30
"""

from jax import lax, vmap
import jax.numpy as jnp
import numpy as np
from jax import Array
from jax.typing import ArrayLike
from jax.scipy import special
import jax

from copulax._src.typing import Scalar


# -----------------------------------------------------------------------------
# Legacy kv implementation (adaptive quadax quadrature), retained for reference.
# -----------------------------------------------------------------------------
# def _kv_integrand(w: Array, v: float, x: Array) -> Array:
#     r"""Integrand for the integral representation of $K_v(x)$.
#
#     Uses the substitution $w = e^t$ in the standard integral
#     $K_v(x) = \frac{1}{2}\int_0^\infty w^{v-1} \exp(-x(w + w^{-1})/2)\,dw$.
#     """
#     frac = jnp.pow(w, -1)
#     inner = -0.5 * x * (w + frac)
#     exp = lax.exp(inner)
#     return 0.5 * lax.pow(w, v - 1.0) * exp
#
#
# def _kv_single_x(v: float, xi: float) -> float:
#     """Evaluate $K_v$ at a single scalar point via quadrature."""
#     from quadax import quadgk
#
#     kv_val, _ = quadgk(_kv_integrand, interval=(0.0, jnp.inf), args=(v, xi))
#     return kv_val.reshape(())
#
#
# def kv_legacy(v: float, x: ArrayLike) -> Array:
#     r"""Legacy adaptive-quadrature implementation of $K_v(x)$."""
#     v = jnp.asarray(jnp.abs(v), dtype=float)
#     x = jnp.asarray(x, dtype=float)
#     xshape = x.shape
#     x_flat = x.flatten()
#     kv_raw = vmap(lambda xi: _kv_single_x(v, xi))(x_flat)
#     kv_adj = jnp.where(x_flat < 0, jnp.nan, kv_raw)
#     return kv_adj.reshape(xshape)


# ---------------------------------------------------------------------------
# K_v(x) quadrature nodes and regime thresholds
# ---------------------------------------------------------------------------

# Gauss-Legendre nodes/weights for the main quadrature (DLMF 10.32.9).
_KV_GL_ORDER = 64
_KV_GL_NODES_NP, _KV_GL_WEIGHTS_NP = np.polynomial.legendre.leggauss(
    _KV_GL_ORDER
)
_KV_GL_NODES = jnp.asarray(_KV_GL_NODES_NP, dtype=float)
_KV_GL_WEIGHTS = jnp.asarray(_KV_GL_WEIGHTS_NP, dtype=float)

_KV_SMALL_X = jnp.asarray(1e-8, dtype=float)
_KV_LARGE_X = jnp.asarray(40.0, dtype=float)
_KV_EULER_GAMMA = jnp.asarray(0.5772156649015329, dtype=float)
_KV_LOG_2 = jnp.log(jnp.asarray(2.0, dtype=float))
_KV_LOG_PI = jnp.log(jnp.asarray(jnp.pi, dtype=float))


# ---------------------------------------------------------------------------
# Overflow-safe log(cosh(z))
# ---------------------------------------------------------------------------

def _log_cosh(z: Array) -> Array:
    r"""Numerically stable computation of ``log(cosh(z))``.

    For :math:`|z| \le 20`, computes ``log(cosh(z))`` directly.
    For :math:`|z| > 20`, uses the identity

    .. math::

        \log\cosh z = |z| - \log 2 + \log(1 + e^{-2|z|})
                     \approx |z| - \log 2

    which avoids the overflow of ``cosh(z)`` that occurs for
    :math:`|z| > {\sim}710` in float64.
    """
    abs_z = jnp.abs(z)
    # Large-|z| branch: |z| - log(2), with small correction
    large = abs_z - _KV_LOG_2 + jnp.log1p(jnp.exp(-2.0 * abs_z))
    # Small-|z| branch: direct
    small = jnp.log(jnp.cosh(z))
    return jnp.where(abs_z > 20.0, large, small)


# ---------------------------------------------------------------------------
# Log-space regime-specific evaluation functions for log(K_v(x))
#
# Each function computes log(K_v(x)) directly, avoiding the exp() that
# causes underflow in K_v(x) for x >= ~710 in float64.
# ---------------------------------------------------------------------------

def _log_kv_small_x(v: Array, x: Array) -> Array:
    r"""Log of K_v(x) for small x (x < 1e-8), dispatching across three regimes.

    **Branch 1 — v < 10⁻⁴ (K₀ leading term, DLMF 10.31.2):**

    .. math::

        K_0(x) \approx -\log(x/2) - \gamma

    Accurate to < 4 × 10⁻⁸ relative error for v < 10⁻⁴.

    **Branch 2 — 10⁻⁴ ≤ v < 1 (two-term formula, DLMF 10.27.4):**

    .. math::

        K_v(x) \approx \frac{\pi}{2}
        \frac{(x/2)^{-v}/\Gamma(1-v) - (x/2)^{v}/\Gamma(v+1)}{\sin(v\pi)}

    This is the leading-order approximation from the definition
    K_v = (π/2)[I_{-v} - I_v]/sin(vπ).  Each I_v series is truncated
    to its m = 0 term, which is exact to O(x²).  For 10⁻⁴ ≤ v < 1 the
    two terms are well-separated (no catastrophic cancellation) and the
    formula achieves < 10⁻¹⁴ relative error against scipy.

    **Branch 3 — v ≥ 1 (single-term dominant, DLMF 10.30.2):**

    .. math::

        K_v(x) \approx \frac{\Gamma(v)}{2}\left(\frac{2}{x}\right)^v

    For v ≥ 1 the (x/2)^{-v} term dominates overwhelmingly and the
    single-term formula achieves < 10⁻¹⁴ relative error.
    """
    log_half_x = jnp.log(0.5 * x)

    def _k0_branch(_: None) -> Array:
        # K_0(x) ≈ -log(x/2) - γ, always positive for x < 1e-8
        return jnp.log(-log_half_x - _KV_EULER_GAMMA)

    def _two_term_branch(_: None) -> Array:
        # K_v(x) = (π/2) * [(x/2)^{-v}/Γ(1-v) - (x/2)^v/Γ(v+1)] / sin(vπ)
        # Compute each term in log space, then subtract.
        log_t1 = -v * log_half_x - special.gammaln(1.0 - v)
        log_t2 = v * log_half_x - special.gammaln(v + 1.0)
        # t1 > t2 always (the (x/2)^{-v} term dominates for x < 1e-8, v > 0)
        # Use log-subtract: log(t1 - t2) = log(t1) + log(1 - exp(log_t2 - log_t1))
        # = log_t1 + log1p(-exp(log_t2 - log_t1))
        diff = jnp.log1p(-jnp.exp(log_t2 - log_t1))
        return jnp.log(jnp.pi / 2.0) + log_t1 + diff - jnp.log(jnp.sin(v * jnp.pi))

    def _single_term_branch(_: None) -> Array:
        # K_v(x) ≈ Γ(v)/2 · (2/x)^v
        v_safe = jnp.maximum(v, jnp.asarray(1e-6, dtype=float))
        return (
            special.gammaln(v_safe)
            + (v_safe - 1.0) * _KV_LOG_2
            - v_safe * jnp.log(x)
        )

    def _nonzero_branch(_: None) -> Array:
        return lax.cond(v < 1.0, _two_term_branch, _single_term_branch, operand=None)

    return lax.cond(v < 1e-4, _k0_branch, _nonzero_branch, operand=None)


def _log_kv_large_x(v: Array, x: Array) -> Array:
    r"""Log of the large-x asymptotic expansion for K_v(x).

    .. math::

        \log K_v(x) \approx \tfrac{1}{2}(\log\pi - \log 2 - \log x) - x
                           + \log\!\bigl(\sum_{k=0}^{3} a_k / (8x)^k\bigr)

    where :math:`a_k = \prod_{j=0}^{k-1} (4v^2 - (2j+1)^2) / k!`.
    (DLMF 10.40.2)
    """
    mu = 4.0 * v * v
    inv8x = 1.0 / (8.0 * x)
    series = (
        1.0
        + (mu - 1.0) * inv8x
        + ((mu - 1.0) * (mu - 9.0)) * 0.5 * (inv8x ** 2)
        + ((mu - 1.0) * (mu - 9.0) * (mu - 25.0))
        * (1.0 / 6.0)
        * (inv8x ** 3)
    )
    log_pref = 0.5 * (_KV_LOG_PI - _KV_LOG_2 - jnp.log(x)) - x
    return log_pref + jnp.log(series)


def _log_kv_debye(v: Array, x: Array) -> Array:
    r"""Log of K_v(x) via the Debye uniform asymptotic expansion.

    Uses the uniform expansion (DLMF 10.41.3):

    .. math::

        K_v(v\,z) \sim \sqrt{\frac{\pi}{2v}}\,
        \frac{e^{-v\,\eta(z)}}{(1+z^2)^{1/4}}\,
        \sum_{k=0}^{5} \frac{(-1)^k\,U_k(p)}{v^k}

    where :math:`z = x/v`, :math:`p = 1/\sqrt{1+z^2}`, and
    :math:`\eta(z) = \sqrt{1+z^2} + \ln\!\bigl(z/(1+\sqrt{1+z^2})\bigr)`.

    The Debye polynomials :math:`U_k(p)` are given explicitly in
    DLMF 10.41.10 / Olver (1954).  Six terms (k=0..5) give ~14-digit
    accuracy for v ≥ 15.

    Returns ``log(K_v(x))`` directly, avoiding the ``exp(-v·η)`` that
    underflows for large v or x.

    References:
        DLMF §10.41; Olver, F.W.J. (1954) "The asymptotic expansion of
        Bessel functions of large order", Phil. Trans. R. Soc. A 247, 328-368.
    """
    z = x / jnp.maximum(v, 1e-30)
    z2 = z * z
    sqrt1z2 = jnp.sqrt(1.0 + z2)
    p = 1.0 / sqrt1z2

    # Debye phase eta(z) = sqrt(1+z^2) + ln(z / (1 + sqrt(1+z^2)))
    eta = sqrt1z2 + jnp.log(z / (1.0 + sqrt1z2))

    # Debye polynomials U_k(p) — coefficients from DLMF 10.41.10
    p2 = p * p
    p3 = p2 * p
    p4 = p2 * p2
    p5 = p4 * p
    p6 = p3 * p3
    p7 = p6 * p
    p8 = p4 * p4
    p9 = p8 * p
    p10 = p5 * p5
    p12 = p6 * p6
    p13 = p12 * p
    p15 = p12 * p3

    u0 = 1.0
    u1 = (3.0 * p - 5.0 * p3) / 24.0
    u2 = (81.0 * p2 - 462.0 * p4 + 385.0 * p6) / 1152.0
    u3 = (30375.0 * p3 - 369603.0 * p5 + 765765.0 * p7
           - 425425.0 * p9) / 414720.0
    u4 = (4465125.0 * p4 - 94121676.0 * p6 + 349922430.0 * p8
           - 446185740.0 * p10 + 185910725.0 * p12) / 39813120.0
    u5 = (1519035525.0 * p5 - 49286948607.0 * p7
           + 284499769554.0 * p9 - 614135872350.0 * p10 * p
           + 566098157625.0 * p13 - 188699385875.0 * p15) / 6688604160.0

    # Series with alternating signs: sum_{k=0}^5 (-1)^k U_k(p) / v^k
    inv_v = 1.0 / jnp.maximum(v, 1e-30)
    series = (u0
              - u1 * inv_v
              + u2 * inv_v ** 2
              - u3 * inv_v ** 3
              + u4 * inv_v ** 4
              - u5 * inv_v ** 5)

    # Log-space prefactor: avoids exp(-v*eta) underflow
    log_pref = (0.5 * (_KV_LOG_PI - _KV_LOG_2 - jnp.log(v))
                - v * eta
                - 0.25 * jnp.log(1.0 + z2))

    return log_pref + jnp.log(series)


def _log_kv_legendre(v: Array, x: Array) -> Array:
    r"""Log of K_v(x) via Gauss-Legendre quadrature on DLMF 10.32.9.

    .. math::

        K_v(x) = \int_0^\infty \cosh(v\,t)\,\exp(-x\,\cosh t)\,\mathrm{d}t

    The integrand is smooth and bounded for all v, x > 0 (no singularity),
    and decays exponentially for large t.

    The integration interval [t_lo, t_hi] adapts to the integrand peak:

    - *Saddle-centred* (sharp peak): when the integrand is dominated by a
      peak at the saddle point :math:`t^* = \mathrm{arcsinh}(v/x)`, the
      interval is centred on :math:`t^*`.
    - *Decay-based* (broad peak): when the peak is broad, the interval is
      ``[0, T_decay]`` where ``T_decay = log(92/x)``.

    Returns ``log(K_v(x))`` directly via the log-sum-exp trick, avoiding
    the ``exp(m)`` that underflows for large x.

    References:
        DLMF 10.32.9; Watson (1944) §6.22; Abramowitz & Stegun 9.6.24.
    """
    x_safe = jnp.maximum(x, 1e-10)

    # Decay-based interval: exp(-x*cosh(T)) < eps when T ≈ log(92/x)
    t_hi_decay = jnp.maximum(jnp.log(92.0 / x_safe), 10.0)

    # Saddle-point interval
    t_star = jnp.arcsinh(v / x_safe)
    cosh_tstar = jnp.cosh(t_star)
    peak_width = 1.0 / jnp.sqrt(jnp.maximum(x_safe * cosh_tstar, 1e-30))
    saddle_half = jnp.maximum(8.0 * peak_width, 4.0)

    # Use saddle-centering when the peak is sharp (half-width < decay interval)
    use_saddle = saddle_half < t_hi_decay

    # Saddle-centred bounds (regime A)
    t_lo_saddle = jnp.maximum(t_star - saddle_half, 0.0)
    t_hi_saddle = t_star + saddle_half

    # Decay bounds (regime B)
    t_lo_decay = jnp.asarray(0.0, dtype=float)

    t_lo = jnp.where(use_saddle, t_lo_saddle, t_lo_decay)
    t_hi = jnp.where(use_saddle, t_hi_saddle, t_hi_decay)

    # When x is small, the integrand at t=0 is cosh(0)*exp(-x) ≈ 1,
    # which is non-negligible.  Force t_lo = 0 so we don't miss this
    # contribution, even when saddle-centering pushes t_lo > 0.
    t_lo = jnp.where(x_safe < 20.0, 0.0, t_lo)

    # Map [-1, 1] -> [t_lo, t_hi]
    t = 0.5 * (t_hi - t_lo) * (_KV_GL_NODES + 1.0) + t_lo
    w = 0.5 * (t_hi - t_lo) * _KV_GL_WEIGHTS

    # Log-space evaluation for numerical stability.
    # _log_cosh prevents overflow of cosh(v*t) for large v*t.
    log_integrand = _log_cosh(v * t) + (-x * jnp.cosh(t))

    # Log-sum-exp: log(Σ w_i * exp(f_i)) = m + log(Σ w_i * exp(f_i - m))
    m = jnp.max(log_integrand)
    result = jnp.sum(w * jnp.exp(log_integrand - m))

    return m + jnp.log(result)


# ---------------------------------------------------------------------------
# Single-point dispatcher and public API
# ---------------------------------------------------------------------------

_KV_DEBYE_V_THRESH = jnp.asarray(15.0, dtype=float)


def _log_kv_single(v: Array, x: Array) -> Array:
    """Evaluate log(K_v(x)) at a single scalar point, dispatching across regimes.

    Regime boundaries:
    - v >= 15: Debye uniform asymptotic expansion (DLMF 10.41.3)
    - x < small_x_threshold: small-x asymptotic (DLMF 10.27.4 / 10.30 / 10.31)
    - x > large_x_threshold: large-x asymptotic (DLMF 10.40.2)
    - otherwise: Gauss-Legendre quadrature on DLMF 10.32.9

    Both thresholds are v-dependent:

    - **Small-x**: ``max(1e-8, v * 1e-5)``.  For large v the leading-term
      asymptotic K_v(x) ≈ Γ(v)/2·(2/x)^v has relative correction
      O(x² / (4(v+1))), which is negligible even at moderately small x.
      The wider threshold keeps the quadrature away from the regime where
      the integrand peak is too sharp for 64 nodes.
    - **Large-x**: ``max(40, 2v² + 20)``.  The 4-term Hankel series
      converges only when x >> v²/2.
    """
    x_pos = jnp.maximum(x, jnp.asarray(1e-30, dtype=float))

    # v-dependent thresholds
    small_x_thresh = jnp.maximum(_KV_SMALL_X, v * 1e-5)
    large_x_thresh = jnp.maximum(_KV_LARGE_X, 2.0 * v * v + 20.0)

    def _moderate_v(xi):
        """Dispatch for v < 15: quadrature or asymptotic."""
        return lax.cond(
            xi < small_x_thresh,
            lambda xj: _log_kv_small_x(v, xj),
            lambda xj: lax.cond(
                xj > large_x_thresh,
                lambda xk: _log_kv_large_x(v, xk),
                lambda xk: _log_kv_legendre(v, xk),
                xj,
            ),
            xi,
        )

    core = lax.cond(
        v >= _KV_DEBYE_V_THRESH,
        lambda xi: _log_kv_debye(v, xi),
        _moderate_v,
        x_pos,
    )
    core = jnp.where(x < 0.0, jnp.nan, core)
    core = jnp.where(x == 0.0, jnp.inf, core)
    return core


# ---------------------------------------------------------------------------
# Public log_kv with a custom JVP
#
# Rationale: the 4-regime ``lax.cond`` cascade inside ``_log_kv_single`` is
# expensive for ``jax.grad`` / ``jax.value_and_grad`` to trace through — JAX
# differentiates through every branch (the 64-node Gauss-Legendre quadrature
# and the Debye series) before selecting the active one with ``stop_gradient``.
# Tracing that derivative graph dominates compile time in copula EM fitting
# (confirmed: the ``mom_gh_params`` ``_solve_gig_moments`` Adam scan spent
# nearly all its compile time inside ``value_and_grad(log_kv)``).
#
# The analytical derivatives below use classical Bessel identities, so the
# resulting gradients are mathematically identical to those autograd would
# compute — only the trace shrinks.
# ---------------------------------------------------------------------------


def _stable_log_sinh(y: Array) -> Array:
    r"""Numerically stable ``log(sinh(y))`` for ``y >= 0``.

    Two complementary formulas are combined to stay accurate on the
    entire non-negative real line.  Both are mathematically exact; they
    differ only in which one keeps its precision in float64.

    * **Small-``y`` branch, direct**: ``log(sinh(y))``.  Accurate
      whenever ``sinh(y)`` is representable in float64 (``y <= ~709``).
      In particular, ``sinh(y) = y + y^3/6 + ...`` is computed by the
      math library without the catastrophic cancellation that the
      large-``y`` branch suffers at small ``y``.  At ``y == 0`` returns
      ``-inf`` exactly (``log(0)``).

    * **Large-``y`` branch, identity**:
      :math:`\log\sinh y = y + \log1p(-e^{-2y}) - \log 2`.
      Equivalent by ``sinh y = e^y (1 - e^{-2y}) / 2``.  Stays finite
      past the ``sinh`` overflow boundary (``e^{2y}`` never evaluated).
      **Not used for small ``y``**: when ``y`` is tiny, ``e^{-2y}`` is
      very close to 1 and ``1 - e^{-2y}`` has only
      ``~15 + log10(1/(2y))`` significant digits — the formula loses
      precision exactly where the direct version is sharpest.

    The switch at ``y > 10`` is conservative on both ends: ``sinh(10)``
    is ``~1.1e4`` (nowhere near float64 overflow at ``y ~= 709``) and
    at ``y = 10`` the large-branch cancellation is negligible
    (``1 - e^{-20} ≈ 1 - 2e-9``), so crossing the boundary is
    effectively free — either side gives full precision.

    Called by :py:func:`_dlog_kv_dv_single` with ``y = v · t`` where
    ``t`` lives on the saddle-centred quadrature interval; at large
    ``v`` (the Debye regime, ``v >= 15``) ``v · t`` can reach several
    hundred, making the overflow-safe large-branch load-bearing for
    the ν-tangent.
    """
    large = y + jnp.log1p(-jnp.exp(-2.0 * y)) - _KV_LOG_2
    small = jnp.log(jnp.sinh(y))
    return jnp.where(y > 10.0, large, small)


def _log_kv_primal(v: Array, x: Array) -> Array:
    """Forward-only ``log K_v(x)`` via the 4-regime dispatcher (``v >= 0``).

    Kept separate from the ``@jax.custom_jvp`` primitive below so the
    JVP rule can call it without recursing into the custom-gradient
    code path (the x-tangent needs ``log K_{v-1}`` and ``log K_{v+1}``,
    which are consumed as constants, not differentiated through).
    """
    v = v.reshape(())
    xshape = x.shape
    x_flat = x.reshape(-1)
    vals = vmap(lambda xi: _log_kv_single(v, xi))(x_flat)
    return vals.reshape(xshape)


def _dlog_kv_dv_single(v: Array, x_val: Array) -> Array:
    r"""Scalar ``∂ log K_v(x) / ∂v`` for ``v >= 0``.

    Uses the ratio of integral representations
    (DLMF 10.32.9 and its v-derivative):

    .. math::

        \frac{\partial K_v(x)}{\partial v}
          = \int_0^\infty t\,\sinh(vt)\,e^{-x\cosh t}\,dt,
        \qquad
        K_v(x) = \int_0^\infty \cosh(vt)\,e^{-x\cosh t}\,dt.

    Evaluating both integrals on the same 64-node Gauss-Legendre rule
    with the same saddle-point-centred interval as ``_log_kv_legendre``
    gives

    .. math::

        \frac{\partial \log K_v(x)}{\partial v}
          = \frac{\sum_i w_i\,t_i\,\sinh(v t_i)\,e^{-x\cosh t_i}}
                 {\sum_i w_i\,\cosh(v t_i)\,e^{-x\cosh t_i}}.

    Both sums share a common log-sum-exp max so the ``e^{-x\cosh t}``
    factor cancels numerically.  At ``v == 0`` the numerator is
    identically zero (``sinh(0 \cdot t) = 0``) and the denominator is
    positive, so the ratio evaluates to an exact 0 — which matches the
    identity ``∂K_v/∂v|_{v=0} = 0`` required by the even-in-ν property
    of ``K_v``.

    The ratio formulation is accurate in every forward regime (Debye,
    small-x, large-x, moderate) because the GL quadrature with saddle
    centring converges to machine precision on the smooth integrand
    over the full support.
    """
    x_safe = jnp.maximum(x_val, 1e-10)
    # Saddle-point-centred interval — identical to ``_log_kv_legendre``.
    t_hi_decay = jnp.maximum(jnp.log(92.0 / x_safe), 10.0)
    t_star = jnp.arcsinh(v / x_safe)
    cosh_tstar = jnp.cosh(t_star)
    peak_width = 1.0 / jnp.sqrt(jnp.maximum(x_safe * cosh_tstar, 1e-30))
    saddle_half = jnp.maximum(8.0 * peak_width, 4.0)
    use_saddle = saddle_half < t_hi_decay
    t_lo_saddle = jnp.maximum(t_star - saddle_half, 0.0)
    t_hi_saddle = t_star + saddle_half
    t_lo = jnp.where(use_saddle, t_lo_saddle, jnp.asarray(0.0, dtype=float))
    t_hi = jnp.where(use_saddle, t_hi_saddle, t_hi_decay)
    t_lo = jnp.where(x_safe < 20.0, 0.0, t_lo)

    t = 0.5 * (t_hi - t_lo) * (_KV_GL_NODES + 1.0) + t_lo
    w = 0.5 * (t_hi - t_lo) * _KV_GL_WEIGHTS

    vt = v * t
    # Numerator integrand ``t · sinh(vt) · e^{-x cosh t}`` and denominator
    # ``cosh(vt) · e^{-x cosh t}`` in log-space, sharing a common maximum
    # so the ``e^{-x cosh t}`` factor cancels stably.
    log_num = jnp.log(t) + _stable_log_sinh(vt) - x_val * jnp.cosh(t)
    log_den = _log_cosh(vt) - x_val * jnp.cosh(t)

    # Shared max over the concatenated log-arrays; ``-inf`` (which occurs
    # only for the numerator at v == 0) is replaced by 0 so ``exp`` stays
    # finite.  Numerator terms then underflow cleanly to 0 for v == 0.
    m = jnp.maximum(jnp.max(log_num), jnp.max(log_den))
    m_safe = jnp.where(jnp.isfinite(m), m, 0.0)

    num = jnp.sum(w * jnp.exp(log_num - m_safe))
    den = jnp.sum(w * jnp.exp(log_den - m_safe))
    # ``den`` is strictly positive for ``v, x >= 0`` (``cosh >= 1``), so
    # no zero-division guard is needed in the forward computation — but
    # add one anyway as a belt-and-braces against float32 underflow at
    # extreme x.
    return num / jnp.maximum(den, 1e-300)


@jax.custom_jvp
def _log_kv_pos(v: Array, x: Array) -> Array:
    r"""Internal ``log K_v(x)`` with ``v >= 0`` and a hand-written JVP.

    The forward path is identical to the public ``log_kv``'s existing
    4-regime dispatcher.  The custom JVP collapses the backward trace
    (which otherwise walks all 4 ``lax.cond`` branches of the
    dispatcher) into two closed-form identities plus a single auxiliary
    Gauss-Legendre quadrature for the ν-tangent.

    Not intended for direct use — call :py:func:`log_kv` instead, which
    also handles negative ``v`` via ``jnp.abs``.
    """
    return _log_kv_primal(v, x)


@_log_kv_pos.defjvp
def _log_kv_pos_jvp(primals, tangents):
    r"""JVP for ``_log_kv_pos``.

    x-tangent via the standard recurrence

    .. math::

        \frac{\partial \log K_v(x)}{\partial x}
          = -\tfrac{1}{2}\bigl(e^{\log K_{v-1}(x) - \log K_v(x)}
                             + e^{\log K_{v+1}(x) - \log K_v(x)}\bigr).

    ν-tangent via :py:func:`_dlog_kv_dv_single`, vmapped across ``x``.
    """
    v, x = primals
    v_dot, x_dot = tangents

    primal_out = _log_kv_primal(v, x)

    # x-tangent: recurrence.  ``|v - 1|`` because the identity uses
    # ``K_{v-1}`` which maps to ``K_{|v-1|}`` for ``v < 1`` via the
    # even-in-ν property.  ``_log_kv_primal`` assumes ``v >= 0`` so the
    # ``jnp.abs`` is load-bearing.
    log_kv_vm1 = _log_kv_primal(jnp.abs(v - 1.0), x)
    log_kv_vp1 = _log_kv_primal(v + 1.0, x)
    dlog_dx = -0.5 * (
        jnp.exp(log_kv_vm1 - primal_out)
        + jnp.exp(log_kv_vp1 - primal_out)
    )

    # ν-tangent: vmap the scalar quadrature across x.
    xshape = x.shape
    x_flat = x.reshape(-1)
    dlog_dv_flat = vmap(lambda xi: _dlog_kv_dv_single(v, xi))(x_flat)
    dlog_dv = dlog_dv_flat.reshape(xshape)

    tangent_out = dlog_dv * v_dot + dlog_dx * x_dot
    return primal_out, tangent_out


[docs] def log_kv(v: float, x: ArrayLike) -> Array: r"""Log of the modified Bessel function of the second kind, :math:`\log K_v(x)`. Computes :math:`\log K_v(x)` directly in log space, remaining finite and accurate for arbitrarily large :math:`x`, avoiding the underflow that occurs in :math:`K_v(x)`. Pure JAX, JIT-compatible, and differentiable w.r.t. both *v* and *x*. Gradients use a hand-written :py:func:`~jax.custom_jvp` rule based on the classical Bessel recurrence for ``∂/∂x`` and the integral representation for ``∂/∂v``. Under ``jax.grad`` the backward trace is therefore small (no ``lax.cond`` unrolling, no differentiation through the 64-node quadrature). Four forward evaluation regimes, selected automatically: 1. **v ≥ 15**: Debye uniform asymptotic expansion (DLMF 10.41.3), 6-term series with Olver's polynomials. 2. **x < 10⁻⁸**: small-x leading asymptotics (DLMF 10.30/10.31). 3. **x > max(40, 2v²+20)**: large-x Hankel expansion (DLMF 10.40.2), 4-term series. 4. **Otherwise**: Gauss-Legendre quadrature (64 points) on the integral :math:`K_v(x) = \int_0^\infty \cosh(vt)\,e^{-x\cosh t}\,dt` (DLMF 10.32.9), with saddle-point-centred integration interval. Args: v: Order (real, may be negative — K_{-v} = K_v). x: Argument (array-like, must be ≥ 0). Returns: Array of log(K_v(x)) values with the same shape as *x*. """ v_abs = jnp.asarray(jnp.abs(v), dtype=float).reshape(()) x = jnp.asarray(x, dtype=float) return _log_kv_pos(v_abs, x)
[docs] def kv(v: float, x: ArrayLike) -> Array: r"""Modified Bessel function of the second kind, :math:`K_v(x)`. Convenience wrapper: ``kv(v, x) = exp(log_kv(v, x))``. Computes :math:`\log K_v(x)` directly in log space, remaining finite and accurate for arbitrarily large :math:`x`, avoiding the underflow that occurs in :math:`K_v(x)`. Pure JAX, JIT-compatible, and differentiable w.r.t. both *v* and *x*. Gradients use a hand-written :py:func:`~jax.custom_jvp` rule based on the classical Bessel recurrence for ``∂/∂x`` and the integral representation for ``∂/∂v``. Under ``jax.grad`` the backward trace is therefore small (no ``lax.cond`` unrolling, no differentiation through the 64-node quadrature). Four forward evaluation regimes, selected automatically: 1. **v ≥ 15**: Debye uniform asymptotic expansion (DLMF 10.41.3), 6-term series with Olver's polynomials. 2. **x < 10⁻⁸**: small-x leading asymptotics (DLMF 10.30/10.31). 3. **x > max(40, 2v²+20)**: large-x Hankel expansion (DLMF 10.40.2), 4-term series. 4. **Otherwise**: Gauss-Legendre quadrature (64 points) on the integral :math:`K_v(x) = \int_0^\infty \cosh(vt)\,e^{-x\cosh t}\,dt` (DLMF 10.32.9), with saddle-point-centred integration interval. Args: v: Order (real, may be negative — K_{-v} = K_v). x: Argument (array-like, must be ≥ 0). Returns: Array of K_v(x) values with the same shape as *x*. """ return jnp.exp(log_kv(v, x))
[docs] def kv_asymptotic(v: float, x: ArrayLike) -> Array: """Alias retained for backward compatibility.""" return kv(v, x)
_LOG_KV_PLUS_S_LOG_R_FLOOR = 1e-12
[docs] def log_kv_plus_s_log_r(s: Scalar, r: ArrayLike) -> Array: r"""Numerically stable ``log K_s(r) + s · log(r)`` for ``s > 0``. Motivation ---------- The skewed-t (and multivariate skewed-t) log-PDF contains the sum .. math:: \log K_s(r) + \tfrac{s}{2}\log\bigl((\nu+Q)\,R\bigr) = \log K_s(r) + s \log r where ``r = sqrt((ν+Q)·R)`` and ``R = (γ/σ)²``. As ``γ → 0`` both terms diverge individually (``log K_s(r) → +∞``, ``s log r → -∞``) but the sum stays finite. Computing them separately as two log-space objects and then subtracting is lossy; this helper computes the sum as one object so the divergent ``s log r`` parts cancel arithmetically in float64. Closed-form limit ----------------- From DLMF 10.30.2 (``K_ν(z) ~ ½ Γ(ν) (2/z)^ν`` as ``z → 0⁺`` for ``ν > 0``): .. math:: \log K_s(r) + s \log r = \log\Gamma(s) + (s-1)\log 2 + O(r^{2s}) \to \log\Gamma(s) + (s-1)\log 2 \quad \text{as } r \to 0. The ``O(r^{2s})`` tail is the next-order correction from the subtracted ``I_s`` term in ``K_s = (π/2)(I_{-s} - I_s)/sin(sπ)``; the helper returns the **full** combination (limit + tail), not the truncated limit, so the output is correct for every ``r ≥ 0`` to the precision at which ``log_kv_small_x`` itself is accurate. Implementation -------------- Computes ``log_kv(s, r_safe) + s · log(r_safe)`` directly, where .. math:: r_{\text{safe}} = \max(r, 10^{-12}). - At ``r = 0`` the floor picks up and the helper returns ``lgamma(s) + (s-1) log 2 + O(10^{-24s})`` — indistinguishable from the true limit at float64 precision for every ``s`` (including ``s = 0.5``, where the tail is ``O(r^{2s}) = O(r)`` and so ``O(10^{-12})`` well below any fitting tolerance). - For ``r`` strictly positive and above the floor the helper is exactly the direct sum ``log K_s(r) + s log r``. - Catastrophic cancellation inside ``log_kv_small_x`` for ``s ≥ 1`` (where ``log_kv(s, r) = lgamma(s) + (s-1) log 2 − s log r`` and the ``−s log r + s log r`` terms cancel arithmetically) costs at most ``eps · s · log(r_safe)`` relative — ``≲ 1e-13`` even at the floor. Gradient safety --------------- The ``jnp.maximum`` floor ensures ``log(0)`` and ``log_kv``'s ``x == 0`` special case never fire, so ``jax.grad`` through the helper does not encounter ``nan`` or ``inf``. At ``r = 0`` the gradient ``∂output/∂r`` is ``0`` (the floor pins ``r_safe`` to a constant), matching the analytic derivative of the ``O(r^{2s})`` tail in the limit. Args: s: Bessel order (scalar, ``s > 0``). r: Non-negative argument (array-like). Returns: Array of ``log K_s(r) + s · log(r)`` values, same shape as ``r``, finite everywhere on ``r ≥ 0``. Notes: - Not a drop-in replacement for ``log_kv`` — it returns the sum, not ``log K_s`` alone. Use this only where the caller needs the cancellation-preserved combination (skewed-t log-PDF, and any future GH / GIG log-density rewrite). - Requires ``s > 0``. The degenerate case ``s = 0`` (where ``K_0`` is logarithmic, not power) is not handled here — ``lgamma(0) = +inf`` would leak into the small-x formula. References: DLMF 10.30.2 (leading asymptotic), 10.30/10.31 (next-order correction, giving the ``O(r^{2s})`` tail). """ s = jnp.asarray(s, dtype=float).reshape(()) r = jnp.asarray(r, dtype=float) r_safe = jnp.maximum(r, _LOG_KV_PLUS_S_LOG_R_FLOOR) return log_kv(s, r_safe) + s * jnp.log(r_safe)
######################################################################## # Digamma and trigamma functions ########################################################################
[docs] def digamma(x: ArrayLike) -> Array: r"""Digamma function, :math:`\psi(x) = \frac{d}{dx} \ln \Gamma(x)`. Computed via automatic differentiation of ``gammaln``, which is exact (not a finite-difference approximation). Args: x: Argument (array-like, must be > 0). Returns: Array of digamma values with the same shape as *x*. """ x = jnp.asarray(x, dtype=float) return jax.vmap(jax.grad(lambda z: special.gammaln(z)))(x.reshape(-1)).reshape( x.shape )
[docs] def trigamma(x: ArrayLike) -> Array: r"""Trigamma function, :math:`\psi'(x) = \frac{d^2}{dx^2} \ln \Gamma(x)`. Computed via automatic differentiation of ``gammaln``, which is exact (not a finite-difference approximation). Args: x: Argument (array-like, must be > 0). Returns: Array of trigamma values with the same shape as *x*. """ x = jnp.asarray(x, dtype=float) return jax.vmap( jax.grad(jax.grad(lambda z: special.gammaln(z))) )(x.reshape(-1)).reshape(x.shape)
######################################################################## # igammainv / igammacinv implementation ######################################################################## def _igammainv_impl(a, p, q): """Core computation for igammainv. Finds x such that gammainc(a, x) = p, where p + q = 1. Both ``p`` and ``q`` are accepted so that callers can preserve full precision of whichever value they have directly (avoiding catastrophic cancellation in ``1 - p`` when p ≈ 1). Uses a hybrid initial approximation (Wilson-Hilferty for moderate p with a >= 1, log-space left-tail asymptotic otherwise) refined by 6 unrolled Newton-Halley iterations. References: Didonato, A. and Morris, A. (1986). Computation of the Incomplete Gamma Function Ratios and their Inverse. ACM Trans. Math. Softw. 12(4), 377-393. """ TINY = jnp.finfo(a.dtype).tiny # ~2.2e-308 for float64 # --- Safe helpers (avoid log(0), ndtri(0), ndtri(1)) --- p_pos = jnp.maximum(p, TINY) q_pos = jnp.maximum(q, TINY) log_p = jnp.log(p_pos) # Normal quantile: use symmetry ndtri(1-q) = -ndtri(q) to always pass # the smaller of p, q to ndtri (avoiding precision loss near 1). s = jnp.where( p <= 0.5, special.ndtri(jnp.clip(p_pos, TINY, 0.5)), -special.ndtri(jnp.clip(q_pos, TINY, 0.5)), ) # --- Initial approximation: Wilson-Hilferty (good for a >= 1, moderate p) --- a_safe = jnp.maximum(a, 1.0) t = 1.0 / (9.0 * a_safe) w = 1.0 - t + s * jnp.sqrt(t) x_wh = a * jnp.power(jnp.maximum(w, TINY), 3.0) # --- Initial approximation: log-space left-tail asymptotic (any a, small p) --- # From gammainc(a,x) ~ x^a / (a * Gamma(a)) for small x: # x ~ (p * a * Gamma(a))^(1/a) = (p * Gamma(a+1))^(1/a) # log(x) = (log(p) + gammaln(a+1)) / a log_x_left = (log_p + special.gammaln(a + 1.0)) / jnp.maximum(a, TINY) # Allow underflow to 0 for extremely small values (correct for a << 1) x_left = jnp.exp(jnp.minimum(log_x_left, 708.0)) # --- Initial approximation: exponential-like for right tail --- # For a <= 1 with p close to 1, the left-tail asymptotic is wrong. # x ~ -log(q) is the exponential (a=1) exact formula and serves as a # reasonable starting point for a < 1 right-tail cases. x_exp = -jnp.log(q_pos) # --- Select initial guess --- # Wilson-Hilferty is dramatically better for moderate-to-large p with a >= 1 # (up to 360,000x at a=100, p=0.5). But when the cube base w collapses # (extreme left tail), it fails and the log-space formula takes over. # For a < 1 with p > 0.5, the left-tail formula can break down for the # right tail (e.g. a=0.75, p=0.99: gives ~0.8 instead of ~3.3). But # x_exp = -log(q) can be wrong for very small a (e.g. a=0.01, p=0.9: # gives 2.3 instead of 1.5e-5). Neither heuristic is reliable across # all (a, p) pairs, so we evaluate gammainc once to pick the better # starting point. use_wh = (a >= 1.0) & (w > 0.2) x_left_safe = jnp.maximum(x_left, TINY) x_exp_safe = jnp.maximum(x_exp, TINY) err_left = jnp.abs(special.gammainc(a, x_left_safe) - p) err_exp = jnp.abs(special.gammainc(a, x_exp_safe) - p) use_right = (a < 1.0) & (p > 0.5) & (err_exp < err_left) x = jnp.where(use_wh, x_wh, x_left) x = jnp.where(use_right, x_exp, x) x = jnp.where(jnp.equal(a, 1.0), x_exp, x) x = jnp.maximum(x, TINY) # --- Newton-Halley refinement (6 iterations, unrolled) --- lgamma_a = special.gammaln(a) for _ in range(6): # fac = x^a * exp(-x) / Gamma(a) (proportional to x * gamma_pdf) log_fac = a * jnp.log(jnp.maximum(x, TINY)) - x - lgamma_a fac = jnp.exp(log_fac) fac_safe = jnp.maximum(fac, TINY) # f / f' using gammainc for p <= 0.5, gammaincc for p > 0.5 # (uses q directly in the upper branch to avoid cancellation) f_over_fprime = jnp.where( p <= 0.5, (special.gammainc(a, x) - p) * x / fac_safe, -(special.gammaincc(a, x) - q) * x / fac_safe, ) # f'' / f' = -1 + (a - 1) / x (guard overflow for tiny x) fprime2_over_fprime = jnp.where( x > 1e-100, -1.0 + (a - 1.0) / x, 0.0, ) # Halley step with safe denominator halley_denom = 1.0 - 0.5 * f_over_fprime * fprime2_over_fprime halley_denom = jnp.where( jnp.abs(halley_denom) < 1e-10, 1.0, halley_denom ) step = f_over_fprime / halley_denom # Prevent step from making x negative step = jnp.maximum(step, -0.9 * x) x = jnp.where(fac > 0.0, x - step, x) x = jnp.maximum(x, TINY) # Boundary conditions x = jnp.where(p <= 0.0, 0.0, x) x = jnp.where(p >= 1.0, jnp.inf, x) return x
[docs] def igammainv(a: ArrayLike, p: ArrayLike) -> Array: r"""Inverse of the regularized lower incomplete gamma function. Finds :math:`x` such that :math:`\mathrm{gammainc}(a, x) = p`. Args: a: positive shape parameter. p: probability values in :math:`[0, 1]`. Returns: Array of the same shape as the broadcast of ``a`` and ``p``. """ a = jnp.asarray(a, dtype=float) p = jnp.asarray(p, dtype=float) return _igammainv_impl(a, p, 1.0 - p)
[docs] def igammacinv(a: ArrayLike, p: ArrayLike) -> Array: r"""Inverse of the regularized upper incomplete gamma function. Finds :math:`x` such that :math:`\mathrm{gammaincc}(a, x) = p`, equivalently :math:`\mathrm{gammainc}(a, x) = 1 - p`. Args: a: positive shape parameter. p: probability values in :math:`[0, 1]`. Returns: Array of the same shape as the broadcast of ``a`` and ``p``. """ a = jnp.asarray(a, dtype=float) q = jnp.asarray(p, dtype=float) return _igammainv_impl(a, 1.0 - q, q)
######################################################################## # stdtr implementation ######################################################################## def _stdtr_impl(df: Scalar, t: Array) -> Array: """Primal Student-t CDF implementation. Uses the complementary betainc form so the small tail probability is computed directly, avoiding catastrophic cancellation when betainc ≈ 1 in the deep tails: x = df / (df + t²) ib = I_x(df/2, 1/2) — small in both tails CDF = 0.5·ib if t < 0 = 1 − 0.5·ib if t ≥ 0 Reference: A&S 26.5.27; DLMF 8.17.4. """ t2 = t * t x = df / (df + t2) ib = special.betainc(0.5 * df, 0.5, x) return jnp.where(t < 0, 0.5 * ib, 1.0 - 0.5 * ib) def _stdtr_pdf_t(df: Scalar, t: Array) -> Array: """Derivative of stdtr(df, t) w.r.t. t (Student-t PDF at t).""" log_norm = ( special.gammaln(0.5 * (df + 1.0)) - special.gammaln(0.5 * df) - 0.5 * (jnp.log(df) + jnp.log(jnp.pi)) ) log_kernel = -0.5 * (df + 1.0) * jnp.log1p((t * t) / df) return jnp.exp(log_norm + log_kernel) @jax.custom_vjp def stdtr(df: Scalar, t: Array) -> Array: r"""Compute the cdf of the standard Student's t-distribution. Note: Gradient flow is supported for ``t``. Gradient flow for ``df`` is explicitly disabled (set to zero) because ``jax.scipy.special.betainc`` does not support gradients w.r.t. its first two arguments. Args: df (scalar): degrees of freedom. t (Array): values at which to evaluate the cdf. Returns: Array: cdf values of the standard Student's t-distribution. """ # transforming args df: Scalar = jnp.asarray(df, dtype=float).reshape(()) t: Array = jnp.asarray(t, dtype=float) return _stdtr_impl(df, t) def _stdtr_fwd(df: Scalar, t: Array) -> tuple[Array, tuple[Scalar, Array]]: df = jnp.asarray(df, dtype=float).reshape(()) t = jnp.asarray(t, dtype=float) y = _stdtr_impl(df, t) return y, (df, t) def _stdtr_bwd(res: tuple[Scalar, Array], g: Array) -> tuple[Scalar, Array]: df, t = res pdf_t = _stdtr_pdf_t(df, t) d_df = jnp.zeros_like(df) d_t = g * pdf_t return d_df, d_t stdtr.defvjp(_stdtr_fwd, _stdtr_bwd)