CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

inference.mddocs/

Inference

NumPyro provides multiple inference algorithms for Bayesian posterior computation including Markov Chain Monte Carlo (MCMC) samplers, variational inference methods, ensemble techniques, and specialized algorithms. All inference methods are built on JAX for efficient automatic differentiation and JIT compilation.

Capabilities

MCMC Algorithms

Markov Chain Monte Carlo methods for sampling from posterior distributions.

Core MCMC Infrastructure

class MCMC:
    """
    Wrapper class for Markov Chain Monte Carlo inference algorithms.
    
    Args:
        kernel: MCMC kernel (e.g., NUTS, HMC)
        num_warmup: Number of warmup steps
        num_samples: Number of samples to draw
        num_chains: Number of parallel chains
        postprocess_fn: Post-processing function for samples
        chain_method: Parallelization method ('parallel', 'sequential', 'vectorized')
        progress_bar: Whether to show progress bar
        jit_model_args: Whether to JIT compile model arguments
    """
    def __init__(self, kernel, num_warmup: int, num_samples: int, num_chains: int = 1,
                postprocess_fn: Optional[Callable] = None, chain_method: str = 'parallel',
                progress_bar: bool = True, jit_model_args: bool = False): ...
    
    def run(self, rng_key: Array, *args, extra_fields=(), init_params=None, **kwargs) -> None:
        """
        Run MCMC sampling.
        
        Args:
            rng_key: Random key for sampling
            *args: Arguments to pass to the model
            extra_fields: Additional fields to collect
            init_params: Initial parameter values
            **kwargs: Keyword arguments to pass to the model
        """
    
    def get_samples(self, group_by_chain: bool = False) -> dict:
        """
        Get posterior samples.
        
        Args:
            group_by_chain: Whether to group samples by chain
            
        Returns:
            Dictionary of posterior samples
        """
    
    def get_extra_fields(self, group_by_chain: bool = False) -> dict:
        """Get additional collected fields (e.g., diagnostics)."""
    
    def print_summary(self, prob: float = 0.9, exclude_deterministic: bool = True) -> None:
        """Print summary statistics of posterior samples."""

Hamiltonian Monte Carlo

class HMC:
    """
    Hamiltonian Monte Carlo kernel.
    
    Args:
        model: Python callable containing Pyro primitives
        step_size: Step size for leapfrog integrator
        num_steps: Number of leapfrog steps
        adapt_step_size: Whether to adapt step size during warmup
        adapt_mass_matrix: Whether to adapt mass matrix during warmup
        dense_mass: Whether to use dense mass matrix
        target_accept_prob: Target acceptance probability for step size adaptation
        trajectory_length: Alternative to num_steps, specifies trajectory length
        max_tree_depth: Maximum tree depth for trajectory building
        find_heuristic_step_size: Whether to find good initial step size
        forward_mode_differentiation: Whether to use forward-mode AD
        regularize_mass_matrix: Whether to regularize mass matrix
    """
    def __init__(self, model, step_size=1.0, num_steps=None, adapt_step_size=True,
                adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8,
                trajectory_length=None, max_tree_depth=10, find_heuristic_step_size=False,
                forward_mode_differentiation=False, regularize_mass_matrix=True): ...

class NUTS:
    """
    No-U-Turn Sampler, an adaptive variant of HMC.
    
    Args:
        model: Python callable containing Pyro primitives
        step_size: Initial step size
        adapt_step_size: Whether to adapt step size during warmup
        adapt_mass_matrix: Whether to adapt mass matrix during warmup  
        dense_mass: Whether to use dense mass matrix
        target_accept_prob: Target acceptance probability
        max_tree_depth: Maximum tree depth for trajectory building
        find_heuristic_step_size: Whether to find good initial step size
        forward_mode_differentiation: Whether to use forward-mode AD
        regularize_mass_matrix: Whether to regularize mass matrix
    """
    def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,
                dense_mass=False, target_accept_prob=0.8, max_tree_depth=10, 
                find_heuristic_step_size=False, forward_mode_differentiation=False,
                regularize_mass_matrix=True): ...

