or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

classification.mdclustering.mdcore-framework.mdevaluation.mdfeature-processing.mdfrequent-pattern-mining.mdindex.mdlinear-algebra.mdrdd-api.mdrecommendation.mdregression.md

classification.mddocs/

0

# Classification

1

2

MLlib provides comprehensive classification algorithms for supervised learning tasks with categorical labels. All classifiers follow the Estimator-Transformer pattern and support the Pipeline API.

3

4

## Logistic Regression

5

6

### Estimator

7

8

```scala { .api }

9

class LogisticRegression(override val uid: String) extends Classifier[Vector, LogisticRegression, LogisticRegressionModel]

10

with LogisticRegressionParams with DefaultParamsWritable {

11

12

def this() = this(Identifiable.randomUID("logreg"))

13

14

def setRegParam(value: Double): LogisticRegression

15

def setElasticNetParam(value: Double): LogisticRegression

16

def setMaxIter(value: Int): LogisticRegression

17

def setTol(value: Double): LogisticRegression

18

def setFitIntercept(value: Boolean): LogisticRegression

19

def setFamily(value: String): LogisticRegression

20

def setStandardization(value: Boolean): LogisticRegression

21

def setThreshold(value: Double): LogisticRegression

22

def setThresholds(value: Array[Double]): LogisticRegression

23

def setWeightCol(value: String): LogisticRegression

24

def setAggregationDepth(value: Int): LogisticRegression

25

26

// Bounded optimization methods (expert parameters)

27

def setLowerBoundsOnCoefficients(value: Matrix): LogisticRegression

28

def setUpperBoundsOnCoefficients(value: Matrix): LogisticRegression

29

def setLowerBoundsOnIntercepts(value: Vector): LogisticRegression

30

def setUpperBoundsOnIntercepts(value: Vector): LogisticRegression

31

32

override def fit(dataset: Dataset[_]): LogisticRegressionModel

33

override def copy(extra: ParamMap): LogisticRegression

34

}

35

```

36

37

### Model

38

39

```scala { .api }

40

class LogisticRegressionModel private[spark] (

41

override val uid: String,

42

val coefficientMatrix: Matrix,

43

val interceptVector: Vector,

44

override val numClasses: Int,

45

private val isMultinomial: Boolean)

46

extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams

47

with MLWritable {

48

49

// Convenience constructor for binary classification (deprecated)

50

private[spark] def this(uid: String, coefficients: Vector, intercept: Double) =

51

this(uid, new DenseMatrix(1, coefficients.size, coefficients.toArray, isTransposed = true),

52

Vectors.dense(intercept), 2, isMultinomial = false)

53

54

// Model coefficients for binary classification (throws exception if multinomial)

55

def coefficients: Vector

56

def intercept: Double

57

58

lazy val summary: LogisticRegressionTrainingSummary

59

def hasSummary: Boolean

60

def binarySummary: BinaryLogisticRegressionTrainingSummary

61

def evaluate(dataset: Dataset[_]): LogisticRegressionSummary

62

63

override def predict(features: Vector): Double

64

override def predictRaw(features: Vector): Vector

65

override def predictProbability(features: Vector): Vector

66

override def transform(dataset: Dataset[_]): DataFrame

67

override def copy(extra: ParamMap): LogisticRegressionModel

68

def write: MLWriter

69

}

70

71

class LogisticRegressionTrainingSummary(predictions: DataFrame, predictionCol: String,

72

labelCol: String, featuresCol: String,

73

val objectiveHistory: Array[Double])

74

extends LogisticRegressionSummary(predictions, predictionCol, labelCol, featuresCol) {

75

76

def totalIterations: Int

77

}

78

79

class LogisticRegressionSummary(predictions: DataFrame, predictionCol: String,

80

labelCol: String, featuresCol: String) extends ClassificationSummary {

81

82

def probabilityCol: String

83

def fMeasureByThreshold: DataFrame

84

def precisionByThreshold: DataFrame

85

def recallByThreshold: DataFrame

86

def roc: DataFrame

87

def areaUnderROC: Double

88

def pr: DataFrame

89

def fMeasureByLabel: DataFrame

90

def precisionByLabel: DataFrame

91

def recallByLabel: DataFrame

92

}

93

```

