CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pyro-ppl

A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference

Pending
Overview
Eval results
Files

inference.mddocs/

Inference Methods

Scalable inference algorithms for posterior approximation and model learning, including variational inference, Markov Chain Monte Carlo, and specialized sampling methods for probabilistic programs.

Capabilities

Stochastic Variational Inference

Gradient-based variational inference for scalable approximate posterior computation.

class SVI:
    """
    Stochastic Variational Inference for scalable posterior approximation.
    
    SVI optimizes variational parameters to minimize the KL divergence between
    a variational guide and the true posterior distribution.
    """
    
    def __init__(self, model, guide, optim, loss):
        """
        Initialize SVI with model, guide, optimizer and loss function.
        
        Parameters:
        - model (callable): Generative model function
        - guide (callable): Variational guide function that approximates posterior
        - optim (PyroOptim): Pyro optimizer wrapping PyTorch optimizer
        - loss (ELBO): Evidence Lower Bound loss function
        
        Examples:
        >>> svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
        """
    
    def step(self, *args, **kwargs) -> float:
        """
        Perform one SVI optimization step.
        
        Parameters:
        - *args, **kwargs: Arguments to pass to model and guide
        
        Returns:
        float: Loss value for this step (negative ELBO)
        
        Examples:
        >>> loss = svi.step(data)
        >>> print(f"Loss: {loss}")
        """
    
    def evaluate_loss(self, *args, **kwargs) -> float:
        """
        Evaluate loss without taking optimization step.
        
        Parameters:
        - *args, **kwargs: Arguments to pass to model and guide
        
        Returns:
        float: Current loss value
        """

def init_to_feasible(site: dict = None) -> torch.Tensor:
    """
    Initialize parameters to feasible values within constraints.
    
    Parameters:
    - site (dict, optional): Sample site information
    
    Returns:
    Tensor: Feasible initialization value
    """

def init_to_mean(site: dict = None) -> torch.Tensor:
    """
    Initialize parameters to distribution mean.
    
    Parameters:
    - site (dict, optional): Sample site information
    
    Returns:
    Tensor: Mean initialization value
    """

def init_to_sample(site: dict = None) -> torch.Tensor:
    """
    Initialize parameters to random samples from prior.
    
    Parameters:
    - site (dict, optional): Sample site information
    
    Returns:
    Tensor: Random sample initialization
    """

Evidence Lower Bound (ELBO)

Loss functions for variational inference based on the evidence lower bound.

class ELBO:
    """
    Base class for Evidence Lower Bound loss functions.
    
    ELBO provides a lower bound on the model evidence (marginal likelihood)
    and serves as the optimization objective for variational inference.
    """
    
    def differentiable_loss(self, model, guide, *args, **kwargs) -> torch.Tensor:
        """
        Compute differentiable ELBO loss.
        
        Returns:
        Tensor: Negative ELBO (loss to minimize)
        """

class Trace_ELBO(ELBO):
    """
    Standard trace-based ELBO implementation.
    
    Uses execution traces to compute ELBO via the log probability of the
    joint model minus the log probability of the guide.
    """
    
    def __init__(self, num_particles: int = 1, max_plate_nesting: int = float('inf'), 
                 max_iarange_nesting: int = None, vectorize_particles: bool = False,
                 strict_enumeration_warning: bool = True):
        """
        Parameters:
        - num_particles (int): Number of Monte Carlo samples for gradient estimation
        - max_plate_nesting (int): Maximum depth of nested plates to vectorize over
        - vectorize_particles (bool): Whether to vectorize over particles
        - strict_enumeration_warning (bool): Whether to warn about enumeration issues
        """

class TraceEnum_ELBO(ELBO):
    """
    ELBO with exact enumeration over discrete latent variables.
    
    Computes exact expectations over discrete variables while using
    Monte Carlo for continuous variables.
    """
    
    def __init__(self, max_plate_nesting: int = float('inf'), max_iarange_nesting: int = None,
                 strict_enumeration_warning: bool = True, ignore_jit_warnings: bool = False):
        """
        Parameters:
        - max_plate_nesting (int): Maximum plate nesting depth for enumeration
        - strict_enumeration_warning (bool): Whether to warn about enumeration issues
        - ignore_jit_warnings (bool): Whether to ignore JIT compilation warnings
        """

class TraceGraph_ELBO(ELBO):
    """
    Memory-efficient ELBO using dependency graphs.
    
    Reduces memory usage by computing gradients using the dependency
    structure of the computational graph.
    """
    pass

