or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

algorithm-operators.mdenvironment-management.mdindex.mdlinear-algebra.mdparameter-system.mdpipeline-base-classes.mdpipeline-framework.mdutility-libraries.md

pipeline-base-classes.mddocs/

0

# Pipeline Base Classes

1

2

Abstract base classes for implementing custom estimators, transformers, and models. Provides common functionality and integration patterns with parameter management and operator framework integration.

3

4

## Capabilities

5

6

### PipelineStageBase Class

7

8

Abstract base class providing common functionality for all pipeline stages.

9

10

```java { .api }

11

/**

12

* Base implementation for pipeline stages with parameter support

13

* @param <S> The concrete pipeline stage type for method chaining

14

*/

15

public abstract class PipelineStageBase<S extends PipelineStageBase<S>>

16

implements WithParams<S>, HasMLEnvironmentId<S>, Cloneable {

17

18

/** Stage parameters */

19

protected Params params;

20

21

/** Create stage with empty parameters */

22

public PipelineStageBase();

23

24

/** Create stage with initial parameters */

25

public PipelineStageBase(Params params);

26

27

/** Get all stage parameters */

28

public Params getParams();

29

30

/** Create deep copy of stage */

31

public S clone();

32

33

/** Get table environment from table (utility method) */

34

public static TableEnvironment tableEnvOf(Table table);

35

}

36

```

37

38

### EstimatorBase Class

39

40

Abstract base class for implementing custom estimators with automatic batch/stream handling.

41

42

```java { .api }

43

/**

44

* Base implementation for estimators

45

* Handles conversion between Table API and operator framework

46

* @param <E> The concrete estimator type

47

* @param <M> The model type produced by this estimator

48

*/

49

public abstract class EstimatorBase<E extends EstimatorBase<E, M>,

50

M extends ModelBase<M>>

51

extends PipelineStageBase<E>

52

implements Estimator<E, M> {

53

54

/** Create estimator with empty parameters */

55

public EstimatorBase();

56

57

/** Create estimator with initial parameters */

58

public EstimatorBase(Params params);

59

60

/** Fit estimator with table environment and input table */

61

public M fit(TableEnvironment tEnv, Table input);

62

63

/** Fit estimator with input table only (uses table's environment) */

64

public M fit(Table input);

65

66

/** Fit estimator using batch operator (must implement) */

67

protected abstract M fit(BatchOperator input);

68

69

/** Fit estimator using stream operator (optional, throws UnsupportedOperationException by default) */

70

protected M fit(StreamOperator input) {

71

throw new UnsupportedOperationException("Stream fitting not supported");

72

}

73

}

74

```

75

76

**Implementation Example:**

77

78

```java

79

import org.apache.flink.ml.pipeline.EstimatorBase;

80

import org.apache.flink.ml.operator.batch.BatchOperator;

81

82

public class LinearRegressionEstimator

83

extends EstimatorBase<LinearRegressionEstimator, LinearRegressionModel> {

84

85

// Parameter definitions

86

public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory

87

.createParamInfo("featuresCol", String.class)

88

.setDescription("Features column name")

89

.setHasDefaultValue("features")

90

.build();

91

92

public static final ParamInfo<String> LABEL_COL = ParamInfoFactory

93

.createParamInfo("labelCol", String.class)

94

.setDescription("Label column name")

95

.setHasDefaultValue("label")

96

.build();

97

98

public static final ParamInfo<Double> REG_PARAM = ParamInfoFactory

99

.createParamInfo("regParam", Double.class)

100

.setDescription("Regularization parameter")

101

.setHasDefaultValue(0.0)

102

.build();

103

104

// Convenience methods

105

public LinearRegressionEstimator setFeaturesCol(String featuresCol) {

106

return set(FEATURES_COL, featuresCol);

107

}

108

109

public String getFeaturesCol() {

110

return get(FEATURES_COL);

111

}

112

113

public LinearRegressionEstimator setLabelCol(String labelCol) {

114

return set(LABEL_COL, labelCol);

115

}

116

117

public String getLabelCol() {

118

return get(LABEL_COL);

119

}

120

121

public LinearRegressionEstimator setRegParam(double regParam) {

122

return set(REG_PARAM, regParam);

123

}

124

125

public double getRegParam() {

126

return get(REG_PARAM);

127

}

128

129

@Override

130

protected LinearRegressionModel fit(BatchOperator input) {

131

// Extract parameters

132

String featuresCol = getFeaturesCol();

133

String labelCol = getLabelCol();

134

double regParam = getRegParam();

135

136

// Implement training logic using batch operators

137

BatchOperator<?> trainedOperator = input

138

.link(new FeatureExtractorBatchOp()

139

.setSelectedCols(featuresCol)

140

.setOutputCol("extractedFeatures"))

141

.link(new LinearRegressionTrainBatchOp()

142

.setFeaturesCol("extractedFeatures")

143

.setLabelCol(labelCol)

144

.setRegParam(regParam));

145

146

// Extract model data

147

Table modelData = trainedOperator.getOutput();

148

149

// Create and return model

150

return new LinearRegressionModel(this.getParams())

151

.setModelData(modelData);

152

}

153

}

154

155

// Usage

156

LinearRegressionEstimator estimator = new LinearRegressionEstimator()

157

.setFeaturesCol("input_features")

158

.setLabelCol("target")

159

.setRegParam(0.01);

160

161

LinearRegressionModel model = estimator.fit(trainingTable);

162

```

