CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-distrax

Distrax: Probability distributions in JAX.

Pending
Overview
Eval results
Files

bijectors.mddocs/

Bijectors

Invertible transformations with known Jacobian determinants for creating complex distributions through composition. Bijectors enable the construction of sophisticated probability models by transforming simple base distributions.

Capabilities

Base Bijector Class

Abstract base class defining the bijector interface.

class Bijector:
    def __init__(self, event_ndims_in, event_ndims_out=None, is_constant_jacobian=False, is_constant_log_det=None):
        """
        Base class for bijectors.
        
        Parameters:
        - event_ndims_in: number of dimensions in input events
        - event_ndims_out: number of dimensions in output events (defaults to event_ndims_in)
        - is_constant_jacobian: whether Jacobian is constant
        - is_constant_log_det: whether log determinant is constant
        """

    def forward(self, x):
        """Forward transformation y = f(x)."""

    def inverse(self, y):
        """Inverse transformation x = f^{-1}(y)."""

    def forward_and_log_det(self, x):
        """Forward transformation with log determinant: (y, log|det J|)."""

    def inverse_and_log_det(self, y):
        """Inverse transformation with log determinant: (x, log|det J^{-1}|)."""

    def forward_log_det_jacobian(self, x):
        """Log determinant of forward Jacobian."""

    def inverse_log_det_jacobian(self, y):
        """Log determinant of inverse Jacobian."""

    def same_as(self, other):
        """Check equality with another bijector."""

    @property
    def event_ndims_in(self): ...
    @property
    def event_ndims_out(self): ...
    @property
    def is_constant_jacobian(self): ...
    @property
    def is_constant_log_det(self): ...
    @property
    def name(self): ...

Affine Transformations

Scalar Affine Transformation

Elementwise affine transformation y = scale * x + shift.

class ScalarAffine(Bijector):
    def __init__(self, shift, scale=None, log_scale=None):
        """
        Scalar affine transformation.
        
        Parameters:
        - shift: translation parameter (float or array)
        - scale: scale parameter (float or array, mutually exclusive with log_scale)
        - log_scale: log scale parameter (float or array, mutually exclusive with scale)
        
        Note: Exactly one of scale or log_scale must be specified.
        """

    @property
    def shift(self): ...
    @property
    def scale(self): ...
    @property
    def log_scale(self): ...

Shift Transformation

Translation bijector y = x + shift.

class Shift(Bijector):
    def __init__(self, shift):
        """
        Shift transformation.
        
        Parameters:
        - shift: translation parameter (float or array)
        """

    @property
    def shift(self): ...

Unconstrained Affine Transformation

General unconstrained affine transformation.

class UnconstrainedAffine(Bijector):
    def __init__(self, shift, matrix):
        """
        Unconstrained affine transformation.
        
        Parameters:
        - shift: translation vector (array)
        - matrix: transformation matrix (array)
        """

    @property
    def shift(self): ...
    @property
    def matrix(self): ...

Linear Transformations

Diagonal Linear Transformation

Linear transformation with diagonal matrix.

class DiagLinear(Bijector):
    def __init__(self, diag):
        """
        Diagonal linear transformation.
        
        Parameters:
        - diag: diagonal elements (array)
        """

    @property
    def diag(self): ...

General Linear Transformation

Linear transformation with arbitrary matrix.

class Linear(Bijector):
    def __init__(self, matrix):
        """
        Linear transformation.
        
        Parameters:
        - matrix: transformation matrix (array)
        """

    @property
    def matrix(self): ...

Triangular Linear Transformation

Linear transformation with triangular matrix.

class TriangularLinear(Bijector):
    def __init__(self, matrix, lower=True):
        """
        Triangular linear transformation.
        
        Parameters:
        - matrix: triangular matrix (array)
        - lower: whether matrix is lower triangular (bool, default True)
        """

    @property
    def matrix(self): ...
    @property
    def lower(self): ...

Diagonal Plus Low-Rank Linear

Linear transformation with diagonal plus low-rank structure.

class DiagPlusLowRankLinear(Bijector):
    def __init__(self, diag, u_matrix, v_matrix):
        """
        Diagonal plus low-rank linear transformation.
        
        Parameters:
        - diag: diagonal component (array)
        - u_matrix: U matrix for low-rank component (array)
        - v_matrix: V matrix for low-rank component (array)
        """

    @property
    def diag(self): ...
    @property
    def u_matrix(self): ...
    @property
    def v_matrix(self): ...

Lower-Upper Triangular Affine

Affine transformation using LU decomposition.

