0
# Classification
1
2
Comprehensive supervised learning algorithms for predicting categorical outcomes. Smile Core provides implementations of traditional algorithms like decision trees and support vector machines, ensemble methods like random forests and gradient boosting, and neural network approaches.
3
4
## Capabilities
5
6
### Core Classification Interface
7
8
All classification algorithms implement the unified `Classifier<T>` interface, providing consistent prediction methods and optional online learning support.
9
10
```java { .api }
11
/**
12
* Base classification interface for all supervised learning algorithms
13
* @param <T> the type of input objects
14
*/
15
interface Classifier<T> extends ToIntFunction<T>, ToDoubleFunction<T>, Serializable {
16
/** Predict the class label for input */
17
int predict(T x);
18
19
/** Predict class label and return class probabilities */
20
int predict(T x, double[] posteriori);
21
22
/** Get number of classes */
23
default int numClasses();
24
25
/** Get class labels array */
26
default int[] classes();
27
28
/** Online learning update (if supported) */
29
default void update(T x, int y);
30
31
/** Create ensemble of multiple classifiers */
32
static <T> Classifier<T> ensemble(Classifier<T>... classifiers);
33
}
34
```
35
36
### Random Forest
37
38
Ensemble method combining multiple decision trees with bootstrap sampling and random feature selection.
39
40
```java { .api }
41
/**
42
* Random Forest classifier implementing ensemble of decision trees
43
*/
44
class RandomForest implements Classifier<Tuple>, DataFrameClassifier, TreeSHAP {
45
/** Train random forest with formula on DataFrame */
46
public static RandomForest fit(Formula formula, DataFrame data);
47
48
/** Train with custom parameters */
49
public static RandomForest fit(Formula formula, DataFrame data, Properties params);
50
51
/** Predict class label for Tuple */
52
public int predict(Tuple x);
53
54
/** Predict with class probabilities for Tuple */
55
public int predict(Tuple x, double[] posteriori);
56
57
/** Calculate SHAP values for feature importance */
58
public double[] shap(Tuple x);
59
60
/** Get out-of-bag error estimate */
61
public double error();
62
63
/** Get feature importance scores */
64
public double[] importance();
65
66
/** Trim forest to specified number of trees */
67
public RandomForest trim(int ntrees);
68
69
/** Merge with another random forest */
70
public RandomForest merge(RandomForest other);
71
72
/** Prune forest using test data */
73
public RandomForest prune(DataFrame test);
74
}
75
```
76
77
**Usage Example:**
78
79
```java
80
import smile.classification.RandomForest;
81
import smile.data.DataFrame;
82
import smile.data.formula.Formula;
83
84
// Train on DataFrame
85
Formula formula = Formula.lhs("species");
86
RandomForest model = RandomForest.fit(formula, irisData);
87
88
// Predict new samples (using Tuple from DataFrame)
89
int prediction = model.predict(newTuple);
90
91
// Get prediction probabilities
92
double[] probabilities = new double[3];
93
int predicted = model.predict(newTuple, probabilities);
94
95
// Get SHAP values for feature importance
96
double[] shapValues = model.shap(newTuple);
97
```
98
99
### Decision Tree
100
101
Single decision tree classifier using CART (Classification and Regression Trees) algorithm.
102
103
```java { .api }
104
/**
105
* Decision tree classifier using CART algorithm
106
*/
107
class DecisionTree implements Classifier<double[]>, DataFrameClassifier {
108
/** Train decision tree with default parameters */
109
public static DecisionTree fit(double[][] x, int[] y);
110
111
/** Train decision tree with formula on DataFrame */
112
public static DecisionTree fit(Formula formula, DataFrame data);
113
114
/** Train with custom split rule and parameters */
115
public static DecisionTree fit(double[][] x, int[] y, SplitRule rule, int maxDepth, int maxNodes, int nodeSize);
116
117
/** Predict class label */
118
public int predict(double[] x);
119
120
/** Get tree structure as string */
121
public String toString();
122
123
/** Get feature importance scores */
124
public double[] importance();
125
}
126
```
127
128
### Support Vector Machine
129
130
Support Vector Machine classifier with various kernel functions and multi-class support.
131
132
```java { .api }
133
/**
134
* Support Vector Machine classifier
135
*/
136
class SVM implements Classifier<double[]> {
137
/** Train linear SVM */
138
public static SVM fit(double[][] x, int[] y);
139
140
/** Train SVM with RBF kernel */
141
public static SVM fit(double[][] x, int[] y, double gamma);
142
143
/** Train SVM with custom kernel and parameters */
144
public static SVM fit(double[][] x, int[] y, Kernel kernel, double C, double tol);
145
146
/** Predict class label */
147
public int predict(double[] x);
148
149
/** Get support vectors */
150
public SupportVector[] supportVectors();
151
152
/** Get number of support vectors */
153
public int numSupportVectors();
154
}
155
```
156
157
### Logistic Regression
158
159
Linear classifier using logistic regression with L1/L2 regularization options.
160
161
```java { .api }
162
/**
163
* Logistic regression classifier
164
*/
165
class LogisticRegression implements Classifier<double[]> {
166
/** Train logistic regression */
167
public static LogisticRegression fit(double[][] x, int[] y);
168
169
/** Train with regularization parameters */
170
public static LogisticRegression fit(double[][] x, int[] y, double lambda, double tolerance, int maxIter);
171
172
/** Predict class label */
173
public int predict(double[] x);
174
175
/** Predict with class probabilities */
176
public int predict(double[] x, double[] posteriori);
177
178
/** Get model coefficients */
179
public double[] coefficients();
180
181
/** Get intercept term */
182
public double intercept();
183
}
184
```
185
186
### Naive Bayes Classifiers
187
188
Family of probabilistic classifiers based on Bayes' theorem with independence assumptions.
189
190
```java { .api }
191
/**
192
* Gaussian Naive Bayes classifier for continuous features
193
*/
194
class NaiveBayes implements Classifier<double[]> {
195
/** Train Gaussian Naive Bayes */
196
public static NaiveBayes fit(double[][] x, int[] y);
197
198
/** Train with Laplace smoothing */
199
public static NaiveBayes fit(double[][] x, int[] y, Model model, int numClasses, double sigma);
200
201
/** Predict class label */
202
public int predict(double[] x);
203
204
/** Online learning update */
205
public void update(double[] x, int y);
206
}
207
208
/**
209
* Discrete Naive Bayes for categorical features
210
*/
211
class DiscreteNaiveBayes implements Classifier<int[]> {
212
enum Model { BERNOULLI, MULTINOMIAL, CNB, WCNB, TWCNB }
213
214
/** Train discrete Naive Bayes */
215
public static DiscreteNaiveBayes fit(int[][] x, int[] y, Model model);
216
217
/** Predict class label */
218
public int predict(int[] x);
219
220
/** Online learning update */
221
public void update(int[] x, int y);
222
}
223
```
224
225
### Neural Networks
226
227
Multi-layer perceptron classifier with configurable architecture and training options.
228
229
```java { .api }
230
/**
231
* Multi-Layer Perceptron classifier
232
*/
233
class MLP implements Classifier<double[]> {
234
/** Train MLP with default architecture */
235
public static MLP fit(double[][] x, int[] y);
236
237
/** Train MLP with custom architecture */
238
public static MLP fit(double[][] x, int[] y, int[] hiddenLayers, ActivationFunction activation);
239
240
/** Train with full configuration */
241
public static MLP fit(double[][] x, int[] y, Properties params);
242
243
/** Predict class label */
244
public int predict(double[] x);
245
246
/** Online learning update */
247
public void update(double[] x, int y);
248
249
/** Get network weights */
250
public double[][] getWeights(int layer);
251
}
252
```
253
254
### Ensemble Methods
255
256
Advanced ensemble techniques for improved prediction accuracy and robustness.
257
258
```java { .api }
259
/**
260
* Adaptive Boosting (AdaBoost) classifier
261
*/
262
class AdaBoost implements Classifier<double[]> {
263
/** Train AdaBoost with decision stumps */
264
public static AdaBoost fit(double[][] x, int[] y);
265
266
/** Train with custom weak learner and iterations */
267
public static AdaBoost fit(double[][] x, int[] y, int numTrees, int maxDepth);
268
269
/** Predict class label */
270
public int predict(double[] x);
271
272
/** Get weak learner weights */
273
public double[] importance();
274
}
275
276
/**
277
* Gradient Tree Boosting classifier
278
*/
279
class GradientTreeBoost implements Classifier<double[]>, DataFrameClassifier {
280
/** Train gradient boosting */
281
public static GradientTreeBoost fit(double[][] x, int[] y);
282
283
/** Train with formula on DataFrame */
284
public static GradientTreeBoost fit(Formula formula, DataFrame data);
285
286
/** Train with custom parameters */
287
public static GradientTreeBoost fit(double[][] x, int[] y, int numTrees, int maxDepth, double shrinkage);
288
289
/** Predict class label */
290
public int predict(double[] x);
291
292
/** Get feature importance */
293
public double[] importance();
294
}
295
```
296
297
### K-Nearest Neighbors
298
299
Instance-based classifier using k-nearest neighbors for prediction.
300
301
```java { .api }
302
/**
303
* K-Nearest Neighbors classifier
304
*/
305
class KNN implements Classifier<double[]> {
306
/** Train KNN classifier */
307
public static KNN fit(double[][] x, int[] y, int k);
308
309
/** Train with custom distance metric */
310
public static KNN fit(double[][] x, int[] y, int k, Distance<double[]> distance);
311
312
/** Predict class label */
313
public int predict(double[] x);
314
315
/** Predict with neighbor distances */
316
public int predict(double[] x, double[] distances);
317
318
/** Get k parameter */
319
public int k();
320
}
321
```
322
323
### Multi-class Strategies
324
325
Techniques for extending binary classifiers to multi-class problems.
326
327
```java { .api }
328
/**
329
* One-versus-one multi-class strategy
330
*/
331
class OneVersusOne implements Classifier<double[]> {
332
/** Train OvO with binary classifier trainer */
333
public static OneVersusOne fit(Classifier.Trainer<double[], ?> trainer, double[][] x, int[] y);
334
335
/** Predict class label */
336
public int predict(double[] x);
337
}
338
339
/**
340
* One-versus-rest multi-class strategy
341
*/
342
class OneVersusRest implements Classifier<double[]> {
343
/** Train OvR with binary classifier trainer */
344
public static OneVersusRest fit(Classifier.Trainer<double[], ?> trainer, double[][] x, int[] y);
345
346
/** Predict class label */
347
public int predict(double[] x);
348
}
349
```
350
351
### Linear Discriminant Analysis
352
353
Linear and quadratic discriminant analysis for Gaussian-distributed classes.
354
355
```java { .api }
356
/**
357
* Linear Discriminant Analysis
358
*/
359
class LDA implements Classifier<double[]> {
360
/** Train LDA classifier */
361
public static LDA fit(double[][] x, int[] y);
362
363
/** Train with prior probabilities */
364
public static LDA fit(double[][] x, int[] y, double[] priori);
365
366
/** Predict class label */
367
public int predict(double[] x);
368
369
/** Get discriminant projection */
370
public double[] project(double[] x);
371
}
372
373
/**
374
* Quadratic Discriminant Analysis
375
*/
376
class QDA implements Classifier<double[]> {
377
/** Train QDA classifier */
378
public static QDA fit(double[][] x, int[] y);
379
380
/** Train with prior probabilities */
381
public static QDA fit(double[][] x, int[] y, double[] priori);
382
383
/** Predict class label */
384
public int predict(double[] x);
385
}
386
387
/**
388
* Regularized Discriminant Analysis
389
*/
390
class RDA implements Classifier<double[]> {
391
/** Train RDA with regularization parameters */
392
public static RDA fit(double[][] x, int[] y, double alpha, double gamma);
393
394
/** Predict class label */
395
public int predict(double[] x);
396
}
397
```
398
399
### Utility Classes
400
401
Helper classes for classification tasks and probability calibration.
402
403
```java { .api }
404
/**
405
* Class label encoding utilities
406
*/
407
class ClassLabels {
408
/** Fit encoder from class labels */
409
public static ClassLabels fit(int[] y);
410
411
/** Encoded class labels */
412
public final int[] classes;
413
414
/** Number of classes */
415
public final int numClasses;
416
}
417
418
/**
419
* Platt scaling for probability calibration
420
*/
421
class PlattScaling {
422
/** Fit Platt scaling from classifier outputs */
423
public static PlattScaling fit(double[] scores, int[] y);
424
425
/** Apply calibration to classifier output */
426
public double calibrate(double score);
427
}
428
429
/**
430
* Isotonic regression scaling for probability calibration
431
*/
432
class IsotonicRegressionScaling {
433
/** Fit isotonic regression scaling */
434
public static IsotonicRegressionScaling fit(double[] scores, int[] y);
435
436
/** Apply calibration to classifier output */
437
public double calibrate(double score);
438
}
439
```
440
441
### Training Patterns
442
443
All classifiers follow consistent training patterns:
444
445
**Array-based training:**
446
```java
447
Classifier model = Algorithm.fit(double[][] x, int[] y);
448
Classifier model = Algorithm.fit(double[][] x, int[] y, Properties params);
449
```
450
451
**DataFrame-based training:**
452
```java
453
Classifier model = Algorithm.fit(Formula formula, DataFrame data);
454
```
455
456
**Prediction patterns:**
457
```java
458
int prediction = model.predict(double[] x);
459
int prediction = model.predict(double[] x, double[] posteriori);
460
```
461
462
### Common Parameters
463
464
Most classification algorithms support these common configuration options:
465
466
- **maxDepth**: Maximum tree depth (tree-based algorithms)
467
- **numTrees**: Number of trees in ensemble
468
- **nodeSize**: Minimum samples per leaf node
469
- **subsample**: Fraction of samples for bootstrap
470
- **mtry**: Number of features to consider at each split
471
- **splitRule**: Splitting criterion (GINI, ENTROPY, CLASSIFICATION_ERROR)
472
- **seed**: Random seed for reproducibility