The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
—
Precision plugins for mixed precision training, quantization, and various floating-point formats to optimize memory usage and training speed while maintaining model quality.
Automatic mixed precision training using 16-bit floats for forward pass and 32-bit for loss computation.
class MixedPrecision:
def __init__(self, precision: str = "16-mixed", device: str = "cuda"):
"""
Initialize mixed precision plugin.
Args:
precision: Precision mode ('16-mixed', 'bf16-mixed')
device: Target device
"""16-bit floating point training for memory efficiency.
class HalfPrecision:
def __init__(self):
"""Initialize half precision plugin."""64-bit floating point training for maximum numerical precision.
class DoublePrecision:
def __init__(self):
"""Initialize double precision plugin."""8-bit and 4-bit quantization using BitsAndBytes for memory-efficient training of large models.
class BitsandbytesPrecision:
def __init__(
self,
mode: str = "int8",
dtype: Optional[torch.dtype] = None,
ignore_modules: Optional[Set[str]] = None
):
"""
Initialize BitsAndBytes precision plugin.
Args:
mode: Quantization mode ('int8', 'int4', 'nf4', 'fp4')
dtype: Data type for computation
ignore_modules: Modules to skip quantization
"""Precision plugin for DeepSpeed optimization with ZeRO memory optimization.
class DeepSpeedPrecision:
def __init__(self):
"""Initialize DeepSpeed precision plugin."""Precision plugin optimized for Fully Sharded Data Parallel training.
class FSDPPrecision:
def __init__(self):
"""Initialize FSDP precision plugin."""NVIDIA Transformer Engine precision for optimized transformer training.
class TransformerEnginePrecision:
def __init__(
self,
weights_dtype: torch.dtype = torch.float32,
recipe: Optional[Dict[str, Any]] = None
):
"""
Initialize Transformer Engine precision plugin.
Args:
weights_dtype: Data type for model weights
recipe: Transformer Engine recipe configuration
"""Precision plugin for TPU training with XLA compilation.
class XLAPrecision:
def __init__(self):
"""Initialize XLA precision plugin for TPU training."""Base class for implementing custom precision plugins.
class Precision:
def __init__(self):
"""Initialize base precision plugin."""
def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert module for precision."""
def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Convert optimizer for precision."""
def backward(self, tensor: Tensor, model: nn.Module) -> None:
"""Perform backward pass with precision handling."""Install with Tessl CLI
npx tessl i tessl/pypi-lightning