or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

handlers.mddocs/

0

# Handlers

1

2

NumPyro provides Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. These handlers enable powerful model manipulation capabilities like conditioning on observed data, substituting values, applying transformations, and controlling inference behavior.

3

4

## Capabilities

5

6

### Core Handler Infrastructure

7

8

Base classes and utilities for the effect handling system.

9

10

```python { .api }

11

class Messenger:

12

"""

13

Base class for effect handlers with context manager protocol.

14

15

Handlers intercept messages at primitive sites and can modify their behavior.

16

This enables conditioning, substitution, masking, and other transformations.

17

"""

18

def __init__(self, fn: Optional[Callable] = None): ...

19

20

def __enter__(self): ...

21

def __exit__(self, exc_type, exc_value, traceback): ...

22

23

def process_message(self, msg: dict) -> None:

24

"""

25

Process a message at a primitive site.

26

27

Args:

28

msg: Message dictionary containing site information

29

"""

30

31

def __call__(self, *args, **kwargs):

32

"""Call the wrapped function with handler active."""

33

34

def default_process_message(msg: dict) -> None:

35

"""Default message processing for primitive sites."""

36

37

def apply_stack(msg: dict) -> dict:

38

"""Apply the current effect handler stack to a message."""

39

```

40

41

### Tracing and Replay

42

43

Handlers for recording and replaying model execution.

44

45

```python { .api }

46

def trace(fn: Callable) -> Callable:

47

"""

48

Record inputs and outputs at all primitive sites during model execution.

49

50

Args:

51

fn: Function to trace

52

53

Returns:

54

Traced function that returns execution trace

55

56

Usage:

57

traced_model = trace(model)

58

trace_dict = traced_model(*args, **kwargs)

59

"""

60

61

def replay(fn: Callable, trace: dict) -> Callable:

62

"""

63

Replay a function with a recorded trace.

64

65

Args:

66

fn: Function to replay

67

trace: Execution trace from previous run

68

69

Returns:

70

Function that replays with given trace

71

72

Usage:

73

replayed_model = replay(model, trace_dict)

74

result = replayed_model(*args, **kwargs)

75

"""

76

77

class TraceHandler(Messenger):

78

"""Handler for recording execution traces."""

79

def __init__(self, fn: Optional[Callable] = None): ...

80

def get_trace(self) -> dict: ...

81

82

class ReplayHandler(Messenger):

83

"""Handler for replaying with stored traces."""

84

def __init__(self, trace: dict, fn: Optional[Callable] = None): ...

85

```

86

87

### Conditioning and Substitution

88

89

Handlers for conditioning models on observed data and substituting values.

90

91

```python { .api }

92

def condition(fn: Callable, data: dict) -> Callable:

93

"""

94

Condition a probabilistic model on observed data.

95

96

Args:

97

fn: Model function to condition

98

data: Dictionary mapping site names to observed values

99

100

Returns:

101

Conditioned model function

102

103

Usage:

104

conditioned_model = condition(model, {"obs": observed_data})

105

result = conditioned_model(*args, **kwargs)

106

"""

107

108

def substitute(fn: Callable, data: dict) -> Callable:

109

"""

110

Substitute values at sample sites, bypassing distributions.

111

112

Args:

113

fn: Function to modify

114

data: Dictionary mapping site names to substitute values

115

116

Returns:

117

Function with substituted values

118

119

Usage:

120

substituted_model = substitute(model, {"param1": fixed_value})

121

result = substituted_model(*args, **kwargs)

122

"""

123

124

class ConditionHandler(Messenger):

125

"""Handler for conditioning on observed data."""

126

def __init__(self, data: dict, fn: Optional[Callable] = None): ...

127

128

class SubstituteHandler(Messenger):

129

"""Handler for substituting values at sample sites."""

130

def __init__(self, data: dict, fn: Optional[Callable] = None): ...

131

```

132

133

### Random Seed Control

134

135

Handlers for controlling random number generation.

136

137

```python { .api }

138

def seed(fn: Callable, rng_seed: int) -> Callable:

139

"""

140

Provide a random seed context for reproducible sampling.

141

142

Args:

143

fn: Function to seed

144

rng_seed: Random seed value

145

146

Returns:

147

Function with seeded random number generation

148

149

Usage:

150

seeded_model = seed(model, rng_seed=42)

151

result = seeded_model(*args, **kwargs)

152

"""

153

154

class SeedHandler(Messenger):

155

"""Handler for providing random seed context."""

156

def __init__(self, rng_seed: int, fn: Optional[Callable] = None): ...

157

```

158

159

### Blocking and Masking

160

161

Handlers for selectively blocking effects or masking computations.

162

163

```python { .api }

164

def block(fn: Callable, hide_fn: Optional[Callable] = None,

165

expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable:

166

"""

167

Block effects at specified sites based on filtering functions.

168

169

Args:

170

fn: Function to modify

171

hide_fn: Function to determine which sites to hide

172

expose_fn: Function to determine which sites to expose

173

hide_all: Whether to hide all sites by default

174

175

Returns:

176

Function with blocked effects

177

178

Usage:

179

# Block all sample sites except "obs"

180

blocked_model = block(model, expose_fn=lambda msg: msg["name"] == "obs")

181

result = blocked_model(*args, **kwargs)

182

"""

183

184

def mask(fn: Callable, mask: ArrayLike) -> Callable:

185

"""

186

Mask effects based on boolean conditions.

187

188

Args:

189

fn: Function to mask

190

mask: Boolean array indicating which elements to mask

191

192

Returns:

193

Function with masked effects

194

195

Usage:

196

masked_model = mask(model, mask_array)

197

result = masked_model(*args, **kwargs)

198

"""

199

200

class BlockHandler(Messenger):

201

"""Handler for blocking effects at specified sites."""

202

def __init__(self, hide_fn: Optional[Callable] = None,

203

expose_fn: Optional[Callable] = None, hide_all: bool = True,

204

fn: Optional[Callable] = None): ...

205

206

class MaskHandler(Messenger):

207

"""Handler for masking effects based on conditions."""

208

def __init__(self, mask: ArrayLike, fn: Optional[Callable] = None): ...

209

```

210

211

### Scaling and Transformation

212

213

Handlers for scaling log probabilities and applying transformations.

214

215

```python { .api }

216

def scale(fn: Callable, scale: float) -> Callable:

217

"""

218

Scale log probabilities by a constant factor.

219

220

Args:

221

fn: Function to scale

222

scale: Scaling factor for log probabilities

223

224

Returns:

225

Function with scaled log probabilities

226

227

Usage:

228

scaled_model = scale(model, scale=0.1) # Tempered model

229

result = scaled_model(*args, **kwargs)

230

"""

231

232

def scope(fn: Callable, prefix: str) -> Callable:

233

"""

234

Add a scope prefix to all site names within the function.

235

236

Args:

237

fn: Function to scope

238

prefix: Prefix to add to site names

239

240

Returns:

241

Function with scoped site names

242

243

Usage:

244

scoped_model = scope(model, prefix="component1")

245

result = scoped_model(*args, **kwargs)

246

"""

247

248

class ScaleHandler(Messenger):

249

"""Handler for scaling log probabilities."""

250

def __init__(self, scale: float, fn: Optional[Callable] = None): ...

251

252

class ScopeHandler(Messenger):

253

"""Handler for adding scope prefixes to site names."""

254

def __init__(self, prefix: str, fn: Optional[Callable] = None): ...

255

```

256

257

### Parameter and Distribution Manipulation

258

259

Handlers for manipulating parameters and distributions.

260

261

```python { .api }

262

def lift(fn: Callable, prior: dict) -> Callable:

263

"""

264

Lift parameters to sample sites with specified priors.

265

266

Args:

267

fn: Function containing param sites to lift

268

prior: Dictionary mapping parameter names to prior distributions

269

270

Returns:

271

Function with parameters converted to sample sites

272

273

Usage:

274

lifted_model = lift(model, {"weight": dist.Normal(0, 1)})

275

result = lifted_model(*args, **kwargs)

276

"""

277

278

def reparam(fn: Callable, config: dict) -> Callable:

279

"""

280

Apply reparameterizations to specified sites.

281

282

Args:

283

fn: Function to reparameterize

284

config: Dictionary mapping site names to reparameterization strategies

285

286

Returns:

287

Function with applied reparameterizations

288

289

Usage:

290

from numpyro.infer.reparam import LocScaleReparam

291

reparamed_model = reparam(model, {"x": LocScaleReparam(centered=0)})

292

result = reparamed_model(*args, **kwargs)

293

"""

294

295

class LiftHandler(Messenger):

296

"""Handler for lifting parameters to sample sites."""

297

def __init__(self, prior: dict, fn: Optional[Callable] = None): ...

298

299

class ReparamHandler(Messenger):

300

"""Handler for applying reparameterizations."""

301

def __init__(self, config: dict, fn: Optional[Callable] = None): ...

302

```

303

304

### Enumeration and Collapse

305

306

Handlers for discrete variable enumeration and marginalization.

307

308

```python { .api }

309

def collapse(fn: Callable, sites: Optional[list] = None) -> Callable:

310

"""

311

Collapse (marginalize out) discrete enumeration at specified sites.

312

313

Args:

314

fn: Function with enumerated discrete variables

315

sites: List of site names to collapse (None for all)

316

317

Returns:

318

Function with collapsed discrete variables

319

320

Usage:

321

collapsed_model = collapse(enumerated_model, sites=["discrete_var"])

322

result = collapsed_model(*args, **kwargs)

323

"""

324

325

class CollapseHandler(Messenger):

326

"""Handler for collapsing discrete enumeration."""

327

def __init__(self, sites: Optional[list] = None, fn: Optional[Callable] = None): ...

328

```

329

330

### Inference Configuration

331

332

Handlers for configuring inference behavior.

333

334

```python { .api }

335

def infer_config(fn: Callable, config_fn: Callable) -> Callable:

336

"""

337

Configure inference behavior at sample sites.

338

339

Args:

340

fn: Function to configure

341

config_fn: Function that takes a site and returns inference config

342

343

Returns:

344

Function with inference configuration applied

345

346

Usage:

347

def config_fn(site):

348

if site["name"] == "x":

349

return {"is_auxiliary": True}

350

return {}

351

352

configured_model = infer_config(model, config_fn)

353

result = configured_model(*args, **kwargs)

354

"""

355

356

class InferConfigHandler(Messenger):

357

"""Handler for setting inference configuration."""

358

def __init__(self, config_fn: Callable, fn: Optional[Callable] = None): ...

359

```

360

361

### Causal Intervention

362

363

Handlers for causal modeling and intervention.

364

365

```python { .api }

366

def do(fn: Callable, data: dict) -> Callable:

367

"""

368

Apply causal interventions (do-operator) to specified variables.

369

370

Args:

371

fn: Model function to intervene on

372

data: Dictionary mapping variable names to intervention values

373

374

Returns:

375

Function with causal interventions applied

376

377

Usage:

378

# Intervene by setting X = 5

379

intervened_model = do(causal_model, {"X": 5})

380

result = intervened_model(*args, **kwargs)

381

"""

382

383

class DoHandler(Messenger):

384

"""Handler for causal interventions."""

385

def __init__(self, data: dict, fn: Optional[Callable] = None): ...

386

```

387

388

### Handler Composition and Utilities

389

390

Utilities for composing and managing multiple handlers.

391

392

```python { .api }

393

def compose(*handlers) -> Callable:

394

"""

395

Compose multiple handlers into a single handler.

396

397

Args:

398

*handlers: Handler functions to compose

399

400

Returns:

401

Composed handler function

402

403

Usage:

404

composed = compose(

405

seed(rng_seed=42),

406

substitute({"param": value}),

407

condition({"obs": data})

408

)

409

result = composed(model)(*args, **kwargs)

410

"""

411

412

def enable_validation(is_validate: bool = True):

413

"""

414

Context manager to enable/disable distribution validation.

415

416

Args:

417

is_validate: Whether to enable validation

418

419

Usage:

420

with enable_validation(True):

421

result = model(*args, **kwargs)

422

"""

423

424

class DynamicHandler(Messenger):

425

"""Handler with dynamic behavior based on runtime conditions."""

426

def __init__(self, handler_fn: Callable, fn: Optional[Callable] = None): ...

427

428

def get_mask() -> Optional[ArrayLike]:

429

"""Get the current mask from the handler stack."""

430

431

def get_dependencies() -> dict:

432

"""Get dependency information from the current trace."""

433

```

434

435

### Advanced Handler Patterns

436

437

Advanced patterns for specialized use cases.

438

439

```python { .api }

440

def escape(fn: Callable, escape_fn: Callable) -> Callable:

441

"""

442

Escape from the current handler context for specified operations.

443

444

Args:

445

fn: Function to modify

446

escape_fn: Function to determine when to escape

447

448

Returns:

449

Function that can escape handler effects

450

"""

451

452

def plate_messenger(name: str, size: int, subsample_size: Optional[int] = None,

453

dim: Optional[int] = None) -> Messenger:

454

"""

455

Create a plate messenger for conditional independence.

456

457

Args:

458

name: Plate name

459

size: Plate size

460

subsample_size: Subsampling size

461

dim: Dimension for broadcasting

462

463

Returns:

464

Plate messenger for conditional independence

465

"""

466

467

class CustomHandler(Messenger):

468

"""

469

Template for creating custom effect handlers.

470

471

Override process_message() to implement custom behavior:

472

473

class MyHandler(CustomHandler):

474

def process_message(self, msg):

475

if msg["type"] == "sample":

476

# Custom logic for sample sites

477

pass

478

elif msg["type"] == "param":

479

# Custom logic for param sites

480

pass

481

"""

482

def process_message(self, msg: dict) -> None: ...

483

```

484

485

## Usage Examples

486

487

```python

488

# Conditioning on observed data

489

import numpyro

490

import numpyro.distributions as dist

491

from numpyro.handlers import condition, substitute, seed, trace

492

493

def model():

494

x = numpyro.sample("x", dist.Normal(0, 1))

495

y = numpyro.sample("y", dist.Normal(x, 1))

496

return y

497

498

# Condition on observed y

499

observed_data = {"y": 2.0}

500

conditioned_model = condition(model, observed_data)

501

502

# Substitute a fixed value for x

503

substituted_model = substitute(model, {"x": 1.5})

504

505

# Set random seed for reproducibility

506

seeded_model = seed(model, rng_seed=42)

507

508

# Trace execution to see all sites

509

traced_model = trace(seeded_model)

510

trace_dict = traced_model()

511

512

# Compose multiple handlers

513

from numpyro.handlers import compose

514

515

composed_model = compose(

516

seed(rng_seed=42),

517

substitute({"x": 1.0}),

518

condition({"y": 2.0})

519

)(model)

520

521

result = composed_model()

522

```

523

524

## Types

525

526

```python { .api }

527

from typing import Optional, Union, Callable, Dict, Any

528

from jax import Array

529

import jax.numpy as jnp

530

531

ArrayLike = Union[Array, jnp.ndarray, float, int]

532

HandlerFunction = Callable[[Callable], Callable]

533

534

class Message:

535

"""

536

Message dictionary structure for effect handlers.

537

538

Common fields:

539

- name: Site name

540

- type: Message type ("sample", "param", "deterministic", etc.)

541

- fn: Distribution or function at the site

542

- args: Arguments to the function

543

- kwargs: Keyword arguments to the function

544

- value: Sampled or computed value

545

- is_observed: Whether the site is observed

546

- infer: Inference configuration

547

- scale: Probability scale factor

548

"""

549

name: str

550

type: str

551

fn: Any

552

args: tuple

553

kwargs: dict

554

value: Any

555

is_observed: bool

556

infer: dict

557

scale: Optional[float]

558

mask: Optional[ArrayLike]

559

cond_indep_stack: list

560

done: bool

561

stop: bool

562

continuation: Optional[Callable]

563

564

class Site:

565

"""Information about a primitive site in the model."""

566

name: str

567

type: str

568

fn: Any

569

args: tuple

570

kwargs: dict

571

value: Any

572

573

class Trace(dict):

574

"""

575

Execution trace containing all primitive sites.

576

577

Keys are site names, values are Site objects.

578

"""

579

def log_prob_sum(self) -> float: ...

580

def copy(self) -> 'Trace': ...

581

def nodes(self) -> dict: ...

582

```