or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

core-training.mddocs/

0

# Core Training

1

2

The main Fabric class and associated wrapper classes that provide the foundation for distributed PyTorch training with minimal code changes.

3

4

## Capabilities

5

6

### Fabric Class

7

8

The main orchestrator class that handles all aspects of distributed training setup and execution.

9

10

```python { .api }

11

class Fabric:

12

"""

13

Main class for accelerating PyTorch training with minimal changes.

14

15

Provides automatic device placement, mixed precision, distributed training,

16

and seamless switching between hardware configurations.

17

"""

18

19

def __init__(

20

self,

21

accelerator: Union[str, Accelerator] = "auto",

22

strategy: Union[str, Strategy] = "auto",

23

devices: Union[list[int], str, int] = "auto",

24

num_nodes: int = 1,

25

precision: Optional[Union[str, int]] = None,

26

plugins: Optional[Union[Any, list[Any]]] = None,

27

callbacks: Optional[Union[list[Any], Any]] = None,

28

loggers: Optional[Union[Logger, list[Logger]]] = None

29

):

30

"""

31

Initialize Fabric with hardware and training configuration.

32

33

Args:

34

accelerator: Hardware to run on ("cpu", "cuda", "mps", "gpu", "tpu", "auto")

35

strategy: Distribution strategy ("dp", "ddp", "ddp_spawn", "deepspeed", "fsdp", "auto")

36

devices: Number of devices or specific device IDs

37

num_nodes: Number of nodes for multi-node training

38

precision: Precision mode ("64", "32", "16-mixed", "bf16-mixed", etc.)

39

plugins: Additional plugins for customization

40

callbacks: Callback functions for training events

41

loggers: Logger instances for experiment tracking

42

"""

43

```

44

45

### Setup Methods

46

47

Configure models, optimizers, and dataloaders for distributed training.

48

49

```python { .api }

50

def setup(

51

self,

52

module: nn.Module,

53

*optimizers: Optimizer,

54

move_to_device: bool = True,

55

_reapply_compile: bool = True

56

) -> Union[_FabricModule, tuple[_FabricModule, _FabricOptimizer, ...]]:

57

"""

58

Setup model and optimizers for distributed training.

59

60

Args:

61

module: PyTorch model to setup

62

*optimizers: One or more optimizers

63

move_to_device: Whether to move model to target device

64

_reapply_compile: Whether to reapply torch.compile if present

65

66

Returns:

67

Fabric-wrapped module and optimizers

68

"""

69

70

def setup_module(

71

self,

72

module: nn.Module,

73

move_to_device: bool = True,

74

_reapply_compile: bool = True

75

) -> _FabricModule:

76

"""

77

Setup only the model for distributed training.

78

79

Args:

80

module: PyTorch model to setup

81

move_to_device: Whether to move model to target device

82

_reapply_compile: Whether to reapply torch.compile if present

83

84

Returns:

85

Fabric-wrapped module

86

"""

87

88

def setup_optimizers(

89

self,

90

*optimizers: Optimizer

91

) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]:

92

"""

93

Setup optimizers for distributed training.

94

95

Args:

96

*optimizers: One or more optimizers to setup

97

98

Returns:

99

Fabric-wrapped optimizer(s)

100

"""

101

102

def setup_dataloaders(

103

self,

104

*dataloaders: DataLoader,

105

use_distributed_sampler: bool = True,

106

move_to_device: bool = True

107

) -> Union[DataLoader, list[DataLoader]]:

108

"""

109

Setup dataloaders for distributed training.

110

111

Args:

112

*dataloaders: One or more dataloaders to setup

113

use_distributed_sampler: Whether to replace sampler for distributed training

114

move_to_device: Whether to move data to target device automatically

115

116

Returns:

117

Configured dataloader(s)

118

"""

119

```

120

121

### Training Operations

122

123

Core methods for training loops including backward pass, gradient clipping, and precision handling.

124

125

```python { .api }

126

def backward(

127

self,

128

tensor: Tensor,

129

*args,

130

model: Optional[_FabricModule] = None,

131

**kwargs

132

) -> None:

133

"""

134

Perform backward pass with automatic gradient scaling and accumulation.

135

136

Args:

137

tensor: Loss tensor to compute gradients for

138

*args: Additional arguments passed to tensor.backward()

139

model: Model to sync gradients for (auto-detected if None)

140

**kwargs: Additional keyword arguments

141

"""

142

143

def clip_gradients(

144

self,

145

module: _FabricModule,

146

optimizer: _FabricOptimizer,

147

clip_val: Optional[Union[int, float]] = None,

148

max_norm: Optional[Union[int, float]] = None,

149

norm_type: Union[int, float] = 2.0,

150

error_if_nonfinite: bool = True

151

) -> Optional[Tensor]:

152

"""

153

Clip gradients by value or norm.

154

155

Args:

156

module: Fabric-wrapped module

157

optimizer: Fabric-wrapped optimizer

158

clip_val: Maximum allowed value of gradients

159

max_norm: Maximum allowed norm of gradients

160

norm_type: Type of norm to compute (default: 2.0 for L2 norm)

161

error_if_nonfinite: Whether to error on non-finite gradients

162

163

Returns:

164

Total norm of the parameters if max_norm is specified

165

"""

166

167

def autocast(self) -> AbstractContextManager:

168

"""

169

Context manager for automatic mixed precision.

170

171

Returns:

172

Context manager that applies appropriate precision casting

173

"""

174

```

