CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/maven-com-github-haifengl--smile-core

Statistical Machine Intelligence and Learning Engine providing comprehensive machine learning algorithms for classification, regression, clustering, and feature engineering in Java

Pending
Overview
Eval results
Files

validation-metrics.mddocs/

Validation and Metrics

Comprehensive model validation framework with cross-validation, bootstrap sampling, and extensive performance metrics for classification, regression, and clustering tasks. Smile Core provides robust tools for model evaluation and comparison.

Capabilities

Cross-Validation Framework

Core framework for model validation using various cross-validation strategies.

/**
 * Cross-validation utilities for model evaluation
 */
class CrossValidation {
    /** Create k-fold cross-validation splits */
    public static Bag[] of(int n, int k);
    
    /** Create stratified k-fold cross-validation splits */
    public static Bag[] stratify(int n, int k, int[] y);
    
    /** Classification cross-validation */
    public static <T, M extends Classifier<T>> ClassificationValidations classification(
        int k, Classifier.Trainer<T, M> trainer, T[] x, int[] y);
    
    /** Classification cross-validation with Formula */
    public static ClassificationValidations classification(
        int k, BiFunction<Formula, DataFrame, Classifier<Tuple>> trainer, Formula formula, DataFrame data);
    
    /** Regression cross-validation */
    public static <T, M extends Regression<T>> RegressionValidations regression(
        int k, Regression.Trainer<T, M> trainer, T[] x, double[] y);
    
    /** Regression cross-validation with Formula */
    public static RegressionValidations regression(
        int k, BiFunction<Formula, DataFrame, Regression<Tuple>> trainer, Formula formula, DataFrame data);
}

/**
 * Data bag containing training and test indices
 */
class Bag {
    /** Training sample indices */
    public final int[] samples;
    
    /** Out-of-bag (test) sample indices */
    public final int[] oob;
    
    /** Get training data subset */
    public <T> T[] trainSet(T[] data);
    
    /** Get test data subset */
    public <T> T[] testSet(T[] data);
}

Usage Example:

import smile.validation.CrossValidation;
import smile.classification.RandomForest;
import smile.data.formula.Formula;

// 10-fold cross-validation with DataFrame
Formula formula = Formula.lhs("target");
var results = CrossValidation.classification(10, RandomForest::fit, formula, data);

System.out.println("Accuracy: " + results.avg.accuracy);
System.out.println("Error: " + results.avg.error);

// Stratified cross-validation splits
var stratifiedSplits = CrossValidation.stratify(data.size(), 5, labels);

Bootstrap Sampling

Bootstrap resampling methods for model validation and uncertainty estimation.

/**
 * Bootstrap sampling for model validation
 */
class Bootstrap {
    /** Create bootstrap samples */
    public static Bag[] of(int n, int size, int subsampleSize);
    
    /** Bootstrap with replacement */
    public static Bag[] of(int n, int size);
    
    /** Bootstrap validation for classification */
    public static <T, M extends Classifier<T>> ClassificationValidations classification(
        int round, Classifier.Trainer<T, M> trainer, T[] x, int[] y);
    
    /** Bootstrap validation for regression */
    public static <T, M extends Regression<T>> RegressionValidations regression(
        int round, Regression.Trainer<T, M> trainer, T[] x, double[] y);
}

Leave-One-Out Cross-Validation

Specialized cross-validation where each sample is used as test set once.

/**
 * Leave-one-out cross-validation
 */
class LOOCV implements CrossValidation {
    /** Create LOOCV splits */
    public Bag[] split(int n);
    
    /** Classification LOOCV */
    public static <T> ClassificationValidation classification(
        Classifier.Trainer<T, ?> trainer, T[] x, int[] y);
    
    /** Regression LOOCV */
    public static <T> RegressionValidation regression(
        Regression.Trainer<T, ?> trainer, T[] x, double[] y);
}

Model Selection

Utilities for comparing and selecting optimal models and hyperparameters.

/**
 * Model selection utilities
 */
class ModelSelection {
    /** Grid search with cross-validation */
    public static <T, M> M gridSearch(
        BiFunction<T[], int[], M> trainer,
        T[] x, int[] y,
        CrossValidation cv,
        Map<String, Object[]> paramGrid);
    
