or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

base-exceptions.mdcontrib.mddistributed.mdengine.mdhandlers.mdindex.mdmetrics.mdutils.md

contrib.mddocs/

0

# Contrib Module

1

2

Specialized engines and utilities for advanced use cases. The contrib module provides experimental and specialized functionality that extends PyTorch Ignite's core capabilities for specific training scenarios.

3

4

## Capabilities

5

6

### Truncated Backpropagation Through Time (TBPTT)

7

8

Specialized trainer for recurrent neural networks using truncated backpropagation through time, enabling training of RNNs on very long sequences.

9

10

```python { .api }

11

def create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, device=None, non_blocking=False, prepare_batch=None, output_transform=None, deterministic=False):

12

"""

13

Create supervised trainer for truncated backpropagation through time.

14

15

Parameters:

16

- model: PyTorch model

17

- optimizer: PyTorch optimizer

18

- loss_fn: loss function

19

- tbtt_step: number of steps for truncated backpropagation

20

- device: device to run on

21

- non_blocking: non-blocking tensor transfers

22

- prepare_batch: batch preparation function

23

- output_transform: output transformation function

24

- deterministic: deterministic execution

25

26

Returns:

27

Engine configured for TBPTT training

28

"""

29

30

class Tbptt_Events:

31

"""Events specific to truncated backpropagation through time training."""

32

TBPTT_STEP_COMPLETED = 'tbptt_step_completed'

33

TIME_STEP_COMPLETED = 'time_step_completed'

34

```

35

36

### Usage Example

37

38

```python

39

from ignite.contrib.engines import create_supervised_tbptt_trainer, Tbptt_Events

40

import torch

41

import torch.nn as nn

42

43

# Define RNN model

44

model = nn.LSTM(input_size=100, hidden_size=256, num_layers=2)

45

optimizer = torch.optim.Adam(model.parameters())

46

criterion = nn.CrossEntropyLoss()

47

48

# Create TBPTT trainer

49

trainer = create_supervised_tbptt_trainer(

50

model=model,

51

optimizer=optimizer,

52

loss_fn=criterion,

53

tbtt_step=20 # Truncate backprop every 20 steps

54

)

55

56

# Add TBPTT-specific event handlers

57

@trainer.on(Tbptt_Events.TBPTT_STEP_COMPLETED)

58

def log_tbptt_step(engine):

59

print(f"TBPTT step completed: {engine.state.iteration}")

60

61

@trainer.on(Tbptt_Events.TIME_STEP_COMPLETED)

62

def log_time_step(engine):

63

print(f"Time step completed: {engine.state.iteration}")

64

65

# Train model

66

trainer.run(data_loader, max_epochs=10)

67

```

68

69

## Types

70

71

```python { .api }

72

class TbpttState:

73

"""Extended state for TBPTT training."""

74

def __init__(self):

75

self.tbptt_step = 0

76

self.time_step = 0

77

# Inherits all attributes from base State class

78

```