0
# Utilities
1
2
Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.
3
4
## Capabilities
5
6
### Conversion Utilities
7
8
#### Convert to Distrax Distribution
9
10
Converts distribution-like objects to Distrax distributions.
11
12
```python { .api }
13
def as_distribution(obj):
14
"""
15
Convert distribution-like object to Distrax distribution.
16
17
Parameters:
18
- obj: DistributionLike object (Distrax or TFP distribution)
19
20
Returns:
21
Distrax Distribution
22
"""
23
```
24
25
#### Convert to Distrax Bijector
26
27
Converts bijector-like objects to Distrax bijectors.
28
29
```python { .api }
30
def as_bijector(obj):
31
"""
32
Convert bijector-like object to Distrax bijector.
33
34
Parameters:
35
- obj: BijectorLike object (Distrax bijector, TFP bijector, or callable)
36
37
Returns:
38
Distrax Bijector
39
"""
40
```
41
42
#### Convert to TensorFlow Probability
43
44
Converts Distrax objects to TFP-compatible equivalents.
45
46
```python { .api }
47
def to_tfp(obj, name=None):
48
"""
49
Convert Distrax object to TFP-compatible equivalent.
50
51
Parameters:
52
- obj: Distrax distribution or bijector
53
- name: optional name for the TFP object
54
55
Returns:
56
TFP-compatible distribution or bijector
57
"""
58
```
59
60
### Mathematical Utilities
61
62
#### Multiply with No NaN
63
64
Element-wise multiplication that returns 0 if second argument is zero.
65
66
```python { .api }
67
def multiply_no_nan(x, y):
68
"""
69
Element-wise multiplication that returns 0 if y is 0.
70
71
Parameters:
72
- x: first operand (array)
73
- y: second operand (array)
74
75
Returns:
76
Element-wise product with NaN-safe handling
77
"""
78
```
79
80
### Monte Carlo Estimation
81
82
#### Best-Effort KL Divergence Estimation
83
84
Estimates KL divergence exactly if possible, otherwise uses Monte Carlo.
85
86
```python { .api }
87
def estimate_kl_best_effort(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):
88
"""
89
Estimate KL divergence using best available method.
90
91
Parameters:
92
- distribution_a: first distribution
93
- distribution_b: second distribution
94
- rng_key: JAX random key
95
- num_samples: number of Monte Carlo samples
96
- proposal_distribution: optional proposal distribution for importance sampling
97
98
Returns:
99
KL divergence estimate
100
"""
101
```
102
103
#### Monte Carlo KL Divergence Estimation
104
105
Monte Carlo estimation of KL divergence using DiCE estimator.
106
107
```python { .api }
108
def mc_estimate_kl(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):
109
"""
110
Monte Carlo estimation of KL divergence.
111
112
Parameters:
113
- distribution_a: first distribution
114
- distribution_b: second distribution
115
- rng_key: JAX random key
116
- num_samples: number of Monte Carlo samples
117
- proposal_distribution: optional proposal distribution for importance sampling
118
119
Returns:
120
KL divergence estimate
121
"""
122
```
123
124
#### Monte Carlo KL with Reparameterized Distributions
125
126
Monte Carlo KL estimation with reparameterized distributions.
127
128
```python { .api }
129
def mc_estimate_kl_with_reparameterized(distribution_a, distribution_b, rng_key, num_samples):
130
"""
131
Monte Carlo KL estimation with reparameterized distributions.
132
133
Parameters:
134
- distribution_a: first distribution (must be reparameterizable)
135
- distribution_b: second distribution
136
- rng_key: JAX random key
137
- num_samples: number of Monte Carlo samples
138
139
Returns:
140
KL divergence estimate
141
"""
142
```
143
144
#### Monte Carlo Mode Estimation
145
146
Monte Carlo estimation of distribution mode.
147
148
```python { .api }
149
def mc_estimate_mode(distribution, rng_key, num_samples):
150
"""
151
Monte Carlo estimation of distribution mode.
152
153
Parameters:
154
- distribution: distribution to estimate mode
155
- rng_key: JAX random key
156
- num_samples: number of Monte Carlo samples
157
158
Returns:
159
Mode estimate
160
"""
161
```
162
163
### Importance Sampling
164
165
#### Importance Sampling Ratios
166
167
Compute importance sampling ratios between distributions.
168
169
```python { .api }
170
def importance_sampling_ratios(target_dist, sampling_dist, event):
171
"""
172
Compute importance sampling ratios.
173
174
Parameters:
175
- target_dist: target distribution
176
- sampling_dist: sampling distribution
177
- event: sampled events (array)
178
179
Returns:
180
Importance sampling ratios
181
"""
182
```
183
184
### Transformation Utilities
185
186
#### Register Inverse Functions
187
188
Register inverse functions for JAX primitives.
189
190
```python { .api }
191
def register_inverse(primitive, inverse_left, inverse_right=None):
192
"""
193
Register inverse functions for JAX primitives.
194
195
Parameters:
196
- primitive: JAX primitive to register inverse for
197
- inverse_left: left inverse function
198
- inverse_right: optional right inverse function
199
"""
200
```
201
202
### Hidden Markov Models
203
204
#### HMM Class
205
206
Hidden Markov Model implementation for sequential modeling.
207
208
```python { .api }
209
class HMM:
210
def __init__(self, init_dist, trans_dist, obs_dist):
211
"""
212
Hidden Markov Model.
213
214
Parameters:
215
- init_dist: initial state distribution
216
- trans_dist: transition distribution
217
- obs_dist: observation distribution
218
"""
219
220
def sample(self, *, seed, seq_len):
221
"""
222
Sample a sequence from the HMM.
223
224
Parameters:
225
- seed: JAX random key
226
- seq_len: length of sequence to sample
227
228
Returns:
229
Tuple of (states, observations)
230
"""
231
232
def forward(self, obs_seq, length=None):
233
"""
234
Forward algorithm for computing marginal likelihood.
235
236
Parameters:
237
- obs_seq: sequence of observations (array)
238
- length: optional sequence length (for batched sequences)
239
240
Returns:
241
Forward probabilities and log marginal likelihood
242
"""
243
244
def backward(self, obs_seq, length=None):
245
"""
246
Backward algorithm for computing backward probabilities.
247
248
Parameters:
249
- obs_seq: sequence of observations (array)
250
- length: optional sequence length (for batched sequences)
251
252
Returns:
253
Backward probabilities
254
"""
255
256
def forward_backward(self, obs_seq, length=None):
257
"""
258
Forward-backward algorithm for state posterior probabilities.
259
260
Parameters:
261
- obs_seq: sequence of observations (array)
262
- length: optional sequence length (for batched sequences)
263
264
Returns:
265
State posterior probabilities and log marginal likelihood
266
"""
267
268
def viterbi(self, obs_seq):
269
"""
270
Viterbi algorithm for most likely state sequence.
271
272
Parameters:
273
- obs_seq: sequence of observations (array)
274
275
Returns:
276
Most likely state sequence and its log probability
277
"""
278
279
@property
280
def init_dist(self): ...
281
@property
282
def trans_dist(self): ...
283
@property
284
def obs_dist(self): ...
285
286
@property
287
def event_shape(self): ...
288
@property
289
def batch_shape(self): ...
290
```
291
292
## Usage Examples
293
294
### Converting Between Libraries
295
296
```python
297
import distrax
298
import tensorflow_probability.substrates.jax as tfp
299
300
# Convert TFP distribution to Distrax
301
tfp_normal = tfp.distributions.Normal(0.0, 1.0)
302
distrax_normal = distrax.as_distribution(tfp_normal)
303
304
# Convert Distrax distribution to TFP
305
distrax_normal = distrax.Normal(0.0, 1.0)
306
tfp_normal = distrax.to_tfp(distrax_normal)
307
```
308
309
### Monte Carlo KL Estimation
310
311
```python
312
import distrax
313
import jax.random as random
314
315
key = random.PRNGKey(42)
316
p = distrax.Normal(0.0, 1.0)
317
q = distrax.Normal(0.5, 1.2)
318
319
# Estimate KL divergence
320
kl_estimate = distrax.mc_estimate_kl(p, q, key, num_samples=10000)
321
```
322
323
### Hidden Markov Model
324
325
```python
326
import distrax
327
import jax.numpy as jnp
328
import jax.random as random
329
330
# Define HMM components
331
init_dist = distrax.Categorical(logits=jnp.array([0.0, 0.0]))
332
trans_dist = distrax.Categorical(logits=jnp.array([[1.0, -1.0], [-1.0, 1.0]]))
333
obs_dist = distrax.Normal(jnp.array([0.0, 3.0]), jnp.array([1.0, 0.5]))
334
335
# Create HMM
336
hmm = distrax.HMM(init_dist, trans_dist, obs_dist)
337
338
# Sample sequence
339
key = random.PRNGKey(42)
340
states, observations = hmm.sample(seed=key, seq_len=100)
341
342
# Compute forward probabilities
343
forward_probs, log_prob = hmm.forward(observations)
344
345
# Find most likely state sequence
346
viterbi_states, viterbi_log_prob = hmm.viterbi(observations)
347
```