or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

apps.mdfabric.mdindex.mdtraining.mdutilities.md

training.mddocs/

0

# Training and Model Organization

1

2

PyTorch Lightning components for organizing training code, managing experiments, and scaling across devices. This module provides the main training orchestrator, base classes for models and data, and the callback system for extending functionality.

3

4

## Capabilities

5

6

### Trainer

7

8

The central orchestrator that automates the training loop, handles device management, logging, checkpointing, and validation. Supports distributed training across multiple GPUs, TPUs, and nodes.

9

10

```python { .api }

11

class Trainer:

12

def __init__(

13

self,

14

logger: Union[Logger, Iterable[Logger], bool] = True,

15

enable_checkpointing: bool = True,

16

callbacks: Optional[Union[List[Callback], Callback]] = None,

17

default_root_dir: Optional[str] = None,

18

gradient_clip_val: Optional[Union[int, float]] = None,

19

gradient_clip_algorithm: Optional[str] = None,

20

num_nodes: int = 1,

21

devices: Optional[Union[List[int], str, int]] = None,

22

enable_progress_bar: bool = True,

23

overfit_batches: Union[int, float] = 0.0,

24

track_grad_norm: Union[int, float, str] = -1,

25

check_val_every_n_epoch: Optional[int] = 1,

26

val_check_interval: Union[int, float] = 1.0,

27

log_every_n_steps: int = 50,

28

accelerator: Optional[str] = None,

29

strategy: Optional[str] = None,

30

sync_batchnorm: bool = False,

31

precision: Optional[Union[int, str]] = None,

32

enable_model_summary: bool = True,

33

max_epochs: Optional[int] = None,

34

min_epochs: Optional[int] = None,

35

max_steps: int = -1,

36

min_steps: Optional[int] = None,

37

max_time: Optional[Union[str, timedelta]] = None,

38

limit_train_batches: Optional[Union[int, float]] = None,

39

limit_val_batches: Optional[Union[int, float]] = None,

40

limit_test_batches: Optional[Union[int, float]] = None,

41

limit_predict_batches: Optional[Union[int, float]] = None,

42

fast_dev_run: Union[int, bool] = False,

43

accumulate_grad_batches: int = 1,

44

profiler: Optional[Union[str, Profiler]] = None,

45

benchmark: Optional[bool] = None,

46

deterministic: Optional[Union[bool, str]] = None,

47

reload_dataloaders_every_n_epochs: int = 0,

48

auto_lr_find: Union[bool, str] = False,

49

replace_sampler_ddp: bool = True,

50

detect_anomaly: bool = False,

51

auto_scale_batch_size: Union[str, bool] = False,

52

plugins: Optional[Union[str, list]] = None,

53

move_metrics_to_cpu: bool = False,

54

multiple_trainloader_mode: str = "max_size_cycle",

55

inference_mode: bool = True,

56

use_distributed_sampler: bool = True,

57

barebones: bool = False,

58

**kwargs

59

):

60

"""

61

Lightning Trainer for automating the training process.

62

63

Parameters:

64

- logger: Logger instance or list of loggers, or True for default TensorBoard logger

65

- enable_checkpointing: Enable automatic model checkpointing

66

- callbacks: Callback instances to customize training behavior

67

- default_root_dir: Default directory for logs and checkpoints

68

- gradient_clip_val: Gradient clipping value (0 means no clipping)

69

- gradient_clip_algorithm: Gradient clipping algorithm ('value' or 'norm')

70

- num_nodes: Number of nodes for distributed training

71

- devices: Device specification (int, list, or 'auto')

72

- enable_progress_bar: Show progress bar during training

73

- overfit_batches: Overfit on a subset of data for debugging

74

- track_grad_norm: Track gradient norms (int for L-norm, -1 to disable)

75

- check_val_every_n_epoch: Run validation every N epochs

76

- val_check_interval: Validation frequency within an epoch

77

- log_every_n_steps: Log metrics every N training steps

78

- accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')

79

- strategy: Training strategy for distributed training

80

- sync_batchnorm: Synchronize batch norm across devices

81

- precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')

82

- enable_model_summary: Print model summary at training start

83

- max_epochs: Maximum number of epochs to train

84

- min_epochs: Minimum number of epochs to train

85

- max_steps: Maximum number of training steps

86

- min_steps: Minimum number of training steps

87

- max_time: Maximum training time

88

- limit_train_batches: Limit training batches per epoch

89

- limit_val_batches: Limit validation batches

90

- limit_test_batches: Limit test batches

91

- limit_predict_batches: Limit prediction batches

92

- fast_dev_run: Quick development run with limited batches

93

- accumulate_grad_batches: Accumulate gradients over N batches

94

- profiler: Profiler for performance analysis

95

- benchmark: Enable cuDNN benchmarking for consistent input sizes

96

- deterministic: Enable deterministic training (may impact performance)

97

- reload_dataloaders_every_n_epochs: Reload dataloaders periodically

98

- auto_lr_find: Automatically find optimal learning rate

99

- replace_sampler_ddp: Replace sampler with DistributedSampler for DDP

100

- detect_anomaly: Enable anomaly detection for debugging

101

- auto_scale_batch_size: Automatically scale batch size

102

- plugins: Additional plugins for custom functionality

103

- move_metrics_to_cpu: Move metrics to CPU to save GPU memory

104

- multiple_trainloader_mode: Mode for handling multiple train dataloaders

105

- inference_mode: Use inference mode during validation/test/predict

106

- use_distributed_sampler: Use distributed sampler in DDP

107

- barebones: Minimal trainer setup for maximum performance

108

"""

109

110

def fit(

111

self,

112

model: LightningModule,

113

train_dataloaders: Optional[DataLoader] = None,

114

val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,

115

datamodule: Optional[LightningDataModule] = None,

116

ckpt_path: Optional[str] = None

117

):

118

"""

119

Train the model.

120

121

Parameters:

122

- model: LightningModule to train

123

- train_dataloaders: Training dataloader(s)

124

- val_dataloaders: Validation dataloader(s)

125

- datamodule: LightningDataModule containing dataloaders

126

- ckpt_path: Path to checkpoint to resume training from

127

"""

128

129

def validate(

130

self,

131

model: Optional[LightningModule] = None,

132

dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,

133

ckpt_path: Optional[str] = None,

134

verbose: bool = True,

135

datamodule: Optional[LightningDataModule] = None

136

):

137

"""

138

Run validation loop.

139

140

Parameters:

141

- model: LightningModule to validate

142

- dataloaders: Validation dataloader(s)

143

- ckpt_path: Path to checkpoint to load

144

- verbose: Print validation results

145

- datamodule: LightningDataModule containing dataloaders

146

147

Returns:

148

List of validation results

149

"""

150

151

def test(

152

self,

153

model: Optional[LightningModule] = None,

154

dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,

155

ckpt_path: Optional[str] = None,

156

verbose: bool = True,

157

datamodule: Optional[LightningDataModule] = None

158

):

159

"""

160

Run test loop.

161

162

Parameters:

163

- model: LightningModule to test

164

- dataloaders: Test dataloader(s)

165

- ckpt_path: Path to checkpoint to load

166

- verbose: Print test results

167

- datamodule: LightningDataModule containing dataloaders

168

169

Returns:

170

List of test results

171

"""

172

173

def predict(

174

self,

175

model: Optional[LightningModule] = None,

176

dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,

177

datamodule: Optional[LightningDataModule] = None,

178

return_predictions: Optional[bool] = None,

179

ckpt_path: Optional[str] = None

180

):

181

"""

182

Run prediction loop.

183

184

Parameters:

185

- model: LightningModule for predictions

186

- dataloaders: Prediction dataloader(s)

187

- datamodule: LightningDataModule containing dataloaders

188

- return_predictions: Return predictions in memory

189

- ckpt_path: Path to checkpoint to load

190

191

Returns:

192

List of predictions if return_predictions=True

193

"""

194

195

def tune(

196

self,

197

model: LightningModule,

198

train_dataloaders: Optional[DataLoader] = None,

199

val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,

200

datamodule: Optional[LightningDataModule] = None,

201

scale_batch_size_kwargs: Optional[dict] = None,

202

lr_find_kwargs: Optional[dict] = None

203

):

204

"""

205

Auto-tune model hyperparameters.

206

207

Parameters:

208

- model: LightningModule to tune

209

- train_dataloaders: Training dataloader(s)

210

- val_dataloaders: Validation dataloader(s)

211

- datamodule: LightningDataModule containing dataloaders

212

- scale_batch_size_kwargs: Arguments for batch size scaling

213

- lr_find_kwargs: Arguments for learning rate finding

214

215

Returns:

216

Tuning results

217

"""

218

```