class SA:
    """
    Simulated Annealing kernel.
    
    Args:
        model: Python callable containing Pyro primitives
        adapt_state_size: Size of adaptive state
        restart_interval: Interval for restarting annealing
        cooling_schedule: Temperature cooling schedule function
    """
    def __init__(self, model, adapt_state_size=None, restart_interval=100, 
                cooling_schedule=None): ...

class BarkerMH:
    """
    Barker Metropolis-Hastings kernel.
    
    Args:
        model: Python callable containing Pyro primitives
        step_size: Step size for proposals
        adapt_step_size: Whether to adapt step size
        adapt_mass_matrix: Whether to adapt mass matrix
        dense_mass: Whether to use dense mass matrix
        target_accept_prob: Target acceptance probability
    """
    def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,
                dense_mass=False, target_accept_prob=0.234): ...

HMC Variants and Extensions

class HMCGibbs:
    """
    HMC-within-Gibbs sampler for models with discrete latent variables.
    
    Args:
        inner_kernel: Inner MCMC kernel (e.g., NUTS, HMC)
        gibbs_fn: Gibbs sampling function for discrete variables  
        gibbs_sites: Names of discrete sites to sample with Gibbs
    """
    def __init__(self, inner_kernel, gibbs_fn=None, gibbs_sites=None): ...

class DiscreteHMCGibbs:
    """
    Specialized HMC-Gibbs for discrete variables.
    
    Args:
        inner_kernel: Inner kernel for continuous variables
        modified: Whether to use modified proposal for discrete variables
        gibbs_sites: Sites to sample with discrete Gibbs
    """
    def __init__(self, inner_kernel, modified=True, gibbs_sites=None): ...

class HMCECS:
    """
    HMC with Energy Conserving Subsampling for large datasets.
    
    Args:
        model: Python callable containing Pyro primitives
        step_size: Step size for leapfrog integrator
        trajectory_length: Length of HMC trajectory
        num_blocks: Number of data blocks for subsampling
        proxy: Proxy function for likelihood approximation
    """
    def __init__(self, model, step_size=1.0, trajectory_length=1.0, num_blocks=1, proxy=None): ...

class MixedHMC:
    """
    Mixed precision HMC for improved performance.
    
    Args:
        inner_kernel: Base HMC kernel
        target_accept_prob: Target acceptance probability
        trajectory_length: HMC trajectory length
    """
    def __init__(self, inner_kernel, target_accept_prob=0.8, trajectory_length=1.0): ...

Ensemble Methods

Ensemble sampling algorithms for parallel chain sampling.

class ESS:
    """
    Ensemble Slice Sampling.
    
    Args:
        model: Python callable containing Pyro primitives
        max_slice_size: Maximum size of slice
        num_slices: Number of slices per step
        moves: Dictionary of move types and probabilities
    """
    def __init__(self, model, max_slice_size=float('inf'), num_slices=1, moves=None): ...

class AIES:
    """
    Affine Invariant Ensemble Sampler.
    
    Args:
        model: Python callable containing Pyro primitives
        num_ensembles: Number of ensemble members
        moves: Dictionary of move types and their configurations
    """
    def __init__(self, model, num_ensembles=100, moves=None): ...

Variational Inference

Stochastic variational inference for approximate posterior computation.

Core SVI Infrastructure

class SVI:
    """
    Stochastic Variational Inference.
    
    Args:
        model: Model function containing Pyro primitives
        guide: Guide (variational family) function
        optim: Optimizer for variational parameters
        loss: Loss function (ELBO variant)
        num_particles: Number of particles for gradient estimation
        stable_update: Whether to use numerically stable updates
    """
    def __init__(self, model, guide, optim, loss, num_particles=1, stable_update=False): ...
    
    def run(self, rng_key: Array, num_steps: int, *args, progress_bar: bool = True,
            stable_update: bool = False, **kwargs):
        """
        Run stochastic variational inference.
        
        Args:
            rng_key: Random key for stochastic optimization
            num_steps: Number of optimization steps
            *args: Arguments to pass to model and guide
            progress_bar: Whether to show progress bar
            stable_update: Whether to use numerically stable updates
            **kwargs: Keyword arguments to pass to model and guide
            
        Returns:
            SVIRunResult with losses and parameters
        """
    
    def evaluate(self, rng_key: Array, *args, **kwargs) -> float:
        """Evaluate the current loss."""
    
    def step(self, rng_key: Array, *args, **kwargs) -> float:
        """Take single SVI step."""

