or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

distributed.mddocs/

0

# Distributed Operations

1

2

Collective communication operations and utilities for coordinating processes in distributed training environments.

3

4

## Capabilities

5

6

### Communication Primitives

7

8

Core collective operations for synchronizing data and computations across distributed processes.

9

10

```python { .api }

11

def barrier(self, name: Optional[str] = None) -> None:

12

"""

13

Synchronize all processes at this point.

14

15

Blocks until all processes reach this barrier. Useful for ensuring

16

all processes complete a phase before proceeding.

17

18

Args:

19

name: Optional name for the barrier (for debugging)

20

21

Raises:

22

RuntimeError: If barrier times out or fails

23

"""

24

25

def broadcast(self, obj: Any, src: int = 0) -> Any:

26

"""

27

Broadcast object from source process to all other processes.

28

29

Args:

30

obj: Object to broadcast (tensor, dict, list, etc.)

31

src: Source process rank (default: 0)

32

33

Returns:

34

The broadcasted object on all processes

35

36

Examples:

37

# Broadcast model parameters from rank 0

38

params = fabric.broadcast(model.state_dict(), src=0)

39

40

# Broadcast configuration dictionary

41

config = fabric.broadcast({"lr": 0.001, "batch_size": 32}, src=0)

42

"""

43

44

def all_gather(

45

self,

46

data: Union[Tensor, dict, list, tuple],

47

group: Optional[Any] = None,

48

sync_grads: bool = False

49

) -> Union[Tensor, dict, list, tuple]:

50

"""

51

Gather data from all processes and concatenate.

52

53

Each process contributes its data, and all processes receive

54

the concatenated result from all processes.

55

56

Args:

57

data: Data to gather (tensor, dict, list, or tuple)

58

group: Process group (None for default group)

59

sync_grads: Whether to synchronize gradients

60

61

Returns:

62

Gathered data from all processes

63

64

Examples:

65

# Gather predictions from all processes

66

local_preds = model(batch)

67

all_preds = fabric.all_gather(local_preds)

68

69

# Gather metrics dictionary

70

local_metrics = {"accuracy": 0.95, "loss": 0.1}

71

all_metrics = fabric.all_gather(local_metrics)

72

"""

73

74

def all_reduce(

75

self,

76

data: Union[Tensor, dict, list, tuple],

77

group: Optional[Any] = None,

78

reduce_op: Union[str, ReduceOp] = "mean"

79

) -> Union[Tensor, dict, list, tuple]:

80

"""

81

Reduce data across all processes using specified operation.

82

83

Applies reduction operation (sum, mean, max, min) across all processes

84

and returns the result to all processes.

85

86

Args:

87

data: Data to reduce (tensor, dict, list, or tuple)

88

group: Process group (None for default group)

89

reduce_op: Reduction operation ("sum", "mean", "max", "min")

90

91

Returns:

92

Reduced data

93

94

Examples:

95

# Average loss across all processes

96

local_loss = compute_loss(batch)

97

avg_loss = fabric.all_reduce(local_loss, reduce_op="mean")

98

99

# Sum gradients across processes

100

grads = fabric.all_reduce(gradients, reduce_op="sum")

101

"""

102

```

103

104

### Synchronization Utilities

105

106

Higher-level utilities for process coordination and data movement.

107

108

```python { .api }

109

def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:

110

"""

111

Move object to the appropriate device.

112

113

Automatically handles device placement for tensors, modules,

114

and nested data structures.

115

116

Args:

117

obj: Object to move to device

118

119

Returns:

120

Object moved to target device

121

122

Examples:

123

# Move tensor to device

124

tensor = torch.randn(10, 10)

125

tensor = fabric.to_device(tensor)

126

127

# Move nested data structure

128

data = {"input": torch.randn(32, 784), "target": torch.randint(0, 10, (32,))}

129

data = fabric.to_device(data)

130

"""

131

132

def print(self, *args, **kwargs) -> None:

133

"""

134

Print only from rank 0 process.

135

136

Prevents duplicate printing in distributed training by only

137

allowing the rank 0 process to print.

138

139

Args:

140

*args: Arguments to print

141

**kwargs: Keyword arguments for print function

142

143

Examples:

144

fabric.print(f"Epoch {epoch}, Loss: {loss:.4f}")

145

fabric.print("Training completed!", file=sys.stderr)

146

"""

147

```

148

149

### Advanced Synchronization

150

151

Context managers and advanced coordination primitives.

152

153

