0
# Core Training
1
2
The main Fabric class and associated wrapper classes that provide the foundation for distributed PyTorch training with minimal code changes.
3
4
## Capabilities
5
6
### Fabric Class
7
8
The main orchestrator class that handles all aspects of distributed training setup and execution.
9
10
```python { .api }
11
class Fabric:
12
"""
13
Main class for accelerating PyTorch training with minimal changes.
14
15
Provides automatic device placement, mixed precision, distributed training,
16
and seamless switching between hardware configurations.
17
"""
18
19
def __init__(
20
self,
21
accelerator: Union[str, Accelerator] = "auto",
22
strategy: Union[str, Strategy] = "auto",
23
devices: Union[list[int], str, int] = "auto",
24
num_nodes: int = 1,
25
precision: Optional[Union[str, int]] = None,
26
plugins: Optional[Union[Any, list[Any]]] = None,
27
callbacks: Optional[Union[list[Any], Any]] = None,
28
loggers: Optional[Union[Logger, list[Logger]]] = None
29
):
30
"""
31
Initialize Fabric with hardware and training configuration.
32
33
Args:
34
accelerator: Hardware to run on ("cpu", "cuda", "mps", "gpu", "tpu", "auto")
35
strategy: Distribution strategy ("dp", "ddp", "ddp_spawn", "deepspeed", "fsdp", "auto")
36
devices: Number of devices or specific device IDs
37
num_nodes: Number of nodes for multi-node training
38
precision: Precision mode ("64", "32", "16-mixed", "bf16-mixed", etc.)
39
plugins: Additional plugins for customization
40
callbacks: Callback functions for training events
41
loggers: Logger instances for experiment tracking
42
"""
43
```
44
45
### Setup Methods
46
47
Configure models, optimizers, and dataloaders for distributed training.
48
49
```python { .api }
50
def setup(
51
self,
52
module: nn.Module,
53
*optimizers: Optimizer,
54
move_to_device: bool = True,
55
_reapply_compile: bool = True
56
) -> Union[_FabricModule, tuple[_FabricModule, _FabricOptimizer, ...]]:
57
"""
58
Setup model and optimizers for distributed training.
59
60
Args:
61
module: PyTorch model to setup
62
*optimizers: One or more optimizers
63
move_to_device: Whether to move model to target device
64
_reapply_compile: Whether to reapply torch.compile if present
65
66
Returns:
67
Fabric-wrapped module and optimizers
68
"""
69
70
def setup_module(
71
self,
72
module: nn.Module,
73
move_to_device: bool = True,
74
_reapply_compile: bool = True
75
) -> _FabricModule:
76
"""
77
Setup only the model for distributed training.
78
79
Args:
80
module: PyTorch model to setup
81
move_to_device: Whether to move model to target device
82
_reapply_compile: Whether to reapply torch.compile if present
83
84
Returns:
85
Fabric-wrapped module
86
"""
87
88
def setup_optimizers(
89
self,
90
*optimizers: Optimizer
91
) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]:
92
"""
93
Setup optimizers for distributed training.
94
95
Args:
96
*optimizers: One or more optimizers to setup
97
98
Returns:
99
Fabric-wrapped optimizer(s)
100
"""
101
102
def setup_dataloaders(
103
self,
104
*dataloaders: DataLoader,
105
use_distributed_sampler: bool = True,
106
move_to_device: bool = True
107
) -> Union[DataLoader, list[DataLoader]]:
108
"""
109
Setup dataloaders for distributed training.
110
111
Args:
112
*dataloaders: One or more dataloaders to setup
113
use_distributed_sampler: Whether to replace sampler for distributed training
114
move_to_device: Whether to move data to target device automatically
115
116
Returns:
117
Configured dataloader(s)
118
"""
119
```
120
121
### Training Operations
122
123
Core methods for training loops including backward pass, gradient clipping, and precision handling.
124
125
```python { .api }
126
def backward(
127
self,
128
tensor: Tensor,
129
*args,
130
model: Optional[_FabricModule] = None,
131
**kwargs
132
) -> None:
133
"""
134
Perform backward pass with automatic gradient scaling and accumulation.
135
136
Args:
137
tensor: Loss tensor to compute gradients for
138
*args: Additional arguments passed to tensor.backward()
139
model: Model to sync gradients for (auto-detected if None)
140
**kwargs: Additional keyword arguments
141
"""
142
143
def clip_gradients(
144
self,
145
module: _FabricModule,
146
optimizer: _FabricOptimizer,
147
clip_val: Optional[Union[int, float]] = None,
148
max_norm: Optional[Union[int, float]] = None,
149
norm_type: Union[int, float] = 2.0,
150
error_if_nonfinite: bool = True
151
) -> Optional[Tensor]:
152
"""
153
Clip gradients by value or norm.
154
155
Args:
156
module: Fabric-wrapped module
157
optimizer: Fabric-wrapped optimizer
158
clip_val: Maximum allowed value of gradients
159
max_norm: Maximum allowed norm of gradients
160
norm_type: Type of norm to compute (default: 2.0 for L2 norm)
161
error_if_nonfinite: Whether to error on non-finite gradients
162
163
Returns:
164
Total norm of the parameters if max_norm is specified
165
"""
166
167
def autocast(self) -> AbstractContextManager:
168
"""
169
Context manager for automatic mixed precision.
170
171
Returns:
172
Context manager that applies appropriate precision casting
173
"""
174
```
175
176
### Checkpoint Management
177
178
Save and load model states, optimizers, and training metadata.
179
180
```python { .api }
181
def save(
182
self,
183
path: _PATH,
184
state: dict[str, Any],
185
filter: Optional[dict[str, Any]] = None
186
) -> None:
187
"""
188
Save checkpoint with distributed training support.
189
190
Args:
191
path: Checkpoint file path
192
state: Dictionary containing model, optimizer, and other state
193
filter: Optional filter for state dict keys
194
"""
195
196
def load(
197
self,
198
path: _PATH,
199
state: Optional[dict[str, Any]] = None,
200
strict: bool = True
201
) -> dict[str, Any]:
202
"""
203
Load checkpoint with distributed training support.
204
205
Args:
206
path: Checkpoint file path
207
state: Dictionary to load state into (if provided)
208
strict: Whether to strictly enforce state dict key matching
209
210
Returns:
211
Loaded checkpoint dictionary
212
"""
213
214
def load_raw(
215
self,
216
path: _PATH,
217
obj: Union[nn.Module, Optimizer],
218
strict: bool = True
219
) -> None:
220
"""
221
Load raw PyTorch checkpoint into object.
222
223
Args:
224
path: Checkpoint file path
225
obj: Object to load state into
226
strict: Whether to strictly enforce state dict key matching
227
"""
228
```
229
230
### Process Management
231
232
Launch and coordinate distributed processes.
233
234
```python { .api }
235
def launch(
236
self,
237
function: Callable = lambda: None,
238
*args,
239
**kwargs
240
) -> Any:
241
"""
242
Launch the distributed training processes.
243
244
Args:
245
function: Function to execute in distributed processes
246
*args: Arguments to pass to function
247
**kwargs: Keyword arguments to pass to function
248
249
Returns:
250
Result from function execution
251
"""
252
253
def run(self, *args, **kwargs) -> Any:
254
"""
255
Execute main training function with distributed setup.
256
257
Args:
258
*args: Arguments passed to training function
259
**kwargs: Keyword arguments passed to training function
260
261
Returns:
262
Result from training function
263
"""
264
```
265
266
### Properties
267
268
Access information about the distributed training setup.
269
270
```python { .api }
271
@property
272
def accelerator(self) -> Accelerator:
273
"""Current accelerator instance."""
274
275
@property
276
def strategy(self) -> Strategy:
277
"""Current strategy instance."""
278
279
@property
280
def device(self) -> torch.device:
281
"""Current device."""
282
283
@property
284
def global_rank(self) -> int:
285
"""Global rank of this process."""
286
287
@property
288
def local_rank(self) -> int:
289
"""Local rank of this process on current node."""
290
291
@property
292
def node_rank(self) -> int:
293
"""Rank of current node."""
294
295
@property
296
def world_size(self) -> int:
297
"""Total number of processes."""
298
299
@property
300
def is_global_zero(self) -> bool:
301
"""Whether this is the rank 0 process."""
302
303
@property
304
def loggers(self) -> list[Logger]:
305
"""List of all logger instances."""
306
307
@property
308
def logger(self) -> Logger:
309
"""Primary logger instance."""
310
```
311
312
### Wrapper Classes
313
314
Fabric automatically wraps PyTorch objects to provide distributed training support.
315
316
```python { .api }
317
class _FabricModule:
318
"""Wrapper for PyTorch modules with distributed training support."""
319
320
@property
321
def module(self) -> nn.Module:
322
"""Access the wrapped PyTorch module."""
323
324
def forward(self, *args, **kwargs) -> Any:
325
"""Forward pass with precision handling."""
326
327
def state_dict(self, **kwargs) -> dict[str, Any]:
328
"""Get module state dictionary."""
329
330
def load_state_dict(self, state_dict: dict, strict: bool = True) -> Any:
331
"""Load module state dictionary."""
332
333
class _FabricOptimizer:
334
"""Wrapper for PyTorch optimizers with distributed training support."""
335
336
@property
337
def optimizer(self) -> Optimizer:
338
"""Access the wrapped PyTorch optimizer."""
339
340
def step(self, closure: Optional[Callable] = None) -> Any:
341
"""Perform optimizer step."""
342
343
def zero_grad(self, set_to_none: bool = False) -> None:
344
"""Zero the gradients."""
345
346
def state_dict(self) -> dict[str, Any]:
347
"""Get optimizer state dictionary."""
348
349
def load_state_dict(self, state_dict: dict) -> None:
350
"""Load optimizer state dictionary."""
351
352
class _FabricDataLoader:
353
"""Wrapper for PyTorch DataLoaders with distributed training support."""
354
355
@property
356
def device(self) -> Optional[torch.device]:
357
"""Target device for data placement."""
358
```
359
360
### Context Managers
361
362
Special context managers for advanced training scenarios.
363
364
```python { .api }
365
def no_backward_sync(
366
self,
367
module: _FabricModule,
368
enabled: bool = True
369
) -> AbstractContextManager:
370
"""
371
Context manager to skip gradient synchronization.
372
373
Args:
374
module: Fabric-wrapped module
375
enabled: Whether to skip sync (True) or perform normal sync (False)
376
377
Returns:
378
Context manager
379
"""
380
381
def rank_zero_first(self, local: bool = False) -> Generator:
382
"""
383
Context manager ensuring rank 0 executes first.
384
385
Args:
386
local: Whether to use local rank (node-level) or global rank
387
388
Yields:
389
None
390
"""
391
392
def init_tensor(self) -> AbstractContextManager:
393
"""
394
Context manager for tensor initialization.
395
396
Returns:
397
Context manager for tensor initialization
398
"""
399
400
def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
401
"""
402
Context manager for module initialization.
403
404
Args:
405
empty_init: Whether to use empty initialization
406
407
Returns:
408
Context manager for module initialization
409
"""
410
```
411
412
### Logging Methods
413
414
Log metrics and values to registered loggers for experiment tracking.
415
416
```python { .api }
417
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
418
"""
419
Log a scalar to all loggers that were added to Fabric.
420
421
Args:
422
name: The name of the metric to log
423
value: The metric value to collect. If the value is a torch.Tensor, it gets detached automatically
424
step: Optional step number. Most Logger implementations auto-increment this value
425
"""
426
427
def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> None:
428
"""
429
Log multiple scalars at once to all loggers that were added to Fabric.
430
431
Args:
432
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged
433
step: Optional step number. Most Logger implementations auto-increment this value
434
"""
435
```
436
437
### Callback Management
438
439
Invoke registered callback methods for training event handling.
440
441
```python { .api }
442
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
443
"""
444
Trigger the callback methods with the given name and arguments.
445
446
Args:
447
hook_name: The name of the callback method
448
*args: Optional positional arguments that get passed down to the callback method
449
**kwargs: Optional keyword arguments that get passed down to the callback method
450
"""
451
```
452
453
## Usage Examples
454
455
### Basic Training Setup
456
457
```python
458
from lightning.fabric import Fabric
459
import torch
460
import torch.nn as nn
461
462
# Initialize Fabric
463
fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")
464
465
# Define model and optimizer
466
model = nn.Sequential(
467
nn.Linear(784, 256),
468
nn.ReLU(),
469
nn.Linear(256, 10)
470
)
471
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
472
473
# Setup with Fabric
474
model, optimizer = fabric.setup(model, optimizer)
475
476
# Training loop
477
for epoch in range(10):
478
for batch in dataloader:
479
x, y = batch
480
optimizer.zero_grad()
481
482
y_pred = model(x)
483
loss = nn.functional.cross_entropy(y_pred, y)
484
485
fabric.backward(loss)
486
optimizer.step()
487
```
488
489
### Checkpoint Management
490
491
```python
492
# Save checkpoint
493
state = {
494
"model": model,
495
"optimizer": optimizer,
496
"epoch": epoch,
497
"loss": loss.item()
498
}
499
fabric.save("checkpoint.ckpt", state)
500
501
# Load checkpoint
502
loaded_state = fabric.load("checkpoint.ckpt")
503
epoch = loaded_state["epoch"]
504
loss = loaded_state["loss"]
505
```
506
507
### Mixed Precision Training
508
509
```python
510
# Initialize with mixed precision
511
fabric = Fabric(precision="16-mixed")
512
513
# Use autocast context
514
for batch in dataloader:
515
with fabric.autocast():
516
y_pred = model(batch)
517
loss = criterion(y_pred, targets)
518
519
fabric.backward(loss)
520
optimizer.step()
521
```