CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning-fabric

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate

Pending
Overview
Eval results
Files

distributed.mddocs/

Distributed Operations

Collective communication operations and utilities for coordinating processes in distributed training environments.

Capabilities

Communication Primitives

Core collective operations for synchronizing data and computations across distributed processes.

def barrier(self, name: Optional[str] = None) -> None:
    """
    Synchronize all processes at this point.
    
    Blocks until all processes reach this barrier. Useful for ensuring
    all processes complete a phase before proceeding.
    
    Args:
        name: Optional name for the barrier (for debugging)
        
    Raises:
        RuntimeError: If barrier times out or fails
    """

def broadcast(self, obj: Any, src: int = 0) -> Any:
    """
    Broadcast object from source process to all other processes.
    
    Args:
        obj: Object to broadcast (tensor, dict, list, etc.)
        src: Source process rank (default: 0)
        
    Returns:
        The broadcasted object on all processes
        
    Examples:
        # Broadcast model parameters from rank 0
        params = fabric.broadcast(model.state_dict(), src=0)
        
        # Broadcast configuration dictionary
        config = fabric.broadcast({"lr": 0.001, "batch_size": 32}, src=0)
    """

def all_gather(
    self,
    data: Union[Tensor, dict, list, tuple],
    group: Optional[Any] = None,
    sync_grads: bool = False
) -> Union[Tensor, dict, list, tuple]:
    """
    Gather data from all processes and concatenate.
    
    Each process contributes its data, and all processes receive
    the concatenated result from all processes.
    
    Args:
        data: Data to gather (tensor, dict, list, or tuple)
        group: Process group (None for default group)
        sync_grads: Whether to synchronize gradients
        
    Returns:
        Gathered data from all processes
        
    Examples:
        # Gather predictions from all processes
        local_preds = model(batch)
        all_preds = fabric.all_gather(local_preds)
        
        # Gather metrics dictionary
        local_metrics = {"accuracy": 0.95, "loss": 0.1}
        all_metrics = fabric.all_gather(local_metrics)
    """

def all_reduce(
    self,
    data: Union[Tensor, dict, list, tuple],
    group: Optional[Any] = None,
    reduce_op: Union[str, ReduceOp] = "mean"
) -> Union[Tensor, dict, list, tuple]:
    """
    Reduce data across all processes using specified operation.
    
    Applies reduction operation (sum, mean, max, min) across all processes
    and returns the result to all processes.
    
    Args:
        data: Data to reduce (tensor, dict, list, or tuple)
        group: Process group (None for default group)
        reduce_op: Reduction operation ("sum", "mean", "max", "min")
        
    Returns:
        Reduced data
        
    Examples:
        # Average loss across all processes
        local_loss = compute_loss(batch)
        avg_loss = fabric.all_reduce(local_loss, reduce_op="mean")
        
        # Sum gradients across processes
        grads = fabric.all_reduce(gradients, reduce_op="sum")
    """

Synchronization Utilities

Higher-level utilities for process coordination and data movement.

def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
    """
    Move object to the appropriate device.
    
    Automatically handles device placement for tensors, modules,
    and nested data structures.
    
    Args:
        obj: Object to move to device
        
    Returns:
        Object moved to target device
        
    Examples:
        # Move tensor to device
        tensor = torch.randn(10, 10)
        tensor = fabric.to_device(tensor)
        
        # Move nested data structure
        data = {"input": torch.randn(32, 784), "target": torch.randint(0, 10, (32,))}
        data = fabric.to_device(data)
    """

def print(self, *args, **kwargs) -> None:
    """
    Print only from rank 0 process.
    
    Prevents duplicate printing in distributed training by only
    allowing the rank 0 process to print.
    
    Args:
        *args: Arguments to print
        **kwargs: Keyword arguments for print function
        
    Examples:
        fabric.print(f"Epoch {epoch}, Loss: {loss:.4f}")
        fabric.print("Training completed!", file=sys.stderr)
    """

Advanced Synchronization

Context managers and advanced coordination primitives.

