CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-distrax

Distrax: Probability distributions in JAX.

Pending
Overview
Eval results
Files

discrete-distributions.mddocs/

Discrete Distributions

Discrete probability distributions for modeling categorical and binary outcomes, including various parameterizations for different use cases.

Capabilities

Bernoulli Distribution

Binary distribution for modeling binary outcomes.

class Bernoulli(Distribution):
    def __init__(self, logits=None, probs=None, dtype=int):
        """
        Bernoulli distribution for binary outcomes.
        
        Parameters:
        - logits: log-odds parameter (float or array, mutually exclusive with probs)
        - probs: probability parameter (float or array in [0,1], mutually exclusive with logits)
        - dtype: output data type (int or float)
        
        Note: Exactly one of logits or probs must be specified.
        """

    @property
    def logits(self): ...
    @property
    def probs(self): ...
    @property
    def dtype(self): ...
    @property
    def event_shape(self): ...

Categorical Distribution

Categorical distribution for discrete outcomes with multiple categories.

class Categorical(Distribution):
    def __init__(self, logits=None, probs=None, dtype=int):
        """
        Categorical distribution for discrete outcomes.
        
        Parameters:
        - logits: log-probabilities (array of shape [..., k], mutually exclusive with probs)
        - probs: probabilities (array of shape [..., k] that sums to 1, mutually exclusive with logits)
        - dtype: output data type (int or float)
        
        Note: Exactly one of logits or probs must be specified.
        """

    @property
    def logits(self): ...
    @property
    def probs(self): ...
    @property
    def dtype(self): ...
    @property
    def num_categories(self): ...
    @property
    def event_shape(self): ...

One-Hot Categorical Distribution

Categorical distribution with one-hot encoded outcomes.

class OneHotCategorical(Distribution):
    def __init__(self, logits=None, probs=None, dtype=float):
        """
        One-hot categorical distribution.
        
        Parameters:
        - logits: log-probabilities (array of shape [..., k], mutually exclusive with probs)
        - probs: probabilities (array of shape [..., k] that sums to 1, mutually exclusive with logits)
        - dtype: output data type (float or int)
        
        Note: Exactly one of logits or probs must be specified.
        """

    @property
    def logits(self): ...
    @property
    def probs(self): ...
    @property
    def dtype(self): ...
    @property
    def num_categories(self): ...
    @property
    def event_shape(self): ...

Categorical Uniform Distribution

Uniform categorical distribution over all categories.

class CategoricalUniform(Distribution):
    def __init__(self, num_categories, dtype=int):
        """
        Uniform categorical distribution.
        
        Parameters:
        - num_categories: number of categories (positive integer)
        - dtype: output data type (int or float)
        """

    @property
    def num_categories(self): ...
    @property
    def dtype(self): ...
    @property
    def logits(self): ...
    @property
    def probs(self): ...

Multinomial Distribution

Multinomial distribution for modeling counts across multiple categories.

class Multinomial(Distribution):
    def __init__(self, total_count, logits=None, probs=None, dtype=int):
        """
        Multinomial distribution.
        
        Parameters:
        - total_count: total number of trials (positive integer or array)
        - logits: log-probabilities for each category (array, mutually exclusive with probs)
        - probs: probabilities for each category (array that sums to 1, mutually exclusive with logits)
        - dtype: output data type (int or float)
        
        Note: Exactly one of logits or probs must be specified.
        """

    @property
    def total_count(self): ...
    @property
    def logits(self): ...
    @property
    def probs(self): ...
    @property
    def event_shape(self): ...

Softmax Distribution

Softmax distribution for normalized discrete outcomes.

class Softmax(Distribution):
    def __init__(self, logits, temperature=1.0):
        """
        Softmax distribution.
        
        Parameters:
        - logits: unnormalized log-probabilities (array of shape [..., k])
        - temperature: temperature parameter for softmax (positive float, default 1.0)
        """

    @property
    def logits(self): ...
    @property
    def temperature(self): ...
    @property
    def probs(self): ...

Install with Tessl CLI

npx tessl i tessl/pypi-distrax

docs

bijectors.md

continuous-distributions.md

discrete-distributions.md

index.md

mixture-composite.md

specialized-distributions.md

utilities.md

tile.json