CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning-fabric

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate

Pending
Overview
Eval results
Files

strategies.mddocs/

Strategies

Distributed training strategies that define how models and data are distributed across devices and processes.

Capabilities

Base Strategy

Abstract base class defining the strategy interface for distributed training.

class Strategy:
    """
    Abstract base class for distributed training strategies.
    
    Strategies define how models, optimizers, and data are distributed
    across devices and processes for parallel training.
    """
    
    def setup_environment(self) -> None:
        """Setup the distributed training environment."""
    
    def setup_module(self, module: nn.Module) -> nn.Module:
        """Setup module for distributed training."""
    
    def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
        """Setup optimizer for distributed training."""
    
    def module_to_device(self, module: nn.Module) -> None:
        """Move module to appropriate device(s)."""
    
    def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Tensor:
        """Reduce tensor across processes."""
    
    def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
        """All-gather tensor across processes."""
    
    def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
        """Broadcast tensor from source process."""
    
    def barrier(self, name: Optional[str] = None) -> None:
        """Synchronize all processes."""
    
    def teardown(self) -> None:
        """Clean up strategy resources."""

Single Device Strategy

Strategy for training on a single device (CPU or GPU).

class SingleDeviceStrategy(Strategy):
    """
    Strategy for single device training.
    
    Handles training on a single CPU or GPU without distributed communication.
    """
    
    def __init__(self, device: Optional[torch.device] = None):
        """
        Initialize single device strategy.
        
        Args:
            device: Target device for training
        """
    
    def setup_module(self, module: nn.Module) -> nn.Module:
        """Move module to target device."""
    
    def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
        """Return optimizer as-is (no distribution needed)."""

Data Parallel Strategy

PyTorch DataParallel strategy for single-node multi-GPU training.

class DataParallelStrategy(Strategy):
    """
    DataParallel strategy for single-node multi-GPU training.
    
    Uses PyTorch's DataParallel for simple multi-GPU training on single node.
    Limited scalability compared to DistributedDataParallel.
    """
    
    def __init__(self, parallel_devices: Optional[list[torch.device]] = None):
        """
        Initialize DataParallel strategy.
        
        Args:
            parallel_devices: List of devices to use for parallel training
        """
    
    def setup_module(self, module: nn.Module) -> nn.Module:
        """Wrap module with DataParallel."""
    
    def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:
        """Reduce tensor across DataParallel devices."""

Distributed Data Parallel Strategy

PyTorch DistributedDataParallel strategy for scalable multi-GPU training.

