0
# Regression Algorithms
1
2
Supervised learning algorithms for predicting continuous values. MLlib provides comprehensive regression capabilities including linear models, tree-based methods, generalized linear models, and specialized algorithms for survival analysis.
3
4
## Capabilities
5
6
### Linear Regression
7
8
Linear regression with L1, L2, and elastic net regularization for modeling linear relationships.
9
10
```scala { .api }
11
/**
12
* Linear regression with support for L1, L2, and elastic net regularization
13
* Supports normal equation and iterative optimization methods
14
*/
15
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] {
16
def setRegParam(value: Double): this.type
17
def setElasticNetParam(value: Double): this.type
18
def setMaxIter(value: Int): this.type
19
def setTol(value: Double): this.type
20
def setFitIntercept(value: Boolean): this.type
21
def setStandardization(value: Boolean): this.type
22
def setWeightCol(value: String): this.type
23
def setSolver(value: String): this.type
24
def setAggregationDepth(value: Int): this.type
25
def setLoss(value: String): this.type
26
def setEpsilon(value: Double): this.type
27
def setMaxBlockSizeInMB(value: Double): this.type
28
}
29
30
class LinearRegressionModel extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable {
31
def coefficients: Vector
32
def intercept: Double
33
def numFeatures: Int
34
def scale: Double
35
def summary: LinearRegressionSummary
36
def hasSummary: Boolean
37
def evaluate(dataset: Dataset[_]): LinearRegressionSummary
38
}
39
40
// Model summary classes
41
class LinearRegressionSummary extends Serializable {
42
def predictions: DataFrame
43
def predictionCol: String
44
def labelCol: String
45
def featuresCol: String
46
def explainedVariance: Double
47
def meanAbsoluteError: Double
48
def meanSquaredError: Double
49
def rootMeanSquaredError: Double
50
def r2: Double
51
def r2adj: Double
52
def residuals: DataFrame
53
def numInstances: Long
54
}
55
56
class LinearRegressionTrainingSummary extends LinearRegressionSummary {
57
def totalIterations: Int
58
def objectiveHistory: Array[Double]
59
def devianceResiduals: Array[Double]
60
def coefficientStandardErrors: Array[Double]
61
def tValues: Array[Double]
62
def pValues: Array[Double]
63
}
64
```
65
66
**Usage Example:**
67
68
```scala
69
import org.apache.spark.ml.regression.LinearRegression
70
71
val lr = new LinearRegression()
72
.setLabelCol("label")
73
.setFeaturesCol("features")
74
.setRegParam(0.01)
75
.setElasticNetParam(0.5)
76
.setMaxIter(100)
77
78
val model = lr.fit(trainingData)
79
val predictions = model.transform(testData)
80
81
// Access model information
82
println(s"Coefficients: ${model.coefficients}")
83
println(s"Intercept: ${model.intercept}")
84
85
// Access training summary
86
val summary = model.summary
87
println(s"RMSE: ${summary.rootMeanSquaredError}")
88
println(s"R-squared: ${summary.r2}")
89
println(s"Mean Absolute Error: ${summary.meanAbsoluteError}")
90
```
91
92
### Decision Tree Regressor
93
94
Decision tree algorithm for regression with automatic handling of categorical features.
95
96
```scala { .api }
97
/**
98
* Decision tree regressor supporting continuous and categorical features
99
* with configurable tree depth and splitting criteria
100
*/
101
class DecisionTreeRegressor extends Regressor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] {
102
def setMaxDepth(value: Int): this.type
103
def setMaxBins(value: Int): this.type
104
def setMinInstancesPerNode(value: Int): this.type
105
def setMinInfoGain(value: Double): this.type
106
def setMaxMemoryInMB(value: Int): this.type
107
def setCacheNodeIds(value: Boolean): this.type
108
def setCheckpointInterval(value: Int): this.type
109
def setImpurity(value: String): this.type
110
def setVarianceCol(value: String): this.type
111
def setSeed(value: Long): this.type
112
def setWeightCol(value: String): this.type
113
}
114
115
class DecisionTreeRegressionModel extends RegressionModel[Vector, DecisionTreeRegressionModel]
116
with DecisionTreeModel with MLWritable {
117
def rootNode: Node
118
def numNodes: Int
119
def depth: Int
120
def toDebugString: String
121
def featureImportances: Vector
122
}
123
```
124
125
### Random Forest Regressor
126
127
Ensemble of decision trees using bootstrap aggregating for improved prediction accuracy.
128
129
```scala { .api }
130
/**
131
* Random Forest regressor - ensemble of decision trees with bootstrap sampling
132
* and random feature selection for robust regression modeling
133
*/
134
class RandomForestRegressor extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel] {
135
def setNumTrees(value: Int): this.type
136
def setMaxDepth(value: Int): this.type
137
def setMaxBins(value: Int): this.type
138
def setMinInstancesPerNode(value: Int): this.type
139
def setMinInfoGain(value: Double): this.type
140
def setMaxMemoryInMB(value: Int): this.type
141
def setCacheNodeIds(value: Boolean): this.type
142
def setCheckpointInterval(value: Int): this.type
143
def setImpurity(value: String): this.type
144
def setSubsamplingRate(value: Double): this.type
145
def setFeatureSubsetStrategy(value: String): this.type
146
def setSeed(value: Long): this.type
147
def setWeightCol(value: String): this.type
148
}
149
150
class RandomForestRegressionModel extends RegressionModel[Vector, RandomForestRegressionModel]
151
with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable {
152
def trees: Array[DecisionTreeRegressionModel]
153
def numTrees: Int
154
def treeWeights: Array[Double]
155
def featureImportances: Vector
156
def toDebugString: String
157
}
158
```
159
160
### Gradient-Boosted Tree Regressor
161
162
Iterative ensemble method building models sequentially to minimize prediction errors.
163
164
```scala { .api }
165
/**
166
* Gradient-Boosted Tree regressor using iterative boosting
167
* to build an ensemble of weak decision tree learners
168
*/
169
class GBTRegressor extends Regressor[Vector, GBTRegressor, GBTRegressionModel] {
170
def setMaxIter(value: Int): this.type
171
def setStepSize(value: Double): this.type
172
def setMaxDepth(value: Int): this.type
173
def setMaxBins(value: Int): this.type
174
def setMinInstancesPerNode(value: Int): this.type
175
def setMinInfoGain(value: Double): this.type
176
def setMaxMemoryInMB(value: Int): this.type
177
def setCacheNodeIds(value: Boolean): this.type
178
def setSubsamplingRate(value: Double): this.type
179
def setCheckpointInterval(value: Int): this.type
180
def setLossType(value: String): this.type
181
def setFeatureSubsetStrategy(value: String): this.type
182
def setValidationTol(value: Double): this.type
183
def setValidationIndicatorCol(value: String): this.type
184
def setSeed(value: Long): this.type
185
def setWeightCol(value: String): this.type
186
}
187
188
class GBTRegressionModel extends RegressionModel[Vector, GBTRegressionModel]
189
with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable {
190
def trees: Array[DecisionTreeRegressionModel]
191
def treeWeights: Array[Double]
192
def numTrees: Int
193
def featureImportances: Vector
194
def toDebugString: String
195
}
196
```
197
198
### Generalized Linear Regression
199
200
Generalized linear models supporting various distributions and link functions.
201
202
```scala { .api }
203
/**
204
* Generalized Linear Regression supporting multiple distributions and link functions
205
* including Gaussian, Binomial, Poisson, Gamma, and Tweedie distributions
206
*/
207
class GeneralizedLinearRegression extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] {
208
def setFamily(value: String): this.type
209
def setLink(value: String): this.type
210
def setFitIntercept(value: Boolean): this.type
211
def setMaxIter(value: Int): this.type
212
def setTol(value: Double): this.type
213
def setRegParam(value: Double): this.type
214
def setWeightCol(value: String): this.type
215
def setSolver(value: String): this.type
216
def setVariancePower(value: Double): this.type
217
def setLinkPower(value: Double): this.type
218
def setOffsetCol(value: String): this.type
219
}
220
221
class GeneralizedLinearRegressionModel extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
222
with GeneralizedLinearRegressionParams with MLWritable {
223
def coefficients: Vector
224
def intercept: Double
225
def numFeatures: Int
226
def summary: GeneralizedLinearRegressionSummary
227
def hasSummary: Boolean
228
def evaluate(dataset: Dataset[_]): GeneralizedLinearRegressionSummary
229
}
230
231
// GLM summary classes
232
class GeneralizedLinearRegressionSummary extends Serializable {
233
def predictions: DataFrame
234
def predictionCol: String
235
def numInstances: Long
236
def degreesOfFreedom: Long
237
def residualDegreeOfFreedom: Long
238
def residualDegreeOfFreedomNull: Long
239
def aic: Double
240
def deviance: Double
241
def nullDeviance: Double
242
def residuals(residualsType: String): DataFrame
243
def rank: Long
244
}
245
246
class GeneralizedLinearRegressionTrainingSummary extends GeneralizedLinearRegressionSummary {
247
def numIterations: Int
248
def solver: String
249
def coefficientStandardErrors: Array[Double]
250
def tValues: Array[Double]
251
def pValues: Array[Double]
252
}
253
```
254
255
### Isotonic Regression
256
257
Non-parametric regression for monotonic relationships between variables.
258
259
```scala { .api }
260
/**
261
* Isotonic regression for non-decreasing (isotonic) or non-increasing (antitonic)
262
* relationships between variables using pool-adjacent-violators algorithm
263
*/
264
class IsotonicRegression extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] {
265
def setIsotonic(value: Boolean): this.type
266
def setFeatureIndex(value: Int): this.type
267
def setWeightCol(value: String): this.type
268
}
269
270
class IsotonicRegressionModel extends RegressionModel[Double, IsotonicRegressionModel]
271
with IsotonicRegressionParams with MLWritable {
272
def boundaries: Vector
273
def predictions: Vector
274
def numFeatures: Int = 1
275
276
override def predict(features: Double): Double
277
override def predict(features: Vector): Double
278
}
279
```
280
281
### AFT Survival Regression
282
283
Accelerated Failure Time model for survival analysis and time-to-event data.
284
285
```scala { .api }
286
/**
287
* Accelerated Failure Time (AFT) survival regression model
288
* for modeling time-to-event data with censoring support
289
*/
290
class AFTSurvivalRegression extends Regressor[Vector, AFTSurvivalRegression, AFTSurvivalRegressionModel] {
291
def setCensorCol(value: String): this.type
292
def setQuantileProbabilities(value: Array[Double]): this.type
293
def setQuantilesCol(value: String): this.type
294
def setFitIntercept(value: Boolean): this.type
295
def setMaxIter(value: Int): this.type
296
def setTol(value: Double): this.type
297
def setAggregationDepth(value: Int): this.type
298
}
299
300
class AFTSurvivalRegressionModel extends RegressionModel[Vector, AFTSurvivalRegressionModel]
301
with AFTSurvivalRegressionParams with MLWritable {
302
def coefficients: Vector
303
def intercept: Double
304
def scale: Double
305
def numFeatures: Int
306
307
def predictQuantiles(features: Vector): Vector
308
def predict(features: Vector, quantile: Double): Double
309
}
310
```
311
312
### Factorization Machine Regressor
313
314
Factorization machines for regression with feature interaction modeling.
315
316
```scala { .api }
317
/**
318
* Factorization Machine regressor for modeling feature interactions
319
* through low-rank matrix factorization, effective for sparse data
320
*/
321
class FMRegressor extends Regressor[Vector, FMRegressor, FMRegressionModel] {
322
def setFactorSize(value: Int): this.type
323
def setFitIntercept(value: Boolean): this.type
324
def setFitLinear(value: Boolean): this.type
325
def setRegParam(value: Double): this.type
326
def setMiniBatchFraction(value: Double): this.type
327
def setInitStd(value: Double): this.type
328
def setMaxIter(value: Int): this.type
329
def setStepSize(value: Double): this.type
330
def setTol(value: Double): this.type
331
def setSolver(value: String): this.type
332
def setSeed(value: Long): this.type
333
}
334
335
class FMRegressionModel extends RegressionModel[Vector, FMRegressionModel]
336
with FMRegressorParams with MLWritable {
337
def linear: Vector
338
def factors: Matrix
339
def intercept: Double
340
}
341
```
342
343
## Base Classes and Traits
344
345
```scala { .api }
346
// Core regression abstractions
347
abstract class Regressor[
348
FeaturesType,
349
Learner <: Regressor[FeaturesType, Learner, M],
350
M <: RegressionModel[FeaturesType, M]
351
] extends Predictor[FeaturesType, Learner, M] with RegressorParams
352
353
abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
354
extends PredictionModel[FeaturesType, M] with RegressorParams
355
356
// Parameter traits
357
trait RegressorParams extends PredictorParams
358
359
// Tree-specific parameters (shared with classification)
360
trait DecisionTreeParams extends Params {
361
final val maxDepth: IntParam
362
final val maxBins: IntParam
363
final val minInstancesPerNode: IntParam
364
final val minInfoGain: DoubleParam
365
final val impurity: Param[String]
366
}
367
368
trait TreeEnsembleParams extends DecisionTreeParams {
369
final val subsamplingRate: DoubleParam
370
final val featureSubsetStrategy: Param[String]
371
final val numTrees: IntParam
372
}
373
374
// GLM-specific parameters
375
trait GeneralizedLinearRegressionParams extends Params {
376
final val family: Param[String]
377
final val link: Param[String]
378
final val solver: Param[String]
379
final val variancePower: DoubleParam
380
final val linkPower: DoubleParam
381
final val offsetCol: Param[String]
382
}
383
```
384
385
## Usage Examples
386
387
### Basic Regression Workflow
388
389
```scala
390
import org.apache.spark.ml.regression.RandomForestRegressor
391
import org.apache.spark.ml.evaluation.RegressionEvaluator
392
393
// Create regressor
394
val rf = new RandomForestRegressor()
395
.setLabelCol("label")
396
.setFeaturesCol("features")
397
.setNumTrees(20)
398
.setMaxDepth(10)
399
400
// Train model
401
val model = rf.fit(trainingData)
402
403
// Make predictions
404
val predictions = model.transform(testData)
405
406
// Evaluate
407
val evaluator = new RegressionEvaluator()
408
.setLabelCol("label")
409
.setPredictionCol("prediction")
410
.setMetricName("rmse")
411
412
val rmse = evaluator.evaluate(predictions)
413
println(s"Root Mean Squared Error (RMSE) = $rmse")
414
```
415
416
### Generalized Linear Regression Example
417
418
```scala
419
import org.apache.spark.ml.regression.GeneralizedLinearRegression
420
421
// Poisson regression for count data
422
val glr = new GeneralizedLinearRegression()
423
.setFamily("poisson")
424
.setLink("log")
425
.setMaxIter(10)
426
.setRegParam(0.3)
427
428
val model = glr.fit(trainingData)
429
val predictions = model.transform(testData)
430
431
// Access model summary
432
val summary = model.summary
433
println(s"Coefficient Standard Errors: ${summary.coefficientStandardErrors.mkString(",")}")
434
println(s"T Values: ${summary.tValues.mkString(",")}")
435
println(s"P Values: ${summary.pValues.mkString(",")}")
436
println(s"Deviance: ${summary.deviance}")
437
println(s"AIC: ${summary.aic}")
438
```
439
440
### Survival Analysis Example
441
442
```scala
443
import org.apache.spark.ml.regression.AFTSurvivalRegression
444
445
val aft = new AFTSurvivalRegression()
446
.setLabelCol("time")
447
.setCensorCol("censor")
448
.setFeaturesCol("features")
449
.setQuantileProbabilities(Array(0.1, 0.5, 0.9))
450
.setQuantilesCol("quantiles")
451
452
val model = aft.fit(trainingData)
453
val predictions = model.transform(testData)
454
455
// Show quantile predictions
456
predictions.select("time", "censor", "prediction", "quantiles").show()
457
```
458
459
### Model Comparison
460
461
```scala
462
import org.apache.spark.ml.regression.{LinearRegression, RandomForestRegressor, GBTRegressor}
463
import org.apache.spark.ml.evaluation.RegressionEvaluator
464
465
val evaluator = new RegressionEvaluator()
466
.setLabelCol("label")
467
.setPredictionCol("prediction")
468
.setMetricName("rmse")
469
470
// Linear Regression
471
val lr = new LinearRegression()
472
val lrModel = lr.fit(trainingData)
473
val lrPredictions = lrModel.transform(testData)
474
val lrRmse = evaluator.evaluate(lrPredictions)
475
476
// Random Forest
477
val rf = new RandomForestRegressor().setNumTrees(20)
478
val rfModel = rf.fit(trainingData)
479
val rfPredictions = rfModel.transform(testData)
480
val rfRmse = evaluator.evaluate(rfPredictions)
481
482
// Gradient Boosted Trees
483
val gbt = new GBTRegressor().setMaxIter(20)
484
val gbtModel = gbt.fit(trainingData)
485
val gbtPredictions = gbtModel.transform(testData)
486
val gbtRmse = evaluator.evaluate(gbtPredictions)
487
488
println(s"Linear Regression RMSE: $lrRmse")
489
println(s"Random Forest RMSE: $rfRmse")
490
println(s"GBT RMSE: $gbtRmse")
491
```
492
493
### Feature Importance Analysis
494
495
```scala
496
// For tree-based models
497
val treeModel = model.asInstanceOf[RandomForestRegressionModel]
498
val importances = treeModel.featureImportances
499
500
// Create feature importance DataFrame
501
val featureNames = Array("feature1", "feature2", "feature3") // Your feature names
502
val importanceData = featureNames.zip(importances.toArray)
503
.sortBy(-_._2)
504
.map { case (name, importance) => (name, importance) }
505
506
importanceData.foreach { case (name, importance) =>
507
println(s"$name: $importance")
508
}
509
```