or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data.mddistributions.mdgp.mdindex.mdmath.mdmodel.mdode.mdsampling.mdstats.mdvariational.md

variational.mddocs/

0

# PyMC Variational Inference

1

2

PyMC provides comprehensive variational inference methods for fast approximate Bayesian inference. Variational methods are particularly useful for large datasets and complex models where MCMC sampling may be computationally prohibitive.

3

4

## Main Variational Interface

5

6

### Primary Fitting Function

7

8

```python { .api }

9

import pymc as pm

10

11

def fit(n=10000, method='advi', model=None, random_seed=None,

12

start=None, inf_kwargs=None, **kwargs):

13

"""

14

Fit variational approximation to the posterior.

15

16

Parameters:

17

- n (int): Number of optimization iterations (default: 10000)

18

- method (str): Inference method ('advi', 'fullrank_advi', 'svgd', 'asvgd')

19

- model: PyMC model (default: current context model)

20

- random_seed (int): Random seed for reproducibility

21

- start (dict): Starting parameter values

22

- inf_kwargs (dict): Method-specific keyword arguments

23

24

Returns:

25

- approximation: Fitted variational approximation object

26

"""

27

28

# Basic variational inference

29

with pm.Model() as model:

30

# Define model...

31

approx = pm.fit(n=50000)

32

33

# Advanced configuration

34

approx = pm.fit(

35

n=100000,

36

method='fullrank_advi',

37

optimizer=pm.adam(learning_rate=0.01),

38

callbacks=[pm.CheckParametersConvergence()],

39

progressbar=True

40

)

41

```

42

43

## Automatic Differentiation Variational Inference (ADVI)

44

45

### Mean-Field ADVI

46

47

The default variational inference method using mean-field approximation:

48

49

```python { .api }

50

from pymc.variational import ADVI

51

52

class ADVI:

53

"""

54

Automatic Differentiation Variational Inference with mean-field approximation.

55

56

Parameters:

57

- model: PyMC model

58

- random_seed (int): Random seed

59

- start (dict): Initial parameter values

60

61

Methods:

62

- fit: Optimize variational parameters

63

- sample: Draw samples from approximation

64

"""

65

66

def __init__(self, model=None, random_seed=None, start=None):

67

pass

68

69

def fit(self, n, optimizer=None, callbacks=None, progressbar=True, **kwargs):

70

"""

71

Fit the variational approximation.

72

73

Parameters:

74

- n (int): Number of optimization steps

75

- optimizer: Optimization algorithm

76

- callbacks (list): Callback functions

77

- progressbar (bool): Show progress bar

78

79

Returns:

80

- approximation: Fitted approximation

81

"""

82

pass

83

84

# Explicit ADVI usage

85

with pm.Model() as model:

86

# Model definition...

87

88

# Create ADVI inference object

89

inference = pm.ADVI()

90

91

# Fit approximation

92

approx = inference.fit(n=50000, optimizer=pm.adam(learning_rate=0.01))

93

94

# Draw samples from approximation

95

trace = approx.sample(2000)

96

```

97

98

### Full-Rank ADVI

99

100

ADVI with full covariance structure:

101

102

```python { .api }

103

from pymc.variational import FullRankADVI

104

105

class FullRankADVI:

106

"""

107

Full-rank ADVI with correlated posterior approximation.

108

109

Parameters:

110

- model: PyMC model

111

- random_seed (int): Random seed

112

"""

113

114

# Full-rank approximation for capturing correlations

115

with pm.Model() as model:

116

# Model with correlated parameters...

117

118

inference = pm.FullRankADVI()

119

approx = inference.fit(n=75000)

120

121

# Full covariance matrix available

122

cov_matrix = approx.cov.eval()

123

```

124

125

## Stein Variational Gradient Descent

126

127

### Standard SVGD

128

129

Particle-based variational inference:

130

131

```python { .api }

132

from pymc.variational import SVGD

133

134

class SVGD:

135

"""

136

Stein Variational Gradient Descent.

137

138

Parameters:

139

- n_particles (int): Number of particles (default: 100)

140

- jitter (float): Jitter for numerical stability

141

- model: PyMC model

142

"""

143

144

def __init__(self, n_particles=100, jitter=1e-6, model=None):

145

pass

146

147

# SVGD for complex posteriors

148

with pm.Model() as complex_model:

149

# Complex model definition...

150

151

inference = pm.SVGD(n_particles=200)

152

approx = inference.fit(n=20000)

153

154

# Particles represent the posterior

155

particles = approx.sample(1000)

156

```

157

158

### Amortized SVGD

159

160

