CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

utilities.mddocs/

Utilities and Tree Operations

Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees. These functions provide essential infrastructure for building and using optimizers effectively.

Capabilities

Parameter Updates

Core Update Functions

def apply_updates(params, updates):
    """
    Apply parameter updates to current parameters.
    
    Args:
        params: Current parameters (pytree)
        updates: Parameter updates (pytree with same structure as params)
    
    Returns:
        Updated parameters (pytree)
    """

def incremental_update(new_tensors, old_tensors, step_size):
    """
    Compute incremental update between tensor sets.
    
    Args:
        new_tensors: New tensor values
        old_tensors: Old tensor values  
        step_size: Step size for interpolation
    
    Returns:
        Incrementally updated tensors
    """

def periodic_update(new_tensors, old_tensors, steps, update_period):
    """
    Update tensors periodically based on step count.
    
    Args:
        new_tensors: New tensor values
        old_tensors: Old tensor values
        steps: Current step count
        update_period: Period for updates
    
    Returns:
        Conditionally updated tensors
    """

Numerical Utilities

Safe Operations

def safe_norm(x, min_norm=0.0, ord=None):
    """
    Numerically stable norm computation.
    
    Args:
        x: Input tensor
        min_norm: Minimum norm value for stability (default: 0.0)
        ord: Norm order (None, 1, 2, 'fro', etc.) (default: None for L2)
    
    Returns:
        Norm value with numerical stability
    """

def safe_root_mean_squares(x, min_rms=0.0):
    """
    Numerically stable root mean square computation.
    
    Args:
        x: Input tensor
        min_rms: Minimum RMS value for stability (default: 0.0)
    
    Returns:
        RMS value with numerical stability
    """

def safe_increment(count):
    """
    Safely increment counter with overflow protection.
    
    Args:
        count: Current counter value
    
    Returns:
        Incremented counter value
    """

def safe_int32_increment(count):
    """
    Safely increment int32 counter with overflow protection.
    
    Args:
        count: Current int32 counter value
    
    Returns:
        Incremented int32 counter value
    """

Linear Algebra

Matrix Operations

def global_norm(updates):
    """
    Compute global norm across all parameters in pytree.
    
    Args:
        updates: Parameter updates (pytree)
    
    Returns:
        Global norm scalar value
    """

def power_iteration(matrix, num_iters=10, error_tolerance=1e-6, precision=None):
    """
    Compute dominant eigenvalue and eigenvector using power iteration.
    
    Args:
        matrix: Input matrix
        num_iters: Maximum number of iterations (default: 10)
        error_tolerance: Convergence tolerance (default: 1e-6)
        precision: Numerical precision (default: None)
    
    Returns:
        Tuple of (eigenvalue, eigenvector)
    """

def matrix_inverse_pth_root(matrix, p, num_iters=15, ridge_epsilon=1e-6, error_tolerance=1e-6, precision=None):
    """
    Compute matrix inverse p-th root using Newton's method.
    
    Args:
        matrix: Input positive definite matrix
        p: Root order (e.g., 2 for square root)
        num_iters: Maximum iterations (default: 15)
        ridge_epsilon: Ridge regularization (default: 1e-6)
        error_tolerance: Convergence tolerance (default: 1e-6)
        precision: Numerical precision (default: None)
    
    Returns:
        Matrix inverse p-th root
    """

def nnls(a, b, max_iters=None, tol=1e-8):
    """
    Non-negative least squares solver.
    
    Args:
        a: Coefficient matrix
        b: Target vector
        max_iters: Maximum iterations (default: None for auto)
        tol: Convergence tolerance (default: 1e-8)
    
    Returns:
        Non-negative solution vector
    """

Core Types and Base Functions

Base Transformations

def identity():
    """
    Identity transformation that passes gradients unchanged.
    
    Returns:
        GradientTransformation
    """

def set_to_zero():
    """
    Transformation that sets all gradients to zero.
    
    Returns:
        GradientTransformation
    """

