CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

optimization.mddocs/

Optimization

NumPyro provides a collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation. All optimizers are built on JAX for efficient automatic differentiation and support JIT compilation for high-performance optimization.

Capabilities

Core Optimizer Infrastructure

Base classes and utilities for the optimization system.

class Optimizer:
    """
    Base class for optimizers in NumPyro.
    
    All optimizers follow the same interface pattern for consistency
    with JAX optimization libraries like optax.
    """
    def init(self, params: dict) -> Any:
        """
        Initialize optimizer state.
        
        Args:
            params: Initial parameter values
            
        Returns:
            Initial optimizer state
        """
    
    def update(self, grads: dict, state: Any, params: dict) -> tuple:
        """
        Update parameters based on gradients.
        
        Args:
            grads: Parameter gradients
            state: Current optimizer state
            params: Current parameter values
            
        Returns:
            Tuple of (updates, new_state)
        """
    
    def get_params(self, state: Any) -> dict:
        """Get current parameter values from optimizer state."""

Adaptive Learning Rate Optimizers

Optimizers that adapt learning rates based on gradient history.

class Adam:
    """
    Adaptive Moment Estimation (Adam) optimizer.
    
    Computes individual adaptive learning rates for different parameters from 
    estimates of first and second moments of the gradients.
    
    Args:
        step_size: Learning rate (default: 0.001)
        b1: Exponential decay rate for first moment estimates (default: 0.9)
        b2: Exponential decay rate for second moment estimates (default: 0.999)
        eps: Small constant for numerical stability (default: 1e-8)
        
    Usage:
        optimizer = Adam(step_size=0.01)
        opt_state = optimizer.init(params)
        
        for step in range(num_steps):
            grads = compute_gradients(params)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = apply_updates(params, updates)
    """
    def __init__(self, step_size: float = 0.001, b1: float = 0.9, 
                b2: float = 0.999, eps: float = 1e-8): ...

class ClippedAdam:
    """
    Adam optimizer with gradient clipping for improved stability.
    
    Args:
        step_size: Learning rate
        b1: First moment decay rate
        b2: Second moment decay rate  
        eps: Numerical stability constant
        clip_norm: Maximum gradient norm for clipping
        
    Usage:
        # Useful for training on unstable loss landscapes
        optimizer = ClippedAdam(step_size=0.01, clip_norm=1.0)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.001, b1: float = 0.9,
                b2: float = 0.999, eps: float = 1e-8, clip_norm: float = 10.0): ...

class Adagrad:
    """
    Adaptive Gradient Algorithm (Adagrad) optimizer.
    
    Adapts learning rate to parameters, performing smaller updates for parameters
    associated with frequently occurring features.
    
    Args:
        step_size: Initial learning rate (default: 0.01)
        eps: Small constant for numerical stability (default: 1e-8)
        
    Usage:
        # Good for sparse data and features
        optimizer = Adagrad(step_size=0.1)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...

class RMSProp:
    """
    Root Mean Square Propagation (RMSProp) optimizer.
    
    Maintains a moving average of squared gradients to normalize the gradient.
    
    Args:
        step_size: Learning rate (default: 0.01)
        decay: Decay rate for moving average (default: 0.9)
        eps: Small constant for numerical stability (default: 1e-8)
        
    Usage:
        # Good for non-stationary objectives
        optimizer = RMSProp(step_size=0.01, decay=0.9)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.01, decay: float = 0.9, eps: float = 1e-8): ...

class RMSPropMomentum:
    """
    RMSProp with momentum for improved convergence.
    
    Args:
        step_size: Learning rate
        decay: Decay rate for squared gradient moving average
        momentum: Momentum coefficient
        eps: Numerical stability constant
        centered: Whether to use centered RMSProp variant
        
    Usage:
        # Combines benefits of RMSProp and momentum
        optimizer = RMSPropMomentum(step_size=0.01, momentum=0.9)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.01, decay: float = 0.9, 
                momentum: float = 0.0, eps: float = 1e-8, centered: bool = False): ...

