or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

distributions.mddocs/

0

# Distributions

1

2

NumPyro provides a comprehensive collection of 150+ probability distributions organized across multiple categories. All distributions inherit from a common base class and provide consistent interfaces for sampling, log probability computation, and parameter validation.

3

4

## Capabilities

5

6

### Base Distribution Classes

7

8

Foundation classes that provide the core distribution interface and specialized distribution wrappers.

9

10

```python { .api }

11

class Distribution:

12

"""

13

Base class for probability distributions in NumPyro.

14

15

Properties:

16

- batch_shape: Shape of batch dimensions

17

- event_shape: Shape of event dimensions

18

- support: Support constraint for the distribution

19

- has_rsample: Whether reparameterized sampling is supported

20

"""

21

def __init__(self, batch_shape=(), event_shape=(), validate_args=None): ...

22

def sample(self, key, sample_shape=()) -> Array: ...

23

def log_prob(self, value) -> Array: ...

24

def cdf(self, value) -> Array: ...

25

def icdf(self, q) -> Array: ...

26

def expand(self, batch_shape) -> 'Distribution': ...

27

def mask(self, mask) -> 'MaskedDistribution': ...

28

29

class ExpandedDistribution(Distribution):

30

"""Distribution with expanded batch dimensions."""

31

def __init__(self, base_distribution: Distribution, batch_shape: tuple): ...

32

33

class Independent(Distribution):

34

"""Reinterprets batch dimensions as event dimensions."""

35

def __init__(self, base_distribution: Distribution, reinterpreted_batch_ndims: int): ...

36

37

class TransformedDistribution(Distribution):

38

"""Distribution transformed by a bijective transformation."""

39

def __init__(self, base_distribution: Distribution, transforms): ...

40

41

class MaskedDistribution(Distribution):

42

"""Distribution with masked values."""

43

def __init__(self, base_distribution: Distribution, mask): ...

44

45

class FoldedDistribution(Distribution):

46

"""Distribution folded around zero by taking absolute value."""

47

def __init__(self, base_distribution: Distribution): ...

48

```

49

50

### Continuous Distributions

51

52

Continuous probability distributions for modeling real-valued random variables.

53

54

#### Basic Continuous Distributions

55

56

```python { .api }

57

class Normal(Distribution):

58

"""

59

Normal (Gaussian) distribution.

60

61

Args:

62

loc: Mean of the distribution

63

scale: Standard deviation of the distribution

64

"""

65

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

66

67

class Uniform(Distribution):

68

"""

69

Uniform distribution over an interval.

70

71

Args:

72

low: Lower bound of the distribution

73

high: Upper bound of the distribution

74

"""

75

def __init__(self, low=0.0, high=1.0, validate_args=None): ...

76

77

class Exponential(Distribution):

78

"""

79

Exponential distribution.

80

81

Args:

82

rate: Rate parameter (inverse scale)

83

"""

84

def __init__(self, rate=1.0, validate_args=None): ...

85

86

class Laplace(Distribution):

87

"""

88

Laplace (double exponential) distribution.

89

90

Args:

91

loc: Location parameter (mean)

92

scale: Scale parameter

93

"""

94

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

95

96

class Logistic(Distribution):

97

"""

98

Logistic distribution.

99

100

Args:

101

loc: Location parameter

102

scale: Scale parameter

103

"""

104

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

105

106

class LogNormal(Distribution):

107

"""

108

Log-normal distribution.

109

110

Args:

111

loc: Mean of underlying normal distribution

112

scale: Standard deviation of underlying normal distribution

113

"""

114

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

115

116

class Cauchy(Distribution):

117

"""

118

Cauchy distribution.

119

120

Args:

121

loc: Location parameter (median)

122

scale: Scale parameter

123

"""

124

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

125

126

class StudentT(Distribution):

127

"""

128

Student's t-distribution.

129

130

Args:

131

df: Degrees of freedom

132

loc: Location parameter (mean when df > 1)

133

scale: Scale parameter

134

"""

135

def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): ...

136

```

