or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

primitives.mddocs/

0

# Primitives

1

2

NumPyro's primitive functions provide the core building blocks for probabilistic models. These functions enable sampling from distributions, defining parameters, handling conditional independence, and marking deterministic computations. All primitives integrate with the effect handler system and support automatic differentiation through JAX.

3

4

## Capabilities

5

6

### Core Sampling Primitives

7

8

The fundamental primitives for probabilistic programming.

9

10

```python { .api }

11

def sample(name: str, fn: Distribution, obs: Optional[ArrayLike] = None,

12

rng_key: Optional[Array] = None, sample_shape: tuple = (),

13

infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None) -> ArrayLike:

14

"""

15

Sample a value from a distribution or condition on observed data.

16

17

Args:

18

name: Name of the sample site (must be unique within model)

19

fn: Probability distribution to sample from

20

obs: Observed value to condition on (optional)

21

rng_key: Random key for sampling (optional, auto-generated if None)

22

sample_shape: Shape of samples to draw (for multiple samples)

23

infer: Dictionary of inference hints and configuration

24

obs_mask: Boolean mask for partially observed data

25

26

Returns:

27

Sampled value or observed value (if obs is provided)

28

29

Usage:

30

# Sample from prior

31

x = numpyro.sample("x", dist.Normal(0, 1))

32

33

# Condition on observed data

34

y = numpyro.sample("y", dist.Normal(x, 0.5), obs=observed_y)

35

36

# Sample multiple values

37

batch_samples = numpyro.sample("batch", dist.Normal(0, 1), sample_shape=(10,))

38

39

# Configure inference behavior

40

z = numpyro.sample("z", dist.Normal(0, 1), infer={"is_auxiliary": True})

41

"""

42

43

def param(name: str, init_value: Optional[Union[ArrayLike, Callable]] = None,

44

constraint: Constraint = constraints.real, event_dim: Optional[int] = None,

45

**kwargs) -> Optional[ArrayLike]:

46

"""

47

Declare an optimizable parameter in the model.

48

49

Args:

50

name: Parameter name (must be unique)

51

init_value: Initial value or initialization function

52

constraint: Parameter constraint (e.g., constraints.positive)

53

event_dim: Number of rightmost dimensions treated as event shape

54

**kwargs: Additional arguments (e.g., for initialization functions)

55

56

Returns:

57

Parameter value (None during initial model trace)

58

59

Usage:

60

# Simple parameter with constraint

61

sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)

62

63

# Parameter with initialization function

64

weights = numpyro.param("weights",

65

lambda key: random.normal(key, (10, 5)),

66

constraint=constraints.real)

67

68

# Simplex-constrained parameter

69

probs = numpyro.param("probs", jnp.ones(3) / 3, constraint=constraints.simplex)

70

"""

71

```

72

73

### Deterministic Sites

74

75

Primitives for marking deterministic computations and adding log probability factors.

76

77

```python { .api }

78

def deterministic(name: str, value: ArrayLike) -> ArrayLike:

79

"""

80

Mark a deterministic computation site for tracking in traces.

81

82

Args:

83

name: Name of the deterministic site

84

value: Computed deterministic value

85

86

Returns:

87

The input value (unchanged)

88

89

Usage:

90

x = numpyro.sample("x", dist.Normal(0, 1))

91

y = numpyro.sample("y", dist.Normal(0, 1))

92

93

# Mark sum as deterministic for tracking

94

sum_xy = numpyro.deterministic("sum", x + y)

95

96

# Can be used for derived quantities

97

mean_xy = numpyro.deterministic("mean", (x + y) / 2)

98

"""

99

100

def factor(name: str, log_factor: ArrayLike) -> None:

101

"""

102

Add a log probability factor to the model's joint density.

103

104

Args:

105

name: Name of the factor site

106

log_factor: Log probability value to add to joint density

107

108

Usage:

109

# Add log-likelihood term directly

110

numpyro.factor("custom_loglik", -0.5 * jnp.sum((y - mu)**2) / sigma**2)

111

112

# Add constraint violation penalty

113

numpyro.factor("penalty", -1e6 * jnp.where(x < 0, 1.0, 0.0))

114

115

# Add custom prior term

116

numpyro.factor("custom_prior", dist.Gamma(2, 1).log_prob(sigma))

117

"""

118

```

