0
# Machine Learning
1
2
Apache Spark provides scalable machine learning capabilities through two APIs: the RDD-based MLlib API (in maintenance mode) and the DataFrame-based ML API (primary API). The ML API provides high-level abstractions for building machine learning pipelines.
3
4
## Package Information
5
6
Machine learning functionality is available through:
7
8
```scala
9
// ML API (DataFrame-based, primary API)
10
import org.apache.spark.ml._
11
import org.apache.spark.ml.feature._
12
import org.apache.spark.ml.classification._
13
import org.apache.spark.ml.regression._
14
import org.apache.spark.ml.clustering._
15
import org.apache.spark.ml.recommendation._
16
import org.apache.spark.ml.evaluation._
17
18
// MLlib API (RDD-based, maintenance mode)
19
import org.apache.spark.mllib.classification._
20
import org.apache.spark.mllib.regression._
21
import org.apache.spark.mllib.clustering._
22
import org.apache.spark.mllib.linalg._
23
```
24
25
## Basic Usage
26
27
```scala
28
import org.apache.spark.sql.SparkSession
29
import org.apache.spark.ml.Pipeline
30
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
31
import org.apache.spark.ml.classification.LogisticRegression
32
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
33
34
val spark = SparkSession.builder().appName("ML Example").getOrCreate()
35
36
// Load data
37
val data = spark.read
38
.option("header", "true")
39
.option("inferSchema", "true")
40
.csv("path/to/data.csv")
41
42
// Feature engineering
43
val stringIndexer = new StringIndexer()
44
.setInputCol("category")
45
.setOutputCol("categoryIndex")
46
47
val assembler = new VectorAssembler()
48
.setInputCols(Array("feature1", "feature2", "categoryIndex"))
49
.setOutputCol("features")
50
51
// Model
52
val lr = new LogisticRegression()
53
.setFeaturesCol("features")
54
.setLabelCol("label")
55
56
// Pipeline
57
val pipeline = new Pipeline()
58
.setStages(Array(stringIndexer, assembler, lr))
59
60
// Train
61
val Array(trainData, testData) = data.randomSplit(Array(0.8, 0.2), seed = 1234)
62
val model = pipeline.fit(trainData)
63
64
// Evaluate
65
val predictions = model.transform(testData)
66
val evaluator = new BinaryClassificationEvaluator()
67
val auc = evaluator.evaluate(predictions)
68
69
println(s"AUC: $auc")
70
```
71
72
## Capabilities
73
74
### ML Pipeline API
75
76
The primary machine learning API built on DataFrames, providing high-level abstractions for creating ML workflows.
77
78
#### Pipeline Components
79
80
```scala { .api }
81
abstract class PipelineStage extends Params {
82
def copy(extra: ParamMap): PipelineStage
83
def transformSchema(schema: StructType): StructType
84
def params: Array[Param[_]]
85
}
86
87
abstract class Estimator[M <: Model[M]] extends PipelineStage {
88
def fit(dataset: Dataset[_]): M
89
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M]
90
def fit(dataset: Dataset[_], paramMap: ParamMap): M
91
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M
92
}
93
94
abstract class Transformer extends PipelineStage {
95
def transform(dataset: Dataset[_]): DataFrame
96
def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame
97
def transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame
98
}
99
100
abstract class Model[M <: Model[M]] extends Transformer
101
102
class Pipeline(val uid: String) extends Estimator[PipelineModel] {
103
def this() = this(Identifiable.randomUID("pipeline"))
104
105
def setStages(value: Array[PipelineStage]): Pipeline
106
def getStages: Array[PipelineStage]
107
def fit(dataset: Dataset[_]): PipelineModel
108
}
109
110
class PipelineModel(override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] {
111
def transform(dataset: Dataset[_]): DataFrame
112
}
113
```
114
115
#### Feature Engineering
116
117
```scala { .api }
118
// Vector operations
119
class VectorAssembler(override val uid: String) extends Transformer {
120
def this() = this(Identifiable.randomUID("vecAssembler"))
121
122
def setInputCols(value: Array[String]): VectorAssembler
123
def setOutputCol(value: String): VectorAssembler
124
def getInputCols: Array[String]
125
def getOutputCol: String
126
def transform(dataset: Dataset[_]): DataFrame
127
}
128
129
// String indexing
130
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] {
131
def this() = this(Identifiable.randomUID("strIdx"))
132
133
def setInputCol(value: String): StringIndexer
134
def setOutputCol(value: String): StringIndexer
135
def setHandleInvalid(value: String): StringIndexer
136
def setStringOrderType(value: String): StringIndexer
137
def fit(dataset: Dataset[_]): StringIndexerModel
138
}
139
140
class StringIndexerModel(override val uid: String, val labels: Array[String]) extends Model[StringIndexerModel] {
141
def setInputCol(value: String): StringIndexerModel
142
def setOutputCol(value: String): StringIndexerModel
143
def setHandleInvalid(value: String): StringIndexerModel
144
def transform(dataset: Dataset[_]): DataFrame
145
}
146
147
// One-hot encoding
148
class OneHotEncoder(override val uid: String) extends Transformer {
149
def this() = this(Identifiable.randomUID("oneHot"))
150
151
def setInputCols(value: Array[String]): OneHotEncoder
152
def setOutputCols(value: Array[String]): OneHotEncoder
153
def setHandleInvalid(value: String): OneHotEncoder
154
def setDropLast(value: Boolean): OneHotEncoder
155
def transform(dataset: Dataset[_]): DataFrame
156
}
157
158
// Scaling
159
class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] {
160
def this() = this(Identifiable.randomUID("stdScal"))
161
162
def setInputCol(value: String): StandardScaler
163
def setOutputCol(value: String): StandardScaler
164
def setWithMean(value: Boolean): StandardScaler
165
def setWithStd(value: Boolean): StandardScaler
166
def fit(dataset: Dataset[_]): StandardScalerModel
167
}
168
169
class StandardScalerModel(override val uid: String, val std: Vector, val mean: Vector) extends Model[StandardScalerModel] {
170
def setInputCol(value: String): StandardScalerModel
171
def setOutputCol(value: String): StandardScalerModel
172
def transform(dataset: Dataset[_]): DataFrame
173
}
174
175
// Text processing
176
class Tokenizer(override val uid: String) extends Transformer {
177
def this() = this(Identifiable.randomUID("tok"))
178
179
def setInputCol(value: String): Tokenizer
180
def setOutputCol(value: String): Tokenizer
181
def transform(dataset: Dataset[_]): DataFrame
182
}
183
184
class HashingTF(override val uid: String) extends Transformer {
185
def this() = this(Identifiable.randomUID("hashingTF"))
186
187
def setInputCol(value: String): HashingTF
188
def setOutputCol(value: String): HashingTF
189
def setNumFeatures(value: Int): HashingTF
190
def setBinary(value: Boolean): HashingTF
191
def transform(dataset: Dataset[_]): DataFrame
192
}
193
194
class IDF(override val uid: String) extends Estimator[IDFModel] {
195
def this() = this(Identifiable.randomUID("idf"))
196
197
def setInputCol(value: String): IDF
198
def setOutputCol(value: String): IDF
199
def setMinDocFreq(value: Int): IDF
200
def fit(dataset: Dataset[_]): IDFModel
201
}
202
```
203
204
#### Classification
205
206
```scala { .api }
207
// Logistic Regression
208
class LogisticRegression(override val uid: String) extends Classifier[Vector, LogisticRegression, LogisticRegressionModel] {
209
def this() = this(Identifiable.randomUID("logreg"))
210
211
def setRegParam(value: Double): LogisticRegression
212
def setElasticNetParam(value: Double): LogisticRegression
213
def setMaxIter(value: Int): LogisticRegression
214
def setTol(value: Double): LogisticRegression
215
def setFitIntercept(value: Boolean): LogisticRegression
216
def setStandardization(value: Boolean): LogisticRegression
217
def setThreshold(value: Double): LogisticRegression
218
def setThresholds(value: Array[Double]): LogisticRegression
219
def setWeightCol(value: String): LogisticRegression
220
def fit(dataset: Dataset[_]): LogisticRegressionModel
221
}
222
223
class LogisticRegressionModel(override val uid: String, val coefficients: Vector, val intercept: Double)
224
extends ClassificationModel[Vector, LogisticRegressionModel] {
225
def setThreshold(value: Double): LogisticRegressionModel
226
def setThresholds(value: Array[Double]): LogisticRegressionModel
227
def summary: LogisticRegressionTrainingSummary
228
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary
229
}
230
231
// Random Forest
232
class RandomForestClassifier(override val uid: String) extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] {
233
def this() = this(Identifiable.randomUID("rfc"))
234
235
def setNumTrees(value: Int): RandomForestClassifier
236
def setMaxDepth(value: Int): RandomForestClassifier
237
def setMaxBins(value: Int): RandomForestClassifier
238
def setMinInstancesPerNode(value: Int): RandomForestClassifier
239
def setMinInfoGain(value: Double): RandomForestClassifier
240
def setSubsamplingRate(value: Double): RandomForestClassifier
241
def setFeatureSubsetStrategy(value: String): RandomForestClassifier
242
def setSeed(value: Long): RandomForestClassifier
243
def fit(dataset: Dataset[_]): RandomForestClassificationModel
244
}
245
246
class RandomForestClassificationModel(override val uid: String, val trees: Array[DecisionTreeClassificationModel], val numFeatures: Int)
247
extends ClassificationModel[Vector, RandomForestClassificationModel] {
248
def featureImportances: Vector
249
def totalNumNodes: Int
250
def toDebugString: String
251
}
252
253
// Gradient Boosted Trees
254
class GBTClassifier(override val uid: String) extends Classifier[Vector, GBTClassifier, GBTClassificationModel] {
255
def this() = this(Identifiable.randomUID("gbtc"))
256
257
def setMaxIter(value: Int): GBTClassifier
258
def setMaxDepth(value: Int): GBTClassifier
259
def setMaxBins(value: Int): GBTClassifier
260
def setMinInstancesPerNode(value: Int): GBTClassifier
261
def setMinInfoGain(value: Double): GBTClassifier
262
def setSubsamplingRate(value: Double): GBTClassifier
263
def setStepSize(value: Double): GBTClassifier
264
def setFeatureSubsetStrategy(value: String): GBTClassifier
265
def setSeed(value: Long): GBTClassifier
266
def fit(dataset: Dataset[_]): GBTClassificationModel
267
}
268
269
// Naive Bayes
270
class NaiveBayes(override val uid: String) extends Classifier[Vector, NaiveBayes, NaiveBayesModel] {
271
def this() = this(Identifiable.randomUID("nb"))
272
273
def setSmoothing(value: Double): NaiveBayes
274
def setModelType(value: String): NaiveBayes
275
def setWeightCol(value: String): NaiveBayes
276
def fit(dataset: Dataset[_]): NaiveBayesModel
277
}
278
```
279
280
#### Regression
281
282
```scala { .api }
283
// Linear Regression
284
class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] {
285
def this() = this(Identifiable.randomUID("linReg"))
286
287
def setRegParam(value: Double): LinearRegression
288
def setElasticNetParam(value: Double): LinearRegression
289
def setMaxIter(value: Int): LinearRegression
290
def setTol(value: Double): LinearRegression
291
def setFitIntercept(value: Boolean): LinearRegression
292
def setStandardization(value: Boolean): LinearRegression
293
def setSolver(value: String): LinearRegression
294
def setWeightCol(value: String): LinearRegression
295
def fit(dataset: Dataset[_]): LinearRegressionModel
296
}
297
298
class LinearRegressionModel(override val uid: String, val coefficients: Vector, val intercept: Double)
299
extends RegressionModel[Vector, LinearRegressionModel] {
300
def summary: LinearRegressionTrainingSummary
301
def evaluate(dataset: Dataset[_]): LinearRegressionSummary
302
}
303
304
// Random Forest Regression
305
class RandomForestRegressor(override val uid: String) extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel] {
306
def this() = this(Identifiable.randomUID("rfr"))
307
308
def setNumTrees(value: Int): RandomForestRegressor
309
def setMaxDepth(value: Int): RandomForestRegressor
310
def setMaxBins(value: Int): RandomForestRegressor
311
def setMinInstancesPerNode(value: Int): RandomForestRegressor
312
def setMinInfoGain(value: Double): RandomForestRegressor
313
def setSubsamplingRate(value: Double): RandomForestRegressor
314
def setFeatureSubsetStrategy(value: String): RandomForestRegressor
315
def setSeed(value: Long): RandomForestRegressor
316
def fit(dataset: Dataset[_]): RandomForestRegressionModel
317
}
318
319
// Gradient Boosted Trees Regression
320
class GBTRegressor(override val uid: String) extends Regressor[Vector, GBTRegressor, GBTRegressionModel] {
321
def this() = this(Identifiable.randomUID("gbtr"))
322
323
def setMaxIter(value: Int): GBTRegressor
324
def setMaxDepth(value: Int): GBTRegressor
325
def setMaxBins(value: Int): GBTRegressor
326
def setMinInstancesPerNode(value: Int): GBTRegressor
327
def setMinInfoGain(value: Double): GBTRegressor
328
def setSubsamplingRate(value: Double): GBTRegressor
329
def setStepSize(value: Double): GBTRegressor
330
def setFeatureSubsetStrategy(value: String): GBTRegressor
331
def setSeed(value: Long): GBTRegressor
332
def fit(dataset: Dataset[_]): GBTRegressionModel
333
}
334
```
335
336
#### Clustering
337
338
```scala { .api }
339
// K-Means
340
class KMeans(override val uid: String) extends Estimator[KMeansModel] {
341
def this() = this(Identifiable.randomUID("kmeans"))
342
343
def setK(value: Int): KMeans
344
def setMaxIter(value: Int): KMeans
345
def setTol(value: Double): KMeans
346
def setInitMode(value: String): KMeans
347
def setInitSteps(value: Int): KMeans
348
def setSeed(value: Long): KMeans
349
def setFeaturesCol(value: String): KMeans
350
def setPredictionCol(value: String): KMeans
351
def fit(dataset: Dataset[_]): KMeansModel
352
}
353
354
class KMeansModel(override val uid: String, val clusterCenters: Array[Vector]) extends Model[KMeansModel] {
355
def setPredictionCol(value: String): KMeansModel
356
def setFeaturesCol(value: String): KMeansModel
357
def transform(dataset: Dataset[_]): DataFrame
358
def computeCost(dataset: Dataset[_]): Double
359
def summary: KMeansSummary
360
}
361
362
// Gaussian Mixture Model
363
class GaussianMixture(override val uid: String) extends Estimator[GaussianMixtureModel] {
364
def this() = this(Identifiable.randomUID("GaussianMixture"))
365
366
def setK(value: Int): GaussianMixture
367
def setMaxIter(value: Int): GaussianMixture
368
def setTol(value: Double): GaussianMixture
369
def setSeed(value: Long): GaussianMixture
370
def setFeaturesCol(value: String): GaussianMixture
371
def setPredictionCol(value: String): GaussianMixture
372
def setProbabilityCol(value: String): GaussianMixture
373
def fit(dataset: Dataset[_]): GaussianMixtureModel
374
}
375
376
class GaussianMixtureModel(override val uid: String, val weights: Array[Double], val gaussians: Array[MultivariateGaussian])
377
extends Model[GaussianMixtureModel] {
378
def setFeaturesCol(value: String): GaussianMixtureModel
379
def setPredictionCol(value: String): GaussianMixtureModel
380
def setProbabilityCol(value: String): GaussianMixtureModel
381
def transform(dataset: Dataset[_]): DataFrame
382
def summary: GaussianMixtureSummary
383
}
384
```
385
386
#### Recommendation
387
388
```scala { .api }
389
// Alternating Least Squares (ALS)
390
class ALS(override val uid: String) extends Estimator[ALSModel] {
391
def this() = this(Identifiable.randomUID("als"))
392
393
def setRank(value: Int): ALS
394
def setNumUserBlocks(value: Int): ALS
395
def setNumItemBlocks(value: Int): ALS
396
def setMaxIter(value: Int): ALS
397
def setRegParam(value: Double): ALS
398
def setAlpha(value: Double): ALS
399
def setColdStartStrategy(value: String): ALS
400
def setUserCol(value: String): ALS
401
def setItemCol(value: String): ALS
402
def setRatingCol(value: String): ALS
403
def setPredictionCol(value: String): ALS
404
def setImplicitPrefs(value: Boolean): ALS
405
def setNonnegative(value: Boolean): ALS
406
def setSeed(value: Long): ALS
407
def fit(dataset: Dataset[_]): ALSModel
408
}
409
410
class ALSModel(override val uid: String, val rank: Int, val userFactors: DataFrame, val itemFactors: DataFrame)
411
extends Model[ALSModel] {
412
def setColdStartStrategy(value: String): ALSModel
413
def setUserCol(value: String): ALSModel
414
def setItemCol(value: String): ALSModel
415
def setPredictionCol(value: String): ALSModel
416
def transform(dataset: Dataset[_]): DataFrame
417
def recommendForAllUsers(numItems: Int): DataFrame
418
def recommendForAllItems(numUsers: Int): DataFrame
419
def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame
420
def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame
421
}
422
```
423
424
#### Model Evaluation
425
426
```scala { .api }
427
// Binary Classification Evaluator
428
class BinaryClassificationEvaluator(override val uid: String) extends Evaluator {
429
def this() = this(Identifiable.randomUID("binEval"))
430
431
def setRawPredictionCol(value: String): BinaryClassificationEvaluator
432
def setLabelCol(value: String): BinaryClassificationEvaluator
433
def setMetricName(value: String): BinaryClassificationEvaluator
434
def setWeightCol(value: String): BinaryClassificationEvaluator
435
def evaluate(dataset: Dataset[_]): Double
436
}
437
438
// Multiclass Classification Evaluator
439
class MulticlassClassificationEvaluator(override val uid: String) extends Evaluator {
440
def this() = this(Identifiable.randomUID("mcEval"))
441
442
def setPredictionCol(value: String): MulticlassClassificationEvaluator
443
def setLabelCol(value: String): MulticlassClassificationEvaluator
444
def setMetricName(value: String): MulticlassClassificationEvaluator
445
def setWeightCol(value: String): MulticlassClassificationEvaluator
446
def evaluate(dataset: Dataset[_]): Double
447
}
448
449
// Regression Evaluator
450
class RegressionEvaluator(override val uid: String) extends Evaluator {
451
def this() = this(Identifiable.randomUID("regEval"))
452
453
def setPredictionCol(value: String): RegressionEvaluator
454
def setLabelCol(value: String): RegressionEvaluator
455
def setMetricName(value: String): RegressionEvaluator
456
def setWeightCol(value: String): RegressionEvaluator
457
def evaluate(dataset: Dataset[_]): Double
458
}
459
460
// Clustering Evaluator
461
class ClusteringEvaluator(override val uid: String) extends Evaluator {
462
def this() = this(Identifiable.randomUID("cluEval"))
463
464
def setPredictionCol(value: String): ClusteringEvaluator
465
def setFeaturesCol(value: String): ClusteringEvaluator
466
def setMetricName(value: String): ClusteringEvaluator
467
def setWeightCol(value: String): ClusteringEvaluator
468
def evaluate(dataset: Dataset[_]): Double
469
}
470
```
471
472
#### Model Selection and Tuning
473
474
```scala { .api }
475
// Cross Validation
476
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] {
477
def this() = this(Identifiable.randomUID("cv"))
478
479
def setEstimator(value: Estimator[_]): CrossValidator
480
def setEstimatorParamMaps(value: Array[ParamMap]): CrossValidator
481
def setEvaluator(value: Evaluator): CrossValidator
482
def setNumFolds(value: Int): CrossValidator
483
def setSeed(value: Long): CrossValidator
484
def setParallelism(value: Int): CrossValidator
485
def setCollectSubModels(value: Boolean): CrossValidator
486
def fit(dataset: Dataset[_]): CrossValidatorModel
487
}
488
489
class CrossValidatorModel(override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double])
490
extends Model[CrossValidatorModel] {
491
def setNumFolds(value: Int): CrossValidatorModel
492
def transform(dataset: Dataset[_]): DataFrame
493
}
494
495
// Train Validation Split
496
class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] {
497
def this() = this(Identifiable.randomUID("tvs"))
498
499
def setEstimator(value: Estimator[_]): TrainValidationSplit
500
def setEstimatorParamMaps(value: Array[ParamMap]): TrainValidationSplit
501
def setEvaluator(value: Evaluator): TrainValidationSplit
502
def setTrainRatio(value: Double): TrainValidationSplit
503
def setSeed(value: Long): TrainValidationSplit
504
def setParallelism(value: Int): TrainValidationSplit
505
def setCollectSubModels(value: Boolean): TrainValidationSplit
506
def fit(dataset: Dataset[_]): TrainValidationSplitModel
507
}
508
509
// Parameter Grid Builder
510
class ParamGridBuilder {
511
def addGrid[T](param: Param[T], values: Array[T]): ParamGridBuilder
512
def baseOn(paramMap: ParamMap): ParamGridBuilder
513
def baseOn(paramPair: ParamPair[_], paramPairs: ParamPair[_]*): ParamGridBuilder
514
def build(): Array[ParamMap]
515
}
516
```
517
518
### MLlib RDD-based API (Legacy)
519
520
The original RDD-based machine learning library, now in maintenance mode.
521
522
#### Linear Algebra
523
524
```scala { .api }
525
// Vectors
526
trait Vector extends Serializable {
527
def size: Int
528
def toArray: Array[Double]
529
def apply(i: Int): Double
530
def copy: Vector
531
def foreachActive(f: (Int, Double) => Unit): Unit
532
def numActives: Int
533
def numNonzeros: Int
534
}
535
536
object Vectors {
537
def dense(values: Array[Double]): Vector
538
def dense(firstValue: Double, otherValues: Double*): Vector
539
def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector
540
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector
541
def zeros(size: Int): Vector
542
def norm(vector: Vector, p: Double): Double
543
def sqdist(v1: Vector, v2: Vector): Double
544
}
545
546
class DenseVector(val values: Array[Double]) extends Vector
547
class SparseVector(override val size: Int, val indices: Array[Int], val values: Array[Double]) extends Vector
548
549
// Matrices
550
trait Matrix extends Serializable {
551
def numRows: Int
552
def numCols: Int
553
def toArray: Array[Double]
554
def apply(i: Int, j: Int): Double
555
def copy: Matrix
556
def transpose: Matrix
557
def foreachActive(f: (Int, Int, Double) => Unit): Unit
558
}
559
560
object Matrices {
561
def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix
562
def sparse(numRows: Int, numCols: Int, colPtrs: Array[Int], rowIndices: Array[Int], values: Array[Double]): Matrix
563
def zeros(numRows: Int, numCols: Int): Matrix
564
def eye(n: Int): Matrix
565
}
566
```
567
568
Usage example for MLlib:
569
570
```scala
571
import org.apache.spark.mllib.linalg.{Vector, Vectors}
572
import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
573
import org.apache.spark.mllib.evaluation.RegressionMetrics
574
575
// Create labeled points
576
val data = sc.textFile("data.txt").map { line =>
577
val parts = line.split(',')
578
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
579
}
580
581
// Split data
582
val splits = data.randomSplit(Array(0.8, 0.2), seed = 11L)
583
val training = splits(0).cache()
584
val test = splits(1)
585
586
// Train model
587
val numIterations = 100
588
val stepSize = 0.00000001
589
val model = LinearRegressionWithSGD.train(training, numIterations, stepSize)
590
591
// Evaluate
592
val valuesAndPreds = test.map { point =>
593
val prediction = model.predict(point.features)
594
(point.label, prediction)
595
}
596
597
val metrics = new RegressionMetrics(valuesAndPreds)
598
println(s"RMSE = ${metrics.rootMeanSquaredError}")
599
```