or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

cross-validation.mdindex.mdneural-networks.mdrecurrent-networks.mdutilities.md
tile.json

cross-validation.mddocs/

Cross-Validation

Model evaluation and hyperparameter tuning through k-fold cross-validation. The CrossValidate class helps assess neural network performance and find optimal configurations by systematically testing different parameter combinations across multiple data partitions.

Capabilities

CrossValidate Class

Implements k-fold cross-validation for neural network evaluation and hyperparameter optimization.

/**
 * Cross-validation utility for neural network evaluation
 * @param NetworkClass - Neural network constructor (NeuralNetwork, LSTM, etc.)
 * @param options - Network configuration options
 */
class CrossValidate {
  constructor(NetworkClass: typeof NeuralNetwork, options?: NetworkOptions);
}

interface NetworkOptions {
  // Options depend on the network type being validated
  hiddenLayers?: number[];
  activation?: string;
  learningRate?: number;
  // ... other network-specific options
}

Usage Examples:

const brain = require('brain.js');

// Cross-validate a basic neural network
const crossValidator = new brain.CrossValidate(brain.NeuralNetwork, {
  hiddenLayers: [4, 3],
  activation: 'relu'
});

// Cross-validate an LSTM
const lstmValidator = new brain.CrossValidate(brain.recurrent.LSTM, {
  hiddenLayers: [10, 10],
  learningRate: 0.01
});

Training with Cross-Validation

Cross-Validation Training

/**
 * Perform k-fold cross-validation training
 * @param data - Training data array
 * @param trainingOptions - Training configuration options
 * @param k - Number of folds for cross-validation (default: 4)
 * @returns Validation statistics aggregated across all folds
 */
train(data: TrainingData[], trainingOptions: TrainingOptions, k?: number): ValidationStats;

interface TrainingData {
  input: number[];
  output: number[];
}

interface TrainingOptions {
  iterations?: number;
  errorThresh?: number;
  learningRate?: number;
  momentum?: number;
  log?: boolean;
  // ... other training options
}

interface ValidationStats {
  /** True positive predictions */
  truePos: number;
  /** True negative predictions */
  trueNeg: number;
  /** False positive predictions */
  falsePos: number;
  /** False negative predictions */
  falseNeg: number;
  /** Total number of test cases */
  total: number;
  /** Average accuracy across all folds */
  accuracy: number;
  /** Precision metric */
  precision: number;
  /** Recall metric */
  recall: number;
  /** F1 score */
  f1Score: number;
}

Usage Examples:

// Prepare training data
const data = [
  { input: [0, 0], output: [0] },
  { input: [0, 1], output: [1] },
  { input: [1, 0], output: [1] },
  { input: [1, 1], output: [0] },
  // ... more training examples
];

// Basic cross-validation
const validator = new brain.CrossValidate(brain.NeuralNetwork);
const stats = validator.train(data, {
  iterations: 5000,
  errorThresh: 0.01
});

console.log(`Accuracy: ${stats.accuracy}`);
console.log(`Precision: ${stats.precision}`);
console.log(`Recall: ${stats.recall}`);
console.log(`F1 Score: ${stats.f1Score}`);

// 10-fold cross-validation with custom options
const detailedStats = validator.train(data, {
  iterations: 10000,
  learningRate: 0.3,
  momentum: 0.1,
  log: true
}, 10);

console.log(`Total samples: ${detailedStats.total}`);
console.log(`True positives: ${detailedStats.truePos}`);
console.log(`False positives: ${detailedStats.falsePos}`);

Partition Testing

Test Single Partition

/**
 * Test a single partition of the cross-validation
 * @param trainOpts - Training options for this partition
 * @param trainSet - Training data subset
 * @param testSet - Test data subset
 * @returns Results for this specific partition
 */
testPartition(trainOpts: TrainingOptions, trainSet: TrainingData[], testSet: TrainingData[]): PartitionResults;

interface PartitionResults {
  /** Time taken for training in milliseconds */
  trainTime: number;
  /** Time taken for testing in milliseconds */
  testTime: number;
  /** Number of training iterations completed */
  iterations: number;
  /** Final training error */
  trainError: number;
  /** Learning rate used */
  learningRate: number;
  /** Hidden layer configuration */
  hidden: number[];
  /** Trained network instance */
  network: NeuralNetwork;
  /** Test accuracy metrics */
  accuracy: number;
  misclasses: TestCase[];
}

interface TestCase {
  input: number[];
  output: number[];
  actual: number[];
}

Usage Examples:

// Manual partition testing
const trainSet = data.slice(0, 80);  // 80% for training
const testSet = data.slice(80);      // 20% for testing

const partitionResult = validator.testPartition(
  { iterations: 5000, learningRate: 0.3 },
  trainSet,
  testSet
);

console.log(`Training took ${partitionResult.trainTime}ms`);
console.log(`Testing took ${partitionResult.testTime}ms`);
console.log(`Training iterations: ${partitionResult.iterations}`);
console.log(`Final error: ${partitionResult.trainError}`);
console.log(`Test accuracy: ${partitionResult.accuracy}`);

// Access the trained network
const trainedNet = partitionResult.network;
const prediction = trainedNet.run([1, 0]);

Serialization

JSON Serialization

/**
 * Serialize cross-validation results to JSON
 * @returns JSON representation of validation results
 */
