CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pytorch-ignite

A lightweight library to help with training neural networks in PyTorch.

Pending
Overview
Eval results
Files

engine.mddocs/

Engine and Training Loop

Core training loop infrastructure with event-driven architecture. The Engine is the central component of PyTorch Ignite, providing a flexible framework for training and evaluating neural networks with comprehensive lifecycle management.

Capabilities

Engine Class

The main Engine class that manages training and evaluation loops with a sophisticated event system.

class Engine:
    """
    Core engine for training and evaluation loops with event system.
    
    Parameters:
    - process_function: callable that processes a batch of data
    
    Attributes:
    - state: State object containing current training information
    - should_terminate: boolean flag to terminate training
    - should_terminate_single_epoch: boolean flag to terminate current epoch
    """
    def __init__(self, process_function):
        """Initialize engine with a process function."""
        
    def run(self, data, max_epochs=1, epoch_length=None, seed=None):
        """
        Run the engine on data for specified epochs.
        
        Parameters:
        - data: data loader or iterable
        - max_epochs: maximum number of epochs to run
        - epoch_length: number of iterations per epoch (optional)
        - seed: random seed for reproducibility
        
        Returns:
        State object with final training state
        """
        
    def add_event_handler(self, event_name, handler, *args, **kwargs):
        """
        Add an event handler for the specified event.
        
        Parameters:
        - event_name: name of the event
        - handler: callable to execute when event occurs
        - args, kwargs: arguments to pass to handler
        
        Returns:
        RemovableEventHandle object
        """
        
    def on(self, event_filter=None):
        """
        Decorator for adding event handlers.
        
        Parameters:
        - event_filter: event or event filter to listen for
        
        Returns:
        Decorator function
        """
        
    def fire_event(self, event_name):
        """Fire an event, executing all registered handlers."""
        
    def terminate(self):
        """Terminate the training loop."""
        
    def terminate_epoch(self):
        """Terminate the current epoch."""
        
    def has_event_handler(self, handler, event_name=None):
        """Check if handler is registered for event."""
        
    def remove_event_handler(self, handler, event_name):
        """Remove an event handler."""

class DeterministicEngine(Engine):
    """
    Deterministic version of Engine with reproducible behavior.
    
    Parameters:
    - process_function: callable that processes a batch of data
    - deterministic: enable deterministic behavior
    """
    def __init__(self, process_function, deterministic=True): ...

Events Enum

Comprehensive event system providing fine-grained control over training lifecycle.

class Events:
    """Event types for engine lifecycle."""
    STARTED = 'started'
    EPOCH_STARTED = 'epoch_started'
    ITERATION_STARTED = 'iteration_started'
    ITERATION_COMPLETED = 'iteration_completed'
    EPOCH_COMPLETED = 'epoch_completed'
    COMPLETED = 'completed'
    EXCEPTION_RAISED = 'exception_raised'
    GET_BATCH_STARTED = 'get_batch_started'
    GET_BATCH_COMPLETED = 'get_batch_completed'
    DATALOADER_STOP_ITERATION = 'dataloader_stop_iteration'
    
    @staticmethod
    def ITERATION_STARTED(every=1, once=None):
        """Create event filter for iteration started events."""
        
    @staticmethod
    def ITERATION_COMPLETED(every=1, once=None):
        """Create event filter for iteration completed events."""
        
    @staticmethod
    def EPOCH_STARTED(every=1, once=None):
        """Create event filter for epoch started events."""
        
    @staticmethod
    def EPOCH_COMPLETED(every=1, once=None):
        """Create event filter for epoch completed events."""

class EventEnum:
    """
    Base class for creating custom event enums.
    
    Allows creation of custom events that integrate with the event system.
    """
    pass

class EventsList:
    """
    Container for multiple events.
    
    Allows grouping multiple events together for batch event handling.
    """
    def __init__(self, *events): ...

class CallableEventWithFilter:
    """
    Event with conditional execution based on filter function.
    
    Parameters:
    - event: base event to filter
    - filter_fn: function that determines when event should fire
    """
    def __init__(self, event, filter_fn, every=None, once=None): ...

Engine State

Container for engine state information during training and evaluation.

class State:
    """
    Engine state containing training information.
    
    Attributes:
    - iteration: current iteration number (global)
    - epoch: current epoch number
    - epoch_length: length of current epoch
    - max_epochs: maximum number of epochs
    - output: output from last process_function call
    - batch: current batch data
    - metrics: dictionary of computed metrics
    - dataloader: current data loader
    - seed: random seed used
    - times: dictionary of timing information
    """
    def __init__(self):
        self.iteration = 0
        self.epoch = 0
        self.epoch_length = None
        self.max_epochs = None
        self.output = None
        self.batch = None
        self.metrics = {}
        self.dataloader = None
        self.seed = None
        self.times = {}

