CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

projections.mddocs/

Constraint Projections

Projection functions for enforcing constraints in optimization. These functions project parameters onto feasible sets, enabling constrained optimization by projecting updated parameters back to the constraint set after each optimization step.

Capabilities

Box and Hypercube Projections

def projection_box(params, lower=None, upper=None):
    """
    Project parameters onto a box constraint [lower, upper].
    
    Args:
        params: Parameters to project
        lower: Lower bounds (default: None for no lower bound)
        upper: Upper bounds (default: None for no upper bound)
    
    Returns:
        Projected parameters clipped to [lower, upper]
    """

def projection_hypercube(params, lower=0.0, upper=1.0):
    """
    Project parameters onto a hypercube [lower, upper]^d.
    
    Args:
        params: Parameters to project
        lower: Lower bound for all dimensions (default: 0.0)
        upper: Upper bound for all dimensions (default: 1.0)
    
    Returns:
        Projected parameters clipped to hypercube
    """

Lp-Norm Ball Projections

def projection_l1_ball(params, radius=1.0):
    """
    Project parameters onto the L1 ball of given radius.
    
    Args:
        params: Parameters to project
        radius: Radius of the L1 ball (default: 1.0)
    
    Returns:
        Projected parameters with L1 norm ≤ radius
    """

def projection_l2_ball(params, radius=1.0):
    """
    Project parameters onto the L2 ball of given radius.
    
    Args:
        params: Parameters to project
        radius: Radius of the L2 ball (default: 1.0)
    
    Returns:
        Projected parameters with L2 norm ≤ radius
    """

def projection_linf_ball(params, radius=1.0):
    """
    Project parameters onto the L∞ ball of given radius.
    
    Args:
        params: Parameters to project
        radius: Radius of the L∞ ball (default: 1.0)
    
    Returns:
        Projected parameters with L∞ norm ≤ radius
    """

Sphere Projections

def projection_l1_sphere(params, radius=1.0):
    """
    Project parameters onto the L1 sphere of given radius.
    
    Args:
        params: Parameters to project
        radius: Radius of the L1 sphere (default: 1.0)
    
    Returns:
        Projected parameters with L1 norm = radius
    """

def projection_l2_sphere(params, radius=1.0):
    """
    Project parameters onto the L2 sphere of given radius.
    
    Args:
        params: Parameters to project
        radius: Radius of the L2 sphere (default: 1.0)
    
    Returns:
        Projected parameters with L2 norm = radius
    """

Simplex and Non-negativity Projections

def projection_simplex(params):
    """
    Project parameters onto the probability simplex.
    
    Args:
        params: Parameters to project
    
    Returns:
        Projected parameters with non-negative values that sum to 1
    """

def projection_non_negative(params):
    """
    Project parameters onto the non-negative orthant.
    
    Args:
        params: Parameters to project
    
    Returns:
        Projected parameters with all values ≥ 0
    """

Usage Examples

import optax
import jax.numpy as jnp

# Example parameters
params = jnp.array([-2.0, 1.5, 3.0, -0.5])

# Project onto unit L2 ball
projected_l2 = optax.projections.projection_l2_ball(params, radius=1.0)

# Project onto probability simplex
projected_simplex = optax.projections.projection_simplex(jnp.abs(params))

# Project onto box constraints
projected_box = optax.projections.projection_box(params, lower=-1.0, upper=2.0)

# Using in constrained optimization
def constrained_optimization_step(params, grad, optimizer, opt_state):
    # Standard optimization step
    updates, opt_state = optimizer.update(grad, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    
    # Project back to feasible set
    constrained_params = optax.projections.projection_l2_ball(new_params, radius=1.0)
    
    return constrained_params, opt_state

Constraint Types

ProjectionConstraint SetUse Case
projection_box[lower, upper]Parameter bounds
projection_hypercube[a, b]^dUniform bounds
projection_l1_ball{x: ‖x‖₁ ≤ r}Sparse solutions
projection_l2_ball{x: ‖x‖₂ ≤ r}Bounded parameters
projection_linf_ball{x: ‖x‖∞ ≤ r}Element-wise bounds
projection_l1_sphere{x: ‖x‖₁ = r}Fixed L1 norm
projection_l2_sphere{x: ‖x‖₂ = r}Unit sphere
projection_simplex{x: x ≥ 0, Σx = 1}Probabilities
projection_non_negative{x: x ≥ 0}Non-negative parameters

Import

import optax.projections
# or
from optax.projections import (
    projection_l2_ball, projection_simplex, projection_box
)

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json