or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

big-modeling.mdcli-commands.mdconfiguration.mdcore-training.mddistributed-operations.mdindex.mdutilities.md

configuration.mddocs/

0

# Configuration and Plugins

1

2

Configuration classes and plugins for customizing distributed training behavior, including DeepSpeed integration, FSDP configuration, mixed precision settings, and other advanced training optimizations.

3

4

## Capabilities

5

6

### Core Configuration Classes

7

8

Base configuration objects for controlling distributed training behavior.

9

10

```python { .api }

11

class DataLoaderConfiguration:

12

"""

13

Configuration for DataLoader behavior in distributed training.

14

15

Controls how data is distributed and processed across multiple processes.

16

"""

17

18

def __init__(

19

self,

20

split_batches: bool = False,

21

dispatch_batches: bool | None = None,

22

even_batches: bool = True,

23

use_seedable_sampler: bool = False,

24

use_configured_sampler: bool = False,

25

non_blocking: bool = False,

26

gradient_accumulation_kwargs: dict | None = None

27

):

28

"""

29

Initialize DataLoader configuration.

30

31

Parameters:

32

- split_batches: Whether to split batches across processes

33

- dispatch_batches: Whether to dispatch batches to processes

34

- even_batches: Ensure all processes get same number of batches

35

- use_seedable_sampler: Use seedable sampler for reproducibility

36

- use_configured_sampler: Use custom sampler configuration

37

- non_blocking: Use non-blocking data transfer

38

- gradient_accumulation_kwargs: Additional gradient accumulation settings

39

"""

40

41

class ProjectConfiguration:

42

"""

43

Configuration for project output directories and logging behavior.

44

"""

45

46

def __init__(

47

self,

48

project_dir: str = ".",

49

logging_dir: str | None = None,

50

automatic_checkpoint_naming: bool = False,

51

total_limit: int | None = None,

52

iteration_checkpoints: bool = False,

53

save_on_each_node: bool = False

54

):

55

"""

56

Initialize project configuration.

57

58

Parameters:

59

- project_dir: Root directory for project outputs

60

- logging_dir: Directory for log files (relative to project_dir)

61

- automatic_checkpoint_naming: Auto-generate checkpoint names

62

- total_limit: Maximum number of checkpoints to keep

63

- iteration_checkpoints: Save checkpoints by iteration number

64

- save_on_each_node: Save checkpoints on every node

65

"""

66

67

class GradientAccumulationPlugin:

68

"""

69

Plugin for configuring gradient accumulation behavior.

70

"""

71

72

def __init__(

73

self,

74

num_steps: int | None = None,

75

adjust_scheduler: bool = True,

76

sync_with_dataloader: bool = True

77

):

78

"""

79

Initialize gradient accumulation plugin.

80

81

Parameters:

82

- num_steps: Number of steps to accumulate gradients

83

- adjust_scheduler: Adjust scheduler for accumulation steps

84

- sync_with_dataloader: Sync accumulation with dataloader length

85

"""

86

```

87

88

### DeepSpeed Plugin

89

90

Configuration for DeepSpeed distributed training integration.

91

92

```python { .api }

93

class DeepSpeedPlugin:

94

"""

95

Plugin for DeepSpeed distributed training configuration.

96

97

Provides integration with Microsoft DeepSpeed for memory-efficient

98

training with ZeRO optimizer states, gradient partitioning, and

99

parameter offloading.

100

"""

101

102

def __init__(

103

self,

104

hf_ds_config: dict | str | None = None,

105

gradient_accumulation_steps: int | None = None,

106

gradient_clipping: float | None = None,

107

zero_stage: int | None = None,

108

is_train_batch_min: bool = True,

109

auto_wrap_policy: bool | None = None,

110

offload_optimizer_device: str | None = None,

111

offload_param_device: str | None = None,

112

offload_optimizer_nvme_path: str | None = None,

113

offload_param_nvme_path: str | None = None,

114

zero3_init_flag: bool | None = None,

115

zero3_save_16bit_model: bool | None = None,

116

**kwargs

117

):

118

"""

119

Initialize DeepSpeed plugin configuration.

120

121

Parameters:

122

- hf_ds_config: DeepSpeed configuration dict or path to config file

123

- gradient_accumulation_steps: Number of gradient accumulation steps

124

- gradient_clipping: Gradient clipping threshold

125

- zero_stage: ZeRO optimization stage (0, 1, 2, or 3)

126

- is_train_batch_min: Whether train_batch_size is minimum per device

127

- auto_wrap_policy: Automatic model wrapping policy

128

- offload_optimizer_device: Device for optimizer state offloading

129

- offload_param_device: Device for parameter offloading

130

- offload_optimizer_nvme_path: NVMe path for optimizer offloading

131

- offload_param_nvme_path: NVMe path for parameter offloading

132

- zero3_init_flag: Enable ZeRO-3 initialization optimizations

133

- zero3_save_16bit_model: Save model in 16-bit precision with ZeRO-3

134

"""

135

```

