The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
—
Multiple strategies for distributed and parallel training including data parallel, distributed data parallel, fully sharded data parallel, model parallel, and specialized strategies for different hardware configurations.
Multi-GPU and multi-node distributed training strategy that replicates the model across devices and synchronizes gradients.
class DDPStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
model_averaging_period: Optional[int] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = None,
start_method: str = "popen",
**kwargs
):
"""
Initialize DDP strategy.
Args:
accelerator: Hardware accelerator to use
parallel_devices: List of devices for parallel training
cluster_environment: Cluster configuration
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin for mixed precision
ddp_comm_state: DDP communication state
ddp_comm_hook: Custom communication hook
ddp_comm_wrapper: Communication wrapper
model_averaging_period: Period for model averaging
process_group_backend: Backend for process group ('nccl', 'gloo')
timeout: Timeout for distributed operations
start_method: Method to start processes
"""Memory-efficient distributed training that shards model parameters, gradients, and optimizer states across devices.
class FSDPStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = None,
auto_wrap_policy: Optional[Callable] = None,
cpu_offload: Union[bool, CPUOffload] = False,
mixed_precision: Optional[MixedPrecision] = None,
sharding_strategy: Union[ShardingStrategy, str] = "FULL_SHARD",
backward_prefetch: Optional[BackwardPrefetch] = None,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = True,
param_init_fn: Optional[Callable] = None,
sync_module_states: bool = False,
**kwargs
):
"""
Initialize FSDP strategy.
Args:
accelerator: Hardware accelerator to use
parallel_devices: List of devices for parallel training
cluster_environment: Cluster configuration
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
process_group_backend: Backend for process group
timeout: Timeout for distributed operations
auto_wrap_policy: Policy for automatic module wrapping
cpu_offload: Enable CPU offloading of parameters
mixed_precision: Mixed precision configuration
sharding_strategy: Strategy for parameter sharding
backward_prefetch: Prefetch strategy for backward pass
forward_prefetch: Enable forward prefetching
limit_all_gathers: Limit all-gather operations
use_orig_params: Use original parameters
param_init_fn: Parameter initialization function
sync_module_states: Synchronize module states
"""Integration with Microsoft DeepSpeed for memory-efficient training of large models with advanced optimization techniques.
class DeepSpeedStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
zero_optimization: bool = True,
stage: int = 2,
remote_device: Optional[str] = None,
offload_optimizer: bool = False,
offload_parameters: bool = False,
offload_params_device: str = "cpu",
nvme_path: str = "/local_nvme",
params_buffer_count: int = 5,
params_buffer_size: int = 100_000_000,
max_in_cpu: int = 1_000_000_000,
offload_optimizer_device: str = "cpu",
optimizer_buffer_count: int = 4,
block_size: int = 1048576,
queue_depth: int = 8,
single_submit: bool = False,
overlap_events: bool = True,
thread_count: int = 1,
pin_memory: bool = False,
sub_group_size: int = 1_000_000_000_000,
cpu_checkpointing: bool = False,
contiguous_gradients: bool = True,
overlap_comm: bool = True,
allgather_partitions: bool = True,
reduce_scatter: bool = True,
allgather_bucket_size: int = 200_000_000,
reduce_bucket_size: int = 200_000_000,
zero_allow_untested_optimizer: bool = True,
logging_batch_size_per_gpu: str = "auto",
config: Optional[Union[Path, str, Dict]] = None,
logging_level: int = logging.WARN,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
process_group_backend: Optional[str] = None,
**kwargs
):
"""
Initialize DeepSpeed strategy.
Args:
accelerator: Hardware accelerator to use
zero_optimization: Enable ZeRO optimization
stage: ZeRO stage (1, 2, or 3)
remote_device: Remote device for parameter storage
offload_optimizer: Offload optimizer to CPU
offload_parameters: Offload parameters to CPU
offload_params_device: Device for parameter offloading
nvme_path: Path for NVMe offloading
params_buffer_count: Number of parameter buffers
params_buffer_size: Size of parameter buffers
max_in_cpu: Maximum parameters in CPU memory
offload_optimizer_device: Device for optimizer offloading
config: DeepSpeed configuration file or dictionary
logging_level: Logging level for DeepSpeed
parallel_devices: List of devices for parallel training
cluster_environment: Cluster configuration
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
process_group_backend: Backend for process group
"""Simple data parallelism that replicates the model on multiple devices and averages gradients.
class DataParallelStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None
):
"""
Initialize DataParallel strategy.
Args:
accelerator: Hardware accelerator to use
parallel_devices: List of devices for parallel training
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
"""Strategy for training on a single device (CPU or GPU).
class SingleDeviceStrategy:
def __init__(
self,
device: torch.device,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None
):
"""
Initialize single device strategy.
Args:
device: Device to use for training
accelerator: Hardware accelerator to use
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
"""Strategies for Google TPU training using XLA compilation.
class XLAStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False,
sync_module_states: bool = True
):
"""
Initialize XLA strategy for multi-TPU training.
Args:
accelerator: XLA accelerator
parallel_devices: List of TPU devices
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
debug: Enable debug mode
sync_module_states: Synchronize module states
"""
class SingleDeviceXLAStrategy:
def __init__(
self,
device: torch.device,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False
):
"""
Initialize single TPU device strategy.
Args:
device: TPU device to use
accelerator: XLA accelerator
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
debug: Enable debug mode
"""
class XLAFSDPStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
**kwargs
):
"""
Initialize XLA FSDP strategy combining XLA with fully sharded data parallel.
Args:
accelerator: XLA accelerator
parallel_devices: List of TPU devices
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
"""Strategy for model parallelism where different parts of the model are placed on different devices.
class ModelParallelStrategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None
):
"""
Initialize model parallel strategy.
Args:
accelerator: Hardware accelerator to use
parallel_devices: List of devices for model placement
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
"""Base classes for creating custom training strategies.
class Strategy:
def __init__(
self,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None
):
"""
Base strategy class.
Args:
accelerator: Hardware accelerator
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
"""
def setup_environment(self) -> None:
"""Set up the training environment."""
def setup(self, trainer: Trainer) -> None:
"""Set up the strategy with trainer."""
def teardown(self) -> None:
"""Clean up the strategy."""
class ParallelStrategy(Strategy):
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None
):
"""
Base parallel strategy class.
Args:
accelerator: Hardware accelerator
parallel_devices: List of devices for parallel training
cluster_environment: Cluster configuration
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
"""
@property
def global_rank(self) -> int:
"""Global rank of the current process."""
@property
def local_rank(self) -> int:
"""Local rank of the current process."""
@property
def world_size(self) -> int:
"""Total number of processes."""
def all_gather(self, tensor: torch.Tensor, sync_grads: bool = False) -> torch.Tensor:
"""Gather tensor from all processes."""
def all_reduce(self, tensor: torch.Tensor, reduce_op: str = "mean") -> torch.Tensor:
"""Reduce tensor across all processes."""
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast tensor from source to all processes."""
def barrier(self, name: Optional[str] = None) -> None:
"""Synchronize all processes."""from lightning import Trainer
# Use DDP strategy
trainer = Trainer(
accelerator="gpu",
devices=4,
strategy="ddp"
)
# Use FSDP strategy
trainer = Trainer(
accelerator="gpu",
devices=8,
strategy="fsdp"
)from lightning import Trainer
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy
from datetime import timedelta
# Configure DDP with custom settings
ddp_strategy = DDPStrategy(
process_group_backend="nccl",
timeout=timedelta(seconds=1800),
start_method="spawn"
)
trainer = Trainer(
accelerator="gpu",
devices=4,
strategy=ddp_strategy,
precision="16-mixed"
)
# Configure FSDP with CPU offloading
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
fsdp_strategy = FSDPStrategy(
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=CPUOffload(offload_params=True),
mixed_precision=None, # Let Lightning handle precision
auto_wrap_policy=None # Use default wrapping
)
trainer = Trainer(
accelerator="gpu",
devices=8,
strategy=fsdp_strategy,
precision="bf16-mixed"
)from lightning import Trainer
from lightning.pytorch.strategies import DeepSpeedStrategy
# DeepSpeed ZeRO Stage 3 with offloading
deepspeed_strategy = DeepSpeedStrategy(
stage=3,
offload_optimizer=True,
offload_parameters=True,
remote_device="nvme",
nvme_path="/local_nvme"
)
trainer = Trainer(
accelerator="gpu",
devices=8,
strategy=deepspeed_strategy,
precision="16-mixed"
)
# DeepSpeed with custom config file
trainer = Trainer(
accelerator="gpu",
devices=8,
strategy=DeepSpeedStrategy(config="deepspeed_config.json"),
precision="16-mixed"
)Install with Tessl CLI
npx tessl i tessl/pypi-lightning