or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

losses.mddocs/

0

# Loss Functions

1

2

Comprehensive collection of loss functions for classification, regression, and structured prediction tasks. These functions provide differentiable objectives for training neural networks and other machine learning models.

3

4

## Capabilities

5

6

### Regression Losses

7

8

#### Mean Squared Error

9

10

```python { .api }

11

def l2_loss(predictions, targets=None):

12

"""

13

L2 loss (mean squared error).

14

15

Args:

16

predictions: Predicted values

17

targets: Target values (default: None, uses zeros if not provided)

18

19

Returns:

20

Scalar loss value

21

"""

22

23

def squared_error(predictions, targets):

24

"""

25

Squared error loss (alias for l2_loss).

26

27

Args:

28

predictions: Predicted values

29

targets: Target values

30

31

Returns:

32

Scalar loss value

33

"""

34

```

35

36

#### Robust Regression Losses

37

38

```python { .api }

39

def huber_loss(predictions, targets, delta=1.0):

40

"""

41

Huber loss for robust regression.

42

43

Args:

44

predictions: Predicted values

45

targets: Target values

46

delta: Threshold for switching between squared and linear loss (default: 1.0)

47

48

Returns:

49

Scalar loss value

50

"""

51

52

def log_cosh(predictions, targets):

53

"""

54

Log-cosh loss for robust regression.

55

56

Args:

57

predictions: Predicted values

58

targets: Target values

59

60

Returns:

61

Scalar loss value

62

"""

63

```

64

65

#### Distance-Based Losses

66

67

```python { .api }

68

def cosine_distance(predictions, targets):

69

"""

70

Cosine distance loss.

71

72

Args:

73

predictions: Predicted vectors

74

targets: Target vectors

75

76

Returns:

77

Scalar loss value

78

"""

79

80

def cosine_similarity(predictions, targets):

81

"""

82

Cosine similarity (negative cosine distance).

83

84

Args:

85

predictions: Predicted vectors

86

targets: Target vectors

87

88

Returns:

89

Scalar similarity value

90

"""

91

```

92

93

### Classification Losses

94

95

#### Cross-Entropy Losses

96

97

```python { .api }

98

def softmax_cross_entropy(logits, labels, axis=-1):

99

"""

100

Softmax cross-entropy loss.

101

102

Args:

103

logits: Predicted logits

104

labels: One-hot encoded target labels

105

axis: Axis along which to apply softmax (default: -1)

106

107

Returns:

108

Scalar loss value

109

"""

110

111

def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):

112

"""

113

Softmax cross-entropy loss with integer labels.

114

115

Args:

116

logits: Predicted logits

117

labels: Integer target labels

118

axis: Axis along which to apply softmax (default: -1)

119

120

Returns:

121

Scalar loss value

122

"""

123

124

def safe_softmax_cross_entropy(logits, labels, axis=-1):

125

"""

126

Numerically stable softmax cross-entropy loss.

127

128

Args:

129

logits: Predicted logits

130

labels: One-hot encoded target labels

131

axis: Axis along which to apply softmax (default: -1)

132

133

Returns:

134

Scalar loss value

135

"""

136

137

def sigmoid_binary_cross_entropy(logits, labels):

138

"""

139

Sigmoid binary cross-entropy loss.

140

141

Args:

142

logits: Predicted logits

143

labels: Binary target labels

144

145

Returns:

146

Scalar loss value

147

"""

148

149

def poly_loss_cross_entropy(logits, labels, epsilon=2.0):

150

"""

151

PolyLoss cross-entropy for improved tail learning.

152

153

Args:

154

logits: Predicted logits

155

labels: One-hot encoded target labels

156

epsilon: Polynomial coefficient (default: 2.0)

157

158

Returns:

159

Scalar loss value

160

"""

161

```

162

163

#### Margin-Based Losses

164

165

