CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/npm-tensorflow--tfjs-node

TensorFlow backend for TensorFlow.js via Node.js - provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood

Pending
Overview
Eval results
Files

callbacks.mddocs/

Training Callbacks

TensorFlow.js Node provides enhanced training callbacks that improve the training experience with progress visualization and detailed logging. These callbacks integrate seamlessly with the standard TensorFlow.js training process.

Capabilities

Progress Bar Logger

ProgbarLogger Class

A terminal-based progress bar callback that automatically displays training progress.

/**
 * Terminal-based progress bar callback for tf.Model.fit()
 * Automatically registered at verbosity level 1
 */
class ProgbarLogger extends CustomCallback {
  constructor();
}

The ProgbarLogger is automatically registered and will be used when training with verbose: 1:

Usage Example:

import * as tf from '@tensorflow/tfjs-node';

// Create model
const model = tf.sequential({
  layers: [
    tf.layers.dense({ inputShape: [10], units: 64, activation: 'relu' }),
    tf.layers.dense({ units: 32, activation: 'relu' }),
    tf.layers.dense({ units: 1, activation: 'linear' })
  ]
});

model.compile({
  optimizer: 'adam',
  loss: 'meanSquaredError',
  metrics: ['mae']
});

// Generate training data
const xs = tf.randomNormal([1000, 10]);
const ys = tf.randomNormal([1000, 1]);

// Train with progress bar (verbose: 1 automatically uses ProgbarLogger)
await model.fit(xs, ys, {
  epochs: 50,
  batchSize: 32,
  validationSplit: 0.2,
  verbose: 1  // This enables the progress bar
});

// Output will show:
// Epoch 1/50
// ████████████████████████████████ 25/25 [==============================] - 2s 45ms/step - loss: 1.2345 - mae: 0.9876 - val_loss: 1.1234 - val_mae: 0.8765
// Epoch 2/50
// ████████████████████████████████ 25/25 [==============================] - 1s 40ms/step - loss: 1.1234 - mae: 0.8765 - val_loss: 1.0123 - val_mae: 0.7654
// ...

TensorBoard Callback

TensorBoardCallback Class

Automatically log training metrics to TensorBoard during training.

/**
 * TensorBoard callback for automatic logging during training
 */
class TensorBoardCallback extends CustomCallback {
  constructor(logdir?: string, updateFreq?: 'batch' | 'epoch', histogramFreq?: number);
}

/**
 * Factory function to create TensorBoard callback
 * @param logdir - Directory to write logs (default: './logs')
 * @param args - Configuration options
 * @returns TensorBoardCallback instance
 */
function tensorBoard(logdir?: string, args?: TensorBoardCallbackArgs): TensorBoardCallback;

interface TensorBoardCallbackArgs {
  /** How often to log: 'batch' for every batch, 'epoch' for every epoch */
  updateFreq?: 'batch' | 'epoch';
  
  /** How often to log weight histograms (in epochs, 0 = disabled) */
  histogramFreq?: number;
}

Usage Example:

// Create TensorBoard callback
const tbCallback = tf.node.tensorBoard('./logs/training_run', {
  updateFreq: 'epoch',    // Log after each epoch
  histogramFreq: 5        // Log weight histograms every 5 epochs
});

// Train with TensorBoard logging
await model.fit(xs, ys, {
  epochs: 100,
  batchSize: 64,
  validationSplit: 0.1,
  callbacks: [tbCallback],
  verbose: 1  // Also show progress bar
});

console.log('Training complete. View logs with: tensorboard --logdir ./logs');

Custom Callback Creation

You can create custom callbacks by extending the CustomCallback class:

// Base class for creating custom callbacks
abstract class CustomCallback {
  onTrainBegin?(logs?: Logs): void | Promise<void>;
  onTrainEnd?(logs?: Logs): void | Promise<void>;
  onEpochBegin?(epoch: number, logs?: Logs): void | Promise<void>;
  onEpochEnd?(epoch: number, logs?: Logs): void | Promise<void>;
  onBatchBegin?(batch: number, logs?: Logs): void | Promise<void>;
  onBatchEnd?(batch: number, logs?: Logs): void | Promise<void>;
}

interface Logs {
  [key: string]: number;
}

Custom Callback Examples

Early Stopping Callback

