or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

inference.mddocs/

0

# Inference

1

2

NumPyro provides multiple inference algorithms for Bayesian posterior computation including Markov Chain Monte Carlo (MCMC) samplers, variational inference methods, ensemble techniques, and specialized algorithms. All inference methods are built on JAX for efficient automatic differentiation and JIT compilation.

3

4

## Capabilities

5

6

### MCMC Algorithms

7

8

Markov Chain Monte Carlo methods for sampling from posterior distributions.

9

10

#### Core MCMC Infrastructure

11

12

```python { .api }

13

class MCMC:

14

"""

15

Wrapper class for Markov Chain Monte Carlo inference algorithms.

16

17

Args:

18

kernel: MCMC kernel (e.g., NUTS, HMC)

19

num_warmup: Number of warmup steps

20

num_samples: Number of samples to draw

21

num_chains: Number of parallel chains

22

postprocess_fn: Post-processing function for samples

23

chain_method: Parallelization method ('parallel', 'sequential', 'vectorized')

24

progress_bar: Whether to show progress bar

25

jit_model_args: Whether to JIT compile model arguments

26

"""

27

def __init__(self, kernel, num_warmup: int, num_samples: int, num_chains: int = 1,

28

postprocess_fn: Optional[Callable] = None, chain_method: str = 'parallel',

29

progress_bar: bool = True, jit_model_args: bool = False): ...

30

31

def run(self, rng_key: Array, *args, extra_fields=(), init_params=None, **kwargs) -> None:

32

"""

33

Run MCMC sampling.

34

35

Args:

36

rng_key: Random key for sampling

37

*args: Arguments to pass to the model

38

extra_fields: Additional fields to collect

39

init_params: Initial parameter values

40

**kwargs: Keyword arguments to pass to the model

41

"""

42

43

def get_samples(self, group_by_chain: bool = False) -> dict:

44

"""

45

Get posterior samples.

46

47

Args:

48

group_by_chain: Whether to group samples by chain

49

50

Returns:

51

Dictionary of posterior samples

52

"""

53

54

def get_extra_fields(self, group_by_chain: bool = False) -> dict:

55

"""Get additional collected fields (e.g., diagnostics)."""

56

57

def print_summary(self, prob: float = 0.9, exclude_deterministic: bool = True) -> None:

58

"""Print summary statistics of posterior samples."""

59

```

60

61

#### Hamiltonian Monte Carlo

62

63

