or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-analytics.mdclassification.mdclustering.mddeep-learning.mdfeature-engineering.mdindex.mdregression.mdvalidation-metrics.md

validation-metrics.mddocs/

0

# Validation and Metrics

1

2

Comprehensive model validation framework with cross-validation, bootstrap sampling, and extensive performance metrics for classification, regression, and clustering tasks. Smile Core provides robust tools for model evaluation and comparison.

3

4

## Capabilities

5

6

### Cross-Validation Framework

7

8

Core framework for model validation using various cross-validation strategies.

9

10

```java { .api }

11

/**

12

* Cross-validation utilities for model evaluation

13

*/

14

class CrossValidation {

15

/** Create k-fold cross-validation splits */

16

public static Bag[] of(int n, int k);

17

18

/** Create stratified k-fold cross-validation splits */

19

public static Bag[] stratify(int n, int k, int[] y);

20

21

/** Classification cross-validation */

22

public static <T, M extends Classifier<T>> ClassificationValidations classification(

23

int k, Classifier.Trainer<T, M> trainer, T[] x, int[] y);

24

25

/** Classification cross-validation with Formula */

26

public static ClassificationValidations classification(

27

int k, BiFunction<Formula, DataFrame, Classifier<Tuple>> trainer, Formula formula, DataFrame data);

28

29

/** Regression cross-validation */

30

public static <T, M extends Regression<T>> RegressionValidations regression(

31

int k, Regression.Trainer<T, M> trainer, T[] x, double[] y);

32

33

/** Regression cross-validation with Formula */

34

public static RegressionValidations regression(

35

int k, BiFunction<Formula, DataFrame, Regression<Tuple>> trainer, Formula formula, DataFrame data);

36

}

37

38

/**

39

* Data bag containing training and test indices

40

*/

41

class Bag {

42

/** Training sample indices */

43

public final int[] samples;

44

45

/** Out-of-bag (test) sample indices */

46

public final int[] oob;

47

48

/** Get training data subset */

49

public <T> T[] trainSet(T[] data);

50

51

/** Get test data subset */

52

public <T> T[] testSet(T[] data);

53

}

54

```

55

56

**Usage Example:**

57

58

```java

59

import smile.validation.CrossValidation;

60

import smile.classification.RandomForest;

61

import smile.data.formula.Formula;

62

63

// 10-fold cross-validation with DataFrame

64

Formula formula = Formula.lhs("target");

65

var results = CrossValidation.classification(10, RandomForest::fit, formula, data);

66

67

System.out.println("Accuracy: " + results.avg.accuracy);

68

System.out.println("Error: " + results.avg.error);

69

70

// Stratified cross-validation splits

71

var stratifiedSplits = CrossValidation.stratify(data.size(), 5, labels);

72

```

73

74

### Bootstrap Sampling

75

76

Bootstrap resampling methods for model validation and uncertainty estimation.

77

78

```java { .api }

79

/**

80

* Bootstrap sampling for model validation

81

*/

82

class Bootstrap {

83

/** Create bootstrap samples */

84

public static Bag[] of(int n, int size, int subsampleSize);

85

86

/** Bootstrap with replacement */

87

public static Bag[] of(int n, int size);

88

89

/** Bootstrap validation for classification */

90

public static <T, M extends Classifier<T>> ClassificationValidations classification(

91

int round, Classifier.Trainer<T, M> trainer, T[] x, int[] y);

92

93

/** Bootstrap validation for regression */

94

public static <T, M extends Regression<T>> RegressionValidations regression(

95

int round, Regression.Trainer<T, M> trainer, T[] x, double[] y);

96

}

97

```

98

99

### Leave-One-Out Cross-Validation

100

101

Specialized cross-validation where each sample is used as test set once.

102

103

```java { .api }

104

/**

105

* Leave-one-out cross-validation

106

*/

107

class LOOCV implements CrossValidation {

108

/** Create LOOCV splits */

109

public Bag[] split(int n);

110

111

/** Classification LOOCV */

112

public static <T> ClassificationValidation classification(

113

Classifier.Trainer<T, ?> trainer, T[] x, int[] y);

114

115

/** Regression LOOCV */

116

public static <T> RegressionValidation regression(

117

Regression.Trainer<T, ?> trainer, T[] x, double[] y);

118

}

119

```

