or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

configuration.mddataset.mddistributed.mdfileio.mdindex.mdlogging.mdmodels.mdoptimization.mdregistry.mdtraining.mdvisualization.md

models.mddocs/

0

# Models and Hooks

1

2

Comprehensive model management system with base classes, weight initialization, model wrappers for distributed training, and extensive hook system for customizing training behaviors. The system provides foundation classes and utilities for building robust training pipelines.

3

4

## Capabilities

5

6

### Base Model Classes

7

8

Foundation classes for all models in MMEngine with standardized interfaces for training, validation, and testing.

9

10

```python { .api }

11

class BaseModel:

12

def __init__(self, init_cfg: dict = None, data_preprocessor: dict = None):

13

"""

14

Base class for all models.

15

16

Parameters:

17

- init_cfg: Weight initialization configuration

18

- data_preprocessor: Data preprocessor configuration

19

"""

20

21

def forward(self, *args, **kwargs):

22

"""

23

Forward pass implementation.

24

25

Parameters:

26

- *args: Input arguments

27

- **kwargs: Input keyword arguments

28

29

Returns:

30

Model outputs

31

"""

32

33

def train_step(self, data, optim_wrapper):

34

"""

35

Training step implementation.

36

37

Parameters:

38

- data: Input data batch

39

- optim_wrapper: Optimizer wrapper

40

41

Returns:

42

Dictionary containing loss and log variables

43

"""

44

45

def val_step(self, data):

46

"""

47

Validation step implementation.

48

49

Parameters:

50

- data: Input data batch

51

52

Returns:

53

Validation outputs

54

"""

55

56

def test_step(self, data):

57

"""

58

Test step implementation.

59

60

Parameters:

61

- data: Input data batch

62

63

Returns:

64

Test outputs

65

"""

66

67

def init_weights(self):

68

"""Initialize model weights."""

69

70

@property

71

def device(self):

72

"""Get model device."""

73

74

def cuda(self, device=None):

75

"""Move model to CUDA device."""

76

77

def cpu(self):

78

"""Move model to CPU."""

79

80

def train(self, mode: bool = True):

81

"""Set training mode."""

82

83

def eval(self):

84

"""Set evaluation mode."""

85

```

86

87

### Data Preprocessors

88

89

Classes for preprocessing input data before feeding to models.

90

91

```python { .api }

92

class BaseDataPreprocessor:

93

def __init__(self, mean: list = None, std: list = None, pad_size_divisor: int = 1, pad_value: float = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, non_blocking: bool = False):

94

"""

95

Base data preprocessor.

96

97

Parameters:

98

- mean: Mean values for normalization

99

- std: Standard deviation values for normalization

100

- pad_size_divisor: Padding size divisor

101

- pad_value: Padding value

102

- bgr_to_rgb: Whether to convert BGR to RGB

103

- rgb_to_bgr: Whether to convert RGB to BGR

104

- non_blocking: Whether to use non-blocking data movement

105

"""

106

107

def forward(self, data: dict, training: bool = False) -> dict:

108

"""

109

Forward pass for data preprocessing.

110

111

Parameters:

112

- data: Input data dictionary

113

- training: Whether in training mode

114

115

Returns:

116

Preprocessed data

117

"""

118

119

def cast_data(self, data):

120

"""

121

Cast data to appropriate types and devices.

122

123

Parameters:

124

- data: Input data

125

126

Returns:

127

Casted data

128

"""

129

130

class ImgDataPreprocessor(BaseDataPreprocessor):

131

def __init__(self, mean: list = None, std: list = None, pad_size_divisor: int = 1, pad_value: float = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, batch_augments: list = None):

132

"""

133

Image data preprocessor.

134

135

Parameters:

136

- mean: RGB mean values for normalization

137

- std: RGB std values for normalization

138

- pad_size_divisor: Padding size divisor

139

- pad_value: Padding value

140

- bgr_to_rgb: Whether to convert BGR to RGB

141

- rgb_to_bgr: Whether to convert RGB to BGR

142

- batch_augments: Batch augmentation transforms

143

"""

144

```

