A lightweight library to help with training neural networks in PyTorch.
npx @tessl/cli install tessl/pypi-pytorch-ignite@0.5.00
# PyTorch Ignite
1
2
A lightweight library to help with training neural networks in PyTorch. Ignite provides a flexible, extensible API that reduces boilerplate code compared to pure PyTorch implementations through its event-driven architecture and handler system.
3
4
## Package Information
5
6
- **Package Name**: pytorch-ignite
7
- **Package Type**: pypi
8
- **Language**: Python
9
- **Installation**: `pip install pytorch-ignite`
10
11
## Core Imports
12
13
```python
14
import ignite
15
```
16
17
Common imports for working with engines and training:
18
19
```python
20
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
21
from ignite.metrics import Accuracy, Loss
22
from ignite.handlers import ModelCheckpoint, EarlyStopping
23
```
24
25
## Basic Usage
26
27
```python
28
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
29
from ignite.metrics import Accuracy, Loss
30
from ignite.handlers import ModelCheckpoint
31
import torch
32
import torch.nn as nn
33
34
# Define model, optimizer, and loss
35
model = nn.Linear(10, 1)
36
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
37
criterion = nn.MSELoss()
38
39
# Create training and validation engines
40
trainer = create_supervised_trainer(model, optimizer, criterion)
41
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(), 'loss': Loss(criterion)})
42
43
# Add event handlers
44
@trainer.on(Events.ITERATION_COMPLETED(every=100))
45
def log_training_loss(engine):
46
print(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")
47
48
@trainer.on(Events.EPOCH_COMPLETED)
49
def log_training_results(engine):
50
evaluator.run(val_loader)
51
metrics = evaluator.state.metrics
52
print(f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")
53
54
# Add model checkpointing
55
checkpoint = ModelCheckpoint('models', 'mymodel', score_function=lambda engine: -engine.state.metrics['loss'])
56
evaluator.add_event_handler(Events.COMPLETED, checkpoint, {'model': model})
57
58
# Start training
59
trainer.run(train_loader, max_epochs=100)
60
```
61
62
## Architecture
63
64
PyTorch Ignite is built around several core architectural components:
65
66
- **Engine**: The central component that manages training/evaluation loops with event-driven architecture
67
- **Events**: Comprehensive event system allowing fine-grained control over training lifecycle
68
- **Handlers**: Pluggable components for checkpointing, logging, scheduling, and training enhancements
69
- **Metrics**: Extensive collection of evaluation metrics for various machine learning tasks
70
- **Distributed**: Built-in support for distributed training across multiple backends
71
72
This design enables maximum flexibility while reducing boilerplate code, allowing researchers and practitioners to focus on model development rather than training infrastructure.
73
74
## Capabilities
75
76
### Engine and Training Loop
77
78
Core training loop infrastructure with event-driven architecture, providing supervised training and evaluation engines with comprehensive lifecycle management.
79
80
```python { .api }
81
class Engine:
82
def __init__(self, process_function): ...
83
def run(self, data, max_epochs=1, epoch_length=None, seed=None): ...
84
def add_event_handler(self, event_name, handler, *args, **kwargs): ...
85
def on(self, event_filter=None): ...
86
87
def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...
88
def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None): ...
89
90
class Events:
91
STARTED = 'started'
92
EPOCH_STARTED = 'epoch_started'
93
ITERATION_STARTED = 'iteration_started'
94
ITERATION_COMPLETED = 'iteration_completed'
95
EPOCH_COMPLETED = 'epoch_completed'
96
COMPLETED = 'completed'
97
EXCEPTION_RAISED = 'exception_raised'
98
```
99
100
[Engine and Training Loop](./engine.md)
101
102
### Metrics Collection
103
104
Comprehensive metric collection system covering classification, regression, NLP, computer vision, clustering, and GAN evaluation with 80+ built-in metrics.
105
106
```python { .api }
107
class Metric:
108
def reset(self): ...
109
def update(self, output): ...
110
def compute(self): ...
111
112
class Accuracy(Metric): ...
113
class Precision(Metric): ...
114
class Recall(Metric): ...
115
class Loss(Metric): ...
116
class MeanSquaredError(Metric): ...
117
class RootMeanSquaredError(Metric): ...
118
class RocAuc(Metric): ...
119
class ConfusionMatrix(Metric): ...
120
```
121
122
[Metrics Collection](./metrics.md)
123
124
### Handlers and Training Enhancement
125
126
Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking with 40+ built-in handlers.
127
128
```python { .api }
129
class Checkpoint:
130
def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...
131
132
class EarlyStopping:
133
def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...
134
135
class LRScheduler:
136
def __init__(self, lr_scheduler, save_history=False, **kwds): ...
137
138
class ProgressBar:
139
def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...
140
```
141
142
[Handlers and Training Enhancement](./handlers.md)
143
144
### Distributed Training
145
146
Comprehensive distributed computing support with multiple backends including native PyTorch DDP, Horovod, and XLA/TPU support.
147
148
```python { .api }
149
def initialize(backend=None, **kwargs): ...
150
def finalize(): ...
151
def all_reduce(tensor, group=None, op='SUM'): ...
152
def all_gather(tensor, group=None): ...
153
def broadcast(tensor, src=0, group=None): ...
154
def barrier(group=None): ...
155
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): ...
156
```
157
158
[Distributed Training](./distributed.md)
159
160
### Utilities and Helpers
161
162
Helper utilities for tensor operations, type conversions, logging setup, and reproducibility management.
163
164
```python { .api }
165
def convert_tensor(input_, device=None, non_blocking=False): ...
166
def to_onehot(indices, num_classes): ...
167
def setup_logger(name=None, level=logging.INFO, stream=None, format="%(asctime)s %(name)s %(levelname)s %(message)s", filepath=None, distributed_rank=None): ...
168
def manual_seed(seed): ...
169
def apply_to_tensor(input_, func): ...
170
def apply_to_type(input_, input_type, func): ...
171
```
172
173
[Utilities and Helpers](./utils.md)
174
175
### Contrib Module
176
177
Specialized engines and utilities for advanced use cases including truncated backpropagation through time (TBPTT) for recurrent neural networks.
178
179
```python { .api }
180
def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...
181
182
class Tbptt_Events:
183
TBPTT_STEP_COMPLETED = 'tbptt_step_completed'
184
TIME_STEP_COMPLETED = 'time_step_completed'
185
```
186
187
[Contrib Module](./contrib.md)
188
189
### Base Classes and Exceptions
190
191
Core base classes and exception types providing fundamental functionality and error handling.
192
193
```python { .api }
194
class Serializable:
195
def state_dict(self): ...
196
def load_state_dict(self, state_dict): ...
197
198
class NotComputableError(RuntimeError):
199
pass
200
```
201
202
[Base Classes and Exceptions](./base-exceptions.md)
203
204
205
## Types
206
207
```python { .api }
208
class State:
209
"""Engine state containing training information."""
210
def __init__(self):
211
self.iteration = 0
212
self.epoch = 0
213
self.epoch_length = None
214
self.max_epochs = None
215
self.output = None
216
self.batch = None
217
self.metrics = {}
218
self.dataloader = None
219
self.seed = None
220
self.times = {}
221
222
class RemovableEventHandle:
223
"""Handle for removable event handlers."""
224
def remove(self): ...
225
226
class NotComputableError(RuntimeError):
227
"""Raised when a metric cannot be computed."""
228
pass
229
230
class Serializable:
231
"""Mixin for serializable objects."""
232
def state_dict(self): ...
233
def load_state_dict(self, state_dict): ...
234
235
class DeterministicEngine(Engine):
236
"""Deterministic version of Engine with reproducible behavior."""
237
def __init__(self, process_function, deterministic=True): ...
238
239
class EventEnum:
240
"""Base class for creating custom event enums."""
241
pass
242
243
class EventsList:
244
"""Container for multiple events."""
245
def __init__(self, *events): ...
246
247
class CallableEventWithFilter:
248
"""Event with conditional execution based on filter function."""
249
def __init__(self, event, filter_fn, every=None, once=None): ...
250
251
class TbpttState:
252
"""Extended state for TBPTT training."""
253
def __init__(self):
254
self.tbptt_step = 0
255
self.time_step = 0
256
257
class Parallel:
258
"""Parallel execution launcher for distributed training."""
259
def __init__(self, backend=None, nprocs=None, **kwargs): ...
260
def run(self, fn, *args, **kwargs): ...
261
262
class DistributedProxySampler:
263
"""Distributed sampler proxy for automatic distributed data sampling."""
264
def __init__(self, sampler, num_replicas=None, rank=None, seed=0): ...
265
```