0
# Inference Methods
1
2
Scalable inference algorithms for posterior approximation and model learning, including variational inference, Markov Chain Monte Carlo, and specialized sampling methods for probabilistic programs.
3
4
## Capabilities
5
6
### Stochastic Variational Inference
7
8
Gradient-based variational inference for scalable approximate posterior computation.
9
10
```python { .api }
11
class SVI:
12
"""
13
Stochastic Variational Inference for scalable posterior approximation.
14
15
SVI optimizes variational parameters to minimize the KL divergence between
16
a variational guide and the true posterior distribution.
17
"""
18
19
def __init__(self, model, guide, optim, loss):
20
"""
21
Initialize SVI with model, guide, optimizer and loss function.
22
23
Parameters:
24
- model (callable): Generative model function
25
- guide (callable): Variational guide function that approximates posterior
26
- optim (PyroOptim): Pyro optimizer wrapping PyTorch optimizer
27
- loss (ELBO): Evidence Lower Bound loss function
28
29
Examples:
30
>>> svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
31
"""
32
33
def step(self, *args, **kwargs) -> float:
34
"""
35
Perform one SVI optimization step.
36
37
Parameters:
38
- *args, **kwargs: Arguments to pass to model and guide
39
40
Returns:
41
float: Loss value for this step (negative ELBO)
42
43
Examples:
44
>>> loss = svi.step(data)
45
>>> print(f"Loss: {loss}")
46
"""
47
48
def evaluate_loss(self, *args, **kwargs) -> float:
49
"""
50
Evaluate loss without taking optimization step.
51
52
Parameters:
53
- *args, **kwargs: Arguments to pass to model and guide
54
55
Returns:
56
float: Current loss value
57
"""
58
59
def init_to_feasible(site: dict = None) -> torch.Tensor:
60
"""
61
Initialize parameters to feasible values within constraints.
62
63
Parameters:
64
- site (dict, optional): Sample site information
65
66
Returns:
67
Tensor: Feasible initialization value
68
"""
69
70
def init_to_mean(site: dict = None) -> torch.Tensor:
71
"""
72
Initialize parameters to distribution mean.
73
74
Parameters:
75
- site (dict, optional): Sample site information
76
77
Returns:
78
Tensor: Mean initialization value
79
"""
80
81
def init_to_sample(site: dict = None) -> torch.Tensor:
82
"""
83
Initialize parameters to random samples from prior.
84
85
Parameters:
86
- site (dict, optional): Sample site information
87
88
Returns:
89
Tensor: Random sample initialization
90
"""
91
```
92
93
### Evidence Lower Bound (ELBO)
94
95
Loss functions for variational inference based on the evidence lower bound.
96
97
```python { .api }
98
class ELBO:
99
"""
100
Base class for Evidence Lower Bound loss functions.
101
102
ELBO provides a lower bound on the model evidence (marginal likelihood)
103
and serves as the optimization objective for variational inference.
104
"""
105
106
def differentiable_loss(self, model, guide, *args, **kwargs) -> torch.Tensor:
107
"""
108
Compute differentiable ELBO loss.
109
110
Returns:
111
Tensor: Negative ELBO (loss to minimize)
112
"""
113
114
class Trace_ELBO(ELBO):
115
"""
116
Standard trace-based ELBO implementation.
117
118
Uses execution traces to compute ELBO via the log probability of the
119
joint model minus the log probability of the guide.
120
"""
121
122
def __init__(self, num_particles: int = 1, max_plate_nesting: int = float('inf'),
123
max_iarange_nesting: int = None, vectorize_particles: bool = False,
124
strict_enumeration_warning: bool = True):
125
"""
126
Parameters:
127
- num_particles (int): Number of Monte Carlo samples for gradient estimation
128
- max_plate_nesting (int): Maximum depth of nested plates to vectorize over
129
- vectorize_particles (bool): Whether to vectorize over particles
130
- strict_enumeration_warning (bool): Whether to warn about enumeration issues
131
"""
132
133
class TraceEnum_ELBO(ELBO):
134
"""
135
ELBO with exact enumeration over discrete latent variables.
136
137
Computes exact expectations over discrete variables while using
138
Monte Carlo for continuous variables.
139
"""
140
141
def __init__(self, max_plate_nesting: int = float('inf'), max_iarange_nesting: int = None,
142
strict_enumeration_warning: bool = True, ignore_jit_warnings: bool = False):
143
"""
144
Parameters:
145
- max_plate_nesting (int): Maximum plate nesting depth for enumeration
146
- strict_enumeration_warning (bool): Whether to warn about enumeration issues
147
- ignore_jit_warnings (bool): Whether to ignore JIT compilation warnings
148
"""
149
150
class TraceGraph_ELBO(ELBO):
151
"""
152
Memory-efficient ELBO using dependency graphs.
153
154
Reduces memory usage by computing gradients using the dependency
155
structure of the computational graph.
156
"""
157
pass
158
159
class TraceMeanField_ELBO(ELBO):
160
"""
161
ELBO for mean-field variational inference.
162
163
Assumes independence between latent variables in the guide,
164
enabling more efficient computation.
165
"""
166
pass
167
168
class RenyiELBO(ELBO):
169
"""
170
Renyi divergence-based ELBO for more robust inference.
171
172
Uses Renyi alpha-divergence instead of KL divergence for
173
potentially better optimization properties.
174
"""
175
176
def __init__(self, alpha: float = 0.0, num_particles: int = 2, max_plate_nesting: int = float('inf')):
177
"""
178
Parameters:
179
- alpha (float): Renyi divergence parameter (alpha=0 gives KL divergence)
180
- num_particles (int): Number of particles for gradient estimation
181
- max_plate_nesting (int): Maximum plate nesting depth
182
"""
183
```
184
185
### Markov Chain Monte Carlo
186
187
MCMC methods for exact sampling from posterior distributions.
188
189
```python { .api }
190
class MCMC:
191
"""
192
Markov Chain Monte Carlo interface for exact posterior sampling.
193
194
MCMC generates correlated samples from the exact posterior distribution
195
using various kernel methods like HMC and NUTS.
196
"""
197
198
def __init__(self, kernel, num_samples: int, warmup_steps: int = None,
199
initial_params: dict = None, chain_id: int = 0, mp_context=None,
200
disable_progbar: bool = False, disable_validation: bool = True,
201
transforms: dict = None, max_tree_depth: int = None,
202
target_accept_prob: float = 0.8, jit_compile: bool = False):
203
"""
204
Parameters:
205
- kernel: MCMC kernel (e.g., HMC, NUTS, RandomWalkKernel)
206
- num_samples (int): Number of MCMC samples to generate
207
- warmup_steps (int): Number of warmup/burn-in steps
208
- initial_params (dict): Initial parameter values
209
- chain_id (int): Chain identifier for multiple chains
210
- transforms (dict): Parameter transforms for constrained sampling
211
- target_accept_prob (float): Target acceptance probability for adaptive kernels
212
- jit_compile (bool): Whether to JIT compile the kernel
213
214
Examples:
215
>>> kernel = NUTS(model)
216
>>> mcmc = MCMC(kernel, num_samples=1000, warmup_steps=500)
217
"""
218
219
def run(self, *args, **kwargs):
220
"""
221
Run the MCMC chain.
222
223
Parameters:
224
- *args, **kwargs: Arguments to pass to the model
225
226
Examples:
227
>>> mcmc.run(data)
228
>>> samples = mcmc.get_samples()
229
"""
230
231
def get_samples(self, group_by_chain: bool = False) -> dict:
232
"""
233
Get MCMC samples after running the chain.
234
235
Parameters:
236
- group_by_chain (bool): Whether to group samples by chain
237
238
Returns:
239
dict: Dictionary mapping sample site names to sample tensors
240
241
Examples:
242
>>> samples = mcmc.get_samples()
243
>>> theta_samples = samples["theta"]
244
"""
245
246
class HMC:
247
"""
248
Hamiltonian Monte Carlo kernel.
249
250
HMC uses gradient information to make efficient proposals in
251
continuous parameter spaces.
252
"""
253
254
def __init__(self, model, step_size: float = 1.0, num_steps: int = 1,
255
adapt_step_size: bool = True, adapt_mass_matrix: bool = True,
256
full_mass: bool = False, transforms: dict = None,
257
max_plate_nesting: int = None, jit_compile: bool = False,
258
jit_options: dict = None, ignore_jit_warnings: bool = False):
259
"""
260
Parameters:
261
- model (callable): Model to sample from
262
- step_size (float): Integration step size
263
- num_steps (int): Number of leapfrog steps per iteration
264
- adapt_step_size (bool): Whether to adapt step size during warmup
265
- adapt_mass_matrix (bool): Whether to adapt mass matrix
266
- full_mass (bool): Whether to use full mass matrix (vs diagonal)
267
- transforms (dict): Parameter transformations
268
"""
269
270
class NUTS:
271
"""
272
No-U-Turn Sampler, an adaptive version of HMC.
273
274
NUTS automatically determines the number of leapfrog steps to take
275
by detecting when the trajectory starts to reverse direction.
276
"""
277
278
def __init__(self, model, step_size: float = 1.0, adapt_step_size: bool = True,
279
adapt_mass_matrix: bool = True, full_mass: bool = False,
280
transforms: dict = None, max_plate_nesting: int = None,
281
max_tree_depth: int = 10, target_accept_prob: float = 0.8,
282
jit_compile: bool = False, jit_options: dict = None,
283
ignore_jit_warnings: bool = False):
284
"""
285
Parameters:
286
- model (callable): Model to sample from
287
- step_size (float): Initial step size
288
- max_tree_depth (int): Maximum binary tree depth
289
- target_accept_prob (float): Target acceptance probability for adaptation
290
291
Examples:
292
>>> nuts_kernel = NUTS(model)
293
>>> mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
294
"""
295
296
class RandomWalkKernel:
297
"""
298
Random walk Metropolis-Hastings kernel.
299
300
Simple MCMC kernel that proposes new states by adding random noise
301
to the current state.
302
"""
303
304
def __init__(self, model, step_size: dict = None, adapt_step_size: bool = True,
305
transforms: dict = None, max_plate_nesting: int = None):
306
"""
307
Parameters:
308
- model (callable): Model to sample from
309
- step_size (dict): Step sizes for each parameter
310
- adapt_step_size (bool): Whether to adapt step size during warmup
311
"""
312
```
313
314
### Predictive Sampling
315
316
Generate predictions and samples from trained models.
317
318
```python { .api }
319
class Predictive:
320
"""
321
Generate predictive samples from posterior or prior distributions.
322
323
Predictive enables posterior predictive checks, prior predictive checks,
324
and out-of-sample predictions by sampling from the model with different
325
parameter configurations.
326
"""
327
328
def __init__(self, model, guide=None, posterior_samples: dict = None,
329
num_samples: int = None, return_sites: list = None,
330
parallel: bool = False, batch_ndims: int = 1):
331
"""
332
Parameters:
333
- model (callable): Generative model function
334
- guide (callable, optional): Variational guide for posterior sampling
335
- posterior_samples (dict, optional): Pre-computed posterior samples
336
- num_samples (int, optional): Number of samples to generate
337
- return_sites (list, optional): Sites to include in output
338
- parallel (bool): Whether to parallelize sampling
339
- batch_ndims (int): Number of batch dimensions
340
341
Examples:
342
>>> # Posterior predictive with guide
343
>>> predictive = Predictive(model, guide=guide, num_samples=1000)
344
>>> samples = predictive(data)
345
>>>
346
>>> # Prior predictive
347
>>> predictive = Predictive(model, num_samples=100)
348
>>> prior_samples = predictive(data)
349
"""
350
351
def __call__(self, *args, **kwargs) -> dict:
352
"""
353
Generate predictive samples.
354
355
Parameters:
356
- *args, **kwargs: Arguments to pass to the model
357
358
Returns:
359
dict: Dictionary mapping site names to sample tensors
360
361
Examples:
362
>>> samples = predictive(test_data)
363
>>> predictions = samples["obs"]
364
"""
365
366
class WeighedPredictive:
367
"""
368
Generate weighted predictive samples using importance sampling.
369
370
Useful when posterior samples come from importance sampling or
371
when samples have non-uniform weights.
372
"""
373
374
def __init__(self, model, guide=None, posterior_samples: dict = None,
375
weights: torch.Tensor = None, num_samples: int = None,
376
return_sites: list = None, parallel: bool = False):
377
"""
378
Parameters:
379
- model (callable): Generative model function
380
- guide (callable, optional): Guide function
381
- posterior_samples (dict, optional): Pre-computed samples
382
- weights (Tensor, optional): Sample weights
383
- num_samples (int, optional): Number of samples to generate
384
"""
385
386
class EmpiricalMarginal:
387
"""
388
Empirical marginal distribution from MCMC or SVI samples.
389
390
Converts a collection of samples into a distribution object that
391
can be used like any other Pyro distribution.
392
"""
393
394
def __init__(self, samples: torch.Tensor, log_weights: torch.Tensor = None):
395
"""
396
Parameters:
397
- samples (Tensor): Sample values
398
- log_weights (Tensor, optional): Log weights for samples
399
400
Examples:
401
>>> samples = mcmc.get_samples()["theta"]
402
>>> marginal = EmpiricalMarginal(samples)
403
>>> new_sample = marginal.sample()
404
"""
405
```
406
407
### Importance Sampling
408
409
Importance sampling methods for model comparison and marginal likelihood estimation.
410
411
```python { .api }
412
class Importance:
413
"""
414
Importance sampling for marginal likelihood estimation.
415
416
Uses importance sampling to estimate the model evidence (marginal likelihood)
417
which is useful for model comparison and selection.
418
"""
419
420
def __init__(self, model, guide, num_samples: int):
421
"""
422
Parameters:
423
- model (callable): Generative model
424
- guide (callable): Importance sampling distribution (proposal)
425
- num_samples (int): Number of importance samples
426
427
Examples:
428
>>> importance = Importance(model, guide, num_samples=10000)
429
>>> log_evidence = importance.run(data)
430
"""
431
432
def run(self, *args, **kwargs) -> torch.Tensor:
433
"""
434
Run importance sampling to estimate log marginal likelihood.
435
436
Parameters:
437
- *args, **kwargs: Arguments to pass to model and guide
438
439
Returns:
440
Tensor: Log marginal likelihood estimate
441
"""
442
443
class SMCFilter:
444
"""
445
Sequential Monte Carlo filtering for state space models.
446
447
Implements particle filtering for sequential Bayesian inference
448
in time series and state space models.
449
"""
450
451
def __init__(self, model, guide, num_particles: int, max_plate_nesting: int):
452
"""
453
Parameters:
454
- model (callable): State space model
455
- guide (callable): Proposal distribution for particles
456
- num_particles (int): Number of particles to maintain
457
- max_plate_nesting (int): Maximum plate nesting depth
458
"""
459
```
460
461
### Specialized Inference Methods
462
463
Advanced inference algorithms for specific model types and scenarios.
464
465
```python { .api }
466
class SVGD:
467
"""
468
Stein Variational Gradient Descent for non-parametric inference.
469
470
SVGD optimizes a set of particles to approximate the posterior distribution
471
using kernelized Stein discrepancy minimization.
472
"""
473
474
def __init__(self, model, kernel, optimizer, num_particles: int):
475
"""
476
Parameters:
477
- model (callable): Model function
478
- kernel: Kernel function for Stein method
479
- optimizer: Optimizer for particle updates
480
- num_particles (int): Number of particles to optimize
481
"""
482
483
class ReweightedWakeSleep:
484
"""
485
Reweighted Wake-Sleep algorithm for deep generative models.
486
487
Alternative to standard variational inference that can handle
488
more complex posterior approximations.
489
"""
490
491
def __init__(self, model, guide, wake_loss, sleep_loss):
492
"""
493
Parameters:
494
- model (callable): Generative model
495
- guide (callable): Recognition model
496
- wake_loss: Loss function for wake phase
497
- sleep_loss: Loss function for sleep phase
498
"""
499
500
def config_enumerate(default: str = None, expand: bool = False, num_samples: int = None):
501
"""
502
Configure automatic enumeration over discrete latent variables.
503
504
Decorator that enables exact marginalization over discrete variables
505
in models with both discrete and continuous latent variables.
506
507
Parameters:
508
- default (str): Default enumeration strategy ("sequential" or "parallel")
509
- expand (bool): Whether to expand enumerated dimensions
510
- num_samples (int): Number of samples for approximate enumeration
511
512
Examples:
513
>>> @config_enumerate
514
>>> def model():
515
... z = pyro.sample("z", dist.Categorical(torch.ones(3)))
516
... return pyro.sample("x", dist.Normal(z, 1))
517
"""
518
519
def infer_discrete(first_available_dim: int = None, temperature: float = 1.0,
520
cooler: callable = None):
521
"""
522
Infer discrete latent variables by enumeration or sampling.
523
524
Effect handler that automatically handles discrete variable inference
525
by choosing between exact enumeration and approximate sampling.
526
527
Parameters:
528
- first_available_dim (int): First tensor dimension available for enumeration
529
- temperature (float): Temperature for discrete sampling
530
- cooler (callable): Cooling schedule for simulated annealing
531
532
Examples:
533
>>> with infer_discrete():
534
... svi.step(data)
535
"""
536
```
537
538
## Examples
539
540
### Basic SVI Training
541
542
```python
543
import pyro
544
import pyro.distributions as dist
545
from pyro.infer import SVI, Trace_ELBO
546
from pyro.optim import Adam
547
548
def model(data):
549
mu = pyro.sample("mu", dist.Normal(0, 10))
550
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
551
552
with pyro.plate("data", len(data)):
553
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
554
555
def guide(data):
556
mu_q = pyro.param("mu_q", torch.tensor(0.0))
557
sigma_q = pyro.param("sigma_q", torch.tensor(1.0), constraint=dist.constraints.positive)
558
559
pyro.sample("mu", dist.Normal(mu_q, sigma_q))
560
pyro.sample("sigma", dist.LogNormal(0, 1)) # Use prior as guide
561
562
# Training
563
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
564
losses = []
565
for step in range(1000):
566
loss = svi.step(data)
567
losses.append(loss)
568
```
569
570
### MCMC Sampling
571
572
```python
573
from pyro.infer import MCMC, NUTS
574
575
def model(data):
576
mu = pyro.sample("mu", dist.Normal(0, 10))
577
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
578
579
with pyro.plate("data", len(data)):
580
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
581
582
# MCMC sampling
583
nuts_kernel = NUTS(model)
584
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
585
mcmc.run(data)
586
587
# Get samples
588
samples = mcmc.get_samples()
589
mu_samples = samples["mu"]
590
sigma_samples = samples["sigma"]
591
```
592
593
### Posterior Predictive Checks
594
595
```python
596
from pyro.infer import Predictive
597
598
# After training SVI or MCMC
599
predictive = Predictive(model, guide=guide, num_samples=1000)
600
posterior_samples = predictive(data)
601
602
# Generate predictions for new data
603
predictive_new = Predictive(model, guide=guide, num_samples=100)
604
predictions = predictive_new(new_data)
605
```
606
607
### Model Comparison with Importance Sampling
608
609
```python
610
from pyro.infer import Importance
611
612
# Compare two models
613
importance1 = Importance(model1, guide1, num_samples=10000)
614
log_evidence1 = importance1.run(data)
615
616
importance2 = Importance(model2, guide2, num_samples=10000)
617
log_evidence2 = importance2.run(data)
618
619
# Bayes factor
620
bayes_factor = torch.exp(log_evidence1 - log_evidence2)
621
print(f"Bayes factor (Model 1 vs Model 2): {bayes_factor}")
622
```