HuggingFace Accelerate is a PyTorch library that simplifies distributed and mixed-precision training by abstracting away the boilerplate code needed for multi-GPU, TPU, and mixed-precision setups.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The main Accelerator class and essential training functionality that forms the foundation of Accelerate's distributed training capabilities. This includes mixed precision support, gradient accumulation, device management, and basic distributed operations.
The central orchestrator for distributed training that handles hardware detection, mixed precision setup, and training component preparation.
class Accelerator:
"""
Main class for coordinating distributed training and mixed precision.
Handles device placement, distributed backend setup, mixed precision
configuration, and provides training utilities.
"""
def __init__(
self,
device_placement: bool = True,
split_batches: bool = False,
mixed_precision: str | None = None,
gradient_accumulation_steps: int = 1,
cpu: bool = False,
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str] | None = None,
log_with: str | list[str] | None = None,
project_dir: str | None = None,
project_config: ProjectConfiguration | None = None,
gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: list[KwargsHandler] | None = None,
dynamo_backend: str | None = None,
dynamo_plugin: TorchDynamoPlugin | None = None,
parallelism_config: ParallelismConfig | None = None
):
"""
Initialize Accelerator with training configuration.
Parameters:
- device_placement: Whether to automatically place tensors on correct device
- split_batches: Whether to split batches across processes
- mixed_precision: Mixed precision mode ("no", "fp16", "bf16", "fp8")
- gradient_accumulation_steps: Number of steps to accumulate gradients
- cpu: Force CPU usage even if GPU available
- dataloader_config: DataLoader behavior configuration
- deepspeed_plugin: DeepSpeed configuration plugin (single or per-model dict)
- fsdp_plugin: FSDP configuration plugin
- megatron_lm_plugin: Megatron-LM configuration plugin
- rng_types: Random number generator types to synchronize
- log_with: Experiment tracking backends to use
- project_dir: Directory for project outputs
- project_config: Project and logging configuration
- gradient_accumulation_plugin: Gradient accumulation configuration
- step_scheduler_with_optimizer: Whether to step scheduler with optimizer
- kwargs_handlers: Additional configuration handlers
- dynamo_backend: Backend for torch.compile optimization
- dynamo_plugin: Torch Dynamo configuration plugin
- parallelism_config: Parallelism configuration object
"""Methods for preparing models, optimizers, and data loaders for distributed training.
def prepare(self, *args):
"""
Prepare models, optimizers, dataloaders for distributed training.
Automatically wraps objects for the current distributed setup and
applies mixed precision, device placement, and other configurations.
Parameters:
- *args: Models, optimizers, dataloaders, schedulers to prepare
Returns:
Tuple of prepared objects in same order as input
"""
def prepare_model(self, model: torch.nn.Module, device_placement: bool | None = None):
"""
Prepare a single model for distributed training.
Parameters:
- model: PyTorch model to prepare
- device_placement: Override default device placement behavior
Returns:
Prepared model wrapped for distributed training
"""
def prepare_optimizer(self, optimizer: torch.optim.Optimizer):
"""
Prepare optimizer for distributed training.
Parameters:
- optimizer: PyTorch optimizer to prepare
Returns:
Wrapped optimizer for distributed training
"""
def prepare_data_loader(
self,
data_loader: torch.utils.data.DataLoader,
device_placement: bool | None = None
):
"""
Prepare DataLoader for distributed training.
Parameters:
- data_loader: PyTorch DataLoader to prepare
- device_placement: Override default device placement
Returns:
DataLoader configured for distributed training
"""
def prepare_scheduler(self, scheduler):
"""
Prepare learning rate scheduler for distributed training.
Parameters:
- scheduler: PyTorch scheduler to prepare
Returns:
Wrapped scheduler for distributed training
"""Core training operations including backward pass, gradient clipping, and model unwrapping.
def backward(self, loss: torch.Tensor, **kwargs):
"""
Perform backward pass with automatic mixed precision scaling.
Parameters:
- loss: Loss tensor to compute gradients from
- **kwargs: Additional arguments passed to loss.backward()
"""
def clip_grad_norm_(
self,
parameters,
max_norm: float,
norm_type: float = 2.0
):
"""
Clip gradient norm across all processes.
Parameters:
- parameters: Model parameters or parameter groups
- max_norm: Maximum norm of gradients
- norm_type: Type of norm to compute (default: 2.0)
Returns:
Total norm of parameters (viewed as single vector)
"""
def clip_grad_value_(self, parameters, clip_value: float):
"""
Clip gradient values to specified range.
Parameters:
- parameters: Model parameters to clip
- clip_value: Maximum absolute value for gradients
"""
def unwrap_model(self, model: torch.nn.Module, keep_fp32_wrapper: bool = True):
"""
Extract original model from distributed training wrappers.
Parameters:
- model: Wrapped model from prepare()
- keep_fp32_wrapper: Whether to keep mixed precision wrapper
Returns:
Original unwrapped model
"""Basic distributed operations for gathering, reducing, and broadcasting tensors.
def gather(self, tensor: torch.Tensor):
"""
Gather tensor from all processes.
Parameters:
- tensor: Tensor to gather across processes
Returns:
Concatenated tensor from all processes (on main process only)
"""
def gather_for_metrics(self, input_data):
"""
Gather data from all processes for metrics computation.
Automatically handles padding for uneven batch sizes.
Parameters:
- input_data: Data to gather (tensors, lists, dicts)
Returns:
Gathered data from all processes
"""
def reduce(self, tensor: torch.Tensor, reduction: str = "mean"):
"""
Reduce tensor across all processes.
Parameters:
- tensor: Tensor to reduce
- reduction: Reduction operation ("mean", "sum")
Returns:
Reduced tensor
"""
def pad_across_processes(self, tensor: torch.Tensor, dim: int = 0, pad_index: int = 0):
"""
Pad tensor to same size across all processes.
Parameters:
- tensor: Tensor to pad
- dim: Dimension to pad along
- pad_index: Value to use for padding
Returns:
Padded tensor
"""Context managers for controlling training behavior and process synchronization.
def accumulate(self, *models):
"""
Context manager for gradient accumulation.
Automatically handles gradient synchronization based on
gradient_accumulation_steps configuration.
Parameters:
- *models: Models to control gradient synchronization for
"""
def no_sync(self, *models):
"""
Context manager to disable gradient synchronization.
Parameters:
- *models: Models to disable synchronization for
"""
def main_process_first(self):
"""
Context manager to run code on main process first.
Ensures main process completes before other processes continue.
Useful for dataset preprocessing, model downloading, etc.
"""
def local_main_process_first(self):
"""
Context manager to run code on local main process first.
Similar to main_process_first but per-node instead of global.
"""
def autocast(self, cache_enabled: bool | None = None):
"""
Context manager for mixed precision autocast.
Parameters:
- cache_enabled: Whether to enable autocast cache
Returns:
Autocast context manager configured for current precision
"""Methods for process management, synchronization, and training control.
def wait_for_everyone(self):
"""
Synchronization barrier - wait for all processes to reach this point.
"""
def print(self, *args, **kwargs):
"""
Print only on the main process.
Parameters:
- *args: Arguments to print
- **kwargs: Keyword arguments for print function
"""
def split_between_processes(self, inputs, apply_padding: bool = False):
"""
Split inputs between processes for distributed processing.
Parameters:
- inputs: Data to split between processes
- apply_padding: Whether to pad to equal sizes
Returns:
Portion of inputs for current process
"""
def free_memory(self):
"""
Free memory by clearing internal caches and calling garbage collection.
"""
def clear(self):
"""
Reset Accelerator to initial state and free memory.
"""
def skip_first_batches(self, dataloader, num_batches: int):
"""
Skip the first num_batches in a DataLoader.
Parameters:
- dataloader: DataLoader to skip batches from
- num_batches: Number of batches to skip
Returns:
DataLoader starting from the specified batch
"""
def verify_device_map(self, model: torch.nn.Module):
"""
Verify that the device map is valid for the given model.
Parameters:
- model: Model to verify device map for
"""
def lomo_backward(self, loss: torch.Tensor, learning_rate: float):
"""
Perform LOMO (Low-Memory Optimization) backward pass.
Parameters:
- loss: Loss tensor to compute gradients from
- learning_rate: Learning rate for LOMO optimizer
"""
def set_trigger(self):
"""
Set trigger for manual gradient synchronization control.
"""
def check_trigger(self):
"""
Check if gradient synchronization trigger is set.
Returns:
bool: Whether trigger is set
"""Key properties providing information about the training environment.
@property
def device(self) -> torch.device:
"""Current device for this process."""
@property
def state(self) -> PartialState:
"""Access to the underlying PartialState."""
@property
def is_main_process(self) -> bool:
"""Whether this is the main process (rank 0)."""
@property
def is_local_main_process(self) -> bool:
"""Whether this is the local main process on this node."""
@property
def process_index(self) -> int:
"""Global process index/rank."""
@property
def local_process_index(self) -> int:
"""Local process index on this node."""
@property
def num_processes(self) -> int:
"""Total number of processes."""
@property
def distributed_type(self) -> DistributedType:
"""Type of distributed training backend being used."""
@property
def mixed_precision(self) -> str:
"""Mixed precision mode being used."""
@property
def use_distributed(self) -> bool:
"""Whether distributed training is being used."""
@property
def should_save_model(self) -> bool:
"""Whether this process should save the model."""
@property
def tensor_parallel_rank(self) -> int:
"""Tensor parallelism rank for this process."""
@property
def pipeline_parallel_rank(self) -> int:
"""Pipeline parallelism rank for this process."""
@property
def context_parallel_rank(self) -> int:
"""Context parallelism rank for this process."""
@property
def data_parallel_rank(self) -> int:
"""Data parallelism rank for this process."""
@property
def fp8_backend(self) -> str | None:
"""FP8 backend being used."""
@property
def is_fsdp2(self) -> bool:
"""Whether FSDP2 is being used."""from accelerate import Accelerator
import torch
import torch.nn as nn
# Initialize with mixed precision
accelerator = Accelerator(
mixed_precision="fp16",
gradient_accumulation_steps=4
)
# Create model and optimizer
model = nn.Linear(784, 10)
optimizer = torch.optim.Adam(model.parameters())
# Prepare for distributed training
model, optimizer = accelerator.prepare(model, optimizer)
# Training loop with gradient accumulation
for batch in dataloader:
with accelerator.accumulate(model):
outputs = model(batch['input'])
loss = criterion(outputs, batch['labels'])
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()from accelerate import Accelerator, DataLoaderConfiguration, ProjectConfiguration
# Advanced configuration
dataloader_config = DataLoaderConfiguration(
split_batches=True,
dispatch_batches=False
)
project_config = ProjectConfiguration(
project_dir="./experiments",
automatic_checkpoint_naming=True,
total_limit=5
)
accelerator = Accelerator(
device_placement=True,
mixed_precision="bf16",
gradient_accumulation_steps=8,
dataloader_config=dataloader_config,
project_config=project_config
)Install with Tessl CLI
npx tessl i tessl/pypi-accelerate