or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-programming.mddistributions.mdgaussian-processes.mdindex.mdinference.mdneural-networks.mdoptimization.mdtransforms-constraints.md

inference.mddocs/

0

# Inference Methods

1

2

Scalable inference algorithms for posterior approximation and model learning, including variational inference, Markov Chain Monte Carlo, and specialized sampling methods for probabilistic programs.

3

4

## Capabilities

5

6

### Stochastic Variational Inference

7

8

Gradient-based variational inference for scalable approximate posterior computation.

9

10

```python { .api }

11

class SVI:

12

"""

13

Stochastic Variational Inference for scalable posterior approximation.

14

15

SVI optimizes variational parameters to minimize the KL divergence between

16

a variational guide and the true posterior distribution.

17

"""

18

19

def __init__(self, model, guide, optim, loss):

20

"""

21

Initialize SVI with model, guide, optimizer and loss function.

22

23

Parameters:

24

- model (callable): Generative model function

25

- guide (callable): Variational guide function that approximates posterior

26

- optim (PyroOptim): Pyro optimizer wrapping PyTorch optimizer

27

- loss (ELBO): Evidence Lower Bound loss function

28

29

Examples:

30

>>> svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())

31

"""

32

33

def step(self, *args, **kwargs) -> float:

34

"""

35

Perform one SVI optimization step.

36

37

Parameters:

38

- *args, **kwargs: Arguments to pass to model and guide

39

40

Returns:

41

float: Loss value for this step (negative ELBO)

42

43

Examples:

44

>>> loss = svi.step(data)

45

>>> print(f"Loss: {loss}")

46

"""

47

48

def evaluate_loss(self, *args, **kwargs) -> float:

49

"""

50

Evaluate loss without taking optimization step.

51

52

Parameters:

53

- *args, **kwargs: Arguments to pass to model and guide

54

55

Returns:

56

float: Current loss value

57

"""

58

59

def init_to_feasible(site: dict = None) -> torch.Tensor:

60

"""

61

Initialize parameters to feasible values within constraints.

62

63

Parameters:

64

- site (dict, optional): Sample site information

65

66

Returns:

67

Tensor: Feasible initialization value

68

"""

69

70

def init_to_mean(site: dict = None) -> torch.Tensor:

71

"""

72

Initialize parameters to distribution mean.

73

74

Parameters:

75

- site (dict, optional): Sample site information

76

77

Returns:

78

Tensor: Mean initialization value

79

"""

80

81

def init_to_sample(site: dict = None) -> torch.Tensor:

82

"""

83

Initialize parameters to random samples from prior.

84

85

Parameters:

86

- site (dict, optional): Sample site information

87

88

Returns:

89

Tensor: Random sample initialization

90

"""

91

```

92

93

### Evidence Lower Bound (ELBO)

94

95

Loss functions for variational inference based on the evidence lower bound.

96

97