class TraceMeanField_ELBO(ELBO):
    """
    ELBO for mean-field variational inference.
    
    Assumes independence between latent variables in the guide,
    enabling more efficient computation.
    """
    pass

class RenyiELBO(ELBO):
    """
    Renyi divergence-based ELBO for more robust inference.
    
    Uses Renyi alpha-divergence instead of KL divergence for
    potentially better optimization properties.
    """
    
    def __init__(self, alpha: float = 0.0, num_particles: int = 2, max_plate_nesting: int = float('inf')):
        """
        Parameters:
        - alpha (float): Renyi divergence parameter (alpha=0 gives KL divergence)
        - num_particles (int): Number of particles for gradient estimation
        - max_plate_nesting (int): Maximum plate nesting depth
        """

Markov Chain Monte Carlo

MCMC methods for exact sampling from posterior distributions.

class MCMC:
    """
    Markov Chain Monte Carlo interface for exact posterior sampling.
    
    MCMC generates correlated samples from the exact posterior distribution
    using various kernel methods like HMC and NUTS.
    """
    
    def __init__(self, kernel, num_samples: int, warmup_steps: int = None, 
                 initial_params: dict = None, chain_id: int = 0, mp_context=None,
                 disable_progbar: bool = False, disable_validation: bool = True,
                 transforms: dict = None, max_tree_depth: int = None, 
                 target_accept_prob: float = 0.8, jit_compile: bool = False):
        """
        Parameters:
        - kernel: MCMC kernel (e.g., HMC, NUTS, RandomWalkKernel)
        - num_samples (int): Number of MCMC samples to generate
        - warmup_steps (int): Number of warmup/burn-in steps
        - initial_params (dict): Initial parameter values
        - chain_id (int): Chain identifier for multiple chains
        - transforms (dict): Parameter transforms for constrained sampling
        - target_accept_prob (float): Target acceptance probability for adaptive kernels
        - jit_compile (bool): Whether to JIT compile the kernel
        
        Examples:
        >>> kernel = NUTS(model)
        >>> mcmc = MCMC(kernel, num_samples=1000, warmup_steps=500)
        """
    
    def run(self, *args, **kwargs):
        """
        Run the MCMC chain.
        
        Parameters:
        - *args, **kwargs: Arguments to pass to the model
        
        Examples:
        >>> mcmc.run(data)
        >>> samples = mcmc.get_samples()
        """
    
    def get_samples(self, group_by_chain: bool = False) -> dict:
        """
        Get MCMC samples after running the chain.
        
        Parameters:
        - group_by_chain (bool): Whether to group samples by chain
        
        Returns:
        dict: Dictionary mapping sample site names to sample tensors
        
        Examples:
        >>> samples = mcmc.get_samples()
        >>> theta_samples = samples["theta"]
        """

class HMC:
    """
    Hamiltonian Monte Carlo kernel.
    
    HMC uses gradient information to make efficient proposals in
    continuous parameter spaces.
    """
    
    def __init__(self, model, step_size: float = 1.0, num_steps: int = 1,
                 adapt_step_size: bool = True, adapt_mass_matrix: bool = True,
                 full_mass: bool = False, transforms: dict = None, 
                 max_plate_nesting: int = None, jit_compile: bool = False,
                 jit_options: dict = None, ignore_jit_warnings: bool = False):
        """
        Parameters:
        - model (callable): Model to sample from
        - step_size (float): Integration step size
        - num_steps (int): Number of leapfrog steps per iteration
        - adapt_step_size (bool): Whether to adapt step size during warmup
        - adapt_mass_matrix (bool): Whether to adapt mass matrix
        - full_mass (bool): Whether to use full mass matrix (vs diagonal)
        - transforms (dict): Parameter transformations
        """

class NUTS:
    """
    No-U-Turn Sampler, an adaptive version of HMC.
    
    NUTS automatically determines the number of leapfrog steps to take
    by detecting when the trajectory starts to reverse direction.
    """
    
    def __init__(self, model, step_size: float = 1.0, adapt_step_size: bool = True,
                 adapt_mass_matrix: bool = True, full_mass: bool = False,
                 transforms: dict = None, max_plate_nesting: int = None,
                 max_tree_depth: int = 10, target_accept_prob: float = 0.8,
                 jit_compile: bool = False, jit_options: dict = None,
                 ignore_jit_warnings: bool = False):
        """
        Parameters:
        - model (callable): Model to sample from
        - step_size (float): Initial step size
        - max_tree_depth (int): Maximum binary tree depth
        - target_accept_prob (float): Target acceptance probability for adaptation
        
        Examples:
        >>> nuts_kernel = NUTS(model)
        >>> mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
        """

