or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

activations.mdapplications.mdbackend-config.mdcore-framework.mdindex.mdinitializers.mdlayers.mdlosses-metrics.mdoperations.mdoptimizers.mdpreprocessing.mdregularizers.mdtraining-callbacks.md

training-callbacks.mddocs/

0

# Training and Callbacks

1

2

Training utilities, callbacks for monitoring and controlling training processes, and model persistence functionality for saving and loading models during and after training.

3

4

## Capabilities

5

6

### Training Control Callbacks

7

8

Callbacks that control the training process based on monitored metrics.

9

10

```python { .api }

11

class EarlyStopping:

12

"""

13

Stop training when monitored metric stops improving.

14

15

Args:

16

monitor (str): Metric to monitor

17

min_delta (float): Minimum change to qualify as improvement

18

patience (int): Number of epochs with no improvement to wait

19

verbose (int): Verbosity mode

20

mode (str): 'auto', 'min', or 'max'

21

baseline (float, optional): Baseline value for monitored metric

22

restore_best_weights (bool): Whether to restore best weights

23

start_from_epoch (int): Epoch to start monitoring from

24

"""

25

def __init__(self, monitor='val_loss', min_delta=0, patience=0, verbose=0,

26

mode='auto', baseline=None, restore_best_weights=False,

27

start_from_epoch=0, **kwargs): ...

28

29

class ReduceLROnPlateau:

30

"""

31

Reduce learning rate when metric stops improving.

32

33

Args:

34

monitor (str): Metric to monitor

35

factor (float): Factor to reduce learning rate by

36

patience (int): Number of epochs with no improvement to wait

37

verbose (int): Verbosity mode

38

mode (str): 'auto', 'min', or 'max'

39

min_delta (float): Minimum change to qualify as improvement

40

cooldown (int): Number of epochs to wait before resuming normal operation

41

min_lr (float): Lower bound on learning rate

42

"""

43

def __init__(self, monitor='val_loss', factor=0.1, patience=10, verbose=0,

44

mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, **kwargs): ...

45

46

class LearningRateScheduler:

47

"""

48

Learning rate scheduler with custom schedule function.

49

50

Args:

51

schedule (callable): Function that takes epoch index and current learning rate

52

verbose (int): Verbosity mode

53

"""

54

def __init__(self, schedule, verbose=0, **kwargs): ...

55

56

class TerminateOnNaN:

57

"""Terminate training when loss becomes NaN."""

58

def __init__(self, **kwargs): ...

59

```

60

61

### Model Persistence Callbacks

62

63

Callbacks for saving model checkpoints and handling training state.

64

65

```python { .api }

66

class ModelCheckpoint:

67

"""

68

Save model checkpoints during training.

69

70

Args:

71

filepath (str): Path to save model files

72

monitor (str): Metric to monitor for best model

73

verbose (int): Verbosity mode

74

save_best_only (bool): Only save when model improves

75

save_weights_only (bool): Only save model weights

76

mode (str): 'auto', 'min', or 'max'

77

save_freq (str or int): Frequency to save ('epoch' or integer steps)

78

options (SaveOptions, optional): Options for saving

79

initial_value_threshold (float, optional): Initial threshold for metric

80

"""

81

def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False,

82

save_weights_only=False, mode='auto', save_freq='epoch', **kwargs): ...

83

84

class BackupAndRestore:

85

"""

86

Backup and restore training state for fault tolerance.

87

88

Args:

89

backup_dir (str): Directory to store backup files

90

save_freq (str or int): Frequency to save backups

91

delete_checkpoint (bool): Whether to delete old checkpoints

92

"""

93

def __init__(self, backup_dir, save_freq='epoch', delete_checkpoint=True, **kwargs): ...

94

```

95

96

### Logging and Monitoring Callbacks

97

98

Callbacks for logging training progress and monitoring metrics.

99

100

