Distrax: Probability distributions in JAX.
npx @tessl/cli install tessl/pypi-distrax@0.1.0Distrax is a lightweight JAX-native library of probability distributions and bijectors that acts as a reimplementation of a subset of TensorFlow Probability (TFP) with emphasis on readability, extensibility, and cross-compatibility. The library provides a comprehensive set of probability distributions and bijectors (invertible functions with known Jacobian determinants) that can be used to create complex distributions by transforming simpler ones.
pip install distraximport distraxCommon patterns for distributions:
from distrax import Normal, Bernoulli, CategoricalCommon patterns for bijectors:
from distrax import ScalarAffine, Chain, Sigmoidimport distrax
import jax.numpy as jnp
import jax.random as random
# Create a simple distribution
key = random.PRNGKey(42)
dist = distrax.Normal(loc=0.0, scale=1.0)
# Sample from the distribution
samples = dist.sample(seed=key, sample_shape=(100,))
# Compute log probabilities
log_probs = dist.log_prob(samples)
# Create a bijector for transformations
bijector = distrax.ScalarAffine(shift=2.0, scale=0.5)
# Transform values
x = jnp.array([1.0, 2.0, 3.0])
y = bijector.forward(x)
x_reconstructed = bijector.inverse(y)
# Create transformed distributions
transformed_dist = distrax.Transformed(dist, bijector)
transformed_samples = transformed_dist.sample(seed=key, sample_shape=(100,))Distrax follows a clear architectural pattern based on two main abstractions:
This design enables:
Univariate and multivariate continuous probability distributions including Normal, Beta, Gamma, Laplace, and multivariate normal variants with different covariance structures.
class Normal(Distribution):
def __init__(self, loc, scale): ...
class Beta(Distribution):
def __init__(self, concentration1, concentration0): ...
class MultivariateNormalDiag(Distribution):
def __init__(self, loc, scale_diag): ...Discrete probability distributions for categorical and binary outcomes, including Bernoulli, Categorical, and Multinomial distributions with various parameterizations.
class Bernoulli(Distribution):
def __init__(self, logits=None, probs=None, dtype=int): ...
class Categorical(Distribution):
def __init__(self, logits=None, probs=None, dtype=int): ...
class OneHotCategorical(Distribution):
def __init__(self, logits=None, probs=None, dtype=float): ...Invertible transformations with known Jacobian determinants for creating complex distributions through composition, including affine transformations, normalizing flows, and neural network layers.
class Bijector:
def forward(self, x): ...
def inverse(self, y): ...
def forward_and_log_det(self, x): ...
class ScalarAffine(Bijector):
def __init__(self, shift, scale=None, log_scale=None): ...
class Chain(Bijector):
def __init__(self, bijectors): ...Complex distributions created by combining simpler components, including mixture models, transformed distributions, and joint distributions for multi-component modeling.
class Transformed(Distribution):
def __init__(self, distribution, bijector): ...
class MixtureSameFamily(Distribution):
def __init__(self, mixture_distribution, components_distribution): ...
class Independent(Distribution):
def __init__(self, distribution, reinterpreted_batch_ndims): ...Mixture and Composite Distributions
Task-specific distributions for reinforcement learning, clipped distributions, and deterministic distributions for specialized modeling needs.
class EpsilonGreedy(Distribution):
def __init__(self, preferences, epsilon): ...
class ClippedNormal(Distribution):
def __init__(self, loc, scale, low, high): ...
class Deterministic(Distribution):
def __init__(self, loc): ...Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.
def as_distribution(obj: DistributionLike) -> Distribution: ...
def as_bijector(obj: BijectorLike) -> Bijector: ...
def to_tfp(obj, name=None): ...
class HMM:
def __init__(self, init_dist, trans_dist, obs_dist): ...class Distribution:
"""
Abstract base class for probability distributions.
Provides common interface for sampling, density evaluation, and statistical properties.
All distributions must implement log_prob() and _sample_n() methods.
"""
def sample(self, *, seed, sample_shape=()): ...
def sample_and_log_prob(self, *, seed, sample_shape=()): ...
def log_prob(self, value): ...
def prob(self, value): ...
def entropy(self): ...
def mean(self): ...
def variance(self): ...
def cdf(self, value): ...
def __getitem__(self, index): ...
@property
def event_shape(self): ...
@property
def batch_shape(self): ...
@property
def dtype(self): ...
class Bijector:
"""
Abstract base class for invertible transformations with known Jacobian determinants.
All bijectors must implement forward_and_log_det() method.
"""
def forward(self, x): ...
def inverse(self, y): ...
def forward_and_log_det(self, x): ...
def inverse_and_log_det(self, y): ...
@property
def event_ndims_in(self): ...
@property
def event_ndims_out(self): ...from typing import Union, Callable
from chex import Array
DistributionLike = Union[Distribution, 'tfd.Distribution']
BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]