CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-accelerate

HuggingFace Accelerate is a PyTorch library that simplifies distributed and mixed-precision training by abstracting away the boilerplate code needed for multi-GPU, TPU, and mixed-precision setups.

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

core-training.mddocs/

Core Training

The main Accelerator class and essential training functionality that forms the foundation of Accelerate's distributed training capabilities. This includes mixed precision support, gradient accumulation, device management, and basic distributed operations.

Capabilities

Accelerator Class

The central orchestrator for distributed training that handles hardware detection, mixed precision setup, and training component preparation.

class Accelerator:
    """
    Main class for coordinating distributed training and mixed precision.
    
    Handles device placement, distributed backend setup, mixed precision
    configuration, and provides training utilities.
    """
    
    def __init__(
        self,
        device_placement: bool = True,
        split_batches: bool = False,
        mixed_precision: str | None = None,
        gradient_accumulation_steps: int = 1,
        cpu: bool = False,
        dataloader_config: DataLoaderConfiguration | None = None,
        deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
        fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
        megatron_lm_plugin: MegatronLMPlugin | None = None,
        rng_types: list[str] | None = None,
        log_with: str | list[str] | None = None,
        project_dir: str | None = None,
        project_config: ProjectConfiguration | None = None,
        gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
        step_scheduler_with_optimizer: bool = True,
        kwargs_handlers: list[KwargsHandler] | None = None,
        dynamo_backend: str | None = None,
        dynamo_plugin: TorchDynamoPlugin | None = None,
        parallelism_config: ParallelismConfig | None = None
    ):
        """
        Initialize Accelerator with training configuration.
        
        Parameters:
        - device_placement: Whether to automatically place tensors on correct device
        - split_batches: Whether to split batches across processes
        - mixed_precision: Mixed precision mode ("no", "fp16", "bf16", "fp8")
        - gradient_accumulation_steps: Number of steps to accumulate gradients
        - cpu: Force CPU usage even if GPU available
        - dataloader_config: DataLoader behavior configuration
        - deepspeed_plugin: DeepSpeed configuration plugin (single or per-model dict)
        - fsdp_plugin: FSDP configuration plugin
        - megatron_lm_plugin: Megatron-LM configuration plugin
        - rng_types: Random number generator types to synchronize
        - log_with: Experiment tracking backends to use
        - project_dir: Directory for project outputs
        - project_config: Project and logging configuration
        - gradient_accumulation_plugin: Gradient accumulation configuration
        - step_scheduler_with_optimizer: Whether to step scheduler with optimizer
        - kwargs_handlers: Additional configuration handlers
        - dynamo_backend: Backend for torch.compile optimization
        - dynamo_plugin: Torch Dynamo configuration plugin
        - parallelism_config: Parallelism configuration object
        """

Training Preparation

Methods for preparing models, optimizers, and data loaders for distributed training.

def prepare(self, *args):
    """
    Prepare models, optimizers, dataloaders for distributed training.
    
    Automatically wraps objects for the current distributed setup and
    applies mixed precision, device placement, and other configurations.
    
    Parameters:
    - *args: Models, optimizers, dataloaders, schedulers to prepare
    
    Returns:
    Tuple of prepared objects in same order as input
    """

def prepare_model(self, model: torch.nn.Module, device_placement: bool | None = None):
    """
    Prepare a single model for distributed training.
    
    Parameters:
    - model: PyTorch model to prepare
    - device_placement: Override default device placement behavior
    
    Returns:
    Prepared model wrapped for distributed training
    """

def prepare_optimizer(self, optimizer: torch.optim.Optimizer):
    """
    Prepare optimizer for distributed training.
    
    Parameters:
    - optimizer: PyTorch optimizer to prepare
    
    Returns:
    Wrapped optimizer for distributed training
    """

def prepare_data_loader(
    self, 
    data_loader: torch.utils.data.DataLoader,
    device_placement: bool | None = None
):
    """
    Prepare DataLoader for distributed training.
    
    Parameters:
    - data_loader: PyTorch DataLoader to prepare
    - device_placement: Override default device placement
    
    Returns:
    DataLoader configured for distributed training
    """

