0
# Advanced Optimizers
1
2
Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers. These optimizers implement cutting-edge techniques and may require more careful tuning than core optimizers.
3
4
## Capabilities
5
6
### Lion Optimizer
7
8
Lion (Evolved Sign Momentum) optimizer that uses sign-based updates for memory efficiency and competitive performance.
9
10
```python { .api }
11
def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0):
12
"""
13
Lion optimizer (Evolved Sign Momentum).
14
15
Args:
16
learning_rate: Learning rate or schedule
17
b1: Exponential decay rate for momentum (default: 0.9)
18
b2: Exponential decay rate for moving average (default: 0.99)
19
weight_decay: Weight decay coefficient (default: 0.0)
20
21
Returns:
22
GradientTransformation
23
"""
24
```
25
26
### LARS Optimizer
27
28
Layer-wise Adaptive Rate Scaling (LARS) optimizer for large batch training.
29
30
```python { .api }
31
def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.):
32
"""
33
LARS (Layer-wise Adaptive Rate Scaling) optimizer.
34
35
Args:
36
learning_rate: Learning rate or schedule
37
weight_decay: Weight decay coefficient (default: 0.0)
38
trust_coefficient: Trust coefficient for layer-wise adaptation (default: 0.001)
39
eps: Small constant for numerical stability (default: 0.0)
40
41
Returns:
42
GradientTransformation
43
"""
44
```
45
46
### LAMB Optimizer
47
48
Layer-wise Adaptive Moments optimizer for Batch training, designed for large batch sizes.
49
50
```python { .api }
51
def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None):
52
"""
53
LAMB (Layer-wise Adaptive Moments optimizer for Batch training) optimizer.
54
55
Args:
56
learning_rate: Learning rate or schedule
57
b1: Exponential decay rate for first moment estimates (default: 0.9)
58
b2: Exponential decay rate for second moment estimates (default: 0.999)
59
eps: Small constant for numerical stability (default: 1e-6)
60
weight_decay: Weight decay coefficient (default: 0.0)
61
mask: Optional mask for parameter selection
62
63
Returns:
64
GradientTransformation
65
"""
66
```
67
68
### L-BFGS Optimizer
69
70
Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.
71
72
```python { .api }
73
def lbfgs(learning_rate, memory_size=10, scale_init_preconditioner=True):
74
"""
75
L-BFGS quasi-Newton optimizer.
76
77
Args:
78
learning_rate: Learning rate or schedule
79
memory_size: Number of previous gradients to store (default: 10)
80
scale_init_preconditioner: Whether to scale initial preconditioner (default: True)
81
82
Returns:
83
GradientTransformation
84
"""
85
```
86
87
### Yogi Optimizer
88
89
Yogi optimizer that controls the increase in effective learning rate to avoid rapid convergence.
90
91
```python { .api }
92
def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6):
93
"""
94
Yogi optimizer.
95
96
Args:
97
learning_rate: Learning rate or schedule
98
b1: Exponential decay rate for first moment estimates (default: 0.9)
99
b2: Exponential decay rate for second moment estimates (default: 0.999)
100
eps: Small constant for numerical stability (default: 1e-3)
101
initial_accumulator: Initial value for accumulator (default: 1e-6)
102
103
Returns:
104
GradientTransformation
105
"""
106
```
107
108
### NovoGrad Optimizer
109
110
NovoGrad optimizer that combines adaptive learning rates with gradient normalization.
111
112
```python { .api }
113
def novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-6, weight_decay=0.):
114
"""
115
NovoGrad optimizer.
116
117
Args:
118
learning_rate: Learning rate or schedule
119
b1: Exponential decay rate for first moment estimates (default: 0.9)
120
b2: Exponential decay rate for second moment estimates (default: 0.25)
121
eps: Small constant for numerical stability (default: 1e-6)
122
weight_decay: Weight decay coefficient (default: 0.0)
123
124
Returns:
125
GradientTransformation
126
"""
127
```
128
129
### RAdam Optimizer
130
131
Rectified Adam optimizer that addresses the variance issue in early training stages.
132
133
```python { .api }
134
def radam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, threshold=5.0):
135
"""
136
RAdam (Rectified Adam) optimizer.
137
138
Args:
139
learning_rate: Learning rate or schedule
140
b1: Exponential decay rate for first moment estimates (default: 0.9)
141
b2: Exponential decay rate for second moment estimates (default: 0.999)
142
eps: Small constant for numerical stability (default: 1e-8)
143
threshold: Threshold for variance tractability (default: 5.0)
144
145
Returns:
146
GradientTransformation
147
"""
148
```
149
150
### SM3 Optimizer
151
152
SM3 optimizer designed for sparse gradients with memory-efficient second moments.
153
154
```python { .api }
155
def sm3(learning_rate, momentum=0.9):
156
"""
157
SM3 optimizer for sparse gradients.
158
159
Args:
160
learning_rate: Learning rate or schedule
161
momentum: Momentum coefficient (default: 0.9)
162
163
Returns:
164
GradientTransformation
165
"""
166
```
167
168
### Fromage Optimizer
169
170
Frobenius matched gradient descent optimizer.
171
172
```python { .api }
173
def fromage(learning_rate):
174
"""
175
Fromage (Frobenius matched gradient descent) optimizer.
176
177
Args:
178
learning_rate: Learning rate or schedule
179
180
Returns:
181
GradientTransformation
182
"""
183
```
184
185
### Specialized SGD Variants
186
187
#### Noisy SGD
188
189
SGD with gradient noise injection for improved generalization.
190
191
```python { .api }
192
def noisy_sgd(learning_rate, eta=0.01):
193
"""
194
Noisy SGD with gradient noise injection.
195
196
Args:
197
learning_rate: Learning rate or schedule
198
eta: Noise scaling parameter (default: 0.01)
199
200
Returns:
201
GradientTransformation
202
"""
203
```
204
205
#### Sign SGD
206
207
SGD using only the sign of gradients.
208
209
```python { .api }
210
def sign_sgd(learning_rate):
211
"""
212
Sign SGD optimizer using gradient signs only.
213
214
Args:
215
learning_rate: Learning rate or schedule
216
217
Returns:
218
GradientTransformation
219
"""
220
```
221
222
#### Polyak SGD
223
224
SGD with Polyak momentum.
225
226
```python { .api }
227
def polyak_sgd(learning_rate, polyak_momentum=0.9):
228
"""
229
SGD with Polyak momentum.
230
231
Args:
232
learning_rate: Learning rate or schedule
233
polyak_momentum: Polyak momentum coefficient (default: 0.9)
234
235
Returns:
236
GradientTransformation
237
"""
238
```
239
240
### RProp Optimizer
241
242
Resilient backpropagation optimizer that uses only gradient signs.
243
244
```python { .api }
245
def rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-6, max_step_size=50.):
246
"""
247
RProp (Resilient backpropagation) optimizer.
248
249
Args:
250
learning_rate: Initial step size
251
eta_minus: Factor for decreasing step size (default: 0.5)
252
eta_plus: Factor for increasing step size (default: 1.2)
253
min_step_size: Minimum step size (default: 1e-6)
254
max_step_size: Maximum step size (default: 50.0)
255
256
Returns:
257
GradientTransformation
258
"""
259
```
260
261
### Optimistic Methods
262
263
#### Optimistic Gradient Descent
264
265
Optimistic gradient descent for saddle point problems.
266
267
```python { .api }
268
def optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0):
269
"""
270
Optimistic gradient descent.
271
272
Args:
273
learning_rate: Learning rate or schedule
274
alpha: Extrapolation coefficient (default: 1.0)
275
beta: Update coefficient (default: 1.0)
276
277
Returns:
278
GradientTransformation
279
"""
280
```
281
282
#### Optimistic Adam
283
284
Optimistic variant of Adam optimizer.
285
286
```python { .api }
287
def optimistic_adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
288
"""
289
Optimistic Adam optimizer.
290
291
Args:
292
learning_rate: Learning rate or schedule
293
b1: Exponential decay rate for first moment estimates (default: 0.9)
294
b2: Exponential decay rate for second moment estimates (default: 0.999)
295
eps: Small constant for numerical stability (default: 1e-8)
296
297
Returns:
298
GradientTransformation
299
"""
300
```
301
302
### Lookahead Wrapper
303
304
Lookahead optimizer that can wrap any base optimizer.
305
306
```python { .api }
307
def lookahead(fast_optimizer, lookahead_steps=5, lookahead_alpha=0.5):
308
"""
309
Lookahead optimizer wrapper.
310
311
Args:
312
fast_optimizer: Base optimizer to wrap
313
lookahead_steps: Number of fast optimizer steps before lookahead (default: 5)
314
lookahead_alpha: Interpolation factor for lookahead (default: 0.5)
315
316
Returns:
317
GradientTransformation
318
"""
319
```
320
321
## Usage Example
322
323
```python
324
import optax
325
import jax.numpy as jnp
326
327
# Initialize parameters
328
params = {'weights': jnp.ones((100, 50)), 'bias': jnp.zeros((50,))}
329
330
# Advanced optimizers for different scenarios
331
lion_opt = optax.lion(learning_rate=0.0001) # Memory efficient
332
lars_opt = optax.lars(learning_rate=0.01) # Large batch training
333
lamb_opt = optax.lamb(learning_rate=0.001) # Large batch training
334
lbfgs_opt = optax.lbfgs(learning_rate=1.0) # Second-order method
335
336
# Lookahead wrapper
337
base_opt = optax.adam(learning_rate=0.001)
338
lookahead_opt = optax.lookahead(base_opt, lookahead_steps=5)
339
340
# Initialize states
341
lion_state = lion_opt.init(params)
342
lookahead_state = lookahead_opt.init(params)
343
344
# Usage in training loop
345
def training_step(params, opt_state, gradients, optimizer):
346
updates, new_opt_state = optimizer.update(gradients, opt_state)
347
new_params = optax.apply_updates(params, updates)
348
return new_params, new_opt_state
349
```