145

146

### Base Module Classes

147

148

Enhanced PyTorch module classes with initialization and utility features.

149

150

```python { .api }

151

class BaseModule:

152

def __init__(self, init_cfg: dict = None):

153

"""

154

Base module with weight initialization support.

155

156

Parameters:

157

- init_cfg: Initialization configuration

158

"""

159

160

def init_weights(self):

161

"""Initialize module weights."""

162

163

class ModuleDict:

164

def __init__(self, modules: dict = None):

165

"""

166

Module dictionary container.

167

168

Parameters:

169

- modules: Dictionary of modules

170

"""

171

172

def __getitem__(self, key: str):

173

"""Get module by key."""

174

175

def __setitem__(self, key: str, module):

176

"""Set module by key."""

177

178

def __delitem__(self, key: str):

179

"""Delete module by key."""

180

181

def __len__(self) -> int:

182

"""Get number of modules."""

183

184

def __iter__(self):

185

"""Iterate over module keys."""

186

187

def keys(self):

188

"""Get module keys."""

189

190

def values(self):

191

"""Get module values."""

192

193

def items(self):

194

"""Get module items."""

195

196

class ModuleList:

197

def __init__(self, modules: list = None):

198

"""

199

Module list container.

200

201

Parameters:

202

- modules: List of modules

203

"""

204

205

def __getitem__(self, idx: int):

206

"""Get module by index."""

207

208

def __setitem__(self, idx: int, module):

209

"""Set module by index."""

210

211

def __delitem__(self, idx: int):

212

"""Delete module by index."""

213

214

def __len__(self) -> int:

215

"""Get number of modules."""

216

217

def __iter__(self):

218

"""Iterate over modules."""

219

220

def append(self, module):

221

"""Append module to list."""

222

223

def extend(self, modules: list):

224

"""Extend list with modules."""

225

226

def insert(self, index: int, module):

227

"""Insert module at index."""

228

229

class Sequential:

230

def __init__(self, *args):

231

"""

232

Sequential module container.

233

234

Parameters:

235

- *args: Modules to add sequentially

236

"""

237

238

def forward(self, input):

239

"""Sequential forward pass."""

240

```

241

242

### Hook System

243

244

Comprehensive hook system for customizing training behaviors at different stages.

245

246

```python { .api }

247

class Hook:

248

priority = 'NORMAL' # Hook priority level

249

250

def before_run(self, runner):

251

"""Called before training starts."""

252

253

def after_run(self, runner):

254

"""Called after training ends."""

255

256

def before_train(self, runner):

257

"""Called before training loop."""

258

259

def after_train(self, runner):

260

"""Called after training loop."""

261

262

def before_train_epoch(self, runner):

263

"""Called before each training epoch."""

264

265

def after_train_epoch(self, runner):

266

"""Called after each training epoch."""

267

268

def before_train_iter(self, runner):

269

"""Called before each training iteration."""

270

271

def after_train_iter(self, runner):

272

"""Called after each training iteration."""

273

274

def before_val(self, runner):

275

"""Called before validation."""

276

277

def after_val(self, runner):

278

"""Called after validation."""

279

280

def before_val_epoch(self, runner):

281

"""Called before validation epoch."""

282

283

def after_val_epoch(self, runner):

284

"""Called after validation epoch."""

285

286

def before_val_iter(self, runner):

287

"""Called before validation iteration."""

288

289

def after_val_iter(self, runner):

290

"""Called after validation iteration."""

291

292

def before_save_checkpoint(self, runner, checkpoint: dict):

293

"""Called before saving checkpoint."""

294

295

def after_load_checkpoint(self, runner, checkpoint: dict):

296

"""Called after loading checkpoint."""

297

298

def before_test(self, runner):

299

"""Called before testing."""

300

301

def after_test(self, runner):

302

"""Called after testing."""

303

```

304

305

### Built-in Hooks