119

120

### Conditional Independence

121

122

Primitives for handling conditional independence and subsetting.

123

124

```python { .api }

125

class plate:

126

"""

127

Context manager for conditionally independent variables with automatic broadcasting.

128

129

Args:

130

name: Plate name (must be unique)

131

size: Size of the independence dimension

132

subsample_size: Size of subsample (for subsampling, optional)

133

dim: Dimension for broadcasting (negative, optional)

134

subsample: Indices for subsampling (optional)

135

136

Usage:

137

# Basic conditional independence

138

with numpyro.plate("data", 100):

139

x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (100,)

140

141

# Subsampling for large datasets

142

with numpyro.plate("data", 10000, subsample_size=100) as idx:

143

# idx contains the subsample indices

144

x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (100,)

145

146

# Nested plates for multidimensional independence

147

with numpyro.plate("batch", 50, dim=-2):

148

with numpyro.plate("features", 10, dim=-1):

149

weights = numpyro.sample("w", dist.Normal(0, 1)) # Shape: (50, 10)

150

"""

151

def __init__(self, name: str, size: int, subsample_size: Optional[int] = None,

152

dim: Optional[int] = None, subsample: Optional[ArrayLike] = None): ...

153

154

def __enter__(self) -> Optional[Array]:

155

"""Enter plate context, returning subsample indices if subsampling."""

156

157

def __exit__(self, exc_type, exc_value, traceback): ...

158

159

def plate_stack(prefix: str, sizes: list[int], rightmost_dim: int = -1) -> list:

160

"""

161

Create a stack of nested plates for multidimensional conditional independence.

162

163

Args:

164

prefix: Prefix for plate names

165

sizes: List of sizes for each dimension

166

rightmost_dim: Rightmost dimension index

167

168

Returns:

169

List of plate contexts

170

171

Usage:

172

# Create 3D tensor of independent samples

173

plates = numpyro.plate_stack("data", [20, 30, 40], rightmost_dim=-3)

174

with plates[0]:

175

with plates[1]:

176

with plates[2]:

177

x = numpyro.sample("x", dist.Normal(0, 1)) # Shape: (20, 30, 40)

178

"""

179

180

def subsample(data: ArrayLike, event_dim: int) -> ArrayLike:

181

"""

182

Subsample data based on active plates in the context.

183

184

Args:

185

data: Data tensor to subsample

186

event_dim: Number of rightmost dimensions that are event dimensions

187

188

Returns:

189

Subsampled data tensor

190

191

Usage:

192

# Subsample based on active plate

193

with numpyro.plate("data", len(full_data), subsample_size=100):

194

batch_data = numpyro.subsample(full_data, event_dim=0)

195

x = numpyro.sample("x", dist.Normal(batch_data, 1))

196

"""

197

```

198

199

### Advanced Primitives

200

201

Specialized primitives for advanced modeling scenarios.

202

203

