"""Correlation and covariance matrix estimation with optional denoising.
Provides 12 correlation estimators (4 base methods x 3 denoising
variants) and corresponding pseudo-covariance estimators, plus random
PSD matrix generation utilities.
Public API:
corr — correlation matrix estimation
cov — covariance matrix estimation
random_correlation — generate a random valid correlation matrix
random_covariance — generate a random valid covariance matrix
"""
import jax.numpy as jnp
from jax import lax, random, jit, vmap
from jax import Array
from jax.typing import ArrayLike
import jax.scipy.stats as stats
import equinox as eqx
from itertools import combinations
from typing import Callable
from copulax._src.univariate._utils import _univariate_input
from copulax._src.typing import Scalar
from copulax._src._utils import _resolve_key
class Correlation(eqx.Module):
r"""Class for computing correlation matrices."""
# Standard correlation matrix implementations
def _ensure_valid(self, A: Array) -> Array:
"""Enforce symmetry and unit diagonal on a correlation matrix."""
lower_triangular: jnp.ndarray = jnp.tril(A)
return jnp.fill_diagonal(
lower_triangular + lower_triangular.T, 1.0, inplace=False
)
def pearson(self, x: jnp.ndarray) -> Array:
r"""Pearson correlation matrix."""
pearson: jnp.ndarray = jnp.corrcoef(x, rowvar=False)
return self._ensure_valid(pearson)
def spearman(self, x: jnp.ndarray) -> Array:
r"""Spearman-rank correlation matrix."""
ranks: jnp.ndarray = stats.rankdata(x, axis=0)
return self.pearson(ranks)
@staticmethod
@jit
def _kendall_pair_vectorized(x_col: jnp.ndarray, y_col: jnp.ndarray) -> Scalar:
r"""Compute Kendall's tau for a single pair of variables.
Uses fully vectorized pairwise concordance via broadcasting,
counting only the upper triangle (:math:`i < j`).
"""
n = x_col.shape[0]
# (n,) -> (n,1) - (1,n) = (n,n) pairwise differences
dx = x_col[:, None] - x_col[None, :]
dy = y_col[:, None] - y_col[None, :]
concordance = jnp.sign(dx) * jnp.sign(dy)
# zero out lower triangle + diagonal, sum upper triangle
mask = jnp.triu(jnp.ones((n, n)), k=1)
return (concordance * mask).sum() * 2.0 / (n * (n - 1))
def kendall(self, x: jnp.ndarray) -> Array:
r"""Kendall's tau correlation matrix.
Vectorized: pairwise concordances are computed via broadcasting
for each dimension pair, then ``vmap`` parallelizes across all
:math:`\binom{d}{2}` pairs.
"""
n, d = x.shape
indices = jnp.array(list(combinations(range(d), 2)))
# Pre-extract column pairs: (num_pairs, n)
cols_i = x[:, indices[:, 0]].T
cols_j = x[:, indices[:, 1]].T
taus = vmap(self._kendall_pair_vectorized)(cols_i, cols_j)
# fill symmetric matrix
kendall = jnp.eye(d)
kendall = kendall.at[indices[:, 0], indices[:, 1]].set(taus)
kendall = kendall.at[indices[:, 1], indices[:, 0]].set(taus)
return self._ensure_valid(kendall)
# Alternative correlation matrix implementations
def pp_kendall(self, x: jnp.ndarray) -> Array:
"""Pseudo-Pearson Kendall correlation matrix.
Note:
This assumes that the data is elliptically distributed and
hence has no skewness. It does however provide a method of
estimating the correlation matrix when variances/covariances
are undefined or infinate.
"""
kendall: jnp.ndarray = self.kendall(x)
pp_kendall: jnp.ndarray = jnp.sin(0.5 * jnp.pi * kendall)
return self._ensure_valid(pp_kendall)
# Rousseeuw and Molenberghs's denoising technique
def _rm_denoising(self, A: jnp.ndarray, delta) -> tuple:
"""Rousseeuw-Molenberghs eigenvalue denoising.
Replaces non-positive eigenvalues with `delta` to ensure
positive semi-definiteness.
Args:
A: Input matrix.
delta: Replacement value for non-positive eigenvalues.
Returns:
Tuple of (clamped eigenvalues, eigenvectors).
"""
eigenvalues, eigenvectors = jnp.linalg.eigh(A)
positive_eigenvalues = jnp.where(eigenvalues > 0.0, eigenvalues, delta)
return positive_eigenvalues.real, eigenvectors.real
def _rm_incomplete(self, A: jnp.ndarray, delta: Scalar) -> Array:
"""Rousseeuw-Molenberghs denoising without enforcing unit diagonal."""
positive_eigenvalues, eigenvectors = self._rm_denoising(A, delta)
new_A: jnp.ndarray = (
eigenvectors @ jnp.diag(positive_eigenvalues) @ eigenvectors.T
)
return new_A
def _rm(self, A: jnp.ndarray, delta: Scalar) -> Array:
"""Full Rousseeuw-Molenberghs denoising with valid correlation output.
Uses diagonal rescaling (Rebonato-Jackel, 1999) to restore unit
diagonal. This is a congruence transformation (D⁻¹AD⁻¹) which
is guaranteed to preserve positive semi-definiteness.
"""
new_A: jnp.ndarray = self._rm_incomplete(A, delta)
return self._corr_from_cov(new_A)
def rm_pearson(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised Pearson correlation matrix via Rousseeuw-Molenberghs."""
return self._rm(self.pearson(x), delta)
def rm_spearman(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised Spearman correlation matrix via Rousseeuw-Molenberghs."""
return self._rm(self.spearman(x), delta)
def rm_kendall(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised Kendall correlation matrix via Rousseeuw-Molenberghs."""
return self._rm(self.kendall(x), delta)
def rm_pp_kendall(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised pseudo-Pearson Kendall matrix via Rousseeuw-Molenberghs."""
return self._rm(self.pp_kendall(x), delta)
# Laloux et al.'s denoising technique
def _laloux(self, x: jnp.ndarray, A: jnp.ndarray, delta: Scalar) -> Array:
"""Laloux et al. random-matrix-theory denoising.
Eigenvalues inside the Marchenko-Pastur bulk are replaced by
their mean, while signal eigenvalues above the bulk upper
bound are preserved.
Args:
x: Input data of shape (n, d), used to compute Q = n/d.
A: Correlation matrix to denoise.
delta: Floor for non-positive eigenvalues.
Returns:
Denoised correlation matrix.
"""
# performing RM denoising
positive_eigenvalues, eigenvectors = self._rm_denoising(A, delta)
# calculating the Bulk
n, d = x.shape
Q: Scalar = n / d
bulk_ub: Scalar = (1 + jnp.pow(Q, -0.5)) ** 2
# replacing eigenvalues with mean
cond: jnp.ndarray = positive_eigenvalues > bulk_ub
k: Scalar = jnp.sum(cond)
denominator: Scalar = jnp.where(d - k > 0, d - k, 1.0)
fill_val: Scalar = (
jnp.where(~cond, positive_eigenvalues, 0.0).sum() / denominator
)
new_eigenvalues: jnp.ndarray = jnp.where(cond, positive_eigenvalues, fill_val)
# reconstructing the matrix
laloux: jnp.ndarray = (
eigenvectors @ jnp.diag(new_eigenvalues) @ eigenvectors.T
)
return self._corr_from_cov(laloux)
def laloux_pearson(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised Pearson correlation matrix via Laloux et al."""
return self._laloux(x, self.pearson(x), delta)
def laloux_spearman(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised Spearman correlation matrix via Laloux et al."""
return self._laloux(x, self.spearman(x), delta)
def laloux_kendall(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised Kendall correlation matrix via Laloux et al."""
return self._laloux(x, self.kendall(x), delta)
def laloux_pp_kendall(self, x: jnp.ndarray, delta: Scalar = 1e-5) -> Array:
"""Denoised pseudo-Pearson Kendall matrix via Laloux et al."""
return self._laloux(x, self.pp_kendall(x), delta)
# helper functions
def _corr_from_cov(self, C: jnp.ndarray) -> Array:
"""Convert covariance matrix to correlation matrix."""
sigma_inv: jnp.ndarray = 1.0 / jnp.sqrt(jnp.diag(C))
R: jnp.ndarray = C * jnp.outer(sigma_inv, sigma_inv)
return R
def _cov_from_vars(self, vars: jnp.ndarray, R: jnp.ndarray) -> Array:
"""Convert variances and correlation matrix to covariance matrix."""
# calculating the diagonal matrix of standard deviations
sigma_diag: jnp.ndarray = jnp.diag(jnp.sqrt(vars.flatten()))
# returning the implied pseudo covariance matrix
return sigma_diag @ R @ sigma_diag
def _cov_from_corr(self, x: jnp.ndarray, R: jnp.ndarray) -> Array:
"""Convert correlation matrix to covariance matrix."""
# calculating the variances of the input data
vars: jnp.ndarray = jnp.var(x, axis=0, ddof=1)
return self._cov_from_vars(vars=vars, R=R)
_corr: Correlation = Correlation()
[docs]
def corr(x: ArrayLike, method: str = "pearson", **kwargs) -> Array:
r"""Compute the correlation matrix of the input data.
Returns a symmetric, positive semi-definite matrix with unit
diagonal and entries in [-1, 1].
Four base estimators are available, each optionally combined with
one of two eigenvalue-denoising techniques:
**Base estimators:**
- ``'pearson'`` — standard linear (Pearson) correlation.
- ``'spearman'`` — Spearman rank correlation (Pearson applied to
ranks).
- ``'kendall'`` — Kendall's tau, a concordance-based rank
correlation. More robust to outliers than Pearson/Spearman.
- ``'pp_kendall'`` — pseudo-Pearson Kendall: converts Kendall's
tau to Pearson via the elliptical identity
:math:`\rho = \sin(\pi \tau / 2)`. Useful when
variances/covariances are undefined or infinite (e.g. heavy-
tailed elliptical distributions).
**Denoising variants** (prefix + base estimator name):
- ``'rm_*'`` — Rousseeuw-Molenberghs (1993) denoising. Clamps
non-positive eigenvalues to ``delta`` (default 1e-5), then
rescales to restore unit diagonal. Guarantees positive semi-
definiteness. Use when the raw estimator may produce a non-PSD
matrix (e.g. Kendall/Spearman on small samples).
- ``'laloux_*'`` — Laloux et al. (1999) random-matrix-theory
denoising. Eigenvalues inside the Marchenko-Pastur noise bulk
are replaced by their mean; signal eigenvalues above the bulk
upper bound :math:`(1 + \sqrt{d/n})^2` are preserved. Use when
n/d is moderate and you want to separate signal from sampling
noise.
Both denoising methods accept a ``delta`` keyword argument
(default 1e-5) controlling the eigenvalue floor.
Args:
x (ArrayLike): Input data of shape ``(n, d)`` where ``n`` is
the number of observations and ``d`` is the number of
variables.
method (str): Correlation method. One of ``'pearson'``,
``'spearman'``, ``'kendall'``, ``'pp_kendall'``,
``'rm_pearson'``, ``'rm_spearman'``, ``'rm_kendall'``,
``'rm_pp_kendall'``, ``'laloux_pearson'``,
``'laloux_spearman'``, ``'laloux_kendall'``,
``'laloux_pp_kendall'``.
**kwargs: Passed to the underlying method (e.g. ``delta``
for denoised variants).
Returns:
Array: Correlation matrix of shape ``(d, d)``.
Raises:
ValueError: If ``method`` is not a recognised method name.
Note:
If you intend to jit wrap this function, ensure that ``method``
is a static argument.
"""
method: str = method.lower().strip()
func: Callable = getattr(_corr, method, None)
if func is None:
raise ValueError(
f"Unknown correlation method '{method}'."
)
return func(x=x, **kwargs)
[docs]
def cov(x: ArrayLike, method: str = "pearson", **kwargs) -> Array:
r"""Compute the covariance matrix of the input data.
Constructs the covariance matrix as
:math:`\Sigma = D \, R \, D` where :math:`R` is the correlation
matrix from :func:`corr` and :math:`D = \text{diag}(\hat\sigma)`
is the diagonal matrix of sample standard deviations (``ddof=1``).
When ``method='pearson'`` this is equivalent to the standard
sample covariance matrix (i.e. ``numpy.cov(x, rowvar=False)``).
For non-Pearson methods the result is a *pseudo-covariance*:
sample variances combined with an alternative correlation
estimator.
Args:
x (ArrayLike): Input data of shape ``(n, d)`` where ``n`` is
the number of observations and ``d`` is the number of
variables.
method (str): Correlation method passed to :func:`corr`.
See :func:`corr` for available options.
**kwargs: Passed to :func:`corr` (e.g. ``delta`` for denoised
variants).
Returns:
Array: Covariance matrix of shape ``(d, d)``.
Raises:
ValueError: If ``method`` is not a recognised method name.
Note:
If you intend to jit wrap this function, ensure that ``method``
is a static argument.
"""
# calculating correlation matrix
corr_matrix: jnp.ndarray = corr(x=x, method=method, **kwargs)
# returning the implied pseudo covariance matrix
return _corr._cov_from_corr(x=x, R=corr_matrix)
[docs]
def random_correlation(size: int, key: Array = None) -> Array:
r"""Generate a random positive-definite correlation matrix.
Produces a symmetric matrix with unit diagonal, entries in
[-1, 1], and strictly positive eigenvalues. Useful for testing,
simulation, and initialisation of multivariate models.
Uses the factors method: :math:`C = W W^\top + D` where
:math:`W \sim \text{Uniform}(-1, 1)^{d \times d}` and
:math:`D` is diagonal with entries in [0, 1]. The PSD matrix
:math:`C` is then rescaled to a correlation matrix via
:math:`R_{ij} = C_{ij} / \sqrt{C_{ii} C_{jj}}`.
Args:
size (int): Dimension ``d`` of the ``(d, d)`` output matrix.
key (jax.random.PRNGKey, optional): JAX PRNG key. If ``None``,
a key is generated automatically.
Returns:
Array: Random correlation matrix of shape ``(size, size)``.
Note:
If you intend to jit wrap this function, ensure that ``size``
is a static argument.
"""
key = _resolve_key(key)
# generating random covariance matrix
key, subkey = random.split(key)
W: Array = random.uniform(key=key, shape=(size, size), minval=-1.0, maxval=1.0)
D: Array = jnp.diag(
random.uniform(key=subkey, shape=(size,), minval=0.0, maxval=1.0)
)
C: Array = W @ W.T + D
# converting covariance matrix to correlation matrix
R: Array = _corr._corr_from_cov(C=C)
return R
[docs]
def random_covariance(vars: Array, key: Array = None) -> Array:
r"""Generate a random positive-definite covariance matrix with
prescribed variances.
Constructs :math:`\Sigma = D \, R \, D` where :math:`R` is a
random correlation matrix from :func:`random_correlation` and
:math:`D = \text{diag}(\sqrt{\text{vars}})`. The diagonal of the
output equals the input ``vars``.
Args:
vars (Array): Variances of each variable. A 1-d array of
length ``d``; the output shape will be ``(d, d)``.
key (jax.random.PRNGKey, optional): JAX PRNG key. If ``None``,
a key is generated automatically.
Returns:
Array: Random covariance matrix of shape ``(d, d)``.
"""
key = _resolve_key(key)
# we could simply use the same approach as in random_correlation,
# to generate the covariance matrix C. However, whilst this would
# be more efficient and would negate the need for the vars argument,
# the scale of the covariances in C can become large and disjoint
# from any relevant data distribution.
vars, _ = _univariate_input(vars)
R: Array = random_correlation(size=vars.size, key=key)
return _corr._cov_from_vars(vars=vars, R=R)