120

121

### Model Selection

122

123

Utilities for comparing and selecting optimal models and hyperparameters.

124

125

```java { .api }

126

/**

127

* Model selection utilities

128

*/

129

class ModelSelection {

130

/** Grid search with cross-validation */

131

public static <T, M> M gridSearch(

132

BiFunction<T[], int[], M> trainer,

133

T[] x, int[] y,

134

CrossValidation cv,

135

Map<String, Object[]> paramGrid);

136

137

/** Random search with cross-validation */

138

public static <T, M> M randomSearch(

139

BiFunction<T[], int[], M> trainer,

140

T[] x, int[] y,

141

CrossValidation cv,

142

Map<String, Distribution> paramDist,

143

int nIter);

144

145

/** Bayesian optimization for hyperparameter tuning */

146

public static <T, M> M bayesianOptimization(

147

BiFunction<T[], int[], M> trainer,

148

T[] x, int[] y,

149

CrossValidation cv,

150

Map<String, Double[]> bounds,

151

int nIter);

152

}

153

```

154

155

### Classification Validation Results

156

157

Classes for storing and analyzing classification validation results.

158

159

```java { .api }

160

/**

161

* Single classification validation result

162

*/

163

class ClassificationValidation {

164

/** Accuracy score */

165

public final double accuracy;

166

167

/** Error rate */

168

public final double error;

169

170

/** Confusion matrix */

171

public final ConfusionMatrix confusion;

172

173

/** Class-wise precision scores */

174

public final double[] precision;

175

176

/** Class-wise recall scores */

177

public final double[] recall;

178

179

/** Class-wise F1 scores */

180

public final double[] f1;

181

182

/** Matthews correlation coefficient */

183

public final double mcc;

184

}

185

186

/**

187

* Multiple classification validation results

188

*/

189

class ClassificationValidations {

190

/** Individual fold results */

191

public final ClassificationValidation[] rounds;

192

193

/** Average validation metrics */

194

public final ClassificationValidation avg;

195

196

/** Standard deviation of metrics */

197

public final ClassificationValidation std;

198

199

/** Get confidence interval for metric */

200

public double[] confidenceInterval(String metric, double confidence);

201

}

202

203

/**

204

* Classification metrics container

205

*/

206

class ClassificationMetrics {

207

/** Calculate all metrics from predictions */

208

public static ClassificationMetrics of(int[] truth, int[] prediction);

209

210

/** Get accuracy */

211

public double getAccuracy();

212

213

/** Get error rate */

214

public double getError();

215

216

/** Get macro-averaged F1 score */

217

public double getMacroF1();

218

219

/** Get weighted F1 score */

220

public double getWeightedF1();

221

}

222

```

223

224

### Regression Validation Results

225

226

Classes for storing and analyzing regression validation results.

227

228

```java { .api }

229

/**

230

* Single regression validation result

231

*/

232

class RegressionValidation {

233

/** Root mean square error */

234

public final double rmse;

235

236

/** Mean absolute error */

237

public final double mae;

238

239

/** Mean absolute deviation */

240

public final double mad;

241

242

/** R-squared coefficient */

243

public final double r2;

244

245

/** Adjusted R-squared */

246

public final double adjustedR2;

247

248

/** Residual sum of squares */

249

public final double rss;

250

251

/** Total sum of squares */

252

public final double tss;

253

}

254

255

/**

256

* Multiple regression validation results

257

*/

258

class RegressionValidations {

259

/** Individual fold results */

260

public final RegressionValidation[] rounds;

261

262

/** Average validation metrics */

263

public final RegressionValidation avg;

264

265

/** Standard deviation of metrics */

266

public final RegressionValidation std;

267

}

268

269

/**

270

* Regression metrics container

271

*/

272

class RegressionMetrics {

273

/** Calculate all metrics from predictions */

274

public static RegressionMetrics of(double[] truth, double[] prediction);

275

276

/** Get RMSE */

277

public double getRMSE();

278

279

/** Get MAE */

280

public double getMAE();

281

282

/** Get R-squared */

283

public double getR2();

284

}

285

```