137

138

#### Beta and Gamma Family

139

140

```python { .api }

141

class Beta(Distribution):

142

"""

143

Beta distribution.

144

145

Args:

146

concentration1: First concentration parameter (alpha)

147

concentration0: Second concentration parameter (beta)

148

"""

149

def __init__(self, concentration1, concentration0, validate_args=None): ...

150

151

class BetaProportion(Distribution):

152

"""

153

Beta distribution parameterized by mean and concentration.

154

155

Args:

156

mean: Mean of the distribution

157

concentration: Total concentration parameter

158

"""

159

def __init__(self, mean, concentration, validate_args=None): ...

160

161

class Gamma(Distribution):

162

"""

163

Gamma distribution.

164

165

Args:

166

concentration: Shape parameter (alpha)

167

rate: Rate parameter (beta), inverse of scale

168

"""

169

def __init__(self, concentration, rate=1.0, validate_args=None): ...

170

171

class InverseGamma(Distribution):

172

"""

173

Inverse Gamma distribution.

174

175

Args:

176

concentration: Shape parameter

177

rate: Rate parameter

178

"""

179

def __init__(self, concentration, rate, validate_args=None): ...

180

181

class Chi2(Distribution):

182

"""

183

Chi-squared distribution.

184

185

Args:

186

df: Degrees of freedom

187

"""

188

def __init__(self, df, validate_args=None): ...

189

190

class Dirichlet(Distribution):

191

"""

192

Dirichlet distribution over probability simplexes.

193

194

Args:

195

concentration: Concentration parameters

196

"""

197

def __init__(self, concentration, validate_args=None): ...

198

```

199

200

#### Multivariate Continuous Distributions

201

202

```python { .api }

203

class MultivariateNormal(Distribution):

204

"""

205

Multivariate normal distribution.

206

207

Args:

208

loc: Mean vector

209

covariance_matrix: Covariance matrix (optional)

210

precision_matrix: Precision matrix (optional)

211

scale_tril: Lower triangular Cholesky factor (optional)

212

"""

213

def __init__(self, loc, covariance_matrix=None, precision_matrix=None,

214

scale_tril=None, validate_args=None): ...

215

216

class LowRankMultivariateNormal(Distribution):

217

"""

218

Low-rank multivariate normal distribution.

219

220

Args:

221

loc: Mean vector

222

cov_factor: Low-rank covariance factor

223

cov_diag: Diagonal covariance component

224

"""

225

def __init__(self, loc, cov_factor, cov_diag, validate_args=None): ...

226

227

class MultivariateStudentT(Distribution):

228

"""

229

Multivariate Student's t-distribution.

230

231

Args:

232

df: Degrees of freedom

233

loc: Location vector

234

scale_tril: Lower triangular scale matrix

235

"""

236

def __init__(self, df, loc=0.0, scale_tril=None, validate_args=None): ...

237

238

class MatrixNormal(Distribution):

239

"""

240

Matrix normal distribution.

241

242

Args:

243

loc: Mean matrix

244

scale_tril_row: Row scale matrix (lower triangular)

245

scale_tril_col: Column scale matrix (lower triangular)

246

"""

247

def __init__(self, loc, scale_tril_row=None, scale_tril_col=None, validate_args=None): ...

248

249

class Wishart(Distribution):

250

"""

251

Wishart distribution over positive definite matrices.

252

253

Args:

254

df: Degrees of freedom

255

scale_tril: Lower triangular scale matrix

256

"""

257

def __init__(self, df, scale_tril, validate_args=None): ...

258

259

class LKJ(Distribution):

260

"""

261

LKJ distribution over correlation matrices.

262

263

Args:

264

dimension: Dimension of correlation matrices

265

concentration: Concentration parameter

266

"""

267

def __init__(self, dimension, concentration, validate_args=None): ...

268

269

class LKJCholesky(Distribution):

270

"""

271

LKJ distribution over Cholesky factors of correlation matrices.

272

273

Args:

274

dimension: Dimension of correlation matrices

275

concentration: Concentration parameter

276

"""

277

def __init__(self, dimension, concentration, validate_args=None): ...

278

```

