A lightweight library to help with training neural networks in PyTorch.
—
Specialized engines and utilities for advanced use cases. The contrib module provides experimental and specialized functionality that extends PyTorch Ignite's core capabilities for specific training scenarios.
Specialized trainer for recurrent neural networks using truncated backpropagation through time, enabling training of RNNs on very long sequences.
def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):
"""
Create supervised trainer for truncated backpropagation through time.
Parameters:
- model: PyTorch model
- optimizer: PyTorch optimizer
- loss_fn: loss function
- tbtt_step: number of steps for truncated backpropagation
- device: device to run on
- non_blocking: non-blocking tensor transfers
- prepare_batch: batch preparation function
- output_transform: output transformation function
- deterministic: deterministic execution
Returns:
Engine configured for TBPTT training
"""
class Tbptt_Events:
"""Events specific to truncated backpropagation through time training."""
TBPTT_STEP_COMPLETED = 'tbptt_step_completed'
TIME_STEP_COMPLETED = 'time_step_completed'from ignite.contrib.engines import create_supervised_tbptt_trainer, Tbptt_Events
import torch
import torch.nn as nn
# Define RNN model
model = nn.LSTM(input_size=100, hidden_size=256, num_layers=2)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# Create TBPTT trainer
trainer = create_supervised_tbptt_trainer(
model=model,
optimizer=optimizer,
loss_fn=criterion,
tbtt_step=20 # Truncate backprop every 20 steps
)
# Add TBPTT-specific event handlers
@trainer.on(Tbptt_Events.TBPTT_STEP_COMPLETED)
def log_tbptt_step(engine):
print(f"TBPTT step completed: {engine.state.iteration}")
@trainer.on(Tbptt_Events.TIME_STEP_COMPLETED)
def log_time_step(engine):
print(f"Time step completed: {engine.state.iteration}")
# Train model
trainer.run(data_loader, max_epochs=10)class TbpttState:
"""Extended state for TBPTT training."""
def __init__(self):
self.tbptt_step = 0
self.time_step = 0
# Inherits all attributes from base State classInstall with Tessl CLI
npx tessl i tessl/pypi-pytorch-ignite