CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

handlers.mddocs/

Handlers

NumPyro provides Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. These handlers enable powerful model manipulation capabilities like conditioning on observed data, substituting values, applying transformations, and controlling inference behavior.

Capabilities

Core Handler Infrastructure

Base classes and utilities for the effect handling system.

class Messenger:
    """
    Base class for effect handlers with context manager protocol.
    
    Handlers intercept messages at primitive sites and can modify their behavior.
    This enables conditioning, substitution, masking, and other transformations.
    """
    def __init__(self, fn: Optional[Callable] = None): ...
    
    def __enter__(self): ...
    def __exit__(self, exc_type, exc_value, traceback): ...
    
    def process_message(self, msg: dict) -> None:
        """
        Process a message at a primitive site.
        
        Args:
            msg: Message dictionary containing site information
        """
    
    def __call__(self, *args, **kwargs):
        """Call the wrapped function with handler active."""

def default_process_message(msg: dict) -> None:
    """Default message processing for primitive sites."""

def apply_stack(msg: dict) -> dict:
    """Apply the current effect handler stack to a message."""

Tracing and Replay

Handlers for recording and replaying model execution.

def trace(fn: Callable) -> Callable:
    """
    Record inputs and outputs at all primitive sites during model execution.
    
    Args:
        fn: Function to trace
        
    Returns:
        Traced function that returns execution trace
        
    Usage:
        traced_model = trace(model)
        trace_dict = traced_model(*args, **kwargs)
    """

def replay(fn: Callable, trace: dict) -> Callable:
    """
    Replay a function with a recorded trace.
    
    Args:
        fn: Function to replay
        trace: Execution trace from previous run
        
    Returns:
        Function that replays with given trace
        
    Usage:
        replayed_model = replay(model, trace_dict)
        result = replayed_model(*args, **kwargs)
    """

class TraceHandler(Messenger):
    """Handler for recording execution traces."""
    def __init__(self, fn: Optional[Callable] = None): ...
    def get_trace(self) -> dict: ...

class ReplayHandler(Messenger):
    """Handler for replaying with stored traces."""
    def __init__(self, trace: dict, fn: Optional[Callable] = None): ...

Conditioning and Substitution

Handlers for conditioning models on observed data and substituting values.

def condition(fn: Callable, data: dict) -> Callable:
    """
    Condition a probabilistic model on observed data.
    
    Args:
        fn: Model function to condition
        data: Dictionary mapping site names to observed values
        
    Returns:
        Conditioned model function
        
    Usage:
        conditioned_model = condition(model, {"obs": observed_data})
        result = conditioned_model(*args, **kwargs)
    """

def substitute(fn: Callable, data: dict) -> Callable:
    """
    Substitute values at sample sites, bypassing distributions.
    
    Args:
        fn: Function to modify
        data: Dictionary mapping site names to substitute values
        
    Returns:
        Function with substituted values
        
    Usage:
        substituted_model = substitute(model, {"param1": fixed_value})
        result = substituted_model(*args, **kwargs)
    """

class ConditionHandler(Messenger):
    """Handler for conditioning on observed data."""
    def __init__(self, data: dict, fn: Optional[Callable] = None): ...

class SubstituteHandler(Messenger):
    """Handler for substituting values at sample sites."""
    def __init__(self, data: dict, fn: Optional[Callable] = None): ...

Random Seed Control

Handlers for controlling random number generation.

def seed(fn: Callable, rng_seed: int) -> Callable:
    """
    Provide a random seed context for reproducible sampling.
    
    Args:
        fn: Function to seed
        rng_seed: Random seed value
        
    Returns:
        Function with seeded random number generation
        
    Usage:
        seeded_model = seed(model, rng_seed=42)
        result = seeded_model(*args, **kwargs)
    """

class SeedHandler(Messenger):
    """Handler for providing random seed context."""
    def __init__(self, rng_seed: int, fn: Optional[Callable] = None): ...

Blocking and Masking

Handlers for selectively blocking effects or masking computations.

