0
# Lightning Fabric
1
2
Lightning Fabric is a lightweight PyTorch scaling library that provides expert-level control over PyTorch training loops and scaling strategies. It enables developers to scale complex models including foundation models, LLMs, diffusion models, transformers, and reinforcement learning across any device or scale without boilerplate code.
3
4
## Package Information
5
6
- **Package Name**: lightning-fabric
7
- **Language**: Python
8
- **Installation**: `pip install lightning-fabric`
9
10
## Core Imports
11
12
```python
13
from lightning.fabric import Fabric, seed_everything, is_wrapped
14
```
15
16
Additional commonly used utilities:
17
18
```python
19
from lightning.fabric.utilities import (
20
move_data_to_device,
21
suggested_max_num_workers,
22
rank_zero_only,
23
rank_zero_warn,
24
rank_zero_info
25
)
26
```
27
28
## Basic Usage
29
30
```python
31
from lightning.fabric import Fabric
32
import torch
33
import torch.nn as nn
34
from torch.utils.data import DataLoader
35
36
# Initialize Fabric with desired configuration
37
fabric = Fabric(accelerator="auto", devices="auto", strategy="auto")
38
39
# Define model and optimizer
40
model = nn.Linear(10, 1)
41
optimizer = torch.optim.AdamW(model.parameters())
42
43
# Setup with Fabric (handles device placement and distributed wrapping)
44
model, optimizer = fabric.setup(model, optimizer)
45
46
# Setup dataloader
47
dataloader = fabric.setup_dataloaders(DataLoader(...))
48
49
# Training loop
50
model.train()
51
for batch in dataloader:
52
x, y = batch
53
optimizer.zero_grad()
54
55
# Forward pass
56
y_pred = model(x)
57
loss = nn.functional.mse_loss(y_pred, y)
58
59
# Backward pass with automatic scaling and gradient handling
60
fabric.backward(loss)
61
optimizer.step()
62
63
# Save checkpoint
64
state = {"model": model, "optimizer": optimizer}
65
fabric.save("checkpoint.ckpt", state)
66
```
67
68
## Architecture
69
70
Lightning Fabric uses a plugin-based architecture that enables flexible scaling and customization:
71
72
- **Fabric**: Main orchestrator class that coordinates all components
73
- **Accelerators**: Hardware abstraction (CPU, GPU, TPU, MPS)
74
- **Strategies**: Distribution patterns (single device, data parallel, model parallel, FSDP, DeepSpeed)
75
- **Precision Plugins**: Mixed precision and quantization support
76
- **Environment Plugins**: Cluster environment detection and configuration
77
- **Loggers**: Experiment tracking and metric logging
78
- **Wrappers**: Transparent wrapping of PyTorch objects for distributed training
79
80
This plugin system allows Fabric to work across any hardware configuration and distributed training setup while maintaining the same simple API.
81
82
## Capabilities
83
84
### Core Training Orchestration
85
86
Main Fabric class that handles distributed training setup, model and optimizer configuration, checkpoint management, and training utilities.
87
88
```python { .api }
89
class Fabric:
90
def __init__(
91
self,
92
accelerator: Union[str, Accelerator] = "auto",
93
strategy: Union[str, Strategy] = "auto",
94
devices: Union[list[int], str, int] = "auto",
95
num_nodes: int = 1,
96
precision: Optional[Union[str, int]] = None,
97
plugins: Optional[Union[Any, list[Any]]] = None,
98
callbacks: Optional[Union[list[Any], Any]] = None,
99
loggers: Optional[Union[Logger, list[Logger]]] = None
100
): ...
101
102
def setup(self, module, *optimizers, move_to_device=True): ...
103
def setup_module(self, module, move_to_device=True): ...
104
def setup_optimizers(self, *optimizers): ...
105
def setup_dataloaders(self, *dataloaders, use_distributed_sampler=True): ...
106
def backward(self, tensor, *args, model=None, **kwargs): ...
107
def save(self, path, state, filter=None): ...
108
def load(self, path, state=None, strict=True): ...
109
```
110
111
[Core Training](./core-training.md)
112
113
### Distributed Operations
114
115
Collective communication operations for synchronizing data and gradients across processes in distributed training.
116
117
```python { .api }
118
def barrier(self, name=None) -> None: ...
119
def broadcast(self, obj, src=0): ...
120
def all_gather(self, data, group=None, sync_grads=False): ...
121
def all_reduce(self, data, group=None, reduce_op="mean"): ...
122
```
123
124
[Distributed Operations](./distributed.md)
125
126
### Accelerators
127
128
Hardware acceleration plugins for different compute devices including CPU, CUDA GPUs, Apple MPS, and TPUs.
129
130
```python { .api }
131
class Accelerator: ... # Abstract base
132
class CPUAccelerator(Accelerator): ...
133
class CUDAAccelerator(Accelerator): ...
134
class MPSAccelerator(Accelerator): ...
135
class XLAAccelerator(Accelerator): ...
136
```
137
138
[Accelerators](./accelerators.md)
139
140
### Strategies
141
142
Distributed training strategies for scaling models across devices and nodes.
143
144
```python { .api }
145
class Strategy: ... # Abstract base
146
class SingleDeviceStrategy(Strategy): ...
147
class DataParallelStrategy(Strategy): ...
148
class DDPStrategy(Strategy): ...
149
class DeepSpeedStrategy(Strategy): ...
150
class FSDPStrategy(Strategy): ...
151
class XLAStrategy(Strategy): ...
152
```
153
154
[Strategies](./strategies.md)
155
156
### Precision and Quantization
157
158
Precision plugins for mixed precision training, quantization, and memory optimization.
159
160
```python { .api }
161
class Precision: ... # Abstract base
162
class DoublePrecision(Precision): ...
163
class HalfPrecision(Precision): ...
164
class MixedPrecision(Precision): ...
165
class BitsandbytesPrecision(Precision): ...
166
class DeepSpeedPrecision(Precision): ...
167
class FSDPPrecision(Precision): ...
168
```
169
170
[Precision](./precision.md)
171
172
### Utilities
173
174
Helper functions for seeding, data movement, distributed utilities, and performance monitoring.
175
176
```python { .api }
177
def seed_everything(seed=None, workers=False, verbose=True) -> int: ...
178
def is_wrapped(obj) -> bool: ...
179
def move_data_to_device(obj, device): ...
180
def suggested_max_num_workers(num_cpus): ...
181
```
182
183
[Utilities](./utilities.md)
184
185
## Types
186
187
```python { .api }
188
# Common type aliases used throughout the API
189
_PATH = Union[str, Path]
190
_DEVICE = Union[torch.device, str, int]
191
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable, dict[_DEVICE, _DEVICE]]]
192
_PARAMETERS = Iterator[torch.nn.Parameter]
193
ReduceOp = torch.distributed.ReduceOp
194
RedOpType = ReduceOp.RedOpType
195
196
# Protocols for type checking
197
@runtime_checkable
198
class _Stateful(Protocol[_DictKey]):
199
def state_dict(self) -> dict[_DictKey, Any]: ...
200
def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ...
201
202
@runtime_checkable
203
class Steppable(Protocol):
204
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: ...
205
206
@runtime_checkable
207
class Optimizable(Steppable, Protocol):
208
param_groups: list[dict[Any, Any]]
209
defaults: dict[Any, Any]
210
state: defaultdict[Tensor, Any]
211
212
def state_dict(self) -> dict[str, dict[Any, Any]]: ...
213
def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ...
214
215
@runtime_checkable
216
class CollectibleGroup(Protocol):
217
def size(self) -> int: ...
218
def rank(self) -> int: ...
219
```