CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pyro-ppl

A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference

Pending
Overview
Eval results
Files

neural-networks.mddocs/

Neural Networks Integration

Deep probabilistic models combining neural networks with probabilistic programming, enabling Bayesian neural networks, stochastic layers, and seamless integration between PyTorch modules and Pyro's probabilistic primitives.

Capabilities

Pyro Module System

Base classes and descriptors for creating probabilistic neural network modules that integrate seamlessly with Pyro's effect system.

class PyroModule(torch.nn.Module):
    """
    Base class for Pyro modules with integrated parameter and sample management.
    
    PyroModule extends torch.nn.Module to support Pyro's parameter store and
    sample statements, enabling probabilistic neural networks and automatic
    integration with inference algorithms.
    
    Examples:
    >>> class BayesianLinear(PyroModule):
    ...     def __init__(self, in_features, out_features):
    ...         super().__init__()
    ...         self.in_features = in_features
    ...         self.out_features = out_features
    ...         
    ...         # Stochastic weights
    ...         self.weight = PyroSample(
    ...             dist.Normal(0, 1).expand([out_features, in_features]).to_event(2)
    ...         )
    ...         
    ...         # Learnable bias
    ...         self.bias = PyroParam(torch.zeros(out_features))
    ...     
    ...     def forward(self, x):
    ...         return torch.nn.functional.linear(x, self.weight, self.bias)
    """
    
    def __setattr__(self, name: str, value):
        """Override to handle PyroParam and PyroSample descriptors."""
    
    def named_pyro_params(self, prefix: str = '', recurse: bool = True):
        """
        Iterate over Pyro parameters in the module.
        
        Parameters:
        - prefix (str): Prefix to prepend to parameter names
        - recurse (bool): Whether to recurse into submodules
        
        Yields:
        Tuple[str, torch.Tensor]: (name, parameter) pairs
        """

class PyroParam:
    """
    Descriptor for Pyro parameters within PyroModule.
    
    PyroParam creates learnable parameters that are automatically registered
    with Pyro's parameter store and can be constrained or transformed.
    """
    
    def __init__(self, init_tensor, constraint=dist.constraints.real, event_dim=None):
        """
        Parameters:
        - init_tensor (Tensor): Initial parameter value
        - constraint (Constraint): Parameter constraint (e.g., positive, simplex)
        - event_dim (int, optional): Number of rightmost event dimensions
        
        Examples:
        >>> # Unconstrained parameter
        >>> self.mu = PyroParam(torch.tensor(0.0))
        >>>
        >>> # Positive parameter
        >>> self.sigma = PyroParam(torch.tensor(1.0), constraint=dist.constraints.positive)
        >>>
        >>> # Simplex parameter (probabilities)
        >>> self.probs = PyroParam(torch.ones(5), constraint=dist.constraints.simplex)
        """
    
    def __get__(self, obj, obj_type=None) -> torch.Tensor:
        """Get parameter value from Pyro parameter store."""
    
    def __set__(self, obj, value):
        """Set parameter value in Pyro parameter store."""

class PyroSample:
    """
    Descriptor for Pyro samples within PyroModule.
    
    PyroSample creates stochastic variables that are automatically sampled
    from specified prior distributions during model execution.
    """
    
    def __init__(self, prior):
        """
        Parameters:
        - prior (Distribution or callable): Prior distribution or function
          returning a distribution
        
        Examples:
        >>> # Fixed prior distribution
        >>> self.weight = PyroSample(dist.Normal(0, 1))
        >>>
        >>> # Parameterized prior
        >>> self.weight = PyroSample(lambda: dist.Normal(self.weight_loc, self.weight_scale))
        >>>
        >>> # Matrix-valued parameter
        >>> self.W = PyroSample(dist.Normal(0, 1).expand([10, 5]).to_event(2))
        """
    
    def __get__(self, obj, obj_type=None) -> torch.Tensor:
        """Sample from prior distribution."""

