CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pyro-ppl

A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference

Pending
Overview
Eval results
Files

core-programming.mddocs/

Core Probabilistic Programming

Core functions and constructs that form the foundation of Pyro's probabilistic programming language, enabling the creation of probabilistic models through composable primitives.

Capabilities

Sample Statements

The fundamental stochastic function for declaring random variables and observed data in probabilistic programs.

def sample(
    name: str,
    fn: TorchDistributionMixin,
    *args,
    obs: Optional[torch.Tensor] = None,
    obs_mask: Optional[torch.BoolTensor] = None,
    infer: Optional[InferDict] = None,
    **kwargs
) -> torch.Tensor:
    """
    Primitive stochastic function for probabilistic programming.
    
    This is the core function for creating sample sites in probabilistic programs.
    It can be used to declare latent variables, observed data, and guide samples.
    
    Parameters:
    - name (str): Unique name for the sample site within the current context
    - fn (Distribution): Probability distribution to sample from
    - obs (Tensor, optional): Observed data to condition on. When provided,
      this becomes a conditioning site rather than a sampling site
    - obs_mask (Tensor, optional): Boolean mask for observed data, useful for
      missing data scenarios
    - infer (dict, optional): Inference configuration dictionary containing
      instructions for inference algorithms
    
    Returns:
    Tensor: Sample from the distribution (or observed value if obs is provided)
    
    Examples:
    >>> # Latent variable
    >>> z = pyro.sample("z", dist.Normal(0, 1))
    >>> 
    >>> # Observed data
    >>> pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
    >>>
    >>> # With inference configuration
    >>> pyro.sample("x", dist.Normal(0, 1), infer={"is_auxiliary": True})
    """

Parameter Management

Functions for declaring and managing learnable parameters that persist across calls to the model.

def param(
    name: str,
    init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None,
    constraint: constraints.Constraint = constraints.real,
    event_dim: Optional[int] = None,
) -> torch.Tensor:
    """
    Declare and retrieve learnable parameters from the global parameter store.
    
    Parameters persist across model calls and are automatically tracked for
    gradient-based optimization.
    
    Parameters:
    - name (str): Parameter name, must be unique within the parameter store
    - init_tensor (Tensor, optional): Initial parameter value. If None, 
      parameter must already exist in the store
    - constraint (Constraint): Constraint on parameter values, defaults to
      unconstrained real numbers
    - event_dim (int, optional): Number of rightmost dimensions that are
      part of the event shape
    
    Returns:
    Tensor: Parameter tensor with gradient tracking enabled
    
    Examples:
    >>> # Scalar parameter
    >>> mu = pyro.param("mu", torch.tensor(0.0))
    >>>
    >>> # Vector parameter with constraint
    >>> theta = pyro.param("theta", torch.ones(5), constraint=constraints.positive)
    >>>
    >>> # Matrix parameter
    >>> W = pyro.param("W", torch.randn(10, 5))
    """

def clear_param_store():
    """
    Clear all parameters from the global parameter store.
    
    Useful for resetting state between different model runs or experiments.
    """

def get_param_store():
    """
    Get the global parameter store instance.
    
    Returns:
    ParamStore: The global parameter store containing all named parameters
    """

Independence Declarations

Context managers for declaring conditional independence and enabling efficient vectorized computation.

class plate(PlateMessenger):
    def __init__(
        self,
        name: str,
        size: Optional[int] = None,
        subsample_size: Optional[int] = None,
        subsample: Optional[torch.Tensor] = None,
        dim: Optional[int] = None,
        use_cuda: Optional[bool] = None,
        device: Optional[str] = None,
    ) -> None:
    """
    Context manager for declaring conditional independence assumptions.
    
    Plates enable vectorized computation and minibatch training by declaring
    that samples within the plate are conditionally independent.
    
    Parameters:
    - name (str): Unique name for the plate
    - size (int): Total size of the independent dimension
    - subsample_size (int, optional): Size of minibatch subsample. If provided,
      enables minibatch training with automatic scaling of log probabilities
    - dim (int, optional): Tensor dimension to use for broadcasting. If None,
      uses the rightmost available dimension
    
    Returns:
    PlateMessenger: Context manager that modifies sample site behavior
    
    Examples:
    >>> # Basic independence
    >>> with pyro.plate("data", 100):
    ...     pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
    >>>
    >>> # Minibatch training
    >>> with pyro.plate("data", 10000, subsample_size=32):
    ...     pyro.sample("obs", dist.Normal(mu, sigma), obs=data_batch)
    >>>
    >>> # Nested plates
    >>> with pyro.plate("batch", N):
    ...     with pyro.plate("features", D):
    ...         pyro.sample("z", dist.Normal(0, 1))
    """

def plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = -1) -> Iterator[None]:
    """
    Create a stack of nested plates for multi-dimensional independence.
    
    Parameters:
    - name (str): Base name for the plates
    - sizes (List[int]): Sizes for each nested plate
    - rightmost_dim (int): Rightmost tensor dimension to use
    
    Returns:
    ContextManager: Nested plate context
    
    Examples:
    >>> with pyro.plate_stack("plates", [N, D, K]):
    ...     pyro.sample("z", dist.Normal(0, 1))
    """

Model Composition

Functions for composing and manipulating probabilistic programs.

def factor(
    name: str, 
    log_factor: torch.Tensor, 
    *, 
    has_rsample: Optional[bool] = None
) -> None:
    """
    Add an arbitrary log probability factor to the model.
    
    Useful for including custom log probability terms that don't correspond
    to standard distributions.
    
    Parameters:
    - name (str): Name of the factor site
    - log_factor (torch.Tensor): Log probability factor to add to the model's
      joint log probability
    - has_rsample (bool, optional): Whether the factor arose from a fully
      reparametrized distribution (required in guides)
    
    Examples:
    >>> # Custom likelihood term
    >>> log_likelihood = -0.5 * torch.sum((data - mu) ** 2) / sigma ** 2
    >>> pyro.factor("custom_likelihood", log_likelihood)
    >>>
    >>> # Penalty term  
    >>> penalty = -0.01 * torch.sum(params ** 2)
    >>> pyro.factor("l2_penalty", penalty)
    """

def deterministic(name: str, value: torch.Tensor) -> torch.Tensor:
    """
    Create a deterministic sample site for tracking intermediate computations.
    
    Parameters:
    - name (str): Name for the deterministic site
    - value (Tensor): Deterministic value to record
    - event_dim (int): Number of rightmost event dimensions
    
    Returns:
    Tensor: The input value (pass-through)
    
    Examples:
    >>> z = pyro.sample("z", dist.Normal(0, 1))
    >>> z_squared = pyro.deterministic("z_squared", z ** 2)
    """

def barrier(data: torch.Tensor) -> torch.Tensor:
    """
    Create a barrier for sequential execution in models.
    
    Useful for enforcing execution order in complex models.
    
    Parameters:
    - name (str): Name for the barrier site
    """

PyTorch Module Integration

Functions for integrating PyTorch modules into probabilistic programs.

def module(name: str, nn_module, update_module_params: bool = False):
    """
    Integrate a PyTorch module into a probabilistic program.
    
    Parameters:
    - name (str): Name for the module
    - nn_module (torch.nn.Module): PyTorch module to integrate
    - update_module_params (bool): Whether to register module parameters
      with Pyro's parameter store
    
    Returns:
    torch.nn.Module: The input module
    
    Examples:
    >>> neural_net = torch.nn.Linear(10, 1)
    >>> nn = pyro.module("neural_net", neural_net, update_module_params=True)
    >>> output = nn(input_tensor)
    """

def random_module(name: str, nn_module, prior, *args, **kwargs):
    """
    Create a stochastic neural network by placing priors over module parameters.
    
    Parameters:
    - name (str): Name for the random module
    - nn_module (torch.nn.Module): PyTorch module template
    - prior (callable): Function that returns prior distributions for parameters
    
    Returns:
    torch.nn.Module: Module with stochastic parameters
    
    Examples:
    >>> def prior(name, shape):
    ...     return dist.Normal(0, 1).expand(shape).to_event(len(shape))
    >>> 
    >>> template = torch.nn.Linear(10, 1)
    >>> bayesian_nn = pyro.random_module("bnn", template, prior)
    """

