CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning-fabric

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate

Pending
Overview
Eval results
Files

precision.mddocs/

Precision

Precision plugins for mixed precision training, quantization, and memory optimization techniques.

Capabilities

Base Precision

Abstract base class defining the precision interface.

class Precision:
    """
    Abstract base class for precision plugins.
    
    Precision plugins handle numerical precision, mixed precision training,
    quantization, and memory optimization techniques.
    """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module to target precision."""
    
    def convert_input(self, data: Any) -> Any:
        """Convert input data to target precision."""
    
    def convert_output(self, data: Any) -> Any:
        """Convert output data from target precision."""
    
    def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
        """Pre-process tensor before backward pass."""
    
    def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
        """Post-process tensor after backward pass."""
    
    def forward_context(self) -> AbstractContextManager:
        """Context manager for forward pass precision."""
    
    def optimizer_step(
        self,
        optimizer: Optimizer,
        model: nn.Module,
        closure: callable,
        **kwargs
    ) -> Any:
        """Execute optimizer step with precision handling."""
    
    def state_dict(self) -> dict[str, Any]:
        """Get precision plugin state."""
    
    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        """Load precision plugin state."""

Double Precision

64-bit double precision for maximum numerical accuracy.

class DoublePrecision(Precision):
    """
    64-bit double precision plugin.
    
    Provides maximum numerical precision using 64-bit floating point
    arithmetic. Useful for research requiring high precision.
    """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module parameters and buffers to float64."""
    
    def convert_input(self, data: Any) -> Any:
        """Convert input tensors to float64."""
    
    def forward_context(self) -> AbstractContextManager:
        """Context manager ensuring double precision during forward pass."""

Half Precision

16-bit half precision for memory efficiency.

class HalfPrecision(Precision):
    """
    16-bit half precision plugin.
    
    Uses 16-bit floating point (float16) for memory efficiency
    and faster training on supported hardware.
    """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module parameters and buffers to float16."""
    
    def convert_input(self, data: Any) -> Any:
        """Convert input tensors to float16."""
    
    def forward_context(self) -> AbstractContextManager:
        """Context manager for half precision forward pass."""

Mixed Precision (AMP)

Automatic Mixed Precision using PyTorch's native AMP implementation.

class MixedPrecision(Precision):
    """
    Automatic Mixed Precision plugin using PyTorch AMP.
    
    Combines float16 precision for speed with float32 precision
    for numerical stability using automatic loss scaling.
    """
    
    def __init__(
        self,
        precision: Union[str, int] = "16-mixed",
        device: str = "cuda",
        scaler: Optional[torch.cuda.amp.GradScaler] = None
    ):
        """
        Initialize mixed precision plugin.
        
        Args:
            precision: Precision mode ("16-mixed", "bf16-mixed")
            device: Target device ("cuda", "cpu")  
            scaler: Custom gradient scaler instance
        """
    
    def setup_scaler(self) -> torch.cuda.amp.GradScaler:
        """Setup gradient scaler for loss scaling."""
    
    def forward_context(self) -> AbstractContextManager:
        """Autocast context manager for mixed precision forward pass."""
    
    def optimizer_step(
        self,
        optimizer: Optimizer,
        model: nn.Module,
        closure: callable,
        **kwargs
    ) -> Any:
        """Optimizer step with gradient scaling and unscaling."""
    
    def pre_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
        """Scale loss before backward pass."""
    
    def post_backward(self, tensor: Tensor, module: nn.Module) -> Tensor:
        """Handle gradient unscaling after backward pass."""

BitsAndBytes Precision

Quantization using BitsAndBytes library for memory-efficient training.

class BitsandbytesPrecision(Precision):
    """
    BitsAndBytes precision plugin for quantized training.
    
    Uses BitsAndBytes library for 8-bit and 4-bit quantization
    to reduce memory usage for large model training.
    """
    
    def __init__(
        self,
        mode: Union[str, BitsAndBytesConfig],
        dtype: Optional[torch.dtype] = None,
        ignore_modules: Optional[set[str]] = None
    ):
        """
        Initialize BitsAndBytes precision plugin.
        
        Args:
            mode: Quantization mode ("nf4", "fp4", "int8") or config object
            dtype: Compute dtype for quantized weights
            ignore_modules: Set of module names to skip quantization
        """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module to use quantized weights."""
    
    def setup_bnb_config(self) -> BitsAndBytesConfig:
        """Setup BitsAndBytes configuration."""

