Recurrent neural networks for sequence processing, time series prediction, and natural language tasks. Brain.js provides multiple RNN variants including basic RNN, LSTM (Long Short-Term Memory), and GRU (Gated Recurrent Units) with both sequence-to-sequence and time-step implementations.
Basic recurrent neural network for sequence data processing.
/**
* Basic recurrent neural network for processing sequences
* @param options - RNN configuration options
*/
class RNN {
constructor(options?: {
/** Input layer size */
inputSize?: number;
/** Input range for character-based data */
inputRange?: number;
/** Array of hidden layer sizes (default: [20, 20]) */
hiddenLayers?: number[];
/** Output layer size */
outputSize?: number;
/** Learning rate (default: 0.01) */
learningRate?: number;
/** Decay rate for learning rate (default: 0.999) */
decayRate?: number;
});
}Long Short-Term Memory network for handling long sequences and avoiding vanishing gradients.
/**
* LSTM network extending RNN with memory cells
* Inherits all RNN methods and options
*/
class LSTM extends RNN {
// Same constructor and methods as RNN
// Automatically uses LSTM cells for better long-term memory
}Gated Recurrent Unit network, a simplified alternative to LSTM.
/**
* GRU network extending RNN with gating mechanisms
* Inherits all RNN methods and options
*/
class GRU extends RNN {
// Same constructor and methods as RNN
// Uses GRU cells for efficient sequence processing
}Usage Examples:
const brain = require('brain.js');
// Basic RNN
const rnn = new brain.recurrent.RNN({
hiddenLayers: [10, 10],
learningRate: 0.01
});
// LSTM for longer sequences
const lstm = new brain.recurrent.LSTM({
hiddenLayers: [20, 20],
learningRate: 0.005
});
// GRU for efficient processing
const gru = new brain.recurrent.GRU({
hiddenLayers: [15, 15]
});/**
* Train the network on string sequences
* @param data - Array of strings or string sequence data
* @param options - Training configuration options
* @returns Training statistics
*/
train(data: string[] | StringSequenceData[], options?: TrainingOptions): TrainingStats;
interface StringSequenceData {
input: string;
output: string;
}
interface TrainingOptions {
iterations?: number;
errorThresh?: number;
log?: boolean | Function;
logPeriod?: number;
learningRate?: number;
callback?: Function;
callbackPeriod?: number;
}Usage Examples:
// Train on string patterns
const rnn = new brain.recurrent.LSTM();
// Simple string array training
rnn.train(['hello world', 'goodbye world', 'hello universe']);
// Structured input-output training
rnn.train([
{ input: 'I am happy', output: 'positive' },
{ input: 'I am sad', output: 'negative' },
{ input: 'I am excited', output: 'positive' }
]);
// With training options
rnn.train(['pattern1', 'pattern2'], {
iterations: 5000,
errorThresh: 0.01,
log: true
});/**
* Run the network on string input
* @param input - Input string to process
* @returns Generated output string
*/
run(input: string): string;Usage Examples:
// Text generation
const rnn = new brain.recurrent.LSTM();
rnn.train(['hello world', 'goodbye world']);
const output = rnn.run('hello');
console.log(output); // Continues the pattern
// Classification
const classifier = new brain.recurrent.RNN();
classifier.train([
{ input: 'I feel great', output: 'happy' },
{ input: 'I feel terrible', output: 'sad' }
]);
const sentiment = classifier.run('I feel amazing');
console.log(sentiment); // 'happy'For numerical time series and multi-dimensional sequence data.
/**
* RNN designed for time-step numerical data
* Processes sequences of numbers or arrays
*/
class RNNTimeStep extends RNN {
constructor(options?: RNNOptions);
}/**
* LSTM variant for time-step numerical data
* Better for long numerical sequences
*/
class LSTMTimeStep extends RNNTimeStep {
constructor(options?: RNNOptions);
}/**
* GRU variant for time-step numerical data
* Efficient processing of numerical sequences
*/
class GRUTimeStep extends RNNTimeStep {
constructor(options?: RNNOptions);
}/**
* Train on numerical sequences or multi-dimensional arrays
* @param data - Time series training data
* @param options - Training options
* @returns Training statistics
*/
train(data: TimeStepData[] | TimeStepData2D[] | TimeStepData3D[], options?: TrainingOptions): TrainingStats;
interface TimeStepData {
input: number[];
output: number[];
}
interface TimeStepData2D {
input: number[][];
output: number[][];
}
interface TimeStepData3D {
input: number[][][];
output: number[][][];
}Usage Examples:
const lstm = new brain.recurrent.LSTMTimeStep();
// 1D time series
lstm.train([
{ input: [1, 2, 3], output: [4] },
{ input: [2, 3, 4], output: [5] },
{ input: [3, 4, 5], output: [6] }
]);
// 2D sequences (e.g., multiple features over time)
lstm.train([
{
input: [[1, 0.5], [2, 0.6], [3, 0.7]],
output: [[4, 0.8]]
}
]);
// Stock price prediction example
const stockData = [
{ input: [100, 101, 102], output: [103] },
{ input: [101, 102, 103], output: [104] },
{ input: [102, 103, 104], output: [105] }
];
lstm.train(stockData, {
iterations: 10000,
errorThresh: 0.001
});/**
* Run inference on numerical sequences
* @param input - Numerical input sequence
* @returns Predicted numerical output
*/
run(input: number[] | number[][] | number[][][]): number[] | number[][] | number[][][];Usage Examples:
// 1D sequence prediction
const result = lstm.run([4, 5, 6]);
console.log(result); // [7] (approximate)
// 2D sequence prediction
const result2D = lstm.run([[4, 0.8], [5, 0.9], [6, 1.0]]);
console.log(result2D); // [[7, 1.1]] (approximate)
// Multi-step prediction
const sequence = [10, 11, 12];
const predictions = [];
let current = sequence;
for (let i = 0; i < 5; i++) {
const next = lstm.run(current);
predictions.push(next[0]);
current = [...current.slice(1), next[0]];
}
console.log(predictions); // [13, 14, 15, 16, 17] (approximate)/**
* Serialize RNN to JSON format
* @returns JSON representation
*/
toJSON(): RNNJson;
/**
* Load RNN from JSON format
* @param json - JSON representation
* @returns Network instance
*/
fromJSON(json: RNNJson): RNN;Usage Examples:
// Save trained RNN
const rnnData = lstm.toJSON();
const jsonString = JSON.stringify(rnnData);
// Load RNN
const newLSTM = new brain.recurrent.LSTMTimeStep();
newLSTM.fromJSON(JSON.parse(jsonString));
// Test loaded network
const output = newLSTM.run([1, 2, 3]);Utility for preparing data for recurrent network training.
/**
* Formats data for RNN training
* @param values - Input values to format
* @param maxThreshold - Maximum threshold for filtering
*/
class DataFormatter {
constructor(values: string[] | number[], maxThreshold?: number);
/** Build character tables from input data */
buildCharactersFromIterable(values: any[]): void;
/** Build lookup tables with threshold filtering */
buildTables(maxThreshold: number): void;
/** Index table for character/value mapping */
indexTable: object;
/** Character table for reverse mapping */
characterTable: object;
/** Array of unique characters/values */
characters: any[];
}Usage Examples:
// Format text data
const formatter = new brain.utilities.DataFormatter([
'hello world',
'goodbye world'
]);
console.log(formatter.characters); // ['h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', 'g', 'b', 'y']
console.log(formatter.indexTable); // {'h': 0, 'e': 1, ...}
// Use with RNN
const rnn = new brain.recurrent.LSTM();
rnn.dataFormatter = formatter;const textGenerator = new brain.recurrent.LSTM({
hiddenLayers: [50, 50],
learningRate: 0.01,
decayRate: 0.999
});
const texts = [
'The quick brown fox',
'jumps over the lazy dog',
'Pack my box with',
'five dozen liquor jugs'
];
textGenerator.train(texts, {
iterations: 10000,
log: (stats) => console.log(`Iteration ${stats.iterations}, Error: ${stats.error}`)
});
const generated = textGenerator.run('The quick');
console.log(generated);const timeSeries = new brain.recurrent.LSTMTimeStep({
inputSize: 1,
hiddenLayers: [20, 20],
outputSize: 1
});
// Generate sine wave data
const data = [];
for (let i = 0; i < 100; i++) {
const input = [];
const output = [];
for (let j = 0; j < 10; j++) {
input.push(Math.sin((i + j) * 0.1));
}
output.push(Math.sin((i + 10) * 0.1));
data.push({ input, output });
}
timeSeries.train(data, {
iterations: 5000,
errorThresh: 0.005
});
// Predict next values
const testInput = data[0].input;
const prediction = timeSeries.run(testInput);
console.log(`Predicted: ${prediction[0]}, Actual: ${data[0].output[0]}`);