```python { .api }

64

class HMC:

65

"""

66

Hamiltonian Monte Carlo kernel.

67

68

Args:

69

model: Python callable containing Pyro primitives

70

step_size: Step size for leapfrog integrator

71

num_steps: Number of leapfrog steps

72

adapt_step_size: Whether to adapt step size during warmup

73

adapt_mass_matrix: Whether to adapt mass matrix during warmup

74

dense_mass: Whether to use dense mass matrix

75

target_accept_prob: Target acceptance probability for step size adaptation

76

trajectory_length: Alternative to num_steps, specifies trajectory length

77

max_tree_depth: Maximum tree depth for trajectory building

78

find_heuristic_step_size: Whether to find good initial step size

79

forward_mode_differentiation: Whether to use forward-mode AD

80

regularize_mass_matrix: Whether to regularize mass matrix

81

"""

82

def __init__(self, model, step_size=1.0, num_steps=None, adapt_step_size=True,

83

adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8,

84

trajectory_length=None, max_tree_depth=10, find_heuristic_step_size=False,

85

forward_mode_differentiation=False, regularize_mass_matrix=True): ...

86

87

class NUTS:

88

"""

89

No-U-Turn Sampler, an adaptive variant of HMC.

90

91

Args:

92

model: Python callable containing Pyro primitives

93

step_size: Initial step size

94

adapt_step_size: Whether to adapt step size during warmup

95

adapt_mass_matrix: Whether to adapt mass matrix during warmup

96

dense_mass: Whether to use dense mass matrix

97

target_accept_prob: Target acceptance probability

98

max_tree_depth: Maximum tree depth for trajectory building

99

find_heuristic_step_size: Whether to find good initial step size

100

forward_mode_differentiation: Whether to use forward-mode AD

101

regularize_mass_matrix: Whether to regularize mass matrix

102

"""

103

def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,

104

dense_mass=False, target_accept_prob=0.8, max_tree_depth=10,

105

find_heuristic_step_size=False, forward_mode_differentiation=False,

106

regularize_mass_matrix=True): ...

107

108

class SA:

109

"""

110

Simulated Annealing kernel.

111

112

Args:

113

model: Python callable containing Pyro primitives

114

adapt_state_size: Size of adaptive state

115

restart_interval: Interval for restarting annealing

116

cooling_schedule: Temperature cooling schedule function

117

"""

118

def __init__(self, model, adapt_state_size=None, restart_interval=100,

119

cooling_schedule=None): ...

120

121

class BarkerMH:

122

"""

123

Barker Metropolis-Hastings kernel.

124

125

Args:

126

model: Python callable containing Pyro primitives

127

step_size: Step size for proposals

128

adapt_step_size: Whether to adapt step size

129

adapt_mass_matrix: Whether to adapt mass matrix

130

dense_mass: Whether to use dense mass matrix

131

target_accept_prob: Target acceptance probability

132

"""

133

def __init__(self, model, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True,

134

dense_mass=False, target_accept_prob=0.234): ...

135

```

136

137

#### HMC Variants and Extensions

138

139

```python { .api }

140

class HMCGibbs:

141

"""

142

HMC-within-Gibbs sampler for models with discrete latent variables.

143

144

Args:

145

inner_kernel: Inner MCMC kernel (e.g., NUTS, HMC)

146

gibbs_fn: Gibbs sampling function for discrete variables

147

gibbs_sites: Names of discrete sites to sample with Gibbs

148

"""

149

def __init__(self, inner_kernel, gibbs_fn=None, gibbs_sites=None): ...

150

151

class DiscreteHMCGibbs:

152

"""

153

Specialized HMC-Gibbs for discrete variables.

154

155

Args:

156

inner_kernel: Inner kernel for continuous variables

157

modified: Whether to use modified proposal for discrete variables

158

gibbs_sites: Sites to sample with discrete Gibbs

159

"""

160

def __init__(self, inner_kernel, modified=True, gibbs_sites=None): ...

161

162

class HMCECS:

163

"""

164

HMC with Energy Conserving Subsampling for large datasets.

165

166

Args:

167

model: Python callable containing Pyro primitives

168

step_size: Step size for leapfrog integrator

169

trajectory_length: Length of HMC trajectory

170

num_blocks: Number of data blocks for subsampling

171

proxy: Proxy function for likelihood approximation

172

"""

173

def __init__(self, model, step_size=1.0, trajectory_length=1.0, num_blocks=1, proxy=None): ...

174

175

class MixedHMC:

176

"""

177

Mixed precision HMC for improved performance.

178

179

Args:

180

inner_kernel: Base HMC kernel

181

target_accept_prob: Target acceptance probability

182

trajectory_length: HMC trajectory length

183

"""

184

def __init__(self, inner_kernel, target_accept_prob=0.8, trajectory_length=1.0): ...

185

```

186

187

### Ensemble Methods

188

189

Ensemble sampling algorithms for parallel chain sampling.

190

191

```python { .api }

192

class ESS:

193

"""

194

Ensemble Slice Sampling.

195

196

Args:

197

model: Python callable containing Pyro primitives

198

max_slice_size: Maximum size of slice

199

num_slices: Number of slices per step

200

moves: Dictionary of move types and probabilities

201

"""

202

def __init__(self, model, max_slice_size=float('inf'), num_slices=1, moves=None): ...

203

204

class AIES:

205

"""

206

Affine Invariant Ensemble Sampler.

207

208

Args:

209

model: Python callable containing Pyro primitives

210

num_ensembles: Number of ensemble members

211

moves: Dictionary of move types and their configurations

212

"""

213

def __init__(self, model, num_ensembles=100, moves=None): ...

214

```

