CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pytorch-lightning

Unified deep learning framework integrating PyTorch Lightning, Lightning Fabric, and Lightning Apps for training, deploying, and shipping AI products.

Pending
Overview
Eval results
Files

fabric.mddocs/

Low-Level Training Control

Lightning Fabric provides fine-grained control over training loops while automatically handling device management, distributed training setup, and gradient synchronization. This enables custom training logic with minimal boilerplate code.

Capabilities

Fabric Class

Core abstraction that handles device management, distributed training setup, mixed precision, and gradient synchronization while giving you full control over the training loop.

class Fabric:
    def __init__(
        self,
        accelerator: str = "auto",
        devices: Union[int, str, List[int]] = "auto",
        num_nodes: int = 1,
        strategy: Optional[str] = None,
        precision: Optional[str] = None,
        plugins: Optional[Union[str, list]] = None,
        callbacks: Optional[Union[List, dict]] = None,
        loggers: Optional[Union[Logger, List[Logger]]] = None
    ):
        """
        Initialize Fabric for low-level training control.
        
        Parameters:
        - accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
        - devices: Device specification (int, list, or 'auto')
        - num_nodes: Number of nodes for distributed training
        - strategy: Training strategy for distributed training
        - precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')
        - plugins: Additional plugins for custom functionality
        - callbacks: Callback instances for training hooks
        - loggers: Logger instances for experiment tracking
        """
    
    def setup(
        self,
        model: nn.Module,
        *optimizers: Optimizer
    ) -> Union[nn.Module, Tuple[nn.Module, ...]]:
        """
        Setup model and optimizers for distributed training.
        
        Parameters:
        - model: PyTorch model to setup
        - optimizers: Optimizer instances to setup
        
        Returns:
        Configured model and optimizers
        """
    
    def setup_dataloaders(
        self,
        *dataloaders: DataLoader
    ) -> Union[DataLoader, Tuple[DataLoader, ...]]:
        """
        Setup dataloaders for distributed training.
        
        Parameters:
        - dataloaders: DataLoader instances to setup
        
        Returns:
        Configured dataloaders
        """
    
    def backward(self, loss: torch.Tensor) -> None:
        """
        Backward pass with automatic gradient scaling.
        
        Parameters:
        - loss: Loss tensor to compute gradients for
        """
    
    def step(self, optimizer: Optimizer, *args, **kwargs) -> None:
        """
        Optimizer step with gradient unscaling and synchronization.
        
        Parameters:
        - optimizer: Optimizer to step
        - args, kwargs: Additional arguments passed to optimizer.step()
        """
    
    def clip_gradients(
        self,
        model: nn.Module,
        optimizer: Optimizer,
        max_norm: Union[float, int],
        norm_type: Union[float, int] = 2.0,
        error_if_nonfinite: bool = True
    ) -> torch.Tensor:
        """
        Clip gradients by norm.
        
        Parameters:
        - model: Model whose gradients to clip
        - optimizer: Associated optimizer
        - max_norm: Maximum norm for gradients
        - norm_type: Type of norm to compute
        - error_if_nonfinite: Raise error for non-finite gradients
        
        Returns:
        Total norm of gradients
        """
    
    def save(self, path: str, state: dict) -> None:
        """
        Save training state to checkpoint.
        
        Parameters:
        - path: Path to save checkpoint
        - state: Dictionary containing model/optimizer states
        """
    
    def load(self, path: str) -> dict:
        """
        Load training state from checkpoint.
        
        Parameters:
        - path: Path to checkpoint file
        
        Returns:
        Dictionary containing loaded state
        """
    
    def barrier(self, name: Optional[str] = None) -> None:
        """
        Synchronize all processes.
        
        Parameters:
        - name: Optional barrier name for debugging
        """
    
    def broadcast(self, obj: Any, src: int = 0) -> Any:
        """
        Broadcast object from source rank to all ranks.
        
        Parameters:
        - obj: Object to broadcast
        - src: Source rank
        
        Returns:
        Broadcasted object
        """
    
    def all_gather(self, data: Any, group: Optional[Any] = None) -> List[Any]:
        """
        Gather data from all processes.
        
        Parameters:
        - data: Data to gather
        - group: Process group
        
        Returns:
        List of gathered data from all processes
        """
    
    def all_reduce(
        self,
        tensor: torch.Tensor,
        op: str = "sum",
        group: Optional[Any] = None
    ) -> torch.Tensor:
        """
        Reduce tensor across all processes.
        
        Parameters:
        - tensor: Tensor to reduce
        - op: Reduction operation ('sum', 'mean', 'max', 'min')
        - group: Process group
        
        Returns:
        Reduced tensor
        """
    
    def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
        """
        Log metrics to configured loggers.
        
        Parameters:
        - name: Metric name
        - value: Metric value
        - step: Training step (auto-incremented if None)
        """
    
    def log_dict(self, metrics: dict, step: Optional[int] = None) -> None:
        """
        Log multiple metrics at once.
        
        Parameters:
        - metrics: Dictionary of metric names and values
        - step: Training step (auto-incremented if None)
        """
    
    def print(self, *args, **kwargs) -> None:
        """
        Print only on rank 0 in distributed training.
        
        Parameters:
        - args, kwargs: Arguments passed to print()
        """
    
    @property
    def device(self) -> torch.device:
        """Current device."""
    
    @property
    def global_rank(self) -> int:
        """Global rank of current process."""
    
    @property
    def local_rank(self) -> int:
        """Local rank of current process."""
    
    @property
    def node_rank(self) -> int:
        """Node rank of current process."""
    
    @property
    def world_size(self) -> int:
        """Total number of processes."""
    
    @property
    def is_global_zero(self) -> bool:
        """Whether current process is global rank 0."""

