or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-programming.mddistributions.mdgaussian-processes.mdindex.mdinference.mdneural-networks.mdoptimization.mdtransforms-constraints.md

neural-networks.mddocs/

0

# Neural Networks Integration

1

2

Deep probabilistic models combining neural networks with probabilistic programming, enabling Bayesian neural networks, stochastic layers, and seamless integration between PyTorch modules and Pyro's probabilistic primitives.

3

4

## Capabilities

5

6

### Pyro Module System

7

8

Base classes and descriptors for creating probabilistic neural network modules that integrate seamlessly with Pyro's effect system.

9

10

```python { .api }

11

class PyroModule(torch.nn.Module):

12

"""

13

Base class for Pyro modules with integrated parameter and sample management.

14

15

PyroModule extends torch.nn.Module to support Pyro's parameter store and

16

sample statements, enabling probabilistic neural networks and automatic

17

integration with inference algorithms.

18

19

Examples:

20

>>> class BayesianLinear(PyroModule):

21

... def __init__(self, in_features, out_features):

22

... super().__init__()

23

... self.in_features = in_features

24

... self.out_features = out_features

25

...

26

... # Stochastic weights

27

... self.weight = PyroSample(

28

... dist.Normal(0, 1).expand([out_features, in_features]).to_event(2)

29

... )

30

...

31

... # Learnable bias

32

... self.bias = PyroParam(torch.zeros(out_features))

33

...

34

... def forward(self, x):

35

... return torch.nn.functional.linear(x, self.weight, self.bias)

36

"""

37

38

def __setattr__(self, name: str, value):

39

"""Override to handle PyroParam and PyroSample descriptors."""

40

41

def named_pyro_params(self, prefix: str = '', recurse: bool = True):

42

"""

43

Iterate over Pyro parameters in the module.

44

45

Parameters:

46

- prefix (str): Prefix to prepend to parameter names

47

- recurse (bool): Whether to recurse into submodules

48

49

Yields:

50

Tuple[str, torch.Tensor]: (name, parameter) pairs

51

"""

52

53

class PyroParam:

54

"""

55

Descriptor for Pyro parameters within PyroModule.

56

57

PyroParam creates learnable parameters that are automatically registered

58

with Pyro's parameter store and can be constrained or transformed.

59

"""

60

61

def __init__(self, init_tensor, constraint=dist.constraints.real, event_dim=None):

62

"""

63

Parameters:

64

- init_tensor (Tensor): Initial parameter value

65

- constraint (Constraint): Parameter constraint (e.g., positive, simplex)

66

- event_dim (int, optional): Number of rightmost event dimensions

67

68

Examples:

69

>>> # Unconstrained parameter

70

>>> self.mu = PyroParam(torch.tensor(0.0))

71

>>>

72

>>> # Positive parameter

73

>>> self.sigma = PyroParam(torch.tensor(1.0), constraint=dist.constraints.positive)

74

>>>

75

>>> # Simplex parameter (probabilities)

76

>>> self.probs = PyroParam(torch.ones(5), constraint=dist.constraints.simplex)

77

"""

78

79

def __get__(self, obj, obj_type=None) -> torch.Tensor:

80

"""Get parameter value from Pyro parameter store."""

81

82

def __set__(self, obj, value):

83

"""Set parameter value in Pyro parameter store."""

84

85

class PyroSample:

86

"""

87

Descriptor for Pyro samples within PyroModule.

88

89

PyroSample creates stochastic variables that are automatically sampled

90

from specified prior distributions during model execution.

91

"""

92

93

def __init__(self, prior):

94

"""

95

Parameters:

96

- prior (Distribution or callable): Prior distribution or function

97

returning a distribution

98

99

Examples:

100

>>> # Fixed prior distribution

101

>>> self.weight = PyroSample(dist.Normal(0, 1))

102

>>>

103

>>> # Parameterized prior

104

>>> self.weight = PyroSample(lambda: dist.Normal(self.weight_loc, self.weight_scale))

105

>>>

106

>>> # Matrix-valued parameter

107

>>> self.W = PyroSample(dist.Normal(0, 1).expand([10, 5]).to_event(2))

108

"""

109

110

def __get__(self, obj, obj_type=None) -> torch.Tensor:

111

"""Sample from prior distribution."""

112

113

def pyro_method(fn):

114

"""

115

Decorator to create Pyro-aware methods in PyroModule.

116

117

Ensures that sample statements within decorated methods use appropriate

118

name scoping and integration with the module's parameter namespace.

119

120

Parameters:

121

- fn (callable): Method to decorate

122

123

Returns:

124

callable: Decorated method with Pyro integration

125

126

Examples:

127

>>> class MyModule(PyroModule):

128

... @pyro_method

129

... def model(self, x):

130

... z = pyro.sample("z", dist.Normal(0, 1))

131

... return self.forward(x, z)

132

"""

133

```

