Source code for copulax._src.univariate.univariate_fitter

"""contains the copulAX implementation of a univariate fitter object."""

import jax.numpy as jnp
from jax import jit, lax, vmap
from typing import Iterable
from functools import partial

from copulax._src.univariate._registry import _dist_tree, _registry
from copulax._src.typing import Scalar
from copulax._src._distributions import Univariate
from copulax._src.univariate._gof import ks_test, cvm_test


_GOF_FUNCS = {"ks": ks_test, "cvm": cvm_test}

_DIST_REGISTRY: tuple = _registry
_MAX_DISTS: int = len(_DIST_REGISTRY)
_MAX_PARAMS: int = max(len(d.example_params()) for d in _DIST_REGISTRY)
_DIST_NAME_TO_INDEX: dict = {d.name: i for i, d in enumerate(_DIST_REGISTRY)}


def _get_dist_objects(dists: Iterable | str) -> tuple:
    """Resolve distribution specifier (string or iterable) to a tuple of Univariate objects."""
    if isinstance(dists, str):
        dists: str = dists.lower().strip()
        if dists not in (
            "all",
            "common",
            "continuous",
            "discrete",
            "common continuous",
            "common discrete",
        ):
            raise ValueError(
                f"Invalid value for 'dists' argument: {dists}."
                "If a string, dists must be one of 'all', "
                "'common', 'continuous', 'discrete', "
                "'common continuous' or 'common disrete'."
            )

        elif dists == "all":
            dists_objs: tuple = (
                *_dist_tree["continuous"].values(),
                *_dist_tree["discrete"].values(),
            )
        elif dists in ("common continuous", "common discrete"):
            dists_objs: tuple = tuple(_dist_tree["common"][dists.split()[-1]].values())
        elif dists == "common":
            dists_objs: tuple = tuple(
                (
                    *_dist_tree["common"]["continuous"].values(),
                    *_dist_tree["common"]["discrete"].values(),
                )
            )
        else:
            dists_objs: tuple = tuple(_dist_tree[dists].values())

    elif isinstance(dists, Iterable):
        dists_objs: tuple = tuple(dists)
        for dist in dists:
            if not isinstance(dist, Univariate):
                raise ValueError(
                    f"Invalid distribution object provided "
                    f"within 'dists' iterable: {dist}. "
                    f"Distribution objects must be univariate "
                    "copulax distributions."
                )
    else:
        raise ValueError(
            f"Invalid value for 'dists' argument: {dists}. "
            "Dists must be a string or an iterable of "
            "copulAX distribution objects."
        )

    return dists_objs


def _dist_to_indices(dists_objs: tuple) -> jnp.ndarray:
    """Map distribution objects to their integer indices in _DIST_REGISTRY."""
    return jnp.array([_DIST_NAME_TO_INDEX[d.name] for d in dists_objs], dtype=jnp.int32)


# ── Branch factory ─────────────────────────────────────────────────────────
def _make_branches(metric: str, gof_test: str | None):
    """Build one branch function per registered distribution.

    Each branch has an identical return pytree so that ``lax.switch`` can
    dispatch across them.  The *metric* and *gof_test* strings are
    captured in closures (resolved at trace time).
    """
    gof_func = _GOF_FUNCS.get(gof_test)

    branches = []
    for dist in _DIST_REGISTRY:

        def _branch(x, _dist=dist):
            fitted = _dist.fit(x)
            params = fitted.params
            params_arr = _dist._padded_params_to_array(params, max_params=_MAX_PARAMS)
            metric_val = getattr(_dist, metric)(x=x, params=params)

            if gof_func is not None:
                gof_result = gof_func(x=x, dist=_dist, params=params)
                gof_stat = gof_result["statistic"]
                gof_pval = gof_result["p_value"]
            else:
                gof_stat = jnp.nan
                gof_pval = jnp.nan

            return params_arr, metric_val, gof_stat, gof_pval

        branches.append(_branch)

    return branches


