or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

apps.mdfabric.mdindex.mdtraining.mdutilities.md

fabric.mddocs/

0

# Low-Level Training Control

1

2

Lightning Fabric provides fine-grained control over training loops while automatically handling device management, distributed training setup, and gradient synchronization. This enables custom training logic with minimal boilerplate code.

3

4

## Capabilities

5

6

### Fabric Class

7

8

Core abstraction that handles device management, distributed training setup, mixed precision, and gradient synchronization while giving you full control over the training loop.

9

10

```python { .api }

11

class Fabric:

12

def __init__(

13

self,

14

accelerator: str = "auto",

15

devices: Union[int, str, List[int]] = "auto",

16

num_nodes: int = 1,

17

strategy: Optional[str] = None,

18

precision: Optional[str] = None,

19

plugins: Optional[Union[str, list]] = None,

20

callbacks: Optional[Union[List, dict]] = None,

21

loggers: Optional[Union[Logger, List[Logger]]] = None

22

):

23

"""

24

Initialize Fabric for low-level training control.

25

26

Parameters:

27

- accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')

28

- devices: Device specification (int, list, or 'auto')

29

- num_nodes: Number of nodes for distributed training

30

- strategy: Training strategy for distributed training

31

- precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')

32

- plugins: Additional plugins for custom functionality

33

- callbacks: Callback instances for training hooks

34

- loggers: Logger instances for experiment tracking

35

"""

36

37

def setup(

38

self,

39

model: nn.Module,

40

*optimizers: Optimizer

41

) -> Union[nn.Module, Tuple[nn.Module, ...]]:

42

"""

43

Setup model and optimizers for distributed training.

44

45

Parameters:

46

- model: PyTorch model to setup

47

- optimizers: Optimizer instances to setup

48

49

Returns:

50

Configured model and optimizers

51

"""

52

53

def setup_dataloaders(

54

self,

55

*dataloaders: DataLoader

56

) -> Union[DataLoader, Tuple[DataLoader, ...]]:

57

"""

58

Setup dataloaders for distributed training.

59

60

Parameters:

61

- dataloaders: DataLoader instances to setup

62

63

Returns:

64

Configured dataloaders

65

"""

66

67

def backward(self, loss: torch.Tensor) -> None:

68

"""

69

Backward pass with automatic gradient scaling.

70

71

Parameters:

72

- loss: Loss tensor to compute gradients for

73

"""

74

75

def step(self, optimizer: Optimizer, *args, **kwargs) -> None:

76

"""

77

Optimizer step with gradient unscaling and synchronization.

78

79

Parameters:

80

- optimizer: Optimizer to step

81

- args, kwargs: Additional arguments passed to optimizer.step()

82

"""

83

84

def clip_gradients(

85

self,

86

model: nn.Module,

87

optimizer: Optimizer,

88

max_norm: Union[float, int],

89

norm_type: Union[float, int] = 2.0,

90

error_if_nonfinite: bool = True

91

) -> torch.Tensor:

92

"""

93

Clip gradients by norm.

94

95

Parameters:

96

- model: Model whose gradients to clip

97

- optimizer: Associated optimizer

98

- max_norm: Maximum norm for gradients

99

- norm_type: Type of norm to compute

100

- error_if_nonfinite: Raise error for non-finite gradients

101

102

Returns:

103

Total norm of gradients

104

"""

105

106

def save(self, path: str, state: dict) -> None:

107

"""

108

Save training state to checkpoint.

109

110

Parameters:

111

- path: Path to save checkpoint

112

- state: Dictionary containing model/optimizer states

113

"""

114

115

def load(self, path: str) -> dict:

116

"""

117

Load training state from checkpoint.

118

119

Parameters:

120

- path: Path to checkpoint file

121

122

Returns:

123

Dictionary containing loaded state

124

"""

125

126

def barrier(self, name: Optional[str] = None) -> None:

127

"""

128

Synchronize all processes.

129

130

Parameters:

131

- name: Optional barrier name for debugging

132

"""

133

134

def broadcast(self, obj: Any, src: int = 0) -> Any:

135

"""

136

Broadcast object from source rank to all ranks.

137

138

Parameters:

139

- obj: Object to broadcast

140

- src: Source rank

141

142

Returns:

143

Broadcasted object

144

"""

145

146

def all_gather(self, data: Any, group: Optional[Any] = None) -> List[Any]:

147

"""

148

Gather data from all processes.

149

150

Parameters:

151

- data: Data to gather

152

- group: Process group

153

154

Returns:

155

List of gathered data from all processes

156

"""

157

158

def all_reduce(

159

self,

160

tensor: torch.Tensor,

161

op: str = "sum",

162

group: Optional[Any] = None

163

) -> torch.Tensor:

164

"""

165

Reduce tensor across all processes.

166

167

Parameters:

168

- tensor: Tensor to reduce

169

- op: Reduction operation ('sum', 'mean', 'max', 'min')

170

- group: Process group

171

172

Returns:

173

Reduced tensor

174

"""

175

176

def log(self, name: str, value: Any, step: Optional[int] = None) -> None:

177

"""

178

Log metrics to configured loggers.

179

180

Parameters:

181

- name: Metric name

182

- value: Metric value

183

- step: Training step (auto-incremented if None)

184

"""

185

186

def log_dict(self, metrics: dict, step: Optional[int] = None) -> None:

187

"""

188

Log multiple metrics at once.

189

190

Parameters:

191

- metrics: Dictionary of metric names and values

192

- step: Training step (auto-incremented if None)

193

"""

194

195

def print(self, *args, **kwargs) -> None:

196

"""

197

Print only on rank 0 in distributed training.

198

199

Parameters:

200

- args, kwargs: Arguments passed to print()

201

"""

202

203

@property

204

def device(self) -> torch.device:

205

"""Current device."""

206

207

@property

208

def global_rank(self) -> int:

209

"""Global rank of current process."""

210

211

@property

212

def local_rank(self) -> int:

213

"""Local rank of current process."""

214

215

@property

216

def node_rank(self) -> int:

217

"""Node rank of current process."""

218

219

@property

220

def world_size(self) -> int:

221

"""Total number of processes."""

222

223

@property

224

def is_global_zero(self) -> bool:

225

"""Whether current process is global rank 0."""

226

```

