CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

distributions.mddocs/

Distributions

NumPyro provides a comprehensive collection of 150+ probability distributions organized across multiple categories. All distributions inherit from a common base class and provide consistent interfaces for sampling, log probability computation, and parameter validation.

Capabilities

Base Distribution Classes

Foundation classes that provide the core distribution interface and specialized distribution wrappers.

class Distribution:
    """
    Base class for probability distributions in NumPyro.
    
    Properties:
    - batch_shape: Shape of batch dimensions
    - event_shape: Shape of event dimensions  
    - support: Support constraint for the distribution
    - has_rsample: Whether reparameterized sampling is supported
    """
    def __init__(self, batch_shape=(), event_shape=(), validate_args=None): ...
    def sample(self, key, sample_shape=()) -> Array: ...
    def log_prob(self, value) -> Array: ...
    def cdf(self, value) -> Array: ...
    def icdf(self, q) -> Array: ...
    def expand(self, batch_shape) -> 'Distribution': ...
    def mask(self, mask) -> 'MaskedDistribution': ...

class ExpandedDistribution(Distribution):
    """Distribution with expanded batch dimensions."""
    def __init__(self, base_distribution: Distribution, batch_shape: tuple): ...

class Independent(Distribution):
    """Reinterprets batch dimensions as event dimensions."""
    def __init__(self, base_distribution: Distribution, reinterpreted_batch_ndims: int): ...

class TransformedDistribution(Distribution):
    """Distribution transformed by a bijective transformation."""
    def __init__(self, base_distribution: Distribution, transforms): ...

class MaskedDistribution(Distribution):
    """Distribution with masked values."""
    def __init__(self, base_distribution: Distribution, mask): ...

class FoldedDistribution(Distribution):
    """Distribution folded around zero by taking absolute value."""
    def __init__(self, base_distribution: Distribution): ...

Continuous Distributions

Continuous probability distributions for modeling real-valued random variables.

Basic Continuous Distributions

