CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-distrax

Distrax: Probability distributions in JAX.

Pending
Overview
Eval results
Files

continuous-distributions.mddocs/

Continuous Distributions

Continuous probability distributions for modeling real-valued random variables, including univariate and multivariate distributions with various parameterizations and covariance structures.

Capabilities

Univariate Normal Distribution

Standard normal distribution with location and scale parameters.

class Normal(Distribution):
    def __init__(self, loc, scale):
        """
        Normal distribution.
        
        Parameters:
        - loc: mean parameter (float or array)
        - scale: standard deviation parameter (float or array, must be positive)
        """

    @property
    def loc(self): ...
    @property
    def scale(self): ...
    @property
    def event_shape(self): ...
    @property
    def batch_shape(self): ...

Alternative Normal Parameterizations

Normal distribution parameterized by log standard deviation for numerical stability.

class LogStddevNormal(Distribution):
    def __init__(self, loc, log_scale):
        """
        Normal distribution parameterized by log standard deviation.
        
        Parameters:
        - loc: mean parameter (float or array)
        - log_scale: log standard deviation parameter (float or array)
        """

    @property
    def loc(self): ...
    @property
    def log_scale(self): ...
    @property
    def scale(self): ...

Beta Distribution

Beta distribution for modeling probabilities and proportions.

class Beta(Distribution):
    def __init__(self, concentration1, concentration0):
        """
        Beta distribution.
        
        Parameters:
        - concentration1: first concentration parameter (must be positive)
        - concentration0: second concentration parameter (must be positive)
        """

    @property
    def concentration1(self): ...
    @property
    def concentration0(self): ...

Gamma Distribution

Gamma distribution for modeling positive continuous variables.

class Gamma(Distribution):
    def __init__(self, concentration, rate):
        """
        Gamma distribution.
        
        Parameters:
        - concentration: shape parameter (must be positive)
        - rate: rate parameter (must be positive)
        """

    @property
    def concentration(self): ...
    @property
    def rate(self): ...

Laplace Distribution

Laplace (double exponential) distribution.

class Laplace(Distribution):
    def __init__(self, loc, scale):
        """
        Laplace distribution.
        
        Parameters:
        - loc: location parameter (float or array)
        - scale: scale parameter (float or array, must be positive)
        """

    @property
    def loc(self): ...
    @property
    def scale(self): ...

Logistic Distribution

Logistic distribution for sigmoid-shaped densities.

class Logistic(Distribution):
    def __init__(self, loc, scale):
        """
        Logistic distribution.
        
        Parameters:
        - loc: location parameter (float or array)
        - scale: scale parameter (float or array, must be positive)
        """

    @property
    def loc(self): ...
    @property
    def scale(self): ...

Gumbel Distribution

Gumbel distribution for modeling extreme values.

class Gumbel(Distribution):
    def __init__(self, loc, scale):
        """
        Gumbel distribution.
        
        Parameters:
        - loc: location parameter (float or array)
        - scale: scale parameter (float or array, must be positive)
        """

    @property
    def loc(self): ...
    @property
    def scale(self): ...

Uniform Distribution

Uniform distribution over a specified interval.

class Uniform(Distribution):
    def __init__(self, low, high):
        """
        Uniform distribution.
        
        Parameters:
        - low: lower bound (float or array)
        - high: upper bound (float or array, must be > low)
        """

    @property
    def low(self): ...
    @property
    def high(self): ...

Von Mises Distribution

Von Mises (circular normal) distribution for circular data.

class VonMises(Distribution):
    def __init__(self, loc, concentration):
        """
        Von Mises distribution.
        
        Parameters:
        - loc: mean direction parameter (float or array)
        - concentration: concentration parameter (float or array, must be >= 0)
        """

    @property
    def loc(self): ...
    @property
    def concentration(self): ...

Multivariate Normal with Diagonal Covariance

Multivariate normal distribution with diagonal covariance matrix.

class MultivariateNormalDiag(Distribution):
    def __init__(self, loc, scale_diag):
        """
        Multivariate normal with diagonal covariance.
        
        Parameters:
        - loc: mean vector (array of shape [..., d])
        - scale_diag: diagonal standard deviations (array of shape [..., d], must be positive)
        """

    @property
    def loc(self): ...
    @property
    def scale_diag(self): ...
    @property
    def event_shape(self): ...

Multivariate Normal with Full Covariance

Multivariate normal distribution with full covariance matrix.

class MultivariateNormalFullCovariance(Distribution):
    def __init__(self, loc, covariance_matrix):
        """
        Multivariate normal with full covariance matrix.
        
        Parameters:
        - loc: mean vector (array of shape [..., d])
        - covariance_matrix: covariance matrix (array of shape [..., d, d], must be positive definite)
        """

    @property
    def loc(self): ...
    @property
    def covariance_matrix(self): ...

Multivariate Normal with Triangular Parameterization

Multivariate normal distribution parameterized by triangular matrices.

class MultivariateNormalTri(Distribution):
    def __init__(self, loc, scale_tri, scale_diag):
        """
        Multivariate normal with triangular parameterization.
        
        Parameters:
        - loc: mean vector (array of shape [..., d])
        - scale_tri: lower triangular matrix (array of shape [..., d, d])
        - scale_diag: diagonal scale factors (array of shape [..., d])
        """

    @property
    def loc(self): ...
    @property
    def scale_tri(self): ...
    @property
    def scale_diag(self): ...

Multivariate Normal with Low-Rank Plus Diagonal

Multivariate normal with diagonal plus low-rank covariance structure.

class MultivariateNormalDiagPlusLowRank(Distribution):
    def __init__(self, loc, scale_diag, scale_identity_multiplier, scale_perturb_factor):
        """
        Multivariate normal with diagonal plus low-rank covariance.
        
        Parameters:
        - loc: mean vector (array of shape [..., d])
        - scale_diag: diagonal component (array of shape [..., d])
        - scale_identity_multiplier: scalar multiplier for identity (float or array)
        - scale_perturb_factor: low-rank perturbation factor (array of shape [..., d, k])
        """

    @property
    def loc(self): ...
    @property
    def scale_diag(self): ...
    @property
    def scale_identity_multiplier(self): ...
    @property
    def scale_perturb_factor(self): ...

Multivariate Normal from Bijector

Multivariate normal distribution constructed using a bijector transformation.

class MultivariateNormalFromBijector(Distribution):
    def __init__(self, shift, bijector):
        """
        Multivariate normal constructed via bijector.
        
        Parameters:
        - shift: location parameter (array of shape [..., d])
        - bijector: bijector defining the transformation from standard normal
        """

    @property
    def shift(self): ...
    @property
    def bijector(self): ...

Dirichlet Distribution

Dirichlet distribution for modeling probability vectors.

class Dirichlet(Distribution):
    def __init__(self, concentration):
        """
        Dirichlet distribution.
        
        Parameters:
        - concentration: concentration parameters (array of shape [..., k], must be positive)
        """

    @property
    def concentration(self): ...
    @property
    def event_shape(self): ...

Install with Tessl CLI

npx tessl i tessl/pypi-distrax

docs

bijectors.md

continuous-distributions.md

discrete-distributions.md

index.md

mixture-composite.md

specialized-distributions.md

utilities.md

tile.json