def stateless(f):
    """
    Create stateless transformation from function.
    
    Args:
        f: Function to convert to transformation
    
    Returns:
        GradientTransformation
    """

def stateless_with_tree_map(f):
    """
    Create stateless transformation with tree mapping.
    
    Args:
        f: Function to apply to each leaf of parameter tree
    
    Returns:
        GradientTransformation
    """

def with_extra_args_support(transformation):
    """
    Add support for extra arguments to transformation.
    
    Args:
        transformation: Base transformation to extend
    
    Returns:
        GradientTransformationExtraArgs
    """

Utility Functions

Gradient Processing

def scale_gradient(inputs, scale):
    """
    Scale gradients during forward/backward pass.
    
    Args:
        inputs: Input values (forward pass is identity)
        scale: Scale factor for gradients in backward pass
    
    Returns:
        Inputs (unchanged in forward pass)
    """

def value_and_grad_from_state(fun, argnums=0, has_aux=False):
    """
    Compute value and gradient while maintaining state.
    
    Args:
        fun: Function to differentiate
        argnums: Argument indices to differentiate (default: 0)
        has_aux: Whether function returns auxiliary data (default: False)
    
    Returns:
        Function that returns (value, grad) tuple
    """

Random Utilities

def multi_normal(loc, scale_tril, random_key):
    """
    Sample from multivariate normal distribution.
    
    Args:
        loc: Mean vector
        scale_tril: Lower triangular scale matrix
        random_key: JAX random key
    
    Returns:
        Random sample from multivariate normal
    """

Tree Operations

Basic Tree Arithmetic

# Tree-level operations in optax.tree module
def add(tree1, tree2):
    """Element-wise addition of two pytrees."""

def sub(tree1, tree2):
    """Element-wise subtraction of two pytrees."""

def mul(tree1, tree2):
    """Element-wise multiplication of two pytrees."""

def div(tree1, tree2):
    """Element-wise division of two pytrees."""

def scale(tree, scalar):
    """Scale all elements in pytree by scalar."""

def norm(tree, ord=2):
    """Compute norm of pytree."""

def sum(tree):
    """Sum all elements in pytree."""

def max(tree):
    """Find maximum element in pytree."""

Tree Utilities

def zeros_like(tree):
    """Create pytree of zeros with same structure."""

def ones_like(tree):
    """Create pytree of ones with same structure."""

def full_like(tree, fill_value):
    """Create pytree filled with specified value."""

Assignment Module

Hungarian Algorithm

def hungarian_algorithm(cost_matrix):
    """
    Hungarian algorithm for solving assignment problems.
    
    Args:
        cost_matrix: 2D cost matrix for assignments
    
    Returns:
        Optimal assignment indices
    """

Tree Utils Module

Parameter Tree Manipulation

def tree_map_params(fn, tree):
    """
    Map function over parameters in pytree.
    
    Args:
        fn: Function to apply to each parameter
        tree: Parameter pytree
    
    Returns:
        Transformed pytree
    """

def tree_bias_correction(moment, decay, count):
    """
    Apply bias correction to moment estimates.
    
    Args:
        moment: Moment estimate
        decay: Decay rate used for moment
        count: Step count for bias correction
    
    Returns:
        Bias-corrected moment
    """

Moment Updates

def tree_update_moment(updates, moments, decay, order):
    """
    Update moment estimates for optimizer state.
    
    Args:
        updates: Current gradient updates
        moments: Previous moment estimates
        decay: Exponential decay rate
        order: Moment order (1 for mean, 2 for variance)
    
    Returns:
        Updated moment estimates
    """

def tree_update_moment_per_elem_norm(updates, moments, decay, order):
    """
    Update moments with per-element normalization.
    
    Args:
        updates: Current gradient updates
        moments: Previous moment estimates
        decay: Exponential decay rate
        order: Moment order
    
    Returns:
        Updated moment estimates with per-element normalization
    """