class SVIRunResult:
    """Result object from SVI.run()."""
    losses: Array  # Loss values over optimization
    params: dict   # Final parameter values

ELBO Objectives

class ELBO:
    """
    Base class for Evidence Lower BOund objectives.
    
    Args:
        num_particles: Number of particles for Monte Carlo estimation
        vectorize_particles: Whether to vectorize over particles
        ignore_jit_warnings: Whether to ignore JIT compilation warnings
    """
    def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
                ignore_jit_warnings: bool = False): ...
    
    def loss(self, rng_key: Array, param_map: dict, model: Callable, guide: Callable,
            *args, **kwargs) -> float: ...

class Trace_ELBO(ELBO):
    """Standard ELBO using Monte Carlo estimation with reparameterized gradients."""
    def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
                ignore_jit_warnings: bool = False): ...

class TraceEnum_ELBO(ELBO):
    """
    ELBO with exact enumeration over discrete latent variables.
    
    Args:
        num_particles: Number of particles for continuous variables
        max_plate_nesting: Maximum nesting level for enumeration
        max_iarange_nesting: Deprecated alias for max_plate_nesting
        strict_enumeration_warning: Whether to warn about enumeration issues
        vectorize_particles: Whether to vectorize over particles
        ignore_jit_warnings: Whether to ignore JIT warnings
    """
    def __init__(self, num_particles: int = 1, max_plate_nesting: Optional[int] = None,
                max_iarange_nesting: Optional[int] = None, strict_enumeration_warning: bool = True,
                vectorize_particles: bool = False, ignore_jit_warnings: bool = False): ...

class TraceGraph_ELBO(ELBO):
    """
    ELBO using Rao-Blackwellized gradient estimator.
    
    Args:
        num_particles: Number of particles for Monte Carlo estimation
        vectorize_particles: Whether to vectorize over particles
        ignore_jit_warnings: Whether to ignore JIT warnings
    """
    def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
                ignore_jit_warnings: bool = False): ...

class TraceMeanField_ELBO(ELBO):
    """ELBO for mean field variational families."""
    def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
                ignore_jit_warnings: bool = False): ...

class RenyiELBO(ELBO):
    """
    Rényi divergence-based ELBO for more robust variational inference.
    
    Args:
        alpha: Rényi divergence parameter (alpha=1 recovers standard ELBO)
        num_particles: Number of particles for Monte Carlo estimation
        vectorize_particles: Whether to vectorize over particles
    """
    def __init__(self, alpha: float = 0.0, num_particles: int = 1, 
                vectorize_particles: bool = False): ...

Automatic Guide Generation

# Located in numpyro.infer.autoguide module

class AutoGuide:
    """Base class for automatic variational guides."""
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
                create_plates=None): ...
    
    def sample_posterior(self, rng_key: Array, params: dict, sample_shape=()) -> dict:
        """Sample from the approximate posterior."""
    
    def median(self, params: dict) -> dict:
        """Compute median of the approximate posterior."""
    
    def quantiles(self, params: dict, quantiles) -> dict:
        """Compute quantiles of the approximate posterior."""

class AutoNormal(AutoGuide):
    """
    Multivariate normal variational family with diagonal covariance.
    
    Args:
        model: Model function
        prefix: Prefix for parameter names
        init_loc_fn: Initialization function for location parameters
        init_scale: Initial scale for variational parameters
        create_plates: Function to create plates for batched parameters
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
                init_scale: float = 0.1, create_plates=None): ...

class AutoMultivariateNormal(AutoGuide):
    """
    Multivariate normal variational family with full covariance matrix.
    
    Args:
        model: Model function
        prefix: Prefix for parameter names  
        init_loc_fn: Initialization function for location parameters
        init_scale: Initial scale for variational parameters
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
                init_scale: float = 0.1): ...

