0
# Models and Hooks
1
2
Comprehensive model management system with base classes, weight initialization, model wrappers for distributed training, and extensive hook system for customizing training behaviors. The system provides foundation classes and utilities for building robust training pipelines.
3
4
## Capabilities
5
6
### Base Model Classes
7
8
Foundation classes for all models in MMEngine with standardized interfaces for training, validation, and testing.
9
10
```python { .api }
11
class BaseModel:
12
def __init__(self, init_cfg: dict = None, data_preprocessor: dict = None):
13
"""
14
Base class for all models.
15
16
Parameters:
17
- init_cfg: Weight initialization configuration
18
- data_preprocessor: Data preprocessor configuration
19
"""
20
21
def forward(self, *args, **kwargs):
22
"""
23
Forward pass implementation.
24
25
Parameters:
26
- *args: Input arguments
27
- **kwargs: Input keyword arguments
28
29
Returns:
30
Model outputs
31
"""
32
33
def train_step(self, data, optim_wrapper):
34
"""
35
Training step implementation.
36
37
Parameters:
38
- data: Input data batch
39
- optim_wrapper: Optimizer wrapper
40
41
Returns:
42
Dictionary containing loss and log variables
43
"""
44
45
def val_step(self, data):
46
"""
47
Validation step implementation.
48
49
Parameters:
50
- data: Input data batch
51
52
Returns:
53
Validation outputs
54
"""
55
56
def test_step(self, data):
57
"""
58
Test step implementation.
59
60
Parameters:
61
- data: Input data batch
62
63
Returns:
64
Test outputs
65
"""
66
67
def init_weights(self):
68
"""Initialize model weights."""
69
70
@property
71
def device(self):
72
"""Get model device."""
73
74
def cuda(self, device=None):
75
"""Move model to CUDA device."""
76
77
def cpu(self):
78
"""Move model to CPU."""
79
80
def train(self, mode: bool = True):
81
"""Set training mode."""
82
83
def eval(self):
84
"""Set evaluation mode."""
85
```
86
87
### Data Preprocessors
88
89
Classes for preprocessing input data before feeding to models.
90
91
```python { .api }
92
class BaseDataPreprocessor:
93
def __init__(self, mean: list = None, std: list = None, pad_size_divisor: int = 1, pad_value: float = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, non_blocking: bool = False):
94
"""
95
Base data preprocessor.
96
97
Parameters:
98
- mean: Mean values for normalization
99
- std: Standard deviation values for normalization
100
- pad_size_divisor: Padding size divisor
101
- pad_value: Padding value
102
- bgr_to_rgb: Whether to convert BGR to RGB
103
- rgb_to_bgr: Whether to convert RGB to BGR
104
- non_blocking: Whether to use non-blocking data movement
105
"""
106
107
def forward(self, data: dict, training: bool = False) -> dict:
108
"""
109
Forward pass for data preprocessing.
110
111
Parameters:
112
- data: Input data dictionary
113
- training: Whether in training mode
114
115
Returns:
116
Preprocessed data
117
"""
118
119
def cast_data(self, data):
120
"""
121
Cast data to appropriate types and devices.
122
123
Parameters:
124
- data: Input data
125
126
Returns:
127
Casted data
128
"""
129
130
class ImgDataPreprocessor(BaseDataPreprocessor):
131
def __init__(self, mean: list = None, std: list = None, pad_size_divisor: int = 1, pad_value: float = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, batch_augments: list = None):
132
"""
133
Image data preprocessor.
134
135
Parameters:
136
- mean: RGB mean values for normalization
137
- std: RGB std values for normalization
138
- pad_size_divisor: Padding size divisor
139
- pad_value: Padding value
140
- bgr_to_rgb: Whether to convert BGR to RGB
141
- rgb_to_bgr: Whether to convert RGB to BGR
142
- batch_augments: Batch augmentation transforms
143
"""
144
```
145
146
### Base Module Classes
147
148
Enhanced PyTorch module classes with initialization and utility features.
149
150
```python { .api }
151
class BaseModule:
152
def __init__(self, init_cfg: dict = None):
153
"""
154
Base module with weight initialization support.
155
156
Parameters:
157
- init_cfg: Initialization configuration
158
"""
159
160
def init_weights(self):
161
"""Initialize module weights."""
162
163
class ModuleDict:
164
def __init__(self, modules: dict = None):
165
"""
166
Module dictionary container.
167
168
Parameters:
169
- modules: Dictionary of modules
170
"""
171
172
def __getitem__(self, key: str):
173
"""Get module by key."""
174
175
def __setitem__(self, key: str, module):
176
"""Set module by key."""
177
178
def __delitem__(self, key: str):
179
"""Delete module by key."""
180
181
def __len__(self) -> int:
182
"""Get number of modules."""
183
184
def __iter__(self):
185
"""Iterate over module keys."""
186
187
def keys(self):
188
"""Get module keys."""
189
190
def values(self):
191
"""Get module values."""
192
193
def items(self):
194
"""Get module items."""
195
196
class ModuleList:
197
def __init__(self, modules: list = None):
198
"""
199
Module list container.
200
201
Parameters:
202
- modules: List of modules
203
"""
204
205
def __getitem__(self, idx: int):
206
"""Get module by index."""
207
208
def __setitem__(self, idx: int, module):
209
"""Set module by index."""
210
211
def __delitem__(self, idx: int):
212
"""Delete module by index."""
213
214
def __len__(self) -> int:
215
"""Get number of modules."""
216
217
def __iter__(self):
218
"""Iterate over modules."""
219
220
def append(self, module):
221
"""Append module to list."""
222
223
def extend(self, modules: list):
224
"""Extend list with modules."""
225
226
def insert(self, index: int, module):
227
"""Insert module at index."""
228
229
class Sequential:
230
def __init__(self, *args):
231
"""
232
Sequential module container.
233
234
Parameters:
235
- *args: Modules to add sequentially
236
"""
237
238
def forward(self, input):
239
"""Sequential forward pass."""
240
```
241
242
### Hook System
243
244
Comprehensive hook system for customizing training behaviors at different stages.
245
246
```python { .api }
247
class Hook:
248
priority = 'NORMAL' # Hook priority level
249
250
def before_run(self, runner):
251
"""Called before training starts."""
252
253
def after_run(self, runner):
254
"""Called after training ends."""
255
256
def before_train(self, runner):
257
"""Called before training loop."""
258
259
def after_train(self, runner):
260
"""Called after training loop."""
261
262
def before_train_epoch(self, runner):
263
"""Called before each training epoch."""
264
265
def after_train_epoch(self, runner):
266
"""Called after each training epoch."""
267
268
def before_train_iter(self, runner):
269
"""Called before each training iteration."""
270
271
def after_train_iter(self, runner):
272
"""Called after each training iteration."""
273
274
def before_val(self, runner):
275
"""Called before validation."""
276
277
def after_val(self, runner):
278
"""Called after validation."""
279
280
def before_val_epoch(self, runner):
281
"""Called before validation epoch."""
282
283
def after_val_epoch(self, runner):
284
"""Called after validation epoch."""
285
286
def before_val_iter(self, runner):
287
"""Called before validation iteration."""
288
289
def after_val_iter(self, runner):
290
"""Called after validation iteration."""
291
292
def before_save_checkpoint(self, runner, checkpoint: dict):
293
"""Called before saving checkpoint."""
294
295
def after_load_checkpoint(self, runner, checkpoint: dict):
296
"""Called after loading checkpoint."""
297
298
def before_test(self, runner):
299
"""Called before testing."""
300
301
def after_test(self, runner):
302
"""Called after testing."""
303
```
304
305
### Built-in Hooks
306
307
Collection of commonly used hooks for various training scenarios.
308
309
```python { .api }
310
class CheckpointHook(Hook):
311
def __init__(self, interval: int = -1, by_epoch: bool = True, save_optimizer: bool = True, save_param_scheduler: bool = True, out_dir: str = None, max_keep_ckpts: int = -1, save_last: bool = True, save_best: str = 'auto', rule: str = 'greater', greater_keys: list = None, less_keys: list = None, file_client_args: dict = None, published_keys: list = None):
312
"""
313
Hook for saving checkpoints.
314
315
Parameters:
316
- interval: Save interval
317
- by_epoch: Whether to save by epoch
318
- save_optimizer: Whether to save optimizer state
319
- save_param_scheduler: Whether to save scheduler state
320
- out_dir: Output directory
321
- max_keep_ckpts: Maximum checkpoints to keep
322
- save_last: Whether to save last checkpoint
323
- save_best: Best checkpoint strategy
324
- rule: Comparison rule for best checkpoint
325
- greater_keys: Keys that should be greater for best
326
- less_keys: Keys that should be less for best
327
- file_client_args: File client arguments
328
- published_keys: Keys to publish in checkpoint
329
"""
330
331
class LoggerHook(Hook):
332
def __init__(self, interval: int = 10, ignore_last: bool = True, reset_flag: bool = False, by_epoch: bool = True):
333
"""
334
Hook for logging training information.
335
336
Parameters:
337
- interval: Logging interval
338
- ignore_last: Whether to ignore last incomplete interval
339
- reset_flag: Whether to reset log flag
340
- by_epoch: Whether to log by epoch
341
"""
342
343
class IterTimerHook(Hook):
344
def __init__(self):
345
"""Hook for timing training iterations."""
346
347
class DistSamplerSeedHook(Hook):
348
def __init__(self):
349
"""Hook for setting distributed sampler seed."""
350
351
class ParamSchedulerHook(Hook):
352
def __init__(self):
353
"""Hook for parameter scheduling."""
354
355
class EMAHook(Hook):
356
def __init__(self, ema_type: str = 'ExponentialMovingAverage', momentum: float = 0.0002, update_buffers: bool = False, priority: int = 49):
357
"""
358
Hook for exponential moving average.
359
360
Parameters:
361
- ema_type: Type of EMA ('ExponentialMovingAverage', 'MomentumAnnealingEMA')
362
- momentum: EMA momentum
363
- update_buffers: Whether to update buffers
364
- priority: Hook priority
365
"""
366
367
class EmptyCacheHook(Hook):
368
def __init__(self, before_epoch: bool = False, after_epoch: bool = True, after_iter: bool = False):
369
"""
370
Hook for emptying CUDA cache.
371
372
Parameters:
373
- before_epoch: Whether to empty before epoch
374
- after_epoch: Whether to empty after epoch
375
- after_iter: Whether to empty after iteration
376
"""
377
378
class SyncBuffersHook(Hook):
379
def __init__(self):
380
"""Hook for synchronizing model buffers in distributed training."""
381
382
class RuntimeInfoHook(Hook):
383
def __init__(self, enable_tensorboard: bool = True):
384
"""
385
Hook for collecting runtime information.
386
387
Parameters:
388
- enable_tensorboard: Whether to enable tensorboard logging
389
"""
390
391
class EarlyStoppingHook(Hook):
392
def __init__(self, monitor: str, min_delta: float = 0, patience: int = 5, verbose: bool = False, mode: str = 'min', baseline: float = None, restore_best_weights: bool = False):
393
"""
394
Hook for early stopping.
395
396
Parameters:
397
- monitor: Metric to monitor
398
- min_delta: Minimum change to qualify as improvement
399
- patience: Number of epochs with no improvement after which training stops
400
- verbose: Whether to print early stopping messages
401
- mode: 'min' or 'max' mode
402
- baseline: Baseline value for the monitored quantity
403
- restore_best_weights: Whether to restore best weights
404
"""
405
406
class ProfilerHook(Hook):
407
def __init__(self, by_epoch: bool = True, profile_iters: int = 1, activities: list = None, schedule: dict = None, on_trace_ready: callable = None, record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, json_trace_path: str = None):
408
"""
409
Hook for profiling training performance.
410
411
Parameters:
412
- by_epoch: Whether to profile by epoch
413
- profile_iters: Number of iterations to profile
414
- activities: List of activities to profile
415
- schedule: Profiling schedule
416
- on_trace_ready: Callback for trace ready
417
- record_shapes: Whether to record tensor shapes
418
- profile_memory: Whether to profile memory
419
- with_stack: Whether to record stack traces
420
- with_flops: Whether to record FLOPs
421
- json_trace_path: Path to save JSON trace
422
"""
423
```
424
425
### Model Utilities
426
427
Utility functions for model operations and management.
428
429
```python { .api }
430
def stack_batch(tensors: list, pad_size_divisor: int = 0, pad_value: float = 0) -> torch.Tensor:
431
"""
432
Stack list of tensors into batch tensor.
433
434
Parameters:
435
- tensors: List of tensors to stack
436
- pad_size_divisor: Padding size divisor
437
- pad_value: Padding value
438
439
Returns:
440
Stacked batch tensor
441
"""
442
443
def merge_dict(*dicts: dict) -> dict:
444
"""
445
Merge multiple dictionaries.
446
447
Parameters:
448
- *dicts: Dictionaries to merge
449
450
Returns:
451
Merged dictionary
452
"""
453
454
def detect_anomalous_params(loss: torch.Tensor, model: torch.nn.Module) -> dict:
455
"""
456
Detect anomalous parameters (NaN or Inf).
457
458
Parameters:
459
- loss: Loss tensor
460
- model: Model to check
461
462
Returns:
463
Dictionary of anomalous parameters
464
"""
465
466
def convert_sync_batchnorm(model: torch.nn.Module, process_group=None) -> torch.nn.Module:
467
"""
468
Convert BatchNorm to SyncBatchNorm for distributed training.
469
470
Parameters:
471
- model: Model to convert
472
- process_group: Process group for synchronization
473
474
Returns:
475
Model with SyncBatchNorm
476
"""
477
478
def revert_sync_batchnorm(model: torch.nn.Module) -> torch.nn.Module:
479
"""
480
Revert SyncBatchNorm back to BatchNorm.
481
482
Parameters:
483
- model: Model to revert
484
485
Returns:
486
Model with BatchNorm
487
"""
488
```
489
490
### Model Wrappers
491
492
Wrappers for models to handle distributed training and other special scenarios.
493
494
```python { .api }
495
def is_model_wrapper(model) -> bool:
496
"""
497
Check if model is wrapped.
498
499
Parameters:
500
- model: Model to check
501
502
Returns:
503
True if model is wrapped
504
"""
505
```
506
507
### Test-Time Augmentation
508
509
Base class for test-time augmentation models.
510
511
```python { .api }
512
class BaseTTAModel:
513
def __init__(self, module, tta_cfg: dict = None):
514
"""
515
Base test-time augmentation model.
516
517
Parameters:
518
- module: Base model module
519
- tta_cfg: TTA configuration
520
"""
521
522
def test_step(self, data):
523
"""
524
Test step with augmentation.
525
526
Parameters:
527
- data: Input data
528
529
Returns:
530
Augmented test results
531
"""
532
533
def merge_preds(self, data_samples_list: list):
534
"""
535
Merge predictions from different augmentations.
536
537
Parameters:
538
- data_samples_list: List of predictions
539
540
Returns:
541
Merged predictions
542
"""
543
```
544
545
## Usage Examples
546
547
### Basic Model Implementation
548
549
```python
550
from mmengine.model import BaseModel
551
import torch.nn as nn
552
553
class MyModel(BaseModel):
554
def __init__(self, num_classes=10, init_cfg=None):
555
super().__init__(init_cfg=init_cfg)
556
self.backbone = nn.Sequential(
557
nn.Conv2d(3, 64, 3, padding=1),
558
nn.ReLU(),
559
nn.AdaptiveAvgPool2d(1)
560
)
561
self.head = nn.Linear(64, num_classes)
562
563
def forward(self, inputs):
564
x = self.backbone(inputs)
565
x = x.flatten(1)
566
return self.head(x)
567
568
def train_step(self, data, optim_wrapper):
569
inputs = data['inputs']
570
labels = data['labels']
571
572
logits = self(inputs)
573
loss = nn.CrossEntropyLoss()(logits, labels)
574
575
parsed_loss, log_vars = self.parse_losses({'loss': loss})
576
optim_wrapper.update_params(parsed_loss)
577
578
return {'loss': parsed_loss, 'log_vars': log_vars}
579
```
580
581
### Custom Hook Implementation
582
583
```python
584
from mmengine.hooks import Hook
585
586
class CustomValidationHook(Hook):
587
def __init__(self, val_interval=1):
588
self.val_interval = val_interval
589
590
def after_train_epoch(self, runner):
591
if (runner.epoch + 1) % self.val_interval == 0:
592
runner.val()
593
594
# Custom validation logic
595
val_metrics = runner.message_hub.get_scalar('val_acc')
596
if val_metrics.current > 0.95:
597
runner.logger.info("High accuracy achieved!")
598
599
# Register and use hook
600
runner.register_hook(CustomValidationHook(val_interval=5))
601
```
602
603
### Model with Data Preprocessor
604
605
```python
606
from mmengine.model import BaseModel, ImgDataPreprocessor
607
608
model = BaseModel(
609
data_preprocessor=dict(
610
type='ImgDataPreprocessor',
611
mean=[123.675, 116.28, 103.53],
612
std=[58.395, 57.12, 57.375],
613
bgr_to_rgb=True,
614
pad_size_divisor=32
615
)
616
)
617
```
618
619
### Using Built-in Hooks
620
621
```python
622
from mmengine import Runner
623
from mmengine.hooks import CheckpointHook, LoggerHook, EMAHook
624
625
# Configure hooks
626
default_hooks = dict(
627
timer=dict(type='IterTimerHook'),
628
logger=dict(type='LoggerHook', interval=100),
629
param_scheduler=dict(type='ParamSchedulerHook'),
630
sampler_seed=dict(type='DistSamplerSeedHook'),
631
checkpoint=dict(
632
type='CheckpointHook',
633
interval=1,
634
save_best='auto',
635
max_keep_ckpts=3
636
)
637
)
638
639
custom_hooks = [
640
dict(type='EMAHook', momentum=0.0002, priority=49)
641
]
642
643
runner = Runner(
644
model=model,
645
default_hooks=default_hooks,
646
custom_hooks=custom_hooks
647
)
648
```
649
650
### Model Utilities Usage
651
652
```python
653
from mmengine.model import convert_sync_batchnorm, detect_anomalous_params
654
655
# Convert model for distributed training
656
model = convert_sync_batchnorm(model)
657
658
# Check for anomalous parameters during training
659
def training_step(model, data, optimizer):
660
loss = model(data)
661
662
# Check for anomalies
663
anomalous = detect_anomalous_params(loss, model)
664
if anomalous:
665
print(f"Anomalous parameters detected: {anomalous}")
666
667
loss.backward()
668
optimizer.step()
669
```
670
671
### Priority-based Hook Ordering
672
673
```python
674
from mmengine.hooks import Hook
675
from mmengine.runner import get_priority
676
677
class HighPriorityHook(Hook):
678
priority = 'HIGH' # or get_priority('HIGH')
679
680
def before_train_iter(self, runner):
681
# This runs before normal priority hooks
682
pass
683
684
class LowPriorityHook(Hook):
685
priority = 'LOW'
686
687
def after_train_iter(self, runner):
688
# This runs after normal priority hooks
689
pass
690
```