or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

activations.mdapplications.mddata-utils.mdindex.mdinitializers.mdlayers.mdmodels.mdoperations.mdrandom.mdregularizers.mdsaving.mdtraining.md

training.mddocs/

0

# Training and Optimization

1

2

Optimizers, loss functions, metrics, and callbacks for training neural networks effectively. These components control how models learn from data and how training progress is monitored and controlled.

3

4

## Capabilities

5

6

### Optimizers

7

8

Optimization algorithms that update model parameters during training to minimize the loss function.

9

10

```python { .api }

11

class Optimizer:

12

def __init__(self, learning_rate=0.001, name=None, **kwargs):

13

"""

14

Base class for all optimizers.

15

16

Parameters:

17

- learning_rate: Initial learning rate

18

- name: Name of the optimizer

19

"""

20

21

def apply_gradients(self, grads_and_vars):

22

"""

23

Apply gradients to variables.

24

25

Parameters:

26

- grads_and_vars: List of (gradient, variable) pairs

27

"""

28

29

class SGD(Optimizer):

30

def __init__(self, learning_rate=0.01, momentum=0.0, nesterov=False, **kwargs):

31

"""

32

Stochastic Gradient Descent optimizer.

33

34

Parameters:

35

- learning_rate: Learning rate

36

- momentum: Momentum factor

37

- nesterov: Whether to apply Nesterov momentum

38

"""

39

40

class Adam(Optimizer):

41

def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,

42

epsilon=1e-7, amsgrad=False, weight_decay=None, clipnorm=None,

43

clipvalue=None, global_clipnorm=None, use_ema=False,

44

ema_momentum=0.99, ema_overwrite_frequency=None,

45

loss_scale_factor=None, gradient_accumulation_steps=None, **kwargs):

46

"""

47

Adam optimizer.

48

49

Parameters:

50

- learning_rate: Learning rate

51

- beta_1: Exponential decay rate for first moment estimates

52

- beta_2: Exponential decay rate for second moment estimates

53

- epsilon: Small constant for numerical stability

54

- amsgrad: Whether to apply AMSGrad variant

55

- weight_decay: Weight decay coefficient

56

- clipnorm: Global norm clipping value

57

- clipvalue: Value clipping threshold

58

- global_clipnorm: Global gradient norm clipping

59

- use_ema: Whether to use exponential moving average

60

- ema_momentum: EMA momentum coefficient

61

- ema_overwrite_frequency: EMA overwrite frequency

62

- loss_scale_factor: Loss scaling factor

63

- gradient_accumulation_steps: Gradient accumulation steps

64

"""

65

66

class AdamW(Optimizer):

67

def __init__(self, learning_rate=0.001, weight_decay=0.004, beta_1=0.9,

68

beta_2=0.999, epsilon=1e-7, amsgrad=False, **kwargs):

69

"""

70

AdamW optimizer with decoupled weight decay.

71

72

Parameters:

73

- learning_rate: Learning rate

74

- weight_decay: Weight decay coefficient

75

- beta_1: Exponential decay rate for first moment estimates

76

- beta_2: Exponential decay rate for second moment estimates

77

- epsilon: Small constant for numerical stability

78

- amsgrad: Whether to apply AMSGrad variant

79

"""

80

81

class RMSprop(Optimizer):

82

def __init__(self, learning_rate=0.001, rho=0.9, momentum=0.0,

83

epsilon=1e-7, centered=False, **kwargs):

84

"""

85

RMSprop optimizer.

86

87

Parameters:

88

- learning_rate: Learning rate

89

- rho: Discounting factor for history/coming gradient

90

- momentum: Momentum factor

91

- epsilon: Small constant for numerical stability

92

- centered: Whether to normalize by estimated variance

93

"""

94

95

class Adagrad(Optimizer):

96

def __init__(self, learning_rate=0.001, initial_accumulator_value=0.1,

97

epsilon=1e-7, **kwargs):

98

"""

99

Adagrad optimizer.

100

101

Parameters:

102

- learning_rate: Learning rate

103

- initial_accumulator_value: Initial value for accumulators

104

- epsilon: Small constant for numerical stability

105

"""

106

107

class Adadelta(Optimizer):

108

def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-7, **kwargs):

109

"""

110

Adadelta optimizer.

111

112

Parameters:

113

- learning_rate: Learning rate

114

- rho: Decay factor

115

- epsilon: Small constant for numerical stability

116

"""

117

```