134

135

### Neural Network Architectures

136

137

Specialized neural network architectures for probabilistic modeling and normalizing flows.

138

139

```python { .api }

140

class DenseNN(PyroModule):

141

"""

142

Dense (fully-connected) neural network with configurable architecture.

143

144

Commonly used in normalizing flows, variational autoencoders, and as

145

function approximators in probabilistic models.

146

"""

147

148

def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int,

149

nonlinearity: torch.nn.Module = torch.nn.ReLU(),

150

residual_connections: bool = False, batch_norm: bool = False,

151

dropout_prob: float = 0.0):

152

"""

153

Parameters:

154

- input_dim (int): Input dimension

155

- hidden_dims (List[int]): List of hidden layer dimensions

156

- output_dim (int): Output dimension

157

- nonlinearity (Module): Activation function between layers

158

- residual_connections (bool): Whether to add residual connections

159

- batch_norm (bool): Whether to use batch normalization

160

- dropout_prob (float): Dropout probability (0 = no dropout)

161

162

Examples:

163

>>> # Simple 3-layer network

164

>>> net = DenseNN(10, [64, 32], 1)

165

>>>

166

>>> # Network with batch norm and dropout

167

>>> net = DenseNN(20, [128, 64, 32], 5,

168

... batch_norm=True, dropout_prob=0.1)

169

"""

170

171

def forward(self, x: torch.Tensor) -> torch.Tensor:

172

"""

173

Forward pass through the network.

174

175

Parameters:

176

- x (Tensor): Input tensor of shape (..., input_dim)

177

178

Returns:

179

Tensor: Output tensor of shape (..., output_dim)

180

"""

181

182

class ConditionalDenseNN(PyroModule):

183

"""

184

Conditional dense neural network that takes additional context input.

185

186

Useful for conditional normalizing flows and context-dependent function

187

approximation in probabilistic models.

188

"""

189

190

def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],

191

output_dim: int, nonlinearity: torch.nn.Module = torch.nn.ReLU(),

192

residual_connections: bool = False):

193

"""

194

Parameters:

195

- input_dim (int): Primary input dimension

196

- context_dim (int): Context/condition dimension

197

- hidden_dims (List[int]): Hidden layer dimensions

198

- output_dim (int): Output dimension

199

- nonlinearity (Module): Activation function

200

- residual_connections (bool): Whether to use residual connections

201

202

Examples:

203

>>> # Conditional network

204

>>> cond_net = ConditionalDenseNN(10, 5, [64, 32], 2)

205

>>> output = cond_net(x, context)

206

"""

207

208

def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:

209

"""

210

Forward pass with context input.

211

212

Parameters:

213

- x (Tensor): Primary input of shape (..., input_dim)

214

- context (Tensor): Context input of shape (..., context_dim)

215

216

Returns:

217

Tensor: Output tensor of shape (..., output_dim)

218

"""

219

220

class AutoRegressiveNN(PyroModule):

221

"""

222

Autoregressive neural network with masked connections.

223

224

Implements MADE (Masked Autoencoder for Distribution Estimation) for

225

autoregressive density modeling and normalizing flows.

226

"""

227

228

def __init__(self, input_dim: int, hidden_dims: List[int], output_dim_multiplier: int = 1,

229

nonlinearity: torch.nn.Module = torch.nn.ReLU(), residual_connections: bool = False,

230

random_mask: bool = False, activation: torch.nn.Module = None):

231

"""

232

Parameters:

233

- input_dim (int): Input dimension

234

- hidden_dims (List[int]): Hidden layer dimensions

235

- output_dim_multiplier (int): Output dimension multiplier (for multiple outputs per input)

236

- nonlinearity (Module): Hidden layer activation

237

- residual_connections (bool): Whether to use residual connections

238

- random_mask (bool): Whether to use random ordering for autoregressive mask

239

- activation (Module): Final layer activation

240

241

Examples:

242

>>> # Autoregressive network for 10-dimensional data

243

>>> ar_net = AutoRegressiveNN(10, [64, 64], output_dim_multiplier=2)

244

>>> # Output has shape (..., 20) for 2 outputs per input dimension

245

"""

246

247

def forward(self, x: torch.Tensor) -> torch.Tensor:

248

"""

249

Forward pass preserving autoregressive property.

250

251

Parameters:

252

- x (Tensor): Input tensor of shape (..., input_dim)

253

254

Returns:

255

Tensor: Output respecting autoregressive ordering

256

"""

257

258

class ConditionalAutoRegressiveNN(AutoRegressiveNN):

259

"""

260

Conditional autoregressive neural network with context input.

261

262

Combines autoregressive masking with conditional computation for

263

context-dependent autoregressive models.

264

"""

265

266

def __init__(self, input_dim: int, context_dim: int, hidden_dims: List[int],

267

output_dim_multiplier: int = 1, nonlinearity: torch.nn.Module = torch.nn.ReLU()):

268

"""

269

Parameters:

270

- input_dim (int): Primary input dimension

271

- context_dim (int): Context dimension

272

- hidden_dims (List[int]): Hidden layer dimensions

273

- output_dim_multiplier (int): Output multiplier per input dimension

274

- nonlinearity (Module): Activation function

275

"""

276

277

def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:

278

"""Forward pass with context input maintaining autoregressive property."""

279

280

class MaskedLinear(torch.nn.Module):

281

"""

282

Linear layer with learnable or fixed mask for autoregressive networks.

283

284

Used as a building block in autoregressive neural networks where

285

connections must respect the autoregressive ordering.

286

"""

287

288

def __init__(self, in_features: int, out_features: int, mask: torch.Tensor = None,

289

bias: bool = True):

290

"""

291

Parameters:

292

- in_features (int): Input feature dimension

293

- out_features (int): Output feature dimension

294

- mask (Tensor, optional): Binary mask matrix (1=keep, 0=mask)

295

- bias (bool): Whether to include bias parameter

296

297

Examples:

298

>>> # Create mask for autoregressive ordering

299

>>> mask = torch.tril(torch.ones(5, 5)) # Lower triangular

300

>>> masked_layer = MaskedLinear(5, 5, mask)

301

"""

302

303

def forward(self, x: torch.Tensor) -> torch.Tensor:

304

"""Forward pass with masked weight matrix."""

305

```

