0
# Core Probabilistic Programming
1
2
Core functions and constructs that form the foundation of Pyro's probabilistic programming language, enabling the creation of probabilistic models through composable primitives.
3
4
## Capabilities
5
6
### Sample Statements
7
8
The fundamental stochastic function for declaring random variables and observed data in probabilistic programs.
9
10
```python { .api }
11
def sample(
12
name: str,
13
fn: TorchDistributionMixin,
14
*args,
15
obs: Optional[torch.Tensor] = None,
16
obs_mask: Optional[torch.BoolTensor] = None,
17
infer: Optional[InferDict] = None,
18
**kwargs
19
) -> torch.Tensor:
20
"""
21
Primitive stochastic function for probabilistic programming.
22
23
This is the core function for creating sample sites in probabilistic programs.
24
It can be used to declare latent variables, observed data, and guide samples.
25
26
Parameters:
27
- name (str): Unique name for the sample site within the current context
28
- fn (Distribution): Probability distribution to sample from
29
- obs (Tensor, optional): Observed data to condition on. When provided,
30
this becomes a conditioning site rather than a sampling site
31
- obs_mask (Tensor, optional): Boolean mask for observed data, useful for
32
missing data scenarios
33
- infer (dict, optional): Inference configuration dictionary containing
34
instructions for inference algorithms
35
36
Returns:
37
Tensor: Sample from the distribution (or observed value if obs is provided)
38
39
Examples:
40
>>> # Latent variable
41
>>> z = pyro.sample("z", dist.Normal(0, 1))
42
>>>
43
>>> # Observed data
44
>>> pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
45
>>>
46
>>> # With inference configuration
47
>>> pyro.sample("x", dist.Normal(0, 1), infer={"is_auxiliary": True})
48
"""
49
```
50
51
### Parameter Management
52
53
Functions for declaring and managing learnable parameters that persist across calls to the model.
54
55
```python { .api }
56
def param(
57
name: str,
58
init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None,
59
constraint: constraints.Constraint = constraints.real,
60
event_dim: Optional[int] = None,
61
) -> torch.Tensor:
62
"""
63
Declare and retrieve learnable parameters from the global parameter store.
64
65
Parameters persist across model calls and are automatically tracked for
66
gradient-based optimization.
67
68
Parameters:
69
- name (str): Parameter name, must be unique within the parameter store
70
- init_tensor (Tensor, optional): Initial parameter value. If None,
71
parameter must already exist in the store
72
- constraint (Constraint): Constraint on parameter values, defaults to
73
unconstrained real numbers
74
- event_dim (int, optional): Number of rightmost dimensions that are
75
part of the event shape
76
77
Returns:
78
Tensor: Parameter tensor with gradient tracking enabled
79
80
Examples:
81
>>> # Scalar parameter
82
>>> mu = pyro.param("mu", torch.tensor(0.0))
83
>>>
84
>>> # Vector parameter with constraint
85
>>> theta = pyro.param("theta", torch.ones(5), constraint=constraints.positive)
86
>>>
87
>>> # Matrix parameter
88
>>> W = pyro.param("W", torch.randn(10, 5))
89
"""
90
91
def clear_param_store():
92
"""
93
Clear all parameters from the global parameter store.
94
95
Useful for resetting state between different model runs or experiments.
96
"""
97
98
def get_param_store():
99
"""
100
Get the global parameter store instance.
101
102
Returns:
103
ParamStore: The global parameter store containing all named parameters
104
"""
105
```
106
107
### Independence Declarations
108
109
Context managers for declaring conditional independence and enabling efficient vectorized computation.
110
111
```python { .api }
112
class plate(PlateMessenger):
113
def __init__(
114
self,
115
name: str,
116
size: Optional[int] = None,
117
subsample_size: Optional[int] = None,
118
subsample: Optional[torch.Tensor] = None,
119
dim: Optional[int] = None,
120
use_cuda: Optional[bool] = None,
121
device: Optional[str] = None,
122
) -> None:
123
"""
124
Context manager for declaring conditional independence assumptions.
125
126
Plates enable vectorized computation and minibatch training by declaring
127
that samples within the plate are conditionally independent.
128
129
Parameters:
130
- name (str): Unique name for the plate
131
- size (int): Total size of the independent dimension
132
- subsample_size (int, optional): Size of minibatch subsample. If provided,
133
enables minibatch training with automatic scaling of log probabilities
134
- dim (int, optional): Tensor dimension to use for broadcasting. If None,
135
uses the rightmost available dimension
136
137
Returns:
138
PlateMessenger: Context manager that modifies sample site behavior
139
140
Examples:
141
>>> # Basic independence
142
>>> with pyro.plate("data", 100):
143
... pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
144
>>>
145
>>> # Minibatch training
146
>>> with pyro.plate("data", 10000, subsample_size=32):
147
... pyro.sample("obs", dist.Normal(mu, sigma), obs=data_batch)
148
>>>
149
>>> # Nested plates
150
>>> with pyro.plate("batch", N):
151
... with pyro.plate("features", D):
152
... pyro.sample("z", dist.Normal(0, 1))
153
"""
154
155
def plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = -1) -> Iterator[None]:
156
"""
157
Create a stack of nested plates for multi-dimensional independence.
158
159
Parameters:
160
- name (str): Base name for the plates
161
- sizes (List[int]): Sizes for each nested plate
162
- rightmost_dim (int): Rightmost tensor dimension to use
163
164
Returns:
165
ContextManager: Nested plate context
166
167
Examples:
168
>>> with pyro.plate_stack("plates", [N, D, K]):
169
... pyro.sample("z", dist.Normal(0, 1))
170
"""
171
```
172
173
### Model Composition
174
175
Functions for composing and manipulating probabilistic programs.
176
177
```python { .api }
178
def factor(
179
name: str,
180
log_factor: torch.Tensor,
181
*,
182
has_rsample: Optional[bool] = None
183
) -> None:
184
"""
185
Add an arbitrary log probability factor to the model.
186
187
Useful for including custom log probability terms that don't correspond
188
to standard distributions.
189
190
Parameters:
191
- name (str): Name of the factor site
192
- log_factor (torch.Tensor): Log probability factor to add to the model's
193
joint log probability
194
- has_rsample (bool, optional): Whether the factor arose from a fully
195
reparametrized distribution (required in guides)
196
197
Examples:
198
>>> # Custom likelihood term
199
>>> log_likelihood = -0.5 * torch.sum((data - mu) ** 2) / sigma ** 2
200
>>> pyro.factor("custom_likelihood", log_likelihood)
201
>>>
202
>>> # Penalty term
203
>>> penalty = -0.01 * torch.sum(params ** 2)
204
>>> pyro.factor("l2_penalty", penalty)
205
"""
206
207
def deterministic(name: str, value: torch.Tensor) -> torch.Tensor:
208
"""
209
Create a deterministic sample site for tracking intermediate computations.
210
211
Parameters:
212
- name (str): Name for the deterministic site
213
- value (Tensor): Deterministic value to record
214
- event_dim (int): Number of rightmost event dimensions
215
216
Returns:
217
Tensor: The input value (pass-through)
218
219
Examples:
220
>>> z = pyro.sample("z", dist.Normal(0, 1))
221
>>> z_squared = pyro.deterministic("z_squared", z ** 2)
222
"""
223
224
def barrier(data: torch.Tensor) -> torch.Tensor:
225
"""
226
Create a barrier for sequential execution in models.
227
228
Useful for enforcing execution order in complex models.
229
230
Parameters:
231
- name (str): Name for the barrier site
232
"""
233
```
234
235
### PyTorch Module Integration
236
237
Functions for integrating PyTorch modules into probabilistic programs.
238
239
```python { .api }
240
def module(name: str, nn_module, update_module_params: bool = False):
241
"""
242
Integrate a PyTorch module into a probabilistic program.
243
244
Parameters:
245
- name (str): Name for the module
246
- nn_module (torch.nn.Module): PyTorch module to integrate
247
- update_module_params (bool): Whether to register module parameters
248
with Pyro's parameter store
249
250
Returns:
251
torch.nn.Module: The input module
252
253
Examples:
254
>>> neural_net = torch.nn.Linear(10, 1)
255
>>> nn = pyro.module("neural_net", neural_net, update_module_params=True)
256
>>> output = nn(input_tensor)
257
"""
258
259
def random_module(name: str, nn_module, prior, *args, **kwargs):
260
"""
261
Create a stochastic neural network by placing priors over module parameters.
262
263
Parameters:
264
- name (str): Name for the random module
265
- nn_module (torch.nn.Module): PyTorch module template
266
- prior (callable): Function that returns prior distributions for parameters
267
268
Returns:
269
torch.nn.Module: Module with stochastic parameters
270
271
Examples:
272
>>> def prior(name, shape):
273
... return dist.Normal(0, 1).expand(shape).to_event(len(shape))
274
>>>
275
>>> template = torch.nn.Linear(10, 1)
276
>>> bayesian_nn = pyro.random_module("bnn", template, prior)
277
"""
278
```
279
280
### Subsampling and Utilities
281
282
Utilities for data subsampling and model visualization.
283
284
```python { .api }
285
def subsample(data: torch.Tensor, event_dim: int) -> torch.Tensor:
286
"""
287
Mark data for automatic subsampling within plates.
288
289
Parameters:
290
- data (Tensor): Data to subsample
291
- event_dim (int): Number of rightmost event dimensions
292
293
Returns:
294
Tensor: Subsampled data when inside a subsampling plate
295
"""
296
297
def render_model(model, *args, **kwargs):
298
"""
299
Render a graphical representation of the probabilistic model.
300
301
Parameters:
302
- model (callable): Model function to visualize
303
- *args, **kwargs: Arguments to pass to the model
304
305
Returns:
306
Visualization object for the model structure
307
"""
308
```
309
310
### Global State Management
311
312
Functions for managing global Pyro state and settings.
313
314
```python { .api }
315
def get_param_store() -> ParamStoreDict:
316
"""
317
Get the global parameter store containing all Pyro parameters.
318
319
Returns:
320
ParamStoreDict: Global parameter store dictionary
321
322
Examples:
323
>>> param_store = pyro.get_param_store()
324
>>> print(list(param_store.keys())) # List all parameter names
325
"""
326
327
def clear_param_store() -> None:
328
"""
329
Clear all parameters from the global parameter store.
330
331
Useful for starting fresh between experiments or tests.
332
333
Examples:
334
>>> pyro.clear_param_store() # Remove all parameters
335
"""
336
337
def enable_validation(is_validate: bool = True):
338
"""
339
Enable or disable runtime validation of distributions and shapes.
340
341
Parameters:
342
- is_validate (bool): Whether to enable validation
343
344
Examples:
345
>>> pyro.enable_validation(True) # Enable for debugging
346
>>> pyro.enable_validation(False) # Disable for performance
347
"""
348
349
def validation_enabled(is_validate: bool = True) -> Iterator[None]:
350
"""
351
Check if validation is currently enabled.
352
353
Returns:
354
bool: True if validation is enabled
355
"""
356
357
def set_rng_seed(rng_seed: int):
358
"""
359
Set random number generator seeds for reproducible results.
360
361
Sets seeds for Python random, NumPy, and PyTorch random number generators.
362
363
Parameters:
364
- rng_seed (int): Seed value for reproducible randomness
365
366
Examples:
367
>>> pyro.set_rng_seed(42) # For reproducible experiments
368
"""
369
```
370
371
## Examples
372
373
### Basic Model Definition
374
375
```python
376
import pyro
377
import pyro.distributions as dist
378
import torch
379
380
def coin_flip_model(data):
381
"""Simple Bernoulli coin flip model."""
382
# Prior on bias
383
bias = pyro.sample("bias", dist.Beta(1.0, 1.0))
384
385
# Likelihood
386
with pyro.plate("data", len(data)):
387
pyro.sample("obs", dist.Bernoulli(bias), obs=data)
388
389
# Usage
390
data = torch.tensor([1.0, 0.0, 1.0, 1.0, 0.0])
391
coin_flip_model(data)
392
```
393
394
### Hierarchical Model
395
396
```python
397
def hierarchical_model(group_data):
398
"""Hierarchical model with group-level parameters."""
399
# Global hyperpriors
400
mu_alpha = pyro.sample("mu_alpha", dist.Normal(0, 10))
401
sigma_alpha = pyro.sample("sigma_alpha", dist.HalfNormal(5))
402
403
# Group-specific parameters
404
with pyro.plate("groups", len(group_data)):
405
alpha = pyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
406
407
# Observations within each group
408
for i, group in enumerate(group_data):
409
with pyro.plate(f"group_{i}_data", len(group)):
410
pyro.sample(f"obs_{i}", dist.Normal(alpha[i], 1), obs=group)
411
```
412
413
### Minibatch Training
414
415
```python
416
def minibatch_model(data_loader):
417
"""Model with minibatch training support."""
418
# Global parameters
419
mu = pyro.param("mu", torch.tensor(0.0))
420
sigma = pyro.param("sigma", torch.tensor(1.0), constraint=dist.constraints.positive)
421
422
# Process minibatch
423
for batch in data_loader:
424
with pyro.plate("data", len(batch), subsample_size=len(batch)):
425
pyro.sample("obs", dist.Normal(mu, sigma), obs=batch)
426
```