0
# Model Evaluation
1
2
MLlib provides comprehensive evaluation tools for assessing model performance including metrics for classification, regression, and clustering tasks, along with model validation techniques like cross-validation and train-validation split.
3
4
## Evaluators
5
6
### BinaryClassificationEvaluator
7
8
```scala { .api }
9
class BinaryClassificationEvaluator(override val uid: String) extends Evaluator
10
with HasLabelCol with HasRawPredictionCol with DefaultParamsWritable {
11
12
def this() = this(Identifiable.randomUID("binEval"))
13
14
final val metricName: Param[String]
15
16
def setMetricName(value: String): BinaryClassificationEvaluator
17
def setLabelCol(value: String): BinaryClassificationEvaluator
18
def setRawPredictionCol(value: String): BinaryClassificationEvaluator
19
20
override def evaluate(dataset: Dataset[_]): Double
21
override def isLargerBetter: Boolean
22
override def copy(extra: ParamMap): BinaryClassificationEvaluator
23
}
24
```
25
26
### MulticlassClassificationEvaluator
27
28
```scala { .api }
29
class MulticlassClassificationEvaluator(override val uid: String) extends Evaluator
30
with HasLabelCol with HasPredictionCol with DefaultParamsWritable {
31
32
def this() = this(Identifiable.randomUID("mcEval"))
33
34
final val metricName: Param[String]
35
final val metricLabel: DoubleParam
36
final val beta: DoubleParam
37
final val eps: DoubleParam
38
39
def setMetricName(value: String): MulticlassClassificationEvaluator
40
def setMetricLabel(value: Double): MulticlassClassificationEvaluator
41
def setBeta(value: Double): MulticlassClassificationEvaluator
42
def setEps(value: Double): MulticlassClassificationEvaluator
43
def setLabelCol(value: String): MulticlassClassificationEvaluator
44
def setPredictionCol(value: String): MulticlassClassificationEvaluator
45
46
override def evaluate(dataset: Dataset[_]): Double
47
override def isLargerBetter: Boolean
48
override def copy(extra: ParamMap): MulticlassClassificationEvaluator
49
}
50
```
51
52
### RegressionEvaluator
53
54
```scala { .api }
55
class RegressionEvaluator(override val uid: String) extends Evaluator
56
with HasLabelCol with HasPredictionCol with DefaultParamsWritable {
57
58
def this() = this(Identifiable.randomUID("regEval"))
59
60
final val metricName: Param[String]
61
final val throughOrigin: BooleanParam
62
63
def setMetricName(value: String): RegressionEvaluator
64
def setThroughOrigin(value: Boolean): RegressionEvaluator
65
def setLabelCol(value: String): RegressionEvaluator
66
def setPredictionCol(value: String): RegressionEvaluator
67
68
override def evaluate(dataset: Dataset[_]): Double
69
override def isLargerBetter: Boolean
70
override def copy(extra: ParamMap): RegressionEvaluator
71
}
72
```
73
74
### ClusteringEvaluator
75
76
```scala { .api }
77
class ClusteringEvaluator(override val uid: String) extends Evaluator
78
with HasPredictionCol with HasFeaturesCol with DefaultParamsWritable {
79
80
def this() = this(Identifiable.randomUID("clustering"))
81
82
final val metricName: Param[String]
83
final val distanceMeasure: Param[String]
84
85
def setMetricName(value: String): ClusteringEvaluator
86
def setDistanceMeasure(value: String): ClusteringEvaluator
87
def setPredictionCol(value: String): ClusteringEvaluator
88
def setFeaturesCol(value: String): ClusteringEvaluator
89
90
override def evaluate(dataset: Dataset[_]): Double
91
override def isLargerBetter: Boolean
92
override def copy(extra: ParamMap): ClusteringEvaluator
93
}
94
```
95
96
## Model Selection and Tuning
97
98
### CrossValidator
99
100
```scala { .api }
101
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
102
with CrossValidatorParams with MLWritable with Logging {
103
104
def this() = this(Identifiable.randomUID("cv"))
105
106
def setEstimator(value: Estimator[_]): CrossValidator
107
def setEstimatorParamMaps(value: Array[ParamMap]): CrossValidator
108
def setEvaluator(value: Evaluator): CrossValidator
109
def setNumFolds(value: Int): CrossValidator
110
def setSeed(value: Long): CrossValidator
111
def setParallelism(value: Int): CrossValidator
112
def setCollectSubModels(value: Boolean): CrossValidator
113
def setFoldCol(value: String): CrossValidator
114
115
override def fit(dataset: Dataset[_]): CrossValidatorModel
116
override def copy(extra: ParamMap): CrossValidator
117
def write: MLWriter
118
}
119
120
class CrossValidatorModel(override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double])
121
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
122
123
lazy val subModels: Option[Array[Array[Model[_]]]]
124
125
override def transform(dataset: Dataset[_]): DataFrame
126
override def transformSchema(schema: StructType): StructType
127
override def copy(extra: ParamMap): CrossValidatorModel
128
def write: MLWriter
129
}
130
```
131
132
### TrainValidationSplit
133
134
```scala { .api }
135
class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
136
with TrainValidationSplitParams with MLWritable with Logging {
137
138
def this() = this(Identifiable.randomUID("tvs"))
139
140
def setEstimator(value: Estimator[_]): TrainValidationSplit
141
def setEstimatorParamMaps(value: Array[ParamMap]): TrainValidationSplit
142
def setEvaluator(value: Evaluator): TrainValidationSplit
143
def setTrainRatio(value: Double): TrainValidationSplit
144
def setSeed(value: Long): TrainValidationSplit
145
def setParallelism(value: Int): TrainValidationSplit
146
def setCollectSubModels(value: Boolean): TrainValidationSplit
147
148
override def fit(dataset: Dataset[_]): TrainValidationSplitModel
149
override def copy(extra: ParamMap): TrainValidationSplit
150
def write: MLWriter
151
}
152
153
class TrainValidationSplitModel(override val uid: String, val bestModel: Model[_], val validationMetrics: Array[Double])
154
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
155
156
lazy val subModels: Option[Array[Model[_]]]
157
158
override def transform(dataset: Dataset[_]): DataFrame
159
override def transformSchema(schema: StructType): StructType
160
override def copy(extra: ParamMap): TrainValidationSplitModel
161
def write: MLWriter
162
}
163
```
164
165
### ParamGridBuilder
166
167
```scala { .api }
168
class ParamGridBuilder {
169
def addGrid[T](param: Param[T], values: Array[T]): ParamGridBuilder
170
def addGrid[T](param: Param[T], values: java.util.List[T]): ParamGridBuilder
171
def baseOn(paramMap: ParamMap): ParamGridBuilder
172
def baseOn(paramPairs: ParamPair[_]*): ParamGridBuilder
173
def build(): Array[ParamMap]
174
}
175
```
176
177
## Base Evaluator
178
179
```scala { .api }
180
abstract class Evaluator extends Params {
181
def evaluate(dataset: Dataset[_]): Double
182
def isLargerBetter: Boolean
183
def copy(extra: ParamMap): Evaluator
184
}
185
```
186
187
## Usage Examples
188
189
### Binary Classification Evaluation
190
191
```scala
192
import org.apache.spark.ml.classification.LogisticRegression
193
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
194
195
// Train model (assuming data is prepared)
196
val lr = new LogisticRegression()
197
val model = lr.fit(trainingData)
198
val predictions = model.transform(testData)
199
200
// Binary classification evaluation
201
val binaryEvaluator = new BinaryClassificationEvaluator()
202
.setLabelCol("label")
203
.setRawPredictionCol("rawPrediction")
204
205
// Area Under ROC Curve
206
val aucROC = binaryEvaluator.setMetricName("areaUnderROC").evaluate(predictions)
207
println(s"Area Under ROC: $aucROC")
208
209
// Area Under Precision-Recall Curve
210
val aucPR = binaryEvaluator.setMetricName("areaUnderPR").evaluate(predictions)
211
println(s"Area Under PR: $aucPR")
212
213
// Print model summary for additional metrics
214
if (model.hasSummary) {
215
val summary = model.summary.asInstanceOf[org.apache.spark.ml.classification.BinaryLogisticRegressionSummary]
216
217
// ROC curve
218
summary.roc.show()
219
220
// Precision-Recall curve
221
summary.pr.show()
222
223
// Metrics by threshold
224
summary.fMeasureByThreshold.show()
225
summary.precisionByThreshold.show()
226
summary.recallByThreshold.show()
227
}
228
```
229
230
### Multiclass Classification Evaluation
231
232
```scala
233
import org.apache.spark.ml.classification.RandomForestClassifier
234
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
235
236
val rf = new RandomForestClassifier()
237
val rfModel = rf.fit(trainingData)
238
val rfPredictions = rfModel.transform(testData)
239
240
val multiclassEvaluator = new MulticlassClassificationEvaluator()
241
.setLabelCol("label")
242
.setPredictionCol("prediction")
243
244
// Overall accuracy
245
val accuracy = multiclassEvaluator.setMetricName("accuracy").evaluate(rfPredictions)
246
println(s"Accuracy: $accuracy")
247
248
// Weighted metrics (accounts for class imbalance)
249
val weightedPrecision = multiclassEvaluator.setMetricName("weightedPrecision").evaluate(rfPredictions)
250
val weightedRecall = multiclassEvaluator.setMetricName("weightedRecall").evaluate(rfPredictions)
251
val weightedF1 = multiclassEvaluator.setMetricName("f1").evaluate(rfPredictions)
252
253
println(s"Weighted Precision: $weightedPrecision")
254
println(s"Weighted Recall: $weightedRecall")
255
println(s"Weighted F1-Score: $weightedF1")
256
257
// Precision and recall for specific class (e.g., class 1.0)
258
val precisionClass1 = multiclassEvaluator
259
.setMetricName("precisionByLabel")
260
.setMetricLabel(1.0)
261
.evaluate(rfPredictions)
262
263
val recallClass1 = multiclassEvaluator
264
.setMetricName("recallByLabel")
265
.setMetricLabel(1.0)
266
.evaluate(rfPredictions)
267
268
println(s"Precision for class 1: $precisionClass1")
269
println(s"Recall for class 1: $recallClass1")
270
271
// F1-score with custom beta (beta=2 emphasizes recall over precision)
272
val f2Score = multiclassEvaluator
273
.setMetricName("f1")
274
.setBeta(2.0)
275
.evaluate(rfPredictions)
276
277
println(s"F2-Score: $f2Score")
278
```
279
280
### Regression Evaluation
281
282
```scala
283
import org.apache.spark.ml.regression.{LinearRegression, RandomForestRegressor}
284
import org.apache.spark.ml.evaluation.RegressionEvaluator
285
286
// Train models
287
val lr = new LinearRegression()
288
val lrModel = lr.fit(trainingData)
289
val lrPredictions = lrModel.transform(testData)
290
291
val rf = new RandomForestRegressor()
292
val rfModel = rf.fit(trainingData)
293
val rfPredictions = rfModel.transform(testData)
294
295
val regressionEvaluator = new RegressionEvaluator()
296
.setLabelCol("label")
297
.setPredictionCol("prediction")
298
299
// Evaluate different metrics for linear regression
300
val metrics = Array("rmse", "mse", "mae", "r2")
301
println("Linear Regression Metrics:")
302
metrics.foreach { metric =>
303
val result = regressionEvaluator.setMetricName(metric).evaluate(lrPredictions)
304
println(s"$metric: $result")
305
}
306
307
println("\nRandom Forest Regression Metrics:")
308
metrics.foreach { metric =>
309
val result = regressionEvaluator.setMetricName(metric).evaluate(rfPredictions)
310
println(s"$metric: $result")
311
}
312
313
// Explained variance
314
val explainedVariance = regressionEvaluator
315
.setMetricName("var")
316
.evaluate(lrPredictions)
317
println(s"Explained Variance: $explainedVariance")
318
319
// Compare models
320
val lrRMSE = regressionEvaluator.setMetricName("rmse").evaluate(lrPredictions)
321
val rfRMSE = regressionEvaluator.setMetricName("rmse").evaluate(rfPredictions)
322
323
println(s"\nModel Comparison (RMSE):")
324
println(s"Linear Regression: $lrRMSE")
325
println(s"Random Forest: $rfRMSE")
326
println(s"Best Model: ${if (lrRMSE < rfRMSE) "Linear Regression" else "Random Forest"}")
327
```
328
329
### Clustering Evaluation
330
331
```scala
332
import org.apache.spark.ml.clustering.KMeans
333
import org.apache.spark.ml.evaluation.ClusteringEvaluator
334
335
// Train clustering models with different k values
336
val kValues = Array(2, 3, 4, 5, 6)
337
val clusteringEvaluator = new ClusteringEvaluator()
338
.setPredictionCol("prediction")
339
.setFeaturesCol("features")
340
.setMetricName("silhouette")
341
.setDistanceMeasure("squaredEuclidean")
342
343
println("Clustering Evaluation (Silhouette Score):")
344
kValues.foreach { k =>
345
val kmeans = new KMeans().setK(k).setSeed(42)
346
val model = kmeans.fit(data)
347
val predictions = model.transform(data)
348
349
val silhouette = clusteringEvaluator.evaluate(predictions)
350
println(s"k=$k: Silhouette Score = $silhouette")
351
}
352
353
// Find optimal k using elbow method (inertia/cost)
354
println("\nElbow Method (Within Set Sum of Squared Errors):")
355
kValues.foreach { k =>
356
val kmeans = new KMeans().setK(k).setSeed(42)
357
val model = kmeans.fit(data)
358
val cost = model.computeCost(data)
359
println(s"k=$k: WSSSE = $cost")
360
}
361
```
362
363
### Cross-Validation for Hyperparameter Tuning
364
365
```scala
366
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
367
import org.apache.spark.ml.classification.RandomForestClassifier
368
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
369
370
val rf = new RandomForestClassifier()
371
.setLabelCol("label")
372
.setFeaturesCol("features")
373
.setSeed(42)
374
375
// Create parameter grid
376
val paramGrid = new ParamGridBuilder()
377
.addGrid(rf.numTrees, Array(10, 20, 30))
378
.addGrid(rf.maxDepth, Array(5, 10, 15))
379
.addGrid(rf.maxBins, Array(16, 32))
380
.build()
381
382
// Create evaluator
383
val evaluator = new MulticlassClassificationEvaluator()
384
.setLabelCol("label")
385
.setPredictionCol("prediction")
386
.setMetricName("f1")
387
388
// Create cross-validator
389
val cv = new CrossValidator()
390
.setEstimator(rf)
391
.setEvaluator(evaluator)
392
.setEstimatorParamMaps(paramGrid)
393
.setNumFolds(5)
394
.setParallelism(2) // Parallel execution
395
.setSeed(42)
396
397
// Fit cross-validator
398
val cvModel = cv.fit(trainingData)
399
400
// Get results
401
println(s"Best F1 Score: ${cvModel.avgMetrics.max}")
402
println("All cross-validation scores:")
403
cvModel.avgMetrics.zipWithIndex.foreach { case (metric, idx) =>
404
println(s" Param set $idx: $metric")
405
}
406
407
// Get best model parameters
408
val bestModel = cvModel.bestModel.asInstanceOf[RandomForestClassificationModel]
409
println(s"Best parameters:")
410
println(s" Number of trees: ${bestModel.getNumTrees}")
411
println(s" Max depth: ${bestModel.getMaxDepth}")
412
println(s" Max bins: ${bestModel.getMaxBins}")
413
414
// Evaluate on test data
415
val finalPredictions = cvModel.transform(testData)
416
val testScore = evaluator.evaluate(finalPredictions)
417
println(s"Test F1 Score: $testScore")
418
```
419
420
### Train-Validation Split
421
422
```scala
423
import org.apache.spark.ml.tuning.TrainValidationSplit
424
import org.apache.spark.ml.regression.{LinearRegression, RandomForestRegressor}
425
import org.apache.spark.ml.evaluation.RegressionEvaluator
426
427
// Compare different algorithms
428
val lr = new LinearRegression()
429
val rf = new RandomForestRegressor()
430
431
// Parameter grids for each algorithm
432
val lrParamGrid = new ParamGridBuilder()
433
.addGrid(lr.regParam, Array(0.001, 0.01, 0.1))
434
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
435
.build()
436
437
val rfParamGrid = new ParamGridBuilder()
438
.addGrid(rf.numTrees, Array(10, 20, 30))
439
.addGrid(rf.maxDepth, Array(5, 10))
440
.build()
441
442
val evaluator = new RegressionEvaluator()
443
.setMetricName("rmse")
444
445
// Train-validation split for linear regression
446
val lrTvs = new TrainValidationSplit()
447
.setEstimator(lr)
448
.setEvaluator(evaluator)
449
.setEstimatorParamMaps(lrParamGrid)
450
.setTrainRatio(0.8)
451
.setSeed(42)
452
453
val lrTvsModel = lrTvs.fit(data)
454
455
// Train-validation split for random forest
456
val rfTvs = new TrainValidationSplit()
457
.setEstimator(rf)
458
.setEvaluator(evaluator)
459
.setEstimatorParamMaps(rfParamGrid)
460
.setTrainRatio(0.8)
461
.setSeed(42)
462
463
val rfTvsModel = rfTvs.fit(data)
464
465
// Compare validation metrics
466
val lrBestScore = lrTvsModel.validationMetrics.min // RMSE: lower is better
467
val rfBestScore = rfTvsModel.validationMetrics.min
468
469
println(s"Linear Regression - Best validation RMSE: $lrBestScore")
470
println(s"Random Forest - Best validation RMSE: $rfBestScore")
471
472
val bestAlgorithm = if (lrBestScore < rfBestScore) {
473
println("Linear Regression performs better")
474
lrTvsModel
475
} else {
476
println("Random Forest performs better")
477
rfTvsModel
478
}
479
480
// Final evaluation on test data
481
val testPredictions = bestAlgorithm.transform(testData)
482
val testRMSE = evaluator.evaluate(testPredictions)
483
println(s"Final test RMSE: $testRMSE")
484
```
485
486
### Custom Evaluation Metrics
487
488
```scala
489
import org.apache.spark.sql.functions._
490
import org.apache.spark.sql.types._
491
492
// Custom evaluation function for classification
493
def customClassificationMetrics(predictions: DataFrame): Unit = {
494
// Confusion matrix
495
val confusionMatrix = predictions
496
.groupBy("label", "prediction")
497
.count()
498
.orderBy("label", "prediction")
499
500
println("Confusion Matrix:")
501
confusionMatrix.show()
502
503
// Per-class metrics
504
val classMetrics = predictions
505
.groupBy("label")
506
.agg(
507
count("*").alias("total"),
508
sum(when(col("label") === col("prediction"), 1).otherwise(0)).alias("correct")
509
)
510
.withColumn("accuracy", col("correct") / col("total"))
511
512
println("Per-class Accuracy:")
513
classMetrics.show()
514
515
// Overall metrics
516
val totalPredictions = predictions.count()
517
val correctPredictions = predictions
518
.filter(col("label") === col("prediction"))
519
.count()
520
521
val overallAccuracy = correctPredictions.toDouble / totalPredictions
522
println(s"Overall Accuracy: $overallAccuracy")
523
}
524
525
// Custom evaluation for regression
526
def customRegressionMetrics(predictions: DataFrame): Unit = {
527
import org.apache.spark.sql.functions._
528
529
val metrics = predictions
530
.select(
531
mean(abs(col("label") - col("prediction"))).alias("mae"),
532
sqrt(mean(pow(col("label") - col("prediction"), 2))).alias("rmse"),
533
mean(pow(col("label") - col("prediction"), 2)).alias("mse"),
534
corr("label", "prediction").alias("correlation")
535
)
536
537
println("Custom Regression Metrics:")
538
metrics.show()
539
540
// Residual analysis
541
val residuals = predictions
542
.withColumn("residual", col("label") - col("prediction"))
543
.withColumn("abs_residual", abs(col("residual")))
544
.withColumn("squared_residual", pow(col("residual"), 2))
545
546
val residualStats = residuals
547
.select(
548
mean("residual").alias("mean_residual"),
549
stddev("residual").alias("std_residual"),
550
min("residual").alias("min_residual"),
551
max("residual").alias("max_residual")
552
)
553
554
println("Residual Statistics:")
555
residualStats.show()
556
}
557
558
// Usage with model predictions
559
val predictions = model.transform(testData)
560
customClassificationMetrics(predictions)
561
// or
562
// customRegressionMetrics(predictions)
563
```
564
565
### Model Performance Monitoring
566
567
```scala
568
import org.apache.spark.sql.functions._
569
570
def monitorModelPerformance(predictions: DataFrame,
571
evaluator: BinaryClassificationEvaluator,
572
timeCol: String = "timestamp"): Unit = {
573
574
// Performance over time
575
val performanceOverTime = predictions
576
.withColumn("date", to_date(col(timeCol)))
577
.groupBy("date")
578
.agg(
579
count("*").alias("num_predictions"),
580
mean(when(col("label") === col("prediction"), 1.0).otherwise(0.0)).alias("accuracy"),
581
mean("rawPrediction").alias("avg_confidence")
582
)
583
.orderBy("date")
584
585
println("Performance Over Time:")
586
performanceOverTime.show()
587
588
// Prediction distribution
589
val predictionDist = predictions
590
.groupBy("prediction")
591
.count()
592
.orderBy("prediction")
593
594
println("Prediction Distribution:")
595
predictionDist.show()
596
597
// Confidence analysis (for probabilistic models)
598
if (predictions.columns.contains("probability")) {
599
val confidenceAnalysis = predictions
600
.withColumn("max_probability",
601
expr("transform(probability.values, x -> x)").getItem(0))
602
.groupBy("prediction")
603
.agg(
604
count("*").alias("count"),
605
mean("max_probability").alias("avg_confidence"),
606
min("max_probability").alias("min_confidence"),
607
max("max_probability").alias("max_confidence")
608
)
609
610
println("Confidence Analysis by Prediction:")
611
confidenceAnalysis.show()
612
}
613
}
614
615
// Monitor model performance
616
monitorModelPerformance(predictions, binaryEvaluator)
617
```