or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

configuration.mddataset.mddistributed.mdfileio.mdindex.mdlogging.mdmodels.mdoptimization.mdregistry.mdtraining.mdvisualization.md

training.mddocs/

0

# Training and Loops

1

2

Complete training orchestration system with flexible runners supporting epoch-based and iteration-based training, validation, and testing loops with built-in checkpointing and logging. The system provides a unified interface for managing the entire training pipeline.

3

4

## Capabilities

5

6

### Runner Class

7

8

Central coordinator managing the entire training process with flexible configuration and automatic component initialization.

9

10

```python { .api }

11

class Runner:

12

def __init__(self, model, work_dir: str = None, train_dataloader = None, val_dataloader = None, test_dataloader = None, train_cfg: dict = None, val_cfg: dict = None, test_cfg: dict = None, auto_scale_lr: dict = None, optim_wrapper = None, param_scheduler = None, val_evaluator = None, test_evaluator = None, default_hooks: dict = None, custom_hooks: list = None, data_preprocessor = None, load_from: str = None, resume: bool = False, launcher: str = 'none', env_cfg: dict = None, log_processor = None, visualizer = None, default_scope: str = 'mmengine', randomness: dict = None, experiment_name: str = None, cfg: dict = None):

13

"""

14

Initialize Runner with comprehensive training configuration.

15

16

Parameters:

17

- model: Model to train (torch.nn.Module or config dict)

18

- work_dir: Working directory for saving outputs

19

- train_dataloader: Training data loader

20

- val_dataloader: Validation data loader

21

- test_dataloader: Test data loader

22

- train_cfg: Training loop configuration

23

- val_cfg: Validation loop configuration

24

- test_cfg: Test loop configuration

25

- auto_scale_lr: Automatic learning rate scaling configuration

26

- optim_wrapper: Optimizer wrapper configuration

27

- param_scheduler: Parameter scheduler configuration

28

- val_evaluator: Validation evaluator configuration

29

- test_evaluator: Test evaluator configuration

30

- default_hooks: Default hooks configuration

31

- custom_hooks: Custom hooks list

32

- data_preprocessor: Data preprocessor configuration

33

- load_from: Checkpoint path to load

34

- resume: Whether to resume training

35

- launcher: Distributed launcher type

36

- env_cfg: Environment configuration

37

- log_processor: Log processor configuration

38

- visualizer: Visualizer configuration

39

- default_scope: Default registry scope

40

- randomness: Randomness configuration

41

- experiment_name: Experiment name

42

- cfg: Complete configuration object

43

"""

44

45

@classmethod

46

def from_cfg(cls, cfg) -> 'Runner':

47

"""

48

Create Runner from configuration.

49

50

Parameters:

51

- cfg: Configuration object or dict

52

53

Returns:

54

Initialized Runner instance

55

"""

56

57

def train(self):

58

"""Run training loop."""

59

60

def val(self):

61

"""Run validation loop."""

62

63

def test(self):

64

"""Run test loop."""

65

66

def call_hook(self, fn_name: str, **kwargs):

67

"""

68

Call hook method.

69

70

Parameters:

71

- fn_name: Hook method name

72

- **kwargs: Hook arguments

73

"""

74

75

def register_hook(self, hook, priority: str = 'NORMAL'):

76

"""

77

Register hook.

78

79

Parameters:

80

- hook: Hook instance or config

81

- priority: Hook priority

82

"""

83

84

def load_or_resume(self):

85

"""Load checkpoint or resume training."""

86

87

def save_checkpoint(self, out_dir: str, filename: str = None, file_client_args: dict = None, save_optimizer: bool = True, save_param_scheduler: bool = True, meta: dict = None, by_epoch: bool = True):

88

"""

89

Save checkpoint.

90

91

Parameters:

92

- out_dir: Output directory

93

- filename: Checkpoint filename

94

- file_client_args: File client arguments

95

- save_optimizer: Whether to save optimizer state

96

- save_param_scheduler: Whether to save scheduler state

97

- meta: Additional metadata

98

- by_epoch: Whether checkpoint is by epoch

99

"""

100

101

@property

102

def epoch(self) -> int:

103

"""Current epoch number."""

104

105

@property

106

def iter(self) -> int:

107

"""Current iteration number."""

108

109

@property

110

def max_epochs(self) -> int:

111

"""Maximum number of epochs."""

112

113

@property

114

def max_iters(self) -> int:

115

"""Maximum number of iterations."""

116

```

117

118

### Flexible Runner

119

120

Extended runner with additional flexibility for custom training workflows.

121

122