class Normal(Distribution):
    """
    Normal (Gaussian) distribution.
    
    Args:
        loc: Mean of the distribution
        scale: Standard deviation of the distribution
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class Uniform(Distribution):
    """
    Uniform distribution over an interval.
    
    Args:
        low: Lower bound of the distribution
        high: Upper bound of the distribution
    """
    def __init__(self, low=0.0, high=1.0, validate_args=None): ...

class Exponential(Distribution):
    """
    Exponential distribution.
    
    Args:
        rate: Rate parameter (inverse scale)
    """
    def __init__(self, rate=1.0, validate_args=None): ...

class Laplace(Distribution):
    """
    Laplace (double exponential) distribution.
    
    Args:
        loc: Location parameter (mean)
        scale: Scale parameter
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class Logistic(Distribution):
    """
    Logistic distribution.
    
    Args:
        loc: Location parameter
        scale: Scale parameter
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class LogNormal(Distribution):
    """
    Log-normal distribution.
    
    Args:
        loc: Mean of underlying normal distribution
        scale: Standard deviation of underlying normal distribution
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class Cauchy(Distribution):
    """
    Cauchy distribution.
    
    Args:
        loc: Location parameter (median)
        scale: Scale parameter
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class StudentT(Distribution):
    """
    Student's t-distribution.
    
    Args:
        df: Degrees of freedom
        loc: Location parameter (mean when df > 1)
        scale: Scale parameter
    """
    def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): ...

Beta and Gamma Family

class Beta(Distribution):
    """
    Beta distribution.
    
    Args:
        concentration1: First concentration parameter (alpha)
        concentration0: Second concentration parameter (beta)
    """
    def __init__(self, concentration1, concentration0, validate_args=None): ...

class BetaProportion(Distribution):
    """
    Beta distribution parameterized by mean and concentration.
    
    Args:
        mean: Mean of the distribution
        concentration: Total concentration parameter
    """
    def __init__(self, mean, concentration, validate_args=None): ...

class Gamma(Distribution):
    """
    Gamma distribution.
    
    Args:
        concentration: Shape parameter (alpha)
        rate: Rate parameter (beta), inverse of scale
    """
    def __init__(self, concentration, rate=1.0, validate_args=None): ...

class InverseGamma(Distribution):
    """
    Inverse Gamma distribution.
    
    Args:
        concentration: Shape parameter
        rate: Rate parameter
    """
    def __init__(self, concentration, rate, validate_args=None): ...

class Chi2(Distribution):
    """
    Chi-squared distribution.
    
    Args:
        df: Degrees of freedom
    """
    def __init__(self, df, validate_args=None): ...

class Dirichlet(Distribution):
    """
    Dirichlet distribution over probability simplexes.
    
    Args:
        concentration: Concentration parameters
    """
    def __init__(self, concentration, validate_args=None): ...

Multivariate Continuous Distributions

class MultivariateNormal(Distribution):
    """
    Multivariate normal distribution.
    
    Args:
        loc: Mean vector
        covariance_matrix: Covariance matrix (optional)
        precision_matrix: Precision matrix (optional)
        scale_tril: Lower triangular Cholesky factor (optional)
    """
    def __init__(self, loc, covariance_matrix=None, precision_matrix=None, 
                scale_tril=None, validate_args=None): ...

class LowRankMultivariateNormal(Distribution):
    """
    Low-rank multivariate normal distribution.
    
    Args:
        loc: Mean vector
        cov_factor: Low-rank covariance factor
        cov_diag: Diagonal covariance component
    """
    def __init__(self, loc, cov_factor, cov_diag, validate_args=None): ...

class MultivariateStudentT(Distribution):
    """
    Multivariate Student's t-distribution.
    
    Args:
        df: Degrees of freedom
        loc: Location vector
        scale_tril: Lower triangular scale matrix
    """
    def __init__(self, df, loc=0.0, scale_tril=None, validate_args=None): ...

class MatrixNormal(Distribution):
    """
    Matrix normal distribution.
    
    Args:
        loc: Mean matrix
        scale_tril_row: Row scale matrix (lower triangular)
        scale_tril_col: Column scale matrix (lower triangular)
    """
    def __init__(self, loc, scale_tril_row=None, scale_tril_col=None, validate_args=None): ...

class Wishart(Distribution):
    """
    Wishart distribution over positive definite matrices.
    
    Args:
        df: Degrees of freedom
        scale_tril: Lower triangular scale matrix
    """
    def __init__(self, df, scale_tril, validate_args=None): ...

class LKJ(Distribution):
    """
    LKJ distribution over correlation matrices.
    
    Args:
        dimension: Dimension of correlation matrices
        concentration: Concentration parameter
    """
    def __init__(self, dimension, concentration, validate_args=None): ...

class LKJCholesky(Distribution):
    """
    LKJ distribution over Cholesky factors of correlation matrices.
    
    Args:
        dimension: Dimension of correlation matrices
        concentration: Concentration parameter
    """
    def __init__(self, dimension, concentration, validate_args=None): ...

Specialized Continuous Distributions

class HalfNormal(Distribution):
    """Half-normal distribution (normal folded at zero)."""
    def __init__(self, scale=1.0, validate_args=None): ...

class HalfCauchy(Distribution):
    """Half-Cauchy distribution (Cauchy folded at zero)."""
    def __init__(self, scale=1.0, validate_args=None): ...

class Pareto(Distribution):
    """
    Pareto distribution.
    
    Args:
        scale: Scale parameter (minimum value)
        alpha: Shape parameter
    """
    def __init__(self, scale, alpha, validate_args=None): ...

class Weibull(Distribution):
    """
    Weibull distribution.
    
    Args:
        scale: Scale parameter
        concentration: Shape parameter
    """
    def __init__(self, scale, concentration, validate_args=None): ...

class Gumbel(Distribution):
    """
    Gumbel distribution.
    
    Args:
        loc: Location parameter
        scale: Scale parameter
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class Levy(Distribution):
    """
    Lévy distribution.
    
    Args:
        loc: Location parameter
        scale: Scale parameter
    """
    def __init__(self, loc=0.0, scale=1.0, validate_args=None): ...

