Distrax: Probability distributions in JAX.
—
Continuous probability distributions for modeling real-valued random variables, including univariate and multivariate distributions with various parameterizations and covariance structures.
Standard normal distribution with location and scale parameters.
class Normal(Distribution):
def __init__(self, loc, scale):
"""
Normal distribution.
Parameters:
- loc: mean parameter (float or array)
- scale: standard deviation parameter (float or array, must be positive)
"""
@property
def loc(self): ...
@property
def scale(self): ...
@property
def event_shape(self): ...
@property
def batch_shape(self): ...Normal distribution parameterized by log standard deviation for numerical stability.
class LogStddevNormal(Distribution):
def __init__(self, loc, log_scale):
"""
Normal distribution parameterized by log standard deviation.
Parameters:
- loc: mean parameter (float or array)
- log_scale: log standard deviation parameter (float or array)
"""
@property
def loc(self): ...
@property
def log_scale(self): ...
@property
def scale(self): ...Beta distribution for modeling probabilities and proportions.
class Beta(Distribution):
def __init__(self, concentration1, concentration0):
"""
Beta distribution.
Parameters:
- concentration1: first concentration parameter (must be positive)
- concentration0: second concentration parameter (must be positive)
"""
@property
def concentration1(self): ...
@property
def concentration0(self): ...Gamma distribution for modeling positive continuous variables.
class Gamma(Distribution):
def __init__(self, concentration, rate):
"""
Gamma distribution.
Parameters:
- concentration: shape parameter (must be positive)
- rate: rate parameter (must be positive)
"""
@property
def concentration(self): ...
@property
def rate(self): ...Laplace (double exponential) distribution.
class Laplace(Distribution):
def __init__(self, loc, scale):
"""
Laplace distribution.
Parameters:
- loc: location parameter (float or array)
- scale: scale parameter (float or array, must be positive)
"""
@property
def loc(self): ...
@property
def scale(self): ...Logistic distribution for sigmoid-shaped densities.
class Logistic(Distribution):
def __init__(self, loc, scale):
"""
Logistic distribution.
Parameters:
- loc: location parameter (float or array)
- scale: scale parameter (float or array, must be positive)
"""
@property
def loc(self): ...
@property
def scale(self): ...Gumbel distribution for modeling extreme values.
class Gumbel(Distribution):
def __init__(self, loc, scale):
"""
Gumbel distribution.
Parameters:
- loc: location parameter (float or array)
- scale: scale parameter (float or array, must be positive)
"""
@property
def loc(self): ...
@property
def scale(self): ...Uniform distribution over a specified interval.
class Uniform(Distribution):
def __init__(self, low, high):
"""
Uniform distribution.
Parameters:
- low: lower bound (float or array)
- high: upper bound (float or array, must be > low)
"""
@property
def low(self): ...
@property
def high(self): ...Von Mises (circular normal) distribution for circular data.
class VonMises(Distribution):
def __init__(self, loc, concentration):
"""
Von Mises distribution.
Parameters:
- loc: mean direction parameter (float or array)
- concentration: concentration parameter (float or array, must be >= 0)
"""
@property
def loc(self): ...
@property
def concentration(self): ...Multivariate normal distribution with diagonal covariance matrix.
class MultivariateNormalDiag(Distribution):
def __init__(self, loc, scale_diag):
"""
Multivariate normal with diagonal covariance.
Parameters:
- loc: mean vector (array of shape [..., d])
- scale_diag: diagonal standard deviations (array of shape [..., d], must be positive)
"""
@property
def loc(self): ...
@property
def scale_diag(self): ...
@property
def event_shape(self): ...Multivariate normal distribution with full covariance matrix.
class MultivariateNormalFullCovariance(Distribution):
def __init__(self, loc, covariance_matrix):
"""
Multivariate normal with full covariance matrix.
Parameters:
- loc: mean vector (array of shape [..., d])
- covariance_matrix: covariance matrix (array of shape [..., d, d], must be positive definite)
"""
@property
def loc(self): ...
@property
def covariance_matrix(self): ...Multivariate normal distribution parameterized by triangular matrices.
class MultivariateNormalTri(Distribution):
def __init__(self, loc, scale_tri, scale_diag):
"""
Multivariate normal with triangular parameterization.
Parameters:
- loc: mean vector (array of shape [..., d])
- scale_tri: lower triangular matrix (array of shape [..., d, d])
- scale_diag: diagonal scale factors (array of shape [..., d])
"""
@property
def loc(self): ...
@property
def scale_tri(self): ...
@property
def scale_diag(self): ...Multivariate normal with diagonal plus low-rank covariance structure.
class MultivariateNormalDiagPlusLowRank(Distribution):
def __init__(self, loc, scale_diag, scale_identity_multiplier, scale_perturb_factor):
"""
Multivariate normal with diagonal plus low-rank covariance.
Parameters:
- loc: mean vector (array of shape [..., d])
- scale_diag: diagonal component (array of shape [..., d])
- scale_identity_multiplier: scalar multiplier for identity (float or array)
- scale_perturb_factor: low-rank perturbation factor (array of shape [..., d, k])
"""
@property
def loc(self): ...
@property
def scale_diag(self): ...
@property
def scale_identity_multiplier(self): ...
@property
def scale_perturb_factor(self): ...Multivariate normal distribution constructed using a bijector transformation.
class MultivariateNormalFromBijector(Distribution):
def __init__(self, shift, bijector):
"""
Multivariate normal constructed via bijector.
Parameters:
- shift: location parameter (array of shape [..., d])
- bijector: bijector defining the transformation from standard normal
"""
@property
def shift(self): ...
@property
def bijector(self): ...Dirichlet distribution for modeling probability vectors.
class Dirichlet(Distribution):
def __init__(self, concentration):
"""
Dirichlet distribution.
Parameters:
- concentration: concentration parameters (array of shape [..., k], must be positive)
"""
@property
def concentration(self): ...
@property
def event_shape(self): ...Install with Tessl CLI
npx tessl i tessl/pypi-distrax