CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pyro-ppl

A flexible, scalable deep probabilistic programming library built on PyTorch for universal probabilistic modeling and inference

Pending
Overview
Eval results
Files

transforms-constraints.mddocs/

Transforms and Constraints

Bijective transformations and parameter constraints for reparametrization, constrained optimization, and normalizing flows in probabilistic models, enabling flexible and efficient inference over constrained parameter spaces.

Capabilities

Parameter Constraints

Constraints that define valid parameter domains and enable automatic constraint handling during optimization.

class Constraint:
    """
    Base class for parameter constraints.
    
    Constraints define the valid domain for parameters and provide
    methods for checking constraint satisfaction and projecting
    values onto the constraint set.
    """
    
    def check(self, value: torch.Tensor) -> torch.Tensor:
        """
        Check if value satisfies the constraint.
        
        Parameters:
        - value (Tensor): Value to check
        
        Returns:
        Tensor: Boolean tensor indicating constraint satisfaction
        """
    
    def is_discrete(self) -> bool:
        """Whether this constraint is over discrete values."""
    
    def event_dim(self) -> int:
        """Number of rightmost dimensions that are part of the event."""

# Basic constraints
real: Constraint  # Unconstrained real numbers
boolean: Constraint  # Boolean values {0, 1}
nonnegative: Constraint  # Non-negative real numbers [0, ∞)
positive: Constraint  # Positive real numbers (0, ∞)
unit_interval: Constraint  # Unit interval [0, 1]
nonnegative_integer: Constraint  # Non-negative integers {0, 1, 2, ...}
positive_integer: Constraint  # Positive integers {1, 2, 3, ...}

# Interval constraints
def greater_than(lower_bound: float) -> Constraint:
    """
    Constraint for values greater than a lower bound.
    
    Parameters:
    - lower_bound (float): Lower bound (exclusive)
    
    Returns:
    Constraint: Greater than constraint
    
    Examples:
    >>> constraint = constraints.greater_than(0.0)  # Positive values
    >>> constraint = constraints.greater_than(-1.0)  # Values > -1
    """

def less_than(upper_bound: float) -> Constraint:
    """
    Constraint for values less than an upper bound.
    
    Parameters:
    - upper_bound (float): Upper bound (exclusive)
    
    Returns:
    Constraint: Less than constraint
    """

def interval(lower_bound: float, upper_bound: float) -> Constraint:
    """
    Constraint for values in an interval.
    
    Parameters:
    - lower_bound (float): Lower bound (inclusive)
    - upper_bound (float): Upper bound (exclusive)
    
    Returns:
    Constraint: Interval constraint
    
    Examples:
    >>> constraint = constraints.interval(-1.0, 1.0)  # Values in [-1, 1)
    """

# Matrix constraints
simplex: Constraint  # Probability simplex (non-negative, sum to 1)
positive_definite: Constraint  # Positive definite matrices  
lower_cholesky: Constraint  # Lower triangular matrices with positive diagonal
corr_cholesky: Constraint  # Cholesky factors of correlation matrices

# Pyro-specific constraints
integer: Constraint  # Integer values
sphere: Constraint  # Unit sphere constraint
corr_matrix: Constraint  # Correlation matrices
ordered_vector: Constraint  # Ordered vectors (x[i] <= x[i+1])
positive_ordered_vector: Constraint  # Positive ordered vectors
softplus_positive: Constraint  # Softplus-transformed positive values
softplus_lower_cholesky: Constraint  # Softplus-transformed lower Cholesky
unit_lower_cholesky: Constraint  # Unit lower Cholesky constraint

# Composite constraints
def independent(constraint: Constraint, reinterpreted_batch_ndims: int) -> Constraint:
    """
    Reinterpret batch dimensions as event dimensions for a constraint.
    
    Parameters:
    - constraint (Constraint): Base constraint
    - reinterpreted_batch_ndims (int): Number of batch dims to treat as event dims
    
    Returns:
    Constraint: Independent constraint
    
    Examples:
    >>> # Vector of positive values
    >>> constraint = constraints.independent(constraints.positive, 1)
    """