class Kumaraswamy(Distribution):
    """
    Kumaraswamy distribution.
    
    Args:
        concentration1: First shape parameter
        concentration0: Second shape parameter
    """
    def __init__(self, concentration1, concentration0, validate_args=None): ...

class Gompertz(Distribution):
    """
    Gompertz distribution.
    
    Args:
        scale: Scale parameter
        concentration: Shape parameter
    """
    def __init__(self, scale, concentration, validate_args=None): ...

class AsymmetricLaplace(Distribution):
    """
    Asymmetric Laplace distribution.
    
    Args:
        loc: Location parameter
        scale: Scale parameter
        asymmetry: Asymmetry parameter
    """
    def __init__(self, loc, scale, asymmetry, validate_args=None): ...

class SoftLaplace(Distribution):
    """Soft Laplace distribution for relaxed discrete variables."""
    def __init__(self, loc, scale, validate_args=None): ...

Time Series Distributions

class GaussianRandomWalk(Distribution):
    """
    Gaussian random walk distribution.
    
    Args:
        scale: Step size scale
        num_steps: Number of time steps
    """
    def __init__(self, scale=1.0, num_steps=1, validate_args=None): ...

class GaussianStateSpace(Distribution):
    """
    Linear Gaussian state space model.
    
    Args:
        initial_state_mean: Initial state mean
        initial_state_cov: Initial state covariance
        transition_matrix: State transition matrix
        transition_cov: Transition noise covariance
        observation_matrix: Observation matrix
        observation_cov: Observation noise covariance
    """
    def __init__(self, initial_state_mean, initial_state_cov, transition_matrix,
                transition_cov, observation_matrix, observation_cov, validate_args=None): ...

class EulerMaruyama(Distribution):
    """
    Euler-Maruyama method for SDEs.
    
    Args:
        drift: Drift function
        diffusion: Diffusion function
        dt: Time step size
        num_steps: Number of steps
    """
    def __init__(self, drift, diffusion, dt, num_steps, validate_args=None): ...

class CAR(Distribution):
    """
    Conditional Autoregressive (CAR) distribution.
    
    Args:
        loc: Location parameter
        precision: Precision parameter
        adjacency_matrix: Spatial adjacency matrix
    """
    def __init__(self, loc, precision, adjacency_matrix, validate_args=None): ...

Discrete Distributions

Discrete probability distributions for modeling integer-valued random variables.

Basic Discrete Distributions

class Bernoulli(Distribution):
    """
    Bernoulli distribution.
    
    Args:
        probs: Success probability (optional)
        logits: Log-odds (optional)
    """
    def __init__(self, probs=None, logits=None, validate_args=None): ...

class Categorical(Distribution):
    """
    Categorical distribution over integers.
    
    Args:
        probs: Category probabilities (optional)
        logits: Log probabilities (optional)
    """
    def __init__(self, probs=None, logits=None, validate_args=None): ...

class Binomial(Distribution):
    """
    Binomial distribution.
    
    Args:
        total_count: Number of trials
        probs: Success probability (optional)
        logits: Log-odds (optional)  
    """
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...

class Multinomial(Distribution):
    """
    Multinomial distribution.
    
    Args:
        total_count: Number of trials
        probs: Category probabilities (optional)
        logits: Log probabilities (optional)
    """
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ...

class Poisson(Distribution):
    """
    Poisson distribution.
    
    Args:
        rate: Rate parameter (mean)
    """
    def __init__(self, rate, validate_args=None): ...

class Geometric(Distribution):
    """
    Geometric distribution (number of failures before first success).
    
    Args:
        probs: Success probability (optional)
        logits: Log-odds (optional)
    """
    def __init__(self, probs=None, logits=None, validate_args=None): ...

class DiscreteUniform(Distribution):
    """
    Discrete uniform distribution.
    
    Args:
        low: Lower bound (inclusive)
        high: Upper bound (exclusive)
    """
    def __init__(self, low=0, high=1, validate_args=None): ...

