or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

bijectors.mdcontinuous-distributions.mddiscrete-distributions.mdindex.mdmixture-composite.mdspecialized-distributions.mdutilities.md

index.mddocs/

0

# Distrax

1

2

Distrax is a lightweight JAX-native library of probability distributions and bijectors that acts as a reimplementation of a subset of TensorFlow Probability (TFP) with emphasis on readability, extensibility, and cross-compatibility. The library provides a comprehensive set of probability distributions and bijectors (invertible functions with known Jacobian determinants) that can be used to create complex distributions by transforming simpler ones.

3

4

## Package Information

5

6

- **Package Name**: distrax

7

- **Language**: Python

8

- **Installation**: `pip install distrax`

9

10

## Core Imports

11

12

```python

13

import distrax

14

```

15

16

Common patterns for distributions:

17

18

```python

19

from distrax import Normal, Bernoulli, Categorical

20

```

21

22

Common patterns for bijectors:

23

24

```python

25

from distrax import ScalarAffine, Chain, Sigmoid

26

```

27

28

## Basic Usage

29

30

```python

31

import distrax

32

import jax.numpy as jnp

33

import jax.random as random

34

35

# Create a simple distribution

36

key = random.PRNGKey(42)

37

dist = distrax.Normal(loc=0.0, scale=1.0)

38

39

# Sample from the distribution

40

samples = dist.sample(seed=key, sample_shape=(100,))

41

42

# Compute log probabilities

43

log_probs = dist.log_prob(samples)

44

45

# Create a bijector for transformations

46

bijector = distrax.ScalarAffine(shift=2.0, scale=0.5)

47

48

# Transform values

49

x = jnp.array([1.0, 2.0, 3.0])

50

y = bijector.forward(x)

51

x_reconstructed = bijector.inverse(y)

52

53

# Create transformed distributions

54

transformed_dist = distrax.Transformed(dist, bijector)

55

transformed_samples = transformed_dist.sample(seed=key, sample_shape=(100,))

56

```

57

58

## Architecture

59

60

Distrax follows a clear architectural pattern based on two main abstractions:

61

62

- **Distribution**: Base class for probability distributions providing sampling, density evaluation, and statistical properties

63

- **Bijector**: Base class for invertible functions with computable Jacobian determinants

64

65

This design enables:

66

- **Compositional flexibility**: Bijectors can be chained and combined with distributions

67

- **JAX integration**: Full compatibility with JAX transformations (jit, vmap, grad)

68

- **TFP compatibility**: Seamless interoperability with TensorFlow Probability

69

- **Type safety**: Comprehensive type hints for better development experience

70

71

## Capabilities

72

73

### Continuous Distributions

74

75

Univariate and multivariate continuous probability distributions including Normal, Beta, Gamma, Laplace, and multivariate normal variants with different covariance structures.

76

77

```python { .api }

78

class Normal(Distribution):

79

def __init__(self, loc, scale): ...

80

81

class Beta(Distribution):

82

def __init__(self, concentration1, concentration0): ...

83

84

class MultivariateNormalDiag(Distribution):

85

def __init__(self, loc, scale_diag): ...

86

```

87

88

[Continuous Distributions](./continuous-distributions.md)

89

90

### Discrete Distributions

91

92

Discrete probability distributions for categorical and binary outcomes, including Bernoulli, Categorical, and Multinomial distributions with various parameterizations.

93

94

```python { .api }

95

class Bernoulli(Distribution):

96

def __init__(self, logits=None, probs=None, dtype=int): ...

97

98

class Categorical(Distribution):

99

def __init__(self, logits=None, probs=None, dtype=int): ...

100

101

class OneHotCategorical(Distribution):

102

def __init__(self, logits=None, probs=None, dtype=float): ...

103

```

104

105

[Discrete Distributions](./discrete-distributions.md)

106

107

### Bijectors

108

109

Invertible transformations with known Jacobian determinants for creating complex distributions through composition, including affine transformations, normalizing flows, and neural network layers.

110

111

```python { .api }

112

class Bijector:

113

def forward(self, x): ...

114

def inverse(self, y): ...

115

def forward_and_log_det(self, x): ...

116

117

class ScalarAffine(Bijector):

118

def __init__(self, shift, scale=None, log_scale=None): ...

119

120

class Chain(Bijector):

121

def __init__(self, bijectors): ...

122

```

123

124

[Bijectors](./bijectors.md)

125

126

### Mixture and Composite Distributions

127

128

Complex distributions created by combining simpler components, including mixture models, transformed distributions, and joint distributions for multi-component modeling.

129

130

```python { .api }

131

class Transformed(Distribution):

132

def __init__(self, distribution, bijector): ...

133

134

class MixtureSameFamily(Distribution):

135

def __init__(self, mixture_distribution, components_distribution): ...

136

137

class Independent(Distribution):

138

def __init__(self, distribution, reinterpreted_batch_ndims): ...

139

```

140

141

[Mixture and Composite Distributions](./mixture-composite.md)

142

143

### Specialized Distributions

144

145

Task-specific distributions for reinforcement learning, clipped distributions, and deterministic distributions for specialized modeling needs.

146

147

```python { .api }

148

class EpsilonGreedy(Distribution):

149

def __init__(self, preferences, epsilon): ...

150

151

class ClippedNormal(Distribution):

152

def __init__(self, loc, scale, low, high): ...

153

154

class Deterministic(Distribution):

155

def __init__(self, loc): ...

156

```

157

158

[Specialized Distributions](./specialized-distributions.md)

159

160

### Utilities

161

162

Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.

163

164

```python { .api }

165

def as_distribution(obj: DistributionLike) -> Distribution: ...

166

def as_bijector(obj: BijectorLike) -> Bijector: ...

167

def to_tfp(obj, name=None): ...

168

169

class HMM:

170

def __init__(self, init_dist, trans_dist, obs_dist): ...

171

```

172

173

[Utilities](./utilities.md)

174

175

## Types

176

177

### Base Classes

178

179

```python { .api }

180

class Distribution:

181

"""

182

Abstract base class for probability distributions.

183

184

Provides common interface for sampling, density evaluation, and statistical properties.

185

All distributions must implement log_prob() and _sample_n() methods.

186

"""

187

188

def sample(self, *, seed, sample_shape=()): ...

189

def sample_and_log_prob(self, *, seed, sample_shape=()): ...

190

def log_prob(self, value): ...

191

def prob(self, value): ...

192

def entropy(self): ...

193

def mean(self): ...

194

def variance(self): ...

195

def cdf(self, value): ...

196

def __getitem__(self, index): ...

197

198

@property

199

def event_shape(self): ...

200

@property

201

def batch_shape(self): ...

202

@property

203

def dtype(self): ...

204

205

class Bijector:

206

"""

207

Abstract base class for invertible transformations with known Jacobian determinants.

208

209

All bijectors must implement forward_and_log_det() method.

210

"""

211

212

def forward(self, x): ...

213

def inverse(self, y): ...

214

def forward_and_log_det(self, x): ...

215

def inverse_and_log_det(self, y): ...

216

217

@property

218

def event_ndims_in(self): ...

219

@property

220

def event_ndims_out(self): ...

221

```

222

223

### Type Aliases

224

225

```python { .api }

226

from typing import Union, Callable

227

from chex import Array

228

229

DistributionLike = Union[Distribution, 'tfd.Distribution']

230

BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]

231

```