or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

tessl/pypi-pytorch-ignite

A lightweight library to help with training neural networks in PyTorch.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/pytorch-ignite@0.5.x

To install, run

npx @tessl/cli install tessl/pypi-pytorch-ignite@0.5.0

0

# PyTorch Ignite

1

2

A lightweight library to help with training neural networks in PyTorch. Ignite provides a flexible, extensible API that reduces boilerplate code compared to pure PyTorch implementations through its event-driven architecture and handler system.

3

4

## Package Information

5

6

- **Package Name**: pytorch-ignite

7

- **Package Type**: pypi

8

- **Language**: Python

9

- **Installation**: `pip install pytorch-ignite`

10

11

## Core Imports

12

13

```python

14

import ignite

15

```

16

17

Common imports for working with engines and training:

18

19

```python

20

from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator

21

from ignite.metrics import Accuracy, Loss

22

from ignite.handlers import ModelCheckpoint, EarlyStopping

23

```

24

25

## Basic Usage

26

27

```python

28

from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator

29

from ignite.metrics import Accuracy, Loss

30

from ignite.handlers import ModelCheckpoint

31

import torch

32

import torch.nn as nn

33

34

# Define model, optimizer, and loss

35

model = nn.Linear(10, 1)

36

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

37

criterion = nn.MSELoss()

38

39

# Create training and validation engines

40

trainer = create_supervised_trainer(model, optimizer, criterion)

41

evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(), 'loss': Loss(criterion)})

42

43

# Add event handlers

44

@trainer.on(Events.ITERATION_COMPLETED(every=100))

45

def log_training_loss(engine):

46

print(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")

47

48

@trainer.on(Events.EPOCH_COMPLETED)

49

def log_training_results(engine):

50

evaluator.run(val_loader)

51

metrics = evaluator.state.metrics

52

print(f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")

53

54

# Add model checkpointing

55

checkpoint = ModelCheckpoint('models', 'mymodel', score_function=lambda engine: -engine.state.metrics['loss'])

56

evaluator.add_event_handler(Events.COMPLETED, checkpoint, {'model': model})

57

58

# Start training

59

trainer.run(train_loader, max_epochs=100)

60

```

61

62

## Architecture

63

64

PyTorch Ignite is built around several core architectural components:

65

66

- **Engine**: The central component that manages training/evaluation loops with event-driven architecture

67

- **Events**: Comprehensive event system allowing fine-grained control over training lifecycle

68

- **Handlers**: Pluggable components for checkpointing, logging, scheduling, and training enhancements

69

- **Metrics**: Extensive collection of evaluation metrics for various machine learning tasks

70

- **Distributed**: Built-in support for distributed training across multiple backends

71

72

This design enables maximum flexibility while reducing boilerplate code, allowing researchers and practitioners to focus on model development rather than training infrastructure.

73

74

## Capabilities

75

76

### Engine and Training Loop

77

78

Core training loop infrastructure with event-driven architecture, providing supervised training and evaluation engines with comprehensive lifecycle management.

79

80

```python { .api }

81

class Engine:

82

def __init__(self, process_function): ...

83

def run(self, data, max_epochs=1, epoch_length=None, seed=None): ...

84

def add_event_handler(self, event_name, handler, *args, **kwargs): ...

85

def on(self, event_filter=None): ...

86

87

def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...

88

def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None): ...

89

90

class Events:

91

STARTED = 'started'

92

EPOCH_STARTED = 'epoch_started'

93

ITERATION_STARTED = 'iteration_started'

94

ITERATION_COMPLETED = 'iteration_completed'

95

EPOCH_COMPLETED = 'epoch_completed'

96

COMPLETED = 'completed'

97

EXCEPTION_RAISED = 'exception_raised'

98

```

99

100

[Engine and Training Loop](./engine.md)

101

102

### Metrics Collection

103

104

Comprehensive metric collection system covering classification, regression, NLP, computer vision, clustering, and GAN evaluation with 80+ built-in metrics.

105

106

```python { .api }

107

class Metric:

108

def reset(self): ...

109

def update(self, output): ...

110

def compute(self): ...

111

112

class Accuracy(Metric): ...

113

class Precision(Metric): ...

114

class Recall(Metric): ...

115

class Loss(Metric): ...

116

class MeanSquaredError(Metric): ...

117

class RootMeanSquaredError(Metric): ...

118

class RocAuc(Metric): ...

119

class ConfusionMatrix(Metric): ...

120

```

