or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-engine.mdgraph-processing.mdindex.mdmachine-learning.mdsql-dataframes.mdstream-processing.md

machine-learning.mddocs/

0

# Machine Learning

1

2

Apache Spark provides scalable machine learning capabilities through two APIs: the RDD-based MLlib API (in maintenance mode) and the DataFrame-based ML API (primary API). The ML API provides high-level abstractions for building machine learning pipelines.

3

4

## Package Information

5

6

Machine learning functionality is available through:

7

8

```scala

9

// ML API (DataFrame-based, primary API)

10

import org.apache.spark.ml._

11

import org.apache.spark.ml.feature._

12

import org.apache.spark.ml.classification._

13

import org.apache.spark.ml.regression._

14

import org.apache.spark.ml.clustering._

15

import org.apache.spark.ml.recommendation._

16

import org.apache.spark.ml.evaluation._

17

18

// MLlib API (RDD-based, maintenance mode)

19

import org.apache.spark.mllib.classification._

20

import org.apache.spark.mllib.regression._

21

import org.apache.spark.mllib.clustering._

22

import org.apache.spark.mllib.linalg._

23

```

24

25

## Basic Usage

26

27

```scala

28

import org.apache.spark.sql.SparkSession

29

import org.apache.spark.ml.Pipeline

30

import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}

31

import org.apache.spark.ml.classification.LogisticRegression

32

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

33

34

val spark = SparkSession.builder().appName("ML Example").getOrCreate()

35

36

// Load data

37

val data = spark.read

38

.option("header", "true")

39

.option("inferSchema", "true")

40

.csv("path/to/data.csv")

41

42

// Feature engineering

43

val stringIndexer = new StringIndexer()

44

.setInputCol("category")

45

.setOutputCol("categoryIndex")

46

47

val assembler = new VectorAssembler()

48

.setInputCols(Array("feature1", "feature2", "categoryIndex"))

49

.setOutputCol("features")

50

51

// Model

52

val lr = new LogisticRegression()

53

.setFeaturesCol("features")

54

.setLabelCol("label")

55

56

// Pipeline

57

val pipeline = new Pipeline()

58

.setStages(Array(stringIndexer, assembler, lr))

59

60

// Train

61

val Array(trainData, testData) = data.randomSplit(Array(0.8, 0.2), seed = 1234)

62

val model = pipeline.fit(trainData)

63

64

// Evaluate

65

val predictions = model.transform(testData)

66

val evaluator = new BinaryClassificationEvaluator()

67

val auc = evaluator.evaluate(predictions)

68

69

println(s"AUC: $auc")

70

```

71

72

## Capabilities

73

74

### ML Pipeline API

75

76

The primary machine learning API built on DataFrames, providing high-level abstractions for creating ML workflows.

77

78

#### Pipeline Components

79

80

```scala { .api }

81

abstract class PipelineStage extends Params {

82

def copy(extra: ParamMap): PipelineStage

83

def transformSchema(schema: StructType): StructType

84

def params: Array[Param[_]]

85

}

86

87

abstract class Estimator[M <: Model[M]] extends PipelineStage {

88

def fit(dataset: Dataset[_]): M

89

def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M]

90

def fit(dataset: Dataset[_], paramMap: ParamMap): M

91

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M

92

}

93

94

abstract class Transformer extends PipelineStage {

95

def transform(dataset: Dataset[_]): DataFrame

96

def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame

97

def transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame

98

}

99

100

abstract class Model[M <: Model[M]] extends Transformer

101

102

class Pipeline(val uid: String) extends Estimator[PipelineModel] {

103

def this() = this(Identifiable.randomUID("pipeline"))

104

105

def setStages(value: Array[PipelineStage]): Pipeline

106

def getStages: Array[PipelineStage]

107

def fit(dataset: Dataset[_]): PipelineModel

108

}

109

110

class PipelineModel(override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] {

111

def transform(dataset: Dataset[_]): DataFrame

112

}

113

```

114

115

#### Feature Engineering

116

117