219

220

### LightningModule

221

222

Base class for organizing PyTorch model code with standardized hooks for training, validation, testing, and prediction. Handles optimizer configuration and provides extensive customization points.

223

224

```python { .api }

225

class LightningModule:

226

def __init__(self):

227

"""Base class for organizing PyTorch model logic."""

228

229

def forward(self, *args, **kwargs):

230

"""

231

Define the forward pass of the model.

232

233

Returns:

234

Model predictions

235

"""

236

237

def training_step(self, batch, batch_idx: int):

238

"""

239

Define training step logic.

240

241

Parameters:

242

- batch: Training batch data

243

- batch_idx: Index of the current batch

244

245

Returns:

246

Training loss (torch.Tensor) or dict with 'loss' key

247

"""

248

249

def validation_step(self, batch, batch_idx: int):

250

"""

251

Define validation step logic.

252

253

Parameters:

254

- batch: Validation batch data

255

- batch_idx: Index of the current batch

256

257

Returns:

258

Validation outputs (optional)

259

"""

260

261

def test_step(self, batch, batch_idx: int):

262

"""

263

Define test step logic.

264

265

Parameters:

266

- batch: Test batch data

267

- batch_idx: Index of the current batch

268

269

Returns:

270

Test outputs (optional)

271

"""

272

273

def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):

274

"""

275

Define prediction step logic.

276

277

Parameters:

278

- batch: Prediction batch data

279

- batch_idx: Index of the current batch

280

- dataloader_idx: Index of the current dataloader

281

282

Returns:

283

Predictions

284

"""

285

286

def configure_optimizers(self):

287

"""

288

Configure optimizers and learning rate schedulers.

289

290

Returns:

291

Optimizer, list of optimizers, or dict with optimizer/scheduler config

292

"""

293

294

def configure_callbacks(self):

295

"""

296

Configure model-specific callbacks.

297

298

Returns:

299

List of callback instances

300

"""

301

302

def log(self, name: str, value, prog_bar: bool = False, logger: bool = True,

303

on_step: bool = None, on_epoch: bool = None, reduce_fx: str = "mean",

304

enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: str = None,

305

add_dataloader_idx: bool = True, batch_size: int = None, metric_attribute: str = None,

306

rank_zero_only: bool = False):

307

"""

308

Log metrics during training.

309

310

Parameters:

311

- name: Metric name

312

- value: Metric value

313

- prog_bar: Show in progress bar

314

- logger: Send to logger

315

- on_step: Log at current step

316

- on_epoch: Log at epoch end

317

- reduce_fx: Reduction function for distributed training

318

- enable_graph: Keep computation graph

319

- sync_dist: Synchronize across distributed processes

320

- sync_dist_group: Process group for synchronization

321

- add_dataloader_idx: Add dataloader index to metric name

322

- batch_size: Batch size for proper averaging

323

- metric_attribute: Attribute name for storing metric

324

- rank_zero_only: Log only on rank 0

325

"""

326

```