class AutoLowRankMultivariateNormal(AutoGuide):
    """
    Low-rank multivariate normal variational family.
    
    Args:
        model: Model function
        prefix: Prefix for parameter names
        init_loc_fn: Initialization function
        rank: Rank of low-rank approximation
        init_scale: Initial scale parameter
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
                rank: int = 1, init_scale: float = 0.1): ...

class AutoDiagonalNormal(AutoGuide):
    """Diagonal normal variational family (alias for AutoNormal)."""

class AutoLaplaceApproximation(AutoGuide):
    """
    Laplace approximation around MAP estimate.
    
    Args:
        model: Model function
        prefix: Prefix for parameter names
        init_loc_fn: Initialization function
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...

class AutoDelta(AutoGuide):
    """
    Point estimate guide (MAP approximation).
    
    Args:
        model: Model function  
        prefix: Prefix for parameter names
        init_loc_fn: Initialization function for point estimates
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...

class AutoIAFNormal(AutoGuide):
    """
    Inverse Autoregressive Flow with normal base distribution.
    
    Args:
        model: Model function
        prefix: Prefix for parameter names
        init_loc_fn: Initialization function
        num_flows: Number of flow transformations
        hidden_dims: Hidden dimensions for autoregressive networks
        skip_connections: Whether to use skip connections
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
                num_flows: int = 3, hidden_dims=None, skip_connections: bool = False): ...

class AutoBNAFNormal(AutoGuide):
    """
    Block Neural Autoregressive Flow with normal base distribution.
    
    Args:
        model: Model function
        prefix: Prefix for parameter names
        init_loc_fn: Initialization function
        num_flows: Number of flow layers
        hidden_factors: Hidden layer size factors
        residual: Whether to use residual connections
    """
    def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
                num_flows: int = 1, hidden_factors=None, residual=None): ...

class AutoSurrogateLikelihoodDAG(AutoGuide):
    """Surrogate likelihood guide for DAG models."""
    def __init__(self, model: Callable, prefix: str = "auto"): ...

Initialization Strategies

Functions for initializing MCMC chains and variational parameters.

def init_to_feasible(model: Callable, *model_args, **model_kwargs):
    """
    Initialize to feasible values within parameter constraints.
    
    Args:
        model: Model function
        *model_args: Arguments to the model
        **model_kwargs: Keyword arguments to the model
        
    Returns:
        Initialization function
    """

def init_to_mean(model: Callable, *model_args, **model_kwargs):
    """Initialize parameters to their prior means (when available)."""

def init_to_median(model: Callable, *model_args, **model_kwargs):  
    """Initialize parameters to their prior medians (when available)."""

def init_to_sample(model: Callable, *model_args, **model_kwargs):
    """Initialize parameters to samples from their priors."""

def init_to_uniform(model: Callable, radius: float = 2.0, *model_args, **model_kwargs):
    """
    Initialize parameters uniformly within their support.
    
    Args:
        model: Model function
        radius: Radius for uniform initialization in unconstrained space
    """

def init_to_value(values: dict):
    """
    Initialize parameters to specified values.
    
    Args:
        values: Dictionary mapping parameter names to initial values
    """

Utilities

Utility functions for inference and posterior analysis.

class Predictive:
    """
    Utility for posterior and prior predictive sampling.
    
    Args:
        model: Model function
        posterior_samples: Dictionary of posterior samples (optional)
        guide: Guide function for variational inference (optional)  
        params: Parameters for guide (when using variational inference)
        num_samples: Number of samples to draw
        return_sites: Sites to return in predictions
        infer_discrete: Whether to infer discrete latent variables
        parallel: Whether to run predictions in parallel
        batch_ndims: Number of batch dimensions in posterior samples
    """
    def __init__(self, model: Callable, posterior_samples: Optional[dict] = None,
                guide: Optional[Callable] = None, params: Optional[dict] = None,
                num_samples: Optional[int] = None, return_sites: Optional[list] = None,
                infer_discrete: bool = False, parallel: bool = False, batch_ndims: int = 1): ...
    
    def __call__(self, rng_key: Array, *args, **kwargs) -> dict:
        """
        Generate predictions.
        
        Args:
            rng_key: Random key for sampling
            *args: Arguments to pass to model
            **kwargs: Keyword arguments to pass to model
            
        Returns:
            Dictionary of predicted values
        """