94

95

## Decision Tree Classifier

96

97

### Estimator

98

99

```scala { .api }

100

class DecisionTreeClassifier(override val uid: String)

101

extends Classifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]

102

with DecisionTreeClassifierParams with DefaultParamsWritable {

103

104

def this() = this(Identifiable.randomUID("dtc"))

105

106

def setMaxDepth(value: Int): DecisionTreeClassifier

107

def setMaxBins(value: Int): DecisionTreeClassifier

108

def setMinInstancesPerNode(value: Int): DecisionTreeClassifier

109

def setMinInfoGain(value: Double): DecisionTreeClassifier

110

def setMaxMemoryInMB(value: Int): DecisionTreeClassifier

111

def setCacheNodeIds(value: Boolean): DecisionTreeClassifier

112

def setCheckpointInterval(value: Int): DecisionTreeClassifier

113

def setImpurity(value: String): DecisionTreeClassifier

114

def setSeed(value: Long): DecisionTreeClassifier

115

116

override def fit(dataset: Dataset[_]): DecisionTreeClassificationModel

117

override def copy(extra: ParamMap): DecisionTreeClassifier

118

}

119

```

120

121

### Model

122

123

```scala { .api }

124

class DecisionTreeClassificationModel(override val uid: String, val rootNode: Node,

125

val numFeatures: Int, val numClasses: Int)

126

extends ClassificationModel[Vector, DecisionTreeClassificationModel]

127

with DecisionTreeClassifierParams with TreeEnsembleModel with MLWritable {

128

129

override def predict(features: Vector): Double

130

override def predictRaw(features: Vector): Vector

131

override def predictProbability(features: Vector): Vector

132

def depth: Int

133

def numNodes: Int

134

def toDebugString: String

135

override def copy(extra: ParamMap): DecisionTreeClassificationModel

136

def write: MLWriter

137

}

138

```

139

140

## Random Forest Classifier

141

142

### Estimator

143

144

```scala { .api }

145

class RandomForestClassifier(override val uid: String)

146

extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]

147

with RandomForestClassifierParams with DefaultParamsWritable {

148

149

def this() = this(Identifiable.randomUID("rfc"))

150

151

def setNumTrees(value: Int): RandomForestClassifier

152

def setMaxDepth(value: Int): RandomForestClassifier

153

def setMaxBins(value: Int): RandomForestClassifier

154

def setMinInstancesPerNode(value: Int): RandomForestClassifier

155

def setMinInfoGain(value: Double): RandomForestClassifier

156

def setMaxMemoryInMB(value: Int): RandomForestClassifier

157

def setCacheNodeIds(value: Boolean): RandomForestClassifier

158

def setCheckpointInterval(value: Int): RandomForestClassifier

159

def setImpurity(value: String): RandomForestClassifier

160

def setSubsamplingRate(value: Double): RandomForestClassifier

161

def setSeed(value: Long): RandomForestClassifier

162

def setFeatureSubsetStrategy(value: String): RandomForestClassifier

163

164

override def fit(dataset: Dataset[_]): RandomForestClassificationModel

165

override def copy(extra: ParamMap): RandomForestClassifier

166

}

167

```

168

169

### Model

170

171