```scala { .api }

118

// Vector operations

119

class VectorAssembler(override val uid: String) extends Transformer {

120

def this() = this(Identifiable.randomUID("vecAssembler"))

121

122

def setInputCols(value: Array[String]): VectorAssembler

123

def setOutputCol(value: String): VectorAssembler

124

def getInputCols: Array[String]

125

def getOutputCol: String

126

def transform(dataset: Dataset[_]): DataFrame

127

}

128

129

// String indexing

130

class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] {

131

def this() = this(Identifiable.randomUID("strIdx"))

132

133

def setInputCol(value: String): StringIndexer

134

def setOutputCol(value: String): StringIndexer

135

def setHandleInvalid(value: String): StringIndexer

136

def setStringOrderType(value: String): StringIndexer

137

def fit(dataset: Dataset[_]): StringIndexerModel

138

}

139

140

class StringIndexerModel(override val uid: String, val labels: Array[String]) extends Model[StringIndexerModel] {

141

def setInputCol(value: String): StringIndexerModel

142

def setOutputCol(value: String): StringIndexerModel

143

def setHandleInvalid(value: String): StringIndexerModel

144

def transform(dataset: Dataset[_]): DataFrame

145

}

146

147

// One-hot encoding

148

class OneHotEncoder(override val uid: String) extends Transformer {

149

def this() = this(Identifiable.randomUID("oneHot"))

150

151

def setInputCols(value: Array[String]): OneHotEncoder

152

def setOutputCols(value: Array[String]): OneHotEncoder

153

def setHandleInvalid(value: String): OneHotEncoder

154

def setDropLast(value: Boolean): OneHotEncoder

155

def transform(dataset: Dataset[_]): DataFrame

156

}

157

158

// Scaling

159

class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] {

160

def this() = this(Identifiable.randomUID("stdScal"))

161

162

def setInputCol(value: String): StandardScaler

163

def setOutputCol(value: String): StandardScaler

164

def setWithMean(value: Boolean): StandardScaler

165

def setWithStd(value: Boolean): StandardScaler

166

def fit(dataset: Dataset[_]): StandardScalerModel

167

}

168

169

class StandardScalerModel(override val uid: String, val std: Vector, val mean: Vector) extends Model[StandardScalerModel] {

170

def setInputCol(value: String): StandardScalerModel

171

def setOutputCol(value: String): StandardScalerModel

172

def transform(dataset: Dataset[_]): DataFrame

173

}

174

175

// Text processing

176

class Tokenizer(override val uid: String) extends Transformer {

177

def this() = this(Identifiable.randomUID("tok"))

178

179

def setInputCol(value: String): Tokenizer

180

def setOutputCol(value: String): Tokenizer

181

def transform(dataset: Dataset[_]): DataFrame

182

}

183

184

class HashingTF(override val uid: String) extends Transformer {

185

def this() = this(Identifiable.randomUID("hashingTF"))

186

187

def setInputCol(value: String): HashingTF

188

def setOutputCol(value: String): HashingTF

189

def setNumFeatures(value: Int): HashingTF

190

def setBinary(value: Boolean): HashingTF

191

def transform(dataset: Dataset[_]): DataFrame

192

}

193

194

class IDF(override val uid: String) extends Estimator[IDFModel] {

195

def this() = this(Identifiable.randomUID("idf"))

196

197

def setInputCol(value: String): IDF

198

def setOutputCol(value: String): IDF

199

def setMinDocFreq(value: Int): IDF

200

def fit(dataset: Dataset[_]): IDFModel

201

}

202

```

203

204

#### Classification

205

206