118

119

### Loss Functions

120

121

Functions that measure the difference between predicted and actual values, guiding the optimization process.

122

123

```python { .api }

124

class Loss:

125

def __init__(self, reduction='sum_over_batch_size', name=None, **kwargs):

126

"""

127

Base class for all loss functions.

128

129

Parameters:

130

- reduction: Type of reduction to apply

131

- name: Name of the loss function

132

"""

133

134

def __call__(self, y_true, y_pred, sample_weight=None):

135

"""

136

Compute loss value.

137

138

Parameters:

139

- y_true: Ground truth values

140

- y_pred: Predicted values

141

- sample_weight: Optional sample weights

142

143

Returns:

144

Loss value

145

"""

146

147

class SparseCategoricalCrossentropy(Loss):

148

def __init__(self, from_logits=False, ignore_class=None, **kwargs):

149

"""

150

Sparse categorical crossentropy loss.

151

152

Parameters:

153

- from_logits: Whether predictions are logits or probabilities

154

- ignore_class: Optional class index to ignore

155

"""

156

157

class CategoricalCrossentropy(Loss):

158

def __init__(self, from_logits=False, label_smoothing=0.0, **kwargs):

159

"""

160

Categorical crossentropy loss.

161

162

Parameters:

163

- from_logits: Whether predictions are logits or probabilities

164

- label_smoothing: Label smoothing factor

165

"""

166

167

class BinaryCrossentropy(Loss):

168

def __init__(self, from_logits=False, label_smoothing=0.0, **kwargs):

169

"""

170

Binary crossentropy loss.

171

172

Parameters:

173

- from_logits: Whether predictions are logits or probabilities

174

- label_smoothing: Label smoothing factor

175

"""

176

177

class MeanSquaredError(Loss):

178

def __init__(self, **kwargs):

179

"""Mean squared error loss."""

180

181

class MeanAbsoluteError(Loss):

182

def __init__(self, **kwargs):

183

"""Mean absolute error loss."""

184

185

class Huber(Loss):

186

def __init__(self, delta=1.0, **kwargs):

187

"""

188

Huber loss.

189

190

Parameters:

191

- delta: Threshold for switching from quadratic to linear loss

192

"""

193

194

class KLDivergence(Loss):

195

def __init__(self, **kwargs):

196

"""Kullback-Leibler divergence loss."""

197

198

class CosineSimilarity(Loss):

199

def __init__(self, axis=-1, **kwargs):

200

"""

201

Cosine similarity loss.

202

203

Parameters:

204

- axis: Axis along which to compute cosine similarity

205

"""

206

```

207

208

### Metrics

209

210

Functions for monitoring training and evaluation performance without affecting the optimization process.

211

212

