or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

cross-validation.mdindex.mdneural-networks.mdrecurrent-networks.mdutilities.md
tile.json

neural-networks.mddocs/

Neural Networks

Feed-forward neural networks with backpropagation for classification and regression tasks. Brain.js provides both CPU and GPU-accelerated implementations with support for multiple activation functions and customizable architectures.

Capabilities

NeuralNetwork Class

Core feed-forward neural network implementation with backpropagation training.

/**
 * Creates a new neural network with optional configuration
 * @param options - Network configuration options
 */
class NeuralNetwork {
  constructor(options?: {
    /** Binary classification threshold (default: 0.5) */
    binaryThresh?: number;
    /** Array of hidden layer sizes (default: [3]) */
    hiddenLayers?: number[];
    /** Activation function (default: 'sigmoid') */
    activation?: 'sigmoid' | 'relu' | 'leaky-relu' | 'tanh';
    /** Alpha parameter for leaky-relu activation (default: 0.01) */
    leakyReluAlpha?: number;
  });
}

Usage Examples:

const brain = require('brain.js');

// Default network
const net = new brain.NeuralNetwork();

// Custom configuration
const customNet = new brain.NeuralNetwork({
  hiddenLayers: [4, 3],        // Two hidden layers
  activation: 'relu',          // ReLU activation
  binaryThresh: 0.6           // Higher threshold
});

Training Methods

Synchronous Training

/**
 * Train the network synchronously on provided data
 * @param data - Training data array or object
 * @param options - Training configuration options
 * @returns Training statistics including final error and iterations
 */
train(data: TrainingData[] | object, options?: TrainingOptions): TrainingStats;

interface TrainingData {
  input: number[];
  output: number[];
}

interface TrainingOptions {
  /** Maximum training iterations (default: 20000) */
  iterations?: number;
  /** Error threshold for early stopping (default: 0.005) */
  errorThresh?: number;
  /** Enable logging or provide custom log function (default: false) */
  log?: boolean | ((stats: TrainingStats) => void);
  /** Iterations between log outputs (default: 10) */
  logPeriod?: number;
  /** Learning rate (default: 0.3) */
  learningRate?: number;
  /** Momentum factor (default: 0.1) */
  momentum?: number;
  /** Periodic callback function during training */
  callback?: (stats: TrainingStats) => void;
  /** Iterations between callback calls (default: 10) */
  callbackPeriod?: number;
  /** Maximum training time in milliseconds (default: Infinity) */
  timeout?: number;
  /** Training algorithm/optimizer (default: null) */
  praxis?: any;
  /** Adam optimizer beta1 parameter (default: 0.9) */
  beta1?: number;
  /** Adam optimizer beta2 parameter (default: 0.999) */
  beta2?: number;
  /** Adam optimizer epsilon parameter (default: 1e-8) */
  epsilon?: number;
}

interface TrainingStats {
  /** Number of training iterations completed */
  iterations: number;
  /** Final training error */
  error: number;
}

Usage Examples:

// Basic training
const trainingData = [
  { input: [0, 0], output: [0] },
  { input: [0, 1], output: [1] },
  { input: [1, 0], output: [1] },
  { input: [1, 1], output: [0] }
];

const stats = net.train(trainingData);
console.log(`Trained in ${stats.iterations} iterations with error ${stats.error}`);

// Advanced training with options
const advancedStats = net.train(trainingData, {
  iterations: 50000,
  errorThresh: 0.001,
  log: true,
  logPeriod: 100,
  learningRate: 0.6,
  momentum: 0.5,
  timeout: 30000  // 30 seconds max
});

Asynchronous Training

/**
 * Train the network asynchronously on provided data
 * @param data - Training data array or object
 * @param options - Training configuration options
 * @returns Promise resolving to training statistics
 */
trainAsync(data: TrainingData[] | object, options?: TrainingOptions): Promise<TrainingStats>;

Usage Examples:

// Async training with await
const stats = await net.trainAsync(trainingData, {
  iterations: 10000,
  log: (stats) => console.log(`Iteration ${stats.iterations}, Error: ${stats.error}`)
});

// Async training with promise
net.trainAsync(trainingData)
  .then(stats => console.log('Training complete:', stats))
  .catch(error => console.error('Training failed:', error));

