or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdimage-processing.mdindex.mdio.mdsavedmodel.mdtensorboard.md

callbacks.mddocs/

0

# Training Callbacks

1

2

TensorFlow.js Node provides enhanced training callbacks that improve the training experience with progress visualization and detailed logging. These callbacks integrate seamlessly with the standard TensorFlow.js training process.

3

4

## Capabilities

5

6

### Progress Bar Logger

7

8

#### ProgbarLogger Class

9

10

A terminal-based progress bar callback that automatically displays training progress.

11

12

```typescript { .api }

13

/**

14

* Terminal-based progress bar callback for tf.Model.fit()

15

* Automatically registered at verbosity level 1

16

*/

17

class ProgbarLogger extends CustomCallback {

18

constructor();

19

}

20

```

21

22

The `ProgbarLogger` is automatically registered and will be used when training with `verbose: 1`:

23

24

**Usage Example:**

25

26

```typescript

27

import * as tf from '@tensorflow/tfjs-node';

28

29

// Create model

30

const model = tf.sequential({

31

layers: [

32

tf.layers.dense({ inputShape: [10], units: 64, activation: 'relu' }),

33

tf.layers.dense({ units: 32, activation: 'relu' }),

34

tf.layers.dense({ units: 1, activation: 'linear' })

35

]

36

});

37

38

model.compile({

39

optimizer: 'adam',

40

loss: 'meanSquaredError',

41

metrics: ['mae']

42

});

43

44

// Generate training data

45

const xs = tf.randomNormal([1000, 10]);

46

const ys = tf.randomNormal([1000, 1]);

47

48

// Train with progress bar (verbose: 1 automatically uses ProgbarLogger)

49

await model.fit(xs, ys, {

50

epochs: 50,

51

batchSize: 32,

52

validationSplit: 0.2,

53

verbose: 1 // This enables the progress bar

54

});

55

56

// Output will show:

57

// Epoch 1/50

58

// ████████████████████████████████ 25/25 [==============================] - 2s 45ms/step - loss: 1.2345 - mae: 0.9876 - val_loss: 1.1234 - val_mae: 0.8765

59

// Epoch 2/50

60

// ████████████████████████████████ 25/25 [==============================] - 1s 40ms/step - loss: 1.1234 - mae: 0.8765 - val_loss: 1.0123 - val_mae: 0.7654

61

// ...

62

```

63

64

### TensorBoard Callback

65

66

#### TensorBoardCallback Class

67

68

Automatically log training metrics to TensorBoard during training.

69

70

```typescript { .api }

71

/**

72

* TensorBoard callback for automatic logging during training

73

*/

74

class TensorBoardCallback extends CustomCallback {

75

constructor(logdir?: string, updateFreq?: 'batch' | 'epoch', histogramFreq?: number);

76

}

77

78

/**

79

* Factory function to create TensorBoard callback

80

* @param logdir - Directory to write logs (default: './logs')

81

* @param args - Configuration options

82

* @returns TensorBoardCallback instance

83

*/

84

function tensorBoard(logdir?: string, args?: TensorBoardCallbackArgs): TensorBoardCallback;

85

86

interface TensorBoardCallbackArgs {

87

/** How often to log: 'batch' for every batch, 'epoch' for every epoch */

88

updateFreq?: 'batch' | 'epoch';

89

90

/** How often to log weight histograms (in epochs, 0 = disabled) */

91

histogramFreq?: number;

92

}

93

```

94

95

**Usage Example:**

96

97

```typescript

98

// Create TensorBoard callback

99

const tbCallback = tf.node.tensorBoard('./logs/training_run', {

100

updateFreq: 'epoch', // Log after each epoch

101

histogramFreq: 5 // Log weight histograms every 5 epochs

102

});

103

104

// Train with TensorBoard logging

105

await model.fit(xs, ys, {

106

epochs: 100,

107

batchSize: 64,

108

validationSplit: 0.1,

109

callbacks: [tbCallback],

110

verbose: 1 // Also show progress bar

111

});

112

113

console.log('Training complete. View logs with: tensorboard --logdir ./logs');

114

```

115

116

117

## Custom Callback Creation

118

119

You can create custom callbacks by extending the `CustomCallback` class:

120

121

```typescript { .api }

122

// Base class for creating custom callbacks

123

abstract class CustomCallback {

124

onTrainBegin?(logs?: Logs): void | Promise<void>;

125

onTrainEnd?(logs?: Logs): void | Promise<void>;

126

onEpochBegin?(epoch: number, logs?: Logs): void | Promise<void>;

127

onEpochEnd?(epoch: number, logs?: Logs): void | Promise<void>;

128

onBatchBegin?(batch: number, logs?: Logs): void | Promise<void>;

129

onBatchEnd?(batch: number, logs?: Logs): void | Promise<void>;

130

}

131

132

interface Logs {

133

[key: string]: number;

134

}

135

```

