CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-keras-nightly

Multi-backend deep learning framework providing a unified API for building and training neural networks across JAX, TensorFlow, PyTorch, and OpenVINO backends

Pending
Overview
Eval results
Files

training-callbacks.mddocs/

Training and Callbacks

Training utilities, callbacks for monitoring and controlling training processes, and model persistence functionality for saving and loading models during and after training.

Capabilities

Training Control Callbacks

Callbacks that control the training process based on monitored metrics.

class EarlyStopping:
    """
    Stop training when monitored metric stops improving.
    
    Args:
        monitor (str): Metric to monitor
        min_delta (float): Minimum change to qualify as improvement
        patience (int): Number of epochs with no improvement to wait
        verbose (int): Verbosity mode
        mode (str): 'auto', 'min', or 'max'
        baseline (float, optional): Baseline value for monitored metric
        restore_best_weights (bool): Whether to restore best weights
        start_from_epoch (int): Epoch to start monitoring from
    """
    def __init__(self, monitor='val_loss', min_delta=0, patience=0, verbose=0,
                 mode='auto', baseline=None, restore_best_weights=False,
                 start_from_epoch=0, **kwargs): ...

class ReduceLROnPlateau:
    """
    Reduce learning rate when metric stops improving.
    
    Args:
        monitor (str): Metric to monitor
        factor (float): Factor to reduce learning rate by
        patience (int): Number of epochs with no improvement to wait
        verbose (int): Verbosity mode  
        mode (str): 'auto', 'min', or 'max'
        min_delta (float): Minimum change to qualify as improvement
        cooldown (int): Number of epochs to wait before resuming normal operation
        min_lr (float): Lower bound on learning rate
    """
    def __init__(self, monitor='val_loss', factor=0.1, patience=10, verbose=0,
                 mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, **kwargs): ...

class LearningRateScheduler:
    """
    Learning rate scheduler with custom schedule function.
    
    Args:
        schedule (callable): Function that takes epoch index and current learning rate
        verbose (int): Verbosity mode
    """
    def __init__(self, schedule, verbose=0, **kwargs): ...

class TerminateOnNaN:
    """Terminate training when loss becomes NaN."""
    def __init__(self, **kwargs): ...

Model Persistence Callbacks

Callbacks for saving model checkpoints and handling training state.

class ModelCheckpoint:
    """
    Save model checkpoints during training.
    
    Args:
        filepath (str): Path to save model files
        monitor (str): Metric to monitor for best model
        verbose (int): Verbosity mode
        save_best_only (bool): Only save when model improves
        save_weights_only (bool): Only save model weights
        mode (str): 'auto', 'min', or 'max'  
        save_freq (str or int): Frequency to save ('epoch' or integer steps)
        options (SaveOptions, optional): Options for saving
        initial_value_threshold (float, optional): Initial threshold for metric
    """
    def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False,
                 save_weights_only=False, mode='auto', save_freq='epoch', **kwargs): ...

class BackupAndRestore:
    """
    Backup and restore training state for fault tolerance.
    
    Args:
        backup_dir (str): Directory to store backup files
        save_freq (str or int): Frequency to save backups
        delete_checkpoint (bool): Whether to delete old checkpoints
    """
    def __init__(self, backup_dir, save_freq='epoch', delete_checkpoint=True, **kwargs): ...

Logging and Monitoring Callbacks

Callbacks for logging training progress and monitoring metrics.

class History:
    """
    Record training history (automatically added to model.fit).
    
    Attributes:
        history (dict): Dictionary containing training metrics by epoch
    """
    def __init__(self, **kwargs): ...

class CSVLogger:
    """
    Log training progress to CSV file.
    
    Args:
        filename (str): Path to CSV file
        separator (str): Field separator
        append (bool): Whether to append to existing file
    """
    def __init__(self, filename, separator=',', append=False, **kwargs): ...