DeepSpeed Precision

Precision plugin integrated with DeepSpeed for large-scale training.

class DeepSpeedPrecision(Precision):
    """
    DeepSpeed precision plugin.
    
    Handles precision in conjunction with DeepSpeed strategy
    for large-scale model training with ZeRO optimizations.
    """
    
    def __init__(
        self,
        precision: Union[str, int] = "16-mixed",
        amp_type: str = "native",
        amp_level: Optional[str] = None
    ):
        """
        Initialize DeepSpeed precision plugin.
        
        Args:
            precision: Precision mode 
            amp_type: AMP implementation ("native", "apex")
            amp_level: APEX AMP level if using APEX
        """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module for DeepSpeed precision handling."""
    
    def forward_context(self) -> AbstractContextManager:
        """Context manager for DeepSpeed precision forward pass."""

FSDP Precision

Precision plugin optimized for Fully Sharded Data Parallel training.

class FSDPPrecision(Precision):
    """
    FSDP precision plugin.
    
    Handles precision in conjunction with FSDP strategy,
    managing parameter and gradient precision for sharded training.
    """
    
    def __init__(
        self,
        precision: Union[str, int] = "32-true",
        scaler: Optional[torch.cuda.amp.GradScaler] = None
    ):
        """
        Initialize FSDP precision plugin.
        
        Args:
            precision: Precision mode
            scaler: Custom gradient scaler
        """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module for FSDP precision handling."""
    
    def setup_mixed_precision_config(self) -> Optional[MixedPrecision]:
        """Setup FSDP mixed precision configuration."""

XLA Precision

Precision plugin for XLA/TPU training.