def stack(constraints: List[Constraint], dim: int = 0) -> Constraint:
    """
    Stack multiple constraints along a dimension.
    
    Parameters:
    - constraints (List[Constraint]): Constraints to stack
    - dim (int): Dimension to stack along
    
    Returns:
    Constraint: Stacked constraint
    """

Basic Transforms

Fundamental bijective transformations for reparametrization and normalizing flows.

class Transform:
    """
    Base class for bijective transformations.
    
    Transforms provide bijective mappings between different parameter spaces,
    enabling reparametrization tricks and normalizing flows.
    """
    
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward transformation.
        
        Parameters:
        - x (Tensor): Input tensor
        
        Returns:
        Tensor: Transformed tensor
        """
    
    def inv(self, y: torch.Tensor) -> torch.Tensor:
        """
        Inverse transformation.
        
        Parameters:
        - y (Tensor): Transformed tensor
        
        Returns:
        Tensor: Original tensor
        """
    
    def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Log absolute determinant of the Jacobian matrix.
        
        Parameters:
        - x (Tensor): Input tensor
        - y (Tensor): Output tensor (usually result of __call__(x))
        
        Returns:
        Tensor: Log absolute Jacobian determinant
        """
    
    def with_cache(self) -> 'Transform':
        """Enable caching of forward/inverse computations."""

# Identity transform
identity_transform: Transform  # Identity transformation (no-op)

class ExpTransform(Transform):
    """
    Exponential transform: y = exp(x).
    
    Maps real numbers to positive numbers. Commonly used for
    ensuring positivity constraints.
    
    Examples:
    >>> transform = ExpTransform()
    >>> x = torch.tensor([-1.0, 0.0, 1.0])
    >>> y = transform(x)  # [exp(-1), 1, exp(1)]
    >>> x_recovered = transform.inv(y)
    """

class SigmoidTransform(Transform):
    """
    Sigmoid transform: y = sigmoid(x) = 1 / (1 + exp(-x)).
    
    Maps real numbers to the unit interval (0, 1). Useful for
    probability parameters.
    """

class TanhTransform(Transform):
    """
    Hyperbolic tangent transform: y = tanh(x).
    
    Maps real numbers to the interval (-1, 1).
    """

class SoftmaxTransform(Transform):
    """
    Softmax transform for probability simplices.
    
    Maps unconstrained vectors to probability simplices where
    components are non-negative and sum to 1.
    """

class StickBreakingTransform(Transform):
    """
    Stick-breaking transform for probability simplices.
    
    Alternative to softmax that constructs probability vectors
    using the stick-breaking construction.
    """

class AffineTransform(Transform):
    """
    Affine transformation: y = scale * x + loc.
    
    Linear transformation with location and scale parameters.
    """
    
    def __init__(self, loc: torch.Tensor, scale: torch.Tensor, event_dim: int = 0):
        """
        Parameters:
        - loc (Tensor): Location/shift parameter
        - scale (Tensor): Scale parameter
        - event_dim (int): Number of rightmost event dimensions
        
        Examples:
        >>> # Standardization transform
        >>> transform = AffineTransform(loc=-mean, scale=1/std)
        >>>
        >>> # Scale and shift
        >>> transform = AffineTransform(loc=5.0, scale=2.0)
        """

class PowerTransform(Transform):
    """
    Power transform: y = sign(x) * |x|^exponent.
    
    Generalizes square and cube transformations.
    """
    
    def __init__(self, exponent: float):
        """
        Parameters:
        - exponent (float): Power exponent
        """

class AbsTransform(Transform):
    """
    Absolute value transform: y = |x|.
    
    Maps real numbers to non-negative numbers.
    """

Constraint-Based Transforms

Transforms that map between unconstrained and constrained parameter spaces.

class SoftplusTransform(Transform):
    """
    Softplus transform: y = log(1 + exp(x)).
    
    Smooth approximation to ReLU that maps real numbers to positive numbers.
    More numerically stable than exp() for large x.
    
    Examples:
    >>> transform = SoftplusTransform()
    >>> constraint = constraints.positive
    >>> # Use together for constrained parameters
    """