```python { .api }

166

def hinge_loss(scores, labels):

167

"""

168

Hinge loss for binary classification.

169

170

Args:

171

scores: Predicted scores

172

labels: Binary labels (+1 or -1)

173

174

Returns:

175

Scalar loss value

176

"""

177

178

def multiclass_hinge_loss(scores, labels):

179

"""

180

Multiclass hinge loss.

181

182

Args:

183

scores: Predicted scores for each class

184

labels: Integer class labels

185

186

Returns:

187

Scalar loss value

188

"""

189

190

def perceptron_loss(scores, labels):

191

"""

192

Perceptron loss for binary classification.

193

194

Args:

195

scores: Predicted scores

196

labels: Binary labels (+1 or -1)

197

198

Returns:

199

Scalar loss value

200

"""

201

202

def multiclass_perceptron_loss(scores, labels):

203

"""

204

Multiclass perceptron loss.

205

206

Args:

207

scores: Predicted scores for each class

208

labels: Integer class labels

209

210

Returns:

211

Scalar loss value

212

"""

213

```

214

215

#### Focal and Sigmoid Losses

216

217

```python { .api }

218

def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):

219

"""

220

Sigmoid focal loss for addressing class imbalance.

221

222

Args:

223

logits: Predicted logits

224

labels: Binary target labels

225

alpha: Weighting factor for rare class (default: 0.25)

226

gamma: Focusing parameter (default: 2.0)

227

228

Returns:

229

Scalar loss value

230

"""

231

```

232

233

### Structured Prediction Losses

234

235

#### Sequence Losses

236

237

```python { .api }

238

def ctc_loss(logits, labels, input_lengths, label_lengths, blank=0):

239

"""

240

Connectionist Temporal Classification (CTC) loss.

241

242

Args:

243

logits: Predicted logits for each time step

244

labels: Target sequence labels

245

input_lengths: Length of each input sequence

246

label_lengths: Length of each target sequence

247

blank: Blank token index (default: 0)

248

249

Returns:

250

Scalar loss value

251

"""

252

253

def ctc_loss_with_forward_probs(logits, labels, input_lengths, label_lengths, blank=0):

254

"""

255

CTC loss that also returns forward probabilities.

256

257

Args:

258

logits: Predicted logits for each time step

259

labels: Target sequence labels

260

input_lengths: Length of each input sequence

261

label_lengths: Length of each target sequence

262

blank: Blank token index (default: 0)

263

264

Returns:

265

Tuple of (loss, forward_probs)

266

"""

267

```

268

269

#### Ranking and Contrastive Losses

270

271

```python { .api }

272

def ranking_softmax_loss(scores, labels):

273

"""

274

Ranking loss using softmax for learning to rank tasks.

275

276

Args:

277

scores: Predicted relevance scores

278

labels: Target relevance labels

279

280

Returns:

281

Scalar loss value

282

"""

283

284

def triplet_margin_loss(anchor, positive, negative, margin=1.0):

285

"""

286

Triplet margin loss for metric learning.

287

288

Args:

289

anchor: Anchor embeddings

290

positive: Positive example embeddings

291

negative: Negative example embeddings

292

margin: Margin parameter (default: 1.0)

293

294

Returns:

295

Scalar loss value

296

"""

297

298

def ntxent(query, key, temperature=1.0):

299

"""

300

Normalized temperature-scaled cross-entropy loss for contrastive learning.

301

302

Args:

303

query: Query embeddings

304

key: Key embeddings

305

temperature: Temperature scaling parameter (default: 1.0)

306

307

Returns:

308

Scalar loss value

309

"""

310

```

311

312

### Divergence and Information-Theoretic Losses

313

314

#### KL Divergence

315

316

```python { .api }

317

def kl_divergence(log_predictions, targets):

318

"""

319

Kullback-Leibler divergence.

320

321

Args:

322

log_predictions: Log probabilities of predictions

323

targets: Target probability distributions

324

325

Returns:

326

Scalar divergence value

327

"""

328

329

def kl_divergence_with_log_targets(log_predictions, log_targets):

330

"""

331

KL divergence with log-space targets for numerical stability.

332

333

Args:

334

log_predictions: Log probabilities of predictions

335

log_targets: Log probabilities of targets

336

337

Returns:

338

Scalar divergence value

339

"""

340

341

def convex_kl_divergence(log_predictions, targets):

342

"""

343

Convex KL divergence (reverse KL).

344

345

Args:

346

log_predictions: Log probabilities of predictions

347

targets: Target probability distributions

348

349

Returns:

350

Scalar divergence value

351

"""

352

```