```scala { .api }

172

class RandomForestClassificationModel(override val uid: String, private val _trees: Array[DecisionTreeClassificationModel],

173

val numFeatures: Int, val numClasses: Int)

174

extends ClassificationModel[Vector, RandomForestClassificationModel]

175

with RandomForestClassifierParams with TreeEnsembleModel with MLWritable {

176

177

def trees: Array[DecisionTreeClassificationModel]

178

def treeWeights: Array[Double]

179

def featureImportances: Vector

180

181

override def predict(features: Vector): Double

182

override def predictRaw(features: Vector): Vector

183

override def predictProbability(features: Vector): Vector

184

def totalNumNodes: Int

185

def toDebugString: String

186

override def copy(extra: ParamMap): RandomForestClassificationModel

187

def write: MLWriter

188

}

189

```

190

191

## Gradient Boosted Tree Classifier

192

193

### Estimator

194

195

```scala { .api }

196

class GBTClassifier(override val uid: String)

197

extends Classifier[Vector, GBTClassifier, GBTClassificationModel]

198

with GBTClassifierParams with DefaultParamsWritable {

199

200

def this() = this(Identifiable.randomUID("gbtc"))

201

202

def setMaxDepth(value: Int): GBTClassifier

203

def setMaxBins(value: Int): GBTClassifier

204

def setMinInstancesPerNode(value: Int): GBTClassifier

205

def setMinInfoGain(value: Double): GBTClassifier

206

def setMaxMemoryInMB(value: Int): GBTClassifier

207

def setCacheNodeIds(value: Boolean): GBTClassifier

208

def setCheckpointInterval(value: Int): GBTClassifier

209

def setLossType(value: String): GBTClassifier

210

def setMaxIter(value: Int): GBTClassifier

211

def setStepSize(value: Double): GBTClassifier

212

def setSubsamplingRate(value: Double): GBTClassifier

213

def setSeed(value: Long): GBTClassifier

214

def setFeatureSubsetStrategy(value: String): GBTClassifier

215

def setValidationTol(value: Double): GBTClassifier

216

def setValidationIndicatorCol(value: String): GBTClassifier

217

218

override def fit(dataset: Dataset[_]): GBTClassificationModel

219

override def copy(extra: ParamMap): GBTClassifier

220

}

221

```

222

223

### Model

224

225

```scala { .api }

226

class GBTClassificationModel(override val uid: String, private val _trees: Array[DecisionTreeRegressionModel],

227

private val _treeWeights: Array[Double], val numFeatures: Int)

228

extends ClassificationModel[Vector, GBTClassificationModel]

229

with GBTClassifierParams with TreeEnsembleModel with MLWritable {

230

231

def trees: Array[DecisionTreeRegressionModel]

232

def treeWeights: Array[Double]

233

def featureImportances: Vector

234

def totalNumNodes: Int

235

def getNumTrees: Int

236

237

override def predict(features: Vector): Double

238

override def predictRaw(features: Vector): Vector

239

def toDebugString: String

240

override def copy(extra: ParamMap): GBTClassificationModel

241

def write: MLWriter

242

}

243

```

244

245

## Naive Bayes

246

247

### Estimator

248

249

```scala { .api }

250

class NaiveBayes(override val uid: String)

251

extends Classifier[Vector, NaiveBayes, NaiveBayesModel]

252

with NaiveBayesParams with DefaultParamsWritable {

253

254

def this() = this(Identifiable.randomUID("nb"))

255

256

def setModelType(value: String): NaiveBayes

257

def setSmoothing(value: Double): NaiveBayes

258

def setThresholds(value: Array[Double]): NaiveBayes

259

def setWeightCol(value: String): NaiveBayes

260

261

override def fit(dataset: Dataset[_]): NaiveBayesModel

262

override def copy(extra: ParamMap): NaiveBayes

263

}

264

```

265

266

### Model

267

268

```scala { .api }

269

class NaiveBayesModel(override val uid: String, val pi: Vector, val theta: Matrix, val sigma: Matrix)

270

extends ClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable {

271

272

val numFeatures: Int

273

val numClasses: Int

274

275

override def predict(features: Vector): Double

276

override def predictRaw(features: Vector): Vector

277

override def predictProbability(features: Vector): Vector

278

override def copy(extra: ParamMap): NaiveBayesModel

279

def write: MLWriter

280

}

281

```

