Engine of OpenMMLab projects for training deep learning models based on PyTorch with large-scale training frameworks, configuration management, and monitoring capabilities
—
Multi-GPU and multi-node training support with various distribution strategies including DDP, FSDP, DeepSpeed, and ColossalAI integration with communication utilities and device management. The system provides comprehensive distributed training capabilities for scalable deep learning.
Functions for initializing distributed training environments.
def init_dist(launcher: str, backend: str = 'nccl', **kwargs):
"""
Initialize distributed training.
Parameters:
- launcher: Launcher type ('pytorch', 'mpi', 'slurm')
- backend: Communication backend ('nccl', 'gloo', 'mpi')
- **kwargs: Additional initialization arguments
"""
def init_local_group(group_size: int):
"""
Initialize local process group.
Parameters:
- group_size: Size of local group
"""
def get_backend() -> str:
"""
Get current distributed backend.
Returns:
Backend name
"""
def infer_launcher() -> str:
"""
Infer distributed launcher from environment.
Returns:
Inferred launcher type
"""Functions for getting information about distributed processes.
def get_dist_info() -> tuple:
"""
Get distributed training information.
Returns:
Tuple of (rank, world_size)
"""
def get_rank() -> int:
"""
Get current process rank.
Returns:
Process rank
"""
def get_world_size() -> int:
"""
Get total number of processes.
Returns:
World size
"""
def get_local_rank() -> int:
"""
Get local rank within node.
Returns:
Local rank
"""
def get_local_size() -> int:
"""
Get local group size.
Returns:
Local group size
"""
def get_local_group():
"""
Get local process group.
Returns:
Local process group
"""
def is_main_process() -> bool:
"""
Check if current process is main process.
Returns:
True if main process
"""
def is_distributed() -> bool:
"""
Check if in distributed mode.
Returns:
True if distributed training is enabled
"""
def get_default_group():
"""
Get default process group.
Returns:
Default process group
"""Functions for inter-process communication in distributed training.
def all_reduce(tensor, op: str = 'sum', group=None, async_op: bool = False):
"""
All-reduce operation across processes.
Parameters:
- tensor: Tensor to reduce
- op: Reduction operation ('sum', 'mean', 'max', 'min')
- group: Process group
- async_op: Whether to perform asynchronously
"""
def all_gather(tensor_list: list, tensor, group=None, async_op: bool = False):
"""
All-gather operation across processes.
Parameters:
- tensor_list: List to store gathered tensors
- tensor: Tensor to gather
- group: Process group
- async_op: Whether to perform asynchronously
"""
def all_gather_object(object_list: list, obj, group=None):
"""
All-gather Python objects across processes.
Parameters:
- object_list: List to store gathered objects
- obj: Object to gather
- group: Process group
"""
def broadcast(tensor, src: int = 0, group=None, async_op: bool = False):
"""
Broadcast tensor from source process.
Parameters:
- tensor: Tensor to broadcast
- src: Source process rank
- group: Process group
- async_op: Whether to perform asynchronously
"""
def broadcast_object_list(object_list: list, src: int = 0, group=None):
"""
Broadcast list of objects from source process.
Parameters:
- object_list: List of objects to broadcast
- src: Source process rank
- group: Process group
"""
def gather(tensor, gather_list: list = None, dst: int = 0, group=None, async_op: bool = False):
"""
Gather tensors to destination process.
Parameters:
- tensor: Tensor to gather
- gather_list: List to store gathered tensors
- dst: Destination process rank
- group: Process group
- async_op: Whether to perform asynchronously
"""
def gather_object(obj, object_gather_list: list = None, dst: int = 0, group=None):
"""
Gather Python objects to destination process.
Parameters:
- obj: Object to gather
- object_gather_list: List to store gathered objects
- dst: Destination process rank
- group: Process group
"""
def reduce(tensor, dst: int = 0, op: str = 'sum', group=None, async_op: bool = False):
"""
Reduce tensor to destination process.
Parameters:
- tensor: Tensor to reduce
- dst: Destination process rank
- op: Reduction operation
- group: Process group
- async_op: Whether to perform asynchronously
"""
def barrier(group=None, async_op: bool = False):
"""
Synchronization barrier across processes.
Parameters:
- group: Process group
- async_op: Whether to perform asynchronously
"""
def sync_random_seed(seed: int = None, device: str = 'cuda') -> int:
"""
Synchronize random seed across processes.
Parameters:
- seed: Random seed (generated if None)
- device: Device for synchronization
Returns:
Synchronized seed
"""Higher-level communication functions for complex operations.
def all_reduce_dict(py_dict: dict, op: str = 'mean', group=None, to_float: bool = True) -> dict:
"""
All-reduce dictionary of tensors.
Parameters:
- py_dict: Dictionary of tensors
- op: Reduction operation
- group: Process group
- to_float: Whether to convert to float
Returns:
Reduced dictionary
"""
def all_reduce_params(params, coalesce: bool = True, bucket_size_mb: int = -1):
"""
All-reduce model parameters.
Parameters:
- params: Model parameters
- coalesce: Whether to coalesce parameters
- bucket_size_mb: Bucket size in MB
"""
def collect_results(result_part: list, size: int, tmpdir: str = None) -> list:
"""
Collect results from all processes.
Parameters:
- result_part: Partial results from current process
- size: Total size of dataset
- tmpdir: Temporary directory for file-based collection
Returns:
Collected results from all processes
"""
def collect_results_cpu(result_part: list, size: int, tmpdir: str = None) -> list:
"""
Collect results to CPU from all processes.
Parameters:
- result_part: Partial results
- size: Total dataset size
- tmpdir: Temporary directory
Returns:
CPU results from all processes
"""
def collect_results_gpu(result_part: list, size: int) -> list:
"""
Collect results on GPU from all processes.
Parameters:
- result_part: Partial results
- size: Total dataset size
Returns:
GPU results from all processes
"""Functions for managing devices in distributed environments.
def get_device() -> str:
"""
Get current device.
Returns:
Device string ('cuda:0', 'cpu', etc.)
"""
def get_data_device(data) -> str:
"""
Get device of data.
Parameters:
- data: Input data (tensor, dict, list, etc.)
Returns:
Device string
"""
def get_comm_device(group=None) -> str:
"""
Get communication device for process group.
Parameters:
- group: Process group
Returns:
Communication device
"""
def cast_data_device(data, device: str, out=None):
"""
Cast data to specified device.
Parameters:
- data: Input data
- device: Target device
- out: Output container
Returns:
Data on target device
"""Distributed data parallel wrappers for models.
class MMDistributedDataParallel:
def __init__(self, module, device_ids: list = None, output_device: int = None, broadcast_buffers: bool = True, find_unused_parameters: bool = False, bucket_cap_mb: int = 25, gradient_as_bucket_view: bool = False):
"""
MMEngine's distributed data parallel wrapper.
Parameters:
- module: Model module to wrap
- device_ids: Device IDs for this process
- output_device: Output device ID
- broadcast_buffers: Whether to broadcast buffers
- find_unused_parameters: Whether to find unused parameters
- bucket_cap_mb: Bucket capacity in MB
- gradient_as_bucket_view: Whether to use gradient bucket view
"""
def forward(self, *inputs, **kwargs):
"""
Forward pass with gradient synchronization.
Parameters:
- *inputs: Input arguments
- **kwargs: Input keyword arguments
Returns:
Model outputs
"""
class MMSeparateDistributedDataParallel:
def __init__(self, module, device_ids: list = None, output_device: int = None, broadcast_buffers: bool = True, find_unused_parameters: bool = False):
"""
Separate distributed data parallel for different parameter groups.
Parameters:
- module: Model module
- device_ids: Device IDs
- output_device: Output device
- broadcast_buffers: Whether to broadcast buffers
- find_unused_parameters: Whether to find unused parameters
"""
class MMFullyShardedDataParallel:
def __init__(self, module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id: int = None, sync_module_states: bool = False, forward_prefetch: bool = False, limit_all_gathers: bool = True, use_orig_params: bool = False):
"""
Fully sharded data parallel wrapper (PyTorch >=2.0).
Parameters:
- module: Model module
- process_group: Process group
- sharding_strategy: Sharding strategy
- cpu_offload: CPU offload configuration
- auto_wrap_policy: Auto-wrap policy
- backward_prefetch: Backward prefetch strategy
- mixed_precision: Mixed precision policy
- ignored_modules: Modules to ignore
- param_init_fn: Parameter initialization function
- device_id: Device ID
- sync_module_states: Whether to sync module states
- forward_prefetch: Whether to prefetch in forward
- limit_all_gathers: Whether to limit all-gathers
- use_orig_params: Whether to use original parameters
"""
def is_model_wrapper(model) -> bool:
"""
Check if model is wrapped with distributed wrapper.
Parameters:
- model: Model to check
Returns:
True if model is wrapped
"""Decorators for distributed training utilities.
def master_only(func):
"""
Decorator to run function only on master process.
Parameters:
- func: Function to decorate
Returns:
Decorated function
"""import torch
from mmengine import Runner, init_dist
# Initialize distributed training
init_dist('pytorch', backend='nccl')
# Get distributed info
rank, world_size = get_dist_info()
local_rank = get_local_rank()
# Set device
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
# Create model and move to device
model = MyModel().to(device)
# Wrap with DDP
from mmengine.model import MMDistributedDataParallel
model = MMDistributedDataParallel(
model,
device_ids=[local_rank],
broadcast_buffers=False,
find_unused_parameters=False
)
# Create runner with distributed configuration
runner = Runner(
model=model,
work_dir='./work_dir',
train_dataloader=train_loader,
launcher='pytorch'
)
runner.train()import torch
from mmengine.dist import all_reduce, all_gather, broadcast
# All-reduce operation
loss = torch.tensor(0.5).cuda()
all_reduce(loss, op='mean') # Average loss across all processes
# All-gather operation
local_tensor = torch.randn(4).cuda()
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(get_world_size())]
all_gather(gathered_tensors, local_tensor)
# Broadcast operation
if get_rank() == 0:
data = torch.randn(10).cuda()
else:
data = torch.zeros(10).cuda()
broadcast(data, src=0)
# Dictionary all-reduce
metrics = {'loss': torch.tensor(0.5), 'acc': torch.tensor(0.9)}
reduced_metrics = all_reduce_dict(metrics, op='mean')from mmengine.dist import collect_results
# Collect evaluation results from all processes
def evaluate_model(model, dataloader):
results = []
for batch in dataloader:
outputs = model(batch)
results.extend(outputs)
# Collect results from all processes
all_results = collect_results(results, len(dataloader.dataset))
# Only compute metrics on main process
if is_main_process():
metrics = compute_metrics(all_results)
return metrics
return {}from mmengine.dist import master_only, is_main_process
@master_only
def save_checkpoint(model, path):
"""Save checkpoint only on master process."""
torch.save(model.state_dict(), path)
@master_only
def log_metrics(metrics):
"""Log metrics only on master process."""
print(f"Metrics: {metrics}")
# Alternative approach
def training_step(model, data):
loss = model(data)
if is_main_process():
print(f"Loss: {loss.item()}")
return lossfrom mmengine.model import MMDistributedDataParallel
# DDP with gradient bucketing and unused parameter detection
model = MMDistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=True,
find_unused_parameters=True,
bucket_cap_mb=25,
gradient_as_bucket_view=True
)
# Separate DDP for models with different parameter update frequencies
model = MMSeparateDistributedDataParallel(
model,
device_ids=[local_rank],
find_unused_parameters=True
)from mmengine.model import MMFullyShardedDataParallel
from torch.distributed.fsdp import ShardingStrategy, CPUOffload
# FSDP configuration
model = MMFullyShardedDataParallel(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=CPUOffload(offload_params=True),
mixed_precision=None,
backward_prefetch=None,
forward_prefetch=False,
limit_all_gathers=True
)from mmengine.dist import sync_random_seed
# Synchronize random seed across all processes
seed = sync_random_seed(42)
# Use synchronized seed
torch.manual_seed(seed)
np.random.seed(seed)Install with Tessl CLI
npx tessl i tessl/pypi-mmengine