163

164

### TransformerBase Class

165

166

Abstract base class for implementing custom transformers with batch and stream support.

167

168

```java { .api }

169

/**

170

* Base implementation for transformers

171

* Provides both batch and stream transformation capabilities

172

* @param <T> The concrete transformer type

173

*/

174

public abstract class TransformerBase<T extends TransformerBase<T>>

175

extends PipelineStageBase<T>

176

implements Transformer<T> {

177

178

/** Create transformer with empty parameters */

179

public TransformerBase();

180

181

/** Create transformer with initial parameters */

182

public TransformerBase(Params params);

183

184

/** Transform data with table environment */

185

public Table transform(TableEnvironment tEnv, Table input);

186

187

/** Transform data using input table's environment */

188

public Table transform(Table input);

189

190

/** Transform data using batch operator (must implement) */

191

protected abstract BatchOperator transform(BatchOperator input);

192

193

/** Transform data using stream operator (must implement) */

194

protected abstract StreamOperator transform(StreamOperator input);

195

}

196

```

197

198

**Implementation Example:**

199

200

```java

201

import org.apache.flink.ml.pipeline.TransformerBase;

202

203

public class StandardScaler extends TransformerBase<StandardScaler> {

204

205

// Parameter definitions

206

public static final ParamInfo<String> INPUT_COL = ParamInfoFactory

207

.createParamInfo("inputCol", String.class)

208

.setDescription("Input column to scale")

209

.setRequired()

210

.build();

211

212

public static final ParamInfo<String> OUTPUT_COL = ParamInfoFactory

213

.createParamInfo("outputCol", String.class)

214

.setDescription("Output column for scaled data")

215

.setRequired()

216

.build();

217

218

public static final ParamInfo<Boolean> WITH_MEAN = ParamInfoFactory

219

.createParamInfo("withMean", Boolean.class)

220

.setDescription("Center data to zero mean")

221

.setHasDefaultValue(true)

222

.build();

223

224

public static final ParamInfo<Boolean> WITH_STD = ParamInfoFactory

225

.createParamInfo("withStd", Boolean.class)

226

.setDescription("Scale data to unit variance")

227

.setHasDefaultValue(true)

228

.build();

229

230

// Convenience methods

231

public StandardScaler setInputCol(String inputCol) {

232

return set(INPUT_COL, inputCol);

233

}

234

235

public String getInputCol() {

236

return get(INPUT_COL);

237

}

238

239

public StandardScaler setOutputCol(String outputCol) {

240

return set(OUTPUT_COL, outputCol);

241

}

242

243

public String getOutputCol() {

244

return get(OUTPUT_COL);

245

}

246

247

public StandardScaler setWithMean(boolean withMean) {

248

return set(WITH_MEAN, withMean);

249

}

250

251

public boolean getWithMean() {

252

return get(WITH_MEAN);

253

}

254

255

public StandardScaler setWithStd(boolean withStd) {

256

return set(WITH_STD, withStd);

257

}

258

259

public boolean getWithStd() {

260

return get(WITH_STD);

261

}

262

263

@Override

264

protected BatchOperator transform(BatchOperator input) {

265

return input.link(new StandardScalerBatchOp()

266

.setSelectedCol(getInputCol())

267

.setOutputCol(getOutputCol())

268

.setWithMean(getWithMean())

269

.setWithStd(getWithStd()));

270

}

271

272

@Override

273

protected StreamOperator transform(StreamOperator input) {

274

return input.link(new StandardScalerStreamOp()

275

.setSelectedCol(getInputCol())

276

.setOutputCol(getOutputCol())

277

.setWithMean(getWithMean())

278

.setWithStd(getWithStd()));

279

}

280

}

281

282

// Usage

283

StandardScaler scaler = new StandardScaler()

284

.setInputCol("raw_features")

285

.setOutputCol("scaled_features")

286

.setWithMean(true)

287

.setWithStd(true);

288

289

// Works with both batch and stream data

290

Table scaledBatchData = scaler.transform(batchData);

291

Table scaledStreamData = scaler.transform(streamData);

292

```

