or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

index.mddocs/

0

# NumPyro

1

2

NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation to GPU/TPU/CPU. It enables Bayesian modeling and statistical inference through MCMC algorithms like Hamiltonian Monte Carlo and No U-Turn Sampler, variational inference methods, and a comprehensive distributions module. The library is designed for machine learning researchers and practitioners who need efficient probabilistic modeling capabilities with the ability to scale computations across different hardware platforms.

3

4

## Package Information

5

6

- **Package Name**: numpyro

7

- **Package Type**: pypi

8

- **Language**: Python

9

- **Installation**: `pip install numpyro`

10

- **Version**: 0.19.0

11

- **License**: Apache-2.0

12

- **Dependencies**: JAX, JAXLib, NumPy, tqdm, multipledispatch

13

14

## Core Imports

15

16

```python

17

import numpyro

18

```

19

20

Common patterns for probabilistic modeling:

21

22

```python

23

import numpyro

24

import numpyro.distributions as dist

25

from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO

26

from numpyro import sample, param, plate

27

```

28

29

JAX integration:

30

31

```python

32

import jax

33

import jax.numpy as jnp

34

from jax import random

35

```

36

37

## Basic Usage

38

39

```python

40

import numpyro

41

import numpyro.distributions as dist

42

from numpyro.infer import MCMC, NUTS

43

import jax.numpy as jnp

44

from jax import random

45

46

# Define a simple Bayesian linear regression model

47

def linear_regression(X, y=None):

48

# Priors

49

alpha = numpyro.sample('alpha', dist.Normal(0, 10))

50

beta = numpyro.sample('beta', dist.Normal(0, 10))

51

sigma = numpyro.sample('sigma', dist.Exponential(1))

52

53

# Linear model

54

mu = alpha + beta * X

55

56

# Likelihood

57

with numpyro.plate('data', X.shape[0]):

58

numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

59

60

# Generate synthetic data

61

key = random.PRNGKey(0)

62

X = jnp.linspace(0, 1, 100)

63

true_alpha, true_beta = 1.0, 2.0

64

y = true_alpha + true_beta * X + 0.1 * random.normal(key, shape=(100,))

65

66

# Run MCMC inference

67

kernel = NUTS(linear_regression)

68

mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)

69

mcmc.run(random.PRNGKey(1), X, y)

70

71

# Get posterior samples

72

samples = mcmc.get_samples()

73

print(f"Posterior mean for alpha: {jnp.mean(samples['alpha']):.3f}")

74

print(f"Posterior mean for beta: {jnp.mean(samples['beta']):.3f}")

75

```

76

77

Variational inference example:

78

79

```python

80

from numpyro.infer import SVI, Trace_ELBO

81

from numpyro.infer.autoguide import AutoNormal

82

import optax

83

84

# Define guide (variational family)

85

guide = AutoNormal(linear_regression)

86

87

# Set up SVI

88

optimizer = optax.adam(0.01)

89

svi = SVI(linear_regression, guide, optimizer, Trace_ELBO())

90

91

# Run variational inference

92

svi_result = svi.run(random.PRNGKey(2), 2000, X, y)

93

```

94

95

## Architecture

96

97

NumPyro's architecture is built on several key design principles:

98

99

### Effect Handler System

100

NumPyro uses Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. This enables powerful model manipulation capabilities like conditioning on observed data, substituting values, and applying transformations.

101

102

### JAX Integration

103

Built on JAX, NumPyro leverages automatic differentiation, JIT compilation, and vectorization for high-performance numerical computing. This enables efficient gradient-based inference algorithms and scalable computations across CPU, GPU, and TPU.

104

105

### Distribution Library

106

A comprehensive collection of 150+ probability distributions organized by type (continuous, discrete, conjugate, directional, mixture, truncated) with consistent interfaces and support for batching and broadcasting.

107

108

### Inference Algorithms

109

Multiple inference backends including:

110

- **MCMC**: Hamiltonian Monte Carlo (HMC), No-U-Turn Sampler (NUTS), ensemble methods

111

- **Variational Inference**: Stochastic Variational Inference (SVI) with automatic guide generation

112

- **Specialized methods**: Nested sampling, Stein variational inference

113

114

### Primitives and Control Flow

115

Core primitives (`sample`, `param`, `plate`) for model construction with support for probabilistic control flow through JAX's functional programming primitives.

116

117

## Capabilities

118

119

### Probabilistic Primitives

120

121

Core primitives for defining probabilistic models including sampling from distributions, defining parameters, and handling conditional independence through plates.

122

123