class CholeskyTransform(Transform):
    """
    Transform to Cholesky decomposition of positive definite matrices.
    
    Maps unconstrained matrices to lower triangular matrices with
    positive diagonal elements.
    """

class CorrCholeskyTransform(Transform):
    """
    Transform to Cholesky factor of correlation matrices.
    
    Maps unconstrained vectors to Cholesky factors of correlation
    matrices (unit diagonal).
    """

class LowerCholeskyTransform(Transform):
    """
    Transform to lower triangular matrices with positive diagonal.
    
    Ensures the result is a valid Cholesky factor.
    """

class OrderedTransform(Transform):
    """
    Transform to ordered vectors where x[i] <= x[i+1].
    
    Useful for ordered parameters like quantiles or cutpoints.
    
    Examples:
    >>> transform = OrderedTransform()
    >>> x = torch.randn(5)  # Unconstrained
    >>> y = transform(x)    # Ordered: y[0] <= y[1] <= ... <= y[4]
    """

class SimplexToOrderedTransform(Transform):
    """
    Transform from probability simplex to ordered vector.
    
    Maps probability vectors to their cumulative sums (quantiles).
    """

def biject_to(constraint: Constraint) -> Transform:
    """
    Get bijective transform to a constrained space.
    
    Returns the appropriate transform that maps from unconstrained
    real numbers to the specified constraint space.
    
    Parameters:
    - constraint (Constraint): Target constraint
    
    Returns:
    Transform: Bijective transform to constraint space
    
    Examples:
    >>> # Transform to positive reals
    >>> transform = biject_to(constraints.positive)  # Returns ExpTransform
    >>>
    >>> # Transform to unit interval
    >>> transform = biject_to(constraints.unit_interval)  # Returns SigmoidTransform
    >>>
    >>> # Transform to probability simplex
    >>> transform = biject_to(constraints.simplex)  # Returns StickBreakingTransform
    """

def transform_to(constraint: Constraint) -> Transform:
    """
    Alias for biject_to() for backward compatibility.
    
    Parameters:
    - constraint (Constraint): Target constraint
    
    Returns:
    Transform: Transform to constraint space
    """

Normalizing Flows

Advanced transforms for flexible density modeling and variational inference.

class ComposeTransform(Transform):
    """
    Compose multiple transforms sequentially.
    
    Chains transforms together: f3(f2(f1(x))) for transforms [f1, f2, f3].
    """
    
    def __init__(self, parts: List[Transform]):
        """
        Parameters:
        - parts (List[Transform]): List of transforms to compose
        
        Examples:
        >>> # Compose affine and exponential transforms
        >>> transform = ComposeTransform([
        ...     AffineTransform(loc=0.0, scale=2.0),
        ...     ExpTransform()
        ... ])
        """

class ConditionalTransform(Transform):
    """
    Base class for transforms that depend on context/conditioning variables.
    
    Enables context-dependent transformations for conditional normalizing flows.
    """
    
    def condition(self, context: torch.Tensor) -> Transform:
        """
        Condition the transform on context variables.
        
        Parameters:
        - context (Tensor): Context/conditioning variables
        
        Returns:
        Transform: Conditioned transform
        """

class AffineAutoregressive(Transform):
    """
    Affine autoregressive transform for normalizing flows.
    
    Implements Real NVP-style coupling layers with affine transformations
    that preserve autoregressive structure.
    """
    
    def __init__(self, autoregressive_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):
        """
        Parameters:
        - autoregressive_nn (Module): Neural network that outputs scale and shift
        - log_scale_min_clip (float): Minimum value for log scale to prevent numerical issues
        
        Examples:
        >>> from pyro.nn import AutoRegressiveNN
        >>> ar_nn = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
        >>> transform = AffineAutoregressive(ar_nn)
        """

