Source code for copulax._src.copulas._mv_copulas

"""CopulAX implementation of mean-variance (normal-mixture) copulas.

Houses the umbrella :class:`MeanVarianceCopulaBase` plus its two
taxonomic sub-bases — :class:`EllipticalCopula` (γ=0; pure normal
*variance* mixtures: Gaussian, Student-T) and
:class:`MeanVarianceCopula` (γ≠0; normal *mean-variance* mixtures:
GH, Skewed-T) — and all four concrete copula classes.

The shared :class:`CopulaBase` (Sklar / marginal fitting / sampling
machinery used by every copula family, including Archimedean) lives
in ``_distributions.py`` and is imported here.

Reference:
    McNeil, Frey, Embrechts (2005) *Quantitative Risk Management*,
    §3.2.1 (variance mixtures), §3.2.2 (mean-variance mixtures),
    §3.2.4 Algorithm 3.14 (EM / ECME for GH).
"""

from abc import abstractmethod
from jax import Array
from jax.typing import ArrayLike
from typing import Callable
import jax
from jax import numpy as jnp
from jax import jit, vmap, lax
import jax.nn as jnn

from copulax._src._distributions import Multivariate, Univariate
from copulax._src.copulas._distributions import CopulaBase
from copulax._src.multivariate._utils import _multivariate_input
from copulax._src._utils import _resolve_key
from copulax._src.typing import Scalar
from copulax._src.multivariate._shape import corr, _corr
from copulax._src.optimize import projected_gradient, adam
from functools import partial

from copulax._src.multivariate.mvt_normal import mvt_normal
from copulax._src.univariate.normal import normal
from copulax._src.multivariate.mvt_student_t import mvt_student_t
from copulax._src.univariate.student_t import student_t
from copulax._src.multivariate.mvt_gh import mvt_gh
from copulax._src.univariate.gh import gh, GH
from copulax._src.multivariate.mvt_skewed_t import mvt_skewed_t
from copulax._src.univariate.skewed_t import skewed_t
from copulax._src.copulas._mom_init import mom_nu_student_t, mom_gh_params

# Module-level constants for copula parameter constraints
_NU_EPS: float = 1e-6
_POS_EPS: float = 1e-8

# Fitting constants
_GRAD_CLIP: float = 10.0
_EPS: float = 1e-8

# Per-method accepted kwargs for ``MeanVarianceCopulaBase.fit_copula``.
# Used to fail fast on inapplicable kwargs (e.g. passing ``brent`` with
# ``method='fc_mle'``) instead of silently dropping them.  ``corr_method``
# is accepted by every method (Stage 1 correlation estimator).
_METHOD_KWARGS: dict[str, frozenset[str]] = {
    "fc_mle":            frozenset({"lr", "maxiter"}),
    "mle":               frozenset({"lr", "maxiter", "brent", "nodes",
                                    "shape_steps"}),
    "ecme":              frozenset({"lr", "maxiter", "brent", "nodes",
                                    "em_maxiter", "shape_steps"}),
    "ecme_double_gamma": frozenset({"lr", "maxiter", "brent", "nodes",
                                    "em_maxiter", "shape_steps"}),
    "ecme_outer_gamma":  frozenset({"lr", "maxiter", "brent", "nodes",
                                    "em_maxiter", "shape_steps"}),
}


def _inv_softplus(x: jnp.ndarray) -> jnp.ndarray:
    r"""Numerically stable inverse of ``jax.nn.softplus``.

    For large x, ``softplus(x) ≈ x`` so ``inv_softplus(x) ≈ x``.
    For small x, ``inv_softplus(x) = log(expm1(x))``.  The crossover
    at x=20 avoids float32 overflow in ``expm1``.

    Args:
        x: Input array (positive values).

    Returns:
        Array y such that ``softplus(y) ≈ x``.
    """
    return jnp.where(x > 20.0, x, jnp.log(jnp.expm1(jnp.minimum(x, 20.0))))


###############################################################################
# Shared copula fitting helpers
###############################################################################
def _reset_adam_state(
    adam_state: tuple[jnp.ndarray, jnp.ndarray, int],
) -> tuple[jnp.ndarray, jnp.ndarray, int]:
    r"""Fully reset Adam first/second moment and step counter to zero.

    Used between outer iterations of the copula EM/MLE fitting loops.
    Each outer iteration is treated as a fresh subproblem starting from
    the EM-warm-started parameters: any momentum carried from the
    previous outer iteration was computed under a different
    (gamma, sigma, x') configuration and is therefore stale.  Even
    parameters that did not move directly (e.g. nu) live on a loss
    surface that has shifted, so their stored gradient direction is no
    longer reliable.  A full reset removes the staleness uniformly.

    Args:
        adam_state: Tuple ``(m, v, t)`` — first moment, second moment,
            step counter.

    Returns:
        Zeroed Adam state ``(0, 0, 0)`` with the same shapes/dtypes
        as the input.
    """
    m, v, _ = adam_state
    return (jnp.zeros_like(m), jnp.zeros_like(v), jnp.array(0))


def _adam_gradient_step(
    nll_fn: Callable,
    opt_arr: jnp.ndarray,
    adam_state: tuple[jnp.ndarray, jnp.ndarray, int],
    lr: float,
) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, int]]:
    r"""Compute NLL gradient, clip, and apply one Adam update.

    This is a plain Python function (not JIT-decorated) that is called
    inside JIT-compiled closures.  JAX traces through it during
    compilation, so all operations must be JAX-traceable.

    Args:
        nll_fn: Scalar-valued negative log-likelihood function of
            ``opt_arr``.
        opt_arr: Current parameter vector.
        adam_state: Tuple ``(m, v, t)``.
        lr: Learning rate.

    Returns:
        Tuple of (updated opt_arr, new adam_state).
    """
    _, grad = jax.value_and_grad(nll_fn)(opt_arr)
    grad = jnp.nan_to_num(grad, nan=0.0)
    grad = jnp.clip(grad, -_GRAD_CLIP, _GRAD_CLIP)
    m, v, t = adam_state
    direction, m, v, t = adam(grad, m, v, t)
    return opt_arr - lr * direction, (m, v, t)


