CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-distrax

Distrax: Probability distributions in JAX.

Pending
Overview
Eval results
Files

utilities.mddocs/

Utilities

Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.

Capabilities

Conversion Utilities

Convert to Distrax Distribution

Converts distribution-like objects to Distrax distributions.

def as_distribution(obj):
    """
    Convert distribution-like object to Distrax distribution.
    
    Parameters:
    - obj: DistributionLike object (Distrax or TFP distribution)
    
    Returns:
    Distrax Distribution
    """

Convert to Distrax Bijector

Converts bijector-like objects to Distrax bijectors.

def as_bijector(obj):
    """
    Convert bijector-like object to Distrax bijector.
    
    Parameters:
    - obj: BijectorLike object (Distrax bijector, TFP bijector, or callable)
    
    Returns:
    Distrax Bijector
    """

Convert to TensorFlow Probability

Converts Distrax objects to TFP-compatible equivalents.

def to_tfp(obj, name=None):
    """
    Convert Distrax object to TFP-compatible equivalent.
    
    Parameters:
    - obj: Distrax distribution or bijector
    - name: optional name for the TFP object
    
    Returns:
    TFP-compatible distribution or bijector
    """

Mathematical Utilities

Multiply with No NaN

Element-wise multiplication that returns 0 if second argument is zero.

def multiply_no_nan(x, y):
    """
    Element-wise multiplication that returns 0 if y is 0.
    
    Parameters:
    - x: first operand (array)
    - y: second operand (array)
    
    Returns:
    Element-wise product with NaN-safe handling
    """

Monte Carlo Estimation

Best-Effort KL Divergence Estimation

Estimates KL divergence exactly if possible, otherwise uses Monte Carlo.

def estimate_kl_best_effort(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):
    """
    Estimate KL divergence using best available method.
    
    Parameters:
    - distribution_a: first distribution
    - distribution_b: second distribution
    - rng_key: JAX random key
    - num_samples: number of Monte Carlo samples
    - proposal_distribution: optional proposal distribution for importance sampling
    
    Returns:
    KL divergence estimate
    """

Monte Carlo KL Divergence Estimation

Monte Carlo estimation of KL divergence using DiCE estimator.

def mc_estimate_kl(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):
    """
    Monte Carlo estimation of KL divergence.
    
    Parameters:
    - distribution_a: first distribution
    - distribution_b: second distribution
    - rng_key: JAX random key
    - num_samples: number of Monte Carlo samples
    - proposal_distribution: optional proposal distribution for importance sampling
    
    Returns:
    KL divergence estimate
    """

Monte Carlo KL with Reparameterized Distributions

Monte Carlo KL estimation with reparameterized distributions.

def mc_estimate_kl_with_reparameterized(distribution_a, distribution_b, rng_key, num_samples):
    """
    Monte Carlo KL estimation with reparameterized distributions.
    
    Parameters:
    - distribution_a: first distribution (must be reparameterizable)
    - distribution_b: second distribution
    - rng_key: JAX random key
    - num_samples: number of Monte Carlo samples
    
    Returns:
    KL divergence estimate
    """

Monte Carlo Mode Estimation

Monte Carlo estimation of distribution mode.

def mc_estimate_mode(distribution, rng_key, num_samples):
    """
    Monte Carlo estimation of distribution mode.
    
    Parameters:
    - distribution: distribution to estimate mode
    - rng_key: JAX random key
    - num_samples: number of Monte Carlo samples
    
    Returns:
    Mode estimate
    """

Importance Sampling

Importance Sampling Ratios

Compute importance sampling ratios between distributions.

def importance_sampling_ratios(target_dist, sampling_dist, event):
    """
    Compute importance sampling ratios.
    
    Parameters:
    - target_dist: target distribution
    - sampling_dist: sampling distribution
    - event: sampled events (array)
    
    Returns:
    Importance sampling ratios
    """

Transformation Utilities

Register Inverse Functions

Register inverse functions for JAX primitives.