Momentum-Based Optimizers

Optimizers that use momentum to accelerate convergence.

class SGD:
    """
    Stochastic Gradient Descent optimizer.
    
    Basic gradient descent with optional momentum.
    
    Args:
        step_size: Learning rate (default: 0.01)
        momentum: Momentum coefficient (default: 0.0)
        
    Usage:
        # Simple gradient descent
        optimizer = SGD(step_size=0.01)
        
        # With momentum for faster convergence
        optimizer = SGD(step_size=0.01, momentum=0.9)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.01, momentum: float = 0.0): ...

class Momentum:
    """
    Stochastic Gradient Descent with momentum.
    
    Accelerates gradient descent by accumulating a velocity vector in directions
    of persistent reduction in the objective function.
    
    Args:
        step_size: Learning rate (default: 0.01)  
        mass: Momentum coefficient (default: 0.9)
        
    Usage:
        # Classical momentum SGD
        optimizer = Momentum(step_size=0.01, mass=0.9)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.01, mass: float = 0.9): ...

Specialized Optimizers

Advanced optimizers for specific use cases.

class SM3:
    """
    Square-root of second Moment (SM3) optimizer.
    
    Memory-efficient adaptive optimizer that maintains a single accumulator
    per parameter instead of separate first and second moment estimates.
    
    Args:
        step_size: Learning rate (default: 0.01)
        eps: Small constant for numerical stability (default: 1e-8)
        
    Usage:
        # Memory-efficient alternative to Adam for large models
        optimizer = SM3(step_size=0.01)
        opt_state = optimizer.init(params)
    """
    def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...

class Minimize:
    """
    Wrapper for JAX's minimize function for direct optimization.
    
    Uses JAX's built-in optimization routines like L-BFGS for direct
    minimization of objective functions.
    
    Args:
        method: Optimization method ('BFGS', 'L-BFGS-B', 'CG', etc.)
        options: Additional options for the underlying scipy optimizer
        
    Usage:
        # For objectives where full optimization is preferred over SGD
        optimizer = Minimize(method='L-BFGS-B')
        
        # Direct minimization (different interface)
        result = optimizer.minimize(loss_fn, init_params)
    """
    def __init__(self, method: str = 'BFGS', options: Optional[dict] = None): ...
    
    def minimize(self, fun: Callable, x0: dict, *args, **kwargs) -> dict:
        """
        Minimize objective function.
        
        Args:
            fun: Objective function to minimize
            x0: Initial parameter values
            *args: Additional arguments to objective function
            **kwargs: Additional keyword arguments
            
        Returns:
            Optimization result with final parameters and metadata
        """

Optimizer Utilities

Utility functions for working with optimizers and optimization schedules.

def multi_transform(transforms: dict, param_labels: dict) -> Optimizer:
    """
    Apply different optimizers to different parameter groups.
    
    Args:
        transforms: Dictionary mapping labels to optimizers
        param_labels: Dictionary mapping parameter names to labels
        
    Returns:
        Combined optimizer that applies appropriate transform to each parameter group
        
    Usage:
        # Different learning rates for different parameter groups
        transforms = {
            'weights': Adam(0.01),
            'biases': Adam(0.1)
        }
        param_labels = {
            'layer1.weight': 'weights',
            'layer1.bias': 'biases'
        }
        optimizer = multi_transform(transforms, param_labels)
    """

def exponential_decay(step_size: float, decay_steps: int, 
                     decay_rate: float, staircase: bool = False) -> Callable:
    """
    Create exponential learning rate decay schedule.
    
    Args:
        step_size: Initial learning rate
        decay_steps: Number of steps after which to apply decay
        decay_rate: Decay factor
        staircase: Whether to apply decay in discrete steps
        
    Returns:
        Learning rate schedule function
        
    Usage:
        schedule = exponential_decay(0.1, decay_steps=1000, decay_rate=0.96)
        optimizer = Adam(step_size=schedule)
    """