def _skewed_t_gig_posteriors(
    nu: jnp.ndarray,
    sigma_inv: jnp.ndarray,
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    eps: float = _EPS,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    r"""GIG posterior parameters for the Skewed-T copula.

    For the Skewed-T distribution,
    :math:`W_i | X_i \sim \text{GIG}(-(\nu+d)/2, \nu + Q_i, R_\gamma)`.

    Args:
        nu: Degrees of freedom (scalar).
        sigma_inv: Inverse correlation matrix, shape ``(d, d)``.
        x: Centred data (mu=0), shape ``(n, d)``.
        gamma: Skewness vector, shape ``(d, 1)``.
        eps: Floor for psi_post.

    Returns:
        ``(lam_post, chi_post, psi_post)`` — per-sample GIG parameters.
    """
    d = x.shape[1]
    Q = jnp.sum((x @ sigma_inv) * x, axis=1)
    R = (gamma.T @ sigma_inv @ gamma).squeeze()
    lam_post = -nu / 2.0 - d / 2.0
    chi_post = nu + Q
    psi_post = jnp.maximum(R, eps)
    return lam_post, chi_post, psi_post


def _gh_gig_posteriors(
    lamb: jnp.ndarray,
    chi: jnp.ndarray,
    psi: jnp.ndarray,
    sigma_inv: jnp.ndarray,
    x: jnp.ndarray,
    gamma: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    r"""GIG posterior parameters for the GH copula.

    For the GH distribution,
    :math:`W_i | X_i \sim \text{GIG}(\lambda - d/2,
    \chi + Q_i, \psi + R_\gamma)`.

    Args:
        lamb: GH lamb parameter (scalar).
        chi: GH chi parameter (scalar).
        psi: GH psi parameter (scalar).
        sigma_inv: Inverse correlation matrix, shape ``(d, d)``.
        x: Centred data (mu=0), shape ``(n, d)``.
        gamma: Skewness vector, shape ``(d, 1)``.

    Returns:
        ``(lam_post, chi_post, psi_post)`` — per-sample GIG parameters.
    """
    d = x.shape[1]
    Q = jnp.sum((x @ sigma_inv) * x, axis=1)
    R = (gamma.T @ sigma_inv @ gamma).squeeze()
    lam_post = lamb - d / 2.0
    chi_post = chi + Q
    psi_post = psi + R
    return lam_post, chi_post, psi_post


@partial(jax.jit, static_argnames=("update_gamma",))
def _copula_inner_em_body(
    gamma: jnp.ndarray,
    sigma: jnp.ndarray,
    x: jnp.ndarray,
    lam_post: jnp.ndarray,
    chi_post: jnp.ndarray,
    psi_post: jnp.ndarray,
    update_gamma: bool,
    eps: float = _EPS,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    r"""Shared inner EM body for copula fitting.

    Computes GIG posterior expectations, optionally updates gamma
    (mu=0 constraint), and updates the correlation matrix sigma.

    Args:
        gamma: Skewness vector, shape ``(d, 1)``.
        sigma: Current correlation matrix, shape ``(d, d)``.
        x: Centred data (mu=0), shape ``(n, d)``.
        lam_post: GIG lamb posterior, shape ``(n,)``.
        chi_post: GIG chi posterior, shape ``(n,)``.
        psi_post: GIG psi posterior, shape ``(n,)`` or scalar.
        update_gamma: Whether to update gamma (True for em/em2,
            False for em3).
        eps: Numerical stability constant.

    Returns:
        ``(gamma_new, sigma_new)``.
    """
    d = x.shape[1]

    delta = jnp.clip(
        GH._gig_expected_inv_w(lam_post, chi_post, psi_post), eps, 1e10
    )
    eta = jnp.clip(
        GH._gig_expected_w(lam_post, chi_post, psi_post), eps, 1e10
    )

    eta_bar = jnp.mean(eta)

    if update_gamma:
        x_bar = jnp.mean(x, axis=0).reshape((d, 1))
        eta_bar_safe = jnp.maximum(eta_bar, eps)
        gamma = jnp.clip(x_bar / eta_bar_safe, -10.0, 10.0)

    psi_mat = (
        jnp.mean(
            delta[:, None, None] * (x[:, :, None] * x[:, None, :]),
            axis=0,
        )
        - eta_bar * (gamma @ gamma.T)
    )
    psi_mat = _corr._rm_incomplete(psi_mat, 1e-5)
    sigma = _corr._corr_from_cov(psi_mat)
    sigma = _corr._ensure_valid(sigma)

    return gamma, sigma


@partial(jax.jit, static_argnames=("update_gamma",))
def _inner_em_step_skewed_t(
    gamma: jnp.ndarray,
    sigma: jnp.ndarray,
    x: jnp.ndarray,
    nu: jnp.ndarray,
    update_gamma: bool,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    r"""Inner EM step for the Skewed-T copula.

    Computes GIG posterior parameters for the Skewed-T family,
    then delegates to the shared EM body.

    Args:
        gamma: Skewness vector, shape ``(d, 1)``.
        sigma: Correlation matrix, shape ``(d, d)``.
        x: Centred data (mu=0), shape ``(n, d)``.
        nu: Degrees of freedom (scalar).
        update_gamma: Whether to update gamma.

    Returns:
        ``(gamma_new, sigma_new)``.
    """
    sigma_inv = jnp.linalg.inv(sigma)
    lp, cp, pp = _skewed_t_gig_posteriors(nu, sigma_inv, x, gamma)
    return _copula_inner_em_body(gamma, sigma, x, lp, cp, pp, update_gamma)


@partial(jax.jit, static_argnames=("update_gamma",))
def _inner_em_step_gh(
    gamma: jnp.ndarray,
    sigma: jnp.ndarray,
    x: jnp.ndarray,
    lamb: jnp.ndarray,
    chi: jnp.ndarray,
    psi: jnp.ndarray,
    update_gamma: bool,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    r"""Inner EM step for the GH copula.

    Computes GIG posterior parameters for the GH family,
    then delegates to the shared EM body.

    Args:
        gamma: Skewness vector, shape ``(d, 1)``.
        sigma: Correlation matrix, shape ``(d, d)``.
        x: Centred data (mu=0), shape ``(n, d)``.
        lamb: GH lamb (scalar).
        chi: GH chi (scalar).
        psi: GH psi (scalar).
        update_gamma: Whether to update gamma.

    Returns:
        ``(gamma_new, sigma_new)``.
    """
    sigma_inv = jnp.linalg.inv(sigma)
    lp, cp, pp = _gh_gig_posteriors(lamb, chi, psi, sigma_inv, x, gamma)
    return _copula_inner_em_body(gamma, sigma, x, lp, cp, pp, update_gamma)


###############################################################################
# Mean-Variance Copula Base Hierarchy
###############################################################################
[docs] class MeanVarianceCopulaBase(CopulaBase): r"""Umbrella base class for normal-mixture copula distributions. Holds the shared ``fit_copula`` dispatcher, ``_METHOD_KWARGS`` validation, correlation estimation, and other machinery common to both normal *variance* mixture copulas (true elliptical, γ=0; see :class:`EllipticalCopula`) and normal *mean-variance* mixture copulas (with skewness γ; see :class:`MeanVarianceCopula`). Reference: McNeil, Frey, Embrechts (2005) *Quantitative Risk Management*, §3.2.1 (variance mixtures) and §3.2.2 (mean-variance mixtures). """ # Concrete sub-bases override. The umbrella alone supports nothing. _supported_methods: frozenset = frozenset() _mvt: Multivariate _uvt: Univariate # initialisation def __init__( self, name, mvt: Multivariate, uvt: Univariate, *, marginals=None, copula=None, ): # MeanVarianceCopulaBase and its two sub-bases (EllipticalCopula, # MeanVarianceCopula) are abstract: they carry the dispatcher / # taxonomic role but have no concrete ``_mvt`` / ``_uvt`` pair # and their ``_supported_methods`` set is empty (umbrella) or # cannot be fit without a concrete subclass's ``_fit_copula_*`` # implementations. Refuse direct instantiation so users get a # clear error instead of a silently-broken object. if type(self) in _ABSTRACT_MV_BASES: raise TypeError( f"{type(self).__name__} is abstract; instantiate one of " f"its concrete subclasses (GaussianCopula, StudentTCopula, " f"GHCopula, SkewedTCopula)." ) super().__init__(name) self._mvt: Multivariate = mvt # multivariate pytree object self._uvt: Univariate = uvt # univariate pytree object self._marginals = marginals if marginals is not None else None self._copula_params = copula if copula is not None else None def _fitted_instance(self, params_dict: dict, name: str = None): """Create a fitted Copula instance (passes mvt/uvt positional args). Args: params_dict: Fitted parameter values. name: Optional custom name for the fitted instance. If ``None``, an auto-generated name is used. Returns: A new Copula instance with the given parameters. """ cls = type(self) if name is None: name = f"Fitted{cls.__name__}-{id(params_dict):x}" return cls(name, self._mvt, self._uvt, **params_dict) def _params_to_tuple(self, params: dict) -> tuple: """Return an empty tuple (elliptical copula params held in dict).""" return tuple()
[docs] def example_params(self, dim: int = 3, *args, **kwargs): r"""Example parameters for the copula distribution. Generates example marginal and copula parameters for the overall joint distribution. Args: dim: int, number of dimensions of the copula distribution. Default is 3. """ # copula parameters mvt_params: dict = self._mvt.example_params(dim=dim, *args, **kwargs) mvt_params["sigma"] = jnp.eye(dim, dim) # marginal parameters marginal_params: tuple = tuple( (self._uvt, self._uvt.example_params(dim=dim)) for _ in range(dim) ) # joint parameters return {"marginals": marginal_params, "copula": mvt_params}
def _get_uvt_params(self, params: dict) -> tuple: """Returns the univariate distribution parameters.""" return tuple() def _scan_uvt_func(self, func: Callable, x: Array, params: dict, **kwargs) -> Array: """Applies func per dimension, vectorized with vmap.""" batched_params: dict = self._get_uvt_params(params) def _per_dim(xi_col, p_slice): return func(xi_col, params=p_slice, **kwargs) return vmap(_per_dim, in_axes=(1, 0), out_axes=1)(x, batched_params)
[docs] def get_x_dash( self, u: ArrayLike, params: dict, brent: bool = False, nodes: int = 100, ) -> Array: r"""Computes x' values, which represent the mappings of the independent marginal cdf values (U) to the domain of the joint multivariate distribution. Routes through :py:meth:`Univariate.ppf`, so distributions with an analytical inverse CDF (Normal, Gamma, LogNormal, IG, Uniform, Gen-Normal) use the closed-form path automatically and ignore ``nodes``. Note: If you intend to jit wrap this function, both ``brent`` and ``nodes`` must be static arguments. Args: u (ArrayLike): The independent univariate marginal cdf values (U) for each dimension, shape ``(n, d)``. params (dict): The copula and marginal distribution parameters. brent (bool): If ``False`` (default), use the analytical inverse CDF when available and otherwise the Chebyshev-node cubic spline approximation. If ``True``, force per-quantile Brent root-finding (machine-epsilon accurate, slower). nodes (int): Number of Chebyshev-Lobatto nodes used by the cubic spline path. Ignored for analytical marginals and when ``brent=True``. Returns: ``x'`` values of shape ``(n, d)``. """ u_raw: jnp.ndarray = _multivariate_input(u)[0] eps: float = 1e-4 u_clipped: jnp.ndarray = jnp.clip(u_raw, eps, 1 - eps) uvt = self._uvt batched_params: dict = self._get_uvt_params(params) def _per_dim(xi_col, p_slice): p = uvt._resolve_params(p_slice) return uvt.ppf(xi_col, params=p, brent=brent, nodes=nodes) return vmap(_per_dim, in_axes=(1, 0), out_axes=1)(u_clipped, batched_params)
# densities
[docs] def copula_logpdf( self, u: ArrayLike, params: dict = None, brent: bool = False, nodes: int = 100, ) -> Array: r"""Computes the log-pdf of the copula distribution. Note: If you intend to jit wrap this function, both ``brent`` and ``nodes`` must be static arguments. Args: u (ArrayLike): The independent univariate marginal cdf values (u) for each dimension. params (dict): The copula and marginal distribution parameters. brent (bool): Forwarded to :py:meth:`get_x_dash`. ``False`` (default) uses the analytical inverse CDF when available and otherwise the Chebyshev cubic spline; ``True`` forces per-quantile Brent root-finding. nodes (int): Number of Chebyshev-Lobatto nodes used by the cubic spline path. Ignored for analytical marginals and when ``brent=True``. Returns: logpdf (Array): The log-pdf values of the copula distribution. """ # mapping u to x' space params = self._resolve_params(params) x_dash: jnp.ndarray = self.get_x_dash(u, params, brent=brent, nodes=nodes) # computing univariate logpdfs uvt_logpdf: jnp.ndarray = self._scan_uvt_func( func=self._uvt.logpdf, x=x_dash, params=params ) # computing copula logpdf mvt_params: dict = params["copula"] mvt_logpdf: jnp.ndarray = self._mvt.logpdf(x_dash, params=mvt_params) return mvt_logpdf - uvt_logpdf.sum(axis=1, keepdims=True)
# sampling
[docs] def copula_rvs(self, size: Scalar, params: dict = None, key: Array = None) -> Array: r"""Generates random samples from the copula distribution. Note: If you intend to jit wrap this function, ensure that 'size' is a static argument. Args: size (Scalar): size (Scalar): The size / shape of the generated output array of random numbers. Must be scalar. Generates an (size, d) array of random numbers, where d is the number of dimensions inferred from the provided distribution parameters. params (dict): The copula and marginal distribution parameters. key (Array): The Key for random number generation. """ params = self._resolve_params(params) key = _resolve_key(key) # generating random samples from x' x_dash: jnp.ndarray = self._mvt.rvs(size=size, key=key, params=params["copula"]) # projecting x' to u space return self._scan_uvt_func(self._uvt.cdf, x=x_dash, params=params)
# fitting def _estimate_copula_correlation( self, u: jnp.ndarray, corr_method: str ) -> Array: r"""Estimate the copula correlation matrix from pseudo-observations. For elliptical copulas, the recommended method is ``rm_pp_kendall`` which computes Kendall's tau, applies :math:`\sin(\pi/2 \cdot \tau)` to recover the linear correlation parameter (Proposition 5.37, McNeil et al. 2005), and denoises via eigenvalue clamping to ensure positive semi-definiteness. Args: u: Pseudo-observations of shape ``(n, d)`` in ``[0, 1]``. corr_method: Correlation estimation method. Recommended: ``'rm_pp_kendall'`` (default). See ``copulax.multivariate.corr`` for all methods. Returns: Estimated correlation matrix of shape ``(d, d)``. """ return corr(x=u, method=corr_method) def _build_initial_copula_params(self, d: int, sigma: Array) -> dict: r"""Construct initial copula parameters with the estimated correlation matrix and sensible defaults for other parameters. Subclasses must override to add distribution-specific parameters (e.g. nu for Student-t, gamma for skewed-t). Args: d: Dimensionality. sigma: Estimated correlation matrix of shape ``(d, d)``. Returns: Initial copula parameter dictionary. """ return self._mvt._params_dict( mu=jnp.zeros((d, 1)), sigma=sigma ) def _copula_nll( self, opt_arr: jnp.ndarray, u: jnp.ndarray, sigma: jnp.ndarray, dummy_marginals: tuple, ) -> Scalar: r"""Negative copula log-likelihood for optimisation. Gradients flow through the PPF via the implicit function theorem (custom JVP on the PPF), giving exact derivatives without differentiating through the root-finder or cubic spline. Args: opt_arr: Flat array of parameters being optimised. u: Pseudo-observations, shape ``(n, d)``. sigma: Fixed correlation matrix, shape ``(d, d)``. dummy_marginals: Tuple of (dist, params) for dimension inference. Returns: Scalar negative log-likelihood. """ d: int = sigma.shape[0] copula_params: dict = self._reconstruct_copula_opt_params( opt_arr, sigma, d ) full_params: dict = { "marginals": dummy_marginals, "copula": copula_params } logpdf: Array = self.copula_logpdf(u, params=full_params) n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) safe_logpdf = jnp.where(finite_mask, logpdf, 0.0) n_invalid = (~finite_mask).astype(float).sum() return -safe_logpdf.sum() / n + 1e6 * n_invalid / n def _reconstruct_copula_opt_params( self, opt_arr: jnp.ndarray, sigma: jnp.ndarray, d: int, ) -> dict: r"""Rebuild copula params dict from optimiser output + fixed sigma. Called by :py:meth:`_copula_nll`. Only subclasses that route ``_fit_copula_fc_mle`` through :func:`projected_gradient` on :py:meth:`_copula_nll` need to override (StudentT, GH, SkewedT). Gaussian's trivial fc_mle doesn't reach this path, so raising here is fine for the umbrella default. """ raise NotImplementedError( f"{type(self).__name__} does not implement " f"_reconstruct_copula_opt_params; its fc_mle path should " f"not reach _copula_nll." )
[docs] def fit_copula( self, u: ArrayLike, corr_method: str = "rm_pp_kendall", method: str = "fc_mle", **kwargs, ) -> dict: r"""Fit copula parameters from pseudo-observations. Two-stage estimation following McNeil, Frey & Embrechts (2005), Section 5.5: **Stage 1** — Estimate the copula correlation matrix *P* from the pseudo-observations *u* using rank correlation. The default ``rm_pp_kendall`` computes Kendall's tau, applies the inversion :math:`\hat\rho_{ij} = \sin(\tfrac{\pi}{2}\,\hat\tau_{ij})` (Proposition 5.37), and ensures positive semi-definiteness via eigenvalue clamping. **Stage 2** — Estimate remaining parameters (e.g. *ν*, *γ*) by maximising the copula log-likelihood, with *P* either held fixed at the Stage 1 estimate or jointly re-optimised, depending on ``method``. Args: u: Pseudo-observations of shape ``(n, d)`` in ``[0, 1]``. corr_method: Correlation estimation method for Stage 1. Default ``'rm_pp_kendall'``. See ``copulax.multivariate.corr`` for all methods. method: Fitting algorithm for Stage 2. One of: ``'fc_mle'`` — *Fixed-Correlation MLE*: shape parameters optimised via projected gradient with Σ held at the Stage 1 Kendall-τ estimate. Available for every concrete subclass. ``'mle'`` — *Full joint MLE*: all parameters (shape + Σ off-diagonals) optimised together via Adam, with Σ tanh-parameterised onto the correlation manifold. Mean-variance subclasses (GH, SkewedT) only. ``'ecme'`` — Inner EM updates (P, γ); outer gradient descent on the copula log-likelihood for the remaining shape parameters (McNeil §3.2.4 ECME variant). Mean-variance subclasses only. ``'ecme_double_gamma'`` — Like ``ecme`` but γ is additionally re-optimised in the outer numerical M-step (so γ is updated twice per outer iteration). Mean-variance subclasses only. ``'ecme_outer_gamma'`` — Inner EM updates Σ only (γ frozen); outer MLE on all shape parameters including γ. Mean-variance subclasses only. **kwargs: Method-specific keyword arguments. Each ``method`` accepts only its own set of kwargs; Common kwargs: ``lr`` (float, all methods), ``maxiter`` (int, all), ``brent`` (bool, all except ``fc_mle``), ``nodes`` (int, all except ``fc_mle``). Returns: dict with key ``'copula'`` containing fitted parameters. Raises: ValueError: If ``method`` is not accepted by this subclass, or if ``kwargs`` contains a key not accepted by the chosen method. """ # --- Validate method + kwargs (Python-level; happens at trace # time when fit_copula is JIT-wrapped with method as a static # arg, so the dispatcher remains JIT- and autograd-safe). --- self._check_method(method) allowed = _METHOD_KWARGS[method] unknown = set(kwargs) - allowed if unknown: raise ValueError( f"Method {method!r} does not accept kwargs " f"{sorted(unknown)}. Accepted: {sorted(allowed)}." ) # --- Resolve kwargs with documented defaults. --- lr = kwargs.get("lr", 1e-2) maxiter = kwargs.get("maxiter", 200) brent = kwargs.get("brent", False) nodes = kwargs.get("nodes", 100) em_maxiter = kwargs.get("em_maxiter", 5) shape_steps = kwargs.get("shape_steps", 10) u_arr, _, n, d = _multivariate_input(u) # Stage 1: estimate correlation matrix P sigma: jnp.ndarray = self._estimate_copula_correlation( u_arr, corr_method ) # Stage 2: estimate remaining parameters if method == "fc_mle": copula_params = self._fit_copula_fc_mle( u_arr, sigma, d, lr, maxiter, ) elif method == "ecme": copula_params = self._fit_copula_ecme( u_arr, sigma, d, lr, maxiter, brent, nodes, em_maxiter, shape_steps, ) elif method == "ecme_double_gamma": copula_params = self._fit_copula_ecme_double_gamma( u_arr, sigma, d, lr, maxiter, brent, nodes, em_maxiter, shape_steps, ) elif method == "ecme_outer_gamma": copula_params = self._fit_copula_ecme_outer_gamma( u_arr, sigma, d, lr, maxiter, brent, nodes, em_maxiter, shape_steps, ) elif method == "mle": copula_params = self._fit_copula_mle( u_arr, sigma, d, lr, maxiter, brent, nodes, shape_steps, ) else: # Should be unreachable thanks to the _supported_methods # guard above, but kept as a defensive backstop. raise ValueError( f"Unhandled supported method {method!r} on " f"{type(self).__name__}; implementation missing." ) return {"copula": copula_params}
def _fit_copula_fc_mle( self, u: jnp.ndarray, sigma: jnp.ndarray, d: int, lr: float, maxiter: int, ) -> dict: r"""Fixed-Correlation MLE: shape-parameter MLE with Σ held fixed. Σ is taken as-is from the Stage 1 Kendall-τ estimate supplied by ``MeanVarianceCopulaBase.fit_copula``. Shape parameters are optimised via :func:`projected_gradient` on the negative copula log-likelihood. For the Gaussian copula this returns the correlation matrix directly (no additional shape parameters). Subclasses with shape parameters override this method; mean-variance subclasses additionally implement ``_fit_copula_mle`` / ``_fit_copula_ecme*`` for joint Σ optimisation. """ return self._build_initial_copula_params(d, sigma)
############################################################################### # Sub-base classes (taxonomic split: variance vs mean-variance mixtures) ###############################################################################
[docs] class EllipticalCopula(MeanVarianceCopulaBase): r"""True elliptical copulas (normal *variance* mixtures, γ=0). Concrete subclasses (:class:`GaussianCopula`, :class:`StudentTCopula`) only support the ``'fc_mle'`` Stage 2 fitting method. Reference: McNeil, Frey, Embrechts (2005) *Quantitative Risk Management*, §3.2.1 Normal Variance Mixtures. """ _supported_methods: frozenset = frozenset({"fc_mle"})
[docs] class MeanVarianceCopula(MeanVarianceCopulaBase): r"""Normal mean-variance mixture copulas with skewness γ. Concrete subclasses (:class:`GHCopula`, :class:`SkewedTCopula`) additionally implement γ-aware fitting methods (``mle``, ``ecme``, ``ecme_double_gamma``, ``ecme_outer_gamma``) on top of ``fc_mle``. Note: ``MeanVarianceCopulaBase`` is the broader umbrella covering this class **and** :class:`EllipticalCopula` (the γ=0 special case). This class is the proper γ≠0 specialisation. Reference: McNeil, Frey, Embrechts (2005) *Quantitative Risk Management*, §3.2.2 Normal Mean-Variance Mixtures, §3.2.4 Algorithm 3.14. """ _supported_methods: frozenset = frozenset({ "fc_mle", "mle", "ecme", "ecme_double_gamma", "ecme_outer_gamma", }) # ------------------------------------------------------------------ # fc_mle optimisation machinery # ------------------------------------------------------------------ # ``fc_mle`` on the mean-variance side is shape-parameter MLE with Σ # held at the Stage 1 Kendall-τ estimate, optimised via # :func:`projected_gradient` on :py:meth:`_copula_nll`. # # The two bookends are: # - :py:meth:`_get_opt_params_and_bounds` — initial parameter # vector + box constraints for :func:`projected_gradient`. # Subclass-specific; abstract here. # - :py:meth:`_optimize_copula_params` — glue that wires the above # through :func:`projected_gradient` and reconstructs the # fitted params. Concrete, shared across subclasses. # # These live on ``MeanVarianceCopula`` (and not on the umbrella) # because :class:`EllipticalCopula` subclasses either need no # optimisation (Gaussian) or inline the call to # :func:`projected_gradient` themselves (StudentT). @abstractmethod def _get_opt_params_and_bounds( self, d: int ) -> tuple[jnp.ndarray, dict]: r"""Return initial optimisation vector and box bounds. Returns: Tuple of (initial_params_array, projection_options_dict). """ def _optimize_copula_params( self, u: jnp.ndarray, sigma: jnp.ndarray, d: int, lr: float, maxiter: int, ) -> dict: r"""Optimise non-correlation copula parameters via ML. Uses ``projected_gradient`` to minimise the negative copula log-likelihood with the correlation matrix ``sigma`` held fixed. Args: u: Pseudo-observations, shape ``(n, d)``. sigma: Fixed correlation matrix, shape ``(d, d)``. d: Dimensionality. lr: Learning rate. maxiter: Maximum optimisation iterations. Returns: Fitted copula parameter dictionary. """ params0, proj_opts = self._get_opt_params_and_bounds(d) dummy_marginals: tuple = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) res: dict = projected_gradient( f=self._copula_nll, x0=params0, projection_method="projection_box", projection_options=proj_opts, u=u, sigma=sigma, dummy_marginals=dummy_marginals, lr=lr, maxiter=maxiter, ) return self._reconstruct_copula_opt_params(res["x"], sigma, d) # ------------------------------------------------------------------ # γ-aware fitting methods (ECME variants + full joint MLE) # ------------------------------------------------------------------ @abstractmethod def _fit_copula_mle( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, shape_steps: int = 10, ) -> dict: r"""Full joint MLE over Σ off-diagonals **and** all shape / skewness parameters. Σ is re-optimised via tanh-parameterisation of the off-diagonal correlations projected onto the correlation manifold, alongside the shape parameters, with Adam steps over a fixed ``maxiter`` outer loop. Unlike :py:meth:`_fit_copula_fc_mle`, which holds Σ fixed at the Kendall-τ rank-correlation estimate supplied by the base dispatcher, ``mle`` re-optimises Σ jointly with the shape parameters. ``shape_steps`` controls the number of inner Adam steps per outer iteration (default 10). Subclasses implement; abstract here for safety. """ @abstractmethod def _fit_copula_ecme( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ) -> dict: r"""ECME fitting: inner EM updates (Σ, γ); outer numerical maximisation of the original copula log-likelihood with respect to the remaining shape parameters (e.g. λ, χ, ψ for GH; ν for SkewedT) with γ and Σ held fixed at the inner-EM values (McNeil §3.2.4 ECME variant). ``em_maxiter`` is the inner EM scan length (default 5); ``shape_steps`` is the outer Adam scan length (default 10).""" @abstractmethod def _fit_copula_ecme_double_gamma( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ) -> dict: r"""ECME variant in which γ is updated *twice* per outer iteration: first by the inner EM step (alongside Σ), then again by the outer numerical M-step alongside the other shape parameters. ``em_maxiter`` / ``shape_steps`` as in :py:meth:`_fit_copula_ecme`.""" @abstractmethod def _fit_copula_ecme_outer_gamma( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ) -> dict: r"""ECME variant in which the inner EM updates Σ only (γ is held fixed in the inner step) and γ is optimised together with the other shape parameters in the outer numerical M-step. ``em_maxiter`` / ``shape_steps`` as in :py:meth:`_fit_copula_ecme`."""
# Set of classes that must not be directly instantiated. Checked in # :py:meth:`MeanVarianceCopulaBase.__init__`. Populated here, after all # three base classes are defined; resolved lazily at instantiation time. _ABSTRACT_MV_BASES: frozenset = frozenset( {MeanVarianceCopulaBase, EllipticalCopula, MeanVarianceCopula} ) ############################################################################### # Copula Distributions ############################################################################### # Normal Mixture Copulas
[docs] class GaussianCopula(EllipticalCopula): r"""The Gaussian Copula is a copula that uses the multivariate normal distribution to model the dependencies between random variables. The copula is parameterised by the correlation matrix *P* only. Fitting estimates *P* from pseudo-observations via rank correlation (McNeil et al. 2005, Example 5.58). https://en.wikipedia.org/wiki/Copula_(statistics) """ @jit def _get_uvt_params(self, params: dict) -> dict: """Extract univariate parameters for the Gaussian copula margins.""" d: int = self._get_dim(params) return {"mu": jnp.zeros(d), "sigma": jnp.ones(d)}
gaussian_copula = GaussianCopula("Gaussian-Copula", mvt_normal, normal)
[docs] class StudentTCopula(EllipticalCopula): r"""The Student-T Copula is a copula that uses the multivariate Student-T distribution to model the dependencies between random variables. The copula is parameterised by degrees of freedom *ν* and correlation matrix *P*. Fitting estimates *P* via Kendall's tau inversion and *ν* by maximising the copula log-likelihood (McNeil et al. 2005, Examples 5.54 and 5.59). https://en.wikipedia.org/wiki/Copula_(statistics) """ @jit def _get_uvt_params(self, params: dict) -> dict: """Extract univariate parameters for the student-t copula margins.""" nu: Scalar = params["copula"]["nu"] d: int = self._get_dim(params) return {"nu": jnp.full(d, nu), "mu": jnp.zeros(d), "sigma": jnp.ones(d)} def _build_initial_copula_params(self, d: int, sigma: Array) -> dict: return self._mvt._params_dict( nu=jnp.array(5.0), mu=jnp.zeros((d, 1)), sigma=sigma, ) def _reconstruct_copula_opt_params(self, opt_arr, sigma, d): r"""Rebuild the Student-T copula params dict from the optimised ``raw_nu`` entry produced by :py:meth:`_fit_copula_fc_mle` and fed back through the umbrella's :py:meth:`_copula_nll`.""" raw_nu = opt_arr[0] nu = jnn.softplus(raw_nu) + _NU_EPS return self._mvt._params_dict( nu=nu, mu=jnp.zeros((d, 1)), sigma=sigma, ) def _fit_copula_fc_mle(self, u, sigma, d, lr, maxiter): r"""Fixed-Correlation MLE for the Student-T copula. Σ is held at the Stage 1 Kendall-τ estimate; only ν (degrees of freedom) is optimised. Warm-start comes from a method-of-moments estimator (:func:`mom_nu_student_t`) matching the empirical quadratic-form median, clipped to ``[2.5, 200]`` before being mapped through ``inv_softplus`` for unconstrained optimisation. The optimisation uses :func:`projected_gradient` with a box constraint of ``[-10, 10]`` on the unconstrained ν parameter. """ R_inv = jnp.linalg.inv(sigma) nu_hat = mom_nu_student_t(u, R_inv, d) raw_nu0 = _inv_softplus(jnp.clip(nu_hat, 2.5, 200.0)) params0 = raw_nu0.reshape((1,)) proj_opts = { "lower": jnp.full((1, 1), -10.0), "upper": jnp.full((1, 1), 10.0), } dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) res = projected_gradient( f=self._copula_nll, x0=params0, projection_method="projection_box", projection_options=proj_opts, u=u, sigma=sigma, dummy_marginals=dummy_marginals, lr=lr, maxiter=maxiter, ) return self._reconstruct_copula_opt_params(res["x"], sigma, d)
student_t_copula = StudentTCopula("Student-T-Copula", mvt_student_t, student_t)
[docs] class GHCopula(MeanVarianceCopula): r"""The GH Copula is a copula that uses the multivariate generalized hyperbolic (GH) distribution to model the dependencies between random variables. The copula is parameterised by (λ, χ, ψ, γ) and correlation matrix *P*. Fitting estimates *P* via Kendall's tau inversion and the remaining parameters via ML or EM (McNeil et al. 2005, Section 5.5). https://en.wikipedia.org/wiki/Copula_(statistics) """ @jit def _get_uvt_params(self, params: dict) -> dict: """Extract univariate parameters for the GH copula margins.""" d: int = self._get_dim(params) lamb: Scalar = params["copula"]["lamb"] chi: Scalar = params["copula"]["chi"] psi: Scalar = params["copula"]["psi"] gamma: Array = params["copula"]["gamma"] return { "lamb": jnp.full(d, lamb), "chi": jnp.full(d, chi), "psi": jnp.full(d, psi), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": gamma.flatten(), } def _build_initial_copula_params(self, d: int, sigma: Array) -> dict: return self._mvt._params_dict( lamb=jnp.array(0.0), chi=jnp.array(1.0), psi=jnp.array(1.0), mu=jnp.zeros((d, 1)), gamma=jnp.zeros((d, 1)), sigma=sigma, ) def _get_opt_params_and_bounds(self, d: int): # Optimise [lamb, raw_chi, raw_psi, gamma_1..gamma_d] params0 = jnp.concatenate([ jnp.array([0.0]), # lamb jnp.log(jnp.expm1(jnp.array([1.0]))), # raw_chi jnp.log(jnp.expm1(jnp.array([1.0]))), # raw_psi jnp.zeros(d), # gamma ]) n_params = 3 + d proj_opts = { "lower": jnp.full((n_params, 1), -10.0), "upper": jnp.full((n_params, 1), 10.0), } return params0, proj_opts def _reconstruct_copula_opt_params(self, opt_arr, sigma, d): lamb = opt_arr[0] chi = jnn.softplus(opt_arr[1]) + _POS_EPS psi = jnn.softplus(opt_arr[2]) + _POS_EPS gamma = opt_arr[3:3 + d].reshape((d, 1)) return self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=jnp.zeros((d, 1)), gamma=gamma, sigma=sigma, ) def _fit_copula_fc_mle(self, u, sigma, d, lr, maxiter): r"""Fixed-Correlation MLE for the GH copula. Σ is held at the Stage 1 Kendall-τ estimate; the shape / mixing parameters (λ, χ, ψ) and skewness vector γ (3+d params total) are optimised jointly via :func:`projected_gradient` on the negative copula log-likelihood. The initial vector and box constraints come from :py:meth:`_get_opt_params_and_bounds`. """ return self._optimize_copula_params(u, sigma, d, lr, maxiter) def _gh_copula_nll_closure(self, d, mu, eps=_EPS): r"""Build a JIT-compiled copula NLL function for the GH family. Returns a function ``nll(opt_arr, sigma, x) -> scalar`` where ``opt_arr = [lamb, raw_chi, raw_psi, gamma_1..gamma_d]``. """ mvt = self._mvt uvt = self._uvt def _copula_nll(opt_arr, sigma_, x): l = opt_arr[0] c = jnn.softplus(opt_arr[1]) + eps p = jnn.softplus(opt_arr[2]) + eps g = opt_arr[3:].reshape((d, 1)) copula_p = mvt._params_dict( lamb=l, chi=c, psi=p, mu=mu, gamma=g, sigma=sigma_, ) mvt_ll = mvt.logpdf(x, params=copula_p) uvt_params = { "lamb": jnp.full(d, l), "chi": jnp.full(d, c), "psi": jnp.full(d, p), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": g.flatten(), } uvt_ll = vmap( lambda xi, pr: uvt.logpdf(xi, params=pr), in_axes=(1, 0), out_axes=1, )(x, uvt_params).sum(axis=1, keepdims=True) logpdf = mvt_ll - uvt_ll n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) safe = jnp.where(finite_mask, logpdf, 0.0) return -safe.sum() / n + 1e6 * (~finite_mask).sum() / n return _copula_nll def _gh_copula_ll(self, d, mu): r"""Build a JIT-compiled copula LL evaluator for convergence monitoring. Returns ``ll(x, lamb, chi, psi, gamma, sigma) -> scalar``.""" mvt = self._mvt uvt = self._uvt @jax.jit def _ll(x, lamb, chi, psi, gamma, sigma_): copula_p = mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma_, ) mvt_ll = mvt.logpdf(x, params=copula_p) uvt_params = { "lamb": jnp.full(d, lamb), "chi": jnp.full(d, chi), "psi": jnp.full(d, psi), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": gamma.flatten(), } uvt_ll = vmap( lambda xi, pr: uvt.logpdf(xi, params=pr), in_axes=(1, 0), out_axes=1, )(x, uvt_params).sum(axis=1, keepdims=True) logpdf = mvt_ll - uvt_ll n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) return jnp.where(finite_mask, logpdf, 0.0).sum() / n return _ll def _fit_copula_ecme( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ): r"""ECME fitting for the GH copula (McNeil §3.2.4 ECME variant). Alternates between inner EM (analytic updates of Σ and γ) and an outer numerical M-step that performs gradient descent on the **original copula log-likelihood** with respect to the remaining shape parameters λ, χ, ψ (with Σ and γ held fixed at the inner- EM values). This matches McNeil-Frey-Embrechts' explicitly-named ECME variant described in §3.2.4: *"instead of maximizing Q₂ we maximize the original likelihood ... with the other parameters held fixed"*. """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) copula_nll_fn = self._gh_copula_nll_closure(d, mu, eps) # --- JIT: inner EM scan (gamma + sigma) --- @jax.jit def _run_inner_em(gamma, sigma_, x_dash, lamb, chi, psi): def _scan_body(carry, _): g, s = carry g, s = _inner_em_step_gh(g, s, x_dash, lamb, chi, psi, True) return (g, s), None (g, s), _ = lax.scan( _scan_body, (gamma, sigma_), None, length=em_maxiter ) return g, s # --- JIT: shape CM scan (lamb, chi, psi) --- @jax.jit def _run_shape_steps(lamb, chi, psi, gamma, sigma_, adam_state, x_dash): def _copula_nll_shape(shape_arr): opt = jnp.concatenate([shape_arr, gamma.flatten()]) return copula_nll_fn(opt, sigma_, x_dash) def _scan_body(carry, _): l, c, p, a_s = carry raw_c = _inv_softplus(jnp.maximum(c, eps)) raw_p = _inv_softplus(jnp.maximum(p, eps)) shape_arr = jnp.array([l, raw_c, raw_p]) shape_arr, a_s = _adam_gradient_step( _copula_nll_shape, shape_arr, a_s, lr ) l = jnp.clip(shape_arr[0], -10.0, 10.0) c = jnp.clip(jnn.softplus(shape_arr[1]) + eps, eps, 100.0) p = jnp.clip(jnn.softplus(shape_arr[2]) + eps, eps, 100.0) return (l, c, p, a_s), None (lamb, chi, psi, adam_state), _ = lax.scan( _scan_body, (lamb, chi, psi, adam_state), None, length=shape_steps, ) return lamb, chi, psi, adam_state _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu_hat = mom_nu_student_t(u, R_inv, d) lamb, chi, psi = mom_gh_params(u, R_inv, d, nu_hat) # --- Outer loop as lax.scan --- gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(3), jnp.zeros(3), jnp.array(0)) def _outer_body(carry, _): lamb, chi, psi, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) gamma, sigma = _run_inner_em( gamma, sigma, x_dash, lamb, chi, psi ) adam_state = _reset_adam_state(adam_state) lamb, chi, psi, adam_state = _run_shape_steps( lamb, chi, psi, gamma, sigma, adam_state, x_dash ) return (lamb, chi, psi, gamma, sigma, adam_state), None (lamb, chi, psi, gamma, sigma, _), _ = lax.scan( _outer_body, (lamb, chi, psi, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) def _fit_copula_ecme_double_gamma( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ): r"""ECME-with-double-γ fitting for the GH copula. Like :py:meth:`_fit_copula_ecme`, but γ is re-optimised in the outer numerical M-step alongside (λ, χ, ψ) — so γ is updated *twice* per outer iteration (once by the inner EM, once by the outer MLE). The inner EM update of γ therefore acts as a warm-start for the outer numerical optimisation. Inner EM updates (Σ, γ); outer MLE optimises (λ, χ, ψ, γ). """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) copula_nll_fn = self._gh_copula_nll_closure(d, mu, eps) # --- JIT: inner EM scan (gamma + sigma) --- @jax.jit def _run_inner_em(gamma, sigma_, x_dash, lamb, chi, psi): def _scan_body(carry, _): g, s = carry g, s = _inner_em_step_gh(g, s, x_dash, lamb, chi, psi, True) return (g, s), None (g, s), _ = lax.scan( _scan_body, (gamma, sigma_), None, length=em_maxiter ) return g, s # --- JIT: outer MLE scan (lamb, chi, psi, gamma) --- @jax.jit def _run_outer_mle(lamb, chi, psi, gamma, sigma_, adam_state, x_dash): def _scan_body(carry, _): l, c, p, g, a_s = carry raw_c = _inv_softplus(jnp.maximum(c, eps)) raw_p = _inv_softplus(jnp.maximum(p, eps)) opt_arr = jnp.concatenate([ jnp.array([l, raw_c, raw_p]), g.flatten() ]) opt_arr, a_s = _adam_gradient_step( lambda arr: copula_nll_fn(arr, sigma_, x_dash), opt_arr, a_s, lr, ) l = jnp.clip(opt_arr[0], -10.0, 10.0) c = jnp.clip(jnn.softplus(opt_arr[1]) + eps, eps, 100.0) p = jnp.clip(jnn.softplus(opt_arr[2]) + eps, eps, 100.0) g = opt_arr[3:].reshape((d, 1)) return (l, c, p, g, a_s), None (lamb, chi, psi, gamma, adam_state), _ = lax.scan( _scan_body, (lamb, chi, psi, gamma, adam_state), None, length=shape_steps, ) return lamb, chi, psi, gamma, adam_state _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu_hat = mom_nu_student_t(u, R_inv, d) lamb, chi, psi = mom_gh_params(u, R_inv, d, nu_hat) # --- Outer loop as lax.scan --- gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(3 + d), jnp.zeros(3 + d), jnp.array(0)) def _outer_body(carry, _): lamb, chi, psi, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) gamma, sigma = _run_inner_em( gamma, sigma, x_dash, lamb, chi, psi ) adam_state = _reset_adam_state(adam_state) lamb, chi, psi, gamma, adam_state = _run_outer_mle( lamb, chi, psi, gamma, sigma, adam_state, x_dash ) return (lamb, chi, psi, gamma, sigma, adam_state), None (lamb, chi, psi, gamma, sigma, _), _ = lax.scan( _outer_body, (lamb, chi, psi, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) def _fit_copula_ecme_outer_gamma( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ): r"""ECME-with-outer-γ fitting for the GH copula. Inner EM updates Σ only (γ is held fixed in the inner step). The outer numerical M-step then optimises (λ, χ, ψ, γ) jointly via MLE. Contrast: - :py:meth:`_fit_copula_ecme` — γ in inner EM; outer MLE on (λ,χ,ψ). - :py:meth:`_fit_copula_ecme_double_gamma` — γ in both inner and outer. - this method — γ in outer only. McNeil's canonical Algorithm 3.14 places γ in the *inner* analytic M-step; this variant defers γ to the numerical outer step alongside the GIG mixing-distribution parameters. """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) copula_nll_fn = self._gh_copula_nll_closure(d, mu, eps) # --- JIT: inner EM scan (sigma only, gamma frozen) --- @jax.jit def _run_inner_em(gamma, sigma_, x_dash, lamb, chi, psi): def _scan_body(carry, _): g, s = carry g, s = _inner_em_step_gh(g, s, x_dash, lamb, chi, psi, False) return (g, s), None (_, s), _ = lax.scan( _scan_body, (gamma, sigma_), None, length=em_maxiter ) return s # --- JIT: outer MLE scan (lamb, chi, psi, gamma) --- @jax.jit def _run_outer_mle(lamb, chi, psi, gamma, sigma_, adam_state, x_dash): def _scan_body(carry, _): l, c, p, g, a_s = carry raw_c = _inv_softplus(jnp.maximum(c, eps)) raw_p = _inv_softplus(jnp.maximum(p, eps)) opt_arr = jnp.concatenate([ jnp.array([l, raw_c, raw_p]), g.flatten() ]) opt_arr, a_s = _adam_gradient_step( lambda arr: copula_nll_fn(arr, sigma_, x_dash), opt_arr, a_s, lr, ) l = jnp.clip(opt_arr[0], -10.0, 10.0) c = jnp.clip(jnn.softplus(opt_arr[1]) + eps, eps, 100.0) p = jnp.clip(jnn.softplus(opt_arr[2]) + eps, eps, 100.0) g = opt_arr[3:].reshape((d, 1)) return (l, c, p, g, a_s), None (lamb, chi, psi, gamma, adam_state), _ = lax.scan( _scan_body, (lamb, chi, psi, gamma, adam_state), None, length=shape_steps, ) return lamb, chi, psi, gamma, adam_state _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu_hat = mom_nu_student_t(u, R_inv, d) lamb, chi, psi = mom_gh_params(u, R_inv, d, nu_hat) # --- Outer loop as lax.scan --- gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(3 + d), jnp.zeros(3 + d), jnp.array(0)) def _outer_body(carry, _): lamb, chi, psi, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) sigma = _run_inner_em(gamma, sigma, x_dash, lamb, chi, psi) adam_state = _reset_adam_state(adam_state) lamb, chi, psi, gamma, adam_state = _run_outer_mle( lamb, chi, psi, gamma, sigma, adam_state, x_dash ) return (lamb, chi, psi, gamma, sigma, adam_state), None (lamb, chi, psi, gamma, sigma, _), _ = lax.scan( _outer_body, (lamb, chi, psi, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) def _fit_copula_mle( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, shape_steps: int = 10, ): r"""Full joint MLE for the GH copula. Optimises all copula parameters (λ, χ, ψ, γ, and the Σ off-diagonals) jointly via gradient descent on the negative copula log-likelihood. Σ is represented via the ``d*(d-1)/2`` free off-diagonal correlations, mapped through ``tanh`` into ``(-1, 1)`` and then projected onto the correlation manifold (unit diagonal + positive-semidefiniteness) each Adam step. Unlike :py:meth:`_fit_copula_fc_mle`, which holds Σ fixed at the Kendall-τ rank-correlation estimate supplied by the :py:meth:`MeanVarianceCopulaBase.fit_copula` dispatcher, this method re-optimises Σ jointly with the shape parameters. """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) tril_rows, tril_cols = jnp.tril_indices(d, k=-1) mvt = self._mvt uvt = self._uvt def _sigma_from_raw(raw_corr): rho = jnp.tanh(raw_corr) P = jnp.eye(d) P = P.at[tril_rows, tril_cols].set(rho) P = P.at[tril_cols, tril_rows].set(rho) P = _corr._rm_incomplete(P, 1e-5) P = _corr._corr_from_cov(P) P = _corr._ensure_valid(P) return P def _raw_from_sigma(sigma_): rho = sigma_[tril_rows, tril_cols] return jnp.arctanh(jnp.clip(rho, -0.999, 0.999)) @jax.jit def _run_mle_steps(opt_arr, adam_state, x_dash): def _copula_nll(arr): l = arr[0] c = jnn.softplus(arr[1]) + eps p = jnn.softplus(arr[2]) + eps g = arr[3:3 + d].reshape((d, 1)) sigma_ = _sigma_from_raw(arr[3 + d:]) copula_p = mvt._params_dict( lamb=l, chi=c, psi=p, mu=mu, gamma=g, sigma=sigma_, ) mvt_ll = mvt.logpdf(x_dash, params=copula_p) uvt_params = { "lamb": jnp.full(d, l), "chi": jnp.full(d, c), "psi": jnp.full(d, p), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": g.flatten(), } uvt_ll = vmap( lambda xi, pr: uvt.logpdf(xi, params=pr), in_axes=(1, 0), out_axes=1, )(x_dash, uvt_params).sum(axis=1, keepdims=True) logpdf = mvt_ll - uvt_ll n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) safe = jnp.where(finite_mask, logpdf, 0.0) return -safe.sum() / n + 1e6 * (~finite_mask).sum() / n def _scan_body(carry, _): arr, a_s = carry arr, a_s = _adam_gradient_step(_copula_nll, arr, a_s, lr) return (arr, a_s), None (opt_arr, adam_state), _ = lax.scan( _scan_body, (opt_arr, adam_state), None, length=shape_steps ) return opt_arr, adam_state # MoM initialization R_inv = jnp.linalg.inv(sigma) nu_hat = mom_nu_student_t(u, R_inv, d) lamb, chi, psi = mom_gh_params(u, R_inv, d, nu_hat) n_corr = d * (d - 1) // 2 n_opt = 3 + d + n_corr gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(n_opt), jnp.zeros(n_opt), jnp.array(0)) _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- Outer loop as lax.scan --- def _outer_body(carry, _): lamb, chi, psi, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) raw_chi = _inv_softplus(jnp.maximum(chi, eps)) raw_psi = _inv_softplus(jnp.maximum(psi, eps)) raw_corr = _raw_from_sigma(sigma) opt_arr = jnp.concatenate([ jnp.array([lamb, raw_chi, raw_psi]), gamma.flatten(), raw_corr, ]) adam_state = _reset_adam_state(adam_state) opt_arr, adam_state = _run_mle_steps(opt_arr, adam_state, x_dash) lamb = jnp.clip(opt_arr[0], -10.0, 10.0) chi = jnp.clip(jnn.softplus(opt_arr[1]) + eps, eps, 100.0) psi = jnp.clip(jnn.softplus(opt_arr[2]) + eps, eps, 100.0) gamma = opt_arr[3:3 + d].reshape((d, 1)) sigma = _sigma_from_raw(opt_arr[3 + d:]) return (lamb, chi, psi, gamma, sigma, adam_state), None (lamb, chi, psi, gamma, sigma, _), _ = lax.scan( _outer_body, (lamb, chi, psi, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( lamb=lamb, chi=chi, psi=psi, mu=mu, gamma=gamma, sigma=sigma, )
gh_copula = GHCopula("GH-Copula", mvt_gh, gh)
[docs] class SkewedTCopula(MeanVarianceCopula): r"""The Skewed-T Copula is a copula that uses the multivariate skewed-T distribution to model the dependencies between random variables. The copula is parameterised by degrees of freedom *ν*, skewness vector *γ*, and correlation matrix *P*. Fitting estimates *P* via Kendall's tau inversion and (*ν*, *γ*) via ML or EM (McNeil et al. 2005, Section 5.5). https://en.wikipedia.org/wiki/Copula_(statistics) """ def _get_uvt_params(self, params: dict) -> dict: """Extract univariate parameters for the skewed-t copula margins.""" d: int = self._get_dim(params) nu: Scalar = params["copula"]["nu"] gamma: Array = params["copula"]["gamma"] return { "nu": jnp.full(d, nu), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": gamma.flatten(), } def _build_initial_copula_params(self, d: int, sigma: Array) -> dict: return self._mvt._params_dict( nu=jnp.array(5.0), mu=jnp.zeros((d, 1)), gamma=jnp.zeros((d, 1)), sigma=sigma, ) def _get_opt_params_and_bounds(self, d: int): # Optimise [raw_nu, gamma_1..gamma_d] raw_nu0 = jnp.log(jnp.expm1(jnp.array(5.0))) params0 = jnp.concatenate([ raw_nu0.reshape((1,)), jnp.zeros(d), ]) n_params = 1 + d proj_opts = { "lower": jnp.full((n_params, 1), -10.0), "upper": jnp.full((n_params, 1), 10.0), } return params0, proj_opts def _reconstruct_copula_opt_params(self, opt_arr, sigma, d): raw_nu = opt_arr[0] nu = jnn.softplus(raw_nu) + _NU_EPS gamma = opt_arr[1:1 + d].reshape((d, 1)) return self._mvt._params_dict( nu=nu, mu=jnp.zeros((d, 1)), gamma=gamma, sigma=sigma, ) def _fit_copula_fc_mle(self, u, sigma, d, lr, maxiter): r"""Fixed-Correlation MLE for the Skewed-T copula. Σ is held at the Stage 1 Kendall-τ estimate; ν and the skewness vector γ (1+d params total) are optimised jointly via :func:`projected_gradient` on the negative copula log-likelihood. The initial vector and box constraints come from :py:meth:`_get_opt_params_and_bounds`. """ return self._optimize_copula_params(u, sigma, d, lr, maxiter) def _st_copula_nll_closure(self, d, mu, eps=_EPS): r"""Build a copula NLL function for the Skewed-T family. Returns ``nll(opt_arr, sigma, x) -> scalar`` where ``opt_arr = [raw_nu, gamma_1..gamma_d]``. """ mvt = self._mvt uvt = self._uvt def _copula_nll(opt_arr, sigma_, x): n_val = jnn.softplus(opt_arr[0]) + eps g = opt_arr[1:].reshape((d, 1)) copula_p = mvt._params_dict( nu=n_val, mu=mu, gamma=g, sigma=sigma_, ) mvt_ll = mvt.logpdf(x, params=copula_p) uvt_params = { "nu": jnp.full(d, n_val), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": g.flatten(), } uvt_ll = vmap( lambda xi, pr: uvt.logpdf(xi, params=pr), in_axes=(1, 0), out_axes=1, )(x, uvt_params).sum(axis=1, keepdims=True) logpdf = mvt_ll - uvt_ll n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) safe = jnp.where(finite_mask, logpdf, 0.0) return -safe.sum() / n + 1e6 * (~finite_mask).sum() / n return _copula_nll def _st_copula_ll(self, d, mu): r"""Build a JIT-compiled copula LL evaluator for convergence monitoring. Returns ``ll(x, nu, gamma, sigma) -> scalar``.""" mvt = self._mvt uvt = self._uvt @jax.jit def _ll(x, nu, gamma, sigma_): copula_p = mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma_, ) mvt_ll = mvt.logpdf(x, params=copula_p) uvt_params = { "nu": jnp.full(d, nu), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": gamma.flatten(), } uvt_ll = vmap( lambda xi, pr: uvt.logpdf(xi, params=pr), in_axes=(1, 0), out_axes=1, )(x, uvt_params).sum(axis=1, keepdims=True) logpdf = mvt_ll - uvt_ll n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) return jnp.where(finite_mask, logpdf, 0.0).sum() / n return _ll def _fit_copula_ecme( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ): r"""ECME fitting for the Skewed-T copula (McNeil §3.2.4 ECME variant). Alternates between inner EM (analytic updates of Σ and γ) and an outer numerical M-step that performs gradient descent on the **original copula log-likelihood** with respect to the remaining shape parameter ν (with Σ and γ held fixed at the inner-EM values). This matches the ECME variant described in McNeil-Frey-Embrechts §3.2.4. """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) copula_nll_fn = self._st_copula_nll_closure(d, mu, eps) # --- JIT: inner EM scan (gamma + sigma) --- @jax.jit def _run_inner_em(gamma, sigma_, x_dash, nu): def _scan_body(carry, _): g, s = carry g, s = _inner_em_step_skewed_t(g, s, x_dash, nu, True) return (g, s), None (g, s), _ = lax.scan( _scan_body, (gamma, sigma_), None, length=em_maxiter ) return g, s # --- JIT: shape CM scan (nu only) --- @jax.jit def _run_shape_steps(nu, gamma, sigma_, adam_state, x_dash): def _copula_nll_nu(raw_nu_arr): opt = jnp.concatenate([raw_nu_arr, gamma.flatten()]) return copula_nll_fn(opt, sigma_, x_dash) def _scan_body(carry, _): n, a_s = carry raw_nu = _inv_softplus(jnp.maximum(n, eps)) raw_arr = raw_nu.reshape((1,)) raw_arr, a_s = _adam_gradient_step( _copula_nll_nu, raw_arr, a_s, lr ) n = jnn.softplus(raw_arr[0]) + eps return (n, a_s), None (nu, adam_state), _ = lax.scan( _scan_body, (nu, adam_state), None, length=shape_steps ) return nu, adam_state _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu = jnp.clip(mom_nu_student_t(u, R_inv, d), 2.5, 60.0) # --- Outer loop as lax.scan (mirrors gh._fit_em / mvt_gh._fit_em) --- gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(1), jnp.zeros(1), jnp.array(0)) def _outer_body(carry, _): nu, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) gamma, sigma = _run_inner_em(gamma, sigma, x_dash, nu) adam_state = _reset_adam_state(adam_state) nu, adam_state = _run_shape_steps( nu, gamma, sigma, adam_state, x_dash ) return (nu, gamma, sigma, adam_state), None (nu, gamma, sigma, _), _ = lax.scan( _outer_body, (nu, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) def _fit_copula_ecme_double_gamma( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ): r"""ECME-with-double-γ fitting for the Skewed-T copula. Like :py:meth:`_fit_copula_ecme`, but γ is re-optimised in the outer numerical M-step alongside ν — so γ is updated *twice* per outer iteration (once by the inner EM, once by the outer MLE). The inner EM update of γ acts as a warm-start for the outer numerical optimisation. Inner EM updates (Σ, γ); outer MLE optimises (ν, γ). """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) copula_nll_fn = self._st_copula_nll_closure(d, mu, eps) # --- JIT: inner EM scan (gamma + sigma) --- @jax.jit def _run_inner_em(gamma, sigma_, x_dash, nu): def _scan_body(carry, _): g, s = carry g, s = _inner_em_step_skewed_t(g, s, x_dash, nu, True) return (g, s), None (g, s), _ = lax.scan( _scan_body, (gamma, sigma_), None, length=em_maxiter ) return g, s # --- JIT: outer MLE scan (nu, gamma) --- @jax.jit def _run_outer_mle(nu, gamma, sigma_, adam_state, x_dash): def _scan_body(carry, _): n, g, a_s = carry raw_nu = _inv_softplus(jnp.maximum(n, eps)) opt_arr = jnp.concatenate([raw_nu.reshape((1,)), g.flatten()]) opt_arr, a_s = _adam_gradient_step( lambda arr: copula_nll_fn(arr, sigma_, x_dash), opt_arr, a_s, lr, ) n = jnn.softplus(opt_arr[0]) + eps g = opt_arr[1:].reshape((d, 1)) return (n, g, a_s), None (nu, gamma, adam_state), _ = lax.scan( _scan_body, (nu, gamma, adam_state), None, length=shape_steps, ) return nu, gamma, adam_state _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu = jnp.clip(mom_nu_student_t(u, R_inv, d), 2.5, 60.0) # --- Outer loop as lax.scan --- gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(1 + d), jnp.zeros(1 + d), jnp.array(0)) def _outer_body(carry, _): nu, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) gamma, sigma = _run_inner_em(gamma, sigma, x_dash, nu) adam_state = _reset_adam_state(adam_state) nu, gamma, adam_state = _run_outer_mle( nu, gamma, sigma, adam_state, x_dash ) return (nu, gamma, sigma, adam_state), None (nu, gamma, sigma, _), _ = lax.scan( _outer_body, (nu, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) def _fit_copula_ecme_outer_gamma( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, em_maxiter: int = 5, shape_steps: int = 10, ): r"""ECME-with-outer-γ fitting for the Skewed-T copula. Inner EM updates Σ only (γ is held fixed in the inner step). The outer numerical M-step then optimises (ν, γ) jointly via MLE. Contrast: - :py:meth:`_fit_copula_ecme` — γ in inner EM; outer MLE on ν. - :py:meth:`_fit_copula_ecme_double_gamma` — γ in both inner and outer. - this method — γ in outer only. """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) copula_nll_fn = self._st_copula_nll_closure(d, mu, eps) # --- JIT: inner EM scan (sigma only, gamma frozen) --- @jax.jit def _run_inner_em(gamma, sigma_, x_dash, nu): def _scan_body(carry, _): g, s = carry g, s = _inner_em_step_skewed_t(g, s, x_dash, nu, False) return (g, s), None (_, s), _ = lax.scan( _scan_body, (gamma, sigma_), None, length=em_maxiter ) return s # --- JIT: outer MLE scan (nu, gamma) --- @jax.jit def _run_outer_mle(nu, gamma, sigma_, adam_state, x_dash): def _scan_body(carry, _): n, g, a_s = carry raw_nu = _inv_softplus(jnp.maximum(n, eps)) opt_arr = jnp.concatenate([raw_nu.reshape((1,)), g.flatten()]) opt_arr, a_s = _adam_gradient_step( lambda arr: copula_nll_fn(arr, sigma_, x_dash), opt_arr, a_s, lr, ) n = jnn.softplus(opt_arr[0]) + eps g = opt_arr[1:].reshape((d, 1)) return (n, g, a_s), None (nu, gamma, adam_state), _ = lax.scan( _scan_body, (nu, gamma, adam_state), None, length=shape_steps, ) return nu, gamma, adam_state _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu = jnp.clip(mom_nu_student_t(u, R_inv, d), 2.5, 60.0) # --- Outer loop as lax.scan --- gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(1 + d), jnp.zeros(1 + d), jnp.array(0)) def _outer_body(carry, _): nu, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) sigma = _run_inner_em(gamma, sigma, x_dash, nu) adam_state = _reset_adam_state(adam_state) nu, gamma, adam_state = _run_outer_mle( nu, gamma, sigma, adam_state, x_dash ) return (nu, gamma, sigma, adam_state), None (nu, gamma, sigma, _), _ = lax.scan( _outer_body, (nu, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) def _fit_copula_mle( self, u, sigma, d, lr, maxiter, brent: bool = False, nodes: int = 100, shape_steps: int = 10, ): r"""Full joint MLE for the Skewed-T copula. Optimises all copula parameters (ν, γ, and the Σ off-diagonals) jointly via gradient descent on the negative copula log-likelihood. Σ is represented via the ``d*(d-1)/2`` free off-diagonal correlations, mapped through ``tanh`` into ``(-1, 1)`` and projected onto the correlation manifold each Adam step. Unlike :py:meth:`_fit_copula_fc_mle`, which holds Σ fixed at the Kendall-τ rank-correlation estimate supplied by the :py:meth:`MeanVarianceCopulaBase.fit_copula` dispatcher, this method re-optimises Σ jointly with the shape parameters. """ eps = _EPS mu = jnp.zeros((d, 1)) dummy_marginals = tuple( (self._uvt, self._uvt.example_params()) for _ in range(d) ) tril_rows, tril_cols = jnp.tril_indices(d, k=-1) mvt = self._mvt uvt = self._uvt def _sigma_from_raw(raw_corr): rho = jnp.tanh(raw_corr) P = jnp.eye(d) P = P.at[tril_rows, tril_cols].set(rho) P = P.at[tril_cols, tril_rows].set(rho) P = _corr._rm_incomplete(P, 1e-5) P = _corr._corr_from_cov(P) P = _corr._ensure_valid(P) return P def _raw_from_sigma(sigma_): rho = sigma_[tril_rows, tril_cols] return jnp.arctanh(jnp.clip(rho, -0.999, 0.999)) @jax.jit def _run_mle_steps(opt_arr, adam_state, x_dash): def _copula_nll(arr): n_val = jnn.softplus(arr[0]) + eps g = arr[1:1 + d].reshape((d, 1)) sigma_ = _sigma_from_raw(arr[1 + d:]) copula_p = mvt._params_dict( nu=n_val, mu=mu, gamma=g, sigma=sigma_, ) mvt_ll = mvt.logpdf(x_dash, params=copula_p) uvt_params = { "nu": jnp.full(d, n_val), "mu": jnp.zeros(d), "sigma": jnp.ones(d), "gamma": g.flatten(), } uvt_ll = vmap( lambda xi, pr: uvt.logpdf(xi, params=pr), in_axes=(1, 0), out_axes=1, )(x_dash, uvt_params).sum(axis=1, keepdims=True) logpdf = mvt_ll - uvt_ll n = logpdf.shape[0] finite_mask = jnp.isfinite(logpdf) safe = jnp.where(finite_mask, logpdf, 0.0) return -safe.sum() / n + 1e6 * (~finite_mask).sum() / n def _scan_body(carry, _): arr, a_s = carry arr, a_s = _adam_gradient_step(_copula_nll, arr, a_s, lr) return (arr, a_s), None (opt_arr, adam_state), _ = lax.scan( _scan_body, (opt_arr, adam_state), None, length=shape_steps ) return opt_arr, adam_state # --- MoM initialization --- R_inv = jnp.linalg.inv(sigma) nu = jnp.clip(mom_nu_student_t(u, R_inv, d), 2.5, 60.0) n_corr = d * (d - 1) // 2 n_opt = 1 + d + n_corr gamma_init = jnp.zeros((d, 1)) adam_init = (jnp.zeros(n_opt), jnp.zeros(n_opt), jnp.array(0)) _get_x_dash_jit = jax.jit( self.get_x_dash, static_argnames=("brent", "nodes") ) # --- Outer loop as lax.scan --- def _outer_body(carry, _): nu, gamma, sigma, adam_state = carry copula_params = self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, ) full_params = { "marginals": dummy_marginals, "copula": copula_params, } x_dash = _get_x_dash_jit( u, full_params, brent=brent, nodes=nodes ) raw_nu = _inv_softplus(jnp.maximum(nu, eps)) raw_corr = _raw_from_sigma(sigma) opt_arr = jnp.concatenate([ raw_nu.reshape((1,)), gamma.flatten(), raw_corr, ]) adam_state = _reset_adam_state(adam_state) opt_arr, adam_state = _run_mle_steps(opt_arr, adam_state, x_dash) nu = jnn.softplus(opt_arr[0]) + eps gamma = opt_arr[1:1 + d].reshape((d, 1)) sigma = _sigma_from_raw(opt_arr[1 + d:]) return (nu, gamma, sigma, adam_state), None (nu, gamma, sigma, _), _ = lax.scan( _outer_body, (nu, gamma_init, sigma, adam_init), None, length=maxiter, ) return self._mvt._params_dict( nu=nu, mu=mu, gamma=gamma, sigma=sigma, )
skewed_t_copula = SkewedTCopula("Skewed-T-Copula", mvt_skewed_t, skewed_t)