215

216

### Variational Inference

217

218

Stochastic variational inference for approximate posterior computation.

219

220

#### Core SVI Infrastructure

221

222

```python { .api }

223

class SVI:

224

"""

225

Stochastic Variational Inference.

226

227

Args:

228

model: Model function containing Pyro primitives

229

guide: Guide (variational family) function

230

optim: Optimizer for variational parameters

231

loss: Loss function (ELBO variant)

232

num_particles: Number of particles for gradient estimation

233

stable_update: Whether to use numerically stable updates

234

"""

235

def __init__(self, model, guide, optim, loss, num_particles=1, stable_update=False): ...

236

237

def run(self, rng_key: Array, num_steps: int, *args, progress_bar: bool = True,

238

stable_update: bool = False, **kwargs):

239

"""

240

Run stochastic variational inference.

241

242

Args:

243

rng_key: Random key for stochastic optimization

244

num_steps: Number of optimization steps

245

*args: Arguments to pass to model and guide

246

progress_bar: Whether to show progress bar

247

stable_update: Whether to use numerically stable updates

248

**kwargs: Keyword arguments to pass to model and guide

249

250

Returns:

251

SVIRunResult with losses and parameters

252

"""

253

254

def evaluate(self, rng_key: Array, *args, **kwargs) -> float:

255

"""Evaluate the current loss."""

256

257

def step(self, rng_key: Array, *args, **kwargs) -> float:

258

"""Take single SVI step."""

259

260

class SVIRunResult:

261

"""Result object from SVI.run()."""

262

losses: Array # Loss values over optimization

263

params: dict # Final parameter values

264

```

265

266

#### ELBO Objectives

267

268

```python { .api }

269

class ELBO:

270

"""

271

Base class for Evidence Lower BOund objectives.

272

273

Args:

274

num_particles: Number of particles for Monte Carlo estimation

275

vectorize_particles: Whether to vectorize over particles

276

ignore_jit_warnings: Whether to ignore JIT compilation warnings

277

"""

278

def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,

279

ignore_jit_warnings: bool = False): ...

280

281

def loss(self, rng_key: Array, param_map: dict, model: Callable, guide: Callable,

282

*args, **kwargs) -> float: ...

283

284

class Trace_ELBO(ELBO):

285

"""Standard ELBO using Monte Carlo estimation with reparameterized gradients."""

286

def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,

287

ignore_jit_warnings: bool = False): ...

288

289

class TraceEnum_ELBO(ELBO):

290

"""

291

ELBO with exact enumeration over discrete latent variables.

292

293

Args:

294

num_particles: Number of particles for continuous variables

295

max_plate_nesting: Maximum nesting level for enumeration

296

max_iarange_nesting: Deprecated alias for max_plate_nesting

297

strict_enumeration_warning: Whether to warn about enumeration issues

298

vectorize_particles: Whether to vectorize over particles

299

ignore_jit_warnings: Whether to ignore JIT warnings

300

"""

301

def __init__(self, num_particles: int = 1, max_plate_nesting: Optional[int] = None,

302

max_iarange_nesting: Optional[int] = None, strict_enumeration_warning: bool = True,

303

vectorize_particles: bool = False, ignore_jit_warnings: bool = False): ...

304

305

class TraceGraph_ELBO(ELBO):

306

"""

307

ELBO using Rao-Blackwellized gradient estimator.

308

309

Args:

310

num_particles: Number of particles for Monte Carlo estimation

311

vectorize_particles: Whether to vectorize over particles

312

ignore_jit_warnings: Whether to ignore JIT warnings

313

"""

314

def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,

315

ignore_jit_warnings: bool = False): ...

316

317

class TraceMeanField_ELBO(ELBO):

318

"""ELBO for mean field variational families."""

319

def __init__(self, num_particles: int = 1, vectorize_particles: bool = False,

320

ignore_jit_warnings: bool = False): ...

321

322

class RenyiELBO(ELBO):

323

"""

324

Rényi divergence-based ELBO for more robust variational inference.

325

326

Args:

327

alpha: Rényi divergence parameter (alpha=1 recovers standard ELBO)

328

num_particles: Number of particles for Monte Carlo estimation

329

vectorize_particles: Whether to vectorize over particles

330

"""

331

def __init__(self, alpha: float = 0.0, num_particles: int = 1,

332

vectorize_particles: bool = False): ...

333

```

