0
# Experimental Optimizers (contrib)
1
2
The `optax.contrib` module contains experimental optimizers and techniques under active development. These are cutting-edge optimization methods that may not be as stable as the core optimizers but represent the latest research in optimization.
3
4
**Note**: Experimental features may have API changes in future versions.
5
6
## Capabilities
7
8
### Advanced Adaptive Optimizers
9
10
#### Sharpness-Aware Minimization (SAM)
11
12
```python { .api }
13
def sam(base_optimizer, rho=0.05, normalize=True):
14
"""
15
Sharpness-Aware Minimization optimizer.
16
17
Args:
18
base_optimizer: Base optimizer to use (e.g., SGD, Adam)
19
rho: Neighborhood size for sharpness computation (default: 0.05)
20
normalize: Whether to normalize perturbation (default: True)
21
22
Returns:
23
GradientTransformation: SAM optimizer
24
"""
25
```
26
27
#### Prodigy Optimizer
28
29
```python { .api }
30
def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0):
31
"""
32
Prodigy adaptive learning rate optimizer.
33
34
Args:
35
learning_rate: Initial learning rate (default: 1.0)
36
eps: Numerical stability parameter (default: 1e-8)
37
beta1: First moment decay rate (default: 0.9)
38
beta2: Second moment decay rate (default: 0.999)
39
weight_decay: Weight decay coefficient (default: 0.0)
40
41
Returns:
42
GradientTransformation: Prodigy optimizer
43
"""
44
```
45
46
#### Sophia Optimizer
47
48
```python { .api }
49
def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4):
50
"""
51
Sophia optimizer using second-order information.
52
53
Args:
54
learning_rate: Learning rate
55
beta1: First moment decay rate (default: 0.965)
56
beta2: Second moment decay rate (default: 0.99)
57
eps: Numerical stability parameter (default: 1e-8)
58
weight_decay: Weight decay coefficient (default: 1e-4)
59
60
Returns:
61
GradientTransformation: Sophia optimizer
62
"""
63
```
64
65
### Schedule-Free Optimizers
66
67
#### Schedule-Free AdamW
68
69
```python { .api }
70
def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0):
71
"""
72
Schedule-free AdamW optimizer that doesn't require learning rate schedules.
73
74
Args:
75
learning_rate: Learning rate (default: 0.0025)
76
beta1: First moment decay rate (default: 0.9)
77
beta2: Second moment decay rate (default: 0.999)
78
eps: Numerical stability parameter (default: 1e-8)
79
weight_decay: Weight decay coefficient (default: 0.0)
80
81
Returns:
82
GradientTransformation: Schedule-free AdamW optimizer
83
"""
84
85
def schedule_free_sgd(learning_rate=1.0, momentum=0.9, weight_decay=0.0):
86
"""
87
Schedule-free SGD optimizer.
88
89
Args:
90
learning_rate: Learning rate (default: 1.0)
91
momentum: Momentum coefficient (default: 0.9)
92
weight_decay: Weight decay coefficient (default: 0.0)
93
94
Returns:
95
GradientTransformation: Schedule-free SGD optimizer
96
"""
97
98
def schedule_free_eval_params(optimizer_state, step_count):
99
"""
100
Extract evaluation parameters from schedule-free optimizer state.
101
102
Args:
103
optimizer_state: State from schedule-free optimizer
104
step_count: Current training step count
105
106
Returns:
107
Parameters suitable for evaluation/inference
108
"""
109
```
110
111
### Momentum-Based Methods
112
113
#### Muon Optimizer
114
115
```python { .api }
116
def muon(learning_rate, momentum=0.95, nesterov=False):
117
"""
118
Muon optimizer with improved momentum handling.
119
120
Args:
121
learning_rate: Learning rate
122
momentum: Momentum coefficient (default: 0.95)
123
nesterov: Whether to use Nesterov momentum (default: False)
124
125
Returns:
126
GradientTransformation: Muon optimizer
127
"""
128
```
129
130
#### MoMo (Momentum Modulation)
131
132
```python { .api }
133
def momo(learning_rate, momentum=0.9):
134
"""
135
MoMo optimizer with momentum modulation.
136
137
Args:
138
learning_rate: Learning rate
139
momentum: Base momentum coefficient (default: 0.9)
140
141
Returns:
142
GradientTransformation: MoMo optimizer
143
"""
144
145
def momo_adam(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8):
146
"""
147
MoMo-Adam combining momentum modulation with Adam.
148
149
Args:
150
learning_rate: Learning rate
151
beta1: First moment decay rate (default: 0.9)
152
beta2: Second moment decay rate (default: 0.999)
153
eps: Numerical stability parameter (default: 1e-8)
154
155
Returns:
156
GradientTransformation: MoMo-Adam optimizer
157
"""
158
```
159
160
### Specialized Methods
161
162
#### DoG (Difference of Gaussians) and DoWG
163
164
```python { .api }
165
def dog(learning_rate, rho=0.05, eps=1e-8):
166
"""
167
DoG (Difference of Gaussians) optimizer.
168
169
Args:
170
learning_rate: Learning rate
171
rho: Difference parameter (default: 0.05)
172
eps: Numerical stability parameter (default: 1e-8)
173
174
Returns:
175
GradientTransformation: DoG optimizer
176
"""
177
178
def dowg(learning_rate, rho=0.05, eps=1e-8, weight_decay=0.0):
179
"""
180
DoWG (DoG with Weight decay) optimizer.
181
182
Args:
183
learning_rate: Learning rate
184
rho: Difference parameter (default: 0.05)
185
eps: Numerical stability parameter (default: 1e-8)
186
weight_decay: Weight decay coefficient (default: 0.0)
187
188
Returns:
189
GradientTransformation: DoWG optimizer
190
"""
191
```
192
193
#### ADOPT
194
195
```python { .api }
196
def adopt(learning_rate, eps=1e-8, beta1=0.9, beta2=0.9999, weight_decay=0.0):
197
"""
198
ADOPT optimizer with adaptive learning rates.
199
200
Args:
201
learning_rate: Learning rate
202
eps: Numerical stability parameter (default: 1e-8)
203
beta1: First moment decay rate (default: 0.9)
204
beta2: Second moment decay rate (default: 0.9999)
205
weight_decay: Weight decay coefficient (default: 0.0)
206
207
Returns:
208
GradientTransformation: ADOPT optimizer
209
"""
210
```
211
212
### Privacy-Preserving Methods
213
214
#### Differential Privacy
215
216
```python { .api }
217
def differentially_private_aggregate(
218
inner_agg_factory,
219
l2_norm_bound,
220
noise_multiplier,
221
seed=None
222
):
223
"""
224
Differentially private gradient aggregation.
225
226
Args:
227
inner_agg_factory: Base aggregation function
228
l2_norm_bound: L2 norm bound for gradient clipping
229
noise_multiplier: Noise multiplier for privacy
230
seed: Random seed (default: None)
231
232
Returns:
233
GradientTransformation: DP aggregation function
234
"""
235
```
236
237
### Experimental Adaptive Methods
238
239
#### AdEMAMix
240
241
```python { .api }
242
def ademamix(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8, alpha=5.0):
243
"""
244
AdEMAMix optimizer with exponential moving average mixing.
245
246
Args:
247
learning_rate: Learning rate
248
beta1: First moment decay rate (default: 0.9)
249
beta2: Second moment decay rate (default: 0.999)
250
eps: Numerical stability parameter (default: 1e-8)
251
alpha: Mixing parameter (default: 5.0)
252
253
Returns:
254
GradientTransformation: AdEMAMix optimizer
255
"""
256
```
257
258
#### COCOB
259
260
```python { .api }
261
def cocob():
262
"""
263
COCOB (Coin-flipping Online Convex Optimization with Budget) optimizer.
264
265
Returns:
266
GradientTransformation: COCOB optimizer (parameter-free)
267
"""
268
```
269
270
## Usage Examples
271
272
```python
273
import optax
274
275
# Using SAM for better generalization
276
base_optimizer = optax.sgd(0.1)
277
sam_optimizer = optax.contrib.sam(base_optimizer, rho=0.05)
278
279
# Using schedule-free optimizers
280
sf_adamw = optax.contrib.schedule_free_adamw(learning_rate=0.001)
281
282
# Using experimental adaptive methods
283
prodigy_opt = optax.contrib.prodigy(learning_rate=1.0)
284
sophia_opt = optax.contrib.sophia(learning_rate=0.001)
285
286
# Training loop with schedule-free optimizer
287
opt_state = sf_adamw.init(params)
288
for step in range(num_steps):
289
grads = compute_gradients(params, data)
290
updates, opt_state = sf_adamw.update(grads, opt_state, params)
291
params = optax.apply_updates(params, updates)
292
293
# Extract evaluation parameters (for schedule-free methods)
294
if step % eval_interval == 0:
295
eval_params = optax.contrib.schedule_free_eval_params(opt_state, step)
296
eval_loss = evaluate(eval_params, eval_data)
297
```
298
299
## Import
300
301
```python
302
import optax.contrib
303
# or
304
from optax.contrib import sam, prodigy, schedule_free_adamw
305
```
306
307
## Research Papers
308
309
Many contrib optimizers are based on recent research:
310
311
- **SAM**: "Sharpness-Aware Minimization for Efficiently Improving Generalization"
312
- **Prodigy**: "Prodigy: An Expeditiously Adaptive Parameter-Free Learner"
313
- **Sophia**: "Sophia: A Scalable Stochastic Second-order Optimizer"
314
- **Schedule-Free**: "The Road Less Scheduled"
315
- **AdEMAMix**: "The AdEMAMix Optimizer: Better, Faster, Older"
316
317
Refer to the respective papers for detailed algorithmic descriptions and theoretical analysis.