or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

precision.mddocs/

0

# Precision

1

2

Precision plugins for mixed precision training, quantization, and memory optimization techniques.

3

4

## Capabilities

5

6

### Base Precision

7

8

Abstract base class defining the precision interface.

9

10

```python { .api }

11

class Precision:

12

"""

13

Abstract base class for precision plugins.

14

15

Precision plugins handle numerical precision, mixed precision training,

16

quantization, and memory optimization techniques.

17

"""

18

19

def convert_module(self, module: nn.Module) -> nn.Module:

20

"""Convert module to target precision."""

21

22

def convert_input(self, data: Any) -> Any:

23

"""Convert input data to target precision."""

24

25

def convert_output(self, data: Any) -> Any:

26

"""Convert output data from target precision."""

27

28

def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:

29

"""Pre-process tensor before backward pass."""

30

31

def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:

32

"""Post-process tensor after backward pass."""

33

34

def forward_context(self) -> AbstractContextManager:

35

"""Context manager for forward pass precision."""

36

37

def optimizer_step(

38

self,

39

optimizer: Optimizer,

40

model: nn.Module,

41

closure: callable,

42

**kwargs

43

) -> Any:

44

"""Execute optimizer step with precision handling."""

45

46

def state_dict(self) -> dict[str, Any]:

47

"""Get precision plugin state."""

48

49

def load_state_dict(self, state_dict: dict[str, Any]) -> None:

50

"""Load precision plugin state."""

51

```

52

53

### Double Precision

54

55

64-bit double precision for maximum numerical accuracy.

56

57

```python { .api }

58

class DoublePrecision(Precision):

59

"""

60

64-bit double precision plugin.

61

62

Provides maximum numerical precision using 64-bit floating point

63

arithmetic. Useful for research requiring high precision.

64

"""

65

66

def convert_module(self, module: nn.Module) -> nn.Module:

67

"""Convert module parameters and buffers to float64."""

68

69

def convert_input(self, data: Any) -> Any:

70

"""Convert input tensors to float64."""

71

72

def forward_context(self) -> AbstractContextManager:

73

"""Context manager ensuring double precision during forward pass."""

74

```

75

76

### Half Precision

77

78

16-bit half precision for memory efficiency.

79

80

```python { .api }

81

class HalfPrecision(Precision):

82

"""

83

16-bit half precision plugin.

84

85

Uses 16-bit floating point (float16) for memory efficiency

86

and faster training on supported hardware.

87

"""

88

89

def convert_module(self, module: nn.Module) -> nn.Module:

90

"""Convert module parameters and buffers to float16."""

91

92

def convert_input(self, data: Any) -> Any:

93

"""Convert input tensors to float16."""

94

95

def forward_context(self) -> AbstractContextManager:

96

"""Context manager for half precision forward pass."""

97

```

98

99

### Mixed Precision (AMP)

100

101

Automatic Mixed Precision using PyTorch's native AMP implementation.

102

103

```python { .api }

104

class MixedPrecision(Precision):

105

"""

106

Automatic Mixed Precision plugin using PyTorch AMP.

107

108

Combines float16 precision for speed with float32 precision

109

for numerical stability using automatic loss scaling.

110

"""

111

112

def __init__(

113

self,

114

precision: Union[str, int] = "16-mixed",

115

device: str = "cuda",

116

scaler: Optional[torch.cuda.amp.GradScaler] = None

117

):

118

"""

119

Initialize mixed precision plugin.

120

121

Args:

122

precision: Precision mode ("16-mixed", "bf16-mixed")

123

device: Target device ("cuda", "cpu")

124

scaler: Custom gradient scaler instance

125

"""

126

127

def setup_scaler(self) -> torch.cuda.amp.GradScaler:

128

"""Setup gradient scaler for loss scaling."""

129

130

def forward_context(self) -> AbstractContextManager:

131

"""Autocast context manager for mixed precision forward pass."""

132

133

def optimizer_step(

134

self,

135

optimizer: Optimizer,

136

model: nn.Module,

137

closure: callable,

138

**kwargs

139

) -> Any:

140

"""Optimizer step with gradient scaling and unscaling."""

141

142

def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:

143

"""Scale loss before backward pass."""

144

145

def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:

146

"""Handle gradient unscaling after backward pass."""

147

```

148

149

### BitsAndBytes Precision

150

151

Quantization using BitsAndBytes library for memory-efficient training.

152

153

