or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

bijectors.mdcontinuous-distributions.mddiscrete-distributions.mdindex.mdmixture-composite.mdspecialized-distributions.mdutilities.md
tile.json

tessl/pypi-distrax

Distrax: Probability distributions in JAX.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/distrax@0.1.x

To install, run

npx @tessl/cli install tessl/pypi-distrax@0.1.0

index.mddocs/

Distrax

Distrax is a lightweight JAX-native library of probability distributions and bijectors that acts as a reimplementation of a subset of TensorFlow Probability (TFP) with emphasis on readability, extensibility, and cross-compatibility. The library provides a comprehensive set of probability distributions and bijectors (invertible functions with known Jacobian determinants) that can be used to create complex distributions by transforming simpler ones.

Package Information

  • Package Name: distrax
  • Language: Python
  • Installation: pip install distrax

Core Imports

import distrax

Common patterns for distributions:

from distrax import Normal, Bernoulli, Categorical

Common patterns for bijectors:

from distrax import ScalarAffine, Chain, Sigmoid

Basic Usage

import distrax
import jax.numpy as jnp
import jax.random as random

# Create a simple distribution
key = random.PRNGKey(42)
dist = distrax.Normal(loc=0.0, scale=1.0)

# Sample from the distribution
samples = dist.sample(seed=key, sample_shape=(100,))

# Compute log probabilities
log_probs = dist.log_prob(samples)

# Create a bijector for transformations
bijector = distrax.ScalarAffine(shift=2.0, scale=0.5)

# Transform values
x = jnp.array([1.0, 2.0, 3.0])
y = bijector.forward(x)
x_reconstructed = bijector.inverse(y)

# Create transformed distributions
transformed_dist = distrax.Transformed(dist, bijector)
transformed_samples = transformed_dist.sample(seed=key, sample_shape=(100,))

Architecture

Distrax follows a clear architectural pattern based on two main abstractions:

  • Distribution: Base class for probability distributions providing sampling, density evaluation, and statistical properties
  • Bijector: Base class for invertible functions with computable Jacobian determinants

This design enables:

  • Compositional flexibility: Bijectors can be chained and combined with distributions
  • JAX integration: Full compatibility with JAX transformations (jit, vmap, grad)
  • TFP compatibility: Seamless interoperability with TensorFlow Probability
  • Type safety: Comprehensive type hints for better development experience

Capabilities

Continuous Distributions

Univariate and multivariate continuous probability distributions including Normal, Beta, Gamma, Laplace, and multivariate normal variants with different covariance structures.

class Normal(Distribution):
    def __init__(self, loc, scale): ...

class Beta(Distribution):
    def __init__(self, concentration1, concentration0): ...

class MultivariateNormalDiag(Distribution):
    def __init__(self, loc, scale_diag): ...

Continuous Distributions

Discrete Distributions

Discrete probability distributions for categorical and binary outcomes, including Bernoulli, Categorical, and Multinomial distributions with various parameterizations.

class Bernoulli(Distribution):
    def __init__(self, logits=None, probs=None, dtype=int): ...

class Categorical(Distribution):
    def __init__(self, logits=None, probs=None, dtype=int): ...

class OneHotCategorical(Distribution):
    def __init__(self, logits=None, probs=None, dtype=float): ...

Discrete Distributions

Bijectors

Invertible transformations with known Jacobian determinants for creating complex distributions through composition, including affine transformations, normalizing flows, and neural network layers.

class Bijector:
    def forward(self, x): ...
    def inverse(self, y): ...
    def forward_and_log_det(self, x): ...

class ScalarAffine(Bijector):
    def __init__(self, shift, scale=None, log_scale=None): ...

class Chain(Bijector):
    def __init__(self, bijectors): ...

Bijectors

Mixture and Composite Distributions

Complex distributions created by combining simpler components, including mixture models, transformed distributions, and joint distributions for multi-component modeling.

class Transformed(Distribution):
    def __init__(self, distribution, bijector): ...

class MixtureSameFamily(Distribution):
    def __init__(self, mixture_distribution, components_distribution): ...

class Independent(Distribution):
    def __init__(self, distribution, reinterpreted_batch_ndims): ...

Mixture and Composite Distributions

Specialized Distributions

Task-specific distributions for reinforcement learning, clipped distributions, and deterministic distributions for specialized modeling needs.

class EpsilonGreedy(Distribution):
    def __init__(self, preferences, epsilon): ...

class ClippedNormal(Distribution):
    def __init__(self, loc, scale, low, high): ...

class Deterministic(Distribution):
    def __init__(self, loc): ...

Specialized Distributions

Utilities

Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.

def as_distribution(obj: DistributionLike) -> Distribution: ...
def as_bijector(obj: BijectorLike) -> Bijector: ...
def to_tfp(obj, name=None): ...

class HMM:
    def __init__(self, init_dist, trans_dist, obs_dist): ...

Utilities

Types

Base Classes

class Distribution:
    """
    Abstract base class for probability distributions.
    
    Provides common interface for sampling, density evaluation, and statistical properties.
    All distributions must implement log_prob() and _sample_n() methods.
    """
    
    def sample(self, *, seed, sample_shape=()): ...
    def sample_and_log_prob(self, *, seed, sample_shape=()): ...
    def log_prob(self, value): ...
    def prob(self, value): ...
    def entropy(self): ...
    def mean(self): ...
    def variance(self): ...
    def cdf(self, value): ...
    def __getitem__(self, index): ...
    
    @property
    def event_shape(self): ...
    @property
    def batch_shape(self): ...
    @property
    def dtype(self): ...

class Bijector:
    """
    Abstract base class for invertible transformations with known Jacobian determinants.
    
    All bijectors must implement forward_and_log_det() method.
    """
    
    def forward(self, x): ...
    def inverse(self, y): ...
    def forward_and_log_det(self, x): ...
    def inverse_and_log_det(self, y): ...
    
    @property
    def event_ndims_in(self): ...
    @property
    def event_ndims_out(self): ...

Type Aliases

from typing import Union, Callable
from chex import Array

DistributionLike = Union[Distribution, 'tfd.Distribution']
BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]