0
# Training Callbacks
1
2
TensorFlow.js Node provides enhanced training callbacks that improve the training experience with progress visualization and detailed logging. These callbacks integrate seamlessly with the standard TensorFlow.js training process.
3
4
## Capabilities
5
6
### Progress Bar Logger
7
8
#### ProgbarLogger Class
9
10
A terminal-based progress bar callback that automatically displays training progress.
11
12
```typescript { .api }
13
/**
14
* Terminal-based progress bar callback for tf.Model.fit()
15
* Automatically registered at verbosity level 1
16
*/
17
class ProgbarLogger extends CustomCallback {
18
constructor();
19
}
20
```
21
22
The `ProgbarLogger` is automatically registered and will be used when training with `verbose: 1`:
23
24
**Usage Example:**
25
26
```typescript
27
import * as tf from '@tensorflow/tfjs-node';
28
29
// Create model
30
const model = tf.sequential({
31
layers: [
32
tf.layers.dense({ inputShape: [10], units: 64, activation: 'relu' }),
33
tf.layers.dense({ units: 32, activation: 'relu' }),
34
tf.layers.dense({ units: 1, activation: 'linear' })
35
]
36
});
37
38
model.compile({
39
optimizer: 'adam',
40
loss: 'meanSquaredError',
41
metrics: ['mae']
42
});
43
44
// Generate training data
45
const xs = tf.randomNormal([1000, 10]);
46
const ys = tf.randomNormal([1000, 1]);
47
48
// Train with progress bar (verbose: 1 automatically uses ProgbarLogger)
49
await model.fit(xs, ys, {
50
epochs: 50,
51
batchSize: 32,
52
validationSplit: 0.2,
53
verbose: 1 // This enables the progress bar
54
});
55
56
// Output will show:
57
// Epoch 1/50
58
// ████████████████████████████████ 25/25 [==============================] - 2s 45ms/step - loss: 1.2345 - mae: 0.9876 - val_loss: 1.1234 - val_mae: 0.8765
59
// Epoch 2/50
60
// ████████████████████████████████ 25/25 [==============================] - 1s 40ms/step - loss: 1.1234 - mae: 0.8765 - val_loss: 1.0123 - val_mae: 0.7654
61
// ...
62
```
63
64
### TensorBoard Callback
65
66
#### TensorBoardCallback Class
67
68
Automatically log training metrics to TensorBoard during training.
69
70
```typescript { .api }
71
/**
72
* TensorBoard callback for automatic logging during training
73
*/
74
class TensorBoardCallback extends CustomCallback {
75
constructor(logdir?: string, updateFreq?: 'batch' | 'epoch', histogramFreq?: number);
76
}
77
78
/**
79
* Factory function to create TensorBoard callback
80
* @param logdir - Directory to write logs (default: './logs')
81
* @param args - Configuration options
82
* @returns TensorBoardCallback instance
83
*/
84
function tensorBoard(logdir?: string, args?: TensorBoardCallbackArgs): TensorBoardCallback;
85
86
interface TensorBoardCallbackArgs {
87
/** How often to log: 'batch' for every batch, 'epoch' for every epoch */
88
updateFreq?: 'batch' | 'epoch';
89
90
/** How often to log weight histograms (in epochs, 0 = disabled) */
91
histogramFreq?: number;
92
}
93
```
94
95
**Usage Example:**
96
97
```typescript
98
// Create TensorBoard callback
99
const tbCallback = tf.node.tensorBoard('./logs/training_run', {
100
updateFreq: 'epoch', // Log after each epoch
101
histogramFreq: 5 // Log weight histograms every 5 epochs
102
});
103
104
// Train with TensorBoard logging
105
await model.fit(xs, ys, {
106
epochs: 100,
107
batchSize: 64,
108
validationSplit: 0.1,
109
callbacks: [tbCallback],
110
verbose: 1 // Also show progress bar
111
});
112
113
console.log('Training complete. View logs with: tensorboard --logdir ./logs');
114
```
115
116
117
## Custom Callback Creation
118
119
You can create custom callbacks by extending the `CustomCallback` class:
120
121
```typescript { .api }
122
// Base class for creating custom callbacks
123
abstract class CustomCallback {
124
onTrainBegin?(logs?: Logs): void | Promise<void>;
125
onTrainEnd?(logs?: Logs): void | Promise<void>;
126
onEpochBegin?(epoch: number, logs?: Logs): void | Promise<void>;
127
onEpochEnd?(epoch: number, logs?: Logs): void | Promise<void>;
128
onBatchBegin?(batch: number, logs?: Logs): void | Promise<void>;
129
onBatchEnd?(batch: number, logs?: Logs): void | Promise<void>;
130
}
131
132
interface Logs {
133
[key: string]: number;
134
}
135
```
136
137
### Custom Callback Examples
138
139
#### Early Stopping Callback
140
141
```typescript
142
class EarlyStoppingCallback extends tf.CustomCallback {
143
private patience: number;
144
private minDelta: number;
145
private monitorMetric: string;
146
private bestValue: number;
147
private waitCount: number;
148
149
constructor(patience: number = 10, minDelta: number = 0.001, monitor: string = 'val_loss') {
150
super();
151
this.patience = patience;
152
this.minDelta = minDelta;
153
this.monitorMetric = monitor;
154
this.bestValue = Infinity;
155
this.waitCount = 0;
156
}
157
158
async onEpochEnd(epoch: number, logs?: tf.Logs) {
159
const currentValue = logs?.[this.monitorMetric];
160
161
if (currentValue == null) {
162
console.warn(`Early stopping metric '${this.monitorMetric}' not found in logs`);
163
return;
164
}
165
166
if (currentValue < this.bestValue - this.minDelta) {
167
this.bestValue = currentValue;
168
this.waitCount = 0;
169
console.log(`Epoch ${epoch + 1}: ${this.monitorMetric} improved to ${currentValue.toFixed(6)}`);
170
} else {
171
this.waitCount++;
172
console.log(`Epoch ${epoch + 1}: ${this.monitorMetric} did not improve (${this.waitCount}/${this.patience})`);
173
174
if (this.waitCount >= this.patience) {
175
console.log(`Early stopping after ${epoch + 1} epochs`);
176
this.model.stopTraining = true;
177
}
178
}
179
}
180
}
181
182
// Usage
183
const earlyStop = new EarlyStoppingCallback(15, 0.001, 'val_loss');
184
185
await model.fit(xs, ys, {
186
epochs: 200,
187
validationSplit: 0.2,
188
callbacks: [earlyStop],
189
verbose: 1
190
});
191
```
192
193
#### Learning Rate Scheduler
194
195
```typescript
196
class LearningRateScheduler extends tf.CustomCallback {
197
private scheduleFn: (epoch: number) => number;
198
199
constructor(schedule: (epoch: number) => number) {
200
super();
201
this.scheduleFn = schedule;
202
}
203
204
async onEpochBegin(epoch: number, logs?: tf.Logs) {
205
const newLr = this.scheduleFn(epoch);
206
207
// Update optimizer learning rate
208
if (this.model.optimizer instanceof tf.AdamOptimizer) {
209
this.model.optimizer.learningRate = newLr;
210
}
211
212
console.log(`Epoch ${epoch + 1}: Learning rate set to ${newLr}`);
213
}
214
}
215
216
// Usage with exponential decay
217
const lrScheduler = new LearningRateScheduler((epoch: number) => {
218
const initialLr = 0.001;
219
const decayRate = 0.95;
220
return initialLr * Math.pow(decayRate, epoch);
221
});
222
223
await model.fit(xs, ys, {
224
epochs: 100,
225
callbacks: [lrScheduler],
226
verbose: 1
227
});
228
```
229
230
#### Model Checkpointing
231
232
```typescript
233
class ModelCheckpoint extends tf.CustomCallback {
234
private filepath: string;
235
private monitor: string;
236
private saveWeightsOnly: boolean;
237
private saveBestOnly: boolean;
238
private bestValue: number;
239
240
constructor(
241
filepath: string,
242
monitor: string = 'val_loss',
243
saveWeightsOnly: boolean = false,
244
saveBestOnly: boolean = true
245
) {
246
super();
247
this.filepath = filepath;
248
this.monitor = monitor;
249
this.saveWeightsOnly = saveWeightsOnly;
250
this.saveBestOnly = saveBestOnly;
251
this.bestValue = Infinity;
252
}
253
254
async onEpochEnd(epoch: number, logs?: tf.Logs) {
255
const currentValue = logs?.[this.monitor];
256
257
if (currentValue == null) {
258
console.warn(`Checkpoint metric '${this.monitor}' not found in logs`);
259
return;
260
}
261
262
let shouldSave = !this.saveBestOnly;
263
264
if (this.saveBestOnly && currentValue < this.bestValue) {
265
this.bestValue = currentValue;
266
shouldSave = true;
267
}
268
269
if (shouldSave) {
270
const epochPath = this.filepath.replace('{epoch}', (epoch + 1).toString());
271
272
try {
273
if (this.saveWeightsOnly) {
274
await this.model.saveWeights(`file://${epochPath}`);
275
} else {
276
await this.model.save(`file://${epochPath}`);
277
}
278
279
console.log(`Epoch ${epoch + 1}: Model saved to ${epochPath}`);
280
} catch (error) {
281
console.error(`Failed to save model: ${error.message}`);
282
}
283
}
284
}
285
}
286
287
// Usage
288
const checkpoint = new ModelCheckpoint(
289
'./checkpoints/model-epoch-{epoch}',
290
'val_accuracy',
291
false, // Save full model
292
true // Save only best model
293
);
294
295
await model.fit(xs, ys, {
296
epochs: 50,
297
validationSplit: 0.2,
298
callbacks: [checkpoint],
299
verbose: 1
300
});
301
```
302
303
#### Metrics Logger
304
305
```typescript
306
class MetricsLogger extends tf.CustomCallback {
307
private metrics: Array<{epoch: number, logs: tf.Logs}> = [];
308
private logFile?: string;
309
310
constructor(logFile?: string) {
311
super();
312
this.logFile = logFile;
313
}
314
315
async onEpochEnd(epoch: number, logs?: tf.Logs) {
316
if (logs) {
317
this.metrics.push({ epoch: epoch + 1, logs: { ...logs } });
318
319
// Log to console
320
const logStr = Object.entries(logs)
321
.map(([key, value]) => `${key}: ${value.toFixed(6)}`)
322
.join(', ');
323
console.log(`Epoch ${epoch + 1} - ${logStr}`);
324
325
// Log to file if specified
326
if (this.logFile) {
327
const fs = require('fs');
328
const logEntry = JSON.stringify({ epoch: epoch + 1, ...logs }) + '\n';
329
fs.appendFileSync(this.logFile, logEntry);
330
}
331
}
332
}
333
334
getMetrics() {
335
return this.metrics;
336
}
337
338
saveMetrics(filepath: string) {
339
const fs = require('fs');
340
fs.writeFileSync(filepath, JSON.stringify(this.metrics, null, 2));
341
}
342
}
343
344
// Usage
345
const metricsLogger = new MetricsLogger('./training_log.jsonl');
346
347
await model.fit(xs, ys, {
348
epochs: 30,
349
validationSplit: 0.2,
350
callbacks: [metricsLogger],
351
verbose: 1
352
});
353
354
// Save metrics summary
355
metricsLogger.saveMetrics('./training_summary.json');
356
```
357
358
## Combining Multiple Callbacks
359
360
You can use multiple callbacks together for comprehensive training monitoring:
361
362
```typescript
363
async function trainWithFullMonitoring(
364
model: tf.LayersModel,
365
xs: tf.Tensor,
366
ys: tf.Tensor
367
) {
368
// Create all callbacks
369
const tensorboard = tf.node.tensorBoard('./logs/full_monitoring');
370
const earlyStop = new EarlyStoppingCallback(20, 0.0001, 'val_loss');
371
const checkpoint = new ModelCheckpoint('./checkpoints/best-model', 'val_accuracy');
372
const lrScheduler = new LearningRateScheduler(epoch => 0.001 * Math.pow(0.9, epoch));
373
const metricsLogger = new MetricsLogger('./training.log');
374
375
// Train with all callbacks
376
const history = await model.fit(xs, ys, {
377
epochs: 200,
378
batchSize: 64,
379
validationSplit: 0.2,
380
callbacks: [
381
tensorboard,
382
earlyStop,
383
checkpoint,
384
lrScheduler,
385
metricsLogger
386
],
387
verbose: 1 // Progress bar + all callback output
388
});
389
390
console.log('Training completed with full monitoring');
391
console.log('Check TensorBoard: tensorboard --logdir ./logs');
392
393
return history;
394
}
395
```
396
397
## Types
398
399
```typescript { .api }
400
// Base callback interface
401
abstract class CustomCallback {
402
protected model?: LayersModel;
403
protected params?: Params;
404
405
setModel(model: LayersModel): void;
406
setParams(params: Params): void;
407
408
onTrainBegin?(logs?: Logs): void | Promise<void>;
409
onTrainEnd?(logs?: Logs): void | Promise<void>;
410
onEpochBegin?(epoch: number, logs?: Logs): void | Promise<void>;
411
onEpochEnd?(epoch: number, logs?: Logs): void | Promise<void>;
412
onBatchBegin?(batch: number, logs?: Logs): void | Promise<void>;
413
onBatchEnd?(batch: number, logs?: Logs): void | Promise<void>;
414
}
415
416
interface Logs {
417
[key: string]: number;
418
}
419
420
interface Params {
421
epochs: number;
422
samples: number;
423
steps: number;
424
batchSize: number;
425
verbose: number;
426
doValidation: boolean;
427
metrics: string[];
428
}
429
430
// TensorBoard callback specific types
431
interface TensorBoardCallbackArgs {
432
updateFreq?: 'batch' | 'epoch';
433
histogramFreq?: number;
434
}
435
```