or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-programming.mddistributions.mdgaussian-processes.mdindex.mdinference.mdneural-networks.mdoptimization.mdtransforms-constraints.md

core-programming.mddocs/

0

# Core Probabilistic Programming

1

2

Core functions and constructs that form the foundation of Pyro's probabilistic programming language, enabling the creation of probabilistic models through composable primitives.

3

4

## Capabilities

5

6

### Sample Statements

7

8

The fundamental stochastic function for declaring random variables and observed data in probabilistic programs.

9

10

```python { .api }

11

def sample(

12

name: str,

13

fn: TorchDistributionMixin,

14

*args,

15

obs: Optional[torch.Tensor] = None,

16

obs_mask: Optional[torch.BoolTensor] = None,

17

infer: Optional[InferDict] = None,

18

**kwargs

19

) -> torch.Tensor:

20

"""

21

Primitive stochastic function for probabilistic programming.

22

23

This is the core function for creating sample sites in probabilistic programs.

24

It can be used to declare latent variables, observed data, and guide samples.

25

26

Parameters:

27

- name (str): Unique name for the sample site within the current context

28

- fn (Distribution): Probability distribution to sample from

29

- obs (Tensor, optional): Observed data to condition on. When provided,

30

this becomes a conditioning site rather than a sampling site

31

- obs_mask (Tensor, optional): Boolean mask for observed data, useful for

32

missing data scenarios

33

- infer (dict, optional): Inference configuration dictionary containing

34

instructions for inference algorithms

35

36

Returns:

37

Tensor: Sample from the distribution (or observed value if obs is provided)

38

39

Examples:

40

>>> # Latent variable

41

>>> z = pyro.sample("z", dist.Normal(0, 1))

42

>>>

43

>>> # Observed data

44

>>> pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

45

>>>

46

>>> # With inference configuration

47

>>> pyro.sample("x", dist.Normal(0, 1), infer={"is_auxiliary": True})

48

"""

49

```

50

51

### Parameter Management

52

53

Functions for declaring and managing learnable parameters that persist across calls to the model.

54

55

```python { .api }

56

def param(

57

name: str,

58

init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None,

59

constraint: constraints.Constraint = constraints.real,

60

event_dim: Optional[int] = None,

61

) -> torch.Tensor:

62

"""

63

Declare and retrieve learnable parameters from the global parameter store.

64

65

Parameters persist across model calls and are automatically tracked for

66

gradient-based optimization.

67

68

Parameters:

69

- name (str): Parameter name, must be unique within the parameter store

70

- init_tensor (Tensor, optional): Initial parameter value. If None,

71

parameter must already exist in the store

72

- constraint (Constraint): Constraint on parameter values, defaults to

73

unconstrained real numbers

74

- event_dim (int, optional): Number of rightmost dimensions that are

75

part of the event shape

76

77

Returns:

78

Tensor: Parameter tensor with gradient tracking enabled

79

80

Examples:

81

>>> # Scalar parameter

82

>>> mu = pyro.param("mu", torch.tensor(0.0))

83

>>>

84

>>> # Vector parameter with constraint

85

>>> theta = pyro.param("theta", torch.ones(5), constraint=constraints.positive)

86

>>>

87

>>> # Matrix parameter

88

>>> W = pyro.param("W", torch.randn(10, 5))

89

"""

90

91

def clear_param_store():

92

"""

93

Clear all parameters from the global parameter store.

94

95

Useful for resetting state between different model runs or experiments.

96

"""

97

98

def get_param_store():

99

"""

100

Get the global parameter store instance.

101

102

Returns:

103

ParamStore: The global parameter store containing all named parameters

104

"""

105

```

106

107

### Independence Declarations

108

109

Context managers for declaring conditional independence and enabling efficient vectorized computation.

110

111