def pyro_method(fn):
    """
    Decorator to create Pyro-aware methods in PyroModule.
    
    Ensures that sample statements within decorated methods use appropriate
    name scoping and integration with the module's parameter namespace.
    
    Parameters:
    - fn (callable): Method to decorate
    
    Returns:
    callable: Decorated method with Pyro integration
    
    Examples:
    >>> class MyModule(PyroModule):
    ...     @pyro_method
    ...     def model(self, x):
    ...         z = pyro.sample("z", dist.Normal(0, 1))
    ...         return self.forward(x, z)
    """

Neural Network Architectures

Specialized neural network architectures for probabilistic modeling and normalizing flows.

class DenseNN(PyroModule):
    """
    Dense (fully-connected) neural network with configurable architecture.
    
    Commonly used in normalizing flows, variational autoencoders, and as
    function approximators in probabilistic models.
    """
    
    def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int,
                 nonlinearity: torch.nn.Module = torch.nn.ReLU(),
                 residual_connections: bool = False, batch_norm: bool = False,
                 dropout_prob: float = 0.0):
        """
        Parameters:
        - input_dim (int): Input dimension
        - hidden_dims (List[int]): List of hidden layer dimensions
        - output_dim (int): Output dimension
        - nonlinearity (Module): Activation function between layers
        - residual_connections (bool): Whether to add residual connections
        - batch_norm (bool): Whether to use batch normalization
        - dropout_prob (float): Dropout probability (0 = no dropout)
        
        Examples:
        >>> # Simple 3-layer network
        >>> net = DenseNN(10, [64, 32], 1)
        >>>
        >>> # Network with batch norm and dropout
        >>> net = DenseNN(20, [128, 64, 32], 5, 
        ...                batch_norm=True, dropout_prob=0.1)
        """
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the network.
        
        Parameters:
        - x (Tensor): Input tensor of shape (..., input_dim)
        
        Returns:
        Tensor: Output tensor of shape (..., output_dim)
        """

class ConditionalDenseNN(PyroModule):
    """
    Conditional dense neural network that takes additional context input.
    
    Useful for conditional normalizing flows and context-dependent function
    approximation in probabilistic models.
    """
    
    def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int], 
                 output_dim: int, nonlinearity: torch.nn.Module = torch.nn.ReLU(),
                 residual_connections: bool = False):
        """
        Parameters:
        - input_dim (int): Primary input dimension
        - context_dim (int): Context/condition dimension
        - hidden_dims (List[int]): Hidden layer dimensions
        - output_dim (int): Output dimension
        - nonlinearity (Module): Activation function
        - residual_connections (bool): Whether to use residual connections
        
        Examples:
        >>> # Conditional network
        >>> cond_net = ConditionalDenseNN(10, 5, [64, 32], 2)
        >>> output = cond_net(x, context)
        """
    
    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with context input.
        
        Parameters:
        - x (Tensor): Primary input of shape (..., input_dim)
        - context (Tensor): Context input of shape (..., context_dim)
        
        Returns:
        Tensor: Output tensor of shape (..., output_dim)
        """