282

283

## Multilayer Perceptron Classifier

284

285

### Estimator

286

287

```scala { .api }

288

class MultilayerPerceptronClassifier(override val uid: String)

289

extends Classifier[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]

290

with MultilayerPerceptronParams with DefaultParamsWritable {

291

292

def this() = this(Identifiable.randomUID("mlpc"))

293

294

def setLayers(value: Array[Int]): MultilayerPerceptronClassifier

295

def setBlockSize(value: Int): MultilayerPerceptronClassifier

296

def setSeed(value: Long): MultilayerPerceptronClassifier

297

def setMaxIter(value: Int): MultilayerPerceptronClassifier

298

def setTol(value: Double): MultilayerPerceptronClassifier

299

def setStepSize(value: Double): MultilayerPerceptronClassifier

300

def setSolver(value: String): MultilayerPerceptronClassifier

301

def setInitialWeights(value: Vector): MultilayerPerceptronClassifier

302

303

override def fit(dataset: Dataset[_]): MultilayerPerceptronClassificationModel

304

override def copy(extra: ParamMap): MultilayerPerceptronClassifier

305

}

306

```

307

308

### Model

309

310

```scala { .api }

311

class MultilayerPerceptronClassificationModel(override val uid: String, val layers: Array[Int],

312

val weights: Vector)

313

extends ClassificationModel[Vector, MultilayerPerceptronClassificationModel]

314

with MultilayerPerceptronParams with MLWritable {

315

316

val numFeatures: Int

317

val numClasses: Int

318

319

override def predict(features: Vector): Double

320

override def predictRaw(features: Vector): Vector

321

override def predictProbability(features: Vector): Vector

322

override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel

323

def write: MLWriter

324

}

325

```

326

327

## Linear Support Vector Classifier

328

329

### Estimator

330

331

```scala { .api }

332

class LinearSVC(override val uid: String)

333

extends Classifier[Vector, LinearSVC, LinearSVCModel]

334

with LinearSVCParams with DefaultParamsWritable {

335

336

def this() = this(Identifiable.randomUID("linearsvc"))

337

338

def setRegParam(value: Double): LinearSVC

339

def setMaxIter(value: Int): LinearSVC

340

def setTol(value: Double): LinearSVC

341

def setFitIntercept(value: Boolean): LinearSVC

342

def setStandardization(value: Boolean): LinearSVC

343

def setThreshold(value: Double): LinearSVC

344

def setWeightCol(value: String): LinearSVC

345

def setAggregationDepth(value: Int): LinearSVC

346

347

override def fit(dataset: Dataset[_]): LinearSVCModel

348

override def copy(extra: ParamMap): LinearSVC

349

}

350

```

351

352

### Model

353

354

```scala { .api }

355

class LinearSVCModel(override val uid: String, val coefficients: Vector, val intercept: Double)

356

extends ClassificationModel[Vector, LinearSVCModel] with LinearSVCParams with MLWritable {

357

358

val numClasses: Int = 2

359

val numFeatures: Int

360

361

override def predict(features: Vector): Double

362

override def predictRaw(features: Vector): Vector

363

override def copy(extra: ParamMap): LinearSVCModel

364

def write: MLWriter

365

}

366

```

367

368

## One-vs-Rest Meta-Classifier

369

370

### Estimator

371

372

```scala { .api }

373

class OneVsRest(override val uid: String)

374

extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {

375

376

def this() = this(Identifiable.randomUID("oneVsRest"))

377

378

def setClassifier(value: Classifier[_, _, _]): OneVsRest

379

def setLabelCol(value: String): OneVsRest

380

def setFeaturesCol(value: String): OneVsRest

381

def setPredictionCol(value: String): OneVsRest

382

def setRawPredictionCol(value: String): OneVsRest

383

def setParallelism(value: Int): OneVsRest

384

385

override def fit(dataset: Dataset[_]): OneVsRestModel

386

override def copy(extra: ParamMap): OneVsRest

387

override def transformSchema(schema: StructType): StructType

388

def write: MLWriter

389

}

390

```

