or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

cnn-models.mdcore-interface.mdimagenet-integration.mdindex.mdmodel-selection.mdrnn-models.md

rnn-models.mddocs/

0

# RNN Models

1

2

Recurrent neural network implementations including LSTM models for text generation and sequence modeling tasks.

3

4

## Capabilities

5

6

### TextGenerationLSTM

7

8

LSTM architecture designed specifically for text generation tasks. The model can be trained on text corpora and used to generate new text sequences. The `numLabels` parameter represents the total number of unique characters in the vocabulary.

9

10

```java { .api }

11

/**

12

* LSTM designed for text generation. Can be trained on a corpus of text.

13

* Architecture follows Keras LSTM text generation implementation.

14

* Includes Walt Whitman pre-trained weights for generating text.

15

*/

16

class TextGenerationLSTM extends ZooModel {

17

/**

18

* Creates TextGenerationLSTM with basic configuration

19

* @param numLabels Total number of unique characters in vocabulary

20

* @param seed Random seed for reproducibility

21

* @param iterations Number of training iterations

22

*/

23

TextGenerationLSTM(int numLabels, long seed, int iterations);

24

25

/**

26

* Creates TextGenerationLSTM with workspace mode configuration

27

* @param numLabels Total number of unique characters in vocabulary

28

* @param seed Random seed for reproducibility

29

* @param iterations Number of training iterations

30

* @param workspaceMode Memory workspace configuration

31

*/

32

TextGenerationLSTM(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);

33

34

/**

35

* Returns the LSTM network configuration

36

* @return MultiLayerConfiguration for the text generation LSTM

37

*/

38

MultiLayerConfiguration conf();

39

}

40

```

41

42

**Architecture Details:**

43

- **Input Shape**: `[maxLength, totalUniqueCharacters]` where `maxLength` defaults to 40

44

- **Layer 1**: GravesLSTM with 256 units and TANH activation

45

- **Layer 2**: GravesLSTM with 256 units and TANH activation

46

- **Output Layer**: RnnOutputLayer with SOFTMAX activation for character prediction

47

- **Loss Function**: MCXENT (Multi-Class Cross Entropy)

48

- **Backprop Type**: Truncated BPTT with forward/backward length of 50

49

50

**Usage Examples:**

51

52

```java

53

// Create LSTM for text generation with 47 unique characters

54

int vocabularySize = 47; // Total unique characters in your text corpus

55

TextGenerationLSTM lstm = new TextGenerationLSTM(vocabularySize, 42, 1);

56

57

// Initialize the model

58

MultiLayerNetwork model = (MultiLayerNetwork) lstm.init();

59

60

// Get model metadata

61

ModelMetaData metadata = lstm.metaData();

62

int[][] inputShape = metadata.getInputShape(); // [40, 47] (sequence length, vocab size)

63

ZooType type = metadata.getZooType(); // ZooType.RNN

64

65

// Custom configuration with different parameters

66

TextGenerationLSTM customLSTM = new TextGenerationLSTM(

67

100, // vocabulary size for larger corpus

68

123, // custom seed

69

10, // more iterations

70

WorkspaceMode.SINGLE

71

);

72

73

MultiLayerConfiguration config = customLSTM.conf();

74

```

75

76

**Text Generation Workflow:**

77

78

```java

79

// 1. Prepare your text data

80

String corpus = "Your training text here...";

81

Map<Character, Integer> charToIndex = createCharacterIndex(corpus);

82

int vocabSize = charToIndex.size();

83

84

// 2. Create and train the model

85

TextGenerationLSTM lstm = new TextGenerationLSTM(vocabSize, 42, 100);

86

MultiLayerNetwork model = (MultiLayerNetwork) lstm.init();

87

88

// 3. Train on your text corpus (data preparation not shown)

89

// model.fit(trainingDataIterator);

90

91

// 4. Generate text (sampling logic not shown in API)

92

// String generatedText = generateText(model, seedText, charToIndex);

93

```

94

95

**Model Configuration Details:**

96

97

The LSTM uses the following configuration:

98

- **Optimization Algorithm**: Stochastic Gradient Descent

99

- **Learning Rate**: 0.01

100

- **Weight Initialization**: Xavier initialization

101

- **Regularization**: L2 with 0.001 coefficient

102

- **Updater**: RMSprop optimizer

103

- **Sequence Handling**: Truncated BPTT for handling long sequences

104

105

**Input Requirements:**

106

107

```java

108

TextGenerationLSTM lstm = new TextGenerationLSTM(47, 42, 1);

109

ModelMetaData metadata = lstm.metaData();

110

111

// Input shape: [sequenceLength, vocabularySize]

112

int[][] inputShape = metadata.getInputShape(); // [40, 47]

113

int sequenceLength = inputShape[0][0]; // 40 characters per sequence

114

int vocabSize = inputShape[0][1]; // 47 unique characters

115

116

// Output: vocabulary-sized probability distribution

117

int numOutputs = metadata.getNumOutputs(); // 1 (single output per timestep)

118

```

119

120

**Memory and Performance:**

121

122

```java

123

// For memory-constrained environments

124

TextGenerationLSTM lstm = new TextGenerationLSTM(

125

vocabSize,

126

42,

127

1,

128

WorkspaceMode.SEPARATE // Better for memory usage

129

);

130

131

// For performance-optimized training

132

TextGenerationLSTM lstm = new TextGenerationLSTM(

133

vocabSize,

134

42,

135

1,

136

WorkspaceMode.SINGLE // Better for speed

137

);

138

```