136

137

### FSDP Plugin

138

139

Configuration for PyTorch Fully Sharded Data Parallel training.

140

141

```python { .api }

142

class FullyShardedDataParallelPlugin:

143

"""

144

Plugin for PyTorch FSDP (Fully Sharded Data Parallel) configuration.

145

146

Enables memory-efficient training by sharding model parameters,

147

gradients, and optimizer states across multiple GPUs.

148

"""

149

150

def __init__(

151

self,

152

sharding_strategy: int | None = None,

153

backward_prefetch: int | None = None,

154

mixed_precision_policy: MixedPrecision | None = None,

155

auto_wrap_policy: ModuleWrapPolicy | None = None,

156

cpu_offload: CPUOffload | None = None,

157

ignored_modules: list[torch.nn.Module] | None = None,

158

state_dict_type: str | None = None,

159

state_dict_config: dict | None = None,

160

optim_state_dict_config: dict | None = None,

161

limit_all_gathers: bool = True,

162

use_orig_params: bool = True,

163

param_init_fn: callable | None = None,

164

sync_module_states: bool = True,

165

forward_prefetch: bool = False,

166

activation_checkpointing: bool = False

167

):

168

"""

169

Initialize FSDP plugin configuration.

170

171

Parameters:

172

- sharding_strategy: Parameter sharding strategy

173

- backward_prefetch: Backward pass prefetching strategy

174

- mixed_precision_policy: Mixed precision configuration

175

- auto_wrap_policy: Automatic module wrapping policy

176

- cpu_offload: CPU offloading configuration

177

- ignored_modules: Modules to exclude from FSDP wrapping

178

- state_dict_type: Type of state dict to use

179

- state_dict_config: State dict configuration

180

- optim_state_dict_config: Optimizer state dict configuration

181

- limit_all_gathers: Limit simultaneous all-gather operations

182

- use_orig_params: Use original parameter references

183

- param_init_fn: Custom parameter initialization function

184

- sync_module_states: Synchronize module states across ranks

185

- forward_prefetch: Enable forward pass prefetching

186

- activation_checkpointing: Enable activation checkpointing

187

"""

188

```

189

190

### Mixed Precision Configuration

191

192

Configuration classes for different mixed precision training modes.

193

194

```python { .api }

195

class AutocastKwargs:

196

"""

197

Configuration for PyTorch autocast mixed precision.

198

"""

199

200

def __init__(

201

self,

202

enabled: bool = True,

203

cache_enabled: bool | None = None

204

):

205

"""

206

Initialize autocast configuration.

207

208

Parameters:

209

- enabled: Whether to enable autocast

210

- cache_enabled: Whether to enable autocast caching

211

"""

212

213

class GradScalerKwargs:

214

"""

215

Configuration for gradient scaling in mixed precision training.

216

"""

217

218

def __init__(

219

self,

220

init_scale: float = 65536.0,

221

growth_factor: float = 2.0,

222

backoff_factor: float = 0.5,

223

growth_interval: int = 2000,

224

enabled: bool = True

225

):

226

"""

227

Initialize gradient scaler configuration.

228

229

Parameters:

230

- init_scale: Initial scaling factor

231

- growth_factor: Factor to multiply scale by when no overflow

232

- backoff_factor: Factor to multiply scale by when overflow detected

233

- growth_interval: Number of steps between scale increases

234

- enabled: Whether gradient scaling is enabled

235

"""

236

237

class FP8RecipeKwargs:

238

"""

239

Configuration for FP8 (8-bit floating point) training.

240

"""

241

242

def __init__(

243

self,

244

backend: str = "TE",

245

use_autocast: bool = True,

246

fp8_format: str = "HYBRID",

247

amax_history_len: int = 1024,

248

amax_compute_algo: str = "most_recent"

249

):

250

"""

251

Initialize FP8 training configuration.

252

253

Parameters:

254

- backend: FP8 backend to use ("TE" for Transformer Engine)

255

- use_autocast: Whether to use autocast with FP8

256

- fp8_format: FP8 format specification

257

- amax_history_len: Length of amax history for scaling

258

- amax_compute_algo: Algorithm for computing amax values

259

"""

260

```

261

262

### Torch Compilation and Optimization

263

264

Configuration for PyTorch compilation and optimization features.

265

266

