Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
—
Distributed training strategies that define how models and data are distributed across devices and processes.
Abstract base class defining the strategy interface for distributed training.
class Strategy:
"""
Abstract base class for distributed training strategies.
Strategies define how models, optimizers, and data are distributed
across devices and processes for parallel training.
"""
def setup_environment(self) -> None:
"""Setup the distributed training environment."""
def setup_module(self, module: nn.Module) -> nn.Module:
"""Setup module for distributed training."""
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Setup optimizer for distributed training."""
def module_to_device(self, module: nn.Module) -> None:
"""Move module to appropriate device(s)."""
def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Tensor:
"""Reduce tensor across processes."""
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""All-gather tensor across processes."""
def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
"""Broadcast tensor from source process."""
def barrier(self, name: Optional[str] = None) -> None:
"""Synchronize all processes."""
def teardown(self) -> None:
"""Clean up strategy resources."""Strategy for training on a single device (CPU or GPU).
class SingleDeviceStrategy(Strategy):
"""
Strategy for single device training.
Handles training on a single CPU or GPU without distributed communication.
"""
def __init__(self, device: Optional[torch.device] = None):
"""
Initialize single device strategy.
Args:
device: Target device for training
"""
def setup_module(self, module: nn.Module) -> nn.Module:
"""Move module to target device."""
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Return optimizer as-is (no distribution needed)."""PyTorch DataParallel strategy for single-node multi-GPU training.
class DataParallelStrategy(Strategy):
"""
DataParallel strategy for single-node multi-GPU training.
Uses PyTorch's DataParallel for simple multi-GPU training on single node.
Limited scalability compared to DistributedDataParallel.
"""
def __init__(self, parallel_devices: Optional[list[torch.device]] = None):
"""
Initialize DataParallel strategy.
Args:
parallel_devices: List of devices to use for parallel training
"""
def setup_module(self, module: nn.Module) -> nn.Module:
"""Wrap module with DataParallel."""
def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:
"""Reduce tensor across DataParallel devices."""PyTorch DistributedDataParallel strategy for scalable multi-GPU training.
class DDPStrategy(Strategy):
"""
DistributedDataParallel strategy for scalable multi-GPU training.
Uses PyTorch's DDP for efficient distributed training across
multiple GPUs and nodes with gradient synchronization.
"""
def __init__(
self,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
model_averaging_period: Optional[int] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = None,
**kwargs
):
"""
Initialize DDP strategy.
Args:
parallel_devices: Devices for parallel training
cluster_environment: Cluster environment plugin
checkpoint_io: Checkpoint I/O plugin
precision_plugin: Precision plugin
ddp_comm_state: DDP communication state
ddp_comm_hook: Custom DDP communication hook
ddp_comm_wrapper: DDP communication wrapper
model_averaging_period: Period for model averaging
process_group_backend: Process group backend (nccl, gloo, mpi)
timeout: Timeout for distributed operations
"""
def setup_distributed(self) -> None:
"""Initialize distributed process group."""
def setup_module(self, module: nn.Module) -> nn.Module:
"""Wrap module with DistributedDataParallel."""
def configure_ddp(self) -> None:
"""Configure DDP-specific settings."""Microsoft DeepSpeed integration for large-scale model training.
class DeepSpeedStrategy(Strategy):
"""
DeepSpeed strategy for large-scale model training.
Integrates with Microsoft DeepSpeed for memory-efficient training
of large models using ZeRO optimizer states and gradients partitioning.
"""
def __init__(
self,
stage: int = 2,
remote_device: Optional[str] = None,
offload_optimizer: bool = False,
offload_parameters: bool = False,
offload_params_device: str = "cpu",
nvme_path: Optional[str] = None,
params_buffer_count: int = 5,
params_buffer_size: int = 100_000_000,
max_in_cpu: int = 1_000_000_000,
offload_optimizer_device: str = "cpu",
optimizer_buffer_count: int = 4,
block_size: int = 1048576,
queue_depth: int = 8,
single_submit: bool = False,
overlap_events: bool = True,
thread_count: int = 1,
config: Optional[Union[str, dict]] = None,
logging_level: int = logging.WARN,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = None,
**kwargs
):
"""
Initialize DeepSpeed strategy.
Args:
stage: DeepSpeed ZeRO stage (1, 2, or 3)
remote_device: Remote device for offloading
offload_optimizer: Whether to offload optimizer states
offload_parameters: Whether to offload parameters
offload_params_device: Device for parameter offloading
nvme_path: Path to NVMe storage for offloading
config: DeepSpeed configuration dict or path to config file
Other args: Additional DeepSpeed configuration options
"""
def setup_module_and_optimizers(
self,
module: nn.Module,
optimizers: list[Optimizer]
) -> tuple[nn.Module, list[Optimizer]]:
"""Setup module and optimizers with DeepSpeed engine."""
def configure_deepspeed_config(self, config: dict) -> dict:
"""Configure DeepSpeed configuration dictionary."""Fully Sharded Data Parallel strategy for memory-efficient large model training.
class FSDPStrategy(Strategy):
"""
Fully Sharded Data Parallel strategy for large model training.
Uses PyTorch's FSDP to shard model parameters, gradients, and
optimizer states across devices for memory-efficient training.
"""
def __init__(
self,
cpu_offload: Optional[bool] = None,
mixed_precision: Optional[MixedPrecision] = None,
auto_wrap_policy: Optional[callable] = None,
activation_checkpointing: Optional[bool] = None,
activation_checkpointing_policy: Optional[callable] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
state_dict_type: Optional[StateDictType] = None,
use_orig_params: bool = False,
limit_all_gathers: bool = True,
sync_module_states: bool = False,
forward_prefetch: bool = False,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = None,
**kwargs
):
"""
Initialize FSDP strategy.
Args:
cpu_offload: Whether to offload parameters and gradients to CPU
mixed_precision: Mixed precision configuration
auto_wrap_policy: Policy for automatic module wrapping
activation_checkpointing: Whether to use activation checkpointing
sharding_strategy: Parameter sharding strategy
state_dict_type: Type of state dict for checkpointing
use_orig_params: Whether to use original parameter names
"""
def setup_module(self, module: nn.Module) -> nn.Module:
"""Wrap module with FSDP."""
def configure_fsdp_auto_wrap_policy(self, module: nn.Module) -> Optional[callable]:
"""Configure automatic wrapping policy for FSDP."""XLA (TPU) strategy for training on Google Cloud TPUs.
class XLAStrategy(Strategy):
"""
XLA strategy for TPU training using PyTorch XLA.
Provides TPU support with XLA compilation for high-performance
training on Google Cloud TPU pods.
"""
def __init__(
self,
sync_module_states: bool = True,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
debug: bool = False,
**kwargs
):
"""
Initialize XLA strategy.
Args:
sync_module_states: Whether to sync module states across TPU cores
debug: Whether to enable XLA debug mode
"""
def setup_module(self, module: nn.Module) -> nn.Module:
"""Setup module for TPU training."""
def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:
"""Reduce tensor across TPU cores using XLA collectives."""
def all_gather(self, tensor: Tensor, *args, **kwargs) -> Tensor:
"""All-gather tensor across TPU cores."""
def mark_step(self) -> None:
"""Mark XLA step boundary for graph compilation."""Strategy for single XLA device training (TPU, XLA on GPU).
class SingleDeviceXLAStrategy(Strategy):
"""
Strategy for training on a single XLA device.
Optimized for single TPU core or XLA compilation on single GPU.
"""
def __init__(
self,
device: Optional[torch.device] = None,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None
):
"""Initialize single XLA device strategy."""Strategy for tensor model parallelism across multiple devices.
class ModelParallelStrategy(Strategy):
"""
Strategy for tensor model parallelism.
Splits individual model layers across multiple devices for very large models
that don't fit on a single device.
"""
def __init__(
self,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None
):
"""Initialize model parallel strategy."""Base class for multi-device parallel strategies.
class ParallelStrategy(Strategy):
"""
Base class for parallel training strategies.
Provides common functionality for strategies that distribute training
across multiple devices or processes.
"""
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[list[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None
):
"""Initialize parallel strategy."""Strategy combining XLA compilation with Fully Sharded Data Parallel for TPUs.
class XLAFSDPStrategy(XLAStrategy):
"""
Strategy combining XLA with Fully Sharded Data Parallel.
Provides FSDP sharding capabilities optimized for XLA devices,
enabling training of very large models on TPU pods.
"""
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[list[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
auto_wrap_policy: Optional[Callable] = None,
**kwargs
):
"""Initialize XLA FSDP strategy."""Global registry for strategy plugins.
class StrategyRegistry:
"""Registry for strategy plugins."""
def register(
self,
name: str,
strategy_class: type[Strategy],
description: Optional[str] = None
) -> None:
"""Register strategy class."""
def get(self, name: str) -> type[Strategy]:
"""Get strategy class by name."""
def available_strategies(self) -> list[str]:
"""Get list of available strategy names."""
def remove(self, name: str) -> None:
"""Remove strategy from registry."""
# Global registry instance
STRATEGY_REGISTRY: StrategyRegistryfrom lightning.fabric import Fabric
# Single device training
fabric = Fabric(strategy="auto") # Auto-selects single device
# Data parallel (single node, multiple GPUs)
fabric = Fabric(strategy="dp", devices=4)
# Distributed data parallel
fabric = Fabric(strategy="ddp", devices=4, num_nodes=2)# DeepSpeed ZeRO Stage 2
fabric = Fabric(
strategy="deepspeed",
devices=8,
precision="16-mixed"
)
# DeepSpeed with custom configuration
deepspeed_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"}
},
"train_micro_batch_size_per_gpu": 1
}
fabric = Fabric(
strategy=DeepSpeedStrategy(config=deepspeed_config),
devices=8
)# FSDP with CPU offloading
fabric = Fabric(
strategy="fsdp",
devices=4,
precision="bf16-mixed"
)
# FSDP with custom configuration
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
fsdp_strategy = FSDPStrategy(
cpu_offload=True,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
),
auto_wrap_policy=transformer_auto_wrap_policy,
activation_checkpointing=True
)
fabric = Fabric(strategy=fsdp_strategy, devices=8)# XLA/TPU training
fabric = Fabric(
accelerator="tpu",
strategy="xla",
devices=8,
precision="bf16-mixed"
)
# Mark XLA steps for optimal compilation
for batch in dataloader:
loss = compute_loss(model, batch)
fabric.backward(loss)
optimizer.step()
# Mark step boundary for XLA
if hasattr(fabric.strategy, 'mark_step'):
fabric.strategy.mark_step()from lightning.fabric.strategies import Strategy, STRATEGY_REGISTRY
class CustomStrategy(Strategy):
def setup_module(self, module):
# Custom module setup
return module
def reduce(self, tensor, *args, **kwargs):
# Custom reduction logic
return tensor
# Register custom strategy
STRATEGY_REGISTRY.register("custom", CustomStrategy)
# Use custom strategy
fabric = Fabric(strategy="custom")from datetime import timedelta
# DDP with custom settings
ddp_strategy = DDPStrategy(
process_group_backend="nccl",
timeout=timedelta(minutes=30),
find_unused_parameters=False, # Set via kwargs
gradient_as_bucket_view=True # Set via kwargs
)
fabric = Fabric(
strategy=ddp_strategy,
devices=4,
num_nodes=2
)Install with Tessl CLI
npx tessl i tessl/pypi-lightning-fabric