0
# Neural Networks Integration
1
2
Deep probabilistic models combining neural networks with probabilistic programming, enabling Bayesian neural networks, stochastic layers, and seamless integration between PyTorch modules and Pyro's probabilistic primitives.
3
4
## Capabilities
5
6
### Pyro Module System
7
8
Base classes and descriptors for creating probabilistic neural network modules that integrate seamlessly with Pyro's effect system.
9
10
```python { .api }
11
class PyroModule(torch.nn.Module):
12
"""
13
Base class for Pyro modules with integrated parameter and sample management.
14
15
PyroModule extends torch.nn.Module to support Pyro's parameter store and
16
sample statements, enabling probabilistic neural networks and automatic
17
integration with inference algorithms.
18
19
Examples:
20
>>> class BayesianLinear(PyroModule):
21
... def __init__(self, in_features, out_features):
22
... super().__init__()
23
... self.in_features = in_features
24
... self.out_features = out_features
25
...
26
... # Stochastic weights
27
... self.weight = PyroSample(
28
... dist.Normal(0, 1).expand([out_features, in_features]).to_event(2)
29
... )
30
...
31
... # Learnable bias
32
... self.bias = PyroParam(torch.zeros(out_features))
33
...
34
... def forward(self, x):
35
... return torch.nn.functional.linear(x, self.weight, self.bias)
36
"""
37
38
def __setattr__(self, name: str, value):
39
"""Override to handle PyroParam and PyroSample descriptors."""
40
41
def named_pyro_params(self, prefix: str = '', recurse: bool = True):
42
"""
43
Iterate over Pyro parameters in the module.
44
45
Parameters:
46
- prefix (str): Prefix to prepend to parameter names
47
- recurse (bool): Whether to recurse into submodules
48
49
Yields:
50
Tuple[str, torch.Tensor]: (name, parameter) pairs
51
"""
52
53
class PyroParam:
54
"""
55
Descriptor for Pyro parameters within PyroModule.
56
57
PyroParam creates learnable parameters that are automatically registered
58
with Pyro's parameter store and can be constrained or transformed.
59
"""
60
61
def __init__(self, init_tensor, constraint=dist.constraints.real, event_dim=None):
62
"""
63
Parameters:
64
- init_tensor (Tensor): Initial parameter value
65
- constraint (Constraint): Parameter constraint (e.g., positive, simplex)
66
- event_dim (int, optional): Number of rightmost event dimensions
67
68
Examples:
69
>>> # Unconstrained parameter
70
>>> self.mu = PyroParam(torch.tensor(0.0))
71
>>>
72
>>> # Positive parameter
73
>>> self.sigma = PyroParam(torch.tensor(1.0), constraint=dist.constraints.positive)
74
>>>
75
>>> # Simplex parameter (probabilities)
76
>>> self.probs = PyroParam(torch.ones(5), constraint=dist.constraints.simplex)
77
"""
78
79
def __get__(self, obj, obj_type=None) -> torch.Tensor:
80
"""Get parameter value from Pyro parameter store."""
81
82
def __set__(self, obj, value):
83
"""Set parameter value in Pyro parameter store."""
84
85
class PyroSample:
86
"""
87
Descriptor for Pyro samples within PyroModule.
88
89
PyroSample creates stochastic variables that are automatically sampled
90
from specified prior distributions during model execution.
91
"""
92
93
def __init__(self, prior):
94
"""
95
Parameters:
96
- prior (Distribution or callable): Prior distribution or function
97
returning a distribution
98
99
Examples:
100
>>> # Fixed prior distribution
101
>>> self.weight = PyroSample(dist.Normal(0, 1))
102
>>>
103
>>> # Parameterized prior
104
>>> self.weight = PyroSample(lambda: dist.Normal(self.weight_loc, self.weight_scale))
105
>>>
106
>>> # Matrix-valued parameter
107
>>> self.W = PyroSample(dist.Normal(0, 1).expand([10, 5]).to_event(2))
108
"""
109
110
def __get__(self, obj, obj_type=None) -> torch.Tensor:
111
"""Sample from prior distribution."""
112
113
def pyro_method(fn):
114
"""
115
Decorator to create Pyro-aware methods in PyroModule.
116
117
Ensures that sample statements within decorated methods use appropriate
118
name scoping and integration with the module's parameter namespace.
119
120
Parameters:
121
- fn (callable): Method to decorate
122
123
Returns:
124
callable: Decorated method with Pyro integration
125
126
Examples:
127
>>> class MyModule(PyroModule):
128
... @pyro_method
129
... def model(self, x):
130
... z = pyro.sample("z", dist.Normal(0, 1))
131
... return self.forward(x, z)
132
"""
133
```
134
135
### Neural Network Architectures
136
137
Specialized neural network architectures for probabilistic modeling and normalizing flows.
138
139
```python { .api }
140
class DenseNN(PyroModule):
141
"""
142
Dense (fully-connected) neural network with configurable architecture.
143
144
Commonly used in normalizing flows, variational autoencoders, and as
145
function approximators in probabilistic models.
146
"""
147
148
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int,
149
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
150
residual_connections: bool = False, batch_norm: bool = False,
151
dropout_prob: float = 0.0):
152
"""
153
Parameters:
154
- input_dim (int): Input dimension
155
- hidden_dims (List[int]): List of hidden layer dimensions
156
- output_dim (int): Output dimension
157
- nonlinearity (Module): Activation function between layers
158
- residual_connections (bool): Whether to add residual connections
159
- batch_norm (bool): Whether to use batch normalization
160
- dropout_prob (float): Dropout probability (0 = no dropout)
161
162
Examples:
163
>>> # Simple 3-layer network
164
>>> net = DenseNN(10, [64, 32], 1)
165
>>>
166
>>> # Network with batch norm and dropout
167
>>> net = DenseNN(20, [128, 64, 32], 5,
168
... batch_norm=True, dropout_prob=0.1)
169
"""
170
171
def forward(self, x: torch.Tensor) -> torch.Tensor:
172
"""
173
Forward pass through the network.
174
175
Parameters:
176
- x (Tensor): Input tensor of shape (..., input_dim)
177
178
Returns:
179
Tensor: Output tensor of shape (..., output_dim)
180
"""
181
182
class ConditionalDenseNN(PyroModule):
183
"""
184
Conditional dense neural network that takes additional context input.
185
186
Useful for conditional normalizing flows and context-dependent function
187
approximation in probabilistic models.
188
"""
189
190
def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],
191
output_dim: int, nonlinearity: torch.nn.Module = torch.nn.ReLU(),
192
residual_connections: bool = False):
193
"""
194
Parameters:
195
- input_dim (int): Primary input dimension
196
- context_dim (int): Context/condition dimension
197
- hidden_dims (List[int]): Hidden layer dimensions
198
- output_dim (int): Output dimension
199
- nonlinearity (Module): Activation function
200
- residual_connections (bool): Whether to use residual connections
201
202
Examples:
203
>>> # Conditional network
204
>>> cond_net = ConditionalDenseNN(10, 5, [64, 32], 2)
205
>>> output = cond_net(x, context)
206
"""
207
208
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
209
"""
210
Forward pass with context input.
211
212
Parameters:
213
- x (Tensor): Primary input of shape (..., input_dim)
214
- context (Tensor): Context input of shape (..., context_dim)
215
216
Returns:
217
Tensor: Output tensor of shape (..., output_dim)
218
"""
219
220
class AutoRegressiveNN(PyroModule):
221
"""
222
Autoregressive neural network with masked connections.
223
224
Implements MADE (Masked Autoencoder for Distribution Estimation) for
225
autoregressive density modeling and normalizing flows.
226
"""
227
228
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim_multiplier: int = 1,
229
nonlinearity: torch.nn.Module = torch.nn.ReLU(), residual_connections: bool = False,
230
random_mask: bool = False, activation: torch.nn.Module = None):
231
"""
232
Parameters:
233
- input_dim (int): Input dimension
234
- hidden_dims (List[int]): Hidden layer dimensions
235
- output_dim_multiplier (int): Output dimension multiplier (for multiple outputs per input)
236
- nonlinearity (Module): Hidden layer activation
237
- residual_connections (bool): Whether to use residual connections
238
- random_mask (bool): Whether to use random ordering for autoregressive mask
239
- activation (Module): Final layer activation
240
241
Examples:
242
>>> # Autoregressive network for 10-dimensional data
243
>>> ar_net = AutoRegressiveNN(10, [64, 64], output_dim_multiplier=2)
244
>>> # Output has shape (..., 20) for 2 outputs per input dimension
245
"""
246
247
def forward(self, x: torch.Tensor) -> torch.Tensor:
248
"""
249
Forward pass preserving autoregressive property.
250
251
Parameters:
252
- x (Tensor): Input tensor of shape (..., input_dim)
253
254
Returns:
255
Tensor: Output respecting autoregressive ordering
256
"""
257
258
class ConditionalAutoRegressiveNN(AutoRegressiveNN):
259
"""
260
Conditional autoregressive neural network with context input.
261
262
Combines autoregressive masking with conditional computation for
263
context-dependent autoregressive models.
264
"""
265
266
def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],
267
output_dim_multiplier: int = 1, nonlinearity: torch.nn.Module = torch.nn.ReLU()):
268
"""
269
Parameters:
270
- input_dim (int): Primary input dimension
271
- context_dim (int): Context dimension
272
- hidden_dims (List[int]): Hidden layer dimensions
273
- output_dim_multiplier (int): Output multiplier per input dimension
274
- nonlinearity (Module): Activation function
275
"""
276
277
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
278
"""Forward pass with context input maintaining autoregressive property."""
279
280
class MaskedLinear(torch.nn.Module):
281
"""
282
Linear layer with learnable or fixed mask for autoregressive networks.
283
284
Used as a building block in autoregressive neural networks where
285
connections must respect the autoregressive ordering.
286
"""
287
288
def __init__(self, in_features: int, out_features: int, mask: torch.Tensor = None,
289
bias: bool = True):
290
"""
291
Parameters:
292
- in_features (int): Input feature dimension
293
- out_features (int): Output feature dimension
294
- mask (Tensor, optional): Binary mask matrix (1=keep, 0=mask)
295
- bias (bool): Whether to include bias parameter
296
297
Examples:
298
>>> # Create mask for autoregressive ordering
299
>>> mask = torch.tril(torch.ones(5, 5)) # Lower triangular
300
>>> masked_layer = MaskedLinear(5, 5, mask)
301
"""
302
303
def forward(self, x: torch.Tensor) -> torch.Tensor:
304
"""Forward pass with masked weight matrix."""
305
```
306
307
### Bayesian Neural Networks
308
309
Tools for creating and working with Bayesian neural networks where weights and biases are treated as random variables.
310
311
```python { .api }
312
def lift_module(nn_module: torch.nn.Module, prior: callable, guide: callable = None):
313
"""
314
Lift a PyTorch module to a Bayesian neural network.
315
316
Converts deterministic neural network parameters to random variables
317
with specified prior distributions.
318
319
Parameters:
320
- nn_module (Module): PyTorch module to convert
321
- prior (callable): Function that returns prior distributions for parameters
322
- guide (callable, optional): Function that returns guide distributions
323
324
Returns:
325
PyroModule: Bayesian version of the input module
326
327
Examples:
328
>>> # Define deterministic network
329
>>> net = torch.nn.Linear(10, 1)
330
>>>
331
>>> # Define priors
332
>>> def prior(name, shape):
333
... return dist.Normal(0, 1).expand(shape).to_event(len(shape))
334
>>>
335
>>> # Create Bayesian network
336
>>> bnn = lift_module(net, prior)
337
>>>
338
>>> # Use in probabilistic model
339
>>> def model(x, y):
340
... lifted_nn = pyro.random_module("nn", net, prior)
341
... prediction = lifted_nn(x)
342
... pyro.sample("obs", dist.Normal(prediction.squeeze(), 0.1), obs=y)
343
"""
344
345
def sample_module_outputs(model: PyroModule, input_data: torch.Tensor,
346
num_samples: int = 100) -> torch.Tensor:
347
"""
348
Sample multiple outputs from a Bayesian neural network.
349
350
Parameters:
351
- model (PyroModule): Bayesian neural network model
352
- input_data (Tensor): Input data
353
- num_samples (int): Number of posterior samples to generate
354
355
Returns:
356
Tensor: Sampled outputs with shape (num_samples, batch_size, output_dim)
357
358
Examples:
359
>>> outputs = sample_module_outputs(bnn, test_data, num_samples=50)
360
>>> mean_prediction = outputs.mean(dim=0)
361
>>> uncertainty = outputs.std(dim=0)
362
"""
363
364
class BayesianModule(PyroModule):
365
"""
366
Base class for implementing custom Bayesian neural network layers.
367
368
Provides utilities for parameter sampling and uncertainty quantification
369
in neural network layers.
370
"""
371
372
def __init__(self, name: str):
373
"""
374
Parameters:
375
- name (str): Module name for parameter scoping
376
"""
377
super().__init__()
378
self._pyro_name = name
379
380
def sample_parameters(self):
381
"""Sample parameters from their prior/posterior distributions."""
382
383
def forward_with_samples(self, x: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
384
"""
385
Forward pass with multiple parameter samples for uncertainty estimation.
386
387
Parameters:
388
- x (Tensor): Input data
389
- num_samples (int): Number of parameter samples
390
391
Returns:
392
Tensor: Output samples with uncertainty
393
"""
394
```
395
396
### Variational Layers
397
398
Specialized layers for variational inference and amortized inference in deep generative models.
399
400
```python { .api }
401
class VariationalLinear(PyroModule):
402
"""
403
Variational linear layer with learnable mean and variance parameters.
404
405
Implements local reparameterization trick for efficient variational
406
inference in neural networks.
407
"""
408
409
def __init__(self, in_features: int, out_features: int, bias: bool = True,
410
prior_scale: float = 1.0):
411
"""
412
Parameters:
413
- in_features (int): Input feature dimension
414
- out_features (int): Output feature dimension
415
- bias (bool): Whether to include bias term
416
- prior_scale (float): Scale of prior distribution on weights
417
418
Examples:
419
>>> var_layer = VariationalLinear(10, 5, prior_scale=0.1)
420
"""
421
422
def forward(self, x: torch.Tensor) -> torch.Tensor:
423
"""Forward pass using local reparameterization trick."""
424
425
class AmortizedLDA(PyroModule):
426
"""
427
Amortized Latent Dirichlet Allocation using neural networks.
428
429
Implements neural variational inference for topic modeling where
430
the variational parameters are predicted by neural networks.
431
"""
432
433
def __init__(self, vocab_size: int, num_topics: int, hidden_dim: int = 100,
434
dropout: float = 0.2):
435
"""
436
Parameters:
437
- vocab_size (int): Vocabulary size
438
- num_topics (int): Number of topics
439
- hidden_dim (int): Hidden dimension for encoder network
440
- dropout (float): Dropout probability
441
"""
442
443
def model(self, docs: torch.Tensor, doc_lengths: torch.Tensor):
444
"""LDA generative model."""
445
446
def guide(self, docs: torch.Tensor, doc_lengths: torch.Tensor):
447
"""Neural variational guide for LDA."""
448
```
449
450
### Integration Utilities
451
452
Functions for seamless integration between PyTorch modules and Pyro probabilistic programs.
453
454
```python { .api }
455
def to_pyro_module_(nn_module: torch.nn.Module, prior: callable = None) -> PyroModule:
456
"""
457
Convert PyTorch module to PyroModule in-place.
458
459
Parameters:
460
- nn_module (Module): PyTorch module to convert
461
- prior (callable, optional): Prior distribution generator for parameters
462
463
Returns:
464
PyroModule: Converted module (same object)
465
466
Examples:
467
>>> net = torch.nn.Linear(10, 1)
468
>>> pyro_net = to_pyro_module_(net)
469
"""
470
471
def clear_module_hooks(module: torch.nn.Module):
472
"""
473
Clear all Pyro-related hooks from a PyTorch module.
474
475
Parameters:
476
- module (Module): Module to clear hooks from
477
"""
478
479
def module_prior(module_name: str, module: torch.nn.Module,
480
prior_fn: callable) -> torch.nn.Module:
481
"""
482
Apply prior distributions to all parameters in a PyTorch module.
483
484
Parameters:
485
- module_name (str): Name prefix for Pyro sample sites
486
- module (Module): PyTorch module
487
- prior_fn (callable): Function returning prior distributions
488
489
Returns:
490
Module: Module with stochastic parameters
491
492
Examples:
493
>>> def weight_prior(name, param):
494
... return dist.Normal(0, 1).expand(param.shape).to_event(param.dim())
495
>>>
496
>>> net = torch.nn.Linear(10, 1)
497
>>> stochastic_net = module_prior("net", net, weight_prior)
498
"""
499
500
class PyroModuleList(torch.nn.ModuleList, PyroModule):
501
"""
502
ModuleList that supports PyroModule functionality.
503
504
Enables lists of PyroModules to work correctly with Pyro's
505
parameter management and effect handling.
506
507
Examples:
508
>>> layers = PyroModuleList([
509
... BayesianLinear(10, 20),
510
... BayesianLinear(20, 1)
511
... ])
512
"""
513
514
def __init__(self, modules=None):
515
"""
516
Parameters:
517
- modules (iterable, optional): Iterable of modules to add
518
"""
519
```
520
521
## Examples
522
523
### Simple Bayesian Neural Network
524
525
```python
526
import pyro
527
import pyro.distributions as dist
528
from pyro.nn import PyroModule, PyroSample, PyroParam
529
import torch.nn.functional as F
530
531
class BayesianLinear(PyroModule):
532
def __init__(self, in_features, out_features):
533
super().__init__()
534
self.in_features = in_features
535
self.out_features = out_features
536
537
# Stochastic weights and biases
538
self.weight = PyroSample(
539
dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
540
)
541
self.bias = PyroSample(
542
dist.Normal(0., 1.).expand([out_features]).to_event(1)
543
)
544
545
def forward(self, x):
546
return F.linear(x, self.weight, self.bias)
547
548
# Usage in a model
549
def model(x, y):
550
fc = BayesianLinear(3, 1)
551
552
# Forward pass
553
mean = fc(x).squeeze()
554
555
# Likelihood
556
with pyro.plate("data", len(x)):
557
pyro.sample("obs", dist.Normal(mean, 0.1), obs=y)
558
559
def guide(x, y):
560
# Use a simpler guide or let AutoGuides handle it
561
pass
562
```
563
564
### Variational Autoencoder
565
566
```python
567
class VAE(PyroModule):
568
def __init__(self, input_dim=784, hidden_dim=400, z_dim=20):
569
super().__init__()
570
571
# Encoder
572
self.encoder_fc1 = torch.nn.Linear(input_dim, hidden_dim)
573
self.encoder_mu = torch.nn.Linear(hidden_dim, z_dim)
574
self.encoder_sigma = torch.nn.Linear(hidden_dim, z_dim)
575
576
# Decoder
577
self.decoder_fc1 = torch.nn.Linear(z_dim, hidden_dim)
578
self.decoder_fc2 = torch.nn.Linear(hidden_dim, input_dim)
579
580
def model(self, x):
581
# Register parameters with Pyro
582
pyro.module("decoder", self)
583
584
batch_size = x.shape[0]
585
586
# Prior
587
with pyro.plate("data", batch_size):
588
z_loc = torch.zeros(batch_size, self.z_dim)
589
z_scale = torch.ones(batch_size, self.z_dim)
590
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
591
592
# Decode
593
hidden = F.relu(self.decoder_fc1(z))
594
mu_img = torch.sigmoid(self.decoder_fc2(hidden))
595
596
# Likelihood
597
pyro.sample("obs", dist.Bernoulli(mu_img).to_event(1), obs=x)
598
599
def guide(self, x):
600
# Register parameters with Pyro
601
pyro.module("encoder", self)
602
603
batch_size = x.shape[0]
604
605
# Encode
606
hidden = F.relu(self.encoder_fc1(x))
607
z_mu = self.encoder_mu(hidden)
608
z_sigma = F.softplus(self.encoder_sigma(hidden))
609
610
# Variational distribution
611
with pyro.plate("data", batch_size):
612
pyro.sample("latent", dist.Normal(z_mu, z_sigma).to_event(1))
613
```
614
615
### Neural Network with Uncertainty
616
617
```python
618
class UncertaintyNet(PyroModule):
619
def __init__(self):
620
super().__init__()
621
self.linear = PyroModule[torch.nn.Linear](10, 1)
622
623
# Learnable noise parameter
624
self.sigma = PyroParam(torch.tensor(1.0),
625
constraint=dist.constraints.positive)
626
627
def forward(self, x, y=None):
628
# Sample network weights
629
lifted_module = pyro.random_module("module", self.linear,
630
lambda name, p: dist.Normal(0, 1)
631
.expand(p.shape).to_event(p.dim()))
632
633
# Forward pass
634
prediction = lifted_module(x).squeeze()
635
636
# Likelihood
637
if y is not None:
638
with pyro.plate("data", len(x)):
639
pyro.sample("obs", dist.Normal(prediction, self.sigma), obs=y)
640
641
return prediction
642
643
# Usage with uncertainty quantification
644
net = UncertaintyNet()
645
646
# Training with SVI
647
from pyro.infer import SVI, Trace_ELBO
648
from pyro.optim import Adam
649
650
svi = SVI(net.forward, lambda x, y: None, Adam({"lr": 0.01}), Trace_ELBO())
651
652
# Get predictions with uncertainty
653
from pyro.infer import Predictive
654
predictive = Predictive(net.forward, num_samples=100)
655
samples = predictive(test_x)
656
mean_pred = samples["obs"].mean(0)
657
std_pred = samples["obs"].std(0)
658
```