toJSON(): CrossValidateJSON;

/**
 * Create neural network from cross-validation results
 * @param json - Cross-validation JSON data
 * @returns Best performing network from validation
 */
fromJSON(json: CrossValidateJSON): NeuralNetwork;

interface CrossValidateJSON {
  /** Average results across all folds */
  avgs: PartitionResults;
  /** Aggregated validation statistics */
  stats: ValidationStats;
  /** Results for each individual fold */
  sets: PartitionResults[];
}

Usage Examples:

// Save cross-validation results
const validationData = validator.toJSON();
const jsonString = JSON.stringify(validationData);

// Examine detailed results
console.log('Average results:', validationData.avgs);
console.log('Overall stats:', validationData.stats);
console.log('Individual folds:', validationData.sets);

// Load and create network from best results
const bestNetwork = validator.fromJSON(validationData);
const output = bestNetwork.run([1, 0]);

Network Conversion

Convert to Neural Network

/**
 * Convert cross-validation results to a trained neural network
 * @returns Best performing network from validation process
 */
toNeuralNetwork(): NeuralNetwork;

Usage Examples:

// Get best network after cross-validation
const validator = new brain.CrossValidate(brain.NeuralNetwork, {
  hiddenLayers: [5, 3]
});

validator.train(data, { iterations: 5000 });
const bestNet = validator.toNeuralNetwork();

// Use the best network for predictions
const prediction = bestNet.run([0.5, 0.8]);
console.log('Best network prediction:', prediction);

// Save the best network
const networkJSON = bestNet.toJSON();

Advanced Usage Examples

Hyperparameter Optimization

const brain = require('brain.js');

// Test different network configurations
const configurations = [
  { hiddenLayers: [3], activation: 'sigmoid' },
  { hiddenLayers: [5, 3], activation: 'sigmoid' },
  { hiddenLayers: [4, 4], activation: 'relu' },
  { hiddenLayers: [8, 4, 2], activation: 'tanh' }
];

const data = [
  // ... your training data
];

let bestConfig = null;
let bestAccuracy = 0;

for (const config of configurations) {
  const validator = new brain.CrossValidate(brain.NeuralNetwork, config);
  
  const stats = validator.train(data, {
    iterations: 5000,
    errorThresh: 0.01
  }, 5); // 5-fold cross-validation
  
  console.log(`Config ${JSON.stringify(config)}: Accuracy ${stats.accuracy}`);
  
  if (stats.accuracy > bestAccuracy) {
    bestAccuracy = stats.accuracy;
    bestConfig = config;
  }
}

console.log(`Best configuration:`, bestConfig);
console.log(`Best accuracy: ${bestAccuracy}`);

Learning Rate Optimization

const learningRates = [0.1, 0.3, 0.5, 0.7, 0.9];
const results = [];

for (const lr of learningRates) {
  const validator = new brain.CrossValidate(brain.NeuralNetwork, {
    hiddenLayers: [4, 3]
  });
  
  const stats = validator.train(data, {
    iterations: 3000,
    learningRate: lr
  });
  
  results.push({
    learningRate: lr,
    accuracy: stats.accuracy,
    f1Score: stats.f1Score
  });
}

// Find optimal learning rate
const bestLR = results.reduce((best, current) => 
  current.accuracy > best.accuracy ? current : best
);

console.log(`Optimal learning rate: ${bestLR.learningRate}`);
console.log(`Accuracy: ${bestLR.accuracy}`);

Model Comparison

// Compare different network types
const networkTypes = [
  { type: brain.NeuralNetwork, name: 'FeedForward' },
  { type: brain.recurrent.RNN, name: 'RNN' },
  { type: brain.recurrent.LSTM, name: 'LSTM' }
];

const comparisonResults = [];

for (const { type: NetworkType, name } of networkTypes) {
  const validator = new brain.CrossValidate(NetworkType, {
    hiddenLayers: [10, 10]
  });
  
  const stats = validator.train(data, {
    iterations: 5000
  });
  
  comparisonResults.push({
    networkType: name,
    accuracy: stats.accuracy,
    precision: stats.precision,
    recall: stats.recall,
    f1Score: stats.f1Score
  });
}

// Display comparison
console.table(comparisonResults);

// Find best performing network type
const bestNetwork = comparisonResults.reduce((best, current) =>
  current.f1Score > best.f1Score ? current : best
);

console.log(`Best network type: ${bestNetwork.networkType}`);

Statistical Analysis

const validator = new brain.CrossValidate(brain.NeuralNetwork);

// Perform multiple validation runs for statistical significance
const runs = 10;
const accuracies = [];

for (let i = 0; i < runs; i++) {
  const stats = validator.train(data, {
    iterations: 5000
  });
  accuracies.push(stats.accuracy);
}

// Calculate statistics
const mean = accuracies.reduce((sum, acc) => sum + acc, 0) / runs;
const variance = accuracies.reduce((sum, acc) => sum + Math.pow(acc - mean, 2), 0) / runs;
const stdDev = Math.sqrt(variance);

console.log(`Mean accuracy: ${mean.toFixed(4)}`);
console.log(`Standard deviation: ${stdDev.toFixed(4)}`);
console.log(`95% confidence interval: ${(mean - 1.96 * stdDev).toFixed(4)} - ${(mean + 1.96 * stdDev).toFixed(4)}`);