or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-analytics.mdclassification.mdclustering.mddeep-learning.mdfeature-engineering.mdindex.mdregression.mdvalidation-metrics.md

classification.mddocs/

0

# Classification

1

2

Comprehensive supervised learning algorithms for predicting categorical outcomes. Smile Core provides implementations of traditional algorithms like decision trees and support vector machines, ensemble methods like random forests and gradient boosting, and neural network approaches.

3

4

## Capabilities

5

6

### Core Classification Interface

7

8

All classification algorithms implement the unified `Classifier<T>` interface, providing consistent prediction methods and optional online learning support.

9

10

```java { .api }

11

/**

12

* Base classification interface for all supervised learning algorithms

13

* @param <T> the type of input objects

14

*/

15

interface Classifier<T> extends ToIntFunction<T>, ToDoubleFunction<T>, Serializable {

16

/** Predict the class label for input */

17

int predict(T x);

18

19

/** Predict class label and return class probabilities */

20

int predict(T x, double[] posteriori);

21

22

/** Get number of classes */

23

default int numClasses();

24

25

/** Get class labels array */

26

default int[] classes();

27

28

/** Online learning update (if supported) */

29

default void update(T x, int y);

30

31

/** Create ensemble of multiple classifiers */

32

static <T> Classifier<T> ensemble(Classifier<T>... classifiers);

33

}

34

```

35

36

### Random Forest

37

38

Ensemble method combining multiple decision trees with bootstrap sampling and random feature selection.

39

40

```java { .api }

41

/**

42

* Random Forest classifier implementing ensemble of decision trees

43

*/

44

class RandomForest implements Classifier<Tuple>, DataFrameClassifier, TreeSHAP {

45

/** Train random forest with formula on DataFrame */

46

public static RandomForest fit(Formula formula, DataFrame data);

47

48

/** Train with custom parameters */

49

public static RandomForest fit(Formula formula, DataFrame data, Properties params);

50

51

/** Predict class label for Tuple */

52

public int predict(Tuple x);

53

54

/** Predict with class probabilities for Tuple */

55

public int predict(Tuple x, double[] posteriori);

56

57

/** Calculate SHAP values for feature importance */

58

public double[] shap(Tuple x);

59

60

/** Get out-of-bag error estimate */

61

public double error();

62

63

/** Get feature importance scores */

64

public double[] importance();

65

66

/** Trim forest to specified number of trees */

67

public RandomForest trim(int ntrees);

68

69

/** Merge with another random forest */

70

public RandomForest merge(RandomForest other);

71

72

/** Prune forest using test data */

73

public RandomForest prune(DataFrame test);

74

}

75

```

76

77

**Usage Example:**

78

79

```java

80

import smile.classification.RandomForest;

81

import smile.data.DataFrame;

82

import smile.data.formula.Formula;

83

84

// Train on DataFrame

85

Formula formula = Formula.lhs("species");

86

RandomForest model = RandomForest.fit(formula, irisData);

87

88

// Predict new samples (using Tuple from DataFrame)

89

int prediction = model.predict(newTuple);

90

91

// Get prediction probabilities

92

double[] probabilities = new double[3];

93

int predicted = model.predict(newTuple, probabilities);

94

95

// Get SHAP values for feature importance

96

double[] shapValues = model.shap(newTuple);

97

```

98

99

### Decision Tree

100

101

Single decision tree classifier using CART (Classification and Regression Trees) algorithm.

102

103

```java { .api }

104

/**

105

* Decision tree classifier using CART algorithm

106

*/

107

class DecisionTree implements Classifier<double[]>, DataFrameClassifier {

108

/** Train decision tree with default parameters */

109

public static DecisionTree fit(double[][] x, int[] y);

110

111

/** Train decision tree with formula on DataFrame */

112

public static DecisionTree fit(Formula formula, DataFrame data);

113

114

/** Train with custom split rule and parameters */

115

public static DecisionTree fit(double[][] x, int[] y, SplitRule rule, int maxDepth, int maxNodes, int nodeSize);

116

117

/** Predict class label */

118

public int predict(double[] x);

119

120

/** Get tree structure as string */

121

public String toString();

122

123

/** Get feature importance scores */

124

public double[] importance();

125

}

126

```

127

128

### Support Vector Machine

129

130

Support Vector Machine classifier with various kernel functions and multi-class support.

131

132

```java { .api }

133

/**

134

* Support Vector Machine classifier

135

*/

136

class SVM implements Classifier<double[]> {

137

/** Train linear SVM */

138

public static SVM fit(double[][] x, int[] y);

139

140

/** Train SVM with RBF kernel */

141

public static SVM fit(double[][] x, int[] y, double gamma);

142

143

/** Train SVM with custom kernel and parameters */

144

public static SVM fit(double[][] x, int[] y, Kernel kernel, double C, double tol);

145

146

/** Predict class label */

147

public int predict(double[] x);

148

149

/** Get support vectors */

150

public SupportVector[] supportVectors();

151

152

/** Get number of support vectors */

153

public int numSupportVectors();

154

}

155

```

