Source code for copulax._src.univariate.uniform
"""File containing the copulAX implementation of the continuous uniform
distribution."""
import jax.numpy as jnp
from jax import lax, random
from jax import Array
from jax.typing import ArrayLike
from copulax._src._distributions import Univariate
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.typing import Scalar
[docs]
class Uniform(Univariate):
r"""The continuous uniform distribution.
The continuous uniform distribution is defined as:
.. math::
f(x|a, b) = \frac{1}{b - a}
where :math:`a` is the lower bound of the distribution and :math:`b` is the
upper bound.
"""
a: Array = None
b: Array = None
def __init__(self, name="Uniform", *, a=None, b=None):
"""Initialize the Uniform distribution.
Args:
name: Display name for the distribution.
a: Lower bound of the interval.
b: Upper bound of the interval.
"""
super().__init__(name)
self.a = jnp.asarray(a, dtype=float).reshape(()) if a is not None else None
self.b = jnp.asarray(b, dtype=float).reshape(()) if b is not None else None
@property
def _stored_params(self):
"""Return stored parameters if all are set, else None."""
if self.a is None or self.b is None:
return None
return {"a": self.a, "b": self.b}
@classmethod
def _params_dict(cls, a: Scalar, b: Scalar) -> dict:
"""Create a parameter dictionary from lower bound ``a`` and upper bound ``b``."""
d: dict = {"a": a, "b": b}
return cls._args_transform(d)
@classmethod
def _params_to_tuple(cls, params: dict) -> tuple:
"""Extract (a, b) from the parameter dictionary."""
params = cls._args_transform(params)
return params["a"], params["b"]
@classmethod
def _support(cls, params: dict) -> Array:
"""Return the support ``[a, b]`` from the given parameters."""
a, b = cls._params_to_tuple(params)
return jnp.array([a, b])
[docs]
def logpdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Compute the log probability density function.
Returns ``log(1 / (b - a))`` inside the support, ``-inf`` outside.
"""
params = self._resolve_params(params)
x, xshape = _univariate_input(x)
a, b = self._params_to_tuple(params)
log_pdf: jnp.ndarray = -jnp.log(lax.sub(b, a))
log_pdf = jnp.where(jnp.logical_and(x >= a, x <= b), log_pdf, -jnp.inf)
return log_pdf.reshape(xshape)
[docs]
def logcdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Compute the log cumulative distribution function."""
params = self._resolve_params(params)
return jnp.log(self.cdf(x=x, params=params))
[docs]
def cdf(self, x: ArrayLike, params: dict = None) -> Array:
"""Compute the cumulative distribution function."""
params = self._resolve_params(params)
x, xshape = _univariate_input(x)
a, b = self._params_to_tuple(params)
cdf: jnp.ndarray = (x - a) / (b - a)
return self._enforce_support_on_cdf(
x=x, cdf=cdf.reshape(xshape), params=params
)
# ppf
def _ppf(self, q: ArrayLike, params: dict, *args, **kwargs) -> Array:
"""Compute the percent-point function (inverse CDF) via linear interpolation."""
q, qshape = _univariate_input(q)
a, b = self._params_to_tuple(params)
ppf_values: jnp.ndarray = lax.add(a, lax.mul(q, lax.sub(b, a)))
ppf_values = jnp.where(jnp.logical_and(q >= 0, q <= 1), ppf_values, jnp.nan)
return ppf_values.reshape(qshape)
# sampling
[docs]
def rvs(self, size: tuple | Scalar, params: dict = None, key=None) -> Array:
"""Generate random variates from the uniform distribution.
Args:
size: Shape of the output array.
params: Distribution parameters. Uses stored parameters if None.
key: JAX PRNG key. A default key is used if None.
Returns:
Array of random samples in ``[a, b)``.
"""
params = self._resolve_params(params)
key = _resolve_key(key)
a, b = self._params_to_tuple(params)
return random.uniform(key=key, shape=size, minval=a, maxval=b)
# stats
[docs]
def stats(self, params: dict = None) -> dict:
"""Compute distribution statistics (mean, median, variance, std, skewness, kurtosis)."""
params = self._resolve_params(params)
a, b = self._params_to_tuple(params)
mean: Scalar = (a + b) / 2
variance: Scalar = lax.pow(b - a, 2) / 12
std: Scalar = jnp.sqrt(variance)
return self._scalar_transform(
{
"mean": mean,
"median": mean,
"variance": variance,
"std": std,
"skewness": 0.0,
"kurtosis": -6 / 5,
}
)
# fitting
_supported_methods = frozenset({"mle"})
[docs]
def fit(self, x: ArrayLike, *args, name: str = None, **kwargs):
r"""Fit the distribution to data via **closed-form** MLE:
``â = min(x)``, ``b̂ = max(x)``.
The closed-form estimator takes no tuning parameters.
Args:
x: Input data to fit.
name: Optional custom name for the fitted instance.
Returns:
Uniform: A fitted ``Uniform`` instance.
"""
x: jnp.ndarray = _univariate_input(x)[0]
a: Scalar = jnp.min(x)
b: Scalar = jnp.max(x)
return self._fitted_instance(self._params_dict(a=a, b=b), name=name)
uniform = Uniform("Uniform")