def log_likelihood(model: Callable, posterior_samples: dict, *args, **kwargs) -> dict:
    """
    Compute log likelihood of observations given posterior samples.
    
    Args:
        model: Model function
        posterior_samples: Dictionary of posterior samples
        *args: Arguments to pass to model  
        **kwargs: Keyword arguments to pass to model
        
    Returns:
        Dictionary of log likelihood values for each observed site
    """

def render_model(model: Callable, model_args=(), model_kwargs=None, filename=None,
                render_distributions: bool = False, render_params: bool = False,
                hide_deterministic: bool = True):
    """
    Render model structure as a graphical diagram.
    
    Args:
        model: Model function to render
        model_args: Arguments to pass to model
        model_kwargs: Keyword arguments to pass to model
        filename: Output filename for rendered graph
        render_distributions: Whether to show distribution details
        render_params: Whether to show parameter nodes
        hide_deterministic: Whether to hide deterministic sites
    """

Reparameterization

Reparameterization strategies for improving inference efficiency.

# Located in numpyro.infer.reparam module

class Reparam:
    """Base class for reparameterizations."""
    def __call__(self, name: str, fn, obs) -> tuple: ...

class LocScaleReparam(Reparam):
    """
    Reparameterization for location-scale distributions.
    
    Args:
        centered: Parameterization type (0=non-centered, 1=centered, None=adaptive)
    """
    def __init__(self, centered: Optional[float] = None): ...

class TransformReparam(Reparam):
    """
    Reparameterization using bijective transforms.
    
    Args:
        transform: Bijective transformation
        suffix: Suffix for transformed variable names
    """
    def __init__(self, transform, suffix: str = "_base"): ...

class NeuTraReparam(Reparam):
    """
    Neural Transport reparameterization.
    
    Args:
        guide: Neural guide for reparameterization
        params: Parameters for the guide
    """
    def __init__(self, guide: Callable, params: dict): ...

class CircularReparam(Reparam):
    """Reparameterization for circular variables."""

class ProjectedNormalReparam(Reparam):
    """Reparameterization for projected normal distributions."""

class ImplicitReparam(Reparam):
    """Implicit reparameterization for complex posteriors."""

class SplitReparam(Reparam):
    """Split reparameterization for multivariate distributions."""
    def __init__(self, sections: list, dim: int = -1): ...

class SymmetricSplitReparam(Reparam):
    """Symmetric split reparameterization."""
    def __init__(self, sections: list, dim: int = -1): ...

Types

from typing import Optional, Union, Callable, Dict, Any, Tuple
from jax import Array
import jax.numpy as jnp

ArrayLike = Union[Array, jnp.ndarray, float, int]
MCMCKernel = Union[HMC, NUTS, SA, BarkerMH, HMCGibbs, DiscreteHMCGibbs, HMCECS, MixedHMC, ESS, AIES]
Optimizer = Any  # From optax or numpyro.optim
LossFunction = Union[ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, RenyiELBO]
InitFunction = Callable[[Array, tuple, dict], dict]

class SVIState:
    """State object for SVI optimization."""
    optim_state: Any
    rng_key: Array

class SVIRunResult:
    """Result from SVI.run()."""
    losses: Array
    params: dict
    state: SVIState

class MCMCState:
    """Internal state for MCMC kernels."""
    z: dict  # Current parameter values
    potential_energy: float
    z_grad: dict  # Current gradients
    adapt_state: Any  # Adaptation state
    rng_key: Array
    
# Kernel interfaces
class MCMCKernel:
    """Base interface for MCMC kernels."""
    def init(self, rng_key: Array, num_warmup: int, init_params: dict, 
            model_args: tuple, model_kwargs: dict) -> MCMCState: ...
    def sample(self, state: MCMCState, model_args: tuple, model_kwargs: dict) -> MCMCState: ...
    def postprocess_fn(self, args: tuple, kwargs: dict) -> Callable: ...

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json