or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

base-exceptions.mdcontrib.mddistributed.mdengine.mdhandlers.mdindex.mdmetrics.mdutils.md

engine.mddocs/

0

# Engine and Training Loop

1

2

Core training loop infrastructure with event-driven architecture. The Engine is the central component of PyTorch Ignite, providing a flexible framework for training and evaluating neural networks with comprehensive lifecycle management.

3

4

## Capabilities

5

6

### Engine Class

7

8

The main Engine class that manages training and evaluation loops with a sophisticated event system.

9

10

```python { .api }

11

class Engine:

12

"""

13

Core engine for training and evaluation loops with event system.

14

15

Parameters:

16

- process_function: callable that processes a batch of data

17

18

Attributes:

19

- state: State object containing current training information

20

- should_terminate: boolean flag to terminate training

21

- should_terminate_single_epoch: boolean flag to terminate current epoch

22

"""

23

def __init__(self, process_function):

24

"""Initialize engine with a process function."""

25

26

def run(self, data, max_epochs=1, epoch_length=None, seed=None):

27

"""

28

Run the engine on data for specified epochs.

29

30

Parameters:

31

- data: data loader or iterable

32

- max_epochs: maximum number of epochs to run

33

- epoch_length: number of iterations per epoch (optional)

34

- seed: random seed for reproducibility

35

36

Returns:

37

State object with final training state

38

"""

39

40

def add_event_handler(self, event_name, handler, *args, **kwargs):

41

"""

42

Add an event handler for the specified event.

43

44

Parameters:

45

- event_name: name of the event

46

- handler: callable to execute when event occurs

47

- args, kwargs: arguments to pass to handler

48

49

Returns:

50

RemovableEventHandle object

51

"""

52

53

def on(self, event_filter=None):

54

"""

55

Decorator for adding event handlers.

56

57

Parameters:

58

- event_filter: event or event filter to listen for

59

60

Returns:

61

Decorator function

62

"""

63

64

def fire_event(self, event_name):

65

"""Fire an event, executing all registered handlers."""

66

67

def terminate(self):

68

"""Terminate the training loop."""

69

70

def terminate_epoch(self):

71

"""Terminate the current epoch."""

72

73

def has_event_handler(self, handler, event_name=None):

74

"""Check if handler is registered for event."""

75

76

def remove_event_handler(self, handler, event_name):

77

"""Remove an event handler."""

78

79

class DeterministicEngine(Engine):

80

"""

81

Deterministic version of Engine with reproducible behavior.

82

83

Parameters:

84

- process_function: callable that processes a batch of data

85

- deterministic: enable deterministic behavior

86

"""

87

def __init__(self, process_function, deterministic=True): ...

88

```

89

90

### Events Enum

91

92

Comprehensive event system providing fine-grained control over training lifecycle.

93

94

```python { .api }

95

class Events:

96

"""Event types for engine lifecycle."""

97

STARTED = 'started'

98

EPOCH_STARTED = 'epoch_started'

99

ITERATION_STARTED = 'iteration_started'

100

ITERATION_COMPLETED = 'iteration_completed'

101

EPOCH_COMPLETED = 'epoch_completed'

102

COMPLETED = 'completed'

103

EXCEPTION_RAISED = 'exception_raised'

104

GET_BATCH_STARTED = 'get_batch_started'

105

GET_BATCH_COMPLETED = 'get_batch_completed'

106

DATALOADER_STOP_ITERATION = 'dataloader_stop_iteration'

107

108

@staticmethod

109

def ITERATION_STARTED(every=1, once=None):

110

"""Create event filter for iteration started events."""

111

112

@staticmethod

113

def ITERATION_COMPLETED(every=1, once=None):

114

"""Create event filter for iteration completed events."""

115

116

@staticmethod

117

def EPOCH_STARTED(every=1, once=None):

118

"""Create event filter for epoch started events."""

119

120

@staticmethod

121

def EPOCH_COMPLETED(every=1, once=None):

122

"""Create event filter for epoch completed events."""

123

124

class EventEnum:

125

"""

126

Base class for creating custom event enums.

127

128

Allows creation of custom events that integrate with the event system.

129

"""

130

pass

131

132

class EventsList:

133

"""

134

Container for multiple events.

135

136

Allows grouping multiple events together for batch event handling.

137

"""

138

def __init__(self, *events): ...

139

140

class CallableEventWithFilter:

141

"""

142

Event with conditional execution based on filter function.

143

144

Parameters:

145

- event: base event to filter

146

- filter_fn: function that determines when event should fire

147

"""

148

def __init__(self, event, filter_fn, every=None, once=None): ...

149

```

150

151

### Engine State

152

153

Container for engine state information during training and evaluation.

154

155

```python { .api }

156

class State:

157

"""

158

Engine state containing training information.

159

160

Attributes:

161

- iteration: current iteration number (global)

162

- epoch: current epoch number

163

- epoch_length: length of current epoch

164

- max_epochs: maximum number of epochs

165

- output: output from last process_function call

166

- batch: current batch data

167

- metrics: dictionary of computed metrics

168

- dataloader: current data loader

169

- seed: random seed used

170

- times: dictionary of timing information

171

"""

172

def __init__(self):

173

self.iteration = 0

174

self.epoch = 0

175

self.epoch_length = None

176

self.max_epochs = None

177

self.output = None

178

self.batch = None

179

self.metrics = {}

180

self.dataloader = None

181

self.seed = None

182

self.times = {}

183

```

184

185

### Supervised Training

186

187

Convenience functions for creating supervised training and evaluation engines.

188

189

