Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro provides essential utility functions for JAX configuration, control flow primitives, model validation, and development helpers. These utilities enable efficient probabilistic programming with proper hardware acceleration, memory management, and debugging capabilities.
Functions for configuring JAX behavior and hardware acceleration.
def enable_x64(use_x64: bool = True) -> None:
"""
Enable or disable 64-bit precision for JAX computations.
By default, JAX uses 32-bit precision for performance. Enable 64-bit
precision when higher numerical accuracy is needed.
Args:
use_x64: Whether to use 64-bit precision (default: True)
Usage:
# Enable double precision for numerical stability
numpyro.enable_x64(True)
# Disable to return to 32-bit (faster but less precise)
numpyro.enable_x64(False)
# Check current precision
import jax
print(f"Current precision: {jax.config.jax_enable_x64}")
"""
def set_platform(platform: Optional[str] = None) -> None:
"""
Set the JAX platform for computations.
Args:
platform: Platform name ('cpu', 'gpu', 'tpu', or None for auto-detection)
Usage:
# Force CPU computation
numpyro.set_platform('cpu')
# Use GPU if available
numpyro.set_platform('gpu')
# Let JAX auto-detect best platform
numpyro.set_platform(None)
# Check current platform
import jax
print(f"Current platform: {jax.default_backend()}")
"""
def set_host_device_count(n: int) -> None:
"""
Set the number of CPU devices for parallel computation.
Useful for parallelizing MCMC chains across multiple CPU cores
when GPU is not available or desired.
Args:
n: Number of CPU devices to use
Usage:
# Use 4 CPU devices for parallel chains
numpyro.set_host_device_count(4)
# Then run MCMC with multiple chains
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
mcmc.run(rng_key, data) # Will use 4 CPU devices
"""
def set_rng_seed(rng_seed: Optional[int] = None) -> None:
"""
Set global random seed for reproducible results.
Args:
rng_seed: Random seed value (None to use system entropy)
Usage:
# Set seed for reproducible experiments
numpyro.set_rng_seed(42)
# Clear seed to use random initialization
numpyro.set_rng_seed(None)
"""JAX-compatible control flow functions for probabilistic programs.
def cond(pred: ArrayLike, true_operand: Any, true_fun: Callable,
false_operand: Any, false_fun: Callable) -> Any:
"""
JAX-compatible conditional execution primitive.
Provides structured control flow that works with JAX transformations
like JIT compilation and automatic differentiation.
Args:
pred: Boolean condition for branching
true_operand: Operand passed to true_fun if pred is True
true_fun: Function to call if pred is True
false_operand: Operand passed to false_fun if pred is False
false_fun: Function to call if pred is False
Returns:
Result of the executed branch
Usage:
def model(x):
# Conditional model structure
def high_noise_model(x):
return numpyro.sample("y", dist.Normal(x, 2.0))
def low_noise_model(x):
return numpyro.sample("y", dist.Normal(x, 0.1))
# Switch based on input value
is_high = x > 0.5
return numpyro.cond(is_high, x, high_noise_model, x, low_noise_model)
"""
def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any:
"""
JAX-compatible while loop primitive.
Executes body_fun repeatedly while cond_fun returns True.
Compatible with JAX transformations.
Args:
cond_fun: Function that takes loop state and returns boolean
body_fun: Function that takes loop state and returns new state
init_val: Initial loop state
Returns:
Final loop state
Usage:
def iterative_sampler(key, n_steps):
def cond_fun(state):
step, _, _ = state
return step < n_steps
def body_fun(state):
step, key, samples = state
key, subkey = random.split(key)
new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))
return step + 1, key, samples.at[step].set(new_sample)
init_samples = jnp.zeros(n_steps)
_, _, final_samples = numpyro.while_loop(
cond_fun, body_fun, (0, key, init_samples)
)
return final_samples
"""
def fori_loop(lower: int, upper: int, body_fun: Callable, init_val: Any) -> Any:
"""
JAX-compatible for loop primitive.
Executes body_fun for indices from lower to upper-1.
Args:
lower: Starting index (inclusive)
upper: Ending index (exclusive)
body_fun: Function that takes (index, state) and returns new state
init_val: Initial loop state
Returns:
Final loop state
Usage:
def accumulate_samples(key, n_samples):
def body_fun(i, state):
key, total = state
key, subkey = random.split(key)
sample = random.normal(subkey)
return key, total + sample
key, final_total = numpyro.fori_loop(0, n_samples, body_fun, (key, 0.0))
return final_total / n_samples
"""Functions for managing memory usage in large-scale computations.
def soft_vmap(fn: Callable, xs: ArrayLike, batch_ndims: int = 1,
chunk_size: Optional[int] = None) -> ArrayLike:
"""
Memory-efficient vectorized map that processes data in chunks.
Alternative to jax.vmap that avoids memory issues with large datasets
by processing inputs in smaller chunks.
Args:
fn: Function to vectorize
xs: Input arrays to map over
batch_ndims: Number of batch dimensions to map over
chunk_size: Size of chunks to process (None for auto-selection)
Returns:
Vectorized results concatenated from chunks
Usage:
# Process large dataset without memory overflow
def expensive_computation(x):
return x @ weight_matrix # Large matrix multiplication
large_data = jnp.ones((10000, 1000)) # Would cause OOM with vmap
# Process in chunks
results = numpyro.soft_vmap(expensive_computation, large_data, chunk_size=100)
# Shape: (10000, output_dim)
"""
def fori_collect(lower: int, upper: int, body_fun: Callable, init_val: Any,
transform: Optional[Callable] = None, progbar: bool = True,
return_last_val: bool = False, collection_size: Optional[int] = None,
**progbar_opts) -> Union[tuple, ArrayLike]:
"""
For loop with collection and optional progress bar.
Collects outputs from each iteration while optionally displaying progress.
Useful for iterative algorithms where you need to track intermediate results.
Args:
lower: Starting index
upper: Ending index
body_fun: Function returning (new_state, collection_item)
init_val: Initial state
transform: Optional transform applied to collected items
progbar: Whether to show progress bar
return_last_val: Whether to return final state
collection_size: Pre-allocate collection array size
**progbar_opts: Additional progress bar options
Returns:
Collection of items (and optionally final state)
Usage:
# Collect MCMC samples with progress tracking
def mcmc_step(i, state):
key, params = state
key, subkey = random.split(key)
# Single MCMC step
new_params = mcmc_kernel_step(subkey, params)
return (key, new_params), new_params # (new_state, collect_item)
init_state = (random.PRNGKey(0), init_params)
samples = numpyro.fori_collect(0, 1000, mcmc_step, init_state, progbar=True)
"""Utilities for validating models and debugging probabilistic programs.
def format_shapes(trace: dict, last_site: Optional[str] = None) -> str:
"""
Format trace shapes for debugging model structure.
Provides a readable summary of all sites in a model trace with their
shapes, which is useful for debugging broadcasting and plate issues.
Args:
trace: Execution trace from model
last_site: Name of last site to include (None for all sites)
Returns:
Formatted string showing site shapes
Usage:
# Debug model shapes
from numpyro.handlers import trace
def model():
with numpyro.plate("batch", 10):
x = numpyro.sample("x", dist.Normal(0, 1)) # Should be (10,)
with numpyro.plate("features", 5):
y = numpyro.sample("y", dist.Normal(x.expand((5,)), 1)) # Should be (10, 5)
traced_model = trace(model)
trace_dict = traced_model()
shape_info = numpyro.format_shapes(trace_dict)
print(shape_info)
# Output:
# Site shapes:
# x: (10,)
# y: (10, 5)
"""
def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None:
"""
Validate that model and guide have compatible structure.
Ensures that the guide provides variational distributions for all
sample sites in the model, which is required for SVI.
Args:
model_trace: Trace from model execution
guide_trace: Trace from guide execution
Raises:
ValueError: If model and guide are incompatible
Usage:
# Validate model-guide compatibility before SVI
from numpyro.handlers import trace
model_trace = trace(model).get_trace(data)
guide_trace = trace(guide).get_trace(data)
try:
numpyro.check_model_guide_match(model_trace, guide_trace)
print("✓ Model and guide are compatible")
except ValueError as e:
print(f"✗ Compatibility error: {e}")
"""
def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:
"""
Comprehensive model validation and structure analysis.
Args:
model: Model function to validate
*model_args: Arguments to pass to model
**model_kwargs: Keyword arguments to pass to model
Returns:
Dictionary containing validation results and model information
Usage:
def my_model(x, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
with numpyro.plate("data", len(x)):
numpyro.sample("y", dist.Normal(alpha + x, 1), obs=y)
x_data = jnp.linspace(0, 1, 100)
validation = numpyro.validate_model(my_model, x_data)
print(f"Number of sample sites: {len(validation['sample_sites'])}")
print(f"Model structure: {validation['structure']}")
print(f"Validation passed: {validation['is_valid']}")
"""Helper functions for development and performance optimization.
def maybe_jit(fn: Callable, *args, **kwargs) -> Callable:
"""
Conditionally apply JIT compilation based on context.
Automatically determines whether to JIT compile based on the computational
context and function characteristics.
Args:
fn: Function to potentially JIT compile
*args: Arguments that would be passed to function
**kwargs: Keyword arguments
Returns:
JIT-compiled or original function
Usage:
# Automatically optimize based on usage pattern
def expensive_computation(x):
return jnp.sum(x ** 2)
optimized_fn = numpyro.maybe_jit(expensive_computation)
result = optimized_fn(large_array) # Will be JIT compiled if beneficial
"""
def progress_bar_factory(num_samples: int, num_chains: int = 1) -> Callable:
"""
Create progress bar decorators for iterative algorithms.
Args:
num_samples: Total number of samples/iterations
num_chains: Number of parallel chains
Returns:
Progress bar decorator function
Usage:
# Add progress bars to custom sampling loops
progress_bar = numpyro.progress_bar_factory(1000, num_chains=4)
@progress_bar
def sampling_step(i, state):
# Custom sampling logic
return new_state
# Progress will be displayed automatically
final_state = fori_loop(0, 1000, sampling_step, init_state)
"""
def cached_by(outer_fn: Callable, *keys) -> Callable:
"""
Function caching decorator with custom cache keys.
Caches function results based on specified keys to avoid recomputation
of expensive operations.
Args:
outer_fn: Function to cache
*keys: Keys to use for cache lookup
Returns:
Cached version of the function
Usage:
# Cache expensive model compilations
@numpyro.cached_by(lambda model, data_shape: (model.__name__, data_shape))
def compile_model(model, data_shape):
# Expensive JIT compilation
return jit(model)
compiled_model = compile_model(my_model, (100,)) # Compiled once
compiled_model = compile_model(my_model, (100,)) # Retrieved from cache
"""
def identity(x: Any, *args, **kwargs) -> Any:
"""
Identity function that returns input unchanged.
Useful as a placeholder or default function in conditional contexts.
Args:
x: Input value
*args: Ignored additional arguments
**kwargs: Ignored keyword arguments
Returns:
Input value unchanged
"""
def not_jax_tracer(x: Any) -> bool:
"""
Check if value is not a JAX tracer.
Useful for conditional logic that depends on whether values are
concrete or abstract (traced) in JAX transformations.
Args:
x: Value to check
Returns:
True if x is not a JAX tracer, False otherwise
Usage:
def conditional_computation(x):
if numpyro.not_jax_tracer(x):
# This branch only executes with concrete values
print(f"Concrete value: {x}")
return x ** 2
"""
def is_prng_key(key: Any) -> bool:
"""
Validate that input is a proper PRNG key.
Args:
key: Potential PRNG key to validate
Returns:
True if key is a valid PRNG key
Usage:
from jax import random
key = random.PRNGKey(0)
if numpyro.is_prng_key(key):
subkey = random.split(key)[0]
else:
raise ValueError("Invalid PRNG key")
"""Utilities for context management and execution control.
def optional(condition: bool, context_manager: Any) -> Any:
"""
Conditionally apply a context manager.
Args:
condition: Whether to apply the context manager
context_manager: Context manager to apply if condition is True
Returns:
Context manager or no-op context
Usage:
# Conditionally enable validation
use_validation = True
with numpyro.optional(use_validation, numpyro.validation_enabled()):
result = model() # Validation applied only if use_validation=True
"""
def control_flow_prims_disabled() -> bool:
"""
Check if control flow primitives are disabled.
Returns:
True if control flow primitives (cond, while_loop) are disabled
Usage:
if numpyro.control_flow_prims_disabled():
# Use alternative implementation without control flow
result = alternative_implementation()
else:
result = numpyro.cond(pred, true_op, true_fn, false_op, false_fn)
"""
def nested_attrgetter(*collect_fields: str) -> Callable:
"""
Create getter for nested attributes in complex data structures.
Args:
*collect_fields: Dot-separated field paths to extract
Returns:
Function that extracts specified fields from objects
Usage:
# Extract nested fields from complex results
getter = numpyro.nested_attrgetter("params.mu.loc", "losses")
# Apply to SVI results
svi_result = svi.run(key, 1000, data)
extracted = getter(svi_result) # Gets params.mu.loc and losses
"""
def find_stack_level() -> int:
"""
Find appropriate stack level for warnings.
Helper function for issuing warnings at the correct stack level
in complex call hierarchies.
Returns:
Appropriate stack level for warnings
"""import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
import jax.numpy as jnp
from jax import random
# JAX configuration for optimal performance
def setup_jax_environment():
"""Configure JAX for optimal NumPyro performance."""
# Enable 64-bit precision for numerical stability
numpyro.enable_x64(True)
# Use GPU if available
numpyro.set_platform('gpu') # Falls back to CPU if GPU unavailable
# Set up multiple CPU devices for parallel chains
numpyro.set_host_device_count(4)
# Set random seed for reproducibility
numpyro.set_rng_seed(42)
print(f"JAX platform: {jax.default_backend()}")
print(f"JAX devices: {jax.device_count()}")
print(f"64-bit enabled: {jax.config.jax_enable_x64}")
# Control flow in probabilistic models
def control_flow_example():
"""Example using JAX-compatible control flow."""
def adaptive_model(x):
# Model switches behavior based on input
def simple_model(x):
return numpyro.sample("y", dist.Normal(x, 0.1))
def complex_model(x):
hidden = numpyro.sample("hidden", dist.Normal(0, 1))
return numpyro.sample("y", dist.Normal(x + hidden, 0.5))
# Use control flow primitive
is_complex = x > 0.5
return numpyro.cond(is_complex, x, complex_model, x, simple_model)
# Iterative sampling with while loop
def iterative_sampler(key, threshold=1.0):
def cond_fun(state):
_, _, total = state
return jnp.abs(total) < threshold
def body_fun(state):
step, key, total = state
key, subkey = random.split(key)
with handlers.seed(rng_seed=subkey):
new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))
return step + 1, key, total + new_sample
_, _, final_total = numpyro.while_loop(cond_fun, body_fun, (0, key, 0.0))
return final_total
return adaptive_model, iterative_sampler
# Memory-efficient processing
def large_scale_example():
"""Example of memory-efficient utilities for large datasets."""
# Simulate large dataset
n_data = 100000
x_large = random.normal(random.PRNGKey(0), (n_data, 50))
def expensive_transform(x_batch):
# Simulate expensive computation
return jnp.sum(x_batch ** 2, axis=1)
# Process in chunks to avoid memory issues
results = numpyro.soft_vmap(
expensive_transform,
x_large,
chunk_size=1000 # Process 1000 samples at a time
)
print(f"Processed {n_data} samples in chunks")
print(f"Result shape: {results.shape}")
# Collect results with progress tracking
def progressive_computation():
def compute_step(i, state):
current_sum = state
# Simulate computation
new_value = jnp.sum(results[i*1000:(i+1)*1000])
return current_sum + new_value, new_value
# Use fori_collect with progress bar
final_sum, intermediate_sums = numpyro.fori_collect(
0, n_data // 1000,
compute_step,
0.0,
progbar=True,
return_last_val=True
)
return final_sum, intermediate_sums
return progressive_computation()
# Model validation workflow
def validation_workflow_example():
"""Comprehensive model validation example."""
def potentially_problematic_model(x, y=None):
# Model with potential issues
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
beta = numpyro.sample("beta", dist.Normal(0, 1))
# Potential broadcasting issue
with numpyro.plate("data", len(x)):
mu = alpha + beta * x # Check shapes here
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
def guide(x, y=None):
# Variational guide
alpha_loc = numpyro.param("alpha_loc", 0.0)
alpha_scale = numpyro.param("alpha_scale", 1.0, constraint=constraints.positive)
beta_loc = numpyro.param("beta_loc", 0.0)
beta_scale = numpyro.param("beta_scale", 1.0, constraint=constraints.positive)
numpyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale))
numpyro.sample("beta", dist.Normal(beta_loc, beta_scale))
# Generate test data
x_test = jnp.linspace(0, 1, 100)
y_test = 1.5 + 2.0 * x_test + 0.1 * random.normal(random.PRNGKey(0), (100,))
print("=== Model Validation Report ===")
# 1. Validate model structure
try:
validation_result = numpyro.validate_model(potentially_problematic_model, x_test, y_test)
print("✓ Model structure validation passed")
print(f" Sample sites: {len(validation_result.get('sample_sites', []))}")
except Exception as e:
print(f"✗ Model validation failed: {e}")
return
# 2. Check model shapes
from numpyro.handlers import trace
try:
model_trace = trace(potentially_problematic_model).get_trace(x_test, y_test)
shape_info = numpyro.format_shapes(model_trace)
print("✓ Shape analysis:")
print(shape_info)
except Exception as e:
print(f"✗ Shape analysis failed: {e}")
# 3. Validate model-guide compatibility
try:
guide_trace = trace(guide).get_trace(x_test, y_test)
numpyro.check_model_guide_match(model_trace, guide_trace)
print("✓ Model-guide compatibility verified")
except Exception as e:
print(f"✗ Model-guide compatibility failed: {e}")
# 4. Test with different JAX configurations
original_x64 = jax.config.jax_enable_x64
for use_x64 in [False, True]:
numpyro.enable_x64(use_x64)
precision = "64-bit" if use_x64 else "32-bit"
try:
# Quick MCMC test
mcmc = MCMC(NUTS(potentially_problematic_model),
num_warmup=100, num_samples=100, num_chains=2)
mcmc.run(random.PRNGKey(0), x_test, y_test)
print(f"✓ {precision} MCMC test passed")
except Exception as e:
print(f"✗ {precision} MCMC test failed: {e}")
# Restore original precision
numpyro.enable_x64(original_x64)
# Performance optimization example
def performance_optimization_example():
"""Example of performance optimization utilities."""
def expensive_model(x):
# Model with expensive computations
weights = numpyro.sample("weights", dist.Normal(0, 1).expand((100, 50)))
# Expensive matrix operations
transformed = x @ weights.T
result = numpyro.sample("result", dist.Normal(transformed, 0.1))
return result
# Create cached version
@numpyro.cached_by(lambda x_shape: x_shape) # Cache by input shape
def compile_model(x_shape):
def compiled_fn(x):
return expensive_model(x)
return jit(compiled_fn)
# Use maybe_jit for conditional optimization
adaptive_model = numpyro.maybe_jit(expensive_model)
# Test data
x = random.normal(random.PRNGKey(0), (1000, 100))
print("Performance comparison:")
# Time original model
import time
start_time = time.time()
result1 = expensive_model(x)
original_time = time.time() - start_time
print(f"Original model: {original_time:.3f}s")
# Time cached/compiled model
start_time = time.time()
compiled_fn = compile_model(x.shape)
result2 = compiled_fn(x)
cached_time = time.time() - start_time
print(f"Cached/compiled: {cached_time:.3f}s")
# Time adaptive model
start_time = time.time()
result3 = adaptive_model(x)
adaptive_time = time.time() - start_time
print(f"Adaptive JIT: {adaptive_time:.3f}s")
speedup = original_time / min(cached_time, adaptive_time)
print(f"Speedup: {speedup:.1f}x")from typing import Optional, Union, Callable, Dict, Any, Tuple, ContextManager
from jax import Array
import jax.numpy as jnp
ArrayLike = Union[Array, jnp.ndarray, float, int]
Platform = Union["cpu", "gpu", "tpu"]
ProgressBarOptions = Dict[str, Any]
class ValidationResult:
"""Result from model validation."""
is_valid: bool
sample_sites: Dict[str, Any]
param_sites: Dict[str, Any]
deterministic_sites: Dict[str, Any]
warnings: list
errors: list
structure: Dict[str, Any]
class TraceInfo:
"""Information about model trace structure."""
sites: Dict[str, Any]
shapes: Dict[str, tuple]
plate_stack: list
dependencies: Dict[str, list]
# Control flow function types
CondFun = Callable[[Any], bool]
BodyFun = Callable[[Any], Any]
TrueFun = Callable[[Any], Any]
FalseFun = Callable[[Any], Any]
# Loop types
LoopState = Any
LoopIndex = int
ForBodyFun = Callable[[LoopIndex, LoopState], LoopState]
CollectBodyFun = Callable[[LoopIndex, LoopState], Tuple[LoopState, Any]]
# Utility types
CacheKey = Any
CacheFun = Callable[..., CacheKey]
TransformFun = Optional[Callable[[Any], Any]]
ProgressBarFun = Callable[[Callable], Callable]
# Context manager types
ConditionalContext = Union[ContextManager, None]
OptionalContext = ContextManager
# Validation types
ModelFun = Callable[..., Any]
GuideFun = Callable[..., Any]
TraceDict = Dict[str, Any]
SiteDict = Dict[str, Any]Install with Tessl CLI
npx tessl i tessl/pypi-numpyro