or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

async.mdexternal-systems.mdindex.mditeration.mdjoins.mdmachine-learning.mdside-output.mdsocket.mdutilities.mdwindowing.mdwordcount.md

machine-learning.mddocs/

0

# Machine Learning Examples

1

2

Incremental learning patterns and online algorithm implementations for streaming ML workflows. Demonstrates real-time model updates and streaming machine learning patterns.

3

4

## Capabilities

5

6

### IncrementalLearningSkeleton

7

8

Skeleton framework for incremental machine learning algorithms with online model updates.

9

10

```java { .api }

11

/**

12

* Skeleton for incremental machine learning algorithms

13

* Demonstrates online learning patterns with streaming model updates

14

* @param args Command line arguments (--input path, --output path)

15

*/

16

public class IncrementalLearningSkeleton {

17

public static void main(String[] args) throws Exception;

18

}

19

```

20

21

**Usage Example:**

22

23

```bash

24

# Run with default sample data

25

java -cp flink-examples-streaming_2.10-1.3.3.jar \

26

org.apache.flink.streaming.examples.ml.IncrementalLearningSkeleton

27

28

# Run with file input

29

java -cp flink-examples-streaming_2.10-1.3.3.jar \

30

org.apache.flink.streaming.examples.ml.IncrementalLearningSkeleton \

31

--input /path/to/training-data.txt --output /path/to/model-updates.txt

32

```

33

34

## Online Learning Patterns

35

36

### Streaming Model Updates

37

38

The incremental learning skeleton demonstrates patterns for:

39

40

1. **Online Training Data Processing**: Continuous ingestion of training examples

41

2. **Model State Management**: Maintaining and updating model parameters in stream state

42

3. **Incremental Updates**: Applying mini-batch or single-example updates to models

43

4. **Model Versioning**: Tracking model versions and update history

44

5. **Prediction Integration**: Using updated models for real-time predictions

45

46

### Key Components

47

48

#### Training Data Stream

49

```java

50

// Stream of training examples

51

DataStream<TrainingExample> trainingData;

52

53

// Training example format (implementation specific)

54

class TrainingExample {

55

public double[] features;

56

public double label;

57

public long timestamp;

58

}

59

```

60

61

#### Model State Management

62

```java

63

// Maintain model parameters in keyed state

64

ValueState<ModelParameters> modelState;

65

66

// Model parameters (implementation specific)

67

class ModelParameters {

68

public double[] weights;

69

public double bias;

70

public long version;

71

public int sampleCount;

72

}

73

```

74

75

#### Incremental Update Function

76

```java

77

public static class IncrementalLearningFunction

78

extends RichMapFunction<TrainingExample, ModelUpdate> {

79

80

private ValueState<ModelParameters> modelState;

81

private final double learningRate = 0.01;

82

83

@Override

84

public void open(Configuration parameters) throws Exception {

85

ValueStateDescriptor<ModelParameters> descriptor =

86

new ValueStateDescriptor<>("model", ModelParameters.class);

87

modelState = getRuntimeContext().getState(descriptor);

88

}

89

90

@Override

91

public ModelUpdate map(TrainingExample example) throws Exception {

92

ModelParameters currentModel = modelState.value();

93

if (currentModel == null) {

94

currentModel = initializeModel();

95

}

96

97

// Apply incremental update (e.g., SGD)

98

ModelParameters updatedModel = updateModel(currentModel, example);

99

modelState.update(updatedModel);

100

101

return new ModelUpdate(updatedModel, example.timestamp);

102

}

103

104

private ModelParameters updateModel(ModelParameters model, TrainingExample example) {

105

// Stochastic Gradient Descent update

106

double prediction = predict(model, example.features);

107

double error = example.label - prediction;

108

109

// Update weights

110

for (int i = 0; i < model.weights.length; i++) {

111

model.weights[i] += learningRate * error * example.features[i];

112

}

113

model.bias += learningRate * error;

114

model.version++;

115

model.sampleCount++;

116

117

return model;

118

}

119

120

private double predict(ModelParameters model, double[] features) {

121

double result = model.bias;

122

for (int i = 0; i < features.length; i++) {

123

result += model.weights[i] * features[i];

124

}

125

return result;

126

}

127

}

128

```

129

130

## ML Pipeline Patterns

131

132

### Feature Engineering Stream

133

134

```java

135

// Feature extraction and preprocessing

136

DataStream<TrainingExample> preprocessedData = rawDataStream

137

.map(new FeatureExtractionFunction())

138

.filter(new DataQualityFilter())

139

.keyBy(new FeatureKeySelector()); // Key by feature group or model ID

140

```

