CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

advanced-optimizers.mddocs/

Advanced Optimizers

Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers. These optimizers implement cutting-edge techniques and may require more careful tuning than core optimizers.

Capabilities

Lion Optimizer

Lion (Evolved Sign Momentum) optimizer that uses sign-based updates for memory efficiency and competitive performance.

def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0):
    """
    Lion optimizer (Evolved Sign Momentum).
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for momentum (default: 0.9)
        b2: Exponential decay rate for moving average (default: 0.99)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation
    """

LARS Optimizer

Layer-wise Adaptive Rate Scaling (LARS) optimizer for large batch training.

def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.):
    """
    LARS (Layer-wise Adaptive Rate Scaling) optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        weight_decay: Weight decay coefficient (default: 0.0)
        trust_coefficient: Trust coefficient for layer-wise adaptation (default: 0.001)
        eps: Small constant for numerical stability (default: 0.0)
    
    Returns:
        GradientTransformation
    """

LAMB Optimizer

Layer-wise Adaptive Moments optimizer for Batch training, designed for large batch sizes.

def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None):
    """
    LAMB (Layer-wise Adaptive Moments optimizer for Batch training) optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-6)
        weight_decay: Weight decay coefficient (default: 0.0)
        mask: Optional mask for parameter selection
    
    Returns:
        GradientTransformation
    """

L-BFGS Optimizer

Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.

def lbfgs(learning_rate, memory_size=10, scale_init_preconditioner=True):
    """
    L-BFGS quasi-Newton optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        memory_size: Number of previous gradients to store (default: 10)
        scale_init_preconditioner: Whether to scale initial preconditioner (default: True)
    
    Returns:
        GradientTransformation
    """

Yogi Optimizer

Yogi optimizer that controls the increase in effective learning rate to avoid rapid convergence.

def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6):
    """
    Yogi optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-3)
        initial_accumulator: Initial value for accumulator (default: 1e-6)
    
    Returns:
        GradientTransformation
    """

NovoGrad Optimizer

NovoGrad optimizer that combines adaptive learning rates with gradient normalization.

def novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-6, weight_decay=0.):
    """
    NovoGrad optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.25)
        eps: Small constant for numerical stability (default: 1e-6)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation
    """

RAdam Optimizer

Rectified Adam optimizer that addresses the variance issue in early training stages.

def radam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, threshold=5.0):
    """
    RAdam (Rectified Adam) optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
        threshold: Threshold for variance tractability (default: 5.0)
    
    Returns:
        GradientTransformation
    """

SM3 Optimizer

SM3 optimizer designed for sparse gradients with memory-efficient second moments.

def sm3(learning_rate, momentum=0.9):
    """
    SM3 optimizer for sparse gradients.
    
    Args:
        learning_rate: Learning rate or schedule
        momentum: Momentum coefficient (default: 0.9)
    
    Returns:
        GradientTransformation
    """

Fromage Optimizer

Frobenius matched gradient descent optimizer.

def fromage(learning_rate):
    """
    Fromage (Frobenius matched gradient descent) optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
    
    Returns:
        GradientTransformation
    """

Specialized SGD Variants

Noisy SGD

SGD with gradient noise injection for improved generalization.

def noisy_sgd(learning_rate, eta=0.01):
    """
    Noisy SGD with gradient noise injection.
    
    Args:
        learning_rate: Learning rate or schedule
        eta: Noise scaling parameter (default: 0.01)
    
    Returns:
        GradientTransformation
    """

Sign SGD

SGD using only the sign of gradients.

def sign_sgd(learning_rate):
    """
    Sign SGD optimizer using gradient signs only.
    
    Args:
        learning_rate: Learning rate or schedule
    
    Returns:
        GradientTransformation
    """

Polyak SGD

SGD with Polyak momentum.

def polyak_sgd(learning_rate, polyak_momentum=0.9):
    """
    SGD with Polyak momentum.
    
    Args:
        learning_rate: Learning rate or schedule
        polyak_momentum: Polyak momentum coefficient (default: 0.9)
    
    Returns:
        GradientTransformation
    """

RProp Optimizer

Resilient backpropagation optimizer that uses only gradient signs.

def rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-6, max_step_size=50.):
    """
    RProp (Resilient backpropagation) optimizer.
    
    Args:
        learning_rate: Initial step size
        eta_minus: Factor for decreasing step size (default: 0.5)
        eta_plus: Factor for increasing step size (default: 1.2)
        min_step_size: Minimum step size (default: 1e-6)
        max_step_size: Maximum step size (default: 50.0)
    
    Returns:
        GradientTransformation
    """

Optimistic Methods

Optimistic Gradient Descent

Optimistic gradient descent for saddle point problems.

def optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0):
    """
    Optimistic gradient descent.
    
    Args:
        learning_rate: Learning rate or schedule
        alpha: Extrapolation coefficient (default: 1.0)
        beta: Update coefficient (default: 1.0)
    
    Returns:
        GradientTransformation
    """

Optimistic Adam

Optimistic variant of Adam optimizer.

def optimistic_adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
    """
    Optimistic Adam optimizer.
    
    Args:
        learning_rate: Learning rate or schedule
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
    
    Returns:
        GradientTransformation
    """

Lookahead Wrapper

Lookahead optimizer that can wrap any base optimizer.

def lookahead(fast_optimizer, lookahead_steps=5, lookahead_alpha=0.5):
    """
    Lookahead optimizer wrapper.
    
    Args:
        fast_optimizer: Base optimizer to wrap
        lookahead_steps: Number of fast optimizer steps before lookahead (default: 5)
        lookahead_alpha: Interpolation factor for lookahead (default: 0.5)
    
    Returns:
        GradientTransformation
    """

Usage Example

import optax
import jax.numpy as jnp

# Initialize parameters
params = {'weights': jnp.ones((100, 50)), 'bias': jnp.zeros((50,))}

# Advanced optimizers for different scenarios
lion_opt = optax.lion(learning_rate=0.0001)  # Memory efficient
lars_opt = optax.lars(learning_rate=0.01)    # Large batch training
lamb_opt = optax.lamb(learning_rate=0.001)   # Large batch training
lbfgs_opt = optax.lbfgs(learning_rate=1.0)   # Second-order method

# Lookahead wrapper
base_opt = optax.adam(learning_rate=0.001)
lookahead_opt = optax.lookahead(base_opt, lookahead_steps=5)

# Initialize states
lion_state = lion_opt.init(params)
lookahead_state = lookahead_opt.init(params)

# Usage in training loop
def training_step(params, opt_state, gradients, optimizer):
    updates, new_opt_state = optimizer.update(gradients, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json