279

280

#### Specialized Continuous Distributions

281

282

```python { .api }

283

class HalfNormal(Distribution):

284

"""Half-normal distribution (normal folded at zero)."""

285

def __init__(self, scale=1.0, validate_args=None): ...

286

287

class HalfCauchy(Distribution):

288

"""Half-Cauchy distribution (Cauchy folded at zero)."""

289

def __init__(self, scale=1.0, validate_args=None): ...

290

291

class Pareto(Distribution):

292

"""

293

Pareto distribution.

294

295

Args:

296

scale: Scale parameter (minimum value)

297

alpha: Shape parameter

298

"""

299

def __init__(self, scale, alpha, validate_args=None): ...

300

301

class Weibull(Distribution):

302

"""

303

Weibull distribution.

304

305

Args:

306

scale: Scale parameter

307

concentration: Shape parameter

308

"""

309

def __init__(self, scale, concentration, validate_args=None): ...

310

311

class Gumbel(Distribution):

312

"""

313

Gumbel distribution.

314

315

Args:

316

loc: Location parameter

317

scale: Scale parameter

318

"""

319

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

320

321

class Levy(Distribution):

322

"""

323

Lévy distribution.

324

325

Args:

326

loc: Location parameter

327

scale: Scale parameter

328

"""

329

def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

330

331

class Kumaraswamy(Distribution):

332

"""

333

Kumaraswamy distribution.

334

335

Args:

336

concentration1: First shape parameter

337

concentration0: Second shape parameter

338

"""

339

def __init__(self, concentration1, concentration0, validate_args=None): ...

340

341

class Gompertz(Distribution):

342

"""

343

Gompertz distribution.

344

345

Args:

346

scale: Scale parameter

347

concentration: Shape parameter

348

"""

349

def __init__(self, scale, concentration, validate_args=None): ...

350

351

class AsymmetricLaplace(Distribution):

352

"""

353

Asymmetric Laplace distribution.

354

355

Args:

356

loc: Location parameter

357

scale: Scale parameter

358

asymmetry: Asymmetry parameter

359

"""

360

def __init__(self, loc, scale, asymmetry, validate_args=None): ...

361

362

class SoftLaplace(Distribution):

363

"""Soft Laplace distribution for relaxed discrete variables."""

364

def __init__(self, loc, scale, validate_args=None): ...

365

```

366

367

#### Time Series Distributions

368

369

```python { .api }

370

class GaussianRandomWalk(Distribution):

371

"""

372

Gaussian random walk distribution.

373

374

Args:

375

scale: Step size scale

376

num_steps: Number of time steps

377

"""

378

def __init__(self, scale=1.0, num_steps=1, validate_args=None): ...

379

380

class GaussianStateSpace(Distribution):

381

"""

382

Linear Gaussian state space model.

383

384

Args:

385

initial_state_mean: Initial state mean

386

initial_state_cov: Initial state covariance

387

transition_matrix: State transition matrix

388

transition_cov: Transition noise covariance

389

observation_matrix: Observation matrix

390

observation_cov: Observation noise covariance

391

"""

392

def __init__(self, initial_state_mean, initial_state_cov, transition_matrix,

393

transition_cov, observation_matrix, observation_cov, validate_args=None): ...

394

395

class EulerMaruyama(Distribution):

396

"""

397

Euler-Maruyama method for SDEs.

398

399

Args:

400

drift: Drift function

401

diffusion: Diffusion function

402

dt: Time step size

403

num_steps: Number of steps

404

"""

405

def __init__(self, drift, diffusion, dt, num_steps, validate_args=None): ...

406

407

class CAR(Distribution):

408

"""

409

Conditional Autoregressive (CAR) distribution.

410

411

Args:

412

loc: Location parameter

413

precision: Precision parameter

414

adjacency_matrix: Spatial adjacency matrix

415

"""

416

def __init__(self, loc, precision, adjacency_matrix, validate_args=None): ...

417

```

