0
# Machine Learning Examples
1
2
Incremental learning patterns and online algorithm implementations for streaming ML workflows. Demonstrates real-time model updates and streaming machine learning patterns.
3
4
## Capabilities
5
6
### IncrementalLearningSkeleton
7
8
Skeleton framework for incremental machine learning algorithms with online model updates.
9
10
```java { .api }
11
/**
12
* Skeleton for incremental machine learning algorithms
13
* Demonstrates online learning patterns with streaming model updates
14
* @param args Command line arguments (--input path, --output path)
15
*/
16
public class IncrementalLearningSkeleton {
17
public static void main(String[] args) throws Exception;
18
}
19
```
20
21
**Usage Example:**
22
23
```bash
24
# Run with default sample data
25
java -cp flink-examples-streaming_2.10-1.3.3.jar \
26
org.apache.flink.streaming.examples.ml.IncrementalLearningSkeleton
27
28
# Run with file input
29
java -cp flink-examples-streaming_2.10-1.3.3.jar \
30
org.apache.flink.streaming.examples.ml.IncrementalLearningSkeleton \
31
--input /path/to/training-data.txt --output /path/to/model-updates.txt
32
```
33
34
## Online Learning Patterns
35
36
### Streaming Model Updates
37
38
The incremental learning skeleton demonstrates patterns for:
39
40
1. **Online Training Data Processing**: Continuous ingestion of training examples
41
2. **Model State Management**: Maintaining and updating model parameters in stream state
42
3. **Incremental Updates**: Applying mini-batch or single-example updates to models
43
4. **Model Versioning**: Tracking model versions and update history
44
5. **Prediction Integration**: Using updated models for real-time predictions
45
46
### Key Components
47
48
#### Training Data Stream
49
```java
50
// Stream of training examples
51
DataStream<TrainingExample> trainingData;
52
53
// Training example format (implementation specific)
54
class TrainingExample {
55
public double[] features;
56
public double label;
57
public long timestamp;
58
}
59
```
60
61
#### Model State Management
62
```java
63
// Maintain model parameters in keyed state
64
ValueState<ModelParameters> modelState;
65
66
// Model parameters (implementation specific)
67
class ModelParameters {
68
public double[] weights;
69
public double bias;
70
public long version;
71
public int sampleCount;
72
}
73
```
74
75
#### Incremental Update Function
76
```java
77
public static class IncrementalLearningFunction
78
extends RichMapFunction<TrainingExample, ModelUpdate> {
79
80
private ValueState<ModelParameters> modelState;
81
private final double learningRate = 0.01;
82
83
@Override
84
public void open(Configuration parameters) throws Exception {
85
ValueStateDescriptor<ModelParameters> descriptor =
86
new ValueStateDescriptor<>("model", ModelParameters.class);
87
modelState = getRuntimeContext().getState(descriptor);
88
}
89
90
@Override
91
public ModelUpdate map(TrainingExample example) throws Exception {
92
ModelParameters currentModel = modelState.value();
93
if (currentModel == null) {
94
currentModel = initializeModel();
95
}
96
97
// Apply incremental update (e.g., SGD)
98
ModelParameters updatedModel = updateModel(currentModel, example);
99
modelState.update(updatedModel);
100
101
return new ModelUpdate(updatedModel, example.timestamp);
102
}
103
104
private ModelParameters updateModel(ModelParameters model, TrainingExample example) {
105
// Stochastic Gradient Descent update
106
double prediction = predict(model, example.features);
107
double error = example.label - prediction;
108
109
// Update weights
110
for (int i = 0; i < model.weights.length; i++) {
111
model.weights[i] += learningRate * error * example.features[i];
112
}
113
model.bias += learningRate * error;
114
model.version++;
115
model.sampleCount++;
116
117
return model;
118
}
119
120
private double predict(ModelParameters model, double[] features) {
121
double result = model.bias;
122
for (int i = 0; i < features.length; i++) {
123
result += model.weights[i] * features[i];
124
}
125
return result;
126
}
127
}
128
```
129
130
## ML Pipeline Patterns
131
132
### Feature Engineering Stream
133
134
```java
135
// Feature extraction and preprocessing
136
DataStream<TrainingExample> preprocessedData = rawDataStream
137
.map(new FeatureExtractionFunction())
138
.filter(new DataQualityFilter())
139
.keyBy(new FeatureKeySelector()); // Key by feature group or model ID
140
```
141
142
### Model Training Pipeline
143
144
```java
145
// Complete incremental learning pipeline
146
DataStream<ModelUpdate> modelUpdates = trainingData
147
.keyBy(new ModelKeySelector()) // Partition by model ID
148
.map(new IncrementalLearningFunction()) // Apply incremental updates
149
.filter(new SignificantUpdateFilter()); // Only emit significant updates
150
151
// Broadcast model updates for prediction
152
BroadcastStream<ModelUpdate> modelBroadcast = modelUpdates
153
.broadcast(MODEL_UPDATE_DESCRIPTOR);
154
```
155
156
### Real-time Prediction
157
158
```java
159
// Use updated models for predictions
160
DataStream<Prediction> predictions = predictionRequests
161
.connect(modelBroadcast)
162
.process(new ModelPredictionFunction());
163
164
public static class ModelPredictionFunction
165
extends BroadcastProcessFunction<PredictionRequest, ModelUpdate, Prediction> {
166
167
@Override
168
public void processElement(
169
PredictionRequest request,
170
ReadOnlyContext ctx,
171
Collector<Prediction> out) throws Exception {
172
173
// Get latest model from broadcast state
174
ModelParameters model = ctx.getBroadcastState(MODEL_UPDATE_DESCRIPTOR)
175
.get(request.modelId);
176
177
if (model != null) {
178
double prediction = predict(model, request.features);
179
out.collect(new Prediction(request.id, prediction, model.version));
180
}
181
}
182
183
@Override
184
public void processBroadcastElement(
185
ModelUpdate update,
186
Context ctx,
187
Collector<Prediction> out) throws Exception {
188
189
// Update broadcast state with new model
190
ctx.getBroadcastState(MODEL_UPDATE_DESCRIPTOR)
191
.put(update.modelId, update.parameters);
192
}
193
}
194
```
195
196
## Advanced ML Patterns
197
198
### Model Ensembles
199
200
```java
201
// Maintain multiple models for ensemble predictions
202
public static class EnsembleLearningFunction
203
extends KeyedProcessFunction<String, TrainingExample, EnsemblePrediction> {
204
205
private ListState<ModelParameters> ensembleModels;
206
207
@Override
208
public void processElement(
209
TrainingExample example,
210
Context ctx,
211
Collector<EnsemblePrediction> out) throws Exception {
212
213
List<ModelParameters> models = new ArrayList<>();
214
ensembleModels.get().forEach(models::add);
215
216
// Update each model in ensemble
217
for (int i = 0; i < models.size(); i++) {
218
models.set(i, updateModel(models.get(i), example));
219
}
220
221
// Make ensemble prediction
222
double[] predictions = models.stream()
223
.mapToDouble(model -> predict(model, example.features))
224
.toArray();
225
226
double ensemblePrediction = Arrays.stream(predictions).average().orElse(0.0);
227
out.collect(new EnsemblePrediction(example.id, ensemblePrediction, predictions));
228
229
// Update state
230
ensembleModels.clear();
231
ensembleModels.addAll(models);
232
}
233
}
234
```
235
236
### Concept Drift Detection
237
238
```java
239
public static class DriftDetectionFunction
240
extends KeyedProcessFunction<String, PredictionResult, DriftAlert> {
241
242
private ValueState<Double> accuracyWindow;
243
private ValueState<Long> windowStartTime;
244
245
@Override
246
public void processElement(
247
PredictionResult result,
248
Context ctx,
249
Collector<DriftAlert> out) throws Exception {
250
251
// Update accuracy statistics
252
double currentAccuracy = accuracyWindow.value();
253
long windowStart = windowStartTime.value();
254
255
// Calculate sliding window accuracy
256
double newAccuracy = updateAccuracy(currentAccuracy, result);
257
accuracyWindow.update(newAccuracy);
258
259
// Check for concept drift
260
if (newAccuracy < DRIFT_THRESHOLD) {
261
out.collect(new DriftAlert(result.modelId, newAccuracy, ctx.timestamp()));
262
263
// Reset window
264
windowStartTime.update(ctx.timestamp());
265
accuracyWindow.clear();
266
}
267
}
268
}
269
```
270
271
### Model Performance Monitoring
272
273
```java
274
public static class ModelMetricsFunction
275
extends RichMapFunction<PredictionResult, ModelMetrics> {
276
277
private AggregatingState<Double, Double> accuracyAggregator;
278
private AggregatingState<Double, Double> latencyAggregator;
279
280
@Override
281
public void open(Configuration parameters) throws Exception {
282
// Configure accuracy aggregator
283
AggregatingStateDescriptor<Double, RunningAverage, Double> accuracyDesc =
284
new AggregatingStateDescriptor<>("accuracy", new AverageAccumulator(), Double.class);
285
accuracyAggregator = getRuntimeContext().getAggregatingState(accuracyDesc);
286
287
// Configure latency aggregator
288
AggregatingStateDescriptor<Double, RunningAverage, Double> latencyDesc =
289
new AggregatingStateDescriptor<>("latency", new AverageAccumulator(), Double.class);
290
latencyAggregator = getRuntimeContext().getAggregatingState(latencyDesc);
291
}
292
293
@Override
294
public ModelMetrics map(PredictionResult result) throws Exception {
295
// Update metrics
296
accuracyAggregator.add(result.accuracy);
297
latencyAggregator.add(result.latency);
298
299
return new ModelMetrics(
300
result.modelId,
301
accuracyAggregator.get(),
302
latencyAggregator.get(),
303
System.currentTimeMillis()
304
);
305
}
306
}
307
```
308
309
## Data Structures
310
311
### Training Example
312
313
```java { .api }
314
/**
315
* Training example for incremental learning
316
*/
317
public class TrainingExample {
318
public String id;
319
public double[] features;
320
public double label;
321
public long timestamp;
322
323
public TrainingExample(String id, double[] features, double label, long timestamp);
324
}
325
```
326
327
### Model Parameters
328
329
```java { .api }
330
/**
331
* Model parameters for linear models
332
*/
333
public class ModelParameters {
334
public double[] weights;
335
public double bias;
336
public long version;
337
public int sampleCount;
338
public double learningRate;
339
340
public ModelParameters(int featureCount);
341
}
342
```
343
344
### Model Update
345
346
```java { .api }
347
/**
348
* Model update event
349
*/
350
public class ModelUpdate {
351
public String modelId;
352
public ModelParameters parameters;
353
public long timestamp;
354
public double performance;
355
356
public ModelUpdate(String modelId, ModelParameters parameters, long timestamp);
357
}
358
```
359
360
### Prediction Request/Result
361
362
```java { .api }
363
/**
364
* Prediction request
365
*/
366
public class PredictionRequest {
367
public String id;
368
public String modelId;
369
public double[] features;
370
public long timestamp;
371
}
372
373
/**
374
* Prediction result
375
*/
376
public class PredictionResult {
377
public String id;
378
public String modelId;
379
public double prediction;
380
public double confidence;
381
public long modelVersion;
382
public double accuracy; // If ground truth available
383
public double latency; // Processing time
384
}
385
```
386
387
## Sample Data Utilities
388
389
### IncrementalLearningSkeletonData
390
391
Utility class providing sample training data for ML examples.
392
393
```java { .api }
394
/**
395
* Sample data generator for incremental learning examples
396
*/
397
public class IncrementalLearningSkeletonData {
398
public static final TrainingExample[] SAMPLE_DATA;
399
400
/**
401
* Generate synthetic training data for linear regression
402
*/
403
public static Iterator<TrainingExample> createLinearRegressionData(
404
int numSamples, int numFeatures, double noise);
405
406
/**
407
* Generate synthetic classification data
408
*/
409
public static Iterator<TrainingExample> createClassificationData(
410
int numSamples, int numFeatures, int numClasses);
411
}
412
```
413
414
## Dependencies
415
416
```xml
417
<dependency>
418
<groupId>org.apache.flink</groupId>
419
<artifactId>flink-streaming-java_2.10</artifactId>
420
<version>1.3.3</version>
421
</dependency>
422
423
<dependency>
424
<groupId>org.apache.flink</groupId>
425
<artifactId>flink-ml_2.10</artifactId>
426
<version>1.3.3</version>
427
</dependency>
428
```
429
430
## Required Imports
431
432
```java
433
import org.apache.flink.api.common.functions.RichMapFunction;
434
import org.apache.flink.api.common.state.ValueState;
435
import org.apache.flink.api.common.state.ValueStateDescriptor;
436
import org.apache.flink.api.common.state.ListState;
437
import org.apache.flink.api.common.state.ListStateDescriptor;
438
import org.apache.flink.configuration.Configuration;
439
import org.apache.flink.streaming.api.datastream.BroadcastStream;
440
import org.apache.flink.streaming.api.datastream.DataStream;
441
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
442
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
443
import org.apache.flink.util.Collector;
444
```