```python { .api }

98

class ELBO:

99

"""

100

Base class for Evidence Lower Bound loss functions.

101

102

ELBO provides a lower bound on the model evidence (marginal likelihood)

103

and serves as the optimization objective for variational inference.

104

"""

105

106

def differentiable_loss(self, model, guide, *args, **kwargs) -> torch.Tensor:

107

"""

108

Compute differentiable ELBO loss.

109

110

Returns:

111

Tensor: Negative ELBO (loss to minimize)

112

"""

113

114

class Trace_ELBO(ELBO):

115

"""

116

Standard trace-based ELBO implementation.

117

118

Uses execution traces to compute ELBO via the log probability of the

119

joint model minus the log probability of the guide.

120

"""

121

122

def __init__(self, num_particles: int = 1, max_plate_nesting: int = float('inf'),

123

max_iarange_nesting: int = None, vectorize_particles: bool = False,

124

strict_enumeration_warning: bool = True):

125

"""

126

Parameters:

127

- num_particles (int): Number of Monte Carlo samples for gradient estimation

128

- max_plate_nesting (int): Maximum depth of nested plates to vectorize over

129

- vectorize_particles (bool): Whether to vectorize over particles

130

- strict_enumeration_warning (bool): Whether to warn about enumeration issues

131

"""

132

133

class TraceEnum_ELBO(ELBO):

134

"""

135

ELBO with exact enumeration over discrete latent variables.

136

137

Computes exact expectations over discrete variables while using

138

Monte Carlo for continuous variables.

139

"""

140

141

def __init__(self, max_plate_nesting: int = float('inf'), max_iarange_nesting: int = None,

142

strict_enumeration_warning: bool = True, ignore_jit_warnings: bool = False):

143

"""

144

Parameters:

145

- max_plate_nesting (int): Maximum plate nesting depth for enumeration

146

- strict_enumeration_warning (bool): Whether to warn about enumeration issues

147

- ignore_jit_warnings (bool): Whether to ignore JIT compilation warnings

148

"""

149

150

class TraceGraph_ELBO(ELBO):

151

"""

152

Memory-efficient ELBO using dependency graphs.

153

154

Reduces memory usage by computing gradients using the dependency

155

structure of the computational graph.

156

"""

157

pass

158

159

class TraceMeanField_ELBO(ELBO):

160

"""

161

ELBO for mean-field variational inference.

162

163

Assumes independence between latent variables in the guide,

164

enabling more efficient computation.

165

"""

166

pass

167

168

class RenyiELBO(ELBO):

169

"""

170

Renyi divergence-based ELBO for more robust inference.

171

172

Uses Renyi alpha-divergence instead of KL divergence for

173

potentially better optimization properties.

174

"""

175

176

def __init__(self, alpha: float = 0.0, num_particles: int = 2, max_plate_nesting: int = float('inf')):

177

"""

178

Parameters:

179

- alpha (float): Renyi divergence parameter (alpha=0 gives KL divergence)

180

- num_particles (int): Number of particles for gradient estimation

181

- max_plate_nesting (int): Maximum plate nesting depth

182

"""

183

```

184

185

### Markov Chain Monte Carlo

186

187

MCMC methods for exact sampling from posterior distributions.

188

189

