or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

big-modeling.mdcli-commands.mdconfiguration.mdcore-training.mddistributed-operations.mdindex.mdutilities.md

utilities.mddocs/

0

# Utilities

1

2

Memory management, checkpointing, model utilities, and various helper functions for training workflows. These utilities provide essential functionality for efficient training, model management, and system optimization.

3

4

## Capabilities

5

6

### Memory Management

7

8

Functions for optimizing memory usage during training and inference.

9

10

```python { .api }

11

def find_executable_batch_size(

12

function: callable,

13

starting_batch_size: int = 128

14

):

15

"""

16

Automatically find the largest executable batch size for a function.

17

18

Performs binary search to find the maximum batch size that doesn't

19

cause out-of-memory errors, useful for maximizing hardware utilization.

20

21

Parameters:

22

- function: Function to test with different batch sizes

23

- starting_batch_size: Initial batch size to try

24

25

Returns:

26

Largest batch size that executes successfully

27

"""

28

29

def release_memory(*objects):

30

"""

31

Release memory from specified objects and trigger garbage collection.

32

33

Parameters:

34

- *objects: Objects to delete and release memory from

35

"""

36

```

37

38

### Model Utilities

39

40

Functions for model introspection, manipulation, and memory analysis.

41

42

```python { .api }

43

def infer_auto_device_map(

44

model: torch.nn.Module,

45

max_memory: dict[int | str, int | str] | None = None,

46

no_split_module_classes: list[str] | None = None,

47

dtype: torch.dtype | str | None = None,

48

special_dtypes: dict[str, torch.dtype | str] | None = None,

49

verbose: bool = False

50

):

51

"""

52

Automatically infer optimal device mapping for a model.

53

54

Analyzes model size and available memory to determine the best

55

placement of layers across devices.

56

57

Parameters:

58

- model: Model to analyze

59

- max_memory: Maximum memory per device

60

- no_split_module_classes: Module classes that shouldn't be split

61

- dtype: Data type for memory calculations

62

- special_dtypes: Special dtypes for specific parameters

63

- verbose: Print detailed mapping information

64

65

Returns:

66

Dictionary mapping layer names to devices

67

"""

68

69

def get_balanced_memory(

70

model: torch.nn.Module,

71

max_memory: dict[int | str, int | str] | None = None,

72

no_split_module_classes: list[str] | None = None,

73

dtype: torch.dtype | None = None,

74

low_zero_memory: bool = False

75

):

76

"""

77

Calculate balanced memory distribution for model across devices.

78

79

Parameters:

80

- model: Model to analyze

81

- max_memory: Memory constraints per device

82

- no_split_module_classes: Modules to keep together

83

- dtype: Data type for calculations

84

- low_zero_memory: Use minimal memory for device 0

85

86

Returns:

87

Balanced memory allocation across devices

88

"""

89

90

def compute_module_sizes(

91

model: torch.nn.Module,

92

dtype: torch.dtype | None = None

93

):

94

"""

95

Compute memory size of each module in the model.

96

97

Parameters:

98

- model: Model to analyze

99

- dtype: Data type for size calculations

100

101

Returns:

102

Dictionary mapping module names to memory sizes in bytes

103

"""

104

105

def get_max_memory(max_memory: dict[int | str, int | str] | None = None):

106

"""

107

Get maximum available memory per device.

108

109

Parameters:

110

- max_memory: User-specified memory limits

111

112

Returns:

113

Dictionary of available memory per device

114

"""

115

116

def has_offloaded_params(model: torch.nn.Module):

117

"""

118

Check if model has any offloaded parameters.

119

120

Parameters:

121

- model: Model to check

122

123

Returns:

124

Boolean indicating presence of offloaded parameters

125

"""

126

```

127

128

### Checkpointing and State Management

129

130

Functions for saving and loading training state and model checkpoints.

131

132

