Comprehensive machine learning library for Apache Flink that enables scalable ML pipelines on distributed stream processing platform.
—
Abstract base classes for implementing custom estimators, transformers, and models. Provides common functionality and integration patterns with parameter management and operator framework integration.
Abstract base class providing common functionality for all pipeline stages.
/**
* Base implementation for pipeline stages with parameter support
* @param <S> The concrete pipeline stage type for method chaining
*/
public abstract class PipelineStageBase<S extends PipelineStageBase<S>>
implements WithParams<S>, HasMLEnvironmentId<S>, Cloneable {
/** Stage parameters */
protected Params params;
/** Create stage with empty parameters */
public PipelineStageBase();
/** Create stage with initial parameters */
public PipelineStageBase(Params params);
/** Get all stage parameters */
public Params getParams();
/** Create deep copy of stage */
public S clone();
/** Get table environment from table (utility method) */
public static TableEnvironment tableEnvOf(Table table);
}Abstract base class for implementing custom estimators with automatic batch/stream handling.
/**
* Base implementation for estimators
* Handles conversion between Table API and operator framework
* @param <E> The concrete estimator type
* @param <M> The model type produced by this estimator
*/
public abstract class EstimatorBase<E extends EstimatorBase<E, M>,
M extends ModelBase<M>>
extends PipelineStageBase<E>
implements Estimator<E, M> {
/** Create estimator with empty parameters */
public EstimatorBase();
/** Create estimator with initial parameters */
public EstimatorBase(Params params);
/** Fit estimator with table environment and input table */
public M fit(TableEnvironment tEnv, Table input);
/** Fit estimator with input table only (uses table's environment) */
public M fit(Table input);
/** Fit estimator using batch operator (must implement) */
protected abstract M fit(BatchOperator input);
/** Fit estimator using stream operator (optional, throws UnsupportedOperationException by default) */
protected M fit(StreamOperator input) {
throw new UnsupportedOperationException("Stream fitting not supported");
}
}Implementation Example:
import org.apache.flink.ml.pipeline.EstimatorBase;
import org.apache.flink.ml.operator.batch.BatchOperator;
public class LinearRegressionEstimator
extends EstimatorBase<LinearRegressionEstimator, LinearRegressionModel> {
// Parameter definitions
public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory
.createParamInfo("featuresCol", String.class)
.setDescription("Features column name")
.setHasDefaultValue("features")
.build();
public static final ParamInfo<String> LABEL_COL = ParamInfoFactory
.createParamInfo("labelCol", String.class)
.setDescription("Label column name")
.setHasDefaultValue("label")
.build();
public static final ParamInfo<Double> REG_PARAM = ParamInfoFactory
.createParamInfo("regParam", Double.class)
.setDescription("Regularization parameter")
.setHasDefaultValue(0.0)
.build();
// Convenience methods
public LinearRegressionEstimator setFeaturesCol(String featuresCol) {
return set(FEATURES_COL, featuresCol);
}
public String getFeaturesCol() {
return get(FEATURES_COL);
}
public LinearRegressionEstimator setLabelCol(String labelCol) {
return set(LABEL_COL, labelCol);
}
public String getLabelCol() {
return get(LABEL_COL);
}
public LinearRegressionEstimator setRegParam(double regParam) {
return set(REG_PARAM, regParam);
}
public double getRegParam() {
return get(REG_PARAM);
}
@Override
protected LinearRegressionModel fit(BatchOperator input) {
// Extract parameters
String featuresCol = getFeaturesCol();
String labelCol = getLabelCol();
double regParam = getRegParam();
// Implement training logic using batch operators
BatchOperator<?> trainedOperator = input
.link(new FeatureExtractorBatchOp()
.setSelectedCols(featuresCol)
.setOutputCol("extractedFeatures"))
.link(new LinearRegressionTrainBatchOp()
.setFeaturesCol("extractedFeatures")
.setLabelCol(labelCol)
.setRegParam(regParam));
// Extract model data
Table modelData = trainedOperator.getOutput();
// Create and return model
return new LinearRegressionModel(this.getParams())
.setModelData(modelData);
}
}
// Usage
LinearRegressionEstimator estimator = new LinearRegressionEstimator()
.setFeaturesCol("input_features")
.setLabelCol("target")
.setRegParam(0.01);
LinearRegressionModel model = estimator.fit(trainingTable);Abstract base class for implementing custom transformers with batch and stream support.
/**
* Base implementation for transformers
* Provides both batch and stream transformation capabilities
* @param <T> The concrete transformer type
*/
public abstract class TransformerBase<T extends TransformerBase<T>>
extends PipelineStageBase<T>
implements Transformer<T> {
/** Create transformer with empty parameters */
public TransformerBase();
/** Create transformer with initial parameters */
public TransformerBase(Params params);
/** Transform data with table environment */
public Table transform(TableEnvironment tEnv, Table input);
/** Transform data using input table's environment */
public Table transform(Table input);
/** Transform data using batch operator (must implement) */
protected abstract BatchOperator transform(BatchOperator input);
/** Transform data using stream operator (must implement) */
protected abstract StreamOperator transform(StreamOperator input);
}Implementation Example:
import org.apache.flink.ml.pipeline.TransformerBase;
public class StandardScaler extends TransformerBase<StandardScaler> {
// Parameter definitions
public static final ParamInfo<String> INPUT_COL = ParamInfoFactory
.createParamInfo("inputCol", String.class)
.setDescription("Input column to scale")
.setRequired()
.build();
public static final ParamInfo<String> OUTPUT_COL = ParamInfoFactory
.createParamInfo("outputCol", String.class)
.setDescription("Output column for scaled data")
.setRequired()
.build();
public static final ParamInfo<Boolean> WITH_MEAN = ParamInfoFactory
.createParamInfo("withMean", Boolean.class)
.setDescription("Center data to zero mean")
.setHasDefaultValue(true)
.build();
public static final ParamInfo<Boolean> WITH_STD = ParamInfoFactory
.createParamInfo("withStd", Boolean.class)
.setDescription("Scale data to unit variance")
.setHasDefaultValue(true)
.build();
// Convenience methods
public StandardScaler setInputCol(String inputCol) {
return set(INPUT_COL, inputCol);
}
public String getInputCol() {
return get(INPUT_COL);
}
public StandardScaler setOutputCol(String outputCol) {
return set(OUTPUT_COL, outputCol);
}
public String getOutputCol() {
return get(OUTPUT_COL);
}
public StandardScaler setWithMean(boolean withMean) {
return set(WITH_MEAN, withMean);
}
public boolean getWithMean() {
return get(WITH_MEAN);
}
public StandardScaler setWithStd(boolean withStd) {
return set(WITH_STD, withStd);
}
public boolean getWithStd() {
return get(WITH_STD);
}
@Override
protected BatchOperator transform(BatchOperator input) {
return input.link(new StandardScalerBatchOp()
.setSelectedCol(getInputCol())
.setOutputCol(getOutputCol())
.setWithMean(getWithMean())
.setWithStd(getWithStd()));
}
@Override
protected StreamOperator transform(StreamOperator input) {
return input.link(new StandardScalerStreamOp()
.setSelectedCol(getInputCol())
.setOutputCol(getOutputCol())
.setWithMean(getWithMean())
.setWithStd(getWithStd()));
}
}
// Usage
StandardScaler scaler = new StandardScaler()
.setInputCol("raw_features")
.setOutputCol("scaled_features")
.setWithMean(true)
.setWithStd(true);
// Works with both batch and stream data
Table scaledBatchData = scaler.transform(batchData);
Table scaledStreamData = scaler.transform(streamData);Abstract base class for implementing trained models with serialization support.
/**
* Base implementation for trained models
* Extends TransformerBase and adds model data management
* @param <M> The concrete model type
*/
public abstract class ModelBase<M extends ModelBase<M>>
extends TransformerBase<M>
implements Model<M> {
/** Model data table */
protected Table modelData;
/** Create model with empty parameters */
public ModelBase();
/** Create model with initial parameters */
public ModelBase(Params params);
/** Get model data table */
public Table getModelData();
/** Set model data table */
public M setModelData(Table modelData);
/** Clone model with model data */
public M clone();
}Implementation Example:
import org.apache.flink.ml.pipeline.ModelBase;
public class LinearRegressionModel extends ModelBase<LinearRegressionModel> {
// Same parameter definitions as estimator for consistency
public static final ParamInfo<String> FEATURES_COL = LinearRegressionEstimator.FEATURES_COL;
public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory
.createParamInfo("predictionCol", String.class)
.setDescription("Prediction output column")
.setHasDefaultValue("prediction")
.build();
public LinearRegressionModel() {
super();
}
public LinearRegressionModel(Params params) {
super(params);
}
// Convenience methods
public LinearRegressionModel setFeaturesCol(String featuresCol) {
return set(FEATURES_COL, featuresCol);
}
public String getFeaturesCol() {
return get(FEATURES_COL);
}
public LinearRegressionModel setPredictionCol(String predictionCol) {
return set(PREDICTION_COL, predictionCol);
}
public String getPredictionCol() {
return get(PREDICTION_COL);
}
@Override
protected BatchOperator transform(BatchOperator input) {
return input.link(new LinearRegressionPredictBatchOp()
.setModelData(this.getModelData())
.setFeaturesCol(getFeaturesCol())
.setPredictionCol(getPredictionCol()));
}
@Override
protected StreamOperator transform(StreamOperator input) {
return input.link(new LinearRegressionPredictStreamOp()
.setModelData(this.getModelData())
.setFeaturesCol(getFeaturesCol())
.setPredictionCol(getPredictionCol()));
}
}
// Usage
LinearRegressionModel model = // ... obtained from estimator
model.setPredictionCol("my_predictions");
Table predictions = model.transform(testData);The base classes are designed to work together in estimator-model pairs:
// The estimator trains and produces the model
public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
@Override
protected MyModel fit(BatchOperator input) {
// Training logic
Table modelData = // ... train and extract model data
return new MyModel(this.getParams()).setModelData(modelData);
}
}
// The model applies the trained logic
public class MyModel extends ModelBase<MyModel> {
@Override
protected BatchOperator transform(BatchOperator input) {
// Apply model using this.getModelData()
return // ... prediction logic
}
@Override
protected StreamOperator transform(StreamOperator input) {
// Apply model to streaming data
return // ... streaming prediction logic
}
}Maintain parameter consistency between estimators and models:
public class ParameterDefinitions {
// Shared parameter definitions
public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory
.createParamInfo("featuresCol", String.class)
.setHasDefaultValue("features")
.build();
public static final ParamInfo<Integer> MAX_ITER = ParamInfoFactory
.createParamInfo("maxIter", Integer.class)
.setHasDefaultValue(100)
.build();
}
public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
// Use shared parameters
public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;
public static final ParamInfo<Integer> MAX_ITER = ParameterDefinitions.MAX_ITER;
// Estimator-specific parameters
public static final ParamInfo<String> LABEL_COL = ParamInfoFactory
.createParamInfo("labelCol", String.class)
.setHasDefaultValue("label")
.build();
}
public class MyModel extends ModelBase<MyModel> {
// Use shared parameters
public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;
// Model-specific parameters
public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory
.createParamInfo("predictionCol", String.class)
.setHasDefaultValue("prediction")
.build();
}Implement both batch and stream operations for maximum flexibility:
public class FlexibleTransformer extends TransformerBase<FlexibleTransformer> {
@Override
protected BatchOperator transform(BatchOperator input) {
// Batch-optimized implementation
return input.link(new MyBatchOp()
.setParallelism(4) // Higher parallelism for batch
.setBufferSize(1000)); // Larger buffers for batch
}
@Override
protected StreamOperator transform(StreamOperator input) {
// Stream-optimized implementation
return input.link(new MyStreamOp()
.setParallelism(1) // Lower parallelism for low latency
.setBufferSize(10)); // Smaller buffers for real-time
}
}Implement robust error handling in base class implementations:
public class RobustEstimator extends EstimatorBase<RobustEstimator, RobustModel> {
@Override
protected RobustModel fit(BatchOperator input) {
try {
// Validate input parameters
validateParameters();
// Check input data schema
validateInputSchema(input.getSchema());
// Perform training
BatchOperator<?> trained = input.link(new TrainingOp());
Table modelData = trained.getOutput();
return new RobustModel(this.getParams()).setModelData(modelData);
} catch (Exception e) {
throw new RuntimeException("Training failed: " + e.getMessage(), e);
}
}
private void validateParameters() {
// Parameter validation logic
if (get(MAX_ITER) <= 0) {
throw new IllegalArgumentException("maxIter must be positive");
}
}
private void validateInputSchema(TableSchema schema) {
// Schema validation logic
String featuresCol = getFeaturesCol();
if (!Arrays.asList(schema.getFieldNames()).contains(featuresCol)) {
throw new IllegalArgumentException("Features column not found: " + featuresCol);
}
}
}This comprehensive base class system provides a solid foundation for implementing custom ML algorithms while maintaining consistency with the broader Flink ML ecosystem.
Install with Tessl CLI
npx tessl i tessl/maven-org-apache-flink--flink-ml-uber-2-11