```python { .api }

154

class BitsandbytesPrecision(Precision):

155

"""

156

BitsAndBytes precision plugin for quantized training.

157

158

Uses BitsAndBytes library for 8-bit and 4-bit quantization

159

to reduce memory usage for large model training.

160

"""

161

162

def __init__(

163

self,

164

mode: Union[str, BitsAndBytesConfig],

165

dtype: Optional[torch.dtype] = None,

166

ignore_modules: Optional[set[str]] = None

167

):

168

"""

169

Initialize BitsAndBytes precision plugin.

170

171

Args:

172

mode: Quantization mode ("nf4", "fp4", "int8") or config object

173

dtype: Compute dtype for quantized weights

174

ignore_modules: Set of module names to skip quantization

175

"""

176

177

def convert_module(self, module: nn.Module) -> nn.Module:

178

"""Convert module to use quantized weights."""

179

180

def setup_bnb_config(self) -> BitsAndBytesConfig:

181

"""Setup BitsAndBytes configuration."""

182

```

183

184

### DeepSpeed Precision

185

186

Precision plugin integrated with DeepSpeed for large-scale training.

187

188

```python { .api }

189

class DeepSpeedPrecision(Precision):

190

"""

191

DeepSpeed precision plugin.

192

193

Handles precision in conjunction with DeepSpeed strategy

194

for large-scale model training with ZeRO optimizations.

195

"""

196

197

def __init__(

198

self,

199

precision: Union[str, int] = "16-mixed",

200

amp_type: str = "native",

201

amp_level: Optional[str] = None

202

):

203

"""

204

Initialize DeepSpeed precision plugin.

205

206

Args:

207

precision: Precision mode

208

amp_type: AMP implementation ("native", "apex")

209

amp_level: APEX AMP level if using APEX

210

"""

211

212

def convert_module(self, module: nn.Module) -> nn.Module:

213

"""Convert module for DeepSpeed precision handling."""

214

215

def forward_context(self) -> AbstractContextManager:

216

"""Context manager for DeepSpeed precision forward pass."""

217

```

218

219

### FSDP Precision

220

221

Precision plugin optimized for Fully Sharded Data Parallel training.

222

223

```python { .api }

224

class FSDPPrecision(Precision):

225

"""

226

FSDP precision plugin.

227

228

Handles precision in conjunction with FSDP strategy,

229

managing parameter and gradient precision for sharded training.

230

"""

231

232

def __init__(

233

self,

234

precision: Union[str, int] = "32-true",

235

scaler: Optional[torch.cuda.amp.GradScaler] = None

236

):

237

"""

238

Initialize FSDP precision plugin.

239

240

Args:

241

precision: Precision mode

242

scaler: Custom gradient scaler

243

"""

244

245

def convert_module(self, module: nn.Module) -> nn.Module:

246

"""Convert module for FSDP precision handling."""

247

248

def setup_mixed_precision_config(self) -> Optional[MixedPrecision]:

249

"""Setup FSDP mixed precision configuration."""

250

```

251

252

### XLA Precision

253

254

Precision plugin for XLA/TPU training.

255

256

```python { .api }

257

class XLAPrecision(Precision):

258

"""

259

XLA precision plugin for TPU training.

260

261

Handles precision for XLA-compiled models running on TPUs,

262

with support for bfloat16 and float32 precision.

263

"""

264

265

def __init__(self, precision: Union[str, int] = "32-true"):

266

"""

267

Initialize XLA precision plugin.

268

269

Args:

270

precision: Precision mode ("32-true", "bf16-mixed")

271

"""

272

273

def convert_module(self, module: nn.Module) -> nn.Module:

274

"""Convert module for XLA precision handling."""

275

276

def forward_context(self) -> AbstractContextManager:

277

"""Context manager for XLA precision forward pass."""

278

```

279

280

### Transformer Engine Precision

281

282

NVIDIA Transformer Engine precision for optimized transformer training.

283

284

```python { .api }

285

class TransformerEnginePrecision(Precision):

286

"""

287

Transformer Engine precision plugin.

288

289

Uses NVIDIA Transformer Engine for optimized transformer

290

model training with FP8 precision on supported hardware.

291

"""

292

293

def __init__(

294

self,

295

precision: Union[str, int] = "16-mixed",

296

replace_layers: bool = True,

297

fp8_format: str = "hybrid"

298

):

299

"""

300

Initialize Transformer Engine precision plugin.

301

302

Args:

303

precision: Base precision mode

304

replace_layers: Whether to replace standard layers with TE layers

305

fp8_format: FP8 format ("e4m3", "e5m2", "hybrid")

306

"""

307

308

def convert_module(self, module: nn.Module) -> nn.Module:

309

"""Convert transformer layers to Transformer Engine layers."""

310

311

def setup_fp8_recipe(self) -> DelayedScaling:

312

"""Setup FP8 recipe for Transformer Engine."""

313

```

314

315

## Usage Examples

316

317

### Basic Mixed Precision

318

319

```python

320

from lightning.fabric import Fabric

321

322

# Automatic mixed precision with 16-bit

323

fabric = Fabric(precision="16-mixed", accelerator="gpu")

324

325

# BFloat16 mixed precision (better numerical stability)

326

fabric = Fabric(precision="bf16-mixed", accelerator="gpu")

327

```