class OrderedLogistic(Distribution):
    """
    Ordered logistic distribution for ordinal data.
    
    Args:
        predictor: Linear predictor
        cutpoints: Ordered cutpoints
    """
    def __init__(self, predictor, cutpoints, validate_args=None): ...

Zero-Inflated Distributions

class ZeroInflatedDistribution(Distribution):
    """
    Zero-inflated wrapper for any discrete distribution.
    
    Args:
        base_dist: Base discrete distribution
        gate: Probability of extra zeros
    """
    def __init__(self, base_dist, gate=None, gate_logits=None, validate_args=None): ...

class ZeroInflatedPoisson(Distribution):
    """
    Zero-inflated Poisson distribution.
    
    Args:
        rate: Poisson rate parameter
        gate: Probability of extra zeros
    """
    def __init__(self, rate, gate=None, gate_logits=None, validate_args=None): ...

Conjugate Distributions

Distributions with known conjugate priors for efficient Bayesian inference.

class BetaBinomial(Distribution):
    """
    Beta-binomial distribution (binomial with beta prior on probability).
    
    Args:
        concentration1: Beta alpha parameter
        concentration0: Beta beta parameter  
        total_count: Number of trials
    """
    def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): ...

class DirichletMultinomial(Distribution):
    """
    Dirichlet-multinomial distribution.
    
    Args:
        concentration: Dirichlet concentration parameters
        total_count: Number of trials
    """
    def __init__(self, concentration, total_count=1, validate_args=None): ...

class GammaPoisson(Distribution):
    """
    Gamma-Poisson (negative binomial) distribution.
    
    Args:
        concentration: Gamma shape parameter
        rate: Gamma rate parameter
    """
    def __init__(self, concentration, rate, validate_args=None): ...

class NegativeBinomial2(Distribution):
    """
    Negative binomial distribution (NB2 parameterization).
    
    Args:
        mean: Mean parameter
        concentration: Concentration parameter
    """
    def __init__(self, mean, concentration, validate_args=None): ...

class ZeroInflatedNegativeBinomial2(Distribution):
    """Zero-inflated negative binomial distribution."""
    def __init__(self, mean, concentration, gate=None, gate_logits=None, validate_args=None): ...

Directional Distributions

Distributions for circular and spherical data.

class VonMises(Distribution):
    """
    Von Mises distribution for circular data.
    
    Args:
        loc: Mean direction
        concentration: Concentration parameter
    """
    def __init__(self, loc, concentration, validate_args=None): ...

class ProjectedNormal(Distribution):
    """
    Projected normal distribution on unit sphere.
    
    Args:
        concentration: Concentration vector
    """
    def __init__(self, concentration, validate_args=None): ...

class SineBivariateVonMises(Distribution):
    """Sine bivariate von Mises distribution."""
    def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, 
                correlation, validate_args=None): ...

class SineSkewed(Distribution):
    """Sine-skewed circular distribution."""
    def __init__(self, base_dist, skewness, validate_args=None): ...

Mixture Distributions

Finite mixture models for modeling multi-modal data.

class Mixture(Distribution):
    """
    Finite mixture distribution.
    
    Args:
        mixing_distribution: Categorical mixing distribution
        component_distributions: List of component distributions
    """
    def __init__(self, mixing_distribution, component_distributions, validate_args=None): ...

class MixtureGeneral(Distribution):
    """General mixture distribution with flexible component selection."""
    def __init__(self, mixing_distribution, component_distributions, 
                support=None, validate_args=None): ...

class MixtureSameFamily(Distribution):
    """
    Mixture of distributions from the same family.
    
    Args:
        mixing_distribution: Categorical mixing distribution
        component_distribution: Batch of component distributions
    """
    def __init__(self, mixing_distribution, component_distribution, validate_args=None): ...

Truncated Distributions

Distributions with restricted support through truncation.