391

392

### Model

393

394

```scala { .api }

395

class OneVsRestModel(override val uid: String, private val labelMetadata: Metadata, val models: Array[_ <: ClassificationModel[_, _]])

396

extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {

397

398

val numClasses: Int

399

400

override def transform(dataset: Dataset[_]): DataFrame

401

override def transformSchema(schema: StructType): StructType

402

override def copy(extra: ParamMap): OneVsRestModel

403

def write: MLWriter

404

}

405

```

406

407

## Classification Summary Classes

408

409

### Base Summary

410

411

```scala { .api }

412

abstract class ClassificationSummary(predictions: DataFrame, predictionCol: String,

413

labelCol: String, featuresCol: String) extends Serializable {

414

415

lazy val accuracy: Double

416

lazy val weightedPrecision: Double

417

lazy val weightedRecall: Double

418

lazy val weightedFMeasure: Double

419

lazy val weightedTruePositiveRate: Double = weightedRecall

420

lazy val weightedFalsePositiveRate: Double

421

422

def fMeasureByLabel(beta: Double = 1.0): Array[Double]

423

def precisionByLabel: Array[Double]

424

def recallByLabel: Array[Double]

425

def truePositiveRateByLabel: Array[Double] = recallByLabel

426

def falsePositiveRateByLabel: Array[Double]

427

428

def labels: Array[Double]

429

}

430

```

431

432

### Binary Classification Summary

433

434

```scala { .api }

435

trait BinaryClassificationSummary extends ClassificationSummary {

436

def scoreCol: String

437

438

def roc: DataFrame

439

def areaUnderROC: Double

440

def pr: DataFrame

441

def fMeasureByThreshold: DataFrame

442

def precisionByThreshold: DataFrame

443

def recallByThreshold: DataFrame

444

}

445

```

446

447

## Usage Examples

448

449

### Basic Classification Pipeline

450

451

```scala

452

import org.apache.spark.ml.Pipeline

453

import org.apache.spark.ml.classification.{LogisticRegression, RandomForestClassifier}

454

import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}

455

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

456

457

// Prepare features

458

val assembler = new VectorAssembler()

459

.setInputCols(Array("feature1", "feature2", "feature3"))

460

.setOutputCol("features")

461

462

// Index labels if they are strings

463

val labelIndexer = new StringIndexer()

464

.setInputCol("label")

465

.setOutputCol("indexedLabel")

466

467

// Create classifier

468

val lr = new LogisticRegression()

469

.setLabelCol("indexedLabel")

470

.setFeaturesCol("features")

471

.setMaxIter(100)

472

.setRegParam(0.1)

473

474

// Create pipeline

475

val pipeline = new Pipeline()

476

.setStages(Array(labelIndexer, assembler, lr))

477

478

// Split data

479

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

480

481

// Train model

482

val model = pipeline.fit(trainingData)

483

484

// Make predictions

485

val predictions = model.transform(testData)

486

487

// Evaluate

488

val evaluator = new BinaryClassificationEvaluator()

489

.setLabelCol("indexedLabel")

490

.setRawPredictionCol("rawPrediction")

491

.setMetricName("areaUnderROC")

492

493

val auc = evaluator.evaluate(predictions)

494

println(s"Area under ROC curve: $auc")

495

```

496

497

### Multiclass Classification with Random Forest

498

499