class TensorBoard:
    """
    Log training metrics for TensorBoard visualization.
    
    Args:
        log_dir (str): Directory to save TensorBoard log files
        histogram_freq (int): Frequency to compute activation histograms
        write_graph (bool): Whether to visualize computation graph
        write_images (bool): Whether to write model weights as images
        write_steps_per_second (bool): Whether to log training speed
        update_freq (str or int): Frequency to write logs ('batch', 'epoch', or integer)
        profile_batch (int or tuple): Batch(es) to profile for performance
        embeddings_freq (int): Frequency to save embeddings
        embeddings_metadata (dict, optional): Metadata for embeddings
    """
    def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True,
                 write_images=False, write_steps_per_second=False, update_freq='epoch',
                 profile_batch=0, embeddings_freq=0, **kwargs): ...

class ProgbarLogger:
    """
    Display training progress bar (automatically added to model.fit).
    
    Args:
        count_mode (str): 'steps' or 'samples'
        stateful_metrics (set, optional): Metrics that shouldn't be averaged
    """
    def __init__(self, count_mode='samples', stateful_metrics=None, **kwargs): ...

class RemoteMonitor:
    """
    Send training events to remote monitoring server.
    
    Args:
        root (str): Root URL of monitoring server
        path (str): Path to send events to
        field (str): Field name for data
        headers (dict, optional): HTTP headers
        send_as_json (bool): Whether to send data as JSON
    """
    def __init__(self, root='http://localhost:9000', path='/publish/epoch/end/',
                 field='data', headers=None, send_as_json=False, **kwargs): ...

Utility Callbacks

General purpose and custom callbacks for specialized training scenarios.

class LambdaCallback:
    """
    Create custom callback using lambda functions.
    
    Args:
        on_epoch_begin (callable, optional): Function called at epoch start
        on_epoch_end (callable, optional): Function called at epoch end
        on_batch_begin (callable, optional): Function called at batch start  
        on_batch_end (callable, optional): Function called at batch end
        on_train_begin (callable, optional): Function called at training start
        on_train_end (callable, optional): Function called at training end
    """
    def __init__(self, on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None,
                 on_batch_end=None, on_train_begin=None, on_train_end=None, **kwargs): ...

class SwapEMAWeights:
    """
    Swap Exponential Moving Average weights for evaluation.
    
    Args:
        swap_on_epoch (bool): Whether to swap weights at epoch end
    """
    def __init__(self, swap_on_epoch=False, **kwargs): ...

Model Persistence Functions

Functions for saving and loading complete models or weights only.

def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
    """
    Save complete model to file.
    
    Args:
        model: Keras model to save
        filepath (str): Path to save model
        overwrite (bool): Whether to overwrite existing file
        save_format (str, optional): Format to save in ('tf', 'h5', or None for auto)
        include_optimizer (bool): Whether to save optimizer state
        save_traces (bool): Whether to save function traces
        options (SaveOptions, optional): Platform-specific save options
        signatures (callable or dict, optional): Model signatures for SavedModel
    """

def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
    """
    Load saved model from file.
    
    Args:
        filepath (str): Path to saved model
        custom_objects (dict, optional): Custom objects for deserialization
        compile (bool): Whether to compile loaded model
        safe_mode (bool): Whether to load in safe mode
        
    Returns:
        Model: Loaded Keras model
    """

def save_weights(model, filepath, overwrite=True, save_format=None, options=None):
    """
    Save model weights to file.
    
    Args:
        model: Keras model
        filepath (str): Path to save weights
        overwrite (bool): Whether to overwrite existing file
        save_format (str, optional): Format to save in
        options (SaveOptions, optional): Platform-specific save options
    """

def load_weights(model, filepath, skip_mismatch=False, by_name=False, options=None):
    """
    Load model weights from file.
    
    Args:
        model: Keras model
        filepath (str): Path to saved weights
        skip_mismatch (bool): Whether to skip layers with mismatched shapes
        by_name (bool): Whether to load weights by layer name
        options (SaveOptions, optional): Platform-specific load options
    """

Base Callback Class

Base class for creating custom callbacks.