def polynomial_decay(step_size: float, transition_steps: int, 
                    transition_begin: int = 0, power: float = 1.0,
                    end_value: float = 0.0) -> Callable:  
    """
    Create polynomial learning rate decay schedule.
    
    Args:
        step_size: Initial learning rate
        transition_steps: Number of steps over which to decay
        transition_begin: Step at which to begin decay
        power: Power of polynomial decay
        end_value: Final learning rate value
        
    Returns:
        Learning rate schedule function
    """

def warmup_schedule(warmup_steps: int, peak_value: float, 
                   end_value: float = 0.0) -> Callable:
    """
    Create learning rate warmup schedule.
    
    Args:
        warmup_steps: Number of warmup steps
        peak_value: Peak learning rate after warmup
        end_value: Final learning rate value
        
    Returns:
        Learning rate schedule function
        
    Usage:
        # Linear warmup to peak, then decay
        schedule = warmup_schedule(1000, peak_value=0.01)
        optimizer = Adam(step_size=schedule)
    """

Integration with SVI

Examples of how optimizers integrate with Stochastic Variational Inference.

# Usage with SVI
from numpyro.infer import SVI, Trace_ELBO

def example_svi_usage():
    """Example of using optimizers with SVI."""
    
    # Define model and guide
    def model(data):
        mu = numpyro.sample("mu", dist.Normal(0, 1))
        with numpyro.plate("data", len(data)):
            numpyro.sample("obs", dist.Normal(mu, 1), obs=data)
    
    def guide(data):
        mu_loc = numpyro.param("mu_loc", 0.0)
        mu_scale = numpyro.param("mu_scale", 1.0, constraint=constraints.positive)
        numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
    
    # Various optimizer configurations
    optimizers = {
        # Basic Adam
        'adam': Adam(0.01),
        
        # Adam with gradient clipping
        'clipped_adam': ClippedAdam(0.01, clip_norm=1.0),
        
        # RMSProp for non-stationary problems
        'rmsprop': RMSProp(0.01, decay=0.9),
        
        # SGD with momentum
        'sgd_momentum': SGD(0.01, momentum=0.9),
        
        # Different rates for different parameters
        'multi_rate': multi_transform({
            'loc': Adam(0.01),
            'scale': Adam(0.001)
        }, {
            'mu_loc': 'loc',
            'mu_scale': 'scale'
        })
    }
    
    # Run SVI with chosen optimizer
    optimizer = optimizers['adam']
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    
    # Training loop
    svi_result = svi.run(random.PRNGKey(0), 1000, data)
    
    return svi_result

Usage Examples

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.optim import Adam, RMSProp, SGD
import jax.numpy as jnp
from jax import random

# Basic optimizer usage
def simple_optimization_example():
    # Define simple model
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 1))
        b = numpyro.sample("b", dist.Normal(0, 1))
        mu = a * x + b
        numpyro.sample("y", dist.Normal(mu, 0.1), obs=y)
    
    def guide(x, y):
        a_loc = numpyro.param("a_loc", 0.0)
        a_scale = numpyro.param("a_scale", 1.0, constraint=constraints.positive)
        b_loc = numpyro.param("b_loc", 0.0)  
        b_scale = numpyro.param("b_scale", 1.0, constraint=constraints.positive)
        
        numpyro.sample("a", dist.Normal(a_loc, a_scale))
        numpyro.sample("b", dist.Normal(b_loc, b_scale))
    
    # Generate synthetic data
    true_a, true_b = 2.0, 1.0
    x = jnp.linspace(0, 1, 100)
    y = true_a * x + true_b + 0.1 * random.normal(random.PRNGKey(0), (100,))
    
    # Compare different optimizers
    optimizers = {
        'Adam': Adam(0.01),
        'RMSProp': RMSProp(0.01),
        'SGD': SGD(0.01, momentum=0.9)
    }
    
    results = {}
    for name, optimizer in optimizers.items():
        svi = SVI(model, guide, optimizer, Trace_ELBO())
        svi_result = svi.run(random.PRNGKey(1), 1000, x, y)
        results[name] = svi_result
        
        # Print final loss
        print(f"{name} final loss: {svi_result.losses[-1]:.4f}")
    
    return results

