Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
—
Collective communication operations and utilities for coordinating processes in distributed training environments.
Core collective operations for synchronizing data and computations across distributed processes.
def barrier(self, name: Optional[str] = None) -> None:
"""
Synchronize all processes at this point.
Blocks until all processes reach this barrier. Useful for ensuring
all processes complete a phase before proceeding.
Args:
name: Optional name for the barrier (for debugging)
Raises:
RuntimeError: If barrier times out or fails
"""
def broadcast(self, obj: Any, src: int = 0) -> Any:
"""
Broadcast object from source process to all other processes.
Args:
obj: Object to broadcast (tensor, dict, list, etc.)
src: Source process rank (default: 0)
Returns:
The broadcasted object on all processes
Examples:
# Broadcast model parameters from rank 0
params = fabric.broadcast(model.state_dict(), src=0)
# Broadcast configuration dictionary
config = fabric.broadcast({"lr": 0.001, "batch_size": 32}, src=0)
"""
def all_gather(
self,
data: Union[Tensor, dict, list, tuple],
group: Optional[Any] = None,
sync_grads: bool = False
) -> Union[Tensor, dict, list, tuple]:
"""
Gather data from all processes and concatenate.
Each process contributes its data, and all processes receive
the concatenated result from all processes.
Args:
data: Data to gather (tensor, dict, list, or tuple)
group: Process group (None for default group)
sync_grads: Whether to synchronize gradients
Returns:
Gathered data from all processes
Examples:
# Gather predictions from all processes
local_preds = model(batch)
all_preds = fabric.all_gather(local_preds)
# Gather metrics dictionary
local_metrics = {"accuracy": 0.95, "loss": 0.1}
all_metrics = fabric.all_gather(local_metrics)
"""
def all_reduce(
self,
data: Union[Tensor, dict, list, tuple],
group: Optional[Any] = None,
reduce_op: Union[str, ReduceOp] = "mean"
) -> Union[Tensor, dict, list, tuple]:
"""
Reduce data across all processes using specified operation.
Applies reduction operation (sum, mean, max, min) across all processes
and returns the result to all processes.
Args:
data: Data to reduce (tensor, dict, list, or tuple)
group: Process group (None for default group)
reduce_op: Reduction operation ("sum", "mean", "max", "min")
Returns:
Reduced data
Examples:
# Average loss across all processes
local_loss = compute_loss(batch)
avg_loss = fabric.all_reduce(local_loss, reduce_op="mean")
# Sum gradients across processes
grads = fabric.all_reduce(gradients, reduce_op="sum")
"""Higher-level utilities for process coordination and data movement.
def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
"""
Move object to the appropriate device.
Automatically handles device placement for tensors, modules,
and nested data structures.
Args:
obj: Object to move to device
Returns:
Object moved to target device
Examples:
# Move tensor to device
tensor = torch.randn(10, 10)
tensor = fabric.to_device(tensor)
# Move nested data structure
data = {"input": torch.randn(32, 784), "target": torch.randint(0, 10, (32,))}
data = fabric.to_device(data)
"""
def print(self, *args, **kwargs) -> None:
"""
Print only from rank 0 process.
Prevents duplicate printing in distributed training by only
allowing the rank 0 process to print.
Args:
*args: Arguments to print
**kwargs: Keyword arguments for print function
Examples:
fabric.print(f"Epoch {epoch}, Loss: {loss:.4f}")
fabric.print("Training completed!", file=sys.stderr)
"""Context managers and advanced coordination primitives.
def rank_zero_first(self, local: bool = False) -> Generator:
"""
Context manager ensuring rank 0 executes first.
Useful for operations that should be performed by one process first
(e.g., dataset preparation, model initialization).
Args:
local: If True, use local rank (within node), otherwise global rank
Yields:
None
Examples:
# Download dataset only on rank 0 first
with fabric.rank_zero_first():
dataset = download_dataset()
# Initialize model weights on rank 0 first
with fabric.rank_zero_first():
if fabric.is_global_zero:
initialize_model_weights(model)
"""
def no_backward_sync(
self,
module: _FabricModule,
enabled: bool = True
) -> AbstractContextManager:
"""
Context manager to skip gradient synchronization.
When enabled, gradients are not synchronized across processes
during backward pass. Useful for gradient accumulation.
Args:
module: Fabric-wrapped module
enabled: Whether to skip synchronization
Returns:
Context manager
Examples:
# Gradient accumulation without sync
for i, batch in enumerate(batches):
with fabric.no_backward_sync(model, enabled=(i < accumulate_steps-1)):
loss = compute_loss(model, batch)
fabric.backward(loss)
# Final step with synchronization
optimizer.step()
"""Properties and methods to query distributed training state.
@property
def global_rank(self) -> int:
"""Global rank of current process across all nodes."""
@property
def local_rank(self) -> int:
"""Local rank of current process within the current node."""
@property
def node_rank(self) -> int:
"""Rank of the current node."""
@property
def world_size(self) -> int:
"""Total number of processes across all nodes."""
@property
def is_global_zero(self) -> bool:
"""Whether current process is global rank 0."""from lightning.fabric import Fabric
fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")
# Broadcast configuration from rank 0
if fabric.is_global_zero:
config = {"learning_rate": 0.001, "batch_size": 32}
else:
config = None
config = fabric.broadcast(config, src=0)
print(f"Rank {fabric.global_rank}: {config}")# Accumulate gradients over multiple batches
accumulate_steps = 4
model.train()
for batch_idx, batch in enumerate(dataloader):
# Skip gradient sync except on last accumulation step
with fabric.no_backward_sync(model, enabled=(batch_idx % accumulate_steps != 0)):
loss = compute_loss(model, batch) / accumulate_steps
fabric.backward(loss)
# Update weights after accumulation steps
if (batch_idx + 1) % accumulate_steps == 0:
optimizer.step()
optimizer.zero_grad()# Evaluate model across all processes
model.eval()
all_predictions = []
all_targets = []
for batch in eval_dataloader:
with torch.no_grad():
predictions = model(batch["input"])
targets = batch["target"]
# Gather predictions and targets from all processes
all_preds = fabric.all_gather(predictions)
all_targs = fabric.all_gather(targets)
all_predictions.append(all_preds)
all_targets.append(all_targs)
# Compute metrics on gathered data
if fabric.is_global_zero:
predictions = torch.cat(all_predictions)
targets = torch.cat(all_targets)
accuracy = compute_accuracy(predictions, targets)
fabric.print(f"Evaluation accuracy: {accuracy:.4f}")# Compute and synchronize loss across processes
model.train()
total_loss = 0
num_batches = 0
for batch in dataloader:
loss = compute_loss(model, batch)
# Synchronize loss across processes for logging
sync_loss = fabric.all_reduce(loss, reduce_op="mean")
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
total_loss += sync_loss.item()
num_batches += 1
if num_batches % 100 == 0:
avg_loss = total_loss / num_batches
fabric.print(f"Step {num_batches}, Avg Loss: {avg_loss:.4f}")# Ensure all processes complete data preparation
fabric.print("Starting data preparation...")
# Each process prepares its portion of data
prepare_local_data()
# Wait for all processes to complete
fabric.barrier("data_preparation")
fabric.print("All processes completed data preparation")
# Continue with training
start_training()Install with Tessl CLI
npx tessl i tessl/pypi-lightning-fabric