293

294

### ModelBase Class

295

296

Abstract base class for implementing trained models with serialization support.

297

298

```java { .api }

299

/**

300

* Base implementation for trained models

301

* Extends TransformerBase and adds model data management

302

* @param <M> The concrete model type

303

*/

304

public abstract class ModelBase<M extends ModelBase<M>>

305

extends TransformerBase<M>

306

implements Model<M> {

307

308

/** Model data table */

309

protected Table modelData;

310

311

/** Create model with empty parameters */

312

public ModelBase();

313

314

/** Create model with initial parameters */

315

public ModelBase(Params params);

316

317

/** Get model data table */

318

public Table getModelData();

319

320

/** Set model data table */

321

public M setModelData(Table modelData);

322

323

/** Clone model with model data */

324

public M clone();

325

}

326

```

327

328

**Implementation Example:**

329

330

```java

331

import org.apache.flink.ml.pipeline.ModelBase;

332

333

public class LinearRegressionModel extends ModelBase<LinearRegressionModel> {

334

335

// Same parameter definitions as estimator for consistency

336

public static final ParamInfo<String> FEATURES_COL = LinearRegressionEstimator.FEATURES_COL;

337

public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory

338

.createParamInfo("predictionCol", String.class)

339

.setDescription("Prediction output column")

340

.setHasDefaultValue("prediction")

341

.build();

342

343

public LinearRegressionModel() {

344

super();

345

}

346

347

public LinearRegressionModel(Params params) {

348

super(params);

349

}

350

351

// Convenience methods

352

public LinearRegressionModel setFeaturesCol(String featuresCol) {

353

return set(FEATURES_COL, featuresCol);

354

}

355

356

public String getFeaturesCol() {

357

return get(FEATURES_COL);

358

}

359

360

public LinearRegressionModel setPredictionCol(String predictionCol) {

361

return set(PREDICTION_COL, predictionCol);

362

}

363

364

public String getPredictionCol() {

365

return get(PREDICTION_COL);

366

}

367

368

@Override

369

protected BatchOperator transform(BatchOperator input) {

370

return input.link(new LinearRegressionPredictBatchOp()

371

.setModelData(this.getModelData())

372

.setFeaturesCol(getFeaturesCol())

373

.setPredictionCol(getPredictionCol()));

374

}

375

376

@Override

377

protected StreamOperator transform(StreamOperator input) {

378

return input.link(new LinearRegressionPredictStreamOp()

379

.setModelData(this.getModelData())

380

.setFeaturesCol(getFeaturesCol())

381

.setPredictionCol(getPredictionCol()));

382

}

383

}

384

385

// Usage

386

LinearRegressionModel model = // ... obtained from estimator

387

model.setPredictionCol("my_predictions");

388

389

Table predictions = model.transform(testData);

390

```

391

392

## Base Class Integration Patterns

393

394

### Estimator-Model Pairing

395

396

The base classes are designed to work together in estimator-model pairs:

397

398

