or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-programming.mddistributions.mdgaussian-processes.mdindex.mdinference.mdneural-networks.mdoptimization.mdtransforms-constraints.md

transforms-constraints.mddocs/

0

# Transforms and Constraints

1

2

Bijective transformations and parameter constraints for reparametrization, constrained optimization, and normalizing flows in probabilistic models, enabling flexible and efficient inference over constrained parameter spaces.

3

4

## Capabilities

5

6

### Parameter Constraints

7

8

Constraints that define valid parameter domains and enable automatic constraint handling during optimization.

9

10

```python { .api }

11

class Constraint:

12

"""

13

Base class for parameter constraints.

14

15

Constraints define the valid domain for parameters and provide

16

methods for checking constraint satisfaction and projecting

17

values onto the constraint set.

18

"""

19

20

def check(self, value: torch.Tensor) -> torch.Tensor:

21

"""

22

Check if value satisfies the constraint.

23

24

Parameters:

25

- value (Tensor): Value to check

26

27

Returns:

28

Tensor: Boolean tensor indicating constraint satisfaction

29

"""

30

31

def is_discrete(self) -> bool:

32

"""Whether this constraint is over discrete values."""

33

34

def event_dim(self) -> int:

35

"""Number of rightmost dimensions that are part of the event."""

36

37

# Basic constraints

38

real: Constraint # Unconstrained real numbers

39

boolean: Constraint # Boolean values {0, 1}

40

nonnegative: Constraint # Non-negative real numbers [0, ∞)

41

positive: Constraint # Positive real numbers (0, ∞)

42

unit_interval: Constraint # Unit interval [0, 1]

43

nonnegative_integer: Constraint # Non-negative integers {0, 1, 2, ...}

44

positive_integer: Constraint # Positive integers {1, 2, 3, ...}

45

46

# Interval constraints

47

def greater_than(lower_bound: float) -> Constraint:

48

"""

49

Constraint for values greater than a lower bound.

50

51

Parameters:

52

- lower_bound (float): Lower bound (exclusive)

53

54

Returns:

55

Constraint: Greater than constraint

56

57

Examples:

58

>>> constraint = constraints.greater_than(0.0) # Positive values

59

>>> constraint = constraints.greater_than(-1.0) # Values > -1

60

"""

61

62

def less_than(upper_bound: float) -> Constraint:

63

"""

64

Constraint for values less than an upper bound.

65

66

Parameters:

67

- upper_bound (float): Upper bound (exclusive)

68

69

Returns:

70

Constraint: Less than constraint

71

"""

72

73

def interval(lower_bound: float, upper_bound: float) -> Constraint:

74

"""

75

Constraint for values in an interval.

76

77

Parameters:

78

- lower_bound (float): Lower bound (inclusive)

79

- upper_bound (float): Upper bound (exclusive)

80

81

Returns:

82

Constraint: Interval constraint

83

84

Examples:

85

>>> constraint = constraints.interval(-1.0, 1.0) # Values in [-1, 1)

86

"""

87

88

# Matrix constraints

89

simplex: Constraint # Probability simplex (non-negative, sum to 1)

90

positive_definite: Constraint # Positive definite matrices

91

lower_cholesky: Constraint # Lower triangular matrices with positive diagonal

92

corr_cholesky: Constraint # Cholesky factors of correlation matrices

93

94

# Pyro-specific constraints

95

integer: Constraint # Integer values

96

sphere: Constraint # Unit sphere constraint

97

corr_matrix: Constraint # Correlation matrices

98

ordered_vector: Constraint # Ordered vectors (x[i] <= x[i+1])

99

positive_ordered_vector: Constraint # Positive ordered vectors

100

softplus_positive: Constraint # Softplus-transformed positive values

101

softplus_lower_cholesky: Constraint # Softplus-transformed lower Cholesky

102

unit_lower_cholesky: Constraint # Unit lower Cholesky constraint

103

104

# Composite constraints

105

def independent(constraint: Constraint, reinterpreted_batch_ndims: int) -> Constraint:

106

"""

107

Reinterpret batch dimensions as event dimensions for a constraint.

108

109

Parameters:

110

- constraint (Constraint): Base constraint

111

- reinterpreted_batch_ndims (int): Number of batch dims to treat as event dims

112

113

Returns:

114

Constraint: Independent constraint

115

116

Examples:

117

>>> # Vector of positive values

118

>>> constraint = constraints.independent(constraints.positive, 1)

119

"""

120

121

def stack(constraints: List[Constraint], dim: int = 0) -> Constraint:

122

"""

123

Stack multiple constraints along a dimension.

124

125

Parameters:

126

- constraints (List[Constraint]): Constraints to stack

127

- dim (int): Dimension to stack along

128

129

Returns:

130

Constraint: Stacked constraint

131

"""

132

```

