Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro provides multiple inference algorithms for Bayesian posterior computation including Markov Chain Monte Carlo (MCMC) samplers, variational inference methods, ensemble techniques, and specialized algorithms. All inference methods are built on JAX for efficient automatic differentiation and JIT compilation.
Markov Chain Monte Carlo methods for sampling from posterior distributions.
class MCMC:
"""
Wrapper class for Markov Chain Monte Carlo inference algorithms.
Args:
kernel: MCMC kernel (e.g., NUTS, HMC)
num_warmup: Number of warmup steps
num_samples: Number of samples to draw
num_chains: Number of parallel chains
postprocess_fn: Post-processing function for samples
chain_method: Parallelization method ('parallel', 'sequential', 'vectorized')
progress_bar: Whether to show progress bar
jit_model_args: Whether to JIT compile model arguments
"""
def __init__(self, kernel, num_warmup: int, num_samples: int, num_chains: int = 1,
postprocess_fn: Optional[Callable] = None, chain_method: str = 'parallel',
progress_bar: bool = True, jit_model_args: bool = False): ...
def run(self, rng_key: Array, *args, extra_fields=(), init_params=None, **kwargs) -> None:
"""
Run MCMC sampling.
Args:
rng_key: Random key for sampling
*args: Arguments to pass to the model
extra_fields: Additional fields to collect
init_params: Initial parameter values
**kwargs: Keyword arguments to pass to the model
"""
def get_samples(self, group_by_chain: bool = False) -> dict:
"""
Get posterior samples.
Args:
group_by_chain: Whether to group samples by chain
Returns:
Dictionary of posterior samples
"""
def get_extra_fields(self, group_by_chain: bool = False) -> dict:
"""Get additional collected fields (e.g., diagnostics)."""
def print_summary(self, prob: float = 0.9, exclude_deterministic: bool = True) -> None:
"""Print summary statistics of posterior samples."""class HMC:
"""
Hamiltonian Monte Carlo kernel.
Args:
model: Python callable containing Pyro primitives
step_size: Step size for leapfrog integrator
num_steps: Number of leapfrog steps
adapt_step_size: Whether to adapt step size during warmup
adapt_mass_matrix: Whether to adapt mass matrix during warmup
dense_mass: Whether to use dense mass matrix
target_accept_prob: Target acceptance probability for step size adaptation
trajectory_length: Alternative to num_steps, specifies trajectory length
max_tree_depth: Maximum tree depth for trajectory building
find_heuristic_step_size: Whether to find good initial step size
forward_mode_differentiation: Whether to use forward-mode AD
regularize_mass_matrix: Whether to regularize mass matrix
"""
def __init__(self, model, step_size=1.0, num_steps=None, adapt_step_size=True,
adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8,
trajectory_length=None, max_tree_depth=10, find_heuristic_step_size=False,
forward_mode_differentiation=False, regularize_mass_matrix=True): ...
class NUTS:
"""
No-U-Turn Sampler, an adaptive variant of HMC.
Args:
model: Python callable containing Pyro primitives
step_size: Initial step size
adapt_step_size: Whether to adapt step size during warmup
adapt_mass_matrix: Whether to adapt mass matrix during warmup
dense_mass: Whether to use dense mass matrix
target_accept_prob: Target acceptance probability
max_tree_depth: Maximum tree depth for trajectory building
find_heuristic_step_size: Whether to find good initial step size
forward_mode_differentiation: Whether to use forward-mode AD
regularize_mass_matrix: Whether to regularize mass matrix
"""
def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,
dense_mass=False, target_accept_prob=0.8, max_tree_depth=10,
find_heuristic_step_size=False, forward_mode_differentiation=False,
regularize_mass_matrix=True): ...
class SA:
"""
Simulated Annealing kernel.
Args:
model: Python callable containing Pyro primitives
adapt_state_size: Size of adaptive state
restart_interval: Interval for restarting annealing
cooling_schedule: Temperature cooling schedule function
"""
def __init__(self, model, adapt_state_size=None, restart_interval=100,
cooling_schedule=None): ...
class BarkerMH:
"""
Barker Metropolis-Hastings kernel.
Args:
model: Python callable containing Pyro primitives
step_size: Step size for proposals
adapt_step_size: Whether to adapt step size
adapt_mass_matrix: Whether to adapt mass matrix
dense_mass: Whether to use dense mass matrix
target_accept_prob: Target acceptance probability
"""
def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,
dense_mass=False, target_accept_prob=0.234): ...class HMCGibbs:
"""
HMC-within-Gibbs sampler for models with discrete latent variables.
Args:
inner_kernel: Inner MCMC kernel (e.g., NUTS, HMC)
gibbs_fn: Gibbs sampling function for discrete variables
gibbs_sites: Names of discrete sites to sample with Gibbs
"""
def __init__(self, inner_kernel, gibbs_fn=None, gibbs_sites=None): ...
class DiscreteHMCGibbs:
"""
Specialized HMC-Gibbs for discrete variables.
Args:
inner_kernel: Inner kernel for continuous variables
modified: Whether to use modified proposal for discrete variables
gibbs_sites: Sites to sample with discrete Gibbs
"""
def __init__(self, inner_kernel, modified=True, gibbs_sites=None): ...
class HMCECS:
"""
HMC with Energy Conserving Subsampling for large datasets.
Args:
model: Python callable containing Pyro primitives
step_size: Step size for leapfrog integrator
trajectory_length: Length of HMC trajectory
num_blocks: Number of data blocks for subsampling
proxy: Proxy function for likelihood approximation
"""
def __init__(self, model, step_size=1.0, trajectory_length=1.0, num_blocks=1, proxy=None): ...
class MixedHMC:
"""
Mixed precision HMC for improved performance.
Args:
inner_kernel: Base HMC kernel
target_accept_prob: Target acceptance probability
trajectory_length: HMC trajectory length
"""
def __init__(self, inner_kernel, target_accept_prob=0.8, trajectory_length=1.0): ...Ensemble sampling algorithms for parallel chain sampling.
class ESS:
"""
Ensemble Slice Sampling.
Args:
model: Python callable containing Pyro primitives
max_slice_size: Maximum size of slice
num_slices: Number of slices per step
moves: Dictionary of move types and probabilities
"""
def __init__(self, model, max_slice_size=float('inf'), num_slices=1, moves=None): ...
class AIES:
"""
Affine Invariant Ensemble Sampler.
Args:
model: Python callable containing Pyro primitives
num_ensembles: Number of ensemble members
moves: Dictionary of move types and their configurations
"""
def __init__(self, model, num_ensembles=100, moves=None): ...Stochastic variational inference for approximate posterior computation.
class SVI:
"""
Stochastic Variational Inference.
Args:
model: Model function containing Pyro primitives
guide: Guide (variational family) function
optim: Optimizer for variational parameters
loss: Loss function (ELBO variant)
num_particles: Number of particles for gradient estimation
stable_update: Whether to use numerically stable updates
"""
def __init__(self, model, guide, optim, loss, num_particles=1, stable_update=False): ...
def run(self, rng_key: Array, num_steps: int, *args, progress_bar: bool = True,
stable_update: bool = False, **kwargs):
"""
Run stochastic variational inference.
Args:
rng_key: Random key for stochastic optimization
num_steps: Number of optimization steps
*args: Arguments to pass to model and guide
progress_bar: Whether to show progress bar
stable_update: Whether to use numerically stable updates
**kwargs: Keyword arguments to pass to model and guide
Returns:
SVIRunResult with losses and parameters
"""
def evaluate(self, rng_key: Array, *args, **kwargs) -> float:
"""Evaluate the current loss."""
def step(self, rng_key: Array, *args, **kwargs) -> float:
"""Take single SVI step."""
class SVIRunResult:
"""Result object from SVI.run()."""
losses: Array # Loss values over optimization
params: dict # Final parameter valuesclass ELBO:
"""
Base class for Evidence Lower BOund objectives.
Args:
num_particles: Number of particles for Monte Carlo estimation
vectorize_particles: Whether to vectorize over particles
ignore_jit_warnings: Whether to ignore JIT compilation warnings
"""
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
ignore_jit_warnings: bool = False): ...
def loss(self, rng_key: Array, param_map: dict, model: Callable, guide: Callable,
*args, **kwargs) -> float: ...
class Trace_ELBO(ELBO):
"""Standard ELBO using Monte Carlo estimation with reparameterized gradients."""
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
ignore_jit_warnings: bool = False): ...
class TraceEnum_ELBO(ELBO):
"""
ELBO with exact enumeration over discrete latent variables.
Args:
num_particles: Number of particles for continuous variables
max_plate_nesting: Maximum nesting level for enumeration
max_iarange_nesting: Deprecated alias for max_plate_nesting
strict_enumeration_warning: Whether to warn about enumeration issues
vectorize_particles: Whether to vectorize over particles
ignore_jit_warnings: Whether to ignore JIT warnings
"""
def __init__(self, num_particles: int = 1, max_plate_nesting: Optional[int] = None,
max_iarange_nesting: Optional[int] = None, strict_enumeration_warning: bool = True,
vectorize_particles: bool = False, ignore_jit_warnings: bool = False): ...
class TraceGraph_ELBO(ELBO):
"""
ELBO using Rao-Blackwellized gradient estimator.
Args:
num_particles: Number of particles for Monte Carlo estimation
vectorize_particles: Whether to vectorize over particles
ignore_jit_warnings: Whether to ignore JIT warnings
"""
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
ignore_jit_warnings: bool = False): ...
class TraceMeanField_ELBO(ELBO):
"""ELBO for mean field variational families."""
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
ignore_jit_warnings: bool = False): ...
class RenyiELBO(ELBO):
"""
Rényi divergence-based ELBO for more robust variational inference.
Args:
alpha: Rényi divergence parameter (alpha=1 recovers standard ELBO)
num_particles: Number of particles for Monte Carlo estimation
vectorize_particles: Whether to vectorize over particles
"""
def __init__(self, alpha: float = 0.0, num_particles: int = 1,
vectorize_particles: bool = False): ...# Located in numpyro.infer.autoguide module
class AutoGuide:
"""Base class for automatic variational guides."""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
create_plates=None): ...
def sample_posterior(self, rng_key: Array, params: dict, sample_shape=()) -> dict:
"""Sample from the approximate posterior."""
def median(self, params: dict) -> dict:
"""Compute median of the approximate posterior."""
def quantiles(self, params: dict, quantiles) -> dict:
"""Compute quantiles of the approximate posterior."""
class AutoNormal(AutoGuide):
"""
Multivariate normal variational family with diagonal covariance.
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function for location parameters
init_scale: Initial scale for variational parameters
create_plates: Function to create plates for batched parameters
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
init_scale: float = 0.1, create_plates=None): ...
class AutoMultivariateNormal(AutoGuide):
"""
Multivariate normal variational family with full covariance matrix.
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function for location parameters
init_scale: Initial scale for variational parameters
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
init_scale: float = 0.1): ...
class AutoLowRankMultivariateNormal(AutoGuide):
"""
Low-rank multivariate normal variational family.
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function
rank: Rank of low-rank approximation
init_scale: Initial scale parameter
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
rank: int = 1, init_scale: float = 0.1): ...
class AutoDiagonalNormal(AutoGuide):
"""Diagonal normal variational family (alias for AutoNormal)."""
class AutoLaplaceApproximation(AutoGuide):
"""
Laplace approximation around MAP estimate.
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...
class AutoDelta(AutoGuide):
"""
Point estimate guide (MAP approximation).
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function for point estimates
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...
class AutoIAFNormal(AutoGuide):
"""
Inverse Autoregressive Flow with normal base distribution.
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function
num_flows: Number of flow transformations
hidden_dims: Hidden dimensions for autoregressive networks
skip_connections: Whether to use skip connections
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
num_flows: int = 3, hidden_dims=None, skip_connections: bool = False): ...
class AutoBNAFNormal(AutoGuide):
"""
Block Neural Autoregressive Flow with normal base distribution.
Args:
model: Model function
prefix: Prefix for parameter names
init_loc_fn: Initialization function
num_flows: Number of flow layers
hidden_factors: Hidden layer size factors
residual: Whether to use residual connections
"""
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
num_flows: int = 1, hidden_factors=None, residual=None): ...
class AutoSurrogateLikelihoodDAG(AutoGuide):
"""Surrogate likelihood guide for DAG models."""
def __init__(self, model: Callable, prefix: str = "auto"): ...Functions for initializing MCMC chains and variational parameters.
def init_to_feasible(model: Callable, *model_args, **model_kwargs):
"""
Initialize to feasible values within parameter constraints.
Args:
model: Model function
*model_args: Arguments to the model
**model_kwargs: Keyword arguments to the model
Returns:
Initialization function
"""
def init_to_mean(model: Callable, *model_args, **model_kwargs):
"""Initialize parameters to their prior means (when available)."""
def init_to_median(model: Callable, *model_args, **model_kwargs):
"""Initialize parameters to their prior medians (when available)."""
def init_to_sample(model: Callable, *model_args, **model_kwargs):
"""Initialize parameters to samples from their priors."""
def init_to_uniform(model: Callable, radius: float = 2.0, *model_args, **model_kwargs):
"""
Initialize parameters uniformly within their support.
Args:
model: Model function
radius: Radius for uniform initialization in unconstrained space
"""
def init_to_value(values: dict):
"""
Initialize parameters to specified values.
Args:
values: Dictionary mapping parameter names to initial values
"""Utility functions for inference and posterior analysis.
class Predictive:
"""
Utility for posterior and prior predictive sampling.
Args:
model: Model function
posterior_samples: Dictionary of posterior samples (optional)
guide: Guide function for variational inference (optional)
params: Parameters for guide (when using variational inference)
num_samples: Number of samples to draw
return_sites: Sites to return in predictions
infer_discrete: Whether to infer discrete latent variables
parallel: Whether to run predictions in parallel
batch_ndims: Number of batch dimensions in posterior samples
"""
def __init__(self, model: Callable, posterior_samples: Optional[dict] = None,
guide: Optional[Callable] = None, params: Optional[dict] = None,
num_samples: Optional[int] = None, return_sites: Optional[list] = None,
infer_discrete: bool = False, parallel: bool = False, batch_ndims: int = 1): ...
def __call__(self, rng_key: Array, *args, **kwargs) -> dict:
"""
Generate predictions.
Args:
rng_key: Random key for sampling
*args: Arguments to pass to model
**kwargs: Keyword arguments to pass to model
Returns:
Dictionary of predicted values
"""
def log_likelihood(model: Callable, posterior_samples: dict, *args, **kwargs) -> dict:
"""
Compute log likelihood of observations given posterior samples.
Args:
model: Model function
posterior_samples: Dictionary of posterior samples
*args: Arguments to pass to model
**kwargs: Keyword arguments to pass to model
Returns:
Dictionary of log likelihood values for each observed site
"""
def render_model(model: Callable, model_args=(), model_kwargs=None, filename=None,
render_distributions: bool = False, render_params: bool = False,
hide_deterministic: bool = True):
"""
Render model structure as a graphical diagram.
Args:
model: Model function to render
model_args: Arguments to pass to model
model_kwargs: Keyword arguments to pass to model
filename: Output filename for rendered graph
render_distributions: Whether to show distribution details
render_params: Whether to show parameter nodes
hide_deterministic: Whether to hide deterministic sites
"""Reparameterization strategies for improving inference efficiency.
# Located in numpyro.infer.reparam module
class Reparam:
"""Base class for reparameterizations."""
def __call__(self, name: str, fn, obs) -> tuple: ...
class LocScaleReparam(Reparam):
"""
Reparameterization for location-scale distributions.
Args:
centered: Parameterization type (0=non-centered, 1=centered, None=adaptive)
"""
def __init__(self, centered: Optional[float] = None): ...
class TransformReparam(Reparam):
"""
Reparameterization using bijective transforms.
Args:
transform: Bijective transformation
suffix: Suffix for transformed variable names
"""
def __init__(self, transform, suffix: str = "_base"): ...
class NeuTraReparam(Reparam):
"""
Neural Transport reparameterization.
Args:
guide: Neural guide for reparameterization
params: Parameters for the guide
"""
def __init__(self, guide: Callable, params: dict): ...
class CircularReparam(Reparam):
"""Reparameterization for circular variables."""
class ProjectedNormalReparam(Reparam):
"""Reparameterization for projected normal distributions."""
class ImplicitReparam(Reparam):
"""Implicit reparameterization for complex posteriors."""
class SplitReparam(Reparam):
"""Split reparameterization for multivariate distributions."""
def __init__(self, sections: list, dim: int = -1): ...
class SymmetricSplitReparam(Reparam):
"""Symmetric split reparameterization."""
def __init__(self, sections: list, dim: int = -1): ...from typing import Optional, Union, Callable, Dict, Any, Tuple
from jax import Array
import jax.numpy as jnp
ArrayLike = Union[Array, jnp.ndarray, float, int]
MCMCKernel = Union[HMC, NUTS, SA, BarkerMH, HMCGibbs, DiscreteHMCGibbs, HMCECS, MixedHMC, ESS, AIES]
Optimizer = Any # From optax or numpyro.optim
LossFunction = Union[ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, RenyiELBO]
InitFunction = Callable[[Array, tuple, dict], dict]
class SVIState:
"""State object for SVI optimization."""
optim_state: Any
rng_key: Array
class SVIRunResult:
"""Result from SVI.run()."""
losses: Array
params: dict
state: SVIState
class MCMCState:
"""Internal state for MCMC kernels."""
z: dict # Current parameter values
potential_energy: float
z_grad: dict # Current gradients
adapt_state: Any # Adaptation state
rng_key: Array
# Kernel interfaces
class MCMCKernel:
"""Base interface for MCMC kernels."""
def init(self, rng_key: Array, num_warmup: int, init_params: dict,
model_args: tuple, model_kwargs: dict) -> MCMCState: ...
def sample(self, state: MCMCState, model_args: tuple, model_kwargs: dict) -> MCMCState: ...
def postprocess_fn(self, args: tuple, kwargs: dict) -> Callable: ...Install with Tessl CLI
npx tessl i tessl/pypi-numpyro