A gradient processing and optimization library in JAX
—
The optax.contrib module contains experimental optimizers and techniques under active development. These are cutting-edge optimization methods that may not be as stable as the core optimizers but represent the latest research in optimization.
Note: Experimental features may have API changes in future versions.
def sam(base_optimizer, rho=0.05, normalize=True):
"""
Sharpness-Aware Minimization optimizer.
Args:
base_optimizer: Base optimizer to use (e.g., SGD, Adam)
rho: Neighborhood size for sharpness computation (default: 0.05)
normalize: Whether to normalize perturbation (default: True)
Returns:
GradientTransformation: SAM optimizer
"""def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0):
"""
Prodigy adaptive learning rate optimizer.
Args:
learning_rate: Initial learning rate (default: 1.0)
eps: Numerical stability parameter (default: 1e-8)
beta1: First moment decay rate (default: 0.9)
beta2: Second moment decay rate (default: 0.999)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation: Prodigy optimizer
"""def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4):
"""
Sophia optimizer using second-order information.
Args:
learning_rate: Learning rate
beta1: First moment decay rate (default: 0.965)
beta2: Second moment decay rate (default: 0.99)
eps: Numerical stability parameter (default: 1e-8)
weight_decay: Weight decay coefficient (default: 1e-4)
Returns:
GradientTransformation: Sophia optimizer
"""def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0):
"""
Schedule-free AdamW optimizer that doesn't require learning rate schedules.
Args:
learning_rate: Learning rate (default: 0.0025)
beta1: First moment decay rate (default: 0.9)
beta2: Second moment decay rate (default: 0.999)
eps: Numerical stability parameter (default: 1e-8)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation: Schedule-free AdamW optimizer
"""
def schedule_free_sgd(learning_rate=1.0, momentum=0.9, weight_decay=0.0):
"""
Schedule-free SGD optimizer.
Args:
learning_rate: Learning rate (default: 1.0)
momentum: Momentum coefficient (default: 0.9)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation: Schedule-free SGD optimizer
"""
def schedule_free_eval_params(optimizer_state, step_count):
"""
Extract evaluation parameters from schedule-free optimizer state.
Args:
optimizer_state: State from schedule-free optimizer
step_count: Current training step count
Returns:
Parameters suitable for evaluation/inference
"""def muon(learning_rate, momentum=0.95, nesterov=False):
"""
Muon optimizer with improved momentum handling.
Args:
learning_rate: Learning rate
momentum: Momentum coefficient (default: 0.95)
nesterov: Whether to use Nesterov momentum (default: False)
Returns:
GradientTransformation: Muon optimizer
"""def momo(learning_rate, momentum=0.9):
"""
MoMo optimizer with momentum modulation.
Args:
learning_rate: Learning rate
momentum: Base momentum coefficient (default: 0.9)
Returns:
GradientTransformation: MoMo optimizer
"""
def momo_adam(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8):
"""
MoMo-Adam combining momentum modulation with Adam.
Args:
learning_rate: Learning rate
beta1: First moment decay rate (default: 0.9)
beta2: Second moment decay rate (default: 0.999)
eps: Numerical stability parameter (default: 1e-8)
Returns:
GradientTransformation: MoMo-Adam optimizer
"""def dog(learning_rate, rho=0.05, eps=1e-8):
"""
DoG (Difference of Gaussians) optimizer.
Args:
learning_rate: Learning rate
rho: Difference parameter (default: 0.05)
eps: Numerical stability parameter (default: 1e-8)
Returns:
GradientTransformation: DoG optimizer
"""
def dowg(learning_rate, rho=0.05, eps=1e-8, weight_decay=0.0):
"""
DoWG (DoG with Weight decay) optimizer.
Args:
learning_rate: Learning rate
rho: Difference parameter (default: 0.05)
eps: Numerical stability parameter (default: 1e-8)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation: DoWG optimizer
"""def adopt(learning_rate, eps=1e-8, beta1=0.9, beta2=0.9999, weight_decay=0.0):
"""
ADOPT optimizer with adaptive learning rates.
Args:
learning_rate: Learning rate
eps: Numerical stability parameter (default: 1e-8)
beta1: First moment decay rate (default: 0.9)
beta2: Second moment decay rate (default: 0.9999)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation: ADOPT optimizer
"""def differentially_private_aggregate(
inner_agg_factory,
l2_norm_bound,
noise_multiplier,
seed=None
):
"""
Differentially private gradient aggregation.
Args:
inner_agg_factory: Base aggregation function
l2_norm_bound: L2 norm bound for gradient clipping
noise_multiplier: Noise multiplier for privacy
seed: Random seed (default: None)
Returns:
GradientTransformation: DP aggregation function
"""def ademamix(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8, alpha=5.0):
"""
AdEMAMix optimizer with exponential moving average mixing.
Args:
learning_rate: Learning rate
beta1: First moment decay rate (default: 0.9)
beta2: Second moment decay rate (default: 0.999)
eps: Numerical stability parameter (default: 1e-8)
alpha: Mixing parameter (default: 5.0)
Returns:
GradientTransformation: AdEMAMix optimizer
"""def cocob():
"""
COCOB (Coin-flipping Online Convex Optimization with Budget) optimizer.
Returns:
GradientTransformation: COCOB optimizer (parameter-free)
"""import optax
# Using SAM for better generalization
base_optimizer = optax.sgd(0.1)
sam_optimizer = optax.contrib.sam(base_optimizer, rho=0.05)
# Using schedule-free optimizers
sf_adamw = optax.contrib.schedule_free_adamw(learning_rate=0.001)
# Using experimental adaptive methods
prodigy_opt = optax.contrib.prodigy(learning_rate=1.0)
sophia_opt = optax.contrib.sophia(learning_rate=0.001)
# Training loop with schedule-free optimizer
opt_state = sf_adamw.init(params)
for step in range(num_steps):
grads = compute_gradients(params, data)
updates, opt_state = sf_adamw.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
# Extract evaluation parameters (for schedule-free methods)
if step % eval_interval == 0:
eval_params = optax.contrib.schedule_free_eval_params(opt_state, step)
eval_loss = evaluate(eval_params, eval_data)import optax.contrib
# or
from optax.contrib import sam, prodigy, schedule_free_adamwMany contrib optimizers are based on recent research:
Refer to the respective papers for detailed algorithmic descriptions and theoretical analysis.
Install with Tessl CLI
npx tessl i tessl/pypi-optax