A gradient processing and optimization library in JAX
—
Projection functions for enforcing constraints in optimization. These functions project parameters onto feasible sets, enabling constrained optimization by projecting updated parameters back to the constraint set after each optimization step.
def projection_box(params, lower=None, upper=None):
"""
Project parameters onto a box constraint [lower, upper].
Args:
params: Parameters to project
lower: Lower bounds (default: None for no lower bound)
upper: Upper bounds (default: None for no upper bound)
Returns:
Projected parameters clipped to [lower, upper]
"""
def projection_hypercube(params, lower=0.0, upper=1.0):
"""
Project parameters onto a hypercube [lower, upper]^d.
Args:
params: Parameters to project
lower: Lower bound for all dimensions (default: 0.0)
upper: Upper bound for all dimensions (default: 1.0)
Returns:
Projected parameters clipped to hypercube
"""def projection_l1_ball(params, radius=1.0):
"""
Project parameters onto the L1 ball of given radius.
Args:
params: Parameters to project
radius: Radius of the L1 ball (default: 1.0)
Returns:
Projected parameters with L1 norm ≤ radius
"""
def projection_l2_ball(params, radius=1.0):
"""
Project parameters onto the L2 ball of given radius.
Args:
params: Parameters to project
radius: Radius of the L2 ball (default: 1.0)
Returns:
Projected parameters with L2 norm ≤ radius
"""
def projection_linf_ball(params, radius=1.0):
"""
Project parameters onto the L∞ ball of given radius.
Args:
params: Parameters to project
radius: Radius of the L∞ ball (default: 1.0)
Returns:
Projected parameters with L∞ norm ≤ radius
"""def projection_l1_sphere(params, radius=1.0):
"""
Project parameters onto the L1 sphere of given radius.
Args:
params: Parameters to project
radius: Radius of the L1 sphere (default: 1.0)
Returns:
Projected parameters with L1 norm = radius
"""
def projection_l2_sphere(params, radius=1.0):
"""
Project parameters onto the L2 sphere of given radius.
Args:
params: Parameters to project
radius: Radius of the L2 sphere (default: 1.0)
Returns:
Projected parameters with L2 norm = radius
"""def projection_simplex(params):
"""
Project parameters onto the probability simplex.
Args:
params: Parameters to project
Returns:
Projected parameters with non-negative values that sum to 1
"""
def projection_non_negative(params):
"""
Project parameters onto the non-negative orthant.
Args:
params: Parameters to project
Returns:
Projected parameters with all values ≥ 0
"""import optax
import jax.numpy as jnp
# Example parameters
params = jnp.array([-2.0, 1.5, 3.0, -0.5])
# Project onto unit L2 ball
projected_l2 = optax.projections.projection_l2_ball(params, radius=1.0)
# Project onto probability simplex
projected_simplex = optax.projections.projection_simplex(jnp.abs(params))
# Project onto box constraints
projected_box = optax.projections.projection_box(params, lower=-1.0, upper=2.0)
# Using in constrained optimization
def constrained_optimization_step(params, grad, optimizer, opt_state):
# Standard optimization step
updates, opt_state = optimizer.update(grad, opt_state, params)
new_params = optax.apply_updates(params, updates)
# Project back to feasible set
constrained_params = optax.projections.projection_l2_ball(new_params, radius=1.0)
return constrained_params, opt_state| Projection | Constraint Set | Use Case |
|---|---|---|
projection_box | [lower, upper] | Parameter bounds |
projection_hypercube | [a, b]^d | Uniform bounds |
projection_l1_ball | {x: ‖x‖₁ ≤ r} | Sparse solutions |
projection_l2_ball | {x: ‖x‖₂ ≤ r} | Bounded parameters |
projection_linf_ball | {x: ‖x‖∞ ≤ r} | Element-wise bounds |
projection_l1_sphere | {x: ‖x‖₁ = r} | Fixed L1 norm |
projection_l2_sphere | {x: ‖x‖₂ = r} | Unit sphere |
projection_simplex | {x: x ≥ 0, Σx = 1} | Probabilities |
projection_non_negative | {x: x ≥ 0} | Non-negative parameters |
import optax.projections
# or
from optax.projections import (
projection_l2_ball, projection_simplex, projection_box
)Install with Tessl CLI
npx tessl i tessl/pypi-optax