0
# Validation and Metrics
1
2
Comprehensive model validation framework with cross-validation, bootstrap sampling, and extensive performance metrics for classification, regression, and clustering tasks. Smile Core provides robust tools for model evaluation and comparison.
3
4
## Capabilities
5
6
### Cross-Validation Framework
7
8
Core framework for model validation using various cross-validation strategies.
9
10
```java { .api }
11
/**
12
* Cross-validation utilities for model evaluation
13
*/
14
class CrossValidation {
15
/** Create k-fold cross-validation splits */
16
public static Bag[] of(int n, int k);
17
18
/** Create stratified k-fold cross-validation splits */
19
public static Bag[] stratify(int n, int k, int[] y);
20
21
/** Classification cross-validation */
22
public static <T, M extends Classifier<T>> ClassificationValidations classification(
23
int k, Classifier.Trainer<T, M> trainer, T[] x, int[] y);
24
25
/** Classification cross-validation with Formula */
26
public static ClassificationValidations classification(
27
int k, BiFunction<Formula, DataFrame, Classifier<Tuple>> trainer, Formula formula, DataFrame data);
28
29
/** Regression cross-validation */
30
public static <T, M extends Regression<T>> RegressionValidations regression(
31
int k, Regression.Trainer<T, M> trainer, T[] x, double[] y);
32
33
/** Regression cross-validation with Formula */
34
public static RegressionValidations regression(
35
int k, BiFunction<Formula, DataFrame, Regression<Tuple>> trainer, Formula formula, DataFrame data);
36
}
37
38
/**
39
* Data bag containing training and test indices
40
*/
41
class Bag {
42
/** Training sample indices */
43
public final int[] samples;
44
45
/** Out-of-bag (test) sample indices */
46
public final int[] oob;
47
48
/** Get training data subset */
49
public <T> T[] trainSet(T[] data);
50
51
/** Get test data subset */
52
public <T> T[] testSet(T[] data);
53
}
54
```
55
56
**Usage Example:**
57
58
```java
59
import smile.validation.CrossValidation;
60
import smile.classification.RandomForest;
61
import smile.data.formula.Formula;
62
63
// 10-fold cross-validation with DataFrame
64
Formula formula = Formula.lhs("target");
65
var results = CrossValidation.classification(10, RandomForest::fit, formula, data);
66
67
System.out.println("Accuracy: " + results.avg.accuracy);
68
System.out.println("Error: " + results.avg.error);
69
70
// Stratified cross-validation splits
71
var stratifiedSplits = CrossValidation.stratify(data.size(), 5, labels);
72
```
73
74
### Bootstrap Sampling
75
76
Bootstrap resampling methods for model validation and uncertainty estimation.
77
78
```java { .api }
79
/**
80
* Bootstrap sampling for model validation
81
*/
82
class Bootstrap {
83
/** Create bootstrap samples */
84
public static Bag[] of(int n, int size, int subsampleSize);
85
86
/** Bootstrap with replacement */
87
public static Bag[] of(int n, int size);
88
89
/** Bootstrap validation for classification */
90
public static <T, M extends Classifier<T>> ClassificationValidations classification(
91
int round, Classifier.Trainer<T, M> trainer, T[] x, int[] y);
92
93
/** Bootstrap validation for regression */
94
public static <T, M extends Regression<T>> RegressionValidations regression(
95
int round, Regression.Trainer<T, M> trainer, T[] x, double[] y);
96
}
97
```
98
99
### Leave-One-Out Cross-Validation
100
101
Specialized cross-validation where each sample is used as test set once.
102
103
```java { .api }
104
/**
105
* Leave-one-out cross-validation
106
*/
107
class LOOCV implements CrossValidation {
108
/** Create LOOCV splits */
109
public Bag[] split(int n);
110
111
/** Classification LOOCV */
112
public static <T> ClassificationValidation classification(
113
Classifier.Trainer<T, ?> trainer, T[] x, int[] y);
114
115
/** Regression LOOCV */
116
public static <T> RegressionValidation regression(
117
Regression.Trainer<T, ?> trainer, T[] x, double[] y);
118
}
119
```
120
121
### Model Selection
122
123
Utilities for comparing and selecting optimal models and hyperparameters.
124
125
```java { .api }
126
/**
127
* Model selection utilities
128
*/
129
class ModelSelection {
130
/** Grid search with cross-validation */
131
public static <T, M> M gridSearch(
132
BiFunction<T[], int[], M> trainer,
133
T[] x, int[] y,
134
CrossValidation cv,
135
Map<String, Object[]> paramGrid);
136
137
/** Random search with cross-validation */
138
public static <T, M> M randomSearch(
139
BiFunction<T[], int[], M> trainer,
140
T[] x, int[] y,
141
CrossValidation cv,
142
Map<String, Distribution> paramDist,
143
int nIter);
144
145
/** Bayesian optimization for hyperparameter tuning */
146
public static <T, M> M bayesianOptimization(
147
BiFunction<T[], int[], M> trainer,
148
T[] x, int[] y,
149
CrossValidation cv,
150
Map<String, Double[]> bounds,
151
int nIter);
152
}
153
```
154
155
### Classification Validation Results
156
157
Classes for storing and analyzing classification validation results.
158
159
```java { .api }
160
/**
161
* Single classification validation result
162
*/
163
class ClassificationValidation {
164
/** Accuracy score */
165
public final double accuracy;
166
167
/** Error rate */
168
public final double error;
169
170
/** Confusion matrix */
171
public final ConfusionMatrix confusion;
172
173
/** Class-wise precision scores */
174
public final double[] precision;
175
176
/** Class-wise recall scores */
177
public final double[] recall;
178
179
/** Class-wise F1 scores */
180
public final double[] f1;
181
182
/** Matthews correlation coefficient */
183
public final double mcc;
184
}
185
186
/**
187
* Multiple classification validation results
188
*/
189
class ClassificationValidations {
190
/** Individual fold results */
191
public final ClassificationValidation[] rounds;
192
193
/** Average validation metrics */
194
public final ClassificationValidation avg;
195
196
/** Standard deviation of metrics */
197
public final ClassificationValidation std;
198
199
/** Get confidence interval for metric */
200
public double[] confidenceInterval(String metric, double confidence);
201
}
202
203
/**
204
* Classification metrics container
205
*/
206
class ClassificationMetrics {
207
/** Calculate all metrics from predictions */
208
public static ClassificationMetrics of(int[] truth, int[] prediction);
209
210
/** Get accuracy */
211
public double getAccuracy();
212
213
/** Get error rate */
214
public double getError();
215
216
/** Get macro-averaged F1 score */
217
public double getMacroF1();
218
219
/** Get weighted F1 score */
220
public double getWeightedF1();
221
}
222
```
223
224
### Regression Validation Results
225
226
Classes for storing and analyzing regression validation results.
227
228
```java { .api }
229
/**
230
* Single regression validation result
231
*/
232
class RegressionValidation {
233
/** Root mean square error */
234
public final double rmse;
235
236
/** Mean absolute error */
237
public final double mae;
238
239
/** Mean absolute deviation */
240
public final double mad;
241
242
/** R-squared coefficient */
243
public final double r2;
244
245
/** Adjusted R-squared */
246
public final double adjustedR2;
247
248
/** Residual sum of squares */
249
public final double rss;
250
251
/** Total sum of squares */
252
public final double tss;
253
}
254
255
/**
256
* Multiple regression validation results
257
*/
258
class RegressionValidations {
259
/** Individual fold results */
260
public final RegressionValidation[] rounds;
261
262
/** Average validation metrics */
263
public final RegressionValidation avg;
264
265
/** Standard deviation of metrics */
266
public final RegressionValidation std;
267
}
268
269
/**
270
* Regression metrics container
271
*/
272
class RegressionMetrics {
273
/** Calculate all metrics from predictions */
274
public static RegressionMetrics of(double[] truth, double[] prediction);
275
276
/** Get RMSE */
277
public double getRMSE();
278
279
/** Get MAE */
280
public double getMAE();
281
282
/** Get R-squared */
283
public double getR2();
284
}
285
```
286
287
### Classification Metrics
288
289
Individual classification metrics for detailed model evaluation.
290
291
```java { .api }
292
/**
293
* Base classification metric interface
294
*/
295
interface ClassificationMetric {
296
/** Calculate metric from true and predicted labels */
297
double score(int[] truth, int[] prediction);
298
}
299
300
/**
301
* Probabilistic classification metric interface
302
*/
303
interface ProbabilisticClassificationMetric {
304
/** Calculate metric from true labels and predicted probabilities */
305
double score(int[] truth, double[][] probability);
306
}
307
308
/**
309
* Accuracy metric
310
*/
311
class Accuracy implements ClassificationMetric {
312
/** Calculate accuracy */
313
public static double of(int[] truth, int[] prediction);
314
}
315
316
/**
317
* Error rate metric
318
*/
319
class Error implements ClassificationMetric {
320
/** Calculate error rate */
321
public static double of(int[] truth, int[] prediction);
322
}
323
324
/**
325
* Precision metric
326
*/
327
class Precision implements ClassificationMetric {
328
/** Calculate macro-averaged precision */
329
public static double of(int[] truth, int[] prediction);
330
331
/** Calculate class-specific precision */
332
public static double[] byClass(int[] truth, int[] prediction);
333
}
334
335
/**
336
* Recall (Sensitivity) metric
337
*/
338
class Recall implements ClassificationMetric {
339
/** Calculate macro-averaged recall */
340
public static double of(int[] truth, int[] prediction);
341
342
/** Calculate class-specific recall */
343
public static double[] byClass(int[] truth, int[] prediction);
344
}
345
346
/**
347
* Specificity metric
348
*/
349
class Specificity implements ClassificationMetric {
350
/** Calculate macro-averaged specificity */
351
public static double of(int[] truth, int[] prediction);
352
353
/** Calculate class-specific specificity */
354
public static double[] byClass(int[] truth, int[] prediction);
355
}
356
357
/**
358
* F-score metric
359
*/
360
class FScore implements ClassificationMetric {
361
/** Calculate macro-averaged F1 score */
362
public static double of(int[] truth, int[] prediction);
363
364
/** Calculate F-beta score */
365
public static double of(int[] truth, int[] prediction, double beta);
366
367
/** Calculate class-specific F1 scores */
368
public static double[] byClass(int[] truth, int[] prediction);
369
}
370
371
/**
372
* Matthews Correlation Coefficient
373
*/
374
class MatthewsCorrelation implements ClassificationMetric {
375
/** Calculate MCC */
376
public static double of(int[] truth, int[] prediction);
377
}
378
379
/**
380
* Area Under ROC Curve
381
*/
382
class AUC implements ProbabilisticClassificationMetric {
383
/** Calculate AUC for binary classification */
384
public static double of(int[] truth, double[] probability);
385
386
/** Calculate multi-class AUC (one-vs-rest) */
387
public static double of(int[] truth, double[][] probability);
388
}
389
390
/**
391
* Cross Entropy loss
392
*/
393
class CrossEntropy implements ProbabilisticClassificationMetric {
394
/** Calculate cross entropy loss */
395
public static double of(int[] truth, double[][] probability);
396
}
397
398
/**
399
* Logarithmic loss
400
*/
401
class LogLoss implements ProbabilisticClassificationMetric {
402
/** Calculate log loss */
403
public static double of(int[] truth, double[][] probability);
404
}
405
406
/**
407
* Confusion Matrix
408
*/
409
class ConfusionMatrix {
410
/** Create confusion matrix */
411
public static ConfusionMatrix of(int[] truth, int[] prediction);
412
413
/** The confusion matrix */
414
public final int[][] matrix;
415
416
/** Number of classes */
417
public final int classes;
418
419
/** Get accuracy from confusion matrix */
420
public double accuracy();
421
422
/** Get error rate */
423
public double error();
424
425
/** Get class-specific precision */
426
public double[] precision();
427
428
/** Get class-specific recall */
429
public double[] recall();
430
}
431
432
/**
433
* False Discovery Rate
434
*/
435
class FDR implements ClassificationMetric {
436
/** Calculate false discovery rate */
437
public static double of(int[] truth, int[] prediction);
438
}
439
440
/**
441
* Fallout (False Positive Rate)
442
*/
443
class Fallout implements ClassificationMetric {
444
/** Calculate fallout */
445
public static double of(int[] truth, int[] prediction);
446
}
447
```
448
449
### Regression Metrics
450
451
Individual regression metrics for model evaluation.
452
453
```java { .api }
454
/**
455
* Base regression metric interface
456
*/
457
interface RegressionMetric {
458
/** Calculate metric from true and predicted values */
459
double score(double[] truth, double[] prediction);
460
}
461
462
/**
463
* Mean Squared Error
464
*/
465
class MSE implements RegressionMetric {
466
/** Calculate MSE */
467
public static double of(double[] truth, double[] prediction);
468
}
469
470
/**
471
* Root Mean Squared Error
472
*/
473
class RMSE implements RegressionMetric {
474
/** Calculate RMSE */
475
public static double of(double[] truth, double[] prediction);
476
}
477
478
/**
479
* Mean Absolute Error
480
*/
481
class MAE implements RegressionMetric {
482
/** Calculate MAE */
483
public static double of(double[] truth, double[] prediction);
484
}
485
486
/**
487
* Mean Absolute Deviation
488
*/
489
class MAD implements RegressionMetric {
490
/** Calculate MAD */
491
public static double of(double[] truth, double[] prediction);
492
}
493
494
/**
495
* Residual Sum of Squares
496
*/
497
class RSS implements RegressionMetric {
498
/** Calculate RSS */
499
public static double of(double[] truth, double[] prediction);
500
}
501
502
/**
503
* R-squared coefficient of determination
504
*/
505
class R2 implements RegressionMetric {
506
/** Calculate R-squared */
507
public static double of(double[] truth, double[] prediction);
508
509
/** Calculate adjusted R-squared */
510
public static double adjusted(double[] truth, double[] prediction, int p);
511
}
512
```
513
514
### Clustering Metrics
515
516
Metrics for evaluating clustering quality and comparing clustering results.
517
518
```java { .api }
519
/**
520
* Base clustering metric interface
521
*/
522
interface ClusteringMetric {
523
/** Calculate metric from true and predicted cluster labels */
524
double score(int[] truth, int[] prediction);
525
}
526
527
/**
528
* Rand Index for clustering comparison
529
*/
530
class RandIndex implements ClusteringMetric {
531
/** Calculate Rand Index */
532
public static double of(int[] truth, int[] prediction);
533
}
534
535
/**
536
* Adjusted Rand Index
537
*/
538
class AdjustedRandIndex implements ClusteringMetric {
539
/** Calculate Adjusted Rand Index */
540
public static double of(int[] truth, int[] prediction);
541
}
542
543
/**
544
* Mutual Information between clusterings
545
*/
546
class MutualInformation implements ClusteringMetric {
547
/** Calculate mutual information */
548
public static double of(int[] truth, int[] prediction);
549
}
550
551
/**
552
* Normalized Mutual Information
553
*/
554
class NormalizedMutualInformation implements ClusteringMetric {
555
/** Normalization methods */
556
enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }
557
558
/** Calculate NMI with arithmetic mean normalization */
559
public static double of(int[] truth, int[] prediction);
560
561
/** Calculate NMI with specified normalization */
562
public static double of(int[] truth, int[] prediction, Method method);
563
}
564
565
/**
566
* Adjusted Mutual Information
567
*/
568
class AdjustedMutualInformation implements ClusteringMetric {
569
/** Adjustment methods */
570
enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }
571
572
/** Calculate AMI with arithmetic mean adjustment */
573
public static double of(int[] truth, int[] prediction);
574
575
/** Calculate AMI with specified adjustment */
576
public static double of(int[] truth, int[] prediction, Method method);
577
}
578
579
/**
580
* Contingency table for clustering evaluation
581
*/
582
class ContingencyTable {
583
/** Create contingency table */
584
public static ContingencyTable of(int[] truth, int[] prediction);
585
586
/** The contingency table matrix */
587
public final int[][] table;
588
589
/** Number of true clusters */
590
public final int n;
591
592
/** Number of predicted clusters */
593
public final int m;
594
595
/** Calculate mutual information */
596
public double mutualInformation();
597
598
/** Calculate entropy of true clustering */
599
public double entropyX();
600
601
/** Calculate entropy of predicted clustering */
602
public double entropyY();
603
}
604
```
605
606
**Comprehensive Usage Example:**
607
608
```java
609
import smile.validation.*;
610
import smile.validation.metric.*;
611
import smile.classification.RandomForest;
612
613
// Complete validation pipeline
614
public class ModelValidation {
615
public void validateModel(double[][] features, int[] labels) {
616
// 1. Cross-validation
617
CrossValidation cv = CrossValidation.stratify(10, labels);
618
var cvResults = cv.classification(RandomForest::fit, features, labels);
619
620
System.out.println("CV Accuracy: " + cvResults.avg.accuracy + " ± " + cvResults.std.accuracy);
621
System.out.println("CV F1 Score: " + cvResults.avg.f1[0]);
622
623
// 2. Bootstrap validation
624
var bootstrapResults = Bootstrap.classification(100, RandomForest::fit, features, labels);
625
System.out.println("Bootstrap Accuracy: " + bootstrapResults.avg.accuracy);
626
627
// 3. Detailed metrics analysis
628
RandomForest model = RandomForest.fit(features, labels);
629
int[] predictions = Arrays.stream(features).mapToInt(model::predict).toArray();
630
631
// Classification metrics
632
double accuracy = Accuracy.of(labels, predictions);
633
double[] precision = Precision.byClass(labels, predictions);
634
double[] recall = Recall.byClass(labels, predictions);
635
double[] f1 = FScore.byClass(labels, predictions);
636
double mcc = MatthewsCorrelation.of(labels, predictions);
637
638
// Confusion matrix
639
ConfusionMatrix cm = ConfusionMatrix.of(labels, predictions);
640
System.out.println("Confusion Matrix:");
641
for (int[] row : cm.matrix) {
642
System.out.println(Arrays.toString(row));
643
}
644
645
// 4. Statistical significance testing
646
double[] ci = cvResults.confidenceInterval("accuracy", 0.95);
647
System.out.println("95% CI for accuracy: [" + ci[0] + ", " + ci[1] + "]");
648
}
649
}
650
```
651
652
### Common Validation Patterns
653
654
Standard patterns for model validation in Smile Core:
655
656
**Basic Cross-Validation:**
657
```java
658
CrossValidation cv = CrossValidation.of(5);
659
var results = cv.classification(trainer, features, labels);
660
```
661
662
**Stratified Cross-Validation:**
663
```java
664
CrossValidation cv = CrossValidation.stratify(10, labels);
665
var results = cv.classification(trainer, features, labels);
666
```
667
668
**Time Series Validation:**
669
```java
670
CrossValidation cv = CrossValidation.timeSeries(5);
671
var results = cv.regression(trainer, features, targets);
672
```
673
674
**Bootstrap Validation:**
675
```java
676
var results = Bootstrap.classification(100, trainer, features, labels);
677
```
678
679
### Performance Analysis
680
681
All validation results provide comprehensive performance analysis:
682
683
- **Point Estimates**: Mean performance across folds/bootstrap samples
684
- **Variability**: Standard deviation of performance metrics
685
- **Confidence Intervals**: Statistical bounds for performance estimates
686
- **Per-Class Metrics**: Detailed breakdown for multi-class problems
687
- **Confusion Analysis**: Detailed error analysis through confusion matrices