```scala

500

import org.apache.spark.ml.classification.{RandomForestClassifier, RandomForestClassificationModel}

501

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

502

503

val rf = new RandomForestClassifier()

504

.setLabelCol("indexedLabel")

505

.setFeaturesCol("features")

506

.setNumTrees(100)

507

.setMaxDepth(10)

508

.setMaxBins(32)

509

.setMinInstancesPerNode(1)

510

.setMinInfoGain(0.0)

511

.setSubsamplingRate(1.0)

512

.setFeatureSubsetStrategy("auto")

513

.setSeed(42)

514

515

val rfModel = rf.fit(trainingData)

516

517

// Get feature importances

518

val featureImportances = rfModel.featureImportances

519

println(s"Feature importances: $featureImportances")

520

521

val predictions = rfModel.transform(testData)

522

523

// Evaluate multiclass metrics

524

val evaluator = new MulticlassClassificationEvaluator()

525

.setLabelCol("indexedLabel")

526

.setPredictionCol("prediction")

527

528

val metrics = Array("accuracy", "weightedPrecision", "weightedRecall", "f1")

529

metrics.foreach { metric =>

530

evaluator.setMetricName(metric)

531

val result = evaluator.evaluate(predictions)

532

println(s"$metric: $result")

533

}

534

```

535

536

### Neural Network Classification

537

538

```scala

539

import org.apache.spark.ml.classification.MultilayerPerceptronClassifier

540

541

// Define network architecture: input layer (4 features) -> 2 hidden layers (5, 4 nodes) -> output layer (3 classes)

542

val layers = Array[Int](4, 5, 4, 3)

543

544

val mlp = new MultilayerPerceptronClassifier()

545

.setLayers(layers)

546

.setBlockSize(128)

547

.setSeed(1234L)

548

.setMaxIter(100)

549

.setStepSize(0.03)

550

.setSolver("l-bfgs")

551

552

val mlpModel = mlp.fit(trainingData)

553

val mlpPredictions = mlpModel.transform(testData)

554

555

// Show predictions

556

mlpPredictions.select("features", "label", "prediction", "probability").show(20)

557

```

558

559

### One-vs-Rest Multi-class Classification

560

561

```scala

562

import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression}

563

564

// Create base classifier

565

val classifier = new LogisticRegression()

566

.setMaxIter(10)

567

.setTol(1E-6)

568

.setFitIntercept(true)

569

570

// Create One-vs-Rest wrapper

571

val ovr = new OneVsRest()

572

.setClassifier(classifier)

573

.setLabelCol("label")

574

.setFeaturesCol("features")

575

576

val ovrModel = ovr.fit(trainingData)

577

val ovrPredictions = ovrModel.transform(testData)

578

579

// The model contains one binary classifier per class

580

println(s"Number of classes: ${ovrModel.models.length}")

581

```

582

583

### Model Summary and Metrics

584

585

```scala

586

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}

587

588

val lr = new LogisticRegression()

589

.setMaxIter(100)

590

.setRegParam(0.01)

591

.setElasticNetParam(0.0)

592

593

val lrModel = lr.fit(trainingData)

594

595

// Get training summary

596

val summary = lrModel.summary

597

println(s"Total iterations: ${summary.totalIterations}")

598

println(s"Objective history: ${summary.objectiveHistory.mkString(", ")}")

599

600

// Binary classification metrics (if binary classification)

601

if (summary.isInstanceOf[org.apache.spark.ml.classification.BinaryLogisticRegressionSummary]) {

602

val binarySummary = summary.asInstanceOf[org.apache.spark.ml.classification.BinaryLogisticRegressionSummary]

603

println(s"Area Under ROC: ${binarySummary.areaUnderROC}")

604

605

// Show ROC curve points

606

binarySummary.roc.show()

607

608

// Show precision-recall curve

609

binarySummary.pr.show()

610

}

611

612

// Evaluate on test data

613

val testSummary = lrModel.evaluate(testData)

614

println(s"Test Accuracy: ${testSummary.accuracy}")

615

println(s"Test Weighted Precision: ${testSummary.weightedPrecision}")

616

println(s"Test Weighted Recall: ${testSummary.weightedRecall}")

617

```