class EarlyStoppingCallback extends tf.CustomCallback {
  private patience: number;
  private minDelta: number;
  private monitorMetric: string;
  private bestValue: number;
  private waitCount: number;
  
  constructor(patience: number = 10, minDelta: number = 0.001, monitor: string = 'val_loss') {
    super();
    this.patience = patience;
    this.minDelta = minDelta;
    this.monitorMetric = monitor;
    this.bestValue = Infinity;
    this.waitCount = 0;
  }
  
  async onEpochEnd(epoch: number, logs?: tf.Logs) {
    const currentValue = logs?.[this.monitorMetric];
    
    if (currentValue == null) {
      console.warn(`Early stopping metric '${this.monitorMetric}' not found in logs`);
      return;
    }
    
    if (currentValue < this.bestValue - this.minDelta) {
      this.bestValue = currentValue;
      this.waitCount = 0;
      console.log(`Epoch ${epoch + 1}: ${this.monitorMetric} improved to ${currentValue.toFixed(6)}`);
    } else {
      this.waitCount++;
      console.log(`Epoch ${epoch + 1}: ${this.monitorMetric} did not improve (${this.waitCount}/${this.patience})`);
      
      if (this.waitCount >= this.patience) {
        console.log(`Early stopping after ${epoch + 1} epochs`);
        this.model.stopTraining = true;
      }
    }
  }
}

// Usage
const earlyStop = new EarlyStoppingCallback(15, 0.001, 'val_loss');

await model.fit(xs, ys, {
  epochs: 200,
  validationSplit: 0.2,
  callbacks: [earlyStop],
  verbose: 1
});

Learning Rate Scheduler

class LearningRateScheduler extends tf.CustomCallback {
  private scheduleFn: (epoch: number) => number;
  
  constructor(schedule: (epoch: number) => number) {
    super();
    this.scheduleFn = schedule;
  }
  
  async onEpochBegin(epoch: number, logs?: tf.Logs) {
    const newLr = this.scheduleFn(epoch);
    
    // Update optimizer learning rate
    if (this.model.optimizer instanceof tf.AdamOptimizer) {
      this.model.optimizer.learningRate = newLr;
    }
    
    console.log(`Epoch ${epoch + 1}: Learning rate set to ${newLr}`);
  }
}

// Usage with exponential decay
const lrScheduler = new LearningRateScheduler((epoch: number) => {
  const initialLr = 0.001;
  const decayRate = 0.95;
  return initialLr * Math.pow(decayRate, epoch);
});

await model.fit(xs, ys, {
  epochs: 100,
  callbacks: [lrScheduler],
  verbose: 1
});

Model Checkpointing

class ModelCheckpoint extends tf.CustomCallback {
  private filepath: string;
  private monitor: string;
  private saveWeightsOnly: boolean;
  private saveBestOnly: boolean;
  private bestValue: number;
  
  constructor(
    filepath: string,
    monitor: string = 'val_loss',
    saveWeightsOnly: boolean = false,
    saveBestOnly: boolean = true
  ) {
    super();
    this.filepath = filepath;
    this.monitor = monitor;
    this.saveWeightsOnly = saveWeightsOnly;
    this.saveBestOnly = saveBestOnly;
    this.bestValue = Infinity;
  }
  
  async onEpochEnd(epoch: number, logs?: tf.Logs) {
    const currentValue = logs?.[this.monitor];
    
    if (currentValue == null) {
      console.warn(`Checkpoint metric '${this.monitor}' not found in logs`);
      return;
    }
    
    let shouldSave = !this.saveBestOnly;
    
    if (this.saveBestOnly && currentValue < this.bestValue) {
      this.bestValue = currentValue;
      shouldSave = true;
    }
    
    if (shouldSave) {
      const epochPath = this.filepath.replace('{epoch}', (epoch + 1).toString());
      
      try {
        if (this.saveWeightsOnly) {
          await this.model.saveWeights(`file://${epochPath}`);
        } else {
          await this.model.save(`file://${epochPath}`);
        }
        
        console.log(`Epoch ${epoch + 1}: Model saved to ${epochPath}`);
      } catch (error) {
        console.error(`Failed to save model: ${error.message}`);
      }
    }
  }
}