286

287

### Classification Metrics

288

289

Individual classification metrics for detailed model evaluation.

290

291

```java { .api }

292

/**

293

* Base classification metric interface

294

*/

295

interface ClassificationMetric {

296

/** Calculate metric from true and predicted labels */

297

double score(int[] truth, int[] prediction);

298

}

299

300

/**

301

* Probabilistic classification metric interface

302

*/

303

interface ProbabilisticClassificationMetric {

304

/** Calculate metric from true labels and predicted probabilities */

305

double score(int[] truth, double[][] probability);

306

}

307

308

/**

309

* Accuracy metric

310

*/

311

class Accuracy implements ClassificationMetric {

312

/** Calculate accuracy */

313

public static double of(int[] truth, int[] prediction);

314

}

315

316

/**

317

* Error rate metric

318

*/

319

class Error implements ClassificationMetric {

320

/** Calculate error rate */

321

public static double of(int[] truth, int[] prediction);

322

}

323

324

/**

325

* Precision metric

326

*/

327

class Precision implements ClassificationMetric {

328

/** Calculate macro-averaged precision */

329

public static double of(int[] truth, int[] prediction);

330

331

/** Calculate class-specific precision */

332

public static double[] byClass(int[] truth, int[] prediction);

333

}

334

335

/**

336

* Recall (Sensitivity) metric

337

*/

338

class Recall implements ClassificationMetric {

339

/** Calculate macro-averaged recall */

340

public static double of(int[] truth, int[] prediction);

341

342

/** Calculate class-specific recall */

343

public static double[] byClass(int[] truth, int[] prediction);

344

}

345

346

/**

347

* Specificity metric

348

*/

349

class Specificity implements ClassificationMetric {

350

/** Calculate macro-averaged specificity */

351

public static double of(int[] truth, int[] prediction);

352

353

/** Calculate class-specific specificity */

354

public static double[] byClass(int[] truth, int[] prediction);

355

}

356

357

/**

358

* F-score metric

359

*/

360

class FScore implements ClassificationMetric {

361

/** Calculate macro-averaged F1 score */

362

public static double of(int[] truth, int[] prediction);

363

364

/** Calculate F-beta score */

365

public static double of(int[] truth, int[] prediction, double beta);

366

367

/** Calculate class-specific F1 scores */

368

public static double[] byClass(int[] truth, int[] prediction);

369

}

370

371

/**

372

* Matthews Correlation Coefficient

373

*/

374

class MatthewsCorrelation implements ClassificationMetric {

375

/** Calculate MCC */

376

public static double of(int[] truth, int[] prediction);

377

}

378

379

/**

380

* Area Under ROC Curve

381

*/

382

class AUC implements ProbabilisticClassificationMetric {

383

/** Calculate AUC for binary classification */

384

public static double of(int[] truth, double[] probability);

385

386

/** Calculate multi-class AUC (one-vs-rest) */

387

public static double of(int[] truth, double[][] probability);

388

}

389

390

/**

391

* Cross Entropy loss

392

*/

393

class CrossEntropy implements ProbabilisticClassificationMetric {

394

/** Calculate cross entropy loss */

395

public static double of(int[] truth, double[][] probability);

396

}

397

398

/**

399

* Logarithmic loss

400

*/

401

class LogLoss implements ProbabilisticClassificationMetric {

402

/** Calculate log loss */

403

public static double of(int[] truth, double[][] probability);

404

}

405

406

/**

407

* Confusion Matrix

408

*/

409

class ConfusionMatrix {

410

/** Create confusion matrix */

411

public static ConfusionMatrix of(int[] truth, int[] prediction);

412

413

/** The confusion matrix */

414

public final int[][] matrix;

415

416

/** Number of classes */

417

public final int classes;

418

419

/** Get accuracy from confusion matrix */

420

public double accuracy();

421

422

/** Get error rate */

423

public double error();

424

425

/** Get class-specific precision */

426

public double[] precision();

427

428

/** Get class-specific recall */

429

public double[] recall();

430

}

431

432

/**

433

* False Discovery Rate

434

*/

435

class FDR implements ClassificationMetric {

436

/** Calculate false discovery rate */

437

public static double of(int[] truth, int[] prediction);

438

}

439

440

/**

441

* Fallout (False Positive Rate)

442

*/

443

class Fallout implements ClassificationMetric {

444

/** Calculate fallout */

445

public static double of(int[] truth, int[] prediction);

446

}

447

```

