0
# Learning Rate Schedules
1
2
Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules. These schedules help optimize training dynamics and achieve better convergence.
3
4
## Capabilities
5
6
### Basic Schedules
7
8
#### Constant Schedule
9
10
```python { .api }
11
def constant_schedule(value):
12
"""
13
Constant value schedule.
14
15
Args:
16
value: Constant value to return
17
18
Returns:
19
Schedule function
20
"""
21
```
22
23
#### Linear Schedule
24
25
```python { .api }
26
def linear_schedule(init_value, end_value, transition_steps):
27
"""
28
Linear interpolation between two values.
29
30
Args:
31
init_value: Initial value
32
end_value: Final value
33
transition_steps: Number of steps for transition
34
35
Returns:
36
Schedule function
37
"""
38
```
39
40
#### Polynomial Schedule
41
42
```python { .api }
43
def polynomial_schedule(init_value, end_value, power, transition_steps):
44
"""
45
Polynomial decay schedule.
46
47
Args:
48
init_value: Initial value
49
end_value: Final value
50
power: Polynomial power (1.0 = linear, 2.0 = quadratic, etc.)
51
transition_steps: Number of steps for transition
52
53
Returns:
54
Schedule function
55
"""
56
```
57
58
### Exponential Decay
59
60
```python { .api }
61
def exponential_decay(init_value, decay_rate, transition_steps, transition_begin=0, staircase=False, end_value=None):
62
"""
63
Exponential decay schedule.
64
65
Args:
66
init_value: Initial value
67
decay_rate: Decay rate (e.g., 0.96 for 4% decay)
68
transition_steps: Steps between decay applications
69
transition_begin: Step to begin decay (default: 0)
70
staircase: Whether to apply decay in discrete steps (default: False)
71
end_value: Minimum value to decay to (default: None)
72
73
Returns:
74
Schedule function
75
"""
76
```
77
78
### Cosine Schedules
79
80
#### Cosine Decay
81
82
```python { .api }
83
def cosine_decay_schedule(init_value, decay_steps, alpha=0.0):
84
"""
85
Cosine decay schedule.
86
87
Args:
88
init_value: Initial value
89
decay_steps: Number of steps for full cosine cycle
90
alpha: Minimum value as fraction of init_value (default: 0.0)
91
92
Returns:
93
Schedule function
94
"""
95
```
96
97
#### Cosine One-Cycle
98
99
```python { .api }
100
def cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):
101
"""
102
One-cycle cosine schedule (warmup, decay, final decay).
103
104
Args:
105
transition_steps: Total number of steps
106
peak_value: Maximum value at peak
107
pct_start: Percentage of steps for warmup phase (default: 0.3)
108
pct_final: Percentage of steps before final decay (default: 0.85)
109
final_div_factor: Final value divisor (default: 1e4)
110
111
Returns:
112
Schedule function
113
"""
114
```
115
116
### Piecewise Schedules
117
118
#### Piecewise Constant
119
120
```python { .api }
121
def piecewise_constant_schedule(boundaries_and_scales):
122
"""
123
Piecewise constant schedule with different values in different intervals.
124
125
Args:
126
boundaries_and_scales: Dict mapping step boundaries to scale factors
127
128
Returns:
129
Schedule function
130
"""
131
```
132
133
#### Piecewise Interpolate
134
135
```python { .api }
136
def piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales):
137
"""
138
Piecewise schedule with interpolation between boundaries.
139
140
Args:
141
interpolate_type: Type of interpolation ('linear', 'cosine')
142
init_value: Initial value
143
boundaries_and_scales: Dict mapping boundaries to scale factors
144
145
Returns:
146
Schedule function
147
"""
148
```
149
150
### Warmup Schedules
151
152
#### Warmup + Constant
153
154
```python { .api }
155
def warmup_constant_schedule(init_value, peak_value, warmup_steps):
156
"""
157
Linear warmup followed by constant value.
158
159
Args:
160
init_value: Initial value during warmup
161
peak_value: Constant value after warmup
162
warmup_steps: Number of warmup steps
163
164
Returns:
165
Schedule function
166
"""
167
```
168
169
#### Warmup + Cosine Decay
170
171
```python { .api }
172
def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0):
173
"""
174
Linear warmup followed by cosine decay.
175
176
Args:
177
init_value: Initial value during warmup
178
peak_value: Peak value after warmup
179
warmup_steps: Number of warmup steps
180
decay_steps: Number of decay steps after warmup
181
end_value: Final value after decay (default: 0.0)
182
183
Returns:
184
Schedule function
185
"""
186
```
187
188
#### Warmup + Exponential Decay
189
190
```python { .api }
191
def warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None):
192
"""
193
Linear warmup followed by exponential decay.
194
195
Args:
196
init_value: Initial value during warmup
197
peak_value: Peak value after warmup
198
warmup_steps: Number of warmup steps
199
transition_steps: Steps between decay applications
200
decay_rate: Exponential decay rate
201
transition_begin: Step to begin decay (default: 0)
202
staircase: Whether to apply decay in discrete steps (default: False)
203
end_value: Minimum decay value (default: None)
204
205
Returns:
206
Schedule function
207
"""
208
```
209
210
### Advanced Schedules
211
212
#### Linear One-Cycle
213
214
```python { .api }
215
def linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):
216
"""
217
One-cycle linear schedule (warmup, decay, final decay).
218
219
Args:
220
transition_steps: Total number of steps
221
peak_value: Maximum value at peak
222
pct_start: Percentage of steps for warmup phase (default: 0.3)
223
pct_final: Percentage of steps before final decay (default: 0.85)
224
final_div_factor: Final value divisor (default: 1e4)
225
226
Returns:
227
Schedule function
228
"""
229
```
230
231
#### SGDR Schedule
232
233
```python { .api }
234
def sgdr_schedule(cosine_decay_schedule, restart_period, t_mult=1.0):
235
"""
236
Stochastic Gradient Descent with Restarts (SGDR) schedule.
237
238
Args:
239
cosine_decay_schedule: Base cosine decay schedule
240
restart_period: Initial restart period
241
t_mult: Multiplier for restart period (default: 1.0)
242
243
Returns:
244
Schedule function
245
"""
246
```
247
248
### Schedule Composition
249
250
#### Join Schedules
251
252
```python { .api }
253
def join_schedules(schedules, boundaries):
254
"""
255
Join multiple schedules at specified boundaries.
256
257
Args:
258
schedules: List of schedule functions
259
boundaries: List of step boundaries for schedule transitions
260
261
Returns:
262
Combined schedule function
263
"""
264
```
265
266
### Hyperparameter Injection
267
268
#### Static Hyperparameters
269
270
```python { .api }
271
def inject_hyperparams(transformation, **scheduled_hyperparams):
272
"""
273
Inject scheduled hyperparameters into transformation.
274
275
Args:
276
transformation: Base gradient transformation
277
**scheduled_hyperparams: Named schedule functions for hyperparameters
278
279
Returns:
280
GradientTransformation with scheduled hyperparameters
281
"""
282
```
283
284
#### Stateful Hyperparameters
285
286
```python { .api }
287
def inject_stateful_hyperparams(transformation, **scheduled_hyperparams):
288
"""
289
Inject stateful scheduled hyperparameters into transformation.
290
291
Args:
292
transformation: Base gradient transformation
293
**scheduled_hyperparams: Named stateful schedule functions
294
295
Returns:
296
GradientTransformation with stateful scheduled hyperparameters
297
"""
298
```
299
300
### Schedule State Classes
301
302
```python { .api }
303
class InjectHyperparamsState:
304
"""State for hyperparameter injection."""
305
count: int
306
inner_state: OptState
307
308
class InjectStatefulHyperparamsState:
309
"""State for stateful hyperparameter injection."""
310
count: int
311
inner_state: OptState
312
hyperparams_states: dict
313
314
class WrappedSchedule:
315
"""Wrapper for schedule functions with state."""
316
schedule_fn: Schedule
317
```
318
319
## Usage Examples
320
321
### Basic Schedule Usage
322
323
```python
324
import optax
325
326
# Create different schedules
327
constant_lr = optax.constant_schedule(0.001)
328
linear_decay = optax.linear_schedule(0.001, 0.0001, 1000)
329
cosine_decay = optax.cosine_decay_schedule(0.001, 1000)
330
exponential_decay = optax.exponential_decay(0.001, 0.96, 100)
331
332
# Use schedule with optimizer
333
optimizer = optax.adam(learning_rate=cosine_decay)
334
335
# Evaluate schedule at different steps
336
step_0_lr = constant_lr(0) # 0.001
337
step_500_lr = linear_decay(500) # 0.0005
338
step_1000_lr = cosine_decay(1000) # close to 0
339
```
340
341
### Warmup Schedules
342
343
```python
344
# Warmup followed by cosine decay
345
warmup_cosine = optax.warmup_cosine_decay_schedule(
346
init_value=0.0,
347
peak_value=0.001,
348
warmup_steps=1000,
349
decay_steps=9000,
350
end_value=0.00001
351
)
352
353
# Warmup followed by constant
354
warmup_constant = optax.warmup_constant_schedule(
355
init_value=0.0,
356
peak_value=0.001,
357
warmup_steps=500
358
)
359
360
# Use with optimizer
361
optimizer = optax.adamw(learning_rate=warmup_cosine, weight_decay=0.01)
362
```
363
364
### Piecewise Schedules
365
366
```python
367
# Different learning rates at different training phases
368
boundaries_and_scales = {
369
500: 1.0, # LR = init_value * 1.0 until step 500
370
1000: 0.5, # LR = init_value * 0.5 from step 500-1000
371
1500: 0.1 # LR = init_value * 0.1 from step 1000-1500
372
}
373
374
piecewise_sched = optax.piecewise_constant_schedule(boundaries_and_scales)
375
376
# With interpolation
377
piecewise_interp = optax.piecewise_interpolate_schedule(
378
'linear', 0.001, boundaries_and_scales
379
)
380
```
381
382
### Advanced Scheduling
383
384
```python
385
# One-cycle schedule
386
onecycle = optax.cosine_onecycle_schedule(
387
transition_steps=5000,
388
peak_value=0.01,
389
pct_start=0.3, # 30% warmup
390
pct_final=0.85 # 85% before final decay
391
)
392
393
# SGDR with restarts
394
base_cosine = optax.cosine_decay_schedule(0.001, 1000)
395
sgdr = optax.sgdr_schedule(base_cosine, restart_period=1000, t_mult=2.0)
396
397
# Join multiple schedules
398
schedules = [
399
optax.constant_schedule(0.001), # First 1000 steps
400
optax.linear_schedule(0.001, 0.0001, 1000) # Next 1000 steps
401
]
402
joined = optax.join_schedules(schedules, [1000])
403
```
404
405
### Hyperparameter Scheduling
406
407
```python
408
# Schedule multiple hyperparameters
409
base_transform = optax.scale_by_adam()
410
411
scheduled_transform = optax.inject_hyperparams(
412
base_transform,
413
learning_rate=optax.cosine_decay_schedule(0.001, 1000),
414
b1=optax.linear_schedule(0.9, 0.95, 500),
415
b2=optax.constant_schedule(0.999)
416
)
417
418
# Create complete optimizer
419
optimizer = optax.chain(
420
scheduled_transform,
421
optax.scale(-1.0) # Apply negative learning rate
422
)
423
```
424
425
### Training Loop Integration
426
427
```python
428
import jax
429
430
# Create schedule
431
schedule = optax.warmup_cosine_decay_schedule(
432
init_value=0.0,
433
peak_value=0.001,
434
warmup_steps=1000,
435
decay_steps=9000
436
)
437
438
optimizer = optax.adam(learning_rate=schedule)
439
440
def train_step(params, opt_state, batch, step):
441
"""Training step with scheduled learning rate."""
442
443
def loss_fn(p):
444
return compute_loss(p, batch)
445
446
loss_val, grads = jax.value_and_grad(loss_fn)(params)
447
updates, opt_state = optimizer.update(grads, opt_state, params)
448
params = optax.apply_updates(params, updates)
449
450
# Current learning rate for logging
451
current_lr = schedule(step)
452
453
return params, opt_state, loss_val, current_lr
454
```