0
# Contrib Module
1
2
Specialized engines and utilities for advanced use cases. The contrib module provides experimental and specialized functionality that extends PyTorch Ignite's core capabilities for specific training scenarios.
3
4
## Capabilities
5
6
### Truncated Backpropagation Through Time (TBPTT)
7
8
Specialized trainer for recurrent neural networks using truncated backpropagation through time, enabling training of RNNs on very long sequences.
9
10
```python { .api }
11
def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):
12
"""
13
Create supervised trainer for truncated backpropagation through time.
14
15
Parameters:
16
- model: PyTorch model
17
- optimizer: PyTorch optimizer
18
- loss_fn: loss function
19
- tbtt_step: number of steps for truncated backpropagation
20
- device: device to run on
21
- non_blocking: non-blocking tensor transfers
22
- prepare_batch: batch preparation function
23
- output_transform: output transformation function
24
- deterministic: deterministic execution
25
26
Returns:
27
Engine configured for TBPTT training
28
"""
29
30
class Tbptt_Events:
31
"""Events specific to truncated backpropagation through time training."""
32
TBPTT_STEP_COMPLETED = 'tbptt_step_completed'
33
TIME_STEP_COMPLETED = 'time_step_completed'
34
```
35
36
### Usage Example
37
38
```python
39
from ignite.contrib.engines import create_supervised_tbptt_trainer, Tbptt_Events
40
import torch
41
import torch.nn as nn
42
43
# Define RNN model
44
model = nn.LSTM(input_size=100, hidden_size=256, num_layers=2)
45
optimizer = torch.optim.Adam(model.parameters())
46
criterion = nn.CrossEntropyLoss()
47
48
# Create TBPTT trainer
49
trainer = create_supervised_tbptt_trainer(
50
model=model,
51
optimizer=optimizer,
52
loss_fn=criterion,
53
tbtt_step=20 # Truncate backprop every 20 steps
54
)
55
56
# Add TBPTT-specific event handlers
57
@trainer.on(Tbptt_Events.TBPTT_STEP_COMPLETED)
58
def log_tbptt_step(engine):
59
print(f"TBPTT step completed: {engine.state.iteration}")
60
61
@trainer.on(Tbptt_Events.TIME_STEP_COMPLETED)
62
def log_time_step(engine):
63
print(f"Time step completed: {engine.state.iteration}")
64
65
# Train model
66
trainer.run(data_loader, max_epochs=10)
67
```
68
69
## Types
70
71
```python { .api }
72
class TbpttState:
73
"""Extended state for TBPTT training."""
74
def __init__(self):
75
self.tbptt_step = 0
76
self.time_step = 0
77
# Inherits all attributes from base State class
78
```