or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core.mdexceptions.mdgraphx.mdindex.mdlogging.mdmllib.mdsql.mdstorage.mdstreaming.mdutils.md

mllib.mddocs/

0

# Machine Learning (MLlib)

1

2

MLlib is Apache Spark's machine learning library, providing scalable algorithms and utilities for classification, regression, clustering, collaborative filtering, dimensionality reduction, and more. It uses DataFrame-based APIs and ML Pipelines for building machine learning workflows.

3

4

## Capabilities

5

6

### ML Pipelines

7

8

The primary API for building machine learning workflows using a pipeline of transformers and estimators.

9

10

```scala { .api }

11

/**

12

* A stage in a pipeline, either an Estimator or a Transformer.

13

*/

14

abstract class PipelineStage extends Params with Logging {

15

def transformSchema(schema: StructType): StructType

16

def copy(extra: ParamMap): PipelineStage

17

}

18

19

/**

20

* Abstract class for transformers that transform one dataset into another.

21

*/

22

abstract class Transformer extends PipelineStage {

23

def transform(dataset: Dataset[_]): DataFrame

24

def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame

25

def transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame

26

}

27

28

/**

29

* Abstract class for estimators that fit models to data.

30

*/

31

abstract class Estimator[M <: Model[M]] extends PipelineStage {

32

def fit(dataset: Dataset[_]): M

33

def fit(dataset: Dataset[_], paramMap: ParamMap): M

34

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M

35

}

36

37

/**

38

* Abstract class for fitted models produced by estimators.

39

*/

40

abstract class Model[M <: Model[M]] extends Transformer

41

42

/**

43

* A simple pipeline which acts as an estimator. A Pipeline consists of a sequence of stages,

44

* each of which is either an Estimator or a Transformer.

45

*/

46

class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable {

47

def this() = this(randomUID("pipeline"))

48

49

def setStages(value: Array[PipelineStage]): Pipeline

50

def getStages: Array[PipelineStage]

51

def fit(dataset: Dataset[_]): PipelineModel

52

}

53

54

/**

55

* Represents a fitted pipeline consisting of fitted models and transformers.

56

*/

57

class PipelineModel(override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] with MLWritable {

58

def transform(dataset: Dataset[_]): DataFrame

59

}

60

```

61

62

**Usage Examples:**

63

64

```scala

65

import org.apache.spark.ml.{Pipeline, PipelineModel}

66

import org.apache.spark.ml.classification.LogisticRegression

67

import org.apache.spark.ml.feature.{HashingTF, Tokenizer}

68

import org.apache.spark.sql.SparkSession

69

70

val spark = SparkSession.builder().appName("MLPipeline").getOrCreate()

71

72

// Sample data

73

val training = spark.createDataFrame(Seq(

74

(0L, "a b c d e spark", 1.0),

75

(1L, "b d", 0.0),

76

(2L, "spark f g h", 1.0),

77

(3L, "hadoop mapreduce", 0.0)

78

)).toDF("id", "text", "label")

79

80

// Configure ML pipeline

81

val tokenizer = new Tokenizer()

82

.setInputCol("text")

83

.setOutputCol("words")

84

85

val hashingTF = new HashingTF()

86

.setNumFeatures(1000)

87

.setInputCol(tokenizer.getOutputCol)

88

.setOutputCol("features")

89

90

val lr = new LogisticRegression()

91

.setMaxIter(10)

92

.setRegParam(0.001)

93

94

val pipeline = new Pipeline()

95

.setStages(Array(tokenizer, hashingTF, lr))

96

97

// Fit the pipeline

98

val model = pipeline.fit(training)

99

100

// Make predictions

101

val test = spark.createDataFrame(Seq(

102

(4L, "spark i j k"),

103

(5L, "l m n"),

104

(6L, "spark hadoop spark"),

105

(7L, "apache hadoop")

106

)).toDF("id", "text")

107

108

val predictions = model.transform(test)

109

predictions.select("id", "text", "probability", "prediction").show()

110

111

// Save and load model

112

model.write.overwrite().save("path/to/model")

113

val loadedModel = PipelineModel.load("path/to/model")

114

```

115

116

### Classification Algorithms

117

118

Algorithms for predicting categorical labels.

119

120