Inference Methods

Run Network

/**
 * Run the network on input data to get predictions
 * @param input - Input data array
 * @returns Network output array
 */
run(input: number[]): number[];

Usage Examples:

// Single prediction
const output = net.run([1, 0]);
console.log(output); // [0.987...]

// Multiple predictions
const inputs = [[0, 0], [0, 1], [1, 0], [1, 1]];
const outputs = inputs.map(input => net.run(input));

Testing and Evaluation

Test Method

/**
 * Test the network performance on a dataset
 * @param data - Test data with input/output pairs
 * @returns Test results including accuracy metrics
 */
test(data: TrainingData[]): TestResults;

interface TestResults {
  /** Mean squared error */
  error: number;
  /** Misclassification rate for binary classification */
  misclasses: TestCase[];
}

interface TestCase {
  input: number[];
  output: number[];
  actual: number[];
}

Usage Examples:

const testData = [
  { input: [0, 0], output: [0] },
  { input: [1, 1], output: [0] }
];

const results = net.test(testData);
console.log(`Test error: ${results.error}`);
console.log(`Misclassifications: ${results.misclasses.length}`);

Serialization Methods

JSON Serialization

/**
 * Serialize the network to JSON format
 * @returns JSON representation of the network
 */
toJSON(): NetworkJSON;

/**
 * Load network from JSON format
 * @param json - JSON representation of a network
 * @returns The network instance for chaining
 */
fromJSON(json: NetworkJSON): NeuralNetwork;

interface NetworkJSON {
  sizes: number[];
  layers: object[];
  outputLookup: any;
  inputLookup: any;
  activation: string;
  trainOpts: TrainingOptions;
  leakyReluAlpha?: number;
}

Usage Examples:

// Save network
const networkData = net.toJSON();
const jsonString = JSON.stringify(networkData);

// Load network
const loadedNet = new brain.NeuralNetwork();
loadedNet.fromJSON(JSON.parse(jsonString));

// Or create new network from JSON
const newNet = new brain.NeuralNetwork().fromJSON(networkData);

Function Generation

/**
 * Convert the trained network to a standalone JavaScript function
 * @returns Standalone function that can run without brain.js
 */
toFunction(): Function;

Usage Examples:

// Generate standalone function
const runNet = net.toFunction();

// Use generated function (no brain.js dependency required)
const output = runNet([1, 0]);
console.log(output);

// Save as string for later use
const functionString = runNet.toString();

Activation Functions

Set Activation Function

/**
 * Set the activation function for the network
 * @param activation - Activation function name
 */
setActivation(activation: 'sigmoid' | 'relu' | 'leaky-relu' | 'tanh'): void;

Activation Function Details:

  • sigmoid: S-shaped curve, outputs between 0 and 1
  • relu: Rectified Linear Unit, outputs max(0, x)
  • leaky-relu: Leaky ReLU with small negative slope
  • tanh: Hyperbolic tangent, outputs between -1 and 1

Usage Examples:

net.setActivation('relu');
net.setActivation('tanh');

NeuralNetworkGPU Class

GPU-accelerated version of NeuralNetwork using gpu.js for faster training.

/**
 * GPU-accelerated neural network with identical API to NeuralNetwork
 * Requires WebGL support or compatible GPU environment
 */
class NeuralNetworkGPU extends NeuralNetwork {
  // Inherits all methods from NeuralNetwork
  // Automatically uses GPU acceleration where available
}

Usage Examples:

// GPU-accelerated network (same API as regular NeuralNetwork)
const gpuNet = new brain.NeuralNetworkGPU({
  hiddenLayers: [10, 10],
  activation: 'relu'
});

// Training and inference work identically
gpuNet.train(trainingData);
const output = gpuNet.run([1, 0]);

Configuration Defaults

Network Defaults

{
  binaryThresh: 0.5,
  hiddenLayers: [3],
  activation: 'sigmoid',
  leakyReluAlpha: 0.01
}

Training Defaults

{
  iterations: 20000,
  errorThresh: 0.005,
  log: false,
  logPeriod: 10,
  learningRate: 0.3,
  momentum: 0.1,
  callback: null,
  callbackPeriod: 10,
  timeout: Infinity
}