or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md
tile.json

tessl/pypi-lightning-fabric

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/lightning-fabric@2.5.x

To install, run

npx @tessl/cli install tessl/pypi-lightning-fabric@2.5.0

index.mddocs/

Lightning Fabric

Lightning Fabric is a lightweight PyTorch scaling library that provides expert-level control over PyTorch training loops and scaling strategies. It enables developers to scale complex models including foundation models, LLMs, diffusion models, transformers, and reinforcement learning across any device or scale without boilerplate code.

Package Information

  • Package Name: lightning-fabric
  • Language: Python
  • Installation: pip install lightning-fabric

Core Imports

from lightning.fabric import Fabric, seed_everything, is_wrapped

Additional commonly used utilities:

from lightning.fabric.utilities import (
    move_data_to_device,
    suggested_max_num_workers,
    rank_zero_only,
    rank_zero_warn,
    rank_zero_info
)

Basic Usage

from lightning.fabric import Fabric
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Initialize Fabric with desired configuration
fabric = Fabric(accelerator="auto", devices="auto", strategy="auto")

# Define model and optimizer
model = nn.Linear(10, 1)
optimizer = torch.optim.AdamW(model.parameters())

# Setup with Fabric (handles device placement and distributed wrapping)
model, optimizer = fabric.setup(model, optimizer)

# Setup dataloader
dataloader = fabric.setup_dataloaders(DataLoader(...))

# Training loop
model.train()
for batch in dataloader:
    x, y = batch
    optimizer.zero_grad()
    
    # Forward pass
    y_pred = model(x)
    loss = nn.functional.mse_loss(y_pred, y)
    
    # Backward pass with automatic scaling and gradient handling
    fabric.backward(loss)
    optimizer.step()

# Save checkpoint
state = {"model": model, "optimizer": optimizer}
fabric.save("checkpoint.ckpt", state)

Architecture

Lightning Fabric uses a plugin-based architecture that enables flexible scaling and customization:

  • Fabric: Main orchestrator class that coordinates all components
  • Accelerators: Hardware abstraction (CPU, GPU, TPU, MPS)
  • Strategies: Distribution patterns (single device, data parallel, model parallel, FSDP, DeepSpeed)
  • Precision Plugins: Mixed precision and quantization support
  • Environment Plugins: Cluster environment detection and configuration
  • Loggers: Experiment tracking and metric logging
  • Wrappers: Transparent wrapping of PyTorch objects for distributed training

This plugin system allows Fabric to work across any hardware configuration and distributed training setup while maintaining the same simple API.

Capabilities

Core Training Orchestration

Main Fabric class that handles distributed training setup, model and optimizer configuration, checkpoint management, and training utilities.

class Fabric:
    def __init__(
        self,
        accelerator: Union[str, Accelerator] = "auto",
        strategy: Union[str, Strategy] = "auto",
        devices: Union[list[int], str, int] = "auto",
        num_nodes: int = 1,
        precision: Optional[Union[str, int]] = None,
        plugins: Optional[Union[Any, list[Any]]] = None,
        callbacks: Optional[Union[list[Any], Any]] = None,
        loggers: Optional[Union[Logger, list[Logger]]] = None
    ): ...

    def setup(self, module, *optimizers, move_to_device=True): ...
    def setup_module(self, module, move_to_device=True): ...
    def setup_optimizers(self, *optimizers): ...
    def setup_dataloaders(self, *dataloaders, use_distributed_sampler=True): ...
    def backward(self, tensor, *args, model=None, **kwargs): ...
    def save(self, path, state, filter=None): ...
    def load(self, path, state=None, strict=True): ...

Core Training

Distributed Operations

Collective communication operations for synchronizing data and gradients across processes in distributed training.

def barrier(self, name=None) -> None: ...
def broadcast(self, obj, src=0): ...
def all_gather(self, data, group=None, sync_grads=False): ...
def all_reduce(self, data, group=None, reduce_op="mean"): ...

Distributed Operations

Accelerators

Hardware acceleration plugins for different compute devices including CPU, CUDA GPUs, Apple MPS, and TPUs.

class Accelerator: ...  # Abstract base
class CPUAccelerator(Accelerator): ...
class CUDAAccelerator(Accelerator): ...
class MPSAccelerator(Accelerator): ...
class XLAAccelerator(Accelerator): ...

Accelerators

Strategies

Distributed training strategies for scaling models across devices and nodes.

class Strategy: ...  # Abstract base
class SingleDeviceStrategy(Strategy): ...
class DataParallelStrategy(Strategy): ...
class DDPStrategy(Strategy): ...
class DeepSpeedStrategy(Strategy): ...
class FSDPStrategy(Strategy): ...
class XLAStrategy(Strategy): ...

Strategies

Precision and Quantization

Precision plugins for mixed precision training, quantization, and memory optimization.

class Precision: ...  # Abstract base
class DoublePrecision(Precision): ...
class HalfPrecision(Precision): ...
class MixedPrecision(Precision): ...
class BitsandbytesPrecision(Precision): ...
class DeepSpeedPrecision(Precision): ...
class FSDPPrecision(Precision): ...

Precision

Utilities

Helper functions for seeding, data movement, distributed utilities, and performance monitoring.

def seed_everything(seed=None, workers=False, verbose=True) -> int: ...
def is_wrapped(obj) -> bool: ...
def move_data_to_device(obj, device): ...
def suggested_max_num_workers(num_cpus): ...

Utilities

Types

# Common type aliases used throughout the API
_PATH = Union[str, Path]
_DEVICE = Union[torch.device, str, int]
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable, dict[_DEVICE, _DEVICE]]]
_PARAMETERS = Iterator[torch.nn.Parameter]
ReduceOp = torch.distributed.ReduceOp
RedOpType = ReduceOp.RedOpType

# Protocols for type checking
@runtime_checkable
class _Stateful(Protocol[_DictKey]):
    def state_dict(self) -> dict[_DictKey, Any]: ...
    def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ...

@runtime_checkable
class Steppable(Protocol):
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: ...

@runtime_checkable
class Optimizable(Steppable, Protocol):
    param_groups: list[dict[Any, Any]]
    defaults: dict[Any, Any]
    state: defaultdict[Tensor, Any]
    
    def state_dict(self) -> dict[str, dict[Any, Any]]: ...
    def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ...

@runtime_checkable
class CollectibleGroup(Protocol):
    def size(self) -> int: ...
    def rank(self) -> int: ...