Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
—
Precision plugins for mixed precision training, quantization, and memory optimization techniques.
Abstract base class defining the precision interface.
class Precision:
"""
Abstract base class for precision plugins.
Precision plugins handle numerical precision, mixed precision training,
quantization, and memory optimization techniques.
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module to target precision."""
def convert_input(self, data: Any) -> Any:
"""Convert input data to target precision."""
def convert_output(self, data: Any) -> Any:
"""Convert output data from target precision."""
def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
"""Pre-process tensor before backward pass."""
def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
"""Post-process tensor after backward pass."""
def forward_context(self) -> AbstractContextManager:
"""Context manager for forward pass precision."""
def optimizer_step(
self,
optimizer: Optimizer,
model: nn.Module,
closure: callable,
**kwargs
) -> Any:
"""Execute optimizer step with precision handling."""
def state_dict(self) -> dict[str, Any]:
"""Get precision plugin state."""
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load precision plugin state."""64-bit double precision for maximum numerical accuracy.
class DoublePrecision(Precision):
"""
64-bit double precision plugin.
Provides maximum numerical precision using 64-bit floating point
arithmetic. Useful for research requiring high precision.
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module parameters and buffers to float64."""
def convert_input(self, data: Any) -> Any:
"""Convert input tensors to float64."""
def forward_context(self) -> AbstractContextManager:
"""Context manager ensuring double precision during forward pass."""16-bit half precision for memory efficiency.
class HalfPrecision(Precision):
"""
16-bit half precision plugin.
Uses 16-bit floating point (float16) for memory efficiency
and faster training on supported hardware.
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module parameters and buffers to float16."""
def convert_input(self, data: Any) -> Any:
"""Convert input tensors to float16."""
def forward_context(self) -> AbstractContextManager:
"""Context manager for half precision forward pass."""Automatic Mixed Precision using PyTorch's native AMP implementation.
class MixedPrecision(Precision):
"""
Automatic Mixed Precision plugin using PyTorch AMP.
Combines float16 precision for speed with float32 precision
for numerical stability using automatic loss scaling.
"""
def __init__(
self,
precision: Union[str, int] = "16-mixed",
device: str = "cuda",
scaler: Optional[torch.cuda.amp.GradScaler] = None
):
"""
Initialize mixed precision plugin.
Args:
precision: Precision mode ("16-mixed", "bf16-mixed")
device: Target device ("cuda", "cpu")
scaler: Custom gradient scaler instance
"""
def setup_scaler(self) -> torch.cuda.amp.GradScaler:
"""Setup gradient scaler for loss scaling."""
def forward_context(self) -> AbstractContextManager:
"""Autocast context manager for mixed precision forward pass."""
def optimizer_step(
self,
optimizer: Optimizer,
model: nn.Module,
closure: callable,
**kwargs
) -> Any:
"""Optimizer step with gradient scaling and unscaling."""
def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
"""Scale loss before backward pass."""
def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
"""Handle gradient unscaling after backward pass."""Quantization using BitsAndBytes library for memory-efficient training.
class BitsandbytesPrecision(Precision):
"""
BitsAndBytes precision plugin for quantized training.
Uses BitsAndBytes library for 8-bit and 4-bit quantization
to reduce memory usage for large model training.
"""
def __init__(
self,
mode: Union[str, BitsAndBytesConfig],
dtype: Optional[torch.dtype] = None,
ignore_modules: Optional[set[str]] = None
):
"""
Initialize BitsAndBytes precision plugin.
Args:
mode: Quantization mode ("nf4", "fp4", "int8") or config object
dtype: Compute dtype for quantized weights
ignore_modules: Set of module names to skip quantization
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module to use quantized weights."""
def setup_bnb_config(self) -> BitsAndBytesConfig:
"""Setup BitsAndBytes configuration."""Precision plugin integrated with DeepSpeed for large-scale training.
class DeepSpeedPrecision(Precision):
"""
DeepSpeed precision plugin.
Handles precision in conjunction with DeepSpeed strategy
for large-scale model training with ZeRO optimizations.
"""
def __init__(
self,
precision: Union[str, int] = "16-mixed",
amp_type: str = "native",
amp_level: Optional[str] = None
):
"""
Initialize DeepSpeed precision plugin.
Args:
precision: Precision mode
amp_type: AMP implementation ("native", "apex")
amp_level: APEX AMP level if using APEX
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module for DeepSpeed precision handling."""
def forward_context(self) -> AbstractContextManager:
"""Context manager for DeepSpeed precision forward pass."""Precision plugin optimized for Fully Sharded Data Parallel training.
class FSDPPrecision(Precision):
"""
FSDP precision plugin.
Handles precision in conjunction with FSDP strategy,
managing parameter and gradient precision for sharded training.
"""
def __init__(
self,
precision: Union[str, int] = "32-true",
scaler: Optional[torch.cuda.amp.GradScaler] = None
):
"""
Initialize FSDP precision plugin.
Args:
precision: Precision mode
scaler: Custom gradient scaler
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module for FSDP precision handling."""
def setup_mixed_precision_config(self) -> Optional[MixedPrecision]:
"""Setup FSDP mixed precision configuration."""Precision plugin for XLA/TPU training.
class XLAPrecision(Precision):
"""
XLA precision plugin for TPU training.
Handles precision for XLA-compiled models running on TPUs,
with support for bfloat16 and float32 precision.
"""
def __init__(self, precision: Union[str, int] = "32-true"):
"""
Initialize XLA precision plugin.
Args:
precision: Precision mode ("32-true", "bf16-mixed")
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module for XLA precision handling."""
def forward_context(self) -> AbstractContextManager:
"""Context manager for XLA precision forward pass."""NVIDIA Transformer Engine precision for optimized transformer training.
class TransformerEnginePrecision(Precision):
"""
Transformer Engine precision plugin.
Uses NVIDIA Transformer Engine for optimized transformer
model training with FP8 precision on supported hardware.
"""
def __init__(
self,
precision: Union[str, int] = "16-mixed",
replace_layers: bool = True,
fp8_format: str = "hybrid"
):
"""
Initialize Transformer Engine precision plugin.
Args:
precision: Base precision mode
replace_layers: Whether to replace standard layers with TE layers
fp8_format: FP8 format ("e4m3", "e5m2", "hybrid")
"""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert transformer layers to Transformer Engine layers."""
def setup_fp8_recipe(self) -> DelayedScaling:
"""Setup FP8 recipe for Transformer Engine."""from lightning.fabric import Fabric
# Automatic mixed precision with 16-bit
fabric = Fabric(precision="16-mixed", accelerator="gpu")
# BFloat16 mixed precision (better numerical stability)
fabric = Fabric(precision="bf16-mixed", accelerator="gpu")from lightning.fabric.plugins.precision import MixedPrecision
import torch
# Custom gradient scaler
scaler = torch.cuda.amp.GradScaler(
init_scale=2**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000
)
precision_plugin = MixedPrecision(
precision="16-mixed",
device="cuda",
scaler=scaler
)
fabric = Fabric(
precision=precision_plugin,
accelerator="gpu"
)from lightning.fabric.plugins.precision import BitsandbytesPrecision
# 8-bit quantization
precision_plugin = BitsandbytesPrecision(mode="int8")
# 4-bit NormalFloat quantization
precision_plugin = BitsandbytesPrecision(
mode="nf4",
dtype=torch.bfloat16,
ignore_modules={"lm_head", "embed_tokens"}
)
fabric = Fabric(
precision=precision_plugin,
accelerator="gpu"
)from lightning.fabric.plugins.precision import DeepSpeedPrecision
from lightning.fabric.strategies import DeepSpeedStrategy
# DeepSpeed with mixed precision
precision_plugin = DeepSpeedPrecision(precision="16-mixed")
strategy = DeepSpeedStrategy(stage=2)
fabric = Fabric(
strategy=strategy,
precision=precision_plugin,
devices=8
)from lightning.fabric.plugins.precision import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp import MixedPrecision as FSDPMixedPrecision
# FSDP mixed precision configuration
fsdp_precision = FSDPPrecision(precision="bf16-mixed")
fsdp_strategy = FSDPStrategy(
mixed_precision=FSDPMixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
)
)
fabric = Fabric(
strategy=fsdp_strategy,
precision=fsdp_precision,
devices=4
)from lightning.fabric.plugins.precision import XLAPrecision
# TPU with bfloat16 precision
precision_plugin = XLAPrecision(precision="bf16-mixed")
fabric = Fabric(
accelerator="tpu",
strategy="xla",
precision=precision_plugin,
devices=8
)# Manual autocast usage
fabric = Fabric(precision="16-mixed")
model, optimizer = fabric.setup(model, optimizer)
for batch in dataloader:
optimizer.zero_grad()
# Manual autocast context
with fabric.autocast():
predictions = model(batch["input"])
loss = criterion(predictions, batch["target"])
fabric.backward(loss)
optimizer.step()# Gradient clipping with mixed precision
fabric = Fabric(precision="16-mixed")
model, optimizer = fabric.setup(model, optimizer)
for batch in dataloader:
optimizer.zero_grad()
with fabric.autocast():
loss = compute_loss(model, batch)
fabric.backward(loss)
# Clip gradients (handles unscaling automatically)
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()# Save precision state in checkpoint
fabric = Fabric(precision="16-mixed")
# Precision state is automatically included in Fabric checkpoints
state = {
"model": model,
"optimizer": optimizer,
"precision": fabric.precision_plugin.state_dict()
}
fabric.save("checkpoint.ckpt", state)
# Load precision state
loaded_state = fabric.load("checkpoint.ckpt")
fabric.precision_plugin.load_state_dict(loaded_state["precision"])Install with Tessl CLI
npx tessl i tessl/pypi-lightning-fabric