156

157

### Logistic Regression

158

159

Linear classifier using logistic regression with L1/L2 regularization options.

160

161

```java { .api }

162

/**

163

* Logistic regression classifier

164

*/

165

class LogisticRegression implements Classifier<double[]> {

166

/** Train logistic regression */

167

public static LogisticRegression fit(double[][] x, int[] y);

168

169

/** Train with regularization parameters */

170

public static LogisticRegression fit(double[][] x, int[] y, double lambda, double tolerance, int maxIter);

171

172

/** Predict class label */

173

public int predict(double[] x);

174

175

/** Predict with class probabilities */

176

public int predict(double[] x, double[] posteriori);

177

178

/** Get model coefficients */

179

public double[] coefficients();

180

181

/** Get intercept term */

182

public double intercept();

183

}

184

```

185

186

### Naive Bayes Classifiers

187

188

Family of probabilistic classifiers based on Bayes' theorem with independence assumptions.

189

190

```java { .api }

191

/**

192

* Gaussian Naive Bayes classifier for continuous features

193

*/

194

class NaiveBayes implements Classifier<double[]> {

195

/** Train Gaussian Naive Bayes */

196

public static NaiveBayes fit(double[][] x, int[] y);

197

198

/** Train with Laplace smoothing */

199

public static NaiveBayes fit(double[][] x, int[] y, Model model, int numClasses, double sigma);

200

201

/** Predict class label */

202

public int predict(double[] x);

203

204

/** Online learning update */

205

public void update(double[] x, int y);

206

}

207

208

/**

209

* Discrete Naive Bayes for categorical features

210

*/

211

class DiscreteNaiveBayes implements Classifier<int[]> {

212

enum Model { BERNOULLI, MULTINOMIAL, CNB, WCNB, TWCNB }

213

214

/** Train discrete Naive Bayes */

215

public static DiscreteNaiveBayes fit(int[][] x, int[] y, Model model);

216

217

/** Predict class label */

218

public int predict(int[] x);

219

220

/** Online learning update */

221

public void update(int[] x, int y);

222

}

223

```

224

225

### Neural Networks

226

227

Multi-layer perceptron classifier with configurable architecture and training options.

228

229

```java { .api }

230

/**

231

* Multi-Layer Perceptron classifier

232

*/

233

class MLP implements Classifier<double[]> {

234

/** Train MLP with default architecture */

235

public static MLP fit(double[][] x, int[] y);

236

237

/** Train MLP with custom architecture */

238

public static MLP fit(double[][] x, int[] y, int[] hiddenLayers, ActivationFunction activation);

239

240

/** Train with full configuration */

241

public static MLP fit(double[][] x, int[] y, Properties params);

242

243

/** Predict class label */

244

public int predict(double[] x);

245

246

/** Online learning update */

247

public void update(double[] x, int y);

248

249

/** Get network weights */

250

public double[][] getWeights(int layer);

251

}

252

```

253

254

### Ensemble Methods

255

256

Advanced ensemble techniques for improved prediction accuracy and robustness.

257

258

```java { .api }

259

/**

260

* Adaptive Boosting (AdaBoost) classifier

261

*/

262

class AdaBoost implements Classifier<double[]> {

263

/** Train AdaBoost with decision stumps */

264

public static AdaBoost fit(double[][] x, int[] y);

265

266

/** Train with custom weak learner and iterations */

267

public static AdaBoost fit(double[][] x, int[] y, int numTrees, int maxDepth);

268

269

/** Predict class label */

270

public int predict(double[] x);

271

272

/** Get weak learner weights */

273

public double[] importance();

274

}

275

276

/**

277

* Gradient Tree Boosting classifier

278

*/

279

class GradientTreeBoost implements Classifier<double[]>, DataFrameClassifier {

280

/** Train gradient boosting */

281

public static GradientTreeBoost fit(double[][] x, int[] y);

282

283

/** Train with formula on DataFrame */

284

public static GradientTreeBoost fit(Formula formula, DataFrame data);

285

286

/** Train with custom parameters */

287

public static GradientTreeBoost fit(double[][] x, int[] y, int numTrees, int maxDepth, double shrinkage);

288

289

/** Predict class label */

290

public int predict(double[] x);

291

292

/** Get feature importance */

293

public double[] importance();

294

}

295

```

296

297

### K-Nearest Neighbors

298

299

Instance-based classifier using k-nearest neighbors for prediction.

300

301