136

137

### Custom Callback Examples

138

139

#### Early Stopping Callback

140

141

```typescript

142

class EarlyStoppingCallback extends tf.CustomCallback {

143

private patience: number;

144

private minDelta: number;

145

private monitorMetric: string;

146

private bestValue: number;

147

private waitCount: number;

148

149

constructor(patience: number = 10, minDelta: number = 0.001, monitor: string = 'val_loss') {

150

super();

151

this.patience = patience;

152

this.minDelta = minDelta;

153

this.monitorMetric = monitor;

154

this.bestValue = Infinity;

155

this.waitCount = 0;

156

}

157

158

async onEpochEnd(epoch: number, logs?: tf.Logs) {

159

const currentValue = logs?.[this.monitorMetric];

160

161

if (currentValue == null) {

162

console.warn(`Early stopping metric '${this.monitorMetric}' not found in logs`);

163

return;

164

}

165

166

if (currentValue < this.bestValue - this.minDelta) {

167

this.bestValue = currentValue;

168

this.waitCount = 0;

169

console.log(`Epoch ${epoch + 1}: ${this.monitorMetric} improved to ${currentValue.toFixed(6)}`);

170

} else {

171

this.waitCount++;

172

console.log(`Epoch ${epoch + 1}: ${this.monitorMetric} did not improve (${this.waitCount}/${this.patience})`);

173

174

if (this.waitCount >= this.patience) {

175

console.log(`Early stopping after ${epoch + 1} epochs`);

176

this.model.stopTraining = true;

177

}

178

}

179

}

180

}

181

182

// Usage

183

const earlyStop = new EarlyStoppingCallback(15, 0.001, 'val_loss');

184

185

await model.fit(xs, ys, {

186

epochs: 200,

187

validationSplit: 0.2,

188

callbacks: [earlyStop],

189

verbose: 1

190

});

191

```

192

193

#### Learning Rate Scheduler

194

195

```typescript

196

class LearningRateScheduler extends tf.CustomCallback {

197

private scheduleFn: (epoch: number) => number;

198

199

constructor(schedule: (epoch: number) => number) {

200

super();

201

this.scheduleFn = schedule;

202

}

203

204

async onEpochBegin(epoch: number, logs?: tf.Logs) {

205

const newLr = this.scheduleFn(epoch);

206

207

// Update optimizer learning rate

208

if (this.model.optimizer instanceof tf.AdamOptimizer) {

209

this.model.optimizer.learningRate = newLr;

210

}

211

212

console.log(`Epoch ${epoch + 1}: Learning rate set to ${newLr}`);

213

}

214

}

215

216

// Usage with exponential decay

217

const lrScheduler = new LearningRateScheduler((epoch: number) => {

218

const initialLr = 0.001;

219

const decayRate = 0.95;

220

return initialLr * Math.pow(decayRate, epoch);

221

});

222

223

await model.fit(xs, ys, {

224

epochs: 100,

225

callbacks: [lrScheduler],

226

verbose: 1

227

});

228

```

229

230

#### Model Checkpointing

231

232

```typescript

233

class ModelCheckpoint extends tf.CustomCallback {

234

private filepath: string;

235

private monitor: string;

236

private saveWeightsOnly: boolean;

237

private saveBestOnly: boolean;

238

private bestValue: number;

239

240

constructor(

241

filepath: string,

242

monitor: string = 'val_loss',

243

saveWeightsOnly: boolean = false,

244

saveBestOnly: boolean = true

245

) {

246

super();

247

this.filepath = filepath;

248

this.monitor = monitor;

249

this.saveWeightsOnly = saveWeightsOnly;

250

this.saveBestOnly = saveBestOnly;

251

this.bestValue = Infinity;

252

}

253

254

async onEpochEnd(epoch: number, logs?: tf.Logs) {

255

const currentValue = logs?.[this.monitor];

256

257

if (currentValue == null) {

258

console.warn(`Checkpoint metric '${this.monitor}' not found in logs`);

259

return;

260

}

261

262

let shouldSave = !this.saveBestOnly;

263

264

if (this.saveBestOnly && currentValue < this.bestValue) {

265

this.bestValue = currentValue;

266

shouldSave = true;

267

}

268

269

if (shouldSave) {

270

const epochPath = this.filepath.replace('{epoch}', (epoch + 1).toString());

271

272

try {

273

if (this.saveWeightsOnly) {

274

await this.model.saveWeights(`file://${epochPath}`);

275

} else {

276

await this.model.save(`file://${epochPath}`);

277

}

278

279

console.log(`Epoch ${epoch + 1}: Model saved to ${epochPath}`);

