Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro's primitive functions provide the core building blocks for probabilistic models. These functions enable sampling from distributions, defining parameters, handling conditional independence, and marking deterministic computations. All primitives integrate with the effect handler system and support automatic differentiation through JAX.
The fundamental primitives for probabilistic programming.
def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None,
rng_key: Optional[Array] = None, sample_shape: tuple = (),
infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike:
"""
Sample a value from a distribution or condition on observed data.
Args:
name: Name of the sample site (must be unique within model)
fn: Probability distribution to sample from
obs: Observed value to condition on (optional)
rng_key: Random key for sampling (optional, auto-generated if None)
sample_shape: Shape of samples to draw (for multiple samples)
infer: Dictionary of inference hints and configuration
obs_mask: Boolean mask for partially observed data
Returns:
Sampled value or observed value (if obs is provided)
Usage:
# Sample from prior
x = numpyro.sample("x", dist.Normal(0, 1))
# Condition on observed data
y = numpyro.sample("y", dist.Normal(x, 0.5), obs=observed_y)
# Sample multiple values
batch_samples = numpyro.sample("batch", dist.Normal(0, 1), sample_shape=(10,))
# Configure inference behavior
z = numpyro.sample("z", dist.Normal(0, 1), infer={"is_auxiliary": True})
"""
def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None,
constraint: Constraint = constraints.real, event_dim: Optional[int] = None,
**kwargs) -> Optional[ArrayLike]:
"""
Declare an optimizable parameter in the model.
Args:
name: Parameter name (must be unique)
init_value: Initial value or initialization function
constraint: Parameter constraint (e.g., constraints.positive)
event_dim: Number of rightmost dimensions treated as event shape
**kwargs: Additional arguments (e.g., for initialization functions)
Returns:
Parameter value (None during initial model trace)
Usage:
# Simple parameter with constraint
sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)
# Parameter with initialization function
weights = numpyro.param("weights",
lambda key: random.normal(key, (10, 5)),
constraint=constraints.real)
# Simplex-constrained parameter
probs = numpyro.param("probs", jnp.ones(3) / 3, constraint=constraints.simplex)
"""Primitives for marking deterministic computations and adding log probability factors.
def deterministic(name: str, value: ArrayLike) -> ArrayLike:
"""
Mark a deterministic computation site for tracking in traces.
Args:
name: Name of the deterministic site
value: Computed deterministic value
Returns:
The input value (unchanged)
Usage:
x = numpyro.sample("x", dist.Normal(0, 1))
y = numpyro.sample("y", dist.Normal(0, 1))
# Mark sum as deterministic for tracking
sum_xy = numpyro.deterministic("sum", x + y)
# Can be used for derived quantities
mean_xy = numpyro.deterministic("mean", (x + y) / 2)
"""
def factor(name: str, log_factor: ArrayLike) -> None:
"""
Add a log probability factor to the model's joint density.
Args:
name: Name of the factor site
log_factor: Log probability value to add to joint density
Usage:
# Add log-likelihood term directly
numpyro.factor("custom_loglik", -0.5 * jnp.sum((y - mu)**2) / sigma**2)
# Add constraint violation penalty
numpyro.factor("penalty", -1e6 * jnp.where(x < 0, 1.0, 0.0))
# Add custom prior term
numpyro.factor("custom_prior", dist.Gamma(2, 1).log_prob(sigma))
"""Primitives for handling conditional independence and subsetting.
class plate:
"""
Context manager for conditionally independent variables with automatic broadcasting.
Args:
name: Plate name (must be unique)
size: Size of the independence dimension
subsample_size: Size of subsample (for subsampling, optional)
dim: Dimension for broadcasting (negative, optional)
subsample: Indices for subsampling (optional)
Usage:
# Basic conditional independence
with numpyro.plate("data", 100):
x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (100,)
# Subsampling for large datasets
with numpyro.plate("data", 10000, subsample_size=100) as idx:
# idx contains the subsample indices
x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (100,)
# Nested plates for multidimensional independence
with numpyro.plate("batch", 50, dim=-2):
with numpyro.plate("features", 10, dim=-1):
weights = numpyro.sample("w", dist.Normal(0, 1)) # Shape: (50, 10)
"""
def __init__(self, name: str, size: int, subsample_size: Optional[int] = None,
dim: Optional[int] = None, subsample: Optional[ArrayLike] = None): ...
def __enter__(self) -> Optional[Array]:
"""Enter plate context, returning subsample indices if subsampling."""
def __exit__(self, exc_type, exc_value, traceback): ...
def plate_stack(prefix: str, sizes: list[int], rightmost_dim: int = -1) -> list:
"""
Create a stack of nested plates for multidimensional conditional independence.
Args:
prefix: Prefix for plate names
sizes: List of sizes for each dimension
rightmost_dim: Rightmost dimension index
Returns:
List of plate contexts
Usage:
# Create 3D tensor of independent samples
plates = numpyro.plate_stack("data", [20, 30, 40], rightmost_dim=-3)
with plates[0]:
with plates[1]:
with plates[2]:
x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (20, 30, 40)
"""
def subsample(data: ArrayLike, event_dim: int) -> ArrayLike:
"""
Subsample data based on active plates in the context.
Args:
data: Data tensor to subsample
event_dim: Number of rightmost dimensions that are event dimensions
Returns:
Subsampled data tensor
Usage:
# Subsample based on active plate
with numpyro.plate("data", len(full_data), subsample_size=100):
batch_data = numpyro.subsample(full_data, event_dim=0)
x = numpyro.sample("x", dist.Normal(batch_data, 1))
"""Specialized primitives for advanced modeling scenarios.
def mutable(name: str, init_value: Optional[ArrayLike] = None) -> ArrayLike:
"""
Create mutable storage that persists across function calls.
Args:
name: Name of the mutable site
init_value: Initial value for the mutable storage
Returns:
Current value of mutable storage
Usage:
# Counter that increments each call
count = numpyro.mutable("counter", 0)
numpyro.mutable("counter", count + 1) # Update the counter
"""
def module(name: str, nn: tuple, input_shape: Optional[tuple] = None) -> Callable:
"""
Register neural network modules for use with JAX transformations.
Args:
name: Module name
nn: Tuple of (init_fn, apply_fn) for neural network
input_shape: Input shape for module initialization
Returns:
Module function that can be called with inputs
Usage:
# Haiku neural network
import haiku as hk
def net_fn(x):
return hk.nets.MLP([64, 32, 1])(x)
net = hk.transform(net_fn)
module_fn = numpyro.module("mlp", net, input_shape=(10,))
# Use in model
x = numpyro.sample("x", dist.Normal(0, 1).expand((batch_size, 10)))
y_pred = module_fn(x)
"""
def prng_key() -> Optional[Array]:
"""
Get the current PRNG key from the execution context.
Returns:
Current random key or None if not available
Usage:
# Get key for manual random operations
key = numpyro.prng_key()
if key is not None:
noise = random.normal(key, shape=(10,))
"""
def get_mask() -> Optional[ArrayLike]:
"""
Get the current mask from the handler stack.
Returns:
Current mask array or None if no mask is active
Usage:
# Check if masking is active
current_mask = numpyro.get_mask()
if current_mask is not None:
# Handle masked computation
pass
"""Internal functions used by the primitive system (typically not used directly).
def _masked_observe(name: str, fn: Distribution, obs: ArrayLike,
obs_mask: ArrayLike, **kwargs) -> ArrayLike:
"""
Handle masked observations in sample sites.
Args:
name: Site name
fn: Distribution
obs: Observed values
obs_mask: Boolean mask for valid observations
**kwargs: Additional arguments
Returns:
Masked observed value
"""
def _subsample_fn(size: int, subsample_size: int,
rng_key: Optional[Array] = None) -> Array:
"""
Generate subsample indices for plate subsampling.
Args:
size: Full dataset size
subsample_size: Size of subsample
rng_key: Random key for sampling
Returns:
Array of subsample indices
"""
def _inspect() -> dict:
"""
Inspect the current Pyro stack (experimental).
Returns:
Dictionary containing stack information
"""
class CondIndepStackFrame:
"""
Named tuple representing a conditional independence stack frame.
Attributes:
name: Frame name
dim: Broadcasting dimension
size: Frame size
counter: Frame counter for tracking
"""
name: str
dim: int
size: int
counter: intUtilities for validating models and inspecting execution.
def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:
"""
Validate model structure and return trace information.
Args:
model: Model function to validate
*model_args: Arguments to pass to model
**model_kwargs: Keyword arguments to pass to model
Returns:
Dictionary containing validation results and trace information
Usage:
def my_model():
x = numpyro.sample("x", dist.Normal(0, 1))
y = numpyro.sample("y", dist.Normal(x, 1))
validation_info = numpyro.validate_model(my_model)
print(f"Model has {len(validation_info['sites'])} sites")
"""
def inspect_fn(fn: Callable, *args, **kwargs) -> dict:
"""
Inspect function execution and return detailed information.
Args:
fn: Function to inspect
*args: Arguments to pass to function
**kwargs: Keyword arguments to pass to function
Returns:
Dictionary with execution information including sites and dependencies
"""import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random
# Basic linear regression model
def linear_regression(X, y=None):
# Prior parameters
alpha = numpyro.sample("alpha", dist.Normal(0, 10))
beta = numpyro.sample("beta", dist.Normal(0, 10))
sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)
# Linear prediction
mu = alpha + beta * X
# Mark prediction for tracking
prediction = numpyro.deterministic("prediction", mu)
# Likelihood with conditional independence over data points
with numpyro.plate("data", X.shape[0]):
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
# Hierarchical model with nested plates
def hierarchical_model(group_idx, y=None):
n_groups = len(jnp.unique(group_idx))
n_obs = len(y) if y is not None else len(group_idx)
# Global hyperparameters
mu_global = numpyro.sample("mu_global", dist.Normal(0, 1))
sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))
# Group-level parameters
with numpyro.plate("groups", n_groups):
mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))
# Observation-level likelihood
with numpyro.plate("obs", n_obs):
mu = mu_group[group_idx]
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
# Model with subsampling for large datasets
def large_dataset_model(X, y=None):
n_data, n_features = X.shape
# Parameters
weights = numpyro.sample("weights", dist.Normal(0, 1).expand((n_features,)))
# Subsample for computational efficiency
with numpyro.plate("data", n_data, subsample_size=min(1000, n_data)) as idx:
X_batch = numpyro.subsample(X, event_dim=1)[idx] if idx is not None else X
y_batch = numpyro.subsample(y, event_dim=0)[idx] if y is not None and idx is not None else y
mu = X_batch @ weights
numpyro.sample("y", dist.Normal(mu, 0.1), obs=y_batch)
# Custom factor for non-standard likelihoods
def custom_likelihood_model(data):
theta = numpyro.sample("theta", dist.Beta(1, 1))
# Custom log-likelihood that doesn't fit standard distributions
log_lik = jnp.sum(data * jnp.log(theta) + (1 - data) * jnp.log(1 - theta))
numpyro.factor("custom_lik", log_lik)from typing import Optional, Union, Callable, Dict, Any, Tuple
from jax import Array
import jax.numpy as jnp
from numpyro.distributions import Distribution, constraints
ArrayLike = Union[Array, jnp.ndarray, float, int]
Constraint = constraints.Constraint
InitFunction = Union[ArrayLike, Callable[[Array], ArrayLike]]
class CondIndepStackFrame:
"""Frame in the conditional independence stack."""
name: str
dim: int
size: int
counter: int
class PlateMessenger:
"""Messenger for plate context management."""
name: str
size: int
subsample_size: Optional[int]
dim: Optional[int]
subsample: Optional[Array]
# Site types for different primitive operations
SiteType = Union["sample", "param", "deterministic", "factor", "mutable"]
class SiteInfo:
"""Information about a primitive site."""
name: str
type: SiteType
fn: Optional[Distribution]
args: tuple
kwargs: dict
value: Any
is_observed: bool
infer: dict
scale: Optional[float]
class ValidationResult:
"""Result from model validation."""
sites: dict
dependencies: dict
plate_stack: list
is_valid: bool
warnings: list
errors: listInstall with Tessl CLI
npx tessl i tessl/pypi-numpyro