```scala { .api }

121

// Logistic Regression

122

class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] {

123

def this() = this(randomUID("logreg"))

124

125

def setRegParam(value: Double): LogisticRegression

126

def setElasticNetParam(value: Double): LogisticRegression

127

def setMaxIter(value: Int): LogisticRegression

128

def setTol(value: Double): LogisticRegression

129

def setFitIntercept(value: Boolean): LogisticRegression

130

def setStandardization(value: Boolean): LogisticRegression

131

}

132

133

// Decision Tree Classifier

134

class DecisionTreeClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] {

135

def this() = this(randomUID("dtc"))

136

137

def setMaxDepth(value: Int): DecisionTreeClassifier

138

def setMaxBins(value: Int): DecisionTreeClassifier

139

def setMinInstancesPerNode(value: Int): DecisionTreeClassifier

140

def setMinInfoGain(value: Double): DecisionTreeClassifier

141

def setImpurity(value: String): DecisionTreeClassifier

142

}

143

144

// Random Forest Classifier

145

class RandomForestClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] {

146

def this() = this(randomUID("rfc"))

147

148

def setNumTrees(value: Int): RandomForestClassifier

149

def setMaxDepth(value: Int): RandomForestClassifier

150

def setFeatureSubsetStrategy(value: String): RandomForestClassifier

151

def setSubsamplingRate(value: Double): RandomForestClassifier

152

}

153

154

// Gradient Boosted Trees Classifier

155

class GBTClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel] {

156

def this() = this(randomUID("gbtc"))

157

158

def setMaxIter(value: Int): GBTClassifier

159

def setStepSize(value: Double): GBTClassifier

160

def setMaxDepth(value: Int): GBTClassifier

161

}

162

163

// Naive Bayes

164

class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] {

165

def this() = this(randomUID("nb"))

166

167

def setSmoothing(value: Double): NaiveBayes

168

def setModelType(value: String): NaiveBayes // "multinomial", "complement", "bernoulli", "gaussian"

169

}

170

171

// Support Vector Machine

172

class LinearSVC(override val uid: String) extends Classifier[Vector, LinearSVC, LinearSVCModel] {

173

def this() = this(randomUID("linearsvc"))

174

175

def setRegParam(value: Double): LinearSVC

176

def setMaxIter(value: Int): LinearSVC

177

def setTol(value: Double): LinearSVC

178

def setFitIntercept(value: Boolean): LinearSVC

179

}

180

```

181

182

**Usage Examples:**

183

184

```scala

185

import org.apache.spark.ml.classification._

186

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

187

import org.apache.spark.ml.feature.VectorAssembler

188

189

// Prepare data

190

val assembler = new VectorAssembler()

191

.setInputCols(Array("feature1", "feature2", "feature3"))

192

.setOutputCol("features")

193

194

val data = assembler.transform(rawData)

195

val Array(training, test) = data.randomSplit(Array(0.7, 0.3), seed = 42)

196

197

// Logistic Regression

198

val lr = new LogisticRegression()

199

.setMaxIter(20)

200

.setRegParam(0.3)

201

.setElasticNetParam(0.8)

202

203

val lrModel = lr.fit(training)

204

val lrPredictions = lrModel.transform(test)

205

206

// Random Forest

207

val rf = new RandomForestClassifier()

208

.setNumTrees(20)

209

.setMaxDepth(5)

210

.setFeatureSubsetStrategy("auto")

211

212

val rfModel = rf.fit(training)

213

val rfPredictions = rfModel.transform(test)

214

215

// Gradient Boosted Trees

216

val gbt = new GBTClassifier()

217

.setMaxIter(10)

218

.setStepSize(0.1)

219

.setMaxDepth(3)

220

221

val gbtModel = gbt.fit(training)

222

val gbtPredictions = gbtModel.transform(test)

223

224

// Evaluate models

225

val evaluator = new BinaryClassificationEvaluator()

226

.setRawPredictionCol("rawPrediction")

227

.setMetricName("areaUnderROC")

228

229

val lrAUC = evaluator.evaluate(lrPredictions)

230

val rfAUC = evaluator.evaluate(rfPredictions)

231

val gbtAUC = evaluator.evaluate(gbtPredictions)

232

233

println(s"Logistic Regression AUC: $lrAUC")

234

println(s"Random Forest AUC: $rfAUC")

235

println(s"GBT AUC: $gbtAUC")

236

```

237

238

### Regression Algorithms

239

240

Algorithms for predicting continuous numerical values.

241

242

