Apache Flink ML provides a pipeline framework that enables building complex machine learning workflows by chaining transformers and predictors. The framework follows a similar design to scikit-learn with Estimator, Predictor, and Transformer abstractions.
Base trait for components that can be fitted to data to learn parameters.
trait Estimator[Self] extends WithParameters {
def fit[Training](training: DataSet[Training])(implicit fitOperation: FitOperation[Self, Training]): Self
}All machine learning algorithms in Flink ML extend Estimator to provide the fit method for training.
Extends Estimator to add prediction capabilities for supervised learning algorithms.
trait Predictor[Self] extends Estimator[Self] {
def predict[Testing](
testing: DataSet[Testing]
)(implicit predictOperation: PredictDataSetOperation[Self, Testing, Prediction]): DataSet[Prediction]
def evaluate[Testing, Prediction](
testing: DataSet[Testing]
)(implicit evaluateOperation: EvaluateDataSetOperation[Self, Testing, Prediction]): DataSet[Prediction]
}Usage Example:
import org.apache.flink.ml.classification.SVM
val svm: SVM = SVM()
.setIterations(100)
.setRegularization(0.01)
// Fit the predictor
val trainedModel = svm.fit(trainingData)
// Make predictions
val predictions = trainedModel.predict(testData)
// Evaluate model performance
val evaluationResults = trainedModel.evaluate(testData)Base trait for unsupervised learning components that transform data without learning from labels.
trait Transformer[Self] extends WithParameters {
def transform[Input](
input: DataSet[Input]
)(implicit transformOperation: TransformDataSetOperation[Self, Input, Output]): DataSet[Output]
def chainTransformer[T <: Transformer[T]](transformer: T): ChainedTransformer[Self, T]
def chainPredictor[P <: Predictor[P]](predictor: P): ChainedPredictor[Self, P]
}Usage Example:
import org.apache.flink.ml.preprocessing.StandardScaler
val scaler: StandardScaler = StandardScaler()
.setMean(true)
.setStd(true)
// Fit the transformer (learn mean and std)
val fittedScaler = scaler.fit(trainingData)
// Transform data
val scaledData = fittedScaler.transform(trainingData)Combines multiple transformers into a single pipeline component.
case class ChainedTransformer[L <: Transformer[L], R <: Transformer[R]](
left: L,
right: R
) extends Transformer[ChainedTransformer[L, R]] {
def transform[Input](input: DataSet[Input]): DataSet[Output]
}Combines transformers with a final predictor for end-to-end ML pipelines.
case class ChainedPredictor[T <: Transformer[T], P <: Predictor[P]](
transformer: T,
predictor: P
) extends Predictor[ChainedPredictor[T, P]] {
def fit[Training](training: DataSet[Training]): ChainedPredictor[T, P]
def predict[Testing](testing: DataSet[Testing]): DataSet[Prediction]
}import org.apache.flink.ml.preprocessing.{StandardScaler, MinMaxScaler, PolynomialFeatures}
val scaler1 = MinMaxScaler().setMin(0.0).setMax(1.0)
val polyFeatures = PolynomialFeatures().setDegree(2)
val scaler2 = StandardScaler()
// Chain transformers
val preprocessingPipeline = scaler1
.chainTransformer(polyFeatures)
.chainTransformer(scaler2)
// Fit and transform
val fittedPipeline = preprocessingPipeline.fit(trainingData)
val transformedData = fittedPipeline.transform(trainingData)import org.apache.flink.ml.preprocessing.StandardScaler
import org.apache.flink.ml.classification.SVM
val scaler = StandardScaler()
val svm = SVM().setIterations(100).setRegularization(0.01)
// Create end-to-end pipeline
val mlPipeline = scaler.chainPredictor(svm)
// Fit entire pipeline
val trainedPipeline = mlPipeline.fit(trainingData)
// Make predictions (automatically applies scaling then SVM)
val predictions = trainedPipeline.predict(testData)import org.apache.flink.ml.preprocessing.{StandardScaler, PolynomialFeatures}
import org.apache.flink.ml.regression.MultipleLinearRegression
// Multi-stage preprocessing
val minMaxScaler = MinMaxScaler().setMin(-1.0).setMax(1.0)
val polyFeatures = PolynomialFeatures().setDegree(3)
val standardScaler = StandardScaler()
// Regression model
val regression = MultipleLinearRegression()
.setIterations(200)
.setStepsize(0.01)
// Build complex pipeline
val complexPipeline = minMaxScaler
.chainTransformer(polyFeatures)
.chainTransformer(standardScaler)
.chainPredictor(regression)
// Train pipeline
val trainedComplexPipeline = complexPipeline.fit(trainingData)
// Use pipeline
val predictions = trainedComplexPipeline.predict(testData)The pipeline framework uses type classes to provide flexible operations for different data types.
trait FitOperation[Self, Training] {
def fit(instance: Self, fitParameters: ParameterMap, input: DataSet[Training]): Self
}trait TransformOperation[Instance, Model, Input, Output] {
def transform(
instance: Instance,
model: Model,
input: Input
): Output
}
trait TransformDataSetOperation[Instance, Input, Output] {
def transformDataSet(
instance: Instance,
transformParameters: ParameterMap,
input: DataSet[Input]
): DataSet[Output]
}trait PredictOperation[Instance, Model, Testing, Prediction] {
def predict(instance: Instance, model: Model, testing: Testing): Prediction
}
trait PredictDataSetOperation[Instance, Testing, Prediction] {
def predictDataSet(
instance: Instance,
predictParameters: ParameterMap,
testing: DataSet[Testing]
): DataSet[Prediction]
}trait EvaluateDataSetOperation[Instance, Testing, Prediction] {
def evaluateDataSet(
instance: Instance,
evaluateParameters: ParameterMap,
testing: DataSet[Testing]
): DataSet[Prediction]
}Pipelines preserve and merge parameters from all components.
// Configure individual components
val scaler = StandardScaler()
.setMean(true)
.setStd(false)
val svm = SVM()
.setIterations(50)
.setRegularization(0.1)
// Create pipeline - parameters are preserved
val pipeline = scaler.chainPredictor(svm)
// Access combined parameters
val allParameters = pipeline.parameters
// You can still modify parameters of the chained pipeline
val modifiedPipeline = pipeline.set(SVM.Iterations, 100)You can create custom transformers and predictors by implementing the respective traits.
import org.apache.flink.ml.pipeline.Transformer
import org.apache.flink.ml.common.WithParameters
class LogTransformer extends Transformer[LogTransformer] with WithParameters {
def transform[Input](input: DataSet[Input]): DataSet[Output] = {
// Implementation depends on implicit TransformDataSetOperation
transformDataSet(this, parameters, input)
}
}
object LogTransformer {
def apply(): LogTransformer = new LogTransformer()
// Define implicit operations
implicit val logTransformVectors = new TransformDataSetOperation[LogTransformer, Vector, Vector] {
def transformDataSet(
instance: LogTransformer,
transformParameters: ParameterMap,
input: DataSet[Vector]
): DataSet[Vector] = {
input.map(vector => {
val logData = vector.toArray.map(x => if (x > 0) math.log(x) else 0.0)
DenseVector(logData)
})
}
}
}
// Usage
val logTransform = LogTransformer()
val scaler = StandardScaler()
val pipeline = logTransform.chainTransformer(scaler)While the core framework doesn't provide built-in persistence, you can save pipeline parameters and recreate pipelines:
// Save pipeline configuration
val scaler = StandardScaler().setMean(true).setStd(true)
val svm = SVM().setIterations(100).setRegularization(0.01)
val pipeline = scaler.chainPredictor(svm)
// Extract parameters for serialization
val scalerParams = scaler.parameters
val svmParams = svm.parameters
// Recreate pipeline later
val recreatedScaler = StandardScaler().setParameters(scalerParams)
val recreatedSVM = SVM().setParameters(svmParams)
val recreatedPipeline = recreatedScaler.chainPredictor(recreatedSVM)Pipeline components should handle errors gracefully:
try {
val trainedPipeline = pipeline.fit(trainingData)
val predictions = trainedPipeline.predict(testData)
} catch {
case e: IllegalArgumentException =>
println(s"Invalid parameters: ${e.getMessage}")
case e: RuntimeException =>
println(s"Runtime error in pipeline: ${e.getMessage}")
}// Good practice example
val pipeline = StandardScaler()
.chainTransformer(PolynomialFeatures().setDegree(2))
.chainPredictor(SVM().setIterations(100))
// Fit on training data only
val trainedPipeline = pipeline.fit(trainingData)
// Apply to test data
val predictions = trainedPipeline.predict(testData)