0
# Evaluation and Tuning
1
2
Comprehensive model evaluation metrics and automated hyperparameter tuning capabilities for assessing model performance and optimizing ML pipelines.
3
4
## Capabilities
5
6
### Model Evaluators
7
8
Statistical metrics for assessing model performance across different machine learning tasks.
9
10
```scala { .api }
11
/**
12
* Base evaluator abstraction
13
*/
14
abstract class Evaluator extends Params {
15
def evaluate(dataset: Dataset[_]): Double
16
def isLargerBetter: Boolean
17
}
18
19
/**
20
* Binary classification evaluation metrics
21
*/
22
class BinaryClassificationEvaluator extends Evaluator {
23
def setRawPredictionCol(value: String): this.type
24
def setLabelCol(value: String): this.type
25
def setMetricName(value: String): this.type
26
def setWeightCol(value: String): this.type
27
def setNumBins(value: Int): this.type
28
}
29
30
/**
31
* Multiclass classification evaluation metrics
32
*/
33
class MulticlassClassificationEvaluator extends Evaluator {
34
def setPredictionCol(value: String): this.type
35
def setLabelCol(value: String): this.type
36
def setMetricName(value: String): this.type
37
def setWeightCol(value: String): this.type
38
def setMetricLabel(value: Double): this.type
39
def setProbabilityCol(value: String): this.type
40
def setBeta(value: Double): this.type
41
def setEps(value: Double): this.type
42
}
43
44
/**
45
* Regression evaluation metrics
46
*/
47
class RegressionEvaluator extends Evaluator {
48
def setPredictionCol(value: String): this.type
49
def setLabelCol(value: String): this.type
50
def setMetricName(value: String): this.type
51
def setWeightCol(value: String): this.type
52
def setThroughOrigin(value: Boolean): this.type
53
}
54
55
/**
56
* Clustering evaluation metrics
57
*/
58
class ClusteringEvaluator extends Evaluator {
59
def setPredictionCol(value: String): this.type
60
def setFeaturesCol(value: String): this.type
61
def setMetricName(value: String): this.type
62
def setDistanceMeasure(value: String): this.type
63
def setWeightCol(value: String): this.type
64
}
65
66
/**
67
* Ranking evaluation metrics
68
*/
69
class RankingEvaluator extends Evaluator {
70
def setPredictionCol(value: String): this.type
71
def setLabelCol(value: String): this.type
72
def setMetricName(value: String): this.type
73
def setK(value: Int): this.type
74
}
75
```
76
77
**Usage Example:**
78
79
```scala
80
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
81
82
// Binary classification evaluation
83
val binaryEvaluator = new BinaryClassificationEvaluator()
84
.setLabelCol("label")
85
.setRawPredictionCol("rawPrediction")
86
.setMetricName("areaUnderROC")
87
88
val auc = binaryEvaluator.evaluate(predictions)
89
println(s"Area under ROC = $auc")
90
91
// Multiclass classification evaluation
92
val multiclassEvaluator = new MulticlassClassificationEvaluator()
93
.setLabelCol("label")
94
.setPredictionCol("prediction")
95
.setMetricName("accuracy")
96
97
val accuracy = multiclassEvaluator.evaluate(predictions)
98
println(s"Test set accuracy = $accuracy")
99
```
100
101
### Hyperparameter Tuning
102
103
Automated model selection and hyperparameter optimization using cross-validation and grid search.
104
105
```scala { .api }
106
/**
107
* Parameter grid builder for hyperparameter tuning
108
*/
109
class ParamGridBuilder {
110
def addGrid(param: Param[_], values: Array[_]): this.type
111
def baseOn(paramMap: ParamMap): this.type
112
def baseOn(paramMaps: ParamMap*): this.type
113
def build(): Array[ParamMap]
114
}
115
116
/**
117
* K-fold cross-validation for model selection
118
*/
119
class CrossValidator extends Estimator[CrossValidatorModel] {
120
def setEstimator(value: Estimator[_]): this.type
121
def setEstimatorParamMaps(value: Array[ParamMap]): this.type
122
def setEvaluator(value: Evaluator): this.type
123
def setNumFolds(value: Int): this.type
124
def setParallelism(value: Int): this.type
125
def setCollectSubModels(value: Boolean): this.type
126
def setSeed(value: Long): this.type
127
def setFoldCol(value: String): this.type
128
}
129
130
class CrossValidatorModel extends Model[CrossValidatorModel] with CrossValidatorParams {
131
def bestModel: Model[_]
132
def avgMetrics: Array[Double]
133
def stdMetrics: Array[Double]
134
def subModels: Array[Array[Model[_]]]
135
def hasSubModels: Boolean
136
}
137
138
/**
139
* Train-validation split for model selection
140
*/
141
class TrainValidationSplit extends Estimator[TrainValidationSplitModel] {
142
def setEstimator(value: Estimator[_]): this.type
143
def setEstimatorParamMaps(value: Array[ParamMap]): this.type
144
def setEvaluator(value: Evaluator): this.type
145
def setTrainRatio(value: Double): this.type
146
def setParallelism(value: Int): this.type
147
def setCollectSubModels(value: Boolean): this.type
148
def setSeed(value: Long): this.type
149
}
150
151
class TrainValidationSplitModel extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
152
def bestModel: Model[_]
153
def validationMetrics: Array[Double]
154
def subModels: Array[Model[_]]
155
def hasSubModels: Boolean
156
}
157
```
158
159
**Usage Example:**
160
161
```scala
162
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
163
import org.apache.spark.ml.classification.LogisticRegression
164
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
165
166
// Create the model
167
val lr = new LogisticRegression()
168
169
// Create parameter grid
170
val paramGrid = new ParamGridBuilder()
171
.addGrid(lr.regParam, Array(0.1, 0.01))
172
.addGrid(lr.fitIntercept)
173
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
174
.build()
175
176
// Create cross validator
177
val cv = new CrossValidator()
178
.setEstimator(lr)
179
.setEvaluator(new BinaryClassificationEvaluator)
180
.setEstimatorParamMaps(paramGrid)
181
.setNumFolds(3)
182
.setParallelism(2)
183
184
// Run cross-validation and choose the best set of parameters
185
val cvModel = cv.fit(training)
186
187
// Make predictions on test data
188
val predictions = cvModel.transform(test)
189
```
190
191
### Advanced Evaluation
192
193
Specialized evaluation methods for complex model assessment scenarios.
194
195
```scala { .api }
196
/**
197
* Multilabel classification evaluation
198
*/
199
class MultilabelClassificationEvaluator extends Evaluator {
200
def setPredictionCol(value: String): this.type
201
def setLabelCol(value: String): this.type
202
def setMetricName(value: String): this.type
203
}
204
205
/**
206
* Recommendation system evaluation
207
*/
208
class RecommendationEvaluator extends Evaluator {
209
def setPredictionCol(value: String): this.type
210
def setLabelCol(value: String): this.type
211
def setMetricName(value: String): this.type
212
def setK(value: Int): this.type
213
def setColdStartStrategy(value: String): this.type
214
}
215
```
216
217
### Model Comparison and Statistical Tests
218
219
Tools for comparing model performance and conducting statistical significance tests.
220
221
```scala { .api }
222
/**
223
* Statistical utilities for model comparison
224
*/
225
object ModelComparison {
226
def compareModels(
227
model1Metrics: Array[Double],
228
model2Metrics: Array[Double]
229
): StatisticalTestResult
230
231
def pairedTTest(
232
differences: Array[Double],
233
confidenceLevel: Double = 0.95
234
): TTestResult
235
}
236
237
case class StatisticalTestResult(
238
pValue: Double,
239
statistic: Double,
240
confidenceInterval: (Double, Double),
241
isSignificant: Boolean
242
)
243
244
case class TTestResult(
245
pValue: Double,
246
tStatistic: Double,
247
degreesOfFreedom: Int,
248
confidenceInterval: (Double, Double)
249
)
250
```
251
252
### Pipeline Validation
253
254
Tools for validating entire ML pipelines and ensuring data consistency.
255
256
```scala { .api }
257
/**
258
* Pipeline validation utilities
259
*/
260
object PipelineValidator {
261
def validatePipeline(
262
pipeline: Pipeline,
263
dataset: Dataset[_]
264
): ValidationReport
265
266
def checkDataLeakage(
267
transformers: Array[Transformer],
268
dataset: Dataset[_]
269
): LeakageReport
270
}
271
272
case class ValidationReport(
273
isValid: Boolean,
274
errors: Array[ValidationError],
275
warnings: Array[ValidationWarning]
276
)
277
278
case class ValidationError(
279
stage: String,
280
message: String,
281
severity: String
282
)
283
284
case class ValidationWarning(
285
stage: String,
286
message: String,
287
recommendation: String
288
)
289
290
case class LeakageReport(
291
hasLeakage: Boolean,
292
suspiciousTransformers: Array[String],
293
details: Map[String, String]
294
)
295
```
296
297
### Custom Evaluation Metrics
298
299
Framework for creating custom evaluation metrics for specialized use cases.
300
301
```scala { .api }
302
/**
303
* Base trait for custom evaluation metrics
304
*/
305
trait CustomEvaluator extends Evaluator {
306
def computeMetric(predictions: DataFrame): Double
307
def getMetricName: String
308
}
309
310
/**
311
* Example custom evaluator implementation
312
*/
313
class CustomRegressionEvaluator extends CustomEvaluator {
314
def setPredictionCol(value: String): this.type
315
def setLabelCol(value: String): this.type
316
def setCustomParams(params: Map[String, Any]): this.type
317
318
def computeMetric(predictions: DataFrame): Double = {
319
// Custom metric computation logic
320
0.0
321
}
322
323
def getMetricName: String = "customMetric"
324
}
325
```
326
327
## Evaluation Metrics Reference
328
329
### Binary Classification Metrics
330
331
- **areaUnderROC**: Area under the Receiver Operating Characteristic curve
332
- **areaUnderPR**: Area under the Precision-Recall curve
333
334
### Multiclass Classification Metrics
335
336
- **accuracy**: Overall accuracy (correct predictions / total predictions)
337
- **weightedPrecision**: Weighted precision across all classes
338
- **weightedRecall**: Weighted recall across all classes
339
- **weightedFMeasure**: Weighted F1-score across all classes
340
- **hammingLoss**: Hamming loss for multilabel classification
341
342
### Regression Metrics
343
344
- **rmse**: Root Mean Squared Error
345
- **mse**: Mean Squared Error
346
- **r2**: R-squared (coefficient of determination)
347
- **mae**: Mean Absolute Error
348
- **var**: Explained variance
349
350
### Clustering Metrics
351
352
- **silhouette**: Silhouette coefficient
353
- **squaredEuclidean**: Squared Euclidean distance
354
355
### Ranking Metrics
356
357
- **meanAveragePrecision**: Mean Average Precision
358
- **meanAveragePrecisionAtK**: MAP at K
359
- **precisionAtK**: Precision at K
360
- **recallAtK**: Recall at K
361
- **ndcgAtK**: Normalized Discounted Cumulative Gain at K
362
363
## Types
364
365
```scala { .api }
366
// Evaluation and tuning imports
367
import org.apache.spark.ml.evaluation._
368
import org.apache.spark.ml.tuning._
369
import org.apache.spark.ml.param.{ParamMap, ParamGridBuilder}
370
import org.apache.spark.sql.{DataFrame, Dataset}
371
372
// Model selection types
373
import org.apache.spark.ml.tuning.{
374
CrossValidator,
375
CrossValidatorModel,
376
TrainValidationSplit,
377
TrainValidationSplitModel
378
}
379
380
// Evaluator types
381
import org.apache.spark.ml.evaluation.{
382
BinaryClassificationEvaluator,
383
MulticlassClassificationEvaluator,
384
RegressionEvaluator,
385
ClusteringEvaluator,
386
RankingEvaluator
387
}
388
```