A lightweight library to help with training neural networks in PyTorch.
—
Core training loop infrastructure with event-driven architecture. The Engine is the central component of PyTorch Ignite, providing a flexible framework for training and evaluating neural networks with comprehensive lifecycle management.
The main Engine class that manages training and evaluation loops with a sophisticated event system.
class Engine:
"""
Core engine for training and evaluation loops with event system.
Parameters:
- process_function: callable that processes a batch of data
Attributes:
- state: State object containing current training information
- should_terminate: boolean flag to terminate training
- should_terminate_single_epoch: boolean flag to terminate current epoch
"""
def __init__(self, process_function):
"""Initialize engine with a process function."""
def run(self, data, max_epochs=1, epoch_length=None, seed=None):
"""
Run the engine on data for specified epochs.
Parameters:
- data: data loader or iterable
- max_epochs: maximum number of epochs to run
- epoch_length: number of iterations per epoch (optional)
- seed: random seed for reproducibility
Returns:
State object with final training state
"""
def add_event_handler(self, event_name, handler, *args, **kwargs):
"""
Add an event handler for the specified event.
Parameters:
- event_name: name of the event
- handler: callable to execute when event occurs
- args, kwargs: arguments to pass to handler
Returns:
RemovableEventHandle object
"""
def on(self, event_filter=None):
"""
Decorator for adding event handlers.
Parameters:
- event_filter: event or event filter to listen for
Returns:
Decorator function
"""
def fire_event(self, event_name):
"""Fire an event, executing all registered handlers."""
def terminate(self):
"""Terminate the training loop."""
def terminate_epoch(self):
"""Terminate the current epoch."""
def has_event_handler(self, handler, event_name=None):
"""Check if handler is registered for event."""
def remove_event_handler(self, handler, event_name):
"""Remove an event handler."""
class DeterministicEngine(Engine):
"""
Deterministic version of Engine with reproducible behavior.
Parameters:
- process_function: callable that processes a batch of data
- deterministic: enable deterministic behavior
"""
def __init__(self, process_function, deterministic=True): ...Comprehensive event system providing fine-grained control over training lifecycle.
class Events:
"""Event types for engine lifecycle."""
STARTED = 'started'
EPOCH_STARTED = 'epoch_started'
ITERATION_STARTED = 'iteration_started'
ITERATION_COMPLETED = 'iteration_completed'
EPOCH_COMPLETED = 'epoch_completed'
COMPLETED = 'completed'
EXCEPTION_RAISED = 'exception_raised'
GET_BATCH_STARTED = 'get_batch_started'
GET_BATCH_COMPLETED = 'get_batch_completed'
DATALOADER_STOP_ITERATION = 'dataloader_stop_iteration'
@staticmethod
def ITERATION_STARTED(every=1, once=None):
"""Create event filter for iteration started events."""
@staticmethod
def ITERATION_COMPLETED(every=1, once=None):
"""Create event filter for iteration completed events."""
@staticmethod
def EPOCH_STARTED(every=1, once=None):
"""Create event filter for epoch started events."""
@staticmethod
def EPOCH_COMPLETED(every=1, once=None):
"""Create event filter for epoch completed events."""
class EventEnum:
"""
Base class for creating custom event enums.
Allows creation of custom events that integrate with the event system.
"""
pass
class EventsList:
"""
Container for multiple events.
Allows grouping multiple events together for batch event handling.
"""
def __init__(self, *events): ...
class CallableEventWithFilter:
"""
Event with conditional execution based on filter function.
Parameters:
- event: base event to filter
- filter_fn: function that determines when event should fire
"""
def __init__(self, event, filter_fn, every=None, once=None): ...Container for engine state information during training and evaluation.
class State:
"""
Engine state containing training information.
Attributes:
- iteration: current iteration number (global)
- epoch: current epoch number
- epoch_length: length of current epoch
- max_epochs: maximum number of epochs
- output: output from last process_function call
- batch: current batch data
- metrics: dictionary of computed metrics
- dataloader: current data loader
- seed: random seed used
- times: dictionary of timing information
"""
def __init__(self):
self.iteration = 0
self.epoch = 0
self.epoch_length = None
self.max_epochs = None
self.output = None
self.batch = None
self.metrics = {}
self.dataloader = None
self.seed = None
self.times = {}Convenience functions for creating supervised training and evaluation engines.
def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):
"""
Create an engine for supervised training.
Parameters:
- model: PyTorch model to train
- optimizer: PyTorch optimizer
- loss_fn: loss function
- device: device to move data to (optional)
- non_blocking: non-blocking data transfer
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
- deterministic: use deterministic algorithms
Returns:
Engine configured for supervised training
"""
def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
"""
Create an engine for supervised evaluation.
Parameters:
- model: PyTorch model to evaluate
- metrics: dictionary of metrics to compute
- device: device to move data to (optional)
- non_blocking: non-blocking data transfer
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
Returns:
Engine configured for supervised evaluation
"""Factory functions for creating training step functions with different precision and device support.
def supervised_training_step(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
"""
Factory function for supervised training step.
Parameters:
- model: PyTorch model
- optimizer: PyTorch optimizer
- loss_fn: loss function
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
Returns:
Process function for training step
"""
def supervised_training_step_amp(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, scaler=None):
"""
Factory function for supervised training step with automatic mixed precision.
Parameters:
- model: PyTorch model
- optimizer: PyTorch optimizer
- loss_fn: loss function
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
- scaler: GradScaler for mixed precision
Returns:
Process function for AMP training step
"""
def supervised_training_step_apex(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
"""
Factory function for supervised training step with NVIDIA Apex.
Parameters:
- model: PyTorch model
- optimizer: PyTorch optimizer
- loss_fn: loss function
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
Returns:
Process function for Apex training step
"""
def supervised_training_step_tpu(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
"""
Factory function for supervised training step on TPU devices.
Parameters:
- model: PyTorch model
- optimizer: PyTorch optimizer
- loss_fn: loss function
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
Returns:
Process function for TPU training step
"""Factory functions for creating evaluation step functions with different precision support.
def supervised_evaluation_step(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
"""
Factory function for supervised evaluation step.
Parameters:
- model: PyTorch model
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
Returns:
Process function for evaluation step
"""
def supervised_evaluation_step_amp(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
"""
Factory function for supervised evaluation step with automatic mixed precision.
Parameters:
- model: PyTorch model
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: function to prepare batch data
- output_transform: function to transform engine output
Returns:
Process function for AMP evaluation step
"""Handle for removable event handlers.
class RemovableEventHandle:
"""Handle for removable event handlers."""
def remove(self):
"""Remove the associated event handler."""from ignite.engine import Engine, Events
def process_function(engine, batch):
model.train()
optimizer.zero_grad()
x, y = batch
y_pred = model(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(process_function)
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_loss(engine):
print(f"Iteration {engine.state.iteration}: Loss = {engine.state.output}")
trainer.run(train_loader, max_epochs=10)# Execute every 50 iterations
@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_intermediate(engine):
print(f"Iteration {engine.state.iteration}")
# Execute only once at iteration 100
@trainer.on(Events.ITERATION_COMPLETED(once=100))
def save_checkpoint(engine):
torch.save(model.state_dict(), 'checkpoint.pth')
# Execute at the end of each epoch
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(engine):
evaluator.run(val_loader)@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
print(f"Exception occurred: {e}")
# Custom exception handling logic
if isinstance(e, KeyboardInterrupt):
print("Training interrupted by user")
else:
print("Unexpected error occurred")Install with Tessl CLI
npx tessl i tessl/pypi-pytorch-ignite