175

176

### Checkpoint Management

177

178

Save and load model states, optimizers, and training metadata.

179

180

```python { .api }

181

def save(

182

self,

183

path: _PATH,

184

state: dict[str, Any],

185

filter: Optional[dict[str, Any]] = None

186

) -> None:

187

"""

188

Save checkpoint with distributed training support.

189

190

Args:

191

path: Checkpoint file path

192

state: Dictionary containing model, optimizer, and other state

193

filter: Optional filter for state dict keys

194

"""

195

196

def load(

197

self,

198

path: _PATH,

199

state: Optional[dict[str, Any]] = None,

200

strict: bool = True

201

) -> dict[str, Any]:

202

"""

203

Load checkpoint with distributed training support.

204

205

Args:

206

path: Checkpoint file path

207

state: Dictionary to load state into (if provided)

208

strict: Whether to strictly enforce state dict key matching

209

210

Returns:

211

Loaded checkpoint dictionary

212

"""

213

214

def load_raw(

215

self,

216

path: _PATH,

217

obj: Union[nn.Module, Optimizer],

218

strict: bool = True

219

) -> None:

220

"""

221

Load raw PyTorch checkpoint into object.

222

223

Args:

224

path: Checkpoint file path

225

obj: Object to load state into

226

strict: Whether to strictly enforce state dict key matching

227

"""

228

```

229

230

### Process Management

231

232

Launch and coordinate distributed processes.

233

234

```python { .api }

235

def launch(

236

self,

237

function: Callable = lambda: None,

238

*args,

239

**kwargs

240

) -> Any:

241

"""

242

Launch the distributed training processes.

243

244

Args:

245

function: Function to execute in distributed processes

246

*args: Arguments to pass to function

247

**kwargs: Keyword arguments to pass to function

248

249

Returns:

250

Result from function execution

251

"""

252

253

def run(self, *args, **kwargs) -> Any:

254

"""

255

Execute main training function with distributed setup.

256

257

Args:

258

*args: Arguments passed to training function

259

**kwargs: Keyword arguments passed to training function

260

261

Returns:

262

Result from training function

263

"""

264

```

265

266

### Properties

267

268

Access information about the distributed training setup.

269

270

```python { .api }

271

@property

272

def accelerator(self) -> Accelerator:

273

"""Current accelerator instance."""

274

275

@property

276

def strategy(self) -> Strategy:

277

"""Current strategy instance."""

278

279

@property

280

def device(self) -> torch.device:

281

"""Current device."""

282

283

@property

284

def global_rank(self) -> int:

285

"""Global rank of this process."""

286

287

@property

288

def local_rank(self) -> int:

289

"""Local rank of this process on current node."""

290

291

@property

292

def node_rank(self) -> int:

293

"""Rank of current node."""

294

295

@property

296

def world_size(self) -> int:

297

"""Total number of processes."""

298

299

@property

300

def is_global_zero(self) -> bool:

301

"""Whether this is the rank 0 process."""

302

303

@property

304

def loggers(self) -> list[Logger]:

305

"""List of all logger instances."""

306

307

@property

308

def logger(self) -> Logger:

309

"""Primary logger instance."""

310

```

311

312

### Wrapper Classes

313

314

Fabric automatically wraps PyTorch objects to provide distributed training support.

315

316