```scala { .api }

207

// Logistic Regression

208

class LogisticRegression(override val uid: String) extends Classifier[Vector, LogisticRegression, LogisticRegressionModel] {

209

def this() = this(Identifiable.randomUID("logreg"))

210

211

def setRegParam(value: Double): LogisticRegression

212

def setElasticNetParam(value: Double): LogisticRegression

213

def setMaxIter(value: Int): LogisticRegression

214

def setTol(value: Double): LogisticRegression

215

def setFitIntercept(value: Boolean): LogisticRegression

216

def setStandardization(value: Boolean): LogisticRegression

217

def setThreshold(value: Double): LogisticRegression

218

def setThresholds(value: Array[Double]): LogisticRegression

219

def setWeightCol(value: String): LogisticRegression

220

def fit(dataset: Dataset[_]): LogisticRegressionModel

221

}

222

223

class LogisticRegressionModel(override val uid: String, val coefficients: Vector, val intercept: Double)

224

extends ClassificationModel[Vector, LogisticRegressionModel] {

225

def setThreshold(value: Double): LogisticRegressionModel

226

def setThresholds(value: Array[Double]): LogisticRegressionModel

227

def summary: LogisticRegressionTrainingSummary

228

def evaluate(dataset: Dataset[_]): LogisticRegressionSummary

229

}

230

231

// Random Forest

232

class RandomForestClassifier(override val uid: String) extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] {

233

def this() = this(Identifiable.randomUID("rfc"))

234

235

def setNumTrees(value: Int): RandomForestClassifier

236

def setMaxDepth(value: Int): RandomForestClassifier

237

def setMaxBins(value: Int): RandomForestClassifier

238

def setMinInstancesPerNode(value: Int): RandomForestClassifier

239

def setMinInfoGain(value: Double): RandomForestClassifier

240

def setSubsamplingRate(value: Double): RandomForestClassifier

241

def setFeatureSubsetStrategy(value: String): RandomForestClassifier

242

def setSeed(value: Long): RandomForestClassifier

243

def fit(dataset: Dataset[_]): RandomForestClassificationModel

244

}

245

246

class RandomForestClassificationModel(override val uid: String, val trees: Array[DecisionTreeClassificationModel], val numFeatures: Int)

247

extends ClassificationModel[Vector, RandomForestClassificationModel] {

248

def featureImportances: Vector

249

def totalNumNodes: Int

250

def toDebugString: String

251

}

252

253

// Gradient Boosted Trees

254

class GBTClassifier(override val uid: String) extends Classifier[Vector, GBTClassifier, GBTClassificationModel] {

255

def this() = this(Identifiable.randomUID("gbtc"))

256

257

def setMaxIter(value: Int): GBTClassifier

258

def setMaxDepth(value: Int): GBTClassifier

259

def setMaxBins(value: Int): GBTClassifier

260

def setMinInstancesPerNode(value: Int): GBTClassifier

261

def setMinInfoGain(value: Double): GBTClassifier

262

def setSubsamplingRate(value: Double): GBTClassifier

263

def setStepSize(value: Double): GBTClassifier

264

def setFeatureSubsetStrategy(value: String): GBTClassifier

265

def setSeed(value: Long): GBTClassifier

266

def fit(dataset: Dataset[_]): GBTClassificationModel

267

}

268

269

// Naive Bayes

270

class NaiveBayes(override val uid: String) extends Classifier[Vector, NaiveBayes, NaiveBayesModel] {

271

def this() = this(Identifiable.randomUID("nb"))

272

273

def setSmoothing(value: Double): NaiveBayes

274

def setModelType(value: String): NaiveBayes

275

def setWeightCol(value: String): NaiveBayes

276

def fit(dataset: Dataset[_]): NaiveBayesModel

277

}

278

```

279

280

#### Regression

281

282

