0
# Engine and Training Loop
1
2
Core training loop infrastructure with event-driven architecture. The Engine is the central component of PyTorch Ignite, providing a flexible framework for training and evaluating neural networks with comprehensive lifecycle management.
3
4
## Capabilities
5
6
### Engine Class
7
8
The main Engine class that manages training and evaluation loops with a sophisticated event system.
9
10
```python { .api }
11
class Engine:
12
"""
13
Core engine for training and evaluation loops with event system.
14
15
Parameters:
16
- process_function: callable that processes a batch of data
17
18
Attributes:
19
- state: State object containing current training information
20
- should_terminate: boolean flag to terminate training
21
- should_terminate_single_epoch: boolean flag to terminate current epoch
22
"""
23
def __init__(self, process_function):
24
"""Initialize engine with a process function."""
25
26
def run(self, data, max_epochs=1, epoch_length=None, seed=None):
27
"""
28
Run the engine on data for specified epochs.
29
30
Parameters:
31
- data: data loader or iterable
32
- max_epochs: maximum number of epochs to run
33
- epoch_length: number of iterations per epoch (optional)
34
- seed: random seed for reproducibility
35
36
Returns:
37
State object with final training state
38
"""
39
40
def add_event_handler(self, event_name, handler, *args, **kwargs):
41
"""
42
Add an event handler for the specified event.
43
44
Parameters:
45
- event_name: name of the event
46
- handler: callable to execute when event occurs
47
- args, kwargs: arguments to pass to handler
48
49
Returns:
50
RemovableEventHandle object
51
"""
52
53
def on(self, event_filter=None):
54
"""
55
Decorator for adding event handlers.
56
57
Parameters:
58
- event_filter: event or event filter to listen for
59
60
Returns:
61
Decorator function
62
"""
63
64
def fire_event(self, event_name):
65
"""Fire an event, executing all registered handlers."""
66
67
def terminate(self):
68
"""Terminate the training loop."""
69
70
def terminate_epoch(self):
71
"""Terminate the current epoch."""
72
73
def has_event_handler(self, handler, event_name=None):
74
"""Check if handler is registered for event."""
75
76
def remove_event_handler(self, handler, event_name):
77
"""Remove an event handler."""
78
79
class DeterministicEngine(Engine):
80
"""
81
Deterministic version of Engine with reproducible behavior.
82
83
Parameters:
84
- process_function: callable that processes a batch of data
85
- deterministic: enable deterministic behavior
86
"""
87
def __init__(self, process_function, deterministic=True): ...
88
```
89
90
### Events Enum
91
92
Comprehensive event system providing fine-grained control over training lifecycle.
93
94
```python { .api }
95
class Events:
96
"""Event types for engine lifecycle."""
97
STARTED = 'started'
98
EPOCH_STARTED = 'epoch_started'
99
ITERATION_STARTED = 'iteration_started'
100
ITERATION_COMPLETED = 'iteration_completed'
101
EPOCH_COMPLETED = 'epoch_completed'
102
COMPLETED = 'completed'
103
EXCEPTION_RAISED = 'exception_raised'
104
GET_BATCH_STARTED = 'get_batch_started'
105
GET_BATCH_COMPLETED = 'get_batch_completed'
106
DATALOADER_STOP_ITERATION = 'dataloader_stop_iteration'
107
108
@staticmethod
109
def ITERATION_STARTED(every=1, once=None):
110
"""Create event filter for iteration started events."""
111
112
@staticmethod
113
def ITERATION_COMPLETED(every=1, once=None):
114
"""Create event filter for iteration completed events."""
115
116
@staticmethod
117
def EPOCH_STARTED(every=1, once=None):
118
"""Create event filter for epoch started events."""
119
120
@staticmethod
121
def EPOCH_COMPLETED(every=1, once=None):
122
"""Create event filter for epoch completed events."""
123
124
class EventEnum:
125
"""
126
Base class for creating custom event enums.
127
128
Allows creation of custom events that integrate with the event system.
129
"""
130
pass
131
132
class EventsList:
133
"""
134
Container for multiple events.
135
136
Allows grouping multiple events together for batch event handling.
137
"""
138
def __init__(self, *events): ...
139
140
class CallableEventWithFilter:
141
"""
142
Event with conditional execution based on filter function.
143
144
Parameters:
145
- event: base event to filter
146
- filter_fn: function that determines when event should fire
147
"""
148
def __init__(self, event, filter_fn, every=None, once=None): ...
149
```
150
151
### Engine State
152
153
Container for engine state information during training and evaluation.
154
155
```python { .api }
156
class State:
157
"""
158
Engine state containing training information.
159
160
Attributes:
161
- iteration: current iteration number (global)
162
- epoch: current epoch number
163
- epoch_length: length of current epoch
164
- max_epochs: maximum number of epochs
165
- output: output from last process_function call
166
- batch: current batch data
167
- metrics: dictionary of computed metrics
168
- dataloader: current data loader
169
- seed: random seed used
170
- times: dictionary of timing information
171
"""
172
def __init__(self):
173
self.iteration = 0
174
self.epoch = 0
175
self.epoch_length = None
176
self.max_epochs = None
177
self.output = None
178
self.batch = None
179
self.metrics = {}
180
self.dataloader = None
181
self.seed = None
182
self.times = {}
183
```
184
185
### Supervised Training
186
187
Convenience functions for creating supervised training and evaluation engines.
188
189
```python { .api }
190
def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):
191
"""
192
Create an engine for supervised training.
193
194
Parameters:
195
- model: PyTorch model to train
196
- optimizer: PyTorch optimizer
197
- loss_fn: loss function
198
- device: device to move data to (optional)
199
- non_blocking: non-blocking data transfer
200
- prepare_batch: function to prepare batch data
201
- output_transform: function to transform engine output
202
- deterministic: use deterministic algorithms
203
204
Returns:
205
Engine configured for supervised training
206
"""
207
208
def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
209
"""
210
Create an engine for supervised evaluation.
211
212
Parameters:
213
- model: PyTorch model to evaluate
214
- metrics: dictionary of metrics to compute
215
- device: device to move data to (optional)
216
- non_blocking: non-blocking data transfer
217
- prepare_batch: function to prepare batch data
218
- output_transform: function to transform engine output
219
220
Returns:
221
Engine configured for supervised evaluation
222
"""
223
```
224
225
### Training Step Functions
226
227
Factory functions for creating training step functions with different precision and device support.
228
229
```python { .api }
230
def supervised_training_step(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
231
"""
232
Factory function for supervised training step.
233
234
Parameters:
235
- model: PyTorch model
236
- optimizer: PyTorch optimizer
237
- loss_fn: loss function
238
- device: device to run on
239
- non_blocking: non-blocking tensor transfers
240
- prepare_batch: function to prepare batch data
241
- output_transform: function to transform engine output
242
243
Returns:
244
Process function for training step
245
"""
246
247
def supervised_training_step_amp(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, scaler=None):
248
"""
249
Factory function for supervised training step with automatic mixed precision.
250
251
Parameters:
252
- model: PyTorch model
253
- optimizer: PyTorch optimizer
254
- loss_fn: loss function
255
- device: device to run on
256
- non_blocking: non-blocking tensor transfers
257
- prepare_batch: function to prepare batch data
258
- output_transform: function to transform engine output
259
- scaler: GradScaler for mixed precision
260
261
Returns:
262
Process function for AMP training step
263
"""
264
265
def supervised_training_step_apex(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
266
"""
267
Factory function for supervised training step with NVIDIA Apex.
268
269
Parameters:
270
- model: PyTorch model
271
- optimizer: PyTorch optimizer
272
- loss_fn: loss function
273
- device: device to run on
274
- non_blocking: non-blocking tensor transfers
275
- prepare_batch: function to prepare batch data
276
- output_transform: function to transform engine output
277
278
Returns:
279
Process function for Apex training step
280
"""
281
282
def supervised_training_step_tpu(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
283
"""
284
Factory function for supervised training step on TPU devices.
285
286
Parameters:
287
- model: PyTorch model
288
- optimizer: PyTorch optimizer
289
- loss_fn: loss function
290
- device: device to run on
291
- non_blocking: non-blocking tensor transfers
292
- prepare_batch: function to prepare batch data
293
- output_transform: function to transform engine output
294
295
Returns:
296
Process function for TPU training step
297
"""
298
```
299
300
### Evaluation Step Functions
301
302
Factory functions for creating evaluation step functions with different precision support.
303
304
```python { .api }
305
def supervised_evaluation_step(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
306
"""
307
Factory function for supervised evaluation step.
308
309
Parameters:
310
- model: PyTorch model
311
- device: device to run on
312
- non_blocking: non-blocking tensor transfers
313
- prepare_batch: function to prepare batch data
314
- output_transform: function to transform engine output
315
316
Returns:
317
Process function for evaluation step
318
"""
319
320
def supervised_evaluation_step_amp(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
321
"""
322
Factory function for supervised evaluation step with automatic mixed precision.
323
324
Parameters:
325
- model: PyTorch model
326
- device: device to run on
327
- non_blocking: non-blocking tensor transfers
328
- prepare_batch: function to prepare batch data
329
- output_transform: function to transform engine output
330
331
Returns:
332
Process function for AMP evaluation step
333
"""
334
```
335
336
### Event Handle
337
338
Handle for removable event handlers.
339
340
```python { .api }
341
class RemovableEventHandle:
342
"""Handle for removable event handlers."""
343
def remove(self):
344
"""Remove the associated event handler."""
345
```
346
347
## Usage Examples
348
349
### Basic Training Loop
350
351
```python
352
from ignite.engine import Engine, Events
353
354
def process_function(engine, batch):
355
model.train()
356
optimizer.zero_grad()
357
x, y = batch
358
y_pred = model(x)
359
loss = criterion(y_pred, y)
360
loss.backward()
361
optimizer.step()
362
return loss.item()
363
364
trainer = Engine(process_function)
365
366
@trainer.on(Events.ITERATION_COMPLETED(every=100))
367
def log_loss(engine):
368
print(f"Iteration {engine.state.iteration}: Loss = {engine.state.output}")
369
370
trainer.run(train_loader, max_epochs=10)
371
```
372
373
### Event Filtering
374
375
```python
376
# Execute every 50 iterations
377
@trainer.on(Events.ITERATION_COMPLETED(every=50))
378
def log_intermediate(engine):
379
print(f"Iteration {engine.state.iteration}")
380
381
# Execute only once at iteration 100
382
@trainer.on(Events.ITERATION_COMPLETED(once=100))
383
def save_checkpoint(engine):
384
torch.save(model.state_dict(), 'checkpoint.pth')
385
386
# Execute at the end of each epoch
387
@trainer.on(Events.EPOCH_COMPLETED)
388
def evaluate(engine):
389
evaluator.run(val_loader)
390
```
391
392
### Exception Handling
393
394
```python
395
@trainer.on(Events.EXCEPTION_RAISED)
396
def handle_exception(engine, e):
397
print(f"Exception occurred: {e}")
398
# Custom exception handling logic
399
if isinstance(e, KeyboardInterrupt):
400
print("Training interrupted by user")
401
else:
402
print("Unexpected error occurred")
403
```