or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcallbacks.mdcore-training.mddata.mdfabric.mdindex.mdloggers.mdprecision.mdprofilers.mdstrategies.md

core-training.mddocs/

0

# Core Training Components

1

2

Essential components for structuring deep learning training workflows in Lightning. These components provide the foundation for organized, scalable, and reproducible machine learning training.

3

4

## Capabilities

5

6

### Trainer

7

8

The main entry point for Lightning training that orchestrates the entire training process, handling distributed training, logging, checkpointing, and validation automatically.

9

10

```python { .api }

11

class Trainer:

12

def __init__(

13

self,

14

accelerator: str = "auto",

15

strategy: str = "auto",

16

devices: Union[List[int], str, int] = "auto",

17

num_nodes: int = 1,

18

precision: Union[str, int] = "32-true",

19

logger: Union[Logger, bool] = True,

20

callbacks: Optional[List[Callback]] = None,

21

fast_dev_run: Union[bool, int] = False,

22

max_epochs: Optional[int] = None,

23

min_epochs: Optional[int] = None,

24

max_steps: int = -1,

25

min_steps: Optional[int] = None,

26

max_time: Union[str, timedelta] = None,

27

limit_train_batches: Union[int, float] = 1.0,

28

limit_val_batches: Union[int, float] = 1.0,

29

limit_test_batches: Union[int, float] = 1.0,

30

limit_predict_batches: Union[int, float] = 1.0,

31

overfit_batches: Union[int, float] = 0.0,

32

val_check_interval: Union[int, float] = 1.0,

33

check_val_every_n_epoch: Optional[int] = 1,

34

num_sanity_val_steps: int = 2,

35

log_every_n_steps: int = 50,

36

enable_checkpointing: bool = True,

37

enable_progress_bar: bool = True,

38

enable_model_summary: bool = True,

39

accumulate_grad_batches: int = 1,

40

gradient_clip_val: Optional[float] = None,

41

gradient_clip_algorithm: Optional[str] = None,

42

deterministic: Optional[bool] = None,

43

benchmark: Optional[bool] = None,

44

inference_mode: bool = True,

45

use_distributed_sampler: bool = True,

46

profiler: Optional[Profiler] = None,

47

detect_anomaly: bool = False,

48

barebones: bool = False,

49

plugins: Optional[List[Any]] = None,

50

sync_batchnorm: bool = False,

51

reload_dataloaders_every_n_epochs: int = 0,

52

default_root_dir: Optional[str] = None,

53

**kwargs

54

):

55

"""

56

Initialize the Lightning Trainer.

57

58

Args:

59

accelerator: Hardware accelerator type ('cpu', 'gpu', 'tpu', 'auto')

60

strategy: Distributed training strategy ('ddp', 'fsdp', 'deepspeed', etc.)

61

devices: Which devices to use for training

62

num_nodes: Number of nodes for distributed training

63

precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)

64

logger: Logger instance or True/False to enable/disable default logger

65

callbacks: List of callbacks to use during training

66

fast_dev_run: Run a single batch for debugging

67

max_epochs: Maximum number of epochs to train

68

min_epochs: Minimum number of epochs to train

69

max_steps: Maximum number of training steps

70

min_steps: Minimum number of training steps

71

max_time: Maximum training time

72

limit_train_batches: Limit training batches per epoch

73

limit_val_batches: Limit validation batches per epoch

74

val_check_interval: How often to check validation

75

enable_checkpointing: Enable automatic checkpointing

76

enable_progress_bar: Show progress bar during training

77

accumulate_grad_batches: Gradient accumulation steps

78

gradient_clip_val: Gradient clipping value

79

deterministic: Make training deterministic

80

profiler: Profiler for performance analysis

81

"""

82

83

def fit(

84

self,

85

model: LightningModule,

86

train_dataloaders: Optional[TRAIN_DATALOADERS] = None,

87

val_dataloaders: Optional[EVAL_DATALOADERS] = None,

88

datamodule: Optional[LightningDataModule] = None,

89

ckpt_path: Optional[str] = None

90

) -> None:

91

"""

92

Fit the model with training and validation data.

93

94

Args:

95

model: LightningModule to train

96

train_dataloaders: Training data loaders

97

val_dataloaders: Validation data loaders

98

datamodule: LightningDataModule containing data loaders

99

ckpt_path: Path to checkpoint to resume from

100

"""

101

102

def validate(

103

self,

104

model: Optional[LightningModule] = None,

105

dataloaders: Optional[EVAL_DATALOADERS] = None,

106

ckpt_path: Optional[str] = None,

107

verbose: bool = True,

108

datamodule: Optional[LightningDataModule] = None

109

) -> List[Dict[str, float]]:

110

"""

111

Run validation on the model.

112

113

Args:

114

model: LightningModule to validate

115

dataloaders: Validation data loaders

116

ckpt_path: Path to checkpoint to load

117

verbose: Print validation results

118

datamodule: LightningDataModule containing data loaders

119

120

Returns:

121

List of validation metrics dictionaries

122

"""

123

124

def test(

125

self,

126

model: Optional[LightningModule] = None,

127

dataloaders: Optional[EVAL_DATALOADERS] = None,

128

ckpt_path: Optional[str] = None,

129

verbose: bool = True,

130

datamodule: Optional[LightningDataModule] = None

131

) -> List[Dict[str, float]]:

132

"""

133

Run testing on the model.

134

135

Args:

136

model: LightningModule to test

137

dataloaders: Test data loaders

138

ckpt_path: Path to checkpoint to load

139

verbose: Print test results

140

datamodule: LightningDataModule containing data loaders

141

142

Returns:

143

List of test metrics dictionaries

144

"""

145

146

def predict(

147

self,

148

model: Optional[LightningModule] = None,

149

dataloaders: Optional[EVAL_DATALOADERS] = None,

150

datamodule: Optional[LightningDataModule] = None,

151

return_predictions: Optional[bool] = None,

152

ckpt_path: Optional[str] = None

153

) -> Optional[List[Any]]:

154

"""

155

Run prediction on the model.

156

157

Args:

158

model: LightningModule to use for prediction

159

dataloaders: Prediction data loaders

160

datamodule: LightningDataModule containing data loaders

161

return_predictions: Whether to return predictions

162

ckpt_path: Path to checkpoint to load

163

164

Returns:

165

List of predictions if return_predictions=True

166

"""

167

168

def tune(

169

self,

170

model: LightningModule,

171

train_dataloaders: Optional[TRAIN_DATALOADERS] = None,

172

val_dataloaders: Optional[EVAL_DATALOADERS] = None,

173

datamodule: Optional[LightningDataModule] = None,

174

scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,

175

lr_find_kwargs: Optional[Dict[str, Any]] = None

176

) -> Dict[str, Any]:

177

"""

178

Tune hyperparameters for the model.

179

180

Args:

181

model: LightningModule to tune

182

train_dataloaders: Training data loaders

183

val_dataloaders: Validation data loaders

184

datamodule: LightningDataModule containing data loaders

185

scale_batch_size_kwargs: Arguments for batch size scaling

186

lr_find_kwargs: Arguments for learning rate finding

187

188

Returns:

189

Dictionary with tuning results

190

"""

191

```