306

307

### Bayesian Neural Networks

308

309

Tools for creating and working with Bayesian neural networks where weights and biases are treated as random variables.

310

311

```python { .api }

312

def lift_module(nn_module: torch.nn.Module, prior: callable, guide: callable = None):

313

"""

314

Lift a PyTorch module to a Bayesian neural network.

315

316

Converts deterministic neural network parameters to random variables

317

with specified prior distributions.

318

319

Parameters:

320

- nn_module (Module): PyTorch module to convert

321

- prior (callable): Function that returns prior distributions for parameters

322

- guide (callable, optional): Function that returns guide distributions

323

324

Returns:

325

PyroModule: Bayesian version of the input module

326

327

Examples:

328

>>> # Define deterministic network

329

>>> net = torch.nn.Linear(10, 1)

330

>>>

331

>>> # Define priors

332

>>> def prior(name, shape):

333

... return dist.Normal(0, 1).expand(shape).to_event(len(shape))

334

>>>

335

>>> # Create Bayesian network

336

>>> bnn = lift_module(net, prior)

337

>>>

338

>>> # Use in probabilistic model

339

>>> def model(x, y):

340

... lifted_nn = pyro.random_module("nn", net, prior)

341

... prediction = lifted_nn(x)

342

... pyro.sample("obs", dist.Normal(prediction.squeeze(), 0.1), obs=y)

343

"""

344

345

def sample_module_outputs(model: PyroModule, input_data: torch.Tensor,

346

num_samples: int = 100) -> torch.Tensor:

347

"""

348

Sample multiple outputs from a Bayesian neural network.

349

350

Parameters:

351

- model (PyroModule): Bayesian neural network model

352

- input_data (Tensor): Input data

353

- num_samples (int): Number of posterior samples to generate

354

355

Returns:

356

Tensor: Sampled outputs with shape (num_samples, batch_size, output_dim)

357

358

Examples:

359

>>> outputs = sample_module_outputs(bnn, test_data, num_samples=50)

360

>>> mean_prediction = outputs.mean(dim=0)

361

>>> uncertainty = outputs.std(dim=0)

362

"""

363

364

class BayesianModule(PyroModule):

365

"""

366

Base class for implementing custom Bayesian neural network layers.

367

368

Provides utilities for parameter sampling and uncertainty quantification

369

in neural network layers.

370

"""

371

372

def __init__(self, name: str):

373

"""

374

Parameters:

375

- name (str): Module name for parameter scoping

376

"""

377

super().__init__()

378

self._pyro_name = name

379

380

def sample_parameters(self):

381

"""Sample parameters from their prior/posterior distributions."""

382

383

def forward_with_samples(self, x: torch.Tensor, num_samples: int = 1) -> torch.Tensor:

384

"""

385

Forward pass with multiple parameter samples for uncertainty estimation.

386

387

Parameters:

388

- x (Tensor): Input data

389

- num_samples (int): Number of parameter samples

390

391

Returns:

392

Tensor: Output samples with uncertainty

393

"""

394

```

