Utilities

Contains utility functions for the copulax package.

copulax._src._utils.get_random_key(bytestring_size=7)[source]

Returns a fresh JAX PRNG key seeded from os.urandom.

The hardware draw is wrapped in jax.pure_callback(), so each call inside an @jax.jit-compiled function receives a distinct seed at runtime. Quality of randomness depends on the OS, and on the jax_enable_x64 setting — when x64 is disabled (JAX’s default), the seed is clamped to the int32 range; enable x64 for the full int64 entropy budget.

Parameters:

bytestring_size (int) – Length of the byte string from os.urandom. Out-of-range integers are reduced modulo the active int range (int64 with x64, int32 without). Default 7.

Return type:

key

Note

Not vmap-safe: pure_callback is hoisted out of vmap, so a vmap’d call returns identical keys across the batch. Pass an explicit key and jax.random.split it for per-leaf entropy. No autograd.

Return type:

key

Returns:

A fresh JAX PRNG key.

Parameters:

bytestring_size (int)

copulax._src.optimize.adam(grad, m, v, t, beta1=0.9, beta2=0.999, eps=1e-08)[source]

Adam optimiser.

Reference:

Kingma, D. P. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. https://arxiv.org/abs/1412.6980

Parameters:
  • grad (Array) – the gradient at the current iteration.

  • m (Array) – the first moment estimate vector from the prior iteration.

  • v (Array) – the second raw moment estimate vector from the prior iteration.

  • t (int) – the prior iteration count.

  • beta1 (float) – exponential decay rate for the first-moment vector m.

  • beta2 (float) – exponential decay rate for the second-moment vector v.

  • eps (float) – small constant added to the denominator to prevent division by zero. Defaults to 1e-8.

Return type:

tuple[Array, Array, Array, int]

Returns:

Adam direction, the first moment estimate vector, the second moment estimate vector, the current iteration.

copulax._src.optimize.single_update(x, d, lr, projection, projection_options)[source]

Update the weights using the projected gradient method.

Parameters:
  • x (Array) – the current parameters from the previous iteration.

  • d (Array) – the descent direction in the unconstrained space.

  • lr (float) – learning rate used in projected gradient descent.

  • projection (Callable) – the projection function to use.

  • projection_options (dict) – dictionary of options for the projection function.

Return type:

Array

Returns:

The updated weights.

copulax._src.optimize.projected_gradient(f, x0, projection_method, lr=1.0, maxiter=100, adam_options={}, jit_options={}, projection_options={}, **kwargs)[source]

Projected gradient descent for linearly constrained optimisation.

Minimises the objective function f using projected gradient descent with Adam gradient updates.

Parameters:
  • f (Callable) – objective function to minimise. Must be jax.grad and jax.jit compatible and return a scalar value. The first argument must be the parameter vector to optimise.

  • x0 (Array) – initial guess. Must be a flattened array with the same size as the solution.

  • projection_method (str) – name of the projection function to use. All optax constrained-optimisation projection functions are supported.

  • lr (float) – learning rate used in projected gradient descent.

  • maxiter (int) – maximum number of iterations.

  • adam_options (dict) – dictionary of options for the Adam optimiser.

  • jit_options (dict) – kwargs to pass to jax.jit when compiling f.

  • projection_options (dict) – kwargs to pass to the specified projection function.

  • kwargs – additional arguments forwarded to the objective function.

Return type:

dict

Returns:

Dictionary containing the optimal results.

copulax._src.optimize.brent(g, bounds, maxiter=20, tol=1e-12, **kwargs)[source]

Find a root of g in the interval bounds using Brent’s method.

Combines inverse quadratic interpolation, secant, and bisection with acceptance criteria that guarantee convergence. Gradients w.r.t. **kwargs are computed via the implicit function theorem, so this function is safe to use inside jax.grad.

Parameters:
  • g (Callable) – Scalar-valued function. Signature g(x, **kwargs) -> scalar.

  • bounds (Array) – Two-element array [a, b] bracketing a root of g.

  • maxiter (int) – Number of Brent iterations (fixed, for jax.lax.scan).

  • tol (float) – Absolute convergence tolerance.

  • kwargs – Extra keyword arguments forwarded to g.

Return type:

Union[float, int, Array, ndarray, bool, number, bool, complex]

Returns:

Scalar root estimate.

Reference:

Brent, R.P. (1973). Algorithms for Minimization without Derivatives, Chapter 4. Prentice-Hall.

Sample statistics implemented in JAX.

All functions are JIT-compatible and support gradient flow.

Serialization

copulax._src._serialization.load(path, name=None)[source]

Load a fitted distribution from a .cpx file.

Parameters:
  • path – Path to the .cpx file.

  • name (str) – Optional name for the loaded instance. When None the name saved in the file is used.

Returns:

A fitted Distribution instance.

Raises: