0
# Handlers
1
2
NumPyro provides Pyro-style effect handlers that act as context managers to intercept and modify the execution of probabilistic programs. These handlers enable powerful model manipulation capabilities like conditioning on observed data, substituting values, applying transformations, and controlling inference behavior.
3
4
## Capabilities
5
6
### Core Handler Infrastructure
7
8
Base classes and utilities for the effect handling system.
9
10
```python { .api }
11
class Messenger:
12
"""
13
Base class for effect handlers with context manager protocol.
14
15
Handlers intercept messages at primitive sites and can modify their behavior.
16
This enables conditioning, substitution, masking, and other transformations.
17
"""
18
def __init__(self, fn: Optional[Callable] = None): ...
19
20
def __enter__(self): ...
21
def __exit__(self, exc_type, exc_value, traceback): ...
22
23
def process_message(self, msg: dict) -> None:
24
"""
25
Process a message at a primitive site.
26
27
Args:
28
msg: Message dictionary containing site information
29
"""
30
31
def __call__(self, *args, **kwargs):
32
"""Call the wrapped function with handler active."""
33
34
def default_process_message(msg: dict) -> None:
35
"""Default message processing for primitive sites."""
36
37
def apply_stack(msg: dict) -> dict:
38
"""Apply the current effect handler stack to a message."""
39
```
40
41
### Tracing and Replay
42
43
Handlers for recording and replaying model execution.
44
45
```python { .api }
46
def trace(fn: Callable) -> Callable:
47
"""
48
Record inputs and outputs at all primitive sites during model execution.
49
50
Args:
51
fn: Function to trace
52
53
Returns:
54
Traced function that returns execution trace
55
56
Usage:
57
traced_model = trace(model)
58
trace_dict = traced_model(*args, **kwargs)
59
"""
60
61
def replay(fn: Callable, trace: dict) -> Callable:
62
"""
63
Replay a function with a recorded trace.
64
65
Args:
66
fn: Function to replay
67
trace: Execution trace from previous run
68
69
Returns:
70
Function that replays with given trace
71
72
Usage:
73
replayed_model = replay(model, trace_dict)
74
result = replayed_model(*args, **kwargs)
75
"""
76
77
class TraceHandler(Messenger):
78
"""Handler for recording execution traces."""
79
def __init__(self, fn: Optional[Callable] = None): ...
80
def get_trace(self) -> dict: ...
81
82
class ReplayHandler(Messenger):
83
"""Handler for replaying with stored traces."""
84
def __init__(self, trace: dict, fn: Optional[Callable] = None): ...
85
```
86
87
### Conditioning and Substitution
88
89
Handlers for conditioning models on observed data and substituting values.
90
91
```python { .api }
92
def condition(fn: Callable, data: dict) -> Callable:
93
"""
94
Condition a probabilistic model on observed data.
95
96
Args:
97
fn: Model function to condition
98
data: Dictionary mapping site names to observed values
99
100
Returns:
101
Conditioned model function
102
103
Usage:
104
conditioned_model = condition(model, {"obs": observed_data})
105
result = conditioned_model(*args, **kwargs)
106
"""
107
108
def substitute(fn: Callable, data: dict) -> Callable:
109
"""
110
Substitute values at sample sites, bypassing distributions.
111
112
Args:
113
fn: Function to modify
114
data: Dictionary mapping site names to substitute values
115
116
Returns:
117
Function with substituted values
118
119
Usage:
120
substituted_model = substitute(model, {"param1": fixed_value})
121
result = substituted_model(*args, **kwargs)
122
"""
123
124
class ConditionHandler(Messenger):
125
"""Handler for conditioning on observed data."""
126
def __init__(self, data: dict, fn: Optional[Callable] = None): ...
127
128
class SubstituteHandler(Messenger):
129
"""Handler for substituting values at sample sites."""
130
def __init__(self, data: dict, fn: Optional[Callable] = None): ...
131
```
132
133
### Random Seed Control
134
135
Handlers for controlling random number generation.
136
137
```python { .api }
138
def seed(fn: Callable, rng_seed: int) -> Callable:
139
"""
140
Provide a random seed context for reproducible sampling.
141
142
Args:
143
fn: Function to seed
144
rng_seed: Random seed value
145
146
Returns:
147
Function with seeded random number generation
148
149
Usage:
150
seeded_model = seed(model, rng_seed=42)
151
result = seeded_model(*args, **kwargs)
152
"""
153
154
class SeedHandler(Messenger):
155
"""Handler for providing random seed context."""
156
def __init__(self, rng_seed: int, fn: Optional[Callable] = None): ...
157
```
158
159
### Blocking and Masking
160
161
Handlers for selectively blocking effects or masking computations.
162
163
```python { .api }
164
def block(fn: Callable, hide_fn: Optional[Callable] = None,
165
expose_fn: Optional[Callable] = None, hide_all: bool = True) -> Callable:
166
"""
167
Block effects at specified sites based on filtering functions.
168
169
Args:
170
fn: Function to modify
171
hide_fn: Function to determine which sites to hide
172
expose_fn: Function to determine which sites to expose
173
hide_all: Whether to hide all sites by default
174
175
Returns:
176
Function with blocked effects
177
178
Usage:
179
# Block all sample sites except "obs"
180
blocked_model = block(model, expose_fn=lambda msg: msg["name"] == "obs")
181
result = blocked_model(*args, **kwargs)
182
"""
183
184
def mask(fn: Callable, mask: ArrayLike) -> Callable:
185
"""
186
Mask effects based on boolean conditions.
187
188
Args:
189
fn: Function to mask
190
mask: Boolean array indicating which elements to mask
191
192
Returns:
193
Function with masked effects
194
195
Usage:
196
masked_model = mask(model, mask_array)
197
result = masked_model(*args, **kwargs)
198
"""
199
200
class BlockHandler(Messenger):
201
"""Handler for blocking effects at specified sites."""
202
def __init__(self, hide_fn: Optional[Callable] = None,
203
expose_fn: Optional[Callable] = None, hide_all: bool = True,
204
fn: Optional[Callable] = None): ...
205
206
class MaskHandler(Messenger):
207
"""Handler for masking effects based on conditions."""
208
def __init__(self, mask: ArrayLike, fn: Optional[Callable] = None): ...
209
```
210
211
### Scaling and Transformation
212
213
Handlers for scaling log probabilities and applying transformations.
214
215
```python { .api }
216
def scale(fn: Callable, scale: float) -> Callable:
217
"""
218
Scale log probabilities by a constant factor.
219
220
Args:
221
fn: Function to scale
222
scale: Scaling factor for log probabilities
223
224
Returns:
225
Function with scaled log probabilities
226
227
Usage:
228
scaled_model = scale(model, scale=0.1) # Tempered model
229
result = scaled_model(*args, **kwargs)
230
"""
231
232
def scope(fn: Callable, prefix: str) -> Callable:
233
"""
234
Add a scope prefix to all site names within the function.
235
236
Args:
237
fn: Function to scope
238
prefix: Prefix to add to site names
239
240
Returns:
241
Function with scoped site names
242
243
Usage:
244
scoped_model = scope(model, prefix="component1")
245
result = scoped_model(*args, **kwargs)
246
"""
247
248
class ScaleHandler(Messenger):
249
"""Handler for scaling log probabilities."""
250
def __init__(self, scale: float, fn: Optional[Callable] = None): ...
251
252
class ScopeHandler(Messenger):
253
"""Handler for adding scope prefixes to site names."""
254
def __init__(self, prefix: str, fn: Optional[Callable] = None): ...
255
```
256
257
### Parameter and Distribution Manipulation
258
259
Handlers for manipulating parameters and distributions.
260
261
```python { .api }
262
def lift(fn: Callable, prior: dict) -> Callable:
263
"""
264
Lift parameters to sample sites with specified priors.
265
266
Args:
267
fn: Function containing param sites to lift
268
prior: Dictionary mapping parameter names to prior distributions
269
270
Returns:
271
Function with parameters converted to sample sites
272
273
Usage:
274
lifted_model = lift(model, {"weight": dist.Normal(0, 1)})
275
result = lifted_model(*args, **kwargs)
276
"""
277
278
def reparam(fn: Callable, config: dict) -> Callable:
279
"""
280
Apply reparameterizations to specified sites.
281
282
Args:
283
fn: Function to reparameterize
284
config: Dictionary mapping site names to reparameterization strategies
285
286
Returns:
287
Function with applied reparameterizations
288
289
Usage:
290
from numpyro.infer.reparam import LocScaleReparam
291
reparamed_model = reparam(model, {"x": LocScaleReparam(centered=0)})
292
result = reparamed_model(*args, **kwargs)
293
"""
294
295
class LiftHandler(Messenger):
296
"""Handler for lifting parameters to sample sites."""
297
def __init__(self, prior: dict, fn: Optional[Callable] = None): ...
298
299
class ReparamHandler(Messenger):
300
"""Handler for applying reparameterizations."""
301
def __init__(self, config: dict, fn: Optional[Callable] = None): ...
302
```
303
304
### Enumeration and Collapse
305
306
Handlers for discrete variable enumeration and marginalization.
307
308
```python { .api }
309
def collapse(fn: Callable, sites: Optional[list] = None) -> Callable:
310
"""
311
Collapse (marginalize out) discrete enumeration at specified sites.
312
313
Args:
314
fn: Function with enumerated discrete variables
315
sites: List of site names to collapse (None for all)
316
317
Returns:
318
Function with collapsed discrete variables
319
320
Usage:
321
collapsed_model = collapse(enumerated_model, sites=["discrete_var"])
322
result = collapsed_model(*args, **kwargs)
323
"""
324
325
class CollapseHandler(Messenger):
326
"""Handler for collapsing discrete enumeration."""
327
def __init__(self, sites: Optional[list] = None, fn: Optional[Callable] = None): ...
328
```
329
330
### Inference Configuration
331
332
Handlers for configuring inference behavior.
333
334
```python { .api }
335
def infer_config(fn: Callable, config_fn: Callable) -> Callable:
336
"""
337
Configure inference behavior at sample sites.
338
339
Args:
340
fn: Function to configure
341
config_fn: Function that takes a site and returns inference config
342
343
Returns:
344
Function with inference configuration applied
345
346
Usage:
347
def config_fn(site):
348
if site["name"] == "x":
349
return {"is_auxiliary": True}
350
return {}
351
352
configured_model = infer_config(model, config_fn)
353
result = configured_model(*args, **kwargs)
354
"""
355
356
class InferConfigHandler(Messenger):
357
"""Handler for setting inference configuration."""
358
def __init__(self, config_fn: Callable, fn: Optional[Callable] = None): ...
359
```
360
361
### Causal Intervention
362
363
Handlers for causal modeling and intervention.
364
365
```python { .api }
366
def do(fn: Callable, data: dict) -> Callable:
367
"""
368
Apply causal interventions (do-operator) to specified variables.
369
370
Args:
371
fn: Model function to intervene on
372
data: Dictionary mapping variable names to intervention values
373
374
Returns:
375
Function with causal interventions applied
376
377
Usage:
378
# Intervene by setting X = 5
379
intervened_model = do(causal_model, {"X": 5})
380
result = intervened_model(*args, **kwargs)
381
"""
382
383
class DoHandler(Messenger):
384
"""Handler for causal interventions."""
385
def __init__(self, data: dict, fn: Optional[Callable] = None): ...
386
```
387
388
### Handler Composition and Utilities
389
390
Utilities for composing and managing multiple handlers.
391
392
```python { .api }
393
def compose(*handlers) -> Callable:
394
"""
395
Compose multiple handlers into a single handler.
396
397
Args:
398
*handlers: Handler functions to compose
399
400
Returns:
401
Composed handler function
402
403
Usage:
404
composed = compose(
405
seed(rng_seed=42),
406
substitute({"param": value}),
407
condition({"obs": data})
408
)
409
result = composed(model)(*args, **kwargs)
410
"""
411
412
def enable_validation(is_validate: bool = True):
413
"""
414
Context manager to enable/disable distribution validation.
415
416
Args:
417
is_validate: Whether to enable validation
418
419
Usage:
420
with enable_validation(True):
421
result = model(*args, **kwargs)
422
"""
423
424
class DynamicHandler(Messenger):
425
"""Handler with dynamic behavior based on runtime conditions."""
426
def __init__(self, handler_fn: Callable, fn: Optional[Callable] = None): ...
427
428
def get_mask() -> Optional[ArrayLike]:
429
"""Get the current mask from the handler stack."""
430
431
def get_dependencies() -> dict:
432
"""Get dependency information from the current trace."""
433
```
434
435
### Advanced Handler Patterns
436
437
Advanced patterns for specialized use cases.
438
439
```python { .api }
440
def escape(fn: Callable, escape_fn: Callable) -> Callable:
441
"""
442
Escape from the current handler context for specified operations.
443
444
Args:
445
fn: Function to modify
446
escape_fn: Function to determine when to escape
447
448
Returns:
449
Function that can escape handler effects
450
"""
451
452
def plate_messenger(name: str, size: int, subsample_size: Optional[int] = None,
453
dim: Optional[int] = None) -> Messenger:
454
"""
455
Create a plate messenger for conditional independence.
456
457
Args:
458
name: Plate name
459
size: Plate size
460
subsample_size: Subsampling size
461
dim: Dimension for broadcasting
462
463
Returns:
464
Plate messenger for conditional independence
465
"""
466
467
class CustomHandler(Messenger):
468
"""
469
Template for creating custom effect handlers.
470
471
Override process_message() to implement custom behavior:
472
473
class MyHandler(CustomHandler):
474
def process_message(self, msg):
475
if msg["type"] == "sample":
476
# Custom logic for sample sites
477
pass
478
elif msg["type"] == "param":
479
# Custom logic for param sites
480
pass
481
"""
482
def process_message(self, msg: dict) -> None: ...
483
```
484
485
## Usage Examples
486
487
```python
488
# Conditioning on observed data
489
import numpyro
490
import numpyro.distributions as dist
491
from numpyro.handlers import condition, substitute, seed, trace
492
493
def model():
494
x = numpyro.sample("x", dist.Normal(0, 1))
495
y = numpyro.sample("y", dist.Normal(x, 1))
496
return y
497
498
# Condition on observed y
499
observed_data = {"y": 2.0}
500
conditioned_model = condition(model, observed_data)
501
502
# Substitute a fixed value for x
503
substituted_model = substitute(model, {"x": 1.5})
504
505
# Set random seed for reproducibility
506
seeded_model = seed(model, rng_seed=42)
507
508
# Trace execution to see all sites
509
traced_model = trace(seeded_model)
510
trace_dict = traced_model()
511
512
# Compose multiple handlers
513
from numpyro.handlers import compose
514
515
composed_model = compose(
516
seed(rng_seed=42),
517
substitute({"x": 1.0}),
518
condition({"y": 2.0})
519
)(model)
520
521
result = composed_model()
522
```
523
524
## Types
525
526
```python { .api }
527
from typing import Optional, Union, Callable, Dict, Any
528
from jax import Array
529
import jax.numpy as jnp
530
531
ArrayLike = Union[Array, jnp.ndarray, float, int]
532
HandlerFunction = Callable[[Callable], Callable]
533
534
class Message:
535
"""
536
Message dictionary structure for effect handlers.
537
538
Common fields:
539
- name: Site name
540
- type: Message type ("sample", "param", "deterministic", etc.)
541
- fn: Distribution or function at the site
542
- args: Arguments to the function
543
- kwargs: Keyword arguments to the function
544
- value: Sampled or computed value
545
- is_observed: Whether the site is observed
546
- infer: Inference configuration
547
- scale: Probability scale factor
548
"""
549
name: str
550
type: str
551
fn: Any
552
args: tuple
553
kwargs: dict
554
value: Any
555
is_observed: bool
556
infer: dict
557
scale: Optional[float]
558
mask: Optional[ArrayLike]
559
cond_indep_stack: list
560
done: bool
561
stop: bool
562
continuation: Optional[Callable]
563
564
class Site:
565
"""Information about a primitive site in the model."""
566
name: str
567
type: str
568
fn: Any
569
args: tuple
570
kwargs: dict
571
value: Any
572
573
class Trace(dict):
574
"""
575
Execution trace containing all primitive sites.
576
577
Keys are site names, values are Site objects.
578
"""
579
def log_prob_sum(self) -> float: ...
580
def copy(self) -> 'Trace': ...
581
def nodes(self) -> dict: ...
582
```