Source code for copulax._src.optimize

import jax.numpy as jnp
import jax
from typing import Callable
import optax.projections as proj
from functools import partial
from jax import Array

from copulax._src.typing import Scalar


###############################################################################
# ADAM optimizer
###############################################################################
[docs] @jax.jit def adam( grad: jnp.ndarray, m: jnp.ndarray, v: jnp.ndarray, t: int, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]: """Adam optimiser. Reference: Kingma, D. P. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. https://arxiv.org/abs/1412.6980 Args: grad: the gradient at the current iteration. m: the first moment estimate vector from the prior iteration. v: the second raw moment estimate vector from the prior iteration. t: the prior iteration count. beta1: exponential decay rate for the first-moment vector ``m``. beta2: exponential decay rate for the second-moment vector ``v``. eps: small constant added to the denominator to prevent division by zero. Defaults to ``1e-8``. Returns: Adam direction, the first moment estimate vector, the second moment estimate vector, the current iteration. """ t += 1 m = beta1 * m + (1 - beta1) * grad v = beta2 * v + (1 - beta2) * grad**2 m_hat = m / (1 - beta1 ** t) v_hat = v / (1 - beta2 ** t) d = m_hat / (jnp.sqrt(v_hat) + eps) return d, m, v, t
[docs] @partial(jax.jit, static_argnames=("projection",)) def single_update( x: jnp.ndarray, d: jnp.ndarray, lr: float, projection: Callable, projection_options: dict, ) -> jnp.ndarray: """Update the weights using the projected gradient method. Args: x: the current parameters from the previous iteration. d: the descent direction in the unconstrained space. lr: learning rate used in projected gradient descent. projection: the projection function to use. projection_options: dictionary of options for the projection function. Returns: The updated weights. """ # Calculate the new weights x_uc: jnp.ndarray = x - lr * d # Project the new weights onto the feasible set x_uc = x_uc[None].T x_proj: jnp.ndarray = projection(x_uc, **projection_options) return x_proj.flatten()
[docs] def projected_gradient( f: Callable, x0: jnp.ndarray, projection_method: str, lr: float = 1.0, maxiter: int = 100, adam_options: dict = {}, jit_options: dict = {}, projection_options: dict = {}, **kwargs, ) -> dict: """Projected gradient descent for linearly constrained optimisation. Minimises the objective function ``f`` using projected gradient descent with Adam gradient updates. Args: f: objective function to minimise. Must be ``jax.grad`` and ``jax.jit`` compatible and return a scalar value. The first argument must be the parameter vector to optimise. x0: initial guess. Must be a flattened array with the same size as the solution. projection_method: name of the projection function to use. All ``optax`` constrained-optimisation projection functions are supported. lr: learning rate used in projected gradient descent. maxiter: maximum number of iterations. adam_options: dictionary of options for the Adam optimiser. jit_options: kwargs to pass to ``jax.jit`` when compiling ``f``. projection_options: kwargs to pass to the specified projection function. kwargs: additional arguments forwarded to the objective function. Returns: Dictionary containing the optimal results. """ # JIT compiling the projection and gradient functions projection: Callable = getattr(proj, projection_method) projection = jax.jit(projection) f_vg: Callable = jax.jit(jax.value_and_grad(f, argnums=0), **jit_options) def _iter(tup: tuple, it): x: jnp.ndarray = tup[0] # current estimate m: jnp.ndarray = tup[2] # first moment estimate v: jnp.ndarray = tup[3] # second moment estimate t: jnp.ndarray = tup[4] # loop iteration count # getting value and gradient in a single forward+backward pass f_val, f_grad = f_vg(x, **kwargs) f_grad = jnp.nan_to_num(f_grad) # replace NaNs with 0s # performing Adam step d, m, v, t = adam(grad=f_grad, m=m, v=v, t=t, **adam_options) # performing projected gradient step x = single_update( x=x, d=d, lr=lr, projection=projection, projection_options=projection_options, ) return (x, f_val, m, v, t), it # initialise the optimization loop m0: jnp.ndarray = jnp.zeros_like(x0) v0: jnp.ndarray = jnp.zeros_like(x0) t: int = 0 init = x0, jnp.inf, m0, v0, t # running projected gradient descent loop res, _ = jax.lax.scan(_iter, init, None, length=maxiter) # getting optimal values x_opt = res[0] val_opt = f_vg(x_opt, **kwargs)[0] return {"x": x_opt, "val": val_opt}
############################################################################### # Brent's method (Brent 1973, Algorithm 4.1) ############################################################################### _DENOM_EPS = 1e-30 def _safe_div(num: Scalar, denom: Scalar) -> Scalar: """Division guarded against zero denominator.""" safe_denom = jnp.where(jnp.abs(denom) < _DENOM_EPS, _DENOM_EPS, denom) return num / safe_denom def _brent_classical( g: Callable, bounds: jnp.ndarray, maxiter: int = 20, tol: float = 1e-12, **kwargs ) -> Scalar: r"""Classical Brent's root-finding algorithm. Adaptively selects between inverse quadratic interpolation, secant, and bisection with acceptance criteria that guarantee convergence. Exactly ``maxiter + 2`` function evaluations are performed (2 for the initial bracket, 1 per scan iteration). Args: g: Scalar-valued function whose root is sought. bounds: Two-element array ``[a, b]`` bracketing the root. maxiter: Fixed number of iterations (for ``jax.lax.scan``). tol: Absolute convergence tolerance. kwargs: Extra keyword arguments forwarded to *g*. Returns: Best root estimate (the bracket endpoint with smallest ``|g|``). Reference: Brent, R.P. (1973). *Algorithms for Minimization without Derivatives*, Chapter 4. """ a, b = bounds fa = g(a, **kwargs) fb = g(b, **kwargs) # Ensure |f(b)| <= |f(a)| so b is the best guess. swap = jnp.abs(fa) < jnp.abs(fb) a, b = jnp.where(swap, b, a), jnp.where(swap, a, b) fa, fb = jnp.where(swap, fb, fa), jnp.where(swap, fa, fb) c, fc = a, fa d = b - a mflag = jnp.array(1.0) # 1.0 = last step was bisection init = (a, b, fa, fb, c, fc, d, mflag) def _step(carry, _): a_, b_, fa_, fb_, c_, fc_, d_, mflag_ = carry # --- interpolation attempt --- # IQI when three distinct function values; secant otherwise. use_iqi = (fa_ != fc_) & (fb_ != fc_) # Secant step: s = b - fb*(b-a)/(fb-fa) s_sec = b_ - _safe_div(fb_ * (b_ - a_), fb_ - fa_) # Inverse quadratic interpolation d1 = (fa_ - fb_) * (fa_ - fc_) d2 = (fb_ - fa_) * (fb_ - fc_) d3 = (fc_ - fa_) * (fc_ - fb_) s_iqi = ( _safe_div(a_ * fb_ * fc_, d1) + _safe_div(b_ * fa_ * fc_, d2) + _safe_div(c_ * fa_ * fb_, d3) ) s_interp = jnp.where(use_iqi, s_iqi, s_sec) # --- bisection fallback --- s_bisect = 0.5 * (a_ + b_) # --- Brent's acceptance criteria --- # s must lie strictly between (3a+b)/4 and b. lo = jnp.minimum(0.75 * a_ + 0.25 * b_, b_) hi = jnp.maximum(0.75 * a_ + 0.25 * b_, b_) cond1 = (s_interp <= lo) | (s_interp >= hi) # Step-size conditions (depend on mflag). abs_sb = jnp.abs(s_interp - b_) abs_bc = jnp.abs(b_ - c_) abs_cd = jnp.abs(c_ - d_) cond2 = (mflag_ > 0.5) & (abs_sb >= 0.5 * abs_bc) cond3 = (mflag_ <= 0.5) & (abs_sb >= 0.5 * abs_cd) cond4 = (mflag_ > 0.5) & (abs_bc < tol) cond5 = (mflag_ <= 0.5) & (abs_cd < tol) use_bisection = cond1 | cond2 | cond3 | cond4 | cond5 s = jnp.where(use_bisection, s_bisect, s_interp) new_mflag = jnp.where(use_bisection, 1.0, 0.0) # --- single function evaluation --- fs = g(s, **kwargs) # --- update bracket --- new_d = c_ new_c = b_ new_fc = fb_ # If fa*fs < 0 the root lies between a and s, so b ← s. root_left = fa_ * fs < 0 new_a = jnp.where(root_left, a_, s) new_fa = jnp.where(root_left, fa_, fs) new_b = jnp.where(root_left, s, b_) new_fb = jnp.where(root_left, fs, fb_) # Swap to maintain |f(b)| <= |f(a)|. need_swap = jnp.abs(new_fa) < jnp.abs(new_fb) fin_a = jnp.where(need_swap, new_b, new_a) fin_fa = jnp.where(need_swap, new_fb, new_fa) fin_b = jnp.where(need_swap, new_a, new_b) fin_fb = jnp.where(need_swap, new_fa, new_fb) return (fin_a, fin_b, fin_fa, fin_fb, new_c, new_fc, new_d, new_mflag), None final, _ = jax.lax.scan(_step, init, None, length=maxiter) return final[1]
[docs] def brent( g: Callable, bounds: jnp.ndarray, maxiter: int = 20, tol: float = 1e-12, **kwargs, ) -> Scalar: r"""Find a root of *g* in the interval *bounds* using Brent's method. Combines inverse quadratic interpolation, secant, and bisection with acceptance criteria that guarantee convergence. Gradients w.r.t. ``**kwargs`` are computed via the implicit function theorem, so this function is safe to use inside ``jax.grad``. Args: g: Scalar-valued function. Signature ``g(x, **kwargs) -> scalar``. bounds: Two-element array ``[a, b]`` bracketing a root of *g*. maxiter: Number of Brent iterations (fixed, for ``jax.lax.scan``). tol: Absolute convergence tolerance. kwargs: Extra keyword arguments forwarded to *g*. Returns: Scalar root estimate. Reference: Brent, R.P. (1973). *Algorithms for Minimization without Derivatives*, Chapter 4. Prentice-Hall. """ bounds = jnp.asarray(bounds, dtype=float).flatten() bounds = jnp.sort(bounds) # Forward solve (no gradients through the iterative loop). x_star = jax.lax.stop_gradient( _brent_classical(g, bounds, maxiter, tol, **kwargs) ) # Implicit differentiation via IFT: # x_out = x* - g(x*,θ) / stop_gradient(∂g/∂x) # Forward value is exact (g(x*)≈0), gradient is the IFT result. # # ∂g/∂x is estimated via central finite differences rather than AD, # so g need not be differentiable w.r.t. x (e.g. betainc-based CDFs). _FD_H = 1e-8 dg_dx = (g(x_star + _FD_H, **kwargs) - g(x_star - _FD_H, **kwargs)) / (2 * _FD_H) g_val = g(x_star, **kwargs) correction = _safe_div(g_val, jax.lax.stop_gradient(dg_dx)) # When the root hasn't converged (g_val far from 0), the correction # can overflow. Clamp to zero so the forward value falls back to # x_star (the best Brent found). The IFT is only valid when # g(x*) ≈ 0, so a large correction signals non-convergence. bracket_width = jnp.abs(bounds[1] - bounds[0]) correction = jnp.where(jnp.abs(correction) > bracket_width, 0.0, correction) correction = jnp.nan_to_num(correction, nan=0.0, posinf=0.0, neginf=0.0) return x_star - correction