A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference
—
Bijective transformations and parameter constraints for reparametrization, constrained optimization, and normalizing flows in probabilistic models, enabling flexible and efficient inference over constrained parameter spaces.
Constraints that define valid parameter domains and enable automatic constraint handling during optimization.
class Constraint:
"""
Base class for parameter constraints.
Constraints define the valid domain for parameters and provide
methods for checking constraint satisfaction and projecting
values onto the constraint set.
"""
def check(self, value: torch.Tensor) -> torch.Tensor:
"""
Check if value satisfies the constraint.
Parameters:
- value (Tensor): Value to check
Returns:
Tensor: Boolean tensor indicating constraint satisfaction
"""
def is_discrete(self) -> bool:
"""Whether this constraint is over discrete values."""
def event_dim(self) -> int:
"""Number of rightmost dimensions that are part of the event."""
# Basic constraints
real: Constraint # Unconstrained real numbers
boolean: Constraint # Boolean values {0, 1}
nonnegative: Constraint # Non-negative real numbers [0, ∞)
positive: Constraint # Positive real numbers (0, ∞)
unit_interval: Constraint # Unit interval [0, 1]
nonnegative_integer: Constraint # Non-negative integers {0, 1, 2, ...}
positive_integer: Constraint # Positive integers {1, 2, 3, ...}
# Interval constraints
def greater_than(lower_bound: float) -> Constraint:
"""
Constraint for values greater than a lower bound.
Parameters:
- lower_bound (float): Lower bound (exclusive)
Returns:
Constraint: Greater than constraint
Examples:
>>> constraint = constraints.greater_than(0.0) # Positive values
>>> constraint = constraints.greater_than(-1.0) # Values > -1
"""
def less_than(upper_bound: float) -> Constraint:
"""
Constraint for values less than an upper bound.
Parameters:
- upper_bound (float): Upper bound (exclusive)
Returns:
Constraint: Less than constraint
"""
def interval(lower_bound: float, upper_bound: float) -> Constraint:
"""
Constraint for values in an interval.
Parameters:
- lower_bound (float): Lower bound (inclusive)
- upper_bound (float): Upper bound (exclusive)
Returns:
Constraint: Interval constraint
Examples:
>>> constraint = constraints.interval(-1.0, 1.0) # Values in [-1, 1)
"""
# Matrix constraints
simplex: Constraint # Probability simplex (non-negative, sum to 1)
positive_definite: Constraint # Positive definite matrices
lower_cholesky: Constraint # Lower triangular matrices with positive diagonal
corr_cholesky: Constraint # Cholesky factors of correlation matrices
# Pyro-specific constraints
integer: Constraint # Integer values
sphere: Constraint # Unit sphere constraint
corr_matrix: Constraint # Correlation matrices
ordered_vector: Constraint # Ordered vectors (x[i] <= x[i+1])
positive_ordered_vector: Constraint # Positive ordered vectors
softplus_positive: Constraint # Softplus-transformed positive values
softplus_lower_cholesky: Constraint # Softplus-transformed lower Cholesky
unit_lower_cholesky: Constraint # Unit lower Cholesky constraint
# Composite constraints
def independent(constraint: Constraint, reinterpreted_batch_ndims: int) -> Constraint:
"""
Reinterpret batch dimensions as event dimensions for a constraint.
Parameters:
- constraint (Constraint): Base constraint
- reinterpreted_batch_ndims (int): Number of batch dims to treat as event dims
Returns:
Constraint: Independent constraint
Examples:
>>> # Vector of positive values
>>> constraint = constraints.independent(constraints.positive, 1)
"""
def stack(constraints: List[Constraint], dim: int = 0) -> Constraint:
"""
Stack multiple constraints along a dimension.
Parameters:
- constraints (List[Constraint]): Constraints to stack
- dim (int): Dimension to stack along
Returns:
Constraint: Stacked constraint
"""Fundamental bijective transformations for reparametrization and normalizing flows.
class Transform:
"""
Base class for bijective transformations.
Transforms provide bijective mappings between different parameter spaces,
enabling reparametrization tricks and normalizing flows.
"""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward transformation.
Parameters:
- x (Tensor): Input tensor
Returns:
Tensor: Transformed tensor
"""
def inv(self, y: torch.Tensor) -> torch.Tensor:
"""
Inverse transformation.
Parameters:
- y (Tensor): Transformed tensor
Returns:
Tensor: Original tensor
"""
def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Log absolute determinant of the Jacobian matrix.
Parameters:
- x (Tensor): Input tensor
- y (Tensor): Output tensor (usually result of __call__(x))
Returns:
Tensor: Log absolute Jacobian determinant
"""
def with_cache(self) -> 'Transform':
"""Enable caching of forward/inverse computations."""
# Identity transform
identity_transform: Transform # Identity transformation (no-op)
class ExpTransform(Transform):
"""
Exponential transform: y = exp(x).
Maps real numbers to positive numbers. Commonly used for
ensuring positivity constraints.
Examples:
>>> transform = ExpTransform()
>>> x = torch.tensor([-1.0, 0.0, 1.0])
>>> y = transform(x) # [exp(-1), 1, exp(1)]
>>> x_recovered = transform.inv(y)
"""
class SigmoidTransform(Transform):
"""
Sigmoid transform: y = sigmoid(x) = 1 / (1 + exp(-x)).
Maps real numbers to the unit interval (0, 1). Useful for
probability parameters.
"""
class TanhTransform(Transform):
"""
Hyperbolic tangent transform: y = tanh(x).
Maps real numbers to the interval (-1, 1).
"""
class SoftmaxTransform(Transform):
"""
Softmax transform for probability simplices.
Maps unconstrained vectors to probability simplices where
components are non-negative and sum to 1.
"""
class StickBreakingTransform(Transform):
"""
Stick-breaking transform for probability simplices.
Alternative to softmax that constructs probability vectors
using the stick-breaking construction.
"""
class AffineTransform(Transform):
"""
Affine transformation: y = scale * x + loc.
Linear transformation with location and scale parameters.
"""
def __init__(self, loc: torch.Tensor, scale: torch.Tensor, event_dim: int = 0):
"""
Parameters:
- loc (Tensor): Location/shift parameter
- scale (Tensor): Scale parameter
- event_dim (int): Number of rightmost event dimensions
Examples:
>>> # Standardization transform
>>> transform = AffineTransform(loc=-mean, scale=1/std)
>>>
>>> # Scale and shift
>>> transform = AffineTransform(loc=5.0, scale=2.0)
"""
class PowerTransform(Transform):
"""
Power transform: y = sign(x) * |x|^exponent.
Generalizes square and cube transformations.
"""
def __init__(self, exponent: float):
"""
Parameters:
- exponent (float): Power exponent
"""
class AbsTransform(Transform):
"""
Absolute value transform: y = |x|.
Maps real numbers to non-negative numbers.
"""Transforms that map between unconstrained and constrained parameter spaces.
class SoftplusTransform(Transform):
"""
Softplus transform: y = log(1 + exp(x)).
Smooth approximation to ReLU that maps real numbers to positive numbers.
More numerically stable than exp() for large x.
Examples:
>>> transform = SoftplusTransform()
>>> constraint = constraints.positive
>>> # Use together for constrained parameters
"""
class CholeskyTransform(Transform):
"""
Transform to Cholesky decomposition of positive definite matrices.
Maps unconstrained matrices to lower triangular matrices with
positive diagonal elements.
"""
class CorrCholeskyTransform(Transform):
"""
Transform to Cholesky factor of correlation matrices.
Maps unconstrained vectors to Cholesky factors of correlation
matrices (unit diagonal).
"""
class LowerCholeskyTransform(Transform):
"""
Transform to lower triangular matrices with positive diagonal.
Ensures the result is a valid Cholesky factor.
"""
class OrderedTransform(Transform):
"""
Transform to ordered vectors where x[i] <= x[i+1].
Useful for ordered parameters like quantiles or cutpoints.
Examples:
>>> transform = OrderedTransform()
>>> x = torch.randn(5) # Unconstrained
>>> y = transform(x) # Ordered: y[0] <= y[1] <= ... <= y[4]
"""
class SimplexToOrderedTransform(Transform):
"""
Transform from probability simplex to ordered vector.
Maps probability vectors to their cumulative sums (quantiles).
"""
def biject_to(constraint: Constraint) -> Transform:
"""
Get bijective transform to a constrained space.
Returns the appropriate transform that maps from unconstrained
real numbers to the specified constraint space.
Parameters:
- constraint (Constraint): Target constraint
Returns:
Transform: Bijective transform to constraint space
Examples:
>>> # Transform to positive reals
>>> transform = biject_to(constraints.positive) # Returns ExpTransform
>>>
>>> # Transform to unit interval
>>> transform = biject_to(constraints.unit_interval) # Returns SigmoidTransform
>>>
>>> # Transform to probability simplex
>>> transform = biject_to(constraints.simplex) # Returns StickBreakingTransform
"""
def transform_to(constraint: Constraint) -> Transform:
"""
Alias for biject_to() for backward compatibility.
Parameters:
- constraint (Constraint): Target constraint
Returns:
Transform: Transform to constraint space
"""Advanced transforms for flexible density modeling and variational inference.
class ComposeTransform(Transform):
"""
Compose multiple transforms sequentially.
Chains transforms together: f3(f2(f1(x))) for transforms [f1, f2, f3].
"""
def __init__(self, parts: List[Transform]):
"""
Parameters:
- parts (List[Transform]): List of transforms to compose
Examples:
>>> # Compose affine and exponential transforms
>>> transform = ComposeTransform([
... AffineTransform(loc=0.0, scale=2.0),
... ExpTransform()
... ])
"""
class ConditionalTransform(Transform):
"""
Base class for transforms that depend on context/conditioning variables.
Enables context-dependent transformations for conditional normalizing flows.
"""
def condition(self, context: torch.Tensor) -> Transform:
"""
Condition the transform on context variables.
Parameters:
- context (Tensor): Context/conditioning variables
Returns:
Transform: Conditioned transform
"""
class AffineAutoregressive(Transform):
"""
Affine autoregressive transform for normalizing flows.
Implements Real NVP-style coupling layers with affine transformations
that preserve autoregressive structure.
"""
def __init__(self, autoregressive_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):
"""
Parameters:
- autoregressive_nn (Module): Neural network that outputs scale and shift
- log_scale_min_clip (float): Minimum value for log scale to prevent numerical issues
Examples:
>>> from pyro.nn import AutoRegressiveNN
>>> ar_nn = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
>>> transform = AffineAutoregressive(ar_nn)
"""
class AffineCoupling(Transform):
"""
Affine coupling transform for normalizing flows.
Implements coupling layers where some dimensions are transformed
as functions of other dimensions.
"""
def __init__(self, split_dim: int, hypernet: torch.nn.Module, log_scale_min_clip: float = -5.0):
"""
Parameters:
- split_dim (int): Dimension to split for coupling
- hypernet (Module): Network that computes transformation parameters
- log_scale_min_clip (float): Minimum log scale value
"""
class Spline(Transform):
"""
Monotonic rational-quadratic spline transform.
Implements neural spline flows with rational-quadratic splines
for flexible and invertible transformations.
"""
def __init__(self, widths: torch.Tensor, heights: torch.Tensor,
derivatives: torch.Tensor, bound: float = 3.0):
"""
Parameters:
- widths (Tensor): Spline bin widths
- heights (Tensor): Spline bin heights
- derivatives (Tensor): Spline derivatives at knots
- bound (float): Domain bound for the spline
"""
class SplineAutoregressive(Transform):
"""
Autoregressive spline transform for normalizing flows.
Combines spline transformations with autoregressive structure
for flexible density modeling.
"""
def __init__(self, input_dim: int, autoregressive_nn: torch.nn.Module,
count_bins: int = 8, bound: float = 3.0):
"""
Parameters:
- input_dim (int): Input dimension
- autoregressive_nn (Module): Neural network for autoregressive parameters
- count_bins (int): Number of spline bins
- bound (float): Spline domain bound
"""
class Planar(Transform):
"""
Planar normalizing flow transform.
Implements planar flows for variational inference with flexible
posterior approximations.
"""
def __init__(self, input_dim: int):
"""
Parameters:
- input_dim (int): Input dimension
Examples:
>>> planar = Planar(10)
>>> # Use in normalizing flow
>>> flows = [Planar(10) for _ in range(5)]
>>> flow = ComposeTransform(flows)
"""
class Radial(Transform):
"""
Radial normalizing flow transform.
Implements radial flows that apply transformations based on
distance from a reference point.
"""
def __init__(self, input_dim: int):
"""
Parameters:
- input_dim (int): Input dimension
"""
class Householder(Transform):
"""
Householder normalizing flow transform.
Uses Householder reflections for volume-preserving transformations
in normalizing flows.
"""
def __init__(self, input_dim: int, count_transforms: int = 1):
"""
Parameters:
- input_dim (int): Input dimension
- count_transforms (int): Number of Householder transforms to compose
"""Transforms that depend on context variables for conditional density modeling.
class ConditionalAffineAutoregressive(ConditionalTransform):
"""
Conditional version of affine autoregressive transform.
Autoregressive transform that conditions on additional context variables.
"""
def __init__(self, context_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):
"""
Parameters:
- context_nn (Module): Neural network that takes context and outputs parameters
- log_scale_min_clip (float): Minimum log scale value
"""
class ConditionalAffineCoupling(ConditionalTransform):
"""
Conditional version of affine coupling transform.
Coupling transform that conditions on context variables.
"""
def __init__(self, split_dim: int, context_nn: torch.nn.Module):
"""
Parameters:
- split_dim (int): Dimension to split for coupling
- context_nn (Module): Context-dependent neural network
"""
class ConditionalSpline(ConditionalTransform):
"""
Conditional spline transform with context dependence.
Spline transform where spline parameters depend on context variables.
"""
def __init__(self, input_dim: int, context_dim: int, count_bins: int = 8,
bound: float = 3.0, hidden_dims: List[int] = None):
"""
Parameters:
- input_dim (int): Input dimension
- context_dim (int): Context dimension
- count_bins (int): Number of spline bins
- bound (float): Spline domain bound
- hidden_dims (List[int]): Hidden dimensions for context network
"""
class ConditionalPlanar(ConditionalTransform):
"""
Conditional planar flow with context dependence.
Planar flow where transformation parameters are functions of context.
"""
def __init__(self, input_dim: int, context_dim: int):
"""
Parameters:
- input_dim (int): Input dimension
- context_dim (int): Context dimension
"""Helper functions for working with transforms and constraints.
def iterated(repeats: int, base_fn: callable, *args, **kwargs) -> Transform:
"""
Create iterated composition of transforms.
Applies the same transform multiple times in sequence.
Parameters:
- repeats (int): Number of repetitions
- base_fn (callable): Function that creates base transform
- *args, **kwargs: Arguments for base transform constructor
Returns:
Transform: Composed transform
Examples:
>>> # Create 5 repeated planar flows
>>> flow = iterated(5, Planar, input_dim=10)
"""
def permute(permutation: torch.Tensor) -> Transform:
"""
Create permutation transform.
Parameters:
- permutation (Tensor): Permutation indices
Returns:
Transform: Permutation transform
"""
def reshape(input_shape: torch.Size, output_shape: torch.Size) -> Transform:
"""
Create reshape transform.
Parameters:
- input_shape (Size): Input tensor shape
- output_shape (Size): Output tensor shape
Returns:
Transform: Reshape transform
"""import pyro
import pyro.distributions as dist
import torch
def model():
# Positive parameter using constraint
sigma = pyro.param("sigma", torch.tensor(1.0),
constraint=constraints.positive)
# Probability parameter
p = pyro.param("p", torch.tensor(0.5),
constraint=constraints.unit_interval)
# Simplex parameter (probabilities that sum to 1)
probs = pyro.param("probs", torch.ones(5) / 5,
constraint=constraints.simplex)
return pyro.sample("x", dist.Normal(0, sigma))# Transform between unconstrained and constrained spaces
constraint = constraints.positive
transform = biject_to(constraint)
# Unconstrained parameter
unconstrained_param = torch.tensor(-1.0)
# Transform to positive space
positive_param = transform(unconstrained_param) # exp(-1.0)
# Transform back
recovered = transform.inv(positive_param) # -1.0
# Jacobian for change of variables
log_det_J = transform.log_abs_det_jacobian(unconstrained_param, positive_param)from pyro.distributions.transforms import AffineAutoregressive, ComposeTransform
from pyro.nn import AutoRegressiveNN
# Create autoregressive neural networks
ar_nn1 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
ar_nn2 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
# Create flow transforms
flow_transforms = [
AffineAutoregressive(ar_nn1),
Permute(torch.randperm(10)), # Permutation between layers
AffineAutoregressive(ar_nn2)
]
# Compose into normalizing flow
flow_transform = ComposeTransform(flow_transforms)
# Use in transformed distribution
base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
flow_dist = dist.TransformedDistribution(base_dist, flow_transform)
# Sample from flow
samples = flow_dist.sample((1000,))
log_probs = flow_dist.log_prob(samples)# Conditional flow for context-dependent transformations
context_dim = 5
input_dim = 10
conditional_transform = ConditionalAffineAutoregressive(
ConditionalAutoRegressiveNN(input_dim, context_dim, [64, 64],
output_dim_multiplier=2)
)
# Condition on context
context = torch.randn(32, context_dim) # Batch of contexts
conditioned_transform = conditional_transform.condition(context)
# Use in model
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
conditional_dist = dist.TransformedDistribution(base_dist, conditioned_transform)Install with Tessl CLI
npx tessl i tessl/pypi-pyro-ppl