Preprocessing

CopulAX provides a small set of jittable, autograd-compatible preprocessing objects that compose cleanly with the rest of the library. All preprocessors are equinox.Module PyTrees — their fitted parameters are traced JAX arrays, so they can be passed through jax.jit, jax.grad, jax.vmap, and equinox serialisation.

Data Scaling

DataScaler fits an affine rescaling to input data and exposes transform / inverse_transform for later observations. Four methods are supported (z-score, min-max, robust, max-abs), all reducing to a uniform (x - offset) / scale representation. Optional user-supplied pre-/post- transform function pairs allow, for example, z-score normalisation over log-transformed data with a faithful round-trip.

copulAX preprocessing utilities.

Public API for data preprocessing objects.

class copulax.preprocessing.DataScaler(method='zscore', *, q_low=0.25, q_high=0.75, offset_only=False, scale_only=False, pre_fns=None, post_fns=None, offset=None, scale=None)[source]

Jittable, autograd-compatible data scaler.

Fits an affine rescaling of the form \(z = (x - \text{offset}) / \text{scale}\) to input data under one of four methods, then applies the same rescaling (or its inverse) to later observations.

All scaling statistics are reduced over axis 0 (the sample axis). Any trailing axes are treated as feature dimensions and are preserved in the fitted offset / scale arrays. Transform and inverse- transform operations broadcast naturally over any leading batch shape as long as the trailing feature dims match.

Four methods are supported:

  • "zscore": offset = mean(x, axis=0), scale = std(x, axis=0).

  • "minmax": offset = min(x, axis=0), scale = max - min.

  • "robust": offset = median(x, axis=0), scale = q_high - q_low.

  • "maxabs": offset = 0, scale = max(|x|, axis=0).

Zero-variance features (a fitted scale of 0) are silently clamped to 1.0 so division does not break autograd or produce NaNs. Optional offset_only / scale_only flags restrict fitting to centring-only or rescaling-only behaviour. Optional pre_fns / post_fns tuples attach JAX-compliant forward and inverse functions to the pipeline (for example, z-scoring over log-transformed data).

