A lightweight library to help with training neural networks in PyTorch.
—
Training enhancement utilities including checkpointing, early stopping, logging, learning rate scheduling, and experiment tracking. PyTorch Ignite provides 40+ built-in handlers that plug into the event system to enhance training workflows.
Model and training state checkpointing with flexible save strategies.
class Checkpoint:
"""
Flexible checkpointing handler.
Parameters:
- to_save: dictionary of objects to save
- save_handler: handler for saving (DiskSaver, etc.)
- filename_prefix: prefix for checkpoint filenames
- score_function: function to compute checkpoint score
- score_name: name of the score metric
- n_saved: number of checkpoints to keep
- atomic: whether to use atomic saves
- require_empty: require empty directory
- archived: whether to archive old checkpoints
- greater_or_equal: score comparison direction
"""
def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, archived=False, greater_or_equal=False): ...
class DiskSaver:
"""
Disk-based checkpoint saver.
Parameters:
- dirname: directory to save checkpoints
- atomic: whether to use atomic saves
- create_dir: whether to create directory if it doesn't exist
- require_empty: require empty directory
"""
def __init__(self, dirname, atomic=True, create_dir=True, require_empty=True): ...
class ModelCheckpoint:
"""
Model checkpoint handler (deprecated - use Checkpoint instead).
Parameters:
- dirname: directory to save checkpoints
- filename_prefix: prefix for checkpoint filenames
- score_function: function to compute checkpoint score
- score_name: name of the score metric
- n_saved: number of checkpoints to keep
- atomic: whether to use atomic saves
- require_empty: require empty directory
- create_dir: whether to create directory
- save_as_state_dict: save as state dict instead of full model
- global_step_transform: function to transform global step
"""
def __init__(self, dirname, filename_prefix, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None): ...Early stopping based on validation metrics to prevent overfitting.
class EarlyStopping:
"""
Early stopping handler to prevent overfitting.
Parameters:
- patience: number of events to wait before stopping
- score_function: function to compute stopping score
- trainer: trainer engine to stop
- min_delta: minimum change required to reset patience
- cumulative_delta: whether to use cumulative delta
"""
def __init__(self, patience, score_function, trainer, min_delta=0.0, cumulative_delta=False): ...Learning rate scheduling with various strategies and warmup support.
class LRScheduler:
"""
Learning rate scheduler wrapper.
Parameters:
- lr_scheduler: PyTorch learning rate scheduler
- save_history: whether to save LR history
- **kwds: additional arguments
"""
def __init__(self, lr_scheduler, save_history=False, **kwds): ...
def create_lr_scheduler_with_warmup(lr_scheduler, warmup_start_value, warmup_end_value, warmup_duration, save_history=False):
"""
Create learning rate scheduler with warmup.
Parameters:
- lr_scheduler: base learning rate scheduler
- warmup_start_value: starting learning rate for warmup
- warmup_end_value: ending learning rate for warmup
- warmup_duration: duration of warmup phase
- save_history: whether to save LR history
Returns:
Combined scheduler with warmup
"""
class CosineAnnealingScheduler:
"""
Cosine annealing scheduler.
Parameters:
- optimizer: PyTorch optimizer
- param_name: parameter name to schedule
- start_value: starting parameter value
- end_value: ending parameter value
- cycle_size: size of one cycle
- cycle_mult: cycle size multiplier
- start_value_mult: start value multiplier per cycle
- end_value_mult: end value multiplier per cycle
- save_history: whether to save parameter history
"""
def __init__(self, optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False): ...
class LinearCyclicalScheduler:
"""
Linear cyclical scheduler.
Parameters:
- optimizer: PyTorch optimizer
- param_name: parameter name to schedule
- start_value: starting parameter value
- end_value: ending parameter value
- cycle_size: size of one cycle
- cycle_mult: cycle size multiplier
- start_value_mult: start value multiplier per cycle
- end_value_mult: end value multiplier per cycle
- save_history: whether to save parameter history
"""
def __init__(self, optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False): ...
class ConcatScheduler:
"""
Concatenated scheduler combining multiple schedulers.
Parameters:
- schedulers: list of (scheduler, duration) tuples
- durations: list of durations for each scheduler
- save_history: whether to save parameter history
"""
def __init__(self, schedulers, durations, save_history=False): ...
class PiecewiseLinear:
"""
Piecewise linear scheduler.
Parameters:
- optimizer: PyTorch optimizer
- param_name: parameter name to schedule
- milestones_values: list of (milestone, value) tuples
- save_history: whether to save parameter history
"""
def __init__(self, optimizer, param_name, milestones_values, save_history=False): ...General parameter scheduling framework for optimizers.
class ParamScheduler:
"""
Base parameter scheduler class.
Parameters:
- optimizer: PyTorch optimizer
- param_name: parameter name to schedule
- save_history: whether to save parameter history
"""
def __init__(self, optimizer, param_name, save_history=False): ...
class ParamGroupScheduler:
"""
Parameter group scheduler for different parameter groups.
Parameters:
- schedulers: list of schedulers for each parameter group
- names: names for each parameter group
"""
def __init__(self, schedulers, names=None): ...
class StateParamScheduler:
"""
State-based parameter scheduler.
Parameters:
- param_scheduler: base parameter scheduler
- param_name: parameter name
- save_history: whether to save parameter history
"""
def __init__(self, param_scheduler, param_name, save_history=False): ...
class LambdaStateScheduler(StateParamScheduler):
"""
Lambda-based state parameter scheduler.
Parameters:
- lambda_func: lambda function for scheduling
- param_name: parameter name
- save_history: whether to save parameter history
"""
def __init__(self, lambda_func, param_name, save_history=False): ...
class ExpStateScheduler(StateParamScheduler):
"""
Exponential decay state parameter scheduler.
Parameters:
- gamma: exponential decay factor
- param_name: parameter name
- save_history: whether to save parameter history
"""
def __init__(self, gamma, param_name, save_history=False): ...
class StepStateScheduler(StateParamScheduler):
"""
Step-based state parameter scheduler.
Parameters:
- step_size: step size for scheduling
- gamma: decay factor
- param_name: parameter name
- save_history: whether to save parameter history
"""
def __init__(self, step_size, gamma, param_name, save_history=False): ...
class MultiStepStateScheduler(StateParamScheduler):
"""
Multi-step state parameter scheduler.
Parameters:
- milestones: list of milestones
- gamma: decay factor
- param_name: parameter name
- save_history: whether to save parameter history
"""
def __init__(self, milestones, gamma, param_name, save_history=False): ...
class PiecewiseLinearStateScheduler(StateParamScheduler):
"""
Piecewise linear state parameter scheduler.
Parameters:
- milestones_values: list of (milestone, value) tuples
- param_name: parameter name
- save_history: whether to save parameter history
"""
def __init__(self, milestones_values, param_name, save_history=False): ...Integration with popular experiment tracking and logging frameworks.
class TensorboardLogger:
"""
TensorBoard logging handler.
Parameters:
- log_dir: directory for TensorBoard logs
- **kwargs: additional arguments for SummaryWriter
"""
def __init__(self, log_dir=None, **kwargs): ...
def attach_output_handler(self, engine, event_name, tag, output_transform=None, metric_names=None, global_step_transform=None):
"""Attach output logging handler."""
def attach_opt_params_handler(self, engine, event_name, optimizer, param_name="lr"):
"""Attach optimizer parameter logging handler."""
class VisdomLogger:
"""
Visdom logging handler.
Parameters:
- server: Visdom server URL
- port: server port
- **kwargs: additional Visdom arguments
"""
def __init__(self, server=None, port=8097, **kwargs): ...
class MLflowLogger:
"""
MLflow experiment tracking.
Parameters:
- tracking_uri: MLflow tracking server URI
- experiment_name: name of the experiment
- run_name: name of the run
- artifact_location: artifact storage location
- **kwargs: additional MLflow arguments
"""
def __init__(self, tracking_uri=None, experiment_name=None, run_name=None, artifact_location=None, **kwargs): ...
class NeptuneLogger:
"""
Neptune experiment tracking.
Parameters:
- api_token: Neptune API token
- project_name: Neptune project name
- experiment_name: name of the experiment
- **kwargs: additional Neptune arguments
"""
def __init__(self, api_token=None, project_name=None, experiment_name=None, **kwargs): ...
class WandBLogger:
"""
Weights & Biases experiment tracking.
Parameters:
- project: W&B project name
- entity: W&B entity name
- config: configuration dictionary
- **kwargs: additional W&B arguments
"""
def __init__(self, project=None, entity=None, config=None, **kwargs): ...
class ClearMLLogger:
"""
ClearML experiment tracking.
Parameters:
- project_name: ClearML project name
- task_name: task name
- **kwargs: additional ClearML arguments
"""
def __init__(self, project_name=None, task_name=None, **kwargs): ...
class PolyaxonLogger:
"""
Polyaxon experiment tracking.
Parameters:
- **kwargs: Polyaxon configuration arguments
"""
def __init__(self, **kwargs): ...Progress bars and timing utilities for monitoring training.
class ProgressBar:
"""
Progress bar for training monitoring.
Parameters:
- persist: whether to persist after completion
- bar_format: custom bar format string
- **tqdm_kwargs: additional tqdm arguments
"""
def __init__(self, persist=False, bar_format=None, **tqdm_kwargs): ...
class Timer:
"""
Timer for measuring elapsed time.
Parameters:
- average: whether to compute running average
"""
def __init__(self, average=False): ...
def value(self):
"""Get current timer value."""
def reset(self):
"""Reset timer."""
def pause(self):
"""Pause timer."""
def resume(self):
"""Resume timer."""
class BasicTimeProfiler:
"""
Basic profiler for timing engine operations.
Parameters:
- dataflow_profiling: whether to profile data loading
"""
def __init__(self, dataflow_profiling=False): ...
def print_results(self, results_dict):
"""Print profiling results."""
class HandlersTimeProfiler:
"""
Profiler for timing handler execution.
"""
def __init__(self): ...Handlers for enhancing model training behavior.
class GradientAccumulation:
"""
Gradient accumulation handler.
Parameters:
- accumulation_steps: number of steps to accumulate gradients
"""
def __init__(self, accumulation_steps): ...
class EMAHandler:
"""
Exponential Moving Average handler for model parameters.
Parameters:
- model: PyTorch model
- decay: decay factor for EMA
- device: device to store EMA parameters
"""
def __init__(self, model, decay=0.9999, device=None): ...
class FastaiLRFinder:
"""
Learning rate finder inspired by fastai.
Parameters:
- engine: training engine
- optimizer: PyTorch optimizer
- criterion: loss function
- device: device to run on
"""
def __init__(self, engine, optimizer, criterion, device=None): ...
def range_test(self, data_loader, start_lr=1e-7, end_lr=10, num_iter=100, step_mode="exp"):
"""Perform learning rate range test."""
class TerminateOnNan:
"""
Terminate training when NaN values are encountered.
"""
def __init__(self): ...
class TimeLimit:
"""
Terminate training after specified time limit.
Parameters:
- limit: time limit in seconds
"""
def __init__(self, limit): ...Base classes for creating custom handlers and loggers.
class BaseLogger:
"""Base class for loggers."""
def __init__(self): ...
class BaseOptimizerParams:
"""Base class for optimizer parameter handlers."""
def __init__(self): ...
class BaseOutputTransform:
"""Base class for output transformations."""
def __init__(self): ...Helper functions for handlers and training enhancement.
def global_step_from_engine(engine):
"""
Get global step from engine state.
Parameters:
- engine: engine instance
Returns:
Global step number
"""from ignite.handlers import Checkpoint, DiskSaver
# Create checkpoint handler
to_save = {'model': model, 'optimizer': optimizer}
save_handler = DiskSaver('checkpoints', create_dir=True)
checkpoint = Checkpoint(
to_save,
save_handler,
filename_prefix='best',
score_function=lambda engine: -engine.state.metrics['loss'],
score_name='neg_loss',
n_saved=3
)
# Attach to evaluator
evaluator.add_event_handler(Events.COMPLETED, checkpoint)from ignite.handlers import EarlyStopping
# Create early stopping handler
early_stopping = EarlyStopping(
patience=10,
score_function=lambda engine: engine.state.metrics['accuracy'],
trainer=trainer
)
# Attach to evaluator
evaluator.add_event_handler(Events.COMPLETED, early_stopping)from ignite.handlers import LRScheduler
from torch.optim.lr_scheduler import StepLR
# Create PyTorch scheduler
torch_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
# Wrap with Ignite scheduler
lr_scheduler = LRScheduler(torch_scheduler)
# Attach to trainer
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler)
# Access LR history if save_history=True
lr_scheduler = LRScheduler(torch_scheduler, save_history=True)
# ... after training
print(lr_scheduler.get_param())from ignite.handlers import TensorboardLogger
# Create TensorBoard logger
tb_logger = TensorboardLogger(log_dir='tb_logs')
# Log training loss
tb_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED(every=100),
tag="training",
output_transform=lambda loss: {"loss": loss}
)
# Log validation metrics
tb_logger.attach_output_handler(
evaluator,
event_name=Events.COMPLETED,
tag="validation",
metric_names=["accuracy", "loss"],
global_step_transform=global_step_from_engine(trainer)
)
# Log learning rate
tb_logger.attach_opt_params_handler(
trainer,
event_name=Events.ITERATION_COMPLETED(every=100),
optimizer=optimizer,
param_name="lr"
)
# Don't forget to close
trainer.add_event_handler(Events.COMPLETED, lambda _: tb_logger.close())from ignite.handlers import ProgressBar
# Create progress bar
pbar = ProgressBar(persist=True)
# Attach to trainer
pbar.attach(trainer, metric_names=['loss'])
# Or with custom output transform
pbar.attach(trainer, output_transform=lambda x: {'loss': x})Install with Tessl CLI
npx tessl i tessl/pypi-pytorch-ignite