class DDPStrategy(Strategy):
    """
    DistributedDataParallel strategy for scalable multi-GPU training.
    
    Uses PyTorch's DDP for efficient distributed training across
    multiple GPUs and nodes with gradient synchronization.
    """
    
    def __init__(
        self,
        parallel_devices: Optional[list[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None,
        ddp_comm_state: Optional[object] = None,
        ddp_comm_hook: Optional[callable] = None,
        ddp_comm_wrapper: Optional[callable] = None,
        model_averaging_period: Optional[int] = None,
        process_group_backend: Optional[str] = None,
        timeout: Optional[timedelta] = None,
        **kwargs
    ):
        """
        Initialize DDP strategy.
        
        Args:
            parallel_devices: Devices for parallel training
            cluster_environment: Cluster environment plugin
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
            ddp_comm_state: DDP communication state
            ddp_comm_hook: Custom DDP communication hook
            ddp_comm_wrapper: DDP communication wrapper
            model_averaging_period: Period for model averaging
            process_group_backend: Process group backend (nccl, gloo, mpi)
            timeout: Timeout for distributed operations
        """
    
    def setup_distributed(self) -> None:
        """Initialize distributed process group."""
    
    def setup_module(self, module: nn.Module) -> nn.Module:
        """Wrap module with DistributedDataParallel."""
    
    def configure_ddp(self) -> None:
        """Configure DDP-specific settings."""

DeepSpeed Strategy

Microsoft DeepSpeed integration for large-scale model training.

class DeepSpeedStrategy(Strategy):
    """
    DeepSpeed strategy for large-scale model training.
    
    Integrates with Microsoft DeepSpeed for memory-efficient training
    of large models using ZeRO optimizer states and gradients partitioning.
    """
    
    def __init__(
        self,
        stage: int = 2,
        remote_device: Optional[str] = None,
        offload_optimizer: bool = False,
        offload_parameters: bool = False,
        offload_params_device: str = "cpu",
        nvme_path: Optional[str] = None,
        params_buffer_count: int = 5,
        params_buffer_size: int = 100_000_000,
        max_in_cpu: int = 1_000_000_000,
        offload_optimizer_device: str = "cpu",
        optimizer_buffer_count: int = 4,
        block_size: int = 1048576,
        queue_depth: int = 8,
        single_submit: bool = False,
        overlap_events: bool = True,
        thread_count: int = 1,
        config: Optional[Union[str, dict]] = None,
        logging_level: int = logging.WARN,
        parallel_devices: Optional[list[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None,
        process_group_backend: Optional[str] = None,
        timeout: Optional[timedelta] = None,
        **kwargs
    ):
        """
        Initialize DeepSpeed strategy.
        
        Args:
            stage: DeepSpeed ZeRO stage (1, 2, or 3)
            remote_device: Remote device for offloading
            offload_optimizer: Whether to offload optimizer states
            offload_parameters: Whether to offload parameters
            offload_params_device: Device for parameter offloading
            nvme_path: Path to NVMe storage for offloading
            config: DeepSpeed configuration dict or path to config file
            Other args: Additional DeepSpeed configuration options
        """
    
    def setup_module_and_optimizers(
        self, 
        module: nn.Module, 
        optimizers: list[Optimizer]
    ) -> tuple[nn.Module, list[Optimizer]]:
        """Setup module and optimizers with DeepSpeed engine."""
    
    def configure_deepspeed_config(self, config: dict) -> dict:
        """Configure DeepSpeed configuration dictionary."""

FSDP Strategy

Fully Sharded Data Parallel strategy for memory-efficient large model training.

class FSDPStrategy(Strategy):
    """
    Fully Sharded Data Parallel strategy for large model training.
    
    Uses PyTorch's FSDP to shard model parameters, gradients, and 
    optimizer states across devices for memory-efficient training.
    """
    
    def __init__(
        self,
        cpu_offload: Optional[bool] = None,
        mixed_precision: Optional[MixedPrecision] = None,
        auto_wrap_policy: Optional[callable] = None,
        activation_checkpointing: Optional[bool] = None,
        activation_checkpointing_policy: Optional[callable] = None,
        sharding_strategy: Optional[ShardingStrategy] = None,
        state_dict_type: Optional[StateDictType] = None,
        use_orig_params: bool = False,
        limit_all_gathers: bool = True,
        sync_module_states: bool = False,
        forward_prefetch: bool = False,
        parallel_devices: Optional[list[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None,
        process_group_backend: Optional[str] = None,
        timeout: Optional[timedelta] = None,
        **kwargs
    ):
        """
        Initialize FSDP strategy.
        
        Args:
            cpu_offload: Whether to offload parameters and gradients to CPU
            mixed_precision: Mixed precision configuration
            auto_wrap_policy: Policy for automatic module wrapping
            activation_checkpointing: Whether to use activation checkpointing
            sharding_strategy: Parameter sharding strategy
            state_dict_type: Type of state dict for checkpointing
            use_orig_params: Whether to use original parameter names
        """
    
    def setup_module(self, module: nn.Module) -> nn.Module:
        """Wrap module with FSDP."""
    
    def configure_fsdp_auto_wrap_policy(self, module: nn.Module) -> Optional[callable]:
        """Configure automatic wrapping policy for FSDP."""

XLA Strategy

XLA (TPU) strategy for training on Google Cloud TPUs.

class XLAStrategy(Strategy):
    """
    XLA strategy for TPU training using PyTorch XLA.
    
    Provides TPU support with XLA compilation for high-performance
    training on Google Cloud TPU pods.
    """
    
    def __init__(
        self,
        sync_module_states: bool = True,
        parallel_devices: Optional[list[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None,
        debug: bool = False,
        **kwargs
    ):
        """
        Initialize XLA strategy.
        
        Args:
            sync_module_states: Whether to sync module states across TPU cores
            debug: Whether to enable XLA debug mode
        """
    
    def setup_module(self, module: nn.Module) -> nn.Module:
        """Setup module for TPU training."""
    
    def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:
        """Reduce tensor across TPU cores using XLA collectives."""
    
    def all_gather(self, tensor: Tensor, *args, **kwargs) -> Tensor:
        """All-gather tensor across TPU cores."""
    
    def mark_step(self) -> None:
        """Mark XLA step boundary for graph compilation."""

Single Device XLA Strategy

Strategy for single XLA device training (TPU, XLA on GPU).

class SingleDeviceXLAStrategy(Strategy):
    """
    Strategy for training on a single XLA device.
    
    Optimized for single TPU core or XLA compilation on single GPU.
    """
    
    def __init__(
        self,
        device: Optional[torch.device] = None,
        accelerator: Optional[Accelerator] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None
    ):
        """Initialize single XLA device strategy."""

Model Parallel Strategy

Strategy for tensor model parallelism across multiple devices.

class ModelParallelStrategy(Strategy):
    """
    Strategy for tensor model parallelism.
    
    Splits individual model layers across multiple devices for very large models
    that don't fit on a single device.
    """
    
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None
    ):
        """Initialize model parallel strategy."""

Parallel Strategy

Base class for multi-device parallel strategies.

class ParallelStrategy(Strategy):
    """
    Base class for parallel training strategies.
    
    Provides common functionality for strategies that distribute training
    across multiple devices or processes.
    """
    
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[list[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None
    ):
        """Initialize parallel strategy."""

XLA FSDP Strategy

Strategy combining XLA compilation with Fully Sharded Data Parallel for TPUs.

class XLAFSDPStrategy(XLAStrategy):
    """
    Strategy combining XLA with Fully Sharded Data Parallel.
    
    Provides FSDP sharding capabilities optimized for XLA devices,
    enabling training of very large models on TPU pods.
    """
    
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[list[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[Precision] = None,
        auto_wrap_policy: Optional[Callable] = None,
        **kwargs
    ):
        """Initialize XLA FSDP strategy."""

Strategy Registry

Global registry for strategy plugins.

class StrategyRegistry:
    """Registry for strategy plugins."""
    
    def register(
        self,
        name: str,
        strategy_class: type[Strategy], 
        description: Optional[str] = None
    ) -> None:
        """Register strategy class."""
    
    def get(self, name: str) -> type[Strategy]:
        """Get strategy class by name."""
    
    def available_strategies(self) -> list[str]:
        """Get list of available strategy names."""
    
    def remove(self, name: str) -> None:
        """Remove strategy from registry."""

# Global registry instance  
STRATEGY_REGISTRY: StrategyRegistry

Usage Examples

Basic Strategy Selection

from lightning.fabric import Fabric

# Single device training
fabric = Fabric(strategy="auto")  # Auto-selects single device

# Data parallel (single node, multiple GPUs)
fabric = Fabric(strategy="dp", devices=4)

# Distributed data parallel
fabric = Fabric(strategy="ddp", devices=4, num_nodes=2)

DeepSpeed Configuration

# DeepSpeed ZeRO Stage 2
fabric = Fabric(
    strategy="deepspeed",
    devices=8,
    precision="16-mixed"
)

# DeepSpeed with custom configuration
deepspeed_config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {"device": "cpu"},
        "offload_param": {"device": "cpu"}
    },
    "train_micro_batch_size_per_gpu": 1
}

fabric = Fabric(
    strategy=DeepSpeedStrategy(config=deepspeed_config),
    devices=8
)

FSDP Configuration

# FSDP with CPU offloading
fabric = Fabric(
    strategy="fsdp",
    devices=4,
    precision="bf16-mixed"
)

# FSDP with custom configuration
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

fsdp_strategy = FSDPStrategy(
    cpu_offload=True,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16
    ),
    auto_wrap_policy=transformer_auto_wrap_policy,
    activation_checkpointing=True
)

fabric = Fabric(strategy=fsdp_strategy, devices=8)

TPU Training

# XLA/TPU training
fabric = Fabric(
    accelerator="tpu", 
    strategy="xla",
    devices=8,
    precision="bf16-mixed"
)

# Mark XLA steps for optimal compilation
for batch in dataloader:
    loss = compute_loss(model, batch)
    fabric.backward(loss)
    optimizer.step()
    
    # Mark step boundary for XLA
    if hasattr(fabric.strategy, 'mark_step'):
        fabric.strategy.mark_step()

Custom Strategy

from lightning.fabric.strategies import Strategy, STRATEGY_REGISTRY

class CustomStrategy(Strategy):
    def setup_module(self, module):
        # Custom module setup
        return module
    
    def reduce(self, tensor, *args, **kwargs):
        # Custom reduction logic
        return tensor

# Register custom strategy
STRATEGY_REGISTRY.register("custom", CustomStrategy)

# Use custom strategy
fabric = Fabric(strategy="custom")

Advanced DDP Configuration

from datetime import timedelta

# DDP with custom settings
ddp_strategy = DDPStrategy(
    process_group_backend="nccl",
    timeout=timedelta(minutes=30),
    find_unused_parameters=False,  # Set via kwargs
    gradient_as_bucket_view=True   # Set via kwargs
)

fabric = Fabric(
    strategy=ddp_strategy,
    devices=4,
    num_nodes=2
)

Install with Tessl CLI

npx tessl i tessl/pypi-lightning-fabric

docs

accelerators.md

core-training.md

distributed.md

index.md

precision.md

strategies.md

utilities.md

tile.json