or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

bijectors.mdcontinuous-distributions.mddiscrete-distributions.mdindex.mdmixture-composite.mdspecialized-distributions.mdutilities.md

bijectors.mddocs/

0

# Bijectors

1

2

Invertible transformations with known Jacobian determinants for creating complex distributions through composition. Bijectors enable the construction of sophisticated probability models by transforming simple base distributions.

3

4

## Capabilities

5

6

### Base Bijector Class

7

8

Abstract base class defining the bijector interface.

9

10

```python { .api }

11

class Bijector:

12

def __init__(self, event_ndims_in, event_ndims_out=None, is_constant_jacobian=False, is_constant_log_det=None):

13

"""

14

Base class for bijectors.

15

16

Parameters:

17

- event_ndims_in: number of dimensions in input events

18

- event_ndims_out: number of dimensions in output events (defaults to event_ndims_in)

19

- is_constant_jacobian: whether Jacobian is constant

20

- is_constant_log_det: whether log determinant is constant

21

"""

22

23

def forward(self, x):

24

"""Forward transformation y = f(x)."""

25

26

def inverse(self, y):

27

"""Inverse transformation x = f^{-1}(y)."""

28

29

def forward_and_log_det(self, x):

30

"""Forward transformation with log determinant: (y, log|det J|)."""

31

32

def inverse_and_log_det(self, y):

33

"""Inverse transformation with log determinant: (x, log|det J^{-1}|)."""

34

35

def forward_log_det_jacobian(self, x):

36

"""Log determinant of forward Jacobian."""

37

38

def inverse_log_det_jacobian(self, y):

39

"""Log determinant of inverse Jacobian."""

40

41

def same_as(self, other):

42

"""Check equality with another bijector."""

43

44

@property

45

def event_ndims_in(self): ...

46

@property

47

def event_ndims_out(self): ...

48

@property

49

def is_constant_jacobian(self): ...

50

@property

51

def is_constant_log_det(self): ...

52

@property

53

def name(self): ...

54

```

55

56

### Affine Transformations

57

58

#### Scalar Affine Transformation

59

60

Elementwise affine transformation y = scale * x + shift.

61

62

```python { .api }

63

class ScalarAffine(Bijector):

64

def __init__(self, shift, scale=None, log_scale=None):

65

"""

66

Scalar affine transformation.

67

68

Parameters:

69

- shift: translation parameter (float or array)

70

- scale: scale parameter (float or array, mutually exclusive with log_scale)

71

- log_scale: log scale parameter (float or array, mutually exclusive with scale)

72

73

Note: Exactly one of scale or log_scale must be specified.

74

"""

75

76

@property

77

def shift(self): ...

78

@property

79

def scale(self): ...

80

@property

81

def log_scale(self): ...

82

```

83

84

#### Shift Transformation

85

86

Translation bijector y = x + shift.

87

88

```python { .api }

89

class Shift(Bijector):

90

def __init__(self, shift):

91

"""

92

Shift transformation.

93

94

Parameters:

95

- shift: translation parameter (float or array)

96

"""

97

98

@property

99

def shift(self): ...

100

```

101

102

#### Unconstrained Affine Transformation

103

104

General unconstrained affine transformation.

105

106

```python { .api }

107

class UnconstrainedAffine(Bijector):

108

def __init__(self, shift, matrix):

109

"""

110

Unconstrained affine transformation.

111

112

Parameters:

113

- shift: translation vector (array)

114

- matrix: transformation matrix (array)

115

"""

116

117

@property

118

def shift(self): ...

119

@property

120

def matrix(self): ...

121

```

122

123

### Linear Transformations

124

125

#### Diagonal Linear Transformation

126

127

Linear transformation with diagonal matrix.

128

129

```python { .api }

130

class DiagLinear(Bijector):

131

def __init__(self, diag):

132

"""

133

Diagonal linear transformation.

134

135

Parameters:

136

- diag: diagonal elements (array)

137

"""

138

139

@property

140

def diag(self): ...

141

```

142

143

#### General Linear Transformation

144

145

Linear transformation with arbitrary matrix.

146

147

```python { .api }

148

class Linear(Bijector):

149

def __init__(self, matrix):

150

"""

151

Linear transformation.

152

153

Parameters:

154

- matrix: transformation matrix (array)

155

"""

156

157

@property

158

def matrix(self): ...

159

```

160

161

#### Triangular Linear Transformation

162

163

Linear transformation with triangular matrix.

164

165

```python { .api }

166

class TriangularLinear(Bijector):

167

def __init__(self, matrix, lower=True):

168

"""

169

Triangular linear transformation.

170

171

Parameters:

172

- matrix: triangular matrix (array)

173

- lower: whether matrix is lower triangular (bool, default True)

174

"""

175

176

@property

177

def matrix(self): ...

178

@property

179

def lower(self): ...

180

```

181

182

#### Diagonal Plus Low-Rank Linear

183

184

Linear transformation with diagonal plus low-rank structure.

185

186

```python { .api }

187

class DiagPlusLowRankLinear(Bijector):

188

def __init__(self, diag, u_matrix, v_matrix):

189

"""

190

Diagonal plus low-rank linear transformation.

191

192

Parameters:

193

- diag: diagonal component (array)

194

- u_matrix: U matrix for low-rank component (array)

195

- v_matrix: V matrix for low-rank component (array)

196

"""

197

198

@property

199

def diag(self): ...

200

@property

201

def u_matrix(self): ...

202

@property

203

def v_matrix(self): ...

204

```

205

206

#### Lower-Upper Triangular Affine

207

208

Affine transformation using LU decomposition.