133

134

### Basic Transforms

135

136

Fundamental bijective transformations for reparametrization and normalizing flows.

137

138

```python { .api }

139

class Transform:

140

"""

141

Base class for bijective transformations.

142

143

Transforms provide bijective mappings between different parameter spaces,

144

enabling reparametrization tricks and normalizing flows.

145

"""

146

147

def __call__(self, x: torch.Tensor) -> torch.Tensor:

148

"""

149

Forward transformation.

150

151

Parameters:

152

- x (Tensor): Input tensor

153

154

Returns:

155

Tensor: Transformed tensor

156

"""

157

158

def inv(self, y: torch.Tensor) -> torch.Tensor:

159

"""

160

Inverse transformation.

161

162

Parameters:

163

- y (Tensor): Transformed tensor

164

165

Returns:

166

Tensor: Original tensor

167

"""

168

169

def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

170

"""

171

Log absolute determinant of the Jacobian matrix.

172

173

Parameters:

174

- x (Tensor): Input tensor

175

- y (Tensor): Output tensor (usually result of __call__(x))

176

177

Returns:

178

Tensor: Log absolute Jacobian determinant

179

"""

180

181

def with_cache(self) -> 'Transform':

182

"""Enable caching of forward/inverse computations."""

183

184

# Identity transform

185

identity_transform: Transform # Identity transformation (no-op)

186

187

class ExpTransform(Transform):

188

"""

189

Exponential transform: y = exp(x).

190

191

Maps real numbers to positive numbers. Commonly used for

192

ensuring positivity constraints.

193

194

Examples:

195

>>> transform = ExpTransform()

196

>>> x = torch.tensor([-1.0, 0.0, 1.0])

197

>>> y = transform(x) # [exp(-1), 1, exp(1)]

198

>>> x_recovered = transform.inv(y)

199

"""

200

201

class SigmoidTransform(Transform):

202

"""

203

Sigmoid transform: y = sigmoid(x) = 1 / (1 + exp(-x)).

204

205

Maps real numbers to the unit interval (0, 1). Useful for

206

probability parameters.

207

"""

208

209

class TanhTransform(Transform):

210

"""

211

Hyperbolic tangent transform: y = tanh(x).

212

213

Maps real numbers to the interval (-1, 1).

214

"""

215

216

class SoftmaxTransform(Transform):

217

"""

218

Softmax transform for probability simplices.

219

220

Maps unconstrained vectors to probability simplices where

221

components are non-negative and sum to 1.

222

"""

223

224

class StickBreakingTransform(Transform):

225

"""

226

Stick-breaking transform for probability simplices.

227

228

Alternative to softmax that constructs probability vectors

229

using the stick-breaking construction.

230

"""

231

232

class AffineTransform(Transform):

233

"""

234

Affine transformation: y = scale * x + loc.

235

236

Linear transformation with location and scale parameters.

237

"""

238

239

def __init__(self, loc: torch.Tensor, scale: torch.Tensor, event_dim: int = 0):

240

"""

241

Parameters:

242

- loc (Tensor): Location/shift parameter

243

- scale (Tensor): Scale parameter

244

- event_dim (int): Number of rightmost event dimensions

245

246

Examples:

247

>>> # Standardization transform

248

>>> transform = AffineTransform(loc=-mean, scale=1/std)

249

>>>

250

>>> # Scale and shift

251

>>> transform = AffineTransform(loc=5.0, scale=2.0)

252

"""

253

254

class PowerTransform(Transform):

255

"""

256

Power transform: y = sign(x) * |x|^exponent.

257

258

Generalizes square and cube transformations.

259

"""

260

261

def __init__(self, exponent: float):

262

"""

263

Parameters:

264

- exponent (float): Power exponent

265

"""

266

267

class AbsTransform(Transform):

268

"""

269

Absolute value transform: y = |x|.

270

271

Maps real numbers to non-negative numbers.

272

"""

273

```

274

275

### Constraint-Based Transforms

276

277

Transforms that map between unconstrained and constrained parameter spaces.

278

279

