0
# RNN Models
1
2
Recurrent neural network implementations including LSTM models for text generation and sequence modeling tasks.
3
4
## Capabilities
5
6
### TextGenerationLSTM
7
8
LSTM architecture designed specifically for text generation tasks. The model can be trained on text corpora and used to generate new text sequences. The `numLabels` parameter represents the total number of unique characters in the vocabulary.
9
10
```java { .api }
11
/**
12
* LSTM designed for text generation. Can be trained on a corpus of text.
13
* Architecture follows Keras LSTM text generation implementation.
14
* Includes Walt Whitman pre-trained weights for generating text.
15
*/
16
class TextGenerationLSTM extends ZooModel {
17
/**
18
* Creates TextGenerationLSTM with basic configuration
19
* @param numLabels Total number of unique characters in vocabulary
20
* @param seed Random seed for reproducibility
21
* @param iterations Number of training iterations
22
*/
23
TextGenerationLSTM(int numLabels, long seed, int iterations);
24
25
/**
26
* Creates TextGenerationLSTM with workspace mode configuration
27
* @param numLabels Total number of unique characters in vocabulary
28
* @param seed Random seed for reproducibility
29
* @param iterations Number of training iterations
30
* @param workspaceMode Memory workspace configuration
31
*/
32
TextGenerationLSTM(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
33
34
/**
35
* Returns the LSTM network configuration
36
* @return MultiLayerConfiguration for the text generation LSTM
37
*/
38
MultiLayerConfiguration conf();
39
}
40
```
41
42
**Architecture Details:**
43
- **Input Shape**: `[maxLength, totalUniqueCharacters]` where `maxLength` defaults to 40
44
- **Layer 1**: GravesLSTM with 256 units and TANH activation
45
- **Layer 2**: GravesLSTM with 256 units and TANH activation
46
- **Output Layer**: RnnOutputLayer with SOFTMAX activation for character prediction
47
- **Loss Function**: MCXENT (Multi-Class Cross Entropy)
48
- **Backprop Type**: Truncated BPTT with forward/backward length of 50
49
50
**Usage Examples:**
51
52
```java
53
// Create LSTM for text generation with 47 unique characters
54
int vocabularySize = 47; // Total unique characters in your text corpus
55
TextGenerationLSTM lstm = new TextGenerationLSTM(vocabularySize, 42, 1);
56
57
// Initialize the model
58
MultiLayerNetwork model = (MultiLayerNetwork) lstm.init();
59
60
// Get model metadata
61
ModelMetaData metadata = lstm.metaData();
62
int[][] inputShape = metadata.getInputShape(); // [40, 47] (sequence length, vocab size)
63
ZooType type = metadata.getZooType(); // ZooType.RNN
64
65
// Custom configuration with different parameters
66
TextGenerationLSTM customLSTM = new TextGenerationLSTM(
67
100, // vocabulary size for larger corpus
68
123, // custom seed
69
10, // more iterations
70
WorkspaceMode.SINGLE
71
);
72
73
MultiLayerConfiguration config = customLSTM.conf();
74
```
75
76
**Text Generation Workflow:**
77
78
```java
79
// 1. Prepare your text data
80
String corpus = "Your training text here...";
81
Map<Character, Integer> charToIndex = createCharacterIndex(corpus);
82
int vocabSize = charToIndex.size();
83
84
// 2. Create and train the model
85
TextGenerationLSTM lstm = new TextGenerationLSTM(vocabSize, 42, 100);
86
MultiLayerNetwork model = (MultiLayerNetwork) lstm.init();
87
88
// 3. Train on your text corpus (data preparation not shown)
89
// model.fit(trainingDataIterator);
90
91
// 4. Generate text (sampling logic not shown in API)
92
// String generatedText = generateText(model, seedText, charToIndex);
93
```
94
95
**Model Configuration Details:**
96
97
The LSTM uses the following configuration:
98
- **Optimization Algorithm**: Stochastic Gradient Descent
99
- **Learning Rate**: 0.01
100
- **Weight Initialization**: Xavier initialization
101
- **Regularization**: L2 with 0.001 coefficient
102
- **Updater**: RMSprop optimizer
103
- **Sequence Handling**: Truncated BPTT for handling long sequences
104
105
**Input Requirements:**
106
107
```java
108
TextGenerationLSTM lstm = new TextGenerationLSTM(47, 42, 1);
109
ModelMetaData metadata = lstm.metaData();
110
111
// Input shape: [sequenceLength, vocabularySize]
112
int[][] inputShape = metadata.getInputShape(); // [40, 47]
113
int sequenceLength = inputShape[0][0]; // 40 characters per sequence
114
int vocabSize = inputShape[0][1]; // 47 unique characters
115
116
// Output: vocabulary-sized probability distribution
117
int numOutputs = metadata.getNumOutputs(); // 1 (single output per timestep)
118
```
119
120
**Memory and Performance:**
121
122
```java
123
// For memory-constrained environments
124
TextGenerationLSTM lstm = new TextGenerationLSTM(
125
vocabSize,
126
42,
127
1,
128
WorkspaceMode.SEPARATE // Better for memory usage
129
);
130
131
// For performance-optimized training
132
TextGenerationLSTM lstm = new TextGenerationLSTM(
133
vocabSize,
134
42,
135
1,
136
WorkspaceMode.SINGLE // Better for speed
137
);
138
```