```python { .api }

161

from pymc.variational import ASVGD

162

163

class ASVGD:

164

"""

165

Amortized Stein Variational Gradient Descent.

166

167

Parameters:

168

- n_particles (int): Number of particles

169

- batch_size (int): Mini-batch size

170

"""

171

172

# ASVGD for large datasets with mini-batching

173

with pm.Model() as large_model:

174

# Model with large dataset...

175

176

inference = pm.ASVGD(n_particles=50, batch_size=128)

177

approx = inference.fit(n=30000)

178

```

179

180

## Variational Approximations

181

182

### Mean-Field Approximation

183

184

Independent normal distributions for each parameter:

185

186

```python { .api }

187

from pymc.variational.approximations import MeanField

188

189

class MeanField:

190

"""

191

Mean-field approximation with independent normal distributions.

192

193

Parameters:

194

- local_rv (dict): Local random variables

195

- model: PyMC model

196

197

Methods:

198

- sample: Draw samples from approximation

199

- apply_replacements: Apply variational replacements

200

"""

201

202

def sample(self, draws=1000, include_transformed=True):

203

"""

204

Sample from mean-field approximation.

205

206

Parameters:

207

- draws (int): Number of samples to draw

208

- include_transformed (bool): Include transformed variables

209

210

Returns:

211

- samples: Dictionary of parameter samples

212

"""

213

pass

214

215

# Access approximation directly

216

with pm.Model() as model:

217

# Model definition...

218

219

# Create mean-field approximation

220

mean_field = pm.MeanField()

221

222

# Fit using KL divergence minimization

223

approx = pm.KLqp(mean_field).fit(n=50000)

224

```

225

226

### Full-Rank Approximation

227

228

Multivariate normal with full covariance:

229

230

```python { .api }

231

from pymc.variational.approximations import FullRank

232

233

class FullRank:

234

"""

235

Full-rank multivariate normal approximation.

236

237

Parameters:

238

- local_rv (dict): Local random variables

239

- model: PyMC model

240

241

Attributes:

242

- cov: Covariance matrix

243

- mean: Mean vector

244

"""

245

246

# Full-rank for capturing parameter correlations

247

with pm.Model() as correlated_model:

248

# Model with strong parameter correlations...

249

250

full_rank = pm.FullRank()

251

approx = pm.KLqp(full_rank).fit(n=75000)

252

253

# Access covariance structure

254

posterior_cov = approx.cov.eval()

255

posterior_corr = approx.std_to_corr(posterior_cov)

256

```

257

258

### Empirical Approximation

259

260

Empirical distribution from particle samples:

261

262

```python { .api }

263

from pymc.variational.approximations import Empirical

264

265

class Empirical:

266

"""

267

Empirical approximation using particle samples.

268

269

Parameters:

270

- local_rv (dict): Local random variables

271

- size (int): Number of particles

272

"""

273

274

# Empirical approximation from SVGD

275

with pm.Model() as model:

276

# Model definition...

277

278

empirical = pm.Empirical(size=500)

279

approx = pm.SVGD(approximation=empirical).fit(n=25000)

280

```

281

282

## Optimization Algorithms

283

284

### Adam Optimizer

285

286

```python { .api }

287

def adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):

288

"""

289

Adam optimizer for variational inference.

290

291

Parameters:

292

- learning_rate (float): Step size

293

- beta1 (float): Exponential decay rate for 1st moment

294

- beta2 (float): Exponential decay rate for 2nd moment

295

- epsilon (float): Small constant for numerical stability

296

297

Returns:

298

- optimizer: Adam optimizer object

299

"""

300

301

# Custom Adam configuration

302

optimizer = pm.adam(

303

learning_rate=0.005,

304

beta1=0.95,

305

beta2=0.999

306

)

307

308

approx = pm.fit(n=50000, optimizer=optimizer)

309

```

310

311

### Other Optimizers

312

313

```python { .api }

314

# Stochastic Gradient Descent

315

sgd_optimizer = pm.sgd(learning_rate=0.01)

316

317

# AdaGrad

318

adagrad_optimizer = pm.adagrad(learning_rate=0.1)

319

320

# RMSprop

321

rmsprop_optimizer = pm.rmsprop(learning_rate=0.001, decay=0.9)

322

323

# Adamax

324

adamax_optimizer = pm.adamax(learning_rate=0.002)

325

326

# AdaDelta

327

adadelta_optimizer = pm.adadelta(learning_rate=1.0, decay=0.95)

328

```

329

330

## Advanced Variational Methods

331

332

### Custom Inference Classes

333

334

