0
# Primitives
1
2
NumPyro's primitive functions provide the core building blocks for probabilistic models. These functions enable sampling from distributions, defining parameters, handling conditional independence, and marking deterministic computations. All primitives integrate with the effect handler system and support automatic differentiation through JAX.
3
4
## Capabilities
5
6
### Core Sampling Primitives
7
8
The fundamental primitives for probabilistic programming.
9
10
```python { .api }
11
def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None,
12
rng_key: Optional[Array] = None, sample_shape: tuple = (),
13
infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike:
14
"""
15
Sample a value from a distribution or condition on observed data.
16
17
Args:
18
name: Name of the sample site (must be unique within model)
19
fn: Probability distribution to sample from
20
obs: Observed value to condition on (optional)
21
rng_key: Random key for sampling (optional, auto-generated if None)
22
sample_shape: Shape of samples to draw (for multiple samples)
23
infer: Dictionary of inference hints and configuration
24
obs_mask: Boolean mask for partially observed data
25
26
Returns:
27
Sampled value or observed value (if obs is provided)
28
29
Usage:
30
# Sample from prior
31
x = numpyro.sample("x", dist.Normal(0, 1))
32
33
# Condition on observed data
34
y = numpyro.sample("y", dist.Normal(x, 0.5), obs=observed_y)
35
36
# Sample multiple values
37
batch_samples = numpyro.sample("batch", dist.Normal(0, 1), sample_shape=(10,))
38
39
# Configure inference behavior
40
z = numpyro.sample("z", dist.Normal(0, 1), infer={"is_auxiliary": True})
41
"""
42
43
def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None,
44
constraint: Constraint = constraints.real, event_dim: Optional[int] = None,
45
**kwargs) -> Optional[ArrayLike]:
46
"""
47
Declare an optimizable parameter in the model.
48
49
Args:
50
name: Parameter name (must be unique)
51
init_value: Initial value or initialization function
52
constraint: Parameter constraint (e.g., constraints.positive)
53
event_dim: Number of rightmost dimensions treated as event shape
54
**kwargs: Additional arguments (e.g., for initialization functions)
55
56
Returns:
57
Parameter value (None during initial model trace)
58
59
Usage:
60
# Simple parameter with constraint
61
sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)
62
63
# Parameter with initialization function
64
weights = numpyro.param("weights",
65
lambda key: random.normal(key, (10, 5)),
66
constraint=constraints.real)
67
68
# Simplex-constrained parameter
69
probs = numpyro.param("probs", jnp.ones(3) / 3, constraint=constraints.simplex)
70
"""
71
```
72
73
### Deterministic Sites
74
75
Primitives for marking deterministic computations and adding log probability factors.
76
77
```python { .api }
78
def deterministic(name: str, value: ArrayLike) -> ArrayLike:
79
"""
80
Mark a deterministic computation site for tracking in traces.
81
82
Args:
83
name: Name of the deterministic site
84
value: Computed deterministic value
85
86
Returns:
87
The input value (unchanged)
88
89
Usage:
90
x = numpyro.sample("x", dist.Normal(0, 1))
91
y = numpyro.sample("y", dist.Normal(0, 1))
92
93
# Mark sum as deterministic for tracking
94
sum_xy = numpyro.deterministic("sum", x + y)
95
96
# Can be used for derived quantities
97
mean_xy = numpyro.deterministic("mean", (x + y) / 2)
98
"""
99
100
def factor(name: str, log_factor: ArrayLike) -> None:
101
"""
102
Add a log probability factor to the model's joint density.
103
104
Args:
105
name: Name of the factor site
106
log_factor: Log probability value to add to joint density
107
108
Usage:
109
# Add log-likelihood term directly
110
numpyro.factor("custom_loglik", -0.5 * jnp.sum((y - mu)**2) / sigma**2)
111
112
# Add constraint violation penalty
113
numpyro.factor("penalty", -1e6 * jnp.where(x < 0, 1.0, 0.0))
114
115
# Add custom prior term
116
numpyro.factor("custom_prior", dist.Gamma(2, 1).log_prob(sigma))
117
"""
118
```
119
120
### Conditional Independence
121
122
Primitives for handling conditional independence and subsetting.
123
124
```python { .api }
125
class plate:
126
"""
127
Context manager for conditionally independent variables with automatic broadcasting.
128
129
Args:
130
name: Plate name (must be unique)
131
size: Size of the independence dimension
132
subsample_size: Size of subsample (for subsampling, optional)
133
dim: Dimension for broadcasting (negative, optional)
134
subsample: Indices for subsampling (optional)
135
136
Usage:
137
# Basic conditional independence
138
with numpyro.plate("data", 100):
139
x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (100,)
140
141
# Subsampling for large datasets
142
with numpyro.plate("data", 10000, subsample_size=100) as idx:
143
# idx contains the subsample indices
144
x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (100,)
145
146
# Nested plates for multidimensional independence
147
with numpyro.plate("batch", 50, dim=-2):
148
with numpyro.plate("features", 10, dim=-1):
149
weights = numpyro.sample("w", dist.Normal(0, 1)) # Shape: (50, 10)
150
"""
151
def __init__(self, name: str, size: int, subsample_size: Optional[int] = None,
152
dim: Optional[int] = None, subsample: Optional[ArrayLike] = None): ...
153
154
def __enter__(self) -> Optional[Array]:
155
"""Enter plate context, returning subsample indices if subsampling."""
156
157
def __exit__(self, exc_type, exc_value, traceback): ...
158
159
def plate_stack(prefix: str, sizes: list[int], rightmost_dim: int = -1) -> list:
160
"""
161
Create a stack of nested plates for multidimensional conditional independence.
162
163
Args:
164
prefix: Prefix for plate names
165
sizes: List of sizes for each dimension
166
rightmost_dim: Rightmost dimension index
167
168
Returns:
169
List of plate contexts
170
171
Usage:
172
# Create 3D tensor of independent samples
173
plates = numpyro.plate_stack("data", [20, 30, 40], rightmost_dim=-3)
174
with plates[0]:
175
with plates[1]:
176
with plates[2]:
177
x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (20, 30, 40)
178
"""
179
180
def subsample(data: ArrayLike, event_dim: int) -> ArrayLike:
181
"""
182
Subsample data based on active plates in the context.
183
184
Args:
185
data: Data tensor to subsample
186
event_dim: Number of rightmost dimensions that are event dimensions
187
188
Returns:
189
Subsampled data tensor
190
191
Usage:
192
# Subsample based on active plate
193
with numpyro.plate("data", len(full_data), subsample_size=100):
194
batch_data = numpyro.subsample(full_data, event_dim=0)
195
x = numpyro.sample("x", dist.Normal(batch_data, 1))
196
"""
197
```
198
199
### Advanced Primitives
200
201
Specialized primitives for advanced modeling scenarios.
202
203
```python { .api }
204
def mutable(name: str, init_value: Optional[ArrayLike] = None) -> ArrayLike:
205
"""
206
Create mutable storage that persists across function calls.
207
208
Args:
209
name: Name of the mutable site
210
init_value: Initial value for the mutable storage
211
212
Returns:
213
Current value of mutable storage
214
215
Usage:
216
# Counter that increments each call
217
count = numpyro.mutable("counter", 0)
218
numpyro.mutable("counter", count + 1) # Update the counter
219
"""
220
221
def module(name: str, nn: tuple, input_shape: Optional[tuple] = None) -> Callable:
222
"""
223
Register neural network modules for use with JAX transformations.
224
225
Args:
226
name: Module name
227
nn: Tuple of (init_fn, apply_fn) for neural network
228
input_shape: Input shape for module initialization
229
230
Returns:
231
Module function that can be called with inputs
232
233
Usage:
234
# Haiku neural network
235
import haiku as hk
236
237
def net_fn(x):
238
return hk.nets.MLP([64, 32, 1])(x)
239
240
net = hk.transform(net_fn)
241
module_fn = numpyro.module("mlp", net, input_shape=(10,))
242
243
# Use in model
244
x = numpyro.sample("x", dist.Normal(0, 1).expand((batch_size, 10)))
245
y_pred = module_fn(x)
246
"""
247
248
def prng_key() -> Optional[Array]:
249
"""
250
Get the current PRNG key from the execution context.
251
252
Returns:
253
Current random key or None if not available
254
255
Usage:
256
# Get key for manual random operations
257
key = numpyro.prng_key()
258
if key is not None:
259
noise = random.normal(key, shape=(10,))
260
"""
261
262
def get_mask() -> Optional[ArrayLike]:
263
"""
264
Get the current mask from the handler stack.
265
266
Returns:
267
Current mask array or None if no mask is active
268
269
Usage:
270
# Check if masking is active
271
current_mask = numpyro.get_mask()
272
if current_mask is not None:
273
# Handle masked computation
274
pass
275
"""
276
```
277
278
### Internal Utilities
279
280
Internal functions used by the primitive system (typically not used directly).
281
282
```python { .api }
283
def _masked_observe(name: str, fn: Distribution, obs: ArrayLike,
284
obs_mask: ArrayLike, **kwargs) -> ArrayLike:
285
"""
286
Handle masked observations in sample sites.
287
288
Args:
289
name: Site name
290
fn: Distribution
291
obs: Observed values
292
obs_mask: Boolean mask for valid observations
293
**kwargs: Additional arguments
294
295
Returns:
296
Masked observed value
297
"""
298
299
def _subsample_fn(size: int, subsample_size: int,
300
rng_key: Optional[Array] = None) -> Array:
301
"""
302
Generate subsample indices for plate subsampling.
303
304
Args:
305
size: Full dataset size
306
subsample_size: Size of subsample
307
rng_key: Random key for sampling
308
309
Returns:
310
Array of subsample indices
311
"""
312
313
def _inspect() -> dict:
314
"""
315
Inspect the current Pyro stack (experimental).
316
317
Returns:
318
Dictionary containing stack information
319
"""
320
321
class CondIndepStackFrame:
322
"""
323
Named tuple representing a conditional independence stack frame.
324
325
Attributes:
326
name: Frame name
327
dim: Broadcasting dimension
328
size: Frame size
329
counter: Frame counter for tracking
330
"""
331
name: str
332
dim: int
333
size: int
334
counter: int
335
```
336
337
### Validation and Inspection
338
339
Utilities for validating models and inspecting execution.
340
341
```python { .api }
342
def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:
343
"""
344
Validate model structure and return trace information.
345
346
Args:
347
model: Model function to validate
348
*model_args: Arguments to pass to model
349
**model_kwargs: Keyword arguments to pass to model
350
351
Returns:
352
Dictionary containing validation results and trace information
353
354
Usage:
355
def my_model():
356
x = numpyro.sample("x", dist.Normal(0, 1))
357
y = numpyro.sample("y", dist.Normal(x, 1))
358
359
validation_info = numpyro.validate_model(my_model)
360
print(f"Model has {len(validation_info['sites'])} sites")
361
"""
362
363
def inspect_fn(fn: Callable, *args, **kwargs) -> dict:
364
"""
365
Inspect function execution and return detailed information.
366
367
Args:
368
fn: Function to inspect
369
*args: Arguments to pass to function
370
**kwargs: Keyword arguments to pass to function
371
372
Returns:
373
Dictionary with execution information including sites and dependencies
374
"""
375
```
376
377
## Usage Examples
378
379
```python
380
import numpyro
381
import numpyro.distributions as dist
382
import jax.numpy as jnp
383
from jax import random
384
385
# Basic linear regression model
386
def linear_regression(X, y=None):
387
# Prior parameters
388
alpha = numpyro.sample("alpha", dist.Normal(0, 10))
389
beta = numpyro.sample("beta", dist.Normal(0, 10))
390
sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)
391
392
# Linear prediction
393
mu = alpha + beta * X
394
395
# Mark prediction for tracking
396
prediction = numpyro.deterministic("prediction", mu)
397
398
# Likelihood with conditional independence over data points
399
with numpyro.plate("data", X.shape[0]):
400
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
401
402
# Hierarchical model with nested plates
403
def hierarchical_model(group_idx, y=None):
404
n_groups = len(jnp.unique(group_idx))
405
n_obs = len(y) if y is not None else len(group_idx)
406
407
# Global hyperparameters
408
mu_global = numpyro.sample("mu_global", dist.Normal(0, 1))
409
sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))
410
411
# Group-level parameters
412
with numpyro.plate("groups", n_groups):
413
mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))
414
415
# Observation-level likelihood
416
with numpyro.plate("obs", n_obs):
417
mu = mu_group[group_idx]
418
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
419
420
# Model with subsampling for large datasets
421
def large_dataset_model(X, y=None):
422
n_data, n_features = X.shape
423
424
# Parameters
425
weights = numpyro.sample("weights", dist.Normal(0, 1).expand((n_features,)))
426
427
# Subsample for computational efficiency
428
with numpyro.plate("data", n_data, subsample_size=min(1000, n_data)) as idx:
429
X_batch = numpyro.subsample(X, event_dim=1)[idx] if idx is not None else X
430
y_batch = numpyro.subsample(y, event_dim=0)[idx] if y is not None and idx is not None else y
431
432
mu = X_batch @ weights
433
numpyro.sample("y", dist.Normal(mu, 0.1), obs=y_batch)
434
435
# Custom factor for non-standard likelihoods
436
def custom_likelihood_model(data):
437
theta = numpyro.sample("theta", dist.Beta(1, 1))
438
439
# Custom log-likelihood that doesn't fit standard distributions
440
log_lik = jnp.sum(data * jnp.log(theta) + (1 - data) * jnp.log(1 - theta))
441
numpyro.factor("custom_lik", log_lik)
442
```
443
444
## Types
445
446
```python { .api }
447
from typing import Optional, Union, Callable, Dict, Any, Tuple
448
from jax import Array
449
import jax.numpy as jnp
450
from numpyro.distributions import Distribution, constraints
451
452
ArrayLike = Union[Array, jnp.ndarray, float, int]
453
Constraint = constraints.Constraint
454
InitFunction = Union[ArrayLike, Callable[[Array], ArrayLike]]
455
456
class CondIndepStackFrame:
457
"""Frame in the conditional independence stack."""
458
name: str
459
dim: int
460
size: int
461
counter: int
462
463
class PlateMessenger:
464
"""Messenger for plate context management."""
465
name: str
466
size: int
467
subsample_size: Optional[int]
468
dim: Optional[int]
469
subsample: Optional[Array]
470
471
# Site types for different primitive operations
472
SiteType = Union["sample", "param", "deterministic", "factor", "mutable"]
473
474
class SiteInfo:
475
"""Information about a primitive site."""
476
name: str
477
type: SiteType
478
fn: Optional[Distribution]
479
args: tuple
480
kwargs: dict
481
value: Any
482
is_observed: bool
483
infer: dict
484
scale: Optional[float]
485
486
class ValidationResult:
487
"""Result from model validation."""
488
sites: dict
489
dependencies: dict
490
plate_stack: list
491
is_valid: bool
492
warnings: list
493
errors: list
494
```