CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

optimizers.mddocs/

Core Optimizers

Popular optimization algorithms that are ready for immediate use in training loops. These optimizers combine multiple gradient transformations into complete optimization strategies with sensible defaults.

Capabilities

Adam Optimizer

The Adam optimizer with optional Nesterov momentum. Combines adaptive learning rates with momentum for efficient optimization across a wide range of problems.

def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, eps_root=0.0, mu_dtype=None, *, nesterov=False):
    """
    Adam optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
        eps_root: Small constant for numerical stability in denominator (default: 0.0)
        mu_dtype: Optional dtype for momentum accumulator (default: None)
        nesterov: Whether to use Nesterov momentum (default: False)
    
    Returns:
        GradientTransformationExtraArgs
    """

AdamW Optimizer

Adam optimizer with decoupled weight decay. Separates weight decay from gradient-based updates for better generalization.

def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False):
    """
    AdamW optimizer with decoupled weight decay.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
        weight_decay: Weight decay coefficient (default: 1e-4)
        nesterov: Whether to use Nesterov momentum (default: False)
    
    Returns:
        GradientTransformation
    """

Stochastic Gradient Descent

Classic SGD optimizer with optional momentum and Nesterov acceleration.

def sgd(learning_rate, momentum=None, nesterov=False):
    """
    Stochastic gradient descent optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        momentum: Momentum coefficient (default: None for no momentum)
        nesterov: Whether to use Nesterov momentum (default: False)
    
    Returns:
        GradientTransformation
    """

RMSprop Optimizer

RMSprop optimizer with adaptive learning rates based on recent gradient magnitudes.

def rmsprop(learning_rate, decay=0.9, eps=1e-8):
    """
    RMSprop optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        decay: Decay rate for moving average of squared gradients (default: 0.9)
        eps: Small constant for numerical stability (default: 1e-8)
    
    Returns:
        GradientTransformation
    """

Adagrad Optimizer

Adagrad optimizer with adaptive learning rates that decrease over time.

def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7):
    """
    Adagrad optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        initial_accumulator_value: Initial value for accumulator (default: 0.1)
        eps: Small constant for numerical stability (default: 1e-7)
    
    Returns:
        GradientTransformation
    """

Adadelta Optimizer

Adadelta optimizer that adapts learning rates based on a moving window of gradient updates.

def adadelta(learning_rate=1.0, rho=0.9, eps=1e-6):
    """
    Adadelta optimizer.
    
    Args:
        learning_rate: Learning rate (default: 1.0)
        rho: Decay rate for moving averages (default: 0.9)
        eps: Small constant for numerical stability (default: 1e-6)
    
    Returns:
        GradientTransformation
    """

Adamax Optimizer

Adamax optimizer, a variant of Adam based on the infinity norm.

def adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
    """
    Adamax optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for exponentially weighted infinity norm (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
    
    Returns:
        GradientTransformation
    """

Nadam Optimizer

Nesterov-accelerated Adam optimizer combining Adam with Nesterov momentum.

def nadam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
    """
    Nadam optimizer (Nesterov-accelerated Adam).
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
    
    Returns:
        GradientTransformation
    """

AdaBelief Optimizer

AdaBelief optimizer that adapts the step size according to the "belief" in the observed gradients.

def adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16, *, nesterov=False):
    """
    AdaBelief optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-16)
        eps_root: Small constant for numerical stability in denominator (default: 1e-16)
        nesterov: Whether to use Nesterov momentum (default: False)
    
    Returns:
        GradientTransformation
    """

Usage Example

import optax
import jax.numpy as jnp

# Initialize parameters
params = {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros((5,))}

# Create different optimizers
adam_opt = optax.adam(learning_rate=0.001)
sgd_opt = optax.sgd(learning_rate=0.01, momentum=0.9)
adamw_opt = optax.adamw(learning_rate=0.001, weight_decay=1e-4)

# Initialize optimizer state
adam_state = adam_opt.init(params)
sgd_state = sgd_opt.init(params)
adamw_state = adamw_opt.init(params)

# In training loop (example with Adam)
def training_step(params, opt_state, gradients):
    updates, new_opt_state = adam_opt.update(gradients, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json