Source code for copulax._src.univariate.lognormal

"""File containing the copulAX implementation of the lognormal distribution."""

import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

from copulax._src._distributions import Univariate
from copulax._src.typing import Scalar
from copulax._src.univariate._utils import _univariate_input
from copulax._src._utils import _resolve_key
from copulax._src.univariate.normal import normal


[docs] class LogNormal(Univariate): r"""The log-normal distribution on :math:`(0, \infty)` describes a positive variate :math:`X` whose logarithm :math:`Y = \log X` is normally distributed with mean :math:`\mu` and standard deviation :math:`\sigma`. Two-parameter continuous family. The PDF is .. math:: f(x | \mu, \sigma) = \frac{1}{x \sigma \sqrt{2\pi}} \exp\!\left(-\frac{(\log x - \mu)^2}{2 \sigma^2}\right), \qquad x > 0 where :math:`\mu \in \mathbb{R}` and :math:`\sigma > 0` are the mean and standard deviation **of the underlying normal** :math:`\log X` (not of :math:`X` itself; the mean of :math:`X` is :math:`\exp(\mu + \sigma^2 / 2)`). https://en.wikipedia.org/wiki/Log-normal_distribution """ mu: Array = None sigma: Array = None def __init__(self, name="LogNormal", *, mu=None, sigma=None): """Initialize the LogNormal distribution. Args: name: Display name for the distribution. mu: Mean of the underlying normal distribution. sigma: Standard deviation of the underlying normal distribution. """ super().__init__(name) self.mu = jnp.asarray(mu, dtype=float).reshape(()) if mu is not None else None self.sigma = ( jnp.asarray(sigma, dtype=float).reshape(()) if sigma is not None else None ) @property def _stored_params(self): """Return stored parameters if all are set, else None.""" if self.mu is None or self.sigma is None: return None return {"mu": self.mu, "sigma": self.sigma} def _params_to_tuple(self, params: dict): """Extract (mu, sigma) from the parameter dictionary.""" return normal._params_to_tuple(params)
[docs] def example_params(self, *args, **kwargs): return normal.example_params()
@classmethod def _support(cls, *args, **kwargs) -> Array: """Return the support ``[0, inf)``.""" return jnp.array([0.0, jnp.inf])
[docs] def logpdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the log-PDF by transforming to the underlying normal.""" params = self._resolve_params(params) x, xshape = _univariate_input(x) x = x.reshape(xshape) logpdf = normal.logpdf(x=jnp.log(x), params=params) - jnp.log(x) return self._enforce_support_on_logpdf(x=x, logpdf=logpdf, params=params)
[docs] def logcdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the log-CDF by transforming to the underlying normal.""" params = self._resolve_params(params) return normal.logcdf(x=jnp.log(x), params=params)
[docs] def cdf(self, x: ArrayLike, params: dict = None) -> Array: """Compute the CDF by transforming to the underlying normal.""" params = self._resolve_params(params) cdf = normal.cdf(x=jnp.log(x), params=params) return self._enforce_support_on_cdf(x=x, cdf=cdf, params=params)
# ppf def _ppf(self, q: ArrayLike, params: dict, *args, **kwargs) -> Array: """Compute the PPF as ``exp(normal_ppf(q))``.""" return jnp.exp(normal._ppf(q=q, params=params, *args, **kwargs)) # sampling
[docs] def rvs( self, size: tuple | Scalar, params: dict = None, key: Array = None ) -> Array: """Generate random variates as ``exp(normal_rvs)``.""" params = self._resolve_params(params) key = _resolve_key(key) return jnp.exp(normal.rvs(size=size, key=key, params=params))
# stats
[docs] def stats(self, params: dict = None) -> dict: """Compute distribution statistics (mean, median, mode, variance, std, skewness, kurtosis).""" params = self._resolve_params(params) mu, sigma = self._params_to_tuple(params) mean: float = jnp.exp(mu + jnp.pow(sigma, 2) / 2) median: float = jnp.exp(mu) mode: float = jnp.exp(mu - jnp.pow(sigma, 2)) variance: float = (jnp.exp(jnp.pow(sigma, 2)) - 1) * jnp.exp( 2 * mu + jnp.pow(sigma, 2) ) std: float = jnp.sqrt(variance) skewness: float = (jnp.exp(jnp.pow(sigma, 2)) + 2) * jnp.sqrt( jnp.exp(jnp.pow(sigma, 2)) - 1 ) kurtosis: float = ( jnp.exp(4 * jnp.pow(sigma, 2)) + 2 * jnp.exp(3 * jnp.pow(sigma, 2)) + 3 * jnp.exp(2 * jnp.pow(sigma, 2)) - 6 ) return self._scalar_transform( { "mean": mean, "median": median, "mode": mode, "variance": variance, "std": std, "skewness": skewness, "kurtosis": kurtosis, } )
# fitting _supported_methods = frozenset({"mle"})
[docs] def fit(self, x: ArrayLike, *args, name: str = None, **kwargs): r"""Fit by applying the normal **closed-form** MLE to ``log(x)``. Delegates to :meth:`Normal.fit` on the log-transformed data, which has no tuning parameters. Args: x: Input data to fit (must be positive). name: Optional custom name for the fitted instance. Returns: LogNormal: A fitted ``LogNormal`` instance. """ fitted_normal = normal.fit(jnp.log(x)) return self._fitted_instance(fitted_normal.params, name=name)
lognormal = LogNormal("LogNormal")