```scala { .api }

283

// Linear Regression

284

class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] {

285

def this() = this(Identifiable.randomUID("linReg"))

286

287

def setRegParam(value: Double): LinearRegression

288

def setElasticNetParam(value: Double): LinearRegression

289

def setMaxIter(value: Int): LinearRegression

290

def setTol(value: Double): LinearRegression

291

def setFitIntercept(value: Boolean): LinearRegression

292

def setStandardization(value: Boolean): LinearRegression

293

def setSolver(value: String): LinearRegression

294

def setWeightCol(value: String): LinearRegression

295

def fit(dataset: Dataset[_]): LinearRegressionModel

296

}

297

298

class LinearRegressionModel(override val uid: String, val coefficients: Vector, val intercept: Double)

299

extends RegressionModel[Vector, LinearRegressionModel] {

300

def summary: LinearRegressionTrainingSummary

301

def evaluate(dataset: Dataset[_]): LinearRegressionSummary

302

}

303

304

// Random Forest Regression

305

class RandomForestRegressor(override val uid: String) extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel] {

306

def this() = this(Identifiable.randomUID("rfr"))

307

308

def setNumTrees(value: Int): RandomForestRegressor

309

def setMaxDepth(value: Int): RandomForestRegressor

310

def setMaxBins(value: Int): RandomForestRegressor

311

def setMinInstancesPerNode(value: Int): RandomForestRegressor

312

def setMinInfoGain(value: Double): RandomForestRegressor

313

def setSubsamplingRate(value: Double): RandomForestRegressor

314

def setFeatureSubsetStrategy(value: String): RandomForestRegressor

315

def setSeed(value: Long): RandomForestRegressor

316

def fit(dataset: Dataset[_]): RandomForestRegressionModel

317

}

318

319

// Gradient Boosted Trees Regression

320

class GBTRegressor(override val uid: String) extends Regressor[Vector, GBTRegressor, GBTRegressionModel] {

321

def this() = this(Identifiable.randomUID("gbtr"))

322

323

def setMaxIter(value: Int): GBTRegressor

324

def setMaxDepth(value: Int): GBTRegressor

325

def setMaxBins(value: Int): GBTRegressor

326

def setMinInstancesPerNode(value: Int): GBTRegressor

327

def setMinInfoGain(value: Double): GBTRegressor

328

def setSubsamplingRate(value: Double): GBTRegressor

329

def setStepSize(value: Double): GBTRegressor

330

def setFeatureSubsetStrategy(value: String): GBTRegressor

331

def setSeed(value: Long): GBTRegressor

332

def fit(dataset: Dataset[_]): GBTRegressionModel

333

}

334

```

335

336

#### Clustering

337

338

```scala { .api }

339

// K-Means

340

class KMeans(override val uid: String) extends Estimator[KMeansModel] {

341

def this() = this(Identifiable.randomUID("kmeans"))

342

343

def setK(value: Int): KMeans

344

def setMaxIter(value: Int): KMeans

345

def setTol(value: Double): KMeans

346

def setInitMode(value: String): KMeans

347

def setInitSteps(value: Int): KMeans

348

def setSeed(value: Long): KMeans

349

def setFeaturesCol(value: String): KMeans

350

def setPredictionCol(value: String): KMeans

351

def fit(dataset: Dataset[_]): KMeansModel

352

}

353

354

class KMeansModel(override val uid: String, val clusterCenters: Array[Vector]) extends Model[KMeansModel] {

355

def setPredictionCol(value: String): KMeansModel

356

def setFeaturesCol(value: String): KMeansModel

357

def transform(dataset: Dataset[_]): DataFrame

358

def computeCost(dataset: Dataset[_]): Double

359

def summary: KMeansSummary

360

}

361

362

// Gaussian Mixture Model

363

class GaussianMixture(override val uid: String) extends Estimator[GaussianMixtureModel] {

364

def this() = this(Identifiable.randomUID("GaussianMixture"))

365

366

def setK(value: Int): GaussianMixture

367

def setMaxIter(value: Int): GaussianMixture

368

def setTol(value: Double): GaussianMixture

369

def setSeed(value: Long): GaussianMixture

370

def setFeaturesCol(value: String): GaussianMixture

371

def setPredictionCol(value: String): GaussianMixture

372

def setProbabilityCol(value: String): GaussianMixture

373

def fit(dataset: Dataset[_]): GaussianMixtureModel

374

}

375

376

class GaussianMixtureModel(override val uid: String, val weights: Array[Double], val gaussians: Array[MultivariateGaussian])

377

extends Model[GaussianMixtureModel] {

378

def setFeaturesCol(value: String): GaussianMixtureModel

379

def setPredictionCol(value: String): GaussianMixtureModel

380

def setProbabilityCol(value: String): GaussianMixtureModel

381

def transform(dataset: Dataset[_]): DataFrame

382

def summary: GaussianMixtureSummary

383

}

384

```

385

386

#### Recommendation

387

388

