0
# Loss Functions
1
2
Comprehensive collection of loss functions for classification, regression, and structured prediction tasks. These functions provide differentiable objectives for training neural networks and other machine learning models.
3
4
## Capabilities
5
6
### Regression Losses
7
8
#### Mean Squared Error
9
10
```python { .api }
11
def l2_loss(predictions, targets=None):
12
"""
13
L2 loss (mean squared error).
14
15
Args:
16
predictions: Predicted values
17
targets: Target values (default: None, uses zeros if not provided)
18
19
Returns:
20
Scalar loss value
21
"""
22
23
def squared_error(predictions, targets):
24
"""
25
Squared error loss (alias for l2_loss).
26
27
Args:
28
predictions: Predicted values
29
targets: Target values
30
31
Returns:
32
Scalar loss value
33
"""
34
```
35
36
#### Robust Regression Losses
37
38
```python { .api }
39
def huber_loss(predictions, targets, delta=1.0):
40
"""
41
Huber loss for robust regression.
42
43
Args:
44
predictions: Predicted values
45
targets: Target values
46
delta: Threshold for switching between squared and linear loss (default: 1.0)
47
48
Returns:
49
Scalar loss value
50
"""
51
52
def log_cosh(predictions, targets):
53
"""
54
Log-cosh loss for robust regression.
55
56
Args:
57
predictions: Predicted values
58
targets: Target values
59
60
Returns:
61
Scalar loss value
62
"""
63
```
64
65
#### Distance-Based Losses
66
67
```python { .api }
68
def cosine_distance(predictions, targets):
69
"""
70
Cosine distance loss.
71
72
Args:
73
predictions: Predicted vectors
74
targets: Target vectors
75
76
Returns:
77
Scalar loss value
78
"""
79
80
def cosine_similarity(predictions, targets):
81
"""
82
Cosine similarity (negative cosine distance).
83
84
Args:
85
predictions: Predicted vectors
86
targets: Target vectors
87
88
Returns:
89
Scalar similarity value
90
"""
91
```
92
93
### Classification Losses
94
95
#### Cross-Entropy Losses
96
97
```python { .api }
98
def softmax_cross_entropy(logits, labels, axis=-1):
99
"""
100
Softmax cross-entropy loss.
101
102
Args:
103
logits: Predicted logits
104
labels: One-hot encoded target labels
105
axis: Axis along which to apply softmax (default: -1)
106
107
Returns:
108
Scalar loss value
109
"""
110
111
def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):
112
"""
113
Softmax cross-entropy loss with integer labels.
114
115
Args:
116
logits: Predicted logits
117
labels: Integer target labels
118
axis: Axis along which to apply softmax (default: -1)
119
120
Returns:
121
Scalar loss value
122
"""
123
124
def safe_softmax_cross_entropy(logits, labels, axis=-1):
125
"""
126
Numerically stable softmax cross-entropy loss.
127
128
Args:
129
logits: Predicted logits
130
labels: One-hot encoded target labels
131
axis: Axis along which to apply softmax (default: -1)
132
133
Returns:
134
Scalar loss value
135
"""
136
137
def sigmoid_binary_cross_entropy(logits, labels):
138
"""
139
Sigmoid binary cross-entropy loss.
140
141
Args:
142
logits: Predicted logits
143
labels: Binary target labels
144
145
Returns:
146
Scalar loss value
147
"""
148
149
def poly_loss_cross_entropy(logits, labels, epsilon=2.0):
150
"""
151
PolyLoss cross-entropy for improved tail learning.
152
153
Args:
154
logits: Predicted logits
155
labels: One-hot encoded target labels
156
epsilon: Polynomial coefficient (default: 2.0)
157
158
Returns:
159
Scalar loss value
160
"""
161
```
162
163
#### Margin-Based Losses
164
165
```python { .api }
166
def hinge_loss(scores, labels):
167
"""
168
Hinge loss for binary classification.
169
170
Args:
171
scores: Predicted scores
172
labels: Binary labels (+1 or -1)
173
174
Returns:
175
Scalar loss value
176
"""
177
178
def multiclass_hinge_loss(scores, labels):
179
"""
180
Multiclass hinge loss.
181
182
Args:
183
scores: Predicted scores for each class
184
labels: Integer class labels
185
186
Returns:
187
Scalar loss value
188
"""
189
190
def perceptron_loss(scores, labels):
191
"""
192
Perceptron loss for binary classification.
193
194
Args:
195
scores: Predicted scores
196
labels: Binary labels (+1 or -1)
197
198
Returns:
199
Scalar loss value
200
"""
201
202
def multiclass_perceptron_loss(scores, labels):
203
"""
204
Multiclass perceptron loss.
205
206
Args:
207
scores: Predicted scores for each class
208
labels: Integer class labels
209
210
Returns:
211
Scalar loss value
212
"""
213
```
214
215
#### Focal and Sigmoid Losses
216
217
```python { .api }
218
def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):
219
"""
220
Sigmoid focal loss for addressing class imbalance.
221
222
Args:
223
logits: Predicted logits
224
labels: Binary target labels
225
alpha: Weighting factor for rare class (default: 0.25)
226
gamma: Focusing parameter (default: 2.0)
227
228
Returns:
229
Scalar loss value
230
"""
231
```
232
233
### Structured Prediction Losses
234
235
#### Sequence Losses
236
237
```python { .api }
238
def ctc_loss(logits, labels, input_lengths, label_lengths, blank=0):
239
"""
240
Connectionist Temporal Classification (CTC) loss.
241
242
Args:
243
logits: Predicted logits for each time step
244
labels: Target sequence labels
245
input_lengths: Length of each input sequence
246
label_lengths: Length of each target sequence
247
blank: Blank token index (default: 0)
248
249
Returns:
250
Scalar loss value
251
"""
252
253
def ctc_loss_with_forward_probs(logits, labels, input_lengths, label_lengths, blank=0):
254
"""
255
CTC loss that also returns forward probabilities.
256
257
Args:
258
logits: Predicted logits for each time step
259
labels: Target sequence labels
260
input_lengths: Length of each input sequence
261
label_lengths: Length of each target sequence
262
blank: Blank token index (default: 0)
263
264
Returns:
265
Tuple of (loss, forward_probs)
266
"""
267
```
268
269
#### Ranking and Contrastive Losses
270
271
```python { .api }
272
def ranking_softmax_loss(scores, labels):
273
"""
274
Ranking loss using softmax for learning to rank tasks.
275
276
Args:
277
scores: Predicted relevance scores
278
labels: Target relevance labels
279
280
Returns:
281
Scalar loss value
282
"""
283
284
def triplet_margin_loss(anchor, positive, negative, margin=1.0):
285
"""
286
Triplet margin loss for metric learning.
287
288
Args:
289
anchor: Anchor embeddings
290
positive: Positive example embeddings
291
negative: Negative example embeddings
292
margin: Margin parameter (default: 1.0)
293
294
Returns:
295
Scalar loss value
296
"""
297
298
def ntxent(query, key, temperature=1.0):
299
"""
300
Normalized temperature-scaled cross-entropy loss for contrastive learning.
301
302
Args:
303
query: Query embeddings
304
key: Key embeddings
305
temperature: Temperature scaling parameter (default: 1.0)
306
307
Returns:
308
Scalar loss value
309
"""
310
```
311
312
### Divergence and Information-Theoretic Losses
313
314
#### KL Divergence
315
316
```python { .api }
317
def kl_divergence(log_predictions, targets):
318
"""
319
Kullback-Leibler divergence.
320
321
Args:
322
log_predictions: Log probabilities of predictions
323
targets: Target probability distributions
324
325
Returns:
326
Scalar divergence value
327
"""
328
329
def kl_divergence_with_log_targets(log_predictions, log_targets):
330
"""
331
KL divergence with log-space targets for numerical stability.
332
333
Args:
334
log_predictions: Log probabilities of predictions
335
log_targets: Log probabilities of targets
336
337
Returns:
338
Scalar divergence value
339
"""
340
341
def convex_kl_divergence(log_predictions, targets):
342
"""
343
Convex KL divergence (reverse KL).
344
345
Args:
346
log_predictions: Log probabilities of predictions
347
targets: Target probability distributions
348
349
Returns:
350
Scalar divergence value
351
"""
352
```
353
354
### Sparsemax and Specialized Losses
355
356
#### Sparsemax Losses
357
358
```python { .api }
359
def sparsemax_loss(logits, labels):
360
"""
361
Sparsemax loss for sparse probability distributions.
362
363
Args:
364
logits: Predicted logits
365
labels: Target labels
366
367
Returns:
368
Scalar loss value
369
"""
370
371
def multiclass_sparsemax_loss(logits, labels):
372
"""
373
Multiclass sparsemax loss.
374
375
Args:
376
logits: Predicted logits for each class
377
labels: Integer class labels
378
379
Returns:
380
Scalar loss value
381
"""
382
```
383
384
### Loss Utilities
385
386
#### Label Processing
387
388
```python { .api }
389
def smooth_labels(labels, alpha=0.1):
390
"""
391
Apply label smoothing to one-hot labels.
392
393
Args:
394
labels: One-hot encoded labels
395
alpha: Smoothing parameter (default: 0.1)
396
397
Returns:
398
Smoothed labels
399
"""
400
401
def make_fenchel_young_loss(regularizer):
402
"""
403
Create Fenchel-Young loss from convex regularizer.
404
405
Args:
406
regularizer: Convex regularization function
407
408
Returns:
409
Fenchel-Young loss function
410
"""
411
Softmax cross-entropy loss.
412
413
Args:
414
logits: Unnormalized log probabilities
415
labels: One-hot encoded labels or label probabilities
416
axis: Axis along which to apply softmax (default: -1)
417
418
Returns:
419
Scalar loss value
420
"""
421
422
def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):
423
"""
424
Softmax cross-entropy with integer labels.
425
426
Args:
427
logits: Unnormalized log probabilities
428
labels: Integer class labels
429
axis: Axis along which to apply softmax (default: -1)
430
431
Returns:
432
Scalar loss value
433
"""
434
435
def safe_softmax_cross_entropy(logits, labels, axis=-1):
436
"""
437
Numerically stable softmax cross-entropy.
438
439
Args:
440
logits: Unnormalized log probabilities
441
labels: One-hot encoded labels or label probabilities
442
axis: Axis along which to apply softmax (default: -1)
443
444
Returns:
445
Scalar loss value
446
"""
447
```
448
449
#### Binary Classification
450
451
```python { .api }
452
def sigmoid_binary_cross_entropy(logits, labels):
453
"""
454
Sigmoid binary cross-entropy loss.
455
456
Args:
457
logits: Unnormalized log probabilities
458
labels: Binary labels (0 or 1)
459
460
Returns:
461
Scalar loss value
462
"""
463
```
464
465
#### Margin-Based Losses
466
467
```python { .api }
468
def hinge_loss(scores, labels):
469
"""
470
Hinge loss for binary classification.
471
472
Args:
473
scores: Prediction scores
474
labels: Binary labels (-1 or 1)
475
476
Returns:
477
Scalar loss value
478
"""
479
```
480
481
#### Focal Loss
482
483
```python { .api }
484
def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):
485
"""
486
Sigmoid focal loss for addressing class imbalance.
487
488
Args:
489
logits: Unnormalized log probabilities
490
labels: Binary labels
491
alpha: Weighting factor for rare class (default: 0.25)
492
gamma: Focusing parameter (default: 2.0)
493
494
Returns:
495
Scalar loss value
496
"""
497
```
498
499
### Probability Divergences
500
501
```python { .api }
502
def kl_divergence(log_predictions, targets):
503
"""
504
Kullback-Leibler divergence.
505
506
Args:
507
log_predictions: Log probabilities of predictions
508
targets: Target probability distribution
509
510
Returns:
511
Scalar divergence value
512
"""
513
514
def convex_kl_divergence(log_predictions, targets):
515
"""
516
Convex KL divergence (targets * log(targets/predictions)).
517
518
Args:
519
log_predictions: Log probabilities of predictions
520
targets: Target probability distribution
521
522
Returns:
523
Scalar divergence value
524
"""
525
```
526
527
### Structured Losses
528
529
#### CTC Loss
530
531
```python { .api }
532
def ctc_loss(logits, logit_paddings, labels, label_paddings):
533
"""
534
Connectionist Temporal Classification (CTC) loss.
535
536
Args:
537
logits: Log probabilities over vocabulary
538
logit_paddings: Padding mask for logits
539
labels: Target label sequences
540
label_paddings: Padding mask for labels
541
542
Returns:
543
Scalar CTC loss value
544
"""
545
546
def ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings):
547
"""
548
CTC loss with forward probabilities for additional insights.
549
550
Args:
551
logits: Log probabilities over vocabulary
552
logit_paddings: Padding mask for logits
553
labels: Target label sequences
554
label_paddings: Padding mask for labels
555
556
Returns:
557
Tuple of (loss, forward_probs)
558
"""
559
```
560
561
### Self-Supervised Losses
562
563
#### Contrastive Learning
564
565
```python { .api }
566
def ntxent(query_features, key_features, temperature=1.0):
567
"""
568
Normalized Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning.
569
570
Args:
571
query_features: Query feature vectors
572
key_features: Key feature vectors
573
temperature: Temperature scaling parameter (default: 1.0)
574
575
Returns:
576
Scalar contrastive loss value
577
"""
578
```
579
580
### Label Processing
581
582
```python { .api }
583
def smooth_labels(labels, alpha):
584
"""
585
Apply label smoothing to one-hot labels.
586
587
Args:
588
labels: One-hot encoded labels
589
alpha: Smoothing parameter (0 = no smoothing, 1 = uniform)
590
591
Returns:
592
Smoothed label distribution
593
"""
594
```
595
596
## Usage Examples
597
598
### Basic Regression
599
600
```python
601
import optax
602
import jax.numpy as jnp
603
604
# Predictions and targets
605
predictions = jnp.array([1.0, 2.0, 3.0])
606
targets = jnp.array([1.1, 1.9, 3.2])
607
608
# Compute losses
609
mse_loss = optax.l2_loss(predictions, targets)
610
huber_loss_val = optax.huber_loss(predictions, targets, delta=1.0)
611
```
612
613
### Classification Setup
614
615
```python
616
# Multi-class classification
617
logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 3.0, 0.5]])
618
one_hot_labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
619
integer_labels = jnp.array([0, 1])
620
621
# Cross-entropy losses
622
ce_loss = optax.softmax_cross_entropy(logits, one_hot_labels)
623
ce_int_loss = optax.softmax_cross_entropy_with_integer_labels(logits, integer_labels)
624
625
# Binary classification
626
binary_logits = jnp.array([0.5, -1.2, 2.1])
627
binary_labels = jnp.array([1.0, 0.0, 1.0])
628
binary_loss = optax.sigmoid_binary_cross_entropy(binary_logits, binary_labels)
629
```
630
631
### Training Loop Integration
632
633
```python
634
import jax
635
636
def compute_loss(params, batch_x, batch_y):
637
"""Compute loss for a batch."""
638
predictions = model_fn(params, batch_x)
639
return optax.softmax_cross_entropy_with_integer_labels(predictions, batch_y)
640
641
def train_step(params, opt_state, batch_x, batch_y):
642
"""Single training step."""
643
# Compute loss and gradients
644
loss_val, grads = jax.value_and_grad(compute_loss)(params, batch_x, batch_y)
645
646
# Update parameters
647
updates, opt_state = optimizer.update(grads, opt_state)
648
params = optax.apply_updates(params, updates)
649
650
return params, opt_state, loss_val
651
```
652
653
### Advanced Loss Combinations
654
655
```python
656
def combined_loss(predictions, targets, params):
657
"""Combine multiple loss terms."""
658
# Main task loss
659
task_loss = optax.softmax_cross_entropy(predictions, targets)
660
661
# Regularization loss
662
l2_reg = sum(optax.l2_loss(p, jnp.zeros_like(p)) for p in jax.tree_leaves(params))
663
664
# Total loss
665
return task_loss + 1e-4 * l2_reg
666
667
# With label smoothing
668
smoothed_labels = optax.smooth_labels(one_hot_labels, alpha=0.1)
669
smooth_loss = optax.softmax_cross_entropy(logits, smoothed_labels)
670
```