Distrax: Probability distributions in JAX.
—
Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.
Converts distribution-like objects to Distrax distributions.
def as_distribution(obj):
"""
Convert distribution-like object to Distrax distribution.
Parameters:
- obj: DistributionLike object (Distrax or TFP distribution)
Returns:
Distrax Distribution
"""Converts bijector-like objects to Distrax bijectors.
def as_bijector(obj):
"""
Convert bijector-like object to Distrax bijector.
Parameters:
- obj: BijectorLike object (Distrax bijector, TFP bijector, or callable)
Returns:
Distrax Bijector
"""Converts Distrax objects to TFP-compatible equivalents.
def to_tfp(obj, name=None):
"""
Convert Distrax object to TFP-compatible equivalent.
Parameters:
- obj: Distrax distribution or bijector
- name: optional name for the TFP object
Returns:
TFP-compatible distribution or bijector
"""Element-wise multiplication that returns 0 if second argument is zero.
def multiply_no_nan(x, y):
"""
Element-wise multiplication that returns 0 if y is 0.
Parameters:
- x: first operand (array)
- y: second operand (array)
Returns:
Element-wise product with NaN-safe handling
"""Estimates KL divergence exactly if possible, otherwise uses Monte Carlo.
def estimate_kl_best_effort(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):
"""
Estimate KL divergence using best available method.
Parameters:
- distribution_a: first distribution
- distribution_b: second distribution
- rng_key: JAX random key
- num_samples: number of Monte Carlo samples
- proposal_distribution: optional proposal distribution for importance sampling
Returns:
KL divergence estimate
"""Monte Carlo estimation of KL divergence using DiCE estimator.
def mc_estimate_kl(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):
"""
Monte Carlo estimation of KL divergence.
Parameters:
- distribution_a: first distribution
- distribution_b: second distribution
- rng_key: JAX random key
- num_samples: number of Monte Carlo samples
- proposal_distribution: optional proposal distribution for importance sampling
Returns:
KL divergence estimate
"""Monte Carlo KL estimation with reparameterized distributions.
def mc_estimate_kl_with_reparameterized(distribution_a, distribution_b, rng_key, num_samples):
"""
Monte Carlo KL estimation with reparameterized distributions.
Parameters:
- distribution_a: first distribution (must be reparameterizable)
- distribution_b: second distribution
- rng_key: JAX random key
- num_samples: number of Monte Carlo samples
Returns:
KL divergence estimate
"""Monte Carlo estimation of distribution mode.
def mc_estimate_mode(distribution, rng_key, num_samples):
"""
Monte Carlo estimation of distribution mode.
Parameters:
- distribution: distribution to estimate mode
- rng_key: JAX random key
- num_samples: number of Monte Carlo samples
Returns:
Mode estimate
"""Compute importance sampling ratios between distributions.
def importance_sampling_ratios(target_dist, sampling_dist, event):
"""
Compute importance sampling ratios.
Parameters:
- target_dist: target distribution
- sampling_dist: sampling distribution
- event: sampled events (array)
Returns:
Importance sampling ratios
"""Register inverse functions for JAX primitives.
def register_inverse(primitive, inverse_left, inverse_right=None):
"""
Register inverse functions for JAX primitives.
Parameters:
- primitive: JAX primitive to register inverse for
- inverse_left: left inverse function
- inverse_right: optional right inverse function
"""Hidden Markov Model implementation for sequential modeling.
class HMM:
def __init__(self, init_dist, trans_dist, obs_dist):
"""
Hidden Markov Model.
Parameters:
- init_dist: initial state distribution
- trans_dist: transition distribution
- obs_dist: observation distribution
"""
def sample(self, *, seed, seq_len):
"""
Sample a sequence from the HMM.
Parameters:
- seed: JAX random key
- seq_len: length of sequence to sample
Returns:
Tuple of (states, observations)
"""
def forward(self, obs_seq, length=None):
"""
Forward algorithm for computing marginal likelihood.
Parameters:
- obs_seq: sequence of observations (array)
- length: optional sequence length (for batched sequences)
Returns:
Forward probabilities and log marginal likelihood
"""
def backward(self, obs_seq, length=None):
"""
Backward algorithm for computing backward probabilities.
Parameters:
- obs_seq: sequence of observations (array)
- length: optional sequence length (for batched sequences)
Returns:
Backward probabilities
"""
def forward_backward(self, obs_seq, length=None):
"""
Forward-backward algorithm for state posterior probabilities.
Parameters:
- obs_seq: sequence of observations (array)
- length: optional sequence length (for batched sequences)
Returns:
State posterior probabilities and log marginal likelihood
"""
def viterbi(self, obs_seq):
"""
Viterbi algorithm for most likely state sequence.
Parameters:
- obs_seq: sequence of observations (array)
Returns:
Most likely state sequence and its log probability
"""
@property
def init_dist(self): ...
@property
def trans_dist(self): ...
@property
def obs_dist(self): ...
@property
def event_shape(self): ...
@property
def batch_shape(self): ...import distrax
import tensorflow_probability.substrates.jax as tfp
# Convert TFP distribution to Distrax
tfp_normal = tfp.distributions.Normal(0.0, 1.0)
distrax_normal = distrax.as_distribution(tfp_normal)
# Convert Distrax distribution to TFP
distrax_normal = distrax.Normal(0.0, 1.0)
tfp_normal = distrax.to_tfp(distrax_normal)import distrax
import jax.random as random
key = random.PRNGKey(42)
p = distrax.Normal(0.0, 1.0)
q = distrax.Normal(0.5, 1.2)
# Estimate KL divergence
kl_estimate = distrax.mc_estimate_kl(p, q, key, num_samples=10000)import distrax
import jax.numpy as jnp
import jax.random as random
# Define HMM components
init_dist = distrax.Categorical(logits=jnp.array([0.0, 0.0]))
trans_dist = distrax.Categorical(logits=jnp.array([[1.0, -1.0], [-1.0, 1.0]]))
obs_dist = distrax.Normal(jnp.array([0.0, 3.0]), jnp.array([1.0, 0.5]))
# Create HMM
hmm = distrax.HMM(init_dist, trans_dist, obs_dist)
# Sample sequence
key = random.PRNGKey(42)
states, observations = hmm.sample(seed=key, seq_len=100)
# Compute forward probabilities
forward_probs, log_prob = hmm.forward(observations)
# Find most likely state sequence
viterbi_states, viterbi_log_prob = hmm.viterbi(observations)Install with Tessl CLI
npx tessl i tessl/pypi-distrax