0
# Precision
1
2
Precision plugins for mixed precision training, quantization, and memory optimization techniques.
3
4
## Capabilities
5
6
### Base Precision
7
8
Abstract base class defining the precision interface.
9
10
```python { .api }
11
class Precision:
12
"""
13
Abstract base class for precision plugins.
14
15
Precision plugins handle numerical precision, mixed precision training,
16
quantization, and memory optimization techniques.
17
"""
18
19
def convert_module(self, module: nn.Module) -> nn.Module:
20
"""Convert module to target precision."""
21
22
def convert_input(self, data: Any) -> Any:
23
"""Convert input data to target precision."""
24
25
def convert_output(self, data: Any) -> Any:
26
"""Convert output data from target precision."""
27
28
def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
29
"""Pre-process tensor before backward pass."""
30
31
def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
32
"""Post-process tensor after backward pass."""
33
34
def forward_context(self) -> AbstractContextManager:
35
"""Context manager for forward pass precision."""
36
37
def optimizer_step(
38
self,
39
optimizer: Optimizer,
40
model: nn.Module,
41
closure: callable,
42
**kwargs
43
) -> Any:
44
"""Execute optimizer step with precision handling."""
45
46
def state_dict(self) -> dict[str, Any]:
47
"""Get precision plugin state."""
48
49
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
50
"""Load precision plugin state."""
51
```
52
53
### Double Precision
54
55
64-bit double precision for maximum numerical accuracy.
56
57
```python { .api }
58
class DoublePrecision(Precision):
59
"""
60
64-bit double precision plugin.
61
62
Provides maximum numerical precision using 64-bit floating point
63
arithmetic. Useful for research requiring high precision.
64
"""
65
66
def convert_module(self, module: nn.Module) -> nn.Module:
67
"""Convert module parameters and buffers to float64."""
68
69
def convert_input(self, data: Any) -> Any:
70
"""Convert input tensors to float64."""
71
72
def forward_context(self) -> AbstractContextManager:
73
"""Context manager ensuring double precision during forward pass."""
74
```
75
76
### Half Precision
77
78
16-bit half precision for memory efficiency.
79
80
```python { .api }
81
class HalfPrecision(Precision):
82
"""
83
16-bit half precision plugin.
84
85
Uses 16-bit floating point (float16) for memory efficiency
86
and faster training on supported hardware.
87
"""
88
89
def convert_module(self, module: nn.Module) -> nn.Module:
90
"""Convert module parameters and buffers to float16."""
91
92
def convert_input(self, data: Any) -> Any:
93
"""Convert input tensors to float16."""
94
95
def forward_context(self) -> AbstractContextManager:
96
"""Context manager for half precision forward pass."""
97
```
98
99
### Mixed Precision (AMP)
100
101
Automatic Mixed Precision using PyTorch's native AMP implementation.
102
103
```python { .api }
104
class MixedPrecision(Precision):
105
"""
106
Automatic Mixed Precision plugin using PyTorch AMP.
107
108
Combines float16 precision for speed with float32 precision
109
for numerical stability using automatic loss scaling.
110
"""
111
112
def __init__(
113
self,
114
precision: Union[str, int] = "16-mixed",
115
device: str = "cuda",
116
scaler: Optional[torch.cuda.amp.GradScaler] = None
117
):
118
"""
119
Initialize mixed precision plugin.
120
121
Args:
122
precision: Precision mode ("16-mixed", "bf16-mixed")
123
device: Target device ("cuda", "cpu")
124
scaler: Custom gradient scaler instance
125
"""
126
127
def setup_scaler(self) -> torch.cuda.amp.GradScaler:
128
"""Setup gradient scaler for loss scaling."""
129
130
def forward_context(self) -> AbstractContextManager:
131
"""Autocast context manager for mixed precision forward pass."""
132
133
def optimizer_step(
134
self,
135
optimizer: Optimizer,
136
model: nn.Module,
137
closure: callable,
138
**kwargs
139
) -> Any:
140
"""Optimizer step with gradient scaling and unscaling."""
141
142
def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
143
"""Scale loss before backward pass."""
144
145
def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
146
"""Handle gradient unscaling after backward pass."""
147
```
148
149
### BitsAndBytes Precision
150
151
Quantization using BitsAndBytes library for memory-efficient training.
152
153
```python { .api }
154
class BitsandbytesPrecision(Precision):
155
"""
156
BitsAndBytes precision plugin for quantized training.
157
158
Uses BitsAndBytes library for 8-bit and 4-bit quantization
159
to reduce memory usage for large model training.
160
"""
161
162
def __init__(
163
self,
164
mode: Union[str, BitsAndBytesConfig],
165
dtype: Optional[torch.dtype] = None,
166
ignore_modules: Optional[set[str]] = None
167
):
168
"""
169
Initialize BitsAndBytes precision plugin.
170
171
Args:
172
mode: Quantization mode ("nf4", "fp4", "int8") or config object
173
dtype: Compute dtype for quantized weights
174
ignore_modules: Set of module names to skip quantization
175
"""
176
177
def convert_module(self, module: nn.Module) -> nn.Module:
178
"""Convert module to use quantized weights."""
179
180
def setup_bnb_config(self) -> BitsAndBytesConfig:
181
"""Setup BitsAndBytes configuration."""
182
```
183
184
### DeepSpeed Precision
185
186
Precision plugin integrated with DeepSpeed for large-scale training.
187
188
```python { .api }
189
class DeepSpeedPrecision(Precision):
190
"""
191
DeepSpeed precision plugin.
192
193
Handles precision in conjunction with DeepSpeed strategy
194
for large-scale model training with ZeRO optimizations.
195
"""
196
197
def __init__(
198
self,
199
precision: Union[str, int] = "16-mixed",
200
amp_type: str = "native",
201
amp_level: Optional[str] = None
202
):
203
"""
204
Initialize DeepSpeed precision plugin.
205
206
Args:
207
precision: Precision mode
208
amp_type: AMP implementation ("native", "apex")
209
amp_level: APEX AMP level if using APEX
210
"""
211
212
def convert_module(self, module: nn.Module) -> nn.Module:
213
"""Convert module for DeepSpeed precision handling."""
214
215
def forward_context(self) -> AbstractContextManager:
216
"""Context manager for DeepSpeed precision forward pass."""
217
```
218
219
### FSDP Precision
220
221
Precision plugin optimized for Fully Sharded Data Parallel training.
222
223
```python { .api }
224
class FSDPPrecision(Precision):
225
"""
226
FSDP precision plugin.
227
228
Handles precision in conjunction with FSDP strategy,
229
managing parameter and gradient precision for sharded training.
230
"""
231
232
def __init__(
233
self,
234
precision: Union[str, int] = "32-true",
235
scaler: Optional[torch.cuda.amp.GradScaler] = None
236
):
237
"""
238
Initialize FSDP precision plugin.
239
240
Args:
241
precision: Precision mode
242
scaler: Custom gradient scaler
243
"""
244
245
def convert_module(self, module: nn.Module) -> nn.Module:
246
"""Convert module for FSDP precision handling."""
247
248
def setup_mixed_precision_config(self) -> Optional[MixedPrecision]:
249
"""Setup FSDP mixed precision configuration."""
250
```
251
252
### XLA Precision
253
254
Precision plugin for XLA/TPU training.
255
256
```python { .api }
257
class XLAPrecision(Precision):
258
"""
259
XLA precision plugin for TPU training.
260
261
Handles precision for XLA-compiled models running on TPUs,
262
with support for bfloat16 and float32 precision.
263
"""
264
265
def __init__(self, precision: Union[str, int] = "32-true"):
266
"""
267
Initialize XLA precision plugin.
268
269
Args:
270
precision: Precision mode ("32-true", "bf16-mixed")
271
"""
272
273
def convert_module(self, module: nn.Module) -> nn.Module:
274
"""Convert module for XLA precision handling."""
275
276
def forward_context(self) -> AbstractContextManager:
277
"""Context manager for XLA precision forward pass."""
278
```
279
280
### Transformer Engine Precision
281
282
NVIDIA Transformer Engine precision for optimized transformer training.
283
284
```python { .api }
285
class TransformerEnginePrecision(Precision):
286
"""
287
Transformer Engine precision plugin.
288
289
Uses NVIDIA Transformer Engine for optimized transformer
290
model training with FP8 precision on supported hardware.
291
"""
292
293
def __init__(
294
self,
295
precision: Union[str, int] = "16-mixed",
296
replace_layers: bool = True,
297
fp8_format: str = "hybrid"
298
):
299
"""
300
Initialize Transformer Engine precision plugin.
301
302
Args:
303
precision: Base precision mode
304
replace_layers: Whether to replace standard layers with TE layers
305
fp8_format: FP8 format ("e4m3", "e5m2", "hybrid")
306
"""
307
308
def convert_module(self, module: nn.Module) -> nn.Module:
309
"""Convert transformer layers to Transformer Engine layers."""
310
311
def setup_fp8_recipe(self) -> DelayedScaling:
312
"""Setup FP8 recipe for Transformer Engine."""
313
```
314
315
## Usage Examples
316
317
### Basic Mixed Precision
318
319
```python
320
from lightning.fabric import Fabric
321
322
# Automatic mixed precision with 16-bit
323
fabric = Fabric(precision="16-mixed", accelerator="gpu")
324
325
# BFloat16 mixed precision (better numerical stability)
326
fabric = Fabric(precision="bf16-mixed", accelerator="gpu")
327
```
328
329
### Custom AMP Configuration
330
331
```python
332
from lightning.fabric.plugins.precision import MixedPrecision
333
import torch
334
335
# Custom gradient scaler
336
scaler = torch.cuda.amp.GradScaler(
337
init_scale=2**16,
338
growth_factor=2.0,
339
backoff_factor=0.5,
340
growth_interval=2000
341
)
342
343
precision_plugin = MixedPrecision(
344
precision="16-mixed",
345
device="cuda",
346
scaler=scaler
347
)
348
349
fabric = Fabric(
350
precision=precision_plugin,
351
accelerator="gpu"
352
)
353
```
354
355
### BitsAndBytes Quantization
356
357
```python
358
from lightning.fabric.plugins.precision import BitsandbytesPrecision
359
360
# 8-bit quantization
361
precision_plugin = BitsandbytesPrecision(mode="int8")
362
363
# 4-bit NormalFloat quantization
364
precision_plugin = BitsandbytesPrecision(
365
mode="nf4",
366
dtype=torch.bfloat16,
367
ignore_modules={"lm_head", "embed_tokens"}
368
)
369
370
fabric = Fabric(
371
precision=precision_plugin,
372
accelerator="gpu"
373
)
374
```
375
376
### DeepSpeed Precision Integration
377
378
```python
379
from lightning.fabric.plugins.precision import DeepSpeedPrecision
380
from lightning.fabric.strategies import DeepSpeedStrategy
381
382
# DeepSpeed with mixed precision
383
precision_plugin = DeepSpeedPrecision(precision="16-mixed")
384
strategy = DeepSpeedStrategy(stage=2)
385
386
fabric = Fabric(
387
strategy=strategy,
388
precision=precision_plugin,
389
devices=8
390
)
391
```
392
393
### FSDP with Mixed Precision
394
395
```python
396
from lightning.fabric.plugins.precision import FSDPPrecision
397
from lightning.fabric.strategies import FSDPStrategy
398
from torch.distributed.fsdp import MixedPrecision as FSDPMixedPrecision
399
400
# FSDP mixed precision configuration
401
fsdp_precision = FSDPPrecision(precision="bf16-mixed")
402
fsdp_strategy = FSDPStrategy(
403
mixed_precision=FSDPMixedPrecision(
404
param_dtype=torch.bfloat16,
405
reduce_dtype=torch.bfloat16,
406
buffer_dtype=torch.bfloat16
407
)
408
)
409
410
fabric = Fabric(
411
strategy=fsdp_strategy,
412
precision=fsdp_precision,
413
devices=4
414
)
415
```
416
417
### TPU BFloat16 Training
418
419
```python
420
from lightning.fabric.plugins.precision import XLAPrecision
421
422
# TPU with bfloat16 precision
423
precision_plugin = XLAPrecision(precision="bf16-mixed")
424
425
fabric = Fabric(
426
accelerator="tpu",
427
strategy="xla",
428
precision=precision_plugin,
429
devices=8
430
)
431
```
432
433
### Manual Precision Control
434
435
```python
436
# Manual autocast usage
437
fabric = Fabric(precision="16-mixed")
438
439
model, optimizer = fabric.setup(model, optimizer)
440
441
for batch in dataloader:
442
optimizer.zero_grad()
443
444
# Manual autocast context
445
with fabric.autocast():
446
predictions = model(batch["input"])
447
loss = criterion(predictions, batch["target"])
448
449
fabric.backward(loss)
450
optimizer.step()
451
```
452
453
### Gradient Clipping with Precision
454
455
```python
456
# Gradient clipping with mixed precision
457
fabric = Fabric(precision="16-mixed")
458
459
model, optimizer = fabric.setup(model, optimizer)
460
461
for batch in dataloader:
462
optimizer.zero_grad()
463
464
with fabric.autocast():
465
loss = compute_loss(model, batch)
466
467
fabric.backward(loss)
468
469
# Clip gradients (handles unscaling automatically)
470
fabric.clip_gradients(model, optimizer, max_norm=1.0)
471
472
optimizer.step()
473
```
474
475
### Precision State Management
476
477
```python
478
# Save precision state in checkpoint
479
fabric = Fabric(precision="16-mixed")
480
481
# Precision state is automatically included in Fabric checkpoints
482
state = {
483
"model": model,
484
"optimizer": optimizer,
485
"precision": fabric.precision_plugin.state_dict()
486
}
487
fabric.save("checkpoint.ckpt", state)
488
489
# Load precision state
490
loaded_state = fabric.load("checkpoint.ckpt")
491
fabric.precision_plugin.load_state_dict(loaded_state["precision"])
492
```