```python { .api }

124

def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None,

125

rng_key: Optional[Array] = None, sample_shape: tuple = (),

126

infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike

127

def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None,

128

constraint: Constraint = constraints.real, event_dim: Optional[int] = None) -> ArrayLike

129

def plate(name: str, size: int, subsample_size: Optional[int] = None,

130

dim: Optional[int] = None) -> CondIndepStackFrame

131

def deterministic(name: str, value: ArrayLike) -> ArrayLike

132

def factor(name: str, log_factor: ArrayLike) -> None

133

```

134

135

[Primitives](./primitives.md)

136

137

### Probability Distributions

138

139

Comprehensive collection of 150+ probability distributions across continuous, discrete, conjugate, directional, mixture, and truncated families with consistent interfaces and extensive parameterization options.

140

141

```python { .api }

142

# Continuous distributions

143

class Normal(Distribution): ...

144

class Beta(Distribution): ...

145

class Gamma(Distribution): ...

146

class MultivariateNormal(Distribution): ...

147

148

# Discrete distributions

149

class Bernoulli(Distribution): ...

150

class Categorical(Distribution): ...

151

class Poisson(Distribution): ...

152

153

# Specialized distributions

154

class Mixture(Distribution): ...

155

class TruncatedDistribution(Distribution): ...

156

```

157

158

[Distributions](./distributions.md)

159

160

### Inference Algorithms

161

162

Multiple inference backends including MCMC samplers, variational inference methods, and ensemble techniques for Bayesian posterior computation.

163

164

```python { .api }

165

class MCMC:

166

def __init__(self, kernel, num_warmup: int, num_samples: int,

167

num_chains: int = 1, postprocess_fn: Optional[Callable] = None): ...

168

def run(self, rng_key: Array, *args, **kwargs) -> None: ...

169

def get_samples(self, group_by_chain: bool = False) -> dict: ...

170

171

class SVI:

172

def __init__(self, model, guide, optim, loss, **kwargs): ...

173

def run(self, rng_key: Array, num_steps: int, *args, **kwargs): ...

174

```

175

176

[Inference](./inference.md)

177

178

### Effect Handlers

179

180

Pyro-style effect handlers for intercepting and modifying probabilistic program execution, enabling conditioning, substitution, masking, and other model transformations.

181

182

```python { .api }

183

def trace(fn: Callable) -> Callable: ...

184

def replay(fn: Callable, trace: dict) -> Callable: ...

185

def condition(fn: Callable, data: dict) -> Callable: ...

186

def substitute(fn: Callable, data: dict) -> Callable: ...

187

def seed(fn: Callable, rng_seed: int) -> Callable: ...

188

def block(fn: Callable, hide_fn: Optional[Callable] = None,

189

expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable: ...

190

```

191

192

[Handlers](./handlers.md)

193

194

### Optimization

195

196

Collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation.

197

198

```python { .api }

199

class Adam:

200

def __init__(self, step_size: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8): ...

201

202

class SGD:

203

def __init__(self, step_size: float, momentum: float = 0): ...

204

205

class RMSProp:

206

def __init__(self, step_size: float, decay: float = 0.9, eps: float = 1e-8): ...

207

```

208

209

[Optimization](./optimization.md)

210

211

### Diagnostics

212

213

Diagnostic utilities for assessing MCMC convergence, effective sample size, and posterior summary statistics.

214

215

```python { .api }

216

def effective_sample_size(x: NDArray) -> NDArray: ...

217

def gelman_rubin(x: NDArray) -> NDArray: ...

218

def split_gelman_rubin(x: NDArray) -> NDArray: ...

219

def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray: ...

220

def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None: ...

221

```

222

223

[Diagnostics](./diagnostics.md)

224

225

### Utilities

226

227

JAX configuration utilities, control flow primitives, and helper functions for model development and debugging.

228

229

```python { .api }

230

def enable_x64(use_x64: bool = True) -> None: ...

231

def set_platform(platform: Optional[str] = None) -> None: ...

232

def set_host_device_count(n: int) -> None: ...

233

def cond(pred, true_operand, true_fun, false_operand, false_fun): ...

234

def while_loop(cond_fun, body_fun, init_val): ...

235

```

236

237

[Utilities](./utilities.md)

238

239

## Types

240

241

```python { .api }

242

from typing import Optional, Union, Callable, Dict, Any

243

from jax import Array

244

import jax.numpy as jnp

245

246

ArrayLike = Union[Array, jnp.ndarray, float, int]

247

NDArray = jnp.ndarray

248

Distribution = numpyro.distributions.Distribution

249

Constraint = numpyro.distributions.constraints.Constraint

250

251

class CondIndepStackFrame:

252

name: str

253

dim: int

254

size: int

255

subsample_size: Optional[int]

256

257

class Messenger:

258

def __enter__(self): ...

259

def __exit__(self, exc_type, exc_value, traceback): ...

260

def process_message(self, msg: dict) -> None: ...

261

```