CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

primitives.mddocs/

Primitives

NumPyro's primitive functions provide the core building blocks for probabilistic models. These functions enable sampling from distributions, defining parameters, handling conditional independence, and marking deterministic computations. All primitives integrate with the effect handler system and support automatic differentiation through JAX.

Capabilities

Core Sampling Primitives

The fundamental primitives for probabilistic programming.

def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None, 
          rng_key: Optional[Array] = None, sample_shape: tuple = (), 
          infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike:
    """
    Sample a value from a distribution or condition on observed data.
    
    Args:
        name: Name of the sample site (must be unique within model)
        fn: Probability distribution to sample from  
        obs: Observed value to condition on (optional)
        rng_key: Random key for sampling (optional, auto-generated if None)
        sample_shape: Shape of samples to draw (for multiple samples)
        infer: Dictionary of inference hints and configuration
        obs_mask: Boolean mask for partially observed data
        
    Returns:
        Sampled value or observed value (if obs is provided)
        
    Usage:
        # Sample from prior
        x = numpyro.sample("x", dist.Normal(0, 1))
        
        # Condition on observed data
        y = numpyro.sample("y", dist.Normal(x, 0.5), obs=observed_y)
        
        # Sample multiple values
        batch_samples = numpyro.sample("batch", dist.Normal(0, 1), sample_shape=(10,))
        
        # Configure inference behavior
        z = numpyro.sample("z", dist.Normal(0, 1), infer={"is_auxiliary": True})
    """

def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None, 
         constraint: Constraint = constraints.real, event_dim: Optional[int] = None,
         **kwargs) -> Optional[ArrayLike]:
    """
    Declare an optimizable parameter in the model.
    
    Args:
        name: Parameter name (must be unique)
        init_value: Initial value or initialization function
        constraint: Parameter constraint (e.g., constraints.positive)
        event_dim: Number of rightmost dimensions treated as event shape
        **kwargs: Additional arguments (e.g., for initialization functions)
        
    Returns:
        Parameter value (None during initial model trace)
        
    Usage:
        # Simple parameter with constraint
        sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)
        
        # Parameter with initialization function
        weights = numpyro.param("weights", 
                               lambda key: random.normal(key, (10, 5)),
                               constraint=constraints.real)
        
        # Simplex-constrained parameter
        probs = numpyro.param("probs", jnp.ones(3) / 3, constraint=constraints.simplex)
    """

Deterministic Sites

Primitives for marking deterministic computations and adding log probability factors.

def deterministic(name: str, value: ArrayLike) -> ArrayLike:
    """
    Mark a deterministic computation site for tracking in traces.
    
    Args:
        name: Name of the deterministic site
        value: Computed deterministic value
        
    Returns:
        The input value (unchanged)
        
    Usage:
        x = numpyro.sample("x", dist.Normal(0, 1))
        y = numpyro.sample("y", dist.Normal(0, 1))
        
        # Mark sum as deterministic for tracking
        sum_xy = numpyro.deterministic("sum", x + y)
        
        # Can be used for derived quantities
        mean_xy = numpyro.deterministic("mean", (x + y) / 2)
    """

def factor(name: str, log_factor: ArrayLike) -> None:
    """
    Add a log probability factor to the model's joint density.
    
    Args:
        name: Name of the factor site
        log_factor: Log probability value to add to joint density
        
    Usage:
        # Add log-likelihood term directly
        numpyro.factor("custom_loglik", -0.5 * jnp.sum((y - mu)**2) / sigma**2)
        
        # Add constraint violation penalty
        numpyro.factor("penalty", -1e6 * jnp.where(x < 0, 1.0, 0.0))
        
        # Add custom prior term
        numpyro.factor("custom_prior", dist.Gamma(2, 1).log_prob(sigma))
    """

Conditional Independence

Primitives for handling conditional independence and subsetting.

class plate:
    """
    Context manager for conditionally independent variables with automatic broadcasting.
    
    Args:
        name: Plate name (must be unique)
        size: Size of the independence dimension
        subsample_size: Size of subsample (for subsampling, optional)
        dim: Dimension for broadcasting (negative, optional)
        subsample: Indices for subsampling (optional)
        
    Usage:
        # Basic conditional independence
        with numpyro.plate("data", 100):
            x = numpyro.sample("x", dist.Normal(0, 1))  # Shape: (100,)
        
        # Subsampling for large datasets
        with numpyro.plate("data", 10000, subsample_size=100) as idx:
            # idx contains the subsample indices
            x = numpyro.sample("x", dist.Normal(0, 1))  # Shape: (100,)
        
        # Nested plates for multidimensional independence
        with numpyro.plate("batch", 50, dim=-2):
            with numpyro.plate("features", 10, dim=-1):
                weights = numpyro.sample("w", dist.Normal(0, 1))  # Shape: (50, 10)
    """
    def __init__(self, name: str, size: int, subsample_size: Optional[int] = None, 
                dim: Optional[int] = None, subsample: Optional[ArrayLike] = None): ...
    
    def __enter__(self) -> Optional[Array]:
        """Enter plate context, returning subsample indices if subsampling."""
    
    def __exit__(self, exc_type, exc_value, traceback): ...