327

328

### LightningDataModule

329

330

Base class for organizing data loading logic, providing a clean interface for data preparation, dataset setup, and dataloader creation across different stages of training.

331

332

```python { .api }

333

class LightningDataModule:

334

def __init__(self, *args, **kwargs):

335

"""Base class for organizing data loading logic."""

336

337

def setup(self, stage: str = None):

338

"""

339

Setup datasets for different stages.

340

341

Parameters:

342

- stage: Current stage ('fit', 'validate', 'test', 'predict')

343

"""

344

345

def prepare_data(self):

346

"""

347

Download and prepare data (called once per node).

348

Use this for data downloading, tokenization, etc.

349

"""

350

351

def train_dataloader(self):

352

"""

353

Create training dataloader.

354

355

Returns:

356

DataLoader for training

357

"""

358

359

def val_dataloader(self):

360

"""

361

Create validation dataloader.

362

363

Returns:

364

DataLoader or list of DataLoaders for validation

365

"""

366

367

def test_dataloader(self):

368

"""

369

Create test dataloader.

370

371

Returns:

372

DataLoader or list of DataLoaders for testing

373

"""

374

375

def predict_dataloader(self):

376

"""

377

Create prediction dataloader.

378

379

Returns:

380

DataLoader or list of DataLoaders for prediction

381

"""

382

383

def teardown(self, stage: str = None):

384

"""

385

Clean up after training/testing.

386

387

Parameters:

388

- stage: Current stage ('fit', 'validate', 'test', 'predict')

389

"""

390

```

