A model zoo for deeplearning4j that provides access to pre-trained neural network models
—
Recurrent neural network implementations including LSTM models for text generation and sequence modeling tasks.
LSTM architecture designed specifically for text generation tasks. The model can be trained on text corpora and used to generate new text sequences. The numLabels parameter represents the total number of unique characters in the vocabulary.
/**
* LSTM designed for text generation. Can be trained on a corpus of text.
* Architecture follows Keras LSTM text generation implementation.
* Includes Walt Whitman pre-trained weights for generating text.
*/
class TextGenerationLSTM extends ZooModel {
/**
* Creates TextGenerationLSTM with basic configuration
* @param numLabels Total number of unique characters in vocabulary
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
TextGenerationLSTM(int numLabels, long seed, int iterations);
/**
* Creates TextGenerationLSTM with workspace mode configuration
* @param numLabels Total number of unique characters in vocabulary
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
TextGenerationLSTM(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the LSTM network configuration
* @return MultiLayerConfiguration for the text generation LSTM
*/
MultiLayerConfiguration conf();
}Architecture Details:
[maxLength, totalUniqueCharacters] where maxLength defaults to 40Usage Examples:
// Create LSTM for text generation with 47 unique characters
int vocabularySize = 47; // Total unique characters in your text corpus
TextGenerationLSTM lstm = new TextGenerationLSTM(vocabularySize, 42, 1);
// Initialize the model
MultiLayerNetwork model = (MultiLayerNetwork) lstm.init();
// Get model metadata
ModelMetaData metadata = lstm.metaData();
int[][] inputShape = metadata.getInputShape(); // [40, 47] (sequence length, vocab size)
ZooType type = metadata.getZooType(); // ZooType.RNN
// Custom configuration with different parameters
TextGenerationLSTM customLSTM = new TextGenerationLSTM(
100, // vocabulary size for larger corpus
123, // custom seed
10, // more iterations
WorkspaceMode.SINGLE
);
MultiLayerConfiguration config = customLSTM.conf();Text Generation Workflow:
// 1. Prepare your text data
String corpus = "Your training text here...";
Map<Character, Integer> charToIndex = createCharacterIndex(corpus);
int vocabSize = charToIndex.size();
// 2. Create and train the model
TextGenerationLSTM lstm = new TextGenerationLSTM(vocabSize, 42, 100);
MultiLayerNetwork model = (MultiLayerNetwork) lstm.init();
// 3. Train on your text corpus (data preparation not shown)
// model.fit(trainingDataIterator);
// 4. Generate text (sampling logic not shown in API)
// String generatedText = generateText(model, seedText, charToIndex);Model Configuration Details:
The LSTM uses the following configuration:
Input Requirements:
TextGenerationLSTM lstm = new TextGenerationLSTM(47, 42, 1);
ModelMetaData metadata = lstm.metaData();
// Input shape: [sequenceLength, vocabularySize]
int[][] inputShape = metadata.getInputShape(); // [40, 47]
int sequenceLength = inputShape[0][0]; // 40 characters per sequence
int vocabSize = inputShape[0][1]; // 47 unique characters
// Output: vocabulary-sized probability distribution
int numOutputs = metadata.getNumOutputs(); // 1 (single output per timestep)Memory and Performance:
// For memory-constrained environments
TextGenerationLSTM lstm = new TextGenerationLSTM(
vocabSize,
42,
1,
WorkspaceMode.SEPARATE // Better for memory usage
);
// For performance-optimized training
TextGenerationLSTM lstm = new TextGenerationLSTM(
vocabSize,
42,
1,
WorkspaceMode.SINGLE // Better for speed
);Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-zoo