```scala { .api }

243

// Linear Regression

244

class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] {

245

def this() = this(randomUID("linreg"))

246

247

def setRegParam(value: Double): LinearRegression

248

def setElasticNetParam(value: Double): LinearRegression

249

def setMaxIter(value: Int): LinearRegression

250

def setTol(value: Double): LinearRegression

251

def setFitIntercept(value: Boolean): LinearRegression

252

def setStandardization(value: Boolean): LinearRegression

253

}

254

255

// Decision Tree Regressor

256

class DecisionTreeRegressor(override val uid: String) extends Regressor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] {

257

def this() = this(randomUID("dtr"))

258

259

def setMaxDepth(value: Int): DecisionTreeRegressor

260

def setMaxBins(value: Int): DecisionTreeRegressor

261

def setMinInstancesPerNode(value: Int): DecisionTreeRegressor

262

def setMinInfoGain(value: Double): DecisionTreeRegressor

263

def setImpurity(value: String): DecisionTreeRegressor // "variance"

264

}

265

266

// Random Forest Regressor

267

class RandomForestRegressor(override val uid: String) extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel] {

268

def this() = this(randomUID("rfr"))

269

270

def setNumTrees(value: Int): RandomForestRegressor

271

def setMaxDepth(value: Int): RandomForestRegressor

272

def setFeatureSubsetStrategy(value: String): RandomForestRegressor

273

def setSubsamplingRate(value: Double): RandomForestRegressor

274

}

275

276

// Gradient Boosted Trees Regressor

277

class GBTRegressor(override val uid: String) extends Regressor[Vector, GBTRegressor, GBTRegressionModel] {

278

def this() = this(randomUID("gbtr"))

279

280

def setMaxIter(value: Int): GBTRegressor

281

def setStepSize(value: Double): GBTRegressor

282

def setMaxDepth(value: Int): GBTRegressor

283

}

284

285

// Generalized Linear Regression

286

class GeneralizedLinearRegression(override val uid: String) extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] {

287

def this() = this(randomUID("glr"))

288

289

def setFamily(value: String): GeneralizedLinearRegression // "gaussian", "binomial", "poisson", "gamma"

290

def setLink(value: String): GeneralizedLinearRegression

291

def setMaxIter(value: Int): GeneralizedLinearRegression

292

def setRegParam(value: Double): GeneralizedLinearRegression

293

}

294

```

295

296

**Usage Examples:**

297

298

```scala

299

import org.apache.spark.ml.regression._

300

import org.apache.spark.ml.evaluation.RegressionEvaluator

301

302

// Linear Regression

303

val lr = new LinearRegression()

304

.setMaxIter(20)

305

.setRegParam(0.3)

306

.setElasticNetParam(0.8)

307

308

val lrModel = lr.fit(training)

309

val lrPredictions = lrModel.transform(test)

310

311

// Print coefficients and intercept

312

println(s"Coefficients: ${lrModel.coefficients}")

313

println(s"Intercept: ${lrModel.intercept}")

314

315

// Random Forest Regression

316

val rf = new RandomForestRegressor()

317

.setNumTrees(100)

318

.setMaxDepth(6)

319

.setFeatureSubsetStrategy("auto")

320

321

val rfModel = rf.fit(training)

322

val rfPredictions = rfModel.transform(test)

323

324

// Feature importance

325

println(s"Feature importances: ${rfModel.featureImportances}")

326

327

// Evaluate models

328

val evaluator = new RegressionEvaluator()

329

.setPredictionCol("prediction")

330

.setLabelCol("label")

331

.setMetricName("rmse")

332

333

val lrRMSE = evaluator.evaluate(lrPredictions)

334

val rfRMSE = evaluator.evaluate(rfPredictions)

335

336

println(s"Linear Regression RMSE: $lrRMSE")

337

println(s"Random Forest RMSE: $rfRMSE")

338

```

339

340

### Clustering Algorithms

341

342

Algorithms for discovering hidden patterns and grouping data.

343

344