353

354

### Sparsemax and Specialized Losses

355

356

#### Sparsemax Losses

357

358

```python { .api }

359

def sparsemax_loss(logits, labels):

360

"""

361

Sparsemax loss for sparse probability distributions.

362

363

Args:

364

logits: Predicted logits

365

labels: Target labels

366

367

Returns:

368

Scalar loss value

369

"""

370

371

def multiclass_sparsemax_loss(logits, labels):

372

"""

373

Multiclass sparsemax loss.

374

375

Args:

376

logits: Predicted logits for each class

377

labels: Integer class labels

378

379

Returns:

380

Scalar loss value

381

"""

382

```

383

384

### Loss Utilities

385

386

#### Label Processing

387

388

```python { .api }

389

def smooth_labels(labels, alpha=0.1):

390

"""

391

Apply label smoothing to one-hot labels.

392

393

Args:

394

labels: One-hot encoded labels

395

alpha: Smoothing parameter (default: 0.1)

396

397

Returns:

398

Smoothed labels

399

"""

400

401

def make_fenchel_young_loss(regularizer):

402

"""

403

Create Fenchel-Young loss from convex regularizer.

404

405

Args:

406

regularizer: Convex regularization function

407

408

Returns:

409

Fenchel-Young loss function

410

"""

411

Softmax cross-entropy loss.

412

413

Args:

414

logits: Unnormalized log probabilities

415

labels: One-hot encoded labels or label probabilities

416

axis: Axis along which to apply softmax (default: -1)

417

418

Returns:

419

Scalar loss value

420

"""

421

422

def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):

423

"""

424

Softmax cross-entropy with integer labels.

425

426

Args:

427

logits: Unnormalized log probabilities

428

labels: Integer class labels

429

axis: Axis along which to apply softmax (default: -1)

430

431

Returns:

432

Scalar loss value

433

"""

434

435

def safe_softmax_cross_entropy(logits, labels, axis=-1):

436

"""

437

Numerically stable softmax cross-entropy.

438

439

Args:

440

logits: Unnormalized log probabilities

441

labels: One-hot encoded labels or label probabilities

442

axis: Axis along which to apply softmax (default: -1)

443

444

Returns:

445

Scalar loss value

446

"""

447

```

448

449

#### Binary Classification

450

451

```python { .api }

452

def sigmoid_binary_cross_entropy(logits, labels):

453

"""

454

Sigmoid binary cross-entropy loss.

455

456

Args:

457

logits: Unnormalized log probabilities

458

labels: Binary labels (0 or 1)

459

460

Returns:

461

Scalar loss value

462

"""

463

```

464

465

#### Margin-Based Losses

466

467

```python { .api }

468

def hinge_loss(scores, labels):

469

"""

470

Hinge loss for binary classification.

471

472

Args:

473

scores: Prediction scores

474

labels: Binary labels (-1 or 1)

475

476

Returns:

477

Scalar loss value

478

"""

479

```

480

481

#### Focal Loss

482

483

```python { .api }

484

def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):

485

"""

486

Sigmoid focal loss for addressing class imbalance.

487

488

Args:

489

logits: Unnormalized log probabilities

490

labels: Binary labels

491

alpha: Weighting factor for rare class (default: 0.25)

492

gamma: Focusing parameter (default: 2.0)

493

494

Returns:

495

Scalar loss value

496

"""

497

```

498

499

### Probability Divergences

500

501

```python { .api }

502

def kl_divergence(log_predictions, targets):

503

"""

504

Kullback-Leibler divergence.

505

506

Args:

507

log_predictions: Log probabilities of predictions

508

targets: Target probability distribution

509

510

Returns:

511

Scalar divergence value

512

"""

513

514

def convex_kl_divergence(log_predictions, targets):

515

"""

516

Convex KL divergence (targets * log(targets/predictions)).

517

518

Args:

519

log_predictions: Log probabilities of predictions

520

targets: Target probability distribution

521

522

Returns:

523

Scalar divergence value

524

"""

525

```

526

527

### Structured Losses

528

529

#### CTC Loss

530

531

