or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

projections.mddocs/

0

# Constraint Projections

1

2

Projection functions for enforcing constraints in optimization. These functions project parameters onto feasible sets, enabling constrained optimization by projecting updated parameters back to the constraint set after each optimization step.

3

4

## Capabilities

5

6

### Box and Hypercube Projections

7

8

```python { .api }

9

def projection_box(params, lower=None, upper=None):

10

"""

11

Project parameters onto a box constraint [lower, upper].

12

13

Args:

14

params: Parameters to project

15

lower: Lower bounds (default: None for no lower bound)

16

upper: Upper bounds (default: None for no upper bound)

17

18

Returns:

19

Projected parameters clipped to [lower, upper]

20

"""

21

22

def projection_hypercube(params, lower=0.0, upper=1.0):

23

"""

24

Project parameters onto a hypercube [lower, upper]^d.

25

26

Args:

27

params: Parameters to project

28

lower: Lower bound for all dimensions (default: 0.0)

29

upper: Upper bound for all dimensions (default: 1.0)

30

31

Returns:

32

Projected parameters clipped to hypercube

33

"""

34

```

35

36

### Lp-Norm Ball Projections

37

38

```python { .api }

39

def projection_l1_ball(params, radius=1.0):

40

"""

41

Project parameters onto the L1 ball of given radius.

42

43

Args:

44

params: Parameters to project

45

radius: Radius of the L1 ball (default: 1.0)

46

47

Returns:

48

Projected parameters with L1 norm ≤ radius

49

"""

50

51

def projection_l2_ball(params, radius=1.0):

52

"""

53

Project parameters onto the L2 ball of given radius.

54

55

Args:

56

params: Parameters to project

57

radius: Radius of the L2 ball (default: 1.0)

58

59

Returns:

60

Projected parameters with L2 norm ≤ radius

61

"""

62

63

def projection_linf_ball(params, radius=1.0):

64

"""

65

Project parameters onto the L∞ ball of given radius.

66

67

Args:

68

params: Parameters to project

69

radius: Radius of the L∞ ball (default: 1.0)

70

71

Returns:

72

Projected parameters with L∞ norm ≤ radius

73

"""

74

```

75

76

### Sphere Projections

77

78

```python { .api }

79

def projection_l1_sphere(params, radius=1.0):

80

"""

81

Project parameters onto the L1 sphere of given radius.

82

83

Args:

84

params: Parameters to project

85

radius: Radius of the L1 sphere (default: 1.0)

86

87

Returns:

88

Projected parameters with L1 norm = radius

89

"""

90

91

def projection_l2_sphere(params, radius=1.0):

92

"""

93

Project parameters onto the L2 sphere of given radius.

94

95

Args:

96

params: Parameters to project

97

radius: Radius of the L2 sphere (default: 1.0)

98

99

Returns:

100

Projected parameters with L2 norm = radius

101

"""

102

```

103

104

### Simplex and Non-negativity Projections

105

106

```python { .api }

107

def projection_simplex(params):

108

"""

109

Project parameters onto the probability simplex.

110

111

Args:

112

params: Parameters to project

113

114

Returns:

115

Projected parameters with non-negative values that sum to 1

116

"""

117

118

def projection_non_negative(params):

119

"""

120

Project parameters onto the non-negative orthant.

121

122

Args:

123

params: Parameters to project

124

125

Returns:

126

Projected parameters with all values ≥ 0

127

"""

128

```

129

130

## Usage Examples

131

132

```python

133

import optax

134

import jax.numpy as jnp

135

136

# Example parameters

137

params = jnp.array([-2.0, 1.5, 3.0, -0.5])

138

139

# Project onto unit L2 ball

140

projected_l2 = optax.projections.projection_l2_ball(params, radius=1.0)

141

142

# Project onto probability simplex

143

projected_simplex = optax.projections.projection_simplex(jnp.abs(params))

144

145

# Project onto box constraints

146

projected_box = optax.projections.projection_box(params, lower=-1.0, upper=2.0)

147

148

# Using in constrained optimization

149

def constrained_optimization_step(params, grad, optimizer, opt_state):

150

# Standard optimization step

151

updates, opt_state = optimizer.update(grad, opt_state, params)

152

new_params = optax.apply_updates(params, updates)

153

154

# Project back to feasible set

155

constrained_params = optax.projections.projection_l2_ball(new_params, radius=1.0)

156

157

return constrained_params, opt_state

158

```

159

160

## Constraint Types

161

162

| Projection | Constraint Set | Use Case |

163

|------------|----------------|----------|

164

| `projection_box` | [lower, upper] | Parameter bounds |

165

| `projection_hypercube` | [a, b]^d | Uniform bounds |

166

| `projection_l1_ball` | {x: ‖x‖₁ ≤ r} | Sparse solutions |

167

| `projection_l2_ball` | {x: ‖x‖₂ ≤ r} | Bounded parameters |

168

| `projection_linf_ball` | {x: ‖x‖∞ ≤ r} | Element-wise bounds |

169

| `projection_l1_sphere` | {x: ‖x‖₁ = r} | Fixed L1 norm |

170

| `projection_l2_sphere` | {x: ‖x‖₂ = r} | Unit sphere |

171

| `projection_simplex` | {x: x ≥ 0, Σx = 1} | Probabilities |

172

| `projection_non_negative` | {x: x ≥ 0} | Non-negative parameters |

173

174

## Import

175

176

```python

177

import optax.projections

178

# or

179

from optax.projections import (

180

projection_l2_ball, projection_simplex, projection_box

181

)

182

```