```python { .api }

190

class MCMC:

191

"""

192

Markov Chain Monte Carlo interface for exact posterior sampling.

193

194

MCMC generates correlated samples from the exact posterior distribution

195

using various kernel methods like HMC and NUTS.

196

"""

197

198

def __init__(self, kernel, num_samples: int, warmup_steps: int = None,

199

initial_params: dict = None, chain_id: int = 0, mp_context=None,

200

disable_progbar: bool = False, disable_validation: bool = True,

201

transforms: dict = None, max_tree_depth: int = None,

202

target_accept_prob: float = 0.8, jit_compile: bool = False):

203

"""

204

Parameters:

205

- kernel: MCMC kernel (e.g., HMC, NUTS, RandomWalkKernel)

206

- num_samples (int): Number of MCMC samples to generate

207

- warmup_steps (int): Number of warmup/burn-in steps

208

- initial_params (dict): Initial parameter values

209

- chain_id (int): Chain identifier for multiple chains

210

- transforms (dict): Parameter transforms for constrained sampling

211

- target_accept_prob (float): Target acceptance probability for adaptive kernels

212

- jit_compile (bool): Whether to JIT compile the kernel

213

214

Examples:

215

>>> kernel = NUTS(model)

216

>>> mcmc = MCMC(kernel, num_samples=1000, warmup_steps=500)

217

"""

218

219

def run(self, *args, **kwargs):

220

"""

221

Run the MCMC chain.

222

223

Parameters:

224

- *args, **kwargs: Arguments to pass to the model

225

226

Examples:

227

>>> mcmc.run(data)

228

>>> samples = mcmc.get_samples()

229

"""

230

231

def get_samples(self, group_by_chain: bool = False) -> dict:

232

"""

233

Get MCMC samples after running the chain.

234

235

Parameters:

236

- group_by_chain (bool): Whether to group samples by chain

237

238

Returns:

239

dict: Dictionary mapping sample site names to sample tensors

240

241

Examples:

242

>>> samples = mcmc.get_samples()

243

>>> theta_samples = samples["theta"]

244

"""

245

246

class HMC:

247

"""

248

Hamiltonian Monte Carlo kernel.

249

250

HMC uses gradient information to make efficient proposals in

251

continuous parameter spaces.

252

"""

253

254

def __init__(self, model, step_size: float = 1.0, num_steps: int = 1,

255

adapt_step_size: bool = True, adapt_mass_matrix: bool = True,

256

full_mass: bool = False, transforms: dict = None,

257

max_plate_nesting: int = None, jit_compile: bool = False,

258

jit_options: dict = None, ignore_jit_warnings: bool = False):

259

"""

260

Parameters:

261

- model (callable): Model to sample from

262

- step_size (float): Integration step size

263

- num_steps (int): Number of leapfrog steps per iteration

264

- adapt_step_size (bool): Whether to adapt step size during warmup

265

- adapt_mass_matrix (bool): Whether to adapt mass matrix

266

- full_mass (bool): Whether to use full mass matrix (vs diagonal)

267

- transforms (dict): Parameter transformations

268

"""

269

270

class NUTS:

271

"""

272

No-U-Turn Sampler, an adaptive version of HMC.

273

274

NUTS automatically determines the number of leapfrog steps to take

275

by detecting when the trajectory starts to reverse direction.

276

"""

277

278

def __init__(self, model, step_size: float = 1.0, adapt_step_size: bool = True,

279

adapt_mass_matrix: bool = True, full_mass: bool = False,

280

transforms: dict = None, max_plate_nesting: int = None,

281

max_tree_depth: int = 10, target_accept_prob: float = 0.8,

282

jit_compile: bool = False, jit_options: dict = None,

283

ignore_jit_warnings: bool = False):

284

"""

285

Parameters:

286

- model (callable): Model to sample from

287

- step_size (float): Initial step size

288

- max_tree_depth (int): Maximum binary tree depth

289

- target_accept_prob (float): Target acceptance probability for adaptation

290

291

Examples:

292

>>> nuts_kernel = NUTS(model)

293

>>> mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)

294

"""

295

296

class RandomWalkKernel:

297

"""

298

Random walk Metropolis-Hastings kernel.

299

300

Simple MCMC kernel that proposes new states by adding random noise

301

to the current state.

302

"""

303

304

def __init__(self, model, step_size: dict = None, adapt_step_size: bool = True,

305

transforms: dict = None, max_plate_nesting: int = None):

306

"""

307

Parameters:

308

- model (callable): Model to sample from

309

- step_size (dict): Step sizes for each parameter

310

- adapt_step_size (bool): Whether to adapt step size during warmup

311

"""

312

```

313

314

### Predictive Sampling

315

316

Generate predictions and samples from trained models.

317

318