```python { .api }

133

def save_accelerator_state(

134

output_dir: str | os.PathLike,

135

safe_serialization: bool = True

136

):

137

"""

138

Save complete Accelerator training state.

139

140

Saves model, optimizer, scheduler, and RNG states for complete

141

training resumption.

142

143

Parameters:

144

- output_dir: Directory to save state files

145

- safe_serialization: Use safetensors format when possible

146

"""

147

148

def load_accelerator_state(input_dir: str | os.PathLike):

149

"""

150

Load complete Accelerator training state.

151

152

Parameters:

153

- input_dir: Directory containing saved state files

154

"""

155

156

def save_custom_state(

157

obj,

158

path: str | os.PathLike,

159

process_index: int = 0,

160

scaler: callable | None = None

161

):

162

"""

163

Save custom object state with process coordination.

164

165

Parameters:

166

- obj: Object to save

167

- path: Path to save object

168

- process_index: Process responsible for saving

169

- scaler: Optional scaling function

170

"""

171

172

def load_custom_state(

173

path: str | os.PathLike,

174

process_index: int = 0,

175

scaler: callable | None = None

176

):

177

"""

178

Load custom object state.

179

180

Parameters:

181

- path: Path to load object from

182

- process_index: Process responsible for loading

183

- scaler: Optional scaling function

184

185

Returns:

186

Loaded object

187

"""

188

189

def load_checkpoint_in_model(

190

model: torch.nn.Module,

191

checkpoint: str | os.PathLike,

192

device_map: dict[str, torch.device | str | int] | None = None,

193

offload_folder: str | os.PathLike | None = None,

194

dtype: torch.dtype | None = None,

195

offload_state_dict: bool = False,

196

offload_buffers: bool = False,

197

keep_in_fp32_modules: list[str] | None = None,

198

strict: bool = False

199

):

200

"""

201

Load checkpoint into model with advanced options.

202

203

Parameters:

204

- model: Model to load checkpoint into

205

- checkpoint: Path to checkpoint file

206

- device_map: Device placement mapping

207

- offload_folder: Directory for offloaded weights

208

- dtype: Target data type

209

- offload_state_dict: Offload entire state dict

210

- offload_buffers: Offload buffer tensors

211

- keep_in_fp32_modules: Modules to keep in FP32

212

- strict: Strict checkpoint loading

213

214

Returns:

215

Tuple of (missing_keys, unexpected_keys)

216

"""

217

```

218

219

### Random State Management

220

221

Functions for managing random number generation across distributed processes.

222

223

```python { .api }

224

def set_seed(seed: int, device_specific: bool = False):

225

"""

226

Set random seed across all processes and libraries.

227

228

Sets seeds for PyTorch, NumPy, Python random, and other libraries

229

to ensure reproducible results across distributed training.

230

231

Parameters:

232

- seed: Random seed value

233

- device_specific: Use device-specific seeding for different results per device

234

"""

235

236

def synchronize_rng_states(

237

rng_types: list[str] | None = None,

238

generator: torch.Generator | None = None

239

):

240

"""

241

Synchronize random number generator states across processes.

242

243

Parameters:

244

- rng_types: Types of RNG to sync ("torch", "cuda", "xla")

245

- generator: Specific generator to synchronize

246

"""

247

248

def synchronize_rng_state(

249

rng_type: str | None = None,

250

generator: torch.Generator | None = None

251

):

252

"""

253

Synchronize single RNG state across processes.

254

255

Parameters:

256

- rng_type: Type of RNG to synchronize

257

- generator: Specific generator to use

258

"""

259

```

260

261

### Model Parameter Management

262

263

Functions for managing model parameters, tied weights, and device placement.

264

265

```python { .api }

266

def find_tied_parameters(model: torch.nn.Module):

267

"""

268

Find tied (shared) parameters in model.

269

270

Parameters:

271

- model: Model to analyze

272

273

Returns:

274

List of parameter groups that share the same tensor

275

"""

276

277

def check_tied_parameters_on_same_device(model: torch.nn.Module):

278

"""

279

Verify that tied parameters are on the same device.

280

281

Parameters:

282

- model: Model to check

283

284

Returns:

285

Boolean indicating if all tied parameters are properly placed

286

"""

287

288

def retie_parameters(

289

model: torch.nn.Module,

290

tied_params: list[list[str]]

291

):

292

"""

293

Re-establish parameter tying after model loading.

294

295

Parameters:

296

- model: Model with parameters to retie

297

- tied_params: List of parameter groups to tie together

298

"""

299

300

def set_module_tensor_to_device(

301

module: torch.nn.Module,

302

tensor_name: str,

303

device: torch.device | str | int,

304

value: torch.Tensor | None = None,

305

dtype: torch.dtype | None = None

306

):

307

"""

308

Set specific tensor in module to device with optional value/dtype.

309

310

Parameters:

311

- module: Module containing the tensor

312

- tensor_name: Name of tensor to modify

313

- device: Target device

314

- value: Optional new tensor value

315

- dtype: Optional target dtype

316

"""

317

318

def align_module_device(

319

module: torch.nn.Module,

320

execution_device: torch.device | str | int

321

):

322

"""

323

Align module device with execution device.

324

325

Parameters:

326

- module: Module to align

327

- execution_device: Target execution device

328

"""

329

```