```scala { .api }

389

// Alternating Least Squares (ALS)

390

class ALS(override val uid: String) extends Estimator[ALSModel] {

391

def this() = this(Identifiable.randomUID("als"))

392

393

def setRank(value: Int): ALS

394

def setNumUserBlocks(value: Int): ALS

395

def setNumItemBlocks(value: Int): ALS

396

def setMaxIter(value: Int): ALS

397

def setRegParam(value: Double): ALS

398

def setAlpha(value: Double): ALS

399

def setColdStartStrategy(value: String): ALS

400

def setUserCol(value: String): ALS

401

def setItemCol(value: String): ALS

402

def setRatingCol(value: String): ALS

403

def setPredictionCol(value: String): ALS

404

def setImplicitPrefs(value: Boolean): ALS

405

def setNonnegative(value: Boolean): ALS

406

def setSeed(value: Long): ALS

407

def fit(dataset: Dataset[_]): ALSModel

408

}

409

410

class ALSModel(override val uid: String, val rank: Int, val userFactors: DataFrame, val itemFactors: DataFrame)

411

extends Model[ALSModel] {

412

def setColdStartStrategy(value: String): ALSModel

413

def setUserCol(value: String): ALSModel

414

def setItemCol(value: String): ALSModel

415

def setPredictionCol(value: String): ALSModel

416

def transform(dataset: Dataset[_]): DataFrame

417

def recommendForAllUsers(numItems: Int): DataFrame

418

def recommendForAllItems(numUsers: Int): DataFrame

419

def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame

420

def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame

421

}

422

```

423

424

#### Model Evaluation

425

426

```scala { .api }

427

// Binary Classification Evaluator

428

class BinaryClassificationEvaluator(override val uid: String) extends Evaluator {

429

def this() = this(Identifiable.randomUID("binEval"))

430

431

def setRawPredictionCol(value: String): BinaryClassificationEvaluator

432

def setLabelCol(value: String): BinaryClassificationEvaluator

433

def setMetricName(value: String): BinaryClassificationEvaluator

434

def setWeightCol(value: String): BinaryClassificationEvaluator

435

def evaluate(dataset: Dataset[_]): Double

436

}

437

438

// Multiclass Classification Evaluator

439

class MulticlassClassificationEvaluator(override val uid: String) extends Evaluator {

440

def this() = this(Identifiable.randomUID("mcEval"))

441

442

def setPredictionCol(value: String): MulticlassClassificationEvaluator

443

def setLabelCol(value: String): MulticlassClassificationEvaluator

444

def setMetricName(value: String): MulticlassClassificationEvaluator

445

def setWeightCol(value: String): MulticlassClassificationEvaluator

446

def evaluate(dataset: Dataset[_]): Double

447

}

448

449

// Regression Evaluator

450

class RegressionEvaluator(override val uid: String) extends Evaluator {

451

def this() = this(Identifiable.randomUID("regEval"))

452

453

def setPredictionCol(value: String): RegressionEvaluator

454

def setLabelCol(value: String): RegressionEvaluator

455

def setMetricName(value: String): RegressionEvaluator

456

def setWeightCol(value: String): RegressionEvaluator

457

def evaluate(dataset: Dataset[_]): Double

458

}

459

460

// Clustering Evaluator

461

class ClusteringEvaluator(override val uid: String) extends Evaluator {

462

def this() = this(Identifiable.randomUID("cluEval"))

463

464

def setPredictionCol(value: String): ClusteringEvaluator

465

def setFeaturesCol(value: String): ClusteringEvaluator

466

def setMetricName(value: String): ClusteringEvaluator

467

def setWeightCol(value: String): ClusteringEvaluator

468

def evaluate(dataset: Dataset[_]): Double

469

}

470

```

471

472

#### Model Selection and Tuning

473

474