    /** Random search with cross-validation */
    public static <T, M> M randomSearch(
        BiFunction<T[], int[], M> trainer,
        T[] x, int[] y,
        CrossValidation cv,
        Map<String, Distribution> paramDist,
        int nIter);
    
    /** Bayesian optimization for hyperparameter tuning */
    public static <T, M> M bayesianOptimization(
        BiFunction<T[], int[], M> trainer,
        T[] x, int[] y,
        CrossValidation cv,
        Map<String, Double[]> bounds,
        int nIter);
}

Classification Validation Results

Classes for storing and analyzing classification validation results.

/**
 * Single classification validation result
 */
class ClassificationValidation {
    /** Accuracy score */
    public final double accuracy;
    
    /** Error rate */
    public final double error;
    
    /** Confusion matrix */
    public final ConfusionMatrix confusion;
    
    /** Class-wise precision scores */
    public final double[] precision;
    
    /** Class-wise recall scores */
    public final double[] recall;
    
    /** Class-wise F1 scores */
    public final double[] f1;
    
    /** Matthews correlation coefficient */
    public final double mcc;
}

/**
 * Multiple classification validation results
 */
class ClassificationValidations {
    /** Individual fold results */
    public final ClassificationValidation[] rounds;
    
    /** Average validation metrics */
    public final ClassificationValidation avg;
    
    /** Standard deviation of metrics */
    public final ClassificationValidation std;
    
    /** Get confidence interval for metric */
    public double[] confidenceInterval(String metric, double confidence);
}

/**
 * Classification metrics container
 */
class ClassificationMetrics {
    /** Calculate all metrics from predictions */
    public static ClassificationMetrics of(int[] truth, int[] prediction);
    
    /** Get accuracy */
    public double getAccuracy();
    
    /** Get error rate */
    public double getError();
    
    /** Get macro-averaged F1 score */
    public double getMacroF1();
    
    /** Get weighted F1 score */
    public double getWeightedF1();
}

Regression Validation Results

Classes for storing and analyzing regression validation results.

/**
 * Single regression validation result
 */
class RegressionValidation {
    /** Root mean square error */
    public final double rmse;
    
    /** Mean absolute error */
    public final double mae;
    
    /** Mean absolute deviation */
    public final double mad;
    
    /** R-squared coefficient */
    public final double r2;
    
    /** Adjusted R-squared */
    public final double adjustedR2;
    
    /** Residual sum of squares */
    public final double rss;
    
    /** Total sum of squares */
    public final double tss;
}

/**
 * Multiple regression validation results
 */
class RegressionValidations {
    /** Individual fold results */
    public final RegressionValidation[] rounds;
    
    /** Average validation metrics */
    public final RegressionValidation avg;
    
    /** Standard deviation of metrics */
    public final RegressionValidation std;
}

/**
 * Regression metrics container
 */
class RegressionMetrics {
    /** Calculate all metrics from predictions */
    public static RegressionMetrics of(double[] truth, double[] prediction);
    
    /** Get RMSE */
    public double getRMSE();
    
    /** Get MAE */
    public double getMAE();
    
    /** Get R-squared */
    public double getR2();
}

Classification Metrics

Individual classification metrics for detailed model evaluation.

/**
 * Base classification metric interface
 */
interface ClassificationMetric {
    /** Calculate metric from true and predicted labels */
    double score(int[] truth, int[] prediction);
}

/**
 * Probabilistic classification metric interface
 */
interface ProbabilisticClassificationMetric {
    /** Calculate metric from true labels and predicted probabilities */
    double score(int[] truth, double[][] probability);
}

/**
 * Accuracy metric
 */
class Accuracy implements ClassificationMetric {
    /** Calculate accuracy */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Error rate metric
 */
class Error implements ClassificationMetric {
    /** Calculate error rate */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Precision metric
 */
class Precision implements ClassificationMetric {
    /** Calculate macro-averaged precision */
    public static double of(int[] truth, int[] prediction);
    
    /** Calculate class-specific precision */
    public static double[] byClass(int[] truth, int[] prediction);
}

/**
 * Recall (Sensitivity) metric
 */
class Recall implements ClassificationMetric {
    /** Calculate macro-averaged recall */
    public static double of(int[] truth, int[] prediction);
    
    /** Calculate class-specific recall */
    public static double[] byClass(int[] truth, int[] prediction);
}

/**
 * Specificity metric
 */
class Specificity implements ClassificationMetric {
    /** Calculate macro-averaged specificity */
    public static double of(int[] truth, int[] prediction);
    