209

210

```python { .api }

211

class LowerUpperTriangularAffine(Bijector):

212

def __init__(self, shift, lower_upper, permutation):

213

"""

214

Lower-upper triangular affine transformation.

215

216

Parameters:

217

- shift: translation vector (array)

218

- lower_upper: combined L and U matrices (array)

219

- permutation: permutation for LU decomposition (array)

220

"""

221

222

@property

223

def shift(self): ...

224

@property

225

def lower_upper(self): ...

226

@property

227

def permutation(self): ...

228

```

229

230

### Activation Function Bijectors

231

232

#### Sigmoid Bijector

233

234

Sigmoid activation function bijector.

235

236

```python { .api }

237

class Sigmoid(Bijector):

238

def __init__(self):

239

"""Sigmoid bijector mapping (-∞, ∞) to (0, 1)."""

240

```

241

242

#### Tanh Bijector

243

244

Hyperbolic tangent bijector.

245

246

```python { .api }

247

class Tanh(Bijector):

248

def __init__(self):

249

"""Tanh bijector mapping (-∞, ∞) to (-1, 1)."""

250

```

251

252

### CDF Bijectors

253

254

#### Gumbel CDF Bijector

255

256

Gumbel cumulative distribution function bijector.

257

258

```python { .api }

259

class GumbelCDF(Bijector):

260

def __init__(self):

261

"""Gumbel CDF bijector."""

262

```

263

264

### Composition and Meta-Bijectors

265

266

#### Chain Bijector

267

268

Composition of bijectors applied in reverse order.

269

270

```python { .api }

271

class Chain(Bijector):

272

def __init__(self, bijectors):

273

"""

274

Chain of bijectors.

275

276

Parameters:

277

- bijectors: sequence of bijectors to compose (applied in reverse order)

278

"""

279

280

@property

281

def bijectors(self): ...

282

```

283

284

#### Inverse Bijector

285

286

Inverts another bijector.

287

288

```python { .api }

289

class Inverse(Bijector):

290

def __init__(self, bijector):

291

"""

292

Inverse bijector.

293

294

Parameters:

295

- bijector: bijector to invert

296

"""

297

298

@property

299

def bijector(self): ...

300

```

301

302

#### Lambda Bijector

303

304

Wraps callable functions as bijectors.

305

306

```python { .api }

307

class Lambda(Bijector):

308

def __init__(self, forward_fn, inverse_fn, forward_log_det_jacobian_fn,

309

inverse_log_det_jacobian_fn=None, event_ndims_in=0, event_ndims_out=None):

310

"""

311

Lambda bijector from functions.

312

313

Parameters:

314

- forward_fn: forward transformation function

315

- inverse_fn: inverse transformation function

316

- forward_log_det_jacobian_fn: forward log Jacobian determinant function

317

- inverse_log_det_jacobian_fn: inverse log Jacobian determinant function

318

- event_ndims_in: number of input event dimensions

319

- event_ndims_out: number of output event dimensions

320

"""

321

322

@property

323

def forward_fn(self): ...

324

@property

325

def inverse_fn(self): ...

326

```

327

328

#### Block Bijector

329

330

Bijector that acts on a subset of input dimensions.

331

332

```python { .api }

333

class Block(Bijector):

334

def __init__(self, bijector, ndims):

335

"""

336

Block bijector.

337

338

Parameters:

339

- bijector: bijector to apply to subset

340

- ndims: number of dimensions to transform

341

"""

342

343

@property

344

def bijector(self): ...

345

@property

346

def ndims(self): ...

347

```

348

349

### Normalizing Flow Bijectors

350

351

#### Masked Coupling Layer

352

353

Masked coupling layer for normalizing flows.

354

355

```python { .api }

356

class MaskedCoupling(Bijector):

357

def __init__(self, mask, bijector_fn):

358

"""

359

Masked coupling layer.

360

361

Parameters:

362

- mask: binary mask for splitting input (array)

363

- bijector_fn: function that creates bijector from conditioning input

364

"""

365

366

@property

367

def mask(self): ...

368

@property

369

def bijector_fn(self): ...

370

```

371

372

#### Split Coupling Layer

373

374

Split coupling layer for normalizing flows.

375

376

```python { .api }

377

class SplitCoupling(Bijector):

378

def __init__(self, split_index, bijector_fn):

379

"""

380

Split coupling layer.

381

382

Parameters:

383

- split_index: index at which to split input

384

- bijector_fn: function that creates bijector from conditioning input

385

"""

386

387

@property

388

def split_index(self): ...

389

@property

390

def bijector_fn(self): ...

391

```

392

393

#### Rational Quadratic Spline

394

395

Rational quadratic spline bijector for flexible transformations.

396

397

```python { .api }

398

class RationalQuadraticSpline(Bijector):

399

def __init__(self, bin_widths, bin_heights, knot_slopes, range_min=-1.0, range_max=1.0):

400

"""

401

Rational quadratic spline bijector.

402

403

Parameters:

404

- bin_widths: widths of spline bins (array)

405

- bin_heights: heights of spline bins (array)

406

- knot_slopes: slopes at knot points (array)

407

- range_min: minimum of transformation range (float)

408

- range_max: maximum of transformation range (float)

409

"""

410

411

@property

412

def bin_widths(self): ...

413

@property

414

def bin_heights(self): ...

415

@property

416

def knot_slopes(self): ...

417

@property

418

def range_min(self): ...

419

@property

420

def range_max(self): ...

421

```

422

423

## Types

424

425

```python { .api }

426

from typing import Union, Callable

427

from chex import Array

428

429

BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]

430

```