def register_inverse(primitive, inverse_left, inverse_right=None):
    """
    Register inverse functions for JAX primitives.
    
    Parameters:
    - primitive: JAX primitive to register inverse for
    - inverse_left: left inverse function
    - inverse_right: optional right inverse function
    """

Hidden Markov Models

HMM Class

Hidden Markov Model implementation for sequential modeling.

class HMM:
    def __init__(self, init_dist, trans_dist, obs_dist):
        """
        Hidden Markov Model.
        
        Parameters:
        - init_dist: initial state distribution
        - trans_dist: transition distribution
        - obs_dist: observation distribution
        """

    def sample(self, *, seed, seq_len):
        """
        Sample a sequence from the HMM.
        
        Parameters:
        - seed: JAX random key
        - seq_len: length of sequence to sample
        
        Returns:
        Tuple of (states, observations)
        """

    def forward(self, obs_seq, length=None):
        """
        Forward algorithm for computing marginal likelihood.
        
        Parameters:
        - obs_seq: sequence of observations (array)
        - length: optional sequence length (for batched sequences)
        
        Returns:
        Forward probabilities and log marginal likelihood
        """

    def backward(self, obs_seq, length=None):
        """
        Backward algorithm for computing backward probabilities.
        
        Parameters:
        - obs_seq: sequence of observations (array)
        - length: optional sequence length (for batched sequences)
        
        Returns:
        Backward probabilities
        """

    def forward_backward(self, obs_seq, length=None):
        """
        Forward-backward algorithm for state posterior probabilities.
        
        Parameters:
        - obs_seq: sequence of observations (array)
        - length: optional sequence length (for batched sequences)
        
        Returns:
        State posterior probabilities and log marginal likelihood
        """

    def viterbi(self, obs_seq):
        """
        Viterbi algorithm for most likely state sequence.
        
        Parameters:
        - obs_seq: sequence of observations (array)
        
        Returns:
        Most likely state sequence and its log probability
        """

    @property
    def init_dist(self): ...
    @property
    def trans_dist(self): ...
    @property
    def obs_dist(self): ...

    @property
    def event_shape(self): ...
    @property
    def batch_shape(self): ...

Usage Examples

Converting Between Libraries

import distrax
import tensorflow_probability.substrates.jax as tfp

# Convert TFP distribution to Distrax
tfp_normal = tfp.distributions.Normal(0.0, 1.0)
distrax_normal = distrax.as_distribution(tfp_normal)

# Convert Distrax distribution to TFP
distrax_normal = distrax.Normal(0.0, 1.0)
tfp_normal = distrax.to_tfp(distrax_normal)

Monte Carlo KL Estimation

import distrax
import jax.random as random

key = random.PRNGKey(42)
p = distrax.Normal(0.0, 1.0)
q = distrax.Normal(0.5, 1.2)

# Estimate KL divergence
kl_estimate = distrax.mc_estimate_kl(p, q, key, num_samples=10000)

Hidden Markov Model

import distrax
import jax.numpy as jnp
import jax.random as random

# Define HMM components
init_dist = distrax.Categorical(logits=jnp.array([0.0, 0.0]))
trans_dist = distrax.Categorical(logits=jnp.array([[1.0, -1.0], [-1.0, 1.0]]))
obs_dist = distrax.Normal(jnp.array([0.0, 3.0]), jnp.array([1.0, 0.5]))

# Create HMM
hmm = distrax.HMM(init_dist, trans_dist, obs_dist)

# Sample sequence
key = random.PRNGKey(42)
states, observations = hmm.sample(seed=key, seq_len=100)

# Compute forward probabilities
forward_probs, log_prob = hmm.forward(observations)

# Find most likely state sequence
viterbi_states, viterbi_log_prob = hmm.viterbi(observations)

Install with Tessl CLI

npx tessl i tessl/pypi-distrax

docs

bijectors.md

continuous-distributions.md

discrete-distributions.md

index.md

mixture-composite.md

specialized-distributions.md

utilities.md

tile.json