306

307

Collection of commonly used hooks for various training scenarios.

308

309

```python { .api }

310

class CheckpointHook(Hook):

311

def __init__(self, interval: int = -1, by_epoch: bool = True, save_optimizer: bool = True, save_param_scheduler: bool = True, out_dir: str = None, max_keep_ckpts: int = -1, save_last: bool = True, save_best: str = 'auto', rule: str = 'greater', greater_keys: list = None, less_keys: list = None, file_client_args: dict = None, published_keys: list = None):

312

"""

313

Hook for saving checkpoints.

314

315

Parameters:

316

- interval: Save interval

317

- by_epoch: Whether to save by epoch

318

- save_optimizer: Whether to save optimizer state

319

- save_param_scheduler: Whether to save scheduler state

320

- out_dir: Output directory

321

- max_keep_ckpts: Maximum checkpoints to keep

322

- save_last: Whether to save last checkpoint

323

- save_best: Best checkpoint strategy

324

- rule: Comparison rule for best checkpoint

325

- greater_keys: Keys that should be greater for best

326

- less_keys: Keys that should be less for best

327

- file_client_args: File client arguments

328

- published_keys: Keys to publish in checkpoint

329

"""

330

331

class LoggerHook(Hook):

332

def __init__(self, interval: int = 10, ignore_last: bool = True, reset_flag: bool = False, by_epoch: bool = True):

333

"""

334

Hook for logging training information.

335

336

Parameters:

337

- interval: Logging interval

338

- ignore_last: Whether to ignore last incomplete interval

339

- reset_flag: Whether to reset log flag

340

- by_epoch: Whether to log by epoch

341

"""

342

343

class IterTimerHook(Hook):

344

def __init__(self):

345

"""Hook for timing training iterations."""

346

347

class DistSamplerSeedHook(Hook):

348

def __init__(self):

349

"""Hook for setting distributed sampler seed."""

350

351

class ParamSchedulerHook(Hook):

352

def __init__(self):

353

"""Hook for parameter scheduling."""

354

355

class EMAHook(Hook):

356

def __init__(self, ema_type: str = 'ExponentialMovingAverage', momentum: float = 0.0002, update_buffers: bool = False, priority: int = 49):

357

"""

358

Hook for exponential moving average.

359

360

Parameters:

361

- ema_type: Type of EMA ('ExponentialMovingAverage', 'MomentumAnnealingEMA')

362

- momentum: EMA momentum

363

- update_buffers: Whether to update buffers

364

- priority: Hook priority

365

"""

366

367

class EmptyCacheHook(Hook):

368

def __init__(self, before_epoch: bool = False, after_epoch: bool = True, after_iter: bool = False):

369

"""

370

Hook for emptying CUDA cache.

371

372

Parameters:

373

- before_epoch: Whether to empty before epoch

374

- after_epoch: Whether to empty after epoch

375

- after_iter: Whether to empty after iteration

376

"""

377

378

class SyncBuffersHook(Hook):

379

def __init__(self):

380

"""Hook for synchronizing model buffers in distributed training."""

381

382

class RuntimeInfoHook(Hook):

383

def __init__(self, enable_tensorboard: bool = True):

384

"""

385

Hook for collecting runtime information.

386

387

Parameters:

388

- enable_tensorboard: Whether to enable tensorboard logging

389

"""

390

391

class EarlyStoppingHook(Hook):

392

def __init__(self, monitor: str, min_delta: float = 0, patience: int = 5, verbose: bool = False, mode: str = 'min', baseline: float = None, restore_best_weights: bool = False):

393

"""

394

Hook for early stopping.

395

396

Parameters:

397

- monitor: Metric to monitor

398

- min_delta: Minimum change to qualify as improvement

399

- patience: Number of epochs with no improvement after which training stops

400

- verbose: Whether to print early stopping messages

401

- mode: 'min' or 'max' mode

402

- baseline: Baseline value for the monitored quantity

403

- restore_best_weights: Whether to restore best weights

404

"""

405

406

class ProfilerHook(Hook):

407

def __init__(self, by_epoch: bool = True, profile_iters: int = 1, activities: list = None, schedule: dict = None, on_trace_ready: callable = None, record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, json_trace_path: str = None):

408

"""

409

Hook for profiling training performance.

410

411

Parameters:

412

- by_epoch: Whether to profile by epoch

413

- profile_iters: Number of iterations to profile

414

- activities: List of activities to profile

415

- schedule: Profiling schedule

416

- on_trace_ready: Callback for trace ready

417

- record_shapes: Whether to record tensor shapes

418

- profile_memory: Whether to profile memory

419

- with_stack: Whether to record stack traces

420

- with_flops: Whether to record FLOPs

421

- json_trace_path: Path to save JSON trace

422

"""

423

```

