Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro provides a comprehensive collection of 150+ probability distributions organized across multiple categories. All distributions inherit from a common base class and provide consistent interfaces for sampling, log probability computation, and parameter validation.
Foundation classes that provide the core distribution interface and specialized distribution wrappers.
class Distribution:
"""
Base class for probability distributions in NumPyro.
Properties:
- batch_shape: Shape of batch dimensions
- event_shape: Shape of event dimensions
- support: Support constraint for the distribution
- has_rsample: Whether reparameterized sampling is supported
"""
def __init__(self, batch_shape=(), event_shape=(), validate_args=None): ...
def sample(self, key, sample_shape=()) -> Array: ...
def log_prob(self, value) -> Array: ...
def cdf(self, value) -> Array: ...
def icdf(self, q) -> Array: ...
def expand(self, batch_shape) -> 'Distribution': ...
def mask(self, mask) -> 'MaskedDistribution': ...
class ExpandedDistribution(Distribution):
"""Distribution with expanded batch dimensions."""
def __init__(self, base_distribution: Distribution, batch_shape: tuple): ...
class Independent(Distribution):
"""Reinterprets batch dimensions as event dimensions."""
def __init__(self, base_distribution: Distribution, reinterpreted_batch_ndims: int): ...
class TransformedDistribution(Distribution):
"""Distribution transformed by a bijective transformation."""
def __init__(self, base_distribution: Distribution, transforms): ...
class MaskedDistribution(Distribution):
"""Distribution with masked values."""
def __init__(self, base_distribution: Distribution, mask): ...
class FoldedDistribution(Distribution):
"""Distribution folded around zero by taking absolute value."""
def __init__(self, base_distribution: Distribution): ...Continuous probability distributions for modeling real-valued random variables.
class Normal(Distribution):
"""
Normal (Gaussian) distribution.
Args:
loc: Mean of the distribution
scale: Standard deviation of the distribution
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class Uniform(Distribution):
"""
Uniform distribution over an interval.
Args:
low: Lower bound of the distribution
high: Upper bound of the distribution
"""
def __init__(self, low=0.0, high=1.0, validate_args=None): ...
class Exponential(Distribution):
"""
Exponential distribution.
Args:
rate: Rate parameter (inverse scale)
"""
def __init__(self, rate=1.0, validate_args=None): ...
class Laplace(Distribution):
"""
Laplace (double exponential) distribution.
Args:
loc: Location parameter (mean)
scale: Scale parameter
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class Logistic(Distribution):
"""
Logistic distribution.
Args:
loc: Location parameter
scale: Scale parameter
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class LogNormal(Distribution):
"""
Log-normal distribution.
Args:
loc: Mean of underlying normal distribution
scale: Standard deviation of underlying normal distribution
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class Cauchy(Distribution):
"""
Cauchy distribution.
Args:
loc: Location parameter (median)
scale: Scale parameter
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class StudentT(Distribution):
"""
Student's t-distribution.
Args:
df: Degrees of freedom
loc: Location parameter (mean when df > 1)
scale: Scale parameter
"""
def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): ...class Beta(Distribution):
"""
Beta distribution.
Args:
concentration1: First concentration parameter (alpha)
concentration0: Second concentration parameter (beta)
"""
def __init__(self, concentration1, concentration0, validate_args=None): ...
class BetaProportion(Distribution):
"""
Beta distribution parameterized by mean and concentration.
Args:
mean: Mean of the distribution
concentration: Total concentration parameter
"""
def __init__(self, mean, concentration, validate_args=None): ...
class Gamma(Distribution):
"""
Gamma distribution.
Args:
concentration: Shape parameter (alpha)
rate: Rate parameter (beta), inverse of scale
"""
def __init__(self, concentration, rate=1.0, validate_args=None): ...
class InverseGamma(Distribution):
"""
Inverse Gamma distribution.
Args:
concentration: Shape parameter
rate: Rate parameter
"""
def __init__(self, concentration, rate, validate_args=None): ...
class Chi2(Distribution):
"""
Chi-squared distribution.
Args:
df: Degrees of freedom
"""
def __init__(self, df, validate_args=None): ...
class Dirichlet(Distribution):
"""
Dirichlet distribution over probability simplexes.
Args:
concentration: Concentration parameters
"""
def __init__(self, concentration, validate_args=None): ...class MultivariateNormal(Distribution):
"""
Multivariate normal distribution.
Args:
loc: Mean vector
covariance_matrix: Covariance matrix (optional)
precision_matrix: Precision matrix (optional)
scale_tril: Lower triangular Cholesky factor (optional)
"""
def __init__(self, loc, covariance_matrix=None, precision_matrix=None,
scale_tril=None, validate_args=None): ...
class LowRankMultivariateNormal(Distribution):
"""
Low-rank multivariate normal distribution.
Args:
loc: Mean vector
cov_factor: Low-rank covariance factor
cov_diag: Diagonal covariance component
"""
def __init__(self, loc, cov_factor, cov_diag, validate_args=None): ...
class MultivariateStudentT(Distribution):
"""
Multivariate Student's t-distribution.
Args:
df: Degrees of freedom
loc: Location vector
scale_tril: Lower triangular scale matrix
"""
def __init__(self, df, loc=0.0, scale_tril=None, validate_args=None): ...
class MatrixNormal(Distribution):
"""
Matrix normal distribution.
Args:
loc: Mean matrix
scale_tril_row: Row scale matrix (lower triangular)
scale_tril_col: Column scale matrix (lower triangular)
"""
def __init__(self, loc, scale_tril_row=None, scale_tril_col=None, validate_args=None): ...
class Wishart(Distribution):
"""
Wishart distribution over positive definite matrices.
Args:
df: Degrees of freedom
scale_tril: Lower triangular scale matrix
"""
def __init__(self, df, scale_tril, validate_args=None): ...
class LKJ(Distribution):
"""
LKJ distribution over correlation matrices.
Args:
dimension: Dimension of correlation matrices
concentration: Concentration parameter
"""
def __init__(self, dimension, concentration, validate_args=None): ...
class LKJCholesky(Distribution):
"""
LKJ distribution over Cholesky factors of correlation matrices.
Args:
dimension: Dimension of correlation matrices
concentration: Concentration parameter
"""
def __init__(self, dimension, concentration, validate_args=None): ...class HalfNormal(Distribution):
"""Half-normal distribution (normal folded at zero)."""
def __init__(self, scale=1.0, validate_args=None): ...
class HalfCauchy(Distribution):
"""Half-Cauchy distribution (Cauchy folded at zero)."""
def __init__(self, scale=1.0, validate_args=None): ...
class Pareto(Distribution):
"""
Pareto distribution.
Args:
scale: Scale parameter (minimum value)
alpha: Shape parameter
"""
def __init__(self, scale, alpha, validate_args=None): ...
class Weibull(Distribution):
"""
Weibull distribution.
Args:
scale: Scale parameter
concentration: Shape parameter
"""
def __init__(self, scale, concentration, validate_args=None): ...
class Gumbel(Distribution):
"""
Gumbel distribution.
Args:
loc: Location parameter
scale: Scale parameter
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class Levy(Distribution):
"""
Lévy distribution.
Args:
loc: Location parameter
scale: Scale parameter
"""
def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...
class Kumaraswamy(Distribution):
"""
Kumaraswamy distribution.
Args:
concentration1: First shape parameter
concentration0: Second shape parameter
"""
def __init__(self, concentration1, concentration0, validate_args=None): ...
class Gompertz(Distribution):
"""
Gompertz distribution.
Args:
scale: Scale parameter
concentration: Shape parameter
"""
def __init__(self, scale, concentration, validate_args=None): ...
class AsymmetricLaplace(Distribution):
"""
Asymmetric Laplace distribution.
Args:
loc: Location parameter
scale: Scale parameter
asymmetry: Asymmetry parameter
"""
def __init__(self, loc, scale, asymmetry, validate_args=None): ...
class SoftLaplace(Distribution):
"""Soft Laplace distribution for relaxed discrete variables."""
def __init__(self, loc, scale, validate_args=None): ...class GaussianRandomWalk(Distribution):
"""
Gaussian random walk distribution.
Args:
scale: Step size scale
num_steps: Number of time steps
"""
def __init__(self, scale=1.0, num_steps=1, validate_args=None): ...
class GaussianStateSpace(Distribution):
"""
Linear Gaussian state space model.
Args:
initial_state_mean: Initial state mean
initial_state_cov: Initial state covariance
transition_matrix: State transition matrix
transition_cov: Transition noise covariance
observation_matrix: Observation matrix
observation_cov: Observation noise covariance
"""
def __init__(self, initial_state_mean, initial_state_cov, transition_matrix,
transition_cov, observation_matrix, observation_cov, validate_args=None): ...
class EulerMaruyama(Distribution):
"""
Euler-Maruyama method for SDEs.
Args:
drift: Drift function
diffusion: Diffusion function
dt: Time step size
num_steps: Number of steps
"""
def __init__(self, drift, diffusion, dt, num_steps, validate_args=None): ...
class CAR(Distribution):
"""
Conditional Autoregressive (CAR) distribution.
Args:
loc: Location parameter
precision: Precision parameter
adjacency_matrix: Spatial adjacency matrix
"""
def __init__(self, loc, precision, adjacency_matrix, validate_args=None): ...Discrete probability distributions for modeling integer-valued random variables.
class Bernoulli(Distribution):
"""
Bernoulli distribution.
Args:
probs: Success probability (optional)
logits: Log-odds (optional)
"""
def __init__(self, probs=None, logits=None, validate_args=None): ...
class Categorical(Distribution):
"""
Categorical distribution over integers.
Args:
probs: Category probabilities (optional)
logits: Log probabilities (optional)
"""
def __init__(self, probs=None, logits=None, validate_args=None): ...
class Binomial(Distribution):
"""
Binomial distribution.
Args:
total_count: Number of trials
probs: Success probability (optional)
logits: Log-odds (optional)
"""
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...
class Multinomial(Distribution):
"""
Multinomial distribution.
Args:
total_count: Number of trials
probs: Category probabilities (optional)
logits: Log probabilities (optional)
"""
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...
class Poisson(Distribution):
"""
Poisson distribution.
Args:
rate: Rate parameter (mean)
"""
def __init__(self, rate, validate_args=None): ...
class Geometric(Distribution):
"""
Geometric distribution (number of failures before first success).
Args:
probs: Success probability (optional)
logits: Log-odds (optional)
"""
def __init__(self, probs=None, logits=None, validate_args=None): ...
class DiscreteUniform(Distribution):
"""
Discrete uniform distribution.
Args:
low: Lower bound (inclusive)
high: Upper bound (exclusive)
"""
def __init__(self, low=0, high=1, validate_args=None): ...
class OrderedLogistic(Distribution):
"""
Ordered logistic distribution for ordinal data.
Args:
predictor: Linear predictor
cutpoints: Ordered cutpoints
"""
def __init__(self, predictor, cutpoints, validate_args=None): ...class ZeroInflatedDistribution(Distribution):
"""
Zero-inflated wrapper for any discrete distribution.
Args:
base_dist: Base discrete distribution
gate: Probability of extra zeros
"""
def __init__(self, base_dist, gate=None, gate_logits=None, validate_args=None): ...
class ZeroInflatedPoisson(Distribution):
"""
Zero-inflated Poisson distribution.
Args:
rate: Poisson rate parameter
gate: Probability of extra zeros
"""
def __init__(self, rate, gate=None, gate_logits=None, validate_args=None): ...Distributions with known conjugate priors for efficient Bayesian inference.
class BetaBinomial(Distribution):
"""
Beta-binomial distribution (binomial with beta prior on probability).
Args:
concentration1: Beta alpha parameter
concentration0: Beta beta parameter
total_count: Number of trials
"""
def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): ...
class DirichletMultinomial(Distribution):
"""
Dirichlet-multinomial distribution.
Args:
concentration: Dirichlet concentration parameters
total_count: Number of trials
"""
def __init__(self, concentration, total_count=1, validate_args=None): ...
class GammaPoisson(Distribution):
"""
Gamma-Poisson (negative binomial) distribution.
Args:
concentration: Gamma shape parameter
rate: Gamma rate parameter
"""
def __init__(self, concentration, rate, validate_args=None): ...
class NegativeBinomial2(Distribution):
"""
Negative binomial distribution (NB2 parameterization).
Args:
mean: Mean parameter
concentration: Concentration parameter
"""
def __init__(self, mean, concentration, validate_args=None): ...
class ZeroInflatedNegativeBinomial2(Distribution):
"""Zero-inflated negative binomial distribution."""
def __init__(self, mean, concentration, gate=None, gate_logits=None, validate_args=None): ...Distributions for circular and spherical data.
class VonMises(Distribution):
"""
Von Mises distribution for circular data.
Args:
loc: Mean direction
concentration: Concentration parameter
"""
def __init__(self, loc, concentration, validate_args=None): ...
class ProjectedNormal(Distribution):
"""
Projected normal distribution on unit sphere.
Args:
concentration: Concentration vector
"""
def __init__(self, concentration, validate_args=None): ...
class SineBivariateVonMises(Distribution):
"""Sine bivariate von Mises distribution."""
def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration,
correlation, validate_args=None): ...
class SineSkewed(Distribution):
"""Sine-skewed circular distribution."""
def __init__(self, base_dist, skewness, validate_args=None): ...Finite mixture models for modeling multi-modal data.
class Mixture(Distribution):
"""
Finite mixture distribution.
Args:
mixing_distribution: Categorical mixing distribution
component_distributions: List of component distributions
"""
def __init__(self, mixing_distribution, component_distributions, validate_args=None): ...
class MixtureGeneral(Distribution):
"""General mixture distribution with flexible component selection."""
def __init__(self, mixing_distribution, component_distributions,
support=None, validate_args=None): ...
class MixtureSameFamily(Distribution):
"""
Mixture of distributions from the same family.
Args:
mixing_distribution: Categorical mixing distribution
component_distribution: Batch of component distributions
"""
def __init__(self, mixing_distribution, component_distribution, validate_args=None): ...Distributions with restricted support through truncation.
class TruncatedDistribution(Distribution):
"""
Generic truncated distribution.
Args:
base_distribution: Base distribution to truncate
low: Lower truncation bound
high: Upper truncation bound
"""
def __init__(self, base_distribution, low=None, high=None, validate_args=None): ...
class LeftTruncatedDistribution(Distribution):
"""Left-truncated distribution (truncated below)."""
def __init__(self, base_distribution, low, validate_args=None): ...
class RightTruncatedDistribution(Distribution):
"""Right-truncated distribution (truncated above)."""
def __init__(self, base_distribution, high, validate_args=None): ...
class TwoSidedTruncatedDistribution(Distribution):
"""Two-sided truncated distribution."""
def __init__(self, base_distribution, low, high, validate_args=None): ...
class TruncatedNormal(Distribution):
"""Truncated normal distribution."""
def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...
class TruncatedCauchy(Distribution):
"""Truncated Cauchy distribution."""
def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...
class LowerTruncatedPowerLaw(Distribution):
"""Lower truncated power law distribution."""
def __init__(self, alpha, scale, validate_args=None): ...
class DoublyTruncatedPowerLaw(Distribution):
"""Doubly truncated power law distribution."""
def __init__(self, alpha, low, high, validate_args=None): ...Copula-based distributions for modeling dependence structures.
class GaussianCopula(Distribution):
"""
Gaussian copula distribution.
Args:
correlation_matrix: Correlation matrix
marginals: List of marginal distributions
"""
def __init__(self, correlation_matrix, marginals, validate_args=None): ...
class GaussianCopulaBeta(Distribution):
"""Gaussian copula with Beta marginals."""
def __init__(self, correlation_matrix, concentration1, concentration0, validate_args=None): ...Utility distributions for specific modeling needs.
class Delta(Distribution):
"""
Point mass (Dirac delta) distribution.
Args:
v: Point mass location
log_density: Log density value at the point
event_dim: Number of event dimensions
"""
def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): ...
class Unit(Distribution):
"""Unit distribution for adding log probability factors."""
def __init__(self, log_factor, validate_args=None): ...
class ImproperUniform(Distribution):
"""
Improper uniform distribution over real numbers.
Args:
support: Support constraint
batch_shape: Batch shape
event_shape: Event shape
"""
def __init__(self, support, batch_shape, event_shape, validate_args=None): ...
class CirculantNormal(Distribution):
"""Normal distribution with circulant covariance matrix."""
def __init__(self, loc, circulant_cov, validate_args=None): ...
class ZeroSumNormal(Distribution):
"""Normal distribution with zero-sum constraint."""
def __init__(self, scale, validate_args=None): ...
class RelaxedBernoulli(Distribution):
"""Relaxed Bernoulli distribution (continuous relaxation)."""
def __init__(self, temperature, probs=None, logits=None, validate_args=None): ...def enable_validation(is_validate: bool) -> None:
"""Enable or disable distribution parameter validation."""
def validation_enabled() -> bool:
"""Check if distribution validation is currently enabled."""
def kl_divergence(p: Distribution, q: Distribution) -> Array:
"""Compute KL divergence between two distributions."""
def biject_to(constraint) -> Transform:
"""Get bijective transform to given constraint."""from typing import Optional, Union, Callable, Sequence
from jax import Array
import jax.numpy as jnp
ArrayLike = Union[Array, jnp.ndarray, float, int]
Constraint = numpyro.distributions.constraints.Constraint
Transform = numpyro.distributions.transforms.Transform
# Distribution parameter types
Concentration = ArrayLike # Positive real numbers
Rate = ArrayLike # Positive real numbers
Scale = ArrayLike # Positive real numbers
Probability = ArrayLike # Numbers in [0, 1]
Logits = ArrayLike # Real numbers
Location = ArrayLike # Real numbersInstall with Tessl CLI
npx tessl i tessl/pypi-numpyro