192

193

### LightningModule

194

195

Base class for organizing PyTorch code in Lightning. Defines model architecture, training logic, optimization, and provides hooks for the training lifecycle.

196

197

```python { .api }

198

class LightningModule(nn.Module):

199

def __init__(self):

200

"""Initialize the LightningModule."""

201

super().__init__()

202

203

def forward(self, *args, **kwargs) -> Any:

204

"""

205

Define the forward pass of the model.

206

207

Args:

208

*args: Positional arguments

209

**kwargs: Keyword arguments

210

211

Returns:

212

Model output

213

"""

214

215

def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:

216

"""

217

Define a single training step.

218

219

Args:

220

batch: Batch of training data

221

batch_idx: Index of the current batch

222

223

Returns:

224

Loss tensor or dictionary with 'loss' key

225

"""

226

227

def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:

228

"""

229

Define a single validation step.

230

231

Args:

232

batch: Batch of validation data

233

batch_idx: Index of the current batch

234

235

Returns:

236

Optional loss tensor or metrics dictionary

237

"""

238

239

def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:

240

"""

241

Define a single test step.

242

243

Args:

244

batch: Batch of test data

245

batch_idx: Index of the current batch

246

247

Returns:

248

Optional loss tensor or metrics dictionary

249

"""

250

251

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:

252

"""

253

Define a single prediction step.

254

255

Args:

256

batch: Batch of prediction data

257

batch_idx: Index of the current batch

258

dataloader_idx: Index of the dataloader

259

260

Returns:

261

Model predictions

262

"""

263

264

def configure_optimizers(self) -> Union[Optimizer, Dict[str, Any]]:

265

"""

266

Configure optimizers and learning rate schedulers.

267

268

Returns:

269

Optimizer or dictionary with optimizer/scheduler configuration

270

"""

271

272

def configure_callbacks(self) -> Union[List[Callback], Callback]:

273

"""

274

Configure callbacks for this model.

275

276

Returns:

277

List of callbacks or single callback

278

"""

279

280

def log(

281

self,

282

name: str,

283

value: Any,

284

prog_bar: bool = False,

285

logger: bool = True,

286

on_step: Optional[bool] = None,

287

on_epoch: Optional[bool] = None,

288

reduce_fx: str = "mean",

289

enable_graph: bool = False,

290

sync_dist: bool = False,

291

sync_dist_group: Optional[Any] = None,

292

add_dataloader_idx: bool = True,

293

batch_size: Optional[int] = None,

294

metric_attribute: Optional[str] = None,

295

rank_zero_only: bool = False

296

) -> None:

297

"""

298

Log a key-value pair.

299

300

Args:

301

name: Name of the metric

302

value: Value to log

303

prog_bar: Show in progress bar

304

logger: Send to logger

305

on_step: Log at each step

306

on_epoch: Log at each epoch

307

reduce_fx: Reduction function for distributed training

308

sync_dist: Synchronize across distributed processes

309

batch_size: Current batch size for proper reduction

310

"""

311

312

def log_dict(

313

self,

314

dictionary: Dict[str, Any],

315

prog_bar: bool = False,

316

logger: bool = True,

317

on_step: Optional[bool] = None,

318

on_epoch: Optional[bool] = None,

319

reduce_fx: str = "mean",

320

enable_graph: bool = False,

321

sync_dist: bool = False,

322

sync_dist_group: Optional[Any] = None,

323

add_dataloader_idx: bool = True,

324

batch_size: Optional[int] = None,

325

rank_zero_only: bool = False

326

) -> None:

327

"""

328

Log a dictionary of key-value pairs.

329

330

Args:

331

dictionary: Dictionary of metrics to log

332

prog_bar: Show in progress bar

333

logger: Send to logger

334

on_step: Log at each step

335

on_epoch: Log at each epoch

336

reduce_fx: Reduction function for distributed training

337

sync_dist: Synchronize across distributed processes

338

batch_size: Current batch size for proper reduction

339

"""

340

```

