0
# Constraint Projections
1
2
Projection functions for enforcing constraints in optimization. These functions project parameters onto feasible sets, enabling constrained optimization by projecting updated parameters back to the constraint set after each optimization step.
3
4
## Capabilities
5
6
### Box and Hypercube Projections
7
8
```python { .api }
9
def projection_box(params, lower=None, upper=None):
10
"""
11
Project parameters onto a box constraint [lower, upper].
12
13
Args:
14
params: Parameters to project
15
lower: Lower bounds (default: None for no lower bound)
16
upper: Upper bounds (default: None for no upper bound)
17
18
Returns:
19
Projected parameters clipped to [lower, upper]
20
"""
21
22
def projection_hypercube(params, lower=0.0, upper=1.0):
23
"""
24
Project parameters onto a hypercube [lower, upper]^d.
25
26
Args:
27
params: Parameters to project
28
lower: Lower bound for all dimensions (default: 0.0)
29
upper: Upper bound for all dimensions (default: 1.0)
30
31
Returns:
32
Projected parameters clipped to hypercube
33
"""
34
```
35
36
### Lp-Norm Ball Projections
37
38
```python { .api }
39
def projection_l1_ball(params, radius=1.0):
40
"""
41
Project parameters onto the L1 ball of given radius.
42
43
Args:
44
params: Parameters to project
45
radius: Radius of the L1 ball (default: 1.0)
46
47
Returns:
48
Projected parameters with L1 norm ≤ radius
49
"""
50
51
def projection_l2_ball(params, radius=1.0):
52
"""
53
Project parameters onto the L2 ball of given radius.
54
55
Args:
56
params: Parameters to project
57
radius: Radius of the L2 ball (default: 1.0)
58
59
Returns:
60
Projected parameters with L2 norm ≤ radius
61
"""
62
63
def projection_linf_ball(params, radius=1.0):
64
"""
65
Project parameters onto the L∞ ball of given radius.
66
67
Args:
68
params: Parameters to project
69
radius: Radius of the L∞ ball (default: 1.0)
70
71
Returns:
72
Projected parameters with L∞ norm ≤ radius
73
"""
74
```
75
76
### Sphere Projections
77
78
```python { .api }
79
def projection_l1_sphere(params, radius=1.0):
80
"""
81
Project parameters onto the L1 sphere of given radius.
82
83
Args:
84
params: Parameters to project
85
radius: Radius of the L1 sphere (default: 1.0)
86
87
Returns:
88
Projected parameters with L1 norm = radius
89
"""
90
91
def projection_l2_sphere(params, radius=1.0):
92
"""
93
Project parameters onto the L2 sphere of given radius.
94
95
Args:
96
params: Parameters to project
97
radius: Radius of the L2 sphere (default: 1.0)
98
99
Returns:
100
Projected parameters with L2 norm = radius
101
"""
102
```
103
104
### Simplex and Non-negativity Projections
105
106
```python { .api }
107
def projection_simplex(params):
108
"""
109
Project parameters onto the probability simplex.
110
111
Args:
112
params: Parameters to project
113
114
Returns:
115
Projected parameters with non-negative values that sum to 1
116
"""
117
118
def projection_non_negative(params):
119
"""
120
Project parameters onto the non-negative orthant.
121
122
Args:
123
params: Parameters to project
124
125
Returns:
126
Projected parameters with all values ≥ 0
127
"""
128
```
129
130
## Usage Examples
131
132
```python
133
import optax
134
import jax.numpy as jnp
135
136
# Example parameters
137
params = jnp.array([-2.0, 1.5, 3.0, -0.5])
138
139
# Project onto unit L2 ball
140
projected_l2 = optax.projections.projection_l2_ball(params, radius=1.0)
141
142
# Project onto probability simplex
143
projected_simplex = optax.projections.projection_simplex(jnp.abs(params))
144
145
# Project onto box constraints
146
projected_box = optax.projections.projection_box(params, lower=-1.0, upper=2.0)
147
148
# Using in constrained optimization
149
def constrained_optimization_step(params, grad, optimizer, opt_state):
150
# Standard optimization step
151
updates, opt_state = optimizer.update(grad, opt_state, params)
152
new_params = optax.apply_updates(params, updates)
153
154
# Project back to feasible set
155
constrained_params = optax.projections.projection_l2_ball(new_params, radius=1.0)
156
157
return constrained_params, opt_state
158
```
159
160
## Constraint Types
161
162
| Projection | Constraint Set | Use Case |
163
|------------|----------------|----------|
164
| `projection_box` | [lower, upper] | Parameter bounds |
165
| `projection_hypercube` | [a, b]^d | Uniform bounds |
166
| `projection_l1_ball` | {x: ‖x‖₁ ≤ r} | Sparse solutions |
167
| `projection_l2_ball` | {x: ‖x‖₂ ≤ r} | Bounded parameters |
168
| `projection_linf_ball` | {x: ‖x‖∞ ≤ r} | Element-wise bounds |
169
| `projection_l1_sphere` | {x: ‖x‖₁ = r} | Fixed L1 norm |
170
| `projection_l2_sphere` | {x: ‖x‖₂ = r} | Unit sphere |
171
| `projection_simplex` | {x: x ≥ 0, Σx = 1} | Probabilities |
172
| `projection_non_negative` | {x: x ≥ 0} | Non-negative parameters |
173
174
## Import
175
176
```python
177
import optax.projections
178
# or
179
from optax.projections import (
180
projection_l2_ball, projection_simplex, projection_box
181
)
182
```