# Advanced optimizer configuration
def advanced_optimization_example():
    # Complex model with multiple parameter groups
    def hierarchical_model(group_idx, y):
        # Global parameters
        mu_global = numpyro.sample("mu_global", dist.Normal(0, 10))
        sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))
        
        # Group parameters  
        n_groups = len(jnp.unique(group_idx))
        with numpyro.plate("groups", n_groups):
            mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))
        
        # Observations
        with numpyro.plate("data", len(y)):
            mu = mu_group[group_idx]
            numpyro.sample("y", dist.Normal(mu, 1), obs=y)
    
    def hierarchical_guide(group_idx, y):
        # Global parameter variational families
        mu_global_loc = numpyro.param("mu_global_loc", 0.0)
        mu_global_scale = numpyro.param("mu_global_scale", 1.0, constraint=constraints.positive)
        sigma_global_rate = numpyro.param("sigma_global_rate", 1.0, constraint=constraints.positive)
        
        # Group parameter variational families
        n_groups = len(jnp.unique(group_idx))
        mu_group_loc = numpyro.param("mu_group_loc", jnp.zeros(n_groups))
        mu_group_scale = numpyro.param("mu_group_scale", jnp.ones(n_groups), constraint=constraints.positive)
        
        # Sample from variational distributions
        numpyro.sample("mu_global", dist.Normal(mu_global_loc, mu_global_scale))
        numpyro.sample("sigma_global", dist.Exponential(sigma_global_rate))
        
        with numpyro.plate("groups", n_groups):
            numpyro.sample("mu_group", dist.Normal(mu_group_loc, mu_group_scale))
    
    # Multi-rate optimization: different learning rates for global vs group parameters
    optimizer = multi_transform({
        'global': Adam(0.01),    # Slower for global parameters
        'group': Adam(0.05)      # Faster for group parameters  
    }, {
        'mu_global_loc': 'global',
        'mu_global_scale': 'global',
        'sigma_global_rate': 'global',
        'mu_group_loc': 'group',
        'mu_group_scale': 'group'
    })
    
    # Learning rate schedule
    schedule = exponential_decay(step_size=0.01, decay_steps=500, decay_rate=0.96)
    scheduled_optimizer = Adam(step_size=schedule)
    
    return optimizer, scheduled_optimizer

Types

from typing import Optional, Union, Callable, Dict, Any, Tuple
from jax import Array
import jax.numpy as jnp

ArrayLike = Union[Array, jnp.ndarray, float, int]
Params = Dict[str, ArrayLike]
Grads = Dict[str, ArrayLike]
Updates = Dict[str, ArrayLike]
OptState = Any  # Optimizer-specific state type

class OptimizerState:
    """Base optimizer state interface."""
    step: int
    params: Params

class AdamState(OptimizerState):
    """State for Adam optimizer."""
    step: int
    params: Params
    m: Params  # First moment estimates
    v: Params  # Second moment estimates

class SGDState(OptimizerState):
    """State for SGD optimizer."""
    step: int
    params: Params
    momentum: Optional[Params]  # Momentum terms

class RMSPropState(OptimizerState):
    """State for RMSProp optimizer."""
    step: int
    params: Params
    v: Params  # Squared gradient moving average

# Optimizer interface
class OptimizerProtocol:
    """Protocol for NumPyro optimizers."""
    def init(self, params: Params) -> OptState: ...
    def update(self, grads: Grads, state: OptState, params: Params) -> Tuple[Updates, OptState]: ...
    def get_params(self, state: OptState) -> Params: ...

# Schedule functions
ScheduleFunction = Callable[[int], float]

# Optimizer factory functions
OptimizerFactory = Callable[..., OptimizerProtocol]

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json