Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
—
Hardware acceleration plugins that provide device abstraction and optimizations for different compute platforms.
Abstract base class defining the accelerator interface.
class Accelerator:
"""
Abstract base class for hardware accelerators.
Accelerators handle device detection, setup, and hardware-specific
optimizations for training and inference.
"""
def setup_device(self, device: torch.device) -> None:
"""Setup the device for training."""
def teardown(self) -> None:
"""Clean up accelerator resources."""
def parse_devices(self, devices: Any) -> Any:
"""Parse device specification into concrete device list."""
def get_parallel_devices(self, devices: Any) -> list[torch.device]:
"""Get list of devices for parallel training."""
def auto_device_count(self) -> int:
"""Get number of available devices."""
def is_available(self) -> bool:
"""Check if accelerator is available on current system."""
@staticmethod
def register_accelerators() -> None:
"""Register accelerator in global registry."""CPU-based training acceleration with optimizations for CPU hardware.
class CPUAccelerator(Accelerator):
"""
CPU accelerator for training on CPU hardware.
Provides CPU-specific optimizations and multi-threading support.
"""
def setup_device(self, device: torch.device) -> None:
"""Setup CPU device with optimal threading configuration."""
def is_available(self) -> bool:
"""CPU is always available."""
def auto_device_count(self) -> int:
"""Returns 1 for CPU (single logical device)."""
def parse_devices(self, devices: Any) -> int:
"""Parse CPU device specification."""
def get_parallel_devices(self, devices: Any) -> list[torch.device]:
"""Get CPU device list for parallel training."""NVIDIA GPU acceleration with CUDA support and GPU-specific optimizations.
class CUDAAccelerator(Accelerator):
"""
CUDA accelerator for NVIDIA GPU training.
Provides GPU memory management, multi-GPU support, and CUDA optimizations.
"""
def setup_device(self, device: torch.device) -> None:
"""Setup CUDA device with memory and compute optimizations."""
def is_available(self) -> bool:
"""Check if CUDA is available and GPUs are present."""
def auto_device_count(self) -> int:
"""Get number of available CUDA devices."""
def parse_devices(self, devices: Any) -> Union[int, list[int]]:
"""Parse GPU device specification (IDs, count, etc.)."""
def get_parallel_devices(self, devices: Any) -> list[torch.device]:
"""Get list of CUDA devices for parallel training."""
def get_device_stats(self, device: torch.device) -> dict[str, Any]:
"""Get GPU memory and utilization statistics."""
def empty_cache(self) -> None:
"""Clear GPU memory cache."""
def set_cuda_device(self, device: torch.device) -> None:
"""Set current CUDA device."""Apple Silicon GPU acceleration using Metal Performance Shaders.
class MPSAccelerator(Accelerator):
"""
MPS (Metal Performance Shaders) accelerator for Apple Silicon.
Provides GPU acceleration on Apple M1/M2/M3 chips using Metal framework.
"""
def setup_device(self, device: torch.device) -> None:
"""Setup MPS device for Apple Silicon GPU training."""
def is_available(self) -> bool:
"""Check if MPS backend is available on current system."""
def auto_device_count(self) -> int:
"""Returns 1 for MPS (single logical GPU device)."""
def parse_devices(self, devices: Any) -> int:
"""Parse MPS device specification."""
def get_parallel_devices(self, devices: Any) -> list[torch.device]:
"""Get MPS device for training (single device)."""TPU acceleration using XLA (Accelerated Linear Algebra) compiler.
class XLAAccelerator(Accelerator):
"""
XLA accelerator for TPU training and XLA-compiled execution.
Provides TPU support and XLA compilation optimizations for
high-performance training on Google Cloud TPUs.
"""
def setup_device(self, device: torch.device) -> None:
"""Setup XLA device for TPU training."""
def is_available(self) -> bool:
"""Check if XLA/TPU runtime is available."""
def auto_device_count(self) -> int:
"""Get number of available TPU cores."""
def parse_devices(self, devices: Any) -> Union[int, list[int]]:
"""Parse TPU device specification."""
def get_parallel_devices(self, devices: Any) -> list[torch.device]:
"""Get list of TPU devices for parallel training."""
def all_gather_object(self, obj: Any) -> list[Any]:
"""TPU-specific all-gather implementation."""
def broadcast_object(self, obj: Any, src: int = 0) -> Any:
"""TPU-specific broadcast implementation."""Helper functions for device detection and management.
def find_usable_cuda_devices(num_devices: int = -1) -> list[int]:
"""
Find CUDA devices that are available and usable.
Args:
num_devices: Number of devices to find (-1 for all available)
Returns:
List of CUDA device IDs that can be used for training
Examples:
# Find all available GPUs
devices = find_usable_cuda_devices()
# Find 2 available GPUs
devices = find_usable_cuda_devices(2)
"""
def get_nvidia_gpu_stats(device: torch.device) -> dict[str, Union[int, float]]:
"""
Get NVIDIA GPU statistics and memory usage.
Args:
device: CUDA device to query
Returns:
Dictionary with GPU statistics including memory usage,
utilization, temperature, and power consumption
"""Global registry system for discovering and instantiating accelerators.
class AcceleratorRegistry:
"""Registry for accelerator plugins."""
def register(
self,
name: str,
accelerator_class: type[Accelerator],
description: Optional[str] = None
) -> None:
"""Register an accelerator class."""
def get(self, name: str) -> type[Accelerator]:
"""Get accelerator class by name."""
def available_accelerators(self) -> list[str]:
"""Get list of available accelerator names."""
def remove(self, name: str) -> None:
"""Remove accelerator from registry."""
# Global registry instance
ACCELERATOR_REGISTRY: AcceleratorRegistryfrom lightning.fabric import Fabric
# Auto-detect best available accelerator
fabric = Fabric(accelerator="auto")
print(f"Using accelerator: {fabric.accelerator.__class__.__name__}")# Use specific accelerator types
fabric_gpu = Fabric(accelerator="cuda", devices=2)
fabric_cpu = Fabric(accelerator="cpu")
fabric_mps = Fabric(accelerator="mps") # Apple Silicon
fabric_tpu = Fabric(accelerator="tpu", devices=8) # TPU v3/v4# Use specific GPU devices
fabric = Fabric(accelerator="cuda", devices=[0, 2, 3])
# Use all available GPUs
fabric = Fabric(accelerator="gpu", devices="auto")
# Use specific number of GPUs
fabric = Fabric(accelerator="gpu", devices=4)from lightning.fabric.accelerators import Accelerator, ACCELERATOR_REGISTRY
class CustomAccelerator(Accelerator):
def setup_device(self, device):
# Custom device setup logic
pass
def is_available(self):
# Custom availability check
return True
# Register custom accelerator
ACCELERATOR_REGISTRY.register("custom", CustomAccelerator)
# Use custom accelerator
fabric = Fabric(accelerator="custom")from lightning.fabric.accelerators.cuda import find_usable_cuda_devices
# Find available GPUs
available_gpus = find_usable_cuda_devices()
print(f"Available GPUs: {available_gpus}")
# Use subset of available GPUs
if len(available_gpus) >= 2:
fabric = Fabric(accelerator="cuda", devices=available_gpus[:2])# TPU training setup
fabric = Fabric(
accelerator="tpu",
devices=8, # TPU v3/v4 pod
precision="bf16-mixed" # BFloat16 for TPUs
)
# Access TPU-specific methods
if hasattr(fabric.accelerator, 'all_gather_object'):
result = fabric.accelerator.all_gather_object(local_data)# Monitor GPU usage during training
if fabric.accelerator.__class__.__name__ == 'CUDAAccelerator':
stats = fabric.accelerator.get_device_stats(fabric.device)
fabric.print(f"GPU Memory: {stats['memory_used']}/{stats['memory_total']} MB")
# Clear cache if needed
if stats['memory_used'] / stats['memory_total'] > 0.9:
fabric.accelerator.empty_cache()Install with Tessl CLI
npx tessl i tessl/pypi-lightning-fabric