```python { .api }

280

class SoftplusTransform(Transform):

281

"""

282

Softplus transform: y = log(1 + exp(x)).

283

284

Smooth approximation to ReLU that maps real numbers to positive numbers.

285

More numerically stable than exp() for large x.

286

287

Examples:

288

>>> transform = SoftplusTransform()

289

>>> constraint = constraints.positive

290

>>> # Use together for constrained parameters

291

"""

292

293

class CholeskyTransform(Transform):

294

"""

295

Transform to Cholesky decomposition of positive definite matrices.

296

297

Maps unconstrained matrices to lower triangular matrices with

298

positive diagonal elements.

299

"""

300

301

class CorrCholeskyTransform(Transform):

302

"""

303

Transform to Cholesky factor of correlation matrices.

304

305

Maps unconstrained vectors to Cholesky factors of correlation

306

matrices (unit diagonal).

307

"""

308

309

class LowerCholeskyTransform(Transform):

310

"""

311

Transform to lower triangular matrices with positive diagonal.

312

313

Ensures the result is a valid Cholesky factor.

314

"""

315

316

class OrderedTransform(Transform):

317

"""

318

Transform to ordered vectors where x[i] <= x[i+1].

319

320

Useful for ordered parameters like quantiles or cutpoints.

321

322

Examples:

323

>>> transform = OrderedTransform()

324

>>> x = torch.randn(5) # Unconstrained

325

>>> y = transform(x) # Ordered: y[0] <= y[1] <= ... <= y[4]

326

"""

327

328

class SimplexToOrderedTransform(Transform):

329

"""

330

Transform from probability simplex to ordered vector.

331

332

Maps probability vectors to their cumulative sums (quantiles).

333

"""

334

335

def biject_to(constraint: Constraint) -> Transform:

336

"""

337

Get bijective transform to a constrained space.

338

339

Returns the appropriate transform that maps from unconstrained

340

real numbers to the specified constraint space.

341

342

Parameters:

343

- constraint (Constraint): Target constraint

344

345

Returns:

346

Transform: Bijective transform to constraint space

347

348

Examples:

349

>>> # Transform to positive reals

350

>>> transform = biject_to(constraints.positive) # Returns ExpTransform

351

>>>

352

>>> # Transform to unit interval

353

>>> transform = biject_to(constraints.unit_interval) # Returns SigmoidTransform

354

>>>

355

>>> # Transform to probability simplex

356

>>> transform = biject_to(constraints.simplex) # Returns StickBreakingTransform

357

"""

358

359

def transform_to(constraint: Constraint) -> Transform:

360

"""

361

Alias for biject_to() for backward compatibility.

362

363

Parameters:

364

- constraint (Constraint): Target constraint

365

366

Returns:

367

Transform: Transform to constraint space

368

"""

369

```

370

371

### Normalizing Flows

372

373

Advanced transforms for flexible density modeling and variational inference.

374

375

```python { .api }

376

class ComposeTransform(Transform):

377

"""

378

Compose multiple transforms sequentially.

379

380

Chains transforms together: f3(f2(f1(x))) for transforms [f1, f2, f3].

381

"""

382

383

def __init__(self, parts: List[Transform]):

384

"""

385

Parameters:

386

- parts (List[Transform]): List of transforms to compose

387

388

Examples:

389

>>> # Compose affine and exponential transforms

390

>>> transform = ComposeTransform([

391

... AffineTransform(loc=0.0, scale=2.0),

392

... ExpTransform()

393

... ])

394

"""

395

396

class ConditionalTransform(Transform):

397

"""

398

Base class for transforms that depend on context/conditioning variables.

399

400

Enables context-dependent transformations for conditional normalizing flows.

401

"""

402

403

def condition(self, context: torch.Tensor) -> Transform:

404

"""

405

Condition the transform on context variables.

406

407

Parameters:

408

- context (Tensor): Context/conditioning variables

409

410

Returns:

411

Transform: Conditioned transform

412

"""

413

414

class AffineAutoregressive(Transform):

415

"""

416

Affine autoregressive transform for normalizing flows.

417

418

Implements Real NVP-style coupling layers with affine transformations

419

that preserve autoregressive structure.

420

"""

421

422

def __init__(self, autoregressive_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):

423

"""

424

Parameters:

425

- autoregressive_nn (Module): Neural network that outputs scale and shift

426

- log_scale_min_clip (float): Minimum value for log scale to prevent numerical issues

427

428

Examples:

429

>>> from pyro.nn import AutoRegressiveNN

430

>>> ar_nn = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)

431

>>> transform = AffineAutoregressive(ar_nn)

432

"""

433

434

class AffineCoupling(Transform):

435

"""

436

Affine coupling transform for normalizing flows.

437

438

Implements coupling layers where some dimensions are transformed

439

as functions of other dimensions.

440

"""

441

442

def __init__(self, split_dim: int, hypernet: torch.nn.Module, log_scale_min_clip: float = -5.0):

443

"""

444

Parameters:

445

- split_dim (int): Dimension to split for coupling

446

- hypernet (Module): Network that computes transformation parameters

447

- log_scale_min_clip (float): Minimum log scale value

448

"""

449

450

class Spline(Transform):

451

"""

452

Monotonic rational-quadratic spline transform.

453

454

Implements neural spline flows with rational-quadratic splines

455

for flexible and invertible transformations.

456

"""

457

458

def __init__(self, widths: torch.Tensor, heights: torch.Tensor,

459

derivatives: torch.Tensor, bound: float = 3.0):

460

"""

461

Parameters:

462

- widths (Tensor): Spline bin widths

463

- heights (Tensor): Spline bin heights

464

- derivatives (Tensor): Spline derivatives at knots

465

- bound (float): Domain bound for the spline

466

"""

467

468

class SplineAutoregressive(Transform):

469

"""

470

Autoregressive spline transform for normalizing flows.

471

472

Combines spline transformations with autoregressive structure

473

for flexible density modeling.

474

"""

475

476

def __init__(self, input_dim: int, autoregressive_nn: torch.nn.Module,

477

count_bins: int = 8, bound: float = 3.0):

478

"""

479

Parameters:

480

- input_dim (int): Input dimension

481

- autoregressive_nn (Module): Neural network for autoregressive parameters

482

- count_bins (int): Number of spline bins

483

- bound (float): Spline domain bound

484

"""

485

486

class Planar(Transform):

487

"""

488

Planar normalizing flow transform.

489

490

Implements planar flows for variational inference with flexible

491

posterior approximations.

492

"""

493

494

def __init__(self, input_dim: int):

495

"""

496

Parameters:

497

- input_dim (int): Input dimension

498

499

Examples:

500

>>> planar = Planar(10)

501

>>> # Use in normalizing flow

502

>>> flows = [Planar(10) for _ in range(5)]

503

>>> flow = ComposeTransform(flows)

504

"""

505

506

class Radial(Transform):

507

"""

508

Radial normalizing flow transform.

509

510

Implements radial flows that apply transformations based on

511

distance from a reference point.

512

"""

513

514

def __init__(self, input_dim: int):

515

"""

516

Parameters:

517

- input_dim (int): Input dimension

518

"""

519

520

class Householder(Transform):

521

"""

522

Householder normalizing flow transform.

523

524

Uses Householder reflections for volume-preserving transformations

525

in normalizing flows.

526

"""

527

528

def __init__(self, input_dim: int, count_transforms: int = 1):

529

"""

530

Parameters:

531

- input_dim (int): Input dimension

532

- count_transforms (int): Number of Householder transforms to compose

533

"""

534

```

535

536

### Conditional Transforms

537

538

Transforms that depend on context variables for conditional density modeling.

539

540

```python { .api }

541

class ConditionalAffineAutoregressive(ConditionalTransform):

542

"""

543

Conditional version of affine autoregressive transform.

544

545

Autoregressive transform that conditions on additional context variables.

546

"""

547

548

def __init__(self, context_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):

549

"""

550

Parameters:

551

- context_nn (Module): Neural network that takes context and outputs parameters

552

- log_scale_min_clip (float): Minimum log scale value

553

"""

554

555

class ConditionalAffineCoupling(ConditionalTransform):

556

"""

557

Conditional version of affine coupling transform.

558

559

Coupling transform that conditions on context variables.

560

"""

561

562

def __init__(self, split_dim: int, context_nn: torch.nn.Module):

563

"""

564

Parameters:

565

- split_dim (int): Dimension to split for coupling

566

- context_nn (Module): Context-dependent neural network

567

"""

568

569

class ConditionalSpline(ConditionalTransform):

570

"""

571

Conditional spline transform with context dependence.

572

573

Spline transform where spline parameters depend on context variables.

574

"""

575

576

def __init__(self, input_dim: int, context_dim: int, count_bins: int = 8,

577

bound: float = 3.0, hidden_dims: List[int] = None):

578

"""

579

Parameters:

580

- input_dim (int): Input dimension

581

- context_dim (int): Context dimension

582

- count_bins (int): Number of spline bins

583

- bound (float): Spline domain bound

584

- hidden_dims (List[int]): Hidden dimensions for context network

585

"""

586

587

class ConditionalPlanar(ConditionalTransform):

588

"""

589

Conditional planar flow with context dependence.

590

591

Planar flow where transformation parameters are functions of context.

592

"""

593

594

def __init__(self, input_dim: int, context_dim: int):

595

"""

596

Parameters:

597

- input_dim (int): Input dimension

598

- context_dim (int): Context dimension

599

"""

600

```