class AffineCoupling(Transform):
    """
    Affine coupling transform for normalizing flows.
    
    Implements coupling layers where some dimensions are transformed
    as functions of other dimensions.
    """
    
    def __init__(self, split_dim: int, hypernet: torch.nn.Module, log_scale_min_clip: float = -5.0):
        """
        Parameters:
        - split_dim (int): Dimension to split for coupling
        - hypernet (Module): Network that computes transformation parameters
        - log_scale_min_clip (float): Minimum log scale value
        """

class Spline(Transform):
    """
    Monotonic rational-quadratic spline transform.
    
    Implements neural spline flows with rational-quadratic splines
    for flexible and invertible transformations.
    """
    
    def __init__(self, widths: torch.Tensor, heights: torch.Tensor, 
                 derivatives: torch.Tensor, bound: float = 3.0):
        """
        Parameters:
        - widths (Tensor): Spline bin widths
        - heights (Tensor): Spline bin heights  
        - derivatives (Tensor): Spline derivatives at knots
        - bound (float): Domain bound for the spline
        """

class SplineAutoregressive(Transform):
    """
    Autoregressive spline transform for normalizing flows.
    
    Combines spline transformations with autoregressive structure
    for flexible density modeling.
    """
    
    def __init__(self, input_dim: int, autoregressive_nn: torch.nn.Module, 
                 count_bins: int = 8, bound: float = 3.0):
        """
        Parameters:
        - input_dim (int): Input dimension
        - autoregressive_nn (Module): Neural network for autoregressive parameters
        - count_bins (int): Number of spline bins
        - bound (float): Spline domain bound
        """

class Planar(Transform):
    """
    Planar normalizing flow transform.
    
    Implements planar flows for variational inference with flexible
    posterior approximations.
    """
    
    def __init__(self, input_dim: int):
        """
        Parameters:
        - input_dim (int): Input dimension
        
        Examples:
        >>> planar = Planar(10)
        >>> # Use in normalizing flow
        >>> flows = [Planar(10) for _ in range(5)]
        >>> flow = ComposeTransform(flows)
        """

class Radial(Transform):
    """
    Radial normalizing flow transform.
    
    Implements radial flows that apply transformations based on
    distance from a reference point.
    """
    
    def __init__(self, input_dim: int):
        """
        Parameters:
        - input_dim (int): Input dimension
        """

class Householder(Transform):
    """
    Householder normalizing flow transform.
    
    Uses Householder reflections for volume-preserving transformations
    in normalizing flows.
    """
    
    def __init__(self, input_dim: int, count_transforms: int = 1):
        """
        Parameters:
        - input_dim (int): Input dimension
        - count_transforms (int): Number of Householder transforms to compose
        """

Conditional Transforms

Transforms that depend on context variables for conditional density modeling.

class ConditionalAffineAutoregressive(ConditionalTransform):
    """
    Conditional version of affine autoregressive transform.
    
    Autoregressive transform that conditions on additional context variables.
    """
    
    def __init__(self, context_nn: torch.nn.Module, log_scale_min_clip: float = -5.0):
        """
        Parameters:
        - context_nn (Module): Neural network that takes context and outputs parameters
        - log_scale_min_clip (float): Minimum log scale value
        """

class ConditionalAffineCoupling(ConditionalTransform):
    """
    Conditional version of affine coupling transform.
    
    Coupling transform that conditions on context variables.
    """
    
    def __init__(self, split_dim: int, context_nn: torch.nn.Module):
        """
        Parameters:
        - split_dim (int): Dimension to split for coupling
        - context_nn (Module): Context-dependent neural network
        """

class ConditionalSpline(ConditionalTransform):
    """
    Conditional spline transform with context dependence.
    
    Spline transform where spline parameters depend on context variables.
    """
    
    def __init__(self, input_dim: int, context_dim: int, count_bins: int = 8,
                 bound: float = 3.0, hidden_dims: List[int] = None):
        """
        Parameters:
        - input_dim (int): Input dimension
        - context_dim (int): Context dimension
        - count_bins (int): Number of spline bins
        - bound (float): Spline domain bound
        - hidden_dims (List[int]): Hidden dimensions for context network
        """