class Callback:
    """
    Base class for callbacks.
    
    Attributes:
        params (dict): Training parameters
        model (Model): Reference to training model
    """
    def __init__(self, **kwargs): ...
    
    def set_params(self, params): ...
    def set_model(self, model): ...
    
    def on_train_begin(self, logs=None): ...
    def on_train_end(self, logs=None): ...
    def on_epoch_begin(self, epoch, logs=None): ...
    def on_epoch_end(self, epoch, logs=None): ...
    def on_train_batch_begin(self, batch, logs=None): ...
    def on_train_batch_end(self, batch, logs=None): ...
    def on_test_batch_begin(self, batch, logs=None): ...
    def on_test_batch_end(self, batch, logs=None): ...
    def on_predict_batch_begin(self, batch, logs=None): ...
    def on_predict_batch_end(self, batch, logs=None): ...

Usage Examples

Basic Training with Callbacks

import keras
from keras import layers, callbacks

# Build model
model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(784,)),
    layers.Dropout(0.2),
    layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Configure callbacks
callback_list = [
    callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    callbacks.ModelCheckpoint('best_model.keras', save_best_only=True),
    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
    callbacks.TensorBoard(log_dir='./logs')
]

# Train with callbacks
history = model.fit(
    x_train, y_train,
    epochs=100,
    validation_data=(x_val, y_val),
    callbacks=callback_list
)

Custom Callback

import keras
from keras import callbacks
import numpy as np

class ValidationMetrics(callbacks.Callback):
    def __init__(self, validation_data, **kwargs):
        super().__init__(**kwargs)
        self.validation_data = validation_data
        
    def on_epoch_end(self, epoch, logs=None):
        val_x, val_y = self.validation_data
        predictions = self.model.predict(val_x, verbose=0)
        
        # Calculate custom metrics
        accuracy = np.mean(np.argmax(predictions, axis=1) == val_y)
        print(f'Custom validation accuracy: {accuracy:.4f}')
        
        # Log custom metrics
        logs = logs or {}
        logs['custom_val_acc'] = accuracy

# Use custom callback
custom_callback = ValidationMetrics((x_val, y_val))
model.fit(x_train, y_train, epochs=10, callbacks=[custom_callback])

Learning Rate Scheduling

import keras
from keras import callbacks
import math

def step_decay(epoch, lr):
    """Step decay schedule."""
    drop_rate = 0.5
    epochs_drop = 10
    return lr * math.pow(drop_rate, math.floor(epoch / epochs_drop))

def cosine_decay(epoch, lr):
    """Cosine annealing schedule."""
    max_epochs = 100
    return 0.001 * 0.5 * (1 + math.cos(math.pi * epoch / max_epochs))

# Use scheduling callback
lr_scheduler = callbacks.LearningRateScheduler(step_decay, verbose=1)

model.fit(
    x_train, y_train,
    epochs=50,
    validation_data=(x_val, y_val),
    callbacks=[lr_scheduler]
)

Model Checkpointing Strategy

import keras
from keras import callbacks

# Save best model based on validation loss
checkpoint_best = callbacks.ModelCheckpoint(
    filepath='models/best_model_{epoch:02d}_{val_loss:.2f}.keras',
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=False,
    verbose=1
)

# Save model every 5 epochs
checkpoint_regular = callbacks.ModelCheckpoint(
    filepath='models/model_epoch_{epoch:02d}.keras',
    save_freq=5,
    verbose=1
)

# Backup and restore for fault tolerance
backup_restore = callbacks.BackupAndRestore(backup_dir='./backup')

model.fit(
    x_train, y_train,
    epochs=100,
    validation_data=(x_val, y_val),
    callbacks=[checkpoint_best, checkpoint_regular, backup_restore]
)

Install with Tessl CLI

npx tessl i tessl/pypi-keras-nightly

docs

activations.md

applications.md

backend-config.md

core-framework.md

index.md

initializers.md

layers.md

losses-metrics.md

operations.md

optimizers.md

preprocessing.md

regularizers.md

training-callbacks.md

tile.json