Preprocessing
CopulAX provides a small set of jittable, autograd-compatible
preprocessing objects that compose cleanly with the rest of the
library. All preprocessors are equinox.Module PyTrees — their
fitted parameters are traced JAX arrays, so they can be passed through
jax.jit, jax.grad, jax.vmap, and equinox serialisation.
Data Scaling
DataScaler fits an affine rescaling to
input data and exposes transform / inverse_transform for later
observations. Four methods are supported (z-score, min-max, robust,
max-abs), all reducing to a uniform (x - offset) / scale
representation. Optional user-supplied pre-/post- transform function
pairs allow, for example, z-score normalisation over log-transformed
data with a faithful round-trip.
copulAX preprocessing utilities.
Public API for data preprocessing objects.
- class copulax.preprocessing.DataScaler(method='zscore', *, q_low=0.25, q_high=0.75, offset_only=False, scale_only=False, pre_fns=None, post_fns=None, offset=None, scale=None)[source]
Jittable, autograd-compatible data scaler.
Fits an affine rescaling of the form \(z = (x - \text{offset}) / \text{scale}\) to input data under one of four methods, then applies the same rescaling (or its inverse) to later observations.
All scaling statistics are reduced over axis 0 (the sample axis). Any trailing axes are treated as feature dimensions and are preserved in the fitted
offset/scalearrays. Transform and inverse- transform operations broadcast naturally over any leading batch shape as long as the trailing feature dims match.Four methods are supported:
"zscore":offset = mean(x, axis=0),scale = std(x, axis=0)."minmax":offset = min(x, axis=0),scale = max - min."robust":offset = median(x, axis=0),scale = q_high - q_low."maxabs":offset = 0,scale = max(|x|, axis=0).
Zero-variance features (a fitted
scaleof0) are silently clamped to1.0so division does not break autograd or produce NaNs. Optionaloffset_only/scale_onlyflags restrict fitting to centring-only or rescaling-only behaviour. Optionalpre_fns/post_fnstuples attach JAX-compliant forward and inverse functions to the pipeline (for example, z-scoring over log-transformed data).- Parameters:
method (
str) – One of"zscore"(default),"minmax","robust", or"maxabs".q_low (
float) – Lower quantile for the"robust"method. Must satisfy0 < q_low < q_high < 1. Defaults to0.25.q_high (
float) – Upper quantile for the"robust"method. Defaults to0.75.offset_only (
bool) – IfTrue, the fittedscaleis forced to1sotransformperforms centring only. Mutually exclusive withscale_only. Defaults toFalse.scale_only (
bool) – IfTrue, the fittedoffsetis forced to0sotransformperforms rescaling only. Mutually exclusive withoffset_only. Defaults toFalse.pre_fns (
Optional[Tuple[Optional[Callable],Optional[Callable]]]) – Optional(forward, inverse)tuple of JAX-compliant functions applied to the data before the affine scaling. The forward function runs during bothfit()andtransform(); the inverse runs at the end ofinverse_transform(). Either element may beNoneto skip that direction. Defaults toNone.post_fns (
Optional[Tuple[Optional[Callable],Optional[Callable]]]) – Optional(forward, inverse)tuple applied after the affine scaling duringtransform()and inverted first ininverse_transform().post_fnsis not applied duringfit(). SameNone-skip semantics aspre_fns. Defaults toNone.offset (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – Pre-fitted offset array. Normally populated byfit()rather than passed directly.scale (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – Pre-fitted scale array. Normally populated byfit()rather than passed directly.
- offset
Fitted offset, shape
x.shape[1:].Noneuntil fit.
- scale
Fitted scale, shape
x.shape[1:].Noneuntil fit.
- is_fitted
Whether both
offsetandscaleare populated.
Notes
method,q_low,q_high,offset_only,scale_only,pre_fns, andpost_fnsare static PyTree fields (eqx.field(static=True)). Onlyoffsetandscaleare traced leaves. Branching onmethodis therefore safe underjit— JIT specialises per method.Example
>>> import jax.numpy as jnp >>> from copulax.preprocessing import DataScaler >>> x = jnp.asarray([[0.0, 1.0], [1.0, 3.0], [2.0, 5.0]]) >>> scaler, z = DataScaler("zscore").fit_transform(x) >>> bool(jnp.allclose(z.mean(axis=0), 0.0, atol=1e-6)) True >>> bool(jnp.allclose(scaler.inverse_transform(z), x, atol=1e-6)) True
- fit(x)[source]
Fit the scaler to
xand return a new fitted instance.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input data of shape(n, *feature_dims). Axis 0 is the sample axis; all remaining axes are feature dims.- Return type:
- Returns:
A new
DataScalerinstance withoffsetandscalepopulated. The original instance is unchanged (pure functional).
- transform(x)[source]
Apply the fitted scaling to
x.The pipeline is
post_forward((pre_forward(x) - offset) / scale); missing halves ofpre_fns/post_fnsare no-ops.
- inverse_transform(z)[source]
Undo the fitted scaling on
z.The pipeline is
pre_inverse(post_inverse(z) * scale + offset); missing halves ofpre_fns/post_fnsare silently skipped — the caller is responsible for providing inverses when full round-trip fidelity is required.
- fit_transform(x)[source]
Fit the scaler to
xand return(fitted_scaler, scaled_x).Equivalent to
fitted = self.fit(x); return fitted, fitted.transform(x)but appliespre_fnsforward only once (fitandtransformwould otherwise each apply it).
- save(path)[source]
Save the fitted scaler to a
.cpxfile.The file can be loaded back with
copulax.load(). The.cpxextension is appended automatically when missing.Any
pre_fns/post_fnscallables are serialised by their import path ({module}.{qualname}) so they can be rehydrated on load withoutpickle. Lambdas and locally-defined closures cannot be serialised this way and will cause aValueErrorat save time — use a module-level function instead, or clear the callable(s) before saving.- Parameters:
path (
str) – Destination file path.- Raises:
ValueError – If the scaler has not been fitted, or any attached callable cannot be round-tripped by qualname.
- Return type:
JAX-native, jittable, autograd-compatible data scaler.
This module provides DataScaler, an equinox-based PyTree
object that fits an affine rescaling to input data and exposes
transform / inverse_transform for applying and undoing the
rescaling on later observations.
Four scaling methods are supported — all reducing to a uniform
z = (x - offset) / scale representation:
"zscore"(default): centre at the mean, scale by the standard deviation."minmax": shift so the minimum is zero, scale so the range is one."robust": centre at the median, scale by the inter-quantile range (default 25/75)."maxabs": no centring, scale by the element-wise absolute maximum.
The class is a proper equinox.Module, so fitted instances compose
cleanly with jax.jit, jax.grad, jax.vmap, and equinox PyTree
utilities.
- class copulax._src.preprocessing.data_scaler.DataScaler(method='zscore', *, q_low=0.25, q_high=0.75, offset_only=False, scale_only=False, pre_fns=None, post_fns=None, offset=None, scale=None)[source]
Jittable, autograd-compatible data scaler.
Fits an affine rescaling of the form \(z = (x - \text{offset}) / \text{scale}\) to input data under one of four methods, then applies the same rescaling (or its inverse) to later observations.
All scaling statistics are reduced over axis 0 (the sample axis). Any trailing axes are treated as feature dimensions and are preserved in the fitted
offset/scalearrays. Transform and inverse- transform operations broadcast naturally over any leading batch shape as long as the trailing feature dims match.Four methods are supported:
"zscore":offset = mean(x, axis=0),scale = std(x, axis=0)."minmax":offset = min(x, axis=0),scale = max - min."robust":offset = median(x, axis=0),scale = q_high - q_low."maxabs":offset = 0,scale = max(|x|, axis=0).
Zero-variance features (a fitted
scaleof0) are silently clamped to1.0so division does not break autograd or produce NaNs. Optionaloffset_only/scale_onlyflags restrict fitting to centring-only or rescaling-only behaviour. Optionalpre_fns/post_fnstuples attach JAX-compliant forward and inverse functions to the pipeline (for example, z-scoring over log-transformed data).- Parameters:
method (
str) – One of"zscore"(default),"minmax","robust", or"maxabs".q_low (
float) – Lower quantile for the"robust"method. Must satisfy0 < q_low < q_high < 1. Defaults to0.25.q_high (
float) – Upper quantile for the"robust"method. Defaults to0.75.offset_only (
bool) – IfTrue, the fittedscaleis forced to1sotransformperforms centring only. Mutually exclusive withscale_only. Defaults toFalse.scale_only (
bool) – IfTrue, the fittedoffsetis forced to0sotransformperforms rescaling only. Mutually exclusive withoffset_only. Defaults toFalse.pre_fns (
Optional[Tuple[Optional[Callable],Optional[Callable]]]) – Optional(forward, inverse)tuple of JAX-compliant functions applied to the data before the affine scaling. The forward function runs during bothfit()andtransform(); the inverse runs at the end ofinverse_transform(). Either element may beNoneto skip that direction. Defaults toNone.post_fns (
Optional[Tuple[Optional[Callable],Optional[Callable]]]) – Optional(forward, inverse)tuple applied after the affine scaling duringtransform()and inverted first ininverse_transform().post_fnsis not applied duringfit(). SameNone-skip semantics aspre_fns. Defaults toNone.offset (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – Pre-fitted offset array. Normally populated byfit()rather than passed directly.scale (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – Pre-fitted scale array. Normally populated byfit()rather than passed directly.
- offset
Fitted offset, shape
x.shape[1:].Noneuntil fit.
- scale
Fitted scale, shape
x.shape[1:].Noneuntil fit.
- is_fitted
Whether both
offsetandscaleare populated.
Notes
method,q_low,q_high,offset_only,scale_only,pre_fns, andpost_fnsare static PyTree fields (eqx.field(static=True)). Onlyoffsetandscaleare traced leaves. Branching onmethodis therefore safe underjit— JIT specialises per method.Example
>>> import jax.numpy as jnp >>> from copulax.preprocessing import DataScaler >>> x = jnp.asarray([[0.0, 1.0], [1.0, 3.0], [2.0, 5.0]]) >>> scaler, z = DataScaler("zscore").fit_transform(x) >>> bool(jnp.allclose(z.mean(axis=0), 0.0, atol=1e-6)) True >>> bool(jnp.allclose(scaler.inverse_transform(z), x, atol=1e-6)) True
- fit(x)[source]
Fit the scaler to
xand return a new fitted instance.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input data of shape(n, *feature_dims). Axis 0 is the sample axis; all remaining axes are feature dims.- Return type:
- Returns:
A new
DataScalerinstance withoffsetandscalepopulated. The original instance is unchanged (pure functional).
- transform(x)[source]
Apply the fitted scaling to
x.The pipeline is
post_forward((pre_forward(x) - offset) / scale); missing halves ofpre_fns/post_fnsare no-ops.
- inverse_transform(z)[source]
Undo the fitted scaling on
z.The pipeline is
pre_inverse(post_inverse(z) * scale + offset); missing halves ofpre_fns/post_fnsare silently skipped — the caller is responsible for providing inverses when full round-trip fidelity is required.
- fit_transform(x)[source]
Fit the scaler to
xand return(fitted_scaler, scaled_x).Equivalent to
fitted = self.fit(x); return fitted, fitted.transform(x)but appliespre_fnsforward only once (fitandtransformwould otherwise each apply it).
- save(path)[source]
Save the fitted scaler to a
.cpxfile.The file can be loaded back with
copulax.load(). The.cpxextension is appended automatically when missing.Any
pre_fns/post_fnscallables are serialised by their import path ({module}.{qualname}) so they can be rehydrated on load withoutpickle. Lambdas and locally-defined closures cannot be serialised this way and will cause aValueErrorat save time — use a module-level function instead, or clear the callable(s) before saving.- Parameters:
path (
str) – Destination file path.- Raises:
ValueError – If the scaler has not been fitted, or any attached callable cannot be round-tripped by qualname.
- Return type: