0
# Distributed Operations
1
2
Low-level distributed communication primitives for gathering, broadcasting, reducing, and synchronizing data across processes. These functions provide the building blocks for distributed training and inference operations.
3
4
## Capabilities
5
6
### Basic Communication Primitives
7
8
Core distributed operations for communicating tensors and data between processes.
9
10
```python { .api }
11
def broadcast(tensor: torch.Tensor, from_process: int = 0):
12
"""
13
Broadcast tensor from one process to all other processes.
14
15
Parameters:
16
- tensor: Tensor to broadcast (modified in-place on receiving processes)
17
- from_process: Source process rank (default: 0)
18
"""
19
20
def gather(tensor: torch.Tensor):
21
"""
22
Gather tensors from all processes to the main process.
23
24
Parameters:
25
- tensor: Tensor to gather from current process
26
27
Returns:
28
Concatenated tensor from all processes (only on main process, None elsewhere)
29
"""
30
31
def reduce(tensor: torch.Tensor, reduction: str = "mean"):
32
"""
33
Reduce tensor across all processes using specified operation.
34
35
Parameters:
36
- tensor: Tensor to reduce (modified in-place)
37
- reduction: Reduction operation ("mean", "sum")
38
39
Returns:
40
Reduced tensor (same shape as input)
41
"""
42
43
def pad_across_processes(
44
tensor: torch.Tensor,
45
dim: int = 0,
46
pad_index: int = 0,
47
pad_first: bool = False
48
):
49
"""
50
Pad tensor to same size across all processes.
51
52
Useful for gathering tensors of different sizes by padding
53
smaller tensors to match the largest tensor size.
54
55
Parameters:
56
- tensor: Tensor to pad
57
- dim: Dimension along which to pad
58
- pad_index: Value to use for padding
59
- pad_first: Whether to pad at beginning or end
60
61
Returns:
62
Padded tensor with same size across all processes
63
"""
64
```
65
66
### Object Communication
67
68
Functions for communicating arbitrary Python objects between processes.
69
70
```python { .api }
71
def broadcast_object_list(
72
objects: list,
73
from_process: int = 0
74
):
75
"""
76
Broadcast list of Python objects from one process to all others.
77
78
Parameters:
79
- objects: List of objects to broadcast (modified in-place on receiving processes)
80
- from_process: Source process rank
81
"""
82
83
def gather_object(obj):
84
"""
85
Gather Python objects from all processes.
86
87
Parameters:
88
- obj: Object to gather from current process
89
90
Returns:
91
List of objects from all processes (only on main process, None elsewhere)
92
"""
93
```
94
95
### Advanced Tensor Operations
96
97
Higher-level operations for tensor manipulation in distributed settings.
98
99
```python { .api }
100
def concatenate(data, dim: int = 0):
101
"""
102
Concatenate tensors or nested data structures along specified dimension.
103
104
Handles complex nested structures including lists, tuples, and dictionaries
105
containing tensors or other concatenatable objects.
106
107
Parameters:
108
- data: Data structure containing tensors to concatenate
109
- dim: Dimension along which to concatenate
110
111
Returns:
112
Concatenated data structure with same nesting as input
113
"""
114
115
def slice_tensors(data, tensor_slice: slice | int):
116
"""
117
Slice tensors in nested data structures.
118
119
Applies the same slice operation to all tensors found in nested
120
lists, tuples, and dictionaries.
121
122
Parameters:
123
- data: Nested data structure containing tensors
124
- tensor_slice: Slice object or integer index to apply
125
126
Returns:
127
Sliced data structure maintaining original nesting
128
"""
129
130
def send_to_device(
131
tensor: torch.Tensor,
132
device: torch.device | str,
133
non_blocking: bool = False,
134
skip_keys: list[str] | str | None = None
135
):
136
"""
137
Move tensor or nested data structure to specified device.
138
139
Recursively moves all tensors in nested structures while preserving
140
the original data organization.
141
142
Parameters:
143
- tensor: Tensor or nested structure to move
144
- device: Target device
145
- non_blocking: Whether to use non-blocking transfer
146
- skip_keys: Keys to skip when moving nested dictionaries
147
148
Returns:
149
Data moved to target device
150
"""
151
```
152
153
### Data Structure Utilities
154
155
Functions for analyzing and manipulating tensor data structures.
156
157
```python { .api }
158
def find_batch_size(data):
159
"""
160
Find batch size from tensor or nested data structure.
161
162
Searches through nested structures to find the first tensor
163
and returns its size along dimension 0 (batch dimension).
164
165
Parameters:
166
- data: Tensor or nested structure containing tensors
167
168
Returns:
169
Batch size (int) or None if no tensors found
170
"""
171
172
def find_device(*args):
173
"""
174
Find device from tensor arguments.
175
176
Searches through arguments to find the first tensor and
177
returns its device.
178
179
Parameters:
180
- *args: Arguments that may contain tensors
181
182
Returns:
183
torch.device of first tensor found, or None
184
"""
185
186
def get_data_structure(data):
187
"""
188
Analyze nested data structure containing tensors.
189
190
Returns metadata about the structure including tensor shapes,
191
devices, and nesting patterns.
192
193
Parameters:
194
- data: Nested data structure to analyze
195
196
Returns:
197
DataStructure object describing the input
198
"""
199
200
def is_torch_tensor(data):
201
"""
202
Check if data is a PyTorch tensor.
203
204
Parameters:
205
- data: Object to check
206
207
Returns:
208
Boolean indicating if data is a torch.Tensor
209
"""
210
211
def is_tensor_information(data):
212
"""
213
Check if data contains tensor metadata information.
214
215
Parameters:
216
- data: Object to check
217
218
Returns:
219
Boolean indicating if data is TensorInformation
220
"""
221
```
222
223
### Process Synchronization
224
225
Functions for coordinating execution across distributed processes.
226
227
```python { .api }
228
def wait_for_everyone():
229
"""
230
Synchronization barrier - all processes wait until everyone reaches this point.
231
232
Ensures all processes are synchronized before continuing execution.
233
Essential for coordinating distributed operations.
234
"""
235
236
def synchronize_rng_states(rng_types: list[str] | None = None):
237
"""
238
Synchronize random number generator states across all processes.
239
240
Ensures reproducible results in distributed training by making
241
all processes use the same random state.
242
243
Parameters:
244
- rng_types: Types of RNG to synchronize ("torch", "cuda", "xla")
245
If None, synchronizes all available types
246
"""
247
248
def set_seed(seed: int, device_specific: bool = False):
249
"""
250
Set random seed across all processes and libraries.
251
252
Sets seeds for PyTorch, NumPy, Python random, and other libraries
253
to ensure reproducible results.
254
255
Parameters:
256
- seed: Random seed value
257
- device_specific: Whether to use device-specific seeding
258
"""
259
```
260
261
### Context Managers
262
263
Context managers for controlling distributed behavior during specific operations.
264
265
```python { .api }
266
class GatheredParameters:
267
"""
268
Context manager for gathering distributed parameters.
269
270
Temporarily gathers sharded parameters from all processes,
271
enabling operations that require the full parameter tensor.
272
"""
273
274
def __init__(self, *models, modifier_rank: int | None = None):
275
"""
276
Initialize parameter gathering context.
277
278
Parameters:
279
- *models: Models with parameters to gather
280
- modifier_rank: Process rank that can modify parameters
281
"""
282
```
283
284
### Precision Conversion
285
286
Functions for converting tensor precision in distributed settings.
287
288
```python { .api }
289
def convert_to_fp32(tensor: torch.Tensor):
290
"""
291
Convert tensor to FP32 precision.
292
293
Parameters:
294
- tensor: Tensor to convert
295
296
Returns:
297
Tensor converted to torch.float32
298
"""
299
300
def convert_outputs_to_fp32(data):
301
"""
302
Convert nested data structure outputs to FP32.
303
304
Recursively converts all tensors in nested structures to FP32,
305
useful for metric computation and logging.
306
307
Parameters:
308
- data: Nested structure containing tensors
309
310
Returns:
311
Data structure with all tensors converted to FP32
312
"""
313
314
def honor_type(obj, generator):
315
"""
316
Ensure generated object maintains same type hierarchy as original.
317
318
Parameters:
319
- obj: Original object to match type of
320
- generator: Generator producing new values
321
322
Returns:
323
Object of same type as obj with values from generator
324
"""
325
```
326
327
## Usage Examples
328
329
### Basic Distributed Communication
330
331
```python
332
from accelerate import broadcast, gather, reduce
333
import torch
334
335
# Initialize distributed training first
336
accelerator = Accelerator()
337
338
# Broadcast tensor from main process to all processes
339
if accelerator.is_main_process:
340
data = torch.randn(10, 20)
341
else:
342
data = torch.zeros(10, 20)
343
344
broadcast(data, from_process=0) # Now all processes have the same data
345
346
# Gather results from all processes
347
local_result = model(local_batch)
348
all_results = gather(local_result) # Only main process gets concatenated results
349
350
# Reduce loss across processes
351
loss = compute_loss(outputs, targets)
352
average_loss = reduce(loss, reduction="mean")
353
```
354
355
### Handling Variable-Size Batches
356
357
```python
358
from accelerate import pad_across_processes, gather
359
360
# When batch sizes differ across processes
361
predictions = model(batch) # Different sizes on each process
362
363
# Pad to same size before gathering
364
padded_predictions = pad_across_processes(predictions, dim=0, pad_index=-100)
365
all_predictions = gather(padded_predictions)
366
367
# Remove padding after gathering (on main process)
368
if accelerator.is_main_process:
369
# Remove padded values
370
valid_predictions = all_predictions[all_predictions != -100]
371
```
372
373
### Complex Data Structure Communication
374
375
```python
376
from accelerate import broadcast_object_list, gather_object
377
378
# Broadcast complex configuration
379
if accelerator.is_main_process:
380
config = {
381
"model_settings": {"layers": 12, "hidden_size": 768},
382
"training_params": [0.001, 0.9, 0.999],
383
"metadata": {"experiment_name": "test_run", "version": "1.0"}
384
}
385
else:
386
config = None
387
388
broadcast_object_list([config])
389
config = config[0] # Extract from list
390
391
# Gather evaluation results
392
eval_metrics = {"accuracy": 0.95, "f1": 0.93}
393
all_metrics = gather_object(eval_metrics)
394
395
if accelerator.is_main_process:
396
# all_metrics is list of metrics from each process
397
avg_accuracy = sum(m["accuracy"] for m in all_metrics) / len(all_metrics)
398
```
399
400
### Advanced Tensor Manipulation
401
402
```python
403
from accelerate import concatenate, slice_tensors, send_to_device
404
405
# Work with nested data structures
406
batch = {
407
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
408
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 0]]),
409
"labels": torch.tensor([0, 1])
410
}
411
412
# Move entire structure to GPU
413
batch_gpu = send_to_device(batch, "cuda:0")
414
415
# Slice first sample from nested structure
416
first_sample = slice_tensors(batch, 0)
417
418
# Concatenate batches from multiple sources
419
batches = [batch1, batch2, batch3]
420
combined_batch = concatenate(batches, dim=0)
421
```
422
423
### Process Synchronization and Reproducibility
424
425
```python
426
from accelerate import wait_for_everyone, set_seed, synchronize_rng_states
427
428
# Set reproducible seeds
429
set_seed(42, device_specific=True)
430
431
# Synchronize RNG states across processes
432
synchronize_rng_states(["torch", "cuda"])
433
434
# Coordinate processes for sequential operations
435
if accelerator.is_main_process:
436
# Download and prepare dataset
437
dataset = download_and_preprocess()
438
439
wait_for_everyone() # Wait for main process to finish
440
441
# Now all processes can safely access the dataset
442
dataloader = DataLoader(dataset, batch_size=32)
443
```