Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
npx @tessl/cli install tessl/pypi-numpyro@0.19.00
# NumPyro
1
2
NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation to GPU/TPU/CPU. It enables Bayesian modeling and statistical inference through MCMC algorithms like Hamiltonian Monte Carlo and No U-Turn Sampler, variational inference methods, and a comprehensive distributions module. The library is designed for machine learning researchers and practitioners who need efficient probabilistic modeling capabilities with the ability to scale computations across different hardware platforms.
3
4
## Package Information
5
6
- **Package Name**: numpyro
7
- **Package Type**: pypi
8
- **Language**: Python
9
- **Installation**: `pip install numpyro`
10
- **Version**: 0.19.0
11
- **License**: Apache-2.0
12
- **Dependencies**: JAX, JAXLib, NumPy, tqdm, multipledispatch
13
14
## Core Imports
15
16
```python
17
import numpyro
18
```
19
20
Common patterns for probabilistic modeling:
21
22
```python
23
import numpyro
24
import numpyro.distributions as dist
25
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
26
from numpyro import sample, param, plate
27
```
28
29
JAX integration:
30
31
```python
32
import jax
33
import jax.numpy as jnp
34
from jax import random
35
```
36
37
## Basic Usage
38
39
```python
40
import numpyro
41
import numpyro.distributions as dist
42
from numpyro.infer import MCMC, NUTS
43
import jax.numpy as jnp
44
from jax import random
45
46
# Define a simple Bayesian linear regression model
47
def linear_regression(X, y=None):
48
# Priors
49
alpha = numpyro.sample('alpha', dist.Normal(0, 10))
50
beta = numpyro.sample('beta', dist.Normal(0, 10))
51
sigma = numpyro.sample('sigma', dist.Exponential(1))
52
53
# Linear model
54
mu = alpha + beta * X
55
56
# Likelihood
57
with numpyro.plate('data', X.shape[0]):
58
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
59
60
# Generate synthetic data
61
key = random.PRNGKey(0)
62
X = jnp.linspace(0, 1, 100)
63
true_alpha, true_beta = 1.0, 2.0
64
y = true_alpha + true_beta * X + 0.1 * random.normal(key, shape=(100,))
65
66
# Run MCMC inference
67
kernel = NUTS(linear_regression)
68
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
69
mcmc.run(random.PRNGKey(1), X, y)
70
71
# Get posterior samples
72
samples = mcmc.get_samples()
73
print(f"Posterior mean for alpha: {jnp.mean(samples['alpha']):.3f}")
74
print(f"Posterior mean for beta: {jnp.mean(samples['beta']):.3f}")
75
```
76
77
Variational inference example:
78
79
```python
80
from numpyro.infer import SVI, Trace_ELBO
81
from numpyro.infer.autoguide import AutoNormal
82
import optax
83
84
# Define guide (variational family)
85
guide = AutoNormal(linear_regression)
86
87
# Set up SVI
88
optimizer = optax.adam(0.01)
89
svi = SVI(linear_regression, guide, optimizer, Trace_ELBO())
90
91
# Run variational inference
92
svi_result = svi.run(random.PRNGKey(2), 2000, X, y)
93
```
94
95
## Architecture
96
97
NumPyro's architecture is built on several key design principles:
98
99
### Effect Handler System
100
NumPyro uses Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. This enables powerful model manipulation capabilities like conditioning on observed data, substituting values, and applying transformations.
101
102
### JAX Integration
103
Built on JAX, NumPyro leverages automatic differentiation, JIT compilation, and vectorization for high-performance numerical computing. This enables efficient gradient-based inference algorithms and scalable computations across CPU, GPU, and TPU.
104
105
### Distribution Library
106
A comprehensive collection of 150+ probability distributions organized by type (continuous, discrete, conjugate, directional, mixture, truncated) with consistent interfaces and support for batching and broadcasting.
107
108
### Inference Algorithms
109
Multiple inference backends including:
110
- **MCMC**: Hamiltonian Monte Carlo (HMC), No-U-Turn Sampler (NUTS), ensemble methods
111
- **Variational Inference**: Stochastic Variational Inference (SVI) with automatic guide generation
112
- **Specialized methods**: Nested sampling, Stein variational inference
113
114
### Primitives and Control Flow
115
Core primitives (`sample`, `param`, `plate`) for model construction with support for probabilistic control flow through JAX's functional programming primitives.
116
117
## Capabilities
118
119
### Probabilistic Primitives
120
121
Core primitives for defining probabilistic models including sampling from distributions, defining parameters, and handling conditional independence through plates.
122
123
```python { .api }
124
def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None,
125
rng_key: Optional[Array] = None, sample_shape: tuple = (),
126
infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike
127
def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None,
128
constraint: Constraint = constraints.real, event_dim: Optional[int] = None) -> ArrayLike
129
def plate(name: str, size: int, subsample_size: Optional[int] = None,
130
dim: Optional[int] = None) -> CondIndepStackFrame
131
def deterministic(name: str, value: ArrayLike) -> ArrayLike
132
def factor(name: str, log_factor: ArrayLike) -> None
133
```
134
135
[Primitives](./primitives.md)
136
137
### Probability Distributions
138
139
Comprehensive collection of 150+ probability distributions across continuous, discrete, conjugate, directional, mixture, and truncated families with consistent interfaces and extensive parameterization options.
140
141
```python { .api }
142
# Continuous distributions
143
class Normal(Distribution): ...
144
class Beta(Distribution): ...
145
class Gamma(Distribution): ...
146
class MultivariateNormal(Distribution): ...
147
148
# Discrete distributions
149
class Bernoulli(Distribution): ...
150
class Categorical(Distribution): ...
151
class Poisson(Distribution): ...
152
153
# Specialized distributions
154
class Mixture(Distribution): ...
155
class TruncatedDistribution(Distribution): ...
156
```
157
158
[Distributions](./distributions.md)
159
160
### Inference Algorithms
161
162
Multiple inference backends including MCMC samplers, variational inference methods, and ensemble techniques for Bayesian posterior computation.
163
164
```python { .api }
165
class MCMC:
166
def __init__(self, kernel, num_warmup: int, num_samples: int,
167
num_chains: int = 1, postprocess_fn: Optional[Callable] = None): ...
168
def run(self, rng_key: Array, *args, **kwargs) -> None: ...
169
def get_samples(self, group_by_chain: bool = False) -> dict: ...
170
171
class SVI:
172
def __init__(self, model, guide, optim, loss, **kwargs): ...
173
def run(self, rng_key: Array, num_steps: int, *args, **kwargs): ...
174
```
175
176
[Inference](./inference.md)
177
178
### Effect Handlers
179
180
Pyro-style effect handlers for intercepting and modifying probabilistic program execution, enabling conditioning, substitution, masking, and other model transformations.
181
182
```python { .api }
183
def trace(fn: Callable) -> Callable: ...
184
def replay(fn: Callable, trace: dict) -> Callable: ...
185
def condition(fn: Callable, data: dict) -> Callable: ...
186
def substitute(fn: Callable, data: dict) -> Callable: ...
187
def seed(fn: Callable, rng_seed: int) -> Callable: ...
188
def block(fn: Callable, hide_fn: Optional[Callable] = None,
189
expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable: ...
190
```
191
192
[Handlers](./handlers.md)
193
194
### Optimization
195
196
Collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation.
197
198
```python { .api }
199
class Adam:
200
def __init__(self, step_size: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8): ...
201
202
class SGD:
203
def __init__(self, step_size: float, momentum: float = 0): ...
204
205
class RMSProp:
206
def __init__(self, step_size: float, decay: float = 0.9, eps: float = 1e-8): ...
207
```
208
209
[Optimization](./optimization.md)
210
211
### Diagnostics
212
213
Diagnostic utilities for assessing MCMC convergence, effective sample size, and posterior summary statistics.
214
215
```python { .api }
216
def effective_sample_size(x: NDArray) -> NDArray: ...
217
def gelman_rubin(x: NDArray) -> NDArray: ...
218
def split_gelman_rubin(x: NDArray) -> NDArray: ...
219
def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray: ...
220
def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None: ...
221
```
222
223
[Diagnostics](./diagnostics.md)
224
225
### Utilities
226
227
JAX configuration utilities, control flow primitives, and helper functions for model development and debugging.
228
229
```python { .api }
230
def enable_x64(use_x64: bool = True) -> None: ...
231
def set_platform(platform: Optional[str] = None) -> None: ...
232
def set_host_device_count(n: int) -> None: ...
233
def cond(pred, true_operand, true_fun, false_operand, false_fun): ...
234
def while_loop(cond_fun, body_fun, init_val): ...
235
```
236
237
[Utilities](./utilities.md)
238
239
## Types
240
241
```python { .api }
242
from typing import Optional, Union, Callable, Dict, Any
243
from jax import Array
244
import jax.numpy as jnp
245
246
ArrayLike = Union[Array, jnp.ndarray, float, int]
247
NDArray = jnp.ndarray
248
Distribution = numpyro.distributions.Distribution
249
Constraint = numpyro.distributions.constraints.Constraint
250
251
class CondIndepStackFrame:
252
name: str
253
dim: int
254
size: int
255
subsample_size: Optional[int]
256
257
class Messenger:
258
def __enter__(self): ...
259
def __exit__(self, exc_type, exc_value, traceback): ...
260
def process_message(self, msg: dict) -> None: ...
261
```