// Usage
const checkpoint = new ModelCheckpoint(
  './checkpoints/model-epoch-{epoch}',
  'val_accuracy',
  false,  // Save full model
  true    // Save only best model
);

await model.fit(xs, ys, {
  epochs: 50,
  validationSplit: 0.2,
  callbacks: [checkpoint],
  verbose: 1
});

Metrics Logger

class MetricsLogger extends tf.CustomCallback {
  private metrics: Array<{epoch: number, logs: tf.Logs}> = [];
  private logFile?: string;
  
  constructor(logFile?: string) {
    super();
    this.logFile = logFile;
  }
  
  async onEpochEnd(epoch: number, logs?: tf.Logs) {
    if (logs) {
      this.metrics.push({ epoch: epoch + 1, logs: { ...logs } });
      
      // Log to console
      const logStr = Object.entries(logs)
        .map(([key, value]) => `${key}: ${value.toFixed(6)}`)
        .join(', ');
      console.log(`Epoch ${epoch + 1} - ${logStr}`);
      
      // Log to file if specified
      if (this.logFile) {
        const fs = require('fs');
        const logEntry = JSON.stringify({ epoch: epoch + 1, ...logs }) + '\n';
        fs.appendFileSync(this.logFile, logEntry);
      }
    }
  }
  
  getMetrics() {
    return this.metrics;
  }
  
  saveMetrics(filepath: string) {
    const fs = require('fs');
    fs.writeFileSync(filepath, JSON.stringify(this.metrics, null, 2));
  }
}

// Usage
const metricsLogger = new MetricsLogger('./training_log.jsonl');

await model.fit(xs, ys, {
  epochs: 30,
  validationSplit: 0.2,
  callbacks: [metricsLogger],
  verbose: 1
});

// Save metrics summary
metricsLogger.saveMetrics('./training_summary.json');

Combining Multiple Callbacks

You can use multiple callbacks together for comprehensive training monitoring:

async function trainWithFullMonitoring(
  model: tf.LayersModel,
  xs: tf.Tensor,
  ys: tf.Tensor
) {
  // Create all callbacks
  const tensorboard = tf.node.tensorBoard('./logs/full_monitoring');
  const earlyStop = new EarlyStoppingCallback(20, 0.0001, 'val_loss');
  const checkpoint = new ModelCheckpoint('./checkpoints/best-model', 'val_accuracy');
  const lrScheduler = new LearningRateScheduler(epoch => 0.001 * Math.pow(0.9, epoch));
  const metricsLogger = new MetricsLogger('./training.log');
  
  // Train with all callbacks
  const history = await model.fit(xs, ys, {
    epochs: 200,
    batchSize: 64,
    validationSplit: 0.2,
    callbacks: [
      tensorboard,
      earlyStop,
      checkpoint,
      lrScheduler,
      metricsLogger
    ],
    verbose: 1  // Progress bar + all callback output
  });
  
  console.log('Training completed with full monitoring');
  console.log('Check TensorBoard: tensorboard --logdir ./logs');
  
  return history;
}

Types

// Base callback interface
abstract class CustomCallback {
  protected model?: LayersModel;
  protected params?: Params;
  
  setModel(model: LayersModel): void;
  setParams(params: Params): void;
  
  onTrainBegin?(logs?: Logs): void | Promise<void>;
  onTrainEnd?(logs?: Logs): void | Promise<void>;
  onEpochBegin?(epoch: number, logs?: Logs): void | Promise<void>;
  onEpochEnd?(epoch: number, logs?: Logs): void | Promise<void>;
  onBatchBegin?(batch: number, logs?: Logs): void | Promise<void>;
  onBatchEnd?(batch: number, logs?: Logs): void | Promise<void>;
}

interface Logs {
  [key: string]: number;
}

interface Params {
  epochs: number;
  samples: number;
  steps: number;
  batchSize: number;
  verbose: number;
  doValidation: boolean;
  metrics: string[];
}

// TensorBoard callback specific types
interface TensorBoardCallbackArgs {
  updateFreq?: 'batch' | 'epoch';
  histogramFreq?: number;
}

Install with Tessl CLI

npx tessl i tessl/npm-tensorflow--tfjs-node

docs

callbacks.md

image-processing.md

index.md

io.md

savedmodel.md

tensorboard.md

tile.json