0
# Optimization
1
2
Specialized optimizers and learning rate schedulers designed for transformer model training and fine-tuning. These optimization tools implement best practices for training large language models with proper weight decay, warmup schedules, and learning rate decay patterns.
3
4
## Capabilities
5
6
### AdamW Optimizer
7
8
Adam optimizer with weight decay fix, specifically designed for transformer models. Unlike standard Adam with L2 regularization, AdamW applies weight decay directly to the parameters.
9
10
```python { .api }
11
class AdamW:
12
def __init__(
13
self,
14
params,
15
lr=1e-3,
16
betas=(0.9, 0.999),
17
eps=1e-8,
18
weight_decay=0.01,
19
correct_bias=True
20
):
21
"""
22
Initialize AdamW optimizer.
23
24
Parameters:
25
- params: Iterable of parameters to optimize
26
- lr (float): Learning rate
27
- betas (Tuple[float, float]): Coefficients for gradient and squared gradient moving averages
28
- eps (float): Term added to denominator for numerical stability
29
- weight_decay (float): Weight decay coefficient
30
- correct_bias (bool): Whether to correct bias in moment estimates
31
"""
32
33
def step(self, closure=None):
34
"""
35
Perform a single optimization step.
36
37
Parameters:
38
- closure (callable, optional): Closure that reevaluates model and returns loss
39
40
Returns:
41
float: Loss value if closure is provided
42
"""
43
44
def zero_grad(self):
45
"""
46
Clear gradients of all optimized parameters.
47
"""
48
```
49
50
**Usage Example:**
51
52
```python
53
from pytorch_transformers import AdamW, BertForSequenceClassification
54
import torch
55
56
# Load model
57
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
58
59
# Initialize optimizer
60
optimizer = AdamW(
61
model.parameters(),
62
lr=2e-5,
63
weight_decay=0.01,
64
correct_bias=False
65
)
66
67
# Training step
68
inputs = torch.randint(0, 1000, (8, 128)) # Dummy input
69
labels = torch.randint(0, 2, (8,)) # Dummy labels
70
71
optimizer.zero_grad()
72
outputs = model(inputs, labels=labels)
73
loss = outputs.loss
74
loss.backward()
75
optimizer.step()
76
77
print(f"Loss: {loss.item():.4f}")
78
```
79
80
### Learning Rate Schedulers
81
82
Various learning rate scheduling strategies commonly used in transformer training, including warmup phases and different decay patterns.
83
84
#### ConstantLRSchedule
85
86
Maintains a constant learning rate throughout training.
87
88
```python { .api }
89
def ConstantLRSchedule(optimizer, last_epoch=-1):
90
"""
91
Create a constant learning rate schedule.
92
93
Parameters:
94
- optimizer: Wrapped optimizer
95
- last_epoch (int): Index of last epoch
96
97
Returns:
98
LambdaLR: Learning rate scheduler
99
"""
100
```
101
102
#### WarmupConstantSchedule
103
104
Linear warmup followed by constant learning rate.
105
106
```python { .api }
107
def WarmupConstantSchedule(optimizer, warmup_steps, last_epoch=-1):
108
"""
109
Create a schedule with linear warmup followed by constant learning rate.
110
111
Parameters:
112
- optimizer: Wrapped optimizer
113
- warmup_steps (int): Number of warmup steps
114
- last_epoch (int): Index of last epoch
115
116
Returns:
117
LambdaLR: Learning rate scheduler
118
"""
119
```
120
121
#### WarmupLinearSchedule
122
123
Linear warmup followed by linear decay to zero.
124
125
```python { .api }
126
def WarmupLinearSchedule(optimizer, warmup_steps, t_total, last_epoch=-1):
127
"""
128
Create a schedule with linear warmup followed by linear decay.
129
130
Parameters:
131
- optimizer: Wrapped optimizer
132
- warmup_steps (int): Number of warmup steps
133
- t_total (int): Total number of training steps
134
- last_epoch (int): Index of last epoch
135
136
Returns:
137
LambdaLR: Learning rate scheduler
138
"""
139
```
140
141
#### WarmupCosineSchedule
142
143
Linear warmup followed by cosine annealing decay.
144
145
```python { .api }
146
def WarmupCosineSchedule(optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):
147
"""
148
Create a schedule with linear warmup followed by cosine annealing.
149
150
Parameters:
151
- optimizer: Wrapped optimizer
152
- warmup_steps (int): Number of warmup steps
153
- t_total (int): Total number of training steps
154
- cycles (float): Number of cosine cycles (0.5 for half cosine)
155
- last_epoch (int): Index of last epoch
156
157
Returns:
158
LambdaLR: Learning rate scheduler
159
"""
160
```
161
162
#### WarmupCosineWithHardRestartsSchedule
163
164
Linear warmup followed by cosine annealing with hard restarts.
165
166
```python { .api }
167
def WarmupCosineWithHardRestartsSchedule(optimizer, warmup_steps, t_total, cycles=1.0, last_epoch=-1):
168
"""
169
Create a schedule with linear warmup followed by cosine annealing with hard restarts.
170
171
Parameters:
172
- optimizer: Wrapped optimizer
173
- warmup_steps (int): Number of warmup steps
174
- t_total (int): Total number of training steps
175
- cycles (float): Number of restart cycles
176
- last_epoch (int): Index of last epoch
177
178
Returns:
179
LambdaLR: Learning rate scheduler
180
"""
181
```
182
183
**Usage Examples:**
184
185
```python
186
from pytorch_transformers import (
187
AdamW,
188
WarmupLinearSchedule,
189
WarmupCosineSchedule,
190
WarmupConstantSchedule
191
)
192
193
# Setup model and optimizer
194
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
195
optimizer = AdamW(model.parameters(), lr=2e-5)
196
197
# Training configuration
198
num_epochs = 3
199
num_training_steps = 1000
200
warmup_steps = 100
201
202
# Linear schedule with warmup
203
linear_scheduler = WarmupLinearSchedule(
204
optimizer,
205
warmup_steps=warmup_steps,
206
t_total=num_training_steps
207
)
208
209
# Cosine schedule with warmup
210
cosine_scheduler = WarmupCosineSchedule(
211
optimizer,
212
warmup_steps=warmup_steps,
213
t_total=num_training_steps,
214
cycles=0.5
215
)
216
217
# Constant schedule with warmup
218
constant_scheduler = WarmupConstantSchedule(
219
optimizer,
220
warmup_steps=warmup_steps
221
)
222
223
# Training loop example
224
for epoch in range(num_epochs):
225
for step in range(num_training_steps // num_epochs):
226
# Training step
227
optimizer.zero_grad()
228
# ... forward pass, loss calculation, backward pass ...
229
optimizer.step()
230
linear_scheduler.step() # Update learning rate
231
232
# Log current learning rate
233
current_lr = optimizer.param_groups[0]['lr']
234
if step % 100 == 0:
235
print(f"Epoch {epoch}, Step {step}, LR: {current_lr:.2e}")
236
```
237
238
## Optimization Best Practices
239
240
### Learning Rate Selection
241
242
**Fine-tuning Pre-trained Models:**
243
- BERT/RoBERTa: 2e-5, 3e-5, 5e-5
244
- GPT-2: 1e-4, 2e-4, 5e-4
245
- Smaller models: Higher learning rates (up to 1e-3)
246
247
**Warmup Steps:**
248
- Typically 10% of total training steps
249
- For short training: 500-1000 steps
250
- For long training: 5000-10000 steps
251
252
```python
253
# Recommended setup for BERT fine-tuning
254
total_steps = len(train_dataloader) * num_epochs
255
warmup_steps = int(0.1 * total_steps)
256
257
optimizer = AdamW(
258
model.parameters(),
259
lr=2e-5,
260
weight_decay=0.01,
261
correct_bias=False
262
)
263
264
scheduler = WarmupLinearSchedule(
265
optimizer,
266
warmup_steps=warmup_steps,
267
t_total=total_steps
268
)
269
```
270
271
### Weight Decay Configuration
272
273
**Recommended weight decay values:**
274
- Default: 0.01
275
- Larger models: 0.1
276
- Smaller models: 0.001
277
278
**Parameter groups with different weight decay:**
279
280
```python
281
# Apply weight decay only to weights, not biases or layer norms
282
no_decay = ["bias", "LayerNorm.weight"]
283
optimizer_grouped_parameters = [
284
{
285
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
286
"weight_decay": 0.01,
287
},
288
{
289
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
290
"weight_decay": 0.0,
291
},
292
]
293
294
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
295
```
296
297
### Gradient Clipping
298
299
```python
300
import torch.nn.utils as nn_utils
301
302
# Training step with gradient clipping
303
optimizer.zero_grad()
304
loss.backward()
305
306
# Clip gradients to prevent exploding gradients
307
nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
308
309
optimizer.step()
310
scheduler.step()
311
```
312
313
### Mixed Precision Training
314
315
```python
316
from torch.cuda.amp import autocast, GradScaler
317
318
# Initialize gradient scaler for mixed precision
319
scaler = GradScaler()
320
321
# Training step with mixed precision
322
optimizer.zero_grad()
323
324
with autocast():
325
outputs = model(**inputs)
326
loss = outputs.loss
327
328
# Scale loss and backward pass
329
scaler.scale(loss).backward()
330
scaler.step(optimizer)
331
scaler.update()
332
scheduler.step()
333
```
334
335
## Schedule Visualization
336
337
Different learning rate schedules behave differently during training:
338
339
**Linear Schedule**: Steady decrease after warmup
340
- Best for: Most fine-tuning tasks
341
- Characteristics: Predictable, stable convergence
342
343
**Cosine Schedule**: Smooth decay following cosine curve
344
- Best for: Long training runs, better final performance
345
- Characteristics: Slower initial decay, faster final decay
346
347
**Constant Schedule**: Maintains rate after warmup
348
- Best for: Continued pre-training, domain adaptation
349
- Characteristics: No decay, constant exploration
350
351
**Cosine with Restarts**: Periodic learning rate increases
352
- Best for: Finding better local minima, avoiding plateaus
353
- Characteristics: Multiple convergence opportunities