Model evaluation and hyperparameter tuning through k-fold cross-validation. The CrossValidate class helps assess neural network performance and find optimal configurations by systematically testing different parameter combinations across multiple data partitions.
Implements k-fold cross-validation for neural network evaluation and hyperparameter optimization.
/**
* Cross-validation utility for neural network evaluation
* @param NetworkClass - Neural network constructor (NeuralNetwork, LSTM, etc.)
* @param options - Network configuration options
*/
class CrossValidate {
constructor(NetworkClass: typeof NeuralNetwork, options?: NetworkOptions);
}
interface NetworkOptions {
// Options depend on the network type being validated
hiddenLayers?: number[];
activation?: string;
learningRate?: number;
// ... other network-specific options
}Usage Examples:
const brain = require('brain.js');
// Cross-validate a basic neural network
const crossValidator = new brain.CrossValidate(brain.NeuralNetwork, {
hiddenLayers: [4, 3],
activation: 'relu'
});
// Cross-validate an LSTM
const lstmValidator = new brain.CrossValidate(brain.recurrent.LSTM, {
hiddenLayers: [10, 10],
learningRate: 0.01
});/**
* Perform k-fold cross-validation training
* @param data - Training data array
* @param trainingOptions - Training configuration options
* @param k - Number of folds for cross-validation (default: 4)
* @returns Validation statistics aggregated across all folds
*/
train(data: TrainingData[], trainingOptions: TrainingOptions, k?: number): ValidationStats;
interface TrainingData {
input: number[];
output: number[];
}
interface TrainingOptions {
iterations?: number;
errorThresh?: number;
learningRate?: number;
momentum?: number;
log?: boolean;
// ... other training options
}
interface ValidationStats {
/** True positive predictions */
truePos: number;
/** True negative predictions */
trueNeg: number;
/** False positive predictions */
falsePos: number;
/** False negative predictions */
falseNeg: number;
/** Total number of test cases */
total: number;
/** Average accuracy across all folds */
accuracy: number;
/** Precision metric */
precision: number;
/** Recall metric */
recall: number;
/** F1 score */
f1Score: number;
}Usage Examples:
// Prepare training data
const data = [
{ input: [0, 0], output: [0] },
{ input: [0, 1], output: [1] },
{ input: [1, 0], output: [1] },
{ input: [1, 1], output: [0] },
// ... more training examples
];
// Basic cross-validation
const validator = new brain.CrossValidate(brain.NeuralNetwork);
const stats = validator.train(data, {
iterations: 5000,
errorThresh: 0.01
});
console.log(`Accuracy: ${stats.accuracy}`);
console.log(`Precision: ${stats.precision}`);
console.log(`Recall: ${stats.recall}`);
console.log(`F1 Score: ${stats.f1Score}`);
// 10-fold cross-validation with custom options
const detailedStats = validator.train(data, {
iterations: 10000,
learningRate: 0.3,
momentum: 0.1,
log: true
}, 10);
console.log(`Total samples: ${detailedStats.total}`);
console.log(`True positives: ${detailedStats.truePos}`);
console.log(`False positives: ${detailedStats.falsePos}`);/**
* Test a single partition of the cross-validation
* @param trainOpts - Training options for this partition
* @param trainSet - Training data subset
* @param testSet - Test data subset
* @returns Results for this specific partition
*/
testPartition(trainOpts: TrainingOptions, trainSet: TrainingData[], testSet: TrainingData[]): PartitionResults;
interface PartitionResults {
/** Time taken for training in milliseconds */
trainTime: number;
/** Time taken for testing in milliseconds */
testTime: number;
/** Number of training iterations completed */
iterations: number;
/** Final training error */
trainError: number;
/** Learning rate used */
learningRate: number;
/** Hidden layer configuration */
hidden: number[];
/** Trained network instance */
network: NeuralNetwork;
/** Test accuracy metrics */
accuracy: number;
misclasses: TestCase[];
}
interface TestCase {
input: number[];
output: number[];
actual: number[];
}Usage Examples:
// Manual partition testing
const trainSet = data.slice(0, 80); // 80% for training
const testSet = data.slice(80); // 20% for testing
const partitionResult = validator.testPartition(
{ iterations: 5000, learningRate: 0.3 },
trainSet,
testSet
);
console.log(`Training took ${partitionResult.trainTime}ms`);
console.log(`Testing took ${partitionResult.testTime}ms`);
console.log(`Training iterations: ${partitionResult.iterations}`);
console.log(`Final error: ${partitionResult.trainError}`);
console.log(`Test accuracy: ${partitionResult.accuracy}`);
// Access the trained network
const trainedNet = partitionResult.network;
const prediction = trainedNet.run([1, 0]);/**
* Serialize cross-validation results to JSON
* @returns JSON representation of validation results
*/
toJSON(): CrossValidateJSON;
/**
* Create neural network from cross-validation results
* @param json - Cross-validation JSON data
* @returns Best performing network from validation
*/
fromJSON(json: CrossValidateJSON): NeuralNetwork;
interface CrossValidateJSON {
/** Average results across all folds */
avgs: PartitionResults;
/** Aggregated validation statistics */
stats: ValidationStats;
/** Results for each individual fold */
sets: PartitionResults[];
}Usage Examples:
// Save cross-validation results
const validationData = validator.toJSON();
const jsonString = JSON.stringify(validationData);
// Examine detailed results
console.log('Average results:', validationData.avgs);
console.log('Overall stats:', validationData.stats);
console.log('Individual folds:', validationData.sets);
// Load and create network from best results
const bestNetwork = validator.fromJSON(validationData);
const output = bestNetwork.run([1, 0]);/**
* Convert cross-validation results to a trained neural network
* @returns Best performing network from validation process
*/
toNeuralNetwork(): NeuralNetwork;Usage Examples:
// Get best network after cross-validation
const validator = new brain.CrossValidate(brain.NeuralNetwork, {
hiddenLayers: [5, 3]
});
validator.train(data, { iterations: 5000 });
const bestNet = validator.toNeuralNetwork();
// Use the best network for predictions
const prediction = bestNet.run([0.5, 0.8]);
console.log('Best network prediction:', prediction);
// Save the best network
const networkJSON = bestNet.toJSON();const brain = require('brain.js');
// Test different network configurations
const configurations = [
{ hiddenLayers: [3], activation: 'sigmoid' },
{ hiddenLayers: [5, 3], activation: 'sigmoid' },
{ hiddenLayers: [4, 4], activation: 'relu' },
{ hiddenLayers: [8, 4, 2], activation: 'tanh' }
];
const data = [
// ... your training data
];
let bestConfig = null;
let bestAccuracy = 0;
for (const config of configurations) {
const validator = new brain.CrossValidate(brain.NeuralNetwork, config);
const stats = validator.train(data, {
iterations: 5000,
errorThresh: 0.01
}, 5); // 5-fold cross-validation
console.log(`Config ${JSON.stringify(config)}: Accuracy ${stats.accuracy}`);
if (stats.accuracy > bestAccuracy) {
bestAccuracy = stats.accuracy;
bestConfig = config;
}
}
console.log(`Best configuration:`, bestConfig);
console.log(`Best accuracy: ${bestAccuracy}`);const learningRates = [0.1, 0.3, 0.5, 0.7, 0.9];
const results = [];
for (const lr of learningRates) {
const validator = new brain.CrossValidate(brain.NeuralNetwork, {
hiddenLayers: [4, 3]
});
const stats = validator.train(data, {
iterations: 3000,
learningRate: lr
});
results.push({
learningRate: lr,
accuracy: stats.accuracy,
f1Score: stats.f1Score
});
}
// Find optimal learning rate
const bestLR = results.reduce((best, current) =>
current.accuracy > best.accuracy ? current : best
);
console.log(`Optimal learning rate: ${bestLR.learningRate}`);
console.log(`Accuracy: ${bestLR.accuracy}`);// Compare different network types
const networkTypes = [
{ type: brain.NeuralNetwork, name: 'FeedForward' },
{ type: brain.recurrent.RNN, name: 'RNN' },
{ type: brain.recurrent.LSTM, name: 'LSTM' }
];
const comparisonResults = [];
for (const { type: NetworkType, name } of networkTypes) {
const validator = new brain.CrossValidate(NetworkType, {
hiddenLayers: [10, 10]
});
const stats = validator.train(data, {
iterations: 5000
});
comparisonResults.push({
networkType: name,
accuracy: stats.accuracy,
precision: stats.precision,
recall: stats.recall,
f1Score: stats.f1Score
});
}
// Display comparison
console.table(comparisonResults);
// Find best performing network type
const bestNetwork = comparisonResults.reduce((best, current) =>
current.f1Score > best.f1Score ? current : best
);
console.log(`Best network type: ${bestNetwork.networkType}`);const validator = new brain.CrossValidate(brain.NeuralNetwork);
// Perform multiple validation runs for statistical significance
const runs = 10;
const accuracies = [];
for (let i = 0; i < runs; i++) {
const stats = validator.train(data, {
iterations: 5000
});
accuracies.push(stats.accuracy);
}
// Calculate statistics
const mean = accuracies.reduce((sum, acc) => sum + acc, 0) / runs;
const variance = accuracies.reduce((sum, acc) => sum + Math.pow(acc - mean, 2), 0) / runs;
const stdDev = Math.sqrt(variance);
console.log(`Mean accuracy: ${mean.toFixed(4)}`);
console.log(`Standard deviation: ${stdDev.toFixed(4)}`);
console.log(`95% confidence interval: ${(mean - 1.96 * stdDev).toFixed(4)} - ${(mean + 1.96 * stdDev).toFixed(4)}`);