```python { .api }

190

def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):

191

"""

192

Create an engine for supervised training.

193

194

Parameters:

195

- model: PyTorch model to train

196

- optimizer: PyTorch optimizer

197

- loss_fn: loss function

198

- device: device to move data to (optional)

199

- non_blocking: non-blocking data transfer

200

- prepare_batch: function to prepare batch data

201

- output_transform: function to transform engine output

202

- deterministic: use deterministic algorithms

203

204

Returns:

205

Engine configured for supervised training

206

"""

207

208

def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None):

209

"""

210

Create an engine for supervised evaluation.

211

212

Parameters:

213

- model: PyTorch model to evaluate

214

- metrics: dictionary of metrics to compute

215

- device: device to move data to (optional)

216

- non_blocking: non-blocking data transfer

217

- prepare_batch: function to prepare batch data

218

- output_transform: function to transform engine output

219

220

Returns:

221

Engine configured for supervised evaluation

222

"""

223

```

224

225

### Training Step Functions

226

227

Factory functions for creating training step functions with different precision and device support.

228

229

```python { .api }

230

def supervised_training_step(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):

231

"""

232

Factory function for supervised training step.

233

234

Parameters:

235

- model: PyTorch model

236

- optimizer: PyTorch optimizer

237

- loss_fn: loss function

238

- device: device to run on

239

- non_blocking: non-blocking tensor transfers

240

- prepare_batch: function to prepare batch data

241

- output_transform: function to transform engine output

242

243

Returns:

244

Process function for training step

245

"""

246

247

def supervised_training_step_amp(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, scaler=None):

248

"""

249

Factory function for supervised training step with automatic mixed precision.

250

251

Parameters:

252

- model: PyTorch model

253

- optimizer: PyTorch optimizer

254

- loss_fn: loss function

255

- device: device to run on

256

- non_blocking: non-blocking tensor transfers

257

- prepare_batch: function to prepare batch data

258

- output_transform: function to transform engine output

259

- scaler: GradScaler for mixed precision

260

261

Returns:

262

Process function for AMP training step

263

"""

264

265

def supervised_training_step_apex(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):

266

"""

267

Factory function for supervised training step with NVIDIA Apex.

268

269

Parameters:

270

- model: PyTorch model

271

- optimizer: PyTorch optimizer

272

- loss_fn: loss function

273

- device: device to run on

274

- non_blocking: non-blocking tensor transfers

275

- prepare_batch: function to prepare batch data

276

- output_transform: function to transform engine output

277

278

Returns:

279

Process function for Apex training step

280

"""

281

282

def supervised_training_step_tpu(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):

283

"""

284

Factory function for supervised training step on TPU devices.

285

286

Parameters:

287

- model: PyTorch model

288

- optimizer: PyTorch optimizer

289

- loss_fn: loss function

290

- device: device to run on

291

- non_blocking: non-blocking tensor transfers

292

- prepare_batch: function to prepare batch data

293

- output_transform: function to transform engine output

294

295

Returns:

296

Process function for TPU training step

297

"""

298

```

299

300

### Evaluation Step Functions

301

302

Factory functions for creating evaluation step functions with different precision support.

303

304

```python { .api }

305

def supervised_evaluation_step(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):

306

"""

307

Factory function for supervised evaluation step.

308

309

Parameters:

310

- model: PyTorch model

311

- device: device to run on

312

- non_blocking: non-blocking tensor transfers

313

- prepare_batch: function to prepare batch data

314

- output_transform: function to transform engine output

315

316

Returns:

317

Process function for evaluation step

318

"""

319

320

def supervised_evaluation_step_amp(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):

321

"""

322

Factory function for supervised evaluation step with automatic mixed precision.

323

324

Parameters:

325

- model: PyTorch model

326

- device: device to run on

327

- non_blocking: non-blocking tensor transfers

328

- prepare_batch: function to prepare batch data

329

- output_transform: function to transform engine output

330

331

Returns:

332

Process function for AMP evaluation step

333

"""

334

```

335

336

### Event Handle

337

338

Handle for removable event handlers.

339

340

```python { .api }

341

class RemovableEventHandle:

342

"""Handle for removable event handlers."""

343

def remove(self):

344

"""Remove the associated event handler."""

345

```

346

347

## Usage Examples

348

349

### Basic Training Loop

350

351

```python

352

from ignite.engine import Engine, Events

353

354

def process_function(engine, batch):

355

model.train()

356

optimizer.zero_grad()

357

x, y = batch

358

y_pred = model(x)

359

loss = criterion(y_pred, y)

360

loss.backward()

361

optimizer.step()

362

return loss.item()

363

364

trainer = Engine(process_function)

365

366

@trainer.on(Events.ITERATION_COMPLETED(every=100))

367

def log_loss(engine):

368

print(f"Iteration {engine.state.iteration}: Loss = {engine.state.output}")

369

370

trainer.run(train_loader, max_epochs=10)

371

```

372

373

### Event Filtering

374

375

```python

376

# Execute every 50 iterations

377

@trainer.on(Events.ITERATION_COMPLETED(every=50))

378

def log_intermediate(engine):

379

print(f"Iteration {engine.state.iteration}")

380

381

# Execute only once at iteration 100

382

@trainer.on(Events.ITERATION_COMPLETED(once=100))

383

def save_checkpoint(engine):

384

torch.save(model.state_dict(), 'checkpoint.pth')

385

386

# Execute at the end of each epoch

387

@trainer.on(Events.EPOCH_COMPLETED)

388

def evaluate(engine):

389

evaluator.run(val_loader)

390

```

391

392

### Exception Handling

393

394

```python

395

@trainer.on(Events.EXCEPTION_RAISED)

396

def handle_exception(engine, e):

397

print(f"Exception occurred: {e}")

398

# Custom exception handling logic

399

if isinstance(e, KeyboardInterrupt):

400

print("Training interrupted by user")

401

else:

402

print("Unexpected error occurred")

403

```