328

329

### Custom AMP Configuration

330

331

```python

332

from lightning.fabric.plugins.precision import MixedPrecision

333

import torch

334

335

# Custom gradient scaler

336

scaler = torch.cuda.amp.GradScaler(

337

init_scale=2**16,

338

growth_factor=2.0,

339

backoff_factor=0.5,

340

growth_interval=2000

341

)

342

343

precision_plugin = MixedPrecision(

344

precision="16-mixed",

345

device="cuda",

346

scaler=scaler

347

)

348

349

fabric = Fabric(

350

precision=precision_plugin,

351

accelerator="gpu"

352

)

353

```

354

355

### BitsAndBytes Quantization

356

357

```python

358

from lightning.fabric.plugins.precision import BitsandbytesPrecision

359

360

# 8-bit quantization

361

precision_plugin = BitsandbytesPrecision(mode="int8")

362

363

# 4-bit NormalFloat quantization

364

precision_plugin = BitsandbytesPrecision(

365

mode="nf4",

366

dtype=torch.bfloat16,

367

ignore_modules={"lm_head", "embed_tokens"}

368

)

369

370

fabric = Fabric(

371

precision=precision_plugin,

372

accelerator="gpu"

373

)

374

```

375

376

### DeepSpeed Precision Integration

377

378

```python

379

from lightning.fabric.plugins.precision import DeepSpeedPrecision

380

from lightning.fabric.strategies import DeepSpeedStrategy

381

382

# DeepSpeed with mixed precision

383

precision_plugin = DeepSpeedPrecision(precision="16-mixed")

384

strategy = DeepSpeedStrategy(stage=2)

385

386

fabric = Fabric(

387

strategy=strategy,

388

precision=precision_plugin,

389

devices=8

390

)

391

```

392

393

### FSDP with Mixed Precision

394

395

```python

396

from lightning.fabric.plugins.precision import FSDPPrecision

397

from lightning.fabric.strategies import FSDPStrategy

398

from torch.distributed.fsdp import MixedPrecision as FSDPMixedPrecision

399

400

# FSDP mixed precision configuration

401

fsdp_precision = FSDPPrecision(precision="bf16-mixed")

402

fsdp_strategy = FSDPStrategy(

403

mixed_precision=FSDPMixedPrecision(

404

param_dtype=torch.bfloat16,

405

reduce_dtype=torch.bfloat16,

406

buffer_dtype=torch.bfloat16

407

)

408

)

409

410

fabric = Fabric(

411

strategy=fsdp_strategy,

412

precision=fsdp_precision,

413

devices=4

414

)

415

```

416

417

### TPU BFloat16 Training

418

419

```python

420

from lightning.fabric.plugins.precision import XLAPrecision

421

422

# TPU with bfloat16 precision

423

precision_plugin = XLAPrecision(precision="bf16-mixed")

424

425

fabric = Fabric(

426

accelerator="tpu",

427

strategy="xla",

428

precision=precision_plugin,

429

devices=8

430

)

431

```

432

433

### Manual Precision Control

434

435

```python

436

# Manual autocast usage

437

fabric = Fabric(precision="16-mixed")

438

439

model, optimizer = fabric.setup(model, optimizer)

440

441

for batch in dataloader:

442

optimizer.zero_grad()

443

444

# Manual autocast context

445

with fabric.autocast():

446

predictions = model(batch["input"])

447

loss = criterion(predictions, batch["target"])

448

449

fabric.backward(loss)

450

optimizer.step()

451

```

452

453

### Gradient Clipping with Precision

454

455

```python

456

# Gradient clipping with mixed precision

457

fabric = Fabric(precision="16-mixed")

458

459

model, optimizer = fabric.setup(model, optimizer)

460

461

for batch in dataloader:

462

optimizer.zero_grad()

463

464

with fabric.autocast():

465

loss = compute_loss(model, batch)

466

467

fabric.backward(loss)

468

469

# Clip gradients (handles unscaling automatically)

470

fabric.clip_gradients(model, optimizer, max_norm=1.0)

471

472

optimizer.step()

473

```

474

475

### Precision State Management

476

477

```python

478

# Save precision state in checkpoint

479

fabric = Fabric(precision="16-mixed")

480

481

# Precision state is automatically included in Fabric checkpoints

482

state = {

483

"model": model,

484

"optimizer": optimizer,

485

"precision": fabric.precision_plugin.state_dict()

486

}

487

fabric.save("checkpoint.ckpt", state)

488

489

# Load precision state

490

loaded_state = fabric.load("checkpoint.ckpt")

491

fabric.precision_plugin.load_state_dict(loaded_state["precision"])

492

```