    /** Calculate class-specific specificity */
    public static double[] byClass(int[] truth, int[] prediction);
}

/**
 * F-score metric
 */
class FScore implements ClassificationMetric {
    /** Calculate macro-averaged F1 score */
    public static double of(int[] truth, int[] prediction);
    
    /** Calculate F-beta score */
    public static double of(int[] truth, int[] prediction, double beta);
    
    /** Calculate class-specific F1 scores */
    public static double[] byClass(int[] truth, int[] prediction);
}

/**
 * Matthews Correlation Coefficient
 */
class MatthewsCorrelation implements ClassificationMetric {
    /** Calculate MCC */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Area Under ROC Curve
 */
class AUC implements ProbabilisticClassificationMetric {
    /** Calculate AUC for binary classification */
    public static double of(int[] truth, double[] probability);
    
    /** Calculate multi-class AUC (one-vs-rest) */
    public static double of(int[] truth, double[][] probability);
}

/**
 * Cross Entropy loss
 */
class CrossEntropy implements ProbabilisticClassificationMetric {
    /** Calculate cross entropy loss */
    public static double of(int[] truth, double[][] probability);
}

/**
 * Logarithmic loss
 */
class LogLoss implements ProbabilisticClassificationMetric {
    /** Calculate log loss */
    public static double of(int[] truth, double[][] probability);
}

/**
 * Confusion Matrix
 */
class ConfusionMatrix {
    /** Create confusion matrix */
    public static ConfusionMatrix of(int[] truth, int[] prediction);
    
    /** The confusion matrix */
    public final int[][] matrix;
    
    /** Number of classes */
    public final int classes;
    
    /** Get accuracy from confusion matrix */
    public double accuracy();
    
    /** Get error rate */
    public double error();
    
    /** Get class-specific precision */
    public double[] precision();
    
    /** Get class-specific recall */
    public double[] recall();
}

/**
 * False Discovery Rate
 */
class FDR implements ClassificationMetric {
    /** Calculate false discovery rate */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Fallout (False Positive Rate)
 */
class Fallout implements ClassificationMetric {
    /** Calculate fallout */
    public static double of(int[] truth, int[] prediction);
}

Regression Metrics

Individual regression metrics for model evaluation.

/**
 * Base regression metric interface
 */
interface RegressionMetric {
    /** Calculate metric from true and predicted values */
    double score(double[] truth, double[] prediction);
}

/**
 * Mean Squared Error
 */
class MSE implements RegressionMetric {
    /** Calculate MSE */
    public static double of(double[] truth, double[] prediction);
}

/**
 * Root Mean Squared Error
 */
class RMSE implements RegressionMetric {
    /** Calculate RMSE */
    public static double of(double[] truth, double[] prediction);
}

/**
 * Mean Absolute Error
 */
class MAE implements RegressionMetric {
    /** Calculate MAE */
    public static double of(double[] truth, double[] prediction);
}

/**
 * Mean Absolute Deviation
 */
class MAD implements RegressionMetric {
    /** Calculate MAD */
    public static double of(double[] truth, double[] prediction);
}

/**
 * Residual Sum of Squares
 */
class RSS implements RegressionMetric {
    /** Calculate RSS */
    public static double of(double[] truth, double[] prediction);
}

/**
 * R-squared coefficient of determination
 */
class R2 implements RegressionMetric {
    /** Calculate R-squared */
    public static double of(double[] truth, double[] prediction);
    
    /** Calculate adjusted R-squared */
    public static double adjusted(double[] truth, double[] prediction, int p);
}

Clustering Metrics

Metrics for evaluating clustering quality and comparing clustering results.

/**
 * Base clustering metric interface
 */
interface ClusteringMetric {
    /** Calculate metric from true and predicted cluster labels */
    double score(int[] truth, int[] prediction);
}

/**
 * Rand Index for clustering comparison
 */
class RandIndex implements ClusteringMetric {
    /** Calculate Rand Index */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Adjusted Rand Index
 */
class AdjustedRandIndex implements ClusteringMetric {
    /** Calculate Adjusted Rand Index */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Mutual Information between clusterings
 */
class MutualInformation implements ClusteringMetric {
    /** Calculate mutual information */
    public static double of(int[] truth, int[] prediction);
}

/**
 * Normalized Mutual Information
 */
class NormalizedMutualInformation implements ClusteringMetric {
    /** Normalization methods */
    enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }
    
