or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

bijectors.mdcontinuous-distributions.mddiscrete-distributions.mdindex.mdmixture-composite.mdspecialized-distributions.mdutilities.md

utilities.mddocs/

0

# Utilities

1

2

Helper functions for distribution conversion, Monte Carlo estimation, mathematical operations, and Hidden Markov Models for advanced probabilistic modeling.

3

4

## Capabilities

5

6

### Conversion Utilities

7

8

#### Convert to Distrax Distribution

9

10

Converts distribution-like objects to Distrax distributions.

11

12

```python { .api }

13

def as_distribution(obj):

14

"""

15

Convert distribution-like object to Distrax distribution.

16

17

Parameters:

18

- obj: DistributionLike object (Distrax or TFP distribution)

19

20

Returns:

21

Distrax Distribution

22

"""

23

```

24

25

#### Convert to Distrax Bijector

26

27

Converts bijector-like objects to Distrax bijectors.

28

29

```python { .api }

30

def as_bijector(obj):

31

"""

32

Convert bijector-like object to Distrax bijector.

33

34

Parameters:

35

- obj: BijectorLike object (Distrax bijector, TFP bijector, or callable)

36

37

Returns:

38

Distrax Bijector

39

"""

40

```

41

42

#### Convert to TensorFlow Probability

43

44

Converts Distrax objects to TFP-compatible equivalents.

45

46

```python { .api }

47

def to_tfp(obj, name=None):

48

"""

49

Convert Distrax object to TFP-compatible equivalent.

50

51

Parameters:

52

- obj: Distrax distribution or bijector

53

- name: optional name for the TFP object

54

55

Returns:

56

TFP-compatible distribution or bijector

57

"""

58

```

59

60

### Mathematical Utilities

61

62

#### Multiply with No NaN

63

64

Element-wise multiplication that returns 0 if second argument is zero.

65

66

```python { .api }

67

def multiply_no_nan(x, y):

68

"""

69

Element-wise multiplication that returns 0 if y is 0.

70

71

Parameters:

72

- x: first operand (array)

73

- y: second operand (array)

74

75

Returns:

76

Element-wise product with NaN-safe handling

77

"""

78

```

79

80

### Monte Carlo Estimation

81

82

#### Best-Effort KL Divergence Estimation

83

84

Estimates KL divergence exactly if possible, otherwise uses Monte Carlo.

85

86

```python { .api }

87

def estimate_kl_best_effort(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):

88

"""

89

Estimate KL divergence using best available method.

90

91

Parameters:

92

- distribution_a: first distribution

93

- distribution_b: second distribution

94

- rng_key: JAX random key

95

- num_samples: number of Monte Carlo samples

96

- proposal_distribution: optional proposal distribution for importance sampling

97

98

Returns:

99

KL divergence estimate

100

"""

101

```

102

103

#### Monte Carlo KL Divergence Estimation

104

105

Monte Carlo estimation of KL divergence using DiCE estimator.

106

107

```python { .api }

108

def mc_estimate_kl(distribution_a, distribution_b, rng_key, num_samples, proposal_distribution=None):

109

"""

110

Monte Carlo estimation of KL divergence.

111

112

Parameters:

113

- distribution_a: first distribution

114

- distribution_b: second distribution

115

- rng_key: JAX random key

116

- num_samples: number of Monte Carlo samples

117

- proposal_distribution: optional proposal distribution for importance sampling

118

119

Returns:

120

KL divergence estimate

121

"""

122

```

123

124

#### Monte Carlo KL with Reparameterized Distributions

125

126

Monte Carlo KL estimation with reparameterized distributions.

127

128

```python { .api }

129

def mc_estimate_kl_with_reparameterized(distribution_a, distribution_b, rng_key, num_samples):

130

"""

131

Monte Carlo KL estimation with reparameterized distributions.

132

133

Parameters:

134

- distribution_a: first distribution (must be reparameterizable)

135

- distribution_b: second distribution

136

- rng_key: JAX random key

137

- num_samples: number of Monte Carlo samples

138

139

Returns:

140

KL divergence estimate

141

"""

142

```

143

144

#### Monte Carlo Mode Estimation

145

146

Monte Carlo estimation of distribution mode.

147

148

```python { .api }

149

def mc_estimate_mode(distribution, rng_key, num_samples):

150

"""

151

Monte Carlo estimation of distribution mode.

152

153

Parameters:

154

- distribution: distribution to estimate mode

155

- rng_key: JAX random key

156

- num_samples: number of Monte Carlo samples

157

158

Returns:

159

Mode estimate

160

"""

161

```

162

163

### Importance Sampling

164

165

#### Importance Sampling Ratios

166

167

Compute importance sampling ratios between distributions.

168

169

