Unified deep learning framework integrating PyTorch Lightning, Lightning Fabric, and Lightning Apps for training, deploying, and shipping AI products.
—
Lightning Fabric provides fine-grained control over training loops while automatically handling device management, distributed training setup, and gradient synchronization. This enables custom training logic with minimal boilerplate code.
Core abstraction that handles device management, distributed training setup, mixed precision, and gradient synchronization while giving you full control over the training loop.
class Fabric:
def __init__(
self,
accelerator: str = "auto",
devices: Union[int, str, List[int]] = "auto",
num_nodes: int = 1,
strategy: Optional[str] = None,
precision: Optional[str] = None,
plugins: Optional[Union[str, list]] = None,
callbacks: Optional[Union[List, dict]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None
):
"""
Initialize Fabric for low-level training control.
Parameters:
- accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
- devices: Device specification (int, list, or 'auto')
- num_nodes: Number of nodes for distributed training
- strategy: Training strategy for distributed training
- precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')
- plugins: Additional plugins for custom functionality
- callbacks: Callback instances for training hooks
- loggers: Logger instances for experiment tracking
"""
def setup(
self,
model: nn.Module,
*optimizers: Optimizer
) -> Union[nn.Module, Tuple[nn.Module, ...]]:
"""
Setup model and optimizers for distributed training.
Parameters:
- model: PyTorch model to setup
- optimizers: Optimizer instances to setup
Returns:
Configured model and optimizers
"""
def setup_dataloaders(
self,
*dataloaders: DataLoader
) -> Union[DataLoader, Tuple[DataLoader, ...]]:
"""
Setup dataloaders for distributed training.
Parameters:
- dataloaders: DataLoader instances to setup
Returns:
Configured dataloaders
"""
def backward(self, loss: torch.Tensor) -> None:
"""
Backward pass with automatic gradient scaling.
Parameters:
- loss: Loss tensor to compute gradients for
"""
def step(self, optimizer: Optimizer, *args, **kwargs) -> None:
"""
Optimizer step with gradient unscaling and synchronization.
Parameters:
- optimizer: Optimizer to step
- args, kwargs: Additional arguments passed to optimizer.step()
"""
def clip_gradients(
self,
model: nn.Module,
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True
) -> torch.Tensor:
"""
Clip gradients by norm.
Parameters:
- model: Model whose gradients to clip
- optimizer: Associated optimizer
- max_norm: Maximum norm for gradients
- norm_type: Type of norm to compute
- error_if_nonfinite: Raise error for non-finite gradients
Returns:
Total norm of gradients
"""
def save(self, path: str, state: dict) -> None:
"""
Save training state to checkpoint.
Parameters:
- path: Path to save checkpoint
- state: Dictionary containing model/optimizer states
"""
def load(self, path: str) -> dict:
"""
Load training state from checkpoint.
Parameters:
- path: Path to checkpoint file
Returns:
Dictionary containing loaded state
"""
def barrier(self, name: Optional[str] = None) -> None:
"""
Synchronize all processes.
Parameters:
- name: Optional barrier name for debugging
"""
def broadcast(self, obj: Any, src: int = 0) -> Any:
"""
Broadcast object from source rank to all ranks.
Parameters:
- obj: Object to broadcast
- src: Source rank
Returns:
Broadcasted object
"""
def all_gather(self, data: Any, group: Optional[Any] = None) -> List[Any]:
"""
Gather data from all processes.
Parameters:
- data: Data to gather
- group: Process group
Returns:
List of gathered data from all processes
"""
def all_reduce(
self,
tensor: torch.Tensor,
op: str = "sum",
group: Optional[Any] = None
) -> torch.Tensor:
"""
Reduce tensor across all processes.
Parameters:
- tensor: Tensor to reduce
- op: Reduction operation ('sum', 'mean', 'max', 'min')
- group: Process group
Returns:
Reduced tensor
"""
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
"""
Log metrics to configured loggers.
Parameters:
- name: Metric name
- value: Metric value
- step: Training step (auto-incremented if None)
"""
def log_dict(self, metrics: dict, step: Optional[int] = None) -> None:
"""
Log multiple metrics at once.
Parameters:
- metrics: Dictionary of metric names and values
- step: Training step (auto-incremented if None)
"""
def print(self, *args, **kwargs) -> None:
"""
Print only on rank 0 in distributed training.
Parameters:
- args, kwargs: Arguments passed to print()
"""
@property
def device(self) -> torch.device:
"""Current device."""
@property
def global_rank(self) -> int:
"""Global rank of current process."""
@property
def local_rank(self) -> int:
"""Local rank of current process."""
@property
def node_rank(self) -> int:
"""Node rank of current process."""
@property
def world_size(self) -> int:
"""Total number of processes."""
@property
def is_global_zero(self) -> bool:
"""Whether current process is global rank 0."""def seed_everything(seed: int, workers: bool = False) -> int:
"""
Seed all random number generators for reproducibility.
Parameters:
- seed: Random seed value
- workers: Seed dataloader worker processes
Returns:
The seed value used
"""import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import lightning.fabric as L
# Initialize Fabric
fabric = L.Fabric(accelerator="gpu", devices=2, precision="16-mixed")
fabric.launch()
# Create model, optimizer, and data
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
dataset = torch.randn(1000, 10), torch.randn(1000, 1)
dataloader = DataLoader(list(zip(*dataset)), batch_size=32)
# Setup for distributed training
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
# Custom training loop
model.train()
for epoch in range(10):
epoch_loss = 0
for batch_idx, (x, y) in enumerate(dataloader):
# Forward pass
output = model(x)
loss = nn.functional.mse_loss(output, y)
# Backward pass
optimizer.zero_grad()
fabric.backward(loss)
fabric.step(optimizer)
epoch_loss += loss.item()
# Log metrics
if batch_idx % 10 == 0:
fabric.log("train_loss", loss.item())
fabric.print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader)}")import lightning.fabric as L
fabric = L.Fabric()
fabric.launch()
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
# Training loop with checkpointing
for epoch in range(100):
# ... training code ...
# Save checkpoint every 10 epochs
if epoch % 10 == 0:
state = {
"model": model,
"optimizer": optimizer,
"epoch": epoch
}
fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)
# Resume from checkpoint
checkpoint = fabric.load("checkpoint_epoch_50.ckpt")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"] + 1import lightning.fabric as L
fabric = L.Fabric(devices=4, strategy="ddp")
fabric.launch()
# Broadcast configuration from rank 0
if fabric.global_rank == 0:
config = {"learning_rate": 0.001, "batch_size": 32}
else:
config = None
config = fabric.broadcast(config, src=0)
# Gather metrics from all processes
local_metrics = {"accuracy": 0.95, "loss": 0.1}
all_metrics = fabric.all_gather(local_metrics)
# Reduce tensor across all processes
local_tensor = torch.tensor([1.0, 2.0, 3.0])
reduced_tensor = fabric.all_reduce(local_tensor, op="mean")
fabric.print(f"Reduced tensor: {reduced_tensor}")import lightning.fabric as L
# Enable mixed precision
fabric = L.Fabric(precision="16-mixed")
fabric.launch()
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
# Training loop with automatic mixed precision
for epoch in range(10):
for batch in dataloader:
x, y = batch
# Forward pass (automatically uses mixed precision)
output = model(x)
loss = nn.functional.mse_loss(output, y)
# Backward pass (automatically handles gradient scaling)
optimizer.zero_grad()
fabric.backward(loss) # Handles gradient scaling
fabric.step(optimizer) # Handles gradient unscalingimport lightning.fabric as L
from lightning.fabric.strategies import DeepSpeedStrategy
# Use custom strategy
strategy = DeepSpeedStrategy(stage=2)
fabric = L.Fabric(strategy=strategy, precision="16-mixed")
fabric.launch()
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
# Training proceeds normally - Fabric handles strategy-specific details
for epoch in range(10):
for batch in dataloader:
# ... training code ...
passInstall with Tessl CLI
npx tessl i tessl/pypi-pytorch-lightning