A gradient processing and optimization library in JAX
—
Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers. These optimizers implement cutting-edge techniques and may require more careful tuning than core optimizers.
Lion (Evolved Sign Momentum) optimizer that uses sign-based updates for memory efficiency and competitive performance.
def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0):
"""
Lion optimizer (Evolved Sign Momentum).
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for momentum (default: 0.9)
b2: Exponential decay rate for moving average (default: 0.99)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation
"""Layer-wise Adaptive Rate Scaling (LARS) optimizer for large batch training.
def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.):
"""
LARS (Layer-wise Adaptive Rate Scaling) optimizer.
Args:
learning_rate: Learning rate or schedule
weight_decay: Weight decay coefficient (default: 0.0)
trust_coefficient: Trust coefficient for layer-wise adaptation (default: 0.001)
eps: Small constant for numerical stability (default: 0.0)
Returns:
GradientTransformation
"""Layer-wise Adaptive Moments optimizer for Batch training, designed for large batch sizes.
def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None):
"""
LAMB (Layer-wise Adaptive Moments optimizer for Batch training) optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-6)
weight_decay: Weight decay coefficient (default: 0.0)
mask: Optional mask for parameter selection
Returns:
GradientTransformation
"""Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.
def lbfgs(learning_rate, memory_size=10, scale_init_preconditioner=True):
"""
L-BFGS quasi-Newton optimizer.
Args:
learning_rate: Learning rate or schedule
memory_size: Number of previous gradients to store (default: 10)
scale_init_preconditioner: Whether to scale initial preconditioner (default: True)
Returns:
GradientTransformation
"""Yogi optimizer that controls the increase in effective learning rate to avoid rapid convergence.
def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6):
"""
Yogi optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-3)
initial_accumulator: Initial value for accumulator (default: 1e-6)
Returns:
GradientTransformation
"""NovoGrad optimizer that combines adaptive learning rates with gradient normalization.
def novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-6, weight_decay=0.):
"""
NovoGrad optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.25)
eps: Small constant for numerical stability (default: 1e-6)
weight_decay: Weight decay coefficient (default: 0.0)
Returns:
GradientTransformation
"""Rectified Adam optimizer that addresses the variance issue in early training stages.
def radam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, threshold=5.0):
"""
RAdam (Rectified Adam) optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
threshold: Threshold for variance tractability (default: 5.0)
Returns:
GradientTransformation
"""SM3 optimizer designed for sparse gradients with memory-efficient second moments.
def sm3(learning_rate, momentum=0.9):
"""
SM3 optimizer for sparse gradients.
Args:
learning_rate: Learning rate or schedule
momentum: Momentum coefficient (default: 0.9)
Returns:
GradientTransformation
"""Frobenius matched gradient descent optimizer.
def fromage(learning_rate):
"""
Fromage (Frobenius matched gradient descent) optimizer.
Args:
learning_rate: Learning rate or schedule
Returns:
GradientTransformation
"""SGD with gradient noise injection for improved generalization.
def noisy_sgd(learning_rate, eta=0.01):
"""
Noisy SGD with gradient noise injection.
Args:
learning_rate: Learning rate or schedule
eta: Noise scaling parameter (default: 0.01)
Returns:
GradientTransformation
"""SGD using only the sign of gradients.
def sign_sgd(learning_rate):
"""
Sign SGD optimizer using gradient signs only.
Args:
learning_rate: Learning rate or schedule
Returns:
GradientTransformation
"""SGD with Polyak momentum.
def polyak_sgd(learning_rate, polyak_momentum=0.9):
"""
SGD with Polyak momentum.
Args:
learning_rate: Learning rate or schedule
polyak_momentum: Polyak momentum coefficient (default: 0.9)
Returns:
GradientTransformation
"""Resilient backpropagation optimizer that uses only gradient signs.
def rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-6, max_step_size=50.):
"""
RProp (Resilient backpropagation) optimizer.
Args:
learning_rate: Initial step size
eta_minus: Factor for decreasing step size (default: 0.5)
eta_plus: Factor for increasing step size (default: 1.2)
min_step_size: Minimum step size (default: 1e-6)
max_step_size: Maximum step size (default: 50.0)
Returns:
GradientTransformation
"""Optimistic gradient descent for saddle point problems.
def optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0):
"""
Optimistic gradient descent.
Args:
learning_rate: Learning rate or schedule
alpha: Extrapolation coefficient (default: 1.0)
beta: Update coefficient (default: 1.0)
Returns:
GradientTransformation
"""Optimistic variant of Adam optimizer.
def optimistic_adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
"""
Optimistic Adam optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
Returns:
GradientTransformation
"""Lookahead optimizer that can wrap any base optimizer.
def lookahead(fast_optimizer, lookahead_steps=5, lookahead_alpha=0.5):
"""
Lookahead optimizer wrapper.
Args:
fast_optimizer: Base optimizer to wrap
lookahead_steps: Number of fast optimizer steps before lookahead (default: 5)
lookahead_alpha: Interpolation factor for lookahead (default: 0.5)
Returns:
GradientTransformation
"""import optax
import jax.numpy as jnp
# Initialize parameters
params = {'weights': jnp.ones((100, 50)), 'bias': jnp.zeros((50,))}
# Advanced optimizers for different scenarios
lion_opt = optax.lion(learning_rate=0.0001) # Memory efficient
lars_opt = optax.lars(learning_rate=0.01) # Large batch training
lamb_opt = optax.lamb(learning_rate=0.001) # Large batch training
lbfgs_opt = optax.lbfgs(learning_rate=1.0) # Second-order method
# Lookahead wrapper
base_opt = optax.adam(learning_rate=0.001)
lookahead_opt = optax.lookahead(base_opt, lookahead_steps=5)
# Initialize states
lion_state = lion_opt.init(params)
lookahead_state = lookahead_opt.init(params)
# Usage in training loop
def training_step(params, opt_state, gradients, optimizer):
updates, new_opt_state = optimizer.update(gradients, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_stateInstall with Tessl CLI
npx tessl i tessl/pypi-optax