Utilities
Contains utility functions for the copulax package.
- copulax._src._utils.get_random_key(bytestring_size=7)[source]
Returns a fresh JAX PRNG key seeded from
os.urandom.The hardware draw is wrapped in
jax.pure_callback(), so each call inside an@jax.jit-compiled function receives a distinct seed at runtime. Quality of randomness depends on the OS, and on thejax_enable_x64setting — when x64 is disabled (JAX’s default), the seed is clamped to the int32 range; enable x64 for the full int64 entropy budget.- Parameters:
bytestring_size (
int) – Length of the byte string fromos.urandom. Out-of-range integers are reduced modulo the active int range (int64 with x64, int32 without). Default7.- Return type:
key
Note
Not
vmap-safe:pure_callbackis hoisted out ofvmap, so a vmap’d call returns identical keys across the batch. Pass an explicitkeyandjax.random.splitit for per-leaf entropy. No autograd.- Return type:
key- Returns:
A fresh JAX PRNG key.
- Parameters:
bytestring_size (int)
- copulax._src.optimize.adam(grad, m, v, t, beta1=0.9, beta2=0.999, eps=1e-08)[source]
Adam optimiser.
- Reference:
Kingma, D. P. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. https://arxiv.org/abs/1412.6980
- Parameters:
grad (
Array) – the gradient at the current iteration.m (
Array) – the first moment estimate vector from the prior iteration.v (
Array) – the second raw moment estimate vector from the prior iteration.t (
int) – the prior iteration count.beta1 (
float) – exponential decay rate for the first-moment vectorm.beta2 (
float) – exponential decay rate for the second-moment vectorv.eps (
float) – small constant added to the denominator to prevent division by zero. Defaults to1e-8.
- Return type:
- Returns:
Adam direction, the first moment estimate vector, the second moment estimate vector, the current iteration.
- copulax._src.optimize.single_update(x, d, lr, projection, projection_options)[source]
Update the weights using the projected gradient method.
- Parameters:
x (
Array) – the current parameters from the previous iteration.d (
Array) – the descent direction in the unconstrained space.lr (
float) – learning rate used in projected gradient descent.projection (
Callable) – the projection function to use.projection_options (
dict) – dictionary of options for the projection function.
- Return type:
- Returns:
The updated weights.
- copulax._src.optimize.projected_gradient(f, x0, projection_method, lr=1.0, maxiter=100, adam_options={}, jit_options={}, projection_options={}, **kwargs)[source]
Projected gradient descent for linearly constrained optimisation.
Minimises the objective function
fusing projected gradient descent with Adam gradient updates.- Parameters:
f (
Callable) – objective function to minimise. Must bejax.gradandjax.jitcompatible and return a scalar value. The first argument must be the parameter vector to optimise.x0 (
Array) – initial guess. Must be a flattened array with the same size as the solution.projection_method (
str) – name of the projection function to use. Alloptaxconstrained-optimisation projection functions are supported.lr (
float) – learning rate used in projected gradient descent.maxiter (
int) – maximum number of iterations.adam_options (
dict) – dictionary of options for the Adam optimiser.jit_options (
dict) – kwargs to pass tojax.jitwhen compilingf.projection_options (
dict) – kwargs to pass to the specified projection function.kwargs – additional arguments forwarded to the objective function.
- Return type:
- Returns:
Dictionary containing the optimal results.
- copulax._src.optimize.brent(g, bounds, maxiter=20, tol=1e-12, **kwargs)[source]
Find a root of g in the interval bounds using Brent’s method.
Combines inverse quadratic interpolation, secant, and bisection with acceptance criteria that guarantee convergence. Gradients w.r.t.
**kwargsare computed via the implicit function theorem, so this function is safe to use insidejax.grad.- Parameters:
g (
Callable) – Scalar-valued function. Signatureg(x, **kwargs) -> scalar.bounds (
Array) – Two-element array[a, b]bracketing a root of g.maxiter (
int) – Number of Brent iterations (fixed, forjax.lax.scan).tol (
float) – Absolute convergence tolerance.kwargs – Extra keyword arguments forwarded to g.
- Return type:
Union[float,int,Array,ndarray,bool,number,bool,complex]- Returns:
Scalar root estimate.
- Reference:
Brent, R.P. (1973). Algorithms for Minimization without Derivatives, Chapter 4. Prentice-Hall.
Sample statistics implemented in JAX.
All functions are JIT-compatible and support gradient flow.
Serialization
- copulax._src._serialization.load(path, name=None)[source]
Load a fitted distribution from a
.cpxfile.- Parameters:
path – Path to the
.cpxfile.name (
str) – Optional name for the loaded instance. WhenNonethe name saved in the file is used.
- Returns:
A fitted
Distributioninstance.- Raises:
FileNotFoundError – If path does not exist.
ValueError – If the file contains an unknown distribution class.