class AutoRegressiveNN(PyroModule):
    """
    Autoregressive neural network with masked connections.
    
    Implements MADE (Masked Autoencoder for Distribution Estimation) for
    autoregressive density modeling and normalizing flows.
    """
    
    def __init__(self, input_dim: int, hidden_dims: List[int], output_dim_multiplier: int = 1,
                 nonlinearity: torch.nn.Module = torch.nn.ReLU(), residual_connections: bool = False,
                 random_mask: bool = False, activation: torch.nn.Module = None):
        """
        Parameters:
        - input_dim (int): Input dimension
        - hidden_dims (List[int]): Hidden layer dimensions  
        - output_dim_multiplier (int): Output dimension multiplier (for multiple outputs per input)
        - nonlinearity (Module): Hidden layer activation
        - residual_connections (bool): Whether to use residual connections
        - random_mask (bool): Whether to use random ordering for autoregressive mask
        - activation (Module): Final layer activation
        
        Examples:
        >>> # Autoregressive network for 10-dimensional data
        >>> ar_net = AutoRegressiveNN(10, [64, 64], output_dim_multiplier=2)
        >>> # Output has shape (..., 20) for 2 outputs per input dimension
        """
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass preserving autoregressive property.
        
        Parameters:
        - x (Tensor): Input tensor of shape (..., input_dim)
        
        Returns:
        Tensor: Output respecting autoregressive ordering
        """

class ConditionalAutoRegressiveNN(AutoRegressiveNN):
    """
    Conditional autoregressive neural network with context input.
    
    Combines autoregressive masking with conditional computation for
    context-dependent autoregressive models.
    """
    
    def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],
                 output_dim_multiplier: int = 1, nonlinearity: torch.nn.Module = torch.nn.ReLU()):
        """
        Parameters:
        - input_dim (int): Primary input dimension
        - context_dim (int): Context dimension
        - hidden_dims (List[int]): Hidden layer dimensions
        - output_dim_multiplier (int): Output multiplier per input dimension
        - nonlinearity (Module): Activation function
        """
    
    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        """Forward pass with context input maintaining autoregressive property."""

class MaskedLinear(torch.nn.Module):
    """
    Linear layer with learnable or fixed mask for autoregressive networks.
    
    Used as a building block in autoregressive neural networks where
    connections must respect the autoregressive ordering.
    """
    
    def __init__(self, in_features: int, out_features: int, mask: torch.Tensor = None,
                 bias: bool = True):
        """
        Parameters:
        - in_features (int): Input feature dimension
        - out_features (int): Output feature dimension  
        - mask (Tensor, optional): Binary mask matrix (1=keep, 0=mask)
        - bias (bool): Whether to include bias parameter
        
        Examples:
        >>> # Create mask for autoregressive ordering
        >>> mask = torch.tril(torch.ones(5, 5))  # Lower triangular
        >>> masked_layer = MaskedLinear(5, 5, mask)
        """
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with masked weight matrix."""

Bayesian Neural Networks

Tools for creating and working with Bayesian neural networks where weights and biases are treated as random variables.

def lift_module(nn_module: torch.nn.Module, prior: callable, guide: callable = None):
    """
    Lift a PyTorch module to a Bayesian neural network.
    
    Converts deterministic neural network parameters to random variables
    with specified prior distributions.
    
    Parameters:
    - nn_module (Module): PyTorch module to convert
    - prior (callable): Function that returns prior distributions for parameters
    - guide (callable, optional): Function that returns guide distributions
    
    Returns:
    PyroModule: Bayesian version of the input module
    
    Examples:
    >>> # Define deterministic network
    >>> net = torch.nn.Linear(10, 1)
    >>>
    >>> # Define priors
    >>> def prior(name, shape):
    ...     return dist.Normal(0, 1).expand(shape).to_event(len(shape))
    >>>
    >>> # Create Bayesian network
    >>> bnn = lift_module(net, prior)
    >>> 
    >>> # Use in probabilistic model
    >>> def model(x, y):
    ...     lifted_nn = pyro.random_module("nn", net, prior)
    ...     prediction = lifted_nn(x)
    ...     pyro.sample("obs", dist.Normal(prediction.squeeze(), 0.1), obs=y)
    """

def sample_module_outputs(model: PyroModule, input_data: torch.Tensor, 
                         num_samples: int = 100) -> torch.Tensor:
    """
    Sample multiple outputs from a Bayesian neural network.
    
    Parameters:
    - model (PyroModule): Bayesian neural network model
    - input_data (Tensor): Input data
    - num_samples (int): Number of posterior samples to generate
    
    Returns:
    Tensor: Sampled outputs with shape (num_samples, batch_size, output_dim)
    
    Examples:
    >>> outputs = sample_module_outputs(bnn, test_data, num_samples=50)
    >>> mean_prediction = outputs.mean(dim=0)
    >>> uncertainty = outputs.std(dim=0)
    """

