Distrax: Probability distributions in JAX.
—
Complex distributions created by combining simpler components, including mixture models, transformed distributions, and joint distributions for multi-component modeling.
Distribution transformed by a bijector.
class Transformed(Distribution):
def __init__(self, distribution, bijector):
"""
Distribution transformed by a bijector.
Parameters:
- distribution: base distribution to transform
- bijector: bijector defining the transformation
"""
@property
def distribution(self): ...
@property
def bijector(self): ...
@property
def event_shape(self): ...
@property
def batch_shape(self): ...Mixture of exactly two distributions.
class MixtureOfTwo(Distribution):
def __init__(self, mixing_probs, dist1, dist2):
"""
Mixture of two distributions.
Parameters:
- mixing_probs: mixing probabilities (array of shape [..., 2] that sums to 1)
- dist1: first component distribution
- dist2: second component distribution
"""
@property
def mixing_probs(self): ...
@property
def dist1(self): ...
@property
def dist2(self): ...Mixture of distributions from the same parametric family.
class MixtureSameFamily(Distribution):
def __init__(self, mixture_distribution, components_distribution):
"""
Mixture of distributions from the same family.
Parameters:
- mixture_distribution: categorical distribution over mixture components
- components_distribution: batch of component distributions
"""
@property
def mixture_distribution(self): ...
@property
def components_distribution(self): ...
@property
def event_shape(self): ...Reinterprets batch dimensions as event dimensions.
class Independent(Distribution):
def __init__(self, distribution, reinterpreted_batch_ndims):
"""
Independent distribution reinterpreting batch dimensions.
Parameters:
- distribution: base distribution
- reinterpreted_batch_ndims: number of batch dimensions to reinterpret as event dimensions
"""
@property
def distribution(self): ...
@property
def reinterpreted_batch_ndims(self): ...
@property
def event_shape(self): ...
@property
def batch_shape(self): ...Joint distribution of multiple components.
class Joint(Distribution):
def __init__(self, distributions, name="Joint"):
"""
Joint distribution of multiple components.
Parameters:
- distributions: sequence or dict of component distributions
- name: name for the distribution
"""
@property
def distributions(self): ...
@property
def event_shape(self): ...Quantized version of a continuous distribution.
class Quantized(Distribution):
def __init__(self, distribution, low=None, high=None):
"""
Quantized distribution.
Parameters:
- distribution: base continuous distribution
- low: lower quantization bound (optional)
- high: upper quantization bound (optional)
"""
@property
def distribution(self): ...
@property
def low(self): ...
@property
def high(self): ...Install with Tessl CLI
npx tessl i tessl/pypi-distrax