0
# Classification Algorithms
1
2
Supervised learning algorithms for predicting categorical labels. MLlib provides a comprehensive suite of classification algorithms including logistic regression, tree-based methods, support vector machines, neural networks, and ensemble approaches.
3
4
## Capabilities
5
6
### Logistic Regression
7
8
Binary and multinomial logistic regression with elastic net regularization for linear classification problems.
9
10
```scala { .api }
11
/**
12
* Logistic regression classifier supporting binary and multinomial classification
13
* with L1, L2, and elastic net regularization
14
*/
15
class LogisticRegression extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] {
16
def setRegParam(value: Double): this.type
17
def setElasticNetParam(value: Double): this.type
18
def setMaxIter(value: Int): this.type
19
def setTol(value: Double): this.type
20
def setFitIntercept(value: Boolean): this.type
21
def setFamily(value: String): this.type
22
def setStandardization(value: Boolean): this.type
23
def setThreshold(value: Double): this.type
24
def setThresholds(value: Array[Double]): this.type
25
def setWeightCol(value: String): this.type
26
def setAggregationDepth(value: Int): this.type
27
def setMaxBlockSizeInMB(value: Double): this.type
28
def setLowerBoundsOnCoefficients(value: Matrix): this.type
29
def setUpperBoundsOnCoefficients(value: Matrix): this.type
30
def setLowerBoundsOnIntercepts(value: Vector): this.type
31
def setUpperBoundsOnIntercepts(value: Vector): this.type
32
def setInitialModel(model: LogisticRegressionModel): this.type
33
}
34
35
class LogisticRegressionModel extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] {
36
def coefficients: Vector
37
def intercept: Double
38
def coefficientMatrix: Matrix
39
def interceptVector: Vector
40
def numClasses: Int
41
def numFeatures: Int
42
def summary: LogisticRegressionSummary
43
def hasSummary: Boolean
44
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary
45
}
46
47
// Model summaries
48
class LogisticRegressionSummary extends Serializable {
49
def predictions: DataFrame
50
def predictionCol: String
51
def labelCol: String
52
def featuresCol: String
53
def totalIterations: Int
54
}
55
56
class BinaryLogisticRegressionSummary extends LogisticRegressionSummary {
57
def areaUnderROC: Double
58
def roc: DataFrame
59
def areaUnderPR: Double
60
def pr: DataFrame
61
def fMeasureByThreshold: DataFrame
62
def precisionByThreshold: DataFrame
63
def recallByThreshold: DataFrame
64
}
65
```
66
67
**Usage Example:**
68
69
```scala
70
import org.apache.spark.ml.classification.LogisticRegression
71
72
val lr = new LogisticRegression()
73
.setLabelCol("label")
74
.setFeaturesCol("features")
75
.setRegParam(0.01)
76
.setElasticNetParam(0.5)
77
.setMaxIter(100)
78
.setFamily("multinomial")
79
80
val model = lr.fit(trainingData)
81
val predictions = model.transform(testData)
82
83
// Access model coefficients
84
println(s"Coefficients: ${model.coefficients}")
85
println(s"Intercept: ${model.intercept}")
86
87
// Binary classification metrics (if binary)
88
if (model.hasSummary) {
89
val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
90
println(s"AUC: ${summary.areaUnderROC}")
91
}
92
```
93
94
### Decision Tree Classifier
95
96
Decision tree algorithm for classification with support for categorical and continuous features.
97
98
```scala { .api }
99
/**
100
* Decision tree classifier supporting both binary and multiclass classification
101
* with automatic handling of categorical features
102
*/
103
class DecisionTreeClassifier extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] {
104
def setMaxDepth(value: Int): this.type
105
def setMaxBins(value: Int): this.type
106
def setMinInstancesPerNode(value: Int): this.type
107
def setMinInfoGain(value: Double): this.type
108
def setMaxMemoryInMB(value: Int): this.type
109
def setCacheNodeIds(value: Boolean): this.type
110
def setCheckpointInterval(value: Int): this.type
111
def setImpurity(value: String): this.type
112
def setSeed(value: Long): this.type
113
def setWeightCol(value: String): this.type
114
}
115
116
class DecisionTreeClassificationModel extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
117
with DecisionTreeModel with MLWritable {
118
def rootNode: Node
119
def numNodes: Int
120
def depth: Int
121
def toDebugString: String
122
def featureImportances: Vector
123
}
124
```
125
126
**Usage Example:**
127
128
```scala
129
import org.apache.spark.ml.classification.DecisionTreeClassifier
130
131
val dt = new DecisionTreeClassifier()
132
.setLabelCol("label")
133
.setFeaturesCol("features")
134
.setMaxDepth(5)
135
.setMaxBins(32)
136
.setImpurity("gini")
137
138
val model = dt.fit(trainingData)
139
val predictions = model.transform(testData)
140
141
// Inspect the tree
142
println(s"Learned classification tree model:\n${model.toDebugString}")
143
println(s"Feature importances: ${model.featureImportances}")
144
```
145
146
### Random Forest Classifier
147
148
Ensemble of decision trees using bootstrap aggregating and random feature selection.
149
150
```scala { .api }
151
/**
152
* Random Forest classifier - ensemble of decision trees with bootstrap sampling
153
* and random feature selection for improved accuracy and overfitting resistance
154
*/
155
class RandomForestClassifier extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] {
156
def setNumTrees(value: Int): this.type
157
def setMaxDepth(value: Int): this.type
158
def setMaxBins(value: Int): this.type
159
def setMinInstancesPerNode(value: Int): this.type
160
def setMinInfoGain(value: Double): this.type
161
def setMaxMemoryInMB(value: Int): this.type
162
def setCacheNodeIds(value: Boolean): this.type
163
def setCheckpointInterval(value: Int): this.type
164
def setImpurity(value: String): this.type
165
def setSubsamplingRate(value: Double): this.type
166
def setFeatureSubsetStrategy(value: String): this.type
167
def setSeed(value: Long): this.type
168
def setWeightCol(value: String): this.type
169
}
170
171
class RandomForestClassificationModel extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
172
with TreeEnsembleModel[DecisionTreeClassificationModel] with MLWritable {
173
def trees: Array[DecisionTreeClassificationModel]
174
def numTrees: Int
175
def treeWeights: Array[Double]
176
def featureImportances: Vector
177
def toDebugString: String
178
}
179
```
180
181
### Gradient-Boosted Tree Classifier
182
183
Iterative ensemble method that builds models sequentially to correct previous errors.
184
185
```scala { .api }
186
/**
187
* Gradient-Boosted Tree classifier using iterative boosting to build
188
* an ensemble of weak decision tree learners
189
*/
190
class GBTClassifier extends Classifier[Vector, GBTClassifier, GBTClassificationModel] {
191
def setMaxIter(value: Int): this.type
192
def setStepSize(value: Double): this.type
193
def setMaxDepth(value: Int): this.type
194
def setMaxBins(value: Int): this.type
195
def setMinInstancesPerNode(value: Int): this.type
196
def setMinInfoGain(value: Double): this.type
197
def setMaxMemoryInMB(value: Int): this.type
198
def setCacheNodeIds(value: Boolean): this.type
199
def setSubsamplingRate(value: Double): this.type
200
def setCheckpointInterval(value: Int): this.type
201
def setLossType(value: String): this.type
202
def setFeatureSubsetStrategy(value: String): this.type
203
def setValidationTol(value: Double): this.type
204
def setValidationIndicatorCol(value: String): this.type
205
def setSeed(value: Long): this.type
206
def setWeightCol(value: String): this.type
207
}
208
209
class GBTClassificationModel extends ClassificationModel[Vector, GBTClassificationModel]
210
with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable {
211
def trees: Array[DecisionTreeRegressionModel]
212
def treeWeights: Array[Double]
213
def numTrees: Int
214
def featureImportances: Vector
215
def toDebugString: String
216
}
217
```
218
219
### Linear Support Vector Classifier
220
221
Linear SVM implementation using L-BFGS optimization for large-scale classification.
222
223
```scala { .api }
224
/**
225
* Linear Support Vector Classifier using L-BFGS optimizer
226
* for binary classification problems with linear decision boundaries
227
*/
228
class LinearSVC extends Classifier[Vector, LinearSVC, LinearSVCModel] {
229
def setRegParam(value: Double): this.type
230
def setMaxIter(value: Int): this.type
231
def setTol(value: Double): this.type
232
def setFitIntercept(value: Boolean): this.type
233
def setStandardization(value: Boolean): this.type
234
def setThreshold(value: Double): this.type
235
def setWeightCol(value: String): this.type
236
def setAggregationDepth(value: Int): this.type
237
}
238
239
class LinearSVCModel extends ClassificationModel[Vector, LinearSVCModel] with LinearSVCParams with MLWritable {
240
def coefficients: Vector
241
def intercept: Double
242
def numClasses: Int
243
def numFeatures: Int
244
}
245
```
246
247
### Multilayer Perceptron Classifier
248
249
Feed-forward artificial neural network with configurable hidden layers.
250
251
```scala { .api }
252
/**
253
* Multilayer Perceptron Classifier - feed-forward artificial neural network
254
* with backpropagation training and configurable layer architecture
255
*/
256
class MultilayerPerceptronClassifier extends ProbabilisticClassifier[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] {
257
def setLayers(value: Array[Int]): this.type
258
def setBlockSize(value: Int): this.type
259
def setSeed(value: Long): this.type
260
def setMaxIter(value: Int): this.type
261
def setTol(value: Double): this.type
262
def setInitialWeights(value: Vector): this.type
263
def setSolver(value: String): this.type
264
def setStepSize(value: Double): this.type
265
def setWeightCol(value: String): this.type
266
}
267
268
class MultilayerPerceptronClassificationModel extends ProbabilisticClassificationModel[Vector, MultilayerPerceptronClassificationModel] {
269
def layers: Array[Int]
270
def weights: Vector
271
def numFeatures: Int
272
}
273
```
274
275
**Usage Example:**
276
277
```scala
278
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
279
280
// Specify layers for the neural network:
281
// input layer (features) -> hidden layer -> output layer (classes)
282
val layers = Array[Int](4, 5, 4, 3) // 4 features, 5 hidden, 4 hidden, 3 classes
283
284
val trainer = new MultilayerPerceptronClassifier()
285
.setLayers(layers)
286
.setBlockSize(128)
287
.setSeed(1234L)
288
.setMaxIter(100)
289
290
val model = trainer.fit(trainingData)
291
val predictions = model. transform(testData)
292
```
293
294
### Naive Bayes Classifier
295
296
Probabilistic classifier based on Bayes' theorem with strong independence assumptions.
297
298
```scala { .api }
299
/**
300
* Naive Bayes classifier supporting multinomial, complement, and Bernoulli models
301
* with Laplace smoothing for handling zero probabilities
302
*/
303
class NaiveBayes extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] {
304
def setModelType(value: String): this.type
305
def setSmoothing(value: Double): this.type
306
def setWeightCol(value: String): this.type
307
}
308
309
class NaiveBayesModel extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable {
310
def pi: Vector
311
def theta: Matrix
312
def sigma: Matrix
313
def numFeatures: Int
314
def numClasses: Int
315
}
316
```
317
318
### One-vs-Rest Classifier
319
320
Meta-classifier for extending binary classifiers to multiclass problems.
321
322
```scala { .api }
323
/**
324
* One-vs-Rest strategy for multiclass classification using binary classifiers
325
* Trains N binary classifiers for N classes, each separating one class from all others
326
*/
327
class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
328
def setClassifier(value: Classifier[_, _, _]): this.type
329
def setLabelCol(value: String): this.type
330
def setFeaturesCol(value: String): this.type
331
def setPredictionCol(value: String): this.type
332
def setParallelism(value: Int): this.type
333
def setWeightCol(value: String): this.type
334
}
335
336
class OneVsRestModel extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
337
def models: Array[_ <: ClassificationModel[_, _]]
338
def numClasses: Int
339
def numFeatures: Int
340
}
341
```
342
343
### Factorization Machine Classifier
344
345
Factorization machines for classification with feature interaction modeling.
346
347
```scala { .api }
348
/**
349
* Factorization Machine classifier for modeling feature interactions
350
* through low-rank matrix factorization, effective for sparse data
351
*/
352
class FMClassifier extends ProbabilisticClassifier[Vector, FMClassifier, FMClassificationModel] {
353
def setFactorSize(value: Int): this.type
354
def setFitIntercept(value: Boolean): this.type
355
def setFitLinear(value: Boolean): this.type
356
def setRegParam(value: Double): this.type
357
def setMiniBatchFraction(value: Double): this.type
358
def setInitStd(value: Double): this.type
359
def setMaxIter(value: Int): this.type
360
def setStepSize(value: Double): this.type
361
def setTol(value: Double): this.type
362
def setSolver(value: String): this.type
363
def setSeed(value: Long): this.type
364
}
365
366
class FMClassificationModel extends ProbabilisticClassificationModel[Vector, FMClassificationModel] with FMClassifierParams with MLWritable {
367
def linear: Vector
368
def factors: Matrix
369
def intercept: Double
370
}
371
```
372
373
## Base Classes and Traits
374
375
```scala { .api }
376
// Core classification abstractions
377
abstract class Classifier[
378
FeaturesType,
379
Learner <: Classifier[FeaturesType, Learner, M],
380
M <: ClassificationModel[FeaturesType, M]
381
] extends Predictor[FeaturesType, Learner, M] with ClassifierParams
382
383
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
384
extends PredictionModel[FeaturesType, M] with ClassifierParams {
385
def numClasses: Int
386
def predictRaw(features: FeaturesType): Vector
387
}
388
389
abstract class ProbabilisticClassifier[
390
FeaturesType,
391
Learner <: ProbabilisticClassifier[FeaturesType, Learner, M],
392
M <: ProbabilisticClassificationModel[FeaturesType, M]
393
] extends Classifier[FeaturesType, Learner, M] with ProbabilisticClassifierParams
394
395
abstract class ProbabilisticClassificationModel[FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]
396
extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
397
def predictProbability(features: FeaturesType): Vector
398
def probabilityCol: String
399
def rawPredictionCol: String
400
}
401
402
// Parameter traits
403
trait ClassifierParams extends PredictorParams with HasRawPredictionCol with HasThresholds
404
trait ProbabilisticClassifierParams extends ClassifierParams with HasProbabilityCol with HasThresholds
405
```
406
407
## Parameter Types
408
409
```scala { .api }
410
// Common parameter traits for classification
411
trait HasThreshold extends Params {
412
final val threshold: DoubleParam
413
def getThreshold: Double
414
def setThreshold(value: Double): this.type
415
}
416
417
trait HasThresholds extends Params {
418
final val thresholds: DoubleArrayParam
419
def getThresholds: Array[Double]
420
def setThresholds(value: Array[Double]): this.type
421
}
422
423
trait HasRawPredictionCol extends Params {
424
final val rawPredictionCol: Param[String]
425
def getRawPredictionCol: String
426
def setRawPredictionCol(value: String): this.type
427
}
428
429
trait HasProbabilityCol extends Params {
430
final val probabilityCol: Param[String]
431
def getProbabilityCol: String
432
def setProbabilityCol(value: String): this.type
433
}
434
435
// Tree-specific parameters
436
trait DecisionTreeParams extends Params {
437
final val maxDepth: IntParam
438
final val maxBins: IntParam
439
final val minInstancesPerNode: IntParam
440
final val minInfoGain: DoubleParam
441
final val impurity: Param[String]
442
}
443
444
trait TreeEnsembleParams extends DecisionTreeParams {
445
final val subsamplingRate: DoubleParam
446
final val featureSubsetStrategy: Param[String]
447
}
448
```
449
450
## Common Usage Patterns
451
452
### Basic Classification Workflow
453
454
```scala
455
import org.apache.spark.ml.classification.RandomForestClassifier
456
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
457
458
// Create classifier
459
val rf = new RandomForestClassifier()
460
.setLabelCol("label")
461
.setFeaturesCol("features")
462
.setNumTrees(20)
463
464
// Train model
465
val model = rf.fit(trainingData)
466
467
// Make predictions
468
val predictions = model.transform(testData)
469
470
// Evaluate
471
val evaluator = new MulticlassClassificationEvaluator()
472
.setLabelCol("label")
473
.setPredictionCol("prediction")
474
.setMetricName("accuracy")
475
476
val accuracy = evaluator.evaluate(predictions)
477
println(s"Test Accuracy = $accuracy")
478
```
479
480
### Cross-Validation with Parameter Tuning
481
482
```scala
483
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
484
485
val paramGrid = new ParamGridBuilder()
486
.addGrid(rf.numTrees, Array(10, 20, 30))
487
.addGrid(rf.maxDepth, Array(5, 10, 15))
488
.build()
489
490
val cv = new CrossValidator()
491
.setEstimator(rf)
492
.setEvaluator(evaluator)
493
.setEstimatorParamMaps(paramGrid)
494
.setNumFolds(3)
495
496
val cvModel = cv.fit(trainingData)
497
val bestModel = cvModel.bestModel
498
```
499
500
### Accessing Model Information
501
502
```scala
503
// For tree-based models
504
val treeModel = model.asInstanceOf[RandomForestClassificationModel]
505
println(s"Number of trees: ${treeModel.numTrees}")
506
println(s"Feature importances: ${treeModel.featureImportances}")
507
508
// For linear models
509
val linearModel = lrModel.asInstanceOf[LogisticRegressionModel]
510
println(s"Coefficients: ${linearModel.coefficients}")
511
println(s"Intercept: ${linearModel.intercept}")
512
513
// For probabilistic models
514
val probPredictions = model.transform(testData)
515
.select("features", "label", "rawPrediction", "probability", "prediction")
516
```