A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference
—
Pyro optimization utilities for training probabilistic models, including custom optimizers and PyTorch optimizer wrappers for use with Pyro's parameter store.
Base classes and utilities for wrapping PyTorch optimizers to work with Pyro's parameter management system.
class PyroOptim:
"""
Base wrapper class for PyTorch optimizers that works with Pyro's parameter store.
Automatically manages parameter registration and optimization for Pyro models.
"""
def __init__(self, optim_constructor, optim_args, clip_args=None):
"""
Parameters:
- optim_constructor: PyTorch optimizer constructor
- optim_args (dict): Arguments to pass to optimizer
- clip_args (dict, optional): Gradient clipping arguments
"""
def __call__(self, params, *args, **kwargs):
"""Create optimizer instance for given parameters."""
class PyroLRScheduler:
"""
Wrapper for PyTorch learning rate schedulers that works with PyroOptim.
"""
def __init__(self, scheduler_constructor, optim_args, **kwargs):
"""
Parameters:
- scheduler_constructor: PyTorch scheduler constructor
- optim_args (dict): Arguments for scheduler
"""Custom optimizers designed specifically for probabilistic programming use cases.
class ClippedAdam(PyroOptim):
"""
Adam optimizer with gradient clipping support.
Particularly useful for training probabilistic models where gradients
can become unstable.
"""
def __init__(self, optim_args, clip_args=None):
"""
Parameters:
- optim_args (dict): Arguments for Adam optimizer (lr, betas, etc.)
- clip_args (dict, optional): Gradient clipping configuration
"""
class AdagradRMSProp(PyroOptim):
"""
Hybrid optimizer combining Adagrad and RMSprop advantages.
Designed for sparse gradient scenarios common in probabilistic models.
"""
def __init__(self, optim_args):
"""
Parameters:
- optim_args (dict): Optimizer arguments (lr, alpha, eps, etc.)
"""
class DCTAdam(PyroOptim):
"""
Adam optimizer with Discrete Cosine Transform preconditioning.
Useful for models with structured parameter spaces.
"""
def __init__(self, optim_args):
"""
Parameters:
- optim_args (dict): Optimizer arguments and DCT configuration
"""All standard PyTorch optimizers are available as Pyro-wrapped versions for seamless integration with the parameter store.
# Standard PyTorch optimizers wrapped for Pyro
def Adam(optim_args, clip_args=None):
"""Wrapped torch.optim.Adam for Pyro parameter store."""
def SGD(optim_args, clip_args=None):
"""Wrapped torch.optim.SGD for Pyro parameter store."""
def RMSprop(optim_args, clip_args=None):
"""Wrapped torch.optim.RMSprop for Pyro parameter store."""
def Adagrad(optim_args, clip_args=None):
"""Wrapped torch.optim.Adagrad for Pyro parameter store."""
def AdamW(optim_args, clip_args=None):
"""Wrapped torch.optim.AdamW for Pyro parameter store."""
# And many more PyTorch optimizers...Support for distributed optimization across multiple processes or machines.
class HorovodOptimizer(PyroOptim):
"""
Wrapper for Horovod distributed training integration.
Enables data-parallel training of Pyro models across multiple GPUs/nodes.
"""
def __init__(self, optim_constructor, optim_args, clip_args=None):
"""
Parameters:
- optim_constructor: Base PyTorch optimizer
- optim_args (dict): Optimizer arguments
- clip_args (dict, optional): Gradient clipping configuration
"""import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
def model(data):
mu = pyro.param("mu", torch.tensor(0.0))
sigma = pyro.param("sigma", torch.tensor(1.0), constraint=dist.constraints.positive)
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
def guide(data):
pass # Empty guide for maximum likelihood
# Setup optimizer
adam = Adam({"lr": 0.01})
svi = SVI(model, guide, adam, Trace_ELBO())
# Training loop
for step in range(1000):
loss = svi.step(data)from pyro.optim import ClippedAdam
# Adam with gradient clipping
clipped_adam = ClippedAdam(
optim_args={"lr": 0.01, "betas": (0.90, 0.999)},
clip_args={"clip_norm": 10.0}
)
svi = SVI(model, guide, clipped_adam, Trace_ELBO())from pyro.optim import PyroLRScheduler
import torch.optim as optim
# Setup base optimizer
base_optimizer = Adam({"lr": 0.1})
# Add learning rate scheduler
scheduler = PyroLRScheduler(
optim.StepLR,
{"optimizer": base_optimizer, "step_size": 100, "gamma": 0.1}
)
# Use in SVI
svi = SVI(model, guide, scheduler, Trace_ELBO())
# Training with scheduled learning rate
for step in range(1000):
loss = svi.step(data)
if step % 100 == 0:
scheduler.step() # Update learning ratefrom pyro.optim import multi
# Different optimizers for different parameter groups
optimizers = {
"mu": Adam({"lr": 0.01}),
"sigma": SGD({"lr": 0.001})
}
multi_optim = multi.MultiOptimizer(optimizers)
svi = SVI(model, guide, multi_optim, Trace_ELBO())Install with Tessl CLI
npx tessl i tessl/pypi-pyro-ppl