0
# Regression
1
2
MLlib provides comprehensive regression algorithms for supervised learning tasks with continuous target variables. All regressors follow the Estimator-Transformer pattern and support the Pipeline API.
3
4
## Linear Regression
5
6
### Estimator
7
8
```scala { .api }
9
class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel]
10
with LinearRegressionParams with DefaultParamsWritable {
11
12
def this() = this(Identifiable.randomUID("linReg"))
13
14
def setRegParam(value: Double): LinearRegression
15
def setFitIntercept(value: Boolean): LinearRegression
16
def setStandardization(value: Boolean): LinearRegression
17
def setElasticNetParam(value: Double): LinearRegression
18
def setMaxIter(value: Int): LinearRegression
19
def setTol(value: Double): LinearRegression
20
def setWeightCol(value: String): LinearRegression
21
def setSolver(value: String): LinearRegression
22
def setAggregationDepth(value: Int): LinearRegression
23
def setLoss(value: String): LinearRegression
24
def setEpsilon(value: Double): LinearRegression
25
26
override def fit(dataset: Dataset[_]): LinearRegressionModel
27
override def copy(extra: ParamMap): LinearRegression
28
}
29
```
30
31
### Model
32
33
```scala { .api }
34
class LinearRegressionModel private[ml] (
35
override val uid: String,
36
val coefficients: Vector,
37
val intercept: Double,
38
val scale: Double)
39
extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with GeneralMLWritable {
40
41
// Convenience constructor for backward compatibility
42
private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
43
this(uid, coefficients, intercept, 1.0)
44
45
lazy val summary: LinearRegressionTrainingSummary
46
def hasSummary: Boolean
47
def evaluate(dataset: Dataset[_]): LinearRegressionSummary
48
49
override def predict(features: Vector): Double
50
override def copy(extra: ParamMap): LinearRegressionModel
51
def write: MLWriter
52
}
53
54
class LinearRegressionTrainingSummary(predictions: DataFrame, predictionCol: String, labelCol: String,
55
featuresCol: String, val objectiveHistory: Array[Double],
56
val totalIterations: Int, val solver: String)
57
extends LinearRegressionSummary(predictions, predictionCol, labelCol, featuresCol) {
58
59
val coefficientStandardErrors: Array[Double]
60
val tValues: Array[Double]
61
val pValues: Array[Double]
62
}
63
64
class LinearRegressionSummary(predictions: DataFrame, predictionCol: String,
65
labelCol: String, featuresCol: String) extends Serializable {
66
67
val residuals: DataFrame
68
val rootMeanSquaredError: Double
69
val meanSquaredError: Double
70
val meanAbsoluteError: Double
71
val r2: Double
72
val explainedVariance: Double
73
val numInstances: Long
74
val degreesOfFreedom: Long
75
val devianceResiduals: Array[Double]
76
}
77
```
78
79
## Generalized Linear Regression
80
81
### Estimator
82
83
```scala { .api }
84
class GeneralizedLinearRegression(override val uid: String)
85
extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
86
with GeneralizedLinearRegressionParams with DefaultParamsWritable {
87
88
def this() = this(Identifiable.randomUID("glm"))
89
90
def setFamily(value: String): GeneralizedLinearRegression
91
def setVarianceFunction(value: String): GeneralizedLinearRegression
92
def setLink(value: String): GeneralizedLinearRegression
93
def setLinkFunction(value: String): GeneralizedLinearRegression
94
def setFitIntercept(value: Boolean): GeneralizedLinearRegression
95
def setMaxIter(value: Int): GeneralizedLinearRegression
96
def setTol(value: Double): GeneralizedLinearRegression
97
def setRegParam(value: Double): GeneralizedLinearRegression
98
def setWeightCol(value: String): GeneralizedLinearRegression
99
def setSolver(value: String): GeneralizedLinearRegression
100
def setLinkPredictionCol(value: String): GeneralizedLinearRegression
101
def setOffsetCol(value: String): GeneralizedLinearRegression
102
def setAggregationDepth(value: Int): GeneralizedLinearRegression
103
104
override def fit(dataset: Dataset[_]): GeneralizedLinearRegressionModel
105
override def copy(extra: ParamMap): GeneralizedLinearRegression
106
}
107
```
108
109
### Model
110
111
```scala { .api }
112
class GeneralizedLinearRegressionModel(override val uid: String, val coefficients: Vector, val intercept: Double)
113
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
114
with GeneralizedLinearRegressionParams with MLWritable {
115
116
lazy val summary: GeneralizedLinearRegressionTrainingSummary
117
def hasSummary: Boolean
118
def evaluate(dataset: Dataset[_]): GeneralizedLinearRegressionSummary
119
120
override def predict(features: Vector): Double
121
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel
122
def write: MLWriter
123
}
124
125
class GeneralizedLinearRegressionTrainingSummary(predictions: DataFrame, predictionCol: String,
126
labelCol: String, featuresCol: String,
127
val objectiveHistory: Array[Double], val solver: String,
128
val totalIterations: Int, val aic: Double, val deviance: Double,
129
val nullDeviance: Double, val dispersionParameter: Double,
130
val degreesOfFreedom: Long, val residualDegreeOfFreedom: Long,
131
val residualDegreeOfFreedomNull: Long, val coefficientStandardErrors: Array[Double],
132
val tValues: Array[Double], val pValues: Array[Double])
133
extends GeneralizedLinearRegressionSummary(predictions, predictionCol, labelCol, featuresCol) {
134
}
135
136
class GeneralizedLinearRegressionSummary(predictions: DataFrame, predictionCol: String,
137
labelCol: String, featuresCol: String) extends Serializable {
138
val residuals: DataFrame
139
val rank: Long
140
val numInstances: Long
141
}
142
```
143
144
## Decision Tree Regressor
145
146
### Estimator
147
148
```scala { .api }
149
class DecisionTreeRegressor(override val uid: String)
150
extends Regressor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
151
with DecisionTreeRegressorParams with DefaultParamsWritable {
152
153
def this() = this(Identifiable.randomUID("dtr"))
154
155
def setMaxDepth(value: Int): DecisionTreeRegressor
156
def setMaxBins(value: Int): DecisionTreeRegressor
157
def setMinInstancesPerNode(value: Int): DecisionTreeRegressor
158
def setMinInfoGain(value: Double): DecisionTreeRegressor
159
def setMaxMemoryInMB(value: Int): DecisionTreeRegressor
160
def setCacheNodeIds(value: Boolean): DecisionTreeRegressor
161
def setCheckpointInterval(value: Int): DecisionTreeRegressor
162
def setImpurity(value: String): DecisionTreeRegressor
163
def setVarianceCol(value: String): DecisionTreeRegressor
164
def setSeed(value: Long): DecisionTreeRegressor
165
166
override def fit(dataset: Dataset[_]): DecisionTreeRegressionModel
167
override def copy(extra: ParamMap): DecisionTreeRegressor
168
}
169
```
170
171
### Model
172
173
```scala { .api }
174
class DecisionTreeRegressionModel(override val uid: String, val rootNode: Node, val numFeatures: Int)
175
extends RegressionModel[Vector, DecisionTreeRegressionModel]
176
with DecisionTreeRegressorParams with TreeEnsembleModel with MLWritable {
177
178
override def predict(features: Vector): Double
179
def depth: Int
180
def numNodes: Int
181
def featureImportances: Vector
182
def toDebugString: String
183
override def copy(extra: ParamMap): DecisionTreeRegressionModel
184
def write: MLWriter
185
}
186
```
187
188
## Random Forest Regressor
189
190
### Estimator
191
192
```scala { .api }
193
class RandomForestRegressor(override val uid: String)
194
extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel]
195
with RandomForestRegressorParams with DefaultParamsWritable {
196
197
def this() = this(Identifiable.randomUID("rfr"))
198
199
def setNumTrees(value: Int): RandomForestRegressor
200
def setMaxDepth(value: Int): RandomForestRegressor
201
def setMaxBins(value: Int): RandomForestRegressor
202
def setMinInstancesPerNode(value: Int): RandomForestRegressor
203
def setMinInfoGain(value: Double): RandomForestRegressor
204
def setMaxMemoryInMB(value: Int): RandomForestRegressor
205
def setCacheNodeIds(value: Boolean): RandomForestRegressor
206
def setCheckpointInterval(value: Int): RandomForestRegressor
207
def setImpurity(value: String): RandomForestRegressor
208
def setSubsamplingRate(value: Double): RandomForestRegressor
209
def setSeed(value: Long): RandomForestRegressor
210
def setFeatureSubsetStrategy(value: String): RandomForestRegressor
211
212
override def fit(dataset: Dataset[_]): RandomForestRegressionModel
213
override def copy(extra: ParamMap): RandomForestRegressor
214
}
215
```
216
217
### Model
218
219
```scala { .api }
220
class RandomForestRegressionModel(override val uid: String, private val _trees: Array[DecisionTreeRegressionModel],
221
val numFeatures: Int)
222
extends RegressionModel[Vector, RandomForestRegressionModel]
223
with RandomForestRegressorParams with TreeEnsembleModel with MLWritable {
224
225
def trees: Array[DecisionTreeRegressionModel]
226
def treeWeights: Array[Double]
227
def featureImportances: Vector
228
229
override def predict(features: Vector): Double
230
def totalNumNodes: Int
231
def toDebugString: String
232
override def copy(extra: ParamMap): RandomForestRegressionModel
233
def write: MLWriter
234
}
235
```
236
237
## Gradient Boosted Tree Regressor
238
239
### Estimator
240
241
```scala { .api }
242
class GBTRegressor(override val uid: String)
243
extends Regressor[Vector, GBTRegressor, GBTRegressionModel]
244
with GBTRegressorParams with DefaultParamsWritable {
245
246
def this() = this(Identifiable.randomUID("gbtr"))
247
248
def setMaxDepth(value: Int): GBTRegressor
249
def setMaxBins(value: Int): GBTRegressor
250
def setMinInstancesPerNode(value: Int): GBTRegressor
251
def setMinInfoGain(value: Double): GBTRegressor
252
def setMaxMemoryInMB(value: Int): GBTRegressor
253
def setCacheNodeIds(value: Boolean): GBTRegressor
254
def setCheckpointInterval(value: Int): GBTRegressor
255
def setLossType(value: String): GBTRegressor
256
def setMaxIter(value: Int): GBTRegressor
257
def setStepSize(value: Double): GBTRegressor
258
def setSubsamplingRate(value: Double): GBTRegressor
259
def setSeed(value: Long): GBTRegressor
260
def setFeatureSubsetStrategy(value: String): GBTRegressor
261
def setValidationTol(value: Double): GBTRegressor
262
def setValidationIndicatorCol(value: String): GBTRegressor
263
264
override def fit(dataset: Dataset[_]): GBTRegressionModel
265
override def copy(extra: ParamMap): GBTRegressor
266
}
267
```
268
269
### Model
270
271
```scala { .api }
272
class GBTRegressionModel(override val uid: String, private val _trees: Array[DecisionTreeRegressionModel],
273
private val _treeWeights: Array[Double], val numFeatures: Int)
274
extends RegressionModel[Vector, GBTRegressionModel]
275
with GBTRegressorParams with TreeEnsembleModel with MLWritable {
276
277
def trees: Array[DecisionTreeRegressionModel]
278
def treeWeights: Array[Double]
279
def featureImportances: Vector
280
def totalNumNodes: Int
281
def getNumTrees: Int
282
283
override def predict(features: Vector): Double
284
def toDebugString: String
285
override def copy(extra: ParamMap): GBTRegressionModel
286
def write: MLWriter
287
}
288
```
289
290
## Accelerated Failure Time Survival Regression
291
292
### Estimator
293
294
```scala { .api }
295
class AFTSurvivalRegression(override val uid: String)
296
extends Regressor[Vector, AFTSurvivalRegression, AFTSurvivalRegressionModel]
297
with AFTSurvivalRegressionParams with DefaultParamsWritable {
298
299
def this() = this(Identifiable.randomUID("aft"))
300
301
def setCensorCol(value: String): AFTSurvivalRegression
302
def setQuantileProbabilities(value: Array[Double]): AFTSurvivalRegression
303
def setQuantilesCol(value: String): AFTSurvivalRegression
304
def setFitIntercept(value: Boolean): AFTSurvivalRegression
305
def setMaxIter(value: Int): AFTSurvivalRegression
306
def setTol(value: Double): AFTSurvivalRegression
307
def setAggregationDepth(value: Int): AFTSurvivalRegression
308
309
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel
310
override def copy(extra: ParamMap): AFTSurvivalRegression
311
}
312
```
313
314
### Model
315
316
```scala { .api }
317
class AFTSurvivalRegressionModel(override val uid: String, val coefficients: Vector,
318
val intercept: Double, val scale: Double)
319
extends RegressionModel[Vector, AFTSurvivalRegressionModel]
320
with AFTSurvivalRegressionParams with MLWritable {
321
322
def predictQuantiles(features: Vector): Vector
323
override def predict(features: Vector): Double
324
override def copy(extra: ParamMap): AFTSurvivalRegressionModel
325
def write: MLWriter
326
}
327
```
328
329
## Isotonic Regression
330
331
### Estimator
332
333
```scala { .api }
334
class IsotonicRegression(override val uid: String)
335
extends Regressor[Vector, IsotonicRegression, IsotonicRegressionModel]
336
with IsotonicRegressionParams with DefaultParamsWritable {
337
338
def this() = this(Identifiable.randomUID("isoReg"))
339
340
def setIsotonic(value: Boolean): IsotonicRegression
341
def setWeightCol(value: String): IsotonicRegression
342
def setFeatureIndex(value: Int): IsotonicRegression
343
344
override def fit(dataset: Dataset[_]): IsotonicRegressionModel
345
override def copy(extra: ParamMap): IsotonicRegression
346
}
347
```
348
349
### Model
350
351
```scala { .api }
352
class IsotonicRegressionModel(override val uid: String, val boundaries: Array[Double],
353
val predictions: Array[Double], val numFeatures: Int)
354
extends RegressionModel[Vector, IsotonicRegressionModel]
355
with IsotonicRegressionParams with MLWritable {
356
357
override def predict(features: Vector): Double
358
override def copy(extra: ParamMap): IsotonicRegressionModel
359
def write: MLWriter
360
}
361
```
362
363
## Usage Examples
364
365
### Basic Linear Regression
366
367
```scala
368
import org.apache.spark.ml.regression.LinearRegression
369
import org.apache.spark.ml.feature.VectorAssembler
370
371
// Prepare features
372
val assembler = new VectorAssembler()
373
.setInputCols(Array("feature1", "feature2", "feature3"))
374
.setOutputCol("features")
375
376
val data = assembler.transform(rawData)
377
378
// Create and configure linear regression
379
val lr = new LinearRegression()
380
.setLabelCol("label")
381
.setFeaturesCol("features")
382
.setRegParam(0.1)
383
.setElasticNetParam(0.8)
384
.setMaxIter(100)
385
.setTol(1E-6)
386
387
// Split data
388
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 42)
389
390
// Train model
391
val lrModel = lr.fit(trainingData)
392
393
// Print coefficients and intercept
394
println(s"Coefficients: ${lrModel.coefficients}")
395
println(s"Intercept: ${lrModel.intercept}")
396
397
// Make predictions
398
val predictions = lrModel.transform(testData)
399
predictions.select("features", "label", "prediction").show()
400
401
// Get training summary
402
val trainingSummary = lrModel.summary
403
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
404
println(s"R-squared: ${trainingSummary.r2}")
405
println(s"Mean Absolute Error: ${trainingSummary.meanAbsoluteError}")
406
```
407
408
### Generalized Linear Model
409
410
```scala
411
import org.apache.spark.ml.regression.GeneralizedLinearRegression
412
413
// Create GLM with Poisson family and log link
414
val glr = new GeneralizedLinearRegression()
415
.setFamily("poisson")
416
.setLink("log")
417
.setMaxIter(10)
418
.setRegParam(0.3)
419
420
val glrModel = glr.fit(trainingData)
421
422
// Print model summary
423
val glrSummary = glrModel.summary
424
println(s"Coefficients: ${glrModel.coefficients}")
425
println(s"Intercept: ${glrModel.intercept}")
426
println(s"AIC: ${glrSummary.aic}")
427
println(s"Deviance: ${glrSummary.deviance}")
428
429
// Statistical significance tests
430
println("Coefficient Standard Errors:")
431
glrSummary.coefficientStandardErrors.zipWithIndex.foreach {
432
case (se, idx) => println(s" Coefficient $idx: $se")
433
}
434
435
println("T-Values:")
436
glrSummary.tValues.zipWithIndex.foreach {
437
case (t, idx) => println(s" Coefficient $idx: $t")
438
}
439
440
println("P-Values:")
441
glrSummary.pValues.zipWithIndex.foreach {
442
case (p, idx) => println(s" Coefficient $idx: $p")
443
}
444
```
445
446
### Random Forest Regression
447
448
```scala
449
import org.apache.spark.ml.regression.RandomForestRegressor
450
import org.apache.spark.ml.evaluation.RegressionEvaluator
451
452
val rf = new RandomForestRegressor()
453
.setLabelCol("label")
454
.setFeaturesCol("features")
455
.setNumTrees(100)
456
.setMaxDepth(10)
457
.setMaxBins(32)
458
.setMinInstancesPerNode(1)
459
.setMinInfoGain(0.0)
460
.setSubsamplingRate(1.0)
461
.setFeatureSubsetStrategy("auto")
462
.setSeed(42)
463
464
val rfModel = rf.fit(trainingData)
465
466
// Get feature importances
467
val featureImportances = rfModel.featureImportances
468
println(s"Feature importances: $featureImportances")
469
470
// Make predictions
471
val rfPredictions = rfModel.transform(testData)
472
473
// Evaluate model
474
val evaluator = new RegressionEvaluator()
475
.setLabelCol("label")
476
.setPredictionCol("prediction")
477
478
val metrics = Array("rmse", "mse", "mae", "r2")
479
metrics.foreach { metric =>
480
evaluator.setMetricName(metric)
481
val result = evaluator.evaluate(rfPredictions)
482
println(s"$metric: $result")
483
}
484
485
// Print tree structure information
486
println(s"Total number of trees: ${rfModel.getNumTrees}")
487
println(s"Total number of nodes: ${rfModel.totalNumNodes}")
488
```
489
490
### Gradient Boosted Trees Regression
491
492
```scala
493
import org.apache.spark.ml.regression.GBTRegressor
494
495
val gbt = new GBTRegressor()
496
.setLabelCol("label")
497
.setFeaturesCol("features")
498
.setMaxIter(100)
499
.setMaxDepth(5)
500
.setStepSize(0.1)
501
.setSubsamplingRate(1.0)
502
.setFeatureSubsetStrategy("auto")
503
.setSeed(42)
504
505
// Add validation for early stopping
506
val gbtWithValidation = gbt
507
.setValidationTol(0.01)
508
.setValidationIndicatorCol("isValidation")
509
510
val gbtModel = gbt.fit(trainingData)
511
512
// Print model information
513
println(s"Feature importances: ${gbtModel.featureImportances}")
514
println(s"Number of trees: ${gbtModel.getNumTrees}")
515
println(s"Tree weights: ${gbtModel.treeWeights.mkString(", ")}")
516
517
val gbtPredictions = gbtModel.transform(testData)
518
519
// Evaluate
520
val rmse = evaluator.setMetricName("rmse").evaluate(gbtPredictions)
521
println(s"RMSE on test data: $rmse")
522
```
523
524
### Survival Regression
525
526
```scala
527
import org.apache.spark.ml.regression.AFTSurvivalRegression
528
529
// Data should have label (survival time), censor indicator, and features
530
val aft = new AFTSurvivalRegression()
531
.setLabelCol("survival_time")
532
.setCensorCol("censor")
533
.setFeaturesCol("features")
534
.setQuantileProbabilities(Array(0.1, 0.5, 0.9))
535
.setQuantilesCol("quantiles")
536
537
val aftModel = aft.fit(trainingData)
538
539
// Print model parameters
540
println(s"Coefficients: ${aftModel.coefficients}")
541
println(s"Intercept: ${aftModel.intercept}")
542
println(s"Scale: ${aftModel.scale}")
543
544
// Make predictions including quantiles
545
val aftPredictions = aftModel.transform(testData)
546
aftPredictions.select("survival_time", "prediction", "quantiles").show()
547
```
548
549
### Isotonic Regression
550
551
```scala
552
import org.apache.spark.ml.regression.IsotonicRegression
553
554
// Isotonic regression expects monotonic relationship
555
val iso = new IsotonicRegression()
556
.setLabelCol("label")
557
.setFeaturesCol("features")
558
.setIsotonic(true) // true for increasing, false for decreasing
559
.setFeatureIndex(0) // which feature to use (for vector inputs)
560
561
val isoModel = iso.fit(trainingData)
562
563
// Print isotonic boundaries and predictions
564
println(s"Boundaries: ${isoModel.boundaries.mkString(", ")}")
565
println(s"Predictions: ${isoModel.predictions.mkString(", ")}")
566
567
val isoPredictions = isoModel.transform(testData)
568
isoPredictions.select("features", "label", "prediction").show()
569
```
570
571
### Cross-Validation for Hyperparameter Tuning
572
573
```scala
574
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
575
import org.apache.spark.ml.regression.RandomForestRegressor
576
import org.apache.spark.ml.evaluation.RegressionEvaluator
577
578
val rf = new RandomForestRegressor()
579
.setLabelCol("label")
580
.setFeaturesCol("features")
581
582
// Build parameter grid
583
val paramGrid = new ParamGridBuilder()
584
.addGrid(rf.numTrees, Array(10, 20, 50))
585
.addGrid(rf.maxDepth, Array(5, 10, 15))
586
.addGrid(rf.minInstancesPerNode, Array(1, 2, 5))
587
.build()
588
589
// Create evaluator
590
val evaluator = new RegressionEvaluator()
591
.setLabelCol("label")
592
.setPredictionCol("prediction")
593
.setMetricName("rmse")
594
595
// Create cross-validator
596
val cv = new CrossValidator()
597
.setEstimator(rf)
598
.setEvaluator(evaluator)
599
.setEstimatorParamMaps(paramGrid)
600
.setNumFolds(3)
601
.setParallelism(2)
602
603
// Train model with cross-validation
604
val cvModel = cv.fit(trainingData)
605
606
// Get best model
607
val bestModel = cvModel.bestModel.asInstanceOf[RandomForestRegressionModel]
608
println(s"Best model num trees: ${bestModel.getNumTrees}")
609
println(s"Best model max depth: ${bestModel.getMaxDepth}")
610
611
// Evaluate on test data
612
val finalPredictions = cvModel.transform(testData)
613
val finalRMSE = evaluator.evaluate(finalPredictions)
614
println(s"Final RMSE on test data: $finalRMSE")
615
```