0
# Pipeline Base Classes
1
2
Abstract base classes for implementing custom estimators, transformers, and models. Provides common functionality and integration patterns with parameter management and operator framework integration.
3
4
## Capabilities
5
6
### PipelineStageBase Class
7
8
Abstract base class providing common functionality for all pipeline stages.
9
10
```java { .api }
11
/**
12
* Base implementation for pipeline stages with parameter support
13
* @param <S> The concrete pipeline stage type for method chaining
14
*/
15
public abstract class PipelineStageBase<S extends PipelineStageBase<S>>
16
implements WithParams<S>, HasMLEnvironmentId<S>, Cloneable {
17
18
/** Stage parameters */
19
protected Params params;
20
21
/** Create stage with empty parameters */
22
public PipelineStageBase();
23
24
/** Create stage with initial parameters */
25
public PipelineStageBase(Params params);
26
27
/** Get all stage parameters */
28
public Params getParams();
29
30
/** Create deep copy of stage */
31
public S clone();
32
33
/** Get table environment from table (utility method) */
34
public static TableEnvironment tableEnvOf(Table table);
35
}
36
```
37
38
### EstimatorBase Class
39
40
Abstract base class for implementing custom estimators with automatic batch/stream handling.
41
42
```java { .api }
43
/**
44
* Base implementation for estimators
45
* Handles conversion between Table API and operator framework
46
* @param <E> The concrete estimator type
47
* @param <M> The model type produced by this estimator
48
*/
49
public abstract class EstimatorBase<E extends EstimatorBase<E, M>,
50
M extends ModelBase<M>>
51
extends PipelineStageBase<E>
52
implements Estimator<E, M> {
53
54
/** Create estimator with empty parameters */
55
public EstimatorBase();
56
57
/** Create estimator with initial parameters */
58
public EstimatorBase(Params params);
59
60
/** Fit estimator with table environment and input table */
61
public M fit(TableEnvironment tEnv, Table input);
62
63
/** Fit estimator with input table only (uses table's environment) */
64
public M fit(Table input);
65
66
/** Fit estimator using batch operator (must implement) */
67
protected abstract M fit(BatchOperator input);
68
69
/** Fit estimator using stream operator (optional, throws UnsupportedOperationException by default) */
70
protected M fit(StreamOperator input) {
71
throw new UnsupportedOperationException("Stream fitting not supported");
72
}
73
}
74
```
75
76
**Implementation Example:**
77
78
```java
79
import org.apache.flink.ml.pipeline.EstimatorBase;
80
import org.apache.flink.ml.operator.batch.BatchOperator;
81
82
public class LinearRegressionEstimator
83
extends EstimatorBase<LinearRegressionEstimator, LinearRegressionModel> {
84
85
// Parameter definitions
86
public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory
87
.createParamInfo("featuresCol", String.class)
88
.setDescription("Features column name")
89
.setHasDefaultValue("features")
90
.build();
91
92
public static final ParamInfo<String> LABEL_COL = ParamInfoFactory
93
.createParamInfo("labelCol", String.class)
94
.setDescription("Label column name")
95
.setHasDefaultValue("label")
96
.build();
97
98
public static final ParamInfo<Double> REG_PARAM = ParamInfoFactory
99
.createParamInfo("regParam", Double.class)
100
.setDescription("Regularization parameter")
101
.setHasDefaultValue(0.0)
102
.build();
103
104
// Convenience methods
105
public LinearRegressionEstimator setFeaturesCol(String featuresCol) {
106
return set(FEATURES_COL, featuresCol);
107
}
108
109
public String getFeaturesCol() {
110
return get(FEATURES_COL);
111
}
112
113
public LinearRegressionEstimator setLabelCol(String labelCol) {
114
return set(LABEL_COL, labelCol);
115
}
116
117
public String getLabelCol() {
118
return get(LABEL_COL);
119
}
120
121
public LinearRegressionEstimator setRegParam(double regParam) {
122
return set(REG_PARAM, regParam);
123
}
124
125
public double getRegParam() {
126
return get(REG_PARAM);
127
}
128
129
@Override
130
protected LinearRegressionModel fit(BatchOperator input) {
131
// Extract parameters
132
String featuresCol = getFeaturesCol();
133
String labelCol = getLabelCol();
134
double regParam = getRegParam();
135
136
// Implement training logic using batch operators
137
BatchOperator<?> trainedOperator = input
138
.link(new FeatureExtractorBatchOp()
139
.setSelectedCols(featuresCol)
140
.setOutputCol("extractedFeatures"))
141
.link(new LinearRegressionTrainBatchOp()
142
.setFeaturesCol("extractedFeatures")
143
.setLabelCol(labelCol)
144
.setRegParam(regParam));
145
146
// Extract model data
147
Table modelData = trainedOperator.getOutput();
148
149
// Create and return model
150
return new LinearRegressionModel(this.getParams())
151
.setModelData(modelData);
152
}
153
}
154
155
// Usage
156
LinearRegressionEstimator estimator = new LinearRegressionEstimator()
157
.setFeaturesCol("input_features")
158
.setLabelCol("target")
159
.setRegParam(0.01);
160
161
LinearRegressionModel model = estimator.fit(trainingTable);
162
```
163
164
### TransformerBase Class
165
166
Abstract base class for implementing custom transformers with batch and stream support.
167
168
```java { .api }
169
/**
170
* Base implementation for transformers
171
* Provides both batch and stream transformation capabilities
172
* @param <T> The concrete transformer type
173
*/
174
public abstract class TransformerBase<T extends TransformerBase<T>>
175
extends PipelineStageBase<T>
176
implements Transformer<T> {
177
178
/** Create transformer with empty parameters */
179
public TransformerBase();
180
181
/** Create transformer with initial parameters */
182
public TransformerBase(Params params);
183
184
/** Transform data with table environment */
185
public Table transform(TableEnvironment tEnv, Table input);
186
187
/** Transform data using input table's environment */
188
public Table transform(Table input);
189
190
/** Transform data using batch operator (must implement) */
191
protected abstract BatchOperator transform(BatchOperator input);
192
193
/** Transform data using stream operator (must implement) */
194
protected abstract StreamOperator transform(StreamOperator input);
195
}
196
```
197
198
**Implementation Example:**
199
200
```java
201
import org.apache.flink.ml.pipeline.TransformerBase;
202
203
public class StandardScaler extends TransformerBase<StandardScaler> {
204
205
// Parameter definitions
206
public static final ParamInfo<String> INPUT_COL = ParamInfoFactory
207
.createParamInfo("inputCol", String.class)
208
.setDescription("Input column to scale")
209
.setRequired()
210
.build();
211
212
public static final ParamInfo<String> OUTPUT_COL = ParamInfoFactory
213
.createParamInfo("outputCol", String.class)
214
.setDescription("Output column for scaled data")
215
.setRequired()
216
.build();
217
218
public static final ParamInfo<Boolean> WITH_MEAN = ParamInfoFactory
219
.createParamInfo("withMean", Boolean.class)
220
.setDescription("Center data to zero mean")
221
.setHasDefaultValue(true)
222
.build();
223
224
public static final ParamInfo<Boolean> WITH_STD = ParamInfoFactory
225
.createParamInfo("withStd", Boolean.class)
226
.setDescription("Scale data to unit variance")
227
.setHasDefaultValue(true)
228
.build();
229
230
// Convenience methods
231
public StandardScaler setInputCol(String inputCol) {
232
return set(INPUT_COL, inputCol);
233
}
234
235
public String getInputCol() {
236
return get(INPUT_COL);
237
}
238
239
public StandardScaler setOutputCol(String outputCol) {
240
return set(OUTPUT_COL, outputCol);
241
}
242
243
public String getOutputCol() {
244
return get(OUTPUT_COL);
245
}
246
247
public StandardScaler setWithMean(boolean withMean) {
248
return set(WITH_MEAN, withMean);
249
}
250
251
public boolean getWithMean() {
252
return get(WITH_MEAN);
253
}
254
255
public StandardScaler setWithStd(boolean withStd) {
256
return set(WITH_STD, withStd);
257
}
258
259
public boolean getWithStd() {
260
return get(WITH_STD);
261
}
262
263
@Override
264
protected BatchOperator transform(BatchOperator input) {
265
return input.link(new StandardScalerBatchOp()
266
.setSelectedCol(getInputCol())
267
.setOutputCol(getOutputCol())
268
.setWithMean(getWithMean())
269
.setWithStd(getWithStd()));
270
}
271
272
@Override
273
protected StreamOperator transform(StreamOperator input) {
274
return input.link(new StandardScalerStreamOp()
275
.setSelectedCol(getInputCol())
276
.setOutputCol(getOutputCol())
277
.setWithMean(getWithMean())
278
.setWithStd(getWithStd()));
279
}
280
}
281
282
// Usage
283
StandardScaler scaler = new StandardScaler()
284
.setInputCol("raw_features")
285
.setOutputCol("scaled_features")
286
.setWithMean(true)
287
.setWithStd(true);
288
289
// Works with both batch and stream data
290
Table scaledBatchData = scaler.transform(batchData);
291
Table scaledStreamData = scaler.transform(streamData);
292
```
293
294
### ModelBase Class
295
296
Abstract base class for implementing trained models with serialization support.
297
298
```java { .api }
299
/**
300
* Base implementation for trained models
301
* Extends TransformerBase and adds model data management
302
* @param <M> The concrete model type
303
*/
304
public abstract class ModelBase<M extends ModelBase<M>>
305
extends TransformerBase<M>
306
implements Model<M> {
307
308
/** Model data table */
309
protected Table modelData;
310
311
/** Create model with empty parameters */
312
public ModelBase();
313
314
/** Create model with initial parameters */
315
public ModelBase(Params params);
316
317
/** Get model data table */
318
public Table getModelData();
319
320
/** Set model data table */
321
public M setModelData(Table modelData);
322
323
/** Clone model with model data */
324
public M clone();
325
}
326
```
327
328
**Implementation Example:**
329
330
```java
331
import org.apache.flink.ml.pipeline.ModelBase;
332
333
public class LinearRegressionModel extends ModelBase<LinearRegressionModel> {
334
335
// Same parameter definitions as estimator for consistency
336
public static final ParamInfo<String> FEATURES_COL = LinearRegressionEstimator.FEATURES_COL;
337
public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory
338
.createParamInfo("predictionCol", String.class)
339
.setDescription("Prediction output column")
340
.setHasDefaultValue("prediction")
341
.build();
342
343
public LinearRegressionModel() {
344
super();
345
}
346
347
public LinearRegressionModel(Params params) {
348
super(params);
349
}
350
351
// Convenience methods
352
public LinearRegressionModel setFeaturesCol(String featuresCol) {
353
return set(FEATURES_COL, featuresCol);
354
}
355
356
public String getFeaturesCol() {
357
return get(FEATURES_COL);
358
}
359
360
public LinearRegressionModel setPredictionCol(String predictionCol) {
361
return set(PREDICTION_COL, predictionCol);
362
}
363
364
public String getPredictionCol() {
365
return get(PREDICTION_COL);
366
}
367
368
@Override
369
protected BatchOperator transform(BatchOperator input) {
370
return input.link(new LinearRegressionPredictBatchOp()
371
.setModelData(this.getModelData())
372
.setFeaturesCol(getFeaturesCol())
373
.setPredictionCol(getPredictionCol()));
374
}
375
376
@Override
377
protected StreamOperator transform(StreamOperator input) {
378
return input.link(new LinearRegressionPredictStreamOp()
379
.setModelData(this.getModelData())
380
.setFeaturesCol(getFeaturesCol())
381
.setPredictionCol(getPredictionCol()));
382
}
383
}
384
385
// Usage
386
LinearRegressionModel model = // ... obtained from estimator
387
model.setPredictionCol("my_predictions");
388
389
Table predictions = model.transform(testData);
390
```
391
392
## Base Class Integration Patterns
393
394
### Estimator-Model Pairing
395
396
The base classes are designed to work together in estimator-model pairs:
397
398
```java
399
// The estimator trains and produces the model
400
public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
401
@Override
402
protected MyModel fit(BatchOperator input) {
403
// Training logic
404
Table modelData = // ... train and extract model data
405
406
return new MyModel(this.getParams()).setModelData(modelData);
407
}
408
}
409
410
// The model applies the trained logic
411
public class MyModel extends ModelBase<MyModel> {
412
@Override
413
protected BatchOperator transform(BatchOperator input) {
414
// Apply model using this.getModelData()
415
return // ... prediction logic
416
}
417
418
@Override
419
protected StreamOperator transform(StreamOperator input) {
420
// Apply model to streaming data
421
return // ... streaming prediction logic
422
}
423
}
424
```
425
426
### Parameter Consistency
427
428
Maintain parameter consistency between estimators and models:
429
430
```java
431
public class ParameterDefinitions {
432
// Shared parameter definitions
433
public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory
434
.createParamInfo("featuresCol", String.class)
435
.setHasDefaultValue("features")
436
.build();
437
438
public static final ParamInfo<Integer> MAX_ITER = ParamInfoFactory
439
.createParamInfo("maxIter", Integer.class)
440
.setHasDefaultValue(100)
441
.build();
442
}
443
444
public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {
445
// Use shared parameters
446
public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;
447
public static final ParamInfo<Integer> MAX_ITER = ParameterDefinitions.MAX_ITER;
448
449
// Estimator-specific parameters
450
public static final ParamInfo<String> LABEL_COL = ParamInfoFactory
451
.createParamInfo("labelCol", String.class)
452
.setHasDefaultValue("label")
453
.build();
454
}
455
456
public class MyModel extends ModelBase<MyModel> {
457
// Use shared parameters
458
public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;
459
460
// Model-specific parameters
461
public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory
462
.createParamInfo("predictionCol", String.class)
463
.setHasDefaultValue("prediction")
464
.build();
465
}
466
```
467
468
### Batch and Stream Support
469
470
Implement both batch and stream operations for maximum flexibility:
471
472
```java
473
public class FlexibleTransformer extends TransformerBase<FlexibleTransformer> {
474
475
@Override
476
protected BatchOperator transform(BatchOperator input) {
477
// Batch-optimized implementation
478
return input.link(new MyBatchOp()
479
.setParallelism(4) // Higher parallelism for batch
480
.setBufferSize(1000)); // Larger buffers for batch
481
}
482
483
@Override
484
protected StreamOperator transform(StreamOperator input) {
485
// Stream-optimized implementation
486
return input.link(new MyStreamOp()
487
.setParallelism(1) // Lower parallelism for low latency
488
.setBufferSize(10)); // Smaller buffers for real-time
489
}
490
}
491
```
492
493
### Error Handling
494
495
Implement robust error handling in base class implementations:
496
497
```java
498
public class RobustEstimator extends EstimatorBase<RobustEstimator, RobustModel> {
499
500
@Override
501
protected RobustModel fit(BatchOperator input) {
502
try {
503
// Validate input parameters
504
validateParameters();
505
506
// Check input data schema
507
validateInputSchema(input.getSchema());
508
509
// Perform training
510
BatchOperator<?> trained = input.link(new TrainingOp());
511
512
Table modelData = trained.getOutput();
513
514
return new RobustModel(this.getParams()).setModelData(modelData);
515
516
} catch (Exception e) {
517
throw new RuntimeException("Training failed: " + e.getMessage(), e);
518
}
519
}
520
521
private void validateParameters() {
522
// Parameter validation logic
523
if (get(MAX_ITER) <= 0) {
524
throw new IllegalArgumentException("maxIter must be positive");
525
}
526
}
527
528
private void validateInputSchema(TableSchema schema) {
529
// Schema validation logic
530
String featuresCol = getFeaturesCol();
531
if (!Arrays.asList(schema.getFieldNames()).contains(featuresCol)) {
532
throw new IllegalArgumentException("Features column not found: " + featuresCol);
533
}
534
}
535
}
536
```
537
538
This comprehensive base class system provides a solid foundation for implementing custom ML algorithms while maintaining consistency with the broader Flink ML ecosystem.