CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

transformations.mddocs/

Gradient Transformations

Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These transformations can be combined using chain() to build custom optimization strategies with fine-grained control over gradient processing.

Capabilities

Chaining Transformations

Combine multiple gradient transformations into a single optimizer.

def chain(*args):
    """
    Chain multiple gradient transformations.
    
    Args:
        *args: Variable number of GradientTransformation objects
    
    Returns:
        GradientTransformationExtraArgs: Combined transformation
    """

def named_chain(**transformations):
    """
    Chain transformations with names for easier debugging.
    
    Args:
        **transformations: Named GradientTransformation objects
    
    Returns:
        GradientTransformation: Combined transformation with named states
    """

Scaling Transformations

Basic Scaling

def scale(step_size):
    """
    Scale updates by a constant factor.
    
    Args:
        step_size: Scaling factor (typically negative learning rate)
    
    Returns:
        GradientTransformation
    """

def scale_by_learning_rate(learning_rate):
    """
    Scale updates by learning rate (with negative sign).
    
    Args:
        learning_rate: Learning rate value or schedule
    
    Returns:
        GradientTransformation
    """

def scale_by_schedule(schedule):
    """
    Scale updates by a schedule function.
    
    Args:
        schedule: Schedule function taking step count and returning scale factor
    
    Returns:
        GradientTransformation
    """

Adaptive Scaling

def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False):
    """
    Scale updates using Adam-style adaptive scaling.
    
    Args:
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
        nesterov: Whether to use Nesterov momentum (default: False)
    
    Returns:
        GradientTransformation
    """

def scale_by_rms(decay=0.9, eps=1e-8):
    """
    Scale updates by root mean square of gradients.
    
    Args:
        decay: Decay rate for moving average (default: 0.9)
        eps: Small constant for numerical stability (default: 1e-8)
    
    Returns:
        GradientTransformation
    """

def scale_by_stddev(decay=0.9, eps=1e-8):
    """
    Scale updates by standard deviation of gradients.
    
    Args:
        decay: Decay rate for moving average (default: 0.9)
        eps: Small constant for numerical stability (default: 1e-8)
    
    Returns:
        GradientTransformation
    """

Momentum and Accumulation

def trace(decay, nesterov=False, accumulator_dtype=None):
    """
    Add momentum/trace to gradient updates.
    
    Args:
        decay: Decay rate for momentum (default: 0.9)
        nesterov: Whether to use Nesterov momentum (default: False)
        accumulator_dtype: Data type for accumulator (default: None)
    
    Returns:
        GradientTransformation
    """

def ema(decay, debias=True, accumulator_dtype=None):
    """
    Exponential moving average of parameters.
    
    Args:
        decay: Decay rate for moving average (default: 0.9)
        debias: Whether to debias the moving average (default: True)
        accumulator_dtype: Data type for accumulator (default: None)
    
    Returns:
        GradientTransformation
    """

Gradient Clipping

def clip(max_delta):
    """
    Clip updates element-wise to maximum absolute value.
    
    Args:
        max_delta: Maximum absolute value for updates
    
    Returns:
        GradientTransformation
    """

def clip_by_global_norm(max_norm):
    """
    Clip updates by global norm.
    
    Args:
        max_norm: Maximum global norm for updates
    
    Returns:
        GradientTransformation
    """

def clip_by_block_rms(threshold):
    """
    Clip updates by block-wise RMS.
    
    Args:
        threshold: RMS threshold for clipping
    
    Returns:
        GradientTransformation
    """

def adaptive_grad_clip(clipping, eps=1e-3):
    """
    Adaptive gradient clipping.
    
    Args:
        clipping: Clipping threshold
        eps: Small constant for numerical stability (default: 1e-3)
    
    Returns:
        GradientTransformation
    """

def per_example_global_norm_clip(l2_norm_clip, single_batch_element=False):
    """
    Per-example gradient clipping for differential privacy.
    
    Args:
        l2_norm_clip: L2 norm clipping threshold
        single_batch_element: Whether input is a single batch element (default: False)
    
    Returns:
        GradientTransformation
    """

Regularization

def add_decayed_weights(weight_decay, mask=None):
    """
    Add L2 weight decay (weight regularization).
    
    Args:
        weight_decay: Weight decay coefficient
        mask: Optional mask for parameter selection
    
    Returns:
        GradientTransformation
    """

