Apache Flink ML provides a flexible optimization framework for training machine learning models. The framework includes gradient descent solvers, pluggable loss functions, and regularization components that can be combined to create custom optimization strategies.
Base trait for optimization algorithms.
trait Solver extends WithParameters {
// Base solver interface
}
object Solver {
// Parameters common to all solvers
case object LossFunction extends Parameter[LossFunction] {
val defaultValue = None
}
case object RegularizationPenaltyValue extends Parameter[RegularizationPenalty] {
val defaultValue = Some(NoRegularization())
}
case object RegularizationConstant extends Parameter[Double] {
val defaultValue = Some(0.0)
}
}Base trait for iterative optimization algorithms that extends Solver.
trait IterativeSolver extends Solver {
// Additional parameters for iterative methods
}
object IterativeSolver {
case object Iterations extends Parameter[Int] {
val defaultValue = Some(10)
}
case object LearningRate extends Parameter[Double] {
val defaultValue = Some(0.1)
}
case object ConvergenceThreshold extends Parameter[Double] {
val defaultValue = Some(1e-6)
}
case object LearningRateMethodValue extends Parameter[LearningRateMethod] {
val defaultValue = Some(LearningRateMethod.Default)
}
}Stochastic Gradient Descent implementation for distributed optimization.
class GradientDescent extends IterativeSolver {
def optimize(
data: DataSet[LabeledVector],
initialWeights: Option[DataSet[WeightVector]] = None
): DataSet[WeightVector]
}
object GradientDescent {
def apply(): GradientDescent
}Usage Example:
import org.apache.flink.ml.optimization.{GradientDescent, SquaredLoss}
import org.apache.flink.ml.optimization.LearningRateMethod
val trainingData: DataSet[LabeledVector] = //... your training data
// Configure gradient descent
val gd = GradientDescent()
.set(IterativeSolver.Iterations, 100)
.set(IterativeSolver.LearningRate, 0.01)
.set(IterativeSolver.ConvergenceThreshold, 1e-8)
.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)
.set(Solver.LossFunction, SquaredLoss())
.set(Solver.RegularizationConstant, 0.1)
// Run optimization
val optimizedWeights = gd.optimize(trainingData)Different strategies for adjusting learning rate during optimization.
sealed trait LearningRateMethod
object LearningRateMethod {
case object Default extends LearningRateMethod
case object Inverse extends LearningRateMethod
case object InverseSquareRoot extends LearningRateMethod
}trait LossFunction {
def loss(dataPoint: LabeledVector, weights: WeightVector): Double
def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector
def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)
}Combines partial loss functions with prediction functions for flexible loss computation.
case class GenericLossFunction(
partialLossFunction: PartialLossFunction,
predictionFunction: PredictionFunction
) extends LossFunction {
def loss(dataPoint: LabeledVector, weights: WeightVector): Double
def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector
def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)
}Partial loss functions define the loss computation without prediction logic.
trait PartialLossFunction {
def loss(prediction: Double, label: Double): Double
def derivative(prediction: Double, label: Double): Double
}Squared Loss (for regression):
class SquaredLoss extends PartialLossFunction {
def loss(prediction: Double, label: Double): Double = {
val diff = prediction - label
0.5 * diff * diff
}
def derivative(prediction: Double, label: Double): Double = {
prediction - label
}
}Hinge Loss (for SVM):
class HingeLoss extends PartialLossFunction {
def loss(prediction: Double, label: Double): Double = {
val margin = label * prediction
if (margin >= 1.0) 0.0 else 1.0 - margin
}
def derivative(prediction: Double, label: Double): Double = {
if (label * prediction >= 1.0) 0.0 else -label
}
}Logistic Loss (for logistic regression):
class LogisticLoss extends PartialLossFunction {
def loss(prediction: Double, label: Double): Double = {
math.log(1.0 + math.exp(-label * prediction))
}
def derivative(prediction: Double, label: Double): Double = {
val exp = math.exp(-label * prediction)
-label * exp / (1.0 + exp)
}
}Prediction functions define how to compute predictions from features and weights.
trait PredictionFunction {
def predict(features: Vector, weights: WeightVector): Double
def gradient(features: Vector, weights: WeightVector): Vector
}Standard linear prediction for linear models.
class LinearPrediction extends PredictionFunction {
def predict(features: Vector, weights: WeightVector): Double = {
features.dot(weights.weights) + weights.intercept
}
def gradient(features: Vector, weights: WeightVector): Vector = {
features // Gradient w.r.t. weights is just the features
}
}Regularization penalties help prevent overfitting by adding penalty terms to the loss function.
trait RegularizationPenalty {
def penalty(weights: WeightVector): Double
def gradient(weights: WeightVector): Vector
}case class NoRegularization() extends RegularizationPenalty {
def penalty(weights: WeightVector): Double = 0.0
def gradient(weights: WeightVector): Vector = DenseVector.zeros(weights.weights.size)
}case class L1Regularization(lambda: Double = 0.1) extends RegularizationPenalty {
def penalty(weights: WeightVector): Double = {
lambda * weights.weights.toArray.map(math.abs).sum
}
def gradient(weights: WeightVector): Vector = {
val gradArray = weights.weights.toArray.map(w => if (w > 0) lambda else if (w < 0) -lambda else 0.0)
DenseVector(gradArray)
}
}case class L2Regularization(lambda: Double = 0.1) extends RegularizationPenalty {
def penalty(weights: WeightVector): Double = {
0.5 * lambda * weights.weights.dot(weights.weights)
}
def gradient(weights: WeightVector): Vector = {
weights.weights.copy.scal(lambda)
}
}Building a custom optimization setup:
import org.apache.flink.ml.optimization._
// Define custom loss function
val customLoss = GenericLossFunction(
partialLossFunction = new SquaredLoss(),
predictionFunction = new LinearPrediction()
)
// Configure solver with custom settings
val optimizer = GradientDescent()
.set(Solver.LossFunction, customLoss)
.set(Solver.RegularizationPenaltyValue, L2Regularization(0.01))
.set(Solver.RegularizationConstant, 1.0)
.set(IterativeSolver.Iterations, 200)
.set(IterativeSolver.LearningRate, 0.05)
.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)
.set(IterativeSolver.ConvergenceThreshold, 1e-10)
// Run optimization
val finalWeights = optimizer.optimize(trainingData)The optimization framework is used internally by ML algorithms. For example, MultipleLinearRegression uses gradient descent:
import org.apache.flink.ml.regression.MultipleLinearRegression
val regression = MultipleLinearRegression()
.setIterations(100) // Maps to IterativeSolver.Iterations
.setStepsize(0.01) // Maps to IterativeSolver.LearningRate
.setConvergenceThreshold(1e-8) // Maps to IterativeSolver.ConvergenceThreshold
// Internally uses optimization framework
val model = regression.fit(trainingData)The optimization framework is designed for distributed execution on Flink:
// Partition data for better distributed performance
val partitionedData = trainingData.partitionByHash(_.hashCode())
val weights = optimizer.optimize(partitionedData)Large model parameters can be broadcast to avoid network overhead:
// The framework automatically handles broadcast variables for efficient
// distributed gradient computation
val distributedWeights = optimizer.optimize(largeDataset)For long-running optimizations, consider checkpointing:
env.enableCheckpointing(10000) // Checkpoint every 10 seconds
val robustOptimizer = GradientDescent()
.set(IterativeSolver.Iterations, 1000) // Many iterations
// ... other parameters
val weights = robustOptimizer.optimize(trainingData)While the framework doesn't provide built-in progress monitoring, you can track convergence:
// Custom convergence tracking (conceptual)
class ProgressTrackingGradientDescent extends GradientDescent {
override def optimize(
data: DataSet[LabeledVector],
initialWeights: Option[DataSet[WeightVector]] = None
): DataSet[WeightVector] = {
// Add custom logging/monitoring logic
val result = super.optimize(data, initialWeights)
// Log final loss
val finalLoss = computeLoss(data, result)
println(s"Final loss: $finalLoss")
result
}
}// Good practice example
val optimizer = GradientDescent()
.set(IterativeSolver.LearningRate, 0.01) // Conservative learning rate
.set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse) // Adaptive rate
.set(Solver.RegularizationPenaltyValue, L2Regularization(0.01)) // Prevent overfitting
.set(IterativeSolver.ConvergenceThreshold, 1e-8) // Reasonable precision
.set(IterativeSolver.Iterations, 1000) // Sufficient iterations
val optimizedWeights = optimizer.optimize(normalizedTrainingData)