A gradient processing and optimization library in JAX
—
Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees. These functions provide essential infrastructure for building and using optimizers effectively.
def apply_updates(params, updates):
"""
Apply parameter updates to current parameters.
Args:
params: Current parameters (pytree)
updates: Parameter updates (pytree with same structure as params)
Returns:
Updated parameters (pytree)
"""
def incremental_update(new_tensors, old_tensors, step_size):
"""
Compute incremental update between tensor sets.
Args:
new_tensors: New tensor values
old_tensors: Old tensor values
step_size: Step size for interpolation
Returns:
Incrementally updated tensors
"""
def periodic_update(new_tensors, old_tensors, steps, update_period):
"""
Update tensors periodically based on step count.
Args:
new_tensors: New tensor values
old_tensors: Old tensor values
steps: Current step count
update_period: Period for updates
Returns:
Conditionally updated tensors
"""def safe_norm(x, min_norm=0.0, ord=None):
"""
Numerically stable norm computation.
Args:
x: Input tensor
min_norm: Minimum norm value for stability (default: 0.0)
ord: Norm order (None, 1, 2, 'fro', etc.) (default: None for L2)
Returns:
Norm value with numerical stability
"""
def safe_root_mean_squares(x, min_rms=0.0):
"""
Numerically stable root mean square computation.
Args:
x: Input tensor
min_rms: Minimum RMS value for stability (default: 0.0)
Returns:
RMS value with numerical stability
"""
def safe_increment(count):
"""
Safely increment counter with overflow protection.
Args:
count: Current counter value
Returns:
Incremented counter value
"""
def safe_int32_increment(count):
"""
Safely increment int32 counter with overflow protection.
Args:
count: Current int32 counter value
Returns:
Incremented int32 counter value
"""def global_norm(updates):
"""
Compute global norm across all parameters in pytree.
Args:
updates: Parameter updates (pytree)
Returns:
Global norm scalar value
"""
def power_iteration(matrix, num_iters=10, error_tolerance=1e-6, precision=None):
"""
Compute dominant eigenvalue and eigenvector using power iteration.
Args:
matrix: Input matrix
num_iters: Maximum number of iterations (default: 10)
error_tolerance: Convergence tolerance (default: 1e-6)
precision: Numerical precision (default: None)
Returns:
Tuple of (eigenvalue, eigenvector)
"""
def matrix_inverse_pth_root(matrix, p, num_iters=15, ridge_epsilon=1e-6, error_tolerance=1e-6, precision=None):
"""
Compute matrix inverse p-th root using Newton's method.
Args:
matrix: Input positive definite matrix
p: Root order (e.g., 2 for square root)
num_iters: Maximum iterations (default: 15)
ridge_epsilon: Ridge regularization (default: 1e-6)
error_tolerance: Convergence tolerance (default: 1e-6)
precision: Numerical precision (default: None)
Returns:
Matrix inverse p-th root
"""
def nnls(a, b, max_iters=None, tol=1e-8):
"""
Non-negative least squares solver.
Args:
a: Coefficient matrix
b: Target vector
max_iters: Maximum iterations (default: None for auto)
tol: Convergence tolerance (default: 1e-8)
Returns:
Non-negative solution vector
"""def identity():
"""
Identity transformation that passes gradients unchanged.
Returns:
GradientTransformation
"""
def set_to_zero():
"""
Transformation that sets all gradients to zero.
Returns:
GradientTransformation
"""
def stateless(f):
"""
Create stateless transformation from function.
Args:
f: Function to convert to transformation
Returns:
GradientTransformation
"""
def stateless_with_tree_map(f):
"""
Create stateless transformation with tree mapping.
Args:
f: Function to apply to each leaf of parameter tree
Returns:
GradientTransformation
"""
def with_extra_args_support(transformation):
"""
Add support for extra arguments to transformation.
Args:
transformation: Base transformation to extend
Returns:
GradientTransformationExtraArgs
"""def scale_gradient(inputs, scale):
"""
Scale gradients during forward/backward pass.
Args:
inputs: Input values (forward pass is identity)
scale: Scale factor for gradients in backward pass
Returns:
Inputs (unchanged in forward pass)
"""
def value_and_grad_from_state(fun, argnums=0, has_aux=False):
"""
Compute value and gradient while maintaining state.
Args:
fun: Function to differentiate
argnums: Argument indices to differentiate (default: 0)
has_aux: Whether function returns auxiliary data (default: False)
Returns:
Function that returns (value, grad) tuple
"""def multi_normal(loc, scale_tril, random_key):
"""
Sample from multivariate normal distribution.
Args:
loc: Mean vector
scale_tril: Lower triangular scale matrix
random_key: JAX random key
Returns:
Random sample from multivariate normal
"""# Tree-level operations in optax.tree module
def add(tree1, tree2):
"""Element-wise addition of two pytrees."""
def sub(tree1, tree2):
"""Element-wise subtraction of two pytrees."""
def mul(tree1, tree2):
"""Element-wise multiplication of two pytrees."""
def div(tree1, tree2):
"""Element-wise division of two pytrees."""
def scale(tree, scalar):
"""Scale all elements in pytree by scalar."""
def norm(tree, ord=2):
"""Compute norm of pytree."""
def sum(tree):
"""Sum all elements in pytree."""
def max(tree):
"""Find maximum element in pytree."""def zeros_like(tree):
"""Create pytree of zeros with same structure."""
def ones_like(tree):
"""Create pytree of ones with same structure."""
def full_like(tree, fill_value):
"""Create pytree filled with specified value."""def hungarian_algorithm(cost_matrix):
"""
Hungarian algorithm for solving assignment problems.
Args:
cost_matrix: 2D cost matrix for assignments
Returns:
Optimal assignment indices
"""def tree_map_params(fn, tree):
"""
Map function over parameters in pytree.
Args:
fn: Function to apply to each parameter
tree: Parameter pytree
Returns:
Transformed pytree
"""
def tree_bias_correction(moment, decay, count):
"""
Apply bias correction to moment estimates.
Args:
moment: Moment estimate
decay: Decay rate used for moment
count: Step count for bias correction
Returns:
Bias-corrected moment
"""def tree_update_moment(updates, moments, decay, order):
"""
Update moment estimates for optimizer state.
Args:
updates: Current gradient updates
moments: Previous moment estimates
decay: Exponential decay rate
order: Moment order (1 for mean, 2 for variance)
Returns:
Updated moment estimates
"""
def tree_update_moment_per_elem_norm(updates, moments, decay, order):
"""
Update moments with per-element normalization.
Args:
updates: Current gradient updates
moments: Previous moment estimates
decay: Exponential decay rate
order: Moment order
Returns:
Updated moment estimates with per-element normalization
"""
def tree_update_infinity_moment(updates, moments, decay):
"""
Update infinity moments (max absolute values).
Args:
updates: Current gradient updates
moments: Previous infinity moments
decay: Exponential decay rate
Returns:
Updated infinity moments
"""# Type aliases
OptState = chex.ArrayTree # Optimizer state
Params = chex.ArrayTree # Model parameters
Updates = Params # Gradient updates
Schedule = Callable[[chex.Numeric], chex.Numeric] # Schedule function
ScalarOrSchedule = Union[float, jax.Array, Schedule] # Flexible numeric type
MaskOrFn = Union[Any, Callable[[Params], Any]] # Mask or masking function
# Function type definitions
TransformInitFn = Callable[[Params], OptState]
TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
TransformUpdateExtraArgsFn = Callable[[Updates, OptState, Optional[Params], ...], Tuple[Updates, OptState]]
# Core classes
class GradientTransformation(NamedTuple):
"""Core gradient transformation with init and update functions."""
init: TransformInitFn
update: TransformUpdateFn
class GradientTransformationExtraArgs(NamedTuple):
"""Extended transformation supporting extra arguments."""
init: TransformInitFn
update: TransformUpdateExtraArgsFn
class EmptyState(NamedTuple):
"""Empty state for stateless transformations."""
pass
class FactoredState(NamedTuple):
"""State for factorized operations."""
count: chex.Array
v_row: chex.ArrayTree
v_col: chex.ArrayTreeimport jax.numpy as jnp
import optax
# Parameters and updates
params = {'w': jnp.ones((5, 3)), 'b': jnp.zeros((3,))}
updates = {'w': jnp.ones((5, 3)) * 0.01, 'b': jnp.ones((3,)) * 0.001}
# Apply updates
new_params = optax.apply_updates(params, updates)
# Compute global norm
grad_norm = optax.global_norm(updates)
print(f"Global gradient norm: {grad_norm}")# Safe operations for numerical stability
x = jnp.array([1e-8, 1e-6, 1.0, 1e6])
safe_norm_val = optax.safe_norm(x, min_norm=1e-8)
safe_rms_val = optax.safe_root_mean_squares(x, min_rms=1e-8)
# Safe counting
step_count = jnp.array(2147483647, dtype=jnp.int32) # Near int32 max
next_count = optax.safe_int32_increment(step_count)# Tree arithmetic
tree1 = {'a': jnp.array([1, 2, 3]), 'b': jnp.array([4, 5])}
tree2 = {'a': jnp.array([6, 7, 8]), 'b': jnp.array([9, 10])}
# Element-wise operations
sum_tree = optax.tree.add(tree1, tree2)
scaled_tree = optax.tree.scale(tree1, 0.5)
tree_norm = optax.tree.norm(tree1)
# Tree utilities
zero_tree = optax.tree.zeros_like(tree1)
ones_tree = optax.tree.ones_like(tree1)# Create custom stateless transformation
def my_scaling_fn(updates):
return jax.tree_map(lambda x: 0.01 * x, updates)
my_transform = optax.stateless(my_scaling_fn)
# Use with other transformations
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
my_transform,
optax.scale_by_adam()
)# Matrix operations for second-order methods
def compute_preconditioner(gradients):
# Flatten gradients for matrix operations
flat_grads = jax.flatten_util.ravel_pytree(gradients)[0]
# Compute outer product approximation
outer_prod = jnp.outer(flat_grads, flat_grads)
# Compute matrix inverse square root
inv_sqrt = optax.matrix_inverse_pth_root(
outer_prod + 1e-6 * jnp.eye(len(flat_grads)),
p=2,
num_iters=10
)
return inv_sqrt
# Gradient scaling with state
def scale_with_state(inputs, state):
scale_factor = jnp.sqrt(state['step_count'])
return optax.scale_gradient(inputs, scale_factor)Install with Tessl CLI
npx tessl i tessl/pypi-optax