def block(fn: Callable, hide_fn: Optional[Callable] = None, 
         expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable:
    """
    Block effects at specified sites based on filtering functions.
    
    Args:
        fn: Function to modify
        hide_fn: Function to determine which sites to hide
        expose_fn: Function to determine which sites to expose  
        hide_all: Whether to hide all sites by default
        
    Returns:
        Function with blocked effects
        
    Usage:
        # Block all sample sites except "obs"
        blocked_model = block(model, expose_fn=lambda msg: msg["name"] == "obs")
        result = blocked_model(*args, **kwargs)
    """

def mask(fn: Callable, mask: ArrayLike) -> Callable:
    """
    Mask effects based on boolean conditions.
    
    Args:
        fn: Function to mask
        mask: Boolean array indicating which elements to mask
        
    Returns:
        Function with masked effects
        
    Usage:
        masked_model = mask(model, mask_array)
        result = masked_model(*args, **kwargs)
    """

class BlockHandler(Messenger):
    """Handler for blocking effects at specified sites."""
    def __init__(self, hide_fn: Optional[Callable] = None, 
                expose_fn: Optional[Callable] = None, hide_all: bool = True,
                fn: Optional[Callable] = None): ...

class MaskHandler(Messenger):
    """Handler for masking effects based on conditions."""
    def __init__(self, mask: ArrayLike, fn: Optional[Callable] = None): ...

Scaling and Transformation

Handlers for scaling log probabilities and applying transformations.

def scale(fn: Callable, scale: float) -> Callable:
    """
    Scale log probabilities by a constant factor.
    
    Args:
        fn: Function to scale
        scale: Scaling factor for log probabilities
        
    Returns:
        Function with scaled log probabilities
        
    Usage:
        scaled_model = scale(model, scale=0.1)  # Tempered model
        result = scaled_model(*args, **kwargs)
    """

def scope(fn: Callable, prefix: str) -> Callable:
    """
    Add a scope prefix to all site names within the function.
    
    Args:
        fn: Function to scope
        prefix: Prefix to add to site names
        
    Returns:
        Function with scoped site names
        
    Usage:
        scoped_model = scope(model, prefix="component1")
        result = scoped_model(*args, **kwargs)
    """

class ScaleHandler(Messenger):
    """Handler for scaling log probabilities."""
    def __init__(self, scale: float, fn: Optional[Callable] = None): ...

class ScopeHandler(Messenger):
    """Handler for adding scope prefixes to site names."""
    def __init__(self, prefix: str, fn: Optional[Callable] = None): ...

Parameter and Distribution Manipulation

Handlers for manipulating parameters and distributions.

def lift(fn: Callable, prior: dict) -> Callable:
    """
    Lift parameters to sample sites with specified priors.
    
    Args:
        fn: Function containing param sites to lift
        prior: Dictionary mapping parameter names to prior distributions
        
    Returns:
        Function with parameters converted to sample sites
        
    Usage:
        lifted_model = lift(model, {"weight": dist.Normal(0, 1)})
        result = lifted_model(*args, **kwargs)
    """

def reparam(fn: Callable, config: dict) -> Callable:
    """
    Apply reparameterizations to specified sites.
    
    Args:
        fn: Function to reparameterize
        config: Dictionary mapping site names to reparameterization strategies
        
    Returns:
        Function with applied reparameterizations
        
    Usage:
        from numpyro.infer.reparam import LocScaleReparam
        reparamed_model = reparam(model, {"x": LocScaleReparam(centered=0)})
        result = reparamed_model(*args, **kwargs)
    """

class LiftHandler(Messenger):
    """Handler for lifting parameters to sample sites."""
    def __init__(self, prior: dict, fn: Optional[Callable] = None): ...

class ReparamHandler(Messenger):
    """Handler for applying reparameterizations."""
    def __init__(self, config: dict, fn: Optional[Callable] = None): ...

Enumeration and Collapse

Handlers for discrete variable enumeration and marginalization.

def collapse(fn: Callable, sites: Optional[list] = None) -> Callable:
    """
    Collapse (marginalize out) discrete enumeration at specified sites.
    
    Args:
        fn: Function with enumerated discrete variables
        sites: List of site names to collapse (None for all)
        
    Returns:
        Function with collapsed discrete variables
        
    Usage:
        collapsed_model = collapse(enumerated_model, sites=["discrete_var"])
        result = collapsed_model(*args, **kwargs)
    """

class CollapseHandler(Messenger):
    """Handler for collapsing discrete enumeration."""
    def __init__(self, sites: Optional[list] = None, fn: Optional[Callable] = None): ...

Inference Configuration

Handlers for configuring inference behavior.

def infer_config(fn: Callable, config_fn: Callable) -> Callable:
    """
    Configure inference behavior at sample sites.
    
    Args:
        fn: Function to configure
        config_fn: Function that takes a site and returns inference config
        
    Returns:
        Function with inference configuration applied
        
    Usage:
        def config_fn(site):
            if site["name"] == "x":
                return {"is_auxiliary": True}
            return {}
        
        configured_model = infer_config(model, config_fn)
        result = configured_model(*args, **kwargs)
    """

class InferConfigHandler(Messenger):
    """Handler for setting inference configuration."""
    def __init__(self, config_fn: Callable, fn: Optional[Callable] = None): ...

Causal Intervention

Handlers for causal modeling and intervention.

def do(fn: Callable, data: dict) -> Callable:
    """
    Apply causal interventions (do-operator) to specified variables.
    
    Args:
        fn: Model function to intervene on
        data: Dictionary mapping variable names to intervention values
        
    Returns:
        Function with causal interventions applied
        
    Usage:
        # Intervene by setting X = 5
        intervened_model = do(causal_model, {"X": 5})
        result = intervened_model(*args, **kwargs)
    """

class DoHandler(Messenger):
    """Handler for causal interventions."""
    def __init__(self, data: dict, fn: Optional[Callable] = None): ...

Handler Composition and Utilities

Utilities for composing and managing multiple handlers.

def compose(*handlers) -> Callable:
    """
    Compose multiple handlers into a single handler.
    
    Args:
        *handlers: Handler functions to compose
        
    Returns:
        Composed handler function
        
    Usage:
        composed = compose(
            seed(rng_seed=42),
            substitute({"param": value}),
            condition({"obs": data})
        )
        result = composed(model)(*args, **kwargs)
    """

def enable_validation(is_validate: bool = True):
    """
    Context manager to enable/disable distribution validation.
    
    Args:
        is_validate: Whether to enable validation
        
    Usage:
        with enable_validation(True):
            result = model(*args, **kwargs)
    """

class DynamicHandler(Messenger):
    """Handler with dynamic behavior based on runtime conditions."""
    def __init__(self, handler_fn: Callable, fn: Optional[Callable] = None): ...

def get_mask() -> Optional[ArrayLike]:
    """Get the current mask from the handler stack."""

def get_dependencies() -> dict:
    """Get dependency information from the current trace."""

Advanced Handler Patterns

Advanced patterns for specialized use cases.

def escape(fn: Callable, escape_fn: Callable) -> Callable:
    """
    Escape from the current handler context for specified operations.
    
    Args:
        fn: Function to modify
        escape_fn: Function to determine when to escape
        
    Returns:
        Function that can escape handler effects
    """

def plate_messenger(name: str, size: int, subsample_size: Optional[int] = None,
                   dim: Optional[int] = None) -> Messenger:
    """
    Create a plate messenger for conditional independence.
    
    Args:
        name: Plate name
        size: Plate size
        subsample_size: Subsampling size
        dim: Dimension for broadcasting
        
    Returns:
        Plate messenger for conditional independence
    """

class CustomHandler(Messenger):
    """
    Template for creating custom effect handlers.
    
    Override process_message() to implement custom behavior:
    
    class MyHandler(CustomHandler):
        def process_message(self, msg):
            if msg["type"] == "sample":
                # Custom logic for sample sites
                pass
            elif msg["type"] == "param":
                # Custom logic for param sites  
                pass
    """
    def process_message(self, msg: dict) -> None: ...

Usage Examples

# Conditioning on observed data
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import condition, substitute, seed, trace

def model():
    x = numpyro.sample("x", dist.Normal(0, 1))
    y = numpyro.sample("y", dist.Normal(x, 1))
    return y

# Condition on observed y
observed_data = {"y": 2.0}
conditioned_model = condition(model, observed_data)

# Substitute a fixed value for x
substituted_model = substitute(model, {"x": 1.5})

# Set random seed for reproducibility
seeded_model = seed(model, rng_seed=42)

# Trace execution to see all sites
traced_model = trace(seeded_model)
trace_dict = traced_model()

# Compose multiple handlers
from numpyro.handlers import compose

composed_model = compose(
    seed(rng_seed=42),
    substitute({"x": 1.0}), 
    condition({"y": 2.0})
)(model)

result = composed_model()

Types

from typing import Optional, Union, Callable, Dict, Any
from jax import Array
import jax.numpy as jnp

ArrayLike = Union[Array, jnp.ndarray, float, int]
HandlerFunction = Callable[[Callable], Callable]

class Message:
    """
    Message dictionary structure for effect handlers.
    
    Common fields:
    - name: Site name
    - type: Message type ("sample", "param", "deterministic", etc.)
    - fn: Distribution or function at the site
    - args: Arguments to the function
    - kwargs: Keyword arguments to the function
    - value: Sampled or computed value
    - is_observed: Whether the site is observed
    - infer: Inference configuration
    - scale: Probability scale factor
    """
    name: str
    type: str
    fn: Any
    args: tuple
    kwargs: dict
    value: Any
    is_observed: bool
    infer: dict
    scale: Optional[float]
    mask: Optional[ArrayLike]
    cond_indep_stack: list
    done: bool
    stop: bool
    continuation: Optional[Callable]

class Site:
    """Information about a primitive site in the model."""
    name: str
    type: str
    fn: Any
    args: tuple
    kwargs: dict
    value: Any

class Trace(dict):
    """
    Execution trace containing all primitive sites.
    
    Keys are site names, values are Site objects.
    """
    def log_prob_sum(self) -> float: ...
    def copy(self) -> 'Trace': ...
    def nodes(self) -> dict: ...

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json