0
# ML Pipeline Framework
1
2
Estimator/Transformer pattern for building machine learning workflows with type-safe parameter management and support for both batch and stream processing. The pipeline framework provides the foundation for creating reusable ML components.
3
4
## Capabilities
5
6
### PipelineStageBase Abstract Class
7
8
Base class for all pipeline stages providing common functionality for parameter management, environment handling, and cloning operations.
9
10
```java { .api }
11
/**
12
* Base class for pipeline stages (estimators and transformers)
13
*/
14
public abstract class PipelineStageBase<S extends PipelineStageBase<S>>
15
implements WithParams<S>, HasMLEnvironmentId<S>, Cloneable {
16
17
/** Parameter storage */
18
protected Params params;
19
20
/** Default constructor */
21
public PipelineStageBase();
22
23
/** Constructor with parameters */
24
public PipelineStageBase(Params params);
25
26
/** Get parameters */
27
public Params getParams();
28
29
/** Clone pipeline stage */
30
public S clone();
31
32
/** Get table environment from table */
33
protected static TableEnvironment tableEnvOf(Table table);
34
}
35
```
36
37
### EstimatorBase Abstract Class
38
39
Base class for estimator implementations that train models from data. Estimators implement the fit() operation to produce trained models.
40
41
```java { .api }
42
/**
43
* Base class for estimator implementations that train models
44
*/
45
public abstract class EstimatorBase<E extends EstimatorBase<E, M>, M extends ModelBase<M>>
46
extends PipelineStageBase<E> implements Estimator<E, M> {
47
48
/** Default constructor */
49
public EstimatorBase();
50
51
/** Constructor with parameters */
52
public EstimatorBase(Params params);
53
54
/**
55
* Fit model with explicit table environment
56
* @param tEnv Table environment to use
57
* @param input Input data table
58
* @return Trained model
59
*/
60
public M fit(TableEnvironment tEnv, Table input);
61
62
/**
63
* Fit model from table (uses table's environment)
64
* @param input Input data table
65
* @return Trained model
66
*/
67
public M fit(Table input);
68
69
/**
70
* Fit model from batch operator (abstract - implemented by subclasses)
71
* @param input Input batch operator
72
* @return Trained model
73
*/
74
protected abstract M fit(BatchOperator input);
75
76
/**
77
* Fit model from stream operator (default throws UnsupportedOperationException)
78
* @param input Input stream operator
79
* @return Trained model
80
*/
81
protected M fit(StreamOperator input);
82
}
83
```
84
85
**Usage Examples:**
86
87
```java
88
import org.apache.flink.ml.pipeline.EstimatorBase;
89
import org.apache.flink.table.api.Table;
90
91
// Example custom estimator implementation
92
public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
93
94
public MyEstimator() {
95
super();
96
}
97
98
@Override
99
protected MyModel fit(BatchOperator input) {
100
// Training logic here
101
// Access parameters: getParams().get(PARAM_NAME)
102
103
// Create and return trained model
104
MyModel model = new MyModel();
105
// Set model parameters and data
106
return model;
107
}
108
}
109
110
// Usage
111
MyEstimator estimator = new MyEstimator()
112
.setMLEnvironmentId(envId)
113
.setParam("learningRate", 0.01);
114
115
Table trainingData = getTrainingData();
116
MyModel trainedModel = estimator.fit(trainingData);
117
```
118
119
### TransformerBase Abstract Class
120
121
Base class for transformer implementations that transform data. Transformers implement the transform() operation to process input data.
122
123
```java { .api }
124
/**
125
* Base class for transformer implementations that transform data
126
*/
127
public abstract class TransformerBase<T extends TransformerBase<T>>
128
extends PipelineStageBase<T> implements Transformer<T> {
129
130
/** Default constructor */
131
public TransformerBase();
132
133
/** Constructor with parameters */
134
public TransformerBase(Params params);
135
136
/**
137
* Transform data with explicit table environment
138
* @param tEnv Table environment to use
139
* @param input Input data table
140
* @return Transformed data table
141
*/
142
public Table transform(TableEnvironment tEnv, Table input);
143
144
/**
145
* Transform table (uses table's environment)
146
* @param input Input data table
147
* @return Transformed data table
148
*/
149
public Table transform(Table input);
150
151
/**
152
* Transform batch data (abstract - implemented by subclasses)
153
* @param input Input batch operator
154
* @return Transformed batch operator
155
*/
156
protected abstract BatchOperator transform(BatchOperator input);
157
158
/**
159
* Transform stream data (abstract - implemented by subclasses)
160
* @param input Input stream operator
161
* @return Transformed stream operator
162
*/
163
protected abstract StreamOperator transform(StreamOperator input);
164
}
165
```
166
167
**Usage Examples:**
168
169
```java
170
import org.apache.flink.ml.pipeline.TransformerBase;
171
import org.apache.flink.table.api.Table;
172
173
// Example custom transformer implementation
174
public class MyTransformer extends TransformerBase<MyTransformer> {
175
176
public MyTransformer() {
177
super();
178
}
179
180
@Override
181
protected BatchOperator transform(BatchOperator input) {
182
// Batch transformation logic
183
return input.select("*"); // Example
184
}
185
186
@Override
187
protected StreamOperator transform(StreamOperator input) {
188
// Stream transformation logic
189
return input.select("*"); // Example
190
}
191
}
192
193
// Usage
194
MyTransformer transformer = new MyTransformer()
195
.setOutputCol("transformed_col");
196
197
Table inputData = getInputData();
198
Table transformedData = transformer.transform(inputData);
199
```
200
201
### ModelBase Abstract Class
202
203
Base class for machine learning models that can transform data and manage model state. Models extend transformers with model data management capabilities.
204
205
```java { .api }
206
/**
207
* Base class for machine learning models
208
*/
209
public abstract class ModelBase<M extends ModelBase<M>>
210
extends TransformerBase<M> implements Model<M> {
211
212
/** Model data table */
213
protected Table modelData;
214
215
/** Default constructor */
216
public ModelBase();
217
218
/** Constructor with parameters */
219
public ModelBase(Params params);
220
221
/**
222
* Get model data table
223
* @return Model data as Table
224
*/
225
public Table getModelData();
226
227
/**
228
* Set model data table
229
* @param modelData Model data table
230
* @return This model instance for method chaining
231
*/
232
public M setModelData(Table modelData);
233
234
/**
235
* Clone model with model data
236
* @return Cloned model instance
237
*/
238
public M clone();
239
}
240
```
241
242
**Usage Examples:**
243
244
```java
245
import org.apache.flink.ml.pipeline.ModelBase;
246
import org.apache.flink.table.api.Table;
247
248
// Example custom model implementation
249
public class MyModel extends ModelBase<MyModel> {
250
251
public MyModel() {
252
super();
253
}
254
255
@Override
256
protected BatchOperator transform(BatchOperator input) {
257
// Use model data for transformation
258
Table modelData = getModelData();
259
// Apply model transformation logic
260
return input; // Example
261
}
262
263
@Override
264
protected StreamOperator transform(StreamOperator input) {
265
// Stream prediction logic using model data
266
return input; // Example
267
}
268
}
269
270
// Usage
271
MyModel model = new MyModel();
272
273
// Set model data (typically from training)
274
Table modelData = getModelData();
275
model.setModelData(modelData);
276
277
// Use model for prediction
278
Table inputData = getInputData();
279
Table predictions = model.transform(inputData);
280
281
// Clone model for different use cases
282
MyModel clonedModel = model.clone();
283
```
284
285
### Pipeline Construction Patterns
286
287
Common patterns for building ML pipelines using the estimator/transformer framework.
288
289
**Usage Examples:**
290
291
```java
292
import org.apache.flink.ml.pipeline.EstimatorBase;
293
import org.apache.flink.ml.pipeline.TransformerBase;
294
import org.apache.flink.table.api.Table;
295
296
// Sequential pipeline pattern
297
public class MLPipeline {
298
299
public static Table buildPipeline(Table data) {
300
// 1. Data preprocessing
301
MyPreprocessor preprocessor = new MyPreprocessor()
302
.setInputCol("raw_features")
303
.setOutputCol("processed_features");
304
305
Table preprocessedData = preprocessor.transform(data);
306
307
// 2. Feature extraction
308
MyFeatureExtractor extractor = new MyFeatureExtractor()
309
.setInputCol("processed_features")
310
.setOutputCol("features");
311
312
Table featuresData = extractor.transform(preprocessedData);
313
314
// 3. Model training
315
MyEstimator estimator = new MyEstimator()
316
.setFeaturesCol("features")
317
.setLabelCol("label");
318
319
MyModel model = estimator.fit(featuresData);
320
321
// 4. Prediction
322
Table predictions = model.transform(featuresData);
323
324
return predictions;
325
}
326
}
327
328
// Reusable pipeline components
329
public class PipelineBuilder {
330
private List<TransformerBase> transformers = new ArrayList<>();
331
332
public PipelineBuilder addTransformer(TransformerBase transformer) {
333
transformers.add(transformer);
334
return this;
335
}
336
337
public Table transform(Table input) {
338
Table result = input;
339
for (TransformerBase transformer : transformers) {
340
result = transformer.transform(result);
341
}
342
return result;
343
}
344
}
345
346
// Usage
347
Table result = new PipelineBuilder()
348
.addTransformer(preprocessor)
349
.addTransformer(featureExtractor)
350
.addTransformer(trainedModel)
351
.transform(inputData);
352
```
353
354
### Parameter Management Integration
355
356
Pipeline stages integrate with Flink ML's parameter system for type-safe configuration.
357
358
**Usage Examples:**
359
360
```java
361
import org.apache.flink.ml.params.shared.colname.HasOutputCol;
362
import org.apache.flink.ml.params.shared.colname.HasSelectedCols;
363
364
// Example estimator with parameter interfaces
365
public class ConfigurableEstimator
366
extends EstimatorBase<ConfigurableEstimator, ConfigurableModel>
367
implements HasSelectedCols<ConfigurableEstimator>, HasOutputCol<ConfigurableEstimator> {
368
369
// Custom parameters
370
public static final ParamInfo<Double> LEARNING_RATE = ParamInfoFactory
371
.createParamInfo("learningRate", Double.class)
372
.setDescription("Learning rate for training")
373
.setHasDefaultValue(0.01)
374
.build();
375
376
public Double getLearningRate() {
377
return get(LEARNING_RATE);
378
}
379
380
public ConfigurableEstimator setLearningRate(Double value) {
381
return set(LEARNING_RATE, value);
382
}
383
384
@Override
385
protected ConfigurableModel fit(BatchOperator input) {
386
// Access parameters
387
String[] selectedCols = getSelectedCols();
388
String outputCol = getOutputCol();
389
Double learningRate = getLearningRate();
390
391
// Training logic using parameters
392
ConfigurableModel model = new ConfigurableModel()
393
.setOutputCol(outputCol);
394
395
return model;
396
}
397
}
398
399
// Usage with parameters
400
ConfigurableEstimator estimator = new ConfigurableEstimator()
401
.setSelectedCols(new String[]{"feature1", "feature2"})
402
.setOutputCol("prediction")
403
.setLearningRate(0.05)
404
.setMLEnvironmentId(envId);
405
406
ConfigurableModel model = estimator.fit(trainingData);
407
```