0
# Training and Optimization
1
2
Optimizers, loss functions, metrics, and callbacks for training neural networks effectively. These components control how models learn from data and how training progress is monitored and controlled.
3
4
## Capabilities
5
6
### Optimizers
7
8
Optimization algorithms that update model parameters during training to minimize the loss function.
9
10
```python { .api }
11
class Optimizer:
12
def __init__(self, learning_rate=0.001, name=None, **kwargs):
13
"""
14
Base class for all optimizers.
15
16
Parameters:
17
- learning_rate: Initial learning rate
18
- name: Name of the optimizer
19
"""
20
21
def apply_gradients(self, grads_and_vars):
22
"""
23
Apply gradients to variables.
24
25
Parameters:
26
- grads_and_vars: List of (gradient, variable) pairs
27
"""
28
29
class SGD(Optimizer):
30
def __init__(self, learning_rate=0.01, momentum=0.0, nesterov=False, **kwargs):
31
"""
32
Stochastic Gradient Descent optimizer.
33
34
Parameters:
35
- learning_rate: Learning rate
36
- momentum: Momentum factor
37
- nesterov: Whether to apply Nesterov momentum
38
"""
39
40
class Adam(Optimizer):
41
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
42
epsilon=1e-7, amsgrad=False, weight_decay=None, clipnorm=None,
43
clipvalue=None, global_clipnorm=None, use_ema=False,
44
ema_momentum=0.99, ema_overwrite_frequency=None,
45
loss_scale_factor=None, gradient_accumulation_steps=None, **kwargs):
46
"""
47
Adam optimizer.
48
49
Parameters:
50
- learning_rate: Learning rate
51
- beta_1: Exponential decay rate for first moment estimates
52
- beta_2: Exponential decay rate for second moment estimates
53
- epsilon: Small constant for numerical stability
54
- amsgrad: Whether to apply AMSGrad variant
55
- weight_decay: Weight decay coefficient
56
- clipnorm: Global norm clipping value
57
- clipvalue: Value clipping threshold
58
- global_clipnorm: Global gradient norm clipping
59
- use_ema: Whether to use exponential moving average
60
- ema_momentum: EMA momentum coefficient
61
- ema_overwrite_frequency: EMA overwrite frequency
62
- loss_scale_factor: Loss scaling factor
63
- gradient_accumulation_steps: Gradient accumulation steps
64
"""
65
66
class AdamW(Optimizer):
67
def __init__(self, learning_rate=0.001, weight_decay=0.004, beta_1=0.9,
68
beta_2=0.999, epsilon=1e-7, amsgrad=False, **kwargs):
69
"""
70
AdamW optimizer with decoupled weight decay.
71
72
Parameters:
73
- learning_rate: Learning rate
74
- weight_decay: Weight decay coefficient
75
- beta_1: Exponential decay rate for first moment estimates
76
- beta_2: Exponential decay rate for second moment estimates
77
- epsilon: Small constant for numerical stability
78
- amsgrad: Whether to apply AMSGrad variant
79
"""
80
81
class RMSprop(Optimizer):
82
def __init__(self, learning_rate=0.001, rho=0.9, momentum=0.0,
83
epsilon=1e-7, centered=False, **kwargs):
84
"""
85
RMSprop optimizer.
86
87
Parameters:
88
- learning_rate: Learning rate
89
- rho: Discounting factor for history/coming gradient
90
- momentum: Momentum factor
91
- epsilon: Small constant for numerical stability
92
- centered: Whether to normalize by estimated variance
93
"""
94
95
class Adagrad(Optimizer):
96
def __init__(self, learning_rate=0.001, initial_accumulator_value=0.1,
97
epsilon=1e-7, **kwargs):
98
"""
99
Adagrad optimizer.
100
101
Parameters:
102
- learning_rate: Learning rate
103
- initial_accumulator_value: Initial value for accumulators
104
- epsilon: Small constant for numerical stability
105
"""
106
107
class Adadelta(Optimizer):
108
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-7, **kwargs):
109
"""
110
Adadelta optimizer.
111
112
Parameters:
113
- learning_rate: Learning rate
114
- rho: Decay factor
115
- epsilon: Small constant for numerical stability
116
"""
117
```
118
119
### Loss Functions
120
121
Functions that measure the difference between predicted and actual values, guiding the optimization process.
122
123
```python { .api }
124
class Loss:
125
def __init__(self, reduction='sum_over_batch_size', name=None, **kwargs):
126
"""
127
Base class for all loss functions.
128
129
Parameters:
130
- reduction: Type of reduction to apply
131
- name: Name of the loss function
132
"""
133
134
def __call__(self, y_true, y_pred, sample_weight=None):
135
"""
136
Compute loss value.
137
138
Parameters:
139
- y_true: Ground truth values
140
- y_pred: Predicted values
141
- sample_weight: Optional sample weights
142
143
Returns:
144
Loss value
145
"""
146
147
class SparseCategoricalCrossentropy(Loss):
148
def __init__(self, from_logits=False, ignore_class=None, **kwargs):
149
"""
150
Sparse categorical crossentropy loss.
151
152
Parameters:
153
- from_logits: Whether predictions are logits or probabilities
154
- ignore_class: Optional class index to ignore
155
"""
156
157
class CategoricalCrossentropy(Loss):
158
def __init__(self, from_logits=False, label_smoothing=0.0, **kwargs):
159
"""
160
Categorical crossentropy loss.
161
162
Parameters:
163
- from_logits: Whether predictions are logits or probabilities
164
- label_smoothing: Label smoothing factor
165
"""
166
167
class BinaryCrossentropy(Loss):
168
def __init__(self, from_logits=False, label_smoothing=0.0, **kwargs):
169
"""
170
Binary crossentropy loss.
171
172
Parameters:
173
- from_logits: Whether predictions are logits or probabilities
174
- label_smoothing: Label smoothing factor
175
"""
176
177
class MeanSquaredError(Loss):
178
def __init__(self, **kwargs):
179
"""Mean squared error loss."""
180
181
class MeanAbsoluteError(Loss):
182
def __init__(self, **kwargs):
183
"""Mean absolute error loss."""
184
185
class Huber(Loss):
186
def __init__(self, delta=1.0, **kwargs):
187
"""
188
Huber loss.
189
190
Parameters:
191
- delta: Threshold for switching from quadratic to linear loss
192
"""
193
194
class KLDivergence(Loss):
195
def __init__(self, **kwargs):
196
"""Kullback-Leibler divergence loss."""
197
198
class CosineSimilarity(Loss):
199
def __init__(self, axis=-1, **kwargs):
200
"""
201
Cosine similarity loss.
202
203
Parameters:
204
- axis: Axis along which to compute cosine similarity
205
"""
206
```
207
208
### Metrics
209
210
Functions for monitoring training and evaluation performance without affecting the optimization process.
211
212
```python { .api }
213
class Metric:
214
def __init__(self, name=None, dtype=None, **kwargs):
215
"""
216
Base class for all metrics.
217
218
Parameters:
219
- name: Name of the metric
220
- dtype: Data type for metric computations
221
"""
222
223
def update_state(self, y_true, y_pred, sample_weight=None):
224
"""
225
Update metric state with new observations.
226
227
Parameters:
228
- y_true: Ground truth values
229
- y_pred: Predicted values
230
- sample_weight: Optional sample weights
231
"""
232
233
def result(self):
234
"""
235
Compute and return metric value.
236
237
Returns:
238
Metric value as tensor
239
"""
240
241
def reset_state(self):
242
"""Reset all metric state variables."""
243
244
class Accuracy(Metric):
245
def __init__(self, name='accuracy', dtype=None, **kwargs):
246
"""Accuracy metric for classification tasks."""
247
248
class SparseCategoricalAccuracy(Metric):
249
def __init__(self, name='sparse_categorical_accuracy', dtype=None, **kwargs):
250
"""Sparse categorical accuracy metric."""
251
252
class CategoricalAccuracy(Metric):
253
def __init__(self, name='categorical_accuracy', dtype=None, **kwargs):
254
"""Categorical accuracy metric."""
255
256
class TopKCategoricalAccuracy(Metric):
257
def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None, **kwargs):
258
"""
259
Top-k categorical accuracy metric.
260
261
Parameters:
262
- k: Number of top predictions to consider
263
"""
264
265
class Precision(Metric):
266
def __init__(self, thresholds=None, top_k=None, class_id=None,
267
name=None, dtype=None, **kwargs):
268
"""
269
Precision metric.
270
271
Parameters:
272
- thresholds: Optional thresholds for binary classification
273
- top_k: Number of top predictions to consider
274
- class_id: Specific class to compute metric for
275
"""
276
277
class Recall(Metric):
278
def __init__(self, thresholds=None, top_k=None, class_id=None,
279
name=None, dtype=None, **kwargs):
280
"""Recall metric."""
281
282
class AUC(Metric):
283
def __init__(self, num_thresholds=200, curve='ROC', summation_method='interpolation',
284
name=None, dtype=None, **kwargs):
285
"""
286
Area under the curve metric.
287
288
Parameters:
289
- num_thresholds: Number of thresholds for approximation
290
- curve: Type of curve ('ROC' or 'PR')
291
- summation_method: Method for approximating AUC
292
"""
293
294
class F1Score(Metric):
295
def __init__(self, average=None, threshold=None, name='f1_score', dtype=None, **kwargs):
296
"""
297
F1 score metric.
298
299
Parameters:
300
- average: Type of averaging ('micro', 'macro', 'weighted', or None)
301
- threshold: Decision threshold for binary classification
302
"""
303
304
class MeanSquaredError(Metric):
305
def __init__(self, name='mean_squared_error', dtype=None, **kwargs):
306
"""Mean squared error metric."""
307
308
class MeanAbsoluteError(Metric):
309
def __init__(self, name='mean_absolute_error', dtype=None, **kwargs):
310
"""Mean absolute error metric."""
311
312
class RootMeanSquaredError(Metric):
313
def __init__(self, name='root_mean_squared_error', dtype=None, **kwargs):
314
"""Root mean squared error metric."""
315
```
316
317
### Callbacks
318
319
Utilities that can perform actions at various stages of training, such as saving models, adjusting learning rates, or early stopping.
320
321
```python { .api }
322
class Callback:
323
def __init__(self):
324
"""Base class for all callbacks."""
325
326
def on_epoch_begin(self, epoch, logs=None):
327
"""Called at the beginning of an epoch."""
328
329
def on_epoch_end(self, epoch, logs=None):
330
"""Called at the end of an epoch."""
331
332
def on_batch_begin(self, batch, logs=None):
333
"""Called at the beginning of a batch."""
334
335
def on_batch_end(self, batch, logs=None):
336
"""Called at the end of a batch."""
337
338
def on_train_begin(self, logs=None):
339
"""Called at the beginning of training."""
340
341
def on_train_end(self, logs=None):
342
"""Called at the end of training."""
343
344
class ModelCheckpoint(Callback):
345
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False,
346
save_weights_only=False, mode='auto', save_freq='epoch', **kwargs):
347
"""
348
Save model or weights at some frequency.
349
350
Parameters:
351
- filepath: Path to save model/weights
352
- monitor: Metric to monitor for saving
353
- verbose: Verbosity mode
354
- save_best_only: Whether to save only when monitored metric improves
355
- save_weights_only: Whether to save only weights
356
- mode: One of {'auto', 'min', 'max'}
357
- save_freq: 'epoch' or integer (number of batches)
358
"""
359
360
class EarlyStopping(Callback):
361
def __init__(self, monitor='val_loss', min_delta=0, patience=0, verbose=0,
362
mode='auto', baseline=None, restore_best_weights=False, **kwargs):
363
"""
364
Stop training when monitored metric has stopped improving.
365
366
Parameters:
367
- monitor: Metric to monitor
368
- min_delta: Minimum change to qualify as improvement
369
- patience: Number of epochs with no improvement to wait
370
- verbose: Verbosity mode
371
- mode: One of {'auto', 'min', 'max'}
372
- baseline: Baseline value for monitored metric
373
- restore_best_weights: Whether to restore model weights from best epoch
374
"""
375
376
class ReduceLROnPlateau(Callback):
377
def __init__(self, monitor='val_loss', factor=0.1, patience=10, verbose=0,
378
mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, **kwargs):
379
"""
380
Reduce learning rate when metric has stopped improving.
381
382
Parameters:
383
- monitor: Metric to monitor
384
- factor: Factor by which learning rate will be reduced
385
- patience: Number of epochs with no improvement to wait
386
- verbose: Verbosity mode
387
- mode: One of {'auto', 'min', 'max'}
388
- min_delta: Threshold for measuring new optimum
389
- cooldown: Number of epochs to wait before resuming normal operation
390
- min_lr: Lower bound on learning rate
391
"""
392
393
class LearningRateScheduler(Callback):
394
def __init__(self, schedule, verbose=0, **kwargs):
395
"""
396
Learning rate scheduler.
397
398
Parameters:
399
- schedule: Function that takes epoch index and returns new learning rate
400
- verbose: Verbosity mode
401
"""
402
403
class TensorBoard(Callback):
404
def __init__(self, log_dir='logs', histogram_freq=0, write_graph=True,
405
write_images=False, write_steps_per_second=False,
406
update_freq='epoch', **kwargs):
407
"""
408
TensorBoard logging callback.
409
410
Parameters:
411
- log_dir: Directory to write TensorBoard logs
412
- histogram_freq: Frequency for writing histograms
413
- write_graph: Whether to write computation graph
414
- write_images: Whether to write model weights as images
415
- write_steps_per_second: Whether to log steps/second
416
- update_freq: 'batch', 'epoch', or integer (number of batches)
417
"""
418
419
class CSVLogger(Callback):
420
def __init__(self, filename, separator=',', append=False, **kwargs):
421
"""
422
Stream epoch results to CSV file.
423
424
Parameters:
425
- filename: Path to CSV file
426
- separator: String used to separate elements in CSV file
427
- append: Whether to append if file exists
428
"""
429
```
430
431
## Usage Examples
432
433
### Basic Training Setup
434
435
```python
436
import keras
437
from keras import layers, optimizers, losses, metrics
438
439
# Build model
440
model = keras.Sequential([
441
layers.Dense(128, activation='relu', input_shape=(784,)),
442
layers.Dropout(0.2),
443
layers.Dense(10, activation='softmax')
444
])
445
446
# Compile with custom optimizer and metrics
447
model.compile(
448
optimizer=optimizers.Adam(learning_rate=0.001),
449
loss=losses.SparseCategoricalCrossentropy(),
450
metrics=[
451
metrics.SparseCategoricalAccuracy(),
452
metrics.TopKCategoricalAccuracy(k=3)
453
]
454
)
455
```
456
457
### Training with Callbacks
458
459
```python
460
from keras import callbacks
461
462
# Define callbacks
463
checkpoint = callbacks.ModelCheckpoint(
464
'best_model.keras',
465
monitor='val_accuracy',
466
save_best_only=True,
467
verbose=1
468
)
469
470
early_stop = callbacks.EarlyStopping(
471
monitor='val_loss',
472
patience=5,
473
restore_best_weights=True
474
)
475
476
reduce_lr = callbacks.ReduceLROnPlateau(
477
monitor='val_loss',
478
factor=0.2,
479
patience=3,
480
min_lr=1e-7
481
)
482
483
# Train with callbacks
484
history = model.fit(
485
x_train, y_train,
486
batch_size=32,
487
epochs=100,
488
validation_data=(x_val, y_val),
489
callbacks=[checkpoint, early_stop, reduce_lr]
490
)
491
```
492
493
### Custom Learning Rate Schedule
494
495
```python
496
from keras import callbacks
497
import math
498
499
def lr_schedule(epoch, lr):
500
if epoch < 10:
501
return lr
502
else:
503
return lr * math.exp(-0.1)
504
505
lr_scheduler = callbacks.LearningRateScheduler(lr_schedule, verbose=1)
506
507
model.fit(
508
x_train, y_train,
509
epochs=50,
510
callbacks=[lr_scheduler]
511
)
512
```
513
514
### Multi-GPU Training
515
516
```python
517
import keras
518
519
# Create distributed strategy
520
strategy = keras.distribute.MirroredStrategy()
521
522
with strategy.scope():
523
# Create model within strategy scope
524
model = keras.Sequential([
525
layers.Dense(128, activation='relu', input_shape=(784,)),
526
layers.Dense(10, activation='softmax')
527
])
528
529
model.compile(
530
optimizer='adam',
531
loss='sparse_categorical_crossentropy',
532
metrics=['accuracy']
533
)
534
535
# Train on multiple GPUs
536
model.fit(x_train, y_train, epochs=10)
537
```