```python { .api }

204

def mutable(name: str, init_value: Optional[ArrayLike] = None) -> ArrayLike:

205

"""

206

Create mutable storage that persists across function calls.

207

208

Args:

209

name: Name of the mutable site

210

init_value: Initial value for the mutable storage

211

212

Returns:

213

Current value of mutable storage

214

215

Usage:

216

# Counter that increments each call

217

count = numpyro.mutable("counter", 0)

218

numpyro.mutable("counter", count + 1) # Update the counter

219

"""

220

221

def module(name: str, nn: tuple, input_shape: Optional[tuple] = None) -> Callable:

222

"""

223

Register neural network modules for use with JAX transformations.

224

225

Args:

226

name: Module name

227

nn: Tuple of (init_fn, apply_fn) for neural network

228

input_shape: Input shape for module initialization

229

230

Returns:

231

Module function that can be called with inputs

232

233

Usage:

234

# Haiku neural network

235

import haiku as hk

236

237

def net_fn(x):

238

return hk.nets.MLP([64, 32, 1])(x)

239

240

net = hk.transform(net_fn)

241

module_fn = numpyro.module("mlp", net, input_shape=(10,))

242

243

# Use in model

244

x = numpyro.sample("x", dist.Normal(0, 1).expand((batch_size, 10)))

245

y_pred = module_fn(x)

246

"""

247

248

def prng_key() -> Optional[Array]:

249

"""

250

Get the current PRNG key from the execution context.

251

252

Returns:

253

Current random key or None if not available

254

255

Usage:

256

# Get key for manual random operations

257

key = numpyro.prng_key()

258

if key is not None:

259

noise = random.normal(key, shape=(10,))

260

"""

261

262

def get_mask() -> Optional[ArrayLike]:

263

"""

264

Get the current mask from the handler stack.

265

266

Returns:

267

Current mask array or None if no mask is active

268

269

Usage:

270

# Check if masking is active

271

current_mask = numpyro.get_mask()

272

if current_mask is not None:

273

# Handle masked computation

274

pass

275

"""

276

```

277

278

### Internal Utilities

279

280

Internal functions used by the primitive system (typically not used directly).

281

282

```python { .api }

283

def _masked_observe(name: str, fn: Distribution, obs: ArrayLike,

284

obs_mask: ArrayLike, **kwargs) -> ArrayLike:

285

"""

286

Handle masked observations in sample sites.

287

288

Args:

289

name: Site name

290

fn: Distribution

291

obs: Observed values

292

obs_mask: Boolean mask for valid observations

293

**kwargs: Additional arguments

294

295

Returns:

296

Masked observed value

297

"""

298

299

def _subsample_fn(size: int, subsample_size: int,

300

rng_key: Optional[Array] = None) -> Array:

301

"""

302

Generate subsample indices for plate subsampling.

303

304

Args:

305

size: Full dataset size

306

subsample_size: Size of subsample

307

rng_key: Random key for sampling

308

309

Returns:

310

Array of subsample indices

311

"""

312

313

def _inspect() -> dict:

314

"""

315

Inspect the current Pyro stack (experimental).

316

317

Returns:

318

Dictionary containing stack information

319

"""

320

321

class CondIndepStackFrame:

322

"""

323

Named tuple representing a conditional independence stack frame.

324

325

Attributes:

326

name: Frame name

327

dim: Broadcasting dimension

328

size: Frame size

329

counter: Frame counter for tracking

330

"""

331

name: str

332

dim: int

333

size: int

334

counter: int

335

```

336

337

### Validation and Inspection

338

339

Utilities for validating models and inspecting execution.

340

341

```python { .api }

342

def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:

343

"""

344

Validate model structure and return trace information.

345

346

Args:

347

model: Model function to validate

348

*model_args: Arguments to pass to model

349

**model_kwargs: Keyword arguments to pass to model

350

351

Returns:

352

Dictionary containing validation results and trace information

353

354

Usage:

355

def my_model():

356

x = numpyro.sample("x", dist.Normal(0, 1))

357

y = numpyro.sample("y", dist.Normal(x, 1))

358

359

validation_info = numpyro.validate_model(my_model)

360

print(f"Model has {len(validation_info['sites'])} sites")

361

"""

362

363

def inspect_fn(fn: Callable, *args, **kwargs) -> dict:

364

"""

365

Inspect function execution and return detailed information.

366

367

Args:

368

fn: Function to inspect

369

*args: Arguments to pass to function

370

**kwargs: Keyword arguments to pass to function

371

372

Returns:

373

Dictionary with execution information including sites and dependencies

374

"""

375

```

376

377

## Usage Examples

378

379

