Distrax: Probability distributions in JAX.
—
Task-specific distributions for reinforcement learning, clipped distributions, and deterministic distributions for specialized modeling needs.
Epsilon-greedy distribution for exploration in reinforcement learning.
class EpsilonGreedy(Distribution):
def __init__(self, preferences, epsilon):
"""
Epsilon-greedy distribution.
Parameters:
- preferences: preference scores for actions (array)
- epsilon: exploration probability (float in [0, 1])
"""
@property
def preferences(self): ...
@property
def epsilon(self): ...
@property
def most_likely_action(self): ...Greedy distribution that always selects the highest-scoring action.
class Greedy(Distribution):
def __init__(self, preferences, dtype=int):
"""
Greedy distribution.
Parameters:
- preferences: preference scores for actions (array)
- dtype: output data type (int or float)
"""
@property
def preferences(self): ...
@property
def most_likely_action(self): ...Base class for distributions with clipped support.
class Clipped(Distribution):
def __init__(self, distribution, low, high):
"""
Base clipped distribution.
Parameters:
- distribution: base distribution to clip
- low: lower clipping bound (float or array)
- high: upper clipping bound (float or array)
"""
@property
def distribution(self): ...
@property
def low(self): ...
@property
def high(self): ...Normal distribution with clipped support.
class ClippedNormal(Distribution):
def __init__(self, loc, scale, low, high):
"""
Clipped normal distribution.
Parameters:
- loc: mean parameter (float or array)
- scale: standard deviation parameter (float or array, must be positive)
- low: lower clipping bound (float or array)
- high: upper clipping bound (float or array)
"""
@property
def loc(self): ...
@property
def scale(self): ...
@property
def low(self): ...
@property
def high(self): ...Logistic distribution with clipped support.
class ClippedLogistic(Distribution):
def __init__(self, loc, scale, low, high):
"""
Clipped logistic distribution.
Parameters:
- loc: location parameter (float or array)
- scale: scale parameter (float or array, must be positive)
- low: lower clipping bound (float or array)
- high: upper clipping bound (float or array)
"""
@property
def loc(self): ...
@property
def scale(self): ...
@property
def low(self): ...
@property
def high(self): ...Distribution that always returns the same value.
class Deterministic(Distribution):
def __init__(self, loc):
"""
Deterministic distribution (Dirac delta).
Parameters:
- loc: deterministic value (float or array)
"""
@property
def loc(self): ...
@property
def event_shape(self): ...Wrapper that uses straight-through gradients for samples.
def straight_through_wrapper(distribution_cls):
"""
Wraps a distribution to use straight-through gradients for samples.
Parameters:
- distribution_cls: distribution class to wrap
Returns:
Wrapped distribution class with straight-through gradients
"""Install with Tessl CLI
npx tessl i tessl/pypi-distrax