```java

399

// The estimator trains and produces the model

400

public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {

401

@Override

402

protected MyModel fit(BatchOperator input) {

403

// Training logic

404

Table modelData = // ... train and extract model data

405

406

return new MyModel(this.getParams()).setModelData(modelData);

407

}

408

}

409

410

// The model applies the trained logic

411

public class MyModel extends ModelBase<MyModel> {

412

@Override

413

protected BatchOperator transform(BatchOperator input) {

414

// Apply model using this.getModelData()

415

return // ... prediction logic

416

}

417

418

@Override

419

protected StreamOperator transform(StreamOperator input) {

420

// Apply model to streaming data

421

return // ... streaming prediction logic

422

}

423

}

424

```

425

426

### Parameter Consistency

427

428

Maintain parameter consistency between estimators and models:

429

430

```java

431

public class ParameterDefinitions {

432

// Shared parameter definitions

433

public static final ParamInfo<String> FEATURES_COL = ParamInfoFactory

434

.createParamInfo("featuresCol", String.class)

435

.setHasDefaultValue("features")

436

.build();

437

438

public static final ParamInfo<Integer> MAX_ITER = ParamInfoFactory

439

.createParamInfo("maxIter", Integer.class)

440

.setHasDefaultValue(100)

441

.build();

442

}

443

444

public class MyEstimator extends EstimatorBase<MyEstimator, MyModel> {

445

// Use shared parameters

446

public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;

447

public static final ParamInfo<Integer> MAX_ITER = ParameterDefinitions.MAX_ITER;

448

449

// Estimator-specific parameters

450

public static final ParamInfo<String> LABEL_COL = ParamInfoFactory

451

.createParamInfo("labelCol", String.class)

452

.setHasDefaultValue("label")

453

.build();

454

}

455

456

public class MyModel extends ModelBase<MyModel> {

457

// Use shared parameters

458

public static final ParamInfo<String> FEATURES_COL = ParameterDefinitions.FEATURES_COL;

459

460

// Model-specific parameters

461

public static final ParamInfo<String> PREDICTION_COL = ParamInfoFactory

462

.createParamInfo("predictionCol", String.class)

463

.setHasDefaultValue("prediction")

464

.build();

465

}

466

```

467

468

### Batch and Stream Support

469

470

Implement both batch and stream operations for maximum flexibility:

471

472

```java

473

public class FlexibleTransformer extends TransformerBase<FlexibleTransformer> {

474

475

@Override

476

protected BatchOperator transform(BatchOperator input) {

477

// Batch-optimized implementation

478

return input.link(new MyBatchOp()

479

.setParallelism(4) // Higher parallelism for batch

480

.setBufferSize(1000)); // Larger buffers for batch

481

}

482

483

@Override

484

protected StreamOperator transform(StreamOperator input) {

485

// Stream-optimized implementation

486

return input.link(new MyStreamOp()

487

.setParallelism(1) // Lower parallelism for low latency

488

.setBufferSize(10)); // Smaller buffers for real-time

489

}

490

}

491

```

492

493

### Error Handling

494

495

Implement robust error handling in base class implementations:

496

497

```java

498

public class RobustEstimator extends EstimatorBase<RobustEstimator, RobustModel> {

499

500

@Override

501

protected RobustModel fit(BatchOperator input) {

502

try {

503

// Validate input parameters

504

validateParameters();

505

506

// Check input data schema

507

validateInputSchema(input.getSchema());

508

509

// Perform training

510

BatchOperator<?> trained = input.link(new TrainingOp());

511

512

Table modelData = trained.getOutput();

513

514

return new RobustModel(this.getParams()).setModelData(modelData);

515

516

} catch (Exception e) {

517

throw new RuntimeException("Training failed: " + e.getMessage(), e);

518

}

519

}

520

521

private void validateParameters() {

522

// Parameter validation logic

523

if (get(MAX_ITER) <= 0) {

524

throw new IllegalArgumentException("maxIter must be positive");

525

}

526

}

527

528

private void validateInputSchema(TableSchema schema) {

529

// Schema validation logic

530

String featuresCol = getFeaturesCol();

531

if (!Arrays.asList(schema.getFieldNames()).contains(featuresCol)) {

532

throw new IllegalArgumentException("Features column not found: " + featuresCol);

533

}

534

}

535

}

536

```

537

538

This comprehensive base class system provides a solid foundation for implementing custom ML algorithms while maintaining consistency with the broader Flink ML ecosystem.