330

331

### File I/O and Serialization

332

333

General-purpose functions for saving and loading objects with device awareness.

334

335

```python { .api }

336

def save(

337

obj,

338

path: str | os.PathLike,

339

save_on_each_node: bool = False,

340

safe_serialization: bool = False

341

):

342

"""

343

Save object with distributed training awareness.

344

345

Parameters:

346

- obj: Object to save

347

- path: Save path

348

- save_on_each_node: Save on each node instead of just main process

349

- safe_serialization: Use safetensors format when possible

350

"""

351

352

def load(

353

path: str | os.PathLike,

354

map_location: str | torch.device | None = None,

355

**kwargs

356

):

357

"""

358

Load object with device mapping support.

359

360

Parameters:

361

- path: Path to load from

362

- map_location: Device mapping for tensors

363

- **kwargs: Additional arguments for loading

364

365

Returns:

366

Loaded object

367

"""

368

369

def clean_state_dict_for_safetensors(state_dict: dict):

370

"""

371

Clean state dict for safetensors serialization.

372

373

Removes incompatible elements and prepares dict for safetensors format.

374

375

Parameters:

376

- state_dict: State dictionary to clean

377

378

Returns:

379

Cleaned state dictionary

380

"""

381

```

382

383

### Environment and Import Detection

384

385

Functions for detecting available libraries and hardware capabilities.

386

387

```python { .api }

388

def is_cuda_available():

389

"""Check if CUDA is available."""

390

391

def is_mps_available():

392

"""Check if Apple MPS is available."""

393

394

def is_xpu_available():

395

"""Check if Intel XPU is available."""

396

397

def is_hpu_available():

398

"""Check if Habana HPU is available."""

399

400

def is_npu_available():

401

"""Check if NPU is available."""

402

403

def is_deepspeed_available():

404

"""Check if DeepSpeed is available."""

405

406

def is_transformers_available():

407

"""Check if Transformers library is available."""

408

409

def is_datasets_available():

410

"""Check if Datasets library is available."""

411

412

def is_wandb_available():

413

"""Check if Weights & Biases is available."""

414

415

def is_tensorboard_available():

416

"""Check if TensorBoard is available."""

417

418

def is_comet_ml_available():

419

"""Check if Comet ML is available."""

420

421

def is_mlflow_available():

422

"""Check if MLflow is available."""

423

424

def is_bnb_available():

425

"""Check if Bitsandbytes is available."""

426

427

def is_4bit_bnb_available():

428

"""Check if 4-bit Bitsandbytes quantization is available."""

429

430

def is_8bit_bnb_available():

431

"""Check if 8-bit Bitsandbytes quantization is available."""

432

433

def is_torch_xla_available():

434

"""Check if Torch XLA is available."""

435

436

def is_rich_available():

437

"""Check if Rich formatting library is available."""

438

```

439

440

### System Utilities

441

442

General system and process management utilities.

443

444

```python { .api }

445

def wait_for_everyone():

446

"""

447

Global synchronization barrier across all processes.

448

"""

449

450

def extract_model_from_parallel(

451

model: torch.nn.Module,

452

keep_fp32_wrapper: bool = True

453

):

454

"""

455

Extract original model from parallel training wrappers.

456

457

Parameters:

458

- model: Wrapped model

459

- keep_fp32_wrapper: Whether to keep mixed precision wrapper

460

461

Returns:

462

Unwrapped model

463

"""

464

465

def merge_dicts(dict1: dict, dict2: dict):

466

"""

467

Merge two dictionaries recursively.

468

469

Parameters:

470

- dict1: First dictionary

471

- dict2: Second dictionary

472

473

Returns:

474

Merged dictionary

475

"""

476

477

def get_pretty_name(obj):

478

"""

479

Get human-readable name for object.

480

481

Parameters:

482

- obj: Object to get name for

483

484

Returns:

485

Pretty string representation

486

"""

487

488

def write_basic_config(

489

mixed_precision: str = "no",

490

save_location: str = "default"

491

):

492

"""

493

Write basic Accelerate configuration file.

494

495

Parameters:

496

- mixed_precision: Mixed precision mode

497

- save_location: Where to save config ("default" or custom path)

498

"""

499

500

def convert_bytes(size_bytes: int):

501

"""

502

Convert bytes to human-readable format.

503

504

Parameters:

505

- size_bytes: Size in bytes

506

507

Returns:

508

Human-readable size string (e.g., "1.5 GB")

509

"""

510

```