```python { .api }

319

class Predictive:

320

"""

321

Generate predictive samples from posterior or prior distributions.

322

323

Predictive enables posterior predictive checks, prior predictive checks,

324

and out-of-sample predictions by sampling from the model with different

325

parameter configurations.

326

"""

327

328

def __init__(self, model, guide=None, posterior_samples: dict = None,

329

num_samples: int = None, return_sites: list = None,

330

parallel: bool = False, batch_ndims: int = 1):

331

"""

332

Parameters:

333

- model (callable): Generative model function

334

- guide (callable, optional): Variational guide for posterior sampling

335

- posterior_samples (dict, optional): Pre-computed posterior samples

336

- num_samples (int, optional): Number of samples to generate

337

- return_sites (list, optional): Sites to include in output

338

- parallel (bool): Whether to parallelize sampling

339

- batch_ndims (int): Number of batch dimensions

340

341

Examples:

342

>>> # Posterior predictive with guide

343

>>> predictive = Predictive(model, guide=guide, num_samples=1000)

344

>>> samples = predictive(data)

345

>>>

346

>>> # Prior predictive

347

>>> predictive = Predictive(model, num_samples=100)

348

>>> prior_samples = predictive(data)

349

"""

350

351

def __call__(self, *args, **kwargs) -> dict:

352

"""

353

Generate predictive samples.

354

355

Parameters:

356

- *args, **kwargs: Arguments to pass to the model

357

358

Returns:

359

dict: Dictionary mapping site names to sample tensors

360

361

Examples:

362

>>> samples = predictive(test_data)

363

>>> predictions = samples["obs"]

364

"""

365

366

class WeighedPredictive:

367

"""

368

Generate weighted predictive samples using importance sampling.

369

370

Useful when posterior samples come from importance sampling or

371

when samples have non-uniform weights.

372

"""

373

374

def __init__(self, model, guide=None, posterior_samples: dict = None,

375

weights: torch.Tensor = None, num_samples: int = None,

376

return_sites: list = None, parallel: bool = False):

377

"""

378

Parameters:

379

- model (callable): Generative model function

380

- guide (callable, optional): Guide function

381

- posterior_samples (dict, optional): Pre-computed samples

382

- weights (Tensor, optional): Sample weights

383

- num_samples (int, optional): Number of samples to generate

384

"""

385

386

class EmpiricalMarginal:

387

"""

388

Empirical marginal distribution from MCMC or SVI samples.

389

390

Converts a collection of samples into a distribution object that

391

can be used like any other Pyro distribution.

392

"""

393

394

def __init__(self, samples: torch.Tensor, log_weights: torch.Tensor = None):

395

"""

396

Parameters:

397

- samples (Tensor): Sample values

398

- log_weights (Tensor, optional): Log weights for samples

399

400

Examples:

401

>>> samples = mcmc.get_samples()["theta"]

402

>>> marginal = EmpiricalMarginal(samples)

403

>>> new_sample = marginal.sample()

404

"""

405

```

406

407

### Importance Sampling

408

409

Importance sampling methods for model comparison and marginal likelihood estimation.

410

411

```python { .api }

412

class Importance:

413

"""

414

Importance sampling for marginal likelihood estimation.

415

416

Uses importance sampling to estimate the model evidence (marginal likelihood)

417

which is useful for model comparison and selection.

418

"""

419

420

def __init__(self, model, guide, num_samples: int):

421

"""

422

Parameters:

423

- model (callable): Generative model

424

- guide (callable): Importance sampling distribution (proposal)

425

- num_samples (int): Number of importance samples

426

427

Examples:

428

>>> importance = Importance(model, guide, num_samples=10000)

429

>>> log_evidence = importance.run(data)

430

"""

431

432

def run(self, *args, **kwargs) -> torch.Tensor:

433

"""

434

Run importance sampling to estimate log marginal likelihood.

435

436

Parameters:

437

- *args, **kwargs: Arguments to pass to model and guide

438

439

Returns:

440

Tensor: Log marginal likelihood estimate

441

"""

442

443

class SMCFilter:

444

"""

445

Sequential Monte Carlo filtering for state space models.

446

447

Implements particle filtering for sequential Bayesian inference

448

in time series and state space models.

449

"""

450

451

def __init__(self, model, guide, num_particles: int, max_plate_nesting: int):

452

"""

453

Parameters:

454

- model (callable): State space model

455

- guide (callable): Proposal distribution for particles

456

- num_particles (int): Number of particles to maintain

457

- max_plate_nesting (int): Maximum plate nesting depth

458

"""

459

```

460

461

### Specialized Inference Methods

462

463

Advanced inference algorithms for specific model types and scenarios.

464

465

