0
# Variational Inference
1
2
PyMC3's variational inference module provides scalable approximate inference methods that optimize a tractable approximation to the posterior distribution. These methods trade some accuracy for computational efficiency, making Bayesian inference feasible for large datasets and complex models.
3
4
## Capabilities
5
6
### Main Inference Function
7
8
High-level interface for variational inference with automatic method selection and optimization.
9
10
```python { .api }
11
def fit(n=10000, method='advi', model=None, random_seed=None,
12
start=None, inf_kwargs=None, **kwargs):
13
"""
14
Fit model using variational inference.
15
16
Main interface for approximate inference using various VI methods.
17
Automatically handles method selection, initialization, and optimization.
18
19
Parameters:
20
- n: int, number of optimization iterations
21
- method: str or Inference class, VI method ('advi', 'fullrank_advi', 'svgd', 'asvgd', 'nfvi')
22
- model: Model, model to fit (current context if None)
23
- random_seed: int, random seed for reproducibility
24
- start: dict, initial parameter values
25
- inf_kwargs: dict, arguments passed to inference method
26
- kwargs: additional arguments for specific methods
27
28
Returns:
29
- Approximation: fitted approximation object
30
31
Example:
32
with pm.Model() as model:
33
mu = pm.Normal('mu', 0, 1)
34
y_obs = pm.Normal('y_obs', mu, 1, observed=data)
35
36
# Fit using ADVI
37
approx = pm.fit(n=20000, method='advi')
38
39
# Sample from approximation
40
trace = approx.sample(1000)
41
"""
42
43
def sample_approx(approx, draws=1000, include_transformed=True):
44
"""
45
Sample from fitted approximation.
46
47
Parameters:
48
- approx: Approximation, fitted approximation object
49
- draws: int, number of samples to draw
50
- include_transformed: bool, include transformed variables
51
52
Returns:
53
- MultiTrace: samples from approximation
54
"""
55
```
56
57
### Inference Classes
58
59
Core variational inference algorithms with different approximation strategies.
60
61
```python { .api }
62
class ADVI:
63
"""
64
Automatic Differentiation Variational Inference.
65
66
Mean-field variational inference using automatic differentiation
67
for gradient-based optimization. Assumes independence between
68
parameters in the approximating distribution.
69
"""
70
71
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
72
scale_cost_to_minibatch=False, random_seed=None,
73
start=None, **kwargs):
74
"""
75
Initialize ADVI inference.
76
77
Parameters:
78
- local_rv: dict, local random variables for minibatch inference
79
- model: Model, model to approximate
80
- cost_part_grad_scale: float, scaling for cost gradient
81
- scale_cost_to_minibatch: bool, scale cost to minibatch size
82
- random_seed: int, random seed
83
- start: dict, starting parameter values
84
"""
85
86
def fit(self, n=10000, score=None, callbacks=None,
87
progressbar=True, **kwargs):
88
"""
89
Fit the approximation.
90
91
Parameters:
92
- n: int, number of optimization iterations
93
- score: bool, compute ELBO score during fitting
94
- callbacks: list, callback functions for monitoring
95
- progressbar: bool, display progress bar
96
97
Returns:
98
- Approximation: fitted approximation
99
"""
100
101
class FullRankADVI:
102
"""
103
Full-rank Automatic Differentiation Variational Inference.
104
105
Variational inference with full covariance structure in the
106
approximating distribution, capturing correlations between parameters.
107
"""
108
109
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
110
scale_cost_to_minibatch=False, random_seed=None,
111
start=None, **kwargs):
112
"""Initialize full-rank ADVI inference."""
113
114
class SVGD:
115
"""
116
Stein Variational Gradient Descent.
117
118
Non-parametric variational method that uses a set of particles
119
to approximate the posterior through deterministic updates
120
guided by Stein's method.
121
"""
122
123
def __init__(self, n_particles=100, jitter=1e-6, model=None,
124
start=None, random_seed=None, **kwargs):
125
"""
126
Initialize SVGD inference.
127
128
Parameters:
129
- n_particles: int, number of particles
130
- jitter: float, noise for numerical stability
131
- model: Model, model to approximate
132
- start: dict, initial particle positions
133
- random_seed: int, random seed
134
"""
135
136
class ASVGD:
137
"""
138
Amortized Stein Variational Gradient Descent.
139
140
Extension of SVGD with amortization for handling multiple
141
related inference problems efficiently.
142
"""
143
144
class NFVI:
145
"""
146
Normalizing Flow Variational Inference.
147
148
Variational inference using normalizing flows to create
149
flexible approximating distributions beyond simple parametric forms.
150
"""
151
152
def __init__(self, flow='planar', model=None, random_seed=None,
153
**kwargs):
154
"""
155
Initialize normalizing flow VI.
156
157
Parameters:
158
- flow: str or Flow, type of normalizing flow
159
- model: Model, model to approximate
160
- random_seed: int, random seed
161
"""
162
163
class ImplicitGradient:
164
"""
165
Implicit gradient variational inference.
166
167
Advanced method using implicit differentiation for
168
variational optimization with complex constraints.
169
"""
170
171
class Inference:
172
"""
173
Base class for all variational inference methods.
174
175
Provides common interface and functionality for different
176
VI algorithms including optimization, callbacks, and diagnostics.
177
"""
178
179
@property
180
def approx(self):
181
"""Current approximation object."""
182
183
@property
184
def hist(self):
185
"""Optimization history."""
186
187
def fit(self, n, **kwargs):
188
"""Abstract method for fitting approximation."""
189
190
class KLqp:
191
"""
192
Kullback-Leibler divergence minimization.
193
194
General framework for KL-divergence based variational inference
195
with customizable divergence measures and optimization strategies.
196
"""
197
```
198
199
### Approximation Classes
200
201
Different parametric forms for approximating posterior distributions.
202
203
```python { .api }
204
class MeanField:
205
"""
206
Mean-field Gaussian approximation.
207
208
Independent normal distributions for each parameter,
209
assuming no correlations in the approximating distribution.
210
"""
211
212
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
213
scale_cost_to_minibatch=False, random_seed=None,
214
start=None):
215
"""Initialize mean-field approximation."""
216
217
def sample(self, draws=1000, include_transformed=True):
218
"""
219
Sample from approximation.
220
221
Parameters:
222
- draws: int, number of samples
223
- include_transformed: bool, include transformed variables
224
225
Returns:
226
- MultiTrace: samples from approximation
227
"""
228
229
@property
230
def mean(self):
231
"""Mean of approximating distribution."""
232
233
@property
234
def std(self):
235
"""Standard deviation of approximating distribution."""
236
237
def sample_node(self, node, size=None, more_replacements=None):
238
"""Sample specific model node."""
239
240
class FullRank:
241
"""
242
Full-rank Gaussian approximation.
243
244
Multivariate normal distribution with full covariance matrix,
245
capturing correlations between parameters in approximation.
246
"""
247
248
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
249
scale_cost_to_minibatch=False, random_seed=None,
250
start=None):
251
"""Initialize full-rank approximation."""
252
253
@property
254
def cov(self):
255
"""Covariance matrix of approximating distribution."""
256
257
@property
258
def L_chol(self):
259
"""Cholesky factor of covariance matrix."""
260
261
class Empirical:
262
"""
263
Empirical approximation using sample particles.
264
265
Non-parametric approximation represented by a set of
266
weighted samples, used in particle-based methods like SVGD.
267
"""
268
269
def __init__(self, trace=None, size=None, jitter=1e-6,
270
local_rv=None, model=None, **kwargs):
271
"""
272
Initialize empirical approximation.
273
274
Parameters:
275
- trace: MultiTrace, initial particles
276
- size: int, number of particles
277
- jitter: float, noise for numerical stability
278
- local_rv: dict, local random variables
279
- model: Model, model context
280
"""
281
282
@property
283
def histogram(self):
284
"""Histogram representation of samples."""
285
286
class NormalizingFlow:
287
"""
288
Normalizing flow approximation.
289
290
Flexible approximation using invertible transformations
291
to map simple base distribution to complex target.
292
"""
293
294
def __init__(self, flow='planar', local_rv=None, model=None,
295
**kwargs):
296
"""
297
Initialize normalizing flow approximation.
298
299
Parameters:
300
- flow: str or Flow, flow architecture
301
- local_rv: dict, local random variables
302
- model: Model, model context
303
"""
304
305
class Approximation:
306
"""
307
Base class for all approximations.
308
309
Provides common interface for sampling, evaluation, and
310
manipulation of approximating distributions.
311
"""
312
313
def sample(self, draws=1000, **kwargs):
314
"""Sample from approximation."""
315
316
def sample_vp(self, draws=1000, hide_transformed=True):
317
"""Sample variational parameters."""
318
319
def apply_replacements(self, node, more_replacements=None):
320
"""Apply parameter replacements to computational graph."""
321
322
@property
323
def logq(self):
324
"""Log-probability of approximating distribution."""
325
326
@property
327
def logq_norm(self):
328
"""Normalized log-probability."""
329
330
@property
331
def shared_params(self):
332
"""Shared parameter tensors."""
333
334
class Group:
335
"""
336
Variable grouping for approximation.
337
338
Groups related variables together for joint approximation,
339
useful for hierarchical models and structured approximations.
340
"""
341
342
def __init__(self, group, vfam=None, params=None):
343
"""
344
Initialize variable group.
345
346
Parameters:
347
- group: list, variables to group together
348
- vfam: str, variational family for group
349
- params: dict, group-specific parameters
350
"""
351
```
352
353
### Optimization Methods
354
355
Stochastic optimization algorithms for variational parameter updates.
356
357
```python { .api }
358
def adam(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
359
beta2=0.999, epsilon=1e-8):
360
"""
361
Adam optimizer for variational parameters.
362
363
Adaptive moment estimation with bias correction
364
for efficient optimization of variational objectives.
365
366
Parameters:
367
- loss_or_grads: loss function or gradients
368
- params: list, parameters to optimize
369
- learning_rate: float, learning rate
370
- beta1: float, exponential decay rate for first moment
371
- beta2: float, exponential decay rate for second moment
372
- epsilon: float, numerical stability constant
373
374
Returns:
375
- list: parameter updates
376
"""
377
378
def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
379
beta2=0.999, epsilon=1e-8):
380
"""
381
Adamax optimizer (Adam with infinity norm).
382
383
Parameters:
384
- loss_or_grads: loss function or gradients
385
- params: list, parameters to optimize
386
- learning_rate: float, learning rate
387
- beta1: float, exponential decay rate for first moment
388
- beta2: float, exponential decay rate for second raw moment
389
- epsilon: float, numerical stability constant
390
391
Returns:
392
- list: parameter updates
393
"""
394
395
def sgd(loss_or_grads, params, learning_rate):
396
"""
397
Stochastic gradient descent optimizer.
398
399
Parameters:
400
- loss_or_grads: loss function or gradients
401
- params: list, parameters to optimize
402
- learning_rate: float, learning rate
403
404
Returns:
405
- list: parameter updates
406
"""
407
408
def momentum(loss_or_grads, params, learning_rate, momentum=0.9):
409
"""
410
SGD with momentum optimizer.
411
412
Parameters:
413
- loss_or_grads: loss function or gradients
414
- params: list, parameters to optimize
415
- learning_rate: float, learning rate
416
- momentum: float, momentum coefficient
417
418
Returns:
419
- list: parameter updates
420
"""
421
422
def nesterov_momentum(loss_or_grads, params, learning_rate, momentum=0.9):
423
"""
424
SGD with Nesterov momentum optimizer.
425
426
Parameters:
427
- loss_or_grads: loss function or gradients
428
- params: list, parameters to optimize
429
- learning_rate: float, learning rate
430
- momentum: float, momentum coefficient
431
432
Returns:
433
- list: parameter updates
434
"""
435
436
def adagrad(loss_or_grads, params, learning_rate=1.0, epsilon=1e-6):
437
"""
438
Adagrad optimizer with adaptive learning rates.
439
440
Parameters:
441
- loss_or_grads: loss function or gradients
442
- params: list, parameters to optimize
443
- learning_rate: float, initial learning rate
444
- epsilon: float, numerical stability constant
445
446
Returns:
447
- list: parameter updates
448
"""
449
450
def adagrad_window(loss_or_grads, params, learning_rate=1.0,
451
epsilon=1e-6, n_win=10):
452
"""
453
Adagrad with windowed accumulation.
454
455
Parameters:
456
- loss_or_grads: loss function or gradients
457
- params: list, parameters to optimize
458
- learning_rate: float, initial learning rate
459
- epsilon: float, numerical stability constant
460
- n_win: int, window size for gradient accumulation
461
462
Returns:
463
- list: parameter updates
464
"""
465
466
def adadelta(loss_or_grads, params, learning_rate=1.0, rho=0.95,
467
epsilon=1e-6):
468
"""
469
Adadelta optimizer with parameter-specific learning rates.
470
471
Parameters:
472
- loss_or_grads: loss function or gradients
473
- params: list, parameters to optimize
474
- learning_rate: float, learning rate scaling factor
475
- rho: float, decay rate for moving averages
476
- epsilon: float, numerical stability constant
477
478
Returns:
479
- list: parameter updates
480
"""
481
482
def rmsprop(loss_or_grads, params, learning_rate=0.001, rho=0.9,
483
epsilon=1e-6):
484
"""
485
RMSprop optimizer with adaptive learning rates.
486
487
Parameters:
488
- loss_or_grads: loss function or gradients
489
- params: list, parameters to optimize
490
- learning_rate: float, learning rate
491
- rho: float, decay rate for moving average
492
- epsilon: float, numerical stability constant
493
494
Returns:
495
- list: parameter updates
496
"""
497
498
def apply_momentum(updates, params=None, momentum=0.9):
499
"""
500
Apply momentum to parameter updates.
501
502
Parameters:
503
- updates: dict, parameter updates
504
- params: list, parameters (inferred if None)
505
- momentum: float, momentum coefficient
506
507
Returns:
508
- dict: momentum-adjusted updates
509
"""
510
511
def apply_nesterov_momentum(updates, params=None, momentum=0.9):
512
"""
513
Apply Nesterov momentum to parameter updates.
514
515
Parameters:
516
- updates: dict, parameter updates
517
- params: list, parameters (inferred if None)
518
- momentum: float, momentum coefficient
519
520
Returns:
521
- dict: Nesterov momentum-adjusted updates
522
"""
523
524
def norm_constraint(tensor, max_norm, norm_axes=None, epsilon=1e-7):
525
"""
526
Apply norm constraint to gradients.
527
528
Parameters:
529
- tensor: tensor to constrain
530
- max_norm: float, maximum allowed norm
531
- norm_axes: tuple, axes for norm computation
532
- epsilon: float, numerical stability
533
534
Returns:
535
- tensor: norm-constrained tensor
536
"""
537
538
def total_norm_constraint(tensor_vars, max_norm, epsilon=1e-7,
539
return_norm=False):
540
"""
541
Apply total norm constraint across multiple tensors.
542
543
Parameters:
544
- tensor_vars: list, tensors to constrain
545
- max_norm: float, maximum total norm
546
- epsilon: float, numerical stability
547
- return_norm: bool, return computed norm
548
549
Returns:
550
- list or tuple: constrained tensors and optionally norm
551
"""
552
```
553
554
### Stein Methods
555
556
Stein variational inference and related particle-based methods.
557
558
```python { .api }
559
class Stein:
560
"""
561
Stein method for particle-based inference.
562
563
Base class for Stein variational methods that use
564
reproducing kernel Hilbert spaces for particle updates.
565
"""
566
567
def __init__(self, approx=None, kernel='rbf', **kwargs):
568
"""
569
Initialize Stein method.
570
571
Parameters:
572
- approx: Approximation, empirical approximation
573
- kernel: str or callable, kernel function
574
"""
575
576
def updates(self, obj, **kwargs):
577
"""Compute Stein variational updates."""
578
```
579
580
## Usage Examples
581
582
### Basic ADVI
583
584
```python
585
import pymc3 as pm
586
import numpy as np
587
import theano.tensor as tt
588
589
# Simple model with ADVI
590
with pm.Model() as simple_model:
591
mu = pm.Normal('mu', mu=0, sigma=10)
592
sigma = pm.HalfNormal('sigma', sigma=5)
593
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)
594
595
# Fit using ADVI
596
approx = pm.fit(n=20000, method='advi')
597
598
# Sample from approximation
599
trace_vi = approx.sample(2000)
600
601
# Check convergence
602
print(f"ELBO: {approx.hist[-1]:.2f}")
603
```
604
605
### Full-Rank ADVI
606
607
```python
608
# Full-rank ADVI for correlated parameters
609
with pm.Model() as corr_model:
610
# Correlated parameters
611
theta = pm.MvNormal('theta',
612
mu=np.zeros(3),
613
cov=np.eye(3),
614
shape=3)
615
616
# Transform to create correlation
617
x = tt.dot(theta, correlation_matrix)
618
619
y_obs = pm.Normal('y_obs', mu=x.sum(), sigma=1, observed=target)
620
621
# Full-rank to capture correlations
622
approx = pm.fit(n=30000, method='fullrank_advi')
623
624
# Access covariance structure
625
cov_matrix = approx.cov.eval()
626
print("Posterior covariance:\n", cov_matrix)
627
```
628
629
### Stein Variational Gradient Descent
630
631
```python
632
# SVGD for complex posteriors
633
with pm.Model() as complex_model:
634
# Multi-modal posterior setup
635
mu1 = pm.Normal('mu1', mu=-2, sigma=1)
636
mu2 = pm.Normal('mu2', mu=2, sigma=1)
637
638
# Create multi-modality
639
mixture = pm.Mixture('mixture',
640
w=np.array([0.3, 0.7]),
641
comp_dists=[pm.Normal.dist(mu=mu1, sigma=0.5),
642
pm.Normal.dist(mu=mu2, sigma=0.5)],
643
observed=mixture_data)
644
645
# SVGD with particles
646
approx = pm.fit(n=10000,
647
method='svgd',
648
inf_kwargs={'n_particles': 200})
649
650
# Analyze particle distribution
651
particles = approx.sample(5000)
652
print("Particle means:", particles['mu1'].mean(), particles['mu2'].mean())
653
```
654
655
### Normalizing Flow VI
656
657
```python
658
# Normalizing flow for flexible approximation
659
with pm.Model() as flow_model:
660
# Skewed posterior
661
x = pm.SkewNormal('x', mu=0, sigma=1, alpha=5)
662
y_obs = pm.Normal('y_obs', mu=x**2, sigma=0.5, observed=skewed_data)
663
664
# Normalizing flow approximation
665
approx = pm.fit(n=25000, method='nfvi',
666
inf_kwargs={'flow': 'planar'})
667
668
# Sample from flexible approximation
669
trace_flow = approx.sample(2000)
670
```
671
672
### Minibatch Variational Inference
673
674
```python
675
# Large dataset with minibatch VI
676
batch_size = 500
677
n_data = len(large_dataset)
678
679
# Minibatch containers
680
x_minibatch = pm.Minibatch(X_large, batch_size=batch_size)
681
y_minibatch = pm.Minibatch(y_large, batch_size=batch_size)
682
683
with pm.Model() as minibatch_model:
684
# Model parameters
685
weights = pm.Normal('weights', mu=0, sigma=1, shape=n_features)
686
intercept = pm.Normal('intercept', mu=0, sigma=10)
687
688
# Predictions with minibatch
689
mu = intercept + tt.dot(x_minibatch, weights)
690
691
# Scaled likelihood
692
y_obs = pm.Normal('y_obs', mu=mu, sigma=1,
693
observed=y_minibatch,
694
total_size=n_data)
695
696
# Minibatch ADVI
697
approx = pm.fit(n=50000,
698
method='advi',
699
inf_kwargs={'scale_cost_to_minibatch': True})
700
```
701
702
### Custom Optimization and Callbacks
703
704
```python
705
# Custom optimization with monitoring
706
def elbo_callback(approx, loss, i):
707
"""Custom callback for monitoring ELBO."""
708
if i % 1000 == 0:
709
print(f"Iteration {i}: ELBO = {loss:.2f}")
710
711
# Custom diagnostics
712
samples = approx.sample(100)
713
mean_est = {var: samples[var].mean()
714
for var in samples.varnames}
715
print(f"Current estimates: {mean_est}")
716
717
with pm.Model() as callback_model:
718
theta = pm.Beta('theta', alpha=2, beta=3)
719
y_obs = pm.Binomial('y_obs', n=20, p=theta, observed=successes)
720
721
# Custom ADVI with callbacks
722
inference = pm.ADVI()
723
approx = pm.fit(n=15000,
724
method=inference,
725
callbacks=[elbo_callback])
726
```
727
728
### Hierarchical VI
729
730
```python
731
# Hierarchical model with VI
732
n_groups = 10
733
734
with pm.Model() as hierarchical_vi:
735
# Hyperpriors
736
mu_mu = pm.Normal('mu_mu', mu=0, sigma=10)
737
sigma_mu = pm.HalfNormal('sigma_mu', sigma=5)
738
739
# Group-level parameters
740
group_mu = pm.Normal('group_mu',
741
mu=mu_mu,
742
sigma=sigma_mu,
743
shape=n_groups)
744
745
# Observations
746
y_obs = pm.Normal('y_obs',
747
mu=group_mu[group_idx],
748
sigma=1,
749
observed=hierarchical_data)
750
751
# Hierarchical ADVI
752
approx = pm.fit(n=25000, method='advi')
753
754
# Analyze group-level effects
755
trace_hier = approx.sample(2000)
756
group_effects = trace_hier['group_mu']
757
758
print("Group effect means:", group_effects.mean(axis=0))
759
print("Group effect stds:", group_effects.std(axis=0))
760
```
761
762
### Model Comparison with VI
763
764
```python
765
# Compare models using VI
766
models = {}
767
approx_results = {}
768
769
# Model 1: Simple linear
770
with pm.Model() as model1:
771
alpha = pm.Normal('alpha', 0, 10)
772
beta = pm.Normal('beta', 0, 10)
773
sigma = pm.HalfNormal('sigma', 1)
774
775
mu = alpha + beta * x_data
776
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_data)
777
778
models['linear'] = model1
779
approx_results['linear'] = pm.fit(n=20000)
780
781
# Model 2: Quadratic
782
with pm.Model() as model2:
783
alpha = pm.Normal('alpha', 0, 10)
784
beta1 = pm.Normal('beta1', 0, 10)
785
beta2 = pm.Normal('beta2', 0, 10)
786
sigma = pm.HalfNormal('sigma', 1)
787
788
mu = alpha + beta1 * x_data + beta2 * x_data**2
789
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_data)
790
791
models['quadratic'] = model2
792
approx_results['quadratic'] = pm.fit(n=20000)
793
794
# Compare ELBO values
795
for name, approx in approx_results.items():
796
final_elbo = approx.hist[-1]
797
print(f"{name} model ELBO: {final_elbo:.2f}")
798
```
799
800
### Advanced Approximation Analysis
801
802
```python
803
# Detailed approximation analysis
804
with pm.Model() as analysis_model:
805
mu = pm.Normal('mu', mu=0, sigma=5)
806
tau = pm.Gamma('tau', alpha=2, beta=1)
807
y_obs = pm.Normal('y_obs', mu=mu, tau=tau, observed=obs_data)
808
809
# Fit approximation
810
approx = pm.fit(n=30000, method='fullrank_advi')
811
812
# Extract approximation parameters
813
mean_params = approx.mean.eval()
814
std_params = approx.std.eval()
815
cov_params = approx.cov.eval()
816
817
print("Variational means:", mean_params)
818
print("Variational stds:", std_params)
819
print("Variational covariance:\n", cov_params)
820
821
# Sample and compare to MCMC
822
vi_samples = approx.sample(5000)
823
mcmc_samples = pm.sample(2000, tune=1000, chains=2)
824
825
# Diagnostic comparison
826
import arviz as az
827
comparison = az.compare({
828
'VI': vi_samples,
829
'MCMC': mcmc_samples
830
})
831
print(comparison)
832
```