395

396

### Variational Layers

397

398

Specialized layers for variational inference and amortized inference in deep generative models.

399

400

```python { .api }

401

class VariationalLinear(PyroModule):

402

"""

403

Variational linear layer with learnable mean and variance parameters.

404

405

Implements local reparameterization trick for efficient variational

406

inference in neural networks.

407

"""

408

409

def __init__(self, in_features: int, out_features: int, bias: bool = True,

410

prior_scale: float = 1.0):

411

"""

412

Parameters:

413

- in_features (int): Input feature dimension

414

- out_features (int): Output feature dimension

415

- bias (bool): Whether to include bias term

416

- prior_scale (float): Scale of prior distribution on weights

417

418

Examples:

419

>>> var_layer = VariationalLinear(10, 5, prior_scale=0.1)

420

"""

421

422

def forward(self, x: torch.Tensor) -> torch.Tensor:

423

"""Forward pass using local reparameterization trick."""

424

425

class AmortizedLDA(PyroModule):

426

"""

427

Amortized Latent Dirichlet Allocation using neural networks.

428

429

Implements neural variational inference for topic modeling where

430

the variational parameters are predicted by neural networks.

431

"""

432

433

def __init__(self, vocab_size: int, num_topics: int, hidden_dim: int = 100,

434

dropout: float = 0.2):

435

"""

436

Parameters:

437

- vocab_size (int): Vocabulary size

438

- num_topics (int): Number of topics

439

- hidden_dim (int): Hidden dimension for encoder network

440

- dropout (float): Dropout probability

441

"""

442

443

def model(self, docs: torch.Tensor, doc_lengths: torch.Tensor):

444

"""LDA generative model."""

445

446

def guide(self, docs: torch.Tensor, doc_lengths: torch.Tensor):

447

"""Neural variational guide for LDA."""

448

```

449

450

### Integration Utilities

451

452

Functions for seamless integration between PyTorch modules and Pyro probabilistic programs.

453

454

```python { .api }

455

def to_pyro_module_(nn_module: torch.nn.Module, prior: callable = None) -> PyroModule:

456

"""

457

Convert PyTorch module to PyroModule in-place.

458

459

Parameters:

460

- nn_module (Module): PyTorch module to convert

461

- prior (callable, optional): Prior distribution generator for parameters

462

463

Returns:

464

PyroModule: Converted module (same object)

465

466

Examples:

467

>>> net = torch.nn.Linear(10, 1)

468

>>> pyro_net = to_pyro_module_(net)

469

"""

470

471

def clear_module_hooks(module: torch.nn.Module):

472

"""

473

Clear all Pyro-related hooks from a PyTorch module.

474

475

Parameters:

476

- module (Module): Module to clear hooks from

477

"""

478

479

def module_prior(module_name: str, module: torch.nn.Module,

480

prior_fn: callable) -> torch.nn.Module:

481

"""

482

Apply prior distributions to all parameters in a PyTorch module.

483

484

Parameters:

485

- module_name (str): Name prefix for Pyro sample sites

486

- module (Module): PyTorch module

487

- prior_fn (callable): Function returning prior distributions

488

489

Returns:

490

Module: Module with stochastic parameters

491

492

Examples:

493

>>> def weight_prior(name, param):

494

... return dist.Normal(0, 1).expand(param.shape).to_event(param.dim())

495

>>>

496

>>> net = torch.nn.Linear(10, 1)

497

>>> stochastic_net = module_prior("net", net, weight_prior)

498

"""

499

500

class PyroModuleList(torch.nn.ModuleList, PyroModule):

501

"""

502

ModuleList that supports PyroModule functionality.

503

504

Enables lists of PyroModules to work correctly with Pyro's

505

parameter management and effect handling.

506

507

Examples:

508

>>> layers = PyroModuleList([

509

... BayesianLinear(10, 20),

510

... BayesianLinear(20, 1)

511

... ])

512

"""

513

514

def __init__(self, modules=None):

515

"""

516

Parameters:

517

- modules (iterable, optional): Iterable of modules to add

518

"""

519

```