```python

380

import numpyro

381

import numpyro.distributions as dist

382

import jax.numpy as jnp

383

from jax import random

384

385

# Basic linear regression model

386

def linear_regression(X, y=None):

387

# Prior parameters

388

alpha = numpyro.sample("alpha", dist.Normal(0, 10))

389

beta = numpyro.sample("beta", dist.Normal(0, 10))

390

sigma = numpyro.param("sigma", 1.0, constraint=constraints.positive)

391

392

# Linear prediction

393

mu = alpha + beta * X

394

395

# Mark prediction for tracking

396

prediction = numpyro.deterministic("prediction", mu)

397

398

# Likelihood with conditional independence over data points

399

with numpyro.plate("data", X.shape[0]):

400

numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

401

402

# Hierarchical model with nested plates

403

def hierarchical_model(group_idx, y=None):

404

n_groups = len(jnp.unique(group_idx))

405

n_obs = len(y) if y is not None else len(group_idx)

406

407

# Global hyperparameters

408

mu_global = numpyro.sample("mu_global", dist.Normal(0, 1))

409

sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))

410

411

# Group-level parameters

412

with numpyro.plate("groups", n_groups):

413

mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))

414

415

# Observation-level likelihood

416

with numpyro.plate("obs", n_obs):

417

mu = mu_group[group_idx]

418

numpyro.sample("y", dist.Normal(mu, 1), obs=y)

419

420

# Model with subsampling for large datasets

421

def large_dataset_model(X, y=None):

422

n_data, n_features = X.shape

423

424

# Parameters

425

weights = numpyro.sample("weights", dist.Normal(0, 1).expand((n_features,)))

426

427

# Subsample for computational efficiency

428

with numpyro.plate("data", n_data, subsample_size=min(1000, n_data)) as idx:

429

X_batch = numpyro.subsample(X, event_dim=1)[idx] if idx is not None else X

430

y_batch = numpyro.subsample(y, event_dim=0)[idx] if y is not None and idx is not None else y

431

432

mu = X_batch @ weights

433

numpyro.sample("y", dist.Normal(mu, 0.1), obs=y_batch)

434

435

# Custom factor for non-standard likelihoods

436

def custom_likelihood_model(data):

437

theta = numpyro.sample("theta", dist.Beta(1, 1))

438

439

# Custom log-likelihood that doesn't fit standard distributions

440

log_lik = jnp.sum(data * jnp.log(theta) + (1 - data) * jnp.log(1 - theta))

441

numpyro.factor("custom_lik", log_lik)

442

```

443

444

## Types

445

446

```python { .api }

447

from typing import Optional, Union, Callable, Dict, Any, Tuple

448

from jax import Array

449

import jax.numpy as jnp

450

from numpyro.distributions import Distribution, constraints

451

452

ArrayLike = Union[Array, jnp.ndarray, float, int]

453

Constraint = constraints.Constraint

454

InitFunction = Union[ArrayLike, Callable[[Array], ArrayLike]]

455

456

class CondIndepStackFrame:

457

"""Frame in the conditional independence stack."""

458

name: str

459

dim: int

460

size: int

461

counter: int

462

463

class PlateMessenger:

464

"""Messenger for plate context management."""

465

name: str

466

size: int

467

subsample_size: Optional[int]

468

dim: Optional[int]

469

subsample: Optional[Array]

470

471

# Site types for different primitive operations

472

SiteType = Union["sample", "param", "deterministic", "factor", "mutable"]

473

474

class SiteInfo:

475

"""Information about a primitive site."""

476

name: str

477

type: SiteType

478

fn: Optional[Distribution]

479

args: tuple

480

kwargs: dict

481

value: Any

482

is_observed: bool

483

infer: dict

484

scale: Optional[float]

485

486

class ValidationResult:

487

"""Result from model validation."""

488

sites: dict

489

dependencies: dict

490

plate_stack: list

491

is_valid: bool

492

warnings: list

493

errors: list

494

```