```python { .api }

112

class plate(PlateMessenger):

113

def __init__(

114

self,

115

name: str,

116

size: Optional[int] = None,

117

subsample_size: Optional[int] = None,

118

subsample: Optional[torch.Tensor] = None,

119

dim: Optional[int] = None,

120

use_cuda: Optional[bool] = None,

121

device: Optional[str] = None,

122

) -> None:

123

"""

124

Context manager for declaring conditional independence assumptions.

125

126

Plates enable vectorized computation and minibatch training by declaring

127

that samples within the plate are conditionally independent.

128

129

Parameters:

130

- name (str): Unique name for the plate

131

- size (int): Total size of the independent dimension

132

- subsample_size (int, optional): Size of minibatch subsample. If provided,

133

enables minibatch training with automatic scaling of log probabilities

134

- dim (int, optional): Tensor dimension to use for broadcasting. If None,

135

uses the rightmost available dimension

136

137

Returns:

138

PlateMessenger: Context manager that modifies sample site behavior

139

140

Examples:

141

>>> # Basic independence

142

>>> with pyro.plate("data", 100):

143

... pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

144

>>>

145

>>> # Minibatch training

146

>>> with pyro.plate("data", 10000, subsample_size=32):

147

... pyro.sample("obs", dist.Normal(mu, sigma), obs=data_batch)

148

>>>

149

>>> # Nested plates

150

>>> with pyro.plate("batch", N):

151

... with pyro.plate("features", D):

152

... pyro.sample("z", dist.Normal(0, 1))

153

"""

154

155

def plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = -1) -> Iterator[None]:

156

"""

157

Create a stack of nested plates for multi-dimensional independence.

158

159

Parameters:

160

- name (str): Base name for the plates

161

- sizes (List[int]): Sizes for each nested plate

162

- rightmost_dim (int): Rightmost tensor dimension to use

163

164

Returns:

165

ContextManager: Nested plate context

166

167

Examples:

168

>>> with pyro.plate_stack("plates", [N, D, K]):

169

... pyro.sample("z", dist.Normal(0, 1))

170

"""

171

```

172

173

### Model Composition

174

175

Functions for composing and manipulating probabilistic programs.

176

177

```python { .api }

178

def factor(

179

name: str,

180

log_factor: torch.Tensor,

181

*,

182

has_rsample: Optional[bool] = None

183

) -> None:

184

"""

185

Add an arbitrary log probability factor to the model.

186

187

Useful for including custom log probability terms that don't correspond

188

to standard distributions.

189

190

Parameters:

191

- name (str): Name of the factor site

192

- log_factor (torch.Tensor): Log probability factor to add to the model's

193

joint log probability

194

- has_rsample (bool, optional): Whether the factor arose from a fully

195

reparametrized distribution (required in guides)

196

197

Examples:

198

>>> # Custom likelihood term

199

>>> log_likelihood = -0.5 * torch.sum((data - mu) ** 2) / sigma ** 2

200

>>> pyro.factor("custom_likelihood", log_likelihood)

201

>>>

202

>>> # Penalty term

203

>>> penalty = -0.01 * torch.sum(params ** 2)

204

>>> pyro.factor("l2_penalty", penalty)

205

"""

206

207

def deterministic(name: str, value: torch.Tensor) -> torch.Tensor:

208

"""

209

Create a deterministic sample site for tracking intermediate computations.

210

211

Parameters:

212

- name (str): Name for the deterministic site

213

- value (Tensor): Deterministic value to record

214

- event_dim (int): Number of rightmost event dimensions

215

216

Returns:

217

Tensor: The input value (pass-through)

218

219

Examples:

220

>>> z = pyro.sample("z", dist.Normal(0, 1))

221

>>> z_squared = pyro.deterministic("z_squared", z ** 2)

222

"""

223

224

def barrier(data: torch.Tensor) -> torch.Tensor:

225

"""

226

Create a barrier for sequential execution in models.

227

228

Useful for enforcing execution order in complex models.

229

230

Parameters:

231

- name (str): Name for the barrier site

232

"""

233

```

234

235

### PyTorch Module Integration

236

237

Functions for integrating PyTorch modules into probabilistic programs.

238

239

```python { .api }

240

def module(name: str, nn_module, update_module_params: bool = False):

241

"""

242

Integrate a PyTorch module into a probabilistic program.

243

244

Parameters:

245

- name (str): Name for the module

246

- nn_module (torch.nn.Module): PyTorch module to integrate

247

- update_module_params (bool): Whether to register module parameters

248

with Pyro's parameter store

249

250

Returns:

251

torch.nn.Module: The input module

252

253

Examples:

254

>>> neural_net = torch.nn.Linear(10, 1)

255

>>> nn = pyro.module("neural_net", neural_net, update_module_params=True)

256

>>> output = nn(input_tensor)

257

"""

258

259

def random_module(name: str, nn_module, prior, *args, **kwargs):

260

"""

261

Create a stochastic neural network by placing priors over module parameters.

262

263

Parameters:

264

- name (str): Name for the random module

265

- nn_module (torch.nn.Module): PyTorch module template

266

- prior (callable): Function that returns prior distributions for parameters

267

268

Returns:

269

torch.nn.Module: Module with stochastic parameters

270

271

Examples:

272

>>> def prior(name, shape):

273

... return dist.Normal(0, 1).expand(shape).to_event(len(shape))

274

>>>

275

>>> template = torch.nn.Linear(10, 1)

276

>>> bayesian_nn = pyro.random_module("bnn", template, prior)

277

"""

278

```

