or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

activations.mdapplications.mdbackend-config.mdcore-framework.mdindex.mdinitializers.mdlayers.mdlosses-metrics.mdoperations.mdoptimizers.mdpreprocessing.mdregularizers.mdtraining-callbacks.md

losses-metrics.mddocs/

0

# Loss Functions and Metrics

1

2

Comprehensive collection of loss functions for training neural networks and metrics for evaluation, covering classification, regression, and specialized tasks with both class-based and function-based APIs.

3

4

## Capabilities

5

6

### Classification Loss Functions

7

8

Loss functions designed for classification tasks including binary, multiclass, and specialized classification scenarios.

9

10

```python { .api }

11

class BinaryCrossentropy:

12

"""

13

Binary cross-entropy loss for binary classification.

14

15

Args:

16

from_logits (bool): Whether input is logits or probabilities

17

label_smoothing (float): Label smoothing factor

18

axis (int): Axis along which to compute loss

19

reduction (str): Type of reduction to apply

20

name (str): Name of the loss

21

"""

22

def __init__(self, from_logits=False, label_smoothing=0.0, axis=-1, **kwargs): ...

23

24

class CategoricalCrossentropy:

25

"""

26

Categorical cross-entropy loss for multiclass classification.

27

28

Args:

29

from_logits (bool): Whether input is logits or probabilities

30

label_smoothing (float): Label smoothing factor

31

axis (int): Axis along which to compute loss

32

"""

33

def __init__(self, from_logits=False, label_smoothing=0.0, axis=-1, **kwargs): ...

34

35

class SparseCategoricalCrossentropy:

36

"""

37

Sparse categorical cross-entropy for integer labels.

38

39

Args:

40

from_logits (bool): Whether input is logits or probabilities

41

ignore_class (int, optional): Class index to ignore

42

axis (int): Axis along which to compute loss

43

"""

44

def __init__(self, from_logits=False, ignore_class=None, axis=-1, **kwargs): ...

45

46

class BinaryFocalCrossentropy:

47

"""

48

Binary focal loss for addressing class imbalance.

49

50

Args:

51

alpha (float): Weighting factor for rare class

52

gamma (float): Focusing parameter

53

from_logits (bool): Whether input is logits or probabilities

54

label_smoothing (float): Label smoothing factor

55

"""

56

def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, label_smoothing=0.0, **kwargs): ...

57

58

class CategoricalFocalCrossentropy:

59

"""

60

Categorical focal loss for multiclass imbalanced datasets.

61

62

Args:

63

alpha (float): Weighting factor

64

gamma (float): Focusing parameter

65

from_logits (bool): Whether input is logits or probabilities

66

"""

67

def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, **kwargs): ...

68

69

# Function equivalents

70

def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1): ...

71

def categorical_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1): ...

72

def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, ignore_class=None, axis=-1): ...

73

```

74

75

### Regression Loss Functions

76

77

Loss functions for continuous value prediction tasks with various robustness properties.

78

79

```python { .api }

80

class MeanSquaredError:

81

"""

82

Mean squared error loss for regression.

83

84

Args:

85

reduction (str): Type of reduction to apply

86

name (str): Name of the loss

87

"""

88

def __init__(self, **kwargs): ...

89

90

class MeanAbsoluteError:

91

"""

92

Mean absolute error loss for regression.

93

94

Args:

95

reduction (str): Type of reduction to apply

96

name (str): Name of the loss

97

"""

98

def __init__(self, **kwargs): ...

99

100

class MeanAbsolutePercentageError:

101

"""

102

Mean absolute percentage error for regression.

103

104

Args:

105

reduction (str): Type of reduction to apply

106

name (str): Name of the loss

107

"""

108

def __init__(self, **kwargs): ...

109

110

class MeanSquaredLogarithmicError:

111

"""

112

Mean squared logarithmic error for regression.

113

114

Args:

115

reduction (str): Type of reduction to apply

116

name (str): Name of the loss

117

"""

118

def __init__(self, **kwargs): ...

119

120

class Huber:

121

"""

122

Huber loss for robust regression.

123

124

Args:

125

delta (float): Point where loss changes from quadratic to linear

126

reduction (str): Type of reduction to apply

127

name (str): Name of the loss

128

"""

129

def __init__(self, delta=1.0, **kwargs): ...

130

131

class LogCosh:

132

"""

133

Log-cosh loss for regression.

134

135

Args:

136

reduction (str): Type of reduction to apply

137

name (str): Name of the loss

138

"""

139

def __init__(self, **kwargs): ...

140

141

# Function equivalents

142

def mean_squared_error(y_true, y_pred): ...

143

def mean_absolute_error(y_true, y_pred): ...

144

def mean_absolute_percentage_error(y_true, y_pred): ...

145

def huber(y_true, y_pred, delta=1.0): ...

146

```