334

335

#### Automatic Guide Generation

336

337

```python { .api }

338

# Located in numpyro.infer.autoguide module

339

340

class AutoGuide:

341

"""Base class for automatic variational guides."""

342

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,

343

create_plates=None): ...

344

345

def sample_posterior(self, rng_key: Array, params: dict, sample_shape=()) -> dict:

346

"""Sample from the approximate posterior."""

347

348

def median(self, params: dict) -> dict:

349

"""Compute median of the approximate posterior."""

350

351

def quantiles(self, params: dict, quantiles) -> dict:

352

"""Compute quantiles of the approximate posterior."""

353

354

class AutoNormal(AutoGuide):

355

"""

356

Multivariate normal variational family with diagonal covariance.

357

358

Args:

359

model: Model function

360

prefix: Prefix for parameter names

361

init_loc_fn: Initialization function for location parameters

362

init_scale: Initial scale for variational parameters

363

create_plates: Function to create plates for batched parameters

364

"""

365

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,

366

init_scale: float = 0.1, create_plates=None): ...

367

368

class AutoMultivariateNormal(AutoGuide):

369

"""

370

Multivariate normal variational family with full covariance matrix.

371

372

Args:

373

model: Model function

374

prefix: Prefix for parameter names

375

init_loc_fn: Initialization function for location parameters

376

init_scale: Initial scale for variational parameters

377

"""

378

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,

379

init_scale: float = 0.1): ...

380

381

class AutoLowRankMultivariateNormal(AutoGuide):

382

"""

383

Low-rank multivariate normal variational family.

384

385

Args:

386

model: Model function

387

prefix: Prefix for parameter names

388

init_loc_fn: Initialization function

389

rank: Rank of low-rank approximation

390

init_scale: Initial scale parameter

391

"""

392

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,

393

rank: int = 1, init_scale: float = 0.1): ...

394

395

class AutoDiagonalNormal(AutoGuide):

396

"""Diagonal normal variational family (alias for AutoNormal)."""

397

398

class AutoLaplaceApproximation(AutoGuide):

399

"""

400

Laplace approximation around MAP estimate.

401

402

Args:

403

model: Model function

404

prefix: Prefix for parameter names

405

init_loc_fn: Initialization function

406

"""

407

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...

408

409

class AutoDelta(AutoGuide):

410

"""

411

Point estimate guide (MAP approximation).

412

413

Args:

414

model: Model function

415

prefix: Prefix for parameter names

416

init_loc_fn: Initialization function for point estimates

417

"""

418

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None): ...

419

420

class AutoIAFNormal(AutoGuide):

421

"""

422

Inverse Autoregressive Flow with normal base distribution.

423

424

Args:

425

model: Model function

426

prefix: Prefix for parameter names

427

init_loc_fn: Initialization function

428

num_flows: Number of flow transformations

429

hidden_dims: Hidden dimensions for autoregressive networks

430

skip_connections: Whether to use skip connections

431

"""

432

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,

433

num_flows: int = 3, hidden_dims=None, skip_connections: bool = False): ...

434

435

class AutoBNAFNormal(AutoGuide):

436

"""

437

Block Neural Autoregressive Flow with normal base distribution.

438

439

Args:

440

model: Model function

441

prefix: Prefix for parameter names

442

init_loc_fn: Initialization function

443

num_flows: Number of flow layers

444

hidden_factors: Hidden layer size factors

445

residual: Whether to use residual connections

446

"""

447

def __init__(self, model: Callable, prefix: str = "auto", init_loc_fn=None,

448

num_flows: int = 1, hidden_factors=None, residual=None): ...

449

450

class AutoSurrogateLikelihoodDAG(AutoGuide):

451

"""Surrogate likelihood guide for DAG models."""

452

def __init__(self, model: Callable, prefix: str = "auto"): ...

453

```