Supervised Training

Convenience functions for creating supervised training and evaluation engines.

def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):
    """
    Create an engine for supervised training.
    
    Parameters:
    - model: PyTorch model to train
    - optimizer: PyTorch optimizer
    - loss_fn: loss function
    - device: device to move data to (optional)
    - non_blocking: non-blocking data transfer
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    - deterministic: use deterministic algorithms
    
    Returns:
    Engine configured for supervised training
    """

def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
    """
    Create an engine for supervised evaluation.
    
    Parameters:
    - model: PyTorch model to evaluate
    - metrics: dictionary of metrics to compute
    - device: device to move data to (optional)
    - non_blocking: non-blocking data transfer
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    
    Returns:
    Engine configured for supervised evaluation
    """

Training Step Functions

Factory functions for creating training step functions with different precision and device support.

def supervised_training_step(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
    """
    Factory function for supervised training step.
    
    Parameters:
    - model: PyTorch model
    - optimizer: PyTorch optimizer
    - loss_fn: loss function
    - device: device to run on
    - non_blocking: non-blocking tensor transfers
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    
    Returns:
    Process function for training step
    """

def supervised_training_step_amp(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, scaler=None):
    """
    Factory function for supervised training step with automatic mixed precision.
    
    Parameters:
    - model: PyTorch model
    - optimizer: PyTorch optimizer
    - loss_fn: loss function
    - device: device to run on
    - non_blocking: non-blocking tensor transfers
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    - scaler: GradScaler for mixed precision
    
    Returns:
    Process function for AMP training step
    """

def supervised_training_step_apex(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
    """
    Factory function for supervised training step with NVIDIA Apex.
    
    Parameters:
    - model: PyTorch model
    - optimizer: PyTorch optimizer
    - loss_fn: loss function
    - device: device to run on
    - non_blocking: non-blocking tensor transfers
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    
    Returns:
    Process function for Apex training step
    """

def supervised_training_step_tpu(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
    """
    Factory function for supervised training step on TPU devices.
    
    Parameters:
    - model: PyTorch model
    - optimizer: PyTorch optimizer
    - loss_fn: loss function
    - device: device to run on
    - non_blocking: non-blocking tensor transfers
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    
    Returns:
    Process function for TPU training step
    """

Evaluation Step Functions

Factory functions for creating evaluation step functions with different precision support.

def supervised_evaluation_step(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
    """
    Factory function for supervised evaluation step.
    
    Parameters:
    - model: PyTorch model
    - device: device to run on
    - non_blocking: non-blocking tensor transfers
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    
    Returns:
    Process function for evaluation step
    """

def supervised_evaluation_step_amp(model, device=None, non_blocking=False, prepare_batch=None, output_transform=None):
    """
    Factory function for supervised evaluation step with automatic mixed precision.
    
    Parameters:
    - model: PyTorch model
    - device: device to run on
    - non_blocking: non-blocking tensor transfers
    - prepare_batch: function to prepare batch data
    - output_transform: function to transform engine output
    
    Returns:
    Process function for AMP evaluation step
    """

Event Handle

Handle for removable event handlers.

class RemovableEventHandle:
    """Handle for removable event handlers."""
    def remove(self):
        """Remove the associated event handler."""

Usage Examples

Basic Training Loop

from ignite.engine import Engine, Events

def process_function(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = batch
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(process_function)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_loss(engine):
    print(f"Iteration {engine.state.iteration}: Loss = {engine.state.output}")

trainer.run(train_loader, max_epochs=10)

Event Filtering

# Execute every 50 iterations
@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_intermediate(engine):
    print(f"Iteration {engine.state.iteration}")

# Execute only once at iteration 100
@trainer.on(Events.ITERATION_COMPLETED(once=100))
def save_checkpoint(engine):
    torch.save(model.state_dict(), 'checkpoint.pth')

# Execute at the end of each epoch
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(engine):
    evaluator.run(val_loader)

Exception Handling

@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
    print(f"Exception occurred: {e}")
    # Custom exception handling logic
    if isinstance(e, KeyboardInterrupt):
        print("Training interrupted by user")
    else:
        print("Unexpected error occurred")

Install with Tessl CLI

npx tessl i tessl/pypi-pytorch-ignite

docs

base-exceptions.md

contrib.md

distributed.md

engine.md

handlers.md

index.md

metrics.md

utils.md

tile.json