def rank_zero_first(self, local: bool = False) -> Generator:
    """
    Context manager ensuring rank 0 executes first.
    
    Useful for operations that should be performed by one process first
    (e.g., dataset preparation, model initialization).
    
    Args:
        local: If True, use local rank (within node), otherwise global rank
        
    Yields:
        None
        
    Examples:
        # Download dataset only on rank 0 first
        with fabric.rank_zero_first():
            dataset = download_dataset()
        
        # Initialize model weights on rank 0 first  
        with fabric.rank_zero_first():
            if fabric.is_global_zero:
                initialize_model_weights(model)
    """

def no_backward_sync(
    self,
    module: _FabricModule,
    enabled: bool = True
) -> AbstractContextManager:
    """
    Context manager to skip gradient synchronization.
    
    When enabled, gradients are not synchronized across processes
    during backward pass. Useful for gradient accumulation.
    
    Args:
        module: Fabric-wrapped module
        enabled: Whether to skip synchronization
        
    Returns:
        Context manager
        
    Examples:
        # Gradient accumulation without sync
        for i, batch in enumerate(batches):
            with fabric.no_backward_sync(model, enabled=(i < accumulate_steps-1)):
                loss = compute_loss(model, batch)
                fabric.backward(loss)
        
        # Final step with synchronization
        optimizer.step()
    """

Process Information

Properties and methods to query distributed training state.

@property
def global_rank(self) -> int:
    """Global rank of current process across all nodes."""

@property
def local_rank(self) -> int:
    """Local rank of current process within the current node."""

@property
def node_rank(self) -> int:
    """Rank of the current node."""

@property
def world_size(self) -> int:
    """Total number of processes across all nodes."""

@property
def is_global_zero(self) -> bool:
    """Whether current process is global rank 0."""

Usage Examples

Basic Communication

from lightning.fabric import Fabric

fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")

# Broadcast configuration from rank 0
if fabric.is_global_zero:
    config = {"learning_rate": 0.001, "batch_size": 32}
else:
    config = None

config = fabric.broadcast(config, src=0)
print(f"Rank {fabric.global_rank}: {config}")

Gradient Accumulation

# Accumulate gradients over multiple batches
accumulate_steps = 4
model.train()

for batch_idx, batch in enumerate(dataloader):
    # Skip gradient sync except on last accumulation step
    with fabric.no_backward_sync(model, enabled=(batch_idx % accumulate_steps != 0)):
        loss = compute_loss(model, batch) / accumulate_steps
        fabric.backward(loss)
    
    # Update weights after accumulation steps
    if (batch_idx + 1) % accumulate_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Distributed Evaluation

# Evaluate model across all processes
model.eval()
all_predictions = []
all_targets = []

for batch in eval_dataloader:
    with torch.no_grad():
        predictions = model(batch["input"])
        targets = batch["target"]
    
    # Gather predictions and targets from all processes
    all_preds = fabric.all_gather(predictions)
    all_targs = fabric.all_gather(targets)
    
    all_predictions.append(all_preds)
    all_targets.append(all_targs)

# Compute metrics on gathered data
if fabric.is_global_zero:
    predictions = torch.cat(all_predictions)
    targets = torch.cat(all_targets)
    accuracy = compute_accuracy(predictions, targets)
    fabric.print(f"Evaluation accuracy: {accuracy:.4f}")

Loss Synchronization

# Compute and synchronize loss across processes
model.train()
total_loss = 0
num_batches = 0

for batch in dataloader:
    loss = compute_loss(model, batch)
    
    # Synchronize loss across processes for logging
    sync_loss = fabric.all_reduce(loss, reduce_op="mean")
    
    fabric.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
    
    total_loss += sync_loss.item()
    num_batches += 1
    
    if num_batches % 100 == 0:
        avg_loss = total_loss / num_batches
        fabric.print(f"Step {num_batches}, Avg Loss: {avg_loss:.4f}")

Barrier Synchronization

# Ensure all processes complete data preparation
fabric.print("Starting data preparation...")

# Each process prepares its portion of data
prepare_local_data()

# Wait for all processes to complete
fabric.barrier("data_preparation")
fabric.print("All processes completed data preparation")

# Continue with training
start_training()

Install with Tessl CLI

npx tessl i tessl/pypi-lightning-fabric

docs

accelerators.md

core-training.md

distributed.md

index.md

precision.md

strategies.md

utilities.md

tile.json