```scala { .api }

475

// Cross Validation

476

class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] {

477

def this() = this(Identifiable.randomUID("cv"))

478

479

def setEstimator(value: Estimator[_]): CrossValidator

480

def setEstimatorParamMaps(value: Array[ParamMap]): CrossValidator

481

def setEvaluator(value: Evaluator): CrossValidator

482

def setNumFolds(value: Int): CrossValidator

483

def setSeed(value: Long): CrossValidator

484

def setParallelism(value: Int): CrossValidator

485

def setCollectSubModels(value: Boolean): CrossValidator

486

def fit(dataset: Dataset[_]): CrossValidatorModel

487

}

488

489

class CrossValidatorModel(override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double])

490

extends Model[CrossValidatorModel] {

491

def setNumFolds(value: Int): CrossValidatorModel

492

def transform(dataset: Dataset[_]): DataFrame

493

}

494

495

// Train Validation Split

496

class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] {

497

def this() = this(Identifiable.randomUID("tvs"))

498

499

def setEstimator(value: Estimator[_]): TrainValidationSplit

500

def setEstimatorParamMaps(value: Array[ParamMap]): TrainValidationSplit

501

def setEvaluator(value: Evaluator): TrainValidationSplit

502

def setTrainRatio(value: Double): TrainValidationSplit

503

def setSeed(value: Long): TrainValidationSplit

504

def setParallelism(value: Int): TrainValidationSplit

505

def setCollectSubModels(value: Boolean): TrainValidationSplit

506

def fit(dataset: Dataset[_]): TrainValidationSplitModel

507

}

508

509

// Parameter Grid Builder

510

class ParamGridBuilder {

511

def addGrid[T](param: Param[T], values: Array[T]): ParamGridBuilder

512

def baseOn(paramMap: ParamMap): ParamGridBuilder

513

def baseOn(paramPair: ParamPair[_], paramPairs: ParamPair[_]*): ParamGridBuilder

514

def build(): Array[ParamMap]

515

}

516

```

517

518

### MLlib RDD-based API (Legacy)

519

520

The original RDD-based machine learning library, now in maintenance mode.

521

522

#### Linear Algebra

523

524

```scala { .api }

525

// Vectors

526

trait Vector extends Serializable {

527

def size: Int

528

def toArray: Array[Double]

529

def apply(i: Int): Double

530

def copy: Vector

531

def foreachActive(f: (Int, Double) => Unit): Unit

532

def numActives: Int

533

def numNonzeros: Int

534

}

535

536

object Vectors {

537

def dense(values: Array[Double]): Vector

538

def dense(firstValue: Double, otherValues: Double*): Vector

539

def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector

540

def sparse(size: Int, elements: Seq[(Int, Double)]): Vector

541

def zeros(size: Int): Vector

542

def norm(vector: Vector, p: Double): Double

543

def sqdist(v1: Vector, v2: Vector): Double

544

}

545

546

class DenseVector(val values: Array[Double]) extends Vector

547

class SparseVector(override val size: Int, val indices: Array[Int], val values: Array[Double]) extends Vector

548

549

// Matrices

550

trait Matrix extends Serializable {

551

def numRows: Int

552

def numCols: Int

553

def toArray: Array[Double]

554

def apply(i: Int, j: Int): Double

555

def copy: Matrix

556

def transpose: Matrix

557

def foreachActive(f: (Int, Int, Double) => Unit): Unit

558

}

559

560

object Matrices {

561

def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix

562

def sparse(numRows: Int, numCols: Int, colPtrs: Array[Int], rowIndices: Array[Int], values: Array[Double]): Matrix

563

def zeros(numRows: Int, numCols: Int): Matrix

564

def eye(n: Int): Matrix

565

}

566

```

567

568

Usage example for MLlib:

569

570

```scala

571

import org.apache.spark.mllib.linalg.{Vector, Vectors}

572

import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}

573

import org.apache.spark.mllib.evaluation.RegressionMetrics

574

575

// Create labeled points

576

val data = sc.textFile("data.txt").map { line =>

577

val parts = line.split(',')

578

LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))

579

}

580

581

// Split data

582

val splits = data.randomSplit(Array(0.8, 0.2), seed = 11L)

583

val training = splits(0).cache()

584

val test = splits(1)

585

586

// Train model

587

val numIterations = 100

588

val stepSize = 0.00000001

589

val model = LinearRegressionWithSGD.train(training, numIterations, stepSize)

590

591

// Evaluate

592

val valuesAndPreds = test.map { point =>

593

val prediction = model.predict(point.features)

594

(point.label, prediction)

595

}

596

597

val metrics = new RegressionMetrics(valuesAndPreds)

598

println(s"RMSE = ${metrics.rootMeanSquaredError}")

599

```