Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro provides a collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation. All optimizers are built on JAX for efficient automatic differentiation and support JIT compilation for high-performance optimization.
Base classes and utilities for the optimization system.
class Optimizer:
"""
Base class for optimizers in NumPyro.
All optimizers follow the same interface pattern for consistency
with JAX optimization libraries like optax.
"""
def init(self, params: dict) -> Any:
"""
Initialize optimizer state.
Args:
params: Initial parameter values
Returns:
Initial optimizer state
"""
def update(self, grads: dict, state: Any, params: dict) -> tuple:
"""
Update parameters based on gradients.
Args:
grads: Parameter gradients
state: Current optimizer state
params: Current parameter values
Returns:
Tuple of (updates, new_state)
"""
def get_params(self, state: Any) -> dict:
"""Get current parameter values from optimizer state."""Optimizers that adapt learning rates based on gradient history.
class Adam:
"""
Adaptive Moment Estimation (Adam) optimizer.
Computes individual adaptive learning rates for different parameters from
estimates of first and second moments of the gradients.
Args:
step_size: Learning rate (default: 0.001)
b1: Exponential decay rate for first moment estimates (default: 0.9)
b2: Exponential decay rate for second moment estimates (default: 0.999)
eps: Small constant for numerical stability (default: 1e-8)
Usage:
optimizer = Adam(step_size=0.01)
opt_state = optimizer.init(params)
for step in range(num_steps):
grads = compute_gradients(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = apply_updates(params, updates)
"""
def __init__(self, step_size: float = 0.001, b1: float = 0.9,
b2: float = 0.999, eps: float = 1e-8): ...
class ClippedAdam:
"""
Adam optimizer with gradient clipping for improved stability.
Args:
step_size: Learning rate
b1: First moment decay rate
b2: Second moment decay rate
eps: Numerical stability constant
clip_norm: Maximum gradient norm for clipping
Usage:
# Useful for training on unstable loss landscapes
optimizer = ClippedAdam(step_size=0.01, clip_norm=1.0)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.001, b1: float = 0.9,
b2: float = 0.999, eps: float = 1e-8, clip_norm: float = 10.0): ...
class Adagrad:
"""
Adaptive Gradient Algorithm (Adagrad) optimizer.
Adapts learning rate to parameters, performing smaller updates for parameters
associated with frequently occurring features.
Args:
step_size: Initial learning rate (default: 0.01)
eps: Small constant for numerical stability (default: 1e-8)
Usage:
# Good for sparse data and features
optimizer = Adagrad(step_size=0.1)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...
class RMSProp:
"""
Root Mean Square Propagation (RMSProp) optimizer.
Maintains a moving average of squared gradients to normalize the gradient.
Args:
step_size: Learning rate (default: 0.01)
decay: Decay rate for moving average (default: 0.9)
eps: Small constant for numerical stability (default: 1e-8)
Usage:
# Good for non-stationary objectives
optimizer = RMSProp(step_size=0.01, decay=0.9)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.01, decay: float = 0.9, eps: float = 1e-8): ...
class RMSPropMomentum:
"""
RMSProp with momentum for improved convergence.
Args:
step_size: Learning rate
decay: Decay rate for squared gradient moving average
momentum: Momentum coefficient
eps: Numerical stability constant
centered: Whether to use centered RMSProp variant
Usage:
# Combines benefits of RMSProp and momentum
optimizer = RMSPropMomentum(step_size=0.01, momentum=0.9)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.01, decay: float = 0.9,
momentum: float = 0.0, eps: float = 1e-8, centered: bool = False): ...Optimizers that use momentum to accelerate convergence.
class SGD:
"""
Stochastic Gradient Descent optimizer.
Basic gradient descent with optional momentum.
Args:
step_size: Learning rate (default: 0.01)
momentum: Momentum coefficient (default: 0.0)
Usage:
# Simple gradient descent
optimizer = SGD(step_size=0.01)
# With momentum for faster convergence
optimizer = SGD(step_size=0.01, momentum=0.9)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.01, momentum: float = 0.0): ...
class Momentum:
"""
Stochastic Gradient Descent with momentum.
Accelerates gradient descent by accumulating a velocity vector in directions
of persistent reduction in the objective function.
Args:
step_size: Learning rate (default: 0.01)
mass: Momentum coefficient (default: 0.9)
Usage:
# Classical momentum SGD
optimizer = Momentum(step_size=0.01, mass=0.9)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.01, mass: float = 0.9): ...Advanced optimizers for specific use cases.
class SM3:
"""
Square-root of second Moment (SM3) optimizer.
Memory-efficient adaptive optimizer that maintains a single accumulator
per parameter instead of separate first and second moment estimates.
Args:
step_size: Learning rate (default: 0.01)
eps: Small constant for numerical stability (default: 1e-8)
Usage:
# Memory-efficient alternative to Adam for large models
optimizer = SM3(step_size=0.01)
opt_state = optimizer.init(params)
"""
def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...
class Minimize:
"""
Wrapper for JAX's minimize function for direct optimization.
Uses JAX's built-in optimization routines like L-BFGS for direct
minimization of objective functions.
Args:
method: Optimization method ('BFGS', 'L-BFGS-B', 'CG', etc.)
options: Additional options for the underlying scipy optimizer
Usage:
# For objectives where full optimization is preferred over SGD
optimizer = Minimize(method='L-BFGS-B')
# Direct minimization (different interface)
result = optimizer.minimize(loss_fn, init_params)
"""
def __init__(self, method: str = 'BFGS', options: Optional[dict] = None): ...
def minimize(self, fun: Callable, x0: dict, *args, **kwargs) -> dict:
"""
Minimize objective function.
Args:
fun: Objective function to minimize
x0: Initial parameter values
*args: Additional arguments to objective function
**kwargs: Additional keyword arguments
Returns:
Optimization result with final parameters and metadata
"""Utility functions for working with optimizers and optimization schedules.
def multi_transform(transforms: dict, param_labels: dict) -> Optimizer:
"""
Apply different optimizers to different parameter groups.
Args:
transforms: Dictionary mapping labels to optimizers
param_labels: Dictionary mapping parameter names to labels
Returns:
Combined optimizer that applies appropriate transform to each parameter group
Usage:
# Different learning rates for different parameter groups
transforms = {
'weights': Adam(0.01),
'biases': Adam(0.1)
}
param_labels = {
'layer1.weight': 'weights',
'layer1.bias': 'biases'
}
optimizer = multi_transform(transforms, param_labels)
"""
def exponential_decay(step_size: float, decay_steps: int,
decay_rate: float, staircase: bool = False) -> Callable:
"""
Create exponential learning rate decay schedule.
Args:
step_size: Initial learning rate
decay_steps: Number of steps after which to apply decay
decay_rate: Decay factor
staircase: Whether to apply decay in discrete steps
Returns:
Learning rate schedule function
Usage:
schedule = exponential_decay(0.1, decay_steps=1000, decay_rate=0.96)
optimizer = Adam(step_size=schedule)
"""
def polynomial_decay(step_size: float, transition_steps: int,
transition_begin: int = 0, power: float = 1.0,
end_value: float = 0.0) -> Callable:
"""
Create polynomial learning rate decay schedule.
Args:
step_size: Initial learning rate
transition_steps: Number of steps over which to decay
transition_begin: Step at which to begin decay
power: Power of polynomial decay
end_value: Final learning rate value
Returns:
Learning rate schedule function
"""
def warmup_schedule(warmup_steps: int, peak_value: float,
end_value: float = 0.0) -> Callable:
"""
Create learning rate warmup schedule.
Args:
warmup_steps: Number of warmup steps
peak_value: Peak learning rate after warmup
end_value: Final learning rate value
Returns:
Learning rate schedule function
Usage:
# Linear warmup to peak, then decay
schedule = warmup_schedule(1000, peak_value=0.01)
optimizer = Adam(step_size=schedule)
"""Examples of how optimizers integrate with Stochastic Variational Inference.
# Usage with SVI
from numpyro.infer import SVI, Trace_ELBO
def example_svi_usage():
"""Example of using optimizers with SVI."""
# Define model and guide
def model(data):
mu = numpyro.sample("mu", dist.Normal(0, 1))
with numpyro.plate("data", len(data)):
numpyro.sample("obs", dist.Normal(mu, 1), obs=data)
def guide(data):
mu_loc = numpyro.param("mu_loc", 0.0)
mu_scale = numpyro.param("mu_scale", 1.0, constraint=constraints.positive)
numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
# Various optimizer configurations
optimizers = {
# Basic Adam
'adam': Adam(0.01),
# Adam with gradient clipping
'clipped_adam': ClippedAdam(0.01, clip_norm=1.0),
# RMSProp for non-stationary problems
'rmsprop': RMSProp(0.01, decay=0.9),
# SGD with momentum
'sgd_momentum': SGD(0.01, momentum=0.9),
# Different rates for different parameters
'multi_rate': multi_transform({
'loc': Adam(0.01),
'scale': Adam(0.001)
}, {
'mu_loc': 'loc',
'mu_scale': 'scale'
})
}
# Run SVI with chosen optimizer
optimizer = optimizers['adam']
svi = SVI(model, guide, optimizer, Trace_ELBO())
# Training loop
svi_result = svi.run(random.PRNGKey(0), 1000, data)
return svi_resultimport numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.optim import Adam, RMSProp, SGD
import jax.numpy as jnp
from jax import random
# Basic optimizer usage
def simple_optimization_example():
# Define simple model
def model(x, y):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
mu = a * x + b
numpyro.sample("y", dist.Normal(mu, 0.1), obs=y)
def guide(x, y):
a_loc = numpyro.param("a_loc", 0.0)
a_scale = numpyro.param("a_scale", 1.0, constraint=constraints.positive)
b_loc = numpyro.param("b_loc", 0.0)
b_scale = numpyro.param("b_scale", 1.0, constraint=constraints.positive)
numpyro.sample("a", dist.Normal(a_loc, a_scale))
numpyro.sample("b", dist.Normal(b_loc, b_scale))
# Generate synthetic data
true_a, true_b = 2.0, 1.0
x = jnp.linspace(0, 1, 100)
y = true_a * x + true_b + 0.1 * random.normal(random.PRNGKey(0), (100,))
# Compare different optimizers
optimizers = {
'Adam': Adam(0.01),
'RMSProp': RMSProp(0.01),
'SGD': SGD(0.01, momentum=0.9)
}
results = {}
for name, optimizer in optimizers.items():
svi = SVI(model, guide, optimizer, Trace_ELBO())
svi_result = svi.run(random.PRNGKey(1), 1000, x, y)
results[name] = svi_result
# Print final loss
print(f"{name} final loss: {svi_result.losses[-1]:.4f}")
return results
# Advanced optimizer configuration
def advanced_optimization_example():
# Complex model with multiple parameter groups
def hierarchical_model(group_idx, y):
# Global parameters
mu_global = numpyro.sample("mu_global", dist.Normal(0, 10))
sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))
# Group parameters
n_groups = len(jnp.unique(group_idx))
with numpyro.plate("groups", n_groups):
mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))
# Observations
with numpyro.plate("data", len(y)):
mu = mu_group[group_idx]
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
def hierarchical_guide(group_idx, y):
# Global parameter variational families
mu_global_loc = numpyro.param("mu_global_loc", 0.0)
mu_global_scale = numpyro.param("mu_global_scale", 1.0, constraint=constraints.positive)
sigma_global_rate = numpyro.param("sigma_global_rate", 1.0, constraint=constraints.positive)
# Group parameter variational families
n_groups = len(jnp.unique(group_idx))
mu_group_loc = numpyro.param("mu_group_loc", jnp.zeros(n_groups))
mu_group_scale = numpyro.param("mu_group_scale", jnp.ones(n_groups), constraint=constraints.positive)
# Sample from variational distributions
numpyro.sample("mu_global", dist.Normal(mu_global_loc, mu_global_scale))
numpyro.sample("sigma_global", dist.Exponential(sigma_global_rate))
with numpyro.plate("groups", n_groups):
numpyro.sample("mu_group", dist.Normal(mu_group_loc, mu_group_scale))
# Multi-rate optimization: different learning rates for global vs group parameters
optimizer = multi_transform({
'global': Adam(0.01), # Slower for global parameters
'group': Adam(0.05) # Faster for group parameters
}, {
'mu_global_loc': 'global',
'mu_global_scale': 'global',
'sigma_global_rate': 'global',
'mu_group_loc': 'group',
'mu_group_scale': 'group'
})
# Learning rate schedule
schedule = exponential_decay(step_size=0.01, decay_steps=500, decay_rate=0.96)
scheduled_optimizer = Adam(step_size=schedule)
return optimizer, scheduled_optimizerfrom typing import Optional, Union, Callable, Dict, Any, Tuple
from jax import Array
import jax.numpy as jnp
ArrayLike = Union[Array, jnp.ndarray, float, int]
Params = Dict[str, ArrayLike]
Grads = Dict[str, ArrayLike]
Updates = Dict[str, ArrayLike]
OptState = Any # Optimizer-specific state type
class OptimizerState:
"""Base optimizer state interface."""
step: int
params: Params
class AdamState(OptimizerState):
"""State for Adam optimizer."""
step: int
params: Params
m: Params # First moment estimates
v: Params # Second moment estimates
class SGDState(OptimizerState):
"""State for SGD optimizer."""
step: int
params: Params
momentum: Optional[Params] # Momentum terms
class RMSPropState(OptimizerState):
"""State for RMSProp optimizer."""
step: int
params: Params
v: Params # Squared gradient moving average
# Optimizer interface
class OptimizerProtocol:
"""Protocol for NumPyro optimizers."""
def init(self, params: Params) -> OptState: ...
def update(self, grads: Grads, state: OptState, params: Params) -> Tuple[Updates, OptState]: ...
def get_params(self, state: OptState) -> Params: ...
# Schedule functions
ScheduleFunction = Callable[[int], float]
# Optimizer factory functions
OptimizerFactory = Callable[..., OptimizerProtocol]Install with Tessl CLI
npx tessl i tessl/pypi-numpyro