A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference
—
Core functions and constructs that form the foundation of Pyro's probabilistic programming language, enabling the creation of probabilistic models through composable primitives.
The fundamental stochastic function for declaring random variables and observed data in probabilistic programs.
def sample(
name: str,
fn: TorchDistributionMixin,
*args,
obs: Optional[torch.Tensor] = None,
obs_mask: Optional[torch.BoolTensor] = None,
infer: Optional[InferDict] = None,
**kwargs
) -> torch.Tensor:
"""
Primitive stochastic function for probabilistic programming.
This is the core function for creating sample sites in probabilistic programs.
It can be used to declare latent variables, observed data, and guide samples.
Parameters:
- name (str): Unique name for the sample site within the current context
- fn (Distribution): Probability distribution to sample from
- obs (Tensor, optional): Observed data to condition on. When provided,
this becomes a conditioning site rather than a sampling site
- obs_mask (Tensor, optional): Boolean mask for observed data, useful for
missing data scenarios
- infer (dict, optional): Inference configuration dictionary containing
instructions for inference algorithms
Returns:
Tensor: Sample from the distribution (or observed value if obs is provided)
Examples:
>>> # Latent variable
>>> z = pyro.sample("z", dist.Normal(0, 1))
>>>
>>> # Observed data
>>> pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
>>>
>>> # With inference configuration
>>> pyro.sample("x", dist.Normal(0, 1), infer={"is_auxiliary": True})
"""Functions for declaring and managing learnable parameters that persist across calls to the model.
def param(
name: str,
init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None,
constraint: constraints.Constraint = constraints.real,
event_dim: Optional[int] = None,
) -> torch.Tensor:
"""
Declare and retrieve learnable parameters from the global parameter store.
Parameters persist across model calls and are automatically tracked for
gradient-based optimization.
Parameters:
- name (str): Parameter name, must be unique within the parameter store
- init_tensor (Tensor, optional): Initial parameter value. If None,
parameter must already exist in the store
- constraint (Constraint): Constraint on parameter values, defaults to
unconstrained real numbers
- event_dim (int, optional): Number of rightmost dimensions that are
part of the event shape
Returns:
Tensor: Parameter tensor with gradient tracking enabled
Examples:
>>> # Scalar parameter
>>> mu = pyro.param("mu", torch.tensor(0.0))
>>>
>>> # Vector parameter with constraint
>>> theta = pyro.param("theta", torch.ones(5), constraint=constraints.positive)
>>>
>>> # Matrix parameter
>>> W = pyro.param("W", torch.randn(10, 5))
"""
def clear_param_store():
"""
Clear all parameters from the global parameter store.
Useful for resetting state between different model runs or experiments.
"""
def get_param_store():
"""
Get the global parameter store instance.
Returns:
ParamStore: The global parameter store containing all named parameters
"""Context managers for declaring conditional independence and enabling efficient vectorized computation.
class plate(PlateMessenger):
def __init__(
self,
name: str,
size: Optional[int] = None,
subsample_size: Optional[int] = None,
subsample: Optional[torch.Tensor] = None,
dim: Optional[int] = None,
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> None:
"""
Context manager for declaring conditional independence assumptions.
Plates enable vectorized computation and minibatch training by declaring
that samples within the plate are conditionally independent.
Parameters:
- name (str): Unique name for the plate
- size (int): Total size of the independent dimension
- subsample_size (int, optional): Size of minibatch subsample. If provided,
enables minibatch training with automatic scaling of log probabilities
- dim (int, optional): Tensor dimension to use for broadcasting. If None,
uses the rightmost available dimension
Returns:
PlateMessenger: Context manager that modifies sample site behavior
Examples:
>>> # Basic independence
>>> with pyro.plate("data", 100):
... pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
>>>
>>> # Minibatch training
>>> with pyro.plate("data", 10000, subsample_size=32):
... pyro.sample("obs", dist.Normal(mu, sigma), obs=data_batch)
>>>
>>> # Nested plates
>>> with pyro.plate("batch", N):
... with pyro.plate("features", D):
... pyro.sample("z", dist.Normal(0, 1))
"""
def plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = -1) -> Iterator[None]:
"""
Create a stack of nested plates for multi-dimensional independence.
Parameters:
- name (str): Base name for the plates
- sizes (List[int]): Sizes for each nested plate
- rightmost_dim (int): Rightmost tensor dimension to use
Returns:
ContextManager: Nested plate context
Examples:
>>> with pyro.plate_stack("plates", [N, D, K]):
... pyro.sample("z", dist.Normal(0, 1))
"""Functions for composing and manipulating probabilistic programs.
def factor(
name: str,
log_factor: torch.Tensor,
*,
has_rsample: Optional[bool] = None
) -> None:
"""
Add an arbitrary log probability factor to the model.
Useful for including custom log probability terms that don't correspond
to standard distributions.
Parameters:
- name (str): Name of the factor site
- log_factor (torch.Tensor): Log probability factor to add to the model's
joint log probability
- has_rsample (bool, optional): Whether the factor arose from a fully
reparametrized distribution (required in guides)
Examples:
>>> # Custom likelihood term
>>> log_likelihood = -0.5 * torch.sum((data - mu) ** 2) / sigma ** 2
>>> pyro.factor("custom_likelihood", log_likelihood)
>>>
>>> # Penalty term
>>> penalty = -0.01 * torch.sum(params ** 2)
>>> pyro.factor("l2_penalty", penalty)
"""
def deterministic(name: str, value: torch.Tensor) -> torch.Tensor:
"""
Create a deterministic sample site for tracking intermediate computations.
Parameters:
- name (str): Name for the deterministic site
- value (Tensor): Deterministic value to record
- event_dim (int): Number of rightmost event dimensions
Returns:
Tensor: The input value (pass-through)
Examples:
>>> z = pyro.sample("z", dist.Normal(0, 1))
>>> z_squared = pyro.deterministic("z_squared", z ** 2)
"""
def barrier(data: torch.Tensor) -> torch.Tensor:
"""
Create a barrier for sequential execution in models.
Useful for enforcing execution order in complex models.
Parameters:
- name (str): Name for the barrier site
"""Functions for integrating PyTorch modules into probabilistic programs.
def module(name: str, nn_module, update_module_params: bool = False):
"""
Integrate a PyTorch module into a probabilistic program.
Parameters:
- name (str): Name for the module
- nn_module (torch.nn.Module): PyTorch module to integrate
- update_module_params (bool): Whether to register module parameters
with Pyro's parameter store
Returns:
torch.nn.Module: The input module
Examples:
>>> neural_net = torch.nn.Linear(10, 1)
>>> nn = pyro.module("neural_net", neural_net, update_module_params=True)
>>> output = nn(input_tensor)
"""
def random_module(name: str, nn_module, prior, *args, **kwargs):
"""
Create a stochastic neural network by placing priors over module parameters.
Parameters:
- name (str): Name for the random module
- nn_module (torch.nn.Module): PyTorch module template
- prior (callable): Function that returns prior distributions for parameters
Returns:
torch.nn.Module: Module with stochastic parameters
Examples:
>>> def prior(name, shape):
... return dist.Normal(0, 1).expand(shape).to_event(len(shape))
>>>
>>> template = torch.nn.Linear(10, 1)
>>> bayesian_nn = pyro.random_module("bnn", template, prior)
"""Utilities for data subsampling and model visualization.
def subsample(data: torch.Tensor, event_dim: int) -> torch.Tensor:
"""
Mark data for automatic subsampling within plates.
Parameters:
- data (Tensor): Data to subsample
- event_dim (int): Number of rightmost event dimensions
Returns:
Tensor: Subsampled data when inside a subsampling plate
"""
def render_model(model, *args, **kwargs):
"""
Render a graphical representation of the probabilistic model.
Parameters:
- model (callable): Model function to visualize
- *args, **kwargs: Arguments to pass to the model
Returns:
Visualization object for the model structure
"""Functions for managing global Pyro state and settings.
def get_param_store() -> ParamStoreDict:
"""
Get the global parameter store containing all Pyro parameters.
Returns:
ParamStoreDict: Global parameter store dictionary
Examples:
>>> param_store = pyro.get_param_store()
>>> print(list(param_store.keys())) # List all parameter names
"""
def clear_param_store() -> None:
"""
Clear all parameters from the global parameter store.
Useful for starting fresh between experiments or tests.
Examples:
>>> pyro.clear_param_store() # Remove all parameters
"""
def enable_validation(is_validate: bool = True):
"""
Enable or disable runtime validation of distributions and shapes.
Parameters:
- is_validate (bool): Whether to enable validation
Examples:
>>> pyro.enable_validation(True) # Enable for debugging
>>> pyro.enable_validation(False) # Disable for performance
"""
def validation_enabled(is_validate: bool = True) -> Iterator[None]:
"""
Check if validation is currently enabled.
Returns:
bool: True if validation is enabled
"""
def set_rng_seed(rng_seed: int):
"""
Set random number generator seeds for reproducible results.
Sets seeds for Python random, NumPy, and PyTorch random number generators.
Parameters:
- rng_seed (int): Seed value for reproducible randomness
Examples:
>>> pyro.set_rng_seed(42) # For reproducible experiments
"""import pyro
import pyro.distributions as dist
import torch
def coin_flip_model(data):
"""Simple Bernoulli coin flip model."""
# Prior on bias
bias = pyro.sample("bias", dist.Beta(1.0, 1.0))
# Likelihood
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Bernoulli(bias), obs=data)
# Usage
data = torch.tensor([1.0, 0.0, 1.0, 1.0, 0.0])
coin_flip_model(data)def hierarchical_model(group_data):
"""Hierarchical model with group-level parameters."""
# Global hyperpriors
mu_alpha = pyro.sample("mu_alpha", dist.Normal(0, 10))
sigma_alpha = pyro.sample("sigma_alpha", dist.HalfNormal(5))
# Group-specific parameters
with pyro.plate("groups", len(group_data)):
alpha = pyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
# Observations within each group
for i, group in enumerate(group_data):
with pyro.plate(f"group_{i}_data", len(group)):
pyro.sample(f"obs_{i}", dist.Normal(alpha[i], 1), obs=group)def minibatch_model(data_loader):
"""Model with minibatch training support."""
# Global parameters
mu = pyro.param("mu", torch.tensor(0.0))
sigma = pyro.param("sigma", torch.tensor(1.0), constraint=dist.constraints.positive)
# Process minibatch
for batch in data_loader:
with pyro.plate("data", len(batch), subsample_size=len(batch)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=batch)Install with Tessl CLI
npx tessl i tessl/pypi-pyro-ppl