CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning

The Deep Learning framework to train, deploy, and ship AI products Lightning fast.

Pending
Overview
Eval results
Files

strategies.mddocs/

Distributed Training Strategies

Multiple strategies for distributed and parallel training including data parallel, distributed data parallel, fully sharded data parallel, model parallel, and specialized strategies for different hardware configurations.

Capabilities

Distributed Data Parallel (DDP)

Multi-GPU and multi-node distributed training strategy that replicates the model across devices and synchronizes gradients.

class DDPStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        ddp_comm_state: Optional[object] = None,
        ddp_comm_hook: Optional[Callable] = None,
        ddp_comm_wrapper: Optional[Callable] = None,
        model_averaging_period: Optional[int] = None,
        process_group_backend: Optional[str] = None,
        timeout: Optional[timedelta] = None,
        start_method: str = "popen",
        **kwargs
    ):
        """
        Initialize DDP strategy.
        
        Args:
            accelerator: Hardware accelerator to use
            parallel_devices: List of devices for parallel training
            cluster_environment: Cluster configuration
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin for mixed precision
            ddp_comm_state: DDP communication state
            ddp_comm_hook: Custom communication hook
            ddp_comm_wrapper: Communication wrapper
            model_averaging_period: Period for model averaging
            process_group_backend: Backend for process group ('nccl', 'gloo')
            timeout: Timeout for distributed operations
            start_method: Method to start processes
        """

Fully Sharded Data Parallel (FSDP)

Memory-efficient distributed training that shards model parameters, gradients, and optimizer states across devices.

class FSDPStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        process_group_backend: Optional[str] = None,
        timeout: Optional[timedelta] = None,
        auto_wrap_policy: Optional[Callable] = None,
        cpu_offload: Union[bool, CPUOffload] = False,
        mixed_precision: Optional[MixedPrecision] = None,
        sharding_strategy: Union[ShardingStrategy, str] = "FULL_SHARD",
        backward_prefetch: Optional[BackwardPrefetch] = None,
        forward_prefetch: bool = False,
        limit_all_gathers: bool = True,
        use_orig_params: bool = True,
        param_init_fn: Optional[Callable] = None,
        sync_module_states: bool = False,
        **kwargs
    ):
        """
        Initialize FSDP strategy.
        
        Args:
            accelerator: Hardware accelerator to use
            parallel_devices: List of devices for parallel training
            cluster_environment: Cluster configuration
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
            process_group_backend: Backend for process group
            timeout: Timeout for distributed operations
            auto_wrap_policy: Policy for automatic module wrapping
            cpu_offload: Enable CPU offloading of parameters
            mixed_precision: Mixed precision configuration
            sharding_strategy: Strategy for parameter sharding
            backward_prefetch: Prefetch strategy for backward pass
            forward_prefetch: Enable forward prefetching
            limit_all_gathers: Limit all-gather operations
            use_orig_params: Use original parameters
            param_init_fn: Parameter initialization function
            sync_module_states: Synchronize module states
        """

DeepSpeed Integration

Integration with Microsoft DeepSpeed for memory-efficient training of large models with advanced optimization techniques.

class DeepSpeedStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        zero_optimization: bool = True,
        stage: int = 2,
        remote_device: Optional[str] = None,
        offload_optimizer: bool = False,
        offload_parameters: bool = False,
        offload_params_device: str = "cpu",
        nvme_path: str = "/local_nvme",
        params_buffer_count: int = 5,
        params_buffer_size: int = 100_000_000,
        max_in_cpu: int = 1_000_000_000,
        offload_optimizer_device: str = "cpu",
        optimizer_buffer_count: int = 4,
        block_size: int = 1048576,
        queue_depth: int = 8,
        single_submit: bool = False,
        overlap_events: bool = True,
        thread_count: int = 1,
        pin_memory: bool = False,
        sub_group_size: int = 1_000_000_000_000,
        cpu_checkpointing: bool = False,
        contiguous_gradients: bool = True,
        overlap_comm: bool = True,
        allgather_partitions: bool = True,
        reduce_scatter: bool = True,
        allgather_bucket_size: int = 200_000_000,
        reduce_bucket_size: int = 200_000_000,
        zero_allow_untested_optimizer: bool = True,
        logging_batch_size_per_gpu: str = "auto",
        config: Optional[Union[Path, str, Dict]] = None,
        logging_level: int = logging.WARN,
        parallel_devices: Optional[List[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        process_group_backend: Optional[str] = None,
        **kwargs
    ):
        """
        Initialize DeepSpeed strategy.
        
        Args:
            accelerator: Hardware accelerator to use
            zero_optimization: Enable ZeRO optimization
            stage: ZeRO stage (1, 2, or 3)
            remote_device: Remote device for parameter storage
            offload_optimizer: Offload optimizer to CPU
            offload_parameters: Offload parameters to CPU
            offload_params_device: Device for parameter offloading
            nvme_path: Path for NVMe offloading
            params_buffer_count: Number of parameter buffers
            params_buffer_size: Size of parameter buffers
            max_in_cpu: Maximum parameters in CPU memory
            offload_optimizer_device: Device for optimizer offloading
            config: DeepSpeed configuration file or dictionary
            logging_level: Logging level for DeepSpeed
            parallel_devices: List of devices for parallel training
            cluster_environment: Cluster configuration
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
            process_group_backend: Backend for process group
        """

Data Parallel Strategy

Simple data parallelism that replicates the model on multiple devices and averages gradients.

class DataParallelStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None
    ):
        """
        Initialize DataParallel strategy.
        
        Args:
            accelerator: Hardware accelerator to use
            parallel_devices: List of devices for parallel training
            checkpoint_io: Checkpoint I/O plugin  
            precision_plugin: Precision plugin
        """

