Distrax: Probability distributions in JAX.
—
Discrete probability distributions for modeling categorical and binary outcomes, including various parameterizations for different use cases.
Binary distribution for modeling binary outcomes.
class Bernoulli(Distribution):
def __init__(self, logits=None, probs=None, dtype=int):
"""
Bernoulli distribution for binary outcomes.
Parameters:
- logits: log-odds parameter (float or array, mutually exclusive with probs)
- probs: probability parameter (float or array in [0,1], mutually exclusive with logits)
- dtype: output data type (int or float)
Note: Exactly one of logits or probs must be specified.
"""
@property
def logits(self): ...
@property
def probs(self): ...
@property
def dtype(self): ...
@property
def event_shape(self): ...Categorical distribution for discrete outcomes with multiple categories.
class Categorical(Distribution):
def __init__(self, logits=None, probs=None, dtype=int):
"""
Categorical distribution for discrete outcomes.
Parameters:
- logits: log-probabilities (array of shape [..., k], mutually exclusive with probs)
- probs: probabilities (array of shape [..., k] that sums to 1, mutually exclusive with logits)
- dtype: output data type (int or float)
Note: Exactly one of logits or probs must be specified.
"""
@property
def logits(self): ...
@property
def probs(self): ...
@property
def dtype(self): ...
@property
def num_categories(self): ...
@property
def event_shape(self): ...Categorical distribution with one-hot encoded outcomes.
class OneHotCategorical(Distribution):
def __init__(self, logits=None, probs=None, dtype=float):
"""
One-hot categorical distribution.
Parameters:
- logits: log-probabilities (array of shape [..., k], mutually exclusive with probs)
- probs: probabilities (array of shape [..., k] that sums to 1, mutually exclusive with logits)
- dtype: output data type (float or int)
Note: Exactly one of logits or probs must be specified.
"""
@property
def logits(self): ...
@property
def probs(self): ...
@property
def dtype(self): ...
@property
def num_categories(self): ...
@property
def event_shape(self): ...Uniform categorical distribution over all categories.
class CategoricalUniform(Distribution):
def __init__(self, num_categories, dtype=int):
"""
Uniform categorical distribution.
Parameters:
- num_categories: number of categories (positive integer)
- dtype: output data type (int or float)
"""
@property
def num_categories(self): ...
@property
def dtype(self): ...
@property
def logits(self): ...
@property
def probs(self): ...Multinomial distribution for modeling counts across multiple categories.
class Multinomial(Distribution):
def __init__(self, total_count, logits=None, probs=None, dtype=int):
"""
Multinomial distribution.
Parameters:
- total_count: total number of trials (positive integer or array)
- logits: log-probabilities for each category (array, mutually exclusive with probs)
- probs: probabilities for each category (array that sums to 1, mutually exclusive with logits)
- dtype: output data type (int or float)
Note: Exactly one of logits or probs must be specified.
"""
@property
def total_count(self): ...
@property
def logits(self): ...
@property
def probs(self): ...
@property
def event_shape(self): ...Softmax distribution for normalized discrete outcomes.
class Softmax(Distribution):
def __init__(self, logits, temperature=1.0):
"""
Softmax distribution.
Parameters:
- logits: unnormalized log-probabilities (array of shape [..., k])
- temperature: temperature parameter for softmax (positive float, default 1.0)
"""
@property
def logits(self): ...
@property
def temperature(self): ...
@property
def probs(self): ...Install with Tessl CLI
npx tessl i tessl/pypi-distrax