0
# Optimization and Scheduling
1
2
Comprehensive optimization framework with support for multiple optimizers, learning rate schedulers, momentum schedulers, automatic mixed precision, and gradient accumulation strategies. The system provides flexible optimization configurations for various training scenarios.
3
4
## Capabilities
5
6
### Optimizer Wrappers
7
8
Wrapper classes that provide unified interface for different optimization strategies.
9
10
```python { .api }
11
class BaseOptimWrapper:
12
def __init__(self, optimizer, accumulative_counts: int = 1, clip_grad: dict = None):
13
"""
14
Base optimizer wrapper.
15
16
Parameters:
17
- optimizer: PyTorch optimizer instance
18
- accumulative_counts: Number of steps for gradient accumulation
19
- clip_grad: Gradient clipping configuration
20
"""
21
22
def update_params(self, loss):
23
"""
24
Update model parameters.
25
26
Parameters:
27
- loss: Loss tensor for backward pass
28
"""
29
30
def zero_grad(self):
31
"""Zero gradients."""
32
33
def step(self):
34
"""Optimizer step."""
35
36
def backward(self, loss):
37
"""
38
Backward pass.
39
40
Parameters:
41
- loss: Loss tensor
42
"""
43
44
def get_lr(self) -> dict:
45
"""
46
Get current learning rates.
47
48
Returns:
49
Dictionary of learning rates
50
"""
51
52
def get_momentum(self) -> dict:
53
"""
54
Get current momentum values.
55
56
Returns:
57
Dictionary of momentum values
58
"""
59
60
@property
61
def param_groups(self) -> list:
62
"""Parameter groups."""
63
64
class OptimWrapper(BaseOptimWrapper):
65
def __init__(self, optimizer, accumulative_counts: int = 1, clip_grad: dict = None):
66
"""
67
Standard optimizer wrapper.
68
69
Parameters:
70
- optimizer: PyTorch optimizer
71
- accumulative_counts: Gradient accumulation steps
72
- clip_grad: Gradient clipping config
73
"""
74
75
class AmpOptimWrapper(BaseOptimWrapper):
76
def __init__(self, loss_scale: str = 'dynamic', **kwargs):
77
"""
78
Automatic mixed precision optimizer wrapper.
79
80
Parameters:
81
- loss_scale: Loss scaling strategy ('dynamic' or float value)
82
- **kwargs: Base wrapper arguments
83
"""
84
85
def backward(self, loss):
86
"""Scaled backward pass for AMP."""
87
88
class ApexOptimWrapper(BaseOptimWrapper):
89
def __init__(self, **kwargs):
90
"""
91
Apex optimizer wrapper for FP16 training.
92
93
Parameters:
94
- **kwargs: Base wrapper arguments
95
"""
96
97
class OptimWrapperDict:
98
def __init__(self, **kwargs):
99
"""
100
Dictionary of optimizer wrappers for multi-optimizer training.
101
102
Parameters:
103
- **kwargs: Named optimizer wrapper configurations
104
"""
105
106
def update_params(self, loss_dict: dict):
107
"""
108
Update parameters for multiple optimizers.
109
110
Parameters:
111
- loss_dict: Dictionary of losses for each optimizer
112
"""
113
114
def zero_grad(self):
115
"""Zero gradients for all optimizers."""
116
117
def step(self):
118
"""Step for all optimizers."""
119
```
120
121
### Learning Rate Schedulers
122
123
Comprehensive collection of learning rate scheduling strategies.
124
125
```python { .api }
126
class ConstantLR:
127
def __init__(self, factor: float = 1.0, **kwargs):
128
"""
129
Constant learning rate scheduler.
130
131
Parameters:
132
- factor: Multiplicative factor for learning rate
133
- **kwargs: Base scheduler arguments
134
"""
135
136
class CosineAnnealingLR:
137
def __init__(self, T_max: int, eta_min: float = 0, **kwargs):
138
"""
139
Cosine annealing learning rate scheduler.
140
141
Parameters:
142
- T_max: Maximum number of iterations
143
- eta_min: Minimum learning rate
144
- **kwargs: Base scheduler arguments
145
"""
146
147
class ExponentialLR:
148
def __init__(self, gamma: float, **kwargs):
149
"""
150
Exponential learning rate scheduler.
151
152
Parameters:
153
- gamma: Multiplicative factor of learning rate decay
154
- **kwargs: Base scheduler arguments
155
"""
156
157
class LinearLR:
158
def __init__(self, start_factor: float = 1.0, end_factor: float = 0.0, total_iters: int = 5, **kwargs):
159
"""
160
Linear learning rate scheduler.
161
162
Parameters:
163
- start_factor: Starting multiplicative factor
164
- end_factor: Ending multiplicative factor
165
- total_iters: Number of iterations for linear decay
166
- **kwargs: Base scheduler arguments
167
"""
168
169
class MultiStepLR:
170
def __init__(self, milestones: list, gamma: float = 0.1, **kwargs):
171
"""
172
Multi-step learning rate scheduler.
173
174
Parameters:
175
- milestones: List of epoch indices for LR decay
176
- gamma: Multiplicative factor of learning rate decay
177
- **kwargs: Base scheduler arguments
178
"""
179
180
class StepLR:
181
def __init__(self, step_size: int, gamma: float = 0.1, **kwargs):
182
"""
183
Step learning rate scheduler.
184
185
Parameters:
186
- step_size: Period of learning rate decay
187
- gamma: Multiplicative factor of learning rate decay
188
- **kwargs: Base scheduler arguments
189
"""
190
191
class OneCycleLR:
192
def __init__(self, max_lr: float, total_steps: int = None, epochs: int = None, steps_per_epoch: int = None, pct_start: float = 0.3, anneal_strategy: str = 'cos', cycle_momentum: bool = True, base_momentum: float = 0.85, max_momentum: float = 0.95, div_factor: float = 25.0, final_div_factor: float = 10000.0, **kwargs):
193
"""
194
One cycle learning rate scheduler.
195
196
Parameters:
197
- max_lr: Upper learning rate boundaries
198
- total_steps: Total number of steps
199
- epochs: Number of epochs
200
- steps_per_epoch: Steps per epoch
201
- pct_start: Percentage of cycle spent increasing learning rate
202
- anneal_strategy: Annealing strategy ('cos' or 'linear')
203
- cycle_momentum: Whether to cycle momentum
204
- base_momentum: Lower momentum boundary
205
- max_momentum: Upper momentum boundary
206
- div_factor: Initial learning rate divisor
207
- final_div_factor: Final learning rate divisor
208
- **kwargs: Base scheduler arguments
209
"""
210
211
class PolyLR:
212
def __init__(self, power: float = 1.0, min_lr: float = 0.0, **kwargs):
213
"""
214
Polynomial learning rate scheduler.
215
216
Parameters:
217
- power: Polynomial power
218
- min_lr: Minimum learning rate
219
- **kwargs: Base scheduler arguments
220
"""
221
222
class ReduceOnPlateauLR:
223
def __init__(self, mode: str = 'min', factor: float = 0.1, patience: int = 10, threshold: float = 1e-4, threshold_mode: str = 'rel', cooldown: int = 0, min_lr: float = 0, eps: float = 1e-8, **kwargs):
224
"""
225
Reduce on plateau learning rate scheduler.
226
227
Parameters:
228
- mode: 'min' or 'max' for monitoring metric
229
- factor: Factor to reduce learning rate
230
- patience: Number of epochs with no improvement
231
- threshold: Threshold for measuring improvement
232
- threshold_mode: 'rel' or 'abs' threshold mode
233
- cooldown: Number of epochs to wait before resuming
234
- min_lr: Minimum learning rate
235
- eps: Minimum decay applied to learning rate
236
- **kwargs: Base scheduler arguments
237
"""
238
```
239
240
### Momentum Schedulers
241
242
Schedulers for momentum parameter in optimizers.
243
244
```python { .api }
245
class ConstantMomentum:
246
def __init__(self, factor: float = 1.0, **kwargs):
247
"""
248
Constant momentum scheduler.
249
250
Parameters:
251
- factor: Multiplicative factor for momentum
252
- **kwargs: Base scheduler arguments
253
"""
254
255
class CosineAnnealingMomentum:
256
def __init__(self, T_max: int, eta_min: float = 0, **kwargs):
257
"""
258
Cosine annealing momentum scheduler.
259
260
Parameters:
261
- T_max: Maximum number of iterations
262
- eta_min: Minimum momentum
263
- **kwargs: Base scheduler arguments
264
"""
265
266
class ExponentialMomentum:
267
def __init__(self, gamma: float, **kwargs):
268
"""
269
Exponential momentum scheduler.
270
271
Parameters:
272
- gamma: Multiplicative factor of momentum decay
273
- **kwargs: Base scheduler arguments
274
"""
275
276
class LinearMomentum:
277
def __init__(self, start_factor: float = 1.0, end_factor: float = 0.0, total_iters: int = 5, **kwargs):
278
"""
279
Linear momentum scheduler.
280
281
Parameters:
282
- start_factor: Starting multiplicative factor
283
- end_factor: Ending multiplicative factor
284
- total_iters: Number of iterations
285
- **kwargs: Base scheduler arguments
286
"""
287
288
class MultiStepMomentum:
289
def __init__(self, milestones: list, gamma: float = 0.1, **kwargs):
290
"""
291
Multi-step momentum scheduler.
292
293
Parameters:
294
- milestones: List of epoch indices
295
- gamma: Multiplicative factor
296
- **kwargs: Base scheduler arguments
297
"""
298
299
class StepMomentum:
300
def __init__(self, step_size: int, gamma: float = 0.1, **kwargs):
301
"""
302
Step momentum scheduler.
303
304
Parameters:
305
- step_size: Period of momentum decay
306
- gamma: Multiplicative factor
307
- **kwargs: Base scheduler arguments
308
"""
309
```
310
311
### Parameter Schedulers
312
313
Generic parameter scheduling framework for any optimizer parameter.
314
315
```python { .api }
316
class _ParamScheduler:
317
def __init__(self, optimizer, param_name: str, **kwargs):
318
"""
319
Base parameter scheduler.
320
321
Parameters:
322
- optimizer: Optimizer instance
323
- param_name: Parameter name to schedule
324
- **kwargs: Scheduler arguments
325
"""
326
327
def step(self):
328
"""Execute scheduler step."""
329
330
def get_value(self) -> list:
331
"""
332
Get current parameter values.
333
334
Returns:
335
List of current parameter values
336
"""
337
338
class ConstantParamScheduler(_ParamScheduler):
339
def __init__(self, optimizer, param_name: str, factor: float = 1.0, **kwargs):
340
"""
341
Constant parameter scheduler.
342
343
Parameters:
344
- optimizer: Optimizer instance
345
- param_name: Parameter name
346
- factor: Multiplicative factor
347
- **kwargs: Base scheduler arguments
348
"""
349
350
class CosineAnnealingParamScheduler(_ParamScheduler):
351
def __init__(self, optimizer, param_name: str, T_max: int, eta_min: float = 0, **kwargs):
352
"""
353
Cosine annealing parameter scheduler.
354
355
Parameters:
356
- optimizer: Optimizer instance
357
- param_name: Parameter name
358
- T_max: Maximum iterations
359
- eta_min: Minimum parameter value
360
- **kwargs: Base scheduler arguments
361
"""
362
```
363
364
### Optimizer Constructor
365
366
Builder class for creating optimizer wrappers from configuration.
367
368
```python { .api }
369
class DefaultOptimWrapperConstructor:
370
def __init__(self, optim_wrapper_cfg: dict, paramwise_cfg: dict = None):
371
"""
372
Default optimizer wrapper constructor.
373
374
Parameters:
375
- optim_wrapper_cfg: Optimizer wrapper configuration
376
- paramwise_cfg: Parameter-wise configuration
377
"""
378
379
def __call__(self, model) -> BaseOptimWrapper:
380
"""
381
Build optimizer wrapper for model.
382
383
Parameters:
384
- model: PyTorch model
385
386
Returns:
387
Optimizer wrapper instance
388
"""
389
390
def build_optim_wrapper(model, cfg: dict) -> BaseOptimWrapper:
391
"""
392
Build optimizer wrapper from configuration.
393
394
Parameters:
395
- model: PyTorch model
396
- cfg: Optimizer wrapper configuration
397
398
Returns:
399
Built optimizer wrapper
400
"""
401
```
402
403
### Zero Redundancy Optimizer
404
405
Implementation of ZeRO (Zero Redundancy Optimizer) for memory-efficient training.
406
407
```python { .api }
408
class ZeroRedundancyOptimizer:
409
def __init__(self, params, optimizer_class, process_group=None, parameters_as_bucket_view: bool = False, overlap_with_ddp: bool = False, **defaults):
410
"""
411
Zero Redundancy Optimizer wrapper.
412
413
Parameters:
414
- params: Model parameters
415
- optimizer_class: Base optimizer class
416
- process_group: Process group for distributed training
417
- parameters_as_bucket_view: Whether to use bucket view
418
- overlap_with_ddp: Whether to overlap with DDP
419
- **defaults: Default optimizer arguments
420
"""
421
422
def step(self, closure=None):
423
"""
424
Optimizer step with gradient synchronization.
425
426
Parameters:
427
- closure: Optional closure function
428
"""
429
430
def zero_grad(self):
431
"""Zero gradients across all processes."""
432
433
def consolidate_state_dict(self, to: int = 0):
434
"""
435
Consolidate optimizer state dictionary.
436
437
Parameters:
438
- to: Target rank for consolidation
439
"""
440
```
441
442
## Usage Examples
443
444
### Basic Optimizer Wrapper Usage
445
446
```python
447
import torch
448
from mmengine.optim import OptimWrapper
449
450
# Create model and optimizer
451
model = torch.nn.Linear(10, 1)
452
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
453
454
# Wrap optimizer
455
optim_wrapper = OptimWrapper(
456
optimizer=optimizer,
457
accumulative_counts=4, # Gradient accumulation
458
clip_grad=dict(max_norm=1.0) # Gradient clipping
459
)
460
461
# Training step
462
loss = model(data)
463
optim_wrapper.update_params(loss)
464
```
465
466
### Automatic Mixed Precision
467
468
```python
469
from mmengine.optim import AmpOptimWrapper
470
471
# Create AMP optimizer wrapper
472
optim_wrapper = AmpOptimWrapper(
473
optimizer=optimizer,
474
loss_scale='dynamic',
475
accumulative_counts=2
476
)
477
478
# Training with AMP
479
loss = model(data)
480
optim_wrapper.update_params(loss) # Automatic scaling
481
```
482
483
### Learning Rate Scheduling
484
485
```python
486
from mmengine.optim import CosineAnnealingLR, MultiStepLR
487
488
# Cosine annealing scheduler
489
scheduler_cfg = dict(
490
type='CosineAnnealingLR',
491
T_max=100,
492
eta_min=1e-6,
493
by_epoch=True,
494
begin=0,
495
end=100
496
)
497
498
# Multi-step scheduler
499
scheduler_cfg = dict(
500
type='MultiStepLR',
501
milestones=[30, 60, 90],
502
gamma=0.1,
503
by_epoch=True
504
)
505
506
# Use in runner configuration
507
runner = Runner(
508
model=model,
509
optim_wrapper=dict(
510
optimizer=dict(type='SGD', lr=0.1, momentum=0.9),
511
clip_grad=dict(max_norm=1.0)
512
),
513
param_scheduler=scheduler_cfg
514
)
515
```
516
517
### Multiple Optimizers
518
519
```python
520
from mmengine.optim import OptimWrapperDict
521
522
# Multiple optimizer configuration
523
optim_wrapper_dict = OptimWrapperDict(
524
generator=dict(
525
optimizer=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))
526
),
527
discriminator=dict(
528
optimizer=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))
529
)
530
)
531
532
# Training step with multiple losses
533
losses = {
534
'generator': gen_loss,
535
'discriminator': disc_loss
536
}
537
optim_wrapper_dict.update_params(losses)
538
```
539
540
### Custom Parameter-wise Configuration
541
542
```python
543
from mmengine.optim import DefaultOptimWrapperConstructor
544
545
# Parameter-wise learning rate configuration
546
paramwise_cfg = dict(
547
bias_lr_mult=2.0, # 2x learning rate for bias
548
bias_decay_mult=0.0, # No weight decay for bias
549
norm_decay_mult=0.0, # No weight decay for normalization
550
custom_keys={
551
'.backbone': dict(lr_mult=0.1), # 0.1x LR for backbone
552
'.head': dict(lr_mult=1.0) # 1x LR for head
553
}
554
)
555
556
# Create optimizer constructor
557
constructor = DefaultOptimWrapperConstructor(
558
optim_wrapper_cfg=dict(
559
type='OptimWrapper',
560
optimizer=dict(type='SGD', lr=0.01, momentum=0.9)
561
),
562
paramwise_cfg=paramwise_cfg
563
)
564
565
# Build optimizer wrapper
566
optim_wrapper = constructor(model)
567
```
568
569
### Advanced Scheduling
570
571
```python
572
from mmengine.optim import OneCycleLR
573
574
# One cycle learning rate policy
575
scheduler_cfg = dict(
576
type='OneCycleLR',
577
max_lr=0.1,
578
total_steps=1000,
579
pct_start=0.3,
580
anneal_strategy='cos',
581
cycle_momentum=True,
582
base_momentum=0.85,
583
max_momentum=0.95
584
)
585
586
# Polynomial learning rate decay
587
poly_scheduler_cfg = dict(
588
type='PolyLR',
589
power=0.9,
590
min_lr=1e-6,
591
by_epoch=False,
592
begin=0,
593
end=1000
594
)
595
```