391

392

### Callback System

393

394

Base class for creating custom training callbacks that can hook into different stages of the training process to extend functionality.

395

396

```python { .api }

397

class Callback:

398

def __init__(self):

399

"""Base class for creating training callbacks."""

400

401

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):

402

"""Called when training begins."""

403

404

def on_train_end(self, trainer: Trainer, pl_module: LightningModule):

405

"""Called when training ends."""

406

407

def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):

408

"""Called at the beginning of each epoch."""

409

410

def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):

411

"""Called at the end of each epoch."""

412

413

def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):

414

"""Called at the beginning of each training epoch."""

415

416

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):

417

"""Called at the end of each training epoch."""

418

419

def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):

420

"""Called at the beginning of each validation epoch."""

421

422

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):

423

"""Called at the end of each validation epoch."""

424

425

def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):

426

"""Called at the beginning of each test epoch."""

427

428

def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):

429

"""Called at the end of each test epoch."""

430

431

def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):

432

"""Called before each training batch."""

433

434

def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):

435

"""Called after each training batch."""

436

437

def on_validation_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):

438

"""Called before each validation batch."""

439

440

def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):

441

"""Called after each validation batch."""

442

443

def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):

444

"""Called before each test batch."""

445

446

def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):

447

"""Called after each test batch."""

448

```

449

450

## Usage Examples

451

452

### Basic Training Setup

453

454

```python

455

import lightning as L

456

import torch

457

import torch.nn as nn

458

from torch.utils.data import DataLoader, Dataset

459

460

class MyDataset(Dataset):

461

def __init__(self, size=1000):

462

self.size = size

463

464

def __len__(self):

465

return self.size

466

467

def __getitem__(self, idx):

468

return torch.randn(10), torch.randn(1)

469

470

class MyModel(L.LightningModule):

471

def __init__(self):

472

super().__init__()

473

self.layer = nn.Linear(10, 1)

474

475

def forward(self, x):

476

return self.layer(x)

477

478

def training_step(self, batch, batch_idx):

479

x, y = batch

480

y_hat = self(x)

481

loss = nn.functional.mse_loss(y_hat, y)

482

self.log('train_loss', loss)

483

return loss

484

485

def configure_optimizers(self):

486

return torch.optim.Adam(self.parameters(), lr=0.02)

487

488

# Training

489

model = MyModel()

490

trainer = L.Trainer(max_epochs=3)

491

train_loader = DataLoader(MyDataset(), batch_size=32)

492

trainer.fit(model, train_loader)

493

```

494

495

### Using DataModule

496

497

```python

498

class MyDataModule(L.LightningDataModule):

499

def __init__(self, batch_size=32):

500

super().__init__()

501

self.batch_size = batch_size

502

503

def setup(self, stage=None):

504

if stage == 'fit':

505

self.train_dataset = MyDataset(size=800)

506

self.val_dataset = MyDataset(size=200)

507

elif stage == 'test':

508

self.test_dataset = MyDataset(size=100)

509

510

def train_dataloader(self):

511

return DataLoader(self.train_dataset, batch_size=self.batch_size)

512

513

def val_dataloader(self):

514

return DataLoader(self.val_dataset, batch_size=self.batch_size)

515

516

def test_dataloader(self):

517

return DataLoader(self.test_dataset, batch_size=self.batch_size)

518

519

# Training with DataModule

520

model = MyModel()

521

datamodule = MyDataModule(batch_size=64)

522

trainer = L.Trainer(max_epochs=3)

523

trainer.fit(model, datamodule=datamodule)

524

trainer.test(model, datamodule=datamodule)

525

```