```python { .api }

267

class TorchDynamoPlugin:

268

"""

269

Plugin for PyTorch Dynamo compilation configuration.

270

271

Enables torch.compile optimizations for faster training and inference.

272

"""

273

274

def __init__(

275

self,

276

backend: str = "inductor",

277

mode: str | None = None,

278

fullgraph: bool = False,

279

dynamic: bool | None = None,

280

options: dict | None = None,

281

disable: bool = False

282

):

283

"""

284

Initialize Torch Dynamo plugin.

285

286

Parameters:

287

- backend: Compilation backend ("inductor", "aot_eager", etc.)

288

- mode: Compilation mode ("default", "reduce-overhead", "max-autotune")

289

- fullgraph: Whether to require full graph compilation

290

- dynamic: Enable dynamic shape compilation

291

- options: Additional backend-specific options

292

- disable: Whether to disable compilation

293

"""

294

295

class TorchTensorParallelPlugin:

296

"""

297

Plugin for PyTorch tensor parallelism configuration.

298

"""

299

300

def __init__(

301

self,

302

tensor_parallel_degree: int = 1,

303

parallelize_plan: dict | None = None

304

):

305

"""

306

Initialize tensor parallel plugin.

307

308

Parameters:

309

- tensor_parallel_degree: Degree of tensor parallelism

310

- parallelize_plan: Custom parallelization plan

311

"""

312

```

313

314

### Quantization Configuration

315

316

Configuration classes for model quantization techniques.

317

318

```python { .api }

319

class BnbQuantizationConfig:

320

"""

321

Configuration for Bitsandbytes quantization.

322

323

Enables 4-bit and 8-bit quantization for memory-efficient training.

324

"""

325

326

def __init__(

327

self,

328

load_in_8bit: bool = False,

329

load_in_4bit: bool = False,

330

llm_int8_threshold: float = 6.0,

331

llm_int8_skip_modules: list[str] | None = None,

332

llm_int8_enable_fp32_cpu_offload: bool = False,

333

llm_int8_has_fp16_weight: bool = False,

334

bnb_4bit_compute_dtype: torch.dtype | None = None,

335

bnb_4bit_quant_type: str = "fp4",

336

bnb_4bit_use_double_quant: bool = False,

337

bnb_4bit_quant_storage: torch.dtype | None = None

338

):

339

"""

340

Initialize Bitsandbytes quantization configuration.

341

342

Parameters:

343

- load_in_8bit: Enable 8-bit quantization

344

- load_in_4bit: Enable 4-bit quantization

345

- llm_int8_threshold: Threshold for int8 quantization

346

- llm_int8_skip_modules: Modules to skip during quantization

347

- llm_int8_enable_fp32_cpu_offload: Enable FP32 CPU offloading

348

- llm_int8_has_fp16_weight: Whether model has FP16 weights

349

- bnb_4bit_compute_dtype: Compute dtype for 4-bit operations

350

- bnb_4bit_quant_type: 4-bit quantization type ("fp4" or "nf4")

351

- bnb_4bit_use_double_quant: Enable double quantization

352

- bnb_4bit_quant_storage: Storage dtype for quantized weights

353

"""

354

```

355

356

### Process Group Configuration

357

358

Configuration for distributed process group initialization.

359

360

```python { .api }

361

class InitProcessGroupKwargs:

362

"""

363

Configuration for distributed process group initialization.

364

"""

365

366

def __init__(

367

self,

368

init_method: str | None = None,

369

timeout: int = 1800,

370

backend: str | None = None

371

):

372

"""

373

Initialize process group configuration.

374

375

Parameters:

376

- init_method: Method for process group initialization

377

- timeout: Timeout for initialization (seconds)

378

- backend: Distributed backend to use

379

"""

380

381

class DistributedDataParallelKwargs:

382

"""

383

Configuration for PyTorch DistributedDataParallel wrapper.

384

"""

385

386

def __init__(

387

self,

388

dim: int = 0,

389

broadcast_buffers: bool = True,

390

bucket_cap_mb: int = 25,

391

find_unused_parameters: bool = False,

392

check_reduction: bool = False,

393

gradient_as_bucket_view: bool = False,

394

static_graph: bool = False,

395

comm_hook: callable | None = None,

396

comm_state_option: str | None = None

397

):

398

"""

399

Initialize DDP configuration.

400

401

Parameters:

402

- dim: Dimension for gradient reduction

403

- broadcast_buffers: Whether to broadcast buffers

404

- bucket_cap_mb: Bucket size for gradient communication (MB)

405

- find_unused_parameters: Find unused parameters during backward

406

- check_reduction: Check gradient reduction correctness

407

- gradient_as_bucket_view: Use gradient as bucket view

408

- static_graph: Whether computation graph is static

409

- comm_hook: Custom communication hook

410

- comm_state_option: Communication state configuration

411

"""

412

```