141

142

### Model Training Pipeline

143

144

```java

145

// Complete incremental learning pipeline

146

DataStream<ModelUpdate> modelUpdates = trainingData

147

.keyBy(new ModelKeySelector()) // Partition by model ID

148

.map(new IncrementalLearningFunction()) // Apply incremental updates

149

.filter(new SignificantUpdateFilter()); // Only emit significant updates

150

151

// Broadcast model updates for prediction

152

BroadcastStream<ModelUpdate> modelBroadcast = modelUpdates

153

.broadcast(MODEL_UPDATE_DESCRIPTOR);

154

```

155

156

### Real-time Prediction

157

158

```java

159

// Use updated models for predictions

160

DataStream<Prediction> predictions = predictionRequests

161

.connect(modelBroadcast)

162

.process(new ModelPredictionFunction());

163

164

public static class ModelPredictionFunction

165

extends BroadcastProcessFunction<PredictionRequest, ModelUpdate, Prediction> {

166

167

@Override

168

public void processElement(

169

PredictionRequest request,

170

ReadOnlyContext ctx,

171

Collector<Prediction> out) throws Exception {

172

173

// Get latest model from broadcast state

174

ModelParameters model = ctx.getBroadcastState(MODEL_UPDATE_DESCRIPTOR)

175

.get(request.modelId);

176

177

if (model != null) {

178

double prediction = predict(model, request.features);

179

out.collect(new Prediction(request.id, prediction, model.version));

180

}

181

}

182

183

@Override

184

public void processBroadcastElement(

185

ModelUpdate update,

186

Context ctx,

187

Collector<Prediction> out) throws Exception {

188

189

// Update broadcast state with new model

190

ctx.getBroadcastState(MODEL_UPDATE_DESCRIPTOR)

191

.put(update.modelId, update.parameters);

192

}

193

}

194

```

195

196

## Advanced ML Patterns

197

198

### Model Ensembles

199

200

```java

201

// Maintain multiple models for ensemble predictions

202

public static class EnsembleLearningFunction

203

extends KeyedProcessFunction<String, TrainingExample, EnsemblePrediction> {

204

205

private ListState<ModelParameters> ensembleModels;

206

207

@Override

208

public void processElement(

209

TrainingExample example,

210

Context ctx,

211

Collector<EnsemblePrediction> out) throws Exception {

212

213

List<ModelParameters> models = new ArrayList<>();

214

ensembleModels.get().forEach(models::add);

215

216

// Update each model in ensemble

217

for (int i = 0; i < models.size(); i++) {

218

models.set(i, updateModel(models.get(i), example));

219

}

220

221

// Make ensemble prediction

222

double[] predictions = models.stream()

223

.mapToDouble(model -> predict(model, example.features))

224

.toArray();

225

226

double ensemblePrediction = Arrays.stream(predictions).average().orElse(0.0);

227

out.collect(new EnsemblePrediction(example.id, ensemblePrediction, predictions));

228

229

// Update state

230

ensembleModels.clear();

231

ensembleModels.addAll(models);

232

}

233

}

234

```

235

236

### Concept Drift Detection

237

238

```java

239

public static class DriftDetectionFunction

240

extends KeyedProcessFunction<String, PredictionResult, DriftAlert> {

241

242

private ValueState<Double> accuracyWindow;

243

private ValueState<Long> windowStartTime;

244

245

@Override

246

public void processElement(

247

PredictionResult result,

248

Context ctx,

249

Collector<DriftAlert> out) throws Exception {

250

251

// Update accuracy statistics

252

double currentAccuracy = accuracyWindow.value();

253

long windowStart = windowStartTime.value();

254

255

// Calculate sliding window accuracy

256

double newAccuracy = updateAccuracy(currentAccuracy, result);

257

accuracyWindow.update(newAccuracy);

258

259

// Check for concept drift

260

if (newAccuracy < DRIFT_THRESHOLD) {

261

out.collect(new DriftAlert(result.modelId, newAccuracy, ctx.timestamp()));

262

263

// Reset window

264

windowStartTime.update(ctx.timestamp());

265

accuracyWindow.clear();

266

}

267

}

268

}

269

```

270

271

### Model Performance Monitoring

272

273