511

512

## Usage Examples

513

514

### Automatic Batch Size Finding

515

516

```python

517

from accelerate import find_executable_batch_size

518

import torch

519

520

def training_function(batch_size):

521

# Your training code here

522

model = MyModel()

523

optimizer = torch.optim.Adam(model.parameters())

524

525

# Simulate training step

526

for _ in range(10):

527

batch = torch.randn(batch_size, 784)

528

loss = model(batch).sum()

529

loss.backward()

530

optimizer.step()

531

optimizer.zero_grad()

532

533

# Find optimal batch size automatically

534

optimal_batch_size = find_executable_batch_size(training_function)

535

print(f"Optimal batch size: {optimal_batch_size}")

536

```

537

538

### Model Memory Analysis

539

540

```python

541

from accelerate import (

542

compute_module_sizes,

543

get_balanced_memory,

544

infer_auto_device_map

545

)

546

547

# Analyze model memory usage

548

module_sizes = compute_module_sizes(model, dtype=torch.float16)

549

print("Memory usage per module:")

550

for name, size in module_sizes.items():

551

print(f"{name}: {size / 1024**3:.2f} GB")

552

553

# Get balanced memory allocation

554

max_memory = {"0": "10GB", "1": "10GB", "cpu": "30GB"}

555

balanced_memory = get_balanced_memory(

556

model,

557

max_memory=max_memory,

558

no_split_module_classes=["LlamaDecoderLayer"]

559

)

560

561

# Infer optimal device mapping

562

device_map = infer_auto_device_map(

563

model,

564

max_memory=balanced_memory,

565

no_split_module_classes=["LlamaDecoderLayer"],

566

verbose=True

567

)

568

```

569

570

### Advanced Checkpointing

571

572

```python

573

from accelerate import (

574

save_accelerator_state,

575

load_accelerator_state,

576

save_custom_state,

577

load_custom_state

578

)

579

580

# Save complete training state

581

accelerator = Accelerator()

582

model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

583

584

# After some training...

585

save_accelerator_state("./checkpoint-1000", safe_serialization=True)

586

587

# Save custom objects

588

training_metadata = {

589

"epoch": 5,

590

"best_loss": 0.1234,

591

"learning_rates": [0.001, 0.0005, 0.0001]

592

}

593

save_custom_state(training_metadata, "./checkpoint-1000/metadata.pkl")

594

595

# Later, load everything back

596

load_accelerator_state("./checkpoint-1000")

597

metadata = load_custom_state("./checkpoint-1000/metadata.pkl")

598

```

599

600

### Random State Management

601

602

```python

603

from accelerate import set_seed, synchronize_rng_states

604

605

# Set reproducible seed across all processes

606

set_seed(42, device_specific=False)

607

608

# Synchronize RNG states for consistency

609

synchronize_rng_states(["torch", "cuda", "numpy"])

610

611

# Training with consistent randomness

612

for epoch in range(num_epochs):

613

# All processes will generate the same random augmentations

614

for batch in dataloader:

615

augmented_batch = apply_random_augmentation(batch)

616

# ... training code

617

```

618

619

### Parameter Management

620

621

```python

622

from accelerate import (

623

find_tied_parameters,

624

check_tied_parameters_on_same_device,

625

retie_parameters

626

)

627

628

# Find tied parameters in model

629

tied_params = find_tied_parameters(model)

630

print("Tied parameter groups:", tied_params)

631

632

# Check if tied parameters are properly placed

633

if not check_tied_parameters_on_same_device(model):

634

print("Warning: Tied parameters are not on the same device!")

635

636

# Re-tie parameters after loading from checkpoint

637

retie_parameters(model, tied_params)

638

```

639

640

### System Integration

641

642

```python

643

from accelerate import (

644

is_cuda_available,

645

is_deepspeed_available,

646

write_basic_config,

647

convert_bytes

648

)

649

650

# Check system capabilities

651

print(f"CUDA available: {is_cuda_available()}")

652

print(f"DeepSpeed available: {is_deepspeed_available()}")

653

654

# Create basic configuration

655

if is_cuda_available():

656

write_basic_config(mixed_precision="fp16")

657

else:

658

write_basic_config(mixed_precision="no")

659

660

# Memory usage reporting

661

model_size = sum(p.numel() * p.element_size() for p in model.parameters())

662

print(f"Model size: {convert_bytes(model_size)}")

663

```