Parameters:
  • method (str) – One of "zscore" (default), "minmax", "robust", or "maxabs".

  • q_low (float) – Lower quantile for the "robust" method. Must satisfy 0 < q_low < q_high < 1. Defaults to 0.25.

  • q_high (float) – Upper quantile for the "robust" method. Defaults to 0.75.

  • offset_only (bool) – If True, the fitted scale is forced to 1 so transform performs centring only. Mutually exclusive with scale_only. Defaults to False.

  • scale_only (bool) – If True, the fitted offset is forced to 0 so transform performs rescaling only. Mutually exclusive with offset_only. Defaults to False.

  • pre_fns (Optional[Tuple[Optional[Callable], Optional[Callable]]]) – Optional (forward, inverse) tuple of JAX-compliant functions applied to the data before the affine scaling. The forward function runs during both fit() and transform(); the inverse runs at the end of inverse_transform(). Either element may be None to skip that direction. Defaults to None.

  • post_fns (Optional[Tuple[Optional[Callable], Optional[Callable]]]) – Optional (forward, inverse) tuple applied after the affine scaling during transform() and inverted first in inverse_transform(). post_fns is not applied during fit(). Same None-skip semantics as pre_fns. Defaults to None.

  • offset (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – Pre-fitted offset array. Normally populated by fit() rather than passed directly.

  • scale (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – Pre-fitted scale array. Normally populated by fit() rather than passed directly.

offset

Fitted offset, shape x.shape[1:]. None until fit.

scale

Fitted scale, shape x.shape[1:]. None until fit.

is_fitted

Whether both offset and scale are populated.

Notes

method, q_low, q_high, offset_only, scale_only, pre_fns, and post_fns are static PyTree fields (eqx.field(static=True)). Only offset and scale are traced leaves. Branching on method is therefore safe under jit — JIT specialises per method.

Example

>>> import jax.numpy as jnp
>>> from copulax.preprocessing import DataScaler
>>> x = jnp.asarray([[0.0, 1.0], [1.0, 3.0], [2.0, 5.0]])
>>> scaler, z = DataScaler("zscore").fit_transform(x)
>>> bool(jnp.allclose(z.mean(axis=0), 0.0, atol=1e-6))
True
>>> bool(jnp.allclose(scaler.inverse_transform(z), x, atol=1e-6))
True
method: str
q_low: float
q_high: float
offset_only: bool
scale_only: bool
pre_fns: Tuple[Callable | None, Callable | None] | None
post_fns: Tuple[Callable | None, Callable | None] | None
offset: Array | None
scale: Array | None
property is_fitted: bool

Whether offset and scale have both been populated.

fit(x)[source]

Fit the scaler to x and return a new fitted instance.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input data of shape (n, *feature_dims). Axis 0 is the sample axis; all remaining axes are feature dims.

Return type:

DataScaler

Returns:

A new DataScaler instance with offset and scale populated. The original instance is unchanged (pure functional).

transform(x)[source]

Apply the fitted scaling to x.

The pipeline is post_forward((pre_forward(x) - offset) / scale); missing halves of pre_fns / post_fns are no-ops.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Data to scale. Trailing dims must match offset / scale.

Return type:

Array

Returns:

The scaled data, same shape as x.

Raises:

ValueError – If the scaler has not been fitted.

inverse_transform(z)[source]

Undo the fitted scaling on z.

The pipeline is pre_inverse(post_inverse(z) * scale + offset); missing halves of pre_fns / post_fns are silently skipped — the caller is responsible for providing inverses when full round-trip fidelity is required.

Parameters:

z (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Previously scaled data. Trailing dims must match offset / scale.

Return type:

Array

Returns:

The unscaled data, same shape as z.

Raises:

ValueError – If the scaler has not been fitted.

fit_transform(x)[source]

Fit the scaler to x and return (fitted_scaler, scaled_x).

Equivalent to fitted = self.fit(x); return fitted, fitted.transform(x) but applies pre_fns forward only once (fit and transform would otherwise each apply it).

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input data of shape (n, *feature_dims).

Return type:

Tuple[DataScaler, Array]

Returns:

A tuple (fitted, scaled) where fitted is the fitted scaler and scaled is the transformed data.

save(path)[source]

Save the fitted scaler to a .cpx file.

The file can be loaded back with copulax.load(). The .cpx extension is appended automatically when missing.

Any pre_fns / post_fns callables are serialised by their import path ({module}.{qualname}) so they can be rehydrated on load without pickle. Lambdas and locally-defined closures cannot be serialised this way and will cause a ValueError at save time — use a module-level function instead, or clear the callable(s) before saving.

Parameters:

path (str) – Destination file path.

Raises:

ValueError – If the scaler has not been fitted, or any attached callable cannot be round-tripped by qualname.

Return type:

None

JAX-native, jittable, autograd-compatible data scaler.

This module provides DataScaler, an equinox-based PyTree object that fits an affine rescaling to input data and exposes transform / inverse_transform for applying and undoing the rescaling on later observations.

Four scaling methods are supported — all reducing to a uniform z = (x - offset) / scale representation:

  • "zscore" (default): centre at the mean, scale by the standard deviation.

  • "minmax": shift so the minimum is zero, scale so the range is one.

  • "robust": centre at the median, scale by the inter-quantile range (default 25/75).

  • "maxabs": no centring, scale by the element-wise absolute maximum.

The class is a proper equinox.Module, so fitted instances compose cleanly with jax.jit, jax.grad, jax.vmap, and equinox PyTree utilities.

class copulax._src.preprocessing.data_scaler.DataScaler(method='zscore', *, q_low=0.25, q_high=0.75, offset_only=False, scale_only=False, pre_fns=None, post_fns=None, offset=None, scale=None)[source]

Jittable, autograd-compatible data scaler.

Fits an affine rescaling of the form \(z = (x - \text{offset}) / \text{scale}\) to input data under one of four methods, then applies the same rescaling (or its inverse) to later observations.

All scaling statistics are reduced over axis 0 (the sample axis). Any trailing axes are treated as feature dimensions and are preserved in the fitted offset / scale arrays. Transform and inverse- transform operations broadcast naturally over any leading batch shape as long as the trailing feature dims match.

Four methods are supported:

  • "zscore": offset = mean(x, axis=0), scale = std(x, axis=0).

  • "minmax": offset = min(x, axis=0), scale = max - min.

  • "robust": offset = median(x, axis=0), scale = q_high - q_low.

  • "maxabs": offset = 0, scale = max(|x|, axis=0).

Zero-variance features (a fitted scale of 0) are silently clamped to 1.0 so division does not break autograd or produce NaNs. Optional offset_only / scale_only flags restrict fitting to centring-only or rescaling-only behaviour. Optional pre_fns / post_fns tuples attach JAX-compliant forward and inverse functions to the pipeline (for example, z-scoring over log-transformed data).

Parameters:
  • method (str) – One of "zscore" (default), "minmax", "robust", or "maxabs".

  • q_low (float) – Lower quantile for the "robust" method. Must satisfy 0 < q_low < q_high < 1. Defaults to 0.25.

  • q_high (float) – Upper quantile for the "robust" method. Defaults to 0.75.

  • offset_only (bool) – If True, the fitted scale is forced to 1 so transform performs centring only. Mutually exclusive with scale_only. Defaults to False.

  • scale_only (bool) – If True, the fitted offset is forced to 0 so transform performs rescaling only. Mutually exclusive with offset_only. Defaults to False.

  • pre_fns (Optional[Tuple[Optional[Callable], Optional[Callable]]]) – Optional (forward, inverse) tuple of JAX-compliant functions applied to the data before the affine scaling. The forward function runs during both fit() and transform(); the inverse runs at the end of inverse_transform(). Either element may be None to skip that direction. Defaults to None.

  • post_fns (Optional[Tuple[Optional[Callable], Optional[Callable]]]) – Optional (forward, inverse) tuple applied after the affine scaling during transform() and inverted first in inverse_transform(). post_fns is not applied during fit(). Same None-skip semantics as pre_fns. Defaults to None.

  • offset (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – Pre-fitted offset array. Normally populated by fit() rather than passed directly.

  • scale (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – Pre-fitted scale array. Normally populated by fit() rather than passed directly.

offset

Fitted offset, shape x.shape[1:]. None until fit.

scale

Fitted scale, shape x.shape[1:]. None until fit.

is_fitted

Whether both offset and scale are populated.

Notes

method, q_low, q_high, offset_only, scale_only, pre_fns, and post_fns are static PyTree fields (eqx.field(static=True)). Only offset and scale are traced leaves. Branching on method is therefore safe under jit — JIT specialises per method.

Example

>>> import jax.numpy as jnp
>>> from copulax.preprocessing import DataScaler
>>> x = jnp.asarray([[0.0, 1.0], [1.0, 3.0], [2.0, 5.0]])
>>> scaler, z = DataScaler("zscore").fit_transform(x)
>>> bool(jnp.allclose(z.mean(axis=0), 0.0, atol=1e-6))
True
>>> bool(jnp.allclose(scaler.inverse_transform(z), x, atol=1e-6))
True
method: str
q_low: float
q_high: float
offset_only: bool
scale_only: bool
pre_fns: Tuple[Callable | None, Callable | None] | None
post_fns: Tuple[Callable | None, Callable | None] | None
offset: Array | None
scale: Array | None
property is_fitted: bool

Whether offset and scale have both been populated.

fit(x)[source]

Fit the scaler to x and return a new fitted instance.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input data of shape (n, *feature_dims). Axis 0 is the sample axis; all remaining axes are feature dims.

Return type:

DataScaler

Returns:

A new DataScaler instance with offset and scale populated. The original instance is unchanged (pure functional).

transform(x)[source]

Apply the fitted scaling to x.

The pipeline is post_forward((pre_forward(x) - offset) / scale); missing halves of pre_fns / post_fns are no-ops.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Data to scale. Trailing dims must match offset / scale.

Return type:

Array

Returns:

The scaled data, same shape as x.

Raises:

ValueError – If the scaler has not been fitted.

inverse_transform(z)[source]

Undo the fitted scaling on z.

The pipeline is pre_inverse(post_inverse(z) * scale + offset); missing halves of pre_fns / post_fns are silently skipped — the caller is responsible for providing inverses when full round-trip fidelity is required.

Parameters:

z (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Previously scaled data. Trailing dims must match offset / scale.

Return type:

Array

Returns:

The unscaled data, same shape as z.

Raises:

ValueError – If the scaler has not been fitted.

fit_transform(x)[source]

Fit the scaler to x and return (fitted_scaler, scaled_x).

Equivalent to fitted = self.fit(x); return fitted, fitted.transform(x) but applies pre_fns forward only once (fit and transform would otherwise each apply it).

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input data of shape (n, *feature_dims).

Return type:

Tuple[DataScaler, Array]

Returns:

A tuple (fitted, scaled) where fitted is the fitted scaler and scaled is the transformed data.

save(path)[source]

Save the fitted scaler to a .cpx file.

The file can be loaded back with copulax.load(). The .cpx extension is appended automatically when missing.

Any pre_fns / post_fns callables are serialised by their import path ({module}.{qualname}) so they can be rehydrated on load without pickle. Lambdas and locally-defined closures cannot be serialised this way and will cause a ValueError at save time — use a module-level function instead, or clear the callable(s) before saving.

Parameters:

path (str) – Destination file path.

Raises:

ValueError – If the scaler has not been fitted, or any attached callable cannot be round-tripped by qualname.

Return type:

None