CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

contrib.mddocs/

Experimental Optimizers (contrib)

The optax.contrib module contains experimental optimizers and techniques under active development. These are cutting-edge optimization methods that may not be as stable as the core optimizers but represent the latest research in optimization.

Note: Experimental features may have API changes in future versions.

Capabilities

Advanced Adaptive Optimizers

Sharpness-Aware Minimization (SAM)

def sam(base_optimizer, rho=0.05, normalize=True):
    """
    Sharpness-Aware Minimization optimizer.
    
    Args:
        base_optimizer: Base optimizer to use (e.g., SGD, Adam)
        rho: Neighborhood size for sharpness computation (default: 0.05)
        normalize: Whether to normalize perturbation (default: True)
    
    Returns:
        GradientTransformation: SAM optimizer
    """

Prodigy Optimizer

def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0):
    """
    Prodigy adaptive learning rate optimizer.
    
    Args:
        learning_rate: Initial learning rate (default: 1.0)
        eps: Numerical stability parameter (default: 1e-8)
        beta1: First moment decay rate (default: 0.9)
        beta2: Second moment decay rate (default: 0.999)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation: Prodigy optimizer
    """

Sophia Optimizer

def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4):
    """
    Sophia optimizer using second-order information.
    
    Args:
        learning_rate: Learning rate
        beta1: First moment decay rate (default: 0.965)
        beta2: Second moment decay rate (default: 0.99)  
        eps: Numerical stability parameter (default: 1e-8)
        weight_decay: Weight decay coefficient (default: 1e-4)
    
    Returns:
        GradientTransformation: Sophia optimizer
    """

Schedule-Free Optimizers

Schedule-Free AdamW

def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0):
    """
    Schedule-free AdamW optimizer that doesn't require learning rate schedules.
    
    Args:
        learning_rate: Learning rate (default: 0.0025)
        beta1: First moment decay rate (default: 0.9)
        beta2: Second moment decay rate (default: 0.999)
        eps: Numerical stability parameter (default: 1e-8)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation: Schedule-free AdamW optimizer
    """

def schedule_free_sgd(learning_rate=1.0, momentum=0.9, weight_decay=0.0):
    """
    Schedule-free SGD optimizer.
    
    Args:
        learning_rate: Learning rate (default: 1.0)
        momentum: Momentum coefficient (default: 0.9)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation: Schedule-free SGD optimizer
    """

def schedule_free_eval_params(optimizer_state, step_count):
    """
    Extract evaluation parameters from schedule-free optimizer state.
    
    Args:
        optimizer_state: State from schedule-free optimizer
        step_count: Current training step count
    
    Returns:
        Parameters suitable for evaluation/inference
    """

Momentum-Based Methods

Muon Optimizer

def muon(learning_rate, momentum=0.95, nesterov=False):
    """
    Muon optimizer with improved momentum handling.
    
    Args:
        learning_rate: Learning rate
        momentum: Momentum coefficient (default: 0.95)
        nesterov: Whether to use Nesterov momentum (default: False)
    
    Returns:
        GradientTransformation: Muon optimizer
    """

MoMo (Momentum Modulation)

def momo(learning_rate, momentum=0.9):
    """
    MoMo optimizer with momentum modulation.
    
    Args:
        learning_rate: Learning rate
        momentum: Base momentum coefficient (default: 0.9)
    
    Returns:
        GradientTransformation: MoMo optimizer
    """

def momo_adam(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8):
    """
    MoMo-Adam combining momentum modulation with Adam.
    
    Args:
        learning_rate: Learning rate
        beta1: First moment decay rate (default: 0.9)
        beta2: Second moment decay rate (default: 0.999)
        eps: Numerical stability parameter (default: 1e-8)
    
    Returns:
        GradientTransformation: MoMo-Adam optimizer
    """

Specialized Methods

DoG (Difference of Gaussians) and DoWG

