or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

algorithms.mddistance-metrics.mdindex.mdlinear-algebra.mdoptimization.mdoutlier-detection.mdpipeline.mdpreprocessing.md

optimization.mddocs/

0

# Optimization Framework

1

2

Apache Flink ML provides a flexible optimization framework for training machine learning models. The framework includes gradient descent solvers, pluggable loss functions, and regularization components that can be combined to create custom optimization strategies.

3

4

## Core Optimization Components

5

6

### Solver

7

8

Base trait for optimization algorithms.

9

10

```scala { .api }

11

trait Solver extends WithParameters {

12

// Base solver interface

13

}

14

15

object Solver {

16

// Parameters common to all solvers

17

case object LossFunction extends Parameter[LossFunction] {

18

val defaultValue = None

19

}

20

21

case object RegularizationPenaltyValue extends Parameter[RegularizationPenalty] {

22

val defaultValue = Some(NoRegularization())

23

}

24

25

case object RegularizationConstant extends Parameter[Double] {

26

val defaultValue = Some(0.0)

27

}

28

}

29

```

30

31

### Iterative Solver

32

33

Base trait for iterative optimization algorithms that extends `Solver`.

34

35

```scala { .api }

36

trait IterativeSolver extends Solver {

37

// Additional parameters for iterative methods

38

}

39

40

object IterativeSolver {

41

case object Iterations extends Parameter[Int] {

42

val defaultValue = Some(10)

43

}

44

45

case object LearningRate extends Parameter[Double] {

46

val defaultValue = Some(0.1)

47

}

48

49

case object ConvergenceThreshold extends Parameter[Double] {

50

val defaultValue = Some(1e-6)

51

}

52

53

case object LearningRateMethodValue extends Parameter[LearningRateMethod] {

54

val defaultValue = Some(LearningRateMethod.Default)

55

}

56

}

57

```

58

59

## Gradient Descent Solver

60

61

Stochastic Gradient Descent implementation for distributed optimization.

62

63

```scala { .api }

64

class GradientDescent extends IterativeSolver {

65

def optimize(

66

data: DataSet[LabeledVector],

67

initialWeights: Option[DataSet[WeightVector]] = None

68

): DataSet[WeightVector]

69

}

70

71

object GradientDescent {

72

def apply(): GradientDescent

73

}

74

```

75

76

**Usage Example:**

77

78

```scala

79

import org.apache.flink.ml.optimization.{GradientDescent, SquaredLoss}

80

import org.apache.flink.ml.optimization.LearningRateMethod

81

82

val trainingData: DataSet[LabeledVector] = //... your training data

83

84

// Configure gradient descent

85

val gd = GradientDescent()

86

.set(IterativeSolver.Iterations, 100)

87

.set(IterativeSolver.LearningRate, 0.01)

88

.set(IterativeSolver.ConvergenceThreshold, 1e-8)

89

.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)

90

.set(Solver.LossFunction, SquaredLoss())

91

.set(Solver.RegularizationConstant, 0.1)

92

93

// Run optimization

94

val optimizedWeights = gd.optimize(trainingData)

95

```

96

97

## Learning Rate Methods

98

99

Different strategies for adjusting learning rate during optimization.

100

101

```scala { .api }

102

sealed trait LearningRateMethod

103

104

object LearningRateMethod {

105

case object Default extends LearningRateMethod

106

case object Inverse extends LearningRateMethod

107

case object InverseSquareRoot extends LearningRateMethod

108

}

109

```

110

111

- **Default**: Constant learning rate

112

- **Inverse**: Learning rate = initial_rate / iteration

113

- **InverseSquareRoot**: Learning rate = initial_rate / sqrt(iteration)

114

115

## Loss Functions

116

117

### Base Loss Function Interface

118

119

```scala { .api }

120

trait LossFunction {

121

def loss(dataPoint: LabeledVector, weights: WeightVector): Double

122

def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector

123

def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)

124

}

125

```

126

127

### Generic Loss Function

128

129

Combines partial loss functions with prediction functions for flexible loss computation.

130

131