Utility Functions

def seed_everything(seed: int, workers: bool = False) -> int:
    """
    Seed all random number generators for reproducibility.
    
    Parameters:
    - seed: Random seed value
    - workers: Seed dataloader worker processes
    
    Returns:
    The seed value used
    """

Usage Examples

Custom Training Loop

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import lightning.fabric as L

# Initialize Fabric
fabric = L.Fabric(accelerator="gpu", devices=2, precision="16-mixed")
fabric.launch()

# Create model, optimizer, and data
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
dataset = torch.randn(1000, 10), torch.randn(1000, 1)
dataloader = DataLoader(list(zip(*dataset)), batch_size=32)

# Setup for distributed training
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)

# Custom training loop
model.train()
for epoch in range(10):
    epoch_loss = 0
    for batch_idx, (x, y) in enumerate(dataloader):
        # Forward pass
        output = model(x)
        loss = nn.functional.mse_loss(output, y)
        
        # Backward pass
        optimizer.zero_grad()
        fabric.backward(loss)
        fabric.step(optimizer)
        
        epoch_loss += loss.item()
        
        # Log metrics
        if batch_idx % 10 == 0:
            fabric.log("train_loss", loss.item())
    
    fabric.print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader)}")

Checkpointing and Resuming

import lightning.fabric as L

fabric = L.Fabric()
fabric.launch()

model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

# Training loop with checkpointing
for epoch in range(100):
    # ... training code ...
    
    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        state = {
            "model": model,
            "optimizer": optimizer,
            "epoch": epoch
        }
        fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)

# Resume from checkpoint
checkpoint = fabric.load("checkpoint_epoch_50.ckpt")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"] + 1

Distributed Training Primitives

import lightning.fabric as L

fabric = L.Fabric(devices=4, strategy="ddp")
fabric.launch()

# Broadcast configuration from rank 0
if fabric.global_rank == 0:
    config = {"learning_rate": 0.001, "batch_size": 32}
else:
    config = None

config = fabric.broadcast(config, src=0)

# Gather metrics from all processes
local_metrics = {"accuracy": 0.95, "loss": 0.1}
all_metrics = fabric.all_gather(local_metrics)

# Reduce tensor across all processes
local_tensor = torch.tensor([1.0, 2.0, 3.0])
reduced_tensor = fabric.all_reduce(local_tensor, op="mean")

fabric.print(f"Reduced tensor: {reduced_tensor}")

Mixed Precision Training

import lightning.fabric as L

# Enable mixed precision
fabric = L.Fabric(precision="16-mixed")
fabric.launch()

model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

# Training loop with automatic mixed precision
for epoch in range(10):
    for batch in dataloader:
        x, y = batch
        
        # Forward pass (automatically uses mixed precision)
        output = model(x)
        loss = nn.functional.mse_loss(output, y)
        
        # Backward pass (automatically handles gradient scaling)
        optimizer.zero_grad()
        fabric.backward(loss)  # Handles gradient scaling
        fabric.step(optimizer)  # Handles gradient unscaling

Custom Strategy Integration

import lightning.fabric as L
from lightning.fabric.strategies import DeepSpeedStrategy

# Use custom strategy
strategy = DeepSpeedStrategy(stage=2)
fabric = L.Fabric(strategy=strategy, precision="16-mixed")
fabric.launch()

model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

# Training proceeds normally - Fabric handles strategy-specific details
for epoch in range(10):
    for batch in dataloader:
        # ... training code ...
        pass

Install with Tessl CLI

npx tessl i tessl/pypi-pytorch-lightning

docs

apps.md

fabric.md

index.md

training.md

utilities.md

tile.json