0
# Distrax
1
2
Distrax is a lightweight JAX-native library of probability distributions and bijectors that acts as a reimplementation of a subset of TensorFlow Probability (TFP) with emphasis on readability, extensibility, and cross-compatibility. The library provides a comprehensive set of probability distributions and bijectors (invertible functions with known Jacobian determinants) that can be used to create complex distributions by transforming simpler ones.
3
4
## Package Information
5
6
- **Package Name**: distrax
7
- **Language**: Python
8
- **Installation**: `pip install distrax`
9
10
## Core Imports
11
12
```python
13
import distrax
14
```
15
16
Common patterns for distributions:
17
18
```python
19
from distrax import Normal, Bernoulli, Categorical
20
```
21
22
Common patterns for bijectors:
23
24
```python
25
from distrax import ScalarAffine, Chain, Sigmoid
26
```
27
28
## Basic Usage
29
30
```python
31
import distrax
32
import jax.numpy as jnp
33
import jax.random as random
34
35
# Create a simple distribution
36
key = random.PRNGKey(42)
37
dist = distrax.Normal(loc=0.0, scale=1.0)
38
39
# Sample from the distribution
40
samples = dist.sample(seed=key, sample_shape=(100,))
41
42
# Compute log probabilities
43
log_probs = dist.log_prob(samples)
44
45
# Create a bijector for transformations
46
bijector = distrax.ScalarAffine(shift=2.0, scale=0.5)
47
48
# Transform values
49
x = jnp.array([1.0, 2.0, 3.0])
50
y = bijector.forward(x)
51
x_reconstructed = bijector.inverse(y)
52
53
# Create transformed distributions
54
transformed_dist = distrax.Transformed(dist, bijector)
55
transformed_samples = transformed_dist.sample(seed=key, sample_shape=(100,))
56
```
57
58
## Architecture
59
60
Distrax follows a clear architectural pattern based on two main abstractions:
61
62
- **Distribution**: Base class for probability distributions providing sampling, density evaluation, and statistical properties
63
- **Bijector**: Base class for invertible functions with computable Jacobian determinants
64
65
This design enables:
66
- **Compositional flexibility**: Bijectors can be chained and combined with distributions
67
- **JAX integration**: Full compatibility with JAX transformations (jit, vmap, grad)
68
- **TFP compatibility**: Seamless interoperability with TensorFlow Probability
69
- **Type safety**: Comprehensive type hints for better development experience
70
71
## Capabilities
72
73
### Continuous Distributions
74
75
Univariate and multivariate continuous probability distributions including Normal, Beta, Gamma, Laplace, and multivariate normal variants with different covariance structures.
76
77
```python { .api }
78
class Normal(Distribution):
79
def __init__(self, loc, scale): ...
80
81
class Beta(Distribution):
82
def __init__(self, concentration1, concentration0): ...
83
84
class MultivariateNormalDiag(Distribution):
85
def __init__(self, loc, scale_diag): ...
86
```
87
88
[Continuous Distributions](./continuous-distributions.md)
89
90
### Discrete Distributions
91
92
Discrete probability distributions for categorical and binary outcomes, including Bernoulli, Categorical, and Multinomial distributions with various parameterizations.
93
94
```python { .api }
95
class Bernoulli(Distribution):
96
def __init__(self, logits=None, probs=None, dtype=int): ...
97
98
class Categorical(Distribution):
99
def __init__(self, logits=None, probs=None, dtype=int): ...
100
101
class OneHotCategorical(Distribution):
102
def __init__(self, logits=None, probs=None, dtype=float): ...
103
```
104
105
[Discrete Distributions](./discrete-distributions.md)
106
107
### Bijectors
108
109
Invertible transformations with known Jacobian determinants for creating complex distributions through composition, including affine transformations, normalizing flows, and neural network layers.
110
111
```python { .api }
112
class Bijector:
113
def forward(self, x): ...
114
def inverse(self, y): ...
115
def forward_and_log_det(self, x): ...
116
117
class ScalarAffine(Bijector):
118
def __init__(self, shift, scale=None, log_scale=None): ...
119
120
class Chain(Bijector):
121
def __init__(self, bijectors): ...
122
```
123
124
[Bijectors](./bijectors.md)
125
126
### Mixture and Composite Distributions
127
128
Complex distributions created by combining simpler components, including mixture models, transformed distributions, and joint distributions for multi-component modeling.
129
130
```python { .api }
131
class Transformed(Distribution):
132
def __init__(self, distribution, bijector): ...
133
134
class MixtureSameFamily(Distribution):
135
def __init__(self, mixture_distribution, components_distribution): ...
136
137
class Independent(Distribution):
138
def __init__(self, distribution, reinterpreted_batch_ndims): ...
139
```
140
141
[Mixture and Composite Distributions](./mixture-composite.md)
142
143
### Specialized Distributions
144
145
Task-specific distributions for reinforcement learning, clipped distributions, and deterministic distributions for specialized modeling needs.
146
147
```python { .api }
148
class EpsilonGreedy(Distribution):
149
def __init__(self, preferences, epsilon): ...
150
151
class ClippedNormal(Distribution):
152
def __init__(self, loc, scale, low, high): ...
153
154
class Deterministic(Distribution):
155
def __init__(self, loc): ...
156
```
157
158
[Specialized Distributions](./specialized-distributions.md)
159
160
### Utilities
161
162
Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.
163
164
```python { .api }
165
def as_distribution(obj: DistributionLike) -> Distribution: ...
166
def as_bijector(obj: BijectorLike) -> Bijector: ...
167
def to_tfp(obj, name=None): ...
168
169
class HMM:
170
def __init__(self, init_dist, trans_dist, obs_dist): ...
171
```
172
173
[Utilities](./utilities.md)
174
175
## Types
176
177
### Base Classes
178
179
```python { .api }
180
class Distribution:
181
"""
182
Abstract base class for probability distributions.
183
184
Provides common interface for sampling, density evaluation, and statistical properties.
185
All distributions must implement log_prob() and _sample_n() methods.
186
"""
187
188
def sample(self, *, seed, sample_shape=()): ...
189
def sample_and_log_prob(self, *, seed, sample_shape=()): ...
190
def log_prob(self, value): ...
191
def prob(self, value): ...
192
def entropy(self): ...
193
def mean(self): ...
194
def variance(self): ...
195
def cdf(self, value): ...
196
def __getitem__(self, index): ...
197
198
@property
199
def event_shape(self): ...
200
@property
201
def batch_shape(self): ...
202
@property
203
def dtype(self): ...
204
205
class Bijector:
206
"""
207
Abstract base class for invertible transformations with known Jacobian determinants.
208
209
All bijectors must implement forward_and_log_det() method.
210
"""
211
212
def forward(self, x): ...
213
def inverse(self, y): ...
214
def forward_and_log_det(self, x): ...
215
def inverse_and_log_det(self, y): ...
216
217
@property
218
def event_ndims_in(self): ...
219
@property
220
def event_ndims_out(self): ...
221
```
222
223
### Type Aliases
224
225
```python { .api }
226
from typing import Union, Callable
227
from chex import Array
228
229
DistributionLike = Union[Distribution, 'tfd.Distribution']
230
BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]
231
```