0
# Training and Loops
1
2
Complete training orchestration system with flexible runners supporting epoch-based and iteration-based training, validation, and testing loops with built-in checkpointing and logging. The system provides a unified interface for managing the entire training pipeline.
3
4
## Capabilities
5
6
### Runner Class
7
8
Central coordinator managing the entire training process with flexible configuration and automatic component initialization.
9
10
```python { .api }
11
class Runner:
12
def __init__(self, model, work_dir: str = None, train_dataloader = None, val_dataloader = None, test_dataloader = None, train_cfg: dict = None, val_cfg: dict = None, test_cfg: dict = None, auto_scale_lr: dict = None, optim_wrapper = None, param_scheduler = None, val_evaluator = None, test_evaluator = None, default_hooks: dict = None, custom_hooks: list = None, data_preprocessor = None, load_from: str = None, resume: bool = False, launcher: str = 'none', env_cfg: dict = None, log_processor = None, visualizer = None, default_scope: str = 'mmengine', randomness: dict = None, experiment_name: str = None, cfg: dict = None):
13
"""
14
Initialize Runner with comprehensive training configuration.
15
16
Parameters:
17
- model: Model to train (torch.nn.Module or config dict)
18
- work_dir: Working directory for saving outputs
19
- train_dataloader: Training data loader
20
- val_dataloader: Validation data loader
21
- test_dataloader: Test data loader
22
- train_cfg: Training loop configuration
23
- val_cfg: Validation loop configuration
24
- test_cfg: Test loop configuration
25
- auto_scale_lr: Automatic learning rate scaling configuration
26
- optim_wrapper: Optimizer wrapper configuration
27
- param_scheduler: Parameter scheduler configuration
28
- val_evaluator: Validation evaluator configuration
29
- test_evaluator: Test evaluator configuration
30
- default_hooks: Default hooks configuration
31
- custom_hooks: Custom hooks list
32
- data_preprocessor: Data preprocessor configuration
33
- load_from: Checkpoint path to load
34
- resume: Whether to resume training
35
- launcher: Distributed launcher type
36
- env_cfg: Environment configuration
37
- log_processor: Log processor configuration
38
- visualizer: Visualizer configuration
39
- default_scope: Default registry scope
40
- randomness: Randomness configuration
41
- experiment_name: Experiment name
42
- cfg: Complete configuration object
43
"""
44
45
@classmethod
46
def from_cfg(cls, cfg) -> 'Runner':
47
"""
48
Create Runner from configuration.
49
50
Parameters:
51
- cfg: Configuration object or dict
52
53
Returns:
54
Initialized Runner instance
55
"""
56
57
def train(self):
58
"""Run training loop."""
59
60
def val(self):
61
"""Run validation loop."""
62
63
def test(self):
64
"""Run test loop."""
65
66
def call_hook(self, fn_name: str, **kwargs):
67
"""
68
Call hook method.
69
70
Parameters:
71
- fn_name: Hook method name
72
- **kwargs: Hook arguments
73
"""
74
75
def register_hook(self, hook, priority: str = 'NORMAL'):
76
"""
77
Register hook.
78
79
Parameters:
80
- hook: Hook instance or config
81
- priority: Hook priority
82
"""
83
84
def load_or_resume(self):
85
"""Load checkpoint or resume training."""
86
87
def save_checkpoint(self, out_dir: str, filename: str = None, file_client_args: dict = None, save_optimizer: bool = True, save_param_scheduler: bool = True, meta: dict = None, by_epoch: bool = True):
88
"""
89
Save checkpoint.
90
91
Parameters:
92
- out_dir: Output directory
93
- filename: Checkpoint filename
94
- file_client_args: File client arguments
95
- save_optimizer: Whether to save optimizer state
96
- save_param_scheduler: Whether to save scheduler state
97
- meta: Additional metadata
98
- by_epoch: Whether checkpoint is by epoch
99
"""
100
101
@property
102
def epoch(self) -> int:
103
"""Current epoch number."""
104
105
@property
106
def iter(self) -> int:
107
"""Current iteration number."""
108
109
@property
110
def max_epochs(self) -> int:
111
"""Maximum number of epochs."""
112
113
@property
114
def max_iters(self) -> int:
115
"""Maximum number of iterations."""
116
```
117
118
### Flexible Runner
119
120
Extended runner with additional flexibility for custom training workflows.
121
122
```python { .api }
123
class FlexibleRunner(Runner):
124
def __init__(self, **kwargs):
125
"""
126
Initialize FlexibleRunner with extended configuration options.
127
128
Parameters:
129
- **kwargs: Same as Runner plus additional flexibility options
130
"""
131
132
def run_loop(self, loop: 'BaseLoop'):
133
"""
134
Run custom training loop.
135
136
Parameters:
137
- loop: Loop instance to execute
138
"""
139
```
140
141
### Base Loop Class
142
143
Abstract base class for all training loops providing common interface and functionality.
144
145
```python { .api }
146
class BaseLoop:
147
def __init__(self, runner, dataloader):
148
"""
149
Initialize base loop.
150
151
Parameters:
152
- runner: Runner instance
153
- dataloader: Data loader for the loop
154
"""
155
156
def run(self):
157
"""Execute the loop."""
158
159
@property
160
def iter(self) -> int:
161
"""Current iteration number."""
162
163
@property
164
def max_iters(self) -> int:
165
"""Maximum iterations for this loop."""
166
```
167
168
### Training Loops
169
170
Specialized training loops for different training strategies.
171
172
```python { .api }
173
class EpochBasedTrainLoop(BaseLoop):
174
def __init__(self, runner, dataloader, max_epochs: int, val_begin: int = 1, val_interval: int = 1, dynamic_intervals: list = None):
175
"""
176
Epoch-based training loop.
177
178
Parameters:
179
- runner: Runner instance
180
- dataloader: Training data loader
181
- max_epochs: Maximum number of epochs
182
- val_begin: Epoch to begin validation
183
- val_interval: Validation interval in epochs
184
- dynamic_intervals: Dynamic validation intervals
185
"""
186
187
def run_epoch(self):
188
"""Run one training epoch."""
189
190
def run_iter(self, idx: int, data_batch):
191
"""
192
Run one training iteration.
193
194
Parameters:
195
- idx: Iteration index
196
- data_batch: Input data batch
197
"""
198
199
class IterBasedTrainLoop(BaseLoop):
200
def __init__(self, runner, dataloader, max_iters: int, val_begin: int = 1, val_interval: int = 1, dynamic_intervals: list = None):
201
"""
202
Iteration-based training loop.
203
204
Parameters:
205
- runner: Runner instance
206
- dataloader: Training data loader
207
- max_iters: Maximum number of iterations
208
- val_begin: Iteration to begin validation
209
- val_interval: Validation interval in iterations
210
- dynamic_intervals: Dynamic validation intervals
211
"""
212
213
def run_iter(self, data_batch):
214
"""
215
Run one training iteration.
216
217
Parameters:
218
- data_batch: Input data batch
219
"""
220
```
221
222
### Validation and Test Loops
223
224
Loops for model evaluation during training or standalone testing.
225
226
```python { .api }
227
class ValLoop(BaseLoop):
228
def __init__(self, runner, dataloader, evaluator, fp16: bool = False):
229
"""
230
Validation loop.
231
232
Parameters:
233
- runner: Runner instance
234
- dataloader: Validation data loader
235
- evaluator: Evaluator for validation metrics
236
- fp16: Whether to use FP16 precision
237
"""
238
239
def run(self) -> dict:
240
"""
241
Run validation loop.
242
243
Returns:
244
Dictionary of validation metrics
245
"""
246
247
class TestLoop(BaseLoop):
248
def __init__(self, runner, dataloader, evaluator, fp16: bool = False):
249
"""
250
Test loop.
251
252
Parameters:
253
- runner: Runner instance
254
- dataloader: Test data loader
255
- evaluator: Evaluator for test metrics
256
- fp16: Whether to use FP16 precision
257
"""
258
259
def run(self) -> dict:
260
"""
261
Run test loop.
262
263
Returns:
264
Dictionary of test metrics
265
"""
266
```
267
268
### Checkpoint Management
269
270
Comprehensive checkpoint loading and saving functionality.
271
272
```python { .api }
273
def load_checkpoint(filename: str, map_location: str = None, logger = None, revise_keys: list = None) -> dict:
274
"""
275
Load checkpoint from file.
276
277
Parameters:
278
- filename: Checkpoint file path
279
- map_location: Device to load checkpoint
280
- logger: Logger instance
281
- revise_keys: Keys to revise during loading
282
283
Returns:
284
Checkpoint dictionary
285
"""
286
287
def save_checkpoint(model, filename: str, optimizer = None, lr_scheduler = None, meta: dict = None, file_client_args: dict = None):
288
"""
289
Save checkpoint to file.
290
291
Parameters:
292
- model: Model to save
293
- filename: Output filename
294
- optimizer: Optimizer state to save
295
- lr_scheduler: Learning rate scheduler to save
296
- meta: Additional metadata
297
- file_client_args: File client arguments
298
"""
299
300
def weights_to_cpu(state_dict: dict) -> dict:
301
"""
302
Move weights to CPU.
303
304
Parameters:
305
- state_dict: Model state dictionary
306
307
Returns:
308
CPU state dictionary
309
"""
310
311
def get_state_dict(module, destination: dict = None, prefix: str = '', keep_vars: bool = False) -> dict:
312
"""
313
Get model state dictionary.
314
315
Parameters:
316
- module: PyTorch module
317
- destination: Destination dictionary
318
- prefix: Key prefix
319
- keep_vars: Whether to keep variables
320
321
Returns:
322
State dictionary
323
"""
324
325
def find_latest_checkpoint(path: str, suffix: str = 'pth') -> str:
326
"""
327
Find latest checkpoint in directory.
328
329
Parameters:
330
- path: Directory path
331
- suffix: Checkpoint file suffix
332
333
Returns:
334
Latest checkpoint path
335
"""
336
```
337
338
### Model Loading Utilities
339
340
Utilities for loading pre-trained models and model information.
341
342
```python { .api }
343
def get_torchvision_models() -> list:
344
"""
345
Get list of available torchvision models.
346
347
Returns:
348
List of model names
349
"""
350
351
def get_external_models() -> list:
352
"""
353
Get list of available external models.
354
355
Returns:
356
List of external model names
357
"""
358
359
def get_mmcls_models() -> list:
360
"""
361
Get list of available MMClassification models.
362
363
Returns:
364
List of MMCls model names
365
"""
366
367
def get_deprecated_model_names() -> list:
368
"""
369
Get list of deprecated model names.
370
371
Returns:
372
List of deprecated model names
373
"""
374
375
class CheckpointLoader:
376
@staticmethod
377
def load_checkpoint(filename: str, map_location: str = None) -> dict:
378
"""
379
Load checkpoint with advanced options.
380
381
Parameters:
382
- filename: Checkpoint file path
383
- map_location: Device mapping
384
385
Returns:
386
Loaded checkpoint
387
"""
388
```
389
390
### Training Utilities
391
392
Additional utilities for training management.
393
394
```python { .api }
395
def set_random_seed(seed: int, deterministic: bool = False, diff_rank_seed: bool = False):
396
"""
397
Set random seed for reproducibility.
398
399
Parameters:
400
- seed: Random seed value
401
- deterministic: Whether to use deterministic algorithms
402
- diff_rank_seed: Whether to use different seeds for different ranks
403
"""
404
405
def turn_on_activation_checkpointing(model, **kwargs):
406
"""
407
Enable activation checkpointing for memory efficiency.
408
409
Parameters:
410
- model: Model to apply checkpointing
411
- **kwargs: Checkpointing configuration
412
"""
413
414
def autocast(*args, **kwargs):
415
"""
416
Automatic mixed precision context manager.
417
418
Parameters:
419
- *args: Positional arguments
420
- **kwargs: Keyword arguments
421
422
Returns:
423
Autocast context manager
424
"""
425
```
426
427
## Usage Examples
428
429
### Basic Training Setup
430
431
```python
432
from mmengine import Runner, Config
433
434
# Load configuration
435
cfg = Config.fromfile('config.py')
436
437
# Create runner
438
runner = Runner.from_cfg(cfg)
439
440
# Start training
441
runner.train()
442
```
443
444
### Custom Training Loop
445
446
```python
447
from mmengine import Runner, EpochBasedTrainLoop
448
449
# Create runner with custom configuration
450
runner = Runner(
451
model=model,
452
work_dir='./work_dir',
453
train_dataloader=train_loader,
454
val_dataloader=val_loader,
455
train_cfg=dict(type='EpochBasedTrainLoop', max_epochs=100),
456
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
457
)
458
459
# Run training
460
runner.train()
461
```
462
463
### Checkpoint Operations
464
465
```python
466
from mmengine.runner import load_checkpoint, save_checkpoint
467
468
# Load checkpoint
469
checkpoint = load_checkpoint('model.pth', map_location='cpu')
470
471
# Save checkpoint with metadata
472
save_checkpoint(
473
model,
474
'checkpoint.pth',
475
optimizer=optimizer,
476
meta={'epoch': 10, 'best_acc': 0.95}
477
)
478
479
# Find latest checkpoint
480
latest_ckpt = find_latest_checkpoint('./checkpoints')
481
```
482
483
### Custom Hook Registration
484
485
```python
486
from mmengine import Runner
487
from mmengine.hooks import Hook
488
489
class CustomHook(Hook):
490
def before_train_epoch(self, runner):
491
print(f"Starting epoch {runner.epoch}")
492
493
runner = Runner.from_cfg(cfg)
494
runner.register_hook(CustomHook(), priority='LOW')
495
runner.train()
496
```