280

} catch (error) {

281

console.error(`Failed to save model: ${error.message}`);

282

}

283

}

284

}

285

}

286

287

// Usage

288

const checkpoint = new ModelCheckpoint(

289

'./checkpoints/model-epoch-{epoch}',

290

'val_accuracy',

291

false, // Save full model

292

true // Save only best model

293

);

294

295

await model.fit(xs, ys, {

296

epochs: 50,

297

validationSplit: 0.2,

298

callbacks: [checkpoint],

299

verbose: 1

300

});

301

```

302

303

#### Metrics Logger

304

305

```typescript

306

class MetricsLogger extends tf.CustomCallback {

307

private metrics: Array<{epoch: number, logs: tf.Logs}> = [];

308

private logFile?: string;

309

310

constructor(logFile?: string) {

311

super();

312

this.logFile = logFile;

313

}

314

315

async onEpochEnd(epoch: number, logs?: tf.Logs) {

316

if (logs) {

317

this.metrics.push({ epoch: epoch + 1, logs: { ...logs } });

318

319

// Log to console

320

const logStr = Object.entries(logs)

321

.map(([key, value]) => `${key}: ${value.toFixed(6)}`)

322

.join(', ');

323

console.log(`Epoch ${epoch + 1} - ${logStr}`);

324

325

// Log to file if specified

326

if (this.logFile) {

327

const fs = require('fs');

328

const logEntry = JSON.stringify({ epoch: epoch + 1, ...logs }) + '\n';

329

fs.appendFileSync(this.logFile, logEntry);

330

}

331

}

332

}

333

334

getMetrics() {

335

return this.metrics;

336

}

337

338

saveMetrics(filepath: string) {

339

const fs = require('fs');

340

fs.writeFileSync(filepath, JSON.stringify(this.metrics, null, 2));

341

}

342

}

343

344

// Usage

345

const metricsLogger = new MetricsLogger('./training_log.jsonl');

346

347

await model.fit(xs, ys, {

348

epochs: 30,

349

validationSplit: 0.2,

350

callbacks: [metricsLogger],

351

verbose: 1

352

});

353

354

// Save metrics summary

355

metricsLogger.saveMetrics('./training_summary.json');

356

```

357

358

## Combining Multiple Callbacks

359

360

You can use multiple callbacks together for comprehensive training monitoring:

361

362

```typescript

363

async function trainWithFullMonitoring(

364

model: tf.LayersModel,

365

xs: tf.Tensor,

366

ys: tf.Tensor

367

) {

368

// Create all callbacks

369

const tensorboard = tf.node.tensorBoard('./logs/full_monitoring');

370

const earlyStop = new EarlyStoppingCallback(20, 0.0001, 'val_loss');

371

const checkpoint = new ModelCheckpoint('./checkpoints/best-model', 'val_accuracy');

372

const lrScheduler = new LearningRateScheduler(epoch => 0.001 * Math.pow(0.9, epoch));

373

const metricsLogger = new MetricsLogger('./training.log');

374

375

// Train with all callbacks

376

const history = await model.fit(xs, ys, {

377

epochs: 200,

378

batchSize: 64,

379

validationSplit: 0.2,

380

callbacks: [

381

tensorboard,

382

earlyStop,

383

checkpoint,

384

lrScheduler,

385

metricsLogger

386

],

387

verbose: 1 // Progress bar + all callback output

388

});

389

390

console.log('Training completed with full monitoring');

391

console.log('Check TensorBoard: tensorboard --logdir ./logs');

392

393

return history;

394

}

395

```

396

397

## Types

398

399

```typescript { .api }

400

// Base callback interface

401

abstract class CustomCallback {

402

protected model?: LayersModel;

403

protected params?: Params;

404

405

setModel(model: LayersModel): void;

406

setParams(params: Params): void;

407

408

onTrainBegin?(logs?: Logs): void | Promise<void>;

409

onTrainEnd?(logs?: Logs): void | Promise<void>;

410

onEpochBegin?(epoch: number, logs?: Logs): void | Promise<void>;

411

onEpochEnd?(epoch: number, logs?: Logs): void | Promise<void>;

412

onBatchBegin?(batch: number, logs?: Logs): void | Promise<void>;

413

onBatchEnd?(batch: number, logs?: Logs): void | Promise<void>;

414

}

415

416

interface Logs {

417

[key: string]: number;

418

}

419

420

interface Params {

421

epochs: number;

422

samples: number;

423

steps: number;

424

batchSize: number;

425

verbose: number;

426

doValidation: boolean;

427

metrics: string[];

428

}

429

430

// TensorBoard callback specific types

431

interface TensorBoardCallbackArgs {

432

updateFreq?: 'batch' | 'epoch';

433

histogramFreq?: number;

434

}

435

```