Distrax: Probability distributions in JAX.
—
Invertible transformations with known Jacobian determinants for creating complex distributions through composition. Bijectors enable the construction of sophisticated probability models by transforming simple base distributions.
Abstract base class defining the bijector interface.
class Bijector:
def __init__(self, event_ndims_in, event_ndims_out=None, is_constant_jacobian=False, is_constant_log_det=None):
"""
Base class for bijectors.
Parameters:
- event_ndims_in: number of dimensions in input events
- event_ndims_out: number of dimensions in output events (defaults to event_ndims_in)
- is_constant_jacobian: whether Jacobian is constant
- is_constant_log_det: whether log determinant is constant
"""
def forward(self, x):
"""Forward transformation y = f(x)."""
def inverse(self, y):
"""Inverse transformation x = f^{-1}(y)."""
def forward_and_log_det(self, x):
"""Forward transformation with log determinant: (y, log|det J|)."""
def inverse_and_log_det(self, y):
"""Inverse transformation with log determinant: (x, log|det J^{-1}|)."""
def forward_log_det_jacobian(self, x):
"""Log determinant of forward Jacobian."""
def inverse_log_det_jacobian(self, y):
"""Log determinant of inverse Jacobian."""
def same_as(self, other):
"""Check equality with another bijector."""
@property
def event_ndims_in(self): ...
@property
def event_ndims_out(self): ...
@property
def is_constant_jacobian(self): ...
@property
def is_constant_log_det(self): ...
@property
def name(self): ...Elementwise affine transformation y = scale * x + shift.
class ScalarAffine(Bijector):
def __init__(self, shift, scale=None, log_scale=None):
"""
Scalar affine transformation.
Parameters:
- shift: translation parameter (float or array)
- scale: scale parameter (float or array, mutually exclusive with log_scale)
- log_scale: log scale parameter (float or array, mutually exclusive with scale)
Note: Exactly one of scale or log_scale must be specified.
"""
@property
def shift(self): ...
@property
def scale(self): ...
@property
def log_scale(self): ...Translation bijector y = x + shift.
class Shift(Bijector):
def __init__(self, shift):
"""
Shift transformation.
Parameters:
- shift: translation parameter (float or array)
"""
@property
def shift(self): ...General unconstrained affine transformation.
class UnconstrainedAffine(Bijector):
def __init__(self, shift, matrix):
"""
Unconstrained affine transformation.
Parameters:
- shift: translation vector (array)
- matrix: transformation matrix (array)
"""
@property
def shift(self): ...
@property
def matrix(self): ...Linear transformation with diagonal matrix.
class DiagLinear(Bijector):
def __init__(self, diag):
"""
Diagonal linear transformation.
Parameters:
- diag: diagonal elements (array)
"""
@property
def diag(self): ...Linear transformation with arbitrary matrix.
class Linear(Bijector):
def __init__(self, matrix):
"""
Linear transformation.
Parameters:
- matrix: transformation matrix (array)
"""
@property
def matrix(self): ...Linear transformation with triangular matrix.
class TriangularLinear(Bijector):
def __init__(self, matrix, lower=True):
"""
Triangular linear transformation.
Parameters:
- matrix: triangular matrix (array)
- lower: whether matrix is lower triangular (bool, default True)
"""
@property
def matrix(self): ...
@property
def lower(self): ...Linear transformation with diagonal plus low-rank structure.
class DiagPlusLowRankLinear(Bijector):
def __init__(self, diag, u_matrix, v_matrix):
"""
Diagonal plus low-rank linear transformation.
Parameters:
- diag: diagonal component (array)
- u_matrix: U matrix for low-rank component (array)
- v_matrix: V matrix for low-rank component (array)
"""
@property
def diag(self): ...
@property
def u_matrix(self): ...
@property
def v_matrix(self): ...Affine transformation using LU decomposition.
class LowerUpperTriangularAffine(Bijector):
def __init__(self, shift, lower_upper, permutation):
"""
Lower-upper triangular affine transformation.
Parameters:
- shift: translation vector (array)
- lower_upper: combined L and U matrices (array)
- permutation: permutation for LU decomposition (array)
"""
@property
def shift(self): ...
@property
def lower_upper(self): ...
@property
def permutation(self): ...Sigmoid activation function bijector.
class Sigmoid(Bijector):
def __init__(self):
"""Sigmoid bijector mapping (-∞, ∞) to (0, 1)."""Hyperbolic tangent bijector.
class Tanh(Bijector):
def __init__(self):
"""Tanh bijector mapping (-∞, ∞) to (-1, 1)."""Gumbel cumulative distribution function bijector.
class GumbelCDF(Bijector):
def __init__(self):
"""Gumbel CDF bijector."""Composition of bijectors applied in reverse order.
class Chain(Bijector):
def __init__(self, bijectors):
"""
Chain of bijectors.
Parameters:
- bijectors: sequence of bijectors to compose (applied in reverse order)
"""
@property
def bijectors(self): ...Inverts another bijector.
class Inverse(Bijector):
def __init__(self, bijector):
"""
Inverse bijector.
Parameters:
- bijector: bijector to invert
"""
@property
def bijector(self): ...Wraps callable functions as bijectors.
class Lambda(Bijector):
def __init__(self, forward_fn, inverse_fn, forward_log_det_jacobian_fn,
inverse_log_det_jacobian_fn=None, event_ndims_in=0, event_ndims_out=None):
"""
Lambda bijector from functions.
Parameters:
- forward_fn: forward transformation function
- inverse_fn: inverse transformation function
- forward_log_det_jacobian_fn: forward log Jacobian determinant function
- inverse_log_det_jacobian_fn: inverse log Jacobian determinant function
- event_ndims_in: number of input event dimensions
- event_ndims_out: number of output event dimensions
"""
@property
def forward_fn(self): ...
@property
def inverse_fn(self): ...Bijector that acts on a subset of input dimensions.
class Block(Bijector):
def __init__(self, bijector, ndims):
"""
Block bijector.
Parameters:
- bijector: bijector to apply to subset
- ndims: number of dimensions to transform
"""
@property
def bijector(self): ...
@property
def ndims(self): ...Masked coupling layer for normalizing flows.
class MaskedCoupling(Bijector):
def __init__(self, mask, bijector_fn):
"""
Masked coupling layer.
Parameters:
- mask: binary mask for splitting input (array)
- bijector_fn: function that creates bijector from conditioning input
"""
@property
def mask(self): ...
@property
def bijector_fn(self): ...Split coupling layer for normalizing flows.
class SplitCoupling(Bijector):
def __init__(self, split_index, bijector_fn):
"""
Split coupling layer.
Parameters:
- split_index: index at which to split input
- bijector_fn: function that creates bijector from conditioning input
"""
@property
def split_index(self): ...
@property
def bijector_fn(self): ...Rational quadratic spline bijector for flexible transformations.
class RationalQuadraticSpline(Bijector):
def __init__(self, bin_widths, bin_heights, knot_slopes, range_min=-1.0, range_max=1.0):
"""
Rational quadratic spline bijector.
Parameters:
- bin_widths: widths of spline bins (array)
- bin_heights: heights of spline bins (array)
- knot_slopes: slopes at knot points (array)
- range_min: minimum of transformation range (float)
- range_max: maximum of transformation range (float)
"""
@property
def bin_widths(self): ...
@property
def bin_heights(self): ...
@property
def knot_slopes(self): ...
@property
def range_min(self): ...
@property
def range_max(self): ...from typing import Union, Callable
from chex import Array
BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]Install with Tessl CLI
npx tessl i tessl/pypi-distrax