Brain.js is a comprehensive JavaScript neural network library that provides multiple types of neural networks including feed-forward networks, recurrent neural networks (RNN), Long Short-Term Memory networks (LSTM), and Gated Recurrent Units (GRU). It enables developers to create, train, and run neural networks both in Node.js and browser environments, with built-in support for time series prediction, pattern recognition, and classification tasks.
npm install brain.jsbrowser.js, browser.min.js)ES6 modules:
import brain from 'brain.js';
// Or import specific components
import { NeuralNetwork, CrossValidate, likely, utilities } from 'brain.js';CommonJS:
const brain = require('brain.js');
// Or destructure specific components
const { NeuralNetwork, CrossValidate, likely, utilities } = require('brain.js');Browser (global):
<script src="node_modules/brain.js/browser.js"></script>
<script>
// brain is available as global variable
const net = new brain.NeuralNetwork();
</script>const brain = require('brain.js');
// Create a simple feed-forward neural network
const net = new brain.NeuralNetwork();
// Train the network with XOR data
net.train([
{ input: [0, 0], output: [0] },
{ input: [0, 1], output: [1] },
{ input: [1, 0], output: [1] },
{ input: [1, 1], output: [0] }
]);
// Run the network
const output = net.run([1, 0]); // approximately [0.987]
// Train a recurrent network for sequences
const rnn = new brain.recurrent.LSTM();
rnn.train(['hello world', 'goodbye world']);
const result = rnn.run('hello'); // continues the patternBrain.js is organized around several key components:
Feed-forward neural networks with backpropagation for classification and regression tasks. Supports multiple activation functions and customizable architectures.
class NeuralNetwork {
constructor(options?: {
/** Binary classification threshold (default: 0.5) */
binaryThresh?: number;
/** Array of hidden layer sizes (default: [3]) */
hiddenLayers?: number[];
/** Activation function (default: 'sigmoid') */
activation?: 'sigmoid' | 'relu' | 'leaky-relu' | 'tanh';
/** Alpha parameter for leaky-relu activation (default: 0.01) */
leakyReluAlpha?: number;
});
train(data, options?: TrainingOptions): TrainingStats;
trainAsync(data, options?: TrainingOptions): Promise<TrainingStats>;
run(input: number[]): number[];
test(data): TestResults;
toJSON(): NetworkJSON;
fromJSON(json: NetworkJSON): NeuralNetwork;
toFunction(): Function;
}
interface TrainingOptions {
iterations?: number;
errorThresh?: number;
log?: boolean | Function;
logPeriod?: number;
learningRate?: number;
momentum?: number;
callback?: Function;
callbackPeriod?: number;
timeout?: number;
}
interface TrainingStats {
iterations: number;
error: number;
}Recurrent neural networks for sequence processing, time series prediction, and natural language tasks. Includes RNN, LSTM, and GRU variants.
namespace recurrent {
class RNN {
constructor(options?: RNNOptions);
train(data: string[] | SequenceData[], options?: TrainingOptions): TrainingStats;
run(input: string | number[]): string | number[];
}
class LSTM extends RNN {}
class GRU extends RNN {}
class RNNTimeStep extends RNN {
run(input: number[] | number[][]): number[] | number[][];
train(data: TimeStepData[], options?: TrainingOptions): TrainingStats;
}
class LSTMTimeStep extends RNNTimeStep {}
class GRUTimeStep extends RNNTimeStep {}
}
interface RNNOptions {
inputSize?: number;
inputRange?: number;
hiddenLayers?: number[];
outputSize?: number;
learningRate?: number;
decayRate?: number;
}Model evaluation and hyperparameter tuning through k-fold cross-validation.
class CrossValidate {
constructor(NetworkClass, options?);
train(data, trainingOptions, k?: number): ValidationStats;
testPartition(trainOpts, trainSet, testSet): PartitionResults;
toJSON(): CrossValidateJSON;
fromJSON(json): NeuralNetwork;
toNeuralNetwork(): NeuralNetwork;
}
interface ValidationStats {
truePos: number;
trueNeg: number;
falsePos: number;
falseNeg: number;
total: number;
}Data processing utilities, lookup tables, mathematical functions, and streaming capabilities.
// Utility functions
function likely(input, network: NeuralNetwork): string;
class lookup {
static buildLookup(hashes: object[]): object;
static lookupFromHash(hash: object): object;
static toArray(lookup: object, hash: object): number[];
static toHash(lookup: object, array: number[]): object;
static lookupFromArray(array): object;
}
class TrainStream {
constructor(options: { neuralNetwork: NeuralNetwork });
write(data): void;
endInputs(): void;
}
// Mathematical utilities
const utilities: {
max: Function;
mse: Function;
ones: Function;
random: Function;
randomWeight: Function;
randos: Function;
range: Function;
toArray: Function;
DataFormatter: Class;
zeros: Function;
};class NeuralNetworkGPU extends NeuralNetwork {
// Same interface as NeuralNetwork but with GPU acceleration
}The NeuralNetworkGPU class provides the same API as NeuralNetwork but leverages GPU acceleration through gpu.js for faster training on supported hardware.