0
# Distributed Operations
1
2
Collective communication operations and utilities for coordinating processes in distributed training environments.
3
4
## Capabilities
5
6
### Communication Primitives
7
8
Core collective operations for synchronizing data and computations across distributed processes.
9
10
```python { .api }
11
def barrier(self, name: Optional[str] = None) -> None:
12
"""
13
Synchronize all processes at this point.
14
15
Blocks until all processes reach this barrier. Useful for ensuring
16
all processes complete a phase before proceeding.
17
18
Args:
19
name: Optional name for the barrier (for debugging)
20
21
Raises:
22
RuntimeError: If barrier times out or fails
23
"""
24
25
def broadcast(self, obj: Any, src: int = 0) -> Any:
26
"""
27
Broadcast object from source process to all other processes.
28
29
Args:
30
obj: Object to broadcast (tensor, dict, list, etc.)
31
src: Source process rank (default: 0)
32
33
Returns:
34
The broadcasted object on all processes
35
36
Examples:
37
# Broadcast model parameters from rank 0
38
params = fabric.broadcast(model.state_dict(), src=0)
39
40
# Broadcast configuration dictionary
41
config = fabric.broadcast({"lr": 0.001, "batch_size": 32}, src=0)
42
"""
43
44
def all_gather(
45
self,
46
data: Union[Tensor, dict, list, tuple],
47
group: Optional[Any] = None,
48
sync_grads: bool = False
49
) -> Union[Tensor, dict, list, tuple]:
50
"""
51
Gather data from all processes and concatenate.
52
53
Each process contributes its data, and all processes receive
54
the concatenated result from all processes.
55
56
Args:
57
data: Data to gather (tensor, dict, list, or tuple)
58
group: Process group (None for default group)
59
sync_grads: Whether to synchronize gradients
60
61
Returns:
62
Gathered data from all processes
63
64
Examples:
65
# Gather predictions from all processes
66
local_preds = model(batch)
67
all_preds = fabric.all_gather(local_preds)
68
69
# Gather metrics dictionary
70
local_metrics = {"accuracy": 0.95, "loss": 0.1}
71
all_metrics = fabric.all_gather(local_metrics)
72
"""
73
74
def all_reduce(
75
self,
76
data: Union[Tensor, dict, list, tuple],
77
group: Optional[Any] = None,
78
reduce_op: Union[str, ReduceOp] = "mean"
79
) -> Union[Tensor, dict, list, tuple]:
80
"""
81
Reduce data across all processes using specified operation.
82
83
Applies reduction operation (sum, mean, max, min) across all processes
84
and returns the result to all processes.
85
86
Args:
87
data: Data to reduce (tensor, dict, list, or tuple)
88
group: Process group (None for default group)
89
reduce_op: Reduction operation ("sum", "mean", "max", "min")
90
91
Returns:
92
Reduced data
93
94
Examples:
95
# Average loss across all processes
96
local_loss = compute_loss(batch)
97
avg_loss = fabric.all_reduce(local_loss, reduce_op="mean")
98
99
# Sum gradients across processes
100
grads = fabric.all_reduce(gradients, reduce_op="sum")
101
"""
102
```
103
104
### Synchronization Utilities
105
106
Higher-level utilities for process coordination and data movement.
107
108
```python { .api }
109
def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
110
"""
111
Move object to the appropriate device.
112
113
Automatically handles device placement for tensors, modules,
114
and nested data structures.
115
116
Args:
117
obj: Object to move to device
118
119
Returns:
120
Object moved to target device
121
122
Examples:
123
# Move tensor to device
124
tensor = torch.randn(10, 10)
125
tensor = fabric.to_device(tensor)
126
127
# Move nested data structure
128
data = {"input": torch.randn(32, 784), "target": torch.randint(0, 10, (32,))}
129
data = fabric.to_device(data)
130
"""
131
132
def print(self, *args, **kwargs) -> None:
133
"""
134
Print only from rank 0 process.
135
136
Prevents duplicate printing in distributed training by only
137
allowing the rank 0 process to print.
138
139
Args:
140
*args: Arguments to print
141
**kwargs: Keyword arguments for print function
142
143
Examples:
144
fabric.print(f"Epoch {epoch}, Loss: {loss:.4f}")
145
fabric.print("Training completed!", file=sys.stderr)
146
"""
147
```
148
149
### Advanced Synchronization
150
151
Context managers and advanced coordination primitives.
152
153
```python { .api }
154
def rank_zero_first(self, local: bool = False) -> Generator:
155
"""
156
Context manager ensuring rank 0 executes first.
157
158
Useful for operations that should be performed by one process first
159
(e.g., dataset preparation, model initialization).
160
161
Args:
162
local: If True, use local rank (within node), otherwise global rank
163
164
Yields:
165
None
166
167
Examples:
168
# Download dataset only on rank 0 first
169
with fabric.rank_zero_first():
170
dataset = download_dataset()
171
172
# Initialize model weights on rank 0 first
173
with fabric.rank_zero_first():
174
if fabric.is_global_zero:
175
initialize_model_weights(model)
176
"""
177
178
def no_backward_sync(
179
self,
180
module: _FabricModule,
181
enabled: bool = True
182
) -> AbstractContextManager:
183
"""
184
Context manager to skip gradient synchronization.
185
186
When enabled, gradients are not synchronized across processes
187
during backward pass. Useful for gradient accumulation.
188
189
Args:
190
module: Fabric-wrapped module
191
enabled: Whether to skip synchronization
192
193
Returns:
194
Context manager
195
196
Examples:
197
# Gradient accumulation without sync
198
for i, batch in enumerate(batches):
199
with fabric.no_backward_sync(model, enabled=(i < accumulate_steps-1)):
200
loss = compute_loss(model, batch)
201
fabric.backward(loss)
202
203
# Final step with synchronization
204
optimizer.step()
205
"""
206
```
207
208
### Process Information
209
210
Properties and methods to query distributed training state.
211
212
```python { .api }
213
@property
214
def global_rank(self) -> int:
215
"""Global rank of current process across all nodes."""
216
217
@property
218
def local_rank(self) -> int:
219
"""Local rank of current process within the current node."""
220
221
@property
222
def node_rank(self) -> int:
223
"""Rank of the current node."""
224
225
@property
226
def world_size(self) -> int:
227
"""Total number of processes across all nodes."""
228
229
@property
230
def is_global_zero(self) -> bool:
231
"""Whether current process is global rank 0."""
232
```
233
234
## Usage Examples
235
236
### Basic Communication
237
238
```python
239
from lightning.fabric import Fabric
240
241
fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")
242
243
# Broadcast configuration from rank 0
244
if fabric.is_global_zero:
245
config = {"learning_rate": 0.001, "batch_size": 32}
246
else:
247
config = None
248
249
config = fabric.broadcast(config, src=0)
250
print(f"Rank {fabric.global_rank}: {config}")
251
```
252
253
### Gradient Accumulation
254
255
```python
256
# Accumulate gradients over multiple batches
257
accumulate_steps = 4
258
model.train()
259
260
for batch_idx, batch in enumerate(dataloader):
261
# Skip gradient sync except on last accumulation step
262
with fabric.no_backward_sync(model, enabled=(batch_idx % accumulate_steps != 0)):
263
loss = compute_loss(model, batch) / accumulate_steps
264
fabric.backward(loss)
265
266
# Update weights after accumulation steps
267
if (batch_idx + 1) % accumulate_steps == 0:
268
optimizer.step()
269
optimizer.zero_grad()
270
```
271
272
### Distributed Evaluation
273
274
```python
275
# Evaluate model across all processes
276
model.eval()
277
all_predictions = []
278
all_targets = []
279
280
for batch in eval_dataloader:
281
with torch.no_grad():
282
predictions = model(batch["input"])
283
targets = batch["target"]
284
285
# Gather predictions and targets from all processes
286
all_preds = fabric.all_gather(predictions)
287
all_targs = fabric.all_gather(targets)
288
289
all_predictions.append(all_preds)
290
all_targets.append(all_targs)
291
292
# Compute metrics on gathered data
293
if fabric.is_global_zero:
294
predictions = torch.cat(all_predictions)
295
targets = torch.cat(all_targets)
296
accuracy = compute_accuracy(predictions, targets)
297
fabric.print(f"Evaluation accuracy: {accuracy:.4f}")
298
```
299
300
### Loss Synchronization
301
302
```python
303
# Compute and synchronize loss across processes
304
model.train()
305
total_loss = 0
306
num_batches = 0
307
308
for batch in dataloader:
309
loss = compute_loss(model, batch)
310
311
# Synchronize loss across processes for logging
312
sync_loss = fabric.all_reduce(loss, reduce_op="mean")
313
314
fabric.backward(loss)
315
optimizer.step()
316
optimizer.zero_grad()
317
318
total_loss += sync_loss.item()
319
num_batches += 1
320
321
if num_batches % 100 == 0:
322
avg_loss = total_loss / num_batches
323
fabric.print(f"Step {num_batches}, Avg Loss: {avg_loss:.4f}")
324
```
325
326
### Barrier Synchronization
327
328
```python
329
# Ensure all processes complete data preparation
330
fabric.print("Starting data preparation...")
331
332
# Each process prepares its portion of data
333
prepare_local_data()
334
335
# Wait for all processes to complete
336
fabric.barrier("data_preparation")
337
fabric.print("All processes completed data preparation")
338
339
# Continue with training
340
start_training()
341
```