```python { .api }

101

class History:

102

"""

103

Record training history (automatically added to model.fit).

104

105

Attributes:

106

history (dict): Dictionary containing training metrics by epoch

107

"""

108

def __init__(self, **kwargs): ...

109

110

class CSVLogger:

111

"""

112

Log training progress to CSV file.

113

114

Args:

115

filename (str): Path to CSV file

116

separator (str): Field separator

117

append (bool): Whether to append to existing file

118

"""

119

def __init__(self, filename, separator=',', append=False, **kwargs): ...

120

121

class TensorBoard:

122

"""

123

Log training metrics for TensorBoard visualization.

124

125

Args:

126

log_dir (str): Directory to save TensorBoard log files

127

histogram_freq (int): Frequency to compute activation histograms

128

write_graph (bool): Whether to visualize computation graph

129

write_images (bool): Whether to write model weights as images

130

write_steps_per_second (bool): Whether to log training speed

131

update_freq (str or int): Frequency to write logs ('batch', 'epoch', or integer)

132

profile_batch (int or tuple): Batch(es) to profile for performance

133

embeddings_freq (int): Frequency to save embeddings

134

embeddings_metadata (dict, optional): Metadata for embeddings

135

"""

136

def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True,

137

write_images=False, write_steps_per_second=False, update_freq='epoch',

138

profile_batch=0, embeddings_freq=0, **kwargs): ...

139

140

class ProgbarLogger:

141

"""

142

Display training progress bar (automatically added to model.fit).

143

144

Args:

145

count_mode (str): 'steps' or 'samples'

146

stateful_metrics (set, optional): Metrics that shouldn't be averaged

147

"""

148

def __init__(self, count_mode='samples', stateful_metrics=None, **kwargs): ...

149

150

class RemoteMonitor:

151

"""

152

Send training events to remote monitoring server.

153

154

Args:

155

root (str): Root URL of monitoring server

156

path (str): Path to send events to

157

field (str): Field name for data

158

headers (dict, optional): HTTP headers

159

send_as_json (bool): Whether to send data as JSON

160

"""

161

def __init__(self, root='http://localhost:9000', path='/publish/epoch/end/',

162

field='data', headers=None, send_as_json=False, **kwargs): ...

163

```

164

165

### Utility Callbacks

166

167

General purpose and custom callbacks for specialized training scenarios.

168

169

```python { .api }

170

class LambdaCallback:

171

"""

172

Create custom callback using lambda functions.

173

174

Args:

175

on_epoch_begin (callable, optional): Function called at epoch start

176

on_epoch_end (callable, optional): Function called at epoch end

177

on_batch_begin (callable, optional): Function called at batch start

178

on_batch_end (callable, optional): Function called at batch end

179

on_train_begin (callable, optional): Function called at training start

180

on_train_end (callable, optional): Function called at training end

181

"""

182

def __init__(self, on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None,

183

on_batch_end=None, on_train_begin=None, on_train_end=None, **kwargs): ...

184

185

class SwapEMAWeights:

186

"""

187

Swap Exponential Moving Average weights for evaluation.

188

189

Args:

190

swap_on_epoch (bool): Whether to swap weights at epoch end

191

"""

192

def __init__(self, swap_on_epoch=False, **kwargs): ...

193

```

194

195

### Model Persistence Functions

196

197

Functions for saving and loading complete models or weights only.

198

199

```python { .api }

200

def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):

201

"""

202

Save complete model to file.

203

204

Args:

205

model: Keras model to save

206

filepath (str): Path to save model

207

overwrite (bool): Whether to overwrite existing file

208

save_format (str, optional): Format to save in ('tf', 'h5', or None for auto)

209

include_optimizer (bool): Whether to save optimizer state

210

save_traces (bool): Whether to save function traces

211

options (SaveOptions, optional): Platform-specific save options

212

signatures (callable or dict, optional): Model signatures for SavedModel

213

"""

214

215

def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):

216

"""

217

Load saved model from file.

218

219

Args:

220

filepath (str): Path to saved model

221

custom_objects (dict, optional): Custom objects for deserialization

222

compile (bool): Whether to compile loaded model

223

safe_mode (bool): Whether to load in safe mode

224

225

Returns:

226

Model: Loaded Keras model

227

"""

228

229

def save_weights(model, filepath, overwrite=True, save_format=None, options=None):

230

"""

231

Save model weights to file.

232

233

Args:

234

model: Keras model

235

filepath (str): Path to save weights

236

overwrite (bool): Whether to overwrite existing file

237

save_format (str, optional): Format to save in

238

options (SaveOptions, optional): Platform-specific save options

239

"""

240

241

def load_weights(model, filepath, skip_mismatch=False, by_name=False, options=None):

242

"""

243

Load model weights from file.

244

245

Args:

246

model: Keras model

247

filepath (str): Path to saved weights

248

skip_mismatch (bool): Whether to skip layers with mismatched shapes

249

by_name (bool): Whether to load weights by layer name

250

options (SaveOptions, optional): Platform-specific load options

251

"""

252

```

253

254

### Base Callback Class

255

256

Base class for creating custom callbacks.

257

258