```python { .api }

317

class _FabricModule:

318

"""Wrapper for PyTorch modules with distributed training support."""

319

320

@property

321

def module(self) -> nn.Module:

322

"""Access the wrapped PyTorch module."""

323

324

def forward(self, *args, **kwargs) -> Any:

325

"""Forward pass with precision handling."""

326

327

def state_dict(self, **kwargs) -> dict[str, Any]:

328

"""Get module state dictionary."""

329

330

def load_state_dict(self, state_dict: dict, strict: bool = True) -> Any:

331

"""Load module state dictionary."""

332

333

class _FabricOptimizer:

334

"""Wrapper for PyTorch optimizers with distributed training support."""

335

336

@property

337

def optimizer(self) -> Optimizer:

338

"""Access the wrapped PyTorch optimizer."""

339

340

def step(self, closure: Optional[Callable] = None) -> Any:

341

"""Perform optimizer step."""

342

343

def zero_grad(self, set_to_none: bool = False) -> None:

344

"""Zero the gradients."""

345

346

def state_dict(self) -> dict[str, Any]:

347

"""Get optimizer state dictionary."""

348

349

def load_state_dict(self, state_dict: dict) -> None:

350

"""Load optimizer state dictionary."""

351

352

class _FabricDataLoader:

353

"""Wrapper for PyTorch DataLoaders with distributed training support."""

354

355

@property

356

def device(self) -> Optional[torch.device]:

357

"""Target device for data placement."""

358

```

359

360

### Context Managers

361

362

Special context managers for advanced training scenarios.

363

364

```python { .api }

365

def no_backward_sync(

366

self,

367

module: _FabricModule,

368

enabled: bool = True

369

) -> AbstractContextManager:

370

"""

371

Context manager to skip gradient synchronization.

372

373

Args:

374

module: Fabric-wrapped module

375

enabled: Whether to skip sync (True) or perform normal sync (False)

376

377

Returns:

378

Context manager

379

"""

380

381

def rank_zero_first(self, local: bool = False) -> Generator:

382

"""

383

Context manager ensuring rank 0 executes first.

384

385

Args:

386

local: Whether to use local rank (node-level) or global rank

387

388

Yields:

389

None

390

"""

391

392

def init_tensor(self) -> AbstractContextManager:

393

"""

394

Context manager for tensor initialization.

395

396

Returns:

397

Context manager for tensor initialization

398

"""

399

400

def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager:

401

"""

402

Context manager for module initialization.

403

404

Args:

405

empty_init: Whether to use empty initialization

406

407

Returns:

408

Context manager for module initialization

409

"""

410

```

411

412

### Logging Methods

413

414

Log metrics and values to registered loggers for experiment tracking.

415

416

```python { .api }

417

def log(self, name: str, value: Any, step: Optional[int] = None) -> None:

418

"""

419

Log a scalar to all loggers that were added to Fabric.

420

421

Args:

422

name: The name of the metric to log

423

value: The metric value to collect. If the value is a torch.Tensor, it gets detached automatically

424

step: Optional step number. Most Logger implementations auto-increment this value

425

"""

426

427

def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> None:

428

"""

429

Log multiple scalars at once to all loggers that were added to Fabric.

430

431

Args:

432

metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged

433

step: Optional step number. Most Logger implementations auto-increment this value

434

"""

435

```

436

437

### Callback Management

438

439

Invoke registered callback methods for training event handling.

440

441

```python { .api }

442

def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:

443

"""

444

Trigger the callback methods with the given name and arguments.

445

446

Args:

447

hook_name: The name of the callback method

448

*args: Optional positional arguments that get passed down to the callback method

449

**kwargs: Optional keyword arguments that get passed down to the callback method

450

"""

451

```

452

453

## Usage Examples

454

455

### Basic Training Setup

456

457

```python

458

from lightning.fabric import Fabric

459

import torch

460

import torch.nn as nn

461

462

# Initialize Fabric

463

fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")

464

465

# Define model and optimizer

466

model = nn.Sequential(

467

nn.Linear(784, 256),

468

nn.ReLU(),

469

nn.Linear(256, 10)

470

)

471

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

472

473

# Setup with Fabric

474

model, optimizer = fabric.setup(model, optimizer)

475

476

# Training loop

477

for epoch in range(10):

478

for batch in dataloader:

479

x, y = batch

480

optimizer.zero_grad()

481

482

y_pred = model(x)

483

loss = nn.functional.cross_entropy(y_pred, y)

484

485

fabric.backward(loss)

486

optimizer.step()

487

```

488

489

### Checkpoint Management

490

491

```python

492

# Save checkpoint

493

state = {

494

"model": model,

495

"optimizer": optimizer,

496

"epoch": epoch,

497

"loss": loss.item()

498

}

499

fabric.save("checkpoint.ckpt", state)

500

501

# Load checkpoint

502

loaded_state = fabric.load("checkpoint.ckpt")

503

epoch = loaded_state["epoch"]

504

loss = loaded_state["loss"]

505

```

506

507

### Mixed Precision Training

508

509

```python

510

# Initialize with mixed precision

511

fabric = Fabric(precision="16-mixed")

512

513

# Use autocast context

514

for batch in dataloader:

515

with fabric.autocast():

516

y_pred = model(batch)

517

loss = criterion(y_pred, targets)

518

519

fabric.backward(loss)

520

optimizer.step()

521

```