or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

algorithms.mddistance-metrics.mdindex.mdlinear-algebra.mdoptimization.mdoutlier-detection.mdpipeline.mdpreprocessing.md
tile.json

optimization.mddocs/

Optimization Framework

Apache Flink ML provides a flexible optimization framework for training machine learning models. The framework includes gradient descent solvers, pluggable loss functions, and regularization components that can be combined to create custom optimization strategies.

Core Optimization Components

Solver

Base trait for optimization algorithms.

trait Solver extends WithParameters {
  // Base solver interface
}

object Solver {
  // Parameters common to all solvers
  case object LossFunction extends Parameter[LossFunction] {
    val defaultValue = None
  }
  
  case object RegularizationPenaltyValue extends Parameter[RegularizationPenalty] {
    val defaultValue = Some(NoRegularization())
  }
  
  case object RegularizationConstant extends Parameter[Double] {
    val defaultValue = Some(0.0)
  }
}

Iterative Solver

Base trait for iterative optimization algorithms that extends Solver.

trait IterativeSolver extends Solver {
  // Additional parameters for iterative methods
}

object IterativeSolver {
  case object Iterations extends Parameter[Int] {
    val defaultValue = Some(10)
  }
  
  case object LearningRate extends Parameter[Double] {
    val defaultValue = Some(0.1)
  }
  
  case object ConvergenceThreshold extends Parameter[Double] {
    val defaultValue = Some(1e-6)
  }
  
  case object LearningRateMethodValue extends Parameter[LearningRateMethod] {
    val defaultValue = Some(LearningRateMethod.Default)
  }
}

Gradient Descent Solver

Stochastic Gradient Descent implementation for distributed optimization.

class GradientDescent extends IterativeSolver {
  def optimize(
    data: DataSet[LabeledVector],
    initialWeights: Option[DataSet[WeightVector]] = None
  ): DataSet[WeightVector]
}

object GradientDescent {
  def apply(): GradientDescent
}

Usage Example:

import org.apache.flink.ml.optimization.{GradientDescent, SquaredLoss}
import org.apache.flink.ml.optimization.LearningRateMethod

val trainingData: DataSet[LabeledVector] = //... your training data

// Configure gradient descent
val gd = GradientDescent()
  .set(IterativeSolver.Iterations, 100)
  .set(IterativeSolver.LearningRate, 0.01)
  .set(IterativeSolver.ConvergenceThreshold, 1e-8)
  .set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)
  .set(Solver.LossFunction, SquaredLoss())
  .set(Solver.RegularizationConstant, 0.1)

// Run optimization
val optimizedWeights = gd.optimize(trainingData)

Learning Rate Methods

Different strategies for adjusting learning rate during optimization.

sealed trait LearningRateMethod

object LearningRateMethod {
  case object Default extends LearningRateMethod
  case object Inverse extends LearningRateMethod
  case object InverseSquareRoot extends LearningRateMethod
}
  • Default: Constant learning rate
  • Inverse: Learning rate = initial_rate / iteration
  • InverseSquareRoot: Learning rate = initial_rate / sqrt(iteration)

Loss Functions

Base Loss Function Interface

trait LossFunction {
  def loss(dataPoint: LabeledVector, weights: WeightVector): Double
  def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector
  def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)
}

Generic Loss Function

Combines partial loss functions with prediction functions for flexible loss computation.

case class GenericLossFunction(
  partialLossFunction: PartialLossFunction,
  predictionFunction: PredictionFunction
) extends LossFunction {
  def loss(dataPoint: LabeledVector, weights: WeightVector): Double
  def gradient(dataPoint: LabeledVector, weights: WeightVector): Vector
  def lossGradient(dataPoint: LabeledVector, weights: WeightVector): (Double, Vector)
}

Partial Loss Functions

Partial loss functions define the loss computation without prediction logic.

trait PartialLossFunction {
  def loss(prediction: Double, label: Double): Double
  def derivative(prediction: Double, label: Double): Double
}

Common Partial Loss Functions

Squared Loss (for regression):

class SquaredLoss extends PartialLossFunction {
  def loss(prediction: Double, label: Double): Double = {
    val diff = prediction - label
    0.5 * diff * diff
  }
  
  def derivative(prediction: Double, label: Double): Double = {
    prediction - label
  }
}

Hinge Loss (for SVM):

class HingeLoss extends PartialLossFunction {
  def loss(prediction: Double, label: Double): Double = {
    val margin = label * prediction
    if (margin >= 1.0) 0.0 else 1.0 - margin
  }
  
  def derivative(prediction: Double, label: Double): Double = {
    if (label * prediction >= 1.0) 0.0 else -label
  }
}

Logistic Loss (for logistic regression):

class LogisticLoss extends PartialLossFunction {
  def loss(prediction: Double, label: Double): Double = {
    math.log(1.0 + math.exp(-label * prediction))
  }
  
  def derivative(prediction: Double, label: Double): Double = {
    val exp = math.exp(-label * prediction)
    -label * exp / (1.0 + exp)
  }
}

Prediction Functions

Prediction functions define how to compute predictions from features and weights.

trait PredictionFunction {
  def predict(features: Vector, weights: WeightVector): Double
  def gradient(features: Vector, weights: WeightVector): Vector
}

Linear Prediction Function

Standard linear prediction for linear models.