class BayesianModule(PyroModule):
    """
    Base class for implementing custom Bayesian neural network layers.
    
    Provides utilities for parameter sampling and uncertainty quantification
    in neural network layers.
    """
    
    def __init__(self, name: str):
        """
        Parameters:
        - name (str): Module name for parameter scoping
        """
        super().__init__()
        self._pyro_name = name
    
    def sample_parameters(self):
        """Sample parameters from their prior/posterior distributions."""
    
    def forward_with_samples(self, x: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """
        Forward pass with multiple parameter samples for uncertainty estimation.
        
        Parameters:
        - x (Tensor): Input data
        - num_samples (int): Number of parameter samples
        
        Returns:
        Tensor: Output samples with uncertainty
        """

Variational Layers

Specialized layers for variational inference and amortized inference in deep generative models.

class VariationalLinear(PyroModule):
    """
    Variational linear layer with learnable mean and variance parameters.
    
    Implements local reparameterization trick for efficient variational
    inference in neural networks.
    """
    
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 prior_scale: float = 1.0):
        """
        Parameters:
        - in_features (int): Input feature dimension
        - out_features (int): Output feature dimension
        - bias (bool): Whether to include bias term
        - prior_scale (float): Scale of prior distribution on weights
        
        Examples:
        >>> var_layer = VariationalLinear(10, 5, prior_scale=0.1)
        """
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass using local reparameterization trick."""

class AmortizedLDA(PyroModule):
    """
    Amortized Latent Dirichlet Allocation using neural networks.
    
    Implements neural variational inference for topic modeling where
    the variational parameters are predicted by neural networks.
    """
    
    def __init__(self, vocab_size: int, num_topics: int, hidden_dim: int = 100,
                 dropout: float = 0.2):
        """
        Parameters:
        - vocab_size (int): Vocabulary size
        - num_topics (int): Number of topics
        - hidden_dim (int): Hidden dimension for encoder network
        - dropout (float): Dropout probability
        """
    
    def model(self, docs: torch.Tensor, doc_lengths: torch.Tensor):
        """LDA generative model."""
    
    def guide(self, docs: torch.Tensor, doc_lengths: torch.Tensor):
        """Neural variational guide for LDA."""

Integration Utilities

Functions for seamless integration between PyTorch modules and Pyro probabilistic programs.

def to_pyro_module_(nn_module: torch.nn.Module, prior: callable = None) -> PyroModule:
    """
    Convert PyTorch module to PyroModule in-place.
    
    Parameters:
    - nn_module (Module): PyTorch module to convert
    - prior (callable, optional): Prior distribution generator for parameters
    
    Returns:
    PyroModule: Converted module (same object)
    
    Examples:
    >>> net = torch.nn.Linear(10, 1)
    >>> pyro_net = to_pyro_module_(net)
    """

def clear_module_hooks(module: torch.nn.Module):
    """
    Clear all Pyro-related hooks from a PyTorch module.
    
    Parameters:
    - module (Module): Module to clear hooks from
    """

def module_prior(module_name: str, module: torch.nn.Module, 
                 prior_fn: callable) -> torch.nn.Module:
    """
    Apply prior distributions to all parameters in a PyTorch module.
    
    Parameters:
    - module_name (str): Name prefix for Pyro sample sites
    - module (Module): PyTorch module
    - prior_fn (callable): Function returning prior distributions
    
    Returns:
    Module: Module with stochastic parameters
    
    Examples:
    >>> def weight_prior(name, param):
    ...     return dist.Normal(0, 1).expand(param.shape).to_event(param.dim())
    >>>
    >>> net = torch.nn.Linear(10, 1)
    >>> stochastic_net = module_prior("net", net, weight_prior)
    """

class PyroModuleList(torch.nn.ModuleList, PyroModule):
    """
    ModuleList that supports PyroModule functionality.
    
    Enables lists of PyroModules to work correctly with Pyro's
    parameter management and effect handling.
    
    Examples:
    >>> layers = PyroModuleList([
    ...     BayesianLinear(10, 20),
    ...     BayesianLinear(20, 1)
    ... ])
    """
    
    def __init__(self, modules=None):
        """
        Parameters:
        - modules (iterable, optional): Iterable of modules to add
        """

Examples

Simple Bayesian Neural Network

import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample, PyroParam
import torch.nn.functional as F

class BayesianLinear(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Stochastic weights and biases
        self.weight = PyroSample(
            dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
        )
        self.bias = PyroSample(
            dist.Normal(0., 1.).expand([out_features]).to_event(1)
        )
    
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

# Usage in a model
def model(x, y):
    fc = BayesianLinear(3, 1)
    
    # Forward pass
    mean = fc(x).squeeze()
    
    # Likelihood
    with pyro.plate("data", len(x)):
        pyro.sample("obs", dist.Normal(mean, 0.1), obs=y)

def guide(x, y):
    # Use a simpler guide or let AutoGuides handle it
    pass

Variational Autoencoder

class VAE(PyroModule):
    def __init__(self, input_dim=784, hidden_dim=400, z_dim=20):
        super().__init__()
        
        # Encoder
        self.encoder_fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.encoder_mu = torch.nn.Linear(hidden_dim, z_dim)
        self.encoder_sigma = torch.nn.Linear(hidden_dim, z_dim)
        
        # Decoder  
        self.decoder_fc1 = torch.nn.Linear(z_dim, hidden_dim)
        self.decoder_fc2 = torch.nn.Linear(hidden_dim, input_dim)
    
    def model(self, x):
        # Register parameters with Pyro
        pyro.module("decoder", self)
        
        batch_size = x.shape[0]
        
        # Prior
        with pyro.plate("data", batch_size):
            z_loc = torch.zeros(batch_size, self.z_dim)
            z_scale = torch.ones(batch_size, self.z_dim)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            
            # Decode
            hidden = F.relu(self.decoder_fc1(z))
            mu_img = torch.sigmoid(self.decoder_fc2(hidden))
            
            # Likelihood
            pyro.sample("obs", dist.Bernoulli(mu_img).to_event(1), obs=x)
    
    def guide(self, x):
        # Register parameters with Pyro
        pyro.module("encoder", self)
        
        batch_size = x.shape[0]
        
        # Encode
        hidden = F.relu(self.encoder_fc1(x))
        z_mu = self.encoder_mu(hidden)
        z_sigma = F.softplus(self.encoder_sigma(hidden))
        
        # Variational distribution
        with pyro.plate("data", batch_size):
            pyro.sample("latent", dist.Normal(z_mu, z_sigma).to_event(1))

Neural Network with Uncertainty

class UncertaintyNet(PyroModule):
    def __init__(self):
        super().__init__()
        self.linear = PyroModule[torch.nn.Linear](10, 1)
        
        # Learnable noise parameter
        self.sigma = PyroParam(torch.tensor(1.0), 
                              constraint=dist.constraints.positive)
    
    def forward(self, x, y=None):
        # Sample network weights
        lifted_module = pyro.random_module("module", self.linear, 
                                         lambda name, p: dist.Normal(0, 1)
                                         .expand(p.shape).to_event(p.dim()))
        
        # Forward pass
        prediction = lifted_module(x).squeeze()
        
        # Likelihood
        if y is not None:
            with pyro.plate("data", len(x)):
                pyro.sample("obs", dist.Normal(prediction, self.sigma), obs=y)
        
        return prediction

# Usage with uncertainty quantification
net = UncertaintyNet() 

# Training with SVI
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

svi = SVI(net.forward, lambda x, y: None, Adam({"lr": 0.01}), Trace_ELBO())

# Get predictions with uncertainty
from pyro.infer import Predictive
predictive = Predictive(net.forward, num_samples=100)
samples = predictive(test_x)
mean_pred = samples["obs"].mean(0)
std_pred = samples["obs"].std(0)

Install with Tessl CLI

npx tessl i tessl/pypi-pyro-ppl

docs

core-programming.md

distributions.md

gaussian-processes.md

index.md

inference.md

neural-networks.md

optimization.md

transforms-constraints.md

tile.json