0
# Transforms and Constraints
1
2
Bijective transformations and parameter constraints for reparametrization, constrained optimization, and normalizing flows in probabilistic models, enabling flexible and efficient inference over constrained parameter spaces.
3
4
## Capabilities
5
6
### Parameter Constraints
7
8
Constraints that define valid parameter domains and enable automatic constraint handling during optimization.
9
10
```python { .api }
11
class Constraint:
12
"""
13
Base class for parameter constraints.
14
15
Constraints define the valid domain for parameters and provide
16
methods for checking constraint satisfaction and projecting
17
values onto the constraint set.
18
"""
19
20
def check(self, value: torch.Tensor) -> torch.Tensor:
21
"""
22
Check if value satisfies the constraint.
23
24
Parameters:
25
- value (Tensor): Value to check
26
27
Returns:
28
Tensor: Boolean tensor indicating constraint satisfaction
29
"""
30
31
def is_discrete(self) -> bool:
32
"""Whether this constraint is over discrete values."""
33
34
def event_dim(self) -> int:
35
"""Number of rightmost dimensions that are part of the event."""
36
37
# Basic constraints
38
real: Constraint # Unconstrained real numbers
39
boolean: Constraint # Boolean values {0, 1}
40
nonnegative: Constraint # Non-negative real numbers [0, ∞)
41
positive: Constraint # Positive real numbers (0, ∞)
42
unit_interval: Constraint # Unit interval [0, 1]
43
nonnegative_integer: Constraint # Non-negative integers {0, 1, 2, ...}
44
positive_integer: Constraint # Positive integers {1, 2, 3, ...}
45
46
# Interval constraints
47
def greater_than(lower_bound: float) -> Constraint:
48
"""
49
Constraint for values greater than a lower bound.
50
51
Parameters:
52
- lower_bound (float): Lower bound (exclusive)
53
54
Returns:
55
Constraint: Greater than constraint
56
57
Examples:
58
>>> constraint = constraints.greater_than(0.0) # Positive values
59
>>> constraint = constraints.greater_than(-1.0) # Values > -1
60
"""
61
62
def less_than(upper_bound: float) -> Constraint:
63
"""
64
Constraint for values less than an upper bound.
65
66
Parameters:
67
- upper_bound (float): Upper bound (exclusive)
68
69
Returns:
70
Constraint: Less than constraint
71
"""
72
73
def interval(lower_bound: float, upper_bound: float) -> Constraint:
74
"""
75
Constraint for values in an interval.
76
77
Parameters:
78
- lower_bound (float): Lower bound (inclusive)
79
- upper_bound (float): Upper bound (exclusive)
80
81
Returns:
82
Constraint: Interval constraint
83
84
Examples:
85
>>> constraint = constraints.interval(-1.0, 1.0) # Values in [-1, 1)
86
"""
87
88
# Matrix constraints
89
simplex: Constraint # Probability simplex (non-negative, sum to 1)
90
positive_definite: Constraint # Positive definite matrices
91
lower_cholesky: Constraint # Lower triangular matrices with positive diagonal
92
corr_cholesky: Constraint # Cholesky factors of correlation matrices
93
94
# Pyro-specific constraints
95
integer: Constraint # Integer values
96
sphere: Constraint # Unit sphere constraint
97
corr_matrix: Constraint # Correlation matrices
98
ordered_vector: Constraint # Ordered vectors (x[i] <= x[i+1])
99
positive_ordered_vector: Constraint # Positive ordered vectors
100
softplus_positive: Constraint # Softplus-transformed positive values
101
softplus_lower_cholesky: Constraint # Softplus-transformed lower Cholesky
102
unit_lower_cholesky: Constraint # Unit lower Cholesky constraint
103
104
# Composite constraints
105
def independent(constraint: Constraint, reinterpreted_batch_ndims: int) -> Constraint:
106
"""
107
Reinterpret batch dimensions as event dimensions for a constraint.
108
109
Parameters:
110
- constraint (Constraint): Base constraint
111
- reinterpreted_batch_ndims (int): Number of batch dims to treat as event dims
112
113
Returns:
114
Constraint: Independent constraint
115
116
Examples:
117
>>> # Vector of positive values
118
>>> constraint = constraints.independent(constraints.positive, 1)
119
"""
120
121
def stack(constraints: List[Constraint], dim: int = 0) -> Constraint:
122
"""
123
Stack multiple constraints along a dimension.
124
125
Parameters:
126
- constraints (List[Constraint]): Constraints to stack
127
- dim (int): Dimension to stack along
128
129
Returns:
130
Constraint: Stacked constraint
131
"""
132
```
133
134
### Basic Transforms
135
136
Fundamental bijective transformations for reparametrization and normalizing flows.
137
138
```python { .api }
139
class Transform:
140
"""
141
Base class for bijective transformations.
142
143
Transforms provide bijective mappings between different parameter spaces,
144
enabling reparametrization tricks and normalizing flows.
145
"""
146
147
def __call__(self, x: torch.Tensor) -> torch.Tensor:
148
"""
149
Forward transformation.
150
151
Parameters:
152
- x (Tensor): Input tensor
153
154
Returns:
155
Tensor: Transformed tensor
156
"""
157
158
def inv(self, y: torch.Tensor) -> torch.Tensor:
159
"""
160
Inverse transformation.
161
162
Parameters:
163
- y (Tensor): Transformed tensor
164
165
Returns:
166
Tensor: Original tensor
167
"""
168
169
def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
170
"""
171
Log absolute determinant of the Jacobian matrix.
172
173
Parameters:
174
- x (Tensor): Input tensor
175
- y (Tensor): Output tensor (usually result of __call__(x))
176
177
Returns:
178
Tensor: Log absolute Jacobian determinant
179
"""
180
181
def with_cache(self) -> 'Transform':
182
"""Enable caching of forward/inverse computations."""
183
184
# Identity transform
185
identity_transform: Transform # Identity transformation (no-op)
186
187
class ExpTransform(Transform):
188
"""
189
Exponential transform: y = exp(x).
190
191
Maps real numbers to positive numbers. Commonly used for
192
ensuring positivity constraints.
193
194
Examples:
195
>>> transform = ExpTransform()
196
>>> x = torch.tensor([-1.0, 0.0, 1.0])
197
>>> y = transform(x) # [exp(-1), 1, exp(1)]
198
>>> x_recovered = transform.inv(y)
199
"""
200
201
class SigmoidTransform(Transform):
202
"""
203
Sigmoid transform: y = sigmoid(x) = 1 / (1 + exp(-x)).
204
205
Maps real numbers to the unit interval (0, 1). Useful for
206
probability parameters.
207
"""
208
209
class TanhTransform(Transform):
210
"""
211
Hyperbolic tangent transform: y = tanh(x).
212
213
Maps real numbers to the interval (-1, 1).
214
"""
215
216
class SoftmaxTransform(Transform):
217
"""
218
Softmax transform for probability simplices.
219
220
Maps unconstrained vectors to probability simplices where
221
components are non-negative and sum to 1.
222
"""
223
224
class StickBreakingTransform(Transform):
225
"""
226
Stick-breaking transform for probability simplices.
227
228
Alternative to softmax that constructs probability vectors
229
using the stick-breaking construction.
230
"""
231
232
class AffineTransform(Transform):
233
"""
234
Affine transformation: y = scale * x + loc.
235
236
Linear transformation with location and scale parameters.
237
"""
238
239
def __init__(self, loc: torch.Tensor, scale: torch.Tensor, event_dim: int = 0):
240
"""
241
Parameters:
242
- loc (Tensor): Location/shift parameter
243
- scale (Tensor): Scale parameter
244
- event_dim (int): Number of rightmost event dimensions
245
246
Examples:
247
>>> # Standardization transform
248
>>> transform = AffineTransform(loc=-mean, scale=1/std)
249
>>>
250
>>> # Scale and shift
251
>>> transform = AffineTransform(loc=5.0, scale=2.0)
252
"""
253
254
class PowerTransform(Transform):
255
"""
256
Power transform: y = sign(x) * |x|^exponent.
257
258
Generalizes square and cube transformations.
259
"""
260
261
def __init__(self, exponent: float):
262
"""
263
Parameters:
264
- exponent (float): Power exponent
265
"""
266
267
class AbsTransform(Transform):
268
"""
269
Absolute value transform: y = |x|.
270
271
Maps real numbers to non-negative numbers.
272
"""
273
```
274
275
### Constraint-Based Transforms
276
277
Transforms that map between unconstrained and constrained parameter spaces.
278
279
```python { .api }
280
class SoftplusTransform(Transform):
281
"""
282
Softplus transform: y = log(1 + exp(x)).
283
284
Smooth approximation to ReLU that maps real numbers to positive numbers.
285
More numerically stable than exp() for large x.
286
287
Examples:
288
>>> transform = SoftplusTransform()
289
>>> constraint = constraints.positive
290
>>> # Use together for constrained parameters
291
"""
292
293
class CholeskyTransform(Transform):
294
"""
295
Transform to Cholesky decomposition of positive definite matrices.
296
297
Maps unconstrained matrices to lower triangular matrices with
298
positive diagonal elements.
299
"""
300
301
class CorrCholeskyTransform(Transform):
302
"""
303
Transform to Cholesky factor of correlation matrices.
304
305
Maps unconstrained vectors to Cholesky factors of correlation
306
matrices (unit diagonal).
307
"""
308
309
class LowerCholeskyTransform(Transform):
310
"""
311
Transform to lower triangular matrices with positive diagonal.
312
313
Ensures the result is a valid Cholesky factor.
314
"""
315
316
class OrderedTransform(Transform):
317
"""
318
Transform to ordered vectors where x[i] <= x[i+1].
319
320
Useful for ordered parameters like quantiles or cutpoints.
321
322
Examples:
323
>>> transform = OrderedTransform()
324
>>> x = torch.randn(5) # Unconstrained
325
>>> y = transform(x) # Ordered: y[0] <= y[1] <= ... <= y[4]
326
"""
327
328
class SimplexToOrderedTransform(Transform):
329
"""
330
Transform from probability simplex to ordered vector.
331
332
Maps probability vectors to their cumulative sums (quantiles).
333
"""
334
335
def biject_to(constraint: Constraint) -> Transform:
336
"""
337
Get bijective transform to a constrained space.
338
339
Returns the appropriate transform that maps from unconstrained
340
real numbers to the specified constraint space.
341
342
Parameters:
343
- constraint (Constraint): Target constraint
344
345
Returns:
346
Transform: Bijective transform to constraint space
347
348
Examples:
349
>>> # Transform to positive reals
350
>>> transform = biject_to(constraints.positive) # Returns ExpTransform
351
>>>
352
>>> # Transform to unit interval
353
>>> transform = biject_to(constraints.unit_interval) # Returns SigmoidTransform
354
>>>
355
>>> # Transform to probability simplex
356
>>> transform = biject_to(constraints.simplex) # Returns StickBreakingTransform
357
"""
358
359
def transform_to(constraint: Constraint) -> Transform:
360
"""
361
Alias for biject_to() for backward compatibility.
362
363
Parameters:
364
- constraint (Constraint): Target constraint
365
366
Returns:
367
Transform: Transform to constraint space
368
"""
369
```
370
371
### Normalizing Flows
372
373
Advanced transforms for flexible density modeling and variational inference.
374
375
```python { .api }
376
class ComposeTransform(Transform):
377
"""
378
Compose multiple transforms sequentially.
379
380
Chains transforms together: f3(f2(f1(x))) for transforms [f1, f2, f3].
381
"""
382
383
def __init__(self, parts: List[Transform]):
384
"""
385
Parameters:
386
- parts (List[Transform]): List of transforms to compose
387
388
Examples:
389
>>> # Compose affine and exponential transforms
390
>>> transform = ComposeTransform([
391
... AffineTransform(loc=0.0, scale=2.0),
392
... ExpTransform()
393
... ])
394
"""
395
396
class ConditionalTransform(Transform):
397
"""
398
Base class for transforms that depend on context/conditioning variables.
399
400
Enables context-dependent transformations for conditional normalizing flows.
401
"""
402
403
def condition(self, context: torch.Tensor) -> Transform:
404
"""
405
Condition the transform on context variables.
406
407
Parameters:
408
- context (Tensor): Context/conditioning variables
409
410
Returns:
411
Transform: Conditioned transform
412
"""
413
414
class AffineAutoregressive(Transform):
415
"""
416
Affine autoregressive transform for normalizing flows.
417
418
Implements Real NVP-style coupling layers with affine transformations
419
that preserve autoregressive structure.
420
"""
421
422
def __init__(self, autoregressive_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):
423
"""
424
Parameters:
425
- autoregressive_nn (Module): Neural network that outputs scale and shift
426
- log_scale_min_clip (float): Minimum value for log scale to prevent numerical issues
427
428
Examples:
429
>>> from pyro.nn import AutoRegressiveNN
430
>>> ar_nn = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
431
>>> transform = AffineAutoregressive(ar_nn)
432
"""
433
434
class AffineCoupling(Transform):
435
"""
436
Affine coupling transform for normalizing flows.
437
438
Implements coupling layers where some dimensions are transformed
439
as functions of other dimensions.
440
"""
441
442
def __init__(self, split_dim: int, hypernet: torch.nn.Module, log_scale_min_clip: float = -5.0):
443
"""
444
Parameters:
445
- split_dim (int): Dimension to split for coupling
446
- hypernet (Module): Network that computes transformation parameters
447
- log_scale_min_clip (float): Minimum log scale value
448
"""
449
450
class Spline(Transform):
451
"""
452
Monotonic rational-quadratic spline transform.
453
454
Implements neural spline flows with rational-quadratic splines
455
for flexible and invertible transformations.
456
"""
457
458
def __init__(self, widths: torch.Tensor, heights: torch.Tensor,
459
derivatives: torch.Tensor, bound: float = 3.0):
460
"""
461
Parameters:
462
- widths (Tensor): Spline bin widths
463
- heights (Tensor): Spline bin heights
464
- derivatives (Tensor): Spline derivatives at knots
465
- bound (float): Domain bound for the spline
466
"""
467
468
class SplineAutoregressive(Transform):
469
"""
470
Autoregressive spline transform for normalizing flows.
471
472
Combines spline transformations with autoregressive structure
473
for flexible density modeling.
474
"""
475
476
def __init__(self, input_dim: int, autoregressive_nn: torch.nn.Module,
477
count_bins: int = 8, bound: float = 3.0):
478
"""
479
Parameters:
480
- input_dim (int): Input dimension
481
- autoregressive_nn (Module): Neural network for autoregressive parameters
482
- count_bins (int): Number of spline bins
483
- bound (float): Spline domain bound
484
"""
485
486
class Planar(Transform):
487
"""
488
Planar normalizing flow transform.
489
490
Implements planar flows for variational inference with flexible
491
posterior approximations.
492
"""
493
494
def __init__(self, input_dim: int):
495
"""
496
Parameters:
497
- input_dim (int): Input dimension
498
499
Examples:
500
>>> planar = Planar(10)
501
>>> # Use in normalizing flow
502
>>> flows = [Planar(10) for _ in range(5)]
503
>>> flow = ComposeTransform(flows)
504
"""
505
506
class Radial(Transform):
507
"""
508
Radial normalizing flow transform.
509
510
Implements radial flows that apply transformations based on
511
distance from a reference point.
512
"""
513
514
def __init__(self, input_dim: int):
515
"""
516
Parameters:
517
- input_dim (int): Input dimension
518
"""
519
520
class Householder(Transform):
521
"""
522
Householder normalizing flow transform.
523
524
Uses Householder reflections for volume-preserving transformations
525
in normalizing flows.
526
"""
527
528
def __init__(self, input_dim: int, count_transforms: int = 1):
529
"""
530
Parameters:
531
- input_dim (int): Input dimension
532
- count_transforms (int): Number of Householder transforms to compose
533
"""
534
```
535
536
### Conditional Transforms
537
538
Transforms that depend on context variables for conditional density modeling.
539
540
```python { .api }
541
class ConditionalAffineAutoregressive(ConditionalTransform):
542
"""
543
Conditional version of affine autoregressive transform.
544
545
Autoregressive transform that conditions on additional context variables.
546
"""
547
548
def __init__(self, context_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):
549
"""
550
Parameters:
551
- context_nn (Module): Neural network that takes context and outputs parameters
552
- log_scale_min_clip (float): Minimum log scale value
553
"""
554
555
class ConditionalAffineCoupling(ConditionalTransform):
556
"""
557
Conditional version of affine coupling transform.
558
559
Coupling transform that conditions on context variables.
560
"""
561
562
def __init__(self, split_dim: int, context_nn: torch.nn.Module):
563
"""
564
Parameters:
565
- split_dim (int): Dimension to split for coupling
566
- context_nn (Module): Context-dependent neural network
567
"""
568
569
class ConditionalSpline(ConditionalTransform):
570
"""
571
Conditional spline transform with context dependence.
572
573
Spline transform where spline parameters depend on context variables.
574
"""
575
576
def __init__(self, input_dim: int, context_dim: int, count_bins: int = 8,
577
bound: float = 3.0, hidden_dims: List[int] = None):
578
"""
579
Parameters:
580
- input_dim (int): Input dimension
581
- context_dim (int): Context dimension
582
- count_bins (int): Number of spline bins
583
- bound (float): Spline domain bound
584
- hidden_dims (List[int]): Hidden dimensions for context network
585
"""
586
587
class ConditionalPlanar(ConditionalTransform):
588
"""
589
Conditional planar flow with context dependence.
590
591
Planar flow where transformation parameters are functions of context.
592
"""
593
594
def __init__(self, input_dim: int, context_dim: int):
595
"""
596
Parameters:
597
- input_dim (int): Input dimension
598
- context_dim (int): Context dimension
599
"""
600
```
601
602
### Utility Functions
603
604
Helper functions for working with transforms and constraints.
605
606
```python { .api }
607
def iterated(repeats: int, base_fn: callable, *args, **kwargs) -> Transform:
608
"""
609
Create iterated composition of transforms.
610
611
Applies the same transform multiple times in sequence.
612
613
Parameters:
614
- repeats (int): Number of repetitions
615
- base_fn (callable): Function that creates base transform
616
- *args, **kwargs: Arguments for base transform constructor
617
618
Returns:
619
Transform: Composed transform
620
621
Examples:
622
>>> # Create 5 repeated planar flows
623
>>> flow = iterated(5, Planar, input_dim=10)
624
"""
625
626
def permute(permutation: torch.Tensor) -> Transform:
627
"""
628
Create permutation transform.
629
630
Parameters:
631
- permutation (Tensor): Permutation indices
632
633
Returns:
634
Transform: Permutation transform
635
"""
636
637
def reshape(input_shape: torch.Size, output_shape: torch.Size) -> Transform:
638
"""
639
Create reshape transform.
640
641
Parameters:
642
- input_shape (Size): Input tensor shape
643
- output_shape (Size): Output tensor shape
644
645
Returns:
646
Transform: Reshape transform
647
"""
648
```
649
650
## Examples
651
652
### Constrained Parameter Optimization
653
654
```python
655
import pyro
656
import pyro.distributions as dist
657
import torch
658
659
def model():
660
# Positive parameter using constraint
661
sigma = pyro.param("sigma", torch.tensor(1.0),
662
constraint=constraints.positive)
663
664
# Probability parameter
665
p = pyro.param("p", torch.tensor(0.5),
666
constraint=constraints.unit_interval)
667
668
# Simplex parameter (probabilities that sum to 1)
669
probs = pyro.param("probs", torch.ones(5) / 5,
670
constraint=constraints.simplex)
671
672
return pyro.sample("x", dist.Normal(0, sigma))
673
```
674
675
### Manual Transform Usage
676
677
```python
678
# Transform between unconstrained and constrained spaces
679
constraint = constraints.positive
680
transform = biject_to(constraint)
681
682
# Unconstrained parameter
683
unconstrained_param = torch.tensor(-1.0)
684
685
# Transform to positive space
686
positive_param = transform(unconstrained_param) # exp(-1.0)
687
688
# Transform back
689
recovered = transform.inv(positive_param) # -1.0
690
691
# Jacobian for change of variables
692
log_det_J = transform.log_abs_det_jacobian(unconstrained_param, positive_param)
693
```
694
695
### Normalizing Flow
696
697
```python
698
from pyro.distributions.transforms import AffineAutoregressive, ComposeTransform
699
from pyro.nn import AutoRegressiveNN
700
701
# Create autoregressive neural networks
702
ar_nn1 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
703
ar_nn2 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
704
705
# Create flow transforms
706
flow_transforms = [
707
AffineAutoregressive(ar_nn1),
708
Permute(torch.randperm(10)), # Permutation between layers
709
AffineAutoregressive(ar_nn2)
710
]
711
712
# Compose into normalizing flow
713
flow_transform = ComposeTransform(flow_transforms)
714
715
# Use in transformed distribution
716
base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
717
flow_dist = dist.TransformedDistribution(base_dist, flow_transform)
718
719
# Sample from flow
720
samples = flow_dist.sample((1000,))
721
log_probs = flow_dist.log_prob(samples)
722
```
723
724
### Conditional Normalizing Flow
725
726
```python
727
# Conditional flow for context-dependent transformations
728
context_dim = 5
729
input_dim = 10
730
731
conditional_transform = ConditionalAffineAutoregressive(
732
ConditionalAutoRegressiveNN(input_dim, context_dim, [64, 64],
733
output_dim_multiplier=2)
734
)
735
736
# Condition on context
737
context = torch.randn(32, context_dim) # Batch of contexts
738
conditioned_transform = conditional_transform.condition(context)
739
740
# Use in model
741
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
742
conditional_dist = dist.TransformedDistribution(base_dist, conditioned_transform)
743
```