Multi-backend deep learning framework providing a unified API for building and training neural networks across JAX, TensorFlow, PyTorch, and OpenVINO backends
—
Training utilities, callbacks for monitoring and controlling training processes, and model persistence functionality for saving and loading models during and after training.
Callbacks that control the training process based on monitored metrics.
class EarlyStopping:
"""
Stop training when monitored metric stops improving.
Args:
monitor (str): Metric to monitor
min_delta (float): Minimum change to qualify as improvement
patience (int): Number of epochs with no improvement to wait
verbose (int): Verbosity mode
mode (str): 'auto', 'min', or 'max'
baseline (float, optional): Baseline value for monitored metric
restore_best_weights (bool): Whether to restore best weights
start_from_epoch (int): Epoch to start monitoring from
"""
def __init__(self, monitor='val_loss', min_delta=0, patience=0, verbose=0,
mode='auto', baseline=None, restore_best_weights=False,
start_from_epoch=0, **kwargs): ...
class ReduceLROnPlateau:
"""
Reduce learning rate when metric stops improving.
Args:
monitor (str): Metric to monitor
factor (float): Factor to reduce learning rate by
patience (int): Number of epochs with no improvement to wait
verbose (int): Verbosity mode
mode (str): 'auto', 'min', or 'max'
min_delta (float): Minimum change to qualify as improvement
cooldown (int): Number of epochs to wait before resuming normal operation
min_lr (float): Lower bound on learning rate
"""
def __init__(self, monitor='val_loss', factor=0.1, patience=10, verbose=0,
mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, **kwargs): ...
class LearningRateScheduler:
"""
Learning rate scheduler with custom schedule function.
Args:
schedule (callable): Function that takes epoch index and current learning rate
verbose (int): Verbosity mode
"""
def __init__(self, schedule, verbose=0, **kwargs): ...
class TerminateOnNaN:
"""Terminate training when loss becomes NaN."""
def __init__(self, **kwargs): ...Callbacks for saving model checkpoints and handling training state.
class ModelCheckpoint:
"""
Save model checkpoints during training.
Args:
filepath (str): Path to save model files
monitor (str): Metric to monitor for best model
verbose (int): Verbosity mode
save_best_only (bool): Only save when model improves
save_weights_only (bool): Only save model weights
mode (str): 'auto', 'min', or 'max'
save_freq (str or int): Frequency to save ('epoch' or integer steps)
options (SaveOptions, optional): Options for saving
initial_value_threshold (float, optional): Initial threshold for metric
"""
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch', **kwargs): ...
class BackupAndRestore:
"""
Backup and restore training state for fault tolerance.
Args:
backup_dir (str): Directory to store backup files
save_freq (str or int): Frequency to save backups
delete_checkpoint (bool): Whether to delete old checkpoints
"""
def __init__(self, backup_dir, save_freq='epoch', delete_checkpoint=True, **kwargs): ...Callbacks for logging training progress and monitoring metrics.
class History:
"""
Record training history (automatically added to model.fit).
Attributes:
history (dict): Dictionary containing training metrics by epoch
"""
def __init__(self, **kwargs): ...
class CSVLogger:
"""
Log training progress to CSV file.
Args:
filename (str): Path to CSV file
separator (str): Field separator
append (bool): Whether to append to existing file
"""
def __init__(self, filename, separator=',', append=False, **kwargs): ...
class TensorBoard:
"""
Log training metrics for TensorBoard visualization.
Args:
log_dir (str): Directory to save TensorBoard log files
histogram_freq (int): Frequency to compute activation histograms
write_graph (bool): Whether to visualize computation graph
write_images (bool): Whether to write model weights as images
write_steps_per_second (bool): Whether to log training speed
update_freq (str or int): Frequency to write logs ('batch', 'epoch', or integer)
profile_batch (int or tuple): Batch(es) to profile for performance
embeddings_freq (int): Frequency to save embeddings
embeddings_metadata (dict, optional): Metadata for embeddings
"""
def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True,
write_images=False, write_steps_per_second=False, update_freq='epoch',
profile_batch=0, embeddings_freq=0, **kwargs): ...
class ProgbarLogger:
"""
Display training progress bar (automatically added to model.fit).
Args:
count_mode (str): 'steps' or 'samples'
stateful_metrics (set, optional): Metrics that shouldn't be averaged
"""
def __init__(self, count_mode='samples', stateful_metrics=None, **kwargs): ...
class RemoteMonitor:
"""
Send training events to remote monitoring server.
Args:
root (str): Root URL of monitoring server
path (str): Path to send events to
field (str): Field name for data
headers (dict, optional): HTTP headers
send_as_json (bool): Whether to send data as JSON
"""
def __init__(self, root='http://localhost:9000', path='/publish/epoch/end/',
field='data', headers=None, send_as_json=False, **kwargs): ...General purpose and custom callbacks for specialized training scenarios.
class LambdaCallback:
"""
Create custom callback using lambda functions.
Args:
on_epoch_begin (callable, optional): Function called at epoch start
on_epoch_end (callable, optional): Function called at epoch end
on_batch_begin (callable, optional): Function called at batch start
on_batch_end (callable, optional): Function called at batch end
on_train_begin (callable, optional): Function called at training start
on_train_end (callable, optional): Function called at training end
"""
def __init__(self, on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None,
on_batch_end=None, on_train_begin=None, on_train_end=None, **kwargs): ...
class SwapEMAWeights:
"""
Swap Exponential Moving Average weights for evaluation.
Args:
swap_on_epoch (bool): Whether to swap weights at epoch end
"""
def __init__(self, swap_on_epoch=False, **kwargs): ...Functions for saving and loading complete models or weights only.
def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
"""
Save complete model to file.
Args:
model: Keras model to save
filepath (str): Path to save model
overwrite (bool): Whether to overwrite existing file
save_format (str, optional): Format to save in ('tf', 'h5', or None for auto)
include_optimizer (bool): Whether to save optimizer state
save_traces (bool): Whether to save function traces
options (SaveOptions, optional): Platform-specific save options
signatures (callable or dict, optional): Model signatures for SavedModel
"""
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
"""
Load saved model from file.
Args:
filepath (str): Path to saved model
custom_objects (dict, optional): Custom objects for deserialization
compile (bool): Whether to compile loaded model
safe_mode (bool): Whether to load in safe mode
Returns:
Model: Loaded Keras model
"""
def save_weights(model, filepath, overwrite=True, save_format=None, options=None):
"""
Save model weights to file.
Args:
model: Keras model
filepath (str): Path to save weights
overwrite (bool): Whether to overwrite existing file
save_format (str, optional): Format to save in
options (SaveOptions, optional): Platform-specific save options
"""
def load_weights(model, filepath, skip_mismatch=False, by_name=False, options=None):
"""
Load model weights from file.
Args:
model: Keras model
filepath (str): Path to saved weights
skip_mismatch (bool): Whether to skip layers with mismatched shapes
by_name (bool): Whether to load weights by layer name
options (SaveOptions, optional): Platform-specific load options
"""Base class for creating custom callbacks.
class Callback:
"""
Base class for callbacks.
Attributes:
params (dict): Training parameters
model (Model): Reference to training model
"""
def __init__(self, **kwargs): ...
def set_params(self, params): ...
def set_model(self, model): ...
def on_train_begin(self, logs=None): ...
def on_train_end(self, logs=None): ...
def on_epoch_begin(self, epoch, logs=None): ...
def on_epoch_end(self, epoch, logs=None): ...
def on_train_batch_begin(self, batch, logs=None): ...
def on_train_batch_end(self, batch, logs=None): ...
def on_test_batch_begin(self, batch, logs=None): ...
def on_test_batch_end(self, batch, logs=None): ...
def on_predict_batch_begin(self, batch, logs=None): ...
def on_predict_batch_end(self, batch, logs=None): ...import keras
from keras import layers, callbacks
# Build model
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Configure callbacks
callback_list = [
callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
callbacks.ModelCheckpoint('best_model.keras', save_best_only=True),
callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
callbacks.TensorBoard(log_dir='./logs')
]
# Train with callbacks
history = model.fit(
x_train, y_train,
epochs=100,
validation_data=(x_val, y_val),
callbacks=callback_list
)import keras
from keras import callbacks
import numpy as np
class ValidationMetrics(callbacks.Callback):
def __init__(self, validation_data, **kwargs):
super().__init__(**kwargs)
self.validation_data = validation_data
def on_epoch_end(self, epoch, logs=None):
val_x, val_y = self.validation_data
predictions = self.model.predict(val_x, verbose=0)
# Calculate custom metrics
accuracy = np.mean(np.argmax(predictions, axis=1) == val_y)
print(f'Custom validation accuracy: {accuracy:.4f}')
# Log custom metrics
logs = logs or {}
logs['custom_val_acc'] = accuracy
# Use custom callback
custom_callback = ValidationMetrics((x_val, y_val))
model.fit(x_train, y_train, epochs=10, callbacks=[custom_callback])import keras
from keras import callbacks
import math
def step_decay(epoch, lr):
"""Step decay schedule."""
drop_rate = 0.5
epochs_drop = 10
return lr * math.pow(drop_rate, math.floor(epoch / epochs_drop))
def cosine_decay(epoch, lr):
"""Cosine annealing schedule."""
max_epochs = 100
return 0.001 * 0.5 * (1 + math.cos(math.pi * epoch / max_epochs))
# Use scheduling callback
lr_scheduler = callbacks.LearningRateScheduler(step_decay, verbose=1)
model.fit(
x_train, y_train,
epochs=50,
validation_data=(x_val, y_val),
callbacks=[lr_scheduler]
)import keras
from keras import callbacks
# Save best model based on validation loss
checkpoint_best = callbacks.ModelCheckpoint(
filepath='models/best_model_{epoch:02d}_{val_loss:.2f}.keras',
monitor='val_loss',
save_best_only=True,
save_weights_only=False,
verbose=1
)
# Save model every 5 epochs
checkpoint_regular = callbacks.ModelCheckpoint(
filepath='models/model_epoch_{epoch:02d}.keras',
save_freq=5,
verbose=1
)
# Backup and restore for fault tolerance
backup_restore = callbacks.BackupAndRestore(backup_dir='./backup')
model.fit(
x_train, y_train,
epochs=100,
validation_data=(x_val, y_val),
callbacks=[checkpoint_best, checkpoint_regular, backup_restore]
)Install with Tessl CLI
npx tessl i tessl/pypi-keras-nightly