```scala { .api }

132

case class GenericLossFunction(

133

partialLossFunction: PartialLossFunction,

134

predictionFunction: PredictionFunction

135

) extends LossFunction {

136

def loss(dataPoint: LabeledVector, weights: WeightVector): Double

137

def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector

138

def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)

139

}

140

```

141

142

## Partial Loss Functions

143

144

Partial loss functions define the loss computation without prediction logic.

145

146

```scala { .api }

147

trait PartialLossFunction {

148

def loss(prediction: Double, label: Double): Double

149

def derivative(prediction: Double, label: Double): Double

150

}

151

```

152

153

### Common Partial Loss Functions

154

155

**Squared Loss (for regression):**

156

```scala

157

class SquaredLoss extends PartialLossFunction {

158

def loss(prediction: Double, label: Double): Double = {

159

val diff = prediction - label

160

0.5 * diff * diff

161

}

162

163

def derivative(prediction: Double, label: Double): Double = {

164

prediction - label

165

}

166

}

167

```

168

169

**Hinge Loss (for SVM):**

170

```scala

171

class HingeLoss extends PartialLossFunction {

172

def loss(prediction: Double, label: Double): Double = {

173

val margin = label * prediction

174

if (margin >= 1.0) 0.0 else 1.0 - margin

175

}

176

177

def derivative(prediction: Double, label: Double): Double = {

178

if (label * prediction >= 1.0) 0.0 else -label

179

}

180

}

181

```

182

183

**Logistic Loss (for logistic regression):**

184

```scala

185

class LogisticLoss extends PartialLossFunction {

186

def loss(prediction: Double, label: Double): Double = {

187

math.log(1.0 + math.exp(-label * prediction))

188

}

189

190

def derivative(prediction: Double, label: Double): Double = {

191

val exp = math.exp(-label * prediction)

192

-label * exp / (1.0 + exp)

193

}

194

}

195

```

196

197

## Prediction Functions

198

199

Prediction functions define how to compute predictions from features and weights.

200

201

```scala { .api }

202

trait PredictionFunction {

203

def predict(features: Vector, weights: WeightVector): Double

204

def gradient(features: Vector, weights: WeightVector): Vector

205

}

206

```

207

208

### Linear Prediction Function

209

210

Standard linear prediction for linear models.

211

212

```scala

213

class LinearPrediction extends PredictionFunction {

214

def predict(features: Vector, weights: WeightVector): Double = {

215

features.dot(weights.weights) + weights.intercept

216

}

217

218

def gradient(features: Vector, weights: WeightVector): Vector = {

219

features // Gradient w.r.t. weights is just the features

220

}

221

}

222

```

223

224

## Regularization

225

226

Regularization penalties help prevent overfitting by adding penalty terms to the loss function.

227

228

```scala { .api }

229

trait RegularizationPenalty {

230

def penalty(weights: WeightVector): Double

231

def gradient(weights: WeightVector): Vector

232

}

233

```

234

235

### No Regularization

236

237

```scala { .api }

238

case class NoRegularization() extends RegularizationPenalty {

239

def penalty(weights: WeightVector): Double = 0.0

240

def gradient(weights: WeightVector): Vector = DenseVector.zeros(weights.weights.size)

241

}

242

```

243

244

### L1 Regularization (Lasso)

245

246

```scala

247

case class L1Regularization(lambda: Double = 0.1) extends RegularizationPenalty {

248

def penalty(weights: WeightVector): Double = {

249

lambda * weights.weights.toArray.map(math.abs).sum

250

}

251

252

def gradient(weights: WeightVector): Vector = {

253

val gradArray = weights.weights.toArray.map(w => if (w > 0) lambda else if (w < 0) -lambda else 0.0)

254

DenseVector(gradArray)

255

}

256

}

257

```

258

259

### L2 Regularization (Ridge)

260

261

```scala

262

case class L2Regularization(lambda: Double = 0.1) extends RegularizationPenalty {

263

def penalty(weights: WeightVector): Double = {

264

0.5 * lambda * weights.weights.dot(weights.weights)

265

}

266

267

def gradient(weights: WeightVector): Vector = {

268

weights.weights.copy.scal(lambda)

269

}

270

}

271

```

272

273

## Custom Optimization Example

274

275

Building a custom optimization setup:

276

277

