A gradient processing and optimization library in JAX
npx @tessl/cli install tessl/pypi-optax@0.2.0A gradient processing and optimization library in JAX. Optax provides modular building blocks that can be easily recombined to create custom optimizers and gradient processing components. The library offers implementations of many popular optimizers, loss functions, and gradient transformations with a focus on composability and research productivity.
pip install optaximport optaxCommon usage patterns:
# Import specific optimizers
from optax import adam, sgd, adamw
# Import transformations and utilities
from optax import apply_updates, chain
# Import loss functions
from optax import l2_loss, softmax_cross_entropy
# Import schedules
from optax import linear_schedule, cosine_decay_scheduleimport jax
import jax.numpy as jnp
import optax
# Initialize model parameters
params = {'w': jnp.ones((10,)), 'b': jnp.zeros((1,))}
# Create an optimizer
optimizer = optax.adam(learning_rate=0.001)
# Initialize optimizer state
opt_state = optimizer.init(params)
# Define a simple loss function
def loss_fn(params, x, y):
pred = params['w'].dot(x) + params['b']
return optax.l2_loss(pred, y)
# Training step
def train_step(params, opt_state, x, y):
# Compute gradients
grads = jax.grad(loss_fn)(params, x, y)
# Update parameters
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Example training data
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y = jnp.array([2.0])
# Perform training step
params, opt_state = train_step(params, opt_state, x, y)Optax is built around three key concepts:
init and update functions that process gradientsoptax.chain() to create custom optimizersThe library provides implementations at multiple levels of abstraction:
Popular optimization algorithms including Adam, SGD, RMSprop, Adagrad, and many others. These are complete optimizers ready for immediate use in training loops.
def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...
def sgd(learning_rate, momentum=None, nesterov=False): ...
def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False): ...
def rmsprop(learning_rate, decay=0.9, eps=1e-8): ...
def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7): ...Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers.
def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0): ...
def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.): ...
def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None): ...
def lbfgs(learning_rate, ...): ...
def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6): ...Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These can be combined using chain() to build custom optimization strategies.
def scale(step_size): ...
def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...
def clip_by_global_norm(max_norm): ...
def add_decayed_weights(weight_decay, mask=None): ...
def trace(decay, nesterov=False, accumulator_dtype=None): ...
def chain(*transformations): ...Comprehensive collection of loss functions for classification, regression, and structured prediction tasks.
def l2_loss(predictions, targets): ...
def softmax_cross_entropy(logits, labels, axis=-1): ...
def sigmoid_binary_cross_entropy(logits, labels): ...
def huber_loss(predictions, targets, delta=1.0): ...
def hinge_loss(scores, labels): ...Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules.
def constant_schedule(value): ...
def linear_schedule(init_value, end_value, transition_steps): ...
def cosine_decay_schedule(init_value, decay_steps, alpha=0.0): ...
def exponential_decay(init_value, decay_rate, transition_steps, ...): ...
def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value): ...Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees.
def apply_updates(params, updates): ...
def global_norm(updates): ...
def safe_norm(x, min_norm=0.0, ord=None): ...
class GradientTransformation: ...
class OptState: ...
class Params: ...Linear assignment algorithms including the Hungarian algorithm for solving optimal assignment problems.
def hungarian_algorithm(cost_matrix): ...
def base_hungarian_algorithm(cost_matrix): ...Utilities for Monte Carlo gradient estimation methods including score function, pathwise, and measure-valued estimators. Note: These functions are deprecated and will be removed in version 0.3.0.
def score_function_jacobians(function, params, dist_builder, rng, num_samples): ...
def pathwise_jacobians(function, params, dist_builder, rng, num_samples): ...
def measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True): ...Utilities for making non-differentiable functions differentiable through stochastic perturbations.
def make_perturbed_fun(fun, num_samples=1000, sigma=0.1, noise=Gumbel(), use_baseline=True): ...
class Gumbel: ...
class Normal: ...Projection functions for enforcing constraints in optimization by projecting parameters onto feasible sets.
def projection_l2_ball(params, radius=1.0): ...
def projection_simplex(params): ...
def projection_box(params, lower=None, upper=None): ...Utilities for second-order optimization including Hessian computations and Fisher information.
def hessian_diag(fun): ...
def fisher_diag(log_likelihood): ...
def hvp(fun, primals, tangents): ...JAX PyTree manipulation utilities for working with nested parameter structures.
def tree_add(tree_a, tree_b): ...
def tree_scale(tree, scalar): ...
def tree_zeros_like(tree): ...The optax.contrib module contains experimental optimizers and techniques under active development, including SAM, Prodigy, Sophia, and schedule-free optimizers.
# Sharpness-Aware Minimization
def sam(base_optimizer, rho=0.05, normalize=True): ...
# Advanced adaptive optimizers
def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0): ...
def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4): ...
# Schedule-free optimizers
def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0): ...# Core type aliases
OptState = chex.ArrayTree # Optimizer state
Params = chex.ArrayTree # Model parameters
Updates = Params # Gradient updates
Schedule = Callable[[chex.Numeric], chex.Numeric] # Schedule function
ScalarOrSchedule = Union[float, jax.Array, Schedule]
# Core classes
class GradientTransformation(NamedTuple):
init: Callable[[Params], OptState]
update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
class GradientTransformationExtraArgs(NamedTuple):
init: Callable[[Params], OptState]
update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
class EmptyState(NamedTuple):
"""Empty state for stateless transformations"""
pass
# Transformation function types
TransformInitFn = Callable[[Params], OptState]
TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
TransformUpdateExtraArgsFn = Callable[..., Tuple[Updates, OptState]]
# Optimizer state classes
class ScaleByAdamState(NamedTuple):
count: chex.Array
mu: Updates
nu: Updates
class ScaleByRmsState(NamedTuple):
count: chex.Array
nu: Updates
class ScaleByScheduleState(NamedTuple):
count: chex.Array
class FactoredState(NamedTuple):
v_row: chex.Array
v_col: chex.Array
v: chex.Array
class LookaheadParams(NamedTuple):
slow: Params
LookaheadState = LookaheadParams
class ApplyEvery(NamedTuple):
count: chex.Array
grad_acc: Updates
# Tree and projection types
MaskOrFn = Union[chex.Array, Callable[[Params], chex.Array]]
MaskedNode = Any
# Schedule types
WrappedSchedule = Callable[[chex.Numeric], chex.Numeric]
# Assignment types (from optax.assignment)
CostMatrix = chex.Array
Assignment = Tuple[chex.Array, chex.Array] # (row_indices, col_indices)
# Monte Carlo types (from optax.monte_carlo) - deprecated
ControlVariate = Tuple[Callable, Callable, Callable]
CvState = Any
# Perturbation types (from optax.perturbations)
NoiseDistribution = Any # Objects with sample() and log_prob() methods
# Contrib optimizer state classes (experimental)
class ScaleByAdemamixState(NamedTuple):
count: chex.Array
mu: Updates
nu: Updates
class MuonState(NamedTuple):
momentum: Updates
class COCOBState(NamedTuple):
sum_grad_squared: Updates
sum_grad: Updates
class DoGState(NamedTuple):
momentum: Updates
# Additional state classes for contrib optimizers
DAdaptAdamWState = Any
MechanicState = Any
MomoState = Any
MomoAdamState = Any
DoWGState = Any
ScaleBySimplifiedAdEMAMixState = Any
DifferentiallyPrivateAggregateState = Any
# Linesearch types
class ScaleByBacktrackingLinesearchState(NamedTuple):
count: chex.Array
f_eval: chex.Array
class ScaleByZoomLinesearchState(NamedTuple):
count: chex.Array
f_eval: chex.Array
class ZoomLinesearchInfo(NamedTuple):
failed: bool
nfev: int
ngev: int
k: int