def prepare_scheduler(self, scheduler):
    """
    Prepare learning rate scheduler for distributed training.
    
    Parameters:
    - scheduler: PyTorch scheduler to prepare
    
    Returns:
    Wrapped scheduler for distributed training
    """

Training Operations

Core training operations including backward pass, gradient clipping, and model unwrapping.

def backward(self, loss: torch.Tensor, **kwargs):
    """
    Perform backward pass with automatic mixed precision scaling.
    
    Parameters:
    - loss: Loss tensor to compute gradients from
    - **kwargs: Additional arguments passed to loss.backward()
    """

def clip_grad_norm_(
    self,
    parameters,
    max_norm: float,
    norm_type: float = 2.0
):
    """
    Clip gradient norm across all processes.
    
    Parameters:
    - parameters: Model parameters or parameter groups
    - max_norm: Maximum norm of gradients
    - norm_type: Type of norm to compute (default: 2.0)
    
    Returns:
    Total norm of parameters (viewed as single vector)
    """

def clip_grad_value_(self, parameters, clip_value: float):
    """
    Clip gradient values to specified range.
    
    Parameters:
    - parameters: Model parameters to clip
    - clip_value: Maximum absolute value for gradients
    """

def unwrap_model(self, model: torch.nn.Module, keep_fp32_wrapper: bool = True):
    """
    Extract original model from distributed training wrappers.
    
    Parameters:
    - model: Wrapped model from prepare()
    - keep_fp32_wrapper: Whether to keep mixed precision wrapper
    
    Returns:
    Original unwrapped model
    """

Distributed Communication

Basic distributed operations for gathering, reducing, and broadcasting tensors.

def gather(self, tensor: torch.Tensor):
    """
    Gather tensor from all processes.
    
    Parameters:
    - tensor: Tensor to gather across processes
    
    Returns:
    Concatenated tensor from all processes (on main process only)
    """

def gather_for_metrics(self, input_data):
    """
    Gather data from all processes for metrics computation.
    
    Automatically handles padding for uneven batch sizes.
    
    Parameters:
    - input_data: Data to gather (tensors, lists, dicts)
    
    Returns:
    Gathered data from all processes
    """

def reduce(self, tensor: torch.Tensor, reduction: str = "mean"):
    """
    Reduce tensor across all processes.
    
    Parameters:
    - tensor: Tensor to reduce
    - reduction: Reduction operation ("mean", "sum")
    
    Returns:
    Reduced tensor
    """

def pad_across_processes(self, tensor: torch.Tensor, dim: int = 0, pad_index: int = 0):
    """
    Pad tensor to same size across all processes.
    
    Parameters:
    - tensor: Tensor to pad
    - dim: Dimension to pad along
    - pad_index: Value to use for padding
    
    Returns:
    Padded tensor
    """

Context Managers

Context managers for controlling training behavior and process synchronization.

def accumulate(self, *models):
    """
    Context manager for gradient accumulation.
    
    Automatically handles gradient synchronization based on 
    gradient_accumulation_steps configuration.
    
    Parameters:
    - *models: Models to control gradient synchronization for
    """

def no_sync(self, *models):
    """
    Context manager to disable gradient synchronization.
    
    Parameters:
    - *models: Models to disable synchronization for
    """

def main_process_first(self):
    """
    Context manager to run code on main process first.
    
    Ensures main process completes before other processes continue.
    Useful for dataset preprocessing, model downloading, etc.
    """

def local_main_process_first(self):
    """
    Context manager to run code on local main process first.
    
    Similar to main_process_first but per-node instead of global.
    """

def autocast(self, cache_enabled: bool | None = None):
    """
    Context manager for mixed precision autocast.
    
    Parameters:
    - cache_enabled: Whether to enable autocast cache
    
    Returns:
    Autocast context manager configured for current precision
    """

Process Control and Utilities

Methods for process management, synchronization, and training control.

def wait_for_everyone(self):
    """
    Synchronization barrier - wait for all processes to reach this point.
    """

def print(self, *args, **kwargs):
    """
    Print only on the main process.
    
    Parameters:
    - *args: Arguments to print
    - **kwargs: Keyword arguments for print function
    """

