0
# Handlers and Training Enhancement
1
2
Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking. PyTorch Ignite provides 40+ built-in handlers that plug into the event system to enhance training workflows.
3
4
## Capabilities
5
6
### Checkpointing
7
8
Model and training state checkpointing with flexible save strategies.
9
10
```python { .api }
11
class Checkpoint:
12
"""
13
Flexible checkpointing handler.
14
15
Parameters:
16
- to_save: dictionary of objects to save
17
- save_handler: handler for saving (DiskSaver, etc.)
18
- filename_prefix: prefix for checkpoint filenames
19
- score_function: function to compute checkpoint score
20
- score_name: name of the score metric
21
- n_saved: number of checkpoints to keep
22
- atomic: whether to use atomic saves
23
- require_empty: require empty directory
24
- archived: whether to archive old checkpoints
25
- greater_or_equal: score comparison direction
26
"""
27
def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...
28
29
class DiskSaver:
30
"""
31
Disk-based checkpoint saver.
32
33
Parameters:
34
- dirname: directory to save checkpoints
35
- atomic: whether to use atomic saves
36
- create_dir: whether to create directory if it doesn't exist
37
- require_empty: require empty directory
38
"""
39
def __init__(self, dirname, atomic=True, create_dir=True, require_empty=True): ...
40
41
class ModelCheckpoint:
42
"""
43
Model checkpoint handler (deprecated - use Checkpoint instead).
44
45
Parameters:
46
- dirname: directory to save checkpoints
47
- filename_prefix: prefix for checkpoint filenames
48
- score_function: function to compute checkpoint score
49
- score_name: name of the score metric
50
- n_saved: number of checkpoints to keep
51
- atomic: whether to use atomic saves
52
- require_empty: require empty directory
53
- create_dir: whether to create directory
54
- save_as_state_dict: save as state dict instead of full model
55
- global_step_transform: function to transform global step
56
"""
57
def __init__(self, dirname, filename_prefix, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None): ...
58
```
59
60
### Early Stopping
61
62
Early stopping based on validation metrics to prevent overfitting.
63
64
```python { .api }
65
class EarlyStopping:
66
"""
67
Early stopping handler to prevent overfitting.
68
69
Parameters:
70
- patience: number of events to wait before stopping
71
- score_function: function to compute stopping score
72
- trainer: trainer engine to stop
73
- min_delta: minimum change required to reset patience
74
- cumulative_delta: whether to use cumulative delta
75
"""
76
def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...
77
```
78
79
### Learning Rate Scheduling
80
81
Learning rate scheduling with various strategies and warmup support.
82
83
```python { .api }
84
class LRScheduler:
85
"""
86
Learning rate scheduler wrapper.
87
88
Parameters:
89
- lr_scheduler: PyTorch learning rate scheduler
90
- save_history: whether to save LR history
91
- **kwds: additional arguments
92
"""
93
def __init__(self, lr_scheduler, save_history=False, **kwds): ...
94
95
def create_lr_scheduler_with_warmup(lr_scheduler, warmup_start_value, warmup_end_value, warmup_duration, save_history=False):
96
"""
97
Create learning rate scheduler with warmup.
98
99
Parameters:
100
- lr_scheduler: base learning rate scheduler
101
- warmup_start_value: starting learning rate for warmup
102
- warmup_end_value: ending learning rate for warmup
103
- warmup_duration: duration of warmup phase
104
- save_history: whether to save LR history
105
106
Returns:
107
Combined scheduler with warmup
108
"""
109
110
class CosineAnnealingScheduler:
111
"""
112
Cosine annealing scheduler.
113
114
Parameters:
115
- optimizer: PyTorch optimizer
116
- param_name: parameter name to schedule
117
- start_value: starting parameter value
118
- end_value: ending parameter value
119
- cycle_size: size of one cycle
120
- cycle_mult: cycle size multiplier
121
- start_value_mult: start value multiplier per cycle
122
- end_value_mult: end value multiplier per cycle
123
- save_history: whether to save parameter history
124
"""
125
def __init__(self, optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False): ...
126
127
class LinearCyclicalScheduler:
128
"""
129
Linear cyclical scheduler.
130
131
Parameters:
132
- optimizer: PyTorch optimizer
133
- param_name: parameter name to schedule
134
- start_value: starting parameter value
135
- end_value: ending parameter value
136
- cycle_size: size of one cycle
137
- cycle_mult: cycle size multiplier
138
- start_value_mult: start value multiplier per cycle
139
- end_value_mult: end value multiplier per cycle
140
- save_history: whether to save parameter history
141
"""
142
def __init__(self, optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False): ...
143
144
class ConcatScheduler:
145
"""
146
Concatenated scheduler combining multiple schedulers.
147
148
Parameters:
149
- schedulers: list of (scheduler, duration) tuples
150
- durations: list of durations for each scheduler
151
- save_history: whether to save parameter history
152
"""
153
def __init__(self, schedulers, durations, save_history=False): ...
154
155
class PiecewiseLinear:
156
"""
157
Piecewise linear scheduler.
158
159
Parameters:
160
- optimizer: PyTorch optimizer
161
- param_name: parameter name to schedule
162
- milestones_values: list of (milestone, value) tuples
163
- save_history: whether to save parameter history
164
"""
165
def __init__(self, optimizer, param_name, milestones_values, save_history=False): ...
166
```
167
168
### Parameter Scheduling
169
170
General parameter scheduling framework for optimizers.
171
172
```python { .api }
173
class ParamScheduler:
174
"""
175
Base parameter scheduler class.
176
177
Parameters:
178
- optimizer: PyTorch optimizer
179
- param_name: parameter name to schedule
180
- save_history: whether to save parameter history
181
"""
182
def __init__(self, optimizer, param_name, save_history=False): ...
183
184
class ParamGroupScheduler:
185
"""
186
Parameter group scheduler for different parameter groups.
187
188
Parameters:
189
- schedulers: list of schedulers for each parameter group
190
- names: names for each parameter group
191
"""
192
def __init__(self, schedulers, names=None): ...
193
194
class StateParamScheduler:
195
"""
196
State-based parameter scheduler.
197
198
Parameters:
199
- param_scheduler: base parameter scheduler
200
- param_name: parameter name
201
- save_history: whether to save parameter history
202
"""
203
def __init__(self, param_scheduler, param_name, save_history=False): ...
204
205
class LambdaStateScheduler(StateParamScheduler):
206
"""
207
Lambda-based state parameter scheduler.
208
209
Parameters:
210
- lambda_func: lambda function for scheduling
211
- param_name: parameter name
212
- save_history: whether to save parameter history
213
"""
214
def __init__(self, lambda_func, param_name, save_history=False): ...
215
216
class ExpStateScheduler(StateParamScheduler):
217
"""
218
Exponential decay state parameter scheduler.
219
220
Parameters:
221
- gamma: exponential decay factor
222
- param_name: parameter name
223
- save_history: whether to save parameter history
224
"""
225
def __init__(self, gamma, param_name, save_history=False): ...
226
227
class StepStateScheduler(StateParamScheduler):
228
"""
229
Step-based state parameter scheduler.
230
231
Parameters:
232
- step_size: step size for scheduling
233
- gamma: decay factor
234
- param_name: parameter name
235
- save_history: whether to save parameter history
236
"""
237
def __init__(self, step_size, gamma, param_name, save_history=False): ...
238
239
class MultiStepStateScheduler(StateParamScheduler):
240
"""
241
Multi-step state parameter scheduler.
242
243
Parameters:
244
- milestones: list of milestones
245
- gamma: decay factor
246
- param_name: parameter name
247
- save_history: whether to save parameter history
248
"""
249
def __init__(self, milestones, gamma, param_name, save_history=False): ...
250
251
class PiecewiseLinearStateScheduler(StateParamScheduler):
252
"""
253
Piecewise linear state parameter scheduler.
254
255
Parameters:
256
- milestones_values: list of (milestone, value) tuples
257
- param_name: parameter name
258
- save_history: whether to save parameter history
259
"""
260
def __init__(self, milestones_values, param_name, save_history=False): ...
261
```
262
263
### Logging and Tracking
264
265
Integration with popular experiment tracking and logging frameworks.
266
267
```python { .api }
268
class TensorboardLogger:
269
"""
270
TensorBoard logging handler.
271
272
Parameters:
273
- log_dir: directory for TensorBoard logs
274
- **kwargs: additional arguments for SummaryWriter
275
"""
276
def __init__(self, log_dir=None, **kwargs): ...
277
278
def attach_output_handler(self, engine, event_name, tag, output_transform=None, metric_names=None, global_step_transform=None):
279
"""Attach output logging handler."""
280
281
def attach_opt_params_handler(self, engine, event_name, optimizer, param_name="lr"):
282
"""Attach optimizer parameter logging handler."""
283
284
class VisdomLogger:
285
"""
286
Visdom logging handler.
287
288
Parameters:
289
- server: Visdom server URL
290
- port: server port
291
- **kwargs: additional Visdom arguments
292
"""
293
def __init__(self, server=None, port=8097, **kwargs): ...
294
295
class MLflowLogger:
296
"""
297
MLflow experiment tracking.
298
299
Parameters:
300
- tracking_uri: MLflow tracking server URI
301
- experiment_name: name of the experiment
302
- run_name: name of the run
303
- artifact_location: artifact storage location
304
- **kwargs: additional MLflow arguments
305
"""
306
def __init__(self, tracking_uri=None, experiment_name=None, run_name=None, artifact_location=None, **kwargs): ...
307
308
class NeptuneLogger:
309
"""
310
Neptune experiment tracking.
311
312
Parameters:
313
- api_token: Neptune API token
314
- project_name: Neptune project name
315
- experiment_name: name of the experiment
316
- **kwargs: additional Neptune arguments
317
"""
318
def __init__(self, api_token=None, project_name=None, experiment_name=None, **kwargs): ...
319
320
class WandBLogger:
321
"""
322
Weights & Biases experiment tracking.
323
324
Parameters:
325
- project: W&B project name
326
- entity: W&B entity name
327
- config: configuration dictionary
328
- **kwargs: additional W&B arguments
329
"""
330
def __init__(self, project=None, entity=None, config=None, **kwargs): ...
331
332
class ClearMLLogger:
333
"""
334
ClearML experiment tracking.
335
336
Parameters:
337
- project_name: ClearML project name
338
- task_name: task name
339
- **kwargs: additional ClearML arguments
340
"""
341
def __init__(self, project_name=None, task_name=None, **kwargs): ...
342
343
class PolyaxonLogger:
344
"""
345
Polyaxon experiment tracking.
346
347
Parameters:
348
- **kwargs: Polyaxon configuration arguments
349
"""
350
def __init__(self, **kwargs): ...
351
```
352
353
### Progress and Timing
354
355
Progress bars and timing utilities for monitoring training.
356
357
```python { .api }
358
class ProgressBar:
359
"""
360
Progress bar for training monitoring.
361
362
Parameters:
363
- persist: whether to persist after completion
364
- bar_format: custom bar format string
365
- **tqdm_kwargs: additional tqdm arguments
366
"""
367
def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...
368
369
class Timer:
370
"""
371
Timer for measuring elapsed time.
372
373
Parameters:
374
- average: whether to compute running average
375
"""
376
def __init__(self, average=False): ...
377
378
def value(self):
379
"""Get current timer value."""
380
381
def reset(self):
382
"""Reset timer."""
383
384
def pause(self):
385
"""Pause timer."""
386
387
def resume(self):
388
"""Resume timer."""
389
390
class BasicTimeProfiler:
391
"""
392
Basic profiler for timing engine operations.
393
394
Parameters:
395
- dataflow_profiling: whether to profile data loading
396
"""
397
def __init__(self, dataflow_profiling=False): ...
398
399
def print_results(self, results_dict):
400
"""Print profiling results."""
401
402
class HandlersTimeProfiler:
403
"""
404
Profiler for timing handler execution.
405
"""
406
def __init__(self): ...
407
```
408
409
### Model Enhancement
410
411
Handlers for enhancing model training behavior.
412
413
```python { .api }
414
class GradientAccumulation:
415
"""
416
Gradient accumulation handler.
417
418
Parameters:
419
- accumulation_steps: number of steps to accumulate gradients
420
"""
421
def __init__(self, accumulation_steps): ...
422
423
class EMAHandler:
424
"""
425
Exponential Moving Average handler for model parameters.
426
427
Parameters:
428
- model: PyTorch model
429
- decay: decay factor for EMA
430
- device: device to store EMA parameters
431
"""
432
def __init__(self, model, decay=0.9999, device=None): ...
433
434
class FastaiLRFinder:
435
"""
436
Learning rate finder inspired by fastai.
437
438
Parameters:
439
- engine: training engine
440
- optimizer: PyTorch optimizer
441
- criterion: loss function
442
- device: device to run on
443
"""
444
def __init__(self, engine, optimizer, criterion, device=None): ...
445
446
def range_test(self, data_loader, start_lr=1e-7, end_lr=10, num_iter=100, step_mode="exp"):
447
"""Perform learning rate range test."""
448
449
class TerminateOnNan:
450
"""
451
Terminate training when NaN values are encountered.
452
"""
453
def __init__(self): ...
454
455
class TimeLimit:
456
"""
457
Terminate training after specified time limit.
458
459
Parameters:
460
- limit: time limit in seconds
461
"""
462
def __init__(self, limit): ...
463
```
464
465
### Base Classes
466
467
Base classes for creating custom handlers and loggers.
468
469
```python { .api }
470
class BaseLogger:
471
"""Base class for loggers."""
472
def __init__(self): ...
473
474
class BaseOptimizerParams:
475
"""Base class for optimizer parameter handlers."""
476
def __init__(self): ...
477
478
class BaseOutputTransform:
479
"""Base class for output transformations."""
480
def __init__(self): ...
481
```
482
483
### Utility Functions
484
485
Helper functions for handlers and training enhancement.
486
487
```python { .api }
488
def global_step_from_engine(engine):
489
"""
490
Get global step from engine state.
491
492
Parameters:
493
- engine: engine instance
494
495
Returns:
496
Global step number
497
"""
498
```
499
500
## Usage Examples
501
502
### Model Checkpointing
503
504
```python
505
from ignite.handlers import Checkpoint, DiskSaver
506
507
# Create checkpoint handler
508
to_save = {'model': model, 'optimizer': optimizer}
509
save_handler = DiskSaver('checkpoints', create_dir=True)
510
511
checkpoint = Checkpoint(
512
to_save,
513
save_handler,
514
filename_prefix='best',
515
score_function=lambda engine: -engine.state.metrics['loss'],
516
score_name='neg_loss',
517
n_saved=3
518
)
519
520
# Attach to evaluator
521
evaluator.add_event_handler(Events.COMPLETED, checkpoint)
522
```
523
524
### Early Stopping
525
526
```python
527
from ignite.handlers import EarlyStopping
528
529
# Create early stopping handler
530
early_stopping = EarlyStopping(
531
patience=10,
532
score_function=lambda engine: engine.state.metrics['accuracy'],
533
trainer=trainer
534
)
535
536
# Attach to evaluator
537
evaluator.add_event_handler(Events.COMPLETED, early_stopping)
538
```
539
540
### Learning Rate Scheduling
541
542
```python
543
from ignite.handlers import LRScheduler
544
from torch.optim.lr_scheduler import StepLR
545
546
# Create PyTorch scheduler
547
torch_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
548
549
# Wrap with Ignite scheduler
550
lr_scheduler = LRScheduler(torch_scheduler)
551
552
# Attach to trainer
553
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler)
554
555
# Access LR history if save_history=True
556
lr_scheduler = LRScheduler(torch_scheduler, save_history=True)
557
# ... after training
558
print(lr_scheduler.get_param())
559
```
560
561
### TensorBoard Logging
562
563
```python
564
from ignite.handlers import TensorboardLogger
565
566
# Create TensorBoard logger
567
tb_logger = TensorboardLogger(log_dir='tb_logs')
568
569
# Log training loss
570
tb_logger.attach_output_handler(
571
trainer,
572
event_name=Events.ITERATION_COMPLETED(every=100),
573
tag="training",
574
output_transform=lambda loss: {"loss": loss}
575
)
576
577
# Log validation metrics
578
tb_logger.attach_output_handler(
579
evaluator,
580
event_name=Events.COMPLETED,
581
tag="validation",
582
metric_names=["accuracy", "loss"],
583
global_step_transform=global_step_from_engine(trainer)
584
)
585
586
# Log learning rate
587
tb_logger.attach_opt_params_handler(
588
trainer,
589
event_name=Events.ITERATION_COMPLETED(every=100),
590
optimizer=optimizer,
591
param_name="lr"
592
)
593
594
# Don't forget to close
595
trainer.add_event_handler(Events.COMPLETED, lambda _: tb_logger.close())
596
```
597
598
### Progress Bar
599
600
```python
601
from ignite.handlers import ProgressBar
602
603
# Create progress bar
604
pbar = ProgressBar(persist=True)
605
606
# Attach to trainer
607
pbar.attach(trainer, metric_names=['loss'])
608
609
# Or with custom output transform
610
pbar.attach(trainer, output_transform=lambda x: {'loss': x})
611
```