# ── Core implementation ────────────────────────────────────────────────────
def _core_impl(
    x, dist_indices, active_mask, significance_level, branches, ascending, has_gof
):
    """Fit all distributions, score, filter and rank — fully on-device.

    Not JIT-decorated so it can be composed with ``vmap``.

    Args:
        x: Input data array, shape (n,).
        dist_indices: Integer indices into _DIST_REGISTRY, shape (MAX_DISTS,).
        active_mask: Boolean mask, shape (MAX_DISTS,).
        significance_level: GoF p-value threshold.
        branches: Tuple of branch callables (static).
        ascending: Whether lower metric is better (static).
        has_gof: Whether GoF filtering is active (static).

    Returns:
        Tuple of (sorted_order, params_arrs, metrics, gof_stats,
        gof_pvals, final_mask, n_pass).
    """

    def _fit_one(idx):
        return lax.switch(idx, branches, x)

    # lax.map applies _fit_one sequentially over the leading axis
    all_results = lax.map(_fit_one, dist_indices)
    params_arrs, metrics, gof_stats, gof_pvals = all_results

    # GoF filtering
    if has_gof:
        gof_passed = gof_pvals >= significance_level
    else:
        gof_passed = jnp.ones(_MAX_DISTS, dtype=jnp.bool_)

    final_mask = active_mask & gof_passed & jnp.isfinite(metrics)

    # Fill inactive / failed slots with sentinel for sorting
    sentinel = jnp.where(ascending, jnp.inf, -jnp.inf)
    scored = jnp.where(final_mask, metrics, sentinel)

    # Sort (best first)
    if ascending:
        order = jnp.argsort(scored)
    else:
        order = jnp.argsort(-scored)
    n_pass = jnp.sum(final_mask)

    return order, params_arrs, metrics, gof_stats, gof_pvals, final_mask, n_pass


# ── JIT core (single variable) ────────────────────────────────────────────
@partial(jit, static_argnames=("branches", "ascending", "has_gof"))
def _jit_core(
    x, dist_indices, active_mask, significance_level, branches, ascending, has_gof
):
    """JIT-compiled wrapper around ``_core_impl`` for a single variable."""
    return _core_impl(
        x,
        dist_indices,
        active_mask,
        significance_level,
        branches,
        ascending,
        has_gof,
    )


# ── JIT core (batched across variables) ───────────────────────────────────
@partial(jit, static_argnames=("branches", "ascending", "has_gof"))
def _batched_jit_core(
    x_batch,
    dist_indices,
    active_mask,
    significance_level,
    branches,
    ascending,
    has_gof,
):
    """``vmap``-ed version of :func:`_core_impl` over the leading axis of
    *x_batch* (shape ``(d, n)``).  All other arguments are broadcast."""

    def _single(xi):
        return _core_impl(
            xi,
            dist_indices,
            active_mask,
            significance_level,
            branches,
            ascending,
            has_gof,
        )

    return vmap(_single)(x_batch)


