CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning

The Deep Learning framework to train, deploy, and ship AI products Lightning fast.

Pending
Overview
Eval results
Files

fabric.mddocs/

Lightning Fabric

Lightweight training acceleration framework providing expert-level control over training loops, device management, and distributed strategies without high-level abstractions. Fabric gives you the flexibility of raw PyTorch with the power of Lightning's optimizations.

Capabilities

Fabric Core

Main Fabric class that accelerates PyTorch training with distributed training, mixed precision, and device management while maintaining full control over the training loop.

class Fabric:
    def __init__(
        self,
        accelerator: str = "auto",
        strategy: str = "auto",
        devices: Union[List[int], str, int] = "auto",
        num_nodes: int = 1,
        precision: Union[str, int] = "32-true",
        plugins: Optional[Union[Plugin, List[Plugin]]] = None,
        callbacks: Optional[Union[Callback, List[Callback]]] = None,
        loggers: Optional[Union[Logger, List[Logger]]] = None,
        **kwargs
    ):
        """
        Initialize Fabric for training acceleration.
        
        Args:
            accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
            strategy: Distributed strategy ('ddp', 'fsdp', 'deepspeed', etc.)
            devices: Which devices to use
            num_nodes: Number of nodes for distributed training
            precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)
            plugins: Additional plugins for customization
            callbacks: Callbacks for training lifecycle hooks
            loggers: Loggers for experiment tracking
        """

    def setup(
        self,
        model: nn.Module,
        *optimizers: Optimizer
    ) -> Union[nn.Module, Tuple[nn.Module, ...]]:
        """
        Set up model and optimizers for training.
        
        Args:
            model: PyTorch model to accelerate
            *optimizers: Optimizers to set up
            
        Returns:
            Wrapped model and optimizers ready for training
        """

    def setup_dataloaders(
        self,
        *dataloaders: DataLoader
    ) -> Union[DataLoader, List[DataLoader]]:
        """
        Set up data loaders for distributed training.
        
        Args:
            *dataloaders: Data loaders to set up
            
        Returns:
            Wrapped data loaders ready for distributed training
        """

    def backward(self, tensor: Tensor) -> None:
        """
        Perform backward pass with proper scaling and synchronization.
        
        Args:
            tensor: Loss tensor to compute gradients from
        """

    def clip_gradients(
        self,
        model: nn.Module,
        optimizer: Optimizer,
        max_norm: Union[float, int],
        norm_type: Union[float, int] = 2.0,
        error_if_nonfinite: bool = True
    ) -> Tensor:
        """
        Clip gradients by norm.
        
        Args:
            model: Model whose gradients to clip
            optimizer: Optimizer being used
            max_norm: Maximum norm for gradients
            norm_type: Type of norm to use
            error_if_nonfinite: Raise error if gradients are non-finite
            
        Returns:
            Total norm of the gradients
        """

    def all_gather(
        self,
        tensor: Tensor,
        group: Optional[Any] = None,
        sync_grads: bool = False
    ) -> Tensor:
        """
        Gather tensors from all processes.
        
        Args:
            tensor: Tensor to gather
            group: Process group
            sync_grads: Synchronize gradients
            
        Returns:
            Gathered tensor from all processes
        """

    def all_reduce(
        self,
        tensor: Tensor,
        group: Optional[Any] = None,
        reduce_op: str = "mean"
    ) -> Tensor:
        """
        Reduce tensor across all processes.
        
        Args:
            tensor: Tensor to reduce
            group: Process group
            reduce_op: Reduction operation ('mean', 'sum')
            
        Returns:
            Reduced tensor
        """

    def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
        """
        Broadcast tensor from source process to all processes.
        
        Args:
            tensor: Tensor to broadcast
            src: Source rank
            
        Returns:
            Broadcasted tensor
        """

    def barrier(self, name: Optional[str] = None) -> None:
        """
        Synchronize all processes.
        
        Args:
            name: Optional barrier name for debugging
        """

    def is_global_zero(self) -> bool:
        """
        Check if current process is global rank 0.
        
        Returns:
            True if global rank 0
        """

    def print(self, *args, **kwargs) -> None:
        """
        Print only on rank 0.
        
        Args:
            *args: Arguments to print
            **kwargs: Keyword arguments for print
        """

    def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
        """
        Log a metric.
        
        Args:
            name: Metric name
            value: Metric value
            step: Optional step number
        """

    def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
        """
        Log a dictionary of metrics.
        
        Args:
            metrics: Dictionary of metrics
            step: Optional step number
        """

    def save(self, path: str, state: Dict[str, Any]) -> None:
        """
        Save checkpoint.
        
        Args:
            path: Path to save checkpoint
            state: State dictionary to save
        """

    def load(self, path: str) -> Dict[str, Any]:
        """
        Load checkpoint.
        
        Args:
            path: Path to load checkpoint from
            
        Returns:
            Loaded state dictionary
        """

    @property
    def device(self) -> torch.device:
        """Get the current device."""

    @property
    def global_rank(self) -> int:
        """Get global rank of current process."""

    @property
    def local_rank(self) -> int:
        """Get local rank of current process."""

    @property
    def node_rank(self) -> int:
        """Get node rank of current process."""

    @property
    def world_size(self) -> int:
        """Get total number of processes."""

    def to_device(self, obj: Any) -> Any:
        """
        Move object to device.
        
        Args:
            obj: Object to move to device
            
        Returns:
            Object on the device
        """