279

280

### Subsampling and Utilities

281

282

Utilities for data subsampling and model visualization.

283

284

```python { .api }

285

def subsample(data: torch.Tensor, event_dim: int) -> torch.Tensor:

286

"""

287

Mark data for automatic subsampling within plates.

288

289

Parameters:

290

- data (Tensor): Data to subsample

291

- event_dim (int): Number of rightmost event dimensions

292

293

Returns:

294

Tensor: Subsampled data when inside a subsampling plate

295

"""

296

297

def render_model(model, *args, **kwargs):

298

"""

299

Render a graphical representation of the probabilistic model.

300

301

Parameters:

302

- model (callable): Model function to visualize

303

- *args, **kwargs: Arguments to pass to the model

304

305

Returns:

306

Visualization object for the model structure

307

"""

308

```

309

310

### Global State Management

311

312

Functions for managing global Pyro state and settings.

313

314

```python { .api }

315

def get_param_store() -> ParamStoreDict:

316

"""

317

Get the global parameter store containing all Pyro parameters.

318

319

Returns:

320

ParamStoreDict: Global parameter store dictionary

321

322

Examples:

323

>>> param_store = pyro.get_param_store()

324

>>> print(list(param_store.keys())) # List all parameter names

325

"""

326

327

def clear_param_store() -> None:

328

"""

329

Clear all parameters from the global parameter store.

330

331

Useful for starting fresh between experiments or tests.

332

333

Examples:

334

>>> pyro.clear_param_store() # Remove all parameters

335

"""

336

337

def enable_validation(is_validate: bool = True):

338

"""

339

Enable or disable runtime validation of distributions and shapes.

340

341

Parameters:

342

- is_validate (bool): Whether to enable validation

343

344

Examples:

345

>>> pyro.enable_validation(True) # Enable for debugging

346

>>> pyro.enable_validation(False) # Disable for performance

347

"""

348

349

def validation_enabled(is_validate: bool = True) -> Iterator[None]:

350

"""

351

Check if validation is currently enabled.

352

353

Returns:

354

bool: True if validation is enabled

355

"""

356

357

def set_rng_seed(rng_seed: int):

358

"""

359

Set random number generator seeds for reproducible results.

360

361

Sets seeds for Python random, NumPy, and PyTorch random number generators.

362

363

Parameters:

364

- rng_seed (int): Seed value for reproducible randomness

365

366

Examples:

367

>>> pyro.set_rng_seed(42) # For reproducible experiments

368

"""

369

```

370

371

## Examples

372

373

### Basic Model Definition

374

375

```python

376

import pyro

377

import pyro.distributions as dist

378

import torch

379

380

def coin_flip_model(data):

381

"""Simple Bernoulli coin flip model."""

382

# Prior on bias

383

bias = pyro.sample("bias", dist.Beta(1.0, 1.0))

384

385

# Likelihood

386

with pyro.plate("data", len(data)):

387

pyro.sample("obs", dist.Bernoulli(bias), obs=data)

388

389

# Usage

390

data = torch.tensor([1.0, 0.0, 1.0, 1.0, 0.0])

391

coin_flip_model(data)

392

```

393

394

### Hierarchical Model

395

396

```python

397

def hierarchical_model(group_data):

398

"""Hierarchical model with group-level parameters."""

399

# Global hyperpriors

400

mu_alpha = pyro.sample("mu_alpha", dist.Normal(0, 10))

401

sigma_alpha = pyro.sample("sigma_alpha", dist.HalfNormal(5))

402

403

# Group-specific parameters

404

with pyro.plate("groups", len(group_data)):

405

alpha = pyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))

406

407

# Observations within each group

408

for i, group in enumerate(group_data):

409

with pyro.plate(f"group_{i}_data", len(group)):

410

pyro.sample(f"obs_{i}", dist.Normal(alpha[i], 1), obs=group)

411

```

412

413

### Minibatch Training

414

415

```python

416

def minibatch_model(data_loader):

417

"""Model with minibatch training support."""

418

# Global parameters

419

mu = pyro.param("mu", torch.tensor(0.0))

420

sigma = pyro.param("sigma", torch.tensor(1.0), constraint=dist.constraints.positive)

421

422

# Process minibatch

423

for batch in data_loader:

424

with pyro.plate("data", len(batch), subsample_size=len(batch)):

425

pyro.sample("obs", dist.Normal(mu, sigma), obs=batch)

426

```