341

342

### LightningDataModule

343

344

Encapsulates data loading logic including data downloading, preparation, splitting, and data loader creation. Provides a clean interface for data handling across train/val/test splits.

345

346

```python { .api }

347

class LightningDataModule:

348

def __init__(self):

349

"""Initialize the LightningDataModule."""

350

351

def prepare_data(self) -> None:

352

"""

353

Download and prepare data. Called only on rank 0.

354

Use this for data download, preprocessing that shouldn't be done on every device.

355

"""

356

357

def setup(self, stage: str) -> None:

358

"""

359

Set up datasets for each stage.

360

361

Args:

362

stage: 'fit', 'validate', 'test', or 'predict'

363

"""

364

365

def train_dataloader(self) -> TRAIN_DATALOADERS:

366

"""

367

Create training data loader.

368

369

Returns:

370

Training data loader(s)

371

"""

372

373

def val_dataloader(self) -> EVAL_DATALOADERS:

374

"""

375

Create validation data loader.

376

377

Returns:

378

Validation data loader(s)

379

"""

380

381

def test_dataloader(self) -> EVAL_DATALOADERS:

382

"""

383

Create test data loader.

384

385

Returns:

386

Test data loader(s)

387

"""

388

389

def predict_dataloader(self) -> EVAL_DATALOADERS:

390

"""

391

Create prediction data loader.

392

393

Returns:

394

Prediction data loader(s)

395

"""

396

397

def teardown(self, stage: str) -> None:

398

"""

399

Clean up after training/testing.

400

401

Args:

402

stage: 'fit', 'validate', 'test', or 'predict'

403

"""

404

405

def state_dict(self) -> Dict[str, Any]:

406

"""

407

Called when saving a checkpoint.

408

409

Returns:

410

Dictionary of state to save

411

"""

412

413

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

414

"""

415

Called when loading a checkpoint.

416

417

Args:

418

state_dict: Dictionary of saved state

419

"""

420

```

421

422

### Callback

423

424

Base class for creating custom callbacks to hook into the training lifecycle. Callbacks provide a way to add functionality at specific points during training.

425

426

```python { .api }

427

class Callback:

428

def __init__(self):

429

"""Initialize the callback."""

430

431

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:

432

"""Called when training begins."""

433

434

def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

435

"""Called when training ends."""

436

437

def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:

438

"""Called when validation begins."""

439

440

def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

441

"""Called when validation ends."""

442

443

def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:

444

"""Called when testing begins."""

445

446

def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

447

"""Called when testing ends."""

448

449

def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:

450

"""Called when an epoch begins."""

451

452

def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

453

"""Called when an epoch ends."""

454

455

def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:

456

"""Called when a training epoch begins."""

457

458

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

459

"""Called when a training epoch ends."""

460

461

def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:

462

"""Called when a validation epoch begins."""

463

464

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

465

"""Called when a validation epoch ends."""

466

467

def on_train_batch_start(

468

self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int

469

) -> None:

470

"""Called when a training batch begins."""

471

472

def on_train_batch_end(

473

self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int

474

) -> None:

475

"""Called when a training batch ends."""

476

477

def on_validation_batch_start(

478

self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0

479

) -> None:

480

"""Called when a validation batch begins."""

481

482

def on_validation_batch_end(

483

self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0

484

) -> None:

485

"""Called when a validation batch ends."""

486

487

def on_before_optimizer_step(

488

self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int

489

) -> None:

490

"""Called before optimizer step."""

491

492

def on_before_zero_grad(

493

self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer

494

) -> None:

495

"""Called before gradients are zeroed."""

496

497

def on_save_checkpoint(

498

self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]

499

) -> None:

500

"""Called when saving a checkpoint."""

501

502

def on_load_checkpoint(

503

self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]

504

) -> None:

505

"""Called when loading a checkpoint."""

506

507

def state_dict(self) -> Dict[str, Any]:

508

"""

509

Called when saving a checkpoint.

510

511

Returns:

512

Dictionary of callback state to save

513

"""

514

515

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

516

"""

517

Called when loading a checkpoint.

518

519

Args:

520

state_dict: Dictionary of saved callback state

521

"""

522

```