or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcallbacks.mdcore-training.mddata.mdfabric.mdindex.mdloggers.mdprecision.mdprofilers.mdstrategies.md

callbacks.mddocs/

0

# Callbacks and Lifecycle Hooks

1

2

Comprehensive callback system for training lifecycle management including checkpointing, early stopping, learning rate scheduling, monitoring, and optimization callbacks. Callbacks provide a clean way to add functionality without modifying the core training loop.

3

4

## Capabilities

5

6

### Model Checkpointing

7

8

Automatically save model checkpoints during training based on monitored metrics, with support for saving top-k models and automatic cleanup.

9

10

```python { .api }

11

class ModelCheckpoint(Callback):

12

def __init__(

13

self,

14

dirpath: Optional[str] = None,

15

filename: Optional[str] = None,

16

monitor: Optional[str] = None,

17

verbose: bool = False,

18

save_last: Optional[bool] = None,

19

save_top_k: int = 1,

20

save_weights_only: bool = False,

21

mode: str = "min",

22

auto_insert_metric_name: bool = True,

23

every_n_train_steps: Optional[int] = None,

24

train_time_interval: Optional[timedelta] = None,

25

every_n_epochs: Optional[int] = None,

26

save_on_train_epoch_end: Optional[bool] = None,

27

enable_version_counter: bool = True

28

):

29

"""

30

Initialize ModelCheckpoint callback.

31

32

Args:

33

dirpath: Directory to save checkpoints

34

filename: Checkpoint filename pattern

35

monitor: Metric to monitor for saving best models

36

verbose: Print checkpoint saving messages

37

save_last: Always save the last checkpoint

38

save_top_k: Number of best models to save

39

save_weights_only: Save only model weights

40

mode: 'min' or 'max' for monitored metric

41

auto_insert_metric_name: Insert metric name in filename

42

every_n_train_steps: Save every N training steps

43

train_time_interval: Save every time interval

44

every_n_epochs: Save every N epochs

45

save_on_train_epoch_end: Save at end of training epoch

46

enable_version_counter: Enable version counter in filename

47

"""

48

49

@property

50

def best_model_path(self) -> str:

51

"""Path to the best saved model."""

52

53

@property

54

def best_model_score(self) -> Optional[float]:

55

"""Score of the best saved model."""

56

57

@property

58

def last_model_path(self) -> str:

59

"""Path to the last saved model."""

60

```

61

62

### Early Stopping

63

64

Stop training when a monitored metric stops improving, with configurable patience and thresholds to prevent overfitting.

65

66

```python { .api }

67

class EarlyStopping(Callback):

68

def __init__(

69

self,

70

monitor: str,

71

min_delta: float = 0.0,

72

patience: int = 3,

73

verbose: bool = False,

74

mode: str = "min",

75

strict: bool = True,

76

check_finite: bool = True,

77

stopping_threshold: Optional[float] = None,

78

divergence_threshold: Optional[float] = None,

79

check_on_train_epoch_end: Optional[bool] = None,

80

log_rank_zero_only: bool = False

81

):

82

"""

83

Initialize EarlyStopping callback.

84

85

Args:

86

monitor: Metric to monitor

87

min_delta: Minimum change to qualify as improvement

88

patience: Number of epochs with no improvement to wait

89

verbose: Print early stopping messages

90

mode: 'min' or 'max' for monitored metric

91

strict: Raise error if monitored metric is not found

92

check_finite: Stop if monitored metric is not finite

93

stopping_threshold: Stop when metric reaches this threshold

94

divergence_threshold: Stop if metric diverges beyond this

95

check_on_train_epoch_end: Check metric at end of training epoch

96

log_rank_zero_only: Log only on rank 0

97

"""

98

99

@property

100

def wait_count(self) -> int:

101

"""Number of epochs waited since last improvement."""

102

103

@property

104

def best_score(self) -> Optional[float]:

105

"""Best score achieved."""

106

107

@property

108

def stopped_epoch(self) -> int:

109

"""Epoch when training was stopped."""

110

```

111

112

### Learning Rate Monitoring

113

114

Monitor and log learning rate changes during training, supporting multiple optimizers and schedulers.

115

116

```python { .api }

117

class LearningRateMonitor(Callback):

118

def __init__(

119

self,

120

logging_interval: str = "epoch",

121

log_momentum: bool = False,

122

log_weight_decay: bool = False

123

):

124

"""

125

Initialize LearningRateMonitor callback.

126

127

Args:

128

logging_interval: 'step' or 'epoch' for logging frequency

129

log_momentum: Also log momentum values

130

log_weight_decay: Also log weight decay values

131

"""

132

```

