0
# Classification
1
2
Supervised learning algorithms for predicting categorical outcomes, including binary and multiclass classification with probabilistic predictions and comprehensive model evaluation.
3
4
## Capabilities
5
6
### Logistic Regression
7
8
Linear classification algorithm using logistic function for binary and multiclass problems with L1/L2 regularization support.
9
10
```scala { .api }
11
/**
12
* Logistic regression classifier with regularization support
13
*/
14
class LogisticRegression extends Classifier[Vector, LogisticRegression, LogisticRegressionModel] {
15
def setMaxIter(value: Int): this.type
16
def setRegParam(value: Double): this.type
17
def setElasticNetParam(value: Double): this.type
18
def setTol(value: Double): this.type
19
def setFitIntercept(value: Boolean): this.type
20
def setStandardization(value: Boolean): this.type
21
def setThreshold(value: Double): this.type
22
def setThresholds(value: Array[Double]): this.type
23
def setWeightCol(value: String): this.type
24
def setAggregationDepth(value: Int): this.type
25
def setFamily(value: String): this.type
26
def setLowerBoundsOnCoefficients(value: Matrix): this.type
27
def setUpperBoundsOnCoefficients(value: Matrix): this.type
28
def setLowerBoundsOnIntercepts(value: Vector): this.type
29
def setUpperBoundsOnIntercepts(value: Vector): this.type
30
def setMaxBlockSizeInMB(value: Double): this.type
31
def setInitialModel(model: LogisticRegressionModel): this.type
32
}
33
34
class LogisticRegressionModel extends ClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams {
35
def coefficients: Vector
36
def intercept: Double
37
def coefficientMatrix: Matrix
38
def interceptVector: Vector
39
def summary: LogisticRegressionTrainingSummary
40
def hasSummary: Boolean
41
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary
42
}
43
44
class LogisticRegressionSummary {
45
def predictions: DataFrame
46
def probabilityCol: String
47
def labelCol: String
48
def featuresCol: String
49
def predictionCol: String
50
}
51
52
class BinaryLogisticRegressionSummary extends LogisticRegressionSummary {
53
def areaUnderROC: Double
54
def roc: DataFrame
55
def areaUnderPR: Double
56
def pr: DataFrame
57
def fMeasureByThreshold: DataFrame
58
def precisionByThreshold: DataFrame
59
def recallByThreshold: DataFrame
60
}
61
```
62
63
**Usage Example:**
64
65
```scala
66
import org.apache.spark.ml.classification.LogisticRegression
67
68
val lr = new LogisticRegression()
69
.setMaxIter(20)
70
.setRegParam(0.3)
71
.setElasticNetParam(0.8)
72
.setFamily("binomial")
73
74
val lrModel = lr.fit(trainingData)
75
val predictions = lrModel.transform(testData)
76
77
// Access model coefficients
78
println(s"Coefficients: ${lrModel.coefficients}")
79
println(s"Intercept: ${lrModel.intercept}")
80
81
// Get training summary
82
val trainingSummary = lrModel.summary
83
println(s"Number of iterations: ${trainingSummary.totalIterations}")
84
```
85
86
### Decision Tree Classifier
87
88
Tree-based classifier using recursive binary splits with support for categorical and continuous features.
89
90
```scala { .api }
91
/**
92
* Decision tree classifier with configurable splitting criteria
93
*/
94
class DecisionTreeClassifier extends Classifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] {
95
def setMaxDepth(value: Int): this.type
96
def setMaxBins(value: Int): this.type
97
def setMinInstancesPerNode(value: Int): this.type
98
def setMinInfoGain(value: Double): this.type
99
def setMaxMemoryInMB(value: Int): this.type
100
def setCacheNodeIds(value: Boolean): this.type
101
def setCheckpointInterval(value: Int): this.type
102
def setImpurity(value: String): this.type
103
def setSeed(value: Long): this.type
104
}
105
106
class DecisionTreeClassificationModel extends ClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeClassifierParams {
107
def rootNode: Node
108
def depth: Int
109
def numNodes: Int
110
def toDebugString: String
111
def featureImportances: Vector
112
}
113
114
abstract class Node extends Serializable {
115
def prediction: Double
116
def impurity: Double
117
def impurityStats: ImpurityStats
118
def isLeaf: Boolean
119
def deepCopy(): Node
120
}
121
```
122
123
**Usage Example:**
124
125
```scala
126
import org.apache.spark.ml.classification.DecisionTreeClassifier
127
128
val dt = new DecisionTreeClassifier()
129
.setLabelCol("indexedLabel")
130
.setFeaturesCol("indexedFeatures")
131
.setMaxDepth(5)
132
.setMaxBins(32)
133
.setMinInstancesPerNode(1)
134
.setMinInfoGain(0.0)
135
.setImpurity("gini")
136
137
val dtModel = dt.fit(trainingData)
138
val predictions = dtModel.transform(testData)
139
140
// Print the learned classification tree model
141
println(s"Learned classification tree model:\n ${dtModel.toDebugString}")
142
143
// Get feature importances
144
println(s"Feature importances: ${dtModel.featureImportances}")
145
```
146
147
### Random Forest Classifier
148
149
Ensemble method combining multiple decision trees with bootstrap aggregating and random feature selection.
150
151
```scala { .api }
152
/**
153
* Random Forest classifier using ensemble of decision trees
154
*/
155
class RandomForestClassifier extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] {
156
def setNumTrees(value: Int): this.type
157
def setMaxDepth(value: Int): this.type
158
def setMaxBins(value: Int): this.type
159
def setMinInstancesPerNode(value: Int): this.type
160
def setMinInfoGain(value: Double): this.type
161
def setMaxMemoryInMB(value: Int): this.type
162
def setCacheNodeIds(value: Boolean): this.type
163
def setCheckpointInterval(value: Int): this.type
164
def setImpurity(value: String): this.type
165
def setSubsamplingRate(value: Double): this.type
166
def setSeed(value: Long): this.type
167
def setFeatureSubsetStrategy(value: String): this.type
168
}
169
170
class RandomForestClassificationModel extends ClassificationModel[Vector, RandomForestClassificationModel] with RandomForestClassifierParams {
171
def trees: Array[DecisionTreeClassificationModel]
172
def treeWeights: Array[Double]
173
def numFeatures: Int
174
def totalNumNodes: Int
175
def toDebugString: String
176
def featureImportances: Vector
177
}
178
```
179
180
**Usage Example:**
181
182
```scala
183
import org.apache.spark.ml.classification.RandomForestClassifier
184
185
val rf = new RandomForestClassifier()
186
.setLabelCol("indexedLabel")
187
.setFeaturesCol("indexedFeatures")
188
.setNumTrees(20)
189
.setMaxDepth(5)
190
.setMaxBins(32)
191
.setFeatureSubsetStrategy("auto")
192
193
val rfModel = rf.fit(trainingData)
194
val predictions = rfModel.transform(testData)
195
196
// Print feature importances
197
println(s"Feature importances: ${rfModel.featureImportances}")
198
199
// Access individual trees
200
println(s"Number of trees: ${rfModel.trees.length}")
201
```
202
203
### Gradient Boosted Tree Classifier
204
205
Ensemble method that builds models sequentially where each new model corrects errors from previous models.
206
207
```scala { .api }
208
/**
209
* Gradient-boosted tree classifier for binary classification
210
*/
211
class GBTClassifier extends Classifier[Vector, GBTClassifier, GBTClassificationModel] {
212
def setLossType(value: String): this.type
213
def setMaxIter(value: Int): this.type
214
def setStepSize(value: Double): this.type
215
def setMaxDepth(value: Int): this.type
216
def setMaxBins(value: Int): this.type
217
def setMinInstancesPerNode(value: Int): this.type
218
def setMinInfoGain(value: Double): this.type
219
def setMaxMemoryInMB(value: Int): this.type
220
def setCacheNodeIds(value: Boolean): this.type
221
def setCheckpointInterval(value: Int): this.type
222
def setImpurity(value: String): this.type
223
def setSubsamplingRate(value: Double): this.type
224
def setSeed(value: Long): this.type
225
def setFeatureSubsetStrategy(value: String): this.type
226
def setValidationTol(value: Double): this.type
227
def setValidationIndicatorCol(value: String): this.type
228
}
229
230
class GBTClassificationModel extends ClassificationModel[Vector, GBTClassificationModel] with GBTClassifierParams {
231
def trees: Array[DecisionTreeRegressionModel]
232
def treeWeights: Array[Double]
233
def numFeatures: Int
234
def totalNumNodes: Int
235
def toDebugString: String
236
def featureImportances: Vector
237
}
238
```
239
240
### Support Vector Machine
241
242
Linear support vector classifier with L2 regularization for binary classification problems.
243
244
```scala { .api }
245
/**
246
* Linear Support Vector Machine classifier
247
*/
248
class LinearSVC extends Classifier[Vector, LinearSVC, LinearSVCModel] {
249
def setRegParam(value: Double): this.type
250
def setMaxIter(value: Int): this.type
251
def setTol(value: Double): this.type
252
def setFitIntercept(value: Boolean): this.type
253
def setStandardization(value: Boolean): this.type
254
def setThreshold(value: Double): this.type
255
def setWeightCol(value: String): this.type
256
def setAggregationDepth(value: Int): this.type
257
}
258
259
class LinearSVCModel extends ClassificationModel[Vector, LinearSVCModel] with LinearSVCParams {
260
def coefficients: Vector
261
def intercept: Double
262
}
263
```
264
265
### Naive Bayes
266
267
Probabilistic classifier based on Bayes' theorem with naive independence assumption between features.
268
269
```scala { .api }
270
/**
271
* Naive Bayes classifier with multiple model types
272
*/
273
class NaiveBayes extends Classifier[Vector, NaiveBayes, NaiveBayesModel] {
274
def setModelType(value: String): this.type
275
def setSmoothing(value: Double): this.type
276
def setThresholds(value: Array[Double]): this.type
277
def setWeightCol(value: String): this.type
278
}
279
280
class NaiveBayesModel extends ClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
281
def pi: Vector
282
def theta: Matrix
283
def sigma: Matrix
284
def numFeatures: Int
285
def numClasses: Int
286
}
287
```
288
289
### Neural Network Classifier
290
291
Multilayer perceptron classifier using backpropagation for training feed-forward neural networks.
292
293
```scala { .api }
294
/**
295
* Multilayer perceptron classifier
296
*/
297
class MultilayerPerceptronClassifier extends Classifier[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] {
298
def setLayers(value: Array[Int]): this.type
299
def setBlockSize(value: Int): this.type
300
def setSolver(value: String): this.type
301
def setMaxIter(value: Int): this.type
302
def setTol(value: Double): this.type
303
def setSeed(value: Long): this.type
304
def setInitialWeights(value: Vector): this.type
305
def setStepSize(value: Double): this.type
306
}
307
308
class MultilayerPerceptronClassificationModel extends ClassificationModel[Vector, MultilayerPerceptronClassificationModel] with MultilayerPerceptronClassifierParams {
309
def layers: Array[Int]
310
def weights: Vector
311
}
312
```
313
314
### One-vs-Rest Strategy
315
316
Meta-algorithm that enables binary classifiers to handle multiclass problems by training one classifier per class.
317
318
```scala { .api }
319
/**
320
* One-vs-Rest multiclass classification strategy
321
*/
322
class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
323
def setClassifier(value: Classifier[_, _, _]): this.type
324
def setLabelCol(value: String): this.type
325
def setFeaturesCol(value: String): this.type
326
def setPredictionCol(value: String): this.type
327
def setRawPredictionCol(value: String): this.type
328
def setParallelism(value: Int): this.type
329
def setWeightCol(value: String): this.type
330
}
331
332
class OneVsRestModel extends Model[OneVsRestModel] with OneVsRestParams {
333
def models: Array[_ <: ClassificationModel[_, _]]
334
def labelMetadata: Metadata
335
}
336
```
337
338
### Factorization Machine Classifier
339
340
Factorization machine for classification tasks modeling feature interactions efficiently.
341
342
```scala { .api }
343
/**
344
* Factorization Machine classifier for binary classification
345
*/
346
class FMClassifier extends Classifier[Vector, FMClassifier, FMClassificationModel] {
347
def setFactorSize(value: Int): this.type
348
def setFitIntercept(value: Boolean): this.type
349
def setFitLinear(value: Boolean): this.type
350
def setRegParam(value: Double): this.type
351
def setMiniBatchFraction(value: Double): this.type
352
def setInitStd(value: Double): this.type
353
def setMaxIter(value: Int): this.type
354
def setStepSize(value: Double): this.type
355
def setTol(value: Double): this.type
356
def setSolver(value: String): this.type
357
def setThreshold(value: Double): this.type
358
def setSeed(value: Long): this.type
359
}
360
361
class FMClassificationModel extends ClassificationModel[Vector, FMClassificationModel] with FMClassifierParams {
362
def intercept: Double
363
def linear: Vector
364
def factors: Matrix
365
}
366
```
367
368
## Shared Classification Components
369
370
### Base Classes and Traits
371
372
```scala { .api }
373
/**
374
* Base classifier abstraction
375
*/
376
abstract class Classifier[
377
FeaturesType,
378
E <: Classifier[FeaturesType, E, M],
379
M <: ClassificationModel[FeaturesType, M]
380
] extends Estimator[M] with ClassifierParams {
381
def fit(dataset: Dataset[_]): M
382
}
383
384
/**
385
* Base classification model
386
*/
387
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
388
extends Model[M] with ClassificationParams {
389
def numClasses: Int
390
def predictRaw(features: FeaturesType): Vector
391
def rawPredictionCol: String
392
}
393
394
/**
395
* Probabilistic classifier with probability predictions
396
*/
397
abstract class ProbabilisticClassifier[
398
FeaturesType,
399
E <: ProbabilisticClassifier[FeaturesType, E, M],
400
M <: ProbabilisticClassificationModel[FeaturesType, M]
401
] extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams
402
403
/**
404
* Probabilistic classification model
405
*/
406
abstract class ProbabilisticClassificationModel[FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]
407
extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
408
def predictProbability(features: FeaturesType): Vector
409
def probabilityCol: String
410
}
411
```
412
413
## Types
414
415
```scala { .api }
416
// Classification-specific imports
417
import org.apache.spark.ml.classification._
418
import org.apache.spark.ml.linalg.{Vector, Matrix}
419
import org.apache.spark.sql.{DataFrame, Dataset}
420
421
// Parameter traits
422
import org.apache.spark.ml.param.shared._
423
424
// Model summary types
425
import org.apache.spark.ml.classification.{
426
LogisticRegressionSummary,
427
BinaryLogisticRegressionSummary,
428
MulticlassLogisticRegressionSummary
429
}
430
431
// Tree model components
432
import org.apache.spark.ml.tree.{Node, InternalNode, LeafNode}
433
import org.apache.spark.ml.tree.impurity.{Gini, Entropy, ImpurityStats}
434
```