or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

big-modeling.mdcli-commands.mdconfiguration.mdcore-training.mddistributed-operations.mdindex.mdutilities.md

distributed-operations.mddocs/

0

# Distributed Operations

1

2

Low-level distributed communication primitives for gathering, broadcasting, reducing, and synchronizing data across processes. These functions provide the building blocks for distributed training and inference operations.

3

4

## Capabilities

5

6

### Basic Communication Primitives

7

8

Core distributed operations for communicating tensors and data between processes.

9

10

```python { .api }

11

def broadcast(tensor: torch.Tensor, from_process: int = 0):

12

"""

13

Broadcast tensor from one process to all other processes.

14

15

Parameters:

16

- tensor: Tensor to broadcast (modified in-place on receiving processes)

17

- from_process: Source process rank (default: 0)

18

"""

19

20

def gather(tensor: torch.Tensor):

21

"""

22

Gather tensors from all processes to the main process.

23

24

Parameters:

25

- tensor: Tensor to gather from current process

26

27

Returns:

28

Concatenated tensor from all processes (only on main process, None elsewhere)

29

"""

30

31

def reduce(tensor: torch.Tensor, reduction: str = "mean"):

32

"""

33

Reduce tensor across all processes using specified operation.

34

35

Parameters:

36

- tensor: Tensor to reduce (modified in-place)

37

- reduction: Reduction operation ("mean", "sum")

38

39

Returns:

40

Reduced tensor (same shape as input)

41

"""

42

43

def pad_across_processes(

44

tensor: torch.Tensor,

45

dim: int = 0,

46

pad_index: int = 0,

47

pad_first: bool = False

48

):

49

"""

50

Pad tensor to same size across all processes.

51

52

Useful for gathering tensors of different sizes by padding

53

smaller tensors to match the largest tensor size.

54

55

Parameters:

56

- tensor: Tensor to pad

57

- dim: Dimension along which to pad

58

- pad_index: Value to use for padding

59

- pad_first: Whether to pad at beginning or end

60

61

Returns:

62

Padded tensor with same size across all processes

63

"""

64

```

65

66

### Object Communication

67

68

Functions for communicating arbitrary Python objects between processes.

69

70

```python { .api }

71

def broadcast_object_list(

72

objects: list,

73

from_process: int = 0

74

):

75

"""

76

Broadcast list of Python objects from one process to all others.

77

78

Parameters:

79

- objects: List of objects to broadcast (modified in-place on receiving processes)

80

- from_process: Source process rank

81

"""

82

83

def gather_object(obj):

84

"""

85

Gather Python objects from all processes.

86

87

Parameters:

88

- obj: Object to gather from current process

89

90

Returns:

91

List of objects from all processes (only on main process, None elsewhere)

92

"""

93

```

94

95

### Advanced Tensor Operations

96

97

Higher-level operations for tensor manipulation in distributed settings.

98

99

```python { .api }

100

def concatenate(data, dim: int = 0):

101

"""

102

Concatenate tensors or nested data structures along specified dimension.

103

104

Handles complex nested structures including lists, tuples, and dictionaries

105

containing tensors or other concatenatable objects.

106

107

Parameters:

108

- data: Data structure containing tensors to concatenate

109

- dim: Dimension along which to concatenate

110

111

Returns:

112

Concatenated data structure with same nesting as input

113

"""

114

115

def slice_tensors(data, tensor_slice: slice | int):

116

"""

117

Slice tensors in nested data structures.

118

119

Applies the same slice operation to all tensors found in nested

120

lists, tuples, and dictionaries.

121

122

Parameters:

123

- data: Nested data structure containing tensors

124

- tensor_slice: Slice object or integer index to apply

125

126

Returns:

127

Sliced data structure maintaining original nesting

128

"""

129

130

def send_to_device(

131

tensor: torch.Tensor,

132

device: torch.device | str,

133

non_blocking: bool = False,

134

skip_keys: list[str] | str | None = None

135

):

136

"""

137

Move tensor or nested data structure to specified device.

138

139

Recursively moves all tensors in nested structures while preserving

140

the original data organization.

141

142

Parameters:

143

- tensor: Tensor or nested structure to move

144

- device: Target device

145

- non_blocking: Whether to use non-blocking transfer

146

- skip_keys: Keys to skip when moving nested dictionaries

147

148

Returns:

149

Data moved to target device

150

"""

151

```