133

134

### Stochastic Weight Averaging

135

136

Implement stochastic weight averaging to improve model generalization by averaging weights from multiple epochs.

137

138

```python { .api }

139

class StochasticWeightAveraging(Callback):

140

def __init__(

141

self,

142

swa_lrs: Union[float, List[float]],

143

swa_epoch_start: Union[int, float] = 0.8,

144

annealing_epochs: int = 10,

145

annealing_strategy: str = "cos",

146

avg_fn: Optional[Callable] = None,

147

device: Optional[Union[torch.device, str]] = None

148

):

149

"""

150

Initialize StochasticWeightAveraging callback.

151

152

Args:

153

swa_lrs: Learning rate(s) for SWA

154

swa_epoch_start: Epoch to start SWA (int or fraction)

155

annealing_epochs: Number of epochs for annealing

156

annealing_strategy: 'linear' or 'cos' annealing

157

avg_fn: Custom averaging function

158

device: Device for SWA model

159

"""

160

```

161

162

### Progress Bars

163

164

Visual progress indicators during training with customizable display options and rich formatting support.

165

166

```python { .api }

167

class TQDMProgressBar(Callback):

168

def __init__(

169

self,

170

refresh_rate: int = 1,

171

process_position: int = 0

172

):

173

"""

174

Initialize TQDM progress bar.

175

176

Args:

177

refresh_rate: Progress bar refresh rate

178

process_position: Position for multiple progress bars

179

"""

180

181

class RichProgressBar(Callback):

182

def __init__(

183

self,

184

refresh_rate: int = 1,

185

leave: bool = False,

186

theme: RichProgressBarTheme = RichProgressBarTheme(),

187

console_kwargs: Optional[Dict[str, Any]] = None

188

):

189

"""

190

Initialize Rich progress bar with enhanced formatting.

191

192

Args:

193

refresh_rate: Progress bar refresh rate

194

leave: Keep progress bar after completion

195

theme: Rich theme configuration

196

console_kwargs: Additional console arguments

197

"""

198

199

class ProgressBar(Callback):

200

def __init__(self):

201

"""Base progress bar callback."""

202

203

def disable(self) -> None:

204

"""Disable the progress bar."""

205

206

def enable(self) -> None:

207

"""Enable the progress bar."""

208

```

209

210

### Model Summary Display

211

212

Display detailed model architecture information including layer types, parameters, and memory usage.

213

214

```python { .api }

215

class ModelSummary(Callback):

216

def __init__(self, max_depth: int = 1):

217

"""

218

Initialize ModelSummary callback.

219

220

Args:

221

max_depth: Maximum depth for nested modules

222

"""

223

224

class RichModelSummary(Callback):

225

def __init__(self, max_depth: int = 1):

226

"""

227

Initialize RichModelSummary with enhanced formatting.

228

229

Args:

230

max_depth: Maximum depth for nested modules

231

"""

232

```

233

234

### Hyperparameter Optimization

235

236

Callbacks for automated hyperparameter tuning including batch size finding and learning rate finding.

237

238

```python { .api }

239

class BatchSizeFinder(Callback):

240

def __init__(

241

self,

242

mode: str = "power",

243

steps_per_trial: int = 3,

244

init_val: int = 2,

245

max_trials: int = 25,

246

batch_arg_name: str = "batch_size"

247

):

248

"""

249

Initialize BatchSizeFinder callback.

250

251

Args:

252

mode: 'power' or 'binsearch' for search strategy

253

steps_per_trial: Steps per batch size trial

254

init_val: Initial batch size

255

max_trials: Maximum number of trials

256

batch_arg_name: Argument name for batch size

257

"""

258

259

class LearningRateFinder(Callback):

260

def __init__(

261

self,

262

min_lr: float = 1e-8,

263

max_lr: float = 1.0,

264

num_training: int = 100,

265

mode: str = "exponential",

266

early_stop_threshold: float = 4.0,

267

update_attr: bool = False

268

):

269

"""

270

Initialize LearningRateFinder callback.

271

272

Args:

273

min_lr: Minimum learning rate

274

max_lr: Maximum learning rate

275

num_training: Number of training steps

276

mode: 'exponential' or 'linear' search

277

early_stop_threshold: Threshold for early stopping

278

update_attr: Update model's learning rate attribute

279

"""

280

```

281

282

### Fine-tuning Callbacks

283

284

Specialized callbacks for transfer learning and progressive fine-tuning strategies.

285

286

