A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference
npx @tessl/cli install tessl/pypi-pyro-ppl@1.9.0Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch that enables universal probabilistic modeling and inference. It combines the expressiveness of probabilistic programming with the power of deep learning, providing a comprehensive toolkit for Bayesian modeling, variational inference, and uncertainty quantification.
pip install pyro-pplimport pyroCommon imports for probabilistic programming:
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import torchimport pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import torch
# Define a simple Bayesian model
def model(data):
# Prior on the parameter
theta = pyro.sample("theta", dist.Beta(1.0, 1.0))
# Likelihood
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Bernoulli(theta), obs=data)
# Define a variational guide (posterior approximation)
def guide(data):
# Variational parameters
alpha_q = pyro.param("alpha_q", torch.tensor(1.0), constraint=dist.constraints.positive)
beta_q = pyro.param("beta_q", torch.tensor(1.0), constraint=dist.constraints.positive)
# Variational distribution
pyro.sample("theta", dist.Beta(alpha_q, beta_q))
# Generate some data
data = torch.tensor([1.0, 0.0, 1.0, 1.0, 0.0])
# Set up stochastic variational inference
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())
# Training loop
for step in range(1000):
loss = svi.step(data)
if step % 100 == 0:
print(f"Step {step}, Loss: {loss}")
# Get posterior samples
from pyro.infer import Predictive
predictive = Predictive(model, guide=guide, num_samples=1000)
samples = predictive(data)
print(f"Posterior mean of theta: {samples['theta'].mean():.3f}")Pyro's architecture is built on several key design principles:
pyro.sample and pyro.parampyro.nnPrimary functions and constructs for building probabilistic programs, including sampling, parameter management, and independence declarations.
def sample(name: str, fn: dist.Distribution, *args, obs=None, obs_mask=None, infer=None, **kwargs):
"""
Primitive stochastic function for probabilistic programming.
Parameters:
- name (str): Unique name for the sample site
- fn (Distribution): Probability distribution to sample from
- obs (Tensor, optional): Observed data to condition on
- obs_mask (Tensor, optional): Mask for observed data
- infer (dict, optional): Inference configuration
Returns:
Tensor: Sample from the distribution
"""
def param(name: str, init_tensor=None, constraint=None, event_dim=None):
"""
Declare and retrieve learnable parameters.
Parameters:
- name (str): Parameter name
- init_tensor (Tensor, optional): Initial value
- constraint (Constraint, optional): Parameter constraint
- event_dim (int, optional): Event dimension
Returns:
Tensor: Parameter tensor
"""
def plate(name: str, size: int, subsample_size=None, dim=None):
"""
Independence context manager for vectorized computation.
Parameters:
- name (str): Plate name
- size (int): Plate size
- subsample_size (int, optional): Subsample size for minibatching
- dim (int, optional): Tensor dimension
Returns:
PlateMessenger: Context manager
"""
def factor(log_factor):
"""
Add arbitrary factor to log probability.
Parameters:
- log_factor (Tensor): Log probability factor
"""Comprehensive collection of probability distributions including continuous, discrete, multivariate, and specialized distributions for probabilistic modeling.
# Core continuous distributions
class Normal(dist.Distribution):
def __init__(self, loc: torch.Tensor, scale: torch.Tensor): ...
class Beta(dist.Distribution):
def __init__(self, concentration1: torch.Tensor, concentration0: torch.Tensor): ...
class Gamma(dist.Distribution):
def __init__(self, concentration: torch.Tensor, rate: torch.Tensor): ...
# Discrete distributions
class Bernoulli(dist.Distribution):
def __init__(self, probs: torch.Tensor = None, logits: torch.Tensor = None): ...
class Categorical(dist.Distribution):
def __init__(self, probs: torch.Tensor = None, logits: torch.Tensor = None): ...
# Multivariate distributions
class MultivariateNormal(dist.Distribution):
def __init__(self, loc: torch.Tensor, covariance_matrix: torch.Tensor = None,
precision_matrix: torch.Tensor = None, scale_tril: torch.Tensor = None): ...Scalable inference algorithms including variational inference (SVI), Markov Chain Monte Carlo (MCMC), and importance sampling for posterior approximation.
class SVI:
"""Stochastic Variational Inference."""
def __init__(self, model, guide, optim, loss):
"""
Parameters:
- model: Generative model function
- guide: Variational guide function
- optim: Optimizer instance
- loss: Loss function (typically ELBO)
"""
def step(self, *args, **kwargs) -> float:
"""Perform one SVI step, returns loss."""
class MCMC:
"""Markov Chain Monte Carlo."""
def __init__(self, kernel, num_samples: int, warmup_steps: int = None): ...
def run(self, *args, **kwargs): ...
class Predictive:
"""Generate predictions from posterior samples."""
def __init__(self, model, guide=None, posterior_samples=None, num_samples=None): ...Deep probabilistic models combining neural networks with probabilistic programming, including Bayesian neural networks and stochastic layers.
class PyroModule(torch.nn.Module):
"""Base class for Pyro modules with parameter/sample integration."""
class PyroParam:
"""Descriptor for Pyro parameters in modules."""
def __init__(self, init_tensor, constraint=None, event_dim=None): ...
class PyroSample:
"""Descriptor for Pyro samples in modules."""
def __init__(self, prior): ...
class DenseNN(PyroModule):
"""Dense neural network for use in flows and guides."""
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int): ...Bijective transformations and parameter constraints for reparametrization and constrained optimization in probabilistic models.
# Core transforms
class Transform:
"""Base class for bijective transforms."""
def __call__(self, x): ...
def inv(self, y): ...
def log_abs_det_jacobian(self, x, y): ...
class AffineTransform(Transform):
def __init__(self, loc: torch.Tensor, scale: torch.Tensor): ...
# Flow-based transforms
class AffineAutoregressive(Transform):
def __init__(self, autoregressive_nn, log_scale_min_clip: float = -5.0): ...
# Constraints
class Constraint:
"""Base class for parameter constraints."""
def check(self, value): ...
positive: Constraint
unit_interval: Constraint
simplex: ConstraintGaussian process models for non-parametric Bayesian modeling, including kernels, likelihoods, and efficient GP inference.
class Kernel:
"""Base class for GP kernels."""
def forward(self, X, Z=None, diag: bool = False): ...
class RBF(Kernel):
"""Radial Basis Function kernel."""
def __init__(self, input_dim: int, lengthscale=None, variance=None): ...
class GPModel:
"""Base Gaussian Process model."""
def __init__(self, X, y, kernel, likelihood): ...Optimization utilities and PyTorch optimizer wrappers for training probabilistic models with Pyro's parameter store system.
class PyroOptim:
"""Base wrapper for PyTorch optimizers."""
def __init__(self, optim_constructor, optim_args, clip_args=None): ...
class ClippedAdam(PyroOptim):
"""Adam optimizer with gradient clipping."""
def __init__(self, optim_args, clip_args=None): ...
# PyTorch optimizer wrappers
def Adam(optim_args, clip_args=None): ...
def SGD(optim_args, clip_args=None): ...
def RMSprop(optim_args, clip_args=None): ...from typing import Union, Optional, Dict, Any, Callable, Iterator, Sequence
from torch import Tensor
from pyro.distributions import Distribution, Transform, Constraint
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.params.param_store import ParamStoreDict
from pyro.poutine.runtime import InferDict
from pyro.poutine.plate_messenger import PlateMessenger
# Core types
ModelFunction = Callable[..., None]
GuideFunction = Callable[..., None]
InitFunction = Callable[[str], Tensor]
# Distribution types
DistributionType = Union[Distribution, type]
ConstraintType = Union[Constraint, type]
TransformType = Union[Transform, type]
# Inference types
OptimType = Union[torch.optim.Optimizer, pyro.optim.PyroOptim]
LossType = Union[pyro.infer.ELBO, Callable]
# Pyro-specific types
TorchDistributionMixin = pyro.distributions.torch_distribution.TorchDistributionMixin
ParamStoreDict = pyro.params.param_store.ParamStoreDict
InferDict = pyro.poutine.runtime.InferDict
PlateMessenger = pyro.poutine.plate_messenger.PlateMessenger