520

521

## Examples

522

523

### Simple Bayesian Neural Network

524

525

```python

526

import pyro

527

import pyro.distributions as dist

528

from pyro.nn import PyroModule, PyroSample, PyroParam

529

import torch.nn.functional as F

530

531

class BayesianLinear(PyroModule):

532

def __init__(self, in_features, out_features):

533

super().__init__()

534

self.in_features = in_features

535

self.out_features = out_features

536

537

# Stochastic weights and biases

538

self.weight = PyroSample(

539

dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)

540

)

541

self.bias = PyroSample(

542

dist.Normal(0., 1.).expand([out_features]).to_event(1)

543

)

544

545

def forward(self, x):

546

return F.linear(x, self.weight, self.bias)

547

548

# Usage in a model

549

def model(x, y):

550

fc = BayesianLinear(3, 1)

551

552

# Forward pass

553

mean = fc(x).squeeze()

554

555

# Likelihood

556

with pyro.plate("data", len(x)):

557

pyro.sample("obs", dist.Normal(mean, 0.1), obs=y)

558

559

def guide(x, y):

560

# Use a simpler guide or let AutoGuides handle it

561

pass

562

```

563

564

### Variational Autoencoder

565

566

```python

567

class VAE(PyroModule):

568

def __init__(self, input_dim=784, hidden_dim=400, z_dim=20):

569

super().__init__()

570

571

# Encoder

572

self.encoder_fc1 = torch.nn.Linear(input_dim, hidden_dim)

573

self.encoder_mu = torch.nn.Linear(hidden_dim, z_dim)

574

self.encoder_sigma = torch.nn.Linear(hidden_dim, z_dim)

575

576

# Decoder

577

self.decoder_fc1 = torch.nn.Linear(z_dim, hidden_dim)

578

self.decoder_fc2 = torch.nn.Linear(hidden_dim, input_dim)

579

580

def model(self, x):

581

# Register parameters with Pyro

582

pyro.module("decoder", self)

583

584

batch_size = x.shape[0]

585

586

# Prior

587

with pyro.plate("data", batch_size):

588

z_loc = torch.zeros(batch_size, self.z_dim)

589

z_scale = torch.ones(batch_size, self.z_dim)

590

z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

591

592

# Decode

593

hidden = F.relu(self.decoder_fc1(z))

594

mu_img = torch.sigmoid(self.decoder_fc2(hidden))

595

596

# Likelihood

597

pyro.sample("obs", dist.Bernoulli(mu_img).to_event(1), obs=x)

598

599

def guide(self, x):

600

# Register parameters with Pyro

601

pyro.module("encoder", self)

602

603

batch_size = x.shape[0]

604

605

# Encode

606

hidden = F.relu(self.encoder_fc1(x))

607

z_mu = self.encoder_mu(hidden)

608

z_sigma = F.softplus(self.encoder_sigma(hidden))

609

610

# Variational distribution

611

with pyro.plate("data", batch_size):

612

pyro.sample("latent", dist.Normal(z_mu, z_sigma).to_event(1))

613

```

614

615

### Neural Network with Uncertainty

616

617

```python

618

class UncertaintyNet(PyroModule):

619

def __init__(self):

620

super().__init__()

621

self.linear = PyroModule[torch.nn.Linear](10, 1)

622

623

# Learnable noise parameter

624

self.sigma = PyroParam(torch.tensor(1.0),

625

constraint=dist.constraints.positive)

626

627

def forward(self, x, y=None):

628

# Sample network weights

629

lifted_module = pyro.random_module("module", self.linear,

630

lambda name, p: dist.Normal(0, 1)

631

.expand(p.shape).to_event(p.dim()))

632

633

# Forward pass

634

prediction = lifted_module(x).squeeze()

635

636

# Likelihood

637

if y is not None:

638

with pyro.plate("data", len(x)):

639

pyro.sample("obs", dist.Normal(prediction, self.sigma), obs=y)

640

641

return prediction

642

643

# Usage with uncertainty quantification

644

net = UncertaintyNet()

645

646

# Training with SVI

647

from pyro.infer import SVI, Trace_ELBO

648

from pyro.optim import Adam

649

650

svi = SVI(net.forward, lambda x, y: None, Adam({"lr": 0.01}), Trace_ELBO())

651

652

# Get predictions with uncertainty

653

from pyro.infer import Predictive

654

predictive = Predictive(net.forward, num_samples=100)

655

samples = predictive(test_x)

656

mean_pred = samples["obs"].mean(0)

657

std_pred = samples["obs"].std(0)

658

```