147

148

### Specialized Loss Functions

149

150

Loss functions for specific tasks including ranking, sequence modeling, and segmentation.

151

152

```python { .api }

153

class Hinge:

154

"""

155

Hinge loss for maximum-margin classification.

156

157

Args:

158

reduction (str): Type of reduction to apply

159

name (str): Name of the loss

160

"""

161

def __init__(self, **kwargs): ...

162

163

class SquaredHinge:

164

"""Squared hinge loss for maximum-margin classification."""

165

def __init__(self, **kwargs): ...

166

167

class CategoricalHinge:

168

"""Categorical hinge loss for multiclass classification."""

169

def __init__(self, **kwargs): ...

170

171

class KLDivergence:

172

"""

173

Kullback-Leibler divergence loss.

174

175

Args:

176

reduction (str): Type of reduction to apply

177

name (str): Name of the loss

178

"""

179

def __init__(self, **kwargs): ...

180

181

class Poisson:

182

"""

183

Poisson loss for count data.

184

185

Args:

186

reduction (str): Type of reduction to apply

187

name (str): Name of the loss

188

"""

189

def __init__(self, **kwargs): ...

190

191

class CosineSimilarity:

192

"""

193

Cosine similarity loss.

194

195

Args:

196

axis (int): Axis along which to compute cosine similarity

197

reduction (str): Type of reduction to apply

198

name (str): Name of the loss

199

"""

200

def __init__(self, axis=-1, **kwargs): ...

201

202

class Dice:

203

"""

204

Dice loss for segmentation tasks.

205

206

Args:

207

axis (int or tuple, optional): Axis to compute dice over

208

reduction (str): Type of reduction to apply

209

name (str): Name of the loss

210

"""

211

def __init__(self, axis=None, **kwargs): ...

212

213

class Tversky:

214

"""

215

Tversky loss for segmentation with adjustable precision/recall balance.

216

217

Args:

218

alpha (float): Weight for false positives

219

beta (float): Weight for false negatives

220

axis (int or tuple, optional): Axis to compute over

221

"""

222

def __init__(self, alpha=0.5, beta=0.5, axis=None, **kwargs): ...

223

224

class CTC:

225

"""

226

Connectionist Temporal Classification loss for sequence labeling.

227

228

Args:

229

logits_time_major (bool): Whether logits are time-major

230

blank_index (int, optional): Index of blank label

231

reduction (str): Type of reduction to apply

232

"""

233

def __init__(self, logits_time_major=False, blank_index=None, **kwargs): ...

234

```

235

236

### Classification Metrics

237

238

Metrics for evaluating classification model performance including accuracy variants and confusion matrix based metrics.

239

240

```python { .api }

241

class Accuracy:

242

"""

243

Generic accuracy metric.

244

245

Args:

246

name (str): Name of the metric

247

dtype (str): Data type for metric computation

248

"""

249

def __init__(self, name='accuracy', dtype=None, **kwargs): ...

250

251

class BinaryAccuracy:

252

"""

253

Binary classification accuracy.

254

255

Args:

256

threshold (float): Decision threshold

257

name (str): Name of the metric

258

dtype (str): Data type for metric computation

259

"""

260

def __init__(self, threshold=0.5, name='binary_accuracy', dtype=None, **kwargs): ...

261

262

class CategoricalAccuracy:

263

"""

264

Categorical accuracy for one-hot encoded labels.

265

266

Args:

267

name (str): Name of the metric

268

dtype (str): Data type for metric computation

269

"""

270

def __init__(self, name='categorical_accuracy', dtype=None, **kwargs): ...

271

272

class SparseCategoricalAccuracy:

273

"""

274

Categorical accuracy for integer labels.

275

276

Args:

277

name (str): Name of the metric

278

dtype (str): Data type for metric computation

279

"""

280

def __init__(self, name='sparse_categorical_accuracy', dtype=None, **kwargs): ...

281

282

class TopKCategoricalAccuracy:

283

"""

284

Top-k categorical accuracy.

285

286

Args:

287

k (int): Number of top predictions to consider

288

name (str): Name of the metric

289

dtype (str): Data type for metric computation

290

"""

291

def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None, **kwargs): ...

292

293

class Precision:

294

"""

295

Precision metric.

296

297

Args:

298

thresholds (list, optional): List of thresholds for multi-threshold precision

299

top_k (int, optional): Top-k precision

300

class_id (int, optional): Class to compute precision for

301

name (str): Name of the metric

302

dtype (str): Data type for metric computation

303

"""

304

def __init__(self, thresholds=None, top_k=None, class_id=None, name='precision', dtype=None, **kwargs): ...

305

306

class Recall:

307

"""

308

Recall metric.

309

310

Args:

311

thresholds (list, optional): List of thresholds for multi-threshold recall

312

top_k (int, optional): Top-k recall

313

class_id (int, optional): Class to compute recall for

314

name (str): Name of the metric

315

dtype (str): Data type for metric computation

316

"""

317

def __init__(self, thresholds=None, top_k=None, class_id=None, name='recall', dtype=None, **kwargs): ...

318

319

class F1Score:

320

"""

321

F1 score metric.

322

323

Args:

324

average (str, optional): Averaging strategy ('micro', 'macro', 'weighted', None)

325

threshold (float, optional): Decision threshold for binary classification

326

name (str): Name of the metric

327

dtype (str): Data type for metric computation

328

"""

329

def __init__(self, average=None, threshold=None, name='f1_score', dtype=None, **kwargs): ...

330

331

class AUC:

332

"""

333

Area Under the ROC Curve metric.

334

335

Args:

336

num_thresholds (int): Number of thresholds for ROC curve

337

curve (str): Type of curve ('ROC' or 'PR')

338

summation_method (str): Method for approximating AUC

339

name (str): Name of the metric

340

dtype (str): Data type for metric computation

341

"""

342

def __init__(self, num_thresholds=200, curve='ROC', summation_method='interpolation',

343

name='auc', dtype=None, **kwargs): ...

344

```

