0
# Distributions
1
2
NumPyro provides a comprehensive collection of 150+ probability distributions organized across multiple categories. All distributions inherit from a common base class and provide consistent interfaces for sampling, log probability computation, and parameter validation.
3
4
## Capabilities
5
6
### Base Distribution Classes
7
8
Foundation classes that provide the core distribution interface and specialized distribution wrappers.
9
10
```python { .api }
11
class Distribution:
12
"""
13
Base class for probability distributions in NumPyro.
14
15
Properties:
16
- batch_shape: Shape of batch dimensions
17
- event_shape: Shape of event dimensions
18
- support: Support constraint for the distribution
19
- has_rsample: Whether reparameterized sampling is supported
20
"""
21
def __init__(self, batch_shape=(), event_shape=(), validate_args=None): ...
22
def sample(self, key, sample_shape=()) -> Array: ...
23
def log_prob(self, value) -> Array: ...
24
def cdf(self, value) -> Array: ...
25
def icdf(self, q) -> Array: ...
26
def expand(self, batch_shape) -> 'Distribution': ...
27
def mask(self, mask) -> 'MaskedDistribution': ...
28
29
class ExpandedDistribution(Distribution):
30
"""Distribution with expanded batch dimensions."""
31
def __init__(self, base_distribution: Distribution, batch_shape: tuple): ...
32
33
class Independent(Distribution):
34
"""Reinterprets batch dimensions as event dimensions."""
35
def __init__(self, base_distribution: Distribution, reinterpreted_batch_ndims: int): ...
36
37
class TransformedDistribution(Distribution):
38
"""Distribution transformed by a bijective transformation."""
39
def __init__(self, base_distribution: Distribution, transforms): ...
40
41
class MaskedDistribution(Distribution):
42
"""Distribution with masked values."""
43
def __init__(self, base_distribution: Distribution, mask): ...
44
45
class FoldedDistribution(Distribution):
46
"""Distribution folded around zero by taking absolute value."""
47
def __init__(self, base_distribution: Distribution): ...
48
```
49
50
### Continuous Distributions
51
52
Continuous probability distributions for modeling real-valued random variables.
53
54
#### Basic Continuous Distributions
55
56
```python { .api }
57
class Normal(Distribution):
58
"""
59
Normal (Gaussian) distribution.
60
61
Args:
62
loc: Mean of the distribution
63
scale: Standard deviation of the distribution
64
"""
65
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
66
67
class Uniform(Distribution):
68
"""
69
Uniform distribution over an interval.
70
71
Args:
72
low: Lower bound of the distribution
73
high: Upper bound of the distribution
74
"""
75
def __init__(self, low=0.0, high=1.0, validate_args=None): ...
76
77
class Exponential(Distribution):
78
"""
79
Exponential distribution.
80
81
Args:
82
rate: Rate parameter (inverse scale)
83
"""
84
def __init__(self, rate=1.0, validate_args=None): ...
85
86
class Laplace(Distribution):
87
"""
88
Laplace (double exponential) distribution.
89
90
Args:
91
loc: Location parameter (mean)
92
scale: Scale parameter
93
"""
94
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
95
96
class Logistic(Distribution):
97
"""
98
Logistic distribution.
99
100
Args:
101
loc: Location parameter
102
scale: Scale parameter
103
"""
104
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
105
106
class LogNormal(Distribution):
107
"""
108
Log-normal distribution.
109
110
Args:
111
loc: Mean of underlying normal distribution
112
scale: Standard deviation of underlying normal distribution
113
"""
114
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
115
116
class Cauchy(Distribution):
117
"""
118
Cauchy distribution.
119
120
Args:
121
loc: Location parameter (median)
122
scale: Scale parameter
123
"""
124
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
125
126
class StudentT(Distribution):
127
"""
128
Student's t-distribution.
129
130
Args:
131
df: Degrees of freedom
132
loc: Location parameter (mean when df > 1)
133
scale: Scale parameter
134
"""
135
def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): ...
136
```
137
138
#### Beta and Gamma Family
139
140
```python { .api }
141
class Beta(Distribution):
142
"""
143
Beta distribution.
144
145
Args:
146
concentration1: First concentration parameter (alpha)
147
concentration0: Second concentration parameter (beta)
148
"""
149
def __init__(self, concentration1, concentration0, validate_args=None): ...
150
151
class BetaProportion(Distribution):
152
"""
153
Beta distribution parameterized by mean and concentration.
154
155
Args:
156
mean: Mean of the distribution
157
concentration: Total concentration parameter
158
"""
159
def __init__(self, mean, concentration, validate_args=None): ...
160
161
class Gamma(Distribution):
162
"""
163
Gamma distribution.
164
165
Args:
166
concentration: Shape parameter (alpha)
167
rate: Rate parameter (beta), inverse of scale
168
"""
169
def __init__(self, concentration, rate=1.0, validate_args=None): ...
170
171
class InverseGamma(Distribution):
172
"""
173
Inverse Gamma distribution.
174
175
Args:
176
concentration: Shape parameter
177
rate: Rate parameter
178
"""
179
def __init__(self, concentration, rate, validate_args=None): ...
180
181
class Chi2(Distribution):
182
"""
183
Chi-squared distribution.
184
185
Args:
186
df: Degrees of freedom
187
"""
188
def __init__(self, df, validate_args=None): ...
189
190
class Dirichlet(Distribution):
191
"""
192
Dirichlet distribution over probability simplexes.
193
194
Args:
195
concentration: Concentration parameters
196
"""
197
def __init__(self, concentration, validate_args=None): ...
198
```
199
200
#### Multivariate Continuous Distributions
201
202
```python { .api }
203
class MultivariateNormal(Distribution):
204
"""
205
Multivariate normal distribution.
206
207
Args:
208
loc: Mean vector
209
covariance_matrix: Covariance matrix (optional)
210
precision_matrix: Precision matrix (optional)
211
scale_tril: Lower triangular Cholesky factor (optional)
212
"""
213
def __init__(self, loc, covariance_matrix=None, precision_matrix=None,
214
scale_tril=None, validate_args=None): ...
215
216
class LowRankMultivariateNormal(Distribution):
217
"""
218
Low-rank multivariate normal distribution.
219
220
Args:
221
loc: Mean vector
222
cov_factor: Low-rank covariance factor
223
cov_diag: Diagonal covariance component
224
"""
225
def __init__(self, loc, cov_factor, cov_diag, validate_args=None): ...
226
227
class MultivariateStudentT(Distribution):
228
"""
229
Multivariate Student's t-distribution.
230
231
Args:
232
df: Degrees of freedom
233
loc: Location vector
234
scale_tril: Lower triangular scale matrix
235
"""
236
def __init__(self, df, loc=0.0, scale_tril=None, validate_args=None): ...
237
238
class MatrixNormal(Distribution):
239
"""
240
Matrix normal distribution.
241
242
Args:
243
loc: Mean matrix
244
scale_tril_row: Row scale matrix (lower triangular)
245
scale_tril_col: Column scale matrix (lower triangular)
246
"""
247
def __init__(self, loc, scale_tril_row=None, scale_tril_col=None, validate_args=None): ...
248
249
class Wishart(Distribution):
250
"""
251
Wishart distribution over positive definite matrices.
252
253
Args:
254
df: Degrees of freedom
255
scale_tril: Lower triangular scale matrix
256
"""
257
def __init__(self, df, scale_tril, validate_args=None): ...
258
259
class LKJ(Distribution):
260
"""
261
LKJ distribution over correlation matrices.
262
263
Args:
264
dimension: Dimension of correlation matrices
265
concentration: Concentration parameter
266
"""
267
def __init__(self, dimension, concentration, validate_args=None): ...
268
269
class LKJCholesky(Distribution):
270
"""
271
LKJ distribution over Cholesky factors of correlation matrices.
272
273
Args:
274
dimension: Dimension of correlation matrices
275
concentration: Concentration parameter
276
"""
277
def __init__(self, dimension, concentration, validate_args=None): ...
278
```
279
280
#### Specialized Continuous Distributions
281
282
```python { .api }
283
class HalfNormal(Distribution):
284
"""Half-normal distribution (normal folded at zero)."""
285
def __init__(self, scale=1.0, validate_args=None): ...
286
287
class HalfCauchy(Distribution):
288
"""Half-Cauchy distribution (Cauchy folded at zero)."""
289
def __init__(self, scale=1.0, validate_args=None): ...
290
291
class Pareto(Distribution):
292
"""
293
Pareto distribution.
294
295
Args:
296
scale: Scale parameter (minimum value)
297
alpha: Shape parameter
298
"""
299
def __init__(self, scale, alpha, validate_args=None): ...
300
301
class Weibull(Distribution):
302
"""
303
Weibull distribution.
304
305
Args:
306
scale: Scale parameter
307
concentration: Shape parameter
308
"""
309
def __init__(self, scale, concentration, validate_args=None): ...
310
311
class Gumbel(Distribution):
312
"""
313
Gumbel distribution.
314
315
Args:
316
loc: Location parameter
317
scale: Scale parameter
318
"""
319
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
320
321
class Levy(Distribution):
322
"""
323
Lévy distribution.
324
325
Args:
326
loc: Location parameter
327
scale: Scale parameter
328
"""
329
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
330
331
class Kumaraswamy(Distribution):
332
"""
333
Kumaraswamy distribution.
334
335
Args:
336
concentration1: First shape parameter
337
concentration0: Second shape parameter
338
"""
339
def __init__(self, concentration1, concentration0, validate_args=None): ...
340
341
class Gompertz(Distribution):
342
"""
343
Gompertz distribution.
344
345
Args:
346
scale: Scale parameter
347
concentration: Shape parameter
348
"""
349
def __init__(self, scale, concentration, validate_args=None): ...
350
351
class AsymmetricLaplace(Distribution):
352
"""
353
Asymmetric Laplace distribution.
354
355
Args:
356
loc: Location parameter
357
scale: Scale parameter
358
asymmetry: Asymmetry parameter
359
"""
360
def __init__(self, loc, scale, asymmetry, validate_args=None): ...
361
362
class SoftLaplace(Distribution):
363
"""Soft Laplace distribution for relaxed discrete variables."""
364
def __init__(self, loc, scale, validate_args=None): ...
365
```
366
367
#### Time Series Distributions
368
369
```python { .api }
370
class GaussianRandomWalk(Distribution):
371
"""
372
Gaussian random walk distribution.
373
374
Args:
375
scale: Step size scale
376
num_steps: Number of time steps
377
"""
378
def __init__(self, scale=1.0, num_steps=1, validate_args=None): ...
379
380
class GaussianStateSpace(Distribution):
381
"""
382
Linear Gaussian state space model.
383
384
Args:
385
initial_state_mean: Initial state mean
386
initial_state_cov: Initial state covariance
387
transition_matrix: State transition matrix
388
transition_cov: Transition noise covariance
389
observation_matrix: Observation matrix
390
observation_cov: Observation noise covariance
391
"""
392
def __init__(self, initial_state_mean, initial_state_cov, transition_matrix,
393
transition_cov, observation_matrix, observation_cov, validate_args=None): ...
394
395
class EulerMaruyama(Distribution):
396
"""
397
Euler-Maruyama method for SDEs.
398
399
Args:
400
drift: Drift function
401
diffusion: Diffusion function
402
dt: Time step size
403
num_steps: Number of steps
404
"""
405
def __init__(self, drift, diffusion, dt, num_steps, validate_args=None): ...
406
407
class CAR(Distribution):
408
"""
409
Conditional Autoregressive (CAR) distribution.
410
411
Args:
412
loc: Location parameter
413
precision: Precision parameter
414
adjacency_matrix: Spatial adjacency matrix
415
"""
416
def __init__(self, loc, precision, adjacency_matrix, validate_args=None): ...
417
```
418
419
### Discrete Distributions
420
421
Discrete probability distributions for modeling integer-valued random variables.
422
423
#### Basic Discrete Distributions
424
425
```python { .api }
426
class Bernoulli(Distribution):
427
"""
428
Bernoulli distribution.
429
430
Args:
431
probs: Success probability (optional)
432
logits: Log-odds (optional)
433
"""
434
def __init__(self, probs=None, logits=None, validate_args=None): ...
435
436
class Categorical(Distribution):
437
"""
438
Categorical distribution over integers.
439
440
Args:
441
probs: Category probabilities (optional)
442
logits: Log probabilities (optional)
443
"""
444
def __init__(self, probs=None, logits=None, validate_args=None): ...
445
446
class Binomial(Distribution):
447
"""
448
Binomial distribution.
449
450
Args:
451
total_count: Number of trials
452
probs: Success probability (optional)
453
logits: Log-odds (optional)
454
"""
455
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...
456
457
class Multinomial(Distribution):
458
"""
459
Multinomial distribution.
460
461
Args:
462
total_count: Number of trials
463
probs: Category probabilities (optional)
464
logits: Log probabilities (optional)
465
"""
466
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...
467
468
class Poisson(Distribution):
469
"""
470
Poisson distribution.
471
472
Args:
473
rate: Rate parameter (mean)
474
"""
475
def __init__(self, rate, validate_args=None): ...
476
477
class Geometric(Distribution):
478
"""
479
Geometric distribution (number of failures before first success).
480
481
Args:
482
probs: Success probability (optional)
483
logits: Log-odds (optional)
484
"""
485
def __init__(self, probs=None, logits=None, validate_args=None): ...
486
487
class DiscreteUniform(Distribution):
488
"""
489
Discrete uniform distribution.
490
491
Args:
492
low: Lower bound (inclusive)
493
high: Upper bound (exclusive)
494
"""
495
def __init__(self, low=0, high=1, validate_args=None): ...
496
497
class OrderedLogistic(Distribution):
498
"""
499
Ordered logistic distribution for ordinal data.
500
501
Args:
502
predictor: Linear predictor
503
cutpoints: Ordered cutpoints
504
"""
505
def __init__(self, predictor, cutpoints, validate_args=None): ...
506
```
507
508
#### Zero-Inflated Distributions
509
510
```python { .api }
511
class ZeroInflatedDistribution(Distribution):
512
"""
513
Zero-inflated wrapper for any discrete distribution.
514
515
Args:
516
base_dist: Base discrete distribution
517
gate: Probability of extra zeros
518
"""
519
def __init__(self, base_dist, gate=None, gate_logits=None, validate_args=None): ...
520
521
class ZeroInflatedPoisson(Distribution):
522
"""
523
Zero-inflated Poisson distribution.
524
525
Args:
526
rate: Poisson rate parameter
527
gate: Probability of extra zeros
528
"""
529
def __init__(self, rate, gate=None, gate_logits=None, validate_args=None): ...
530
```
531
532
### Conjugate Distributions
533
534
Distributions with known conjugate priors for efficient Bayesian inference.
535
536
```python { .api }
537
class BetaBinomial(Distribution):
538
"""
539
Beta-binomial distribution (binomial with beta prior on probability).
540
541
Args:
542
concentration1: Beta alpha parameter
543
concentration0: Beta beta parameter
544
total_count: Number of trials
545
"""
546
def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): ...
547
548
class DirichletMultinomial(Distribution):
549
"""
550
Dirichlet-multinomial distribution.
551
552
Args:
553
concentration: Dirichlet concentration parameters
554
total_count: Number of trials
555
"""
556
def __init__(self, concentration, total_count=1, validate_args=None): ...
557
558
class GammaPoisson(Distribution):
559
"""
560
Gamma-Poisson (negative binomial) distribution.
561
562
Args:
563
concentration: Gamma shape parameter
564
rate: Gamma rate parameter
565
"""
566
def __init__(self, concentration, rate, validate_args=None): ...
567
568
class NegativeBinomial2(Distribution):
569
"""
570
Negative binomial distribution (NB2 parameterization).
571
572
Args:
573
mean: Mean parameter
574
concentration: Concentration parameter
575
"""
576
def __init__(self, mean, concentration, validate_args=None): ...
577
578
class ZeroInflatedNegativeBinomial2(Distribution):
579
"""Zero-inflated negative binomial distribution."""
580
def __init__(self, mean, concentration, gate=None, gate_logits=None, validate_args=None): ...
581
```
582
583
### Directional Distributions
584
585
Distributions for circular and spherical data.
586
587
```python { .api }
588
class VonMises(Distribution):
589
"""
590
Von Mises distribution for circular data.
591
592
Args:
593
loc: Mean direction
594
concentration: Concentration parameter
595
"""
596
def __init__(self, loc, concentration, validate_args=None): ...
597
598
class ProjectedNormal(Distribution):
599
"""
600
Projected normal distribution on unit sphere.
601
602
Args:
603
concentration: Concentration vector
604
"""
605
def __init__(self, concentration, validate_args=None): ...
606
607
class SineBivariateVonMises(Distribution):
608
"""Sine bivariate von Mises distribution."""
609
def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration,
610
correlation, validate_args=None): ...
611
612
class SineSkewed(Distribution):
613
"""Sine-skewed circular distribution."""
614
def __init__(self, base_dist, skewness, validate_args=None): ...
615
```
616
617
### Mixture Distributions
618
619
Finite mixture models for modeling multi-modal data.
620
621
```python { .api }
622
class Mixture(Distribution):
623
"""
624
Finite mixture distribution.
625
626
Args:
627
mixing_distribution: Categorical mixing distribution
628
component_distributions: List of component distributions
629
"""
630
def __init__(self, mixing_distribution, component_distributions, validate_args=None): ...
631
632
class MixtureGeneral(Distribution):
633
"""General mixture distribution with flexible component selection."""
634
def __init__(self, mixing_distribution, component_distributions,
635
support=None, validate_args=None): ...
636
637
class MixtureSameFamily(Distribution):
638
"""
639
Mixture of distributions from the same family.
640
641
Args:
642
mixing_distribution: Categorical mixing distribution
643
component_distribution: Batch of component distributions
644
"""
645
def __init__(self, mixing_distribution, component_distribution, validate_args=None): ...
646
```
647
648
### Truncated Distributions
649
650
Distributions with restricted support through truncation.
651
652
```python { .api }
653
class TruncatedDistribution(Distribution):
654
"""
655
Generic truncated distribution.
656
657
Args:
658
base_distribution: Base distribution to truncate
659
low: Lower truncation bound
660
high: Upper truncation bound
661
"""
662
def __init__(self, base_distribution, low=None, high=None, validate_args=None): ...
663
664
class LeftTruncatedDistribution(Distribution):
665
"""Left-truncated distribution (truncated below)."""
666
def __init__(self, base_distribution, low, validate_args=None): ...
667
668
class RightTruncatedDistribution(Distribution):
669
"""Right-truncated distribution (truncated above)."""
670
def __init__(self, base_distribution, high, validate_args=None): ...
671
672
class TwoSidedTruncatedDistribution(Distribution):
673
"""Two-sided truncated distribution."""
674
def __init__(self, base_distribution, low, high, validate_args=None): ...
675
676
class TruncatedNormal(Distribution):
677
"""Truncated normal distribution."""
678
def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...
679
680
class TruncatedCauchy(Distribution):
681
"""Truncated Cauchy distribution."""
682
def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...
683
684
class LowerTruncatedPowerLaw(Distribution):
685
"""Lower truncated power law distribution."""
686
def __init__(self, alpha, scale, validate_args=None): ...
687
688
class DoublyTruncatedPowerLaw(Distribution):
689
"""Doubly truncated power law distribution."""
690
def __init__(self, alpha, low, high, validate_args=None): ...
691
```
692
693
### Copula Distributions
694
695
Copula-based distributions for modeling dependence structures.
696
697
```python { .api }
698
class GaussianCopula(Distribution):
699
"""
700
Gaussian copula distribution.
701
702
Args:
703
correlation_matrix: Correlation matrix
704
marginals: List of marginal distributions
705
"""
706
def __init__(self, correlation_matrix, marginals, validate_args=None): ...
707
708
class GaussianCopulaBeta(Distribution):
709
"""Gaussian copula with Beta marginals."""
710
def __init__(self, correlation_matrix, concentration1, concentration0, validate_args=None): ...
711
```
712
713
### Special Distributions
714
715
Utility distributions for specific modeling needs.
716
717
```python { .api }
718
class Delta(Distribution):
719
"""
720
Point mass (Dirac delta) distribution.
721
722
Args:
723
v: Point mass location
724
log_density: Log density value at the point
725
event_dim: Number of event dimensions
726
"""
727
def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): ...
728
729
class Unit(Distribution):
730
"""Unit distribution for adding log probability factors."""
731
def __init__(self, log_factor, validate_args=None): ...
732
733
class ImproperUniform(Distribution):
734
"""
735
Improper uniform distribution over real numbers.
736
737
Args:
738
support: Support constraint
739
batch_shape: Batch shape
740
event_shape: Event shape
741
"""
742
def __init__(self, support, batch_shape, event_shape, validate_args=None): ...
743
744
class CirculantNormal(Distribution):
745
"""Normal distribution with circulant covariance matrix."""
746
def __init__(self, loc, circulant_cov, validate_args=None): ...
747
748
class ZeroSumNormal(Distribution):
749
"""Normal distribution with zero-sum constraint."""
750
def __init__(self, scale, validate_args=None): ...
751
752
class RelaxedBernoulli(Distribution):
753
"""Relaxed Bernoulli distribution (continuous relaxation)."""
754
def __init__(self, temperature, probs=None, logits=None, validate_args=None): ...
755
```
756
757
### Distribution Utilities
758
759
```python { .api }
760
def enable_validation(is_validate: bool) -> None:
761
"""Enable or disable distribution parameter validation."""
762
763
def validation_enabled() -> bool:
764
"""Check if distribution validation is currently enabled."""
765
766
def kl_divergence(p: Distribution, q: Distribution) -> Array:
767
"""Compute KL divergence between two distributions."""
768
769
def biject_to(constraint) -> Transform:
770
"""Get bijective transform to given constraint."""
771
```
772
773
## Types
774
775
```python { .api }
776
from typing import Optional, Union, Callable, Sequence
777
from jax import Array
778
import jax.numpy as jnp
779
780
ArrayLike = Union[Array, jnp.ndarray, float, int]
781
Constraint = numpyro.distributions.constraints.Constraint
782
Transform = numpyro.distributions.transforms.Transform
783
784
# Distribution parameter types
785
Concentration = ArrayLike # Positive real numbers
786
Rate = ArrayLike # Positive real numbers
787
Scale = ArrayLike # Positive real numbers
788
Probability = ArrayLike # Numbers in [0, 1]
789
Logits = ArrayLike # Real numbers
790
Location = ArrayLike # Real numbers
791
```