Utility Functions

Core utility functions for reproducibility, object inspection, and common operations in Fabric workflows.

def seed_everything(seed: int, workers: bool = False) -> int:
    """
    Set random seeds for reproducibility.
    
    Args:
        seed: Random seed to set
        workers: Also set seed for data loader workers
        
    Returns:
        The seed that was set
    """

def is_wrapped(obj: Any) -> bool:
    """
    Check if an object has been wrapped by Fabric.
    
    Args:
        obj: Object to check
        
    Returns:
        True if object is wrapped by Fabric
    """

Basic Usage Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from lightning import Fabric

# Initialize Fabric
fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")

# Define model and optimizer
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Setup model and optimizer with Fabric
model, optimizer = fabric.setup(model, optimizer)

# Create sample data and dataloader
data = torch.randn(1000, 10)
targets = torch.randn(1000, 1)
dataset = TensorDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32)

# Setup dataloader
dataloader = fabric.setup_dataloaders(dataloader)

# Training loop with full control
for epoch in range(10):
    for batch_idx, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        
        # Forward pass
        y_pred = model(x)
        loss = nn.functional.mse_loss(y_pred, y)
        
        # Backward pass - Fabric handles scaling and synchronization
        fabric.backward(loss)
        
        optimizer.step()
        
        # Log metrics
        if batch_idx % 10 == 0:
            fabric.log("train_loss", loss.item())
            fabric.print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

    # Save checkpoint
    state = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch
    }
    fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)

Advanced Usage Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lightning import Fabric

# Initialize Fabric with advanced configuration
fabric = Fabric(
    accelerator="gpu",
    devices=4,
    strategy="fsdp",
    precision="16-mixed",
    plugins=None
)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        return self.layers(x)

# Model and optimizers
model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Setup with Fabric
model, optimizer = fabric.setup(model, optimizer)

# Training loop with advanced features
for epoch in range(100):
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        output = model(data)
        loss = nn.functional.cross_entropy(output, target)
        
        # Backward with automatic mixed precision
        fabric.backward(loss)
        
        # Gradient clipping
        fabric.clip_gradients(model, optimizer, max_norm=1.0)
        
        optimizer.step()
        
        # Metrics logging
        if batch_idx % 100 == 0:
            accuracy = (output.argmax(dim=1) == target).float().mean()
            
            # Log metrics - automatically handles distributed averaging
            fabric.log_dict({
                "train_loss": loss.item(),
                "train_acc": accuracy.item(),
                "lr": scheduler.get_last_lr()[0]
            })
            
            # Print only on rank 0
            fabric.print(f"Epoch {epoch}/{100}, Batch {batch_idx}, "
                        f"Loss: {loss.item():.4f}, Acc: {accuracy.item():.4f}")
    
    scheduler.step()
    
    # Synchronization barrier
    fabric.barrier()
    
    # Save checkpoint (only on rank 0)
    if fabric.is_global_zero():
        checkpoint = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "epoch": epoch,
        }
        fabric.save(f"model_epoch_{epoch}.ckpt", checkpoint)

Install with Tessl CLI

npx tessl i tessl/pypi-lightning

docs

accelerators.md

callbacks.md

core-training.md

data.md

fabric.md

index.md

loggers.md

precision.md

profilers.md

strategies.md

tile.json