```scala

278

import org.apache.flink.ml.optimization._

279

280

// Define custom loss function

281

val customLoss = GenericLossFunction(

282

partialLossFunction = new SquaredLoss(),

283

predictionFunction = new LinearPrediction()

284

)

285

286

// Configure solver with custom settings

287

val optimizer = GradientDescent()

288

.set(Solver.LossFunction, customLoss)

289

.set(Solver.RegularizationPenaltyValue, L2Regularization(0.01))

290

.set(Solver.RegularizationConstant, 1.0)

291

.set(IterativeSolver.Iterations, 200)

292

.set(IterativeSolver.LearningRate, 0.05)

293

.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)

294

.set(IterativeSolver.ConvergenceThreshold, 1e-10)

295

296

// Run optimization

297

val finalWeights = optimizer.optimize(trainingData)

298

```

299

300

## Integration with ML Algorithms

301

302

The optimization framework is used internally by ML algorithms. For example, MultipleLinearRegression uses gradient descent:

303

304

```scala

305

import org.apache.flink.ml.regression.MultipleLinearRegression

306

307

val regression = MultipleLinearRegression()

308

.setIterations(100) // Maps to IterativeSolver.Iterations

309

.setStepsize(0.01) // Maps to IterativeSolver.LearningRate

310

.setConvergenceThreshold(1e-8) // Maps to IterativeSolver.ConvergenceThreshold

311

312

// Internally uses optimization framework

313

val model = regression.fit(trainingData)

314

```

315

316

## Distributed Optimization Considerations

317

318

The optimization framework is designed for distributed execution on Flink:

319

320

### Data Partitioning

321

```scala

322

// Partition data for better distributed performance

323

val partitionedData = trainingData.partitionByHash(_.hashCode())

324

val weights = optimizer.optimize(partitionedData)

325

```

326

327

### Broadcast Variables

328

Large model parameters can be broadcast to avoid network overhead:

329

330

```scala

331

// The framework automatically handles broadcast variables for efficient

332

// distributed gradient computation

333

val distributedWeights = optimizer.optimize(largeDataset)

334

```

335

336

### Checkpointing

337

For long-running optimizations, consider checkpointing:

338

339

```scala

340

env.enableCheckpointing(10000) // Checkpoint every 10 seconds

341

342

val robustOptimizer = GradientDescent()

343

.set(IterativeSolver.Iterations, 1000) // Many iterations

344

// ... other parameters

345

346

val weights = robustOptimizer.optimize(trainingData)

347

```

348

349

## Monitoring Optimization Progress

350

351

While the framework doesn't provide built-in progress monitoring, you can track convergence:

352

353

```scala

354

// Custom convergence tracking (conceptual)

355

class ProgressTrackingGradientDescent extends GradientDescent {

356

override def optimize(

357

data: DataSet[LabeledVector],

358

initialWeights: Option[DataSet[WeightVector]] = None

359

): DataSet[WeightVector] = {

360

// Add custom logging/monitoring logic

361

val result = super.optimize(data, initialWeights)

362

363

// Log final loss

364

val finalLoss = computeLoss(data, result)

365

println(s"Final loss: $finalLoss")

366

367

result

368

}

369

}

370

```

371

372

## Best Practices

373

374

1. **Learning Rate Tuning**: Start with default rates and adjust based on convergence behavior

375

2. **Regularization**: Use L2 regularization to prevent overfitting, L1 for feature selection

376

3. **Convergence Thresholds**: Set appropriate thresholds based on your precision requirements

377

4. **Data Preprocessing**: Normalize features for better optimization performance

378

5. **Monitoring**: Track loss values to ensure proper convergence

379

380

```scala

381

// Good practice example

382

val optimizer = GradientDescent()

383

.set(IterativeSolver.LearningRate, 0.01) // Conservative learning rate

384

.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse) // Adaptive rate

385

.set(Solver.RegularizationPenaltyValue, L2Regularization(0.01)) // Prevent overfitting

386

.set(IterativeSolver.ConvergenceThreshold, 1e-8) // Reasonable precision

387

.set(IterativeSolver.Iterations, 1000) // Sufficient iterations

388

389

val optimizedWeights = optimizer.optimize(normalizedTrainingData)

390

```