class RandomWalkKernel:
    """
    Random walk Metropolis-Hastings kernel.
    
    Simple MCMC kernel that proposes new states by adding random noise
    to the current state.
    """
    
    def __init__(self, model, step_size: dict = None, adapt_step_size: bool = True,
                 transforms: dict = None, max_plate_nesting: int = None):
        """
        Parameters:
        - model (callable): Model to sample from  
        - step_size (dict): Step sizes for each parameter
        - adapt_step_size (bool): Whether to adapt step size during warmup
        """

Predictive Sampling

Generate predictions and samples from trained models.

class Predictive:
    """
    Generate predictive samples from posterior or prior distributions.
    
    Predictive enables posterior predictive checks, prior predictive checks,
    and out-of-sample predictions by sampling from the model with different
    parameter configurations.
    """
    
    def __init__(self, model, guide=None, posterior_samples: dict = None, 
                 num_samples: int = None, return_sites: list = None,
                 parallel: bool = False, batch_ndims: int = 1):
        """
        Parameters:
        - model (callable): Generative model function
        - guide (callable, optional): Variational guide for posterior sampling
        - posterior_samples (dict, optional): Pre-computed posterior samples
        - num_samples (int, optional): Number of samples to generate
        - return_sites (list, optional): Sites to include in output
        - parallel (bool): Whether to parallelize sampling
        - batch_ndims (int): Number of batch dimensions
        
        Examples:
        >>> # Posterior predictive with guide
        >>> predictive = Predictive(model, guide=guide, num_samples=1000)
        >>> samples = predictive(data)
        >>>
        >>> # Prior predictive
        >>> predictive = Predictive(model, num_samples=100)
        >>> prior_samples = predictive(data)
        """
    
    def __call__(self, *args, **kwargs) -> dict:
        """
        Generate predictive samples.
        
        Parameters:
        - *args, **kwargs: Arguments to pass to the model
        
        Returns:
        dict: Dictionary mapping site names to sample tensors
        
        Examples:
        >>> samples = predictive(test_data)
        >>> predictions = samples["obs"]
        """

class WeighedPredictive:
    """
    Generate weighted predictive samples using importance sampling.
    
    Useful when posterior samples come from importance sampling or
    when samples have non-uniform weights.
    """
    
    def __init__(self, model, guide=None, posterior_samples: dict = None,
                 weights: torch.Tensor = None, num_samples: int = None,
                 return_sites: list = None, parallel: bool = False):
        """
        Parameters:
        - model (callable): Generative model function
        - guide (callable, optional): Guide function
        - posterior_samples (dict, optional): Pre-computed samples
        - weights (Tensor, optional): Sample weights
        - num_samples (int, optional): Number of samples to generate
        """

class EmpiricalMarginal:
    """
    Empirical marginal distribution from MCMC or SVI samples.
    
    Converts a collection of samples into a distribution object that
    can be used like any other Pyro distribution.
    """
    
    def __init__(self, samples: torch.Tensor, log_weights: torch.Tensor = None):
        """
        Parameters:
        - samples (Tensor): Sample values
        - log_weights (Tensor, optional): Log weights for samples
        
        Examples:
        >>> samples = mcmc.get_samples()["theta"]
        >>> marginal = EmpiricalMarginal(samples)
        >>> new_sample = marginal.sample()
        """

Importance Sampling

Importance sampling methods for model comparison and marginal likelihood estimation.

class Importance:
    """
    Importance sampling for marginal likelihood estimation.
    
    Uses importance sampling to estimate the model evidence (marginal likelihood)
    which is useful for model comparison and selection.
    """
    
    def __init__(self, model, guide, num_samples: int):
        """
        Parameters:
        - model (callable): Generative model
        - guide (callable): Importance sampling distribution (proposal)
        - num_samples (int): Number of importance samples
        
        Examples:
        >>> importance = Importance(model, guide, num_samples=10000)
        >>> log_evidence = importance.run(data)
        """
    
    def run(self, *args, **kwargs) -> torch.Tensor:
        """
        Run importance sampling to estimate log marginal likelihood.
        
        Parameters:
        - *args, **kwargs: Arguments to pass to model and guide
        
        Returns:
        Tensor: Log marginal likelihood estimate
        """

