Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
npx @tessl/cli install tessl/pypi-lightning-fabric@2.5.0Lightning Fabric is a lightweight PyTorch scaling library that provides expert-level control over PyTorch training loops and scaling strategies. It enables developers to scale complex models including foundation models, LLMs, diffusion models, transformers, and reinforcement learning across any device or scale without boilerplate code.
pip install lightning-fabricfrom lightning.fabric import Fabric, seed_everything, is_wrappedAdditional commonly used utilities:
from lightning.fabric.utilities import (
move_data_to_device,
suggested_max_num_workers,
rank_zero_only,
rank_zero_warn,
rank_zero_info
)from lightning.fabric import Fabric
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# Initialize Fabric with desired configuration
fabric = Fabric(accelerator="auto", devices="auto", strategy="auto")
# Define model and optimizer
model = nn.Linear(10, 1)
optimizer = torch.optim.AdamW(model.parameters())
# Setup with Fabric (handles device placement and distributed wrapping)
model, optimizer = fabric.setup(model, optimizer)
# Setup dataloader
dataloader = fabric.setup_dataloaders(DataLoader(...))
# Training loop
model.train()
for batch in dataloader:
x, y = batch
optimizer.zero_grad()
# Forward pass
y_pred = model(x)
loss = nn.functional.mse_loss(y_pred, y)
# Backward pass with automatic scaling and gradient handling
fabric.backward(loss)
optimizer.step()
# Save checkpoint
state = {"model": model, "optimizer": optimizer}
fabric.save("checkpoint.ckpt", state)Lightning Fabric uses a plugin-based architecture that enables flexible scaling and customization:
This plugin system allows Fabric to work across any hardware configuration and distributed training setup while maintaining the same simple API.
Main Fabric class that handles distributed training setup, model and optimizer configuration, checkpoint management, and training utilities.
class Fabric:
def __init__(
self,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[list[int], str, int] = "auto",
num_nodes: int = 1,
precision: Optional[Union[str, int]] = None,
plugins: Optional[Union[Any, list[Any]]] = None,
callbacks: Optional[Union[list[Any], Any]] = None,
loggers: Optional[Union[Logger, list[Logger]]] = None
): ...
def setup(self, module, *optimizers, move_to_device=True): ...
def setup_module(self, module, move_to_device=True): ...
def setup_optimizers(self, *optimizers): ...
def setup_dataloaders(self, *dataloaders, use_distributed_sampler=True): ...
def backward(self, tensor, *args, model=None, **kwargs): ...
def save(self, path, state, filter=None): ...
def load(self, path, state=None, strict=True): ...Collective communication operations for synchronizing data and gradients across processes in distributed training.
def barrier(self, name=None) -> None: ...
def broadcast(self, obj, src=0): ...
def all_gather(self, data, group=None, sync_grads=False): ...
def all_reduce(self, data, group=None, reduce_op="mean"): ...Hardware acceleration plugins for different compute devices including CPU, CUDA GPUs, Apple MPS, and TPUs.
class Accelerator: ... # Abstract base
class CPUAccelerator(Accelerator): ...
class CUDAAccelerator(Accelerator): ...
class MPSAccelerator(Accelerator): ...
class XLAAccelerator(Accelerator): ...Distributed training strategies for scaling models across devices and nodes.
class Strategy: ... # Abstract base
class SingleDeviceStrategy(Strategy): ...
class DataParallelStrategy(Strategy): ...
class DDPStrategy(Strategy): ...
class DeepSpeedStrategy(Strategy): ...
class FSDPStrategy(Strategy): ...
class XLAStrategy(Strategy): ...Precision plugins for mixed precision training, quantization, and memory optimization.
class Precision: ... # Abstract base
class DoublePrecision(Precision): ...
class HalfPrecision(Precision): ...
class MixedPrecision(Precision): ...
class BitsandbytesPrecision(Precision): ...
class DeepSpeedPrecision(Precision): ...
class FSDPPrecision(Precision): ...Helper functions for seeding, data movement, distributed utilities, and performance monitoring.
def seed_everything(seed=None, workers=False, verbose=True) -> int: ...
def is_wrapped(obj) -> bool: ...
def move_data_to_device(obj, device): ...
def suggested_max_num_workers(num_cpus): ...# Common type aliases used throughout the API
_PATH = Union[str, Path]
_DEVICE = Union[torch.device, str, int]
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable, dict[_DEVICE, _DEVICE]]]
_PARAMETERS = Iterator[torch.nn.Parameter]
ReduceOp = torch.distributed.ReduceOp
RedOpType = ReduceOp.RedOpType
# Protocols for type checking
@runtime_checkable
class _Stateful(Protocol[_DictKey]):
def state_dict(self) -> dict[_DictKey, Any]: ...
def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ...
@runtime_checkable
class Steppable(Protocol):
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: ...
@runtime_checkable
class Optimizable(Steppable, Protocol):
param_groups: list[dict[Any, Any]]
defaults: dict[Any, Any]
state: defaultdict[Tensor, Any]
def state_dict(self) -> dict[str, dict[Any, Any]]: ...
def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ...
@runtime_checkable
class CollectibleGroup(Protocol):
def size(self) -> int: ...
def rank(self) -> int: ...