A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference
—
Scalable inference algorithms for posterior approximation and model learning, including variational inference, Markov Chain Monte Carlo, and specialized sampling methods for probabilistic programs.
Gradient-based variational inference for scalable approximate posterior computation.
class SVI:
"""
Stochastic Variational Inference for scalable posterior approximation.
SVI optimizes variational parameters to minimize the KL divergence between
a variational guide and the true posterior distribution.
"""
def __init__(self, model, guide, optim, loss):
"""
Initialize SVI with model, guide, optimizer and loss function.
Parameters:
- model (callable): Generative model function
- guide (callable): Variational guide function that approximates posterior
- optim (PyroOptim): Pyro optimizer wrapping PyTorch optimizer
- loss (ELBO): Evidence Lower Bound loss function
Examples:
>>> svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
"""
def step(self, *args, **kwargs) -> float:
"""
Perform one SVI optimization step.
Parameters:
- *args, **kwargs: Arguments to pass to model and guide
Returns:
float: Loss value for this step (negative ELBO)
Examples:
>>> loss = svi.step(data)
>>> print(f"Loss: {loss}")
"""
def evaluate_loss(self, *args, **kwargs) -> float:
"""
Evaluate loss without taking optimization step.
Parameters:
- *args, **kwargs: Arguments to pass to model and guide
Returns:
float: Current loss value
"""
def init_to_feasible(site: dict = None) -> torch.Tensor:
"""
Initialize parameters to feasible values within constraints.
Parameters:
- site (dict, optional): Sample site information
Returns:
Tensor: Feasible initialization value
"""
def init_to_mean(site: dict = None) -> torch.Tensor:
"""
Initialize parameters to distribution mean.
Parameters:
- site (dict, optional): Sample site information
Returns:
Tensor: Mean initialization value
"""
def init_to_sample(site: dict = None) -> torch.Tensor:
"""
Initialize parameters to random samples from prior.
Parameters:
- site (dict, optional): Sample site information
Returns:
Tensor: Random sample initialization
"""Loss functions for variational inference based on the evidence lower bound.
class ELBO:
"""
Base class for Evidence Lower Bound loss functions.
ELBO provides a lower bound on the model evidence (marginal likelihood)
and serves as the optimization objective for variational inference.
"""
def differentiable_loss(self, model, guide, *args, **kwargs) -> torch.Tensor:
"""
Compute differentiable ELBO loss.
Returns:
Tensor: Negative ELBO (loss to minimize)
"""
class Trace_ELBO(ELBO):
"""
Standard trace-based ELBO implementation.
Uses execution traces to compute ELBO via the log probability of the
joint model minus the log probability of the guide.
"""
def __init__(self, num_particles: int = 1, max_plate_nesting: int = float('inf'),
max_iarange_nesting: int = None, vectorize_particles: bool = False,
strict_enumeration_warning: bool = True):
"""
Parameters:
- num_particles (int): Number of Monte Carlo samples for gradient estimation
- max_plate_nesting (int): Maximum depth of nested plates to vectorize over
- vectorize_particles (bool): Whether to vectorize over particles
- strict_enumeration_warning (bool): Whether to warn about enumeration issues
"""
class TraceEnum_ELBO(ELBO):
"""
ELBO with exact enumeration over discrete latent variables.
Computes exact expectations over discrete variables while using
Monte Carlo for continuous variables.
"""
def __init__(self, max_plate_nesting: int = float('inf'), max_iarange_nesting: int = None,
strict_enumeration_warning: bool = True, ignore_jit_warnings: bool = False):
"""
Parameters:
- max_plate_nesting (int): Maximum plate nesting depth for enumeration
- strict_enumeration_warning (bool): Whether to warn about enumeration issues
- ignore_jit_warnings (bool): Whether to ignore JIT compilation warnings
"""
class TraceGraph_ELBO(ELBO):
"""
Memory-efficient ELBO using dependency graphs.
Reduces memory usage by computing gradients using the dependency
structure of the computational graph.
"""
pass
class TraceMeanField_ELBO(ELBO):
"""
ELBO for mean-field variational inference.
Assumes independence between latent variables in the guide,
enabling more efficient computation.
"""
pass
class RenyiELBO(ELBO):
"""
Renyi divergence-based ELBO for more robust inference.
Uses Renyi alpha-divergence instead of KL divergence for
potentially better optimization properties.
"""
def __init__(self, alpha: float = 0.0, num_particles: int = 2, max_plate_nesting: int = float('inf')):
"""
Parameters:
- alpha (float): Renyi divergence parameter (alpha=0 gives KL divergence)
- num_particles (int): Number of particles for gradient estimation
- max_plate_nesting (int): Maximum plate nesting depth
"""MCMC methods for exact sampling from posterior distributions.
class MCMC:
"""
Markov Chain Monte Carlo interface for exact posterior sampling.
MCMC generates correlated samples from the exact posterior distribution
using various kernel methods like HMC and NUTS.
"""
def __init__(self, kernel, num_samples: int, warmup_steps: int = None,
initial_params: dict = None, chain_id: int = 0, mp_context=None,
disable_progbar: bool = False, disable_validation: bool = True,
transforms: dict = None, max_tree_depth: int = None,
target_accept_prob: float = 0.8, jit_compile: bool = False):
"""
Parameters:
- kernel: MCMC kernel (e.g., HMC, NUTS, RandomWalkKernel)
- num_samples (int): Number of MCMC samples to generate
- warmup_steps (int): Number of warmup/burn-in steps
- initial_params (dict): Initial parameter values
- chain_id (int): Chain identifier for multiple chains
- transforms (dict): Parameter transforms for constrained sampling
- target_accept_prob (float): Target acceptance probability for adaptive kernels
- jit_compile (bool): Whether to JIT compile the kernel
Examples:
>>> kernel = NUTS(model)
>>> mcmc = MCMC(kernel, num_samples=1000, warmup_steps=500)
"""
def run(self, *args, **kwargs):
"""
Run the MCMC chain.
Parameters:
- *args, **kwargs: Arguments to pass to the model
Examples:
>>> mcmc.run(data)
>>> samples = mcmc.get_samples()
"""
def get_samples(self, group_by_chain: bool = False) -> dict:
"""
Get MCMC samples after running the chain.
Parameters:
- group_by_chain (bool): Whether to group samples by chain
Returns:
dict: Dictionary mapping sample site names to sample tensors
Examples:
>>> samples = mcmc.get_samples()
>>> theta_samples = samples["theta"]
"""
class HMC:
"""
Hamiltonian Monte Carlo kernel.
HMC uses gradient information to make efficient proposals in
continuous parameter spaces.
"""
def __init__(self, model, step_size: float = 1.0, num_steps: int = 1,
adapt_step_size: bool = True, adapt_mass_matrix: bool = True,
full_mass: bool = False, transforms: dict = None,
max_plate_nesting: int = None, jit_compile: bool = False,
jit_options: dict = None, ignore_jit_warnings: bool = False):
"""
Parameters:
- model (callable): Model to sample from
- step_size (float): Integration step size
- num_steps (int): Number of leapfrog steps per iteration
- adapt_step_size (bool): Whether to adapt step size during warmup
- adapt_mass_matrix (bool): Whether to adapt mass matrix
- full_mass (bool): Whether to use full mass matrix (vs diagonal)
- transforms (dict): Parameter transformations
"""
class NUTS:
"""
No-U-Turn Sampler, an adaptive version of HMC.
NUTS automatically determines the number of leapfrog steps to take
by detecting when the trajectory starts to reverse direction.
"""
def __init__(self, model, step_size: float = 1.0, adapt_step_size: bool = True,
adapt_mass_matrix: bool = True, full_mass: bool = False,
transforms: dict = None, max_plate_nesting: int = None,
max_tree_depth: int = 10, target_accept_prob: float = 0.8,
jit_compile: bool = False, jit_options: dict = None,
ignore_jit_warnings: bool = False):
"""
Parameters:
- model (callable): Model to sample from
- step_size (float): Initial step size
- max_tree_depth (int): Maximum binary tree depth
- target_accept_prob (float): Target acceptance probability for adaptation
Examples:
>>> nuts_kernel = NUTS(model)
>>> mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
"""
class RandomWalkKernel:
"""
Random walk Metropolis-Hastings kernel.
Simple MCMC kernel that proposes new states by adding random noise
to the current state.
"""
def __init__(self, model, step_size: dict = None, adapt_step_size: bool = True,
transforms: dict = None, max_plate_nesting: int = None):
"""
Parameters:
- model (callable): Model to sample from
- step_size (dict): Step sizes for each parameter
- adapt_step_size (bool): Whether to adapt step size during warmup
"""Generate predictions and samples from trained models.
class Predictive:
"""
Generate predictive samples from posterior or prior distributions.
Predictive enables posterior predictive checks, prior predictive checks,
and out-of-sample predictions by sampling from the model with different
parameter configurations.
"""
def __init__(self, model, guide=None, posterior_samples: dict = None,
num_samples: int = None, return_sites: list = None,
parallel: bool = False, batch_ndims: int = 1):
"""
Parameters:
- model (callable): Generative model function
- guide (callable, optional): Variational guide for posterior sampling
- posterior_samples (dict, optional): Pre-computed posterior samples
- num_samples (int, optional): Number of samples to generate
- return_sites (list, optional): Sites to include in output
- parallel (bool): Whether to parallelize sampling
- batch_ndims (int): Number of batch dimensions
Examples:
>>> # Posterior predictive with guide
>>> predictive = Predictive(model, guide=guide, num_samples=1000)
>>> samples = predictive(data)
>>>
>>> # Prior predictive
>>> predictive = Predictive(model, num_samples=100)
>>> prior_samples = predictive(data)
"""
def __call__(self, *args, **kwargs) -> dict:
"""
Generate predictive samples.
Parameters:
- *args, **kwargs: Arguments to pass to the model
Returns:
dict: Dictionary mapping site names to sample tensors
Examples:
>>> samples = predictive(test_data)
>>> predictions = samples["obs"]
"""
class WeighedPredictive:
"""
Generate weighted predictive samples using importance sampling.
Useful when posterior samples come from importance sampling or
when samples have non-uniform weights.
"""
def __init__(self, model, guide=None, posterior_samples: dict = None,
weights: torch.Tensor = None, num_samples: int = None,
return_sites: list = None, parallel: bool = False):
"""
Parameters:
- model (callable): Generative model function
- guide (callable, optional): Guide function
- posterior_samples (dict, optional): Pre-computed samples
- weights (Tensor, optional): Sample weights
- num_samples (int, optional): Number of samples to generate
"""
class EmpiricalMarginal:
"""
Empirical marginal distribution from MCMC or SVI samples.
Converts a collection of samples into a distribution object that
can be used like any other Pyro distribution.
"""
def __init__(self, samples: torch.Tensor, log_weights: torch.Tensor = None):
"""
Parameters:
- samples (Tensor): Sample values
- log_weights (Tensor, optional): Log weights for samples
Examples:
>>> samples = mcmc.get_samples()["theta"]
>>> marginal = EmpiricalMarginal(samples)
>>> new_sample = marginal.sample()
"""Importance sampling methods for model comparison and marginal likelihood estimation.
class Importance:
"""
Importance sampling for marginal likelihood estimation.
Uses importance sampling to estimate the model evidence (marginal likelihood)
which is useful for model comparison and selection.
"""
def __init__(self, model, guide, num_samples: int):
"""
Parameters:
- model (callable): Generative model
- guide (callable): Importance sampling distribution (proposal)
- num_samples (int): Number of importance samples
Examples:
>>> importance = Importance(model, guide, num_samples=10000)
>>> log_evidence = importance.run(data)
"""
def run(self, *args, **kwargs) -> torch.Tensor:
"""
Run importance sampling to estimate log marginal likelihood.
Parameters:
- *args, **kwargs: Arguments to pass to model and guide
Returns:
Tensor: Log marginal likelihood estimate
"""
class SMCFilter:
"""
Sequential Monte Carlo filtering for state space models.
Implements particle filtering for sequential Bayesian inference
in time series and state space models.
"""
def __init__(self, model, guide, num_particles: int, max_plate_nesting: int):
"""
Parameters:
- model (callable): State space model
- guide (callable): Proposal distribution for particles
- num_particles (int): Number of particles to maintain
- max_plate_nesting (int): Maximum plate nesting depth
"""Advanced inference algorithms for specific model types and scenarios.
class SVGD:
"""
Stein Variational Gradient Descent for non-parametric inference.
SVGD optimizes a set of particles to approximate the posterior distribution
using kernelized Stein discrepancy minimization.
"""
def __init__(self, model, kernel, optimizer, num_particles: int):
"""
Parameters:
- model (callable): Model function
- kernel: Kernel function for Stein method
- optimizer: Optimizer for particle updates
- num_particles (int): Number of particles to optimize
"""
class ReweightedWakeSleep:
"""
Reweighted Wake-Sleep algorithm for deep generative models.
Alternative to standard variational inference that can handle
more complex posterior approximations.
"""
def __init__(self, model, guide, wake_loss, sleep_loss):
"""
Parameters:
- model (callable): Generative model
- guide (callable): Recognition model
- wake_loss: Loss function for wake phase
- sleep_loss: Loss function for sleep phase
"""
def config_enumerate(default: str = None, expand: bool = False, num_samples: int = None):
"""
Configure automatic enumeration over discrete latent variables.
Decorator that enables exact marginalization over discrete variables
in models with both discrete and continuous latent variables.
Parameters:
- default (str): Default enumeration strategy ("sequential" or "parallel")
- expand (bool): Whether to expand enumerated dimensions
- num_samples (int): Number of samples for approximate enumeration
Examples:
>>> @config_enumerate
>>> def model():
... z = pyro.sample("z", dist.Categorical(torch.ones(3)))
... return pyro.sample("x", dist.Normal(z, 1))
"""
def infer_discrete(first_available_dim: int = None, temperature: float = 1.0,
cooler: callable = None):
"""
Infer discrete latent variables by enumeration or sampling.
Effect handler that automatically handles discrete variable inference
by choosing between exact enumeration and approximate sampling.
Parameters:
- first_available_dim (int): First tensor dimension available for enumeration
- temperature (float): Temperature for discrete sampling
- cooler (callable): Cooling schedule for simulated annealing
Examples:
>>> with infer_discrete():
... svi.step(data)
"""import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
def model(data):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
def guide(data):
mu_q = pyro.param("mu_q", torch.tensor(0.0))
sigma_q = pyro.param("sigma_q", torch.tensor(1.0), constraint=dist.constraints.positive)
pyro.sample("mu", dist.Normal(mu_q, sigma_q))
pyro.sample("sigma", dist.LogNormal(0, 1)) # Use prior as guide
# Training
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
losses = []
for step in range(1000):
loss = svi.step(data)
losses.append(loss)from pyro.infer import MCMC, NUTS
def model(data):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
# MCMC sampling
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(data)
# Get samples
samples = mcmc.get_samples()
mu_samples = samples["mu"]
sigma_samples = samples["sigma"]from pyro.infer import Predictive
# After training SVI or MCMC
predictive = Predictive(model, guide=guide, num_samples=1000)
posterior_samples = predictive(data)
# Generate predictions for new data
predictive_new = Predictive(model, guide=guide, num_samples=100)
predictions = predictive_new(new_data)from pyro.infer import Importance
# Compare two models
importance1 = Importance(model1, guide1, num_samples=10000)
log_evidence1 = importance1.run(data)
importance2 = Importance(model2, guide2, num_samples=10000)
log_evidence2 = importance2.run(data)
# Bayes factor
bayes_factor = torch.exp(log_evidence1 - log_evidence2)
print(f"Bayes factor (Model 1 vs Model 2): {bayes_factor}")Install with Tessl CLI
npx tessl i tessl/pypi-pyro-ppl