0
# Inference
1
2
NumPyro provides multiple inference algorithms for Bayesian posterior computation including Markov Chain Monte Carlo (MCMC) samplers, variational inference methods, ensemble techniques, and specialized algorithms. All inference methods are built on JAX for efficient automatic differentiation and JIT compilation.
3
4
## Capabilities
5
6
### MCMC Algorithms
7
8
Markov Chain Monte Carlo methods for sampling from posterior distributions.
9
10
#### Core MCMC Infrastructure
11
12
```python { .api }
13
class MCMC:
14
"""
15
Wrapper class for Markov Chain Monte Carlo inference algorithms.
16
17
Args:
18
kernel: MCMC kernel (e.g., NUTS, HMC)
19
num_warmup: Number of warmup steps
20
num_samples: Number of samples to draw
21
num_chains: Number of parallel chains
22
postprocess_fn: Post-processing function for samples
23
chain_method: Parallelization method ('parallel', 'sequential', 'vectorized')
24
progress_bar: Whether to show progress bar
25
jit_model_args: Whether to JIT compile model arguments
26
"""
27
def __init__(self, kernel, num_warmup: int, num_samples: int, num_chains: int = 1,
28
postprocess_fn: Optional[Callable] = None, chain_method: str = 'parallel',
29
progress_bar: bool = True, jit_model_args: bool = False): ...
30
31
def run(self, rng_key: Array, *args, extra_fields=(), init_params=None, **kwargs) -> None:
32
"""
33
Run MCMC sampling.
34
35
Args:
36
rng_key: Random key for sampling
37
*args: Arguments to pass to the model
38
extra_fields: Additional fields to collect
39
init_params: Initial parameter values
40
**kwargs: Keyword arguments to pass to the model
41
"""
42
43
def get_samples(self, group_by_chain: bool = False) -> dict:
44
"""
45
Get posterior samples.
46
47
Args:
48
group_by_chain: Whether to group samples by chain
49
50
Returns:
51
Dictionary of posterior samples
52
"""
53
54
def get_extra_fields(self, group_by_chain: bool = False) -> dict:
55
"""Get additional collected fields (e.g., diagnostics)."""
56
57
def print_summary(self, prob: float = 0.9, exclude_deterministic: bool = True) -> None:
58
"""Print summary statistics of posterior samples."""
59
```
60
61
#### Hamiltonian Monte Carlo
62
63
```python { .api }
64
class HMC:
65
"""
66
Hamiltonian Monte Carlo kernel.
67
68
Args:
69
model: Python callable containing Pyro primitives
70
step_size: Step size for leapfrog integrator
71
num_steps: Number of leapfrog steps
72
adapt_step_size: Whether to adapt step size during warmup
73
adapt_mass_matrix: Whether to adapt mass matrix during warmup
74
dense_mass: Whether to use dense mass matrix
75
target_accept_prob: Target acceptance probability for step size adaptation
76
trajectory_length: Alternative to num_steps, specifies trajectory length
77
max_tree_depth: Maximum tree depth for trajectory building
78
find_heuristic_step_size: Whether to find good initial step size
79
forward_mode_differentiation: Whether to use forward-mode AD
80
regularize_mass_matrix: Whether to regularize mass matrix
81
"""
82
def __init__(self, model, step_size=1.0, num_steps=None, adapt_step_size=True,
83
adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8,
84
trajectory_length=None, max_tree_depth=10, find_heuristic_step_size=False,
85
forward_mode_differentiation=False, regularize_mass_matrix=True): ...
86
87
class NUTS:
88
"""
89
No-U-Turn Sampler, an adaptive variant of HMC.
90
91
Args:
92
model: Python callable containing Pyro primitives
93
step_size: Initial step size
94
adapt_step_size: Whether to adapt step size during warmup
95
adapt_mass_matrix: Whether to adapt mass matrix during warmup
96
dense_mass: Whether to use dense mass matrix
97
target_accept_prob: Target acceptance probability
98
max_tree_depth: Maximum tree depth for trajectory building
99
find_heuristic_step_size: Whether to find good initial step size
100
forward_mode_differentiation: Whether to use forward-mode AD
101
regularize_mass_matrix: Whether to regularize mass matrix
102
"""
103
def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,
104
dense_mass=False, target_accept_prob=0.8, max_tree_depth=10,
105
find_heuristic_step_size=False, forward_mode_differentiation=False,
106
regularize_mass_matrix=True): ...
107
108
class SA:
109
"""
110
Simulated Annealing kernel.
111
112
Args:
113
model: Python callable containing Pyro primitives
114
adapt_state_size: Size of adaptive state
115
restart_interval: Interval for restarting annealing
116
cooling_schedule: Temperature cooling schedule function
117
"""
118
def __init__(self, model, adapt_state_size=None, restart_interval=100,
119
cooling_schedule=None): ...
120
121
class BarkerMH:
122
"""
123
Barker Metropolis-Hastings kernel.
124
125
Args:
126
model: Python callable containing Pyro primitives
127
step_size: Step size for proposals
128
adapt_step_size: Whether to adapt step size
129
adapt_mass_matrix: Whether to adapt mass matrix
130
dense_mass: Whether to use dense mass matrix
131
target_accept_prob: Target acceptance probability
132
"""
133
def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,
134
dense_mass=False, target_accept_prob=0.234): ...
135
```
136
137
#### HMC Variants and Extensions
138
139
```python { .api }
140
class HMCGibbs:
141
"""
142
HMC-within-Gibbs sampler for models with discrete latent variables.
143
144
Args:
145
inner_kernel: Inner MCMC kernel (e.g., NUTS, HMC)
146
gibbs_fn: Gibbs sampling function for discrete variables
147
gibbs_sites: Names of discrete sites to sample with Gibbs
148
"""
149
def __init__(self, inner_kernel, gibbs_fn=None, gibbs_sites=None): ...
150
151
class DiscreteHMCGibbs:
152
"""
153
Specialized HMC-Gibbs for discrete variables.
154
155
Args:
156
inner_kernel: Inner kernel for continuous variables
157
modified: Whether to use modified proposal for discrete variables
158
gibbs_sites: Sites to sample with discrete Gibbs
159
"""
160
def __init__(self, inner_kernel, modified=True, gibbs_sites=None): ...
161
162
class HMCECS:
163
"""
164
HMC with Energy Conserving Subsampling for large datasets.
165
166
Args:
167
model: Python callable containing Pyro primitives
168
step_size: Step size for leapfrog integrator
169
trajectory_length: Length of HMC trajectory
170
num_blocks: Number of data blocks for subsampling
171
proxy: Proxy function for likelihood approximation
172
"""
173
def __init__(self, model, step_size=1.0, trajectory_length=1.0, num_blocks=1, proxy=None): ...
174
175
class MixedHMC:
176
"""
177
Mixed precision HMC for improved performance.
178
179
Args:
180
inner_kernel: Base HMC kernel
181
target_accept_prob: Target acceptance probability
182
trajectory_length: HMC trajectory length
183
"""
184
def __init__(self, inner_kernel, target_accept_prob=0.8, trajectory_length=1.0): ...
185
```
186
187
### Ensemble Methods
188
189
Ensemble sampling algorithms for parallel chain sampling.
190
191
```python { .api }
192
class ESS:
193
"""
194
Ensemble Slice Sampling.
195
196
Args:
197
model: Python callable containing Pyro primitives
198
max_slice_size: Maximum size of slice
199
num_slices: Number of slices per step
200
moves: Dictionary of move types and probabilities
201
"""
202
def __init__(self, model, max_slice_size=float('inf'), num_slices=1, moves=None): ...
203
204
class AIES:
205
"""
206
Affine Invariant Ensemble Sampler.
207
208
Args:
209
model: Python callable containing Pyro primitives
210
num_ensembles: Number of ensemble members
211
moves: Dictionary of move types and their configurations
212
"""
213
def __init__(self, model, num_ensembles=100, moves=None): ...
214
```
215
216
### Variational Inference
217
218
Stochastic variational inference for approximate posterior computation.
219
220
#### Core SVI Infrastructure
221
222
```python { .api }
223
class SVI:
224
"""
225
Stochastic Variational Inference.
226
227
Args:
228
model: Model function containing Pyro primitives
229
guide: Guide (variational family) function
230
optim: Optimizer for variational parameters
231
loss: Loss function (ELBO variant)
232
num_particles: Number of particles for gradient estimation
233
stable_update: Whether to use numerically stable updates
234
"""
235
def __init__(self, model, guide, optim, loss, num_particles=1, stable_update=False): ...
236
237
def run(self, rng_key: Array, num_steps: int, *args, progress_bar: bool = True,
238
stable_update: bool = False, **kwargs):
239
"""
240
Run stochastic variational inference.
241
242
Args:
243
rng_key: Random key for stochastic optimization
244
num_steps: Number of optimization steps
245
*args: Arguments to pass to model and guide
246
progress_bar: Whether to show progress bar
247
stable_update: Whether to use numerically stable updates
248
**kwargs: Keyword arguments to pass to model and guide
249
250
Returns:
251
SVIRunResult with losses and parameters
252
"""
253
254
def evaluate(self, rng_key: Array, *args, **kwargs) -> float:
255
"""Evaluate the current loss."""
256
257
def step(self, rng_key: Array, *args, **kwargs) -> float:
258
"""Take single SVI step."""
259
260
class SVIRunResult:
261
"""Result object from SVI.run()."""
262
losses: Array # Loss values over optimization
263
params: dict # Final parameter values
264
```
265
266
#### ELBO Objectives
267
268
```python { .api }
269
class ELBO:
270
"""
271
Base class for Evidence Lower BOund objectives.
272
273
Args:
274
num_particles: Number of particles for Monte Carlo estimation
275
vectorize_particles: Whether to vectorize over particles
276
ignore_jit_warnings: Whether to ignore JIT compilation warnings
277
"""
278
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
279
ignore_jit_warnings: bool = False): ...
280
281
def loss(self, rng_key: Array, param_map: dict, model: Callable, guide: Callable,
282
*args, **kwargs) -> float: ...
283
284
class Trace_ELBO(ELBO):
285
"""Standard ELBO using Monte Carlo estimation with reparameterized gradients."""
286
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
287
ignore_jit_warnings: bool = False): ...
288
289
class TraceEnum_ELBO(ELBO):
290
"""
291
ELBO with exact enumeration over discrete latent variables.
292
293
Args:
294
num_particles: Number of particles for continuous variables
295
max_plate_nesting: Maximum nesting level for enumeration
296
max_iarange_nesting: Deprecated alias for max_plate_nesting
297
strict_enumeration_warning: Whether to warn about enumeration issues
298
vectorize_particles: Whether to vectorize over particles
299
ignore_jit_warnings: Whether to ignore JIT warnings
300
"""
301
def __init__(self, num_particles: int = 1, max_plate_nesting: Optional[int] = None,
302
max_iarange_nesting: Optional[int] = None, strict_enumeration_warning: bool = True,
303
vectorize_particles: bool = False, ignore_jit_warnings: bool = False): ...
304
305
class TraceGraph_ELBO(ELBO):
306
"""
307
ELBO using Rao-Blackwellized gradient estimator.
308
309
Args:
310
num_particles: Number of particles for Monte Carlo estimation
311
vectorize_particles: Whether to vectorize over particles
312
ignore_jit_warnings: Whether to ignore JIT warnings
313
"""
314
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
315
ignore_jit_warnings: bool = False): ...
316
317
class TraceMeanField_ELBO(ELBO):
318
"""ELBO for mean field variational families."""
319
def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,
320
ignore_jit_warnings: bool = False): ...
321
322
class RenyiELBO(ELBO):
323
"""
324
Rényi divergence-based ELBO for more robust variational inference.
325
326
Args:
327
alpha: Rényi divergence parameter (alpha=1 recovers standard ELBO)
328
num_particles: Number of particles for Monte Carlo estimation
329
vectorize_particles: Whether to vectorize over particles
330
"""
331
def __init__(self, alpha: float = 0.0, num_particles: int = 1,
332
vectorize_particles: bool = False): ...
333
```
334
335
#### Automatic Guide Generation
336
337
```python { .api }
338
# Located in numpyro.infer.autoguide module
339
340
class AutoGuide:
341
"""Base class for automatic variational guides."""
342
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
343
create_plates=None): ...
344
345
def sample_posterior(self, rng_key: Array, params: dict, sample_shape=()) -> dict:
346
"""Sample from the approximate posterior."""
347
348
def median(self, params: dict) -> dict:
349
"""Compute median of the approximate posterior."""
350
351
def quantiles(self, params: dict, quantiles) -> dict:
352
"""Compute quantiles of the approximate posterior."""
353
354
class AutoNormal(AutoGuide):
355
"""
356
Multivariate normal variational family with diagonal covariance.
357
358
Args:
359
model: Model function
360
prefix: Prefix for parameter names
361
init_loc_fn: Initialization function for location parameters
362
init_scale: Initial scale for variational parameters
363
create_plates: Function to create plates for batched parameters
364
"""
365
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
366
init_scale: float = 0.1, create_plates=None): ...
367
368
class AutoMultivariateNormal(AutoGuide):
369
"""
370
Multivariate normal variational family with full covariance matrix.
371
372
Args:
373
model: Model function
374
prefix: Prefix for parameter names
375
init_loc_fn: Initialization function for location parameters
376
init_scale: Initial scale for variational parameters
377
"""
378
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
379
init_scale: float = 0.1): ...
380
381
class AutoLowRankMultivariateNormal(AutoGuide):
382
"""
383
Low-rank multivariate normal variational family.
384
385
Args:
386
model: Model function
387
prefix: Prefix for parameter names
388
init_loc_fn: Initialization function
389
rank: Rank of low-rank approximation
390
init_scale: Initial scale parameter
391
"""
392
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
393
rank: int = 1, init_scale: float = 0.1): ...
394
395
class AutoDiagonalNormal(AutoGuide):
396
"""Diagonal normal variational family (alias for AutoNormal)."""
397
398
class AutoLaplaceApproximation(AutoGuide):
399
"""
400
Laplace approximation around MAP estimate.
401
402
Args:
403
model: Model function
404
prefix: Prefix for parameter names
405
init_loc_fn: Initialization function
406
"""
407
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...
408
409
class AutoDelta(AutoGuide):
410
"""
411
Point estimate guide (MAP approximation).
412
413
Args:
414
model: Model function
415
prefix: Prefix for parameter names
416
init_loc_fn: Initialization function for point estimates
417
"""
418
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...
419
420
class AutoIAFNormal(AutoGuide):
421
"""
422
Inverse Autoregressive Flow with normal base distribution.
423
424
Args:
425
model: Model function
426
prefix: Prefix for parameter names
427
init_loc_fn: Initialization function
428
num_flows: Number of flow transformations
429
hidden_dims: Hidden dimensions for autoregressive networks
430
skip_connections: Whether to use skip connections
431
"""
432
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
433
num_flows: int = 3, hidden_dims=None, skip_connections: bool = False): ...
434
435
class AutoBNAFNormal(AutoGuide):
436
"""
437
Block Neural Autoregressive Flow with normal base distribution.
438
439
Args:
440
model: Model function
441
prefix: Prefix for parameter names
442
init_loc_fn: Initialization function
443
num_flows: Number of flow layers
444
hidden_factors: Hidden layer size factors
445
residual: Whether to use residual connections
446
"""
447
def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,
448
num_flows: int = 1, hidden_factors=None, residual=None): ...
449
450
class AutoSurrogateLikelihoodDAG(AutoGuide):
451
"""Surrogate likelihood guide for DAG models."""
452
def __init__(self, model: Callable, prefix: str = "auto"): ...
453
```
454
455
### Initialization Strategies
456
457
Functions for initializing MCMC chains and variational parameters.
458
459
```python { .api }
460
def init_to_feasible(model: Callable, *model_args, **model_kwargs):
461
"""
462
Initialize to feasible values within parameter constraints.
463
464
Args:
465
model: Model function
466
*model_args: Arguments to the model
467
**model_kwargs: Keyword arguments to the model
468
469
Returns:
470
Initialization function
471
"""
472
473
def init_to_mean(model: Callable, *model_args, **model_kwargs):
474
"""Initialize parameters to their prior means (when available)."""
475
476
def init_to_median(model: Callable, *model_args, **model_kwargs):
477
"""Initialize parameters to their prior medians (when available)."""
478
479
def init_to_sample(model: Callable, *model_args, **model_kwargs):
480
"""Initialize parameters to samples from their priors."""
481
482
def init_to_uniform(model: Callable, radius: float = 2.0, *model_args, **model_kwargs):
483
"""
484
Initialize parameters uniformly within their support.
485
486
Args:
487
model: Model function
488
radius: Radius for uniform initialization in unconstrained space
489
"""
490
491
def init_to_value(values: dict):
492
"""
493
Initialize parameters to specified values.
494
495
Args:
496
values: Dictionary mapping parameter names to initial values
497
"""
498
```
499
500
### Utilities
501
502
Utility functions for inference and posterior analysis.
503
504
```python { .api }
505
class Predictive:
506
"""
507
Utility for posterior and prior predictive sampling.
508
509
Args:
510
model: Model function
511
posterior_samples: Dictionary of posterior samples (optional)
512
guide: Guide function for variational inference (optional)
513
params: Parameters for guide (when using variational inference)
514
num_samples: Number of samples to draw
515
return_sites: Sites to return in predictions
516
infer_discrete: Whether to infer discrete latent variables
517
parallel: Whether to run predictions in parallel
518
batch_ndims: Number of batch dimensions in posterior samples
519
"""
520
def __init__(self, model: Callable, posterior_samples: Optional[dict] = None,
521
guide: Optional[Callable] = None, params: Optional[dict] = None,
522
num_samples: Optional[int] = None, return_sites: Optional[list] = None,
523
infer_discrete: bool = False, parallel: bool = False, batch_ndims: int = 1): ...
524
525
def __call__(self, rng_key: Array, *args, **kwargs) -> dict:
526
"""
527
Generate predictions.
528
529
Args:
530
rng_key: Random key for sampling
531
*args: Arguments to pass to model
532
**kwargs: Keyword arguments to pass to model
533
534
Returns:
535
Dictionary of predicted values
536
"""
537
538
def log_likelihood(model: Callable, posterior_samples: dict, *args, **kwargs) -> dict:
539
"""
540
Compute log likelihood of observations given posterior samples.
541
542
Args:
543
model: Model function
544
posterior_samples: Dictionary of posterior samples
545
*args: Arguments to pass to model
546
**kwargs: Keyword arguments to pass to model
547
548
Returns:
549
Dictionary of log likelihood values for each observed site
550
"""
551
552
def render_model(model: Callable, model_args=(), model_kwargs=None, filename=None,
553
render_distributions: bool = False, render_params: bool = False,
554
hide_deterministic: bool = True):
555
"""
556
Render model structure as a graphical diagram.
557
558
Args:
559
model: Model function to render
560
model_args: Arguments to pass to model
561
model_kwargs: Keyword arguments to pass to model
562
filename: Output filename for rendered graph
563
render_distributions: Whether to show distribution details
564
render_params: Whether to show parameter nodes
565
hide_deterministic: Whether to hide deterministic sites
566
"""
567
```
568
569
### Reparameterization
570
571
Reparameterization strategies for improving inference efficiency.
572
573
```python { .api }
574
# Located in numpyro.infer.reparam module
575
576
class Reparam:
577
"""Base class for reparameterizations."""
578
def __call__(self, name: str, fn, obs) -> tuple: ...
579
580
class LocScaleReparam(Reparam):
581
"""
582
Reparameterization for location-scale distributions.
583
584
Args:
585
centered: Parameterization type (0=non-centered, 1=centered, None=adaptive)
586
"""
587
def __init__(self, centered: Optional[float] = None): ...
588
589
class TransformReparam(Reparam):
590
"""
591
Reparameterization using bijective transforms.
592
593
Args:
594
transform: Bijective transformation
595
suffix: Suffix for transformed variable names
596
"""
597
def __init__(self, transform, suffix: str = "_base"): ...
598
599
class NeuTraReparam(Reparam):
600
"""
601
Neural Transport reparameterization.
602
603
Args:
604
guide: Neural guide for reparameterization
605
params: Parameters for the guide
606
"""
607
def __init__(self, guide: Callable, params: dict): ...
608
609
class CircularReparam(Reparam):
610
"""Reparameterization for circular variables."""
611
612
class ProjectedNormalReparam(Reparam):
613
"""Reparameterization for projected normal distributions."""
614
615
class ImplicitReparam(Reparam):
616
"""Implicit reparameterization for complex posteriors."""
617
618
class SplitReparam(Reparam):
619
"""Split reparameterization for multivariate distributions."""
620
def __init__(self, sections: list, dim: int = -1): ...
621
622
class SymmetricSplitReparam(Reparam):
623
"""Symmetric split reparameterization."""
624
def __init__(self, sections: list, dim: int = -1): ...
625
```
626
627
## Types
628
629
```python { .api }
630
from typing import Optional, Union, Callable, Dict, Any, Tuple
631
from jax import Array
632
import jax.numpy as jnp
633
634
ArrayLike = Union[Array, jnp.ndarray, float, int]
635
MCMCKernel = Union[HMC, NUTS, SA, BarkerMH, HMCGibbs, DiscreteHMCGibbs, HMCECS, MixedHMC, ESS, AIES]
636
Optimizer = Any # From optax or numpyro.optim
637
LossFunction = Union[ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, RenyiELBO]
638
InitFunction = Callable[[Array, tuple, dict], dict]
639
640
class SVIState:
641
"""State object for SVI optimization."""
642
optim_state: Any
643
rng_key: Array
644
645
class SVIRunResult:
646
"""Result from SVI.run()."""
647
losses: Array
648
params: dict
649
state: SVIState
650
651
class MCMCState:
652
"""Internal state for MCMC kernels."""
653
z: dict # Current parameter values
654
potential_energy: float
655
z_grad: dict # Current gradients
656
adapt_state: Any # Adaptation state
657
rng_key: Array
658
659
# Kernel interfaces
660
class MCMCKernel:
661
"""Base interface for MCMC kernels."""
662
def init(self, rng_key: Array, num_warmup: int, init_params: dict,
663
model_args: tuple, model_kwargs: dict) -> MCMCState: ...
664
def sample(self, state: MCMCState, model_args: tuple, model_kwargs: dict) -> MCMCState: ...
665
def postprocess_fn(self, args: tuple, kwargs: dict) -> Callable: ...
666
```