0
# Training and Callbacks
1
2
Training utilities, callbacks for monitoring and controlling training processes, and model persistence functionality for saving and loading models during and after training.
3
4
## Capabilities
5
6
### Training Control Callbacks
7
8
Callbacks that control the training process based on monitored metrics.
9
10
```python { .api }
11
class EarlyStopping:
12
"""
13
Stop training when monitored metric stops improving.
14
15
Args:
16
monitor (str): Metric to monitor
17
min_delta (float): Minimum change to qualify as improvement
18
patience (int): Number of epochs with no improvement to wait
19
verbose (int): Verbosity mode
20
mode (str): 'auto', 'min', or 'max'
21
baseline (float, optional): Baseline value for monitored metric
22
restore_best_weights (bool): Whether to restore best weights
23
start_from_epoch (int): Epoch to start monitoring from
24
"""
25
def __init__(self, monitor='val_loss', min_delta=0, patience=0, verbose=0,
26
mode='auto', baseline=None, restore_best_weights=False,
27
start_from_epoch=0, **kwargs): ...
28
29
class ReduceLROnPlateau:
30
"""
31
Reduce learning rate when metric stops improving.
32
33
Args:
34
monitor (str): Metric to monitor
35
factor (float): Factor to reduce learning rate by
36
patience (int): Number of epochs with no improvement to wait
37
verbose (int): Verbosity mode
38
mode (str): 'auto', 'min', or 'max'
39
min_delta (float): Minimum change to qualify as improvement
40
cooldown (int): Number of epochs to wait before resuming normal operation
41
min_lr (float): Lower bound on learning rate
42
"""
43
def __init__(self, monitor='val_loss', factor=0.1, patience=10, verbose=0,
44
mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, **kwargs): ...
45
46
class LearningRateScheduler:
47
"""
48
Learning rate scheduler with custom schedule function.
49
50
Args:
51
schedule (callable): Function that takes epoch index and current learning rate
52
verbose (int): Verbosity mode
53
"""
54
def __init__(self, schedule, verbose=0, **kwargs): ...
55
56
class TerminateOnNaN:
57
"""Terminate training when loss becomes NaN."""
58
def __init__(self, **kwargs): ...
59
```
60
61
### Model Persistence Callbacks
62
63
Callbacks for saving model checkpoints and handling training state.
64
65
```python { .api }
66
class ModelCheckpoint:
67
"""
68
Save model checkpoints during training.
69
70
Args:
71
filepath (str): Path to save model files
72
monitor (str): Metric to monitor for best model
73
verbose (int): Verbosity mode
74
save_best_only (bool): Only save when model improves
75
save_weights_only (bool): Only save model weights
76
mode (str): 'auto', 'min', or 'max'
77
save_freq (str or int): Frequency to save ('epoch' or integer steps)
78
options (SaveOptions, optional): Options for saving
79
initial_value_threshold (float, optional): Initial threshold for metric
80
"""
81
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False,
82
save_weights_only=False, mode='auto', save_freq='epoch', **kwargs): ...
83
84
class BackupAndRestore:
85
"""
86
Backup and restore training state for fault tolerance.
87
88
Args:
89
backup_dir (str): Directory to store backup files
90
save_freq (str or int): Frequency to save backups
91
delete_checkpoint (bool): Whether to delete old checkpoints
92
"""
93
def __init__(self, backup_dir, save_freq='epoch', delete_checkpoint=True, **kwargs): ...
94
```
95
96
### Logging and Monitoring Callbacks
97
98
Callbacks for logging training progress and monitoring metrics.
99
100
```python { .api }
101
class History:
102
"""
103
Record training history (automatically added to model.fit).
104
105
Attributes:
106
history (dict): Dictionary containing training metrics by epoch
107
"""
108
def __init__(self, **kwargs): ...
109
110
class CSVLogger:
111
"""
112
Log training progress to CSV file.
113
114
Args:
115
filename (str): Path to CSV file
116
separator (str): Field separator
117
append (bool): Whether to append to existing file
118
"""
119
def __init__(self, filename, separator=',', append=False, **kwargs): ...
120
121
class TensorBoard:
122
"""
123
Log training metrics for TensorBoard visualization.
124
125
Args:
126
log_dir (str): Directory to save TensorBoard log files
127
histogram_freq (int): Frequency to compute activation histograms
128
write_graph (bool): Whether to visualize computation graph
129
write_images (bool): Whether to write model weights as images
130
write_steps_per_second (bool): Whether to log training speed
131
update_freq (str or int): Frequency to write logs ('batch', 'epoch', or integer)
132
profile_batch (int or tuple): Batch(es) to profile for performance
133
embeddings_freq (int): Frequency to save embeddings
134
embeddings_metadata (dict, optional): Metadata for embeddings
135
"""
136
def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True,
137
write_images=False, write_steps_per_second=False, update_freq='epoch',
138
profile_batch=0, embeddings_freq=0, **kwargs): ...
139
140
class ProgbarLogger:
141
"""
142
Display training progress bar (automatically added to model.fit).
143
144
Args:
145
count_mode (str): 'steps' or 'samples'
146
stateful_metrics (set, optional): Metrics that shouldn't be averaged
147
"""
148
def __init__(self, count_mode='samples', stateful_metrics=None, **kwargs): ...
149
150
class RemoteMonitor:
151
"""
152
Send training events to remote monitoring server.
153
154
Args:
155
root (str): Root URL of monitoring server
156
path (str): Path to send events to
157
field (str): Field name for data
158
headers (dict, optional): HTTP headers
159
send_as_json (bool): Whether to send data as JSON
160
"""
161
def __init__(self, root='http://localhost:9000', path='/publish/epoch/end/',
162
field='data', headers=None, send_as_json=False, **kwargs): ...
163
```
164
165
### Utility Callbacks
166
167
General purpose and custom callbacks for specialized training scenarios.
168
169
```python { .api }
170
class LambdaCallback:
171
"""
172
Create custom callback using lambda functions.
173
174
Args:
175
on_epoch_begin (callable, optional): Function called at epoch start
176
on_epoch_end (callable, optional): Function called at epoch end
177
on_batch_begin (callable, optional): Function called at batch start
178
on_batch_end (callable, optional): Function called at batch end
179
on_train_begin (callable, optional): Function called at training start
180
on_train_end (callable, optional): Function called at training end
181
"""
182
def __init__(self, on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None,
183
on_batch_end=None, on_train_begin=None, on_train_end=None, **kwargs): ...
184
185
class SwapEMAWeights:
186
"""
187
Swap Exponential Moving Average weights for evaluation.
188
189
Args:
190
swap_on_epoch (bool): Whether to swap weights at epoch end
191
"""
192
def __init__(self, swap_on_epoch=False, **kwargs): ...
193
```
194
195
### Model Persistence Functions
196
197
Functions for saving and loading complete models or weights only.
198
199
```python { .api }
200
def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
201
"""
202
Save complete model to file.
203
204
Args:
205
model: Keras model to save
206
filepath (str): Path to save model
207
overwrite (bool): Whether to overwrite existing file
208
save_format (str, optional): Format to save in ('tf', 'h5', or None for auto)
209
include_optimizer (bool): Whether to save optimizer state
210
save_traces (bool): Whether to save function traces
211
options (SaveOptions, optional): Platform-specific save options
212
signatures (callable or dict, optional): Model signatures for SavedModel
213
"""
214
215
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
216
"""
217
Load saved model from file.
218
219
Args:
220
filepath (str): Path to saved model
221
custom_objects (dict, optional): Custom objects for deserialization
222
compile (bool): Whether to compile loaded model
223
safe_mode (bool): Whether to load in safe mode
224
225
Returns:
226
Model: Loaded Keras model
227
"""
228
229
def save_weights(model, filepath, overwrite=True, save_format=None, options=None):
230
"""
231
Save model weights to file.
232
233
Args:
234
model: Keras model
235
filepath (str): Path to save weights
236
overwrite (bool): Whether to overwrite existing file
237
save_format (str, optional): Format to save in
238
options (SaveOptions, optional): Platform-specific save options
239
"""
240
241
def load_weights(model, filepath, skip_mismatch=False, by_name=False, options=None):
242
"""
243
Load model weights from file.
244
245
Args:
246
model: Keras model
247
filepath (str): Path to saved weights
248
skip_mismatch (bool): Whether to skip layers with mismatched shapes
249
by_name (bool): Whether to load weights by layer name
250
options (SaveOptions, optional): Platform-specific load options
251
"""
252
```
253
254
### Base Callback Class
255
256
Base class for creating custom callbacks.
257
258
```python { .api }
259
class Callback:
260
"""
261
Base class for callbacks.
262
263
Attributes:
264
params (dict): Training parameters
265
model (Model): Reference to training model
266
"""
267
def __init__(self, **kwargs): ...
268
269
def set_params(self, params): ...
270
def set_model(self, model): ...
271
272
def on_train_begin(self, logs=None): ...
273
def on_train_end(self, logs=None): ...
274
def on_epoch_begin(self, epoch, logs=None): ...
275
def on_epoch_end(self, epoch, logs=None): ...
276
def on_train_batch_begin(self, batch, logs=None): ...
277
def on_train_batch_end(self, batch, logs=None): ...
278
def on_test_batch_begin(self, batch, logs=None): ...
279
def on_test_batch_end(self, batch, logs=None): ...
280
def on_predict_batch_begin(self, batch, logs=None): ...
281
def on_predict_batch_end(self, batch, logs=None): ...
282
```
283
284
## Usage Examples
285
286
### Basic Training with Callbacks
287
288
```python
289
import keras
290
from keras import layers, callbacks
291
292
# Build model
293
model = keras.Sequential([
294
layers.Dense(64, activation='relu', input_shape=(784,)),
295
layers.Dropout(0.2),
296
layers.Dense(10, activation='softmax')
297
])
298
299
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
300
301
# Configure callbacks
302
callback_list = [
303
callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
304
callbacks.ModelCheckpoint('best_model.keras', save_best_only=True),
305
callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
306
callbacks.TensorBoard(log_dir='./logs')
307
]
308
309
# Train with callbacks
310
history = model.fit(
311
x_train, y_train,
312
epochs=100,
313
validation_data=(x_val, y_val),
314
callbacks=callback_list
315
)
316
```
317
318
### Custom Callback
319
320
```python
321
import keras
322
from keras import callbacks
323
import numpy as np
324
325
class ValidationMetrics(callbacks.Callback):
326
def __init__(self, validation_data, **kwargs):
327
super().__init__(**kwargs)
328
self.validation_data = validation_data
329
330
def on_epoch_end(self, epoch, logs=None):
331
val_x, val_y = self.validation_data
332
predictions = self.model.predict(val_x, verbose=0)
333
334
# Calculate custom metrics
335
accuracy = np.mean(np.argmax(predictions, axis=1) == val_y)
336
print(f'Custom validation accuracy: {accuracy:.4f}')
337
338
# Log custom metrics
339
logs = logs or {}
340
logs['custom_val_acc'] = accuracy
341
342
# Use custom callback
343
custom_callback = ValidationMetrics((x_val, y_val))
344
model.fit(x_train, y_train, epochs=10, callbacks=[custom_callback])
345
```
346
347
### Learning Rate Scheduling
348
349
```python
350
import keras
351
from keras import callbacks
352
import math
353
354
def step_decay(epoch, lr):
355
"""Step decay schedule."""
356
drop_rate = 0.5
357
epochs_drop = 10
358
return lr * math.pow(drop_rate, math.floor(epoch / epochs_drop))
359
360
def cosine_decay(epoch, lr):
361
"""Cosine annealing schedule."""
362
max_epochs = 100
363
return 0.001 * 0.5 * (1 + math.cos(math.pi * epoch / max_epochs))
364
365
# Use scheduling callback
366
lr_scheduler = callbacks.LearningRateScheduler(step_decay, verbose=1)
367
368
model.fit(
369
x_train, y_train,
370
epochs=50,
371
validation_data=(x_val, y_val),
372
callbacks=[lr_scheduler]
373
)
374
```
375
376
### Model Checkpointing Strategy
377
378
```python
379
import keras
380
from keras import callbacks
381
382
# Save best model based on validation loss
383
checkpoint_best = callbacks.ModelCheckpoint(
384
filepath='models/best_model_{epoch:02d}_{val_loss:.2f}.keras',
385
monitor='val_loss',
386
save_best_only=True,
387
save_weights_only=False,
388
verbose=1
389
)
390
391
# Save model every 5 epochs
392
checkpoint_regular = callbacks.ModelCheckpoint(
393
filepath='models/model_epoch_{epoch:02d}.keras',
394
save_freq=5,
395
verbose=1
396
)
397
398
# Backup and restore for fault tolerance
399
backup_restore = callbacks.BackupAndRestore(backup_dir='./backup')
400
401
model.fit(
402
x_train, y_train,
403
epochs=100,
404
validation_data=(x_val, y_val),
405
callbacks=[checkpoint_best, checkpoint_regular, backup_restore]
406
)
407
```