```python { .api }

259

class Callback:

260

"""

261

Base class for callbacks.

262

263

Attributes:

264

params (dict): Training parameters

265

model (Model): Reference to training model

266

"""

267

def __init__(self, **kwargs): ...

268

269

def set_params(self, params): ...

270

def set_model(self, model): ...

271

272

def on_train_begin(self, logs=None): ...

273

def on_train_end(self, logs=None): ...

274

def on_epoch_begin(self, epoch, logs=None): ...

275

def on_epoch_end(self, epoch, logs=None): ...

276

def on_train_batch_begin(self, batch, logs=None): ...

277

def on_train_batch_end(self, batch, logs=None): ...

278

def on_test_batch_begin(self, batch, logs=None): ...

279

def on_test_batch_end(self, batch, logs=None): ...

280

def on_predict_batch_begin(self, batch, logs=None): ...

281

def on_predict_batch_end(self, batch, logs=None): ...

282

```

283

284

## Usage Examples

285

286

### Basic Training with Callbacks

287

288

```python

289

import keras

290

from keras import layers, callbacks

291

292

# Build model

293

model = keras.Sequential([

294

layers.Dense(64, activation='relu', input_shape=(784,)),

295

layers.Dropout(0.2),

296

layers.Dense(10, activation='softmax')

297

])

298

299

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

300

301

# Configure callbacks

302

callback_list = [

303

callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),

304

callbacks.ModelCheckpoint('best_model.keras', save_best_only=True),

305

callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),

306

callbacks.TensorBoard(log_dir='./logs')

307

]

308

309

# Train with callbacks

310

history = model.fit(

311

x_train, y_train,

312

epochs=100,

313

validation_data=(x_val, y_val),

314

callbacks=callback_list

315

)

316

```

317

318

### Custom Callback

319

320

```python

321

import keras

322

from keras import callbacks

323

import numpy as np

324

325

class ValidationMetrics(callbacks.Callback):

326

def __init__(self, validation_data, **kwargs):

327

super().__init__(**kwargs)

328

self.validation_data = validation_data

329

330

def on_epoch_end(self, epoch, logs=None):

331

val_x, val_y = self.validation_data

332

predictions = self.model.predict(val_x, verbose=0)

333

334

# Calculate custom metrics

335

accuracy = np.mean(np.argmax(predictions, axis=1) == val_y)

336

print(f'Custom validation accuracy: {accuracy:.4f}')

337

338

# Log custom metrics

339

logs = logs or {}

340

logs['custom_val_acc'] = accuracy

341

342

# Use custom callback

343

custom_callback = ValidationMetrics((x_val, y_val))

344

model.fit(x_train, y_train, epochs=10, callbacks=[custom_callback])

345

```

346

347

### Learning Rate Scheduling

348

349

```python

350

import keras

351

from keras import callbacks

352

import math

353

354

def step_decay(epoch, lr):

355

"""Step decay schedule."""

356

drop_rate = 0.5

357

epochs_drop = 10

358

return lr * math.pow(drop_rate, math.floor(epoch / epochs_drop))

359

360

def cosine_decay(epoch, lr):

361

"""Cosine annealing schedule."""

362

max_epochs = 100

363

return 0.001 * 0.5 * (1 + math.cos(math.pi * epoch / max_epochs))

364

365

# Use scheduling callback

366

lr_scheduler = callbacks.LearningRateScheduler(step_decay, verbose=1)

367

368

model.fit(

369

x_train, y_train,

370

epochs=50,

371

validation_data=(x_val, y_val),

372

callbacks=[lr_scheduler]

373

)

374

```

375

376

### Model Checkpointing Strategy

377

378

```python

379

import keras

380

from keras import callbacks

381

382

# Save best model based on validation loss

383

checkpoint_best = callbacks.ModelCheckpoint(

384

filepath='models/best_model_{epoch:02d}_{val_loss:.2f}.keras',

385

monitor='val_loss',

386

save_best_only=True,

387

save_weights_only=False,

388

verbose=1

389

)

390

391

# Save model every 5 epochs

392

checkpoint_regular = callbacks.ModelCheckpoint(

393

filepath='models/model_epoch_{epoch:02d}.keras',

394

save_freq=5,

395

verbose=1

396

)

397

398

# Backup and restore for fault tolerance

399

backup_restore = callbacks.BackupAndRestore(backup_dir='./backup')

400

401

model.fit(

402

x_train, y_train,

403

epochs=100,

404

validation_data=(x_val, y_val),

405

callbacks=[checkpoint_best, checkpoint_regular, backup_restore]

406

)

407

```