424

425

### Model Utilities

426

427

Utility functions for model operations and management.

428

429

```python { .api }

430

def stack_batch(tensors: list, pad_size_divisor: int = 0, pad_value: float = 0) -> torch.Tensor:

431

"""

432

Stack list of tensors into batch tensor.

433

434

Parameters:

435

- tensors: List of tensors to stack

436

- pad_size_divisor: Padding size divisor

437

- pad_value: Padding value

438

439

Returns:

440

Stacked batch tensor

441

"""

442

443

def merge_dict(*dicts: dict) -> dict:

444

"""

445

Merge multiple dictionaries.

446

447

Parameters:

448

- *dicts: Dictionaries to merge

449

450

Returns:

451

Merged dictionary

452

"""

453

454

def detect_anomalous_params(loss: torch.Tensor, model: torch.nn.Module) -> dict:

455

"""

456

Detect anomalous parameters (NaN or Inf).

457

458

Parameters:

459

- loss: Loss tensor

460

- model: Model to check

461

462

Returns:

463

Dictionary of anomalous parameters

464

"""

465

466

def convert_sync_batchnorm(model: torch.nn.Module, process_group=None) -> torch.nn.Module:

467

"""

468

Convert BatchNorm to SyncBatchNorm for distributed training.

469

470

Parameters:

471

- model: Model to convert

472

- process_group: Process group for synchronization

473

474

Returns:

475

Model with SyncBatchNorm

476

"""

477

478

def revert_sync_batchnorm(model: torch.nn.Module) -> torch.nn.Module:

479

"""

480

Revert SyncBatchNorm back to BatchNorm.

481

482

Parameters:

483

- model: Model to revert

484

485

Returns:

486

Model with BatchNorm

487

"""

488

```

489

490

### Model Wrappers

491

492

Wrappers for models to handle distributed training and other special scenarios.

493

494

```python { .api }

495

def is_model_wrapper(model) -> bool:

496

"""

497

Check if model is wrapped.

498

499

Parameters:

500

- model: Model to check

501

502

Returns:

503

True if model is wrapped

504

"""

505

```

506

507

### Test-Time Augmentation

508

509

Base class for test-time augmentation models.

510

511

```python { .api }

512

class BaseTTAModel:

513

def __init__(self, module, tta_cfg: dict = None):

514

"""

515

Base test-time augmentation model.

516

517

Parameters:

518

- module: Base model module

519

- tta_cfg: TTA configuration

520

"""

521

522

def test_step(self, data):

523

"""

524

Test step with augmentation.

525

526

Parameters:

527

- data: Input data

528

529

Returns:

530

Augmented test results

531

"""

532

533

def merge_preds(self, data_samples_list: list):

534

"""

535

Merge predictions from different augmentations.

536

537

Parameters:

538

- data_samples_list: List of predictions

539

540

Returns:

541

Merged predictions

542

"""

543

```

544

545

## Usage Examples

546

547

### Basic Model Implementation

548

549

