0
# Machine Learning (MLlib)
1
2
MLlib is Apache Spark's machine learning library, providing scalable algorithms and utilities for classification, regression, clustering, collaborative filtering, dimensionality reduction, and more. It uses DataFrame-based APIs and ML Pipelines for building machine learning workflows.
3
4
## Capabilities
5
6
### ML Pipelines
7
8
The primary API for building machine learning workflows using a pipeline of transformers and estimators.
9
10
```scala { .api }
11
/**
12
* A stage in a pipeline, either an Estimator or a Transformer.
13
*/
14
abstract class PipelineStage extends Params with Logging {
15
def transformSchema(schema: StructType): StructType
16
def copy(extra: ParamMap): PipelineStage
17
}
18
19
/**
20
* Abstract class for transformers that transform one dataset into another.
21
*/
22
abstract class Transformer extends PipelineStage {
23
def transform(dataset: Dataset[_]): DataFrame
24
def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame
25
def transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame
26
}
27
28
/**
29
* Abstract class for estimators that fit models to data.
30
*/
31
abstract class Estimator[M <: Model[M]] extends PipelineStage {
32
def fit(dataset: Dataset[_]): M
33
def fit(dataset: Dataset[_], paramMap: ParamMap): M
34
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M
35
}
36
37
/**
38
* Abstract class for fitted models produced by estimators.
39
*/
40
abstract class Model[M <: Model[M]] extends Transformer
41
42
/**
43
* A simple pipeline which acts as an estimator. A Pipeline consists of a sequence of stages,
44
* each of which is either an Estimator or a Transformer.
45
*/
46
class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable {
47
def this() = this(randomUID("pipeline"))
48
49
def setStages(value: Array[PipelineStage]): Pipeline
50
def getStages: Array[PipelineStage]
51
def fit(dataset: Dataset[_]): PipelineModel
52
}
53
54
/**
55
* Represents a fitted pipeline consisting of fitted models and transformers.
56
*/
57
class PipelineModel(override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] with MLWritable {
58
def transform(dataset: Dataset[_]): DataFrame
59
}
60
```
61
62
**Usage Examples:**
63
64
```scala
65
import org.apache.spark.ml.{Pipeline, PipelineModel}
66
import org.apache.spark.ml.classification.LogisticRegression
67
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
68
import org.apache.spark.sql.SparkSession
69
70
val spark = SparkSession.builder().appName("MLPipeline").getOrCreate()
71
72
// Sample data
73
val training = spark.createDataFrame(Seq(
74
(0L, "a b c d e spark", 1.0),
75
(1L, "b d", 0.0),
76
(2L, "spark f g h", 1.0),
77
(3L, "hadoop mapreduce", 0.0)
78
)).toDF("id", "text", "label")
79
80
// Configure ML pipeline
81
val tokenizer = new Tokenizer()
82
.setInputCol("text")
83
.setOutputCol("words")
84
85
val hashingTF = new HashingTF()
86
.setNumFeatures(1000)
87
.setInputCol(tokenizer.getOutputCol)
88
.setOutputCol("features")
89
90
val lr = new LogisticRegression()
91
.setMaxIter(10)
92
.setRegParam(0.001)
93
94
val pipeline = new Pipeline()
95
.setStages(Array(tokenizer, hashingTF, lr))
96
97
// Fit the pipeline
98
val model = pipeline.fit(training)
99
100
// Make predictions
101
val test = spark.createDataFrame(Seq(
102
(4L, "spark i j k"),
103
(5L, "l m n"),
104
(6L, "spark hadoop spark"),
105
(7L, "apache hadoop")
106
)).toDF("id", "text")
107
108
val predictions = model.transform(test)
109
predictions.select("id", "text", "probability", "prediction").show()
110
111
// Save and load model
112
model.write.overwrite().save("path/to/model")
113
val loadedModel = PipelineModel.load("path/to/model")
114
```
115
116
### Classification Algorithms
117
118
Algorithms for predicting categorical labels.
119
120
```scala { .api }
121
// Logistic Regression
122
class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] {
123
def this() = this(randomUID("logreg"))
124
125
def setRegParam(value: Double): LogisticRegression
126
def setElasticNetParam(value: Double): LogisticRegression
127
def setMaxIter(value: Int): LogisticRegression
128
def setTol(value: Double): LogisticRegression
129
def setFitIntercept(value: Boolean): LogisticRegression
130
def setStandardization(value: Boolean): LogisticRegression
131
}
132
133
// Decision Tree Classifier
134
class DecisionTreeClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] {
135
def this() = this(randomUID("dtc"))
136
137
def setMaxDepth(value: Int): DecisionTreeClassifier
138
def setMaxBins(value: Int): DecisionTreeClassifier
139
def setMinInstancesPerNode(value: Int): DecisionTreeClassifier
140
def setMinInfoGain(value: Double): DecisionTreeClassifier
141
def setImpurity(value: String): DecisionTreeClassifier
142
}
143
144
// Random Forest Classifier
145
class RandomForestClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] {
146
def this() = this(randomUID("rfc"))
147
148
def setNumTrees(value: Int): RandomForestClassifier
149
def setMaxDepth(value: Int): RandomForestClassifier
150
def setFeatureSubsetStrategy(value: String): RandomForestClassifier
151
def setSubsamplingRate(value: Double): RandomForestClassifier
152
}
153
154
// Gradient Boosted Trees Classifier
155
class GBTClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel] {
156
def this() = this(randomUID("gbtc"))
157
158
def setMaxIter(value: Int): GBTClassifier
159
def setStepSize(value: Double): GBTClassifier
160
def setMaxDepth(value: Int): GBTClassifier
161
}
162
163
// Naive Bayes
164
class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] {
165
def this() = this(randomUID("nb"))
166
167
def setSmoothing(value: Double): NaiveBayes
168
def setModelType(value: String): NaiveBayes // "multinomial", "complement", "bernoulli", "gaussian"
169
}
170
171
// Support Vector Machine
172
class LinearSVC(override val uid: String) extends Classifier[Vector, LinearSVC, LinearSVCModel] {
173
def this() = this(randomUID("linearsvc"))
174
175
def setRegParam(value: Double): LinearSVC
176
def setMaxIter(value: Int): LinearSVC
177
def setTol(value: Double): LinearSVC
178
def setFitIntercept(value: Boolean): LinearSVC
179
}
180
```
181
182
**Usage Examples:**
183
184
```scala
185
import org.apache.spark.ml.classification._
186
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
187
import org.apache.spark.ml.feature.VectorAssembler
188
189
// Prepare data
190
val assembler = new VectorAssembler()
191
.setInputCols(Array("feature1", "feature2", "feature3"))
192
.setOutputCol("features")
193
194
val data = assembler.transform(rawData)
195
val Array(training, test) = data.randomSplit(Array(0.7, 0.3), seed = 42)
196
197
// Logistic Regression
198
val lr = new LogisticRegression()
199
.setMaxIter(20)
200
.setRegParam(0.3)
201
.setElasticNetParam(0.8)
202
203
val lrModel = lr.fit(training)
204
val lrPredictions = lrModel.transform(test)
205
206
// Random Forest
207
val rf = new RandomForestClassifier()
208
.setNumTrees(20)
209
.setMaxDepth(5)
210
.setFeatureSubsetStrategy("auto")
211
212
val rfModel = rf.fit(training)
213
val rfPredictions = rfModel.transform(test)
214
215
// Gradient Boosted Trees
216
val gbt = new GBTClassifier()
217
.setMaxIter(10)
218
.setStepSize(0.1)
219
.setMaxDepth(3)
220
221
val gbtModel = gbt.fit(training)
222
val gbtPredictions = gbtModel.transform(test)
223
224
// Evaluate models
225
val evaluator = new BinaryClassificationEvaluator()
226
.setRawPredictionCol("rawPrediction")
227
.setMetricName("areaUnderROC")
228
229
val lrAUC = evaluator.evaluate(lrPredictions)
230
val rfAUC = evaluator.evaluate(rfPredictions)
231
val gbtAUC = evaluator.evaluate(gbtPredictions)
232
233
println(s"Logistic Regression AUC: $lrAUC")
234
println(s"Random Forest AUC: $rfAUC")
235
println(s"GBT AUC: $gbtAUC")
236
```
237
238
### Regression Algorithms
239
240
Algorithms for predicting continuous numerical values.
241
242
```scala { .api }
243
// Linear Regression
244
class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] {
245
def this() = this(randomUID("linreg"))
246
247
def setRegParam(value: Double): LinearRegression
248
def setElasticNetParam(value: Double): LinearRegression
249
def setMaxIter(value: Int): LinearRegression
250
def setTol(value: Double): LinearRegression
251
def setFitIntercept(value: Boolean): LinearRegression
252
def setStandardization(value: Boolean): LinearRegression
253
}
254
255
// Decision Tree Regressor
256
class DecisionTreeRegressor(override val uid: String) extends Regressor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] {
257
def this() = this(randomUID("dtr"))
258
259
def setMaxDepth(value: Int): DecisionTreeRegressor
260
def setMaxBins(value: Int): DecisionTreeRegressor
261
def setMinInstancesPerNode(value: Int): DecisionTreeRegressor
262
def setMinInfoGain(value: Double): DecisionTreeRegressor
263
def setImpurity(value: String): DecisionTreeRegressor // "variance"
264
}
265
266
// Random Forest Regressor
267
class RandomForestRegressor(override val uid: String) extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel] {
268
def this() = this(randomUID("rfr"))
269
270
def setNumTrees(value: Int): RandomForestRegressor
271
def setMaxDepth(value: Int): RandomForestRegressor
272
def setFeatureSubsetStrategy(value: String): RandomForestRegressor
273
def setSubsamplingRate(value: Double): RandomForestRegressor
274
}
275
276
// Gradient Boosted Trees Regressor
277
class GBTRegressor(override val uid: String) extends Regressor[Vector, GBTRegressor, GBTRegressionModel] {
278
def this() = this(randomUID("gbtr"))
279
280
def setMaxIter(value: Int): GBTRegressor
281
def setStepSize(value: Double): GBTRegressor
282
def setMaxDepth(value: Int): GBTRegressor
283
}
284
285
// Generalized Linear Regression
286
class GeneralizedLinearRegression(override val uid: String) extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] {
287
def this() = this(randomUID("glr"))
288
289
def setFamily(value: String): GeneralizedLinearRegression // "gaussian", "binomial", "poisson", "gamma"
290
def setLink(value: String): GeneralizedLinearRegression
291
def setMaxIter(value: Int): GeneralizedLinearRegression
292
def setRegParam(value: Double): GeneralizedLinearRegression
293
}
294
```
295
296
**Usage Examples:**
297
298
```scala
299
import org.apache.spark.ml.regression._
300
import org.apache.spark.ml.evaluation.RegressionEvaluator
301
302
// Linear Regression
303
val lr = new LinearRegression()
304
.setMaxIter(20)
305
.setRegParam(0.3)
306
.setElasticNetParam(0.8)
307
308
val lrModel = lr.fit(training)
309
val lrPredictions = lrModel.transform(test)
310
311
// Print coefficients and intercept
312
println(s"Coefficients: ${lrModel.coefficients}")
313
println(s"Intercept: ${lrModel.intercept}")
314
315
// Random Forest Regression
316
val rf = new RandomForestRegressor()
317
.setNumTrees(100)
318
.setMaxDepth(6)
319
.setFeatureSubsetStrategy("auto")
320
321
val rfModel = rf.fit(training)
322
val rfPredictions = rfModel.transform(test)
323
324
// Feature importance
325
println(s"Feature importances: ${rfModel.featureImportances}")
326
327
// Evaluate models
328
val evaluator = new RegressionEvaluator()
329
.setPredictionCol("prediction")
330
.setLabelCol("label")
331
.setMetricName("rmse")
332
333
val lrRMSE = evaluator.evaluate(lrPredictions)
334
val rfRMSE = evaluator.evaluate(rfPredictions)
335
336
println(s"Linear Regression RMSE: $lrRMSE")
337
println(s"Random Forest RMSE: $rfRMSE")
338
```
339
340
### Clustering Algorithms
341
342
Algorithms for discovering hidden patterns and grouping data.
343
344
```scala { .api }
345
// K-Means Clustering
346
class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams with MLWritable {
347
def this() = this(randomUID("kmeans"))
348
349
def setK(value: Int): KMeans
350
def setMaxIter(value: Int): KMeans
351
def setTol(value: Double): KMeans
352
def setInitMode(value: String): KMeans // "k-means++", "random"
353
def setInitSteps(value: Int): KMeans
354
def setSeed(value: Long): KMeans
355
}
356
357
// Gaussian Mixture Model
358
class GaussianMixture(override val uid: String) extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
359
def this() = this(randomUID("gmm"))
360
361
def setK(value: Int): GaussianMixture
362
def setMaxIter(value: Int): GaussianMixture
363
def setTol(value: Double): GaussianMixture
364
def setSeed(value: Long): GaussianMixture
365
}
366
367
// Latent Dirichlet Allocation (Topic Modeling)
368
class LDA(override val uid: String) extends Estimator[LDAModel] with LDAParams with MLWritable {
369
def this() = this(randomUID("lda"))
370
371
def setK(value: Int): LDA // Number of topics
372
def setMaxIter(value: Int): LDA
373
def setSeed(value: Long): LDA
374
def setCheckpointInterval(value: Int): LDA
375
def setOptimizer(value: String): LDA // "online", "em"
376
}
377
378
// Bisecting K-Means
379
class BisectingKMeans(override val uid: String) extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
380
def this() = this(randomUID("bisecting-kmeans"))
381
382
def setK(value: Int): BisectingKMeans
383
def setMaxIter(value: Int): BisectingKMeans
384
def setSeed(value: Long): BisectingKMeans
385
def setMinDivisibleClusterSize(value: Double): BisectingKMeans
386
}
387
```
388
389
**Usage Examples:**
390
391
```scala
392
import org.apache.spark.ml.clustering._
393
import org.apache.spark.ml.evaluation.ClusteringEvaluator
394
395
// K-Means Clustering
396
val kmeans = new KMeans()
397
.setK(3)
398
.setMaxIter(20)
399
.setSeed(42L)
400
401
val kmeansModel = kmeans.fit(dataset)
402
val predictions = kmeansModel.transform(dataset)
403
404
// Show cluster centers
405
println("Cluster Centers:")
406
kmeansModel.clusterCenters.foreach(println)
407
408
// Evaluate clustering
409
val evaluator = new ClusteringEvaluator()
410
.setPredictionCol("prediction")
411
.setFeaturesCol("features")
412
.setMetricName("silhouette")
413
414
val silhouette = evaluator.evaluate(predictions)
415
println(s"Silhouette with squared euclidean distance = $silhouette")
416
417
// Gaussian Mixture Model
418
val gmm = new GaussianMixture()
419
.setK(3)
420
.setMaxIter(100)
421
.setSeed(42L)
422
423
val gmmModel = gmm.fit(dataset)
424
val gmmPredictions = gmmModel.transform(dataset)
425
426
// Show mixture weights and gaussians
427
println("Mixture weights:")
428
gmmModel.weights.foreach(println)
429
println("Gaussians:")
430
gmmModel.gaussians.foreach(println)
431
```
432
433
### Feature Engineering
434
435
Transformers for feature extraction, transformation, and selection.
436
437
```scala { .api }
438
// Feature Extractors
439
class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {
440
def this() = this(randomUID("tok"))
441
}
442
443
class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
444
def this() = this(randomUID("hashingTF"))
445
def setNumFeatures(value: Int): HashingTF
446
}
447
448
class CountVectorizer(override val uid: String) extends Estimator[CountVectorizerModel] {
449
def this() = this(randomUID("cntVec"))
450
def setVocabSize(value: Int): CountVectorizer
451
def setMinDF(value: Double): CountVectorizer
452
def setMinTF(value: Double): CountVectorizer
453
}
454
455
class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] {
456
def this() = this(randomUID("w2v"))
457
def setVectorSize(value: Int): Word2Vec
458
def setMinCount(value: Int): Word2Vec
459
def setNumPartitions(value: Int): Word2Vec
460
def setStepSize(value: Double): Word2Vec
461
def setMaxIter(value: Int): Word2Vec
462
}
463
464
// Feature Transformers
465
class VectorAssembler(override val uid: String) extends Transformer {
466
def this() = this(randomUID("vecAssembler"))
467
def setInputCols(value: Array[String]): VectorAssembler
468
def setOutputCol(value: String): VectorAssembler
469
}
470
471
class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] {
472
def this() = this(randomUID("stdScal"))
473
def setWithMean(value: Boolean): StandardScaler
474
def setWithStd(value: Boolean): StandardScaler
475
}
476
477
class MinMaxScaler(override val uid: String) extends Estimator[MinMaxScalerModel] {
478
def this() = this(randomUID("minMaxScal"))
479
def setMin(value: Double): MinMaxScaler
480
def setMax(value: Double): MinMaxScaler
481
}
482
483
class PCA(override val uid: String) extends Estimator[PCAModel] {
484
def this() = this(randomUID("pca"))
485
def setK(value: Int): PCA // Number of principal components
486
}
487
488
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] {
489
def this() = this(randomUID("strIdx"))
490
def setHandleInvalid(value: String): StringIndexer // "error", "skip", "keep"
491
def setStringOrderType(value: String): StringIndexer // "frequencyDesc", "frequencyAsc", "alphabetDesc", "alphabetAsc"
492
}
493
494
class OneHotEncoder(override val uid: String) extends Transformer {
495
def this() = this(randomUID("oneHot"))
496
def setInputCols(value: Array[String]): OneHotEncoder
497
def setOutputCols(value: Array[String]): OneHotEncoder
498
def setHandleInvalid(value: String): OneHotEncoder
499
}
500
501
// Feature Selectors
502
class ChiSqSelector(override val uid: String) extends Estimator[ChiSqSelectorModel] {
503
def this() = this(randomUID("chiSqSelector"))
504
def setNumTopFeatures(value: Int): ChiSqSelector
505
def setPercentile(value: Double): ChiSqSelector
506
def setSelectorType(value: String): ChiSqSelector // "numTopFeatures", "percentile", "fpr", "fdr", "fwe"
507
}
508
509
class UnivariateFeatureSelector(override val uid: String) extends Estimator[UnivariateFeatureSelectorModel] {
510
def this() = this(randomUID("univariateFeatureSelector"))
511
def setFeatureType(value: String): UnivariateFeatureSelector // "categorical", "continuous"
512
def setLabelType(value: String): UnivariateFeatureSelector // "categorical", "continuous"
513
def setSelectionMode(value: String): UnivariateFeatureSelector
514
}
515
```
516
517
**Usage Examples:**
518
519
```scala
520
import org.apache.spark.ml.feature._
521
522
// Text feature extraction pipeline
523
val tokenizer = new Tokenizer()
524
.setInputCol("text")
525
.setOutputCol("words")
526
527
val hashingTF = new HashingTF()
528
.setInputCol("words")
529
.setOutputCol("rawFeatures")
530
.setNumFeatures(1000)
531
532
val idf = new IDF()
533
.setInputCol("rawFeatures")
534
.setOutputCol("features")
535
536
// Categorical feature encoding
537
val indexer = new StringIndexer()
538
.setInputCol("category")
539
.setOutputCol("categoryIndex")
540
.setHandleInvalid("keep")
541
542
val encoder = new OneHotEncoder()
543
.setInputCols(Array("categoryIndex"))
544
.setOutputCols(Array("categoryVec"))
545
546
// Numerical feature scaling
547
val assembler = new VectorAssembler()
548
.setInputCols(Array("feature1", "feature2", "feature3"))
549
.setOutputCol("rawFeatures")
550
551
val scaler = new StandardScaler()
552
.setInputCol("rawFeatures")
553
.setOutputCol("scaledFeatures")
554
.setWithStd(true)
555
.setWithMean(false)
556
557
// Dimensionality reduction
558
val pca = new PCA()
559
.setInputCol("scaledFeatures")
560
.setOutputCol("pcaFeatures")
561
.setK(10)
562
563
// Feature selection
564
val selector = new ChiSqSelector()
565
.setNumTopFeatures(50)
566
.setFeaturesCol("features")
567
.setLabelCol("label")
568
.setOutputCol("selectedFeatures")
569
570
// Combine into pipeline
571
val pipeline = new Pipeline()
572
.setStages(Array(tokenizer, hashingTF, idf, indexer, encoder, assembler, scaler, pca, selector))
573
574
val model = pipeline.fit(trainingData)
575
val transformedData = model.transform(testData)
576
```
577
578
### Model Evaluation
579
580
Evaluators for assessing model performance.
581
582
```scala { .api }
583
// Binary Classification Evaluator
584
class BinaryClassificationEvaluator(override val uid: String) extends Evaluator {
585
def this() = this(randomUID("binEval"))
586
def setMetricName(value: String): BinaryClassificationEvaluator // "areaUnderROC", "areaUnderPR"
587
def setRawPredictionCol(value: String): BinaryClassificationEvaluator
588
def setLabelCol(value: String): BinaryClassificationEvaluator
589
}
590
591
// Multiclass Classification Evaluator
592
class MulticlassClassificationEvaluator(override val uid: String) extends Evaluator {
593
def this() = this(randomUID("mcEval"))
594
def setMetricName(value: String): MulticlassClassificationEvaluator // "f1", "accuracy", "weightedPrecision", "weightedRecall"
595
def setPredictionCol(value: String): MulticlassClassificationEvaluator
596
def setLabelCol(value: String): MulticlassClassificationEvaluator
597
}
598
599
// Regression Evaluator
600
class RegressionEvaluator(override val uid: String) extends Evaluator {
601
def this() = this(randomUID("regEval"))
602
def setMetricName(value: String): RegressionEvaluator // "rmse", "mse", "r2", "mae"
603
def setPredictionCol(value: String): RegressionEvaluator
604
def setLabelCol(value: String): RegressionEvaluator
605
}
606
607
// Clustering Evaluator
608
class ClusteringEvaluator(override val uid: String) extends Evaluator {
609
def this() = this(randomUID("cluEval"))
610
def setMetricName(value: String): ClusteringEvaluator // "silhouette"
611
def setPredictionCol(value: String): ClusteringEvaluator
612
def setFeaturesCol(value: String): ClusteringEvaluator
613
}
614
```
615
616
### Model Selection and Tuning
617
618
Tools for hyperparameter tuning and model selection.
619
620
```scala { .api }
621
// Cross Validator
622
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] {
623
def this() = this(randomUID("cv"))
624
def setEstimator(value: Estimator[_]): CrossValidator
625
def setEstimatorParamMaps(value: Array[ParamMap]): CrossValidator
626
def setEvaluator(value: Evaluator): CrossValidator
627
def setNumFolds(value: Int): CrossValidator
628
def setSeed(value: Long): CrossValidator
629
}
630
631
// Train Validation Split
632
class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] {
633
def this() = this(randomUID("tvs"))
634
def setEstimator(value: Estimator[_]): TrainValidationSplit
635
def setEstimatorParamMaps(value: Array[ParamMap]): TrainValidationSplit
636
def setEvaluator(value: Evaluator): TrainValidationSplit
637
def setTrainRatio(value: Double): TrainValidationSplit
638
def setSeed(value: Long): TrainValidationSplit
639
}
640
641
// Parameter Grid Builder
642
class ParamGridBuilder {
643
def addGrid[T](param: Param[T], values: Array[T]): ParamGridBuilder
644
def baseOn(paramMap: ParamMap): ParamGridBuilder
645
def build(): Array[ParamMap]
646
}
647
```
648
649
**Usage Examples:**
650
651
```scala
652
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
653
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
654
655
// Create parameter grid
656
val paramGrid = new ParamGridBuilder()
657
.addGrid(lr.regParam, Array(0.1, 0.01))
658
.addGrid(lr.fitIntercept, Array(false, true))
659
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
660
.build()
661
662
// Cross validation
663
val cv = new CrossValidator()
664
.setEstimator(lr)
665
.setEvaluator(new BinaryClassificationEvaluator())
666
.setEstimatorParamMaps(paramGrid)
667
.setNumFolds(3)
668
.setSeed(42)
669
670
val cvModel = cv.fit(trainingData)
671
672
// Best model and parameters
673
val bestModel = cvModel.bestModel.asInstanceOf[LogisticRegressionModel]
674
println(s"Best parameters: ${cvModel.bestModel.extractParamMap()}")
675
println(s"Best CV performance: ${cvModel.avgMetrics.max}")
676
677
// Make predictions with best model
678
val predictions = cvModel.transform(testData)
679
```
680
681
## Performance and Scalability
682
683
### Best Practices
684
685
1. **Data Preprocessing**: Use DataFrame operations for data cleaning and feature engineering
686
2. **Feature Engineering**: Leverage built-in transformers for common operations
687
3. **Pipeline Usage**: Use ML Pipelines for reproducible workflows
688
4. **Model Persistence**: Save and load models for production deployment
689
5. **Hyperparameter Tuning**: Use CrossValidator or TrainValidationSplit for model selection
690
6. **Resource Management**: Configure appropriate executor memory and cores for ML workloads
691
692
### Distributed Training
693
694
MLlib algorithms are designed to scale horizontally:
695
- **Tree-based algorithms**: Naturally parallelizable across features and data
696
- **Linear methods**: Use distributed optimization algorithms like L-BFGS
697
- **Clustering**: Distributed implementations with efficient convergence
698
- **Deep Learning**: Integration with external libraries for neural networks
699
700
MLlib provides a comprehensive, scalable machine learning toolkit that integrates seamlessly with Spark's distributed computing capabilities.