345

346

### Regression Metrics

347

348

Metrics for evaluating regression model performance.

349

350

```python { .api }

351

class MeanSquaredError:

352

"""Mean squared error metric for regression."""

353

def __init__(self, name='mean_squared_error', dtype=None, **kwargs): ...

354

355

class RootMeanSquaredError:

356

"""Root mean squared error metric for regression."""

357

def __init__(self, name='root_mean_squared_error', dtype=None, **kwargs): ...

358

359

class MeanAbsoluteError:

360

"""Mean absolute error metric for regression."""

361

def __init__(self, name='mean_absolute_error', dtype=None, **kwargs): ...

362

363

class MeanAbsolutePercentageError:

364

"""Mean absolute percentage error metric for regression."""

365

def __init__(self, name='mean_absolute_percentage_error', dtype=None, **kwargs): ...

366

367

class R2Score:

368

"""

369

R² (coefficient of determination) metric.

370

371

Args:

372

class_aggregation (str): How to aggregate multiclass R²

373

num_regressors (int, optional): Number of regressors for adjusted R²

374

name (str): Name of the metric

375

dtype (str): Data type for metric computation

376

"""

377

def __init__(self, class_aggregation='uniform_average', num_regressors=0,

378

name='r2_score', dtype=None, **kwargs): ...

379

380

class CosineSimilarity:

381

"""

382

Cosine similarity metric.

383

384

Args:

385

axis (int): Axis along which to compute cosine similarity

386

name (str): Name of the metric

387

dtype (str): Data type for metric computation

388

"""

389

def __init__(self, axis=-1, name='cosine_similarity', dtype=None, **kwargs): ...

390

```

391

392

### Segmentation Metrics

393

394

Metrics for evaluating image segmentation and pixel-wise classification tasks.

395

396

```python { .api }

397

class IoU:

398

"""

399

Intersection over Union (Jaccard Index) metric.

400

401

Args:

402

num_classes (int): Number of classes

403

target_class_ids (list, optional): Specific classes to compute IoU for

404

threshold (float, optional): Threshold for binary predictions

405

name (str): Name of the metric

406

dtype (str): Data type for metric computation

407

"""

408

def __init__(self, num_classes, target_class_ids=None, threshold=None,

409

name='iou', dtype=None, **kwargs): ...

410

411

class MeanIoU:

412

"""

413

Mean Intersection over Union metric.

414

415

Args:

416

num_classes (int): Number of classes

417

name (str): Name of the metric

418

dtype (str): Data type for metric computation

419

"""

420

def __init__(self, num_classes, name='mean_iou', dtype=None, **kwargs): ...

421

422

class BinaryIoU:

423

"""

424

Binary Intersection over Union metric.

425

426

Args:

427

target_class_ids (list, optional): Target class IDs

428

threshold (float): Decision threshold

429

name (str): Name of the metric

430

dtype (str): Data type for metric computation

431

"""

432

def __init__(self, target_class_ids=None, threshold=0.5, name='binary_iou', dtype=None, **kwargs): ...

433

```

