0
# Lightning Fabric
1
2
Lightweight training acceleration framework providing expert-level control over training loops, device management, and distributed strategies without high-level abstractions. Fabric gives you the flexibility of raw PyTorch with the power of Lightning's optimizations.
3
4
## Capabilities
5
6
### Fabric Core
7
8
Main Fabric class that accelerates PyTorch training with distributed training, mixed precision, and device management while maintaining full control over the training loop.
9
10
```python { .api }
11
class Fabric:
12
def __init__(
13
self,
14
accelerator: str = "auto",
15
strategy: str = "auto",
16
devices: Union[List[int], str, int] = "auto",
17
num_nodes: int = 1,
18
precision: Union[str, int] = "32-true",
19
plugins: Optional[Union[Plugin, List[Plugin]]] = None,
20
callbacks: Optional[Union[Callback, List[Callback]]] = None,
21
loggers: Optional[Union[Logger, List[Logger]]] = None,
22
**kwargs
23
):
24
"""
25
Initialize Fabric for training acceleration.
26
27
Args:
28
accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
29
strategy: Distributed strategy ('ddp', 'fsdp', 'deepspeed', etc.)
30
devices: Which devices to use
31
num_nodes: Number of nodes for distributed training
32
precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)
33
plugins: Additional plugins for customization
34
callbacks: Callbacks for training lifecycle hooks
35
loggers: Loggers for experiment tracking
36
"""
37
38
def setup(
39
self,
40
model: nn.Module,
41
*optimizers: Optimizer
42
) -> Union[nn.Module, Tuple[nn.Module, ...]]:
43
"""
44
Set up model and optimizers for training.
45
46
Args:
47
model: PyTorch model to accelerate
48
*optimizers: Optimizers to set up
49
50
Returns:
51
Wrapped model and optimizers ready for training
52
"""
53
54
def setup_dataloaders(
55
self,
56
*dataloaders: DataLoader
57
) -> Union[DataLoader, List[DataLoader]]:
58
"""
59
Set up data loaders for distributed training.
60
61
Args:
62
*dataloaders: Data loaders to set up
63
64
Returns:
65
Wrapped data loaders ready for distributed training
66
"""
67
68
def backward(self, tensor: Tensor) -> None:
69
"""
70
Perform backward pass with proper scaling and synchronization.
71
72
Args:
73
tensor: Loss tensor to compute gradients from
74
"""
75
76
def clip_gradients(
77
self,
78
model: nn.Module,
79
optimizer: Optimizer,
80
max_norm: Union[float, int],
81
norm_type: Union[float, int] = 2.0,
82
error_if_nonfinite: bool = True
83
) -> Tensor:
84
"""
85
Clip gradients by norm.
86
87
Args:
88
model: Model whose gradients to clip
89
optimizer: Optimizer being used
90
max_norm: Maximum norm for gradients
91
norm_type: Type of norm to use
92
error_if_nonfinite: Raise error if gradients are non-finite
93
94
Returns:
95
Total norm of the gradients
96
"""
97
98
def all_gather(
99
self,
100
tensor: Tensor,
101
group: Optional[Any] = None,
102
sync_grads: bool = False
103
) -> Tensor:
104
"""
105
Gather tensors from all processes.
106
107
Args:
108
tensor: Tensor to gather
109
group: Process group
110
sync_grads: Synchronize gradients
111
112
Returns:
113
Gathered tensor from all processes
114
"""
115
116
def all_reduce(
117
self,
118
tensor: Tensor,
119
group: Optional[Any] = None,
120
reduce_op: str = "mean"
121
) -> Tensor:
122
"""
123
Reduce tensor across all processes.
124
125
Args:
126
tensor: Tensor to reduce
127
group: Process group
128
reduce_op: Reduction operation ('mean', 'sum')
129
130
Returns:
131
Reduced tensor
132
"""
133
134
def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
135
"""
136
Broadcast tensor from source process to all processes.
137
138
Args:
139
tensor: Tensor to broadcast
140
src: Source rank
141
142
Returns:
143
Broadcasted tensor
144
"""
145
146
def barrier(self, name: Optional[str] = None) -> None:
147
"""
148
Synchronize all processes.
149
150
Args:
151
name: Optional barrier name for debugging
152
"""
153
154
def is_global_zero(self) -> bool:
155
"""
156
Check if current process is global rank 0.
157
158
Returns:
159
True if global rank 0
160
"""
161
162
def print(self, *args, **kwargs) -> None:
163
"""
164
Print only on rank 0.
165
166
Args:
167
*args: Arguments to print
168
**kwargs: Keyword arguments for print
169
"""
170
171
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
172
"""
173
Log a metric.
174
175
Args:
176
name: Metric name
177
value: Metric value
178
step: Optional step number
179
"""
180
181
def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
182
"""
183
Log a dictionary of metrics.
184
185
Args:
186
metrics: Dictionary of metrics
187
step: Optional step number
188
"""
189
190
def save(self, path: str, state: Dict[str, Any]) -> None:
191
"""
192
Save checkpoint.
193
194
Args:
195
path: Path to save checkpoint
196
state: State dictionary to save
197
"""
198
199
def load(self, path: str) -> Dict[str, Any]:
200
"""
201
Load checkpoint.
202
203
Args:
204
path: Path to load checkpoint from
205
206
Returns:
207
Loaded state dictionary
208
"""
209
210
@property
211
def device(self) -> torch.device:
212
"""Get the current device."""
213
214
@property
215
def global_rank(self) -> int:
216
"""Get global rank of current process."""
217
218
@property
219
def local_rank(self) -> int:
220
"""Get local rank of current process."""
221
222
@property
223
def node_rank(self) -> int:
224
"""Get node rank of current process."""
225
226
@property
227
def world_size(self) -> int:
228
"""Get total number of processes."""
229
230
def to_device(self, obj: Any) -> Any:
231
"""
232
Move object to device.
233
234
Args:
235
obj: Object to move to device
236
237
Returns:
238
Object on the device
239
"""
240
```
241
242
### Utility Functions
243
244
Core utility functions for reproducibility, object inspection, and common operations in Fabric workflows.
245
246
```python { .api }
247
def seed_everything(seed: int, workers: bool = False) -> int:
248
"""
249
Set random seeds for reproducibility.
250
251
Args:
252
seed: Random seed to set
253
workers: Also set seed for data loader workers
254
255
Returns:
256
The seed that was set
257
"""
258
259
def is_wrapped(obj: Any) -> bool:
260
"""
261
Check if an object has been wrapped by Fabric.
262
263
Args:
264
obj: Object to check
265
266
Returns:
267
True if object is wrapped by Fabric
268
"""
269
```
270
271
## Basic Usage Example
272
273
```python
274
import torch
275
import torch.nn as nn
276
from torch.utils.data import DataLoader, TensorDataset
277
from lightning import Fabric
278
279
# Initialize Fabric
280
fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")
281
282
# Define model and optimizer
283
model = nn.Linear(10, 1)
284
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
285
286
# Setup model and optimizer with Fabric
287
model, optimizer = fabric.setup(model, optimizer)
288
289
# Create sample data and dataloader
290
data = torch.randn(1000, 10)
291
targets = torch.randn(1000, 1)
292
dataset = TensorDataset(data, targets)
293
dataloader = DataLoader(dataset, batch_size=32)
294
295
# Setup dataloader
296
dataloader = fabric.setup_dataloaders(dataloader)
297
298
# Training loop with full control
299
for epoch in range(10):
300
for batch_idx, (x, y) in enumerate(dataloader):
301
optimizer.zero_grad()
302
303
# Forward pass
304
y_pred = model(x)
305
loss = nn.functional.mse_loss(y_pred, y)
306
307
# Backward pass - Fabric handles scaling and synchronization
308
fabric.backward(loss)
309
310
optimizer.step()
311
312
# Log metrics
313
if batch_idx % 10 == 0:
314
fabric.log("train_loss", loss.item())
315
fabric.print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")
316
317
# Save checkpoint
318
state = {
319
"model": model.state_dict(),
320
"optimizer": optimizer.state_dict(),
321
"epoch": epoch
322
}
323
fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)
324
```
325
326
## Advanced Usage Example
327
328
```python
329
import torch
330
import torch.nn as nn
331
from torch.utils.data import DataLoader
332
from lightning import Fabric
333
334
# Initialize Fabric with advanced configuration
335
fabric = Fabric(
336
accelerator="gpu",
337
devices=4,
338
strategy="fsdp",
339
precision="16-mixed",
340
plugins=None
341
)
342
343
class MyModel(nn.Module):
344
def __init__(self):
345
super().__init__()
346
self.layers = nn.Sequential(
347
nn.Linear(784, 256),
348
nn.ReLU(),
349
nn.Dropout(0.2),
350
nn.Linear(256, 128),
351
nn.ReLU(),
352
nn.Dropout(0.2),
353
nn.Linear(128, 10)
354
)
355
356
def forward(self, x):
357
return self.layers(x)
358
359
# Model and optimizers
360
model = MyModel()
361
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
362
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
363
364
# Setup with Fabric
365
model, optimizer = fabric.setup(model, optimizer)
366
367
# Training loop with advanced features
368
for epoch in range(100):
369
model.train()
370
371
for batch_idx, (data, target) in enumerate(train_dataloader):
372
optimizer.zero_grad()
373
374
output = model(data)
375
loss = nn.functional.cross_entropy(output, target)
376
377
# Backward with automatic mixed precision
378
fabric.backward(loss)
379
380
# Gradient clipping
381
fabric.clip_gradients(model, optimizer, max_norm=1.0)
382
383
optimizer.step()
384
385
# Metrics logging
386
if batch_idx % 100 == 0:
387
accuracy = (output.argmax(dim=1) == target).float().mean()
388
389
# Log metrics - automatically handles distributed averaging
390
fabric.log_dict({
391
"train_loss": loss.item(),
392
"train_acc": accuracy.item(),
393
"lr": scheduler.get_last_lr()[0]
394
})
395
396
# Print only on rank 0
397
fabric.print(f"Epoch {epoch}/{100}, Batch {batch_idx}, "
398
f"Loss: {loss.item():.4f}, Acc: {accuracy.item():.4f}")
399
400
scheduler.step()
401
402
# Synchronization barrier
403
fabric.barrier()
404
405
# Save checkpoint (only on rank 0)
406
if fabric.is_global_zero():
407
checkpoint = {
408
"model": model.state_dict(),
409
"optimizer": optimizer.state_dict(),
410
"scheduler": scheduler.state_dict(),
411
"epoch": epoch,
412
}
413
fabric.save(f"model_epoch_{epoch}.ckpt", checkpoint)
414
```