```python { .api }

123

class FlexibleRunner(Runner):

124

def __init__(self, **kwargs):

125

"""

126

Initialize FlexibleRunner with extended configuration options.

127

128

Parameters:

129

- **kwargs: Same as Runner plus additional flexibility options

130

"""

131

132

def run_loop(self, loop: 'BaseLoop'):

133

"""

134

Run custom training loop.

135

136

Parameters:

137

- loop: Loop instance to execute

138

"""

139

```

140

141

### Base Loop Class

142

143

Abstract base class for all training loops providing common interface and functionality.

144

145

```python { .api }

146

class BaseLoop:

147

def __init__(self, runner, dataloader):

148

"""

149

Initialize base loop.

150

151

Parameters:

152

- runner: Runner instance

153

- dataloader: Data loader for the loop

154

"""

155

156

def run(self):

157

"""Execute the loop."""

158

159

@property

160

def iter(self) -> int:

161

"""Current iteration number."""

162

163

@property

164

def max_iters(self) -> int:

165

"""Maximum iterations for this loop."""

166

```

167

168

### Training Loops

169

170

Specialized training loops for different training strategies.

171

172

```python { .api }

173

class EpochBasedTrainLoop(BaseLoop):

174

def __init__(self, runner, dataloader, max_epochs: int, val_begin: int = 1, val_interval: int = 1, dynamic_intervals: list = None):

175

"""

176

Epoch-based training loop.

177

178

Parameters:

179

- runner: Runner instance

180

- dataloader: Training data loader

181

- max_epochs: Maximum number of epochs

182

- val_begin: Epoch to begin validation

183

- val_interval: Validation interval in epochs

184

- dynamic_intervals: Dynamic validation intervals

185

"""

186

187

def run_epoch(self):

188

"""Run one training epoch."""

189

190

def run_iter(self, idx: int, data_batch):

191

"""

192

Run one training iteration.

193

194

Parameters:

195

- idx: Iteration index

196

- data_batch: Input data batch

197

"""

198

199

class IterBasedTrainLoop(BaseLoop):

200

def __init__(self, runner, dataloader, max_iters: int, val_begin: int = 1, val_interval: int = 1, dynamic_intervals: list = None):

201

"""

202

Iteration-based training loop.

203

204

Parameters:

205

- runner: Runner instance

206

- dataloader: Training data loader

207

- max_iters: Maximum number of iterations

208

- val_begin: Iteration to begin validation

209

- val_interval: Validation interval in iterations

210

- dynamic_intervals: Dynamic validation intervals

211

"""

212

213

def run_iter(self, data_batch):

214

"""

215

Run one training iteration.

216

217

Parameters:

218

- data_batch: Input data batch

219

"""

220

```

221

222

### Validation and Test Loops

223

224

Loops for model evaluation during training or standalone testing.

225

226

```python { .api }

227

class ValLoop(BaseLoop):

228

def __init__(self, runner, dataloader, evaluator, fp16: bool = False):

229

"""

230

Validation loop.

231

232

Parameters:

233

- runner: Runner instance

234

- dataloader: Validation data loader

235

- evaluator: Evaluator for validation metrics

236

- fp16: Whether to use FP16 precision

237

"""

238

239

def run(self) -> dict:

240

"""

241

Run validation loop.

242

243

Returns:

244

Dictionary of validation metrics

245

"""

246

247

class TestLoop(BaseLoop):

248

def __init__(self, runner, dataloader, evaluator, fp16: bool = False):

249

"""

250

Test loop.

251

252

Parameters:

253

- runner: Runner instance

254

- dataloader: Test data loader

255

- evaluator: Evaluator for test metrics

256

- fp16: Whether to use FP16 precision

257

"""

258

259

def run(self) -> dict:

260

"""

261

Run test loop.

262

263

Returns:

264

Dictionary of test metrics

265

"""

266

```

267

268

### Checkpoint Management

269

270

Comprehensive checkpoint loading and saving functionality.

271

272

```python { .api }

273

def load_checkpoint(filename: str, map_location: str = None, logger = None, revise_keys: list = None) -> dict:

274

"""

275

Load checkpoint from file.

276

277

Parameters:

278

- filename: Checkpoint file path

279

- map_location: Device to load checkpoint

280

- logger: Logger instance

281

- revise_keys: Keys to revise during loading

282

283

Returns:

284

Checkpoint dictionary

285

"""

286

287

def save_checkpoint(model, filename: str, optimizer = None, lr_scheduler = None, meta: dict = None, file_client_args: dict = None):

288

"""

289

Save checkpoint to file.

290

291

Parameters:

292

- model: Model to save

293

- filename: Output filename

294

- optimizer: Optimizer state to save

295

- lr_scheduler: Learning rate scheduler to save

296

- meta: Additional metadata

297

- file_client_args: File client arguments

298

"""

299

300

def weights_to_cpu(state_dict: dict) -> dict:

301

"""

302

Move weights to CPU.

303

304

Parameters:

305

- state_dict: Model state dictionary

306

307

Returns:

308

CPU state dictionary

309

"""

310

311

def get_state_dict(module, destination: dict = None, prefix: str = '', keep_vars: bool = False) -> dict:

312

"""

313

Get model state dictionary.

314

315

Parameters:

316

- module: PyTorch module

317

- destination: Destination dictionary

318

- prefix: Key prefix

319

- keep_vars: Whether to keep variables

320

321

Returns:

322

State dictionary

323

"""

324

325

def find_latest_checkpoint(path: str, suffix: str = 'pth') -> str:

326

"""

327

Find latest checkpoint in directory.

328

329

Parameters:

330

- path: Directory path

331

- suffix: Checkpoint file suffix

332

333

Returns:

334

Latest checkpoint path

335

"""

336

```