434

435

### Utility Functions

436

437

Functions for metric and loss management.

438

439

```python { .api }

440

# Loss utilities

441

def get(identifier):

442

"""Get loss function by name or return callable."""

443

444

def serialize(loss):

445

"""Serialize loss to JSON-serializable dict."""

446

447

def deserialize(config, custom_objects=None):

448

"""Deserialize loss from config dict."""

449

450

# Metric utilities

451

def get(identifier):

452

"""Get metric by name or return callable."""

453

454

def serialize(metric):

455

"""Serialize metric to JSON-serializable dict."""

456

457

def deserialize(config, custom_objects=None):

458

"""Deserialize metric from config dict."""

459

```

460

461

## Usage Examples

462

463

### Using Loss Functions in Model Compilation

464

465

```python

466

import keras

467

from keras import layers, losses, metrics

468

469

model = keras.Sequential([

470

layers.Dense(64, activation='relu', input_shape=(784,)),

471

layers.Dropout(0.2),

472

layers.Dense(10, activation='softmax')

473

])

474

475

# Using string identifiers

476

model.compile(

477

optimizer='adam',

478

loss='sparse_categorical_crossentropy',

479

metrics=['accuracy']

480

)

481

482

# Using class instances for more control

483

model.compile(

484

optimizer='adam',

485

loss=losses.SparseCategoricalCrossentropy(from_logits=False),

486

metrics=[

487

metrics.SparseCategoricalAccuracy(),

488

metrics.TopKCategoricalAccuracy(k=3)

489

]

490

)

491

```

492

493

### Multi-output Model with Different Losses

494

495

```python

496

import keras

497

from keras import layers, losses, metrics

498

499

# Define inputs

500

inputs = keras.Input(shape=(784,))

501

x = layers.Dense(64, activation='relu')(inputs)

502

503

# Multiple outputs

504

classification_output = layers.Dense(10, activation='softmax', name='classification')(x)

505

regression_output = layers.Dense(1, name='regression')(x)

506

507

model = keras.Model(inputs=inputs, outputs=[classification_output, regression_output])

508

509

# Different losses for different outputs

510

model.compile(

511

optimizer='adam',

512

loss={

513

'classification': losses.SparseCategoricalCrossentropy(),

514

'regression': losses.MeanSquaredError()

515

},

516

metrics={

517

'classification': [metrics.SparseCategoricalAccuracy(), metrics.F1Score()],

518

'regression': [metrics.MeanAbsoluteError(), metrics.R2Score()]

519

},

520

loss_weights={'classification': 1.0, 'regression': 0.5}

521

)

522

```

523

524

### Custom Loss Function

525

526

```python

527

import keras

528

from keras import ops

529

530

def focal_loss(alpha=0.25, gamma=2.0):

531

def loss_fn(y_true, y_pred):

532

# Convert to probabilities if logits

533

y_pred = ops.sigmoid(y_pred)

534

535

# Compute focal loss

536

pt = ops.where(y_true == 1, y_pred, 1 - y_pred)

537

alpha_t = ops.where(y_true == 1, alpha, 1 - alpha)

538

focal_weight = alpha_t * ops.power(1 - pt, gamma)

539

540

bce = -ops.log(pt + 1e-8)

541

focal = focal_weight * bce

542

543

return ops.mean(focal)

544

545

return loss_fn

546

547

# Use custom loss

548

model.compile(

549

optimizer='adam',

550

loss=focal_loss(alpha=0.25, gamma=2.0),

551

metrics=['accuracy']

552

)

553

```

554

555

### Custom Metric

556

557

```python

558

import keras

559

from keras import ops

560

561

class F2Score(keras.metrics.Metric):

562

def __init__(self, name='f2_score', **kwargs):

563

super().__init__(name=name, **kwargs)

564

self.precision = keras.metrics.Precision()

565

self.recall = keras.metrics.Recall()

566

567

def update_state(self, y_true, y_pred, sample_weight=None):

568

self.precision.update_state(y_true, y_pred, sample_weight)

569

self.recall.update_state(y_true, y_pred, sample_weight)

570

571

def result(self):

572

p = self.precision.result()

573

r = self.recall.result()

574

return 5 * p * r / (4 * p + r + 1e-8)

575

576

def reset_state(self):

577

self.precision.reset_state()

578

self.recall.reset_state()

579

580

# Use custom metric

581

model.compile(

582

optimizer='adam',

583

loss='binary_crossentropy',

584

metrics=[F2Score(), 'accuracy']

585

)

586

```