448

449

### Regression Metrics

450

451

Individual regression metrics for model evaluation.

452

453

```java { .api }

454

/**

455

* Base regression metric interface

456

*/

457

interface RegressionMetric {

458

/** Calculate metric from true and predicted values */

459

double score(double[] truth, double[] prediction);

460

}

461

462

/**

463

* Mean Squared Error

464

*/

465

class MSE implements RegressionMetric {

466

/** Calculate MSE */

467

public static double of(double[] truth, double[] prediction);

468

}

469

470

/**

471

* Root Mean Squared Error

472

*/

473

class RMSE implements RegressionMetric {

474

/** Calculate RMSE */

475

public static double of(double[] truth, double[] prediction);

476

}

477

478

/**

479

* Mean Absolute Error

480

*/

481

class MAE implements RegressionMetric {

482

/** Calculate MAE */

483

public static double of(double[] truth, double[] prediction);

484

}

485

486

/**

487

* Mean Absolute Deviation

488

*/

489

class MAD implements RegressionMetric {

490

/** Calculate MAD */

491

public static double of(double[] truth, double[] prediction);

492

}

493

494

/**

495

* Residual Sum of Squares

496

*/

497

class RSS implements RegressionMetric {

498

/** Calculate RSS */

499

public static double of(double[] truth, double[] prediction);

500

}

501

502

/**

503

* R-squared coefficient of determination

504

*/

505

class R2 implements RegressionMetric {

506

/** Calculate R-squared */

507

public static double of(double[] truth, double[] prediction);

508

509

/** Calculate adjusted R-squared */

510

public static double adjusted(double[] truth, double[] prediction, int p);

511

}

512

```

513

514

### Clustering Metrics

515

516

Metrics for evaluating clustering quality and comparing clustering results.

517

518

```java { .api }

519

/**

520

* Base clustering metric interface

521

*/

522

interface ClusteringMetric {

523

/** Calculate metric from true and predicted cluster labels */

524

double score(int[] truth, int[] prediction);

525

}

526

527

/**

528

* Rand Index for clustering comparison

529

*/

530

class RandIndex implements ClusteringMetric {

531

/** Calculate Rand Index */

532

public static double of(int[] truth, int[] prediction);

533

}

534

535

/**

536

* Adjusted Rand Index

537

*/

538

class AdjustedRandIndex implements ClusteringMetric {

539

/** Calculate Adjusted Rand Index */

540

public static double of(int[] truth, int[] prediction);

541

}

542

543

/**

544

* Mutual Information between clusterings

545

*/

546

class MutualInformation implements ClusteringMetric {

547

/** Calculate mutual information */

548

public static double of(int[] truth, int[] prediction);

549

}

550

551

/**

552

* Normalized Mutual Information

553

*/

554

class NormalizedMutualInformation implements ClusteringMetric {

555

/** Normalization methods */

556

enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }

557

558

/** Calculate NMI with arithmetic mean normalization */

559

public static double of(int[] truth, int[] prediction);

560

561

/** Calculate NMI with specified normalization */

562

public static double of(int[] truth, int[] prediction, Method method);

563

}

564

565

/**

566

* Adjusted Mutual Information

567

*/

568

class AdjustedMutualInformation implements ClusteringMetric {

569

/** Adjustment methods */

570

enum Method { ARITHMETIC, GEOMETRIC, MAX, MIN }

571

572

/** Calculate AMI with arithmetic mean adjustment */

573

public static double of(int[] truth, int[] prediction);

574

575

/** Calculate AMI with specified adjustment */

576

public static double of(int[] truth, int[] prediction, Method method);

577

}

578

579

/**

580

* Contingency table for clustering evaluation

581

*/

582

class ContingencyTable {

583

/** Create contingency table */

584

public static ContingencyTable of(int[] truth, int[] prediction);

585

586

/** The contingency table matrix */

587

public final int[][] table;

588

589

/** Number of true clusters */

590

public final int n;

591

592

/** Number of predicted clusters */

593

public final int m;

594

595

/** Calculate mutual information */

596

public double mutualInformation();

597

598

/** Calculate entropy of true clustering */

599

public double entropyX();

600

601

/** Calculate entropy of predicted clustering */

602

public double entropyY();

603

}

604

```