337

338

### Model Loading Utilities

339

340

Utilities for loading pre-trained models and model information.

341

342

```python { .api }

343

def get_torchvision_models() -> list:

344

"""

345

Get list of available torchvision models.

346

347

Returns:

348

List of model names

349

"""

350

351

def get_external_models() -> list:

352

"""

353

Get list of available external models.

354

355

Returns:

356

List of external model names

357

"""

358

359

def get_mmcls_models() -> list:

360

"""

361

Get list of available MMClassification models.

362

363

Returns:

364

List of MMCls model names

365

"""

366

367

def get_deprecated_model_names() -> list:

368

"""

369

Get list of deprecated model names.

370

371

Returns:

372

List of deprecated model names

373

"""

374

375

class CheckpointLoader:

376

@staticmethod

377

def load_checkpoint(filename: str, map_location: str = None) -> dict:

378

"""

379

Load checkpoint with advanced options.

380

381

Parameters:

382

- filename: Checkpoint file path

383

- map_location: Device mapping

384

385

Returns:

386

Loaded checkpoint

387

"""

388

```

389

390

### Training Utilities

391

392

Additional utilities for training management.

393

394

```python { .api }

395

def set_random_seed(seed: int, deterministic: bool = False, diff_rank_seed: bool = False):

396

"""

397

Set random seed for reproducibility.

398

399

Parameters:

400

- seed: Random seed value

401

- deterministic: Whether to use deterministic algorithms

402

- diff_rank_seed: Whether to use different seeds for different ranks

403

"""

404

405

def turn_on_activation_checkpointing(model, **kwargs):

406

"""

407

Enable activation checkpointing for memory efficiency.

408

409

Parameters:

410

- model: Model to apply checkpointing

411

- **kwargs: Checkpointing configuration

412

"""

413

414

def autocast(*args, **kwargs):

415

"""

416

Automatic mixed precision context manager.

417

418

Parameters:

419

- *args: Positional arguments

420

- **kwargs: Keyword arguments

421

422

Returns:

423

Autocast context manager

424

"""

425

```

426

427

## Usage Examples

428

429

### Basic Training Setup

430

431

```python

432

from mmengine import Runner, Config

433

434

# Load configuration

435

cfg = Config.fromfile('config.py')

436

437

# Create runner

438

runner = Runner.from_cfg(cfg)

439

440

# Start training

441

runner.train()

442

```

443

444

### Custom Training Loop

445

446

```python

447

from mmengine import Runner, EpochBasedTrainLoop

448

449

# Create runner with custom configuration

450

runner = Runner(

451

model=model,

452

work_dir='./work_dir',

453

train_dataloader=train_loader,

454

val_dataloader=val_loader,

455

train_cfg=dict(type='EpochBasedTrainLoop', max_epochs=100),

456

optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),

457

)

458

459

# Run training

460

runner.train()

461

```

462

463

### Checkpoint Operations

464

465

```python

466

from mmengine.runner import load_checkpoint, save_checkpoint

467

468

# Load checkpoint

469

checkpoint = load_checkpoint('model.pth', map_location='cpu')

470

471

# Save checkpoint with metadata

472

save_checkpoint(

473

model,

474

'checkpoint.pth',

475

optimizer=optimizer,

476

meta={'epoch': 10, 'best_acc': 0.95}

477

)

478

479

# Find latest checkpoint

480

latest_ckpt = find_latest_checkpoint('./checkpoints')

481

```

482

483

### Custom Hook Registration

484

485

```python

486

from mmengine import Runner

487

from mmengine.hooks import Hook

488

489

class CustomHook(Hook):

490

def before_train_epoch(self, runner):

491

print(f"Starting epoch {runner.epoch}")

492

493

runner = Runner.from_cfg(cfg)

494

runner.register_hook(CustomHook(), priority='LOW')

495

runner.train()

496

```