```scala { .api }

345

// K-Means Clustering

346

class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams with MLWritable {

347

def this() = this(randomUID("kmeans"))

348

349

def setK(value: Int): KMeans

350

def setMaxIter(value: Int): KMeans

351

def setTol(value: Double): KMeans

352

def setInitMode(value: String): KMeans // "k-means++", "random"

353

def setInitSteps(value: Int): KMeans

354

def setSeed(value: Long): KMeans

355

}

356

357

// Gaussian Mixture Model

358

class GaussianMixture(override val uid: String) extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {

359

def this() = this(randomUID("gmm"))

360

361

def setK(value: Int): GaussianMixture

362

def setMaxIter(value: Int): GaussianMixture

363

def setTol(value: Double): GaussianMixture

364

def setSeed(value: Long): GaussianMixture

365

}

366

367

// Latent Dirichlet Allocation (Topic Modeling)

368

class LDA(override val uid: String) extends Estimator[LDAModel] with LDAParams with MLWritable {

369

def this() = this(randomUID("lda"))

370

371

def setK(value: Int): LDA // Number of topics

372

def setMaxIter(value: Int): LDA

373

def setSeed(value: Long): LDA

374

def setCheckpointInterval(value: Int): LDA

375

def setOptimizer(value: String): LDA // "online", "em"

376

}

377

378

// Bisecting K-Means

379

class BisectingKMeans(override val uid: String) extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {

380

def this() = this(randomUID("bisecting-kmeans"))

381

382

def setK(value: Int): BisectingKMeans

383

def setMaxIter(value: Int): BisectingKMeans

384

def setSeed(value: Long): BisectingKMeans

385

def setMinDivisibleClusterSize(value: Double): BisectingKMeans

386

}

387

```

388

389

**Usage Examples:**

390

391

```scala

392

import org.apache.spark.ml.clustering._

393

import org.apache.spark.ml.evaluation.ClusteringEvaluator

394

395

// K-Means Clustering

396

val kmeans = new KMeans()

397

.setK(3)

398

.setMaxIter(20)

399

.setSeed(42L)

400

401

val kmeansModel = kmeans.fit(dataset)

402

val predictions = kmeansModel.transform(dataset)

403

404

// Show cluster centers

405

println("Cluster Centers:")

406

kmeansModel.clusterCenters.foreach(println)

407

408

// Evaluate clustering

409

val evaluator = new ClusteringEvaluator()

410

.setPredictionCol("prediction")

411

.setFeaturesCol("features")

412

.setMetricName("silhouette")

413

414

val silhouette = evaluator.evaluate(predictions)

415

println(s"Silhouette with squared euclidean distance = $silhouette")

416

417

// Gaussian Mixture Model

418

val gmm = new GaussianMixture()

419

.setK(3)

420

.setMaxIter(100)

421

.setSeed(42L)

422

423

val gmmModel = gmm.fit(dataset)

424

val gmmPredictions = gmmModel.transform(dataset)

425

426

// Show mixture weights and gaussians

427

println("Mixture weights:")

428

gmmModel.weights.foreach(println)

429

println("Gaussians:")

430

gmmModel.gaussians.foreach(println)

431

```

432

433

### Feature Engineering

434

435

Transformers for feature extraction, transformation, and selection.

436

437

```scala { .api }

438

// Feature Extractors

439

class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {

440

def this() = this(randomUID("tok"))

441

}

442

443

class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {

444

def this() = this(randomUID("hashingTF"))

445

def setNumFeatures(value: Int): HashingTF

446

}

447

448

class CountVectorizer(override val uid: String) extends Estimator[CountVectorizerModel] {

449

def this() = this(randomUID("cntVec"))

450

def setVocabSize(value: Int): CountVectorizer

451

def setMinDF(value: Double): CountVectorizer

452

def setMinTF(value: Double): CountVectorizer

453

}

454

455

class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] {

456

def this() = this(randomUID("w2v"))

457

def setVectorSize(value: Int): Word2Vec

458

def setMinCount(value: Int): Word2Vec

459

def setNumPartitions(value: Int): Word2Vec

460

def setStepSize(value: Double): Word2Vec

461

def setMaxIter(value: Int): Word2Vec

462

}

463

464

// Feature Transformers

465

class VectorAssembler(override val uid: String) extends Transformer {

466

def this() = this(randomUID("vecAssembler"))

467

def setInputCols(value: Array[String]): VectorAssembler

468

def setOutputCol(value: String): VectorAssembler

469

}

470

471

class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] {

472

def this() = this(randomUID("stdScal"))

473

def setWithMean(value: Boolean): StandardScaler

474

def setWithStd(value: Boolean): StandardScaler

475

}

476

477

class MinMaxScaler(override val uid: String) extends Estimator[MinMaxScalerModel] {

478

def this() = this(randomUID("minMaxScal"))

479

def setMin(value: Double): MinMaxScaler

480

def setMax(value: Double): MinMaxScaler

481

}

482

483

class PCA(override val uid: String) extends Estimator[PCAModel] {

484

def this() = this(randomUID("pca"))

485

def setK(value: Int): PCA // Number of principal components

486

}

487

488

class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] {

489

def this() = this(randomUID("strIdx"))

490

def setHandleInvalid(value: String): StringIndexer // "error", "skip", "keep"

491

def setStringOrderType(value: String): StringIndexer // "frequencyDesc", "frequencyAsc", "alphabetDesc", "alphabetAsc"

492

}

493

494

class OneHotEncoder(override val uid: String) extends Transformer {

495

def this() = this(randomUID("oneHot"))

496

def setInputCols(value: Array[String]): OneHotEncoder

497

def setOutputCols(value: Array[String]): OneHotEncoder

498

def setHandleInvalid(value: String): OneHotEncoder

499

}

500

501

// Feature Selectors

502

class ChiSqSelector(override val uid: String) extends Estimator[ChiSqSelectorModel] {

503

def this() = this(randomUID("chiSqSelector"))

504

def setNumTopFeatures(value: Int): ChiSqSelector

505

def setPercentile(value: Double): ChiSqSelector

506

def setSelectorType(value: String): ChiSqSelector // "numTopFeatures", "percentile", "fpr", "fdr", "fwe"

507

}

508

509

class UnivariateFeatureSelector(override val uid: String) extends Estimator[UnivariateFeatureSelectorModel] {

510

def this() = this(randomUID("univariateFeatureSelector"))

511

def setFeatureType(value: String): UnivariateFeatureSelector // "categorical", "continuous"

512

def setLabelType(value: String): UnivariateFeatureSelector // "categorical", "continuous"

513

def setSelectionMode(value: String): UnivariateFeatureSelector

514

}

515

```