# ── Public API ─────────────────────────────────────────────────────────────
[docs] def univariate_fitter( x: jnp.ndarray, metric: str = "bic", distributions: Iterable | str = "common continuous", gof_test: str | None = None, significance_level: float = 0.05, ) -> tuple: r"""Find and fit the 'best' univariate distribution to the input data according to a specified metric. The implementation is fully JIT-compiled: a single ``jax.jit``-traced function fits every registered distribution via ``lax.switch``, computes the chosen metric, optionally applies a goodness-of-fit filter, and sorts the results — all on-device in one XLA graph. After the first call the compiled graph is cached, so subsequent calls (e.g. once per marginal variable in a copula) execute at near-zero Python overhead. Args: x (ArrayLike): The input data to fit a distribution to. metric (str): The metric to use when selecting the 'best' distribution. Must be one of 'aic', 'bic' or 'loglikelihood'. Default is 'bic'. distributions (Iterable | str): The distribution(s) to fit to the data. If a string, must be one of 'all', 'common', 'continuous', 'discrete', 'common continuous' or 'common discrete'. If an iterable, must contain copulAX distribution objects. Default is 'common continuous'. gof_test (str | None): Optional goodness-of-fit test to apply after fitting. One of 'ks' (Kolmogorov-Smirnov), 'cvm' (Cramér-von Mises), or None (no test). When set, distributions that fail the test at the given *significance_level* are removed from the results. Default is None. significance_level (float): Significance level for the goodness-of-fit test. Distributions with a p-value below this threshold are removed. Only used when *gof_test* is not None. Default is 0.05. Returns: res (tuple): The index of the best distribution fit (always 0) and a tuple of fitted distribution results sorted by the metric (best first). Each result is a dict with keys 'params', 'metric', 'dist', and optionally 'gof'. Returns ``(None, ())`` if all distributions are filtered out by the goodness-of-fit test. Examples: >>> import jax.numpy as jnp >>> import numpy as np >>> from copulax.univariate import univariate_fitter >>> x = np.random.normal(0, 1, 100) >>> univariate_fitter(x) >>> univariate_fitter(x, gof_test='ks', significance_level=0.05) """ # ── Validation (Python-side, not traced) ── dists_objs = _get_dist_objects(distributions) if metric not in ("aic", "bic", "loglikelihood"): raise ValueError( f"Invalid value for 'metric' argument: {metric}. " "Must be one of 'aic', 'bic' or 'loglikelihood'." ) if gof_test is not None and gof_test not in _GOF_FUNCS: raise ValueError( f"Invalid value for 'gof_test' argument: {gof_test}. " "Must be one of 'ks', 'cvm' or None." ) ascending = metric != "loglikelihood" has_gof = gof_test is not None n_active = len(dists_objs) # ── Map distributions → fixed-size index & mask arrays ── raw_indices = _dist_to_indices(dists_objs) # Pad to _MAX_DISTS with index 0 (inactive slots are masked out) dist_indices = jnp.zeros(_MAX_DISTS, dtype=jnp.int32).at[:n_active].set(raw_indices) active_mask = jnp.arange(_MAX_DISTS) < n_active # ── Build branches & call JIT core ── branches = tuple(_make_branches(metric, gof_test)) order, params_arrs, metrics, gof_stats, gof_pvals, final_mask, n_pass = _jit_core( x=jnp.asarray(x, dtype=float), dist_indices=dist_indices, active_mask=active_mask, significance_level=jnp.asarray(significance_level, dtype=float), branches=branches, ascending=ascending, has_gof=has_gof, ) # ── Reconstruct Python result dicts from JIT output ── if has_gof and int(n_pass) == 0: return None, () output = [] for i in range(_MAX_DISTS): idx = int(order[i]) if not bool(final_mask[idx]): continue dist = _DIST_REGISTRY[int(dist_indices[idx])] n_p = dist.n_params keys = tuple(dist.example_params().keys()) params = dist._args_transform( {k: params_arrs[idx, j] for j, k in enumerate(keys)} ) result = { "params": params, "metric": metrics[idx], "dist": dist, } if has_gof: result["gof"] = { "statistic": gof_stats[idx], "p_value": gof_pvals[idx], } output.append(result) return 0, tuple(output)
# ── Batched public API ─────────────────────────────────────────────────────
[docs] def batch_univariate_fitter( x: jnp.ndarray, metric: str = "bic", distributions: Iterable | str = "common continuous", gof_test: str | None = None, significance_level: float = 0.05, ) -> list[tuple]: r"""Fit univariate distributions to every column of *x* simultaneously. Equivalent to calling :func:`univariate_fitter` on each column, but uses ``jax.vmap`` to process all columns in a **single device call**, which is significantly faster for multi-dimensional data. Args: x (ArrayLike): Input data of shape ``(n, d)``. metric (str): Selection metric — ``'aic'``, ``'bic'``, or ``'loglikelihood'``. Default ``'bic'``. distributions (Iterable | str): Distributions to try (same as :func:`univariate_fitter`). gof_test (str | None): Optional goodness-of-fit test (``'ks'``, ``'cvm'``, or ``None``). significance_level (float): GoF p-value threshold (default 0.05). Returns: list[tuple]: One ``(best_index, fitted)`` tuple per column, in the same format as :func:`univariate_fitter`. """ # ── Shared validation (done once) ── dists_objs = _get_dist_objects(distributions) if metric not in ("aic", "bic", "loglikelihood"): raise ValueError( f"Invalid value for 'metric' argument: {metric}. " "Must be one of 'aic', 'bic' or 'loglikelihood'." ) if gof_test is not None and gof_test not in _GOF_FUNCS: raise ValueError( f"Invalid value for 'gof_test' argument: {gof_test}. " "Must be one of 'ks', 'cvm' or None." ) ascending = metric != "loglikelihood" has_gof = gof_test is not None n_active = len(dists_objs) raw_indices = _dist_to_indices(dists_objs) dist_indices = jnp.zeros(_MAX_DISTS, dtype=jnp.int32).at[:n_active].set(raw_indices) active_mask = jnp.arange(_MAX_DISTS) < n_active branches = tuple(_make_branches(metric, gof_test)) # ── Single batched device call ── x_batch = jnp.asarray(x, dtype=float).T # (d, n) ( orders, params_arrs_all, metrics_all, gof_stats_all, gof_pvals_all, final_masks, n_passes, ) = _batched_jit_core( x_batch=x_batch, dist_indices=dist_indices, active_mask=active_mask, significance_level=jnp.asarray(significance_level, dtype=float), branches=branches, ascending=ascending, has_gof=has_gof, ) # ── Reconstruct per-column Python result dicts ── d = x_batch.shape[0] results = [] for dim in range(d): order = orders[dim] params_arrs = params_arrs_all[dim] metrics = metrics_all[dim] final_mask = final_masks[dim] n_pass = n_passes[dim] if has_gof and int(n_pass) == 0: results.append((None, ())) continue output = [] for i in range(_MAX_DISTS): idx = int(order[i]) if not bool(final_mask[idx]): continue dist = _DIST_REGISTRY[int(dist_indices[idx])] keys = tuple(dist.example_params().keys()) params = dist._args_transform( {k: params_arrs[idx, j] for j, k in enumerate(keys)} ) result = { "params": params, "metric": metrics[idx], "dist": dist, } if has_gof: result["gof"] = { "statistic": gof_stats_all[dim, idx], "p_value": gof_pvals_all[dim, idx], } output.append(result) results.append((0, tuple(output))) return results