```python { .api }

170

def importance_sampling_ratios(target_dist, sampling_dist, event):

171

"""

172

Compute importance sampling ratios.

173

174

Parameters:

175

- target_dist: target distribution

176

- sampling_dist: sampling distribution

177

- event: sampled events (array)

178

179

Returns:

180

Importance sampling ratios

181

"""

182

```

183

184

### Transformation Utilities

185

186

#### Register Inverse Functions

187

188

Register inverse functions for JAX primitives.

189

190

```python { .api }

191

def register_inverse(primitive, inverse_left, inverse_right=None):

192

"""

193

Register inverse functions for JAX primitives.

194

195

Parameters:

196

- primitive: JAX primitive to register inverse for

197

- inverse_left: left inverse function

198

- inverse_right: optional right inverse function

199

"""

200

```

201

202

### Hidden Markov Models

203

204

#### HMM Class

205

206

Hidden Markov Model implementation for sequential modeling.

207

208

```python { .api }

209

class HMM:

210

def __init__(self, init_dist, trans_dist, obs_dist):

211

"""

212

Hidden Markov Model.

213

214

Parameters:

215

- init_dist: initial state distribution

216

- trans_dist: transition distribution

217

- obs_dist: observation distribution

218

"""

219

220

def sample(self, *, seed, seq_len):

221

"""

222

Sample a sequence from the HMM.

223

224

Parameters:

225

- seed: JAX random key

226

- seq_len: length of sequence to sample

227

228

Returns:

229

Tuple of (states, observations)

230

"""

231

232

def forward(self, obs_seq, length=None):

233

"""

234

Forward algorithm for computing marginal likelihood.

235

236

Parameters:

237

- obs_seq: sequence of observations (array)

238

- length: optional sequence length (for batched sequences)

239

240

Returns:

241

Forward probabilities and log marginal likelihood

242

"""

243

244

def backward(self, obs_seq, length=None):

245

"""

246

Backward algorithm for computing backward probabilities.

247

248

Parameters:

249

- obs_seq: sequence of observations (array)

250

- length: optional sequence length (for batched sequences)

251

252

Returns:

253

Backward probabilities

254

"""

255

256

def forward_backward(self, obs_seq, length=None):

257

"""

258

Forward-backward algorithm for state posterior probabilities.

259

260

Parameters:

261

- obs_seq: sequence of observations (array)

262

- length: optional sequence length (for batched sequences)

263

264

Returns:

265

State posterior probabilities and log marginal likelihood

266

"""

267

268

def viterbi(self, obs_seq):

269

"""

270

Viterbi algorithm for most likely state sequence.

271

272

Parameters:

273

- obs_seq: sequence of observations (array)

274

275

Returns:

276

Most likely state sequence and its log probability

277

"""

278

279

@property

280

def init_dist(self): ...

281

@property

282

def trans_dist(self): ...

283

@property

284

def obs_dist(self): ...

285

286

@property

287

def event_shape(self): ...

288

@property

289

def batch_shape(self): ...

290

```

291

292

## Usage Examples

293

294

### Converting Between Libraries

295

296

```python

297

import distrax

298

import tensorflow_probability.substrates.jax as tfp

299

300

# Convert TFP distribution to Distrax

301

tfp_normal = tfp.distributions.Normal(0.0, 1.0)

302

distrax_normal = distrax.as_distribution(tfp_normal)

303

304

# Convert Distrax distribution to TFP

305

distrax_normal = distrax.Normal(0.0, 1.0)

306

tfp_normal = distrax.to_tfp(distrax_normal)

307

```

308

309

### Monte Carlo KL Estimation

310

311

```python

312

import distrax

313

import jax.random as random

314

315

key = random.PRNGKey(42)

316

p = distrax.Normal(0.0, 1.0)

317

q = distrax.Normal(0.5, 1.2)

318

319

# Estimate KL divergence

320

kl_estimate = distrax.mc_estimate_kl(p, q, key, num_samples=10000)

321

```

322

323

### Hidden Markov Model

324

325

```python

326

import distrax

327

import jax.numpy as jnp

328

import jax.random as random

329

330

# Define HMM components

331

init_dist = distrax.Categorical(logits=jnp.array([0.0, 0.0]))

332

trans_dist = distrax.Categorical(logits=jnp.array([[1.0, -1.0], [-1.0, 1.0]]))

333

obs_dist = distrax.Normal(jnp.array([0.0, 3.0]), jnp.array([1.0, 0.5]))

334

335

# Create HMM

336

hmm = distrax.HMM(init_dist, trans_dist, obs_dist)

337

338

# Sample sequence

339

key = random.PRNGKey(42)

340

states, observations = hmm.sample(seed=key, seq_len=100)

341

342

# Compute forward probabilities

343

forward_probs, log_prob = hmm.forward(observations)

344

345

# Find most likely state sequence

346

viterbi_states, viterbi_log_prob = hmm.viterbi(observations)

347

```