```python { .api }

287

class BaseFinetuning(Callback):

288

def __init__(self, unfreeze_at_epoch: int = 10, lambda_func: Optional[Callable] = None):

289

"""

290

Base class for fine-tuning callbacks.

291

292

Args:

293

unfreeze_at_epoch: Epoch to unfreeze parameters

294

lambda_func: Function to determine learning rates

295

"""

296

297

def freeze_before_training(self, pl_module: LightningModule) -> None:

298

"""Freeze parameters before training starts."""

299

300

def finetune_function(

301

self,

302

pl_module: LightningModule,

303

current_epoch: int,

304

optimizer: Optimizer,

305

optimizer_idx: int

306

) -> None:

307

"""Function called during fine-tuning."""

308

309

class BackboneFinetuning(BaseFinetuning):

310

def __init__(

311

self,

312

unfreeze_backbone_at_epoch: int = 10,

313

lambda_func: Optional[Callable] = None,

314

backbone_initial_ratio_lr: float = 0.1,

315

backbone_initial_lr: Optional[float] = None,

316

should_align: bool = True,

317

initial_denom_lr: float = 10.0,

318

train_bn: bool = True

319

):

320

"""

321

Fine-tuning callback for backbone networks.

322

323

Args:

324

unfreeze_backbone_at_epoch: Epoch to unfreeze backbone

325

lambda_func: Learning rate scheduling function

326

backbone_initial_ratio_lr: Initial backbone LR ratio

327

backbone_initial_lr: Initial backbone learning rate

328

should_align: Align learning rates

329

initial_denom_lr: Initial denominator for LR calculation

330

train_bn: Train batch normalization layers

331

"""

332

```

333

334

### Performance Monitoring

335

336

Callbacks for monitoring training performance, throughput, and resource utilization.

337

338

```python { .api }

339

class ThroughputMonitor(Callback):

340

def __init__(

341

self,

342

length_key: str = "seq_len",

343

batch_size_key: str = "batch_size",

344

window_size: int = 100

345

):

346

"""

347

Initialize ThroughputMonitor callback.

348

349

Args:

350

length_key: Key for sequence length in batch

351

batch_size_key: Key for batch size

352

window_size: Window size for throughput calculation

353

"""

354

355

class DeviceStatsMonitor(Callback):

356

def __init__(self, cpu_stats: Optional[bool] = None):

357

"""

358

Initialize DeviceStatsMonitor callback.

359

360

Args:

361

cpu_stats: Monitor CPU statistics

362

"""

363

364

class Timer(Callback):

365

def __init__(self, duration: Optional[Union[str, timedelta]] = None, interval: str = "step"):

366

"""

367

Initialize Timer callback for training duration control.

368

369

Args:

370

duration: Maximum training duration

371

interval: 'step' or 'epoch' for timing

372

"""

373

```

374

375

### Custom Callback Creation

376

377

```python { .api }

378

class LambdaCallback(Callback):

379

def __init__(self, **kwargs):

380

"""

381

Create callback from lambda functions.

382

383

Args:

384

**kwargs: Mapping of hook names to functions

385

"""

386

```

387

388

## Usage Examples

389

390

### Basic Callback Usage

391

392

```python

393

from lightning import Trainer

394

from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

395

396

# Configure callbacks

397

checkpoint_callback = ModelCheckpoint(

398

monitor='val_loss',

399

dirpath='./checkpoints',

400

filename='model-{epoch:02d}-{val_loss:.2f}',

401

save_top_k=3,

402

mode='min'

403

)

404

405

early_stopping = EarlyStopping(

406

monitor='val_loss',

407

patience=5,

408

mode='min'

409

)

410

411

# Use callbacks in trainer

412

trainer = Trainer(

413

callbacks=[checkpoint_callback, early_stopping],

414

max_epochs=100

415

)

416

```

417

418

### Custom Callback Example

419

420

```python

421

import lightning as L

422

423

class MetricLoggingCallback(L.Callback):

424

def on_train_epoch_end(self, trainer, pl_module):

425

# Log custom metrics at end of each epoch

426

metrics = trainer.callback_metrics

427

epoch = trainer.current_epoch

428

429

# Custom logging logic

430

if 'train_loss' in metrics:

431

print(f"Epoch {epoch}: Train Loss = {metrics['train_loss']:.4f}")

432

433

# Save metrics to file

434

with open('metrics.log', 'a') as f:

435

f.write(f"Epoch {epoch}: {dict(metrics)}\n")

436

437

# Use custom callback

438

trainer = Trainer(callbacks=[MetricLoggingCallback()])

439

```