class SMCFilter:
    """
    Sequential Monte Carlo filtering for state space models.
    
    Implements particle filtering for sequential Bayesian inference
    in time series and state space models.
    """
    
    def __init__(self, model, guide, num_particles: int, max_plate_nesting: int):
        """
        Parameters:
        - model (callable): State space model
        - guide (callable): Proposal distribution for particles
        - num_particles (int): Number of particles to maintain
        - max_plate_nesting (int): Maximum plate nesting depth
        """

Specialized Inference Methods

Advanced inference algorithms for specific model types and scenarios.

class SVGD:
    """
    Stein Variational Gradient Descent for non-parametric inference.
    
    SVGD optimizes a set of particles to approximate the posterior distribution
    using kernelized Stein discrepancy minimization.
    """
    
    def __init__(self, model, kernel, optimizer, num_particles: int):
        """
        Parameters:
        - model (callable): Model function
        - kernel: Kernel function for Stein method
        - optimizer: Optimizer for particle updates
        - num_particles (int): Number of particles to optimize
        """

class ReweightedWakeSleep:
    """
    Reweighted Wake-Sleep algorithm for deep generative models.
    
    Alternative to standard variational inference that can handle
    more complex posterior approximations.
    """
    
    def __init__(self, model, guide, wake_loss, sleep_loss):
        """
        Parameters:
        - model (callable): Generative model
        - guide (callable): Recognition model
        - wake_loss: Loss function for wake phase
        - sleep_loss: Loss function for sleep phase
        """

def config_enumerate(default: str = None, expand: bool = False, num_samples: int = None):
    """
    Configure automatic enumeration over discrete latent variables.
    
    Decorator that enables exact marginalization over discrete variables
    in models with both discrete and continuous latent variables.
    
    Parameters:
    - default (str): Default enumeration strategy ("sequential" or "parallel")
    - expand (bool): Whether to expand enumerated dimensions
    - num_samples (int): Number of samples for approximate enumeration
    
    Examples:
    >>> @config_enumerate
    >>> def model():
    ...     z = pyro.sample("z", dist.Categorical(torch.ones(3)))
    ...     return pyro.sample("x", dist.Normal(z, 1))
    """

def infer_discrete(first_available_dim: int = None, temperature: float = 1.0,
                   cooler: callable = None):
    """
    Infer discrete latent variables by enumeration or sampling.
    
    Effect handler that automatically handles discrete variable inference
    by choosing between exact enumeration and approximate sampling.
    
    Parameters:
    - first_available_dim (int): First tensor dimension available for enumeration
    - temperature (float): Temperature for discrete sampling
    - cooler (callable): Cooling schedule for simulated annealing
    
    Examples:
    >>> with infer_discrete():
    ...     svi.step(data)
    """

Examples

Basic SVI Training

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

def model(data):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
    
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

def guide(data):
    mu_q = pyro.param("mu_q", torch.tensor(0.0))
    sigma_q = pyro.param("sigma_q", torch.tensor(1.0), constraint=dist.constraints.positive)
    
    pyro.sample("mu", dist.Normal(mu_q, sigma_q))
    pyro.sample("sigma", dist.LogNormal(0, 1))  # Use prior as guide

# Training
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
losses = []
for step in range(1000):
    loss = svi.step(data)
    losses.append(loss)

MCMC Sampling

from pyro.infer import MCMC, NUTS

def model(data):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
    
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

# MCMC sampling
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(data)

# Get samples
samples = mcmc.get_samples()
mu_samples = samples["mu"]
sigma_samples = samples["sigma"]

Posterior Predictive Checks

from pyro.infer import Predictive

# After training SVI or MCMC
predictive = Predictive(model, guide=guide, num_samples=1000)
posterior_samples = predictive(data)

# Generate predictions for new data
predictive_new = Predictive(model, guide=guide, num_samples=100)
predictions = predictive_new(new_data)

Model Comparison with Importance Sampling

from pyro.infer import Importance

# Compare two models
importance1 = Importance(model1, guide1, num_samples=10000)
log_evidence1 = importance1.run(data)

importance2 = Importance(model2, guide2, num_samples=10000)  
log_evidence2 = importance2.run(data)

# Bayes factor
bayes_factor = torch.exp(log_evidence1 - log_evidence2)
print(f"Bayes factor (Model 1 vs Model 2): {bayes_factor}")

Install with Tessl CLI

npx tessl i tessl/pypi-pyro-ppl

docs

core-programming.md

distributions.md

gaussian-processes.md

index.md

inference.md

neural-networks.md

optimization.md

transforms-constraints.md

tile.json