152

153

### Data Structure Utilities

154

155

Functions for analyzing and manipulating tensor data structures.

156

157

```python { .api }

158

def find_batch_size(data):

159

"""

160

Find batch size from tensor or nested data structure.

161

162

Searches through nested structures to find the first tensor

163

and returns its size along dimension 0 (batch dimension).

164

165

Parameters:

166

- data: Tensor or nested structure containing tensors

167

168

Returns:

169

Batch size (int) or None if no tensors found

170

"""

171

172

def find_device(*args):

173

"""

174

Find device from tensor arguments.

175

176

Searches through arguments to find the first tensor and

177

returns its device.

178

179

Parameters:

180

- *args: Arguments that may contain tensors

181

182

Returns:

183

torch.device of first tensor found, or None

184

"""

185

186

def get_data_structure(data):

187

"""

188

Analyze nested data structure containing tensors.

189

190

Returns metadata about the structure including tensor shapes,

191

devices, and nesting patterns.

192

193

Parameters:

194

- data: Nested data structure to analyze

195

196

Returns:

197

DataStructure object describing the input

198

"""

199

200

def is_torch_tensor(data):

201

"""

202

Check if data is a PyTorch tensor.

203

204

Parameters:

205

- data: Object to check

206

207

Returns:

208

Boolean indicating if data is a torch.Tensor

209

"""

210

211

def is_tensor_information(data):

212

"""

213

Check if data contains tensor metadata information.

214

215

Parameters:

216

- data: Object to check

217

218

Returns:

219

Boolean indicating if data is TensorInformation

220

"""

221

```

222

223

### Process Synchronization

224

225

Functions for coordinating execution across distributed processes.

226

227

```python { .api }

228

def wait_for_everyone():

229

"""

230

Synchronization barrier - all processes wait until everyone reaches this point.

231

232

Ensures all processes are synchronized before continuing execution.

233

Essential for coordinating distributed operations.

234

"""

235

236

def synchronize_rng_states(rng_types: list[str] | None = None):

237

"""

238

Synchronize random number generator states across all processes.

239

240

Ensures reproducible results in distributed training by making

241

all processes use the same random state.

242

243

Parameters:

244

- rng_types: Types of RNG to synchronize ("torch", "cuda", "xla")

245

If None, synchronizes all available types

246

"""

247

248

def set_seed(seed: int, device_specific: bool = False):

249

"""

250

Set random seed across all processes and libraries.

251

252

Sets seeds for PyTorch, NumPy, Python random, and other libraries

253

to ensure reproducible results.

254

255

Parameters:

256

- seed: Random seed value

257

- device_specific: Whether to use device-specific seeding

258

"""

259

```

260

261

### Context Managers

262

263

Context managers for controlling distributed behavior during specific operations.

264

265

```python { .api }

266

class GatheredParameters:

267

"""

268

Context manager for gathering distributed parameters.

269

270

Temporarily gathers sharded parameters from all processes,

271

enabling operations that require the full parameter tensor.

272

"""

273

274

def __init__(self, *models, modifier_rank: int | None = None):

275

"""

276

Initialize parameter gathering context.

277

278

Parameters:

279

- *models: Models with parameters to gather

280

- modifier_rank: Process rank that can modify parameters

281

"""

282

```

283

284

### Precision Conversion

285

286

Functions for converting tensor precision in distributed settings.

287

288