```python { .api }

213

class Metric:

214

def __init__(self, name=None, dtype=None, **kwargs):

215

"""

216

Base class for all metrics.

217

218

Parameters:

219

- name: Name of the metric

220

- dtype: Data type for metric computations

221

"""

222

223

def update_state(self, y_true, y_pred, sample_weight=None):

224

"""

225

Update metric state with new observations.

226

227

Parameters:

228

- y_true: Ground truth values

229

- y_pred: Predicted values

230

- sample_weight: Optional sample weights

231

"""

232

233

def result(self):

234

"""

235

Compute and return metric value.

236

237

Returns:

238

Metric value as tensor

239

"""

240

241

def reset_state(self):

242

"""Reset all metric state variables."""

243

244

class Accuracy(Metric):

245

def __init__(self, name='accuracy', dtype=None, **kwargs):

246

"""Accuracy metric for classification tasks."""

247

248

class SparseCategoricalAccuracy(Metric):

249

def __init__(self, name='sparse_categorical_accuracy', dtype=None, **kwargs):

250

"""Sparse categorical accuracy metric."""

251

252

class CategoricalAccuracy(Metric):

253

def __init__(self, name='categorical_accuracy', dtype=None, **kwargs):

254

"""Categorical accuracy metric."""

255

256

class TopKCategoricalAccuracy(Metric):

257

def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None, **kwargs):

258

"""

259

Top-k categorical accuracy metric.

260

261

Parameters:

262

- k: Number of top predictions to consider

263

"""

264

265

class Precision(Metric):

266

def __init__(self, thresholds=None, top_k=None, class_id=None,

267

name=None, dtype=None, **kwargs):

268

"""

269

Precision metric.

270

271

Parameters:

272

- thresholds: Optional thresholds for binary classification

273

- top_k: Number of top predictions to consider

274

- class_id: Specific class to compute metric for

275

"""

276

277

class Recall(Metric):

278

def __init__(self, thresholds=None, top_k=None, class_id=None,

279

name=None, dtype=None, **kwargs):

280

"""Recall metric."""

281

282

class AUC(Metric):

283

def __init__(self, num_thresholds=200, curve='ROC', summation_method='interpolation',

284

name=None, dtype=None, **kwargs):

285

"""

286

Area under the curve metric.

287

288

Parameters:

289

- num_thresholds: Number of thresholds for approximation

290

- curve: Type of curve ('ROC' or 'PR')

291

- summation_method: Method for approximating AUC

292

"""

293

294

class F1Score(Metric):

295

def __init__(self, average=None, threshold=None, name='f1_score', dtype=None, **kwargs):

296

"""

297

F1 score metric.

298

299

Parameters:

300

- average: Type of averaging ('micro', 'macro', 'weighted', or None)

301

- threshold: Decision threshold for binary classification

302

"""

303

304

class MeanSquaredError(Metric):

305

def __init__(self, name='mean_squared_error', dtype=None, **kwargs):

306

"""Mean squared error metric."""

307

308

class MeanAbsoluteError(Metric):

309

def __init__(self, name='mean_absolute_error', dtype=None, **kwargs):

310

"""Mean absolute error metric."""

311

312

class RootMeanSquaredError(Metric):

313

def __init__(self, name='root_mean_squared_error', dtype=None, **kwargs):

314

"""Root mean squared error metric."""

315

```

316

317

### Callbacks

318

319

Utilities that can perform actions at various stages of training, such as saving models, adjusting learning rates, or early stopping.

320

321

```python { .api }

322

class Callback:

323

def __init__(self):

324

"""Base class for all callbacks."""

325

326

def on_epoch_begin(self, epoch, logs=None):

327

"""Called at the beginning of an epoch."""

328

329

def on_epoch_end(self, epoch, logs=None):

330

"""Called at the end of an epoch."""

331

332

def on_batch_begin(self, batch, logs=None):

333

"""Called at the beginning of a batch."""

334

335

def on_batch_end(self, batch, logs=None):

336

"""Called at the end of a batch."""

337

338

def on_train_begin(self, logs=None):

339

"""Called at the beginning of training."""

340

341

def on_train_end(self, logs=None):

342

"""Called at the end of training."""

343

344

class ModelCheckpoint(Callback):

345

def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False,

346

save_weights_only=False, mode='auto', save_freq='epoch', **kwargs):

347

"""

348

Save model or weights at some frequency.

349

350

Parameters:

351

- filepath: Path to save model/weights

352

- monitor: Metric to monitor for saving

353

- verbose: Verbosity mode

354

- save_best_only: Whether to save only when monitored metric improves

355

- save_weights_only: Whether to save only weights

356

- mode: One of {'auto', 'min', 'max'}

357

- save_freq: 'epoch' or integer (number of batches)

358

"""

359

360

class EarlyStopping(Callback):

361

def __init__(self, monitor='val_loss', min_delta=0, patience=0, verbose=0,

362

mode='auto', baseline=None, restore_best_weights=False, **kwargs):

363

"""

364

Stop training when monitored metric has stopped improving.

365

366

Parameters:

367

- monitor: Metric to monitor

368

- min_delta: Minimum change to qualify as improvement

369

- patience: Number of epochs with no improvement to wait

370

- verbose: Verbosity mode

371

- mode: One of {'auto', 'min', 'max'}

372

- baseline: Baseline value for monitored metric

373

- restore_best_weights: Whether to restore model weights from best epoch

374

"""

375

376

class ReduceLROnPlateau(Callback):

377

def __init__(self, monitor='val_loss', factor=0.1, patience=10, verbose=0,

378

mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, **kwargs):

379

"""

380

Reduce learning rate when metric has stopped improving.

381

382

Parameters:

383

- monitor: Metric to monitor

384

- factor: Factor by which learning rate will be reduced

385

- patience: Number of epochs with no improvement to wait

386

- verbose: Verbosity mode

387

- mode: One of {'auto', 'min', 'max'}

388

- min_delta: Threshold for measuring new optimum

389

- cooldown: Number of epochs to wait before resuming normal operation

390

- min_lr: Lower bound on learning rate

391

"""

392

393

class LearningRateScheduler(Callback):

394

def __init__(self, schedule, verbose=0, **kwargs):

395

"""

396

Learning rate scheduler.

397

398

Parameters:

399

- schedule: Function that takes epoch index and returns new learning rate

400

- verbose: Verbosity mode

401

"""

402

403

class TensorBoard(Callback):

404

def __init__(self, log_dir='logs', histogram_freq=0, write_graph=True,

405

write_images=False, write_steps_per_second=False,

406

update_freq='epoch', **kwargs):

407

"""

408

TensorBoard logging callback.

409

410

Parameters:

411

- log_dir: Directory to write TensorBoard logs

412

- histogram_freq: Frequency for writing histograms

413

- write_graph: Whether to write computation graph

414

- write_images: Whether to write model weights as images

415

- write_steps_per_second: Whether to log steps/second

416

- update_freq: 'batch', 'epoch', or integer (number of batches)

417

"""

418

419

class CSVLogger(Callback):

420

def __init__(self, filename, separator=',', append=False, **kwargs):

421

"""

422

Stream epoch results to CSV file.

423

424

Parameters:

425

- filename: Path to CSV file

426

- separator: String used to separate elements in CSV file

427

- append: Whether to append if file exists

428

"""

429

```

