A gradient processing and optimization library in JAX
—
Popular optimization algorithms that are ready for immediate use in training loops. These optimizers combine multiple gradient transformations into complete optimization strategies with sensible defaults.
The Adam optimizer with optional Nesterov momentum. Combines adaptive learning rates with momentum for efficient optimization across a wide range of problems.
def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, eps_root=0.0, mu_dtype=None, *, nesterov=False):
"""
Adam optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
eps_root: Small constant for numerical stability in denominator (default: 0.0)
mu_dtype: Optional dtype for momentum accumulator (default: None)
nesterov: Whether to use Nesterov momentum (default: False)
Returns:
GradientTransformationExtraArgs
"""Adam optimizer with decoupled weight decay. Separates weight decay from gradient-based updates for better generalization.
def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False):
"""
AdamW optimizer with decoupled weight decay.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
weight_decay: Weight decay coefficient (default: 1e-4)
nesterov: Whether to use Nesterov momentum (default: False)
Returns:
GradientTransformation
"""Classic SGD optimizer with optional momentum and Nesterov acceleration.
def sgd(learning_rate, momentum=None, nesterov=False):
"""
Stochastic gradient descent optimizer.
Args:
learning_rate: Learning rate or schedule
momentum: Momentum coefficient (default: None for no momentum)
nesterov: Whether to use Nesterov momentum (default: False)
Returns:
GradientTransformation
"""RMSprop optimizer with adaptive learning rates based on recent gradient magnitudes.
def rmsprop(learning_rate, decay=0.9, eps=1e-8):
"""
RMSprop optimizer.
Args:
learning_rate: Learning rate or schedule
decay: Decay rate for moving average of squared gradients (default: 0.9)
eps: Small constant for numerical stability (default: 1e-8)
Returns:
GradientTransformation
"""Adagrad optimizer with adaptive learning rates that decrease over time.
def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7):
"""
Adagrad optimizer.
Args:
learning_rate: Learning rate or schedule
initial_accumulator_value: Initial value for accumulator (default: 0.1)
eps: Small constant for numerical stability (default: 1e-7)
Returns:
GradientTransformation
"""Adadelta optimizer that adapts learning rates based on a moving window of gradient updates.
def adadelta(learning_rate=1.0, rho=0.9, eps=1e-6):
"""
Adadelta optimizer.
Args:
learning_rate: Learning rate (default: 1.0)
rho: Decay rate for moving averages (default: 0.9)
eps: Small constant for numerical stability (default: 1e-6)
Returns:
GradientTransformation
"""Adamax optimizer, a variant of Adam based on the infinity norm.
def adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
"""
Adamax optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for exponentially weighted infinity norm (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
Returns:
GradientTransformation
"""Nesterov-accelerated Adam optimizer combining Adam with Nesterov momentum.
def nadam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
"""
Nadam optimizer (Nesterov-accelerated Adam).
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
Returns:
GradientTransformation
"""AdaBelief optimizer that adapts the step size according to the "belief" in the observed gradients.
def adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16, *, nesterov=False):
"""
AdaBelief optimizer.
Args:
learning_rate: Learning rate or schedule
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-16)
eps_root: Small constant for numerical stability in denominator (default: 1e-16)
nesterov: Whether to use Nesterov momentum (default: False)
Returns:
GradientTransformation
"""import optax
import jax.numpy as jnp
# Initialize parameters
params = {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros((5,))}
# Create different optimizers
adam_opt = optax.adam(learning_rate=0.001)
sgd_opt = optax.sgd(learning_rate=0.01, momentum=0.9)
adamw_opt = optax.adamw(learning_rate=0.001, weight_decay=1e-4)
# Initialize optimizer state
adam_state = adam_opt.init(params)
sgd_state = sgd_opt.init(params)
adamw_state = adamw_opt.init(params)
# In training loop (example with Adam)
def training_step(params, opt_state, gradients):
updates, new_opt_state = adam_opt.update(gradients, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_stateInstall with Tessl CLI
npx tessl i tessl/pypi-optax