def plate_stack(prefix: str, sizes: list[int], rightmost_dim: int = -1) -> list:
    """
    Create a stack of nested plates for multidimensional conditional independence.
    
    Args:
        prefix: Prefix for plate names
        sizes: List of sizes for each dimension  
        rightmost_dim: Rightmost dimension index
        
    Returns:
        List of plate contexts
        
    Usage:
        # Create 3D tensor of independent samples
        plates = numpyro.plate_stack("data", [20, 30, 40], rightmost_dim=-3)
        with plates[0]:
            with plates[1]:
                with plates[2]:
                    x = numpyro.sample("x", dist.Normal(0, 1))  # Shape: (20, 30, 40)
    """

def subsample(data: ArrayLike, event_dim: int) -> ArrayLike:
    """
    Subsample data based on active plates in the context.
    
    Args:
        data: Data tensor to subsample
        event_dim: Number of rightmost dimensions that are event dimensions
        
    Returns:
        Subsampled data tensor
        
    Usage:
        # Subsample based on active plate  
        with numpyro.plate("data", len(full_data), subsample_size=100):
            batch_data = numpyro.subsample(full_data, event_dim=0)
            x = numpyro.sample("x", dist.Normal(batch_data, 1))
    """

Advanced Primitives

Specialized primitives for advanced modeling scenarios.

def mutable(name: str, init_value: Optional[ArrayLike] = None) -> ArrayLike:
    """
    Create mutable storage that persists across function calls.
    
    Args:
        name: Name of the mutable site
        init_value: Initial value for the mutable storage
        
    Returns:
        Current value of mutable storage
        
    Usage:
        # Counter that increments each call
        count = numpyro.mutable("counter", 0)
        numpyro.mutable("counter", count + 1)  # Update the counter
    """

def module(name: str, nn: tuple, input_shape: Optional[tuple] = None) -> Callable:
    """
    Register neural network modules for use with JAX transformations.
    
    Args:
        name: Module name
        nn: Tuple of (init_fn, apply_fn) for neural network
        input_shape: Input shape for module initialization
        
    Returns:
        Module function that can be called with inputs
        
    Usage:
        # Haiku neural network
        import haiku as hk
        
        def net_fn(x):
            return hk.nets.MLP([64, 32, 1])(x)
        
        net = hk.transform(net_fn)
        module_fn = numpyro.module("mlp", net, input_shape=(10,))
        
        # Use in model
        x = numpyro.sample("x", dist.Normal(0, 1).expand((batch_size, 10)))
        y_pred = module_fn(x)
    """

def prng_key() -> Optional[Array]:
    """
    Get the current PRNG key from the execution context.
    
    Returns:
        Current random key or None if not available
        
    Usage:
        # Get key for manual random operations
        key = numpyro.prng_key()
        if key is not None:
            noise = random.normal(key, shape=(10,))
    """

def get_mask() -> Optional[ArrayLike]:
    """
    Get the current mask from the handler stack.
    
    Returns:
        Current mask array or None if no mask is active
        
    Usage:
        # Check if masking is active
        current_mask = numpyro.get_mask()
        if current_mask is not None:
            # Handle masked computation
            pass
    """

Internal Utilities

Internal functions used by the primitive system (typically not used directly).

def _masked_observe(name: str, fn: Distribution, obs: ArrayLike, 
                   obs_mask: ArrayLike, **kwargs) -> ArrayLike:
    """
    Handle masked observations in sample sites.
    
    Args:
        name: Site name
        fn: Distribution
        obs: Observed values
        obs_mask: Boolean mask for valid observations
        **kwargs: Additional arguments
        
    Returns:
        Masked observed value
    """

def _subsample_fn(size: int, subsample_size: int, 
                 rng_key: Optional[Array] = None) -> Array:
    """
    Generate subsample indices for plate subsampling.
    
    Args:
        size: Full dataset size
        subsample_size: Size of subsample
        rng_key: Random key for sampling
        
    Returns:
        Array of subsample indices
    """

def _inspect() -> dict:
    """
    Inspect the current Pyro stack (experimental).
    
    Returns:
        Dictionary containing stack information
    """

