Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro provides Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. These handlers enable powerful model manipulation capabilities like conditioning on observed data, substituting values, applying transformations, and controlling inference behavior.
Base classes and utilities for the effect handling system.
class Messenger:
"""
Base class for effect handlers with context manager protocol.
Handlers intercept messages at primitive sites and can modify their behavior.
This enables conditioning, substitution, masking, and other transformations.
"""
def __init__(self, fn: Optional[Callable] = None): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
def process_message(self, msg: dict) -> None:
"""
Process a message at a primitive site.
Args:
msg: Message dictionary containing site information
"""
def __call__(self, *args, **kwargs):
"""Call the wrapped function with handler active."""
def default_process_message(msg: dict) -> None:
"""Default message processing for primitive sites."""
def apply_stack(msg: dict) -> dict:
"""Apply the current effect handler stack to a message."""Handlers for recording and replaying model execution.
def trace(fn: Callable) -> Callable:
"""
Record inputs and outputs at all primitive sites during model execution.
Args:
fn: Function to trace
Returns:
Traced function that returns execution trace
Usage:
traced_model = trace(model)
trace_dict = traced_model(*args, **kwargs)
"""
def replay(fn: Callable, trace: dict) -> Callable:
"""
Replay a function with a recorded trace.
Args:
fn: Function to replay
trace: Execution trace from previous run
Returns:
Function that replays with given trace
Usage:
replayed_model = replay(model, trace_dict)
result = replayed_model(*args, **kwargs)
"""
class TraceHandler(Messenger):
"""Handler for recording execution traces."""
def __init__(self, fn: Optional[Callable] = None): ...
def get_trace(self) -> dict: ...
class ReplayHandler(Messenger):
"""Handler for replaying with stored traces."""
def __init__(self, trace: dict, fn: Optional[Callable] = None): ...Handlers for conditioning models on observed data and substituting values.
def condition(fn: Callable, data: dict) -> Callable:
"""
Condition a probabilistic model on observed data.
Args:
fn: Model function to condition
data: Dictionary mapping site names to observed values
Returns:
Conditioned model function
Usage:
conditioned_model = condition(model, {"obs": observed_data})
result = conditioned_model(*args, **kwargs)
"""
def substitute(fn: Callable, data: dict) -> Callable:
"""
Substitute values at sample sites, bypassing distributions.
Args:
fn: Function to modify
data: Dictionary mapping site names to substitute values
Returns:
Function with substituted values
Usage:
substituted_model = substitute(model, {"param1": fixed_value})
result = substituted_model(*args, **kwargs)
"""
class ConditionHandler(Messenger):
"""Handler for conditioning on observed data."""
def __init__(self, data: dict, fn: Optional[Callable] = None): ...
class SubstituteHandler(Messenger):
"""Handler for substituting values at sample sites."""
def __init__(self, data: dict, fn: Optional[Callable] = None): ...Handlers for controlling random number generation.
def seed(fn: Callable, rng_seed: int) -> Callable:
"""
Provide a random seed context for reproducible sampling.
Args:
fn: Function to seed
rng_seed: Random seed value
Returns:
Function with seeded random number generation
Usage:
seeded_model = seed(model, rng_seed=42)
result = seeded_model(*args, **kwargs)
"""
class SeedHandler(Messenger):
"""Handler for providing random seed context."""
def __init__(self, rng_seed: int, fn: Optional[Callable] = None): ...Handlers for selectively blocking effects or masking computations.
def block(fn: Callable, hide_fn: Optional[Callable] = None,
expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable:
"""
Block effects at specified sites based on filtering functions.
Args:
fn: Function to modify
hide_fn: Function to determine which sites to hide
expose_fn: Function to determine which sites to expose
hide_all: Whether to hide all sites by default
Returns:
Function with blocked effects
Usage:
# Block all sample sites except "obs"
blocked_model = block(model, expose_fn=lambda msg: msg["name"] == "obs")
result = blocked_model(*args, **kwargs)
"""
def mask(fn: Callable, mask: ArrayLike) -> Callable:
"""
Mask effects based on boolean conditions.
Args:
fn: Function to mask
mask: Boolean array indicating which elements to mask
Returns:
Function with masked effects
Usage:
masked_model = mask(model, mask_array)
result = masked_model(*args, **kwargs)
"""
class BlockHandler(Messenger):
"""Handler for blocking effects at specified sites."""
def __init__(self, hide_fn: Optional[Callable] = None,
expose_fn: Optional[Callable] = None, hide_all: bool = True,
fn: Optional[Callable] = None): ...
class MaskHandler(Messenger):
"""Handler for masking effects based on conditions."""
def __init__(self, mask: ArrayLike, fn: Optional[Callable] = None): ...Handlers for scaling log probabilities and applying transformations.
def scale(fn: Callable, scale: float) -> Callable:
"""
Scale log probabilities by a constant factor.
Args:
fn: Function to scale
scale: Scaling factor for log probabilities
Returns:
Function with scaled log probabilities
Usage:
scaled_model = scale(model, scale=0.1) # Tempered model
result = scaled_model(*args, **kwargs)
"""
def scope(fn: Callable, prefix: str) -> Callable:
"""
Add a scope prefix to all site names within the function.
Args:
fn: Function to scope
prefix: Prefix to add to site names
Returns:
Function with scoped site names
Usage:
scoped_model = scope(model, prefix="component1")
result = scoped_model(*args, **kwargs)
"""
class ScaleHandler(Messenger):
"""Handler for scaling log probabilities."""
def __init__(self, scale: float, fn: Optional[Callable] = None): ...
class ScopeHandler(Messenger):
"""Handler for adding scope prefixes to site names."""
def __init__(self, prefix: str, fn: Optional[Callable] = None): ...Handlers for manipulating parameters and distributions.
def lift(fn: Callable, prior: dict) -> Callable:
"""
Lift parameters to sample sites with specified priors.
Args:
fn: Function containing param sites to lift
prior: Dictionary mapping parameter names to prior distributions
Returns:
Function with parameters converted to sample sites
Usage:
lifted_model = lift(model, {"weight": dist.Normal(0, 1)})
result = lifted_model(*args, **kwargs)
"""
def reparam(fn: Callable, config: dict) -> Callable:
"""
Apply reparameterizations to specified sites.
Args:
fn: Function to reparameterize
config: Dictionary mapping site names to reparameterization strategies
Returns:
Function with applied reparameterizations
Usage:
from numpyro.infer.reparam import LocScaleReparam
reparamed_model = reparam(model, {"x": LocScaleReparam(centered=0)})
result = reparamed_model(*args, **kwargs)
"""
class LiftHandler(Messenger):
"""Handler for lifting parameters to sample sites."""
def __init__(self, prior: dict, fn: Optional[Callable] = None): ...
class ReparamHandler(Messenger):
"""Handler for applying reparameterizations."""
def __init__(self, config: dict, fn: Optional[Callable] = None): ...Handlers for discrete variable enumeration and marginalization.
def collapse(fn: Callable, sites: Optional[list] = None) -> Callable:
"""
Collapse (marginalize out) discrete enumeration at specified sites.
Args:
fn: Function with enumerated discrete variables
sites: List of site names to collapse (None for all)
Returns:
Function with collapsed discrete variables
Usage:
collapsed_model = collapse(enumerated_model, sites=["discrete_var"])
result = collapsed_model(*args, **kwargs)
"""
class CollapseHandler(Messenger):
"""Handler for collapsing discrete enumeration."""
def __init__(self, sites: Optional[list] = None, fn: Optional[Callable] = None): ...Handlers for configuring inference behavior.
def infer_config(fn: Callable, config_fn: Callable) -> Callable:
"""
Configure inference behavior at sample sites.
Args:
fn: Function to configure
config_fn: Function that takes a site and returns inference config
Returns:
Function with inference configuration applied
Usage:
def config_fn(site):
if site["name"] == "x":
return {"is_auxiliary": True}
return {}
configured_model = infer_config(model, config_fn)
result = configured_model(*args, **kwargs)
"""
class InferConfigHandler(Messenger):
"""Handler for setting inference configuration."""
def __init__(self, config_fn: Callable, fn: Optional[Callable] = None): ...Handlers for causal modeling and intervention.
def do(fn: Callable, data: dict) -> Callable:
"""
Apply causal interventions (do-operator) to specified variables.
Args:
fn: Model function to intervene on
data: Dictionary mapping variable names to intervention values
Returns:
Function with causal interventions applied
Usage:
# Intervene by setting X = 5
intervened_model = do(causal_model, {"X": 5})
result = intervened_model(*args, **kwargs)
"""
class DoHandler(Messenger):
"""Handler for causal interventions."""
def __init__(self, data: dict, fn: Optional[Callable] = None): ...Utilities for composing and managing multiple handlers.
def compose(*handlers) -> Callable:
"""
Compose multiple handlers into a single handler.
Args:
*handlers: Handler functions to compose
Returns:
Composed handler function
Usage:
composed = compose(
seed(rng_seed=42),
substitute({"param": value}),
condition({"obs": data})
)
result = composed(model)(*args, **kwargs)
"""
def enable_validation(is_validate: bool = True):
"""
Context manager to enable/disable distribution validation.
Args:
is_validate: Whether to enable validation
Usage:
with enable_validation(True):
result = model(*args, **kwargs)
"""
class DynamicHandler(Messenger):
"""Handler with dynamic behavior based on runtime conditions."""
def __init__(self, handler_fn: Callable, fn: Optional[Callable] = None): ...
def get_mask() -> Optional[ArrayLike]:
"""Get the current mask from the handler stack."""
def get_dependencies() -> dict:
"""Get dependency information from the current trace."""Advanced patterns for specialized use cases.
def escape(fn: Callable, escape_fn: Callable) -> Callable:
"""
Escape from the current handler context for specified operations.
Args:
fn: Function to modify
escape_fn: Function to determine when to escape
Returns:
Function that can escape handler effects
"""
def plate_messenger(name: str, size: int, subsample_size: Optional[int] = None,
dim: Optional[int] = None) -> Messenger:
"""
Create a plate messenger for conditional independence.
Args:
name: Plate name
size: Plate size
subsample_size: Subsampling size
dim: Dimension for broadcasting
Returns:
Plate messenger for conditional independence
"""
class CustomHandler(Messenger):
"""
Template for creating custom effect handlers.
Override process_message() to implement custom behavior:
class MyHandler(CustomHandler):
def process_message(self, msg):
if msg["type"] == "sample":
# Custom logic for sample sites
pass
elif msg["type"] == "param":
# Custom logic for param sites
pass
"""
def process_message(self, msg: dict) -> None: ...# Conditioning on observed data
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import condition, substitute, seed, trace
def model():
x = numpyro.sample("x", dist.Normal(0, 1))
y = numpyro.sample("y", dist.Normal(x, 1))
return y
# Condition on observed y
observed_data = {"y": 2.0}
conditioned_model = condition(model, observed_data)
# Substitute a fixed value for x
substituted_model = substitute(model, {"x": 1.5})
# Set random seed for reproducibility
seeded_model = seed(model, rng_seed=42)
# Trace execution to see all sites
traced_model = trace(seeded_model)
trace_dict = traced_model()
# Compose multiple handlers
from numpyro.handlers import compose
composed_model = compose(
seed(rng_seed=42),
substitute({"x": 1.0}),
condition({"y": 2.0})
)(model)
result = composed_model()from typing import Optional, Union, Callable, Dict, Any
from jax import Array
import jax.numpy as jnp
ArrayLike = Union[Array, jnp.ndarray, float, int]
HandlerFunction = Callable[[Callable], Callable]
class Message:
"""
Message dictionary structure for effect handlers.
Common fields:
- name: Site name
- type: Message type ("sample", "param", "deterministic", etc.)
- fn: Distribution or function at the site
- args: Arguments to the function
- kwargs: Keyword arguments to the function
- value: Sampled or computed value
- is_observed: Whether the site is observed
- infer: Inference configuration
- scale: Probability scale factor
"""
name: str
type: str
fn: Any
args: tuple
kwargs: dict
value: Any
is_observed: bool
infer: dict
scale: Optional[float]
mask: Optional[ArrayLike]
cond_indep_stack: list
done: bool
stop: bool
continuation: Optional[Callable]
class Site:
"""Information about a primitive site in the model."""
name: str
type: str
fn: Any
args: tuple
kwargs: dict
value: Any
class Trace(dict):
"""
Execution trace containing all primitive sites.
Keys are site names, values are Site objects.
"""
def log_prob_sum(self) -> float: ...
def copy(self) -> 'Trace': ...
def nodes(self) -> dict: ...Install with Tessl CLI
npx tessl i tessl/pypi-numpyro