430

431

## Usage Examples

432

433

### Basic Training Setup

434

435

```python

436

import keras

437

from keras import layers, optimizers, losses, metrics

438

439

# Build model

440

model = keras.Sequential([

441

layers.Dense(128, activation='relu', input_shape=(784,)),

442

layers.Dropout(0.2),

443

layers.Dense(10, activation='softmax')

444

])

445

446

# Compile with custom optimizer and metrics

447

model.compile(

448

optimizer=optimizers.Adam(learning_rate=0.001),

449

loss=losses.SparseCategoricalCrossentropy(),

450

metrics=[

451

metrics.SparseCategoricalAccuracy(),

452

metrics.TopKCategoricalAccuracy(k=3)

453

]

454

)

455

```

456

457

### Training with Callbacks

458

459

```python

460

from keras import callbacks

461

462

# Define callbacks

463

checkpoint = callbacks.ModelCheckpoint(

464

'best_model.keras',

465

monitor='val_accuracy',

466

save_best_only=True,

467

verbose=1

468

)

469

470

early_stop = callbacks.EarlyStopping(

471

monitor='val_loss',

472

patience=5,

473

restore_best_weights=True

474

)

475

476

reduce_lr = callbacks.ReduceLROnPlateau(

477

monitor='val_loss',

478

factor=0.2,

479

patience=3,

480

min_lr=1e-7

481

)

482

483

# Train with callbacks

484

history = model.fit(

485

x_train, y_train,

486

batch_size=32,

487

epochs=100,

488

validation_data=(x_val, y_val),

489

callbacks=[checkpoint, early_stop, reduce_lr]

490

)

491

```

492

493

### Custom Learning Rate Schedule

494

495

```python

496

from keras import callbacks

497

import math

498

499

def lr_schedule(epoch, lr):

500

if epoch < 10:

501

return lr

502

else:

503

return lr * math.exp(-0.1)

504

505

lr_scheduler = callbacks.LearningRateScheduler(lr_schedule, verbose=1)

506

507

model.fit(

508

x_train, y_train,

509

epochs=50,

510

callbacks=[lr_scheduler]

511

)

512

```

513

514

### Multi-GPU Training

515

516

```python

517

import keras

518

519

# Create distributed strategy

520

strategy = keras.distribute.MirroredStrategy()

521

522

with strategy.scope():

523

# Create model within strategy scope

524

model = keras.Sequential([

525

layers.Dense(128, activation='relu', input_shape=(784,)),

526

layers.Dense(10, activation='softmax')

527

])

528

529

model.compile(

530

optimizer='adam',

531

loss='sparse_categorical_crossentropy',

532

metrics=['accuracy']

533

)

534

535

# Train on multiple GPUs

536

model.fit(x_train, y_train, epochs=10)

537

```