227

228

### Utility Functions

229

230

```python { .api }

231

def seed_everything(seed: int, workers: bool = False) -> int:

232

"""

233

Seed all random number generators for reproducibility.

234

235

Parameters:

236

- seed: Random seed value

237

- workers: Seed dataloader worker processes

238

239

Returns:

240

The seed value used

241

"""

242

```

243

244

## Usage Examples

245

246

### Custom Training Loop

247

248

```python

249

import torch

250

import torch.nn as nn

251

from torch.utils.data import DataLoader

252

import lightning.fabric as L

253

254

# Initialize Fabric

255

fabric = L.Fabric(accelerator="gpu", devices=2, precision="16-mixed")

256

fabric.launch()

257

258

# Create model, optimizer, and data

259

model = nn.Linear(10, 1)

260

optimizer = torch.optim.Adam(model.parameters())

261

dataset = torch.randn(1000, 10), torch.randn(1000, 1)

262

dataloader = DataLoader(list(zip(*dataset)), batch_size=32)

263

264

# Setup for distributed training

265

model, optimizer = fabric.setup(model, optimizer)

266

dataloader = fabric.setup_dataloaders(dataloader)

267

268

# Custom training loop

269

model.train()

270

for epoch in range(10):

271

epoch_loss = 0

272

for batch_idx, (x, y) in enumerate(dataloader):

273

# Forward pass

274

output = model(x)

275

loss = nn.functional.mse_loss(output, y)

276

277

# Backward pass

278

optimizer.zero_grad()

279

fabric.backward(loss)

280

fabric.step(optimizer)

281

282

epoch_loss += loss.item()

283

284

# Log metrics

285

if batch_idx % 10 == 0:

286

fabric.log("train_loss", loss.item())

287

288

fabric.print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader)}")

289

```

290

291

### Checkpointing and Resuming

292

293

```python

294

import lightning.fabric as L

295

296

fabric = L.Fabric()

297

fabric.launch()

298

299

model = nn.Linear(10, 1)

300

optimizer = torch.optim.Adam(model.parameters())

301

model, optimizer = fabric.setup(model, optimizer)

302

303

# Training loop with checkpointing

304

for epoch in range(100):

305

# ... training code ...

306

307

# Save checkpoint every 10 epochs

308

if epoch % 10 == 0:

309

state = {

310

"model": model,

311

"optimizer": optimizer,

312

"epoch": epoch

313

}

314

fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)

315

316

# Resume from checkpoint

317

checkpoint = fabric.load("checkpoint_epoch_50.ckpt")

318

model.load_state_dict(checkpoint["model"])

319

optimizer.load_state_dict(checkpoint["optimizer"])

320

start_epoch = checkpoint["epoch"] + 1

321

```

322

323

### Distributed Training Primitives

324

325

```python

326

import lightning.fabric as L

327

328

fabric = L.Fabric(devices=4, strategy="ddp")

329

fabric.launch()

330

331

# Broadcast configuration from rank 0

332

if fabric.global_rank == 0:

333

config = {"learning_rate": 0.001, "batch_size": 32}

334

else:

335

config = None

336

337

config = fabric.broadcast(config, src=0)

338

339

# Gather metrics from all processes

340

local_metrics = {"accuracy": 0.95, "loss": 0.1}

341

all_metrics = fabric.all_gather(local_metrics)

342

343

# Reduce tensor across all processes

344

local_tensor = torch.tensor([1.0, 2.0, 3.0])

345

reduced_tensor = fabric.all_reduce(local_tensor, op="mean")

346

347

fabric.print(f"Reduced tensor: {reduced_tensor}")

348

```

349

350

### Mixed Precision Training

351

352

```python

353

import lightning.fabric as L

354

355

# Enable mixed precision

356

fabric = L.Fabric(precision="16-mixed")

357

fabric.launch()

358

359

model = nn.Linear(10, 1)

360

optimizer = torch.optim.Adam(model.parameters())

361

model, optimizer = fabric.setup(model, optimizer)

362

363

# Training loop with automatic mixed precision

364

for epoch in range(10):

365

for batch in dataloader:

366

x, y = batch

367

368

# Forward pass (automatically uses mixed precision)

369

output = model(x)

370

loss = nn.functional.mse_loss(output, y)

371

372

# Backward pass (automatically handles gradient scaling)

373

optimizer.zero_grad()

374

fabric.backward(loss) # Handles gradient scaling

375

fabric.step(optimizer) # Handles gradient unscaling

376

```

377

378

### Custom Strategy Integration

379

380

```python

381

import lightning.fabric as L

382

from lightning.fabric.strategies import DeepSpeedStrategy

383

384

# Use custom strategy

385

strategy = DeepSpeedStrategy(stage=2)

386

fabric = L.Fabric(strategy=strategy, precision="16-mixed")

387

fabric.launch()

388

389

model = nn.Linear(10, 1)

390

optimizer = torch.optim.Adam(model.parameters())

391

model, optimizer = fabric.setup(model, optimizer)

392

393

# Training proceeds normally - Fabric handles strategy-specific details

394

for epoch in range(10):

395

for batch in dataloader:

396

# ... training code ...

397

pass

398

```