A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference
—
Deep probabilistic models combining neural networks with probabilistic programming, enabling Bayesian neural networks, stochastic layers, and seamless integration between PyTorch modules and Pyro's probabilistic primitives.
Base classes and descriptors for creating probabilistic neural network modules that integrate seamlessly with Pyro's effect system.
class PyroModule(torch.nn.Module):
"""
Base class for Pyro modules with integrated parameter and sample management.
PyroModule extends torch.nn.Module to support Pyro's parameter store and
sample statements, enabling probabilistic neural networks and automatic
integration with inference algorithms.
Examples:
>>> class BayesianLinear(PyroModule):
... def __init__(self, in_features, out_features):
... super().__init__()
... self.in_features = in_features
... self.out_features = out_features
...
... # Stochastic weights
... self.weight = PyroSample(
... dist.Normal(0, 1).expand([out_features, in_features]).to_event(2)
... )
...
... # Learnable bias
... self.bias = PyroParam(torch.zeros(out_features))
...
... def forward(self, x):
... return torch.nn.functional.linear(x, self.weight, self.bias)
"""
def __setattr__(self, name: str, value):
"""Override to handle PyroParam and PyroSample descriptors."""
def named_pyro_params(self, prefix: str = '', recurse: bool = True):
"""
Iterate over Pyro parameters in the module.
Parameters:
- prefix (str): Prefix to prepend to parameter names
- recurse (bool): Whether to recurse into submodules
Yields:
Tuple[str, torch.Tensor]: (name, parameter) pairs
"""
class PyroParam:
"""
Descriptor for Pyro parameters within PyroModule.
PyroParam creates learnable parameters that are automatically registered
with Pyro's parameter store and can be constrained or transformed.
"""
def __init__(self, init_tensor, constraint=dist.constraints.real, event_dim=None):
"""
Parameters:
- init_tensor (Tensor): Initial parameter value
- constraint (Constraint): Parameter constraint (e.g., positive, simplex)
- event_dim (int, optional): Number of rightmost event dimensions
Examples:
>>> # Unconstrained parameter
>>> self.mu = PyroParam(torch.tensor(0.0))
>>>
>>> # Positive parameter
>>> self.sigma = PyroParam(torch.tensor(1.0), constraint=dist.constraints.positive)
>>>
>>> # Simplex parameter (probabilities)
>>> self.probs = PyroParam(torch.ones(5), constraint=dist.constraints.simplex)
"""
def __get__(self, obj, obj_type=None) -> torch.Tensor:
"""Get parameter value from Pyro parameter store."""
def __set__(self, obj, value):
"""Set parameter value in Pyro parameter store."""
class PyroSample:
"""
Descriptor for Pyro samples within PyroModule.
PyroSample creates stochastic variables that are automatically sampled
from specified prior distributions during model execution.
"""
def __init__(self, prior):
"""
Parameters:
- prior (Distribution or callable): Prior distribution or function
returning a distribution
Examples:
>>> # Fixed prior distribution
>>> self.weight = PyroSample(dist.Normal(0, 1))
>>>
>>> # Parameterized prior
>>> self.weight = PyroSample(lambda: dist.Normal(self.weight_loc, self.weight_scale))
>>>
>>> # Matrix-valued parameter
>>> self.W = PyroSample(dist.Normal(0, 1).expand([10, 5]).to_event(2))
"""
def __get__(self, obj, obj_type=None) -> torch.Tensor:
"""Sample from prior distribution."""
def pyro_method(fn):
"""
Decorator to create Pyro-aware methods in PyroModule.
Ensures that sample statements within decorated methods use appropriate
name scoping and integration with the module's parameter namespace.
Parameters:
- fn (callable): Method to decorate
Returns:
callable: Decorated method with Pyro integration
Examples:
>>> class MyModule(PyroModule):
... @pyro_method
... def model(self, x):
... z = pyro.sample("z", dist.Normal(0, 1))
... return self.forward(x, z)
"""Specialized neural network architectures for probabilistic modeling and normalizing flows.
class DenseNN(PyroModule):
"""
Dense (fully-connected) neural network with configurable architecture.
Commonly used in normalizing flows, variational autoencoders, and as
function approximators in probabilistic models.
"""
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int,
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
residual_connections: bool = False, batch_norm: bool = False,
dropout_prob: float = 0.0):
"""
Parameters:
- input_dim (int): Input dimension
- hidden_dims (List[int]): List of hidden layer dimensions
- output_dim (int): Output dimension
- nonlinearity (Module): Activation function between layers
- residual_connections (bool): Whether to add residual connections
- batch_norm (bool): Whether to use batch normalization
- dropout_prob (float): Dropout probability (0 = no dropout)
Examples:
>>> # Simple 3-layer network
>>> net = DenseNN(10, [64, 32], 1)
>>>
>>> # Network with batch norm and dropout
>>> net = DenseNN(20, [128, 64, 32], 5,
... batch_norm=True, dropout_prob=0.1)
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network.
Parameters:
- x (Tensor): Input tensor of shape (..., input_dim)
Returns:
Tensor: Output tensor of shape (..., output_dim)
"""
class ConditionalDenseNN(PyroModule):
"""
Conditional dense neural network that takes additional context input.
Useful for conditional normalizing flows and context-dependent function
approximation in probabilistic models.
"""
def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],
output_dim: int, nonlinearity: torch.nn.Module = torch.nn.ReLU(),
residual_connections: bool = False):
"""
Parameters:
- input_dim (int): Primary input dimension
- context_dim (int): Context/condition dimension
- hidden_dims (List[int]): Hidden layer dimensions
- output_dim (int): Output dimension
- nonlinearity (Module): Activation function
- residual_connections (bool): Whether to use residual connections
Examples:
>>> # Conditional network
>>> cond_net = ConditionalDenseNN(10, 5, [64, 32], 2)
>>> output = cond_net(x, context)
"""
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""
Forward pass with context input.
Parameters:
- x (Tensor): Primary input of shape (..., input_dim)
- context (Tensor): Context input of shape (..., context_dim)
Returns:
Tensor: Output tensor of shape (..., output_dim)
"""
class AutoRegressiveNN(PyroModule):
"""
Autoregressive neural network with masked connections.
Implements MADE (Masked Autoencoder for Distribution Estimation) for
autoregressive density modeling and normalizing flows.
"""
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim_multiplier: int = 1,
nonlinearity: torch.nn.Module = torch.nn.ReLU(), residual_connections: bool = False,
random_mask: bool = False, activation: torch.nn.Module = None):
"""
Parameters:
- input_dim (int): Input dimension
- hidden_dims (List[int]): Hidden layer dimensions
- output_dim_multiplier (int): Output dimension multiplier (for multiple outputs per input)
- nonlinearity (Module): Hidden layer activation
- residual_connections (bool): Whether to use residual connections
- random_mask (bool): Whether to use random ordering for autoregressive mask
- activation (Module): Final layer activation
Examples:
>>> # Autoregressive network for 10-dimensional data
>>> ar_net = AutoRegressiveNN(10, [64, 64], output_dim_multiplier=2)
>>> # Output has shape (..., 20) for 2 outputs per input dimension
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass preserving autoregressive property.
Parameters:
- x (Tensor): Input tensor of shape (..., input_dim)
Returns:
Tensor: Output respecting autoregressive ordering
"""
class ConditionalAutoRegressiveNN(AutoRegressiveNN):
"""
Conditional autoregressive neural network with context input.
Combines autoregressive masking with conditional computation for
context-dependent autoregressive models.
"""
def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],
output_dim_multiplier: int = 1, nonlinearity: torch.nn.Module = torch.nn.ReLU()):
"""
Parameters:
- input_dim (int): Primary input dimension
- context_dim (int): Context dimension
- hidden_dims (List[int]): Hidden layer dimensions
- output_dim_multiplier (int): Output multiplier per input dimension
- nonlinearity (Module): Activation function
"""
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""Forward pass with context input maintaining autoregressive property."""
class MaskedLinear(torch.nn.Module):
"""
Linear layer with learnable or fixed mask for autoregressive networks.
Used as a building block in autoregressive neural networks where
connections must respect the autoregressive ordering.
"""
def __init__(self, in_features: int, out_features: int, mask: torch.Tensor = None,
bias: bool = True):
"""
Parameters:
- in_features (int): Input feature dimension
- out_features (int): Output feature dimension
- mask (Tensor, optional): Binary mask matrix (1=keep, 0=mask)
- bias (bool): Whether to include bias parameter
Examples:
>>> # Create mask for autoregressive ordering
>>> mask = torch.tril(torch.ones(5, 5)) # Lower triangular
>>> masked_layer = MaskedLinear(5, 5, mask)
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with masked weight matrix."""Tools for creating and working with Bayesian neural networks where weights and biases are treated as random variables.
def lift_module(nn_module: torch.nn.Module, prior: callable, guide: callable = None):
"""
Lift a PyTorch module to a Bayesian neural network.
Converts deterministic neural network parameters to random variables
with specified prior distributions.
Parameters:
- nn_module (Module): PyTorch module to convert
- prior (callable): Function that returns prior distributions for parameters
- guide (callable, optional): Function that returns guide distributions
Returns:
PyroModule: Bayesian version of the input module
Examples:
>>> # Define deterministic network
>>> net = torch.nn.Linear(10, 1)
>>>
>>> # Define priors
>>> def prior(name, shape):
... return dist.Normal(0, 1).expand(shape).to_event(len(shape))
>>>
>>> # Create Bayesian network
>>> bnn = lift_module(net, prior)
>>>
>>> # Use in probabilistic model
>>> def model(x, y):
... lifted_nn = pyro.random_module("nn", net, prior)
... prediction = lifted_nn(x)
... pyro.sample("obs", dist.Normal(prediction.squeeze(), 0.1), obs=y)
"""
def sample_module_outputs(model: PyroModule, input_data: torch.Tensor,
num_samples: int = 100) -> torch.Tensor:
"""
Sample multiple outputs from a Bayesian neural network.
Parameters:
- model (PyroModule): Bayesian neural network model
- input_data (Tensor): Input data
- num_samples (int): Number of posterior samples to generate
Returns:
Tensor: Sampled outputs with shape (num_samples, batch_size, output_dim)
Examples:
>>> outputs = sample_module_outputs(bnn, test_data, num_samples=50)
>>> mean_prediction = outputs.mean(dim=0)
>>> uncertainty = outputs.std(dim=0)
"""
class BayesianModule(PyroModule):
"""
Base class for implementing custom Bayesian neural network layers.
Provides utilities for parameter sampling and uncertainty quantification
in neural network layers.
"""
def __init__(self, name: str):
"""
Parameters:
- name (str): Module name for parameter scoping
"""
super().__init__()
self._pyro_name = name
def sample_parameters(self):
"""Sample parameters from their prior/posterior distributions."""
def forward_with_samples(self, x: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
"""
Forward pass with multiple parameter samples for uncertainty estimation.
Parameters:
- x (Tensor): Input data
- num_samples (int): Number of parameter samples
Returns:
Tensor: Output samples with uncertainty
"""Specialized layers for variational inference and amortized inference in deep generative models.
class VariationalLinear(PyroModule):
"""
Variational linear layer with learnable mean and variance parameters.
Implements local reparameterization trick for efficient variational
inference in neural networks.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True,
prior_scale: float = 1.0):
"""
Parameters:
- in_features (int): Input feature dimension
- out_features (int): Output feature dimension
- bias (bool): Whether to include bias term
- prior_scale (float): Scale of prior distribution on weights
Examples:
>>> var_layer = VariationalLinear(10, 5, prior_scale=0.1)
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass using local reparameterization trick."""
class AmortizedLDA(PyroModule):
"""
Amortized Latent Dirichlet Allocation using neural networks.
Implements neural variational inference for topic modeling where
the variational parameters are predicted by neural networks.
"""
def __init__(self, vocab_size: int, num_topics: int, hidden_dim: int = 100,
dropout: float = 0.2):
"""
Parameters:
- vocab_size (int): Vocabulary size
- num_topics (int): Number of topics
- hidden_dim (int): Hidden dimension for encoder network
- dropout (float): Dropout probability
"""
def model(self, docs: torch.Tensor, doc_lengths: torch.Tensor):
"""LDA generative model."""
def guide(self, docs: torch.Tensor, doc_lengths: torch.Tensor):
"""Neural variational guide for LDA."""Functions for seamless integration between PyTorch modules and Pyro probabilistic programs.
def to_pyro_module_(nn_module: torch.nn.Module, prior: callable = None) -> PyroModule:
"""
Convert PyTorch module to PyroModule in-place.
Parameters:
- nn_module (Module): PyTorch module to convert
- prior (callable, optional): Prior distribution generator for parameters
Returns:
PyroModule: Converted module (same object)
Examples:
>>> net = torch.nn.Linear(10, 1)
>>> pyro_net = to_pyro_module_(net)
"""
def clear_module_hooks(module: torch.nn.Module):
"""
Clear all Pyro-related hooks from a PyTorch module.
Parameters:
- module (Module): Module to clear hooks from
"""
def module_prior(module_name: str, module: torch.nn.Module,
prior_fn: callable) -> torch.nn.Module:
"""
Apply prior distributions to all parameters in a PyTorch module.
Parameters:
- module_name (str): Name prefix for Pyro sample sites
- module (Module): PyTorch module
- prior_fn (callable): Function returning prior distributions
Returns:
Module: Module with stochastic parameters
Examples:
>>> def weight_prior(name, param):
... return dist.Normal(0, 1).expand(param.shape).to_event(param.dim())
>>>
>>> net = torch.nn.Linear(10, 1)
>>> stochastic_net = module_prior("net", net, weight_prior)
"""
class PyroModuleList(torch.nn.ModuleList, PyroModule):
"""
ModuleList that supports PyroModule functionality.
Enables lists of PyroModules to work correctly with Pyro's
parameter management and effect handling.
Examples:
>>> layers = PyroModuleList([
... BayesianLinear(10, 20),
... BayesianLinear(20, 1)
... ])
"""
def __init__(self, modules=None):
"""
Parameters:
- modules (iterable, optional): Iterable of modules to add
"""import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample, PyroParam
import torch.nn.functional as F
class BayesianLinear(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Stochastic weights and biases
self.weight = PyroSample(
dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
)
self.bias = PyroSample(
dist.Normal(0., 1.).expand([out_features]).to_event(1)
)
def forward(self, x):
return F.linear(x, self.weight, self.bias)
# Usage in a model
def model(x, y):
fc = BayesianLinear(3, 1)
# Forward pass
mean = fc(x).squeeze()
# Likelihood
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.Normal(mean, 0.1), obs=y)
def guide(x, y):
# Use a simpler guide or let AutoGuides handle it
passclass VAE(PyroModule):
def __init__(self, input_dim=784, hidden_dim=400, z_dim=20):
super().__init__()
# Encoder
self.encoder_fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.encoder_mu = torch.nn.Linear(hidden_dim, z_dim)
self.encoder_sigma = torch.nn.Linear(hidden_dim, z_dim)
# Decoder
self.decoder_fc1 = torch.nn.Linear(z_dim, hidden_dim)
self.decoder_fc2 = torch.nn.Linear(hidden_dim, input_dim)
def model(self, x):
# Register parameters with Pyro
pyro.module("decoder", self)
batch_size = x.shape[0]
# Prior
with pyro.plate("data", batch_size):
z_loc = torch.zeros(batch_size, self.z_dim)
z_scale = torch.ones(batch_size, self.z_dim)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# Decode
hidden = F.relu(self.decoder_fc1(z))
mu_img = torch.sigmoid(self.decoder_fc2(hidden))
# Likelihood
pyro.sample("obs", dist.Bernoulli(mu_img).to_event(1), obs=x)
def guide(self, x):
# Register parameters with Pyro
pyro.module("encoder", self)
batch_size = x.shape[0]
# Encode
hidden = F.relu(self.encoder_fc1(x))
z_mu = self.encoder_mu(hidden)
z_sigma = F.softplus(self.encoder_sigma(hidden))
# Variational distribution
with pyro.plate("data", batch_size):
pyro.sample("latent", dist.Normal(z_mu, z_sigma).to_event(1))class UncertaintyNet(PyroModule):
def __init__(self):
super().__init__()
self.linear = PyroModule[torch.nn.Linear](10, 1)
# Learnable noise parameter
self.sigma = PyroParam(torch.tensor(1.0),
constraint=dist.constraints.positive)
def forward(self, x, y=None):
# Sample network weights
lifted_module = pyro.random_module("module", self.linear,
lambda name, p: dist.Normal(0, 1)
.expand(p.shape).to_event(p.dim()))
# Forward pass
prediction = lifted_module(x).squeeze()
# Likelihood
if y is not None:
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.Normal(prediction, self.sigma), obs=y)
return prediction
# Usage with uncertainty quantification
net = UncertaintyNet()
# Training with SVI
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
svi = SVI(net.forward, lambda x, y: None, Adam({"lr": 0.01}), Trace_ELBO())
# Get predictions with uncertainty
from pyro.infer import Predictive
predictive = Predictive(net.forward, num_samples=100)
samples = predictive(test_x)
mean_pred = samples["obs"].mean(0)
std_pred = samples["obs"].std(0)Install with Tessl CLI
npx tessl i tessl/pypi-pyro-ppl