0
# PyMC Variational Inference
1
2
PyMC provides comprehensive variational inference methods for fast approximate Bayesian inference. Variational methods are particularly useful for large datasets and complex models where MCMC sampling may be computationally prohibitive.
3
4
## Main Variational Interface
5
6
### Primary Fitting Function
7
8
```python { .api }
9
import pymc as pm
10
11
def fit(n=10000, method='advi', model=None, random_seed=None,
12
start=None, inf_kwargs=None, **kwargs):
13
"""
14
Fit variational approximation to the posterior.
15
16
Parameters:
17
- n (int): Number of optimization iterations (default: 10000)
18
- method (str): Inference method ('advi', 'fullrank_advi', 'svgd', 'asvgd')
19
- model: PyMC model (default: current context model)
20
- random_seed (int): Random seed for reproducibility
21
- start (dict): Starting parameter values
22
- inf_kwargs (dict): Method-specific keyword arguments
23
24
Returns:
25
- approximation: Fitted variational approximation object
26
"""
27
28
# Basic variational inference
29
with pm.Model() as model:
30
# Define model...
31
approx = pm.fit(n=50000)
32
33
# Advanced configuration
34
approx = pm.fit(
35
n=100000,
36
method='fullrank_advi',
37
optimizer=pm.adam(learning_rate=0.01),
38
callbacks=[pm.CheckParametersConvergence()],
39
progressbar=True
40
)
41
```
42
43
## Automatic Differentiation Variational Inference (ADVI)
44
45
### Mean-Field ADVI
46
47
The default variational inference method using mean-field approximation:
48
49
```python { .api }
50
from pymc.variational import ADVI
51
52
class ADVI:
53
"""
54
Automatic Differentiation Variational Inference with mean-field approximation.
55
56
Parameters:
57
- model: PyMC model
58
- random_seed (int): Random seed
59
- start (dict): Initial parameter values
60
61
Methods:
62
- fit: Optimize variational parameters
63
- sample: Draw samples from approximation
64
"""
65
66
def __init__(self, model=None, random_seed=None, start=None):
67
pass
68
69
def fit(self, n, optimizer=None, callbacks=None, progressbar=True, **kwargs):
70
"""
71
Fit the variational approximation.
72
73
Parameters:
74
- n (int): Number of optimization steps
75
- optimizer: Optimization algorithm
76
- callbacks (list): Callback functions
77
- progressbar (bool): Show progress bar
78
79
Returns:
80
- approximation: Fitted approximation
81
"""
82
pass
83
84
# Explicit ADVI usage
85
with pm.Model() as model:
86
# Model definition...
87
88
# Create ADVI inference object
89
inference = pm.ADVI()
90
91
# Fit approximation
92
approx = inference.fit(n=50000, optimizer=pm.adam(learning_rate=0.01))
93
94
# Draw samples from approximation
95
trace = approx.sample(2000)
96
```
97
98
### Full-Rank ADVI
99
100
ADVI with full covariance structure:
101
102
```python { .api }
103
from pymc.variational import FullRankADVI
104
105
class FullRankADVI:
106
"""
107
Full-rank ADVI with correlated posterior approximation.
108
109
Parameters:
110
- model: PyMC model
111
- random_seed (int): Random seed
112
"""
113
114
# Full-rank approximation for capturing correlations
115
with pm.Model() as model:
116
# Model with correlated parameters...
117
118
inference = pm.FullRankADVI()
119
approx = inference.fit(n=75000)
120
121
# Full covariance matrix available
122
cov_matrix = approx.cov.eval()
123
```
124
125
## Stein Variational Gradient Descent
126
127
### Standard SVGD
128
129
Particle-based variational inference:
130
131
```python { .api }
132
from pymc.variational import SVGD
133
134
class SVGD:
135
"""
136
Stein Variational Gradient Descent.
137
138
Parameters:
139
- n_particles (int): Number of particles (default: 100)
140
- jitter (float): Jitter for numerical stability
141
- model: PyMC model
142
"""
143
144
def __init__(self, n_particles=100, jitter=1e-6, model=None):
145
pass
146
147
# SVGD for complex posteriors
148
with pm.Model() as complex_model:
149
# Complex model definition...
150
151
inference = pm.SVGD(n_particles=200)
152
approx = inference.fit(n=20000)
153
154
# Particles represent the posterior
155
particles = approx.sample(1000)
156
```
157
158
### Amortized SVGD
159
160
```python { .api }
161
from pymc.variational import ASVGD
162
163
class ASVGD:
164
"""
165
Amortized Stein Variational Gradient Descent.
166
167
Parameters:
168
- n_particles (int): Number of particles
169
- batch_size (int): Mini-batch size
170
"""
171
172
# ASVGD for large datasets with mini-batching
173
with pm.Model() as large_model:
174
# Model with large dataset...
175
176
inference = pm.ASVGD(n_particles=50, batch_size=128)
177
approx = inference.fit(n=30000)
178
```
179
180
## Variational Approximations
181
182
### Mean-Field Approximation
183
184
Independent normal distributions for each parameter:
185
186
```python { .api }
187
from pymc.variational.approximations import MeanField
188
189
class MeanField:
190
"""
191
Mean-field approximation with independent normal distributions.
192
193
Parameters:
194
- local_rv (dict): Local random variables
195
- model: PyMC model
196
197
Methods:
198
- sample: Draw samples from approximation
199
- apply_replacements: Apply variational replacements
200
"""
201
202
def sample(self, draws=1000, include_transformed=True):
203
"""
204
Sample from mean-field approximation.
205
206
Parameters:
207
- draws (int): Number of samples to draw
208
- include_transformed (bool): Include transformed variables
209
210
Returns:
211
- samples: Dictionary of parameter samples
212
"""
213
pass
214
215
# Access approximation directly
216
with pm.Model() as model:
217
# Model definition...
218
219
# Create mean-field approximation
220
mean_field = pm.MeanField()
221
222
# Fit using KL divergence minimization
223
approx = pm.KLqp(mean_field).fit(n=50000)
224
```
225
226
### Full-Rank Approximation
227
228
Multivariate normal with full covariance:
229
230
```python { .api }
231
from pymc.variational.approximations import FullRank
232
233
class FullRank:
234
"""
235
Full-rank multivariate normal approximation.
236
237
Parameters:
238
- local_rv (dict): Local random variables
239
- model: PyMC model
240
241
Attributes:
242
- cov: Covariance matrix
243
- mean: Mean vector
244
"""
245
246
# Full-rank for capturing parameter correlations
247
with pm.Model() as correlated_model:
248
# Model with strong parameter correlations...
249
250
full_rank = pm.FullRank()
251
approx = pm.KLqp(full_rank).fit(n=75000)
252
253
# Access covariance structure
254
posterior_cov = approx.cov.eval()
255
posterior_corr = approx.std_to_corr(posterior_cov)
256
```
257
258
### Empirical Approximation
259
260
Empirical distribution from particle samples:
261
262
```python { .api }
263
from pymc.variational.approximations import Empirical
264
265
class Empirical:
266
"""
267
Empirical approximation using particle samples.
268
269
Parameters:
270
- local_rv (dict): Local random variables
271
- size (int): Number of particles
272
"""
273
274
# Empirical approximation from SVGD
275
with pm.Model() as model:
276
# Model definition...
277
278
empirical = pm.Empirical(size=500)
279
approx = pm.SVGD(approximation=empirical).fit(n=25000)
280
```
281
282
## Optimization Algorithms
283
284
### Adam Optimizer
285
286
```python { .api }
287
def adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
288
"""
289
Adam optimizer for variational inference.
290
291
Parameters:
292
- learning_rate (float): Step size
293
- beta1 (float): Exponential decay rate for 1st moment
294
- beta2 (float): Exponential decay rate for 2nd moment
295
- epsilon (float): Small constant for numerical stability
296
297
Returns:
298
- optimizer: Adam optimizer object
299
"""
300
301
# Custom Adam configuration
302
optimizer = pm.adam(
303
learning_rate=0.005,
304
beta1=0.95,
305
beta2=0.999
306
)
307
308
approx = pm.fit(n=50000, optimizer=optimizer)
309
```
310
311
### Other Optimizers
312
313
```python { .api }
314
# Stochastic Gradient Descent
315
sgd_optimizer = pm.sgd(learning_rate=0.01)
316
317
# AdaGrad
318
adagrad_optimizer = pm.adagrad(learning_rate=0.1)
319
320
# RMSprop
321
rmsprop_optimizer = pm.rmsprop(learning_rate=0.001, decay=0.9)
322
323
# Adamax
324
adamax_optimizer = pm.adamax(learning_rate=0.002)
325
326
# AdaDelta
327
adadelta_optimizer = pm.adadelta(learning_rate=1.0, decay=0.95)
328
```
329
330
## Advanced Variational Methods
331
332
### Custom Inference Classes
333
334
```python { .api }
335
from pymc.variational.inference import KLqp, Inference
336
337
class KLqp(Inference):
338
"""
339
Kullback-Leibler divergence minimization.
340
341
Parameters:
342
- approx: Variational approximation
343
- beta (float): Regularization parameter
344
"""
345
346
def __init__(self, approx, beta=1.0):
347
pass
348
349
# Custom inference setup
350
with pm.Model() as model:
351
# Model definition...
352
353
# Custom approximation
354
custom_approx = pm.MeanField()
355
356
# KL divergence inference
357
inference = pm.KLqp(custom_approx, beta=0.9)
358
approx = inference.fit(n=40000)
359
```
360
361
### Implicit Gradient Methods
362
363
```python { .api }
364
from pymc.variational.inference import ImplicitGradient
365
366
class ImplicitGradient(Inference):
367
"""
368
Implicit gradient variational inference.
369
370
Parameters:
371
- approx: Variational approximation
372
- tk (float): Temperature parameter
373
"""
374
375
# Implicit gradient inference for difficult posteriors
376
with pm.Model() as difficult_model:
377
# Model with complex geometry...
378
379
implicit = pm.ImplicitGradient(pm.MeanField(), tk=1.5)
380
approx = implicit.fit(n=60000)
381
```
382
383
## Variational Groups and Structured Approximations
384
385
### Grouping Variables
386
387
```python { .api }
388
from pymc.variational.opvi import Group
389
390
class Group:
391
"""
392
Group variables for structured approximations.
393
394
Parameters:
395
- group_vars (list): Variables in the group
396
- approximation: Group-specific approximation
397
"""
398
399
# Group correlated parameters together
400
with pm.Model() as hierarchical_model:
401
# Hierarchical model...
402
403
# Group 1: Hyperparameters (mean-field)
404
hyper_group = pm.Group([mu_alpha, sigma_alpha], pm.MeanField())
405
406
# Group 2: Group effects (full-rank)
407
group_effects = pm.Group([alpha], pm.FullRank())
408
409
# Combined approximation
410
approximation = hyper_group + group_effects
411
approx = pm.KLqp(approximation).fit(n=50000)
412
```
413
414
## Callbacks and Monitoring
415
416
### Built-in Callbacks
417
418
```python { .api }
419
# Parameter convergence monitoring
420
convergence_cb = pm.CheckParametersConvergence(tolerance=0.01)
421
422
# Early stopping
423
early_stop_cb = pm.CheckParametersConvergence(tolerance=0.001, patience=5000)
424
425
# Custom callback function
426
def custom_callback(approx, loss_history, i):
427
if i % 1000 == 0:
428
current_loss = loss_history[-1]
429
print(f"Iteration {i}: Loss = {current_loss:.4f}")
430
431
# Use callbacks during fitting
432
approx = pm.fit(
433
n=50000,
434
callbacks=[convergence_cb, custom_callback],
435
progressbar=True
436
)
437
```
438
439
## Sampling from Approximations
440
441
### Drawing Samples
442
443
```python { .api }
444
def sample_approx(n, approximation, more_replacements=None,
445
return_inferencedata=True, **kwargs):
446
"""
447
Sample from variational approximation.
448
449
Parameters:
450
- n (int): Number of samples
451
- approximation: Fitted approximation
452
- more_replacements (dict): Additional variable replacements
453
- return_inferencedata (bool): Return ArviZ InferenceData
454
455
Returns:
456
- samples: Samples from approximation
457
"""
458
459
# Sample from fitted approximation
460
samples = pm.sample_approx(n=5000, approximation=approx)
461
462
# Sample with additional replacements
463
custom_samples = pm.sample_approx(
464
n=3000,
465
approximation=approx,
466
more_replacements={'custom_var': custom_replacement}
467
)
468
```
469
470
### Integration with MCMC
471
472
```python { .api }
473
# Use variational approximation to initialize MCMC
474
with pm.Model() as model:
475
# Model definition...
476
477
# Fit variational approximation
478
approx = pm.fit(n=30000)
479
480
# Use as MCMC initialization
481
vi_samples = approx.sample(1000)
482
start_point = {var: samples[var][-1] for var, samples in vi_samples.items()}
483
484
# MCMC with VI initialization
485
mcmc_trace = pm.sample(initvals=start_point, tune=1000, draws=2000)
486
```
487
488
## Model Comparison and Diagnostics
489
490
### ELBO Monitoring
491
492
```python { .api }
493
# Track Evidence Lower Bound during optimization
494
with pm.Model() as model:
495
# Model definition...
496
497
# Fit with ELBO tracking
498
approx = pm.fit(n=50000, progressbar=True)
499
500
# Access ELBO history
501
elbo_history = approx.hist
502
503
# Plot convergence
504
import matplotlib.pyplot as plt
505
plt.plot(elbo_history)
506
plt.xlabel('Iteration')
507
plt.ylabel('ELBO')
508
plt.title('Variational Inference Convergence')
509
```
510
511
### Approximation Quality Assessment
512
513
```python { .api }
514
# Compare VI approximation with true posterior (if available)
515
def assess_approximation_quality(approx, true_trace, var_names):
516
"""Compare VI approximation with MCMC samples."""
517
vi_samples = approx.sample(5000)
518
519
for var in var_names:
520
vi_mean = vi_samples[var].mean()
521
vi_std = vi_samples[var].std()
522
523
mcmc_mean = true_trace[var].mean()
524
mcmc_std = true_trace[var].std()
525
526
print(f"{var}:")
527
print(f" VI: mean={vi_mean:.3f}, std={vi_std:.3f}")
528
print(f" MCMC: mean={mcmc_mean:.3f}, std={mcmc_std:.3f}")
529
530
# Usage
531
assess_approximation_quality(approx, mcmc_trace, ['alpha', 'beta', 'sigma'])
532
```
533
534
## Large-Scale Variational Inference
535
536
### Mini-batch Variational Inference
537
538
```python { .api }
539
# Mini-batch VI for large datasets
540
with pm.Model() as large_scale_model:
541
# Large dataset
542
X_mb = pm.Minibatch(X_large, batch_size=256)
543
y_mb = pm.Minibatch(y_large, batch_size=256)
544
545
# Model with mini-batched data
546
alpha = pm.Normal('alpha', 0, 1)
547
beta = pm.Normal('beta', 0, 1, shape=p)
548
mu = alpha + pm.math.dot(X_mb, beta)
549
550
# Scale likelihood for mini-batching
551
n_total = X_large.shape[0]
552
batch_size = 256
553
scaling_factor = n_total / batch_size
554
555
y_obs = pm.Normal('y_obs', mu=mu, sigma=1, observed=y_mb,
556
total_size=n_total)
557
558
# Variational inference with mini-batches
559
approx = pm.fit(n=100000, method='advi')
560
```
561
562
### Parallel Variational Inference
563
564
```python { .api }
565
# Parallel VI with multiple chains
566
import multiprocessing as mp
567
568
with pm.Model() as model:
569
# Model definition...
570
571
# Parallel VI approximations
572
n_chains = mp.cpu_count()
573
approximations = []
574
575
for chain in range(n_chains):
576
approx_chain = pm.fit(
577
n=25000,
578
random_seed=chain,
579
progressbar=False
580
)
581
approximations.append(approx_chain)
582
583
# Combine approximations (ensemble)
584
ensemble_samples = []
585
for approx in approximations:
586
samples = approx.sample(1000)
587
ensemble_samples.append(samples)
588
```
589
590
## Usage Patterns and Best Practices
591
592
### Hierarchical Models with VI
593
594
```python { .api }
595
# Efficient VI for hierarchical models
596
with pm.Model() as hierarchical_vi:
597
# Hyperparameters
598
mu_mu = pm.Normal('mu_mu', 0, 10)
599
sigma_mu = pm.HalfNormal('sigma_mu', 5)
600
601
# Group parameters (non-centered parameterization)
602
mu_raw = pm.Normal('mu_raw', 0, 1, shape=n_groups)
603
mu = pm.Deterministic('mu', mu_mu + sigma_mu * mu_raw)
604
605
# Likelihood
606
y_obs = pm.Normal('y_obs', mu=mu[group_idx], sigma=1, observed=data)
607
608
# VI works well with non-centered parameterization
609
approx = pm.fit(n=50000, method='advi')
610
```
611
612
### Model Selection with Variational Methods
613
614
```python { .api }
615
# Compare models using variational inference
616
models_vi = {}
617
approximations = {}
618
619
for model_name, model in candidate_models.items():
620
with model:
621
approx = pm.fit(n=40000)
622
approximations[model_name] = approx
623
624
# Store ELBO for comparison
625
models_vi[model_name] = {
626
'elbo': approx.hist[-1],
627
'n_params': len(model.free_RVs),
628
'approximation': approx
629
}
630
631
# Select best model by ELBO
632
best_model = max(models_vi.keys(), key=lambda k: models_vi[k]['elbo'])
633
```
634
635
PyMC's variational inference framework provides efficient approximate inference methods suitable for large-scale Bayesian modeling, offering significant computational advantages over MCMC while maintaining reasonable approximation quality for many practical applications.