```python { .api }

335

from pymc.variational.inference import KLqp, Inference

336

337

class KLqp(Inference):

338

"""

339

Kullback-Leibler divergence minimization.

340

341

Parameters:

342

- approx: Variational approximation

343

- beta (float): Regularization parameter

344

"""

345

346

def __init__(self, approx, beta=1.0):

347

pass

348

349

# Custom inference setup

350

with pm.Model() as model:

351

# Model definition...

352

353

# Custom approximation

354

custom_approx = pm.MeanField()

355

356

# KL divergence inference

357

inference = pm.KLqp(custom_approx, beta=0.9)

358

approx = inference.fit(n=40000)

359

```

360

361

### Implicit Gradient Methods

362

363

```python { .api }

364

from pymc.variational.inference import ImplicitGradient

365

366

class ImplicitGradient(Inference):

367

"""

368

Implicit gradient variational inference.

369

370

Parameters:

371

- approx: Variational approximation

372

- tk (float): Temperature parameter

373

"""

374

375

# Implicit gradient inference for difficult posteriors

376

with pm.Model() as difficult_model:

377

# Model with complex geometry...

378

379

implicit = pm.ImplicitGradient(pm.MeanField(), tk=1.5)

380

approx = implicit.fit(n=60000)

381

```

382

383

## Variational Groups and Structured Approximations

384

385

### Grouping Variables

386

387

```python { .api }

388

from pymc.variational.opvi import Group

389

390

class Group:

391

"""

392

Group variables for structured approximations.

393

394

Parameters:

395

- group_vars (list): Variables in the group

396

- approximation: Group-specific approximation

397

"""

398

399

# Group correlated parameters together

400

with pm.Model() as hierarchical_model:

401

# Hierarchical model...

402

403

# Group 1: Hyperparameters (mean-field)

404

hyper_group = pm.Group([mu_alpha, sigma_alpha], pm.MeanField())

405

406

# Group 2: Group effects (full-rank)

407

group_effects = pm.Group([alpha], pm.FullRank())

408

409

# Combined approximation

410

approximation = hyper_group + group_effects

411

approx = pm.KLqp(approximation).fit(n=50000)

412

```

413

414

## Callbacks and Monitoring

415

416

### Built-in Callbacks

417

418

```python { .api }

419

# Parameter convergence monitoring

420

convergence_cb = pm.CheckParametersConvergence(tolerance=0.01)

421

422

# Early stopping

423

early_stop_cb = pm.CheckParametersConvergence(tolerance=0.001, patience=5000)

424

425

# Custom callback function

426

def custom_callback(approx, loss_history, i):

427

if i % 1000 == 0:

428

current_loss = loss_history[-1]

429

print(f"Iteration {i}: Loss = {current_loss:.4f}")

430

431

# Use callbacks during fitting

432

approx = pm.fit(

433

n=50000,

434

callbacks=[convergence_cb, custom_callback],

435

progressbar=True

436

)

437

```

438

439

## Sampling from Approximations

440

441

### Drawing Samples

442

443

```python { .api }

444

def sample_approx(n, approximation, more_replacements=None,

445

return_inferencedata=True, **kwargs):

446

"""

447

Sample from variational approximation.

448

449

Parameters:

450

- n (int): Number of samples

451

- approximation: Fitted approximation

452

- more_replacements (dict): Additional variable replacements

453

- return_inferencedata (bool): Return ArviZ InferenceData

454

455

Returns:

456

- samples: Samples from approximation

457

"""

458

459

# Sample from fitted approximation

460

samples = pm.sample_approx(n=5000, approximation=approx)

461

462

# Sample with additional replacements

463

custom_samples = pm.sample_approx(

464

n=3000,

465

approximation=approx,

466

more_replacements={'custom_var': custom_replacement}

467

)

468

```

469

470

### Integration with MCMC

471

472

```python { .api }

473

# Use variational approximation to initialize MCMC

474

with pm.Model() as model:

475

# Model definition...

476

477

# Fit variational approximation

478

approx = pm.fit(n=30000)

479

480

# Use as MCMC initialization

481

vi_samples = approx.sample(1000)

482

start_point = {var: samples[var][-1] for var, samples in vi_samples.items()}

483

484

# MCMC with VI initialization

485

mcmc_trace = pm.sample(initvals=start_point, tune=1000, draws=2000)

486

```

487

488

## Model Comparison and Diagnostics

489

490

### ELBO Monitoring

491

492

