CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/maven-org-apache-flink--flink-ml-uber-2-11

Comprehensive machine learning library for Apache Flink that enables scalable ML pipelines on distributed stream processing platform.

Pending
Overview
Eval results
Files

pipeline-base-classes.mddocs/

Pipeline Base Classes

Abstract base classes for implementing custom estimators, transformers, and models. Provides common functionality and integration patterns with parameter management and operator framework integration.

Capabilities

PipelineStageBase Class

Abstract base class providing common functionality for all pipeline stages.

/**
 * Base implementation for pipeline stages with parameter support
 * @param <S> The concrete pipeline stage type for method chaining
 */
public abstract class PipelineStageBase<S extends PipelineStageBase<S>> 
                      implements WithParams<S>, HasMLEnvironmentId<S>, Cloneable {
    
    /** Stage parameters */
    protected Params params;
    
    /** Create stage with empty parameters */
    public PipelineStageBase();
    
    /** Create stage with initial parameters */
    public PipelineStageBase(Params params);
    
    /** Get all stage parameters */
    public Params getParams();
    
    /** Create deep copy of stage */
    public S clone();
    
    /** Get table environment from table (utility method) */
    public static TableEnvironment tableEnvOf(Table table);
}

EstimatorBase Class

Abstract base class for implementing custom estimators with automatic batch/stream handling.

/**
 * Base implementation for estimators
 * Handles conversion between Table API and operator framework
 * @param <E> The concrete estimator type
 * @param <M> The model type produced by this estimator
 */
public abstract class EstimatorBase<E extends EstimatorBase<E, M>, 
                                   M extends ModelBase<M>> 
                      extends PipelineStageBase<E> 
                      implements Estimator<E, M> {
    
    /** Create estimator with empty parameters */
    public EstimatorBase();
    
    /** Create estimator with initial parameters */
    public EstimatorBase(Params params);
    
    /** Fit estimator with table environment and input table */
    public M fit(TableEnvironment tEnv, Table input);
    
    /** Fit estimator with input table only (uses table's environment) */
    public M fit(Table input);
    
    /** Fit estimator using batch operator (must implement) */
    protected abstract M fit(BatchOperator input);
    
    /** Fit estimator using stream operator (optional, throws UnsupportedOperationException by default) */
    protected M fit(StreamOperator input) {
        throw new UnsupportedOperationException("Stream fitting not supported");
    }
}

Implementation Example:

import org.apache.flink.ml.pipeline.EstimatorBase;
import org.apache.flink.ml.operator.batch.BatchOperator;

public class LinearRegressionEstimator 
       extends EstimatorBase<LinearRegressionEstimator, LinearRegressionModel> {
    
    // Parameter definitions
    public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory
        .createParamInfo("featuresCol", String.class)
        .setDescription("Features column name")
        .setHasDefaultValue("features")
        .build();
    
    public static final ParamInfo<String> LABEL_COL = ParamInfoFactory
        .createParamInfo("labelCol", String.class)
        .setDescription("Label column name")
        .setHasDefaultValue("label")
        .build();
    
    public static final ParamInfo<Double> REG_PARAM = ParamInfoFactory
        .createParamInfo("regParam", Double.class)
        .setDescription("Regularization parameter")
        .setHasDefaultValue(0.0)
        .build();
    
    // Convenience methods
    public LinearRegressionEstimator setFeaturesCol(String featuresCol) {
        return set(FEATURES_COL, featuresCol);
    }
    
    public String getFeaturesCol() {
        return get(FEATURES_COL);
    }
    
    public LinearRegressionEstimator setLabelCol(String labelCol) {
        return set(LABEL_COL, labelCol);
    }
    
    public String getLabelCol() {
        return get(LABEL_COL);
    }
    
    public LinearRegressionEstimator setRegParam(double regParam) {
        return set(REG_PARAM, regParam);
    }
    
    public double getRegParam() {
        return get(REG_PARAM);
    }
    
    @Override
    protected LinearRegressionModel fit(BatchOperator input) {
        // Extract parameters
        String featuresCol = getFeaturesCol();
        String labelCol = getLabelCol();
        double regParam = getRegParam();
        
        // Implement training logic using batch operators
        BatchOperator<?> trainedOperator = input
            .link(new FeatureExtractorBatchOp()
                .setSelectedCols(featuresCol)
                .setOutputCol("extractedFeatures"))
            .link(new LinearRegressionTrainBatchOp()
                .setFeaturesCol("extractedFeatures")
                .setLabelCol(labelCol)
                .setRegParam(regParam));
        
        // Extract model data
        Table modelData = trainedOperator.getOutput();
        
        // Create and return model
        return new LinearRegressionModel(this.getParams())
            .setModelData(modelData);
    }
}

// Usage
LinearRegressionEstimator estimator = new LinearRegressionEstimator()
    .setFeaturesCol("input_features")
    .setLabelCol("target")
    .setRegParam(0.01);

LinearRegressionModel model = estimator.fit(trainingTable);

TransformerBase Class

Abstract base class for implementing custom transformers with batch and stream support.

