0
# Classification
1
2
MLlib provides comprehensive classification algorithms for supervised learning tasks with categorical labels. All classifiers follow the Estimator-Transformer pattern and support the Pipeline API.
3
4
## Logistic Regression
5
6
### Estimator
7
8
```scala { .api }
9
class LogisticRegression(override val uid: String) extends Classifier[Vector, LogisticRegression, LogisticRegressionModel]
10
with LogisticRegressionParams with DefaultParamsWritable {
11
12
def this() = this(Identifiable.randomUID("logreg"))
13
14
def setRegParam(value: Double): LogisticRegression
15
def setElasticNetParam(value: Double): LogisticRegression
16
def setMaxIter(value: Int): LogisticRegression
17
def setTol(value: Double): LogisticRegression
18
def setFitIntercept(value: Boolean): LogisticRegression
19
def setFamily(value: String): LogisticRegression
20
def setStandardization(value: Boolean): LogisticRegression
21
def setThreshold(value: Double): LogisticRegression
22
def setThresholds(value: Array[Double]): LogisticRegression
23
def setWeightCol(value: String): LogisticRegression
24
def setAggregationDepth(value: Int): LogisticRegression
25
26
// Bounded optimization methods (expert parameters)
27
def setLowerBoundsOnCoefficients(value: Matrix): LogisticRegression
28
def setUpperBoundsOnCoefficients(value: Matrix): LogisticRegression
29
def setLowerBoundsOnIntercepts(value: Vector): LogisticRegression
30
def setUpperBoundsOnIntercepts(value: Vector): LogisticRegression
31
32
override def fit(dataset: Dataset[_]): LogisticRegressionModel
33
override def copy(extra: ParamMap): LogisticRegression
34
}
35
```
36
37
### Model
38
39
```scala { .api }
40
class LogisticRegressionModel private[spark] (
41
override val uid: String,
42
val coefficientMatrix: Matrix,
43
val interceptVector: Vector,
44
override val numClasses: Int,
45
private val isMultinomial: Boolean)
46
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams
47
with MLWritable {
48
49
// Convenience constructor for binary classification (deprecated)
50
private[spark] def this(uid: String, coefficients: Vector, intercept: Double) =
51
this(uid, new DenseMatrix(1, coefficients.size, coefficients.toArray, isTransposed = true),
52
Vectors.dense(intercept), 2, isMultinomial = false)
53
54
// Model coefficients for binary classification (throws exception if multinomial)
55
def coefficients: Vector
56
def intercept: Double
57
58
lazy val summary: LogisticRegressionTrainingSummary
59
def hasSummary: Boolean
60
def binarySummary: BinaryLogisticRegressionTrainingSummary
61
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary
62
63
override def predict(features: Vector): Double
64
override def predictRaw(features: Vector): Vector
65
override def predictProbability(features: Vector): Vector
66
override def transform(dataset: Dataset[_]): DataFrame
67
override def copy(extra: ParamMap): LogisticRegressionModel
68
def write: MLWriter
69
}
70
71
class LogisticRegressionTrainingSummary(predictions: DataFrame, predictionCol: String,
72
labelCol: String, featuresCol: String,
73
val objectiveHistory: Array[Double])
74
extends LogisticRegressionSummary(predictions, predictionCol, labelCol, featuresCol) {
75
76
def totalIterations: Int
77
}
78
79
class LogisticRegressionSummary(predictions: DataFrame, predictionCol: String,
80
labelCol: String, featuresCol: String) extends ClassificationSummary {
81
82
def probabilityCol: String
83
def fMeasureByThreshold: DataFrame
84
def precisionByThreshold: DataFrame
85
def recallByThreshold: DataFrame
86
def roc: DataFrame
87
def areaUnderROC: Double
88
def pr: DataFrame
89
def fMeasureByLabel: DataFrame
90
def precisionByLabel: DataFrame
91
def recallByLabel: DataFrame
92
}
93
```
94
95
## Decision Tree Classifier
96
97
### Estimator
98
99
```scala { .api }
100
class DecisionTreeClassifier(override val uid: String)
101
extends Classifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
102
with DecisionTreeClassifierParams with DefaultParamsWritable {
103
104
def this() = this(Identifiable.randomUID("dtc"))
105
106
def setMaxDepth(value: Int): DecisionTreeClassifier
107
def setMaxBins(value: Int): DecisionTreeClassifier
108
def setMinInstancesPerNode(value: Int): DecisionTreeClassifier
109
def setMinInfoGain(value: Double): DecisionTreeClassifier
110
def setMaxMemoryInMB(value: Int): DecisionTreeClassifier
111
def setCacheNodeIds(value: Boolean): DecisionTreeClassifier
112
def setCheckpointInterval(value: Int): DecisionTreeClassifier
113
def setImpurity(value: String): DecisionTreeClassifier
114
def setSeed(value: Long): DecisionTreeClassifier
115
116
override def fit(dataset: Dataset[_]): DecisionTreeClassificationModel
117
override def copy(extra: ParamMap): DecisionTreeClassifier
118
}
119
```
120
121
### Model
122
123
```scala { .api }
124
class DecisionTreeClassificationModel(override val uid: String, val rootNode: Node,
125
val numFeatures: Int, val numClasses: Int)
126
extends ClassificationModel[Vector, DecisionTreeClassificationModel]
127
with DecisionTreeClassifierParams with TreeEnsembleModel with MLWritable {
128
129
override def predict(features: Vector): Double
130
override def predictRaw(features: Vector): Vector
131
override def predictProbability(features: Vector): Vector
132
def depth: Int
133
def numNodes: Int
134
def toDebugString: String
135
override def copy(extra: ParamMap): DecisionTreeClassificationModel
136
def write: MLWriter
137
}
138
```
139
140
## Random Forest Classifier
141
142
### Estimator
143
144
```scala { .api }
145
class RandomForestClassifier(override val uid: String)
146
extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
147
with RandomForestClassifierParams with DefaultParamsWritable {
148
149
def this() = this(Identifiable.randomUID("rfc"))
150
151
def setNumTrees(value: Int): RandomForestClassifier
152
def setMaxDepth(value: Int): RandomForestClassifier
153
def setMaxBins(value: Int): RandomForestClassifier
154
def setMinInstancesPerNode(value: Int): RandomForestClassifier
155
def setMinInfoGain(value: Double): RandomForestClassifier
156
def setMaxMemoryInMB(value: Int): RandomForestClassifier
157
def setCacheNodeIds(value: Boolean): RandomForestClassifier
158
def setCheckpointInterval(value: Int): RandomForestClassifier
159
def setImpurity(value: String): RandomForestClassifier
160
def setSubsamplingRate(value: Double): RandomForestClassifier
161
def setSeed(value: Long): RandomForestClassifier
162
def setFeatureSubsetStrategy(value: String): RandomForestClassifier
163
164
override def fit(dataset: Dataset[_]): RandomForestClassificationModel
165
override def copy(extra: ParamMap): RandomForestClassifier
166
}
167
```
168
169
### Model
170
171
```scala { .api }
172
class RandomForestClassificationModel(override val uid: String, private val _trees: Array[DecisionTreeClassificationModel],
173
val numFeatures: Int, val numClasses: Int)
174
extends ClassificationModel[Vector, RandomForestClassificationModel]
175
with RandomForestClassifierParams with TreeEnsembleModel with MLWritable {
176
177
def trees: Array[DecisionTreeClassificationModel]
178
def treeWeights: Array[Double]
179
def featureImportances: Vector
180
181
override def predict(features: Vector): Double
182
override def predictRaw(features: Vector): Vector
183
override def predictProbability(features: Vector): Vector
184
def totalNumNodes: Int
185
def toDebugString: String
186
override def copy(extra: ParamMap): RandomForestClassificationModel
187
def write: MLWriter
188
}
189
```
190
191
## Gradient Boosted Tree Classifier
192
193
### Estimator
194
195
```scala { .api }
196
class GBTClassifier(override val uid: String)
197
extends Classifier[Vector, GBTClassifier, GBTClassificationModel]
198
with GBTClassifierParams with DefaultParamsWritable {
199
200
def this() = this(Identifiable.randomUID("gbtc"))
201
202
def setMaxDepth(value: Int): GBTClassifier
203
def setMaxBins(value: Int): GBTClassifier
204
def setMinInstancesPerNode(value: Int): GBTClassifier
205
def setMinInfoGain(value: Double): GBTClassifier
206
def setMaxMemoryInMB(value: Int): GBTClassifier
207
def setCacheNodeIds(value: Boolean): GBTClassifier
208
def setCheckpointInterval(value: Int): GBTClassifier
209
def setLossType(value: String): GBTClassifier
210
def setMaxIter(value: Int): GBTClassifier
211
def setStepSize(value: Double): GBTClassifier
212
def setSubsamplingRate(value: Double): GBTClassifier
213
def setSeed(value: Long): GBTClassifier
214
def setFeatureSubsetStrategy(value: String): GBTClassifier
215
def setValidationTol(value: Double): GBTClassifier
216
def setValidationIndicatorCol(value: String): GBTClassifier
217
218
override def fit(dataset: Dataset[_]): GBTClassificationModel
219
override def copy(extra: ParamMap): GBTClassifier
220
}
221
```
222
223
### Model
224
225
```scala { .api }
226
class GBTClassificationModel(override val uid: String, private val _trees: Array[DecisionTreeRegressionModel],
227
private val _treeWeights: Array[Double], val numFeatures: Int)
228
extends ClassificationModel[Vector, GBTClassificationModel]
229
with GBTClassifierParams with TreeEnsembleModel with MLWritable {
230
231
def trees: Array[DecisionTreeRegressionModel]
232
def treeWeights: Array[Double]
233
def featureImportances: Vector
234
def totalNumNodes: Int
235
def getNumTrees: Int
236
237
override def predict(features: Vector): Double
238
override def predictRaw(features: Vector): Vector
239
def toDebugString: String
240
override def copy(extra: ParamMap): GBTClassificationModel
241
def write: MLWriter
242
}
243
```
244
245
## Naive Bayes
246
247
### Estimator
248
249
```scala { .api }
250
class NaiveBayes(override val uid: String)
251
extends Classifier[Vector, NaiveBayes, NaiveBayesModel]
252
with NaiveBayesParams with DefaultParamsWritable {
253
254
def this() = this(Identifiable.randomUID("nb"))
255
256
def setModelType(value: String): NaiveBayes
257
def setSmoothing(value: Double): NaiveBayes
258
def setThresholds(value: Array[Double]): NaiveBayes
259
def setWeightCol(value: String): NaiveBayes
260
261
override def fit(dataset: Dataset[_]): NaiveBayesModel
262
override def copy(extra: ParamMap): NaiveBayes
263
}
264
```
265
266
### Model
267
268
```scala { .api }
269
class NaiveBayesModel(override val uid: String, val pi: Vector, val theta: Matrix, val sigma: Matrix)
270
extends ClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable {
271
272
val numFeatures: Int
273
val numClasses: Int
274
275
override def predict(features: Vector): Double
276
override def predictRaw(features: Vector): Vector
277
override def predictProbability(features: Vector): Vector
278
override def copy(extra: ParamMap): NaiveBayesModel
279
def write: MLWriter
280
}
281
```
282
283
## Multilayer Perceptron Classifier
284
285
### Estimator
286
287
```scala { .api }
288
class MultilayerPerceptronClassifier(override val uid: String)
289
extends Classifier[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
290
with MultilayerPerceptronParams with DefaultParamsWritable {
291
292
def this() = this(Identifiable.randomUID("mlpc"))
293
294
def setLayers(value: Array[Int]): MultilayerPerceptronClassifier
295
def setBlockSize(value: Int): MultilayerPerceptronClassifier
296
def setSeed(value: Long): MultilayerPerceptronClassifier
297
def setMaxIter(value: Int): MultilayerPerceptronClassifier
298
def setTol(value: Double): MultilayerPerceptronClassifier
299
def setStepSize(value: Double): MultilayerPerceptronClassifier
300
def setSolver(value: String): MultilayerPerceptronClassifier
301
def setInitialWeights(value: Vector): MultilayerPerceptronClassifier
302
303
override def fit(dataset: Dataset[_]): MultilayerPerceptronClassificationModel
304
override def copy(extra: ParamMap): MultilayerPerceptronClassifier
305
}
306
```
307
308
### Model
309
310
```scala { .api }
311
class MultilayerPerceptronClassificationModel(override val uid: String, val layers: Array[Int],
312
val weights: Vector)
313
extends ClassificationModel[Vector, MultilayerPerceptronClassificationModel]
314
with MultilayerPerceptronParams with MLWritable {
315
316
val numFeatures: Int
317
val numClasses: Int
318
319
override def predict(features: Vector): Double
320
override def predictRaw(features: Vector): Vector
321
override def predictProbability(features: Vector): Vector
322
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel
323
def write: MLWriter
324
}
325
```
326
327
## Linear Support Vector Classifier
328
329
### Estimator
330
331
```scala { .api }
332
class LinearSVC(override val uid: String)
333
extends Classifier[Vector, LinearSVC, LinearSVCModel]
334
with LinearSVCParams with DefaultParamsWritable {
335
336
def this() = this(Identifiable.randomUID("linearsvc"))
337
338
def setRegParam(value: Double): LinearSVC
339
def setMaxIter(value: Int): LinearSVC
340
def setTol(value: Double): LinearSVC
341
def setFitIntercept(value: Boolean): LinearSVC
342
def setStandardization(value: Boolean): LinearSVC
343
def setThreshold(value: Double): LinearSVC
344
def setWeightCol(value: String): LinearSVC
345
def setAggregationDepth(value: Int): LinearSVC
346
347
override def fit(dataset: Dataset[_]): LinearSVCModel
348
override def copy(extra: ParamMap): LinearSVC
349
}
350
```
351
352
### Model
353
354
```scala { .api }
355
class LinearSVCModel(override val uid: String, val coefficients: Vector, val intercept: Double)
356
extends ClassificationModel[Vector, LinearSVCModel] with LinearSVCParams with MLWritable {
357
358
val numClasses: Int = 2
359
val numFeatures: Int
360
361
override def predict(features: Vector): Double
362
override def predictRaw(features: Vector): Vector
363
override def copy(extra: ParamMap): LinearSVCModel
364
def write: MLWriter
365
}
366
```
367
368
## One-vs-Rest Meta-Classifier
369
370
### Estimator
371
372
```scala { .api }
373
class OneVsRest(override val uid: String)
374
extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
375
376
def this() = this(Identifiable.randomUID("oneVsRest"))
377
378
def setClassifier(value: Classifier[_, _, _]): OneVsRest
379
def setLabelCol(value: String): OneVsRest
380
def setFeaturesCol(value: String): OneVsRest
381
def setPredictionCol(value: String): OneVsRest
382
def setRawPredictionCol(value: String): OneVsRest
383
def setParallelism(value: Int): OneVsRest
384
385
override def fit(dataset: Dataset[_]): OneVsRestModel
386
override def copy(extra: ParamMap): OneVsRest
387
override def transformSchema(schema: StructType): StructType
388
def write: MLWriter
389
}
390
```
391
392
### Model
393
394
```scala { .api }
395
class OneVsRestModel(override val uid: String, private val labelMetadata: Metadata, val models: Array[_ <: ClassificationModel[_, _]])
396
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
397
398
val numClasses: Int
399
400
override def transform(dataset: Dataset[_]): DataFrame
401
override def transformSchema(schema: StructType): StructType
402
override def copy(extra: ParamMap): OneVsRestModel
403
def write: MLWriter
404
}
405
```
406
407
## Classification Summary Classes
408
409
### Base Summary
410
411
```scala { .api }
412
abstract class ClassificationSummary(predictions: DataFrame, predictionCol: String,
413
labelCol: String, featuresCol: String) extends Serializable {
414
415
lazy val accuracy: Double
416
lazy val weightedPrecision: Double
417
lazy val weightedRecall: Double
418
lazy val weightedFMeasure: Double
419
lazy val weightedTruePositiveRate: Double = weightedRecall
420
lazy val weightedFalsePositiveRate: Double
421
422
def fMeasureByLabel(beta: Double = 1.0): Array[Double]
423
def precisionByLabel: Array[Double]
424
def recallByLabel: Array[Double]
425
def truePositiveRateByLabel: Array[Double] = recallByLabel
426
def falsePositiveRateByLabel: Array[Double]
427
428
def labels: Array[Double]
429
}
430
```
431
432
### Binary Classification Summary
433
434
```scala { .api }
435
trait BinaryClassificationSummary extends ClassificationSummary {
436
def scoreCol: String
437
438
def roc: DataFrame
439
def areaUnderROC: Double
440
def pr: DataFrame
441
def fMeasureByThreshold: DataFrame
442
def precisionByThreshold: DataFrame
443
def recallByThreshold: DataFrame
444
}
445
```
446
447
## Usage Examples
448
449
### Basic Classification Pipeline
450
451
```scala
452
import org.apache.spark.ml.Pipeline
453
import org.apache.spark.ml.classification.{LogisticRegression, RandomForestClassifier}
454
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
455
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
456
457
// Prepare features
458
val assembler = new VectorAssembler()
459
.setInputCols(Array("feature1", "feature2", "feature3"))
460
.setOutputCol("features")
461
462
// Index labels if they are strings
463
val labelIndexer = new StringIndexer()
464
.setInputCol("label")
465
.setOutputCol("indexedLabel")
466
467
// Create classifier
468
val lr = new LogisticRegression()
469
.setLabelCol("indexedLabel")
470
.setFeaturesCol("features")
471
.setMaxIter(100)
472
.setRegParam(0.1)
473
474
// Create pipeline
475
val pipeline = new Pipeline()
476
.setStages(Array(labelIndexer, assembler, lr))
477
478
// Split data
479
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
480
481
// Train model
482
val model = pipeline.fit(trainingData)
483
484
// Make predictions
485
val predictions = model.transform(testData)
486
487
// Evaluate
488
val evaluator = new BinaryClassificationEvaluator()
489
.setLabelCol("indexedLabel")
490
.setRawPredictionCol("rawPrediction")
491
.setMetricName("areaUnderROC")
492
493
val auc = evaluator.evaluate(predictions)
494
println(s"Area under ROC curve: $auc")
495
```
496
497
### Multiclass Classification with Random Forest
498
499
```scala
500
import org.apache.spark.ml.classification.{RandomForestClassifier, RandomForestClassificationModel}
501
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
502
503
val rf = new RandomForestClassifier()
504
.setLabelCol("indexedLabel")
505
.setFeaturesCol("features")
506
.setNumTrees(100)
507
.setMaxDepth(10)
508
.setMaxBins(32)
509
.setMinInstancesPerNode(1)
510
.setMinInfoGain(0.0)
511
.setSubsamplingRate(1.0)
512
.setFeatureSubsetStrategy("auto")
513
.setSeed(42)
514
515
val rfModel = rf.fit(trainingData)
516
517
// Get feature importances
518
val featureImportances = rfModel.featureImportances
519
println(s"Feature importances: $featureImportances")
520
521
val predictions = rfModel.transform(testData)
522
523
// Evaluate multiclass metrics
524
val evaluator = new MulticlassClassificationEvaluator()
525
.setLabelCol("indexedLabel")
526
.setPredictionCol("prediction")
527
528
val metrics = Array("accuracy", "weightedPrecision", "weightedRecall", "f1")
529
metrics.foreach { metric =>
530
evaluator.setMetricName(metric)
531
val result = evaluator.evaluate(predictions)
532
println(s"$metric: $result")
533
}
534
```
535
536
### Neural Network Classification
537
538
```scala
539
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
540
541
// Define network architecture: input layer (4 features) -> 2 hidden layers (5, 4 nodes) -> output layer (3 classes)
542
val layers = Array[Int](4, 5, 4, 3)
543
544
val mlp = new MultilayerPerceptronClassifier()
545
.setLayers(layers)
546
.setBlockSize(128)
547
.setSeed(1234L)
548
.setMaxIter(100)
549
.setStepSize(0.03)
550
.setSolver("l-bfgs")
551
552
val mlpModel = mlp.fit(trainingData)
553
val mlpPredictions = mlpModel.transform(testData)
554
555
// Show predictions
556
mlpPredictions.select("features", "label", "prediction", "probability").show(20)
557
```
558
559
### One-vs-Rest Multi-class Classification
560
561
```scala
562
import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression}
563
564
// Create base classifier
565
val classifier = new LogisticRegression()
566
.setMaxIter(10)
567
.setTol(1E-6)
568
.setFitIntercept(true)
569
570
// Create One-vs-Rest wrapper
571
val ovr = new OneVsRest()
572
.setClassifier(classifier)
573
.setLabelCol("label")
574
.setFeaturesCol("features")
575
576
val ovrModel = ovr.fit(trainingData)
577
val ovrPredictions = ovrModel.transform(testData)
578
579
// The model contains one binary classifier per class
580
println(s"Number of classes: ${ovrModel.models.length}")
581
```
582
583
### Model Summary and Metrics
584
585
```scala
586
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
587
588
val lr = new LogisticRegression()
589
.setMaxIter(100)
590
.setRegParam(0.01)
591
.setElasticNetParam(0.0)
592
593
val lrModel = lr.fit(trainingData)
594
595
// Get training summary
596
val summary = lrModel.summary
597
println(s"Total iterations: ${summary.totalIterations}")
598
println(s"Objective history: ${summary.objectiveHistory.mkString(", ")}")
599
600
// Binary classification metrics (if binary classification)
601
if (summary.isInstanceOf[org.apache.spark.ml.classification.BinaryLogisticRegressionSummary]) {
602
val binarySummary = summary.asInstanceOf[org.apache.spark.ml.classification.BinaryLogisticRegressionSummary]
603
println(s"Area Under ROC: ${binarySummary.areaUnderROC}")
604
605
// Show ROC curve points
606
binarySummary.roc.show()
607
608
// Show precision-recall curve
609
binarySummary.pr.show()
610
}
611
612
// Evaluate on test data
613
val testSummary = lrModel.evaluate(testData)
614
println(s"Test Accuracy: ${testSummary.accuracy}")
615
println(s"Test Weighted Precision: ${testSummary.weightedPrecision}")
616
println(s"Test Weighted Recall: ${testSummary.weightedRecall}")
617
```