```python { .api }

493

# Track Evidence Lower Bound during optimization

494

with pm.Model() as model:

495

# Model definition...

496

497

# Fit with ELBO tracking

498

approx = pm.fit(n=50000, progressbar=True)

499

500

# Access ELBO history

501

elbo_history = approx.hist

502

503

# Plot convergence

504

import matplotlib.pyplot as plt

505

plt.plot(elbo_history)

506

plt.xlabel('Iteration')

507

plt.ylabel('ELBO')

508

plt.title('Variational Inference Convergence')

509

```

510

511

### Approximation Quality Assessment

512

513

```python { .api }

514

# Compare VI approximation with true posterior (if available)

515

def assess_approximation_quality(approx, true_trace, var_names):

516

"""Compare VI approximation with MCMC samples."""

517

vi_samples = approx.sample(5000)

518

519

for var in var_names:

520

vi_mean = vi_samples[var].mean()

521

vi_std = vi_samples[var].std()

522

523

mcmc_mean = true_trace[var].mean()

524

mcmc_std = true_trace[var].std()

525

526

print(f"{var}:")

527

print(f" VI: mean={vi_mean:.3f}, std={vi_std:.3f}")

528

print(f" MCMC: mean={mcmc_mean:.3f}, std={mcmc_std:.3f}")

529

530

# Usage

531

assess_approximation_quality(approx, mcmc_trace, ['alpha', 'beta', 'sigma'])

532

```

533

534

## Large-Scale Variational Inference

535

536

### Mini-batch Variational Inference

537

538

```python { .api }

539

# Mini-batch VI for large datasets

540

with pm.Model() as large_scale_model:

541

# Large dataset

542

X_mb = pm.Minibatch(X_large, batch_size=256)

543

y_mb = pm.Minibatch(y_large, batch_size=256)

544

545

# Model with mini-batched data

546

alpha = pm.Normal('alpha', 0, 1)

547

beta = pm.Normal('beta', 0, 1, shape=p)

548

mu = alpha + pm.math.dot(X_mb, beta)

549

550

# Scale likelihood for mini-batching

551

n_total = X_large.shape[0]

552

batch_size = 256

553

scaling_factor = n_total / batch_size

554

555

y_obs = pm.Normal('y_obs', mu=mu, sigma=1, observed=y_mb,

556

total_size=n_total)

557

558

# Variational inference with mini-batches

559

approx = pm.fit(n=100000, method='advi')

560

```

561

562

### Parallel Variational Inference

563

564

```python { .api }

565

# Parallel VI with multiple chains

566

import multiprocessing as mp

567

568

with pm.Model() as model:

569

# Model definition...

570

571

# Parallel VI approximations

572

n_chains = mp.cpu_count()

573

approximations = []

574

575

for chain in range(n_chains):

576

approx_chain = pm.fit(

577

n=25000,

578

random_seed=chain,

579

progressbar=False

580

)

581

approximations.append(approx_chain)

582

583

# Combine approximations (ensemble)

584

ensemble_samples = []

585

for approx in approximations:

586

samples = approx.sample(1000)

587

ensemble_samples.append(samples)

588

```

589

590

## Usage Patterns and Best Practices

591

592

### Hierarchical Models with VI

593

594

```python { .api }

595

# Efficient VI for hierarchical models

596

with pm.Model() as hierarchical_vi:

597

# Hyperparameters

598

mu_mu = pm.Normal('mu_mu', 0, 10)

599

sigma_mu = pm.HalfNormal('sigma_mu', 5)

600

601

# Group parameters (non-centered parameterization)

602

mu_raw = pm.Normal('mu_raw', 0, 1, shape=n_groups)

603

mu = pm.Deterministic('mu', mu_mu + sigma_mu * mu_raw)

604

605

# Likelihood

606

y_obs = pm.Normal('y_obs', mu=mu[group_idx], sigma=1, observed=data)

607

608

# VI works well with non-centered parameterization

609

approx = pm.fit(n=50000, method='advi')

610

```

611

612

### Model Selection with Variational Methods

613

614

```python { .api }

615

# Compare models using variational inference

616

models_vi = {}

617

approximations = {}

618

619

for model_name, model in candidate_models.items():

620

with model:

621

approx = pm.fit(n=40000)

622

approximations[model_name] = approx

623

624

# Store ELBO for comparison

625

models_vi[model_name] = {

626

'elbo': approx.hist[-1],

627

'n_params': len(model.free_RVs),

628

'approximation': approx

629

}

630

631

# Select best model by ELBO

632

best_model = max(models_vi.keys(), key=lambda k: models_vi[k]['elbo'])

633

```

634

635

PyMC's variational inference framework provides efficient approximate inference methods suitable for large-scale Bayesian modeling, offering significant computational advantages over MCMC while maintaining reasonable approximation quality for many practical applications.