0
# Optimization Framework
1
2
Apache Flink ML provides a flexible optimization framework for training machine learning models. The framework includes gradient descent solvers, pluggable loss functions, and regularization components that can be combined to create custom optimization strategies.
3
4
## Core Optimization Components
5
6
### Solver
7
8
Base trait for optimization algorithms.
9
10
```scala { .api }
11
trait Solver extends WithParameters {
12
// Base solver interface
13
}
14
15
object Solver {
16
// Parameters common to all solvers
17
case object LossFunction extends Parameter[LossFunction] {
18
val defaultValue = None
19
}
20
21
case object RegularizationPenaltyValue extends Parameter[RegularizationPenalty] {
22
val defaultValue = Some(NoRegularization())
23
}
24
25
case object RegularizationConstant extends Parameter[Double] {
26
val defaultValue = Some(0.0)
27
}
28
}
29
```
30
31
### Iterative Solver
32
33
Base trait for iterative optimization algorithms that extends `Solver`.
34
35
```scala { .api }
36
trait IterativeSolver extends Solver {
37
// Additional parameters for iterative methods
38
}
39
40
object IterativeSolver {
41
case object Iterations extends Parameter[Int] {
42
val defaultValue = Some(10)
43
}
44
45
case object LearningRate extends Parameter[Double] {
46
val defaultValue = Some(0.1)
47
}
48
49
case object ConvergenceThreshold extends Parameter[Double] {
50
val defaultValue = Some(1e-6)
51
}
52
53
case object LearningRateMethodValue extends Parameter[LearningRateMethod] {
54
val defaultValue = Some(LearningRateMethod.Default)
55
}
56
}
57
```
58
59
## Gradient Descent Solver
60
61
Stochastic Gradient Descent implementation for distributed optimization.
62
63
```scala { .api }
64
class GradientDescent extends IterativeSolver {
65
def optimize(
66
data: DataSet[LabeledVector],
67
initialWeights: Option[DataSet[WeightVector]] = None
68
): DataSet[WeightVector]
69
}
70
71
object GradientDescent {
72
def apply(): GradientDescent
73
}
74
```
75
76
**Usage Example:**
77
78
```scala
79
import org.apache.flink.ml.optimization.{GradientDescent, SquaredLoss}
80
import org.apache.flink.ml.optimization.LearningRateMethod
81
82
val trainingData: DataSet[LabeledVector] = //... your training data
83
84
// Configure gradient descent
85
val gd = GradientDescent()
86
.set(IterativeSolver.Iterations, 100)
87
.set(IterativeSolver.LearningRate, 0.01)
88
.set(IterativeSolver.ConvergenceThreshold, 1e-8)
89
.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)
90
.set(Solver.LossFunction, SquaredLoss())
91
.set(Solver.RegularizationConstant, 0.1)
92
93
// Run optimization
94
val optimizedWeights = gd.optimize(trainingData)
95
```
96
97
## Learning Rate Methods
98
99
Different strategies for adjusting learning rate during optimization.
100
101
```scala { .api }
102
sealed trait LearningRateMethod
103
104
object LearningRateMethod {
105
case object Default extends LearningRateMethod
106
case object Inverse extends LearningRateMethod
107
case object InverseSquareRoot extends LearningRateMethod
108
}
109
```
110
111
- **Default**: Constant learning rate
112
- **Inverse**: Learning rate = initial_rate / iteration
113
- **InverseSquareRoot**: Learning rate = initial_rate / sqrt(iteration)
114
115
## Loss Functions
116
117
### Base Loss Function Interface
118
119
```scala { .api }
120
trait LossFunction {
121
def loss(dataPoint: LabeledVector, weights: WeightVector): Double
122
def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector
123
def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)
124
}
125
```
126
127
### Generic Loss Function
128
129
Combines partial loss functions with prediction functions for flexible loss computation.
130
131
```scala { .api }
132
case class GenericLossFunction(
133
partialLossFunction: PartialLossFunction,
134
predictionFunction: PredictionFunction
135
) extends LossFunction {
136
def loss(dataPoint: LabeledVector, weights: WeightVector): Double
137
def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector
138
def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)
139
}
140
```
141
142
## Partial Loss Functions
143
144
Partial loss functions define the loss computation without prediction logic.
145
146
```scala { .api }
147
trait PartialLossFunction {
148
def loss(prediction: Double, label: Double): Double
149
def derivative(prediction: Double, label: Double): Double
150
}
151
```
152
153
### Common Partial Loss Functions
154
155
**Squared Loss (for regression):**
156
```scala
157
class SquaredLoss extends PartialLossFunction {
158
def loss(prediction: Double, label: Double): Double = {
159
val diff = prediction - label
160
0.5 * diff * diff
161
}
162
163
def derivative(prediction: Double, label: Double): Double = {
164
prediction - label
165
}
166
}
167
```
168
169
**Hinge Loss (for SVM):**
170
```scala
171
class HingeLoss extends PartialLossFunction {
172
def loss(prediction: Double, label: Double): Double = {
173
val margin = label * prediction
174
if (margin >= 1.0) 0.0 else 1.0 - margin
175
}
176
177
def derivative(prediction: Double, label: Double): Double = {
178
if (label * prediction >= 1.0) 0.0 else -label
179
}
180
}
181
```
182
183
**Logistic Loss (for logistic regression):**
184
```scala
185
class LogisticLoss extends PartialLossFunction {
186
def loss(prediction: Double, label: Double): Double = {
187
math.log(1.0 + math.exp(-label * prediction))
188
}
189
190
def derivative(prediction: Double, label: Double): Double = {
191
val exp = math.exp(-label * prediction)
192
-label * exp / (1.0 + exp)
193
}
194
}
195
```
196
197
## Prediction Functions
198
199
Prediction functions define how to compute predictions from features and weights.
200
201
```scala { .api }
202
trait PredictionFunction {
203
def predict(features: Vector, weights: WeightVector): Double
204
def gradient(features: Vector, weights: WeightVector): Vector
205
}
206
```
207
208
### Linear Prediction Function
209
210
Standard linear prediction for linear models.
211
212
```scala
213
class LinearPrediction extends PredictionFunction {
214
def predict(features: Vector, weights: WeightVector): Double = {
215
features.dot(weights.weights) + weights.intercept
216
}
217
218
def gradient(features: Vector, weights: WeightVector): Vector = {
219
features // Gradient w.r.t. weights is just the features
220
}
221
}
222
```
223
224
## Regularization
225
226
Regularization penalties help prevent overfitting by adding penalty terms to the loss function.
227
228
```scala { .api }
229
trait RegularizationPenalty {
230
def penalty(weights: WeightVector): Double
231
def gradient(weights: WeightVector): Vector
232
}
233
```
234
235
### No Regularization
236
237
```scala { .api }
238
case class NoRegularization() extends RegularizationPenalty {
239
def penalty(weights: WeightVector): Double = 0.0
240
def gradient(weights: WeightVector): Vector = DenseVector.zeros(weights.weights.size)
241
}
242
```
243
244
### L1 Regularization (Lasso)
245
246
```scala
247
case class L1Regularization(lambda: Double = 0.1) extends RegularizationPenalty {
248
def penalty(weights: WeightVector): Double = {
249
lambda * weights.weights.toArray.map(math.abs).sum
250
}
251
252
def gradient(weights: WeightVector): Vector = {
253
val gradArray = weights.weights.toArray.map(w => if (w > 0) lambda else if (w < 0) -lambda else 0.0)
254
DenseVector(gradArray)
255
}
256
}
257
```
258
259
### L2 Regularization (Ridge)
260
261
```scala
262
case class L2Regularization(lambda: Double = 0.1) extends RegularizationPenalty {
263
def penalty(weights: WeightVector): Double = {
264
0.5 * lambda * weights.weights.dot(weights.weights)
265
}
266
267
def gradient(weights: WeightVector): Vector = {
268
weights.weights.copy.scal(lambda)
269
}
270
}
271
```
272
273
## Custom Optimization Example
274
275
Building a custom optimization setup:
276
277
```scala
278
import org.apache.flink.ml.optimization._
279
280
// Define custom loss function
281
val customLoss = GenericLossFunction(
282
partialLossFunction = new SquaredLoss(),
283
predictionFunction = new LinearPrediction()
284
)
285
286
// Configure solver with custom settings
287
val optimizer = GradientDescent()
288
.set(Solver.LossFunction, customLoss)
289
.set(Solver.RegularizationPenaltyValue, L2Regularization(0.01))
290
.set(Solver.RegularizationConstant, 1.0)
291
.set(IterativeSolver.Iterations, 200)
292
.set(IterativeSolver.LearningRate, 0.05)
293
.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)
294
.set(IterativeSolver.ConvergenceThreshold, 1e-10)
295
296
// Run optimization
297
val finalWeights = optimizer.optimize(trainingData)
298
```
299
300
## Integration with ML Algorithms
301
302
The optimization framework is used internally by ML algorithms. For example, MultipleLinearRegression uses gradient descent:
303
304
```scala
305
import org.apache.flink.ml.regression.MultipleLinearRegression
306
307
val regression = MultipleLinearRegression()
308
.setIterations(100) // Maps to IterativeSolver.Iterations
309
.setStepsize(0.01) // Maps to IterativeSolver.LearningRate
310
.setConvergenceThreshold(1e-8) // Maps to IterativeSolver.ConvergenceThreshold
311
312
// Internally uses optimization framework
313
val model = regression.fit(trainingData)
314
```
315
316
## Distributed Optimization Considerations
317
318
The optimization framework is designed for distributed execution on Flink:
319
320
### Data Partitioning
321
```scala
322
// Partition data for better distributed performance
323
val partitionedData = trainingData.partitionByHash(_.hashCode())
324
val weights = optimizer.optimize(partitionedData)
325
```
326
327
### Broadcast Variables
328
Large model parameters can be broadcast to avoid network overhead:
329
330
```scala
331
// The framework automatically handles broadcast variables for efficient
332
// distributed gradient computation
333
val distributedWeights = optimizer.optimize(largeDataset)
334
```
335
336
### Checkpointing
337
For long-running optimizations, consider checkpointing:
338
339
```scala
340
env.enableCheckpointing(10000) // Checkpoint every 10 seconds
341
342
val robustOptimizer = GradientDescent()
343
.set(IterativeSolver.Iterations, 1000) // Many iterations
344
// ... other parameters
345
346
val weights = robustOptimizer.optimize(trainingData)
347
```
348
349
## Monitoring Optimization Progress
350
351
While the framework doesn't provide built-in progress monitoring, you can track convergence:
352
353
```scala
354
// Custom convergence tracking (conceptual)
355
class ProgressTrackingGradientDescent extends GradientDescent {
356
override def optimize(
357
data: DataSet[LabeledVector],
358
initialWeights: Option[DataSet[WeightVector]] = None
359
): DataSet[WeightVector] = {
360
// Add custom logging/monitoring logic
361
val result = super.optimize(data, initialWeights)
362
363
// Log final loss
364
val finalLoss = computeLoss(data, result)
365
println(s"Final loss: $finalLoss")
366
367
result
368
}
369
}
370
```
371
372
## Best Practices
373
374
1. **Learning Rate Tuning**: Start with default rates and adjust based on convergence behavior
375
2. **Regularization**: Use L2 regularization to prevent overfitting, L1 for feature selection
376
3. **Convergence Thresholds**: Set appropriate thresholds based on your precision requirements
377
4. **Data Preprocessing**: Normalize features for better optimization performance
378
5. **Monitoring**: Track loss values to ensure proper convergence
379
380
```scala
381
// Good practice example
382
val optimizer = GradientDescent()
383
.set(IterativeSolver.LearningRate, 0.01) // Conservative learning rate
384
.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse) // Adaptive rate
385
.set(Solver.RegularizationPenaltyValue, L2Regularization(0.01)) // Prevent overfitting
386
.set(IterativeSolver.ConvergenceThreshold, 1e-8) // Reasonable precision
387
.set(IterativeSolver.Iterations, 1000) // Sufficient iterations
388
389
val optimizedWeights = optimizer.optimize(normalizedTrainingData)
390
```