418

419

### Discrete Distributions

420

421

Discrete probability distributions for modeling integer-valued random variables.

422

423

#### Basic Discrete Distributions

424

425

```python { .api }

426

class Bernoulli(Distribution):

427

"""

428

Bernoulli distribution.

429

430

Args:

431

probs: Success probability (optional)

432

logits: Log-odds (optional)

433

"""

434

def __init__(self, probs=None, logits=None, validate_args=None): ...

435

436

class Categorical(Distribution):

437

"""

438

Categorical distribution over integers.

439

440

Args:

441

probs: Category probabilities (optional)

442

logits: Log probabilities (optional)

443

"""

444

def __init__(self, probs=None, logits=None, validate_args=None): ...

445

446

class Binomial(Distribution):

447

"""

448

Binomial distribution.

449

450

Args:

451

total_count: Number of trials

452

probs: Success probability (optional)

453

logits: Log-odds (optional)

454

"""

455

def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...

456

457

class Multinomial(Distribution):

458

"""

459

Multinomial distribution.

460

461

Args:

462

total_count: Number of trials

463

probs: Category probabilities (optional)

464

logits: Log probabilities (optional)

465

"""

466

def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...

467

468

class Poisson(Distribution):

469

"""

470

Poisson distribution.

471

472

Args:

473

rate: Rate parameter (mean)

474

"""

475

def __init__(self, rate, validate_args=None): ...

476

477

class Geometric(Distribution):

478

"""

479

Geometric distribution (number of failures before first success).

480

481

Args:

482

probs: Success probability (optional)

483

logits: Log-odds (optional)

484

"""

485

def __init__(self, probs=None, logits=None, validate_args=None): ...

486

487

class DiscreteUniform(Distribution):

488

"""

489

Discrete uniform distribution.

490

491

Args:

492

low: Lower bound (inclusive)

493

high: Upper bound (exclusive)

494

"""

495

def __init__(self, low=0, high=1, validate_args=None): ...

496

497

class OrderedLogistic(Distribution):

498

"""

499

Ordered logistic distribution for ordinal data.

500

501

Args:

502

predictor: Linear predictor

503

cutpoints: Ordered cutpoints

504

"""

505

def __init__(self, predictor, cutpoints, validate_args=None): ...

506

```

507

508

#### Zero-Inflated Distributions

509

510

```python { .api }

511

class ZeroInflatedDistribution(Distribution):

512

"""

513

Zero-inflated wrapper for any discrete distribution.

514

515

Args:

516

base_dist: Base discrete distribution

517

gate: Probability of extra zeros

518

"""

519

def __init__(self, base_dist, gate=None, gate_logits=None, validate_args=None): ...

520

521

class ZeroInflatedPoisson(Distribution):

522

"""

523

Zero-inflated Poisson distribution.

524

525

Args:

526

rate: Poisson rate parameter

527

gate: Probability of extra zeros

528

"""

529

def __init__(self, rate, gate=None, gate_logits=None, validate_args=None): ...

530

```

531

532

### Conjugate Distributions

533

534

Distributions with known conjugate priors for efficient Bayesian inference.

535

536