def add_noise(eta, gamma, seed):
    """
    Add gradient noise for improved generalization.
    
    Args:
        eta: Noise scaling parameter
        gamma: Annealing rate for noise
        seed: Random seed
    
    Returns:
        GradientTransformation
    """

Conditioning and Normalization

def centralize():
    """
    Centralize gradients by subtracting their mean.
    
    Returns:
        GradientTransformation
    """

def normalize_by_update_norm():
    """
    Normalize updates by their norm.
    
    Returns:
        GradientTransformation
    """

def scale_by_trust_ratio():
    """
    Scale updates by trust ratio (parameter norm / update norm).
    
    Returns:
        GradientTransformation
    """

Conditional Operations

def apply_if_finite(transformation):
    """
    Apply transformation only if gradients are finite.
    
    Args:
        transformation: Transformation to apply conditionally
    
    Returns:
        GradientTransformation
    """

def apply_every(k, transformation):
    """
    Apply transformation every k steps.
    
    Args:
        k: Step interval
        transformation: Transformation to apply periodically
    
    Returns:
        GradientTransformation
    """

def conditionally_transform(condition_fn, transformation):
    """
    Apply transformation based on condition function.
    
    Args:
        condition_fn: Function that returns boolean condition
        transformation: Transformation to apply conditionally
    
    Returns:
        GradientTransformation
    """

Parameter Partitioning

def partition(selector_fn, *transformations):
    """
    Apply different transformations to different parameter subsets.
    
    Args:
        selector_fn: Function to select parameter subsets
        *transformations: Transformations for each subset
    
    Returns:
        GradientTransformation
    """

def masked(mask_fn, transformation):
    """
    Apply transformation with parameter masking.
    
    Args:
        mask_fn: Function to generate parameter mask
        transformation: Transformation to apply with mask
    
    Returns:
        GradientTransformation
    """

Parameter Constraints

def keep_params_nonnegative():
    """
    Keep parameters non-negative by projecting to positive orthant.
    
    Returns:
        GradientTransformation
    """

def zero_nans():
    """
    Set NaN gradients to zero.
    
    Returns:
        GradientTransformation
    """

Multi-Step Accumulation

class MultiSteps:
    """Multi-step gradient accumulation."""
    
    def __init__(self, every_k_schedule, use_grad_mean=True):
        """
        Initialize multi-step accumulation.
        
        Args:
            every_k_schedule: Schedule for accumulation steps
            use_grad_mean: Whether to use gradient mean instead of sum (default: True)
        """

def skip_not_finite(updates, state, params=None):
    """
    Skip updates that are not finite.
    
    Args:
        updates: Gradient updates
        state: Optimizer state
        params: Optional parameters
    
    Returns:
        Tuple of (updates, state)
    """

def skip_large_updates(updates, state, max_norm):
    """
    Skip updates with norm larger than threshold.
    
    Args:
        updates: Gradient updates
        state: Optimizer state
        max_norm: Maximum allowed update norm
    
    Returns:
        Tuple of (updates, state)
    """

Usage Examples

Custom Optimizer with Chaining

import optax

# Create custom optimizer by chaining transformations
custom_optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),           # Gradient clipping
    optax.add_decayed_weights(weight_decay=1e-4),  # Weight decay
    optax.scale_by_adam(b1=0.9, b2=0.999),    # Adam scaling
    optax.scale(-0.001)                        # Learning rate
)

# Initialize with parameters
params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros((5,))}
opt_state = custom_optimizer.init(params)

Conditional and Partitioned Updates

# Apply different learning rates to different parameter groups
def is_bias(path, param):
    return 'bias' in path

bias_tx = optax.scale(-0.01)      # Higher learning rate for biases
weight_tx = optax.scale(-0.001)   # Lower learning rate for weights

partitioned_optimizer = optax.partition(is_bias, bias_tx, weight_tx)

# Apply transformation only every 5 steps
sparse_optimizer = optax.apply_every(5, optax.adam(0.001))

Robust Training Setup

# Robust optimizer with multiple safeguards
robust_optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),           # Prevent exploding gradients
    optax.apply_if_finite(                     # Skip non-finite updates
        optax.chain(
            optax.centralize(),                 # Center gradients
            optax.scale_by_adam(),             # Adaptive scaling
            optax.add_decayed_weights(1e-4),   # Weight regularization
        )
    ),
    optax.scale_by_schedule(                   # Learning rate schedule
        optax.cosine_decay_schedule(0.001, 1000)
    )
)

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json