454

455

### Initialization Strategies

456

457

Functions for initializing MCMC chains and variational parameters.

458

459

```python { .api }

460

def init_to_feasible(model: Callable, *model_args, **model_kwargs):

461

"""

462

Initialize to feasible values within parameter constraints.

463

464

Args:

465

model: Model function

466

*model_args: Arguments to the model

467

**model_kwargs: Keyword arguments to the model

468

469

Returns:

470

Initialization function

471

"""

472

473

def init_to_mean(model: Callable, *model_args, **model_kwargs):

474

"""Initialize parameters to their prior means (when available)."""

475

476

def init_to_median(model: Callable, *model_args, **model_kwargs):

477

"""Initialize parameters to their prior medians (when available)."""

478

479

def init_to_sample(model: Callable, *model_args, **model_kwargs):

480

"""Initialize parameters to samples from their priors."""

481

482

def init_to_uniform(model: Callable, radius: float = 2.0, *model_args, **model_kwargs):

483

"""

484

Initialize parameters uniformly within their support.

485

486

Args:

487

model: Model function

488

radius: Radius for uniform initialization in unconstrained space

489

"""

490

491

def init_to_value(values: dict):

492

"""

493

Initialize parameters to specified values.

494

495

Args:

496

values: Dictionary mapping parameter names to initial values

497

"""

498

```

499

500

### Utilities

501

502

Utility functions for inference and posterior analysis.

503

504

```python { .api }

505

class Predictive:

506

"""

507

Utility for posterior and prior predictive sampling.

508

509

Args:

510

model: Model function

511

posterior_samples: Dictionary of posterior samples (optional)

512

guide: Guide function for variational inference (optional)

513

params: Parameters for guide (when using variational inference)

514

num_samples: Number of samples to draw

515

return_sites: Sites to return in predictions

516

infer_discrete: Whether to infer discrete latent variables

517

parallel: Whether to run predictions in parallel

518

batch_ndims: Number of batch dimensions in posterior samples

519

"""

520

def __init__(self, model: Callable, posterior_samples: Optional[dict] = None,

521

guide: Optional[Callable] = None, params: Optional[dict] = None,

522

num_samples: Optional[int] = None, return_sites: Optional[list] = None,

523

infer_discrete: bool = False, parallel: bool = False, batch_ndims: int = 1): ...

524

525

def __call__(self, rng_key: Array, *args, **kwargs) -> dict:

526

"""

527

Generate predictions.

528

529

Args:

530

rng_key: Random key for sampling

531

*args: Arguments to pass to model

532

**kwargs: Keyword arguments to pass to model

533

534

Returns:

535

Dictionary of predicted values

536

"""

537

538

def log_likelihood(model: Callable, posterior_samples: dict, *args, **kwargs) -> dict:

539

"""

540

Compute log likelihood of observations given posterior samples.

541

542

Args:

543

model: Model function

544

posterior_samples: Dictionary of posterior samples

545

*args: Arguments to pass to model

546

**kwargs: Keyword arguments to pass to model

547

548

Returns:

549

Dictionary of log likelihood values for each observed site

550

"""

551

552

def render_model(model: Callable, model_args=(), model_kwargs=None, filename=None,

553

render_distributions: bool = False, render_params: bool = False,

554

hide_deterministic: bool = True):

555

"""

556

Render model structure as a graphical diagram.

557

558

Args:

559

model: Model function to render

560

model_args: Arguments to pass to model

561

model_kwargs: Keyword arguments to pass to model

562

filename: Output filename for rendered graph

563

render_distributions: Whether to show distribution details

564

render_params: Whether to show parameter nodes

565

hide_deterministic: Whether to hide deterministic sites

566

"""

567

```