```python { .api }

537

class BetaBinomial(Distribution):

538

"""

539

Beta-binomial distribution (binomial with beta prior on probability).

540

541

Args:

542

concentration1: Beta alpha parameter

543

concentration0: Beta beta parameter

544

total_count: Number of trials

545

"""

546

def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): ...

547

548

class DirichletMultinomial(Distribution):

549

"""

550

Dirichlet-multinomial distribution.

551

552

Args:

553

concentration: Dirichlet concentration parameters

554

total_count: Number of trials

555

"""

556

def __init__(self, concentration, total_count=1, validate_args=None): ...

557

558

class GammaPoisson(Distribution):

559

"""

560

Gamma-Poisson (negative binomial) distribution.

561

562

Args:

563

concentration: Gamma shape parameter

564

rate: Gamma rate parameter

565

"""

566

def __init__(self, concentration, rate, validate_args=None): ...

567

568

class NegativeBinomial2(Distribution):

569

"""

570

Negative binomial distribution (NB2 parameterization).

571

572

Args:

573

mean: Mean parameter

574

concentration: Concentration parameter

575

"""

576

def __init__(self, mean, concentration, validate_args=None): ...

577

578

class ZeroInflatedNegativeBinomial2(Distribution):

579

"""Zero-inflated negative binomial distribution."""

580

def __init__(self, mean, concentration, gate=None, gate_logits=None, validate_args=None): ...

581

```

582

583

### Directional Distributions

584

585

Distributions for circular and spherical data.

586

587

```python { .api }

588

class VonMises(Distribution):

589

"""

590

Von Mises distribution for circular data.

591

592

Args:

593

loc: Mean direction

594

concentration: Concentration parameter

595

"""

596

def __init__(self, loc, concentration, validate_args=None): ...

597

598

class ProjectedNormal(Distribution):

599

"""

600

Projected normal distribution on unit sphere.

601

602

Args:

603

concentration: Concentration vector

604

"""

605

def __init__(self, concentration, validate_args=None): ...

606

607

class SineBivariateVonMises(Distribution):

608

"""Sine bivariate von Mises distribution."""

609

def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration,

610

correlation, validate_args=None): ...

611

612

class SineSkewed(Distribution):

613

"""Sine-skewed circular distribution."""

614

def __init__(self, base_dist, skewness, validate_args=None): ...

615

```

616

617

### Mixture Distributions

618

619

Finite mixture models for modeling multi-modal data.

620

621

```python { .api }

622

class Mixture(Distribution):

623

"""

624

Finite mixture distribution.

625

626

Args:

627

mixing_distribution: Categorical mixing distribution

628

component_distributions: List of component distributions

629

"""

630

def __init__(self, mixing_distribution, component_distributions, validate_args=None): ...

631

632

class MixtureGeneral(Distribution):

633

"""General mixture distribution with flexible component selection."""

634

def __init__(self, mixing_distribution, component_distributions,

635

support=None, validate_args=None): ...

636

637

class MixtureSameFamily(Distribution):

638

"""

639

Mixture of distributions from the same family.

640

641

Args:

642

mixing_distribution: Categorical mixing distribution

643

component_distribution: Batch of component distributions

644

"""

645

def __init__(self, mixing_distribution, component_distribution, validate_args=None): ...

646

```

647

648

### Truncated Distributions

649

650

Distributions with restricted support through truncation.

651

652

```python { .api }

653

class TruncatedDistribution(Distribution):

654

"""

655

Generic truncated distribution.

656

657

Args:

658

base_distribution: Base distribution to truncate

659

low: Lower truncation bound

660

high: Upper truncation bound

661

"""

662

def __init__(self, base_distribution, low=None, high=None, validate_args=None): ...

663

664

class LeftTruncatedDistribution(Distribution):

665

"""Left-truncated distribution (truncated below)."""

666

def __init__(self, base_distribution, low, validate_args=None): ...

667

668

class RightTruncatedDistribution(Distribution):

669

"""Right-truncated distribution (truncated above)."""

670

def __init__(self, base_distribution, high, validate_args=None): ...

671

672

class TwoSidedTruncatedDistribution(Distribution):

673

"""Two-sided truncated distribution."""

674

def __init__(self, base_distribution, low, high, validate_args=None): ...

675

676

class TruncatedNormal(Distribution):

677

"""Truncated normal distribution."""

678

def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...

679

680

class TruncatedCauchy(Distribution):

681

"""Truncated Cauchy distribution."""

682

def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...

683

684

class LowerTruncatedPowerLaw(Distribution):

685

"""Lower truncated power law distribution."""

686

def __init__(self, alpha, scale, validate_args=None): ...

687

688

class DoublyTruncatedPowerLaw(Distribution):

689

"""Doubly truncated power law distribution."""

690

def __init__(self, alpha, low, high, validate_args=None): ...

691

```