def dog(learning_rate, rho=0.05, eps=1e-8):
    """
    DoG (Difference of Gaussians) optimizer.
    
    Args:
        learning_rate: Learning rate
        rho: Difference parameter (default: 0.05)
        eps: Numerical stability parameter (default: 1e-8)
    
    Returns:
        GradientTransformation: DoG optimizer
    """

def dowg(learning_rate, rho=0.05, eps=1e-8, weight_decay=0.0):
    """
    DoWG (DoG with Weight decay) optimizer.
    
    Args:
        learning_rate: Learning rate
        rho: Difference parameter (default: 0.05)
        eps: Numerical stability parameter (default: 1e-8)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation: DoWG optimizer
    """

ADOPT

def adopt(learning_rate, eps=1e-8, beta1=0.9, beta2=0.9999, weight_decay=0.0):
    """
    ADOPT optimizer with adaptive learning rates.
    
    Args:
        learning_rate: Learning rate
        eps: Numerical stability parameter (default: 1e-8)
        beta1: First moment decay rate (default: 0.9)
        beta2: Second moment decay rate (default: 0.9999)
        weight_decay: Weight decay coefficient (default: 0.0)
    
    Returns:
        GradientTransformation: ADOPT optimizer
    """

Privacy-Preserving Methods

Differential Privacy

def differentially_private_aggregate(
    inner_agg_factory,
    l2_norm_bound,
    noise_multiplier,
    seed=None
):
    """
    Differentially private gradient aggregation.
    
    Args:
        inner_agg_factory: Base aggregation function
        l2_norm_bound: L2 norm bound for gradient clipping
        noise_multiplier: Noise multiplier for privacy
        seed: Random seed (default: None)
    
    Returns:
        GradientTransformation: DP aggregation function
    """

Experimental Adaptive Methods

AdEMAMix

def ademamix(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8, alpha=5.0):
    """
    AdEMAMix optimizer with exponential moving average mixing.
    
    Args:
        learning_rate: Learning rate
        beta1: First moment decay rate (default: 0.9)
        beta2: Second moment decay rate (default: 0.999)
        eps: Numerical stability parameter (default: 1e-8)
        alpha: Mixing parameter (default: 5.0)
    
    Returns:
        GradientTransformation: AdEMAMix optimizer
    """

COCOB

def cocob():
    """
    COCOB (Coin-flipping Online Convex Optimization with Budget) optimizer.
    
    Returns:
        GradientTransformation: COCOB optimizer (parameter-free)
    """

Usage Examples

import optax

# Using SAM for better generalization
base_optimizer = optax.sgd(0.1)
sam_optimizer = optax.contrib.sam(base_optimizer, rho=0.05)

# Using schedule-free optimizers
sf_adamw = optax.contrib.schedule_free_adamw(learning_rate=0.001)

# Using experimental adaptive methods
prodigy_opt = optax.contrib.prodigy(learning_rate=1.0)
sophia_opt = optax.contrib.sophia(learning_rate=0.001)

# Training loop with schedule-free optimizer
opt_state = sf_adamw.init(params)
for step in range(num_steps):
    grads = compute_gradients(params, data)
    updates, opt_state = sf_adamw.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    # Extract evaluation parameters (for schedule-free methods)
    if step % eval_interval == 0:
        eval_params = optax.contrib.schedule_free_eval_params(opt_state, step)
        eval_loss = evaluate(eval_params, eval_data)

Import

import optax.contrib
# or
from optax.contrib import sam, prodigy, schedule_free_adamw

Research Papers

Many contrib optimizers are based on recent research:

  • SAM: "Sharpness-Aware Minimization for Efficiently Improving Generalization"
  • Prodigy: "Prodigy: An Expeditiously Adaptive Parameter-Free Learner"
  • Sophia: "Sophia: A Scalable Stochastic Second-order Optimizer"
  • Schedule-Free: "The Road Less Scheduled"
  • AdEMAMix: "The AdEMAMix Optimizer: Better, Faster, Older"

Refer to the respective papers for detailed algorithmic descriptions and theoretical analysis.

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json