0
# Bijectors
1
2
Invertible transformations with known Jacobian determinants for creating complex distributions through composition. Bijectors enable the construction of sophisticated probability models by transforming simple base distributions.
3
4
## Capabilities
5
6
### Base Bijector Class
7
8
Abstract base class defining the bijector interface.
9
10
```python { .api }
11
class Bijector:
12
def __init__(self, event_ndims_in, event_ndims_out=None, is_constant_jacobian=False, is_constant_log_det=None):
13
"""
14
Base class for bijectors.
15
16
Parameters:
17
- event_ndims_in: number of dimensions in input events
18
- event_ndims_out: number of dimensions in output events (defaults to event_ndims_in)
19
- is_constant_jacobian: whether Jacobian is constant
20
- is_constant_log_det: whether log determinant is constant
21
"""
22
23
def forward(self, x):
24
"""Forward transformation y = f(x)."""
25
26
def inverse(self, y):
27
"""Inverse transformation x = f^{-1}(y)."""
28
29
def forward_and_log_det(self, x):
30
"""Forward transformation with log determinant: (y, log|det J|)."""
31
32
def inverse_and_log_det(self, y):
33
"""Inverse transformation with log determinant: (x, log|det J^{-1}|)."""
34
35
def forward_log_det_jacobian(self, x):
36
"""Log determinant of forward Jacobian."""
37
38
def inverse_log_det_jacobian(self, y):
39
"""Log determinant of inverse Jacobian."""
40
41
def same_as(self, other):
42
"""Check equality with another bijector."""
43
44
@property
45
def event_ndims_in(self): ...
46
@property
47
def event_ndims_out(self): ...
48
@property
49
def is_constant_jacobian(self): ...
50
@property
51
def is_constant_log_det(self): ...
52
@property
53
def name(self): ...
54
```
55
56
### Affine Transformations
57
58
#### Scalar Affine Transformation
59
60
Elementwise affine transformation y = scale * x + shift.
61
62
```python { .api }
63
class ScalarAffine(Bijector):
64
def __init__(self, shift, scale=None, log_scale=None):
65
"""
66
Scalar affine transformation.
67
68
Parameters:
69
- shift: translation parameter (float or array)
70
- scale: scale parameter (float or array, mutually exclusive with log_scale)
71
- log_scale: log scale parameter (float or array, mutually exclusive with scale)
72
73
Note: Exactly one of scale or log_scale must be specified.
74
"""
75
76
@property
77
def shift(self): ...
78
@property
79
def scale(self): ...
80
@property
81
def log_scale(self): ...
82
```
83
84
#### Shift Transformation
85
86
Translation bijector y = x + shift.
87
88
```python { .api }
89
class Shift(Bijector):
90
def __init__(self, shift):
91
"""
92
Shift transformation.
93
94
Parameters:
95
- shift: translation parameter (float or array)
96
"""
97
98
@property
99
def shift(self): ...
100
```
101
102
#### Unconstrained Affine Transformation
103
104
General unconstrained affine transformation.
105
106
```python { .api }
107
class UnconstrainedAffine(Bijector):
108
def __init__(self, shift, matrix):
109
"""
110
Unconstrained affine transformation.
111
112
Parameters:
113
- shift: translation vector (array)
114
- matrix: transformation matrix (array)
115
"""
116
117
@property
118
def shift(self): ...
119
@property
120
def matrix(self): ...
121
```
122
123
### Linear Transformations
124
125
#### Diagonal Linear Transformation
126
127
Linear transformation with diagonal matrix.
128
129
```python { .api }
130
class DiagLinear(Bijector):
131
def __init__(self, diag):
132
"""
133
Diagonal linear transformation.
134
135
Parameters:
136
- diag: diagonal elements (array)
137
"""
138
139
@property
140
def diag(self): ...
141
```
142
143
#### General Linear Transformation
144
145
Linear transformation with arbitrary matrix.
146
147
```python { .api }
148
class Linear(Bijector):
149
def __init__(self, matrix):
150
"""
151
Linear transformation.
152
153
Parameters:
154
- matrix: transformation matrix (array)
155
"""
156
157
@property
158
def matrix(self): ...
159
```
160
161
#### Triangular Linear Transformation
162
163
Linear transformation with triangular matrix.
164
165
```python { .api }
166
class TriangularLinear(Bijector):
167
def __init__(self, matrix, lower=True):
168
"""
169
Triangular linear transformation.
170
171
Parameters:
172
- matrix: triangular matrix (array)
173
- lower: whether matrix is lower triangular (bool, default True)
174
"""
175
176
@property
177
def matrix(self): ...
178
@property
179
def lower(self): ...
180
```
181
182
#### Diagonal Plus Low-Rank Linear
183
184
Linear transformation with diagonal plus low-rank structure.
185
186
```python { .api }
187
class DiagPlusLowRankLinear(Bijector):
188
def __init__(self, diag, u_matrix, v_matrix):
189
"""
190
Diagonal plus low-rank linear transformation.
191
192
Parameters:
193
- diag: diagonal component (array)
194
- u_matrix: U matrix for low-rank component (array)
195
- v_matrix: V matrix for low-rank component (array)
196
"""
197
198
@property
199
def diag(self): ...
200
@property
201
def u_matrix(self): ...
202
@property
203
def v_matrix(self): ...
204
```
205
206
#### Lower-Upper Triangular Affine
207
208
Affine transformation using LU decomposition.
209
210
```python { .api }
211
class LowerUpperTriangularAffine(Bijector):
212
def __init__(self, shift, lower_upper, permutation):
213
"""
214
Lower-upper triangular affine transformation.
215
216
Parameters:
217
- shift: translation vector (array)
218
- lower_upper: combined L and U matrices (array)
219
- permutation: permutation for LU decomposition (array)
220
"""
221
222
@property
223
def shift(self): ...
224
@property
225
def lower_upper(self): ...
226
@property
227
def permutation(self): ...
228
```
229
230
### Activation Function Bijectors
231
232
#### Sigmoid Bijector
233
234
Sigmoid activation function bijector.
235
236
```python { .api }
237
class Sigmoid(Bijector):
238
def __init__(self):
239
"""Sigmoid bijector mapping (-∞, ∞) to (0, 1)."""
240
```
241
242
#### Tanh Bijector
243
244
Hyperbolic tangent bijector.
245
246
```python { .api }
247
class Tanh(Bijector):
248
def __init__(self):
249
"""Tanh bijector mapping (-∞, ∞) to (-1, 1)."""
250
```
251
252
### CDF Bijectors
253
254
#### Gumbel CDF Bijector
255
256
Gumbel cumulative distribution function bijector.
257
258
```python { .api }
259
class GumbelCDF(Bijector):
260
def __init__(self):
261
"""Gumbel CDF bijector."""
262
```
263
264
### Composition and Meta-Bijectors
265
266
#### Chain Bijector
267
268
Composition of bijectors applied in reverse order.
269
270
```python { .api }
271
class Chain(Bijector):
272
def __init__(self, bijectors):
273
"""
274
Chain of bijectors.
275
276
Parameters:
277
- bijectors: sequence of bijectors to compose (applied in reverse order)
278
"""
279
280
@property
281
def bijectors(self): ...
282
```
283
284
#### Inverse Bijector
285
286
Inverts another bijector.
287
288
```python { .api }
289
class Inverse(Bijector):
290
def __init__(self, bijector):
291
"""
292
Inverse bijector.
293
294
Parameters:
295
- bijector: bijector to invert
296
"""
297
298
@property
299
def bijector(self): ...
300
```
301
302
#### Lambda Bijector
303
304
Wraps callable functions as bijectors.
305
306
```python { .api }
307
class Lambda(Bijector):
308
def __init__(self, forward_fn, inverse_fn, forward_log_det_jacobian_fn,
309
inverse_log_det_jacobian_fn=None, event_ndims_in=0, event_ndims_out=None):
310
"""
311
Lambda bijector from functions.
312
313
Parameters:
314
- forward_fn: forward transformation function
315
- inverse_fn: inverse transformation function
316
- forward_log_det_jacobian_fn: forward log Jacobian determinant function
317
- inverse_log_det_jacobian_fn: inverse log Jacobian determinant function
318
- event_ndims_in: number of input event dimensions
319
- event_ndims_out: number of output event dimensions
320
"""
321
322
@property
323
def forward_fn(self): ...
324
@property
325
def inverse_fn(self): ...
326
```
327
328
#### Block Bijector
329
330
Bijector that acts on a subset of input dimensions.
331
332
```python { .api }
333
class Block(Bijector):
334
def __init__(self, bijector, ndims):
335
"""
336
Block bijector.
337
338
Parameters:
339
- bijector: bijector to apply to subset
340
- ndims: number of dimensions to transform
341
"""
342
343
@property
344
def bijector(self): ...
345
@property
346
def ndims(self): ...
347
```
348
349
### Normalizing Flow Bijectors
350
351
#### Masked Coupling Layer
352
353
Masked coupling layer for normalizing flows.
354
355
```python { .api }
356
class MaskedCoupling(Bijector):
357
def __init__(self, mask, bijector_fn):
358
"""
359
Masked coupling layer.
360
361
Parameters:
362
- mask: binary mask for splitting input (array)
363
- bijector_fn: function that creates bijector from conditioning input
364
"""
365
366
@property
367
def mask(self): ...
368
@property
369
def bijector_fn(self): ...
370
```
371
372
#### Split Coupling Layer
373
374
Split coupling layer for normalizing flows.
375
376
```python { .api }
377
class SplitCoupling(Bijector):
378
def __init__(self, split_index, bijector_fn):
379
"""
380
Split coupling layer.
381
382
Parameters:
383
- split_index: index at which to split input
384
- bijector_fn: function that creates bijector from conditioning input
385
"""
386
387
@property
388
def split_index(self): ...
389
@property
390
def bijector_fn(self): ...
391
```
392
393
#### Rational Quadratic Spline
394
395
Rational quadratic spline bijector for flexible transformations.
396
397
```python { .api }
398
class RationalQuadraticSpline(Bijector):
399
def __init__(self, bin_widths, bin_heights, knot_slopes, range_min=-1.0, range_max=1.0):
400
"""
401
Rational quadratic spline bijector.
402
403
Parameters:
404
- bin_widths: widths of spline bins (array)
405
- bin_heights: heights of spline bins (array)
406
- knot_slopes: slopes at knot points (array)
407
- range_min: minimum of transformation range (float)
408
- range_max: maximum of transformation range (float)
409
"""
410
411
@property
412
def bin_widths(self): ...
413
@property
414
def bin_heights(self): ...
415
@property
416
def knot_slopes(self): ...
417
@property
418
def range_min(self): ...
419
@property
420
def range_max(self): ...
421
```
422
423
## Types
424
425
```python { .api }
426
from typing import Union, Callable
427
from chex import Array
428
429
BijectorLike = Union[Bijector, 'tfb.Bijector', Callable[[Array], Array]]
430
```