or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md
tile.json

tessl/pypi-optax

A gradient processing and optimization library in JAX

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/optax@0.2.x

To install, run

npx @tessl/cli install tessl/pypi-optax@0.2.0

index.mddocs/

Optax

A gradient processing and optimization library in JAX. Optax provides modular building blocks that can be easily recombined to create custom optimizers and gradient processing components. The library offers implementations of many popular optimizers, loss functions, and gradient transformations with a focus on composability and research productivity.

Package Information

  • Package Name: optax
  • Language: Python
  • Installation: pip install optax
  • Documentation: https://optax.readthedocs.io/

Core Imports

import optax

Common usage patterns:

# Import specific optimizers
from optax import adam, sgd, adamw

# Import transformations and utilities
from optax import apply_updates, chain

# Import loss functions
from optax import l2_loss, softmax_cross_entropy

# Import schedules
from optax import linear_schedule, cosine_decay_schedule

Basic Usage

import jax
import jax.numpy as jnp
import optax

# Initialize model parameters
params = {'w': jnp.ones((10,)), 'b': jnp.zeros((1,))}

# Create an optimizer
optimizer = optax.adam(learning_rate=0.001)

# Initialize optimizer state
opt_state = optimizer.init(params)

# Define a simple loss function
def loss_fn(params, x, y):
    pred = params['w'].dot(x) + params['b']
    return optax.l2_loss(pred, y)

# Training step
def train_step(params, opt_state, x, y):
    # Compute gradients
    grads = jax.grad(loss_fn)(params, x, y)
    
    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state

# Example training data
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y = jnp.array([2.0])

# Perform training step
params, opt_state = train_step(params, opt_state, x, y)

Architecture

Optax is built around three key concepts:

  • GradientTransformation: Core abstraction with init and update functions that process gradients
  • Composability: Transformations can be chained together using optax.chain() to create custom optimizers
  • Modularity: Small building blocks that can be recombined in custom ways for research flexibility

The library provides implementations at multiple levels of abstraction:

  • High-level optimizers (adam, sgd, etc.) that are ready to use
  • Mid-level gradient transformations that can be combined
  • Low-level utilities for building custom components

Capabilities

Core Optimizers

Popular optimization algorithms including Adam, SGD, RMSprop, Adagrad, and many others. These are complete optimizers ready for immediate use in training loops.

def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...
def sgd(learning_rate, momentum=None, nesterov=False): ...
def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False): ...
def rmsprop(learning_rate, decay=0.9, eps=1e-8): ...
def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7): ...

Core Optimizers

Advanced Optimizers

Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers.

def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0): ...
def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.): ...
def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None): ...
def lbfgs(learning_rate, ...): ...
def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6): ...

Advanced Optimizers

Gradient Transformations

Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These can be combined using chain() to build custom optimization strategies.

def scale(step_size): ...
def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...
def clip_by_global_norm(max_norm): ...
def add_decayed_weights(weight_decay, mask=None): ...
def trace(decay, nesterov=False, accumulator_dtype=None): ...
def chain(*transformations): ...

Gradient Transformations

Loss Functions

Comprehensive collection of loss functions for classification, regression, and structured prediction tasks.

def l2_loss(predictions, targets): ...
def softmax_cross_entropy(logits, labels, axis=-1): ...
def sigmoid_binary_cross_entropy(logits, labels): ...
def huber_loss(predictions, targets, delta=1.0): ...
def hinge_loss(scores, labels): ...

Loss Functions

Learning Rate Schedules

Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules.

def constant_schedule(value): ...
def linear_schedule(init_value, end_value, transition_steps): ...
def cosine_decay_schedule(init_value, decay_steps, alpha=0.0): ...
def exponential_decay(init_value, decay_rate, transition_steps, ...): ...
def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value): ...

Schedules

Utilities and Tree Operations

Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees.

def apply_updates(params, updates): ...
def global_norm(updates): ...
def safe_norm(x, min_norm=0.0, ord=None): ...
class GradientTransformation: ...
class OptState: ...
class Params: ...

Utilities

Assignment Operations

Linear assignment algorithms including the Hungarian algorithm for solving optimal assignment problems.

def hungarian_algorithm(cost_matrix): ...
def base_hungarian_algorithm(cost_matrix): ...

Assignment Operations

Monte Carlo Gradient Estimation

Utilities for Monte Carlo gradient estimation methods including score function, pathwise, and measure-valued estimators. Note: These functions are deprecated and will be removed in version 0.3.0.

def score_function_jacobians(function, params, dist_builder, rng, num_samples): ...
def pathwise_jacobians(function, params, dist_builder, rng, num_samples): ...
def measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True): ...

