A gradient processing and optimization library in JAX
—
Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These transformations can be combined using chain() to build custom optimization strategies with fine-grained control over gradient processing.
Combine multiple gradient transformations into a single optimizer.
def chain(*args):
"""
Chain multiple gradient transformations.
Args:
*args: Variable number of GradientTransformation objects
Returns:
GradientTransformationExtraArgs: Combined transformation
"""
def named_chain(**transformations):
"""
Chain transformations with names for easier debugging.
Args:
**transformations: Named GradientTransformation objects
Returns:
GradientTransformation: Combined transformation with named states
"""def scale(step_size):
"""
Scale updates by a constant factor.
Args:
step_size: Scaling factor (typically negative learning rate)
Returns:
GradientTransformation
"""
def scale_by_learning_rate(learning_rate):
"""
Scale updates by learning rate (with negative sign).
Args:
learning_rate: Learning rate value or schedule
Returns:
GradientTransformation
"""
def scale_by_schedule(schedule):
"""
Scale updates by a schedule function.
Args:
schedule: Schedule function taking step count and returning scale factor
Returns:
GradientTransformation
"""def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False):
"""
Scale updates using Adam-style adaptive scaling.
Args:
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
nesterov: Whether to use Nesterov momentum (default: False)
Returns:
GradientTransformation
"""
def scale_by_rms(decay=0.9, eps=1e-8):
"""
Scale updates by root mean square of gradients.
Args:
decay: Decay rate for moving average (default: 0.9)
eps: Small constant for numerical stability (default: 1e-8)
Returns:
GradientTransformation
"""
def scale_by_stddev(decay=0.9, eps=1e-8):
"""
Scale updates by standard deviation of gradients.
Args:
decay: Decay rate for moving average (default: 0.9)
eps: Small constant for numerical stability (default: 1e-8)
Returns:
GradientTransformation
"""def trace(decay, nesterov=False, accumulator_dtype=None):
"""
Add momentum/trace to gradient updates.
Args:
decay: Decay rate for momentum (default: 0.9)
nesterov: Whether to use Nesterov momentum (default: False)
accumulator_dtype: Data type for accumulator (default: None)
Returns:
GradientTransformation
"""
def ema(decay, debias=True, accumulator_dtype=None):
"""
Exponential moving average of parameters.
Args:
decay: Decay rate for moving average (default: 0.9)
debias: Whether to debias the moving average (default: True)
accumulator_dtype: Data type for accumulator (default: None)
Returns:
GradientTransformation
"""def clip(max_delta):
"""
Clip updates element-wise to maximum absolute value.
Args:
max_delta: Maximum absolute value for updates
Returns:
GradientTransformation
"""
def clip_by_global_norm(max_norm):
"""
Clip updates by global norm.
Args:
max_norm: Maximum global norm for updates
Returns:
GradientTransformation
"""
def clip_by_block_rms(threshold):
"""
Clip updates by block-wise RMS.
Args:
threshold: RMS threshold for clipping
Returns:
GradientTransformation
"""
def adaptive_grad_clip(clipping, eps=1e-3):
"""
Adaptive gradient clipping.
Args:
clipping: Clipping threshold
eps: Small constant for numerical stability (default: 1e-3)
Returns:
GradientTransformation
"""
def per_example_global_norm_clip(l2_norm_clip, single_batch_element=False):
"""
Per-example gradient clipping for differential privacy.
Args:
l2_norm_clip: L2 norm clipping threshold
single_batch_element: Whether input is a single batch element (default: False)
Returns:
GradientTransformation
"""def add_decayed_weights(weight_decay, mask=None):
"""
Add L2 weight decay (weight regularization).
Args:
weight_decay: Weight decay coefficient
mask: Optional mask for parameter selection
Returns:
GradientTransformation
"""
def add_noise(eta, gamma, seed):
"""
Add gradient noise for improved generalization.
Args:
eta: Noise scaling parameter
gamma: Annealing rate for noise
seed: Random seed
Returns:
GradientTransformation
"""def centralize():
"""
Centralize gradients by subtracting their mean.
Returns:
GradientTransformation
"""
def normalize_by_update_norm():
"""
Normalize updates by their norm.
Returns:
GradientTransformation
"""
def scale_by_trust_ratio():
"""
Scale updates by trust ratio (parameter norm / update norm).
Returns:
GradientTransformation
"""def apply_if_finite(transformation):
"""
Apply transformation only if gradients are finite.
Args:
transformation: Transformation to apply conditionally
Returns:
GradientTransformation
"""
def apply_every(k, transformation):
"""
Apply transformation every k steps.
Args:
k: Step interval
transformation: Transformation to apply periodically
Returns:
GradientTransformation
"""
def conditionally_transform(condition_fn, transformation):
"""
Apply transformation based on condition function.
Args:
condition_fn: Function that returns boolean condition
transformation: Transformation to apply conditionally
Returns:
GradientTransformation
"""def partition(selector_fn, *transformations):
"""
Apply different transformations to different parameter subsets.
Args:
selector_fn: Function to select parameter subsets
*transformations: Transformations for each subset
Returns:
GradientTransformation
"""
def masked(mask_fn, transformation):
"""
Apply transformation with parameter masking.
Args:
mask_fn: Function to generate parameter mask
transformation: Transformation to apply with mask
Returns:
GradientTransformation
"""def keep_params_nonnegative():
"""
Keep parameters non-negative by projecting to positive orthant.
Returns:
GradientTransformation
"""
def zero_nans():
"""
Set NaN gradients to zero.
Returns:
GradientTransformation
"""class MultiSteps:
"""Multi-step gradient accumulation."""
def __init__(self, every_k_schedule, use_grad_mean=True):
"""
Initialize multi-step accumulation.
Args:
every_k_schedule: Schedule for accumulation steps
use_grad_mean: Whether to use gradient mean instead of sum (default: True)
"""
def skip_not_finite(updates, state, params=None):
"""
Skip updates that are not finite.
Args:
updates: Gradient updates
state: Optimizer state
params: Optional parameters
Returns:
Tuple of (updates, state)
"""
def skip_large_updates(updates, state, max_norm):
"""
Skip updates with norm larger than threshold.
Args:
updates: Gradient updates
state: Optimizer state
max_norm: Maximum allowed update norm
Returns:
Tuple of (updates, state)
"""import optax
# Create custom optimizer by chaining transformations
custom_optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping
optax.add_decayed_weights(weight_decay=1e-4), # Weight decay
optax.scale_by_adam(b1=0.9, b2=0.999), # Adam scaling
optax.scale(-0.001) # Learning rate
)
# Initialize with parameters
params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros((5,))}
opt_state = custom_optimizer.init(params)# Apply different learning rates to different parameter groups
def is_bias(path, param):
return 'bias' in path
bias_tx = optax.scale(-0.01) # Higher learning rate for biases
weight_tx = optax.scale(-0.001) # Lower learning rate for weights
partitioned_optimizer = optax.partition(is_bias, bias_tx, weight_tx)
# Apply transformation only every 5 steps
sparse_optimizer = optax.apply_every(5, optax.adam(0.001))# Robust optimizer with multiple safeguards
robust_optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Prevent exploding gradients
optax.apply_if_finite( # Skip non-finite updates
optax.chain(
optax.centralize(), # Center gradients
optax.scale_by_adam(), # Adaptive scaling
optax.add_decayed_weights(1e-4), # Weight regularization
)
),
optax.scale_by_schedule( # Learning rate schedule
optax.cosine_decay_schedule(0.001, 1000)
)
)Install with Tessl CLI
npx tessl i tessl/pypi-optax