or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md
tile.json

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/numpyro@0.19.x

To install, run

npx @tessl/cli install tessl/pypi-numpyro@0.19.0

index.mddocs/

NumPyro

NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation to GPU/TPU/CPU. It enables Bayesian modeling and statistical inference through MCMC algorithms like Hamiltonian Monte Carlo and No U-Turn Sampler, variational inference methods, and a comprehensive distributions module. The library is designed for machine learning researchers and practitioners who need efficient probabilistic modeling capabilities with the ability to scale computations across different hardware platforms.

Package Information

  • Package Name: numpyro
  • Package Type: pypi
  • Language: Python
  • Installation: pip install numpyro
  • Version: 0.19.0
  • License: Apache-2.0
  • Dependencies: JAX, JAXLib, NumPy, tqdm, multipledispatch

Core Imports

import numpyro

Common patterns for probabilistic modeling:

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro import sample, param, plate

JAX integration:

import jax
import jax.numpy as jnp
from jax import random

Basic Usage

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
from jax import random

# Define a simple Bayesian linear regression model
def linear_regression(X, y=None):
    # Priors
    alpha = numpyro.sample('alpha', dist.Normal(0, 10))
    beta = numpyro.sample('beta', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.Exponential(1))
    
    # Linear model
    mu = alpha + beta * X
    
    # Likelihood
    with numpyro.plate('data', X.shape[0]):
        numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

# Generate synthetic data
key = random.PRNGKey(0)
X = jnp.linspace(0, 1, 100)
true_alpha, true_beta = 1.0, 2.0
y = true_alpha + true_beta * X + 0.1 * random.normal(key, shape=(100,))

# Run MCMC inference
kernel = NUTS(linear_regression)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(1), X, y)

# Get posterior samples
samples = mcmc.get_samples()
print(f"Posterior mean for alpha: {jnp.mean(samples['alpha']):.3f}")
print(f"Posterior mean for beta: {jnp.mean(samples['beta']):.3f}")

Variational inference example:

from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import optax

# Define guide (variational family)
guide = AutoNormal(linear_regression)

# Set up SVI
optimizer = optax.adam(0.01)
svi = SVI(linear_regression, guide, optimizer, Trace_ELBO())

# Run variational inference
svi_result = svi.run(random.PRNGKey(2), 2000, X, y)

Architecture

NumPyro's architecture is built on several key design principles:

Effect Handler System

NumPyro uses Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. This enables powerful model manipulation capabilities like conditioning on observed data, substituting values, and applying transformations.

JAX Integration

Built on JAX, NumPyro leverages automatic differentiation, JIT compilation, and vectorization for high-performance numerical computing. This enables efficient gradient-based inference algorithms and scalable computations across CPU, GPU, and TPU.

Distribution Library

A comprehensive collection of 150+ probability distributions organized by type (continuous, discrete, conjugate, directional, mixture, truncated) with consistent interfaces and support for batching and broadcasting.

Inference Algorithms

Multiple inference backends including:

  • MCMC: Hamiltonian Monte Carlo (HMC), No-U-Turn Sampler (NUTS), ensemble methods
  • Variational Inference: Stochastic Variational Inference (SVI) with automatic guide generation
  • Specialized methods: Nested sampling, Stein variational inference

Primitives and Control Flow

Core primitives (sample, param, plate) for model construction with support for probabilistic control flow through JAX's functional programming primitives.

Capabilities

Probabilistic Primitives

Core primitives for defining probabilistic models including sampling from distributions, defining parameters, and handling conditional independence through plates.

def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None, 
          rng_key: Optional[Array] = None, sample_shape: tuple = (), 
          infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike
def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None, 
         constraint: Constraint = constraints.real, event_dim: Optional[int] = None) -> ArrayLike
def plate(name: str, size: int, subsample_size: Optional[int] = None, 
         dim: Optional[int] = None) -> CondIndepStackFrame
def deterministic(name: str, value: ArrayLike) -> ArrayLike
def factor(name: str, log_factor: ArrayLike) -> None

Primitives

Probability Distributions

Comprehensive collection of 150+ probability distributions across continuous, discrete, conjugate, directional, mixture, and truncated families with consistent interfaces and extensive parameterization options.

# Continuous distributions
class Normal(Distribution): ...
class Beta(Distribution): ...
class Gamma(Distribution): ...
class MultivariateNormal(Distribution): ...

# Discrete distributions  
class Bernoulli(Distribution): ...
class Categorical(Distribution): ...
class Poisson(Distribution): ...

# Specialized distributions
class Mixture(Distribution): ...
class TruncatedDistribution(Distribution): ...

Distributions

Inference Algorithms

Multiple inference backends including MCMC samplers, variational inference methods, and ensemble techniques for Bayesian posterior computation.

class MCMC:
    def __init__(self, kernel, num_warmup: int, num_samples: int, 
                num_chains: int = 1, postprocess_fn: Optional[Callable] = None): ...
    def run(self, rng_key: Array, *args, **kwargs) -> None: ...
    def get_samples(self, group_by_chain: bool = False) -> dict: ...

class SVI:
    def __init__(self, model, guide, optim, loss, **kwargs): ...
    def run(self, rng_key: Array, num_steps: int, *args, **kwargs): ...

Inference

Effect Handlers

Pyro-style effect handlers for intercepting and modifying probabilistic program execution, enabling conditioning, substitution, masking, and other model transformations.

def trace(fn: Callable) -> Callable: ...
def replay(fn: Callable, trace: dict) -> Callable: ...
def condition(fn: Callable, data: dict) -> Callable: ...
def substitute(fn: Callable, data: dict) -> Callable: ...
def seed(fn: Callable, rng_seed: int) -> Callable: ...
def block(fn: Callable, hide_fn: Optional[Callable] = None, 
         expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable: ...

Handlers

Optimization

Collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation.

class Adam:
    def __init__(self, step_size: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8): ...

class SGD:
    def __init__(self, step_size: float, momentum: float = 0): ...

class RMSProp:
    def __init__(self, step_size: float, decay: float = 0.9, eps: float = 1e-8): ...

Optimization

Diagnostics

Diagnostic utilities for assessing MCMC convergence, effective sample size, and posterior summary statistics.

def effective_sample_size(x: NDArray) -> NDArray: ...
def gelman_rubin(x: NDArray) -> NDArray: ...
def split_gelman_rubin(x: NDArray) -> NDArray: ...
def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray: ...
def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None: ...

Diagnostics

Utilities

JAX configuration utilities, control flow primitives, and helper functions for model development and debugging.

def enable_x64(use_x64: bool = True) -> None: ...
def set_platform(platform: Optional[str] = None) -> None: ...
def set_host_device_count(n: int) -> None: ...
def cond(pred, true_operand, true_fun, false_operand, false_fun): ...
def while_loop(cond_fun, body_fun, init_val): ...

Utilities

Types

from typing import Optional, Union, Callable, Dict, Any
from jax import Array
import jax.numpy as jnp

ArrayLike = Union[Array, jnp.ndarray, float, int]
NDArray = jnp.ndarray
Distribution = numpyro.distributions.Distribution
Constraint = numpyro.distributions.constraints.Constraint

class CondIndepStackFrame:
    name: str
    dim: int
    size: int
    subsample_size: Optional[int]

class Messenger:
    def __enter__(self): ...
    def __exit__(self, exc_type, exc_value, traceback): ...
    def process_message(self, msg: dict) -> None: ...