Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
npx @tessl/cli install tessl/pypi-numpyro@0.19.0NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation to GPU/TPU/CPU. It enables Bayesian modeling and statistical inference through MCMC algorithms like Hamiltonian Monte Carlo and No U-Turn Sampler, variational inference methods, and a comprehensive distributions module. The library is designed for machine learning researchers and practitioners who need efficient probabilistic modeling capabilities with the ability to scale computations across different hardware platforms.
pip install numpyroimport numpyroCommon patterns for probabilistic modeling:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro import sample, param, plateJAX integration:
import jax
import jax.numpy as jnp
from jax import randomimport numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
from jax import random
# Define a simple Bayesian linear regression model
def linear_regression(X, y=None):
# Priors
alpha = numpyro.sample('alpha', dist.Normal(0, 10))
beta = numpyro.sample('beta', dist.Normal(0, 10))
sigma = numpyro.sample('sigma', dist.Exponential(1))
# Linear model
mu = alpha + beta * X
# Likelihood
with numpyro.plate('data', X.shape[0]):
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
# Generate synthetic data
key = random.PRNGKey(0)
X = jnp.linspace(0, 1, 100)
true_alpha, true_beta = 1.0, 2.0
y = true_alpha + true_beta * X + 0.1 * random.normal(key, shape=(100,))
# Run MCMC inference
kernel = NUTS(linear_regression)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(1), X, y)
# Get posterior samples
samples = mcmc.get_samples()
print(f"Posterior mean for alpha: {jnp.mean(samples['alpha']):.3f}")
print(f"Posterior mean for beta: {jnp.mean(samples['beta']):.3f}")Variational inference example:
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import optax
# Define guide (variational family)
guide = AutoNormal(linear_regression)
# Set up SVI
optimizer = optax.adam(0.01)
svi = SVI(linear_regression, guide, optimizer, Trace_ELBO())
# Run variational inference
svi_result = svi.run(random.PRNGKey(2), 2000, X, y)NumPyro's architecture is built on several key design principles:
NumPyro uses Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. This enables powerful model manipulation capabilities like conditioning on observed data, substituting values, and applying transformations.
Built on JAX, NumPyro leverages automatic differentiation, JIT compilation, and vectorization for high-performance numerical computing. This enables efficient gradient-based inference algorithms and scalable computations across CPU, GPU, and TPU.
A comprehensive collection of 150+ probability distributions organized by type (continuous, discrete, conjugate, directional, mixture, truncated) with consistent interfaces and support for batching and broadcasting.
Multiple inference backends including:
Core primitives (sample, param, plate) for model construction with support for probabilistic control flow through JAX's functional programming primitives.
Core primitives for defining probabilistic models including sampling from distributions, defining parameters, and handling conditional independence through plates.
def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None,
rng_key: Optional[Array] = None, sample_shape: tuple = (),
infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike
def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None,
constraint: Constraint = constraints.real, event_dim: Optional[int] = None) -> ArrayLike
def plate(name: str, size: int, subsample_size: Optional[int] = None,
dim: Optional[int] = None) -> CondIndepStackFrame
def deterministic(name: str, value: ArrayLike) -> ArrayLike
def factor(name: str, log_factor: ArrayLike) -> NoneComprehensive collection of 150+ probability distributions across continuous, discrete, conjugate, directional, mixture, and truncated families with consistent interfaces and extensive parameterization options.
# Continuous distributions
class Normal(Distribution): ...
class Beta(Distribution): ...
class Gamma(Distribution): ...
class MultivariateNormal(Distribution): ...
# Discrete distributions
class Bernoulli(Distribution): ...
class Categorical(Distribution): ...
class Poisson(Distribution): ...
# Specialized distributions
class Mixture(Distribution): ...
class TruncatedDistribution(Distribution): ...Multiple inference backends including MCMC samplers, variational inference methods, and ensemble techniques for Bayesian posterior computation.
class MCMC:
def __init__(self, kernel, num_warmup: int, num_samples: int,
num_chains: int = 1, postprocess_fn: Optional[Callable] = None): ...
def run(self, rng_key: Array, *args, **kwargs) -> None: ...
def get_samples(self, group_by_chain: bool = False) -> dict: ...
class SVI:
def __init__(self, model, guide, optim, loss, **kwargs): ...
def run(self, rng_key: Array, num_steps: int, *args, **kwargs): ...Pyro-style effect handlers for intercepting and modifying probabilistic program execution, enabling conditioning, substitution, masking, and other model transformations.
def trace(fn: Callable) -> Callable: ...
def replay(fn: Callable, trace: dict) -> Callable: ...
def condition(fn: Callable, data: dict) -> Callable: ...
def substitute(fn: Callable, data: dict) -> Callable: ...
def seed(fn: Callable, rng_seed: int) -> Callable: ...
def block(fn: Callable, hide_fn: Optional[Callable] = None,
expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable: ...Collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation.
class Adam:
def __init__(self, step_size: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8): ...
class SGD:
def __init__(self, step_size: float, momentum: float = 0): ...
class RMSProp:
def __init__(self, step_size: float, decay: float = 0.9, eps: float = 1e-8): ...Diagnostic utilities for assessing MCMC convergence, effective sample size, and posterior summary statistics.
def effective_sample_size(x: NDArray) -> NDArray: ...
def gelman_rubin(x: NDArray) -> NDArray: ...
def split_gelman_rubin(x: NDArray) -> NDArray: ...
def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray: ...
def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None: ...JAX configuration utilities, control flow primitives, and helper functions for model development and debugging.
def enable_x64(use_x64: bool = True) -> None: ...
def set_platform(platform: Optional[str] = None) -> None: ...
def set_host_device_count(n: int) -> None: ...
def cond(pred, true_operand, true_fun, false_operand, false_fun): ...
def while_loop(cond_fun, body_fun, init_val): ...from typing import Optional, Union, Callable, Dict, Any
from jax import Array
import jax.numpy as jnp
ArrayLike = Union[Array, jnp.ndarray, float, int]
NDArray = jnp.ndarray
Distribution = numpyro.distributions.Distribution
Constraint = numpyro.distributions.constraints.Constraint
class CondIndepStackFrame:
name: str
dim: int
size: int
subsample_size: Optional[int]
class Messenger:
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
def process_message(self, msg: dict) -> None: ...