CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-distrax

Distrax: Probability distributions in JAX.

Pending
Overview
Eval results
Files

specialized-distributions.mddocs/

Specialized Distributions

Task-specific distributions for reinforcement learning, clipped distributions, and deterministic distributions for specialized modeling needs.

Capabilities

Reinforcement Learning Distributions

Epsilon-Greedy Distribution

Epsilon-greedy distribution for exploration in reinforcement learning.

class EpsilonGreedy(Distribution):
    def __init__(self, preferences, epsilon):
        """
        Epsilon-greedy distribution.
        
        Parameters:
        - preferences: preference scores for actions (array)
        - epsilon: exploration probability (float in [0, 1])
        """

    @property
    def preferences(self): ...
    @property
    def epsilon(self): ...
    @property
    def most_likely_action(self): ...

Greedy Distribution

Greedy distribution that always selects the highest-scoring action.

class Greedy(Distribution):
    def __init__(self, preferences, dtype=int):
        """
        Greedy distribution.
        
        Parameters:
        - preferences: preference scores for actions (array)
        - dtype: output data type (int or float)
        """

    @property
    def preferences(self): ...
    @property
    def most_likely_action(self): ...

Clipped Distributions

Base Clipped Distribution

Base class for distributions with clipped support.

class Clipped(Distribution):
    def __init__(self, distribution, low, high):
        """
        Base clipped distribution.
        
        Parameters:
        - distribution: base distribution to clip
        - low: lower clipping bound (float or array)
        - high: upper clipping bound (float or array)
        """

    @property
    def distribution(self): ...
    @property
    def low(self): ...
    @property
    def high(self): ...

Clipped Normal Distribution

Normal distribution with clipped support.

class ClippedNormal(Distribution):
    def __init__(self, loc, scale, low, high):
        """
        Clipped normal distribution.
        
        Parameters:
        - loc: mean parameter (float or array)
        - scale: standard deviation parameter (float or array, must be positive)
        - low: lower clipping bound (float or array)
        - high: upper clipping bound (float or array)
        """

    @property
    def loc(self): ...
    @property
    def scale(self): ...
    @property
    def low(self): ...
    @property
    def high(self): ...

Clipped Logistic Distribution

Logistic distribution with clipped support.

class ClippedLogistic(Distribution):
    def __init__(self, loc, scale, low, high):
        """
        Clipped logistic distribution.
        
        Parameters:
        - loc: location parameter (float or array)
        - scale: scale parameter (float or array, must be positive)
        - low: lower clipping bound (float or array)
        - high: upper clipping bound (float or array)
        """

    @property
    def loc(self): ...
    @property
    def scale(self): ...
    @property
    def low(self): ...
    @property
    def high(self): ...

Deterministic Distribution

Distribution that always returns the same value.

class Deterministic(Distribution):
    def __init__(self, loc):
        """
        Deterministic distribution (Dirac delta).
        
        Parameters:
        - loc: deterministic value (float or array)
        """

    @property
    def loc(self): ...
    @property
    def event_shape(self): ...

Straight-Through Wrapper

Wrapper that uses straight-through gradients for samples.

def straight_through_wrapper(distribution_cls):
    """
    Wraps a distribution to use straight-through gradients for samples.
    
    Parameters:
    - distribution_cls: distribution class to wrap
    
    Returns:
    Wrapped distribution class with straight-through gradients
    """

Install with Tessl CLI

npx tessl i tessl/pypi-distrax

docs

bijectors.md

continuous-distributions.md

discrete-distributions.md

index.md

mixture-composite.md

specialized-distributions.md

utilities.md

tile.json