```python { .api }

466

class SVGD:

467

"""

468

Stein Variational Gradient Descent for non-parametric inference.

469

470

SVGD optimizes a set of particles to approximate the posterior distribution

471

using kernelized Stein discrepancy minimization.

472

"""

473

474

def __init__(self, model, kernel, optimizer, num_particles: int):

475

"""

476

Parameters:

477

- model (callable): Model function

478

- kernel: Kernel function for Stein method

479

- optimizer: Optimizer for particle updates

480

- num_particles (int): Number of particles to optimize

481

"""

482

483

class ReweightedWakeSleep:

484

"""

485

Reweighted Wake-Sleep algorithm for deep generative models.

486

487

Alternative to standard variational inference that can handle

488

more complex posterior approximations.

489

"""

490

491

def __init__(self, model, guide, wake_loss, sleep_loss):

492

"""

493

Parameters:

494

- model (callable): Generative model

495

- guide (callable): Recognition model

496

- wake_loss: Loss function for wake phase

497

- sleep_loss: Loss function for sleep phase

498

"""

499

500

def config_enumerate(default: str = None, expand: bool = False, num_samples: int = None):

501

"""

502

Configure automatic enumeration over discrete latent variables.

503

504

Decorator that enables exact marginalization over discrete variables

505

in models with both discrete and continuous latent variables.

506

507

Parameters:

508

- default (str): Default enumeration strategy ("sequential" or "parallel")

509

- expand (bool): Whether to expand enumerated dimensions

510

- num_samples (int): Number of samples for approximate enumeration

511

512

Examples:

513

>>> @config_enumerate

514

>>> def model():

515

... z = pyro.sample("z", dist.Categorical(torch.ones(3)))

516

... return pyro.sample("x", dist.Normal(z, 1))

517

"""

518

519

def infer_discrete(first_available_dim: int = None, temperature: float = 1.0,

520

cooler: callable = None):

521

"""

522

Infer discrete latent variables by enumeration or sampling.

523

524

Effect handler that automatically handles discrete variable inference

525

by choosing between exact enumeration and approximate sampling.

526

527

Parameters:

528

- first_available_dim (int): First tensor dimension available for enumeration

529

- temperature (float): Temperature for discrete sampling

530

- cooler (callable): Cooling schedule for simulated annealing

531

532

Examples:

533

>>> with infer_discrete():

534

... svi.step(data)

535

"""

536

```

537

538

## Examples

539

540

### Basic SVI Training

541

542

```python

543

import pyro

544

import pyro.distributions as dist

545

from pyro.infer import SVI, Trace_ELBO

546

from pyro.optim import Adam

547

548

def model(data):

549

mu = pyro.sample("mu", dist.Normal(0, 10))

550

sigma = pyro.sample("sigma", dist.LogNormal(0, 1))

551

552

with pyro.plate("data", len(data)):

553

pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

554

555

def guide(data):

556

mu_q = pyro.param("mu_q", torch.tensor(0.0))

557

sigma_q = pyro.param("sigma_q", torch.tensor(1.0), constraint=dist.constraints.positive)

558

559

pyro.sample("mu", dist.Normal(mu_q, sigma_q))

560

pyro.sample("sigma", dist.LogNormal(0, 1)) # Use prior as guide

561

562

# Training

563

svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())

564

losses = []

565

for step in range(1000):

566

loss = svi.step(data)

567

losses.append(loss)

568

```

569

570

### MCMC Sampling

571

572

```python

573

from pyro.infer import MCMC, NUTS

574

575

def model(data):

576

mu = pyro.sample("mu", dist.Normal(0, 10))

577

sigma = pyro.sample("sigma", dist.LogNormal(0, 1))

578

579

with pyro.plate("data", len(data)):

580

pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

581

582

# MCMC sampling

583

nuts_kernel = NUTS(model)

584

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)

585

mcmc.run(data)

586

587

# Get samples

588

samples = mcmc.get_samples()

589

mu_samples = samples["mu"]

590

sigma_samples = samples["sigma"]

591

```

592

593

### Posterior Predictive Checks

594

595

```python

596

from pyro.infer import Predictive

597

598

# After training SVI or MCMC

599

predictive = Predictive(model, guide=guide, num_samples=1000)

600

posterior_samples = predictive(data)

601

602

# Generate predictions for new data

603

predictive_new = Predictive(model, guide=guide, num_samples=100)

604

predictions = predictive_new(new_data)

605

```

606

607

### Model Comparison with Importance Sampling

608

609

```python

610

from pyro.infer import Importance

611

612

# Compare two models

613

importance1 = Importance(model1, guide1, num_samples=10000)

614

log_evidence1 = importance1.run(data)

615

616

importance2 = Importance(model2, guide2, num_samples=10000)

617

log_evidence2 = importance2.run(data)

618

619

# Bayes factor

620

bayes_factor = torch.exp(log_evidence1 - log_evidence2)

621

print(f"Bayes factor (Model 1 vs Model 2): {bayes_factor}")

622

```