class LowerUpperTriangularAffine(Bijector):
    def __init__(self, shift, lower_upper, permutation):
        """
        Lower-upper triangular affine transformation.
        
        Parameters:
        - shift: translation vector (array)
        - lower_upper: combined L and U matrices (array)
        - permutation: permutation for LU decomposition (array)
        """

    @property
    def shift(self): ...
    @property
    def lower_upper(self): ...
    @property
    def permutation(self): ...

Activation Function Bijectors

Sigmoid Bijector

Sigmoid activation function bijector.

class Sigmoid(Bijector):
    def __init__(self):
        """Sigmoid bijector mapping (-∞, ∞) to (0, 1)."""

Tanh Bijector

Hyperbolic tangent bijector.

class Tanh(Bijector):
    def __init__(self):
        """Tanh bijector mapping (-∞, ∞) to (-1, 1)."""

CDF Bijectors

Gumbel CDF Bijector

Gumbel cumulative distribution function bijector.

class GumbelCDF(Bijector):
    def __init__(self):
        """Gumbel CDF bijector."""

Composition and Meta-Bijectors

Chain Bijector

Composition of bijectors applied in reverse order.

class Chain(Bijector):
    def __init__(self, bijectors):
        """
        Chain of bijectors.
        
        Parameters:
        - bijectors: sequence of bijectors to compose (applied in reverse order)
        """

    @property
    def bijectors(self): ...

Inverse Bijector

Inverts another bijector.

class Inverse(Bijector):
    def __init__(self, bijector):
        """
        Inverse bijector.
        
        Parameters:
        - bijector: bijector to invert
        """

    @property
    def bijector(self): ...

Lambda Bijector

Wraps callable functions as bijectors.

class Lambda(Bijector):
    def __init__(self, forward_fn, inverse_fn, forward_log_det_jacobian_fn, 
                 inverse_log_det_jacobian_fn=None, event_ndims_in=0, event_ndims_out=None):
        """
        Lambda bijector from functions.
        
        Parameters:
        - forward_fn: forward transformation function
        - inverse_fn: inverse transformation function
        - forward_log_det_jacobian_fn: forward log Jacobian determinant function
        - inverse_log_det_jacobian_fn: inverse log Jacobian determinant function
        - event_ndims_in: number of input event dimensions
        - event_ndims_out: number of output event dimensions
        """

    @property
    def forward_fn(self): ...
    @property
    def inverse_fn(self): ...

Block Bijector

Bijector that acts on a subset of input dimensions.

class Block(Bijector):
    def __init__(self, bijector, ndims):
        """
        Block bijector.
        
        Parameters:
        - bijector: bijector to apply to subset
        - ndims: number of dimensions to transform
        """

    @property
    def bijector(self): ...
    @property
    def ndims(self): ...

Normalizing Flow Bijectors

Masked Coupling Layer

Masked coupling layer for normalizing flows.

class MaskedCoupling(Bijector):
    def __init__(self, mask, bijector_fn):
        """
        Masked coupling layer.
        
        Parameters:
        - mask: binary mask for splitting input (array)
        - bijector_fn: function that creates bijector from conditioning input
        """

    @property
    def mask(self): ...
    @property
    def bijector_fn(self): ...

Split Coupling Layer

Split coupling layer for normalizing flows.

class SplitCoupling(Bijector):
    def __init__(self, split_index, bijector_fn):
        """
        Split coupling layer.
        
        Parameters:
        - split_index: index at which to split input
        - bijector_fn: function that creates bijector from conditioning input
        """

    @property
    def split_index(self): ...
    @property
    def bijector_fn(self): ...

Rational Quadratic Spline

Rational quadratic spline bijector for flexible transformations.

class RationalQuadraticSpline(Bijector):
    def __init__(self, bin_widths, bin_heights, knot_slopes, range_min=-1.0, range_max=1.0):
        """
        Rational quadratic spline bijector.
        
        Parameters:
        - bin_widths: widths of spline bins (array)
        - bin_heights: heights of spline bins (array)
        - knot_slopes: slopes at knot points (array)
        - range_min: minimum of transformation range (float)
        - range_max: maximum of transformation range (float)
        """

    @property
    def bin_widths(self): ...
    @property
    def bin_heights(self): ...
    @property
    def knot_slopes(self): ...
    @property
    def range_min(self): ...
    @property
    def range_max(self): ...

Types

from typing import Union, Callable
from chex import Array

BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]

Install with Tessl CLI

npx tessl i tessl/pypi-distrax

docs

bijectors.md

continuous-distributions.md

discrete-distributions.md

index.md

mixture-composite.md

specialized-distributions.md

utilities.md

tile.json