Single Device Strategy

Strategy for training on a single device (CPU or GPU).

class SingleDeviceStrategy:
    def __init__(
        self,
        device: torch.device,
        accelerator: Optional[Accelerator] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None
    ):
        """
        Initialize single device strategy.
        
        Args:
            device: Device to use for training
            accelerator: Hardware accelerator to use
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
        """

XLA Strategies

Strategies for Google TPU training using XLA compilation.

class XLAStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        debug: bool = False,
        sync_module_states: bool = True
    ):
        """
        Initialize XLA strategy for multi-TPU training.
        
        Args:
            accelerator: XLA accelerator
            parallel_devices: List of TPU devices
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
            debug: Enable debug mode
            sync_module_states: Synchronize module states
        """

class SingleDeviceXLAStrategy:
    def __init__(
        self,
        device: torch.device,
        accelerator: Optional[Accelerator] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        debug: bool = False
    ):
        """
        Initialize single TPU device strategy.
        
        Args:
            device: TPU device to use
            accelerator: XLA accelerator
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
            debug: Enable debug mode
        """

class XLAFSDPStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        **kwargs
    ):
        """
        Initialize XLA FSDP strategy combining XLA with fully sharded data parallel.
        
        Args:
            accelerator: XLA accelerator
            parallel_devices: List of TPU devices
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
        """

Model Parallel Strategy

Strategy for model parallelism where different parts of the model are placed on different devices.

class ModelParallelStrategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None
    ):
        """
        Initialize model parallel strategy.
        
        Args:
            accelerator: Hardware accelerator to use
            parallel_devices: List of devices for model placement
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
        """

Base Strategy Classes

Base classes for creating custom training strategies.

class Strategy:
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None
    ):
        """
        Base strategy class.
        
        Args:
            accelerator: Hardware accelerator
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
        """

    def setup_environment(self) -> None:
        """Set up the training environment."""

    def setup(self, trainer: Trainer) -> None:
        """Set up the strategy with trainer."""

    def teardown(self) -> None:
        """Clean up the strategy."""

class ParallelStrategy(Strategy):
    def __init__(
        self,
        accelerator: Optional[Accelerator] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None
    ):
        """
        Base parallel strategy class.
        
        Args:
            accelerator: Hardware accelerator
            parallel_devices: List of devices for parallel training
            cluster_environment: Cluster configuration
            checkpoint_io: Checkpoint I/O plugin
            precision_plugin: Precision plugin
        """

    @property
    def global_rank(self) -> int:
        """Global rank of the current process."""

    @property
    def local_rank(self) -> int:
        """Local rank of the current process."""

    @property
    def world_size(self) -> int:
        """Total number of processes."""

    def all_gather(self, tensor: torch.Tensor, sync_grads: bool = False) -> torch.Tensor:
        """Gather tensor from all processes."""

    def all_reduce(self, tensor: torch.Tensor, reduce_op: str = "mean") -> torch.Tensor:
        """Reduce tensor across all processes."""

    def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
        """Broadcast tensor from source to all processes."""

    def barrier(self, name: Optional[str] = None) -> None:
        """Synchronize all processes."""

Usage Examples

Basic Strategy Usage

from lightning import Trainer

# Use DDP strategy
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy="ddp"
)

# Use FSDP strategy  
trainer = Trainer(
    accelerator="gpu", 
    devices=8,
    strategy="fsdp"
)

Advanced Strategy Configuration

from lightning import Trainer
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy
from datetime import timedelta

# Configure DDP with custom settings
ddp_strategy = DDPStrategy(
    process_group_backend="nccl",
    timeout=timedelta(seconds=1800),
    start_method="spawn"
)

trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=ddp_strategy,
    precision="16-mixed"
)

# Configure FSDP with CPU offloading
from torch.distributed.fsdp import CPUOffload, ShardingStrategy

fsdp_strategy = FSDPStrategy(
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    cpu_offload=CPUOffload(offload_params=True),
    mixed_precision=None,  # Let Lightning handle precision
    auto_wrap_policy=None  # Use default wrapping
)

trainer = Trainer(
    accelerator="gpu",
    devices=8,
    strategy=fsdp_strategy,
    precision="bf16-mixed"
)

DeepSpeed Configuration

from lightning import Trainer
from lightning.pytorch.strategies import DeepSpeedStrategy

# DeepSpeed ZeRO Stage 3 with offloading
deepspeed_strategy = DeepSpeedStrategy(
    stage=3,
    offload_optimizer=True,
    offload_parameters=True,
    remote_device="nvme",
    nvme_path="/local_nvme"
)

trainer = Trainer(
    accelerator="gpu",
    devices=8,
    strategy=deepspeed_strategy,
    precision="16-mixed"
)

# DeepSpeed with custom config file
trainer = Trainer(
    accelerator="gpu",
    devices=8,
    strategy=DeepSpeedStrategy(config="deepspeed_config.json"),
    precision="16-mixed"
)

Install with Tessl CLI

npx tessl i tessl/pypi-lightning

docs

accelerators.md

callbacks.md

core-training.md

data.md

fabric.md

index.md

loggers.md

precision.md

profilers.md

strategies.md

tile.json