516

517

**Usage Examples:**

518

519

```scala

520

import org.apache.spark.ml.feature._

521

522

// Text feature extraction pipeline

523

val tokenizer = new Tokenizer()

524

.setInputCol("text")

525

.setOutputCol("words")

526

527

val hashingTF = new HashingTF()

528

.setInputCol("words")

529

.setOutputCol("rawFeatures")

530

.setNumFeatures(1000)

531

532

val idf = new IDF()

533

.setInputCol("rawFeatures")

534

.setOutputCol("features")

535

536

// Categorical feature encoding

537

val indexer = new StringIndexer()

538

.setInputCol("category")

539

.setOutputCol("categoryIndex")

540

.setHandleInvalid("keep")

541

542

val encoder = new OneHotEncoder()

543

.setInputCols(Array("categoryIndex"))

544

.setOutputCols(Array("categoryVec"))

545

546

// Numerical feature scaling

547

val assembler = new VectorAssembler()

548

.setInputCols(Array("feature1", "feature2", "feature3"))

549

.setOutputCol("rawFeatures")

550

551

val scaler = new StandardScaler()

552

.setInputCol("rawFeatures")

553

.setOutputCol("scaledFeatures")

554

.setWithStd(true)

555

.setWithMean(false)

556

557

// Dimensionality reduction

558

val pca = new PCA()

559

.setInputCol("scaledFeatures")

560

.setOutputCol("pcaFeatures")

561

.setK(10)

562

563

// Feature selection

564

val selector = new ChiSqSelector()

565

.setNumTopFeatures(50)

566

.setFeaturesCol("features")

567

.setLabelCol("label")

568

.setOutputCol("selectedFeatures")

569

570

// Combine into pipeline

571

val pipeline = new Pipeline()

572

.setStages(Array(tokenizer, hashingTF, idf, indexer, encoder, assembler, scaler, pca, selector))

573

574

val model = pipeline.fit(trainingData)

575

val transformedData = model.transform(testData)

576

```

577

578

### Model Evaluation

579

580

Evaluators for assessing model performance.

581

582