568

569

### Reparameterization

570

571

Reparameterization strategies for improving inference efficiency.

572

573

```python { .api }

574

# Located in numpyro.infer.reparam module

575

576

class Reparam:

577

"""Base class for reparameterizations."""

578

def __call__(self, name: str, fn, obs) -> tuple: ...

579

580

class LocScaleReparam(Reparam):

581

"""

582

Reparameterization for location-scale distributions.

583

584

Args:

585

centered: Parameterization type (0=non-centered, 1=centered, None=adaptive)

586

"""

587

def __init__(self, centered: Optional[float] = None): ...

588

589

class TransformReparam(Reparam):

590

"""

591

Reparameterization using bijective transforms.

592

593

Args:

594

transform: Bijective transformation

595

suffix: Suffix for transformed variable names

596

"""

597

def __init__(self, transform, suffix: str = "_base"): ...

598

599

class NeuTraReparam(Reparam):

600

"""

601

Neural Transport reparameterization.

602

603

Args:

604

guide: Neural guide for reparameterization

605

params: Parameters for the guide

606

"""

607

def __init__(self, guide: Callable, params: dict): ...

608

609

class CircularReparam(Reparam):

610

"""Reparameterization for circular variables."""

611

612

class ProjectedNormalReparam(Reparam):

613

"""Reparameterization for projected normal distributions."""

614

615

class ImplicitReparam(Reparam):

616

"""Implicit reparameterization for complex posteriors."""

617

618

class SplitReparam(Reparam):

619

"""Split reparameterization for multivariate distributions."""

620

def __init__(self, sections: list, dim: int = -1): ...

621

622

class SymmetricSplitReparam(Reparam):

623

"""Symmetric split reparameterization."""

624

def __init__(self, sections: list, dim: int = -1): ...

625

```

626

627

## Types

628

629

```python { .api }

630

from typing import Optional, Union, Callable, Dict, Any, Tuple

631

from jax import Array

632

import jax.numpy as jnp

633

634

ArrayLike = Union[Array, jnp.ndarray, float, int]

635

MCMCKernel = Union[HMC, NUTS, SA, BarkerMH, HMCGibbs, DiscreteHMCGibbs, HMCECS, MixedHMC, ESS, AIES]

636

Optimizer = Any # From optax or numpyro.optim

637

LossFunction = Union[ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, RenyiELBO]

638

InitFunction = Callable[[Array, tuple, dict], dict]

639

640

class SVIState:

641

"""State object for SVI optimization."""

642

optim_state: Any

643

rng_key: Array

644

645

class SVIRunResult:

646

"""Result from SVI.run()."""

647

losses: Array

648

params: dict

649

state: SVIState

650

651

class MCMCState:

652

"""Internal state for MCMC kernels."""

653

z: dict # Current parameter values

654

potential_energy: float

655

z_grad: dict # Current gradients

656

adapt_state: Any # Adaptation state

657

rng_key: Array

658

659

# Kernel interfaces

660

class MCMCKernel:

661

"""Base interface for MCMC kernels."""

662

def init(self, rng_key: Array, num_warmup: int, init_params: dict,

663

model_args: tuple, model_kwargs: dict) -> MCMCState: ...

664

def sample(self, state: MCMCState, model_args: tuple, model_kwargs: dict) -> MCMCState: ...

665

def postprocess_fn(self, args: tuple, kwargs: dict) -> Callable: ...

666

```