class TruncatedDistribution(Distribution):
    """
    Generic truncated distribution.
    
    Args:
        base_distribution: Base distribution to truncate
        low: Lower truncation bound
        high: Upper truncation bound
    """
    def __init__(self, base_distribution, low=None, high=None, validate_args=None): ...

class LeftTruncatedDistribution(Distribution):
    """Left-truncated distribution (truncated below)."""
    def __init__(self, base_distribution, low, validate_args=None): ...

class RightTruncatedDistribution(Distribution):
    """Right-truncated distribution (truncated above)."""
    def __init__(self, base_distribution, high, validate_args=None): ...

class TwoSidedTruncatedDistribution(Distribution):
    """Two-sided truncated distribution."""
    def __init__(self, base_distribution, low, high, validate_args=None): ...

class TruncatedNormal(Distribution):
    """Truncated normal distribution."""
    def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...

class TruncatedCauchy(Distribution):
    """Truncated Cauchy distribution."""
    def __init__(self, loc=0.0, scale=1.0, low=None, high=None, validate_args=None): ...

class LowerTruncatedPowerLaw(Distribution):
    """Lower truncated power law distribution."""
    def __init__(self, alpha, scale, validate_args=None): ...

class DoublyTruncatedPowerLaw(Distribution):
    """Doubly truncated power law distribution."""
    def __init__(self, alpha, low, high, validate_args=None): ...

Copula Distributions

Copula-based distributions for modeling dependence structures.

class GaussianCopula(Distribution):
    """
    Gaussian copula distribution.
    
    Args:
        correlation_matrix: Correlation matrix
        marginals: List of marginal distributions
    """
    def __init__(self, correlation_matrix, marginals, validate_args=None): ...

class GaussianCopulaBeta(Distribution):
    """Gaussian copula with Beta marginals."""
    def __init__(self, correlation_matrix, concentration1, concentration0, validate_args=None): ...

Special Distributions

Utility distributions for specific modeling needs.

class Delta(Distribution):
    """
    Point mass (Dirac delta) distribution.
    
    Args:
        v: Point mass location
        log_density: Log density value at the point
        event_dim: Number of event dimensions
    """
    def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): ...

class Unit(Distribution):
    """Unit distribution for adding log probability factors."""
    def __init__(self, log_factor, validate_args=None): ...

class ImproperUniform(Distribution):
    """
    Improper uniform distribution over real numbers.
    
    Args:
        support: Support constraint
        batch_shape: Batch shape
        event_shape: Event shape
    """
    def __init__(self, support, batch_shape, event_shape, validate_args=None): ...

class CirculantNormal(Distribution):
    """Normal distribution with circulant covariance matrix."""
    def __init__(self, loc, circulant_cov, validate_args=None): ...

class ZeroSumNormal(Distribution):
    """Normal distribution with zero-sum constraint."""
    def __init__(self, scale, validate_args=None): ...

class RelaxedBernoulli(Distribution):
    """Relaxed Bernoulli distribution (continuous relaxation)."""
    def __init__(self, temperature, probs=None, logits=None, validate_args=None): ...

Distribution Utilities

def enable_validation(is_validate: bool) -> None:
    """Enable or disable distribution parameter validation."""

def validation_enabled() -> bool:
    """Check if distribution validation is currently enabled."""

def kl_divergence(p: Distribution, q: Distribution) -> Array:
    """Compute KL divergence between two distributions."""

def biject_to(constraint) -> Transform:
    """Get bijective transform to given constraint."""

Types

from typing import Optional, Union, Callable, Sequence
from jax import Array
import jax.numpy as jnp

ArrayLike = Union[Array, jnp.ndarray, float, int]
Constraint = numpyro.distributions.constraints.Constraint
Transform = numpyro.distributions.transforms.Transform

# Distribution parameter types
Concentration = ArrayLike  # Positive real numbers
Rate = ArrayLike  # Positive real numbers  
Scale = ArrayLike  # Positive real numbers
Probability = ArrayLike  # Numbers in [0, 1]
Logits = ArrayLike  # Real numbers
Location = ArrayLike  # Real numbers

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json