Statistical Machine Intelligence and Learning Engine providing comprehensive machine learning algorithms for classification, regression, clustering, and feature engineering in Java
—
Comprehensive model validation framework with cross-validation, bootstrap sampling, and extensive performance metrics for classification, regression, and clustering tasks. Smile Core provides robust tools for model evaluation and comparison.
Core framework for model validation using various cross-validation strategies.
/**
* Cross-validation utilities for model evaluation
*/
class CrossValidation {
/** Create k-fold cross-validation splits */
public static Bag[] of(int n, int k);
/** Create stratified k-fold cross-validation splits */
public static Bag[] stratify(int n, int k, int[] y);
/** Classification cross-validation */
public static <T, M extends Classifier<T>> ClassificationValidations classification(
int k, Classifier.Trainer<T, M> trainer, T[] x, int[] y);
/** Classification cross-validation with Formula */
public static ClassificationValidations classification(
int k, BiFunction<Formula, DataFrame, Classifier<Tuple>> trainer, Formula formula, DataFrame data);
/** Regression cross-validation */
public static <T, M extends Regression<T>> RegressionValidations regression(
int k, Regression.Trainer<T, M> trainer, T[] x, double[] y);
/** Regression cross-validation with Formula */
public static RegressionValidations regression(
int k, BiFunction<Formula, DataFrame, Regression<Tuple>> trainer, Formula formula, DataFrame data);
}
/**
* Data bag containing training and test indices
*/
class Bag {
/** Training sample indices */
public final int[] samples;
/** Out-of-bag (test) sample indices */
public final int[] oob;
/** Get training data subset */
public <T> T[] trainSet(T[] data);
/** Get test data subset */
public <T> T[] testSet(T[] data);
}Usage Example:
import smile.validation.CrossValidation;
import smile.classification.RandomForest;
import smile.data.formula.Formula;
// 10-fold cross-validation with DataFrame
Formula formula = Formula.lhs("target");
var results = CrossValidation.classification(10, RandomForest::fit, formula, data);
System.out.println("Accuracy: " + results.avg.accuracy);
System.out.println("Error: " + results.avg.error);
// Stratified cross-validation splits
var stratifiedSplits = CrossValidation.stratify(data.size(), 5, labels);Bootstrap resampling methods for model validation and uncertainty estimation.
/**
* Bootstrap sampling for model validation
*/
class Bootstrap {
/** Create bootstrap samples */
public static Bag[] of(int n, int size, int subsampleSize);
/** Bootstrap with replacement */
public static Bag[] of(int n, int size);
/** Bootstrap validation for classification */
public static <T, M extends Classifier<T>> ClassificationValidations classification(
int round, Classifier.Trainer<T, M> trainer, T[] x, int[] y);
/** Bootstrap validation for regression */
public static <T, M extends Regression<T>> RegressionValidations regression(
int round, Regression.Trainer<T, M> trainer, T[] x, double[] y);
}Specialized cross-validation where each sample is used as test set once.
/**
* Leave-one-out cross-validation
*/
class LOOCV implements CrossValidation {
/** Create LOOCV splits */
public Bag[] split(int n);
/** Classification LOOCV */
public static <T> ClassificationValidation classification(
Classifier.Trainer<T, ?> trainer, T[] x, int[] y);
/** Regression LOOCV */
public static <T> RegressionValidation regression(
Regression.Trainer<T, ?> trainer, T[] x, double[] y);
}Utilities for comparing and selecting optimal models and hyperparameters.
/**
* Model selection utilities
*/
class ModelSelection {
/** Grid search with cross-validation */
public static <T, M> M gridSearch(
BiFunction<T[], int[], M> trainer,
T[] x, int[] y,
CrossValidation cv,
Map<String, Object[]> paramGrid);
/** Random search with cross-validation */
public static <T, M> M randomSearch(
BiFunction<T[], int[], M> trainer,
T[] x, int[] y,
CrossValidation cv,
Map<String, Distribution> paramDist,
int nIter);
/** Bayesian optimization for hyperparameter tuning */
public static <T, M> M bayesianOptimization(
BiFunction<T[], int[], M> trainer,
T[] x, int[] y,
CrossValidation cv,
Map<String, Double[]> bounds,
int nIter);
}Classes for storing and analyzing classification validation results.
/**
* Single classification validation result
*/
class ClassificationValidation {
/** Accuracy score */
public final double accuracy;
/** Error rate */
public final double error;
/** Confusion matrix */
public final ConfusionMatrix confusion;
/** Class-wise precision scores */
public final double[] precision;
/** Class-wise recall scores */
public final double[] recall;
/** Class-wise F1 scores */
public final double[] f1;
/** Matthews correlation coefficient */
public final double mcc;
}
/**
* Multiple classification validation results
*/
class ClassificationValidations {
/** Individual fold results */
public final ClassificationValidation[] rounds;
/** Average validation metrics */
public final ClassificationValidation avg;
/** Standard deviation of metrics */
public final ClassificationValidation std;
/** Get confidence interval for metric */
public double[] confidenceInterval(String metric, double confidence);
}
/**
* Classification metrics container
*/
class ClassificationMetrics {
/** Calculate all metrics from predictions */
public static ClassificationMetrics of(int[] truth, int[] prediction);
/** Get accuracy */
public double getAccuracy();
/** Get error rate */
public double getError();
/** Get macro-averaged F1 score */
public double getMacroF1();
/** Get weighted F1 score */
public double getWeightedF1();
}Classes for storing and analyzing regression validation results.
/**
* Single regression validation result
*/
class RegressionValidation {
/** Root mean square error */
public final double rmse;
/** Mean absolute error */
public final double mae;
/** Mean absolute deviation */
public final double mad;
/** R-squared coefficient */
public final double r2;
/** Adjusted R-squared */
public final double adjustedR2;
/** Residual sum of squares */
public final double rss;
/** Total sum of squares */
public final double tss;
}
/**
* Multiple regression validation results
*/
class RegressionValidations {
/** Individual fold results */
public final RegressionValidation[] rounds;
/** Average validation metrics */
public final RegressionValidation avg;
/** Standard deviation of metrics */
public final RegressionValidation std;
}
/**
* Regression metrics container
*/
class RegressionMetrics {
/** Calculate all metrics from predictions */
public static RegressionMetrics of(double[] truth, double[] prediction);
/** Get RMSE */
public double getRMSE();
/** Get MAE */
public double getMAE();
/** Get R-squared */
public double getR2();
}Individual classification metrics for detailed model evaluation.
/**
* Base classification metric interface
*/
interface ClassificationMetric {
/** Calculate metric from true and predicted labels */
double score(int[] truth, int[] prediction);
}
/**
* Probabilistic classification metric interface
*/
interface ProbabilisticClassificationMetric {
/** Calculate metric from true labels and predicted probabilities */
double score(int[] truth, double[][] probability);
}
/**
* Accuracy metric
*/
class Accuracy implements ClassificationMetric {
/** Calculate accuracy */
public static double of(int[] truth, int[] prediction);
}
/**
* Error rate metric
*/
class Error implements ClassificationMetric {
/** Calculate error rate */
public static double of(int[] truth, int[] prediction);
}
/**
* Precision metric
*/
class Precision implements ClassificationMetric {
/** Calculate macro-averaged precision */
public static double of(int[] truth, int[] prediction);
/** Calculate class-specific precision */
public static double[] byClass(int[] truth, int[] prediction);
}
/**
* Recall (Sensitivity) metric
*/
class Recall implements ClassificationMetric {
/** Calculate macro-averaged recall */
public static double of(int[] truth, int[] prediction);
/** Calculate class-specific recall */
public static double[] byClass(int[] truth, int[] prediction);
}
/**
* Specificity metric
*/
class Specificity implements ClassificationMetric {
/** Calculate macro-averaged specificity */
public static double of(int[] truth, int[] prediction);
/** Calculate class-specific specificity */
public static double[] byClass(int[] truth, int[] prediction);
}
/**
* F-score metric
*/
class FScore implements ClassificationMetric {
/** Calculate macro-averaged F1 score */
public static double of(int[] truth, int[] prediction);
/** Calculate F-beta score */
public static double of(int[] truth, int[] prediction, double beta);
/** Calculate class-specific F1 scores */
public static double[] byClass(int[] truth, int[] prediction);
}
/**
* Matthews Correlation Coefficient
*/
class MatthewsCorrelation implements ClassificationMetric {
/** Calculate MCC */
public static double of(int[] truth, int[] prediction);
}
/**
* Area Under ROC Curve
*/
class AUC implements ProbabilisticClassificationMetric {
/** Calculate AUC for binary classification */
public static double of(int[] truth, double[] probability);
/** Calculate multi-class AUC (one-vs-rest) */
public static double of(int[] truth, double[][] probability);
}
/**
* Cross Entropy loss
*/
class CrossEntropy implements ProbabilisticClassificationMetric {
/** Calculate cross entropy loss */
public static double of(int[] truth, double[][] probability);
}
/**
* Logarithmic loss
*/
class LogLoss implements ProbabilisticClassificationMetric {
/** Calculate log loss */
public static double of(int[] truth, double[][] probability);
}
/**
* Confusion Matrix
*/
class ConfusionMatrix {
/** Create confusion matrix */
public static ConfusionMatrix of(int[] truth, int[] prediction);
/** The confusion matrix */
public final int[][] matrix;
/** Number of classes */
public final int classes;
/** Get accuracy from confusion matrix */
public double accuracy();
/** Get error rate */
public double error();
/** Get class-specific precision */
public double[] precision();
/** Get class-specific recall */
public double[] recall();
}
/**
* False Discovery Rate
*/
class FDR implements ClassificationMetric {
/** Calculate false discovery rate */
public static double of(int[] truth, int[] prediction);
}
/**
* Fallout (False Positive Rate)
*/
class Fallout implements ClassificationMetric {
/** Calculate fallout */
public static double of(int[] truth, int[] prediction);
}Individual regression metrics for model evaluation.
/**
* Base regression metric interface
*/
interface RegressionMetric {
/** Calculate metric from true and predicted values */
double score(double[] truth, double[] prediction);
}
/**
* Mean Squared Error
*/
class MSE implements RegressionMetric {
/** Calculate MSE */
public static double of(double[] truth, double[] prediction);
}
/**
* Root Mean Squared Error
*/
class RMSE implements RegressionMetric {
/** Calculate RMSE */
public static double of(double[] truth, double[] prediction);
}
/**
* Mean Absolute Error
*/
class MAE implements RegressionMetric {
/** Calculate MAE */
public static double of(double[] truth, double[] prediction);
}
/**
* Mean Absolute Deviation
*/
class MAD implements RegressionMetric {
/** Calculate MAD */
public static double of(double[] truth, double[] prediction);
}
/**
* Residual Sum of Squares
*/
class RSS implements RegressionMetric {
/** Calculate RSS */
public static double of(double[] truth, double[] prediction);
}
/**
* R-squared coefficient of determination
*/
class R2 implements RegressionMetric {
/** Calculate R-squared */
public static double of(double[] truth, double[] prediction);
/** Calculate adjusted R-squared */
public static double adjusted(double[] truth, double[] prediction, int p);
}Metrics for evaluating clustering quality and comparing clustering results.
/**
* Base clustering metric interface
*/
interface ClusteringMetric {
/** Calculate metric from true and predicted cluster labels */
double score(int[] truth, int[] prediction);
}
/**
* Rand Index for clustering comparison
*/
class RandIndex implements ClusteringMetric {
/** Calculate Rand Index */
public static double of(int[] truth, int[] prediction);
}
/**
* Adjusted Rand Index
*/
class AdjustedRandIndex implements ClusteringMetric {
/** Calculate Adjusted Rand Index */
public static double of(int[] truth, int[] prediction);
}
/**
* Mutual Information between clusterings
*/
class MutualInformation implements ClusteringMetric {
/** Calculate mutual information */
public static double of(int[] truth, int[] prediction);
}
/**
* Normalized Mutual Information
*/
class NormalizedMutualInformation implements ClusteringMetric {
/** Normalization methods */
enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }
/** Calculate NMI with arithmetic mean normalization */
public static double of(int[] truth, int[] prediction);
/** Calculate NMI with specified normalization */
public static double of(int[] truth, int[] prediction, Method method);
}
/**
* Adjusted Mutual Information
*/
class AdjustedMutualInformation implements ClusteringMetric {
/** Adjustment methods */
enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }
/** Calculate AMI with arithmetic mean adjustment */
public static double of(int[] truth, int[] prediction);
/** Calculate AMI with specified adjustment */
public static double of(int[] truth, int[] prediction, Method method);
}
/**
* Contingency table for clustering evaluation
*/
class ContingencyTable {
/** Create contingency table */
public static ContingencyTable of(int[] truth, int[] prediction);
/** The contingency table matrix */
public final int[][] table;
/** Number of true clusters */
public final int n;
/** Number of predicted clusters */
public final int m;
/** Calculate mutual information */
public double mutualInformation();
/** Calculate entropy of true clustering */
public double entropyX();
/** Calculate entropy of predicted clustering */
public double entropyY();
}Comprehensive Usage Example:
import smile.validation.*;
import smile.validation.metric.*;
import smile.classification.RandomForest;
// Complete validation pipeline
public class ModelValidation {
public void validateModel(double[][] features, int[] labels) {
// 1. Cross-validation
CrossValidation cv = CrossValidation.stratify(10, labels);
var cvResults = cv.classification(RandomForest::fit, features, labels);
System.out.println("CV Accuracy: " + cvResults.avg.accuracy + " ± " + cvResults.std.accuracy);
System.out.println("CV F1 Score: " + cvResults.avg.f1[0]);
// 2. Bootstrap validation
var bootstrapResults = Bootstrap.classification(100, RandomForest::fit, features, labels);
System.out.println("Bootstrap Accuracy: " + bootstrapResults.avg.accuracy);
// 3. Detailed metrics analysis
RandomForest model = RandomForest.fit(features, labels);
int[] predictions = Arrays.stream(features).mapToInt(model::predict).toArray();
// Classification metrics
double accuracy = Accuracy.of(labels, predictions);
double[] precision = Precision.byClass(labels, predictions);
double[] recall = Recall.byClass(labels, predictions);
double[] f1 = FScore.byClass(labels, predictions);
double mcc = MatthewsCorrelation.of(labels, predictions);
// Confusion matrix
ConfusionMatrix cm = ConfusionMatrix.of(labels, predictions);
System.out.println("Confusion Matrix:");
for (int[] row : cm.matrix) {
System.out.println(Arrays.toString(row));
}
// 4. Statistical significance testing
double[] ci = cvResults.confidenceInterval("accuracy", 0.95);
System.out.println("95% CI for accuracy: [" + ci[0] + ", " + ci[1] + "]");
}
}Standard patterns for model validation in Smile Core:
Basic Cross-Validation:
CrossValidation cv = CrossValidation.of(5);
var results = cv.classification(trainer, features, labels);Stratified Cross-Validation:
CrossValidation cv = CrossValidation.stratify(10, labels);
var results = cv.classification(trainer, features, labels);Time Series Validation:
CrossValidation cv = CrossValidation.timeSeries(5);
var results = cv.regression(trainer, features, targets);Bootstrap Validation:
var results = Bootstrap.classification(100, trainer, features, labels);All validation results provide comprehensive performance analysis:
Install with Tessl CLI
npx tessl i tessl/maven-com-github-haifengl--smile-core