413

414

## Usage Examples

415

416

### Basic Configuration Setup

417

418

```python

419

from accelerate import (

420

Accelerator,

421

DataLoaderConfiguration,

422

ProjectConfiguration,

423

GradientAccumulationPlugin

424

)

425

426

# Configure data loading behavior

427

dataloader_config = DataLoaderConfiguration(

428

split_batches=True,

429

even_batches=True,

430

use_seedable_sampler=True

431

)

432

433

# Configure project outputs

434

project_config = ProjectConfiguration(

435

project_dir="./experiments",

436

logging_dir="logs",

437

automatic_checkpoint_naming=True,

438

total_limit=5

439

)

440

441

# Configure gradient accumulation

442

grad_accumulation = GradientAccumulationPlugin(

443

num_steps=4,

444

adjust_scheduler=True

445

)

446

447

# Initialize accelerator with configurations

448

accelerator = Accelerator(

449

mixed_precision="fp16",

450

dataloader_config=dataloader_config,

451

project_config=project_config,

452

gradient_accumulation_plugin=grad_accumulation

453

)

454

```

455

456

### DeepSpeed Configuration

457

458

```python

459

from accelerate import Accelerator, DeepSpeedPlugin

460

461

# Define DeepSpeed configuration

462

deepspeed_config = {

463

"train_batch_size": 16,

464

"gradient_accumulation_steps": 4,

465

"optimizer": {

466

"type": "Adam",

467

"params": {"lr": 1e-4}

468

},

469

"zero_optimization": {

470

"stage": 2,

471

"offload_optimizer": {"device": "cpu"},

472

"overlap_comm": True,

473

"contiguous_gradients": True

474

},

475

"fp16": {"enabled": True}

476

}

477

478

# Create DeepSpeed plugin

479

deepspeed_plugin = DeepSpeedPlugin(

480

hf_ds_config=deepspeed_config,

481

zero_stage=2,

482

gradient_accumulation_steps=4,

483

gradient_clipping=1.0

484

)

485

486

# Initialize accelerator with DeepSpeed

487

accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)

488

```

489

490

### FSDP Configuration

491

492

```python

493

from accelerate import Accelerator, FullyShardedDataParallelPlugin

494

from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch

495

496

# Configure FSDP plugin

497

fsdp_plugin = FullyShardedDataParallelPlugin(

498

sharding_strategy=ShardingStrategy.FULL_SHARD,

499

backward_prefetch=BackwardPrefetch.BACKWARD_PRE,

500

cpu_offload=None, # Keep on GPU

501

mixed_precision_policy=None, # Use default

502

auto_wrap_policy=None, # Auto-detect

503

limit_all_gathers=True,

504

use_orig_params=True,

505

sync_module_states=True

506

)

507

508

# Initialize with FSDP

509

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

510

```

511

512

### Advanced Mixed Precision Setup

513

514

```python

515

from accelerate import (

516

Accelerator,

517

AutocastKwargs,

518

GradScalerKwargs,

519

FP8RecipeKwargs

520

)

521

522

# Configure autocast behavior

523

autocast_kwargs = AutocastKwargs(

524

enabled=True,

525

cache_enabled=False

526

)

527

528

# Configure gradient scaling

529

scaler_kwargs = GradScalerKwargs(

530

init_scale=2**16,

531

growth_factor=2.0,

532

backoff_factor=0.5,

533

growth_interval=2000

534

)

535

536

# Configure FP8 training (if supported)

537

fp8_kwargs = FP8RecipeKwargs(

538

backend="TE",

539

use_autocast=True,

540

fp8_format="HYBRID"

541

)

542

543

# Initialize with advanced mixed precision

544

accelerator = Accelerator(

545

mixed_precision="fp16",

546

kwargs_handlers=[autocast_kwargs, scaler_kwargs]

547

)

548

```

549

550

### Quantization Configuration

551

552

```python

553

from accelerate import Accelerator, BnbQuantizationConfig

554

import torch

555

556

# Configure 4-bit quantization

557

bnb_config = BnbQuantizationConfig(

558

load_in_4bit=True,

559

bnb_4bit_compute_dtype=torch.bfloat16,

560

bnb_4bit_quant_type="nf4",

561

bnb_4bit_use_double_quant=True

562

)

563

564

# Note: Quantization is typically applied during model loading

565

# rather than through Accelerator initialization

566

model = AutoModelForCausalLM.from_pretrained(

567

"model_name",

568

quantization_config=bnb_config,

569

device_map="auto"

570

)

571

```