Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
—
The main Fabric class and associated wrapper classes that provide the foundation for distributed PyTorch training with minimal code changes.
The main orchestrator class that handles all aspects of distributed training setup and execution.
class Fabric:
"""
Main class for accelerating PyTorch training with minimal changes.
Provides automatic device placement, mixed precision, distributed training,
and seamless switching between hardware configurations.
"""
def __init__(
self,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[list[int], str, int] = "auto",
num_nodes: int = 1,
precision: Optional[Union[str, int]] = None,
plugins: Optional[Union[Any, list[Any]]] = None,
callbacks: Optional[Union[list[Any], Any]] = None,
loggers: Optional[Union[Logger, list[Logger]]] = None
):
"""
Initialize Fabric with hardware and training configuration.
Args:
accelerator: Hardware to run on ("cpu", "cuda", "mps", "gpu", "tpu", "auto")
strategy: Distribution strategy ("dp", "ddp", "ddp_spawn", "deepspeed", "fsdp", "auto")
devices: Number of devices or specific device IDs
num_nodes: Number of nodes for multi-node training
precision: Precision mode ("64", "32", "16-mixed", "bf16-mixed", etc.)
plugins: Additional plugins for customization
callbacks: Callback functions for training events
loggers: Logger instances for experiment tracking
"""Configure models, optimizers, and dataloaders for distributed training.
def setup(
self,
module: nn.Module,
*optimizers: Optimizer,
move_to_device: bool = True,
_reapply_compile: bool = True
) -> Union[_FabricModule, tuple[_FabricModule, _FabricOptimizer, ...]]:
"""
Setup model and optimizers for distributed training.
Args:
module: PyTorch model to setup
*optimizers: One or more optimizers
move_to_device: Whether to move model to target device
_reapply_compile: Whether to reapply torch.compile if present
Returns:
Fabric-wrapped module and optimizers
"""
def setup_module(
self,
module: nn.Module,
move_to_device: bool = True,
_reapply_compile: bool = True
) -> _FabricModule:
"""
Setup only the model for distributed training.
Args:
module: PyTorch model to setup
move_to_device: Whether to move model to target device
_reapply_compile: Whether to reapply torch.compile if present
Returns:
Fabric-wrapped module
"""
def setup_optimizers(
self,
*optimizers: Optimizer
) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]:
"""
Setup optimizers for distributed training.
Args:
*optimizers: One or more optimizers to setup
Returns:
Fabric-wrapped optimizer(s)
"""
def setup_dataloaders(
self,
*dataloaders: DataLoader,
use_distributed_sampler: bool = True,
move_to_device: bool = True
) -> Union[DataLoader, list[DataLoader]]:
"""
Setup dataloaders for distributed training.
Args:
*dataloaders: One or more dataloaders to setup
use_distributed_sampler: Whether to replace sampler for distributed training
move_to_device: Whether to move data to target device automatically
Returns:
Configured dataloader(s)
"""Core methods for training loops including backward pass, gradient clipping, and precision handling.
def backward(
self,
tensor: Tensor,
*args,
model: Optional[_FabricModule] = None,
**kwargs
) -> None:
"""
Perform backward pass with automatic gradient scaling and accumulation.
Args:
tensor: Loss tensor to compute gradients for
*args: Additional arguments passed to tensor.backward()
model: Model to sync gradients for (auto-detected if None)
**kwargs: Additional keyword arguments
"""
def clip_gradients(
self,
module: _FabricModule,
optimizer: _FabricOptimizer,
clip_val: Optional[Union[int, float]] = None,
max_norm: Optional[Union[int, float]] = None,
norm_type: Union[int, float] = 2.0,
error_if_nonfinite: bool = True
) -> Optional[Tensor]:
"""
Clip gradients by value or norm.
Args:
module: Fabric-wrapped module
optimizer: Fabric-wrapped optimizer
clip_val: Maximum allowed value of gradients
max_norm: Maximum allowed norm of gradients
norm_type: Type of norm to compute (default: 2.0 for L2 norm)
error_if_nonfinite: Whether to error on non-finite gradients
Returns:
Total norm of the parameters if max_norm is specified
"""
def autocast(self) -> AbstractContextManager:
"""
Context manager for automatic mixed precision.
Returns:
Context manager that applies appropriate precision casting
"""Save and load model states, optimizers, and training metadata.
def save(
self,
path: _PATH,
state: dict[str, Any],
filter: Optional[dict[str, Any]] = None
) -> None:
"""
Save checkpoint with distributed training support.
Args:
path: Checkpoint file path
state: Dictionary containing model, optimizer, and other state
filter: Optional filter for state dict keys
"""
def load(
self,
path: _PATH,
state: Optional[dict[str, Any]] = None,
strict: bool = True
) -> dict[str, Any]:
"""
Load checkpoint with distributed training support.
Args:
path: Checkpoint file path
state: Dictionary to load state into (if provided)
strict: Whether to strictly enforce state dict key matching
Returns:
Loaded checkpoint dictionary
"""
def load_raw(
self,
path: _PATH,
obj: Union[nn.Module, Optimizer],
strict: bool = True
) -> None:
"""
Load raw PyTorch checkpoint into object.
Args:
path: Checkpoint file path
obj: Object to load state into
strict: Whether to strictly enforce state dict key matching
"""Launch and coordinate distributed processes.
def launch(
self,
function: Callable = lambda: None,
*args,
**kwargs
) -> Any:
"""
Launch the distributed training processes.
Args:
function: Function to execute in distributed processes
*args: Arguments to pass to function
**kwargs: Keyword arguments to pass to function
Returns:
Result from function execution
"""
def run(self, *args, **kwargs) -> Any:
"""
Execute main training function with distributed setup.
Args:
*args: Arguments passed to training function
**kwargs: Keyword arguments passed to training function
Returns:
Result from training function
"""Access information about the distributed training setup.
@property
def accelerator(self) -> Accelerator:
"""Current accelerator instance."""
@property
def strategy(self) -> Strategy:
"""Current strategy instance."""
@property
def device(self) -> torch.device:
"""Current device."""
@property
def global_rank(self) -> int:
"""Global rank of this process."""
@property
def local_rank(self) -> int:
"""Local rank of this process on current node."""
@property
def node_rank(self) -> int:
"""Rank of current node."""
@property
def world_size(self) -> int:
"""Total number of processes."""
@property
def is_global_zero(self) -> bool:
"""Whether this is the rank 0 process."""
@property
def loggers(self) -> list[Logger]:
"""List of all logger instances."""
@property
def logger(self) -> Logger:
"""Primary logger instance."""Fabric automatically wraps PyTorch objects to provide distributed training support.
class _FabricModule:
"""Wrapper for PyTorch modules with distributed training support."""
@property
def module(self) -> nn.Module:
"""Access the wrapped PyTorch module."""
def forward(self, *args, **kwargs) -> Any:
"""Forward pass with precision handling."""
def state_dict(self, **kwargs) -> dict[str, Any]:
"""Get module state dictionary."""
def load_state_dict(self, state_dict: dict, strict: bool = True) -> Any:
"""Load module state dictionary."""
class _FabricOptimizer:
"""Wrapper for PyTorch optimizers with distributed training support."""
@property
def optimizer(self) -> Optimizer:
"""Access the wrapped PyTorch optimizer."""
def step(self, closure: Optional[Callable] = None) -> Any:
"""Perform optimizer step."""
def zero_grad(self, set_to_none: bool = False) -> None:
"""Zero the gradients."""
def state_dict(self) -> dict[str, Any]:
"""Get optimizer state dictionary."""
def load_state_dict(self, state_dict: dict) -> None:
"""Load optimizer state dictionary."""
class _FabricDataLoader:
"""Wrapper for PyTorch DataLoaders with distributed training support."""
@property
def device(self) -> Optional[torch.device]:
"""Target device for data placement."""Special context managers for advanced training scenarios.
def no_backward_sync(
self,
module: _FabricModule,
enabled: bool = True
) -> AbstractContextManager:
"""
Context manager to skip gradient synchronization.
Args:
module: Fabric-wrapped module
enabled: Whether to skip sync (True) or perform normal sync (False)
Returns:
Context manager
"""
def rank_zero_first(self, local: bool = False) -> Generator:
"""
Context manager ensuring rank 0 executes first.
Args:
local: Whether to use local rank (node-level) or global rank
Yields:
None
"""
def init_tensor(self) -> AbstractContextManager:
"""
Context manager for tensor initialization.
Returns:
Context manager for tensor initialization
"""
def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
"""
Context manager for module initialization.
Args:
empty_init: Whether to use empty initialization
Returns:
Context manager for module initialization
"""Log metrics and values to registered loggers for experiment tracking.
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
"""
Log a scalar to all loggers that were added to Fabric.
Args:
name: The name of the metric to log
value: The metric value to collect. If the value is a torch.Tensor, it gets detached automatically
step: Optional step number. Most Logger implementations auto-increment this value
"""
def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> None:
"""
Log multiple scalars at once to all loggers that were added to Fabric.
Args:
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged
step: Optional step number. Most Logger implementations auto-increment this value
"""Invoke registered callback methods for training event handling.
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
"""
Trigger the callback methods with the given name and arguments.
Args:
hook_name: The name of the callback method
*args: Optional positional arguments that get passed down to the callback method
**kwargs: Optional keyword arguments that get passed down to the callback method
"""from lightning.fabric import Fabric
import torch
import torch.nn as nn
# Initialize Fabric
fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")
# Define model and optimizer
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Setup with Fabric
model, optimizer = fabric.setup(model, optimizer)
# Training loop
for epoch in range(10):
for batch in dataloader:
x, y = batch
optimizer.zero_grad()
y_pred = model(x)
loss = nn.functional.cross_entropy(y_pred, y)
fabric.backward(loss)
optimizer.step()# Save checkpoint
state = {
"model": model,
"optimizer": optimizer,
"epoch": epoch,
"loss": loss.item()
}
fabric.save("checkpoint.ckpt", state)
# Load checkpoint
loaded_state = fabric.load("checkpoint.ckpt")
epoch = loaded_state["epoch"]
loss = loaded_state["loss"]# Initialize with mixed precision
fabric = Fabric(precision="16-mixed")
# Use autocast context
for batch in dataloader:
with fabric.autocast():
y_pred = model(batch)
loss = criterion(y_pred, targets)
fabric.backward(loss)
optimizer.step()Install with Tessl CLI
npx tessl i tessl/pypi-lightning-fabric