```python { .api }

289

def convert_to_fp32(tensor: torch.Tensor):

290

"""

291

Convert tensor to FP32 precision.

292

293

Parameters:

294

- tensor: Tensor to convert

295

296

Returns:

297

Tensor converted to torch.float32

298

"""

299

300

def convert_outputs_to_fp32(data):

301

"""

302

Convert nested data structure outputs to FP32.

303

304

Recursively converts all tensors in nested structures to FP32,

305

useful for metric computation and logging.

306

307

Parameters:

308

- data: Nested structure containing tensors

309

310

Returns:

311

Data structure with all tensors converted to FP32

312

"""

313

314

def honor_type(obj, generator):

315

"""

316

Ensure generated object maintains same type hierarchy as original.

317

318

Parameters:

319

- obj: Original object to match type of

320

- generator: Generator producing new values

321

322

Returns:

323

Object of same type as obj with values from generator

324

"""

325

```

326

327

## Usage Examples

328

329

### Basic Distributed Communication

330

331

```python

332

from accelerate import broadcast, gather, reduce

333

import torch

334

335

# Initialize distributed training first

336

accelerator = Accelerator()

337

338

# Broadcast tensor from main process to all processes

339

if accelerator.is_main_process:

340

data = torch.randn(10, 20)

341

else:

342

data = torch.zeros(10, 20)

343

344

broadcast(data, from_process=0) # Now all processes have the same data

345

346

# Gather results from all processes

347

local_result = model(local_batch)

348

all_results = gather(local_result) # Only main process gets concatenated results

349

350

# Reduce loss across processes

351

loss = compute_loss(outputs, targets)

352

average_loss = reduce(loss, reduction="mean")

353

```

354

355

### Handling Variable-Size Batches

356

357

```python

358

from accelerate import pad_across_processes, gather

359

360

# When batch sizes differ across processes

361

predictions = model(batch) # Different sizes on each process

362

363

# Pad to same size before gathering

364

padded_predictions = pad_across_processes(predictions, dim=0, pad_index=-100)

365

all_predictions = gather(padded_predictions)

366

367

# Remove padding after gathering (on main process)

368

if accelerator.is_main_process:

369

# Remove padded values

370

valid_predictions = all_predictions[all_predictions != -100]

371

```

372

373

### Complex Data Structure Communication

374

375

```python

376

from accelerate import broadcast_object_list, gather_object

377

378

# Broadcast complex configuration

379

if accelerator.is_main_process:

380

config = {

381

"model_settings": {"layers": 12, "hidden_size": 768},

382

"training_params": [0.001, 0.9, 0.999],

383

"metadata": {"experiment_name": "test_run", "version": "1.0"}

384

}

385

else:

386

config = None

387

388

broadcast_object_list([config])

389

config = config[0] # Extract from list

390

391

# Gather evaluation results

392

eval_metrics = {"accuracy": 0.95, "f1": 0.93}

393

all_metrics = gather_object(eval_metrics)

394

395

if accelerator.is_main_process:

396

# all_metrics is list of metrics from each process

397

avg_accuracy = sum(m["accuracy"] for m in all_metrics) / len(all_metrics)

398

```

399

400

### Advanced Tensor Manipulation

401

402

```python

403

from accelerate import concatenate, slice_tensors, send_to_device

404

405

# Work with nested data structures

406

batch = {

407

"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),

408

"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 0]]),

409

"labels": torch.tensor([0, 1])

410

}

411

412

# Move entire structure to GPU

413

batch_gpu = send_to_device(batch, "cuda:0")

414

415

# Slice first sample from nested structure

416

first_sample = slice_tensors(batch, 0)

417

418

# Concatenate batches from multiple sources

419

batches = [batch1, batch2, batch3]

420

combined_batch = concatenate(batches, dim=0)

421

```

422

423

### Process Synchronization and Reproducibility

424

425

```python

426

from accelerate import wait_for_everyone, set_seed, synchronize_rng_states

427

428

# Set reproducible seeds

429

set_seed(42, device_specific=True)

430

431

# Synchronize RNG states across processes

432

synchronize_rng_states(["torch", "cuda"])

433

434

# Coordinate processes for sequential operations

435

if accelerator.is_main_process:

436

# Download and prepare dataset

437

dataset = download_and_preprocess()

438

439

wait_for_everyone() # Wait for main process to finish

440

441

# Now all processes can safely access the dataset

442

dataloader = DataLoader(dataset, batch_size=32)

443

```