Subsampling and Utilities

Utilities for data subsampling and model visualization.

def subsample(data: torch.Tensor, event_dim: int) -> torch.Tensor:
    """
    Mark data for automatic subsampling within plates.
    
    Parameters:
    - data (Tensor): Data to subsample
    - event_dim (int): Number of rightmost event dimensions
    
    Returns:
    Tensor: Subsampled data when inside a subsampling plate
    """

def render_model(model, *args, **kwargs):
    """
    Render a graphical representation of the probabilistic model.
    
    Parameters:
    - model (callable): Model function to visualize
    - *args, **kwargs: Arguments to pass to the model
    
    Returns:
    Visualization object for the model structure
    """

Global State Management

Functions for managing global Pyro state and settings.

def get_param_store() -> ParamStoreDict:
    """
    Get the global parameter store containing all Pyro parameters.
    
    Returns:
    ParamStoreDict: Global parameter store dictionary
    
    Examples:
    >>> param_store = pyro.get_param_store()
    >>> print(list(param_store.keys()))  # List all parameter names
    """

def clear_param_store() -> None:
    """
    Clear all parameters from the global parameter store.
    
    Useful for starting fresh between experiments or tests.
    
    Examples:
    >>> pyro.clear_param_store()  # Remove all parameters
    """

def enable_validation(is_validate: bool = True):
    """
    Enable or disable runtime validation of distributions and shapes.
    
    Parameters:
    - is_validate (bool): Whether to enable validation
    
    Examples:
    >>> pyro.enable_validation(True)  # Enable for debugging
    >>> pyro.enable_validation(False)  # Disable for performance
    """

def validation_enabled(is_validate: bool = True) -> Iterator[None]:
    """
    Check if validation is currently enabled.
    
    Returns:
    bool: True if validation is enabled
    """

def set_rng_seed(rng_seed: int):
    """
    Set random number generator seeds for reproducible results.
    
    Sets seeds for Python random, NumPy, and PyTorch random number generators.
    
    Parameters:
    - rng_seed (int): Seed value for reproducible randomness
    
    Examples:
    >>> pyro.set_rng_seed(42)  # For reproducible experiments
    """

Examples

Basic Model Definition

import pyro
import pyro.distributions as dist
import torch

def coin_flip_model(data):
    """Simple Bernoulli coin flip model."""
    # Prior on bias
    bias = pyro.sample("bias", dist.Beta(1.0, 1.0))
    
    # Likelihood  
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Bernoulli(bias), obs=data)

# Usage
data = torch.tensor([1.0, 0.0, 1.0, 1.0, 0.0])
coin_flip_model(data)

Hierarchical Model

def hierarchical_model(group_data):
    """Hierarchical model with group-level parameters."""
    # Global hyperpriors
    mu_alpha = pyro.sample("mu_alpha", dist.Normal(0, 10))
    sigma_alpha = pyro.sample("sigma_alpha", dist.HalfNormal(5))
    
    # Group-specific parameters
    with pyro.plate("groups", len(group_data)):
        alpha = pyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
        
        # Observations within each group
        for i, group in enumerate(group_data):
            with pyro.plate(f"group_{i}_data", len(group)):
                pyro.sample(f"obs_{i}", dist.Normal(alpha[i], 1), obs=group)

Minibatch Training

def minibatch_model(data_loader):
    """Model with minibatch training support."""
    # Global parameters
    mu = pyro.param("mu", torch.tensor(0.0))
    sigma = pyro.param("sigma", torch.tensor(1.0), constraint=dist.constraints.positive)
    
    # Process minibatch
    for batch in data_loader:
        with pyro.plate("data", len(batch), subsample_size=len(batch)):
            pyro.sample("obs", dist.Normal(mu, sigma), obs=batch)

Install with Tessl CLI

npx tessl i tessl/pypi-pyro-ppl

docs

core-programming.md

distributions.md

gaussian-processes.md

index.md

inference.md

neural-networks.md

optimization.md

transforms-constraints.md

tile.json