    /** Calculate NMI with arithmetic mean normalization */
    public static double of(int[] truth, int[] prediction);
    
    /** Calculate NMI with specified normalization */
    public static double of(int[] truth, int[] prediction, Method method);
}

/**
 * Adjusted Mutual Information
 */
class AdjustedMutualInformation implements ClusteringMetric {
    /** Adjustment methods */
    enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }
    
    /** Calculate AMI with arithmetic mean adjustment */
    public static double of(int[] truth, int[] prediction);
    
    /** Calculate AMI with specified adjustment */
    public static double of(int[] truth, int[] prediction, Method method);
}

/**
 * Contingency table for clustering evaluation
 */
class ContingencyTable {
    /** Create contingency table */
    public static ContingencyTable of(int[] truth, int[] prediction);
    
    /** The contingency table matrix */
    public final int[][] table;
    
    /** Number of true clusters */
    public final int n;
    
    /** Number of predicted clusters */
    public final int m;
    
    /** Calculate mutual information */
    public double mutualInformation();
    
    /** Calculate entropy of true clustering */
    public double entropyX();
    
    /** Calculate entropy of predicted clustering */
    public double entropyY();
}

Comprehensive Usage Example:

import smile.validation.*;
import smile.validation.metric.*;
import smile.classification.RandomForest;

// Complete validation pipeline
public class ModelValidation {
    public void validateModel(double[][] features, int[] labels) {
        // 1. Cross-validation
        CrossValidation cv = CrossValidation.stratify(10, labels);
        var cvResults = cv.classification(RandomForest::fit, features, labels);
        
        System.out.println("CV Accuracy: " + cvResults.avg.accuracy + " ± " + cvResults.std.accuracy);
        System.out.println("CV F1 Score: " + cvResults.avg.f1[0]);
        
        // 2. Bootstrap validation
        var bootstrapResults = Bootstrap.classification(100, RandomForest::fit, features, labels);
        System.out.println("Bootstrap Accuracy: " + bootstrapResults.avg.accuracy);
        
        // 3. Detailed metrics analysis
        RandomForest model = RandomForest.fit(features, labels);
        int[] predictions = Arrays.stream(features).mapToInt(model::predict).toArray();
        
        // Classification metrics
        double accuracy = Accuracy.of(labels, predictions);
        double[] precision = Precision.byClass(labels, predictions);
        double[] recall = Recall.byClass(labels, predictions);
        double[] f1 = FScore.byClass(labels, predictions);
        double mcc = MatthewsCorrelation.of(labels, predictions);
        
        // Confusion matrix
        ConfusionMatrix cm = ConfusionMatrix.of(labels, predictions);
        System.out.println("Confusion Matrix:");
        for (int[] row : cm.matrix) {
            System.out.println(Arrays.toString(row));
        }
        
        // 4. Statistical significance testing
        double[] ci = cvResults.confidenceInterval("accuracy", 0.95);
        System.out.println("95% CI for accuracy: [" + ci[0] + ", " + ci[1] + "]");
    }
}

Common Validation Patterns

Standard patterns for model validation in Smile Core:

Basic Cross-Validation:

CrossValidation cv = CrossValidation.of(5);
var results = cv.classification(trainer, features, labels);

Stratified Cross-Validation:

CrossValidation cv = CrossValidation.stratify(10, labels);
var results = cv.classification(trainer, features, labels);

Time Series Validation:

CrossValidation cv = CrossValidation.timeSeries(5);
var results = cv.regression(trainer, features, targets);

Bootstrap Validation:

var results = Bootstrap.classification(100, trainer, features, labels);

Performance Analysis

All validation results provide comprehensive performance analysis:

  • Point Estimates: Mean performance across folds/bootstrap samples
  • Variability: Standard deviation of performance metrics
  • Confidence Intervals: Statistical bounds for performance estimates
  • Per-Class Metrics: Detailed breakdown for multi-class problems
  • Confusion Analysis: Detailed error analysis through confusion matrices

Install with Tessl CLI

npx tessl i tessl/maven-com-github-haifengl--smile-core

docs

advanced-analytics.md

classification.md

clustering.md

deep-learning.md

feature-engineering.md

index.md

regression.md

validation-metrics.md

tile.json