class XLAPrecision(Precision):
    """
    XLA precision plugin for TPU training.
    
    Handles precision for XLA-compiled models running on TPUs,
    with support for bfloat16 and float32 precision.
    """
    
    def __init__(self, precision: Union[str, int] = "32-true"):
        """
        Initialize XLA precision plugin.
        
        Args:
            precision: Precision mode ("32-true", "bf16-mixed")
        """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert module for XLA precision handling."""
    
    def forward_context(self) -> AbstractContextManager:
        """Context manager for XLA precision forward pass."""

Transformer Engine Precision

NVIDIA Transformer Engine precision for optimized transformer training.

class TransformerEnginePrecision(Precision):
    """
    Transformer Engine precision plugin.
    
    Uses NVIDIA Transformer Engine for optimized transformer
    model training with FP8 precision on supported hardware.
    """
    
    def __init__(
        self,
        precision: Union[str, int] = "16-mixed",
        replace_layers: bool = True,
        fp8_format: str = "hybrid"
    ):
        """
        Initialize Transformer Engine precision plugin.
        
        Args:
            precision: Base precision mode
            replace_layers: Whether to replace standard layers with TE layers
            fp8_format: FP8 format ("e4m3", "e5m2", "hybrid")
        """
    
    def convert_module(self, module: nn.Module) -> nn.Module:
        """Convert transformer layers to Transformer Engine layers."""
    
    def setup_fp8_recipe(self) -> DelayedScaling:
        """Setup FP8 recipe for Transformer Engine."""

Usage Examples

Basic Mixed Precision

from lightning.fabric import Fabric

# Automatic mixed precision with 16-bit
fabric = Fabric(precision="16-mixed", accelerator="gpu")

# BFloat16 mixed precision (better numerical stability)
fabric = Fabric(precision="bf16-mixed", accelerator="gpu")

Custom AMP Configuration

from lightning.fabric.plugins.precision import MixedPrecision
import torch

# Custom gradient scaler
scaler = torch.cuda.amp.GradScaler(
    init_scale=2**16,
    growth_factor=2.0,
    backoff_factor=0.5,
    growth_interval=2000
)

precision_plugin = MixedPrecision(
    precision="16-mixed",
    device="cuda", 
    scaler=scaler
)

fabric = Fabric(
    precision=precision_plugin,
    accelerator="gpu"
)

BitsAndBytes Quantization

from lightning.fabric.plugins.precision import BitsandbytesPrecision

# 8-bit quantization
precision_plugin = BitsandbytesPrecision(mode="int8")

# 4-bit NormalFloat quantization
precision_plugin = BitsandbytesPrecision(
    mode="nf4",
    dtype=torch.bfloat16,
    ignore_modules={"lm_head", "embed_tokens"}
)

fabric = Fabric(
    precision=precision_plugin,
    accelerator="gpu"
)

DeepSpeed Precision Integration

from lightning.fabric.plugins.precision import DeepSpeedPrecision
from lightning.fabric.strategies import DeepSpeedStrategy

# DeepSpeed with mixed precision
precision_plugin = DeepSpeedPrecision(precision="16-mixed")
strategy = DeepSpeedStrategy(stage=2)

fabric = Fabric(
    strategy=strategy,
    precision=precision_plugin,
    devices=8
)

FSDP with Mixed Precision

from lightning.fabric.plugins.precision import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp import MixedPrecision as FSDPMixedPrecision

# FSDP mixed precision configuration
fsdp_precision = FSDPPrecision(precision="bf16-mixed")
fsdp_strategy = FSDPStrategy(
    mixed_precision=FSDPMixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16
    )
)

fabric = Fabric(
    strategy=fsdp_strategy,
    precision=fsdp_precision,
    devices=4
)

TPU BFloat16 Training

from lightning.fabric.plugins.precision import XLAPrecision

# TPU with bfloat16 precision
precision_plugin = XLAPrecision(precision="bf16-mixed")

fabric = Fabric(
    accelerator="tpu",
    strategy="xla", 
    precision=precision_plugin,
    devices=8
)

Manual Precision Control

# Manual autocast usage
fabric = Fabric(precision="16-mixed")

model, optimizer = fabric.setup(model, optimizer)

for batch in dataloader:
    optimizer.zero_grad()
    
    # Manual autocast context
    with fabric.autocast():
        predictions = model(batch["input"])
        loss = criterion(predictions, batch["target"])
    
    fabric.backward(loss)
    optimizer.step()

Gradient Clipping with Precision

# Gradient clipping with mixed precision
fabric = Fabric(precision="16-mixed")

model, optimizer = fabric.setup(model, optimizer)

for batch in dataloader:
    optimizer.zero_grad()
    
    with fabric.autocast():
        loss = compute_loss(model, batch)
    
    fabric.backward(loss)
    
    # Clip gradients (handles unscaling automatically)
    fabric.clip_gradients(model, optimizer, max_norm=1.0)
    
    optimizer.step()

Precision State Management

# Save precision state in checkpoint
fabric = Fabric(precision="16-mixed")

# Precision state is automatically included in Fabric checkpoints
state = {
    "model": model,
    "optimizer": optimizer,
    "precision": fabric.precision_plugin.state_dict()
}
fabric.save("checkpoint.ckpt", state)

# Load precision state
loaded_state = fabric.load("checkpoint.ckpt")
fabric.precision_plugin.load_state_dict(loaded_state["precision"])

Install with Tessl CLI

npx tessl i tessl/pypi-lightning-fabric

docs

accelerators.md

core-training.md

distributed.md

index.md

precision.md

strategies.md

utilities.md

tile.json