```java { .api }

302

/**

303

* K-Nearest Neighbors classifier

304

*/

305

class KNN implements Classifier<double[]> {

306

/** Train KNN classifier */

307

public static KNN fit(double[][] x, int[] y, int k);

308

309

/** Train with custom distance metric */

310

public static KNN fit(double[][] x, int[] y, int k, Distance<double[]> distance);

311

312

/** Predict class label */

313

public int predict(double[] x);

314

315

/** Predict with neighbor distances */

316

public int predict(double[] x, double[] distances);

317

318

/** Get k parameter */

319

public int k();

320

}

321

```

322

323

### Multi-class Strategies

324

325

Techniques for extending binary classifiers to multi-class problems.

326

327

```java { .api }

328

/**

329

* One-versus-one multi-class strategy

330

*/

331

class OneVersusOne implements Classifier<double[]> {

332

/** Train OvO with binary classifier trainer */

333

public static OneVersusOne fit(Classifier.Trainer<double[], ?> trainer, double[][] x, int[] y);

334

335

/** Predict class label */

336

public int predict(double[] x);

337

}

338

339

/**

340

* One-versus-rest multi-class strategy

341

*/

342

class OneVersusRest implements Classifier<double[]> {

343

/** Train OvR with binary classifier trainer */

344

public static OneVersusRest fit(Classifier.Trainer<double[], ?> trainer, double[][] x, int[] y);

345

346

/** Predict class label */

347

public int predict(double[] x);

348

}

349

```

350

351

### Linear Discriminant Analysis

352

353

Linear and quadratic discriminant analysis for Gaussian-distributed classes.

354

355

```java { .api }

356

/**

357

* Linear Discriminant Analysis

358

*/

359

class LDA implements Classifier<double[]> {

360

/** Train LDA classifier */

361

public static LDA fit(double[][] x, int[] y);

362

363

/** Train with prior probabilities */

364

public static LDA fit(double[][] x, int[] y, double[] priori);

365

366

/** Predict class label */

367

public int predict(double[] x);

368

369

/** Get discriminant projection */

370

public double[] project(double[] x);

371

}

372

373

/**

374

* Quadratic Discriminant Analysis

375

*/

376

class QDA implements Classifier<double[]> {

377

/** Train QDA classifier */

378

public static QDA fit(double[][] x, int[] y);

379

380

/** Train with prior probabilities */

381

public static QDA fit(double[][] x, int[] y, double[] priori);

382

383

/** Predict class label */

384

public int predict(double[] x);

385

}

386

387

/**

388

* Regularized Discriminant Analysis

389

*/

390

class RDA implements Classifier<double[]> {

391

/** Train RDA with regularization parameters */

392

public static RDA fit(double[][] x, int[] y, double alpha, double gamma);

393

394

/** Predict class label */

395

public int predict(double[] x);

396

}

397

```

398

399

### Utility Classes

400

401

Helper classes for classification tasks and probability calibration.

402

403

```java { .api }

404

/**

405

* Class label encoding utilities

406

*/

407

class ClassLabels {

408

/** Fit encoder from class labels */

409

public static ClassLabels fit(int[] y);

410

411

/** Encoded class labels */

412

public final int[] classes;

413

414

/** Number of classes */

415

public final int numClasses;

416

}

417

418

/**

419

* Platt scaling for probability calibration

420

*/

421

class PlattScaling {

422

/** Fit Platt scaling from classifier outputs */

423

public static PlattScaling fit(double[] scores, int[] y);

424

425

/** Apply calibration to classifier output */

426

public double calibrate(double score);

427

}

428

429

/**

430

* Isotonic regression scaling for probability calibration

431

*/

432

class IsotonicRegressionScaling {

433

/** Fit isotonic regression scaling */

434

public static IsotonicRegressionScaling fit(double[] scores, int[] y);

435

436

/** Apply calibration to classifier output */

437

public double calibrate(double score);

438

}

439

```

440

441

### Training Patterns

442

443

All classifiers follow consistent training patterns:

444

445

**Array-based training:**

446

```java

447

Classifier model = Algorithm.fit(double[][] x, int[] y);

448

Classifier model = Algorithm.fit(double[][] x, int[] y, Properties params);

449

```

450

451

**DataFrame-based training:**

452

```java

453

Classifier model = Algorithm.fit(Formula formula, DataFrame data);

454

```

455

456

**Prediction patterns:**

457

```java

458

int prediction = model.predict(double[] x);

459

int prediction = model.predict(double[] x, double[] posteriori);

460

```

461

462

### Common Parameters

463

464

Most classification algorithms support these common configuration options:

465

466

- **maxDepth**: Maximum tree depth (tree-based algorithms)

467

- **numTrees**: Number of trees in ensemble

468

- **nodeSize**: Minimum samples per leaf node

469

- **subsample**: Fraction of samples for bootstrap

470

- **mtry**: Number of features to consider at each split

471

- **splitRule**: Splitting criterion (GINI, ENTROPY, CLASSIFICATION_ERROR)

472

- **seed**: Random seed for reproducibility