0
# Loss Functions and Metrics
1
2
Comprehensive collection of loss functions for training neural networks and metrics for evaluation, covering classification, regression, and specialized tasks with both class-based and function-based APIs.
3
4
## Capabilities
5
6
### Classification Loss Functions
7
8
Loss functions designed for classification tasks including binary, multiclass, and specialized classification scenarios.
9
10
```python { .api }
11
class BinaryCrossentropy:
12
"""
13
Binary cross-entropy loss for binary classification.
14
15
Args:
16
from_logits (bool): Whether input is logits or probabilities
17
label_smoothing (float): Label smoothing factor
18
axis (int): Axis along which to compute loss
19
reduction (str): Type of reduction to apply
20
name (str): Name of the loss
21
"""
22
def __init__(self, from_logits=False, label_smoothing=0.0, axis=-1, **kwargs): ...
23
24
class CategoricalCrossentropy:
25
"""
26
Categorical cross-entropy loss for multiclass classification.
27
28
Args:
29
from_logits (bool): Whether input is logits or probabilities
30
label_smoothing (float): Label smoothing factor
31
axis (int): Axis along which to compute loss
32
"""
33
def __init__(self, from_logits=False, label_smoothing=0.0, axis=-1, **kwargs): ...
34
35
class SparseCategoricalCrossentropy:
36
"""
37
Sparse categorical cross-entropy for integer labels.
38
39
Args:
40
from_logits (bool): Whether input is logits or probabilities
41
ignore_class (int, optional): Class index to ignore
42
axis (int): Axis along which to compute loss
43
"""
44
def __init__(self, from_logits=False, ignore_class=None, axis=-1, **kwargs): ...
45
46
class BinaryFocalCrossentropy:
47
"""
48
Binary focal loss for addressing class imbalance.
49
50
Args:
51
alpha (float): Weighting factor for rare class
52
gamma (float): Focusing parameter
53
from_logits (bool): Whether input is logits or probabilities
54
label_smoothing (float): Label smoothing factor
55
"""
56
def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, label_smoothing=0.0, **kwargs): ...
57
58
class CategoricalFocalCrossentropy:
59
"""
60
Categorical focal loss for multiclass imbalanced datasets.
61
62
Args:
63
alpha (float): Weighting factor
64
gamma (float): Focusing parameter
65
from_logits (bool): Whether input is logits or probabilities
66
"""
67
def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, **kwargs): ...
68
69
# Function equivalents
70
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1): ...
71
def categorical_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1): ...
72
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, ignore_class=None, axis=-1): ...
73
```
74
75
### Regression Loss Functions
76
77
Loss functions for continuous value prediction tasks with various robustness properties.
78
79
```python { .api }
80
class MeanSquaredError:
81
"""
82
Mean squared error loss for regression.
83
84
Args:
85
reduction (str): Type of reduction to apply
86
name (str): Name of the loss
87
"""
88
def __init__(self, **kwargs): ...
89
90
class MeanAbsoluteError:
91
"""
92
Mean absolute error loss for regression.
93
94
Args:
95
reduction (str): Type of reduction to apply
96
name (str): Name of the loss
97
"""
98
def __init__(self, **kwargs): ...
99
100
class MeanAbsolutePercentageError:
101
"""
102
Mean absolute percentage error for regression.
103
104
Args:
105
reduction (str): Type of reduction to apply
106
name (str): Name of the loss
107
"""
108
def __init__(self, **kwargs): ...
109
110
class MeanSquaredLogarithmicError:
111
"""
112
Mean squared logarithmic error for regression.
113
114
Args:
115
reduction (str): Type of reduction to apply
116
name (str): Name of the loss
117
"""
118
def __init__(self, **kwargs): ...
119
120
class Huber:
121
"""
122
Huber loss for robust regression.
123
124
Args:
125
delta (float): Point where loss changes from quadratic to linear
126
reduction (str): Type of reduction to apply
127
name (str): Name of the loss
128
"""
129
def __init__(self, delta=1.0, **kwargs): ...
130
131
class LogCosh:
132
"""
133
Log-cosh loss for regression.
134
135
Args:
136
reduction (str): Type of reduction to apply
137
name (str): Name of the loss
138
"""
139
def __init__(self, **kwargs): ...
140
141
# Function equivalents
142
def mean_squared_error(y_true, y_pred): ...
143
def mean_absolute_error(y_true, y_pred): ...
144
def mean_absolute_percentage_error(y_true, y_pred): ...
145
def huber(y_true, y_pred, delta=1.0): ...
146
```
147
148
### Specialized Loss Functions
149
150
Loss functions for specific tasks including ranking, sequence modeling, and segmentation.
151
152
```python { .api }
153
class Hinge:
154
"""
155
Hinge loss for maximum-margin classification.
156
157
Args:
158
reduction (str): Type of reduction to apply
159
name (str): Name of the loss
160
"""
161
def __init__(self, **kwargs): ...
162
163
class SquaredHinge:
164
"""Squared hinge loss for maximum-margin classification."""
165
def __init__(self, **kwargs): ...
166
167
class CategoricalHinge:
168
"""Categorical hinge loss for multiclass classification."""
169
def __init__(self, **kwargs): ...
170
171
class KLDivergence:
172
"""
173
Kullback-Leibler divergence loss.
174
175
Args:
176
reduction (str): Type of reduction to apply
177
name (str): Name of the loss
178
"""
179
def __init__(self, **kwargs): ...
180
181
class Poisson:
182
"""
183
Poisson loss for count data.
184
185
Args:
186
reduction (str): Type of reduction to apply
187
name (str): Name of the loss
188
"""
189
def __init__(self, **kwargs): ...
190
191
class CosineSimilarity:
192
"""
193
Cosine similarity loss.
194
195
Args:
196
axis (int): Axis along which to compute cosine similarity
197
reduction (str): Type of reduction to apply
198
name (str): Name of the loss
199
"""
200
def __init__(self, axis=-1, **kwargs): ...
201
202
class Dice:
203
"""
204
Dice loss for segmentation tasks.
205
206
Args:
207
axis (int or tuple, optional): Axis to compute dice over
208
reduction (str): Type of reduction to apply
209
name (str): Name of the loss
210
"""
211
def __init__(self, axis=None, **kwargs): ...
212
213
class Tversky:
214
"""
215
Tversky loss for segmentation with adjustable precision/recall balance.
216
217
Args:
218
alpha (float): Weight for false positives
219
beta (float): Weight for false negatives
220
axis (int or tuple, optional): Axis to compute over
221
"""
222
def __init__(self, alpha=0.5, beta=0.5, axis=None, **kwargs): ...
223
224
class CTC:
225
"""
226
Connectionist Temporal Classification loss for sequence labeling.
227
228
Args:
229
logits_time_major (bool): Whether logits are time-major
230
blank_index (int, optional): Index of blank label
231
reduction (str): Type of reduction to apply
232
"""
233
def __init__(self, logits_time_major=False, blank_index=None, **kwargs): ...
234
```
235
236
### Classification Metrics
237
238
Metrics for evaluating classification model performance including accuracy variants and confusion matrix based metrics.
239
240
```python { .api }
241
class Accuracy:
242
"""
243
Generic accuracy metric.
244
245
Args:
246
name (str): Name of the metric
247
dtype (str): Data type for metric computation
248
"""
249
def __init__(self, name='accuracy', dtype=None, **kwargs): ...
250
251
class BinaryAccuracy:
252
"""
253
Binary classification accuracy.
254
255
Args:
256
threshold (float): Decision threshold
257
name (str): Name of the metric
258
dtype (str): Data type for metric computation
259
"""
260
def __init__(self, threshold=0.5, name='binary_accuracy', dtype=None, **kwargs): ...
261
262
class CategoricalAccuracy:
263
"""
264
Categorical accuracy for one-hot encoded labels.
265
266
Args:
267
name (str): Name of the metric
268
dtype (str): Data type for metric computation
269
"""
270
def __init__(self, name='categorical_accuracy', dtype=None, **kwargs): ...
271
272
class SparseCategoricalAccuracy:
273
"""
274
Categorical accuracy for integer labels.
275
276
Args:
277
name (str): Name of the metric
278
dtype (str): Data type for metric computation
279
"""
280
def __init__(self, name='sparse_categorical_accuracy', dtype=None, **kwargs): ...
281
282
class TopKCategoricalAccuracy:
283
"""
284
Top-k categorical accuracy.
285
286
Args:
287
k (int): Number of top predictions to consider
288
name (str): Name of the metric
289
dtype (str): Data type for metric computation
290
"""
291
def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None, **kwargs): ...
292
293
class Precision:
294
"""
295
Precision metric.
296
297
Args:
298
thresholds (list, optional): List of thresholds for multi-threshold precision
299
top_k (int, optional): Top-k precision
300
class_id (int, optional): Class to compute precision for
301
name (str): Name of the metric
302
dtype (str): Data type for metric computation
303
"""
304
def __init__(self, thresholds=None, top_k=None, class_id=None, name='precision', dtype=None, **kwargs): ...
305
306
class Recall:
307
"""
308
Recall metric.
309
310
Args:
311
thresholds (list, optional): List of thresholds for multi-threshold recall
312
top_k (int, optional): Top-k recall
313
class_id (int, optional): Class to compute recall for
314
name (str): Name of the metric
315
dtype (str): Data type for metric computation
316
"""
317
def __init__(self, thresholds=None, top_k=None, class_id=None, name='recall', dtype=None, **kwargs): ...
318
319
class F1Score:
320
"""
321
F1 score metric.
322
323
Args:
324
average (str, optional): Averaging strategy ('micro', 'macro', 'weighted', None)
325
threshold (float, optional): Decision threshold for binary classification
326
name (str): Name of the metric
327
dtype (str): Data type for metric computation
328
"""
329
def __init__(self, average=None, threshold=None, name='f1_score', dtype=None, **kwargs): ...
330
331
class AUC:
332
"""
333
Area Under the ROC Curve metric.
334
335
Args:
336
num_thresholds (int): Number of thresholds for ROC curve
337
curve (str): Type of curve ('ROC' or 'PR')
338
summation_method (str): Method for approximating AUC
339
name (str): Name of the metric
340
dtype (str): Data type for metric computation
341
"""
342
def __init__(self, num_thresholds=200, curve='ROC', summation_method='interpolation',
343
name='auc', dtype=None, **kwargs): ...
344
```
345
346
### Regression Metrics
347
348
Metrics for evaluating regression model performance.
349
350
```python { .api }
351
class MeanSquaredError:
352
"""Mean squared error metric for regression."""
353
def __init__(self, name='mean_squared_error', dtype=None, **kwargs): ...
354
355
class RootMeanSquaredError:
356
"""Root mean squared error metric for regression."""
357
def __init__(self, name='root_mean_squared_error', dtype=None, **kwargs): ...
358
359
class MeanAbsoluteError:
360
"""Mean absolute error metric for regression."""
361
def __init__(self, name='mean_absolute_error', dtype=None, **kwargs): ...
362
363
class MeanAbsolutePercentageError:
364
"""Mean absolute percentage error metric for regression."""
365
def __init__(self, name='mean_absolute_percentage_error', dtype=None, **kwargs): ...
366
367
class R2Score:
368
"""
369
R² (coefficient of determination) metric.
370
371
Args:
372
class_aggregation (str): How to aggregate multiclass R²
373
num_regressors (int, optional): Number of regressors for adjusted R²
374
name (str): Name of the metric
375
dtype (str): Data type for metric computation
376
"""
377
def __init__(self, class_aggregation='uniform_average', num_regressors=0,
378
name='r2_score', dtype=None, **kwargs): ...
379
380
class CosineSimilarity:
381
"""
382
Cosine similarity metric.
383
384
Args:
385
axis (int): Axis along which to compute cosine similarity
386
name (str): Name of the metric
387
dtype (str): Data type for metric computation
388
"""
389
def __init__(self, axis=-1, name='cosine_similarity', dtype=None, **kwargs): ...
390
```
391
392
### Segmentation Metrics
393
394
Metrics for evaluating image segmentation and pixel-wise classification tasks.
395
396
```python { .api }
397
class IoU:
398
"""
399
Intersection over Union (Jaccard Index) metric.
400
401
Args:
402
num_classes (int): Number of classes
403
target_class_ids (list, optional): Specific classes to compute IoU for
404
threshold (float, optional): Threshold for binary predictions
405
name (str): Name of the metric
406
dtype (str): Data type for metric computation
407
"""
408
def __init__(self, num_classes, target_class_ids=None, threshold=None,
409
name='iou', dtype=None, **kwargs): ...
410
411
class MeanIoU:
412
"""
413
Mean Intersection over Union metric.
414
415
Args:
416
num_classes (int): Number of classes
417
name (str): Name of the metric
418
dtype (str): Data type for metric computation
419
"""
420
def __init__(self, num_classes, name='mean_iou', dtype=None, **kwargs): ...
421
422
class BinaryIoU:
423
"""
424
Binary Intersection over Union metric.
425
426
Args:
427
target_class_ids (list, optional): Target class IDs
428
threshold (float): Decision threshold
429
name (str): Name of the metric
430
dtype (str): Data type for metric computation
431
"""
432
def __init__(self, target_class_ids=None, threshold=0.5, name='binary_iou', dtype=None, **kwargs): ...
433
```
434
435
### Utility Functions
436
437
Functions for metric and loss management.
438
439
```python { .api }
440
# Loss utilities
441
def get(identifier):
442
"""Get loss function by name or return callable."""
443
444
def serialize(loss):
445
"""Serialize loss to JSON-serializable dict."""
446
447
def deserialize(config, custom_objects=None):
448
"""Deserialize loss from config dict."""
449
450
# Metric utilities
451
def get(identifier):
452
"""Get metric by name or return callable."""
453
454
def serialize(metric):
455
"""Serialize metric to JSON-serializable dict."""
456
457
def deserialize(config, custom_objects=None):
458
"""Deserialize metric from config dict."""
459
```
460
461
## Usage Examples
462
463
### Using Loss Functions in Model Compilation
464
465
```python
466
import keras
467
from keras import layers, losses, metrics
468
469
model = keras.Sequential([
470
layers.Dense(64, activation='relu', input_shape=(784,)),
471
layers.Dropout(0.2),
472
layers.Dense(10, activation='softmax')
473
])
474
475
# Using string identifiers
476
model.compile(
477
optimizer='adam',
478
loss='sparse_categorical_crossentropy',
479
metrics=['accuracy']
480
)
481
482
# Using class instances for more control
483
model.compile(
484
optimizer='adam',
485
loss=losses.SparseCategoricalCrossentropy(from_logits=False),
486
metrics=[
487
metrics.SparseCategoricalAccuracy(),
488
metrics.TopKCategoricalAccuracy(k=3)
489
]
490
)
491
```
492
493
### Multi-output Model with Different Losses
494
495
```python
496
import keras
497
from keras import layers, losses, metrics
498
499
# Define inputs
500
inputs = keras.Input(shape=(784,))
501
x = layers.Dense(64, activation='relu')(inputs)
502
503
# Multiple outputs
504
classification_output = layers.Dense(10, activation='softmax', name='classification')(x)
505
regression_output = layers.Dense(1, name='regression')(x)
506
507
model = keras.Model(inputs=inputs, outputs=[classification_output, regression_output])
508
509
# Different losses for different outputs
510
model.compile(
511
optimizer='adam',
512
loss={
513
'classification': losses.SparseCategoricalCrossentropy(),
514
'regression': losses.MeanSquaredError()
515
},
516
metrics={
517
'classification': [metrics.SparseCategoricalAccuracy(), metrics.F1Score()],
518
'regression': [metrics.MeanAbsoluteError(), metrics.R2Score()]
519
},
520
loss_weights={'classification': 1.0, 'regression': 0.5}
521
)
522
```
523
524
### Custom Loss Function
525
526
```python
527
import keras
528
from keras import ops
529
530
def focal_loss(alpha=0.25, gamma=2.0):
531
def loss_fn(y_true, y_pred):
532
# Convert to probabilities if logits
533
y_pred = ops.sigmoid(y_pred)
534
535
# Compute focal loss
536
pt = ops.where(y_true == 1, y_pred, 1 - y_pred)
537
alpha_t = ops.where(y_true == 1, alpha, 1 - alpha)
538
focal_weight = alpha_t * ops.power(1 - pt, gamma)
539
540
bce = -ops.log(pt + 1e-8)
541
focal = focal_weight * bce
542
543
return ops.mean(focal)
544
545
return loss_fn
546
547
# Use custom loss
548
model.compile(
549
optimizer='adam',
550
loss=focal_loss(alpha=0.25, gamma=2.0),
551
metrics=['accuracy']
552
)
553
```
554
555
### Custom Metric
556
557
```python
558
import keras
559
from keras import ops
560
561
class F2Score(keras.metrics.Metric):
562
def __init__(self, name='f2_score', **kwargs):
563
super().__init__(name=name, **kwargs)
564
self.precision = keras.metrics.Precision()
565
self.recall = keras.metrics.Recall()
566
567
def update_state(self, y_true, y_pred, sample_weight=None):
568
self.precision.update_state(y_true, y_pred, sample_weight)
569
self.recall.update_state(y_true, y_pred, sample_weight)
570
571
def result(self):
572
p = self.precision.result()
573
r = self.recall.result()
574
return 5 * p * r / (4 * p + r + 1e-8)
575
576
def reset_state(self):
577
self.precision.reset_state()
578
self.recall.reset_state()
579
580
# Use custom metric
581
model.compile(
582
optimizer='adam',
583
loss='binary_crossentropy',
584
metrics=[F2Score(), 'accuracy']
585
)
586
```