```python

550

from mmengine.model import BaseModel

551

import torch.nn as nn

552

553

class MyModel(BaseModel):

554

def __init__(self, num_classes=10, init_cfg=None):

555

super().__init__(init_cfg=init_cfg)

556

self.backbone = nn.Sequential(

557

nn.Conv2d(3, 64, 3, padding=1),

558

nn.ReLU(),

559

nn.AdaptiveAvgPool2d(1)

560

)

561

self.head = nn.Linear(64, num_classes)

562

563

def forward(self, inputs):

564

x = self.backbone(inputs)

565

x = x.flatten(1)

566

return self.head(x)

567

568

def train_step(self, data, optim_wrapper):

569

inputs = data['inputs']

570

labels = data['labels']

571

572

logits = self(inputs)

573

loss = nn.CrossEntropyLoss()(logits, labels)

574

575

parsed_loss, log_vars = self.parse_losses({'loss': loss})

576

optim_wrapper.update_params(parsed_loss)

577

578

return {'loss': parsed_loss, 'log_vars': log_vars}

579

```

580

581

### Custom Hook Implementation

582

583

```python

584

from mmengine.hooks import Hook

585

586

class CustomValidationHook(Hook):

587

def __init__(self, val_interval=1):

588

self.val_interval = val_interval

589

590

def after_train_epoch(self, runner):

591

if (runner.epoch + 1) % self.val_interval == 0:

592

runner.val()

593

594

# Custom validation logic

595

val_metrics = runner.message_hub.get_scalar('val_acc')

596

if val_metrics.current > 0.95:

597

runner.logger.info("High accuracy achieved!")

598

599

# Register and use hook

600

runner.register_hook(CustomValidationHook(val_interval=5))

601

```

602

603

### Model with Data Preprocessor

604

605

```python

606

from mmengine.model import BaseModel, ImgDataPreprocessor

607

608

model = BaseModel(

609

data_preprocessor=dict(

610

type='ImgDataPreprocessor',

611

mean=[123.675, 116.28, 103.53],

612

std=[58.395, 57.12, 57.375],

613

bgr_to_rgb=True,

614

pad_size_divisor=32

615

)

616

)

617

```

618

619

### Using Built-in Hooks

620

621

```python

622

from mmengine import Runner

623

from mmengine.hooks import CheckpointHook, LoggerHook, EMAHook

624

625

# Configure hooks

626

default_hooks = dict(

627

timer=dict(type='IterTimerHook'),

628

logger=dict(type='LoggerHook', interval=100),

629

param_scheduler=dict(type='ParamSchedulerHook'),

630

sampler_seed=dict(type='DistSamplerSeedHook'),

631

checkpoint=dict(

632

type='CheckpointHook',

633

interval=1,

634

save_best='auto',

635

max_keep_ckpts=3

636

)

637

)

638

639

custom_hooks = [

640

dict(type='EMAHook', momentum=0.0002, priority=49)

641

]

642

643

runner = Runner(

644

model=model,

645

default_hooks=default_hooks,

646

custom_hooks=custom_hooks

647

)

648

```

649

650

### Model Utilities Usage

651

652

```python

653

from mmengine.model import convert_sync_batchnorm, detect_anomalous_params

654

655

# Convert model for distributed training

656

model = convert_sync_batchnorm(model)

657

658

# Check for anomalous parameters during training

659

def training_step(model, data, optimizer):

660

loss = model(data)

661

662

# Check for anomalies

663

anomalous = detect_anomalous_params(loss, model)

664

if anomalous:

665

print(f"Anomalous parameters detected: {anomalous}")

666

667

loss.backward()

668

optimizer.step()

669

```

670

671

### Priority-based Hook Ordering

672

673

```python

674

from mmengine.hooks import Hook

675

from mmengine.runner import get_priority

676

677

class HighPriorityHook(Hook):

678

priority = 'HIGH' # or get_priority('HIGH')

679

680

def before_train_iter(self, runner):

681

# This runs before normal priority hooks

682

pass

683

684

class LowPriorityHook(Hook):

685

priority = 'LOW'

686

687

def after_train_iter(self, runner):

688

# This runs after normal priority hooks

689

pass

690

```