```python { .api }

154

def rank_zero_first(self, local: bool = False) -> Generator:

155

"""

156

Context manager ensuring rank 0 executes first.

157

158

Useful for operations that should be performed by one process first

159

(e.g., dataset preparation, model initialization).

160

161

Args:

162

local: If True, use local rank (within node), otherwise global rank

163

164

Yields:

165

None

166

167

Examples:

168

# Download dataset only on rank 0 first

169

with fabric.rank_zero_first():

170

dataset = download_dataset()

171

172

# Initialize model weights on rank 0 first

173

with fabric.rank_zero_first():

174

if fabric.is_global_zero:

175

initialize_model_weights(model)

176

"""

177

178

def no_backward_sync(

179

self,

180

module: _FabricModule,

181

enabled: bool = True

182

) -> AbstractContextManager:

183

"""

184

Context manager to skip gradient synchronization.

185

186

When enabled, gradients are not synchronized across processes

187

during backward pass. Useful for gradient accumulation.

188

189

Args:

190

module: Fabric-wrapped module

191

enabled: Whether to skip synchronization

192

193

Returns:

194

Context manager

195

196

Examples:

197

# Gradient accumulation without sync

198

for i, batch in enumerate(batches):

199

with fabric.no_backward_sync(model, enabled=(i < accumulate_steps-1)):

200

loss = compute_loss(model, batch)

201

fabric.backward(loss)

202

203

# Final step with synchronization

204

optimizer.step()

205

"""

206

```

207

208

### Process Information

209

210

Properties and methods to query distributed training state.

211

212

```python { .api }

213

@property

214

def global_rank(self) -> int:

215

"""Global rank of current process across all nodes."""

216

217

@property

218

def local_rank(self) -> int:

219

"""Local rank of current process within the current node."""

220

221

@property

222

def node_rank(self) -> int:

223

"""Rank of the current node."""

224

225

@property

226

def world_size(self) -> int:

227

"""Total number of processes across all nodes."""

228

229

@property

230

def is_global_zero(self) -> bool:

231

"""Whether current process is global rank 0."""

232

```

233

234

## Usage Examples

235

236

### Basic Communication

237

238

```python

239

from lightning.fabric import Fabric

240

241

fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")

242

243

# Broadcast configuration from rank 0

244

if fabric.is_global_zero:

245

config = {"learning_rate": 0.001, "batch_size": 32}

246

else:

247

config = None

248

249

config = fabric.broadcast(config, src=0)

250

print(f"Rank {fabric.global_rank}: {config}")

251

```

252

253

### Gradient Accumulation

254

255

```python

256

# Accumulate gradients over multiple batches

257

accumulate_steps = 4

258

model.train()

259

260

for batch_idx, batch in enumerate(dataloader):

261

# Skip gradient sync except on last accumulation step

262

with fabric.no_backward_sync(model, enabled=(batch_idx % accumulate_steps != 0)):

263

loss = compute_loss(model, batch) / accumulate_steps

264

fabric.backward(loss)

265

266

# Update weights after accumulation steps

267

if (batch_idx + 1) % accumulate_steps == 0:

268

optimizer.step()

269

optimizer.zero_grad()

270

```

271

272

### Distributed Evaluation

273

274

```python

275

# Evaluate model across all processes

276

model.eval()

277

all_predictions = []

278

all_targets = []

279

280

for batch in eval_dataloader:

281

with torch.no_grad():

282

predictions = model(batch["input"])

283

targets = batch["target"]

284

285

# Gather predictions and targets from all processes

286

all_preds = fabric.all_gather(predictions)

287

all_targs = fabric.all_gather(targets)

288

289

all_predictions.append(all_preds)

290

all_targets.append(all_targs)

291

292

# Compute metrics on gathered data

293

if fabric.is_global_zero:

294

predictions = torch.cat(all_predictions)

295

targets = torch.cat(all_targets)

296

accuracy = compute_accuracy(predictions, targets)

297

fabric.print(f"Evaluation accuracy: {accuracy:.4f}")

298

```

299

300

### Loss Synchronization

301

302

```python

303

# Compute and synchronize loss across processes

304

model.train()

305

total_loss = 0

306

num_batches = 0

307

308

for batch in dataloader:

309

loss = compute_loss(model, batch)

310

311

# Synchronize loss across processes for logging

312

sync_loss = fabric.all_reduce(loss, reduce_op="mean")

313

314

fabric.backward(loss)

315

optimizer.step()

316

optimizer.zero_grad()

317

318

total_loss += sync_loss.item()

319

num_batches += 1

320

321

if num_batches % 100 == 0:

322

avg_loss = total_loss / num_batches

323

fabric.print(f"Step {num_batches}, Avg Loss: {avg_loss:.4f}")

324

```

325

326

### Barrier Synchronization

327

328

```python

329

# Ensure all processes complete data preparation

330

fabric.print("Starting data preparation...")

331

332

# Each process prepares its portion of data

333

prepare_local_data()

334

335

# Wait for all processes to complete

336

fabric.barrier("data_preparation")

337

fabric.print("All processes completed data preparation")

338

339

# Continue with training

340

start_training()

341

```