def tree_update_infinity_moment(updates, moments, decay):
    """
    Update infinity moments (max absolute values).
    
    Args:
        updates: Current gradient updates
        moments: Previous infinity moments
        decay: Exponential decay rate
    
    Returns:
        Updated infinity moments
    """

Type Definitions

# Type aliases
OptState = chex.ArrayTree       # Optimizer state
Params = chex.ArrayTree         # Model parameters  
Updates = Params                # Gradient updates
Schedule = Callable[[chex.Numeric], chex.Numeric]  # Schedule function
ScalarOrSchedule = Union[float, jax.Array, Schedule]  # Flexible numeric type
MaskOrFn = Union[Any, Callable[[Params], Any]]  # Mask or masking function

# Function type definitions
TransformInitFn = Callable[[Params], OptState]
TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
TransformUpdateExtraArgsFn = Callable[[Updates, OptState, Optional[Params], ...], Tuple[Updates, OptState]]

# Core classes
class GradientTransformation(NamedTuple):
    """Core gradient transformation with init and update functions."""
    init: TransformInitFn
    update: TransformUpdateFn

class GradientTransformationExtraArgs(NamedTuple):
    """Extended transformation supporting extra arguments."""
    init: TransformInitFn
    update: TransformUpdateExtraArgsFn

class EmptyState(NamedTuple):
    """Empty state for stateless transformations."""
    pass

class FactoredState(NamedTuple):
    """State for factorized operations."""
    count: chex.Array
    v_row: chex.ArrayTree
    v_col: chex.ArrayTree

Usage Examples

Basic Parameter Updates

import jax.numpy as jnp
import optax

# Parameters and updates
params = {'w': jnp.ones((5, 3)), 'b': jnp.zeros((3,))}
updates = {'w': jnp.ones((5, 3)) * 0.01, 'b': jnp.ones((3,)) * 0.001}

# Apply updates
new_params = optax.apply_updates(params, updates)

# Compute global norm
grad_norm = optax.global_norm(updates)
print(f"Global gradient norm: {grad_norm}")

Numerical Stability

# Safe operations for numerical stability
x = jnp.array([1e-8, 1e-6, 1.0, 1e6])

safe_norm_val = optax.safe_norm(x, min_norm=1e-8)
safe_rms_val = optax.safe_root_mean_squares(x, min_rms=1e-8)

# Safe counting
step_count = jnp.array(2147483647, dtype=jnp.int32)  # Near int32 max
next_count = optax.safe_int32_increment(step_count)

Tree Operations

# Tree arithmetic
tree1 = {'a': jnp.array([1, 2, 3]), 'b': jnp.array([4, 5])}
tree2 = {'a': jnp.array([6, 7, 8]), 'b': jnp.array([9, 10])}

# Element-wise operations
sum_tree = optax.tree.add(tree1, tree2)
scaled_tree = optax.tree.scale(tree1, 0.5)
tree_norm = optax.tree.norm(tree1)

# Tree utilities
zero_tree = optax.tree.zeros_like(tree1)
ones_tree = optax.tree.ones_like(tree1)

Custom Transformations

# Create custom stateless transformation
def my_scaling_fn(updates):
    return jax.tree_map(lambda x: 0.01 * x, updates)

my_transform = optax.stateless(my_scaling_fn)

# Use with other transformations
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    my_transform,
    optax.scale_by_adam()
)

Advanced Usage

# Matrix operations for second-order methods
def compute_preconditioner(gradients):
    # Flatten gradients for matrix operations
    flat_grads = jax.flatten_util.ravel_pytree(gradients)[0]
    
    # Compute outer product approximation
    outer_prod = jnp.outer(flat_grads, flat_grads)
    
    # Compute matrix inverse square root
    inv_sqrt = optax.matrix_inverse_pth_root(
        outer_prod + 1e-6 * jnp.eye(len(flat_grads)), 
        p=2,
        num_iters=10
    )
    
    return inv_sqrt

# Gradient scaling with state
def scale_with_state(inputs, state):
    scale_factor = jnp.sqrt(state['step_count'])
    return optax.scale_gradient(inputs, scale_factor)

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json