/**
 * Base implementation for transformers
 * Provides both batch and stream transformation capabilities
 * @param <T> The concrete transformer type
 */
public abstract class TransformerBase<T extends TransformerBase<T>> 
                      extends PipelineStageBase<T> 
                      implements Transformer<T> {
    
    /** Create transformer with empty parameters */
    public TransformerBase();
    
    /** Create transformer with initial parameters */
    public TransformerBase(Params params);
    
    /** Transform data with table environment */
    public Table transform(TableEnvironment tEnv, Table input);
    
    /** Transform data using input table's environment */
    public Table transform(Table input);
    
    /** Transform data using batch operator (must implement) */
    protected abstract BatchOperator transform(BatchOperator input);
    
    /** Transform data using stream operator (must implement) */
    protected abstract StreamOperator transform(StreamOperator input);
}

Implementation Example:

import org.apache.flink.ml.pipeline.TransformerBase;

public class StandardScaler extends TransformerBase<StandardScaler> {
    
    // Parameter definitions
    public static final ParamInfo<String> INPUT_COL = ParamInfoFactory
        .createParamInfo("inputCol", String.class)
        .setDescription("Input column to scale")
        .setRequired()
        .build();
    
    public static final ParamInfo<String> OUTPUT_COL = ParamInfoFactory
        .createParamInfo("outputCol", String.class)
        .setDescription("Output column for scaled data")
        .setRequired()
        .build();
    
    public static final ParamInfo<Boolean> WITH_MEAN = ParamInfoFactory
        .createParamInfo("withMean", Boolean.class)
        .setDescription("Center data to zero mean")
        .setHasDefaultValue(true)
        .build();
    
    public static final ParamInfo<Boolean> WITH_STD = ParamInfoFactory
        .createParamInfo("withStd", Boolean.class)
        .setDescription("Scale data to unit variance")
        .setHasDefaultValue(true)
        .build();
    
    // Convenience methods
    public StandardScaler setInputCol(String inputCol) {
        return set(INPUT_COL, inputCol);
    }
    
    public String getInputCol() {
        return get(INPUT_COL);
    }
    
    public StandardScaler setOutputCol(String outputCol) {
        return set(OUTPUT_COL, outputCol);
    }
    
    public String getOutputCol() {
        return get(OUTPUT_COL);
    }
    
    public StandardScaler setWithMean(boolean withMean) {
        return set(WITH_MEAN, withMean);
    }
    
    public boolean getWithMean() {
        return get(WITH_MEAN);
    }
    
    public StandardScaler setWithStd(boolean withStd) {
        return set(WITH_STD, withStd);
    }
    
    public boolean getWithStd() {
        return get(WITH_STD);
    }
    
    @Override
    protected BatchOperator transform(BatchOperator input) {
        return input.link(new StandardScalerBatchOp()
            .setSelectedCol(getInputCol())
            .setOutputCol(getOutputCol())
            .setWithMean(getWithMean())
            .setWithStd(getWithStd()));
    }
    
    @Override
    protected StreamOperator transform(StreamOperator input) {
        return input.link(new StandardScalerStreamOp()
            .setSelectedCol(getInputCol())
            .setOutputCol(getOutputCol())
            .setWithMean(getWithMean())
            .setWithStd(getWithStd()));
    }
}

// Usage
StandardScaler scaler = new StandardScaler()
    .setInputCol("raw_features")
    .setOutputCol("scaled_features")
    .setWithMean(true)
    .setWithStd(true);

// Works with both batch and stream data
Table scaledBatchData = scaler.transform(batchData);
Table scaledStreamData = scaler.transform(streamData);

ModelBase Class

Abstract base class for implementing trained models with serialization support.

/**
 * Base implementation for trained models
 * Extends TransformerBase and adds model data management
 * @param <M> The concrete model type
 */
public abstract class ModelBase<M extends ModelBase<M>> 
                      extends TransformerBase<M> 
                      implements Model<M> {
    
    /** Model data table */
    protected Table modelData;
    
    /** Create model with empty parameters */
    public ModelBase();
    
    /** Create model with initial parameters */
    public ModelBase(Params params);
    
    /** Get model data table */
    public Table getModelData();
    
    /** Set model data table */
    public M setModelData(Table modelData);
    
    /** Clone model with model data */
    public M clone();
}

Implementation Example:

import org.apache.flink.ml.pipeline.ModelBase;

public class LinearRegressionModel extends ModelBase<LinearRegressionModel> {
    
    // Same parameter definitions as estimator for consistency
    public static final ParamInfo<String> FEATURES_COL = LinearRegressionEstimator.FEATURES_COL;
    public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory
        .createParamInfo("predictionCol", String.class)
        .setDescription("Prediction output column")
        .setHasDefaultValue("prediction")
        .build();
    
    public LinearRegressionModel() {
        super();
    }
    
    public LinearRegressionModel(Params params) {
        super(params);
    }
    
    // Convenience methods
    public LinearRegressionModel setFeaturesCol(String featuresCol) {
        return set(FEATURES_COL, featuresCol);
    }
    
    public String getFeaturesCol() {
        return get(FEATURES_COL);
    }
    
    public LinearRegressionModel setPredictionCol(String predictionCol) {
        return set(PREDICTION_COL, predictionCol);
    }
    
    public String getPredictionCol() {
        return get(PREDICTION_COL);
    }
    
    @Override
    protected BatchOperator transform(BatchOperator input) {
        return input.link(new LinearRegressionPredictBatchOp()
            .setModelData(this.getModelData())
            .setFeaturesCol(getFeaturesCol())
            .setPredictionCol(getPredictionCol()));
    }
    
    @Override
    protected StreamOperator transform(StreamOperator input) {
        return input.link(new LinearRegressionPredictStreamOp()
            .setModelData(this.getModelData())
            .setFeaturesCol(getFeaturesCol())
            .setPredictionCol(getPredictionCol()));
    }
}

// Usage
LinearRegressionModel model = // ... obtained from estimator
model.setPredictionCol("my_predictions");

Table predictions = model.transform(testData);

Base Class Integration Patterns

Estimator-Model Pairing

The base classes are designed to work together in estimator-model pairs:

// The estimator trains and produces the model
public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
    @Override
    protected MyModel fit(BatchOperator input) {
        // Training logic
        Table modelData = // ... train and extract model data
        
        return new MyModel(this.getParams()).setModelData(modelData);
    }
}

// The model applies the trained logic
public class MyModel extends ModelBase<MyModel> {
    @Override
    protected BatchOperator transform(BatchOperator input) {
        // Apply model using this.getModelData()
        return // ... prediction logic
    }
    
    @Override
    protected StreamOperator transform(StreamOperator input) {
        // Apply model to streaming data
        return // ... streaming prediction logic
    }
}

Parameter Consistency

Maintain parameter consistency between estimators and models:

public class ParameterDefinitions {
    // Shared parameter definitions
    public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory
        .createParamInfo("featuresCol", String.class)
        .setHasDefaultValue("features")
        .build();
    
    public static final ParamInfo<Integer> MAX_ITER = ParamInfoFactory
        .createParamInfo("maxIter", Integer.class)
        .setHasDefaultValue(100)
        .build();
}

public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
    // Use shared parameters
    public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;
    public static final ParamInfo<Integer> MAX_ITER = ParameterDefinitions.MAX_ITER;
    
    // Estimator-specific parameters
    public static final ParamInfo<String> LABEL_COL = ParamInfoFactory
        .createParamInfo("labelCol", String.class)
        .setHasDefaultValue("label")
        .build();
}

public class MyModel extends ModelBase<MyModel> {
    // Use shared parameters
    public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;
    
    // Model-specific parameters
    public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory
        .createParamInfo("predictionCol", String.class)
        .setHasDefaultValue("prediction")
        .build();
}

Batch and Stream Support

Implement both batch and stream operations for maximum flexibility:

public class FlexibleTransformer extends TransformerBase<FlexibleTransformer> {
    
    @Override
    protected BatchOperator transform(BatchOperator input) {
        // Batch-optimized implementation
        return input.link(new MyBatchOp()
            .setParallelism(4)  // Higher parallelism for batch
            .setBufferSize(1000)); // Larger buffers for batch
    }
    
    @Override
    protected StreamOperator transform(StreamOperator input) {
        // Stream-optimized implementation  
        return input.link(new MyStreamOp()
            .setParallelism(1)   // Lower parallelism for low latency
            .setBufferSize(10)); // Smaller buffers for real-time
    }
}

Error Handling

Implement robust error handling in base class implementations:

public class RobustEstimator extends EstimatorBase<RobustEstimator, RobustModel> {
    
    @Override
    protected RobustModel fit(BatchOperator input) {
        try {
            // Validate input parameters
            validateParameters();
            
            // Check input data schema
            validateInputSchema(input.getSchema());
            
            // Perform training
            BatchOperator<?> trained = input.link(new TrainingOp());
            
            Table modelData = trained.getOutput();
            
            return new RobustModel(this.getParams()).setModelData(modelData);
            
        } catch (Exception e) {
            throw new RuntimeException("Training failed: " + e.getMessage(), e);
        }
    }
    
    private void validateParameters() {
        // Parameter validation logic
        if (get(MAX_ITER) <= 0) {
            throw new IllegalArgumentException("maxIter must be positive");
        }
    }
    
    private void validateInputSchema(TableSchema schema) {
        // Schema validation logic
        String featuresCol = getFeaturesCol();
        if (!Arrays.asList(schema.getFieldNames()).contains(featuresCol)) {
            throw new IllegalArgumentException("Features column not found: " + featuresCol);
        }
    }
}

This comprehensive base class system provides a solid foundation for implementing custom ML algorithms while maintaining consistency with the broader Flink ML ecosystem.

Install with Tessl CLI

npx tessl i tessl/maven-org-apache-flink--flink-ml-uber-2-11

docs

algorithm-operators.md

environment-management.md

index.md

linear-algebra.md

parameter-system.md

pipeline-base-classes.md

pipeline-framework.md

utility-libraries.md

tile.json