def split_between_processes(self, inputs, apply_padding: bool = False):
    """
    Split inputs between processes for distributed processing.
    
    Parameters:
    - inputs: Data to split between processes
    - apply_padding: Whether to pad to equal sizes
    
    Returns:
    Portion of inputs for current process
    """

def free_memory(self):
    """
    Free memory by clearing internal caches and calling garbage collection.
    """

def clear(self):
    """
    Reset Accelerator to initial state and free memory.
    """

def skip_first_batches(self, dataloader, num_batches: int):
    """
    Skip the first num_batches in a DataLoader.
    
    Parameters:
    - dataloader: DataLoader to skip batches from
    - num_batches: Number of batches to skip
    
    Returns:
    DataLoader starting from the specified batch
    """

def verify_device_map(self, model: torch.nn.Module):
    """
    Verify that the device map is valid for the given model.
    
    Parameters:
    - model: Model to verify device map for
    """

def lomo_backward(self, loss: torch.Tensor, learning_rate: float):
    """
    Perform LOMO (Low-Memory Optimization) backward pass.
    
    Parameters:
    - loss: Loss tensor to compute gradients from
    - learning_rate: Learning rate for LOMO optimizer
    """

def set_trigger(self):
    """
    Set trigger for manual gradient synchronization control.
    """

def check_trigger(self):
    """
    Check if gradient synchronization trigger is set.
    
    Returns:
    bool: Whether trigger is set
    """

Properties

Key properties providing information about the training environment.

@property
def device(self) -> torch.device:
    """Current device for this process."""

@property
def state(self) -> PartialState:
    """Access to the underlying PartialState."""

@property
def is_main_process(self) -> bool:
    """Whether this is the main process (rank 0)."""

@property
def is_local_main_process(self) -> bool:
    """Whether this is the local main process on this node."""

@property
def process_index(self) -> int:
    """Global process index/rank."""

@property
def local_process_index(self) -> int:
    """Local process index on this node."""

@property
def num_processes(self) -> int:
    """Total number of processes."""

@property
def distributed_type(self) -> DistributedType:
    """Type of distributed training backend being used."""

@property
def mixed_precision(self) -> str:
    """Mixed precision mode being used."""

@property
def use_distributed(self) -> bool:
    """Whether distributed training is being used."""

@property
def should_save_model(self) -> bool:
    """Whether this process should save the model."""

@property
def tensor_parallel_rank(self) -> int:
    """Tensor parallelism rank for this process."""

@property
def pipeline_parallel_rank(self) -> int:
    """Pipeline parallelism rank for this process."""

@property
def context_parallel_rank(self) -> int:
    """Context parallelism rank for this process."""

@property
def data_parallel_rank(self) -> int:
    """Data parallelism rank for this process."""

@property
def fp8_backend(self) -> str | None:
    """FP8 backend being used."""

@property
def is_fsdp2(self) -> bool:
    """Whether FSDP2 is being used."""

Usage Examples

Basic Training Setup

from accelerate import Accelerator
import torch
import torch.nn as nn

# Initialize with mixed precision
accelerator = Accelerator(
    mixed_precision="fp16",
    gradient_accumulation_steps=4
)

# Create model and optimizer
model = nn.Linear(784, 10)
optimizer = torch.optim.Adam(model.parameters())

# Prepare for distributed training
model, optimizer = accelerator.prepare(model, optimizer)

# Training loop with gradient accumulation
for batch in dataloader:
    with accelerator.accumulate(model):
        outputs = model(batch['input'])
        loss = criterion(outputs, batch['labels'])
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

Advanced Configuration

from accelerate import Accelerator, DataLoaderConfiguration, ProjectConfiguration

# Advanced configuration
dataloader_config = DataLoaderConfiguration(
    split_batches=True,
    dispatch_batches=False
)

project_config = ProjectConfiguration(
    project_dir="./experiments",
    automatic_checkpoint_naming=True,
    total_limit=5
)

accelerator = Accelerator(
    device_placement=True,
    mixed_precision="bf16",
    gradient_accumulation_steps=8,
    dataloader_config=dataloader_config,
    project_config=project_config
)

Install with Tessl CLI

npx tessl i tessl/pypi-accelerate

docs

big-modeling.md

cli-commands.md

configuration.md

core-training.md

distributed-operations.md

index.md

utilities.md

tile.json