A lightweight library to help with training neural networks in PyTorch.
npx @tessl/cli install tessl/pypi-pytorch-ignite@0.5.0A lightweight library to help with training neural networks in PyTorch. Ignite provides a flexible, extensible API that reduces boilerplate code compared to pure PyTorch implementations through its event-driven architecture and handler system.
pip install pytorch-igniteimport igniteCommon imports for working with engines and training:
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint, EarlyStoppingfrom ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint
import torch
import torch.nn as nn
# Define model, optimizer, and loss
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# Create training and validation engines
trainer = create_supervised_trainer(model, optimizer, criterion)
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(), 'loss': Loss(criterion)})
# Add event handlers
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training_loss(engine):
print(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")
# Add model checkpointing
checkpoint = ModelCheckpoint('models', 'mymodel', score_function=lambda engine: -engine.state.metrics['loss'])
evaluator.add_event_handler(Events.COMPLETED, checkpoint, {'model': model})
# Start training
trainer.run(train_loader, max_epochs=100)PyTorch Ignite is built around several core architectural components:
This design enables maximum flexibility while reducing boilerplate code, allowing researchers and practitioners to focus on model development rather than training infrastructure.
Core training loop infrastructure with event-driven architecture, providing supervised training and evaluation engines with comprehensive lifecycle management.
class Engine:
def __init__(self, process_function): ...
def run(self, data, max_epochs=1, epoch_length=None, seed=None): ...
def add_event_handler(self, event_name, handler, *args, **kwargs): ...
def on(self, event_filter=None): ...
def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...
def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None): ...
class Events:
STARTED = 'started'
EPOCH_STARTED = 'epoch_started'
ITERATION_STARTED = 'iteration_started'
ITERATION_COMPLETED = 'iteration_completed'
EPOCH_COMPLETED = 'epoch_completed'
COMPLETED = 'completed'
EXCEPTION_RAISED = 'exception_raised'Comprehensive metric collection system covering classification, regression, NLP, computer vision, clustering, and GAN evaluation with 80+ built-in metrics.
class Metric:
def reset(self): ...
def update(self, output): ...
def compute(self): ...
class Accuracy(Metric): ...
class Precision(Metric): ...
class Recall(Metric): ...
class Loss(Metric): ...
class MeanSquaredError(Metric): ...
class RootMeanSquaredError(Metric): ...
class RocAuc(Metric): ...
class ConfusionMatrix(Metric): ...Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking with 40+ built-in handlers.
class Checkpoint:
def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...
class EarlyStopping:
def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...
class LRScheduler:
def __init__(self, lr_scheduler, save_history=False, **kwds): ...
class ProgressBar:
def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...Handlers and Training Enhancement
Comprehensive distributed computing support with multiple backends including native PyTorch DDP, Horovod, and XLA/TPU support.
def initialize(backend=None, **kwargs): ...
def finalize(): ...
def all_reduce(tensor, group=None, op='SUM'): ...
def all_gather(tensor, group=None): ...
def broadcast(tensor, src=0, group=None): ...
def barrier(group=None): ...
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): ...Helper utilities for tensor operations, type conversions, logging setup, and reproducibility management.
def convert_tensor(input_, device=None, non_blocking=False): ...
def to_onehot(indices, num_classes): ...
def setup_logger(name=None, level=logging.INFO, stream=None, format="%(asctime)s %(name)s %(levelname)s %(message)s", filepath=None, distributed_rank=None): ...
def manual_seed(seed): ...
def apply_to_tensor(input_, func): ...
def apply_to_type(input_, input_type, func): ...Specialized engines and utilities for advanced use cases including truncated backpropagation through time (TBPTT) for recurrent neural networks.
def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...
class Tbptt_Events:
TBPTT_STEP_COMPLETED = 'tbptt_step_completed'
TIME_STEP_COMPLETED = 'time_step_completed'Core base classes and exception types providing fundamental functionality and error handling.
class Serializable:
def state_dict(self): ...
def load_state_dict(self, state_dict): ...
class NotComputableError(RuntimeError):
passclass State:
"""Engine state containing training information."""
def __init__(self):
self.iteration = 0
self.epoch = 0
self.epoch_length = None
self.max_epochs = None
self.output = None
self.batch = None
self.metrics = {}
self.dataloader = None
self.seed = None
self.times = {}
class RemovableEventHandle:
"""Handle for removable event handlers."""
def remove(self): ...
class NotComputableError(RuntimeError):
"""Raised when a metric cannot be computed."""
pass
class Serializable:
"""Mixin for serializable objects."""
def state_dict(self): ...
def load_state_dict(self, state_dict): ...
class DeterministicEngine(Engine):
"""Deterministic version of Engine with reproducible behavior."""
def __init__(self, process_function, deterministic=True): ...
class EventEnum:
"""Base class for creating custom event enums."""
pass
class EventsList:
"""Container for multiple events."""
def __init__(self, *events): ...
class CallableEventWithFilter:
"""Event with conditional execution based on filter function."""
def __init__(self, event, filter_fn, every=None, once=None): ...
class TbpttState:
"""Extended state for TBPTT training."""
def __init__(self):
self.tbptt_step = 0
self.time_step = 0
class Parallel:
"""Parallel execution launcher for distributed training."""
def __init__(self, backend=None, nprocs=None, **kwargs): ...
def run(self, fn, *args, **kwargs): ...
class DistributedProxySampler:
"""Distributed sampler proxy for automatic distributed data sampling."""
def __init__(self, sampler, num_replicas=None, rank=None, seed=0): ...