class CondIndepStackFrame:
    """
    Named tuple representing a conditional independence stack frame.
    
    Attributes:
        name: Frame name
        dim: Broadcasting dimension
        size: Frame size
        counter: Frame counter for tracking
    """
    name: str
    dim: int
    size: int
    counter: int

Validation and Inspection

Utilities for validating models and inspecting execution.

def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:
    """
    Validate model structure and return trace information.
    
    Args:
        model: Model function to validate
        *model_args: Arguments to pass to model
        **model_kwargs: Keyword arguments to pass to model
        
    Returns:
        Dictionary containing validation results and trace information
        
    Usage:
        def my_model():
            x = numpyro.sample("x", dist.Normal(0, 1))
            y = numpyro.sample("y", dist.Normal(x, 1))
        
        validation_info = numpyro.validate_model(my_model)
        print(f"Model has {len(validation_info['sites'])} sites")
    """

def inspect_fn(fn: Callable, *args, **kwargs) -> dict:
    """
    Inspect function execution and return detailed information.
    
    Args:
        fn: Function to inspect
        *args: Arguments to pass to function
        **kwargs: Keyword arguments to pass to function
        
    Returns:
        Dictionary with execution information including sites and dependencies
    """

Usage Examples

import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random

# Basic linear regression model
def linear_regression(X, y=None):
    # Prior parameters
    alpha = numpyro.sample("alpha", dist.Normal(0, 10))
    beta = numpyro.sample("beta", dist.Normal(0, 10))
    sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)
    
    # Linear prediction
    mu = alpha + beta * X
    
    # Mark prediction for tracking
    prediction = numpyro.deterministic("prediction", mu)
    
    # Likelihood with conditional independence over data points
    with numpyro.plate("data", X.shape[0]):
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

# Hierarchical model with nested plates
def hierarchical_model(group_idx, y=None):
    n_groups = len(jnp.unique(group_idx))
    n_obs = len(y) if y is not None else len(group_idx)
    
    # Global hyperparameters
    mu_global = numpyro.sample("mu_global", dist.Normal(0, 1))
    sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))
    
    # Group-level parameters
    with numpyro.plate("groups", n_groups):
        mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))
    
    # Observation-level likelihood
    with numpyro.plate("obs", n_obs):
        mu = mu_group[group_idx]
        numpyro.sample("y", dist.Normal(mu, 1), obs=y)

# Model with subsampling for large datasets
def large_dataset_model(X, y=None):
    n_data, n_features = X.shape
    
    # Parameters
    weights = numpyro.sample("weights", dist.Normal(0, 1).expand((n_features,)))
    
    # Subsample for computational efficiency
    with numpyro.plate("data", n_data, subsample_size=min(1000, n_data)) as idx:
        X_batch = numpyro.subsample(X, event_dim=1)[idx] if idx is not None else X
        y_batch = numpyro.subsample(y, event_dim=0)[idx] if y is not None and idx is not None else y
        
        mu = X_batch @ weights
        numpyro.sample("y", dist.Normal(mu, 0.1), obs=y_batch)

# Custom factor for non-standard likelihoods
def custom_likelihood_model(data):
    theta = numpyro.sample("theta", dist.Beta(1, 1))
    
    # Custom log-likelihood that doesn't fit standard distributions
    log_lik = jnp.sum(data * jnp.log(theta) + (1 - data) * jnp.log(1 - theta))
    numpyro.factor("custom_lik", log_lik)

Types

from typing import Optional, Union, Callable, Dict, Any, Tuple
from jax import Array
import jax.numpy as jnp
from numpyro.distributions import Distribution, constraints

ArrayLike = Union[Array, jnp.ndarray, float, int]
Constraint = constraints.Constraint
InitFunction = Union[ArrayLike, Callable[[Array], ArrayLike]]

class CondIndepStackFrame:
    """Frame in the conditional independence stack."""
    name: str
    dim: int  
    size: int
    counter: int

class PlateMessenger:
    """Messenger for plate context management."""
    name: str
    size: int
    subsample_size: Optional[int]
    dim: Optional[int]
    subsample: Optional[Array]
    
# Site types for different primitive operations
SiteType = Union["sample", "param", "deterministic", "factor", "mutable"]

class SiteInfo:
    """Information about a primitive site."""
    name: str
    type: SiteType
    fn: Optional[Distribution]
    args: tuple
    kwargs: dict
    value: Any
    is_observed: bool
    infer: dict
    scale: Optional[float]
    
class ValidationResult:
    """Result from model validation."""
    sites: dict
    dependencies: dict
    plate_stack: list
    is_valid: bool
    warnings: list
    errors: list

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json