0
# Training and Optimization
1
2
Optimizers, learning rate schedulers, and training utilities for model optimization and parameter updates. The torch.optim module provides optimization algorithms and learning rate scheduling strategies.
3
4
## Capabilities
5
6
### Optimizers
7
8
Optimization algorithms for updating model parameters during training.
9
10
```python { .api }
11
class Optimizer:
12
"""Base class for all optimizers."""
13
def __init__(self, params, defaults): ...
14
def state_dict(self):
15
"""Return optimizer state dictionary."""
16
def load_state_dict(self, state_dict):
17
"""Load optimizer state."""
18
def zero_grad(self, set_to_none: bool = False):
19
"""Set gradients to zero."""
20
def step(self, closure=None):
21
"""Perform optimization step."""
22
def add_param_group(self, param_group):
23
"""Add parameter group."""
24
```
25
26
### SGD Optimizers
27
28
Stochastic Gradient Descent and variants.
29
30
```python { .api }
31
class SGD(Optimizer):
32
"""Stochastic Gradient Descent optimizer."""
33
def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False):
34
"""
35
Parameters:
36
- params: Iterable of parameters or parameter groups
37
- lr: Learning rate
38
- momentum: Momentum factor (default: 0)
39
- dampening: Dampening for momentum (default: 0)
40
- weight_decay: Weight decay (L2 penalty) (default: 0)
41
- nesterov: Enable Nesterov momentum (default: False)
42
"""
43
def step(self, closure=None): ...
44
45
class ASGD(Optimizer):
46
"""Averaged Stochastic Gradient Descent."""
47
def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0):
48
"""
49
Parameters:
50
- params: Iterable of parameters
51
- lr: Learning rate (default: 1e-2)
52
- lambd: Decay term (default: 1e-4)
53
- alpha: Power for eta update (default: 0.75)
54
- t0: Point at which to start averaging (default: 1e6)
55
- weight_decay: Weight decay (default: 0)
56
"""
57
def step(self, closure=None): ...
58
```
59
60
### Adam-family Optimizers
61
62
Adam and its variants for adaptive learning rates.
63
64
```python { .api }
65
class Adam(Optimizer):
66
"""Adam optimizer."""
67
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
68
"""
69
Parameters:
70
- params: Iterable of parameters
71
- lr: Learning rate (default: 1e-3)
72
- betas: Coefficients for momentum and squared gradient averaging (default: (0.9, 0.999))
73
- eps: Term for numerical stability (default: 1e-8)
74
- weight_decay: Weight decay (default: 0)
75
- amsgrad: Use AMSGrad variant (default: False)
76
"""
77
def step(self, closure=None): ...
78
79
class AdamW(Optimizer):
80
"""AdamW optimizer with decoupled weight decay."""
81
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False):
82
"""
83
Parameters:
84
- params: Iterable of parameters
85
- lr: Learning rate (default: 1e-3)
86
- betas: Coefficients for momentum and squared gradient averaging
87
- eps: Term for numerical stability
88
- weight_decay: Weight decay coefficient (default: 1e-2)
89
- amsgrad: Use AMSGrad variant
90
"""
91
def step(self, closure=None): ...
92
93
class Adamax(Optimizer):
94
"""Adamax optimizer (Adam based on infinity norm)."""
95
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
96
"""
97
Parameters:
98
- params: Iterable of parameters
99
- lr: Learning rate (default: 2e-3)
100
- betas: Coefficients for momentum and squared gradient averaging
101
- eps: Term for numerical stability
102
- weight_decay: Weight decay
103
"""
104
def step(self, closure=None): ...
105
106
class NAdam(Optimizer):
107
"""NAdam optimizer (Adam with Nesterov momentum)."""
108
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, momentum_decay=4e-3):
109
"""
110
Parameters:
111
- params: Iterable of parameters
112
- lr: Learning rate (default: 2e-3)
113
- betas: Coefficients for momentum and squared gradient averaging
114
- eps: Term for numerical stability
115
- weight_decay: Weight decay
116
- momentum_decay: Momentum decay
117
"""
118
def step(self, closure=None): ...
119
120
class RAdam(Optimizer):
121
"""RAdam optimizer (Rectified Adam)."""
122
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
123
"""
124
Parameters:
125
- params: Iterable of parameters
126
- lr: Learning rate (default: 1e-3)
127
- betas: Coefficients for momentum and squared gradient averaging
128
- eps: Term for numerical stability
129
- weight_decay: Weight decay
130
"""
131
def step(self, closure=None): ...
132
```
133
134
### Adaptive Learning Rate Optimizers
135
136
Optimizers that adapt learning rates based on gradient history.
137
138
```python { .api }
139
class Adagrad(Optimizer):
140
"""Adagrad optimizer."""
141
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10):
142
"""
143
Parameters:
144
- params: Iterable of parameters
145
- lr: Learning rate (default: 1e-2)
146
- lr_decay: Learning rate decay (default: 0)
147
- weight_decay: Weight decay (default: 0)
148
- initial_accumulator_value: Initial value for accumulator
149
- eps: Term for numerical stability
150
"""
151
def step(self, closure=None): ...
152
153
class Adadelta(Optimizer):
154
"""Adadelta optimizer."""
155
def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0):
156
"""
157
Parameters:
158
- params: Iterable of parameters
159
- lr: Coefficient that scales delta (default: 1.0)
160
- rho: Coefficient for squared gradient averaging (default: 0.9)
161
- eps: Term for numerical stability (default: 1e-6)
162
- weight_decay: Weight decay (default: 0)
163
"""
164
def step(self, closure=None): ...
165
166
class RMSprop(Optimizer):
167
"""RMSprop optimizer."""
168
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
169
"""
170
Parameters:
171
- params: Iterable of parameters
172
- lr: Learning rate (default: 1e-2)
173
- alpha: Smoothing constant (default: 0.99)
174
- eps: Term for numerical stability (default: 1e-8)
175
- weight_decay: Weight decay (default: 0)
176
- momentum: Momentum factor (default: 0)
177
- centered: Compute centered RMSprop (default: False)
178
"""
179
def step(self, closure=None): ...
180
181
class Rprop(Optimizer):
182
"""Rprop optimizer."""
183
def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):
184
"""
185
Parameters:
186
- params: Iterable of parameters
187
- lr: Learning rate (default: 1e-2)
188
- etas: Pair of (etaminus, etaplus) for multiplicative increase/decrease
189
- step_sizes: Pair of minimal and maximal allowed step sizes
190
"""
191
def step(self, closure=None): ...
192
```
193
194
### Advanced Optimizers
195
196
Specialized optimization algorithms.
197
198
```python { .api }
199
class LBFGS(Optimizer):
200
"""Limited-memory BFGS optimizer."""
201
def __init__(self, params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-7,
202
tolerance_change=1e-9, history_size=100, line_search_fn=None):
203
"""
204
Parameters:
205
- params: Iterable of parameters
206
- lr: Learning rate (default: 1)
207
- max_iter: Maximum number of iterations per optimization step
208
- max_eval: Maximum number of function evaluations per step
209
- tolerance_grad: Termination tolerance on first order optimality
210
- tolerance_change: Termination tolerance on function/parameter changes
211
- history_size: Update history size
212
- line_search_fn: Line search function ('strong_wolfe' or None)
213
"""
214
def step(self, closure): ...
215
216
class SparseAdam(Optimizer):
217
"""Adam optimizer for sparse tensors."""
218
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
219
"""
220
Parameters:
221
- params: Iterable of parameters
222
- lr: Learning rate (default: 1e-3)
223
- betas: Coefficients for momentum and squared gradient averaging
224
- eps: Term for numerical stability
225
"""
226
def step(self, closure=None): ...
227
228
class Adafactor(Optimizer):
229
"""Adafactor optimizer for memory-efficient training."""
230
def __init__(self, params, lr=None, eps2=1e-30, cliping_threshold=1.0, decay_rate=-0.8,
231
beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True):
232
"""
233
Parameters:
234
- params: Iterable of parameters
235
- lr: Learning rate (None for automatic scaling)
236
- eps2: Regularization constant for second moment
237
- cliping_threshold: Threshold of root mean square of final gradient update
238
- decay_rate: Coefficient for moving average of squared gradient
239
- beta1: Coefficient for moving average of gradient
240
- weight_decay: Weight decay
241
- scale_parameter: Scale learning rate by root mean square of parameter
242
- relative_step: Set learning rate relative to current step
243
"""
244
def step(self, closure=None): ...
245
```
246
247
### Learning Rate Schedulers
248
249
Learning rate scheduling strategies for training optimization.
250
251
```python { .api }
252
class LRScheduler:
253
"""Base class for learning rate schedulers."""
254
def __init__(self, optimizer, last_epoch=-1, verbose=False): ...
255
def state_dict(self):
256
"""Return scheduler state dictionary."""
257
def load_state_dict(self, state_dict):
258
"""Load scheduler state."""
259
def get_last_lr(self):
260
"""Return last computed learning rates."""
261
def step(self, epoch=None):
262
"""Update learning rates."""
263
264
class StepLR(LRScheduler):
265
"""Decay learning rate by gamma every step_size epochs."""
266
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
267
"""
268
Parameters:
269
- optimizer: Wrapped optimizer
270
- step_size: Period of learning rate decay
271
- gamma: Multiplicative factor of learning rate decay (default: 0.1)
272
- last_epoch: Index of last epoch (default: -1)
273
- verbose: Print message on every update (default: False)
274
"""
275
276
class MultiStepLR(LRScheduler):
277
"""Decay learning rate by gamma at specified milestones."""
278
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
279
"""
280
Parameters:
281
- optimizer: Wrapped optimizer
282
- milestones: List of epoch indices for decay
283
- gamma: Multiplicative factor of learning rate decay
284
- last_epoch: Index of last epoch
285
- verbose: Print message on every update
286
"""
287
288
class ExponentialLR(LRScheduler):
289
"""Decay learning rate by gamma every epoch."""
290
def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
291
"""
292
Parameters:
293
- optimizer: Wrapped optimizer
294
- gamma: Multiplicative factor of learning rate decay
295
- last_epoch: Index of last epoch
296
- verbose: Print message on every update
297
"""
298
299
class CosineAnnealingLR(LRScheduler):
300
"""Cosine annealing learning rate schedule."""
301
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
302
"""
303
Parameters:
304
- optimizer: Wrapped optimizer
305
- T_max: Maximum number of iterations
306
- eta_min: Minimum learning rate (default: 0)
307
- last_epoch: Index of last epoch
308
- verbose: Print message on every update
309
"""
310
311
class CosineAnnealingWarmRestarts(LRScheduler):
312
"""Cosine annealing with warm restarts."""
313
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
314
"""
315
Parameters:
316
- optimizer: Wrapped optimizer
317
- T_0: Number of iterations for first restart
318
- T_mult: Factor to increase T_i after restart (default: 1)
319
- eta_min: Minimum learning rate (default: 0)
320
- last_epoch: Index of last epoch
321
- verbose: Print message on every update
322
"""
323
324
class ReduceLROnPlateau:
325
"""Reduce learning rate when metric stops improving."""
326
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False,
327
threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8):
328
"""
329
Parameters:
330
- optimizer: Wrapped optimizer
331
- mode: 'min' or 'max' for metric improvement direction
332
- factor: Factor to reduce learning rate (default: 0.1)
333
- patience: Number of epochs with no improvement to wait
334
- verbose: Print message when reducing lr
335
- threshold: Threshold for measuring new optimum
336
- threshold_mode: 'rel' or 'abs' for threshold comparison
337
- cooldown: Number of epochs to wait before resuming normal operation
338
- min_lr: Lower bound on learning rate
339
- eps: Minimal decay applied to lr
340
"""
341
def step(self, metrics, epoch=None): ...
342
343
class CyclicLR(LRScheduler):
344
"""Cyclical learning rate policy."""
345
def __init__(self, optimizer, base_lr, max_lr, step_size_up=2000, step_size_down=None,
346
mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True,
347
base_momentum=0.8, max_momentum=0.9, last_epoch=-1, verbose=False):
348
"""
349
Parameters:
350
- optimizer: Wrapped optimizer
351
- base_lr: Lower learning rate boundary
352
- max_lr: Upper learning rate boundary
353
- step_size_up: Number of training iterations in increasing half
354
- step_size_down: Number of training iterations in decreasing half
355
- mode: 'triangular', 'triangular2', or 'exp_range'
356
- gamma: Constant in 'exp_range' scaling function
357
- scale_fn: Custom scaling policy function
358
- scale_mode: 'cycle' or 'iterations'
359
- cycle_momentum: Cycle momentum inversely to learning rate
360
- base_momentum: Lower momentum boundary
361
- max_momentum: Upper momentum boundary
362
- last_epoch: Index of last epoch
363
- verbose: Print message on every update
364
"""
365
366
class OneCycleLR(LRScheduler):
367
"""One cycle learning rate policy."""
368
def __init__(self, optimizer, max_lr, total_steps=None, epochs=None, steps_per_epoch=None,
369
pct_start=0.3, anneal_strategy='cos', cycle_momentum=True, base_momentum=0.85,
370
max_momentum=0.95, div_factor=25.0, final_div_factor=1e4, three_phase=False, last_epoch=-1, verbose=False):
371
"""
372
Parameters:
373
- optimizer: Wrapped optimizer
374
- max_lr: Upper learning rate boundary
375
- total_steps: Total number of steps in cycle
376
- epochs: Number of epochs (alternative to total_steps)
377
- steps_per_epoch: Steps per epoch (with epochs)
378
- pct_start: Percentage of cycle spent increasing learning rate
379
- anneal_strategy: 'cos' or 'linear' annealing strategy
380
- cycle_momentum: Cycle momentum inversely to learning rate
381
- base_momentum: Lower momentum boundary
382
- max_momentum: Upper momentum boundary
383
- div_factor: Determines initial learning rate (max_lr/div_factor)
384
- final_div_factor: Determines minimum learning rate (max_lr/(div_factor*final_div_factor))
385
- three_phase: Use three phase schedule
386
- last_epoch: Index of last epoch
387
- verbose: Print message on every update
388
"""
389
```
390
391
### Gradient Processing
392
393
Utilities for gradient manipulation and processing.
394
395
```python { .api }
396
def clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False):
397
"""
398
Clip gradient norm of parameters.
399
400
Parameters:
401
- parameters: Iterable of parameters or single tensor
402
- max_norm: Maximum norm of gradients
403
- norm_type: Type of norm (default: 2.0)
404
- error_if_nonfinite: Raise error if total norm is NaN or inf
405
406
Returns:
407
Total norm of the parameters
408
"""
409
410
def clip_grad_value_(parameters, clip_value):
411
"""
412
Clip gradient values to specified range.
413
414
Parameters:
415
- parameters: Iterable of parameters or single tensor
416
- clip_value: Maximum absolute value for gradients
417
"""
418
```
419
420
### Stochastic Weight Averaging
421
422
Utilities for stochastic weight averaging to improve generalization.
423
424
```python { .api }
425
class AveragedModel(nn.Module):
426
"""Averaged model for stochastic weight averaging."""
427
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
428
"""
429
Parameters:
430
- model: Model to average
431
- device: Device to store averaged parameters
432
- avg_fn: Function to compute running average
433
- use_buffers: Whether to average buffers
434
"""
435
def update_parameters(self, model): ...
436
437
class SWALR(LRScheduler):
438
"""Learning rate scheduler for stochastic weight averaging."""
439
def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
440
"""
441
Parameters:
442
- optimizer: Wrapped optimizer
443
- swa_lr: SWA learning rate
444
- anneal_epochs: Number of epochs for annealing (default: 10)
445
- anneal_strategy: 'cos' or 'linear' annealing strategy
446
- last_epoch: Index of last epoch
447
"""
448
```
449
450
## Usage Examples
451
452
### Basic Training Loop
453
454
```python
455
import torch
456
import torch.nn as nn
457
import torch.optim as optim
458
from torch.utils.data import DataLoader
459
460
# Setup model, loss, and optimizer
461
model = nn.Sequential(
462
nn.Linear(784, 128),
463
nn.ReLU(),
464
nn.Linear(128, 10)
465
)
466
criterion = nn.CrossEntropyLoss()
467
optimizer = optim.Adam(model.parameters(), lr=0.001)
468
469
# Training loop
470
def train_epoch(model, dataloader, criterion, optimizer):
471
model.train()
472
total_loss = 0
473
474
for batch_idx, (data, targets) in enumerate(dataloader):
475
# Zero gradients
476
optimizer.zero_grad()
477
478
# Forward pass
479
outputs = model(data)
480
loss = criterion(outputs, targets)
481
482
# Backward pass
483
loss.backward()
484
485
# Gradient clipping (optional)
486
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
487
488
# Update parameters
489
optimizer.step()
490
491
total_loss += loss.item()
492
493
return total_loss / len(dataloader)
494
495
# Example usage
496
# train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
497
# loss = train_epoch(model, train_loader, criterion, optimizer)
498
# print(f"Training loss: {loss:.4f}")
499
```
500
501
### Learning Rate Scheduling
502
503
```python
504
import torch
505
import torch.optim as optim
506
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
507
508
# Setup optimizer and scheduler
509
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
510
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
511
512
# Alternative: Reduce on plateau
513
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
514
515
# Training loop with scheduler
516
for epoch in range(100):
517
train_loss = train_epoch(model, train_loader, criterion, optimizer)
518
val_loss = validate(model, val_loader, criterion)
519
520
# Step scheduler
521
scheduler.step() # For StepLR
522
# scheduler.step(val_loss) # For ReduceLROnPlateau
523
524
current_lr = optimizer.param_groups[0]['lr']
525
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}")
526
```
527
528
### Advanced Optimization with Multiple Parameter Groups
529
530
```python
531
import torch
532
import torch.optim as optim
533
534
# Different learning rates for different parts of the model
535
model = nn.Sequential(
536
nn.Linear(784, 128),
537
nn.ReLU(),
538
nn.Linear(128, 10)
539
)
540
541
# Create parameter groups
542
params = [
543
{'params': model[0].parameters(), 'lr': 0.001}, # First layer
544
{'params': model[2].parameters(), 'lr': 0.01} # Last layer
545
]
546
547
optimizer = optim.Adam(params, weight_decay=1e-4)
548
549
# Training with different learning rates
550
for epoch in range(100):
551
for batch_idx, (data, targets) in enumerate(train_loader):
552
optimizer.zero_grad()
553
outputs = model(data)
554
loss = criterion(outputs, targets)
555
loss.backward()
556
optimizer.step()
557
```
558
559
### Stochastic Weight Averaging
560
561
```python
562
import torch
563
import torch.optim as optim
564
from torch.optim.swa_utils import AveragedModel, SWALR
565
566
# Setup model and optimizer
567
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1))
568
optimizer = optim.SGD(model.parameters(), lr=0.1)
569
570
# Create averaged model and SWA scheduler
571
swa_model = AveragedModel(model)
572
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
573
574
# Training with SWA
575
swa_start_epoch = 80
576
for epoch in range(100):
577
train_loss = train_epoch(model, train_loader, criterion, optimizer)
578
579
if epoch >= swa_start_epoch:
580
swa_model.update_parameters(model)
581
swa_scheduler.step()
582
else:
583
# Regular scheduler before SWA
584
regular_scheduler.step()
585
586
print(f"Epoch {epoch}: Loss: {train_loss:.4f}")
587
588
# Update SWA batch normalization statistics
589
torch.optim.swa_utils.update_bn(train_loader, swa_model)
590
591
# Use SWA model for inference
592
swa_model.eval()
593
```
594
595
### One Cycle Learning Rate Policy
596
597
```python
598
import torch
599
import torch.optim as optim
600
from torch.optim.lr_scheduler import OneCycleLR
601
602
# Setup optimizer
603
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
604
605
# One cycle scheduler
606
steps_per_epoch = len(train_loader)
607
scheduler = OneCycleLR(
608
optimizer,
609
max_lr=0.1,
610
epochs=100,
611
steps_per_epoch=steps_per_epoch,
612
pct_start=0.3,
613
div_factor=25,
614
final_div_factor=1e4
615
)
616
617
# Training loop
618
for epoch in range(100):
619
for batch_idx, (data, targets) in enumerate(train_loader):
620
optimizer.zero_grad()
621
outputs = model(data)
622
loss = criterion(outputs, targets)
623
loss.backward()
624
optimizer.step()
625
626
# Step after each batch
627
scheduler.step()
628
629
print(f"Epoch {epoch}: LR: {optimizer.param_groups[0]['lr']:.6f}")
630
```