class ConditionalPlanar(ConditionalTransform):
    """
    Conditional planar flow with context dependence.
    
    Planar flow where transformation parameters are functions of context.
    """
    
    def __init__(self, input_dim: int, context_dim: int):
        """
        Parameters:
        - input_dim (int): Input dimension  
        - context_dim (int): Context dimension
        """

Utility Functions

Helper functions for working with transforms and constraints.

def iterated(repeats: int, base_fn: callable, *args, **kwargs) -> Transform:
    """
    Create iterated composition of transforms.
    
    Applies the same transform multiple times in sequence.
    
    Parameters:
    - repeats (int): Number of repetitions
    - base_fn (callable): Function that creates base transform
    - *args, **kwargs: Arguments for base transform constructor
    
    Returns:
    Transform: Composed transform
    
    Examples:
    >>> # Create 5 repeated planar flows
    >>> flow = iterated(5, Planar, input_dim=10)
    """

def permute(permutation: torch.Tensor) -> Transform:
    """
    Create permutation transform.
    
    Parameters:
    - permutation (Tensor): Permutation indices
    
    Returns:
    Transform: Permutation transform
    """

def reshape(input_shape: torch.Size, output_shape: torch.Size) -> Transform:
    """
    Create reshape transform.
    
    Parameters:
    - input_shape (Size): Input tensor shape
    - output_shape (Size): Output tensor shape
    
    Returns:
    Transform: Reshape transform
    """

Examples

Constrained Parameter Optimization

import pyro
import pyro.distributions as dist
import torch

def model():
    # Positive parameter using constraint
    sigma = pyro.param("sigma", torch.tensor(1.0), 
                      constraint=constraints.positive)
    
    # Probability parameter
    p = pyro.param("p", torch.tensor(0.5),
                  constraint=constraints.unit_interval)
    
    # Simplex parameter (probabilities that sum to 1)
    probs = pyro.param("probs", torch.ones(5) / 5,
                      constraint=constraints.simplex)
    
    return pyro.sample("x", dist.Normal(0, sigma))

Manual Transform Usage

# Transform between unconstrained and constrained spaces
constraint = constraints.positive
transform = biject_to(constraint)

# Unconstrained parameter  
unconstrained_param = torch.tensor(-1.0)

# Transform to positive space
positive_param = transform(unconstrained_param)  # exp(-1.0)

# Transform back
recovered = transform.inv(positive_param)  # -1.0

# Jacobian for change of variables
log_det_J = transform.log_abs_det_jacobian(unconstrained_param, positive_param)

Normalizing Flow

from pyro.distributions.transforms import AffineAutoregressive, ComposeTransform
from pyro.nn import AutoRegressiveNN

# Create autoregressive neural networks
ar_nn1 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)
ar_nn2 = AutoRegressiveNN(10, [50, 50], output_dim_multiplier=2)

# Create flow transforms
flow_transforms = [
    AffineAutoregressive(ar_nn1),
    Permute(torch.randperm(10)),  # Permutation between layers
    AffineAutoregressive(ar_nn2)
]

# Compose into normalizing flow
flow_transform = ComposeTransform(flow_transforms)

# Use in transformed distribution
base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
flow_dist = dist.TransformedDistribution(base_dist, flow_transform)

# Sample from flow
samples = flow_dist.sample((1000,))
log_probs = flow_dist.log_prob(samples)

Conditional Normalizing Flow

# Conditional flow for context-dependent transformations
context_dim = 5
input_dim = 10

conditional_transform = ConditionalAffineAutoregressive(
    ConditionalAutoRegressiveNN(input_dim, context_dim, [64, 64], 
                               output_dim_multiplier=2)
)

# Condition on context
context = torch.randn(32, context_dim)  # Batch of contexts
conditioned_transform = conditional_transform.condition(context)

# Use in model
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
conditional_dist = dist.TransformedDistribution(base_dist, conditioned_transform)

Install with Tessl CLI

npx tessl i tessl/pypi-pyro-ppl

docs

core-programming.md

distributions.md

gaussian-processes.md

index.md

inference.md

neural-networks.md

optimization.md

transforms-constraints.md

tile.json