or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

base-exceptions.mdcontrib.mddistributed.mdengine.mdhandlers.mdindex.mdmetrics.mdutils.md
tile.json

tessl/pypi-pytorch-ignite

A lightweight library to help with training neural networks in PyTorch.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/pytorch-ignite@0.5.x

To install, run

npx @tessl/cli install tessl/pypi-pytorch-ignite@0.5.0

index.mddocs/

PyTorch Ignite

A lightweight library to help with training neural networks in PyTorch. Ignite provides a flexible, extensible API that reduces boilerplate code compared to pure PyTorch implementations through its event-driven architecture and handler system.

Package Information

  • Package Name: pytorch-ignite
  • Package Type: pypi
  • Language: Python
  • Installation: pip install pytorch-ignite

Core Imports

import ignite

Common imports for working with engines and training:

from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint, EarlyStopping

Basic Usage

from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint
import torch
import torch.nn as nn

# Define model, optimizer, and loss
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Create training and validation engines
trainer = create_supervised_trainer(model, optimizer, criterion)
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(), 'loss': Loss(criterion)})

# Add event handlers
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training_loss(engine):
    print(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")

# Add model checkpointing
checkpoint = ModelCheckpoint('models', 'mymodel', score_function=lambda engine: -engine.state.metrics['loss'])
evaluator.add_event_handler(Events.COMPLETED, checkpoint, {'model': model})

# Start training
trainer.run(train_loader, max_epochs=100)

Architecture

PyTorch Ignite is built around several core architectural components:

  • Engine: The central component that manages training/evaluation loops with event-driven architecture
  • Events: Comprehensive event system allowing fine-grained control over training lifecycle
  • Handlers: Pluggable components for checkpointing, logging, scheduling, and training enhancements
  • Metrics: Extensive collection of evaluation metrics for various machine learning tasks
  • Distributed: Built-in support for distributed training across multiple backends

This design enables maximum flexibility while reducing boilerplate code, allowing researchers and practitioners to focus on model development rather than training infrastructure.

Capabilities

Engine and Training Loop

Core training loop infrastructure with event-driven architecture, providing supervised training and evaluation engines with comprehensive lifecycle management.

class Engine:
    def __init__(self, process_function): ...
    def run(self, data, max_epochs=1, epoch_length=None, seed=None): ...
    def add_event_handler(self, event_name, handler, *args, **kwargs): ...
    def on(self, event_filter=None): ...

def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...
def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=None, output_transform=None): ...

class Events:
    STARTED = 'started'
    EPOCH_STARTED = 'epoch_started'
    ITERATION_STARTED = 'iteration_started'
    ITERATION_COMPLETED = 'iteration_completed'
    EPOCH_COMPLETED = 'epoch_completed'
    COMPLETED = 'completed'
    EXCEPTION_RAISED = 'exception_raised'

Engine and Training Loop

Metrics Collection

Comprehensive metric collection system covering classification, regression, NLP, computer vision, clustering, and GAN evaluation with 80+ built-in metrics.

class Metric:
    def reset(self): ...
    def update(self, output): ...
    def compute(self): ...

class Accuracy(Metric): ...
class Precision(Metric): ...
class Recall(Metric): ...
class Loss(Metric): ...
class MeanSquaredError(Metric): ...
class RootMeanSquaredError(Metric): ...
class RocAuc(Metric): ...
class ConfusionMatrix(Metric): ...

Metrics Collection

Handlers and Training Enhancement

Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking with 40+ built-in handlers.

class Checkpoint:
    def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...

class EarlyStopping:
    def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...

class LRScheduler:
    def __init__(self, lr_scheduler, save_history=False, **kwds): ...

class ProgressBar:
    def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...

Handlers and Training Enhancement

Distributed Training

Comprehensive distributed computing support with multiple backends including native PyTorch DDP, Horovod, and XLA/TPU support.

def initialize(backend=None, **kwargs): ...
def finalize(): ...
def all_reduce(tensor, group=None, op='SUM'): ...
def all_gather(tensor, group=None): ...
def broadcast(tensor, src=0, group=None): ...
def barrier(group=None): ...
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): ...

Distributed Training

Utilities and Helpers

Helper utilities for tensor operations, type conversions, logging setup, and reproducibility management.

def convert_tensor(input_, device=None, non_blocking=False): ...
def to_onehot(indices, num_classes): ...
def setup_logger(name=None, level=logging.INFO, stream=None, format="%(asctime)s %(name)s %(levelname)s %(message)s", filepath=None, distributed_rank=None): ...
def manual_seed(seed): ...
def apply_to_tensor(input_, func): ...
def apply_to_type(input_, input_type, func): ...

Utilities and Helpers

Contrib Module

Specialized engines and utilities for advanced use cases including truncated backpropagation through time (TBPTT) for recurrent neural networks.

def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False): ...

class Tbptt_Events:
    TBPTT_STEP_COMPLETED = 'tbptt_step_completed'
    TIME_STEP_COMPLETED = 'time_step_completed'

Contrib Module

Base Classes and Exceptions

Core base classes and exception types providing fundamental functionality and error handling.

class Serializable:
    def state_dict(self): ...
    def load_state_dict(self, state_dict): ...

class NotComputableError(RuntimeError):
    pass

Base Classes and Exceptions

Types

class State:
    """Engine state containing training information."""
    def __init__(self):
        self.iteration = 0
        self.epoch = 0
        self.epoch_length = None
        self.max_epochs = None
        self.output = None
        self.batch = None
        self.metrics = {}
        self.dataloader = None
        self.seed = None
        self.times = {}

class RemovableEventHandle:
    """Handle for removable event handlers."""
    def remove(self): ...

class NotComputableError(RuntimeError):
    """Raised when a metric cannot be computed."""
    pass

class Serializable:
    """Mixin for serializable objects."""
    def state_dict(self): ...
    def load_state_dict(self, state_dict): ...

class DeterministicEngine(Engine):
    """Deterministic version of Engine with reproducible behavior."""
    def __init__(self, process_function, deterministic=True): ...

class EventEnum:
    """Base class for creating custom event enums."""
    pass

class EventsList:
    """Container for multiple events."""
    def __init__(self, *events): ...

class CallableEventWithFilter:
    """Event with conditional execution based on filter function."""
    def __init__(self, event, filter_fn, every=None, once=None): ...

class TbpttState:
    """Extended state for TBPTT training."""
    def __init__(self):
        self.tbptt_step = 0
        self.time_step = 0

class Parallel:
    """Parallel execution launcher for distributed training."""
    def __init__(self, backend=None, nprocs=None, **kwargs): ...
    def run(self, fn, *args, **kwargs): ...

class DistributedProxySampler:
    """Distributed sampler proxy for automatic distributed data sampling."""
    def __init__(self, sampler, num_replicas=None, rank=None, seed=0): ...