692

693

### Copula Distributions

694

695

Copula-based distributions for modeling dependence structures.

696

697

```python { .api }

698

class GaussianCopula(Distribution):

699

"""

700

Gaussian copula distribution.

701

702

Args:

703

correlation_matrix: Correlation matrix

704

marginals: List of marginal distributions

705

"""

706

def __init__(self, correlation_matrix, marginals, validate_args=None): ...

707

708

class GaussianCopulaBeta(Distribution):

709

"""Gaussian copula with Beta marginals."""

710

def __init__(self, correlation_matrix, concentration1, concentration0, validate_args=None): ...

711

```

712

713

### Special Distributions

714

715

Utility distributions for specific modeling needs.

716

717

```python { .api }

718

class Delta(Distribution):

719

"""

720

Point mass (Dirac delta) distribution.

721

722

Args:

723

v: Point mass location

724

log_density: Log density value at the point

725

event_dim: Number of event dimensions

726

"""

727

def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): ...

728

729

class Unit(Distribution):

730

"""Unit distribution for adding log probability factors."""

731

def __init__(self, log_factor, validate_args=None): ...

732

733

class ImproperUniform(Distribution):

734

"""

735

Improper uniform distribution over real numbers.

736

737

Args:

738

support: Support constraint

739

batch_shape: Batch shape

740

event_shape: Event shape

741

"""

742

def __init__(self, support, batch_shape, event_shape, validate_args=None): ...

743

744

class CirculantNormal(Distribution):

745

"""Normal distribution with circulant covariance matrix."""

746

def __init__(self, loc, circulant_cov, validate_args=None): ...

747

748

class ZeroSumNormal(Distribution):

749

"""Normal distribution with zero-sum constraint."""

750

def __init__(self, scale, validate_args=None): ...

751

752

class RelaxedBernoulli(Distribution):

753

"""Relaxed Bernoulli distribution (continuous relaxation)."""

754

def __init__(self, temperature, probs=None, logits=None, validate_args=None): ...

755

```

756

757

### Distribution Utilities

758

759

```python { .api }

760

def enable_validation(is_validate: bool) -> None:

761

"""Enable or disable distribution parameter validation."""

762

763

def validation_enabled() -> bool:

764

"""Check if distribution validation is currently enabled."""

765

766

def kl_divergence(p: Distribution, q: Distribution) -> Array:

767

"""Compute KL divergence between two distributions."""

768

769

def biject_to(constraint) -> Transform:

770

"""Get bijective transform to given constraint."""

771

```

772

773

## Types

774

775

```python { .api }

776

from typing import Optional, Union, Callable, Sequence

777

from jax import Array

778

import jax.numpy as jnp

779

780

ArrayLike = Union[Array, jnp.ndarray, float, int]

781

Constraint = numpyro.distributions.constraints.Constraint

782

Transform = numpyro.distributions.transforms.Transform

783

784

# Distribution parameter types

785

Concentration = ArrayLike # Positive real numbers

786

Rate = ArrayLike # Positive real numbers

787

Scale = ArrayLike # Positive real numbers

788

Probability = ArrayLike # Numbers in [0, 1]

789

Logits = ArrayLike # Real numbers

790

Location = ArrayLike # Real numbers

791

```