0
# Low-Level Training Control
1
2
Lightning Fabric provides fine-grained control over training loops while automatically handling device management, distributed training setup, and gradient synchronization. This enables custom training logic with minimal boilerplate code.
3
4
## Capabilities
5
6
### Fabric Class
7
8
Core abstraction that handles device management, distributed training setup, mixed precision, and gradient synchronization while giving you full control over the training loop.
9
10
```python { .api }
11
class Fabric:
12
def __init__(
13
self,
14
accelerator: str = "auto",
15
devices: Union[int, str, List[int]] = "auto",
16
num_nodes: int = 1,
17
strategy: Optional[str] = None,
18
precision: Optional[str] = None,
19
plugins: Optional[Union[str, list]] = None,
20
callbacks: Optional[Union[List, dict]] = None,
21
loggers: Optional[Union[Logger, List[Logger]]] = None
22
):
23
"""
24
Initialize Fabric for low-level training control.
25
26
Parameters:
27
- accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
28
- devices: Device specification (int, list, or 'auto')
29
- num_nodes: Number of nodes for distributed training
30
- strategy: Training strategy for distributed training
31
- precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')
32
- plugins: Additional plugins for custom functionality
33
- callbacks: Callback instances for training hooks
34
- loggers: Logger instances for experiment tracking
35
"""
36
37
def setup(
38
self,
39
model: nn.Module,
40
*optimizers: Optimizer
41
) -> Union[nn.Module, Tuple[nn.Module, ...]]:
42
"""
43
Setup model and optimizers for distributed training.
44
45
Parameters:
46
- model: PyTorch model to setup
47
- optimizers: Optimizer instances to setup
48
49
Returns:
50
Configured model and optimizers
51
"""
52
53
def setup_dataloaders(
54
self,
55
*dataloaders: DataLoader
56
) -> Union[DataLoader, Tuple[DataLoader, ...]]:
57
"""
58
Setup dataloaders for distributed training.
59
60
Parameters:
61
- dataloaders: DataLoader instances to setup
62
63
Returns:
64
Configured dataloaders
65
"""
66
67
def backward(self, loss: torch.Tensor) -> None:
68
"""
69
Backward pass with automatic gradient scaling.
70
71
Parameters:
72
- loss: Loss tensor to compute gradients for
73
"""
74
75
def step(self, optimizer: Optimizer, *args, **kwargs) -> None:
76
"""
77
Optimizer step with gradient unscaling and synchronization.
78
79
Parameters:
80
- optimizer: Optimizer to step
81
- args, kwargs: Additional arguments passed to optimizer.step()
82
"""
83
84
def clip_gradients(
85
self,
86
model: nn.Module,
87
optimizer: Optimizer,
88
max_norm: Union[float, int],
89
norm_type: Union[float, int] = 2.0,
90
error_if_nonfinite: bool = True
91
) -> torch.Tensor:
92
"""
93
Clip gradients by norm.
94
95
Parameters:
96
- model: Model whose gradients to clip
97
- optimizer: Associated optimizer
98
- max_norm: Maximum norm for gradients
99
- norm_type: Type of norm to compute
100
- error_if_nonfinite: Raise error for non-finite gradients
101
102
Returns:
103
Total norm of gradients
104
"""
105
106
def save(self, path: str, state: dict) -> None:
107
"""
108
Save training state to checkpoint.
109
110
Parameters:
111
- path: Path to save checkpoint
112
- state: Dictionary containing model/optimizer states
113
"""
114
115
def load(self, path: str) -> dict:
116
"""
117
Load training state from checkpoint.
118
119
Parameters:
120
- path: Path to checkpoint file
121
122
Returns:
123
Dictionary containing loaded state
124
"""
125
126
def barrier(self, name: Optional[str] = None) -> None:
127
"""
128
Synchronize all processes.
129
130
Parameters:
131
- name: Optional barrier name for debugging
132
"""
133
134
def broadcast(self, obj: Any, src: int = 0) -> Any:
135
"""
136
Broadcast object from source rank to all ranks.
137
138
Parameters:
139
- obj: Object to broadcast
140
- src: Source rank
141
142
Returns:
143
Broadcasted object
144
"""
145
146
def all_gather(self, data: Any, group: Optional[Any] = None) -> List[Any]:
147
"""
148
Gather data from all processes.
149
150
Parameters:
151
- data: Data to gather
152
- group: Process group
153
154
Returns:
155
List of gathered data from all processes
156
"""
157
158
def all_reduce(
159
self,
160
tensor: torch.Tensor,
161
op: str = "sum",
162
group: Optional[Any] = None
163
) -> torch.Tensor:
164
"""
165
Reduce tensor across all processes.
166
167
Parameters:
168
- tensor: Tensor to reduce
169
- op: Reduction operation ('sum', 'mean', 'max', 'min')
170
- group: Process group
171
172
Returns:
173
Reduced tensor
174
"""
175
176
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
177
"""
178
Log metrics to configured loggers.
179
180
Parameters:
181
- name: Metric name
182
- value: Metric value
183
- step: Training step (auto-incremented if None)
184
"""
185
186
def log_dict(self, metrics: dict, step: Optional[int] = None) -> None:
187
"""
188
Log multiple metrics at once.
189
190
Parameters:
191
- metrics: Dictionary of metric names and values
192
- step: Training step (auto-incremented if None)
193
"""
194
195
def print(self, *args, **kwargs) -> None:
196
"""
197
Print only on rank 0 in distributed training.
198
199
Parameters:
200
- args, kwargs: Arguments passed to print()
201
"""
202
203
@property
204
def device(self) -> torch.device:
205
"""Current device."""
206
207
@property
208
def global_rank(self) -> int:
209
"""Global rank of current process."""
210
211
@property
212
def local_rank(self) -> int:
213
"""Local rank of current process."""
214
215
@property
216
def node_rank(self) -> int:
217
"""Node rank of current process."""
218
219
@property
220
def world_size(self) -> int:
221
"""Total number of processes."""
222
223
@property
224
def is_global_zero(self) -> bool:
225
"""Whether current process is global rank 0."""
226
```
227
228
### Utility Functions
229
230
```python { .api }
231
def seed_everything(seed: int, workers: bool = False) -> int:
232
"""
233
Seed all random number generators for reproducibility.
234
235
Parameters:
236
- seed: Random seed value
237
- workers: Seed dataloader worker processes
238
239
Returns:
240
The seed value used
241
"""
242
```
243
244
## Usage Examples
245
246
### Custom Training Loop
247
248
```python
249
import torch
250
import torch.nn as nn
251
from torch.utils.data import DataLoader
252
import lightning.fabric as L
253
254
# Initialize Fabric
255
fabric = L.Fabric(accelerator="gpu", devices=2, precision="16-mixed")
256
fabric.launch()
257
258
# Create model, optimizer, and data
259
model = nn.Linear(10, 1)
260
optimizer = torch.optim.Adam(model.parameters())
261
dataset = torch.randn(1000, 10), torch.randn(1000, 1)
262
dataloader = DataLoader(list(zip(*dataset)), batch_size=32)
263
264
# Setup for distributed training
265
model, optimizer = fabric.setup(model, optimizer)
266
dataloader = fabric.setup_dataloaders(dataloader)
267
268
# Custom training loop
269
model.train()
270
for epoch in range(10):
271
epoch_loss = 0
272
for batch_idx, (x, y) in enumerate(dataloader):
273
# Forward pass
274
output = model(x)
275
loss = nn.functional.mse_loss(output, y)
276
277
# Backward pass
278
optimizer.zero_grad()
279
fabric.backward(loss)
280
fabric.step(optimizer)
281
282
epoch_loss += loss.item()
283
284
# Log metrics
285
if batch_idx % 10 == 0:
286
fabric.log("train_loss", loss.item())
287
288
fabric.print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader)}")
289
```
290
291
### Checkpointing and Resuming
292
293
```python
294
import lightning.fabric as L
295
296
fabric = L.Fabric()
297
fabric.launch()
298
299
model = nn.Linear(10, 1)
300
optimizer = torch.optim.Adam(model.parameters())
301
model, optimizer = fabric.setup(model, optimizer)
302
303
# Training loop with checkpointing
304
for epoch in range(100):
305
# ... training code ...
306
307
# Save checkpoint every 10 epochs
308
if epoch % 10 == 0:
309
state = {
310
"model": model,
311
"optimizer": optimizer,
312
"epoch": epoch
313
}
314
fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)
315
316
# Resume from checkpoint
317
checkpoint = fabric.load("checkpoint_epoch_50.ckpt")
318
model.load_state_dict(checkpoint["model"])
319
optimizer.load_state_dict(checkpoint["optimizer"])
320
start_epoch = checkpoint["epoch"] + 1
321
```
322
323
### Distributed Training Primitives
324
325
```python
326
import lightning.fabric as L
327
328
fabric = L.Fabric(devices=4, strategy="ddp")
329
fabric.launch()
330
331
# Broadcast configuration from rank 0
332
if fabric.global_rank == 0:
333
config = {"learning_rate": 0.001, "batch_size": 32}
334
else:
335
config = None
336
337
config = fabric.broadcast(config, src=0)
338
339
# Gather metrics from all processes
340
local_metrics = {"accuracy": 0.95, "loss": 0.1}
341
all_metrics = fabric.all_gather(local_metrics)
342
343
# Reduce tensor across all processes
344
local_tensor = torch.tensor([1.0, 2.0, 3.0])
345
reduced_tensor = fabric.all_reduce(local_tensor, op="mean")
346
347
fabric.print(f"Reduced tensor: {reduced_tensor}")
348
```
349
350
### Mixed Precision Training
351
352
```python
353
import lightning.fabric as L
354
355
# Enable mixed precision
356
fabric = L.Fabric(precision="16-mixed")
357
fabric.launch()
358
359
model = nn.Linear(10, 1)
360
optimizer = torch.optim.Adam(model.parameters())
361
model, optimizer = fabric.setup(model, optimizer)
362
363
# Training loop with automatic mixed precision
364
for epoch in range(10):
365
for batch in dataloader:
366
x, y = batch
367
368
# Forward pass (automatically uses mixed precision)
369
output = model(x)
370
loss = nn.functional.mse_loss(output, y)
371
372
# Backward pass (automatically handles gradient scaling)
373
optimizer.zero_grad()
374
fabric.backward(loss) # Handles gradient scaling
375
fabric.step(optimizer) # Handles gradient unscaling
376
```
377
378
### Custom Strategy Integration
379
380
```python
381
import lightning.fabric as L
382
from lightning.fabric.strategies import DeepSpeedStrategy
383
384
# Use custom strategy
385
strategy = DeepSpeedStrategy(stage=2)
386
fabric = L.Fabric(strategy=strategy, precision="16-mixed")
387
fabric.launch()
388
389
model = nn.Linear(10, 1)
390
optimizer = torch.optim.Adam(model.parameters())
391
model, optimizer = fabric.setup(model, optimizer)
392
393
# Training proceeds normally - Fabric handles strategy-specific details
394
for epoch in range(10):
395
for batch in dataloader:
396
# ... training code ...
397
pass
398
```