Monte Carlo Methods

Perturbation-Based Optimization

Utilities for making non-differentiable functions differentiable through stochastic perturbations.

def make_perturbed_fun(fun, num_samples=1000, sigma=0.1, noise=Gumbel(), use_baseline=True): ...
class Gumbel: ...
class Normal: ...

Perturbations

Constraint Projections

Projection functions for enforcing constraints in optimization by projecting parameters onto feasible sets.

def projection_l2_ball(params, radius=1.0): ...
def projection_simplex(params): ...
def projection_box(params, lower=None, upper=None): ...

Projections

Second-Order Methods

Utilities for second-order optimization including Hessian computations and Fisher information.

def hessian_diag(fun): ...
def fisher_diag(log_likelihood): ...
def hvp(fun, primals, tangents): ...

Second-Order Methods

Tree Utilities

JAX PyTree manipulation utilities for working with nested parameter structures.

def tree_add(tree_a, tree_b): ...
def tree_scale(tree, scalar): ...
def tree_zeros_like(tree): ...

Tree Utilities

Experimental Features

The optax.contrib module contains experimental optimizers and techniques under active development, including SAM, Prodigy, Sophia, and schedule-free optimizers.

# Sharpness-Aware Minimization
def sam(base_optimizer, rho=0.05, normalize=True): ...

# Advanced adaptive optimizers
def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0): ...
def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4): ...

# Schedule-free optimizers
def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0): ...

Experimental Optimizers

Types

# Core type aliases
OptState = chex.ArrayTree  # Optimizer state
Params = chex.ArrayTree    # Model parameters
Updates = Params           # Gradient updates
Schedule = Callable[[chex.Numeric], chex.Numeric]  # Schedule function
ScalarOrSchedule = Union[float, jax.Array, Schedule]

# Core classes
class GradientTransformation(NamedTuple):
    init: Callable[[Params], OptState]
    update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]

class GradientTransformationExtraArgs(NamedTuple):
    init: Callable[[Params], OptState]
    update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
    
class EmptyState(NamedTuple):
    """Empty state for stateless transformations"""
    pass

# Transformation function types
TransformInitFn = Callable[[Params], OptState]
TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
TransformUpdateExtraArgsFn = Callable[..., Tuple[Updates, OptState]]

# Optimizer state classes
class ScaleByAdamState(NamedTuple):
    count: chex.Array
    mu: Updates
    nu: Updates

class ScaleByRmsState(NamedTuple):
    count: chex.Array
    nu: Updates

class ScaleByScheduleState(NamedTuple):
    count: chex.Array

class FactoredState(NamedTuple):
    v_row: chex.Array
    v_col: chex.Array
    v: chex.Array

class LookaheadParams(NamedTuple):
    slow: Params
LookaheadState = LookaheadParams

class ApplyEvery(NamedTuple):
    count: chex.Array
    grad_acc: Updates

# Tree and projection types  
MaskOrFn = Union[chex.Array, Callable[[Params], chex.Array]]
MaskedNode = Any

# Schedule types
WrappedSchedule = Callable[[chex.Numeric], chex.Numeric]

# Assignment types (from optax.assignment)
CostMatrix = chex.Array
Assignment = Tuple[chex.Array, chex.Array]  # (row_indices, col_indices)

# Monte Carlo types (from optax.monte_carlo) - deprecated
ControlVariate = Tuple[Callable, Callable, Callable]
CvState = Any

# Perturbation types (from optax.perturbations)
NoiseDistribution = Any  # Objects with sample() and log_prob() methods

# Contrib optimizer state classes (experimental)
class ScaleByAdemamixState(NamedTuple):
    count: chex.Array
    mu: Updates
    nu: Updates
    
class MuonState(NamedTuple):
    momentum: Updates
    
class COCOBState(NamedTuple):
    sum_grad_squared: Updates
    sum_grad: Updates
    
class DoGState(NamedTuple):
    momentum: Updates
    
# Additional state classes for contrib optimizers
DAdaptAdamWState = Any
MechanicState = Any
MomoState = Any
MomoAdamState = Any
DoWGState = Any
ScaleBySimplifiedAdEMAMixState = Any
DifferentiallyPrivateAggregateState = Any

# Linesearch types
class ScaleByBacktrackingLinesearchState(NamedTuple):
    count: chex.Array
    f_eval: chex.Array

class ScaleByZoomLinesearchState(NamedTuple):
    count: chex.Array
    f_eval: chex.Array
    
class ZoomLinesearchInfo(NamedTuple):
    failed: bool
    nfev: int
    ngev: int
    k: int