```scala { .api }

583

// Binary Classification Evaluator

584

class BinaryClassificationEvaluator(override val uid: String) extends Evaluator {

585

def this() = this(randomUID("binEval"))

586

def setMetricName(value: String): BinaryClassificationEvaluator // "areaUnderROC", "areaUnderPR"

587

def setRawPredictionCol(value: String): BinaryClassificationEvaluator

588

def setLabelCol(value: String): BinaryClassificationEvaluator

589

}

590

591

// Multiclass Classification Evaluator

592

class MulticlassClassificationEvaluator(override val uid: String) extends Evaluator {

593

def this() = this(randomUID("mcEval"))

594

def setMetricName(value: String): MulticlassClassificationEvaluator // "f1", "accuracy", "weightedPrecision", "weightedRecall"

595

def setPredictionCol(value: String): MulticlassClassificationEvaluator

596

def setLabelCol(value: String): MulticlassClassificationEvaluator

597

}

598

599

// Regression Evaluator

600

class RegressionEvaluator(override val uid: String) extends Evaluator {

601

def this() = this(randomUID("regEval"))

602

def setMetricName(value: String): RegressionEvaluator // "rmse", "mse", "r2", "mae"

603

def setPredictionCol(value: String): RegressionEvaluator

604

def setLabelCol(value: String): RegressionEvaluator

605

}

606

607

// Clustering Evaluator

608

class ClusteringEvaluator(override val uid: String) extends Evaluator {

609

def this() = this(randomUID("cluEval"))

610

def setMetricName(value: String): ClusteringEvaluator // "silhouette"

611

def setPredictionCol(value: String): ClusteringEvaluator

612

def setFeaturesCol(value: String): ClusteringEvaluator

613

}

614

```

615

616

### Model Selection and Tuning

617

618

Tools for hyperparameter tuning and model selection.

619

620

```scala { .api }

621

// Cross Validator

622

class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] {

623

def this() = this(randomUID("cv"))

624

def setEstimator(value: Estimator[_]): CrossValidator

625

def setEstimatorParamMaps(value: Array[ParamMap]): CrossValidator

626

def setEvaluator(value: Evaluator): CrossValidator

627

def setNumFolds(value: Int): CrossValidator

628

def setSeed(value: Long): CrossValidator

629

}

630

631

// Train Validation Split

632

class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] {

633

def this() = this(randomUID("tvs"))

634

def setEstimator(value: Estimator[_]): TrainValidationSplit

635

def setEstimatorParamMaps(value: Array[ParamMap]): TrainValidationSplit

636

def setEvaluator(value: Evaluator): TrainValidationSplit

637

def setTrainRatio(value: Double): TrainValidationSplit

638

def setSeed(value: Long): TrainValidationSplit

639

}

640

641

// Parameter Grid Builder

642

class ParamGridBuilder {

643

def addGrid[T](param: Param[T], values: Array[T]): ParamGridBuilder

644

def baseOn(paramMap: ParamMap): ParamGridBuilder

645

def build(): Array[ParamMap]

646

}

647

```

648

649

**Usage Examples:**

650

651

```scala

652

import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}

653

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

654

655

// Create parameter grid

656

val paramGrid = new ParamGridBuilder()

657

.addGrid(lr.regParam, Array(0.1, 0.01))

658

.addGrid(lr.fitIntercept, Array(false, true))

659

.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))

660

.build()

661

662

// Cross validation

663

val cv = new CrossValidator()

664

.setEstimator(lr)

665

.setEvaluator(new BinaryClassificationEvaluator())

666

.setEstimatorParamMaps(paramGrid)

667

.setNumFolds(3)

668

.setSeed(42)

669

670

val cvModel = cv.fit(trainingData)

671

672

// Best model and parameters

673

val bestModel = cvModel.bestModel.asInstanceOf[LogisticRegressionModel]

674

println(s"Best parameters: ${cvModel.bestModel.extractParamMap()}")

675

println(s"Best CV performance: ${cvModel.avgMetrics.max}")

676

677

// Make predictions with best model

678

val predictions = cvModel.transform(testData)

679

```

680

681

## Performance and Scalability

682

683

### Best Practices

684

685

1. **Data Preprocessing**: Use DataFrame operations for data cleaning and feature engineering

686

2. **Feature Engineering**: Leverage built-in transformers for common operations

687

3. **Pipeline Usage**: Use ML Pipelines for reproducible workflows

688

4. **Model Persistence**: Save and load models for production deployment

689

5. **Hyperparameter Tuning**: Use CrossValidator or TrainValidationSplit for model selection

690

6. **Resource Management**: Configure appropriate executor memory and cores for ML workloads

691

692

### Distributed Training

693

694

MLlib algorithms are designed to scale horizontally:

695

- **Tree-based algorithms**: Naturally parallelizable across features and data

696

- **Linear methods**: Use distributed optimization algorithms like L-BFGS

697

- **Clustering**: Distributed implementations with efficient convergence

698

- **Deep Learning**: Integration with external libraries for neural networks

699

700

MLlib provides a comprehensive, scalable machine learning toolkit that integrates seamlessly with Spark's distributed computing capabilities.