121

122

[Metrics Collection](./metrics.md)

123

124

### Handlers and Training Enhancement

125

126

Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking with 40+ built-in handlers.

127

128

```python { .api }

129

class Checkpoint:

130

def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...

131

132

class EarlyStopping:

133

def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...

134

135

class LRScheduler:

136

def __init__(self, lr_scheduler, save_history=False, **kwds): ...

137

138

class ProgressBar:

139

def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...

140

```

141

142

[Handlers and Training Enhancement](./handlers.md)

143

144

### Distributed Training

145

146

Comprehensive distributed computing support with multiple backends including native PyTorch DDP, Horovod, and XLA/TPU support.

147

148

```python { .api }

149

def initialize(backend=None, **kwargs): ...

150

def finalize(): ...

151

def all_reduce(tensor, group=None, op='SUM'): ...

152

def all_gather(tensor, group=None): ...

153

def broadcast(tensor, src=0, group=None): ...

154

def barrier(group=None): ...

155

def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): ...

156

```

157

158

[Distributed Training](./distributed.md)

159

160

### Utilities and Helpers

161

162

Helper utilities for tensor operations, type conversions, logging setup, and reproducibility management.

163

164

```python { .api }

165

def convert_tensor(input_, device=None, non_blocking=False): ...

166

def to_onehot(indices, num_classes): ...

167

def setup_logger(name=None, level=logging.INFO, stream=None, format="%(asctime)s %(name)s %(levelname)s %(message)s", filepath=None, distributed_rank=None): ...

168

def manual_seed(seed): ...

169

def apply_to_tensor(input_, func): ...

170

def apply_to_type(input_, input_type, func): ...

171

```

172

173

[Utilities and Helpers](./utils.md)

174

175

### Contrib Module

176

177

Specialized engines and utilities for advanced use cases including truncated backpropagation through time (TBPTT) for recurrent neural networks.

178

179

```python { .api }

180

def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...

181

182

class Tbptt_Events:

183

TBPTT_STEP_COMPLETED = 'tbptt_step_completed'

184

TIME_STEP_COMPLETED = 'time_step_completed'

185

```

186

187

[Contrib Module](./contrib.md)

188

189

### Base Classes and Exceptions

190

191

Core base classes and exception types providing fundamental functionality and error handling.

192

193

```python { .api }

194

class Serializable:

195

def state_dict(self): ...

196

def load_state_dict(self, state_dict): ...

197

198

class NotComputableError(RuntimeError):

199

pass

200

```

201

202

[Base Classes and Exceptions](./base-exceptions.md)

203

204

205

## Types

206

207

```python { .api }

208

class State:

209

"""Engine state containing training information."""

210

def __init__(self):

211

self.iteration = 0

212

self.epoch = 0

213

self.epoch_length = None

214

self.max_epochs = None

215

self.output = None

216

self.batch = None

217

self.metrics = {}

218

self.dataloader = None

219

self.seed = None

220

self.times = {}

221

222

class RemovableEventHandle:

223

"""Handle for removable event handlers."""

224

def remove(self): ...

225

226

class NotComputableError(RuntimeError):

227

"""Raised when a metric cannot be computed."""

228

pass

229

230

class Serializable:

231

"""Mixin for serializable objects."""

232

def state_dict(self): ...

233

def load_state_dict(self, state_dict): ...

234

235

class DeterministicEngine(Engine):

236

"""Deterministic version of Engine with reproducible behavior."""

237

def __init__(self, process_function, deterministic=True): ...

238

239

class EventEnum:

240

"""Base class for creating custom event enums."""

241

pass

242

243

class EventsList:

244

"""Container for multiple events."""

245

def __init__(self, *events): ...

246

247

class CallableEventWithFilter:

248

"""Event with conditional execution based on filter function."""

249

def __init__(self, event, filter_fn, every=None, once=None): ...

250

251

class TbpttState:

252

"""Extended state for TBPTT training."""

253

def __init__(self):

254

self.tbptt_step = 0

255

self.time_step = 0

256

257

class Parallel:

258

"""Parallel execution launcher for distributed training."""

259

def __init__(self, backend=None, nprocs=None, **kwargs): ...

260

def run(self, fn, *args, **kwargs): ...

261

262

class DistributedProxySampler:

263

"""Distributed sampler proxy for automatic distributed data sampling."""

264

def __init__(self, sampler, num_replicas=None, rank=None, seed=0): ...

265

```