```java

274

public static class ModelMetricsFunction

275

extends RichMapFunction<PredictionResult, ModelMetrics> {

276

277

private AggregatingState<Double, Double> accuracyAggregator;

278

private AggregatingState<Double, Double> latencyAggregator;

279

280

@Override

281

public void open(Configuration parameters) throws Exception {

282

// Configure accuracy aggregator

283

AggregatingStateDescriptor<Double, RunningAverage, Double> accuracyDesc =

284

new AggregatingStateDescriptor<>("accuracy", new AverageAccumulator(), Double.class);

285

accuracyAggregator = getRuntimeContext().getAggregatingState(accuracyDesc);

286

287

// Configure latency aggregator

288

AggregatingStateDescriptor<Double, RunningAverage, Double> latencyDesc =

289

new AggregatingStateDescriptor<>("latency", new AverageAccumulator(), Double.class);

290

latencyAggregator = getRuntimeContext().getAggregatingState(latencyDesc);

291

}

292

293

@Override

294

public ModelMetrics map(PredictionResult result) throws Exception {

295

// Update metrics

296

accuracyAggregator.add(result.accuracy);

297

latencyAggregator.add(result.latency);

298

299

return new ModelMetrics(

300

result.modelId,

301

accuracyAggregator.get(),

302

latencyAggregator.get(),

303

System.currentTimeMillis()

304

);

305

}

306

}

307

```

308

309

## Data Structures

310

311

### Training Example

312

313

```java { .api }

314

/**

315

* Training example for incremental learning

316

*/

317

public class TrainingExample {

318

public String id;

319

public double[] features;

320

public double label;

321

public long timestamp;

322

323

public TrainingExample(String id, double[] features, double label, long timestamp);

324

}

325

```

326

327

### Model Parameters

328

329

```java { .api }

330

/**

331

* Model parameters for linear models

332

*/

333

public class ModelParameters {

334

public double[] weights;

335

public double bias;

336

public long version;

337

public int sampleCount;

338

public double learningRate;

339

340

public ModelParameters(int featureCount);

341

}

342

```

343

344

### Model Update

345

346

```java { .api }

347

/**

348

* Model update event

349

*/

350

public class ModelUpdate {

351

public String modelId;

352

public ModelParameters parameters;

353

public long timestamp;

354

public double performance;

355

356

public ModelUpdate(String modelId, ModelParameters parameters, long timestamp);

357

}

358

```

359

360

### Prediction Request/Result

361

362

```java { .api }

363

/**

364

* Prediction request

365

*/

366

public class PredictionRequest {

367

public String id;

368

public String modelId;

369

public double[] features;

370

public long timestamp;

371

}

372

373

/**

374

* Prediction result

375

*/

376

public class PredictionResult {

377

public String id;

378

public String modelId;

379

public double prediction;

380

public double confidence;

381

public long modelVersion;

382

public double accuracy; // If ground truth available

383

public double latency; // Processing time

384

}

385

```

386

387

## Sample Data Utilities

388

389

### IncrementalLearningSkeletonData

390

391

Utility class providing sample training data for ML examples.

392

393

```java { .api }

394

/**

395

* Sample data generator for incremental learning examples

396

*/

397

public class IncrementalLearningSkeletonData {

398

public static final TrainingExample[] SAMPLE_DATA;

399

400

/**

401

* Generate synthetic training data for linear regression

402

*/

403

public static Iterator<TrainingExample> createLinearRegressionData(

404

int numSamples, int numFeatures, double noise);

405

406

/**

407

* Generate synthetic classification data

408

*/

409

public static Iterator<TrainingExample> createClassificationData(

410

int numSamples, int numFeatures, int numClasses);

411

}

412

```

413

414

## Dependencies

415

416

```xml

417

<dependency>

418

<groupId>org.apache.flink</groupId>

419

<artifactId>flink-streaming-java_2.10</artifactId>

420

<version>1.3.3</version>

421

</dependency>

422

423

<dependency>

424

<groupId>org.apache.flink</groupId>

425

<artifactId>flink-ml_2.10</artifactId>

426

<version>1.3.3</version>

427

</dependency>

428

```

429

430

## Required Imports

431

432

```java

433

import org.apache.flink.api.common.functions.RichMapFunction;

434

import org.apache.flink.api.common.state.ValueState;

435

import org.apache.flink.api.common.state.ValueStateDescriptor;

436

import org.apache.flink.api.common.state.ListState;

437

import org.apache.flink.api.common.state.ListStateDescriptor;

438

import org.apache.flink.configuration.Configuration;

439

import org.apache.flink.streaming.api.datastream.BroadcastStream;

440

import org.apache.flink.streaming.api.datastream.DataStream;

441

import org.apache.flink.streaming.api.functions.KeyedProcessFunction;

442

import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;

443

import org.apache.flink.util.Collector;

444

```