```python { .api }

532

def ctc_loss(logits, logit_paddings, labels, label_paddings):

533

"""

534

Connectionist Temporal Classification (CTC) loss.

535

536

Args:

537

logits: Log probabilities over vocabulary

538

logit_paddings: Padding mask for logits

539

labels: Target label sequences

540

label_paddings: Padding mask for labels

541

542

Returns:

543

Scalar CTC loss value

544

"""

545

546

def ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings):

547

"""

548

CTC loss with forward probabilities for additional insights.

549

550

Args:

551

logits: Log probabilities over vocabulary

552

logit_paddings: Padding mask for logits

553

labels: Target label sequences

554

label_paddings: Padding mask for labels

555

556

Returns:

557

Tuple of (loss, forward_probs)

558

"""

559

```

560

561

### Self-Supervised Losses

562

563

#### Contrastive Learning

564

565

```python { .api }

566

def ntxent(query_features, key_features, temperature=1.0):

567

"""

568

Normalized Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning.

569

570

Args:

571

query_features: Query feature vectors

572

key_features: Key feature vectors

573

temperature: Temperature scaling parameter (default: 1.0)

574

575

Returns:

576

Scalar contrastive loss value

577

"""

578

```

579

580

### Label Processing

581

582

```python { .api }

583

def smooth_labels(labels, alpha):

584

"""

585

Apply label smoothing to one-hot labels.

586

587

Args:

588

labels: One-hot encoded labels

589

alpha: Smoothing parameter (0 = no smoothing, 1 = uniform)

590

591

Returns:

592

Smoothed label distribution

593

"""

594

```

595

596

## Usage Examples

597

598

### Basic Regression

599

600

```python

601

import optax

602

import jax.numpy as jnp

603

604

# Predictions and targets

605

predictions = jnp.array([1.0, 2.0, 3.0])

606

targets = jnp.array([1.1, 1.9, 3.2])

607

608

# Compute losses

609

mse_loss = optax.l2_loss(predictions, targets)

610

huber_loss_val = optax.huber_loss(predictions, targets, delta=1.0)

611

```

612

613

### Classification Setup

614

615

```python

616

# Multi-class classification

617

logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 3.0, 0.5]])

618

one_hot_labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])

619

integer_labels = jnp.array([0, 1])

620

621

# Cross-entropy losses

622

ce_loss = optax.softmax_cross_entropy(logits, one_hot_labels)

623

ce_int_loss = optax.softmax_cross_entropy_with_integer_labels(logits, integer_labels)

624

625

# Binary classification

626

binary_logits = jnp.array([0.5, -1.2, 2.1])

627

binary_labels = jnp.array([1.0, 0.0, 1.0])

628

binary_loss = optax.sigmoid_binary_cross_entropy(binary_logits, binary_labels)

629

```

630

631

### Training Loop Integration

632

633

```python

634

import jax

635

636

def compute_loss(params, batch_x, batch_y):

637

"""Compute loss for a batch."""

638

predictions = model_fn(params, batch_x)

639

return optax.softmax_cross_entropy_with_integer_labels(predictions, batch_y)

640

641

def train_step(params, opt_state, batch_x, batch_y):

642

"""Single training step."""

643

# Compute loss and gradients

644

loss_val, grads = jax.value_and_grad(compute_loss)(params, batch_x, batch_y)

645

646

# Update parameters

647

updates, opt_state = optimizer.update(grads, opt_state)

648

params = optax.apply_updates(params, updates)

649

650

return params, opt_state, loss_val

651

```

652

653

### Advanced Loss Combinations

654

655

```python

656

def combined_loss(predictions, targets, params):

657

"""Combine multiple loss terms."""

658

# Main task loss

659

task_loss = optax.softmax_cross_entropy(predictions, targets)

660

661

# Regularization loss

662

l2_reg = sum(optax.l2_loss(p, jnp.zeros_like(p)) for p in jax.tree_leaves(params))

663

664

# Total loss

665

return task_loss + 1e-4 * l2_reg

666

667

# With label smoothing

668

smoothed_labels = optax.smooth_labels(one_hot_labels, alpha=0.1)

669

smooth_loss = optax.softmax_cross_entropy(logits, smoothed_labels)

670

```