"""Save and load fitted copulAX distribution and preprocessing objects.
Fitted objects are saved as ``.cpx`` files — ZIP archives containing a
human-readable ``metadata.json`` and binary NumPy ``.npy`` arrays for
each parameter. The format is cross-platform (Windows, macOS, Linux),
requires no additional dependencies, and avoids ``pickle``.
Supported families:
* Univariate, multivariate, and copula distributions (legacy behaviour).
* ``copulax.preprocessing`` objects such as :class:`DataScaler`, which
additionally may carry user-supplied callables on ``pre_fns`` /
``post_fns``. Those callables are serialised by import-path qualname —
lambdas and locally-defined functions are rejected at save time.
"""
import importlib
import io
import json
import warnings
import zipfile
from pathlib import Path
import numpy as np
import jax.numpy as jnp
# ---------------------------------------------------------------------------
# Registry lookup
# ---------------------------------------------------------------------------
def _get_singleton(class_name: str):
"""Look up an unparameterized template singleton by its class name.
Searches the univariate, multivariate, and copula registries in that
order.
Args:
class_name: Python class name (e.g. ``"Normal"``, ``"MvtNormal"``).
Returns:
An unparameterized distribution instance.
Raises:
ValueError: If *class_name* is not found in any registry.
"""
from copulax._src.univariate._registry import _registry as uvt_registry
for dist in uvt_registry:
if type(dist).__name__ == class_name:
return dist
from copulax._src.multivariate._registry import _registry as mvt_registry
for dist in mvt_registry:
if type(dist).__name__ == class_name:
return dist
from copulax._src.copulas._registry import _registry as cop_registry
for dist in cop_registry:
if type(dist).__name__ == class_name:
return dist
raise ValueError(
f"Unknown distribution class: {class_name!r}. "
"Ensure the distribution is registered in copulax."
)
# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------
def _save_distribution(dist, path) -> None:
"""Serialize a fitted distribution to a ``.cpx`` file.
Args:
dist: A fitted ``Distribution`` instance (must have stored
parameters).
path: Destination file path. The ``.cpx`` extension is appended
automatically when missing.
Raises:
ValueError: If the distribution has no stored parameters.
"""
path = Path(path)
if path.suffix != ".cpx":
path = path.with_suffix(path.suffix + ".cpx")
params = dist._stored_params
if params is None:
raise ValueError(
"Cannot save an unfitted distribution (no parameters set). "
"Fit the distribution first via .fit()."
)
metadata: dict = {
"dist_family": dist.dist_type,
"dist_dtype": dist.dtype,
"dist_class": type(dist).__name__,
"dist_name": dist.name,
}
arrays: dict[str, np.ndarray] = {}
if dist.dist_type in ("univariate", "multivariate"):
metadata["params"] = {}
for key, val in params.items():
arr = np.asarray(val)
arrays[key] = arr
metadata["params"][key] = {
"shape": list(arr.shape),
"dtype": str(arr.dtype),
}
elif dist.dist_type == "copula":
from copulax._src.copulas._mv_copulas import (
EllipticalCopula,
MeanVarianceCopula,
MeanVarianceCopulaBase,
)
# Tag with the most specific known taxonomic category. The
# serialiser-side reader only needs to distinguish "is this a
# mean-variance / elliptical-style copula" from "Archimedean",
# but we record the finer grain too for forward compatibility.
if isinstance(dist, EllipticalCopula):
metadata["copula_type"] = "elliptical"
elif isinstance(dist, MeanVarianceCopula):
metadata["copula_type"] = "mean_variance"
elif isinstance(dist, MeanVarianceCopulaBase):
# Future MV-base subclass that's neither Elliptical nor
# MeanVariance — fall back to the umbrella label.
metadata["copula_type"] = "mean_variance_base"
else:
metadata["copula_type"] = "archimedean"
# Copula parameters
copula_params = params["copula"]
metadata["copula_params"] = {}
for key, val in copula_params.items():
arr = np.asarray(val)
arrays[f"copula__{key}"] = arr
metadata["copula_params"][key] = {
"shape": list(arr.shape),
"dtype": str(arr.dtype),
}
# Marginal distributions
marginals = params["marginals"]
metadata["marginals"] = []
for i, (marginal_dist, marginal_params) in enumerate(marginals):
m_meta: dict = {
"dist_class": type(marginal_dist).__name__,
"params": {},
}
for key, val in marginal_params.items():
arr = np.asarray(val)
arrays[f"marginal_{i}__{key}"] = arr
m_meta["params"][key] = {
"shape": list(arr.shape),
"dtype": str(arr.dtype),
}
metadata["marginals"].append(m_meta)
else:
raise ValueError(f"Unsupported dist_type: {dist.dist_type!r}")
# Write ZIP archive
with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("metadata.json", json.dumps(metadata, indent=2))
for name, arr in arrays.items():
buf = io.BytesIO()
np.save(buf, arr)
zf.writestr(f"arrays/{name}.npy", buf.getvalue())
# ---------------------------------------------------------------------------
# Callable serialisation (for preprocessing objects with pre_fns/post_fns)
# ---------------------------------------------------------------------------
def _serialise_callable(fn):
"""Serialise a callable by its import path.
Returns a ``{"module", "qualname"}`` dict, or ``None`` when *fn* is
``None``. Raises :class:`ValueError` when *fn* cannot be round-tripped
via ``importlib`` (lambdas, nested / locally-defined functions, or
callables whose qualname does not resolve back to the same object).
A :class:`UserWarning` is emitted when *fn* lives in ``__main__`` —
the save itself succeeds, but reloading in a different session
requires an identically-named callable to be defined in that
session's ``__main__`` too.
"""
if fn is None:
return None
qn = getattr(fn, "__qualname__", None)
mod = getattr(fn, "__module__", None)
if qn is None or mod is None:
raise ValueError(
f"Cannot serialise callable {fn!r}: missing __module__ or "
"__qualname__. Use a plain module-level function."
)
if "<lambda>" in qn:
raise ValueError(
f"Cannot serialise lambda {fn!r}: lambdas have no stable import "
"path. Use a named module-level function instead, or save the "
"scaler after removing pre_fns/post_fns."
)
if "<locals>" in qn:
raise ValueError(
f"Cannot serialise locally-defined callable {fn!r} (qualname "
f"{qn!r}): nested / closure-defined functions cannot be "
"round-tripped by qualname. Move the function to module scope."
)
try:
module = importlib.import_module(mod)
except ImportError as exc:
raise ValueError(
f"Cannot serialise callable {fn!r}: module {mod!r} is not "
f"importable ({exc})."
) from exc
resolved = module
for part in qn.split("."):
try:
resolved = getattr(resolved, part)
except AttributeError as exc:
raise ValueError(
f"Cannot serialise callable {fn!r}: qualname {qn!r} does "
f"not resolve under module {mod!r} ({exc})."
) from exc
if resolved is not fn:
raise ValueError(
f"Cannot serialise callable {fn!r}: qualname {mod}.{qn} "
"resolves to a different object. This happens when a function "
"has been monkey-patched, redefined, or wrapped after import."
)
if mod == "__main__":
warnings.warn(
f"Serialising callable {qn!r} from __main__. The save will "
"succeed, but loading in a different session will only work "
f"if an identically-named callable {qn!r} is defined in that "
"session's __main__ (e.g. the same script is re-run). For "
"portable loading, move the function to an importable module.",
UserWarning,
stacklevel=3,
)
return {"module": mod, "qualname": qn}
def _serialise_fn_pair(fns):
"""Serialise a ``(forward, inverse)`` tuple. Returns a list or ``None``."""
if fns is None:
return None
return [_serialise_callable(fns[0]), _serialise_callable(fns[1])]
def _deserialise_callable(entry):
"""Reverse of :func:`_serialise_callable`. ``None`` passes through."""
if entry is None:
return None
module = importlib.import_module(entry["module"])
obj = module
for part in entry["qualname"].split("."):
obj = getattr(obj, part)
return obj
def _deserialise_fn_pair(entry):
"""Reverse of :func:`_serialise_fn_pair`. Returns a tuple or ``None``."""
if entry is None:
return None
return (_deserialise_callable(entry[0]), _deserialise_callable(entry[1]))
# ---------------------------------------------------------------------------
# Save — preprocessing objects (DataScaler)
# ---------------------------------------------------------------------------
def _save_scaler(scaler, path) -> None:
"""Serialise a fitted preprocessing object to a ``.cpx`` file.
Currently only :class:`~copulax.preprocessing.DataScaler` is
supported. The file contains a ``metadata.json`` describing the
static configuration (method, quantile bounds, mode flags, function
pair qualnames) plus ``arrays/offset.npy`` and ``arrays/scale.npy``
holding the fitted parameters.
Args:
scaler: A fitted :class:`DataScaler` instance.
path: Destination file path. The ``.cpx`` extension is appended
automatically when missing.
Raises:
ValueError: If the scaler has not been fitted, or if any
``pre_fns`` / ``post_fns`` callable cannot be serialised by
qualname (lambdas, closures, etc.).
"""
if not scaler.is_fitted:
raise ValueError(
"Cannot save an unfitted DataScaler (offset/scale are None). "
"Call .fit(x) first."
)
path = Path(path)
if path.suffix != ".cpx":
path = path.with_suffix(path.suffix + ".cpx")
offset_arr = np.asarray(scaler.offset)
scale_arr = np.asarray(scaler.scale)
metadata = {
"dist_family": "preprocessing",
"scaler_class": type(scaler).__name__,
"method": scaler.method,
"q_low": scaler.q_low,
"q_high": scaler.q_high,
"offset_only": scaler.offset_only,
"scale_only": scaler.scale_only,
"pre_fns": _serialise_fn_pair(scaler.pre_fns),
"post_fns": _serialise_fn_pair(scaler.post_fns),
"arrays": {
"offset": {"shape": list(offset_arr.shape), "dtype": str(offset_arr.dtype)},
"scale": {"shape": list(scale_arr.shape), "dtype": str(scale_arr.dtype)},
},
}
with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("metadata.json", json.dumps(metadata, indent=2))
for name, arr in (("offset", offset_arr), ("scale", scale_arr)):
buf = io.BytesIO()
np.save(buf, arr)
zf.writestr(f"arrays/{name}.npy", buf.getvalue())
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
[docs]
def load(path, name: str = None):
"""Load a fitted distribution from a ``.cpx`` file.
Args:
path: Path to the ``.cpx`` file.
name: Optional name for the loaded instance. When ``None`` the
name saved in the file is used.
Returns:
A fitted ``Distribution`` instance.
Raises:
FileNotFoundError: If *path* does not exist.
ValueError: If the file contains an unknown distribution class.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"No such file: {path}")
with zipfile.ZipFile(path, "r") as zf:
metadata = json.loads(zf.read("metadata.json"))
def _read_array(array_name: str) -> jnp.ndarray:
buf = io.BytesIO(zf.read(f"arrays/{array_name}.npy"))
return jnp.asarray(np.load(buf))
dist_family = metadata["dist_family"]
if dist_family in ("univariate", "multivariate"):
dist_class = metadata["dist_class"]
dist_name = name if name is not None else metadata["dist_name"]
template = _get_singleton(dist_class)
params = {
key: _read_array(key) for key in metadata["params"]
}
return template._fitted_instance(params, name=dist_name)
elif dist_family == "copula":
dist_class = metadata["dist_class"]
dist_name = name if name is not None else metadata["dist_name"]
template = _get_singleton(dist_class)
# Copula parameters
copula_params = {
key: _read_array(f"copula__{key}")
for key in metadata["copula_params"]
}
# Marginal distributions
marginals = []
for i, m_meta in enumerate(metadata["marginals"]):
m_template = _get_singleton(m_meta["dist_class"])
m_params = {
key: _read_array(f"marginal_{i}__{key}")
for key in m_meta["params"]
}
marginals.append((m_template, m_params))
params = {
"marginals": tuple(marginals),
"copula": copula_params,
}
return template._fitted_instance(params, name=dist_name)
elif dist_family == "preprocessing":
# Local import to avoid a circular dependency at module import
# time (preprocessing imports nothing from _serialization, but
# _serialization must reference the class by name).
from copulax.preprocessing import DataScaler
scaler_class = metadata.get("scaler_class", "DataScaler")
if scaler_class != "DataScaler":
raise ValueError(
f"Unknown preprocessing scaler class: {scaler_class!r}."
)
offset = _read_array("offset")
scale = _read_array("scale")
return DataScaler(
method=metadata["method"],
q_low=metadata["q_low"],
q_high=metadata["q_high"],
offset_only=metadata["offset_only"],
scale_only=metadata["scale_only"],
pre_fns=_deserialise_fn_pair(metadata.get("pre_fns")),
post_fns=_deserialise_fn_pair(metadata.get("post_fns")),
offset=offset,
scale=scale,
)
else:
raise ValueError(
f"Unknown dist_family in metadata: {dist_family!r}"
)