605

606

**Comprehensive Usage Example:**

607

608

```java

609

import smile.validation.*;

610

import smile.validation.metric.*;

611

import smile.classification.RandomForest;

612

613

// Complete validation pipeline

614

public class ModelValidation {

615

public void validateModel(double[][] features, int[] labels) {

616

// 1. Cross-validation

617

CrossValidation cv = CrossValidation.stratify(10, labels);

618

var cvResults = cv.classification(RandomForest::fit, features, labels);

619

620

System.out.println("CV Accuracy: " + cvResults.avg.accuracy + " ± " + cvResults.std.accuracy);

621

System.out.println("CV F1 Score: " + cvResults.avg.f1[0]);

622

623

// 2. Bootstrap validation

624

var bootstrapResults = Bootstrap.classification(100, RandomForest::fit, features, labels);

625

System.out.println("Bootstrap Accuracy: " + bootstrapResults.avg.accuracy);

626

627

// 3. Detailed metrics analysis

628

RandomForest model = RandomForest.fit(features, labels);

629

int[] predictions = Arrays.stream(features).mapToInt(model::predict).toArray();

630

631

// Classification metrics

632

double accuracy = Accuracy.of(labels, predictions);

633

double[] precision = Precision.byClass(labels, predictions);

634

double[] recall = Recall.byClass(labels, predictions);

635

double[] f1 = FScore.byClass(labels, predictions);

636

double mcc = MatthewsCorrelation.of(labels, predictions);

637

638

// Confusion matrix

639

ConfusionMatrix cm = ConfusionMatrix.of(labels, predictions);

640

System.out.println("Confusion Matrix:");

641

for (int[] row : cm.matrix) {

642

System.out.println(Arrays.toString(row));

643

}

644

645

// 4. Statistical significance testing

646

double[] ci = cvResults.confidenceInterval("accuracy", 0.95);

647

System.out.println("95% CI for accuracy: [" + ci[0] + ", " + ci[1] + "]");

648

}

649

}

650

```

651

652

### Common Validation Patterns

653

654

Standard patterns for model validation in Smile Core:

655

656

**Basic Cross-Validation:**

657

```java

658

CrossValidation cv = CrossValidation.of(5);

659

var results = cv.classification(trainer, features, labels);

660

```

661

662

**Stratified Cross-Validation:**

663

```java

664

CrossValidation cv = CrossValidation.stratify(10, labels);

665

var results = cv.classification(trainer, features, labels);

666

```

667

668

**Time Series Validation:**

669

```java

670

CrossValidation cv = CrossValidation.timeSeries(5);

671

var results = cv.regression(trainer, features, targets);

672

```

673

674

**Bootstrap Validation:**

675

```java

676

var results = Bootstrap.classification(100, trainer, features, labels);

677

```

678

679

### Performance Analysis

680

681

All validation results provide comprehensive performance analysis:

682

683

- **Point Estimates**: Mean performance across folds/bootstrap samples

684

- **Variability**: Standard deviation of performance metrics

685

- **Confidence Intervals**: Statistical bounds for performance estimates

686

- **Per-Class Metrics**: Detailed breakdown for multi-class problems

687

- **Confusion Analysis**: Detailed error analysis through confusion matrices