class LinearPrediction extends PredictionFunction {
  def predict(features: Vector, weights: WeightVector): Double = {
    features.dot(weights.weights) + weights.intercept
  }
  
  def gradient(features: Vector, weights: WeightVector): Vector = {
    features  // Gradient w.r.t. weights is just the features
  }
}

Regularization

Regularization penalties help prevent overfitting by adding penalty terms to the loss function.

trait RegularizationPenalty {
  def penalty(weights: WeightVector): Double
  def gradient(weights: WeightVector): Vector
}

No Regularization

case class NoRegularization() extends RegularizationPenalty {
  def penalty(weights: WeightVector): Double = 0.0
  def gradient(weights: WeightVector): Vector = DenseVector.zeros(weights.weights.size)
}

L1 Regularization (Lasso)

case class L1Regularization(lambda: Double = 0.1) extends RegularizationPenalty {
  def penalty(weights: WeightVector): Double = {
    lambda * weights.weights.toArray.map(math.abs).sum
  }
  
  def gradient(weights: WeightVector): Vector = {
    val gradArray = weights.weights.toArray.map(w => if (w > 0) lambda else if (w < 0) -lambda else 0.0)
    DenseVector(gradArray)
  }
}

L2 Regularization (Ridge)

case class L2Regularization(lambda: Double = 0.1) extends RegularizationPenalty {
  def penalty(weights: WeightVector): Double = {
    0.5 * lambda * weights.weights.dot(weights.weights)
  }
  
  def gradient(weights: WeightVector): Vector = {
    weights.weights.copy.scal(lambda)
  }
}

Custom Optimization Example

Building a custom optimization setup:

import org.apache.flink.ml.optimization._

// Define custom loss function
val customLoss = GenericLossFunction(
  partialLossFunction = new SquaredLoss(),
  predictionFunction = new LinearPrediction()
)

// Configure solver with custom settings
val optimizer = GradientDescent()
  .set(Solver.LossFunction, customLoss)
  .set(Solver.RegularizationPenaltyValue, L2Regularization(0.01))
  .set(Solver.RegularizationConstant, 1.0)
  .set(IterativeSolver.Iterations, 200)
  .set(IterativeSolver.LearningRate, 0.05)
  .set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse)
  .set(IterativeSolver.ConvergenceThreshold, 1e-10)

// Run optimization
val finalWeights = optimizer.optimize(trainingData)

Integration with ML Algorithms

The optimization framework is used internally by ML algorithms. For example, MultipleLinearRegression uses gradient descent:

import org.apache.flink.ml.regression.MultipleLinearRegression

val regression = MultipleLinearRegression()
  .setIterations(100)           // Maps to IterativeSolver.Iterations
  .setStepsize(0.01)           // Maps to IterativeSolver.LearningRate
  .setConvergenceThreshold(1e-8) // Maps to IterativeSolver.ConvergenceThreshold

// Internally uses optimization framework
val model = regression.fit(trainingData)

Distributed Optimization Considerations

The optimization framework is designed for distributed execution on Flink:

Data Partitioning

// Partition data for better distributed performance
val partitionedData = trainingData.partitionByHash(_.hashCode())
val weights = optimizer.optimize(partitionedData)

Broadcast Variables

Large model parameters can be broadcast to avoid network overhead:

// The framework automatically handles broadcast variables for efficient
// distributed gradient computation
val distributedWeights = optimizer.optimize(largeDataset)

Checkpointing

For long-running optimizations, consider checkpointing:

env.enableCheckpointing(10000) // Checkpoint every 10 seconds

val robustOptimizer = GradientDescent()
  .set(IterativeSolver.Iterations, 1000)  // Many iterations
  // ... other parameters

val weights = robustOptimizer.optimize(trainingData)

Monitoring Optimization Progress

While the framework doesn't provide built-in progress monitoring, you can track convergence:

// Custom convergence tracking (conceptual)
class ProgressTrackingGradientDescent extends GradientDescent {
  override def optimize(
    data: DataSet[LabeledVector],
    initialWeights: Option[DataSet[WeightVector]] = None
  ): DataSet[WeightVector] = {
    // Add custom logging/monitoring logic
    val result = super.optimize(data, initialWeights)
    
    // Log final loss
    val finalLoss = computeLoss(data, result)
    println(s"Final loss: $finalLoss")
    
    result
  }
}

Best Practices

  1. Learning Rate Tuning: Start with default rates and adjust based on convergence behavior
  2. Regularization: Use L2 regularization to prevent overfitting, L1 for feature selection
  3. Convergence Thresholds: Set appropriate thresholds based on your precision requirements
  4. Data Preprocessing: Normalize features for better optimization performance
  5. Monitoring: Track loss values to ensure proper convergence
// Good practice example
val optimizer = GradientDescent()
  .set(IterativeSolver.LearningRate, 0.01)                    // Conservative learning rate
  .set(IterativeSolver.LearningRateMethodValue, LearningRateMethod.Inverse) // Adaptive rate
  .set(Solver.RegularizationPenaltyValue, L2Regularization(0.01))           // Prevent overfitting
  .set(IterativeSolver.ConvergenceThreshold, 1e-8)           // Reasonable precision
  .set(IterativeSolver.Iterations, 1000)                     // Sufficient iterations

val optimizedWeights = optimizer.optimize(normalizedTrainingData)