601

602

### Utility Functions

603

604

Helper functions for working with transforms and constraints.

605

606

```python { .api }

607

def iterated(repeats: int, base_fn: callable, *args, **kwargs) -> Transform:

608

"""

609

Create iterated composition of transforms.

610

611

Applies the same transform multiple times in sequence.

612

613

Parameters:

614

- repeats (int): Number of repetitions

615

- base_fn (callable): Function that creates base transform

616

- *args, **kwargs: Arguments for base transform constructor

617

618

Returns:

619

Transform: Composed transform

620

621

Examples:

622

>>> # Create 5 repeated planar flows

623

>>> flow = iterated(5, Planar, input_dim=10)

624

"""

625

626

def permute(permutation: torch.Tensor) -> Transform:

627

"""

628

Create permutation transform.

629

630

Parameters:

631

- permutation (Tensor): Permutation indices

632

633

Returns:

634

Transform: Permutation transform

635

"""

636

637

def reshape(input_shape: torch.Size, output_shape: torch.Size) -> Transform:

638

"""

639

Create reshape transform.

640

641

Parameters:

642

- input_shape (Size): Input tensor shape

643

- output_shape (Size): Output tensor shape

644

645

Returns:

646

Transform: Reshape transform

647

"""

648

```

649

650

## Examples

651

652

### Constrained Parameter Optimization

653

654

```python

655

import pyro

656

import pyro.distributions as dist

657

import torch

658

659

def model():

660

# Positive parameter using constraint

661

sigma = pyro.param("sigma", torch.tensor(1.0),

662

constraint=constraints.positive)

663

664

# Probability parameter

665

p = pyro.param("p", torch.tensor(0.5),

666

constraint=constraints.unit_interval)

667

668

# Simplex parameter (probabilities that sum to 1)

669

probs = pyro.param("probs", torch.ones(5) / 5,

670

constraint=constraints.simplex)

671

672

return pyro.sample("x", dist.Normal(0, sigma))

673

```

674

675

### Manual Transform Usage

676

677

```python

678

# Transform between unconstrained and constrained spaces

679

constraint = constraints.positive

680

transform = biject_to(constraint)

681

682

# Unconstrained parameter

683

unconstrained_param = torch.tensor(-1.0)

684

685

# Transform to positive space

686

positive_param = transform(unconstrained_param) # exp(-1.0)

687

688

# Transform back

689

recovered = transform.inv(positive_param) # -1.0

690

691

# Jacobian for change of variables

692

log_det_J = transform.log_abs_det_jacobian(unconstrained_param, positive_param)

693

```

694

695

### Normalizing Flow

696

697

```python

698

from pyro.distributions.transforms import AffineAutoregressive, ComposeTransform

699

from pyro.nn import AutoRegressiveNN

700

701

# Create autoregressive neural networks

702

ar_nn1 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)

703

ar_nn2 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)

704

705

# Create flow transforms

706

flow_transforms = [

707

AffineAutoregressive(ar_nn1),

708

Permute(torch.randperm(10)), # Permutation between layers

709

AffineAutoregressive(ar_nn2)

710

]

711

712

# Compose into normalizing flow

713

flow_transform = ComposeTransform(flow_transforms)

714

715

# Use in transformed distribution

716

base_dist = dist.Normal(torch.zeros(10), torch.ones(10))

717

flow_dist = dist.TransformedDistribution(base_dist, flow_transform)

718

719

# Sample from flow

720

samples = flow_dist.sample((1000,))

721

log_probs = flow_dist.log_prob(samples)

722

```

723

724

### Conditional Normalizing Flow

725

726

```python

727

# Conditional flow for context-dependent transformations

728

context_dim = 5

729

input_dim = 10

730

731

conditional_transform = ConditionalAffineAutoregressive(

732

ConditionalAutoRegressiveNN(input_dim, context_dim, [64, 64],

733

output_dim_multiplier=2)

734

)

735

736

# Condition on context

737

context = torch.randn(32, context_dim) # Batch of contexts

738

conditioned_transform = conditional_transform.condition(context)

739

740

# Use in model

741

base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))

742

conditional_dist = dist.TransformedDistribution(base_dist, conditioned_transform)

743

```