The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
npx @tessl/cli install tessl/pypi-lightning@2.5.0The Deep Learning framework to train, deploy, and ship AI products Lightning fast. Lightning provides a unified interface combining PyTorch Lightning (for high-level model training) with Lightning Fabric (for expert-level control) and data utilities, enabling researchers and practitioners to build production-ready deep learning applications at scale.
pip install lightningimport lightning as LMain framework components:
from lightning import Trainer, LightningModule, LightningDataModule, CallbackLightweight acceleration:
from lightning import FabricUtilities:
from lightning import seed_everything
from lightning.pytorch.utilities.warnings import disable_possible_user_warningsimport lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
# Define a Lightning Module
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.layer_1(x))
x = self.layer_2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
# Define a Data Module
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
# Train the model
if __name__ == "__main__":
model = LitModel()
datamodule = MNISTDataModule()
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule)Lightning provides a layered architecture designed for maximum flexibility and production readiness:
This design enables seamless transitions from research prototyping to production deployment while maintaining code reusability and scalability.
Essential components for structuring deep learning training: the Trainer orchestrator, LightningModule for model definition, LightningDataModule for data handling, and Callback system for training lifecycle hooks.
class Trainer:
def __init__(self, **kwargs): ...
def fit(self, model, datamodule=None, train_dataloaders=None, val_dataloaders=None, **kwargs): ...
def test(self, model=None, dataloaders=None, **kwargs): ...
def predict(self, model=None, dataloaders=None, **kwargs): ...
class LightningModule:
def __init__(self): ...
def forward(self, *args, **kwargs): ...
def training_step(self, batch, batch_idx): ...
def validation_step(self, batch, batch_idx): ...
def test_step(self, batch, batch_idx): ...
def configure_optimizers(self): ...
class LightningDataModule:
def __init__(self): ...
def prepare_data(self): ...
def setup(self, stage: str): ...
def train_dataloader(self): ...
def val_dataloader(self): ...
def test_dataloader(self): ...
class Callback:
def on_train_start(self, trainer, pl_module): ...
def on_train_end(self, trainer, pl_module): ...
def on_epoch_start(self, trainer, pl_module): ...
def on_epoch_end(self, trainer, pl_module): ...Lightweight training acceleration framework providing expert-level control over training loops, device management, and distributed strategies without high-level abstractions.
class Fabric:
def __init__(self, **kwargs): ...
def setup(self, model, *optimizers): ...
def setup_dataloaders(self, *dataloaders): ...
def backward(self, tensor): ...
def all_gather(self, tensor): ...
def broadcast(self, tensor): ...
def seed_everything(seed: int): ...
def is_wrapped(obj): ...Comprehensive callback system for training lifecycle management including checkpointing, early stopping, learning rate scheduling, monitoring, and optimization callbacks.
class ModelCheckpoint(Callback):
def __init__(self, dirpath=None, filename=None, monitor=None, **kwargs): ...
class EarlyStopping(Callback):
def __init__(self, monitor, patience=3, **kwargs): ...
class LearningRateMonitor(Callback):
def __init__(self, logging_interval='epoch'): ...
class StochasticWeightAveraging(Callback):
def __init__(self, swa_lrs=None, **kwargs): ...Multiple strategies for distributed and parallel training including data parallel, distributed data parallel, fully sharded data parallel, model parallel, and specialized strategies for different hardware.
class DDPStrategy:
def __init__(self, **kwargs): ...
class FSDPStrategy:
def __init__(self, **kwargs): ...
class DeepSpeedStrategy:
def __init__(self, **kwargs): ...
class DataParallelStrategy:
def __init__(self): ...Distributed Training Strategies
Support for various hardware accelerators including CPU, CUDA GPUs, Apple Metal Performance Shaders, and Google TPUs with automatic device detection and optimization.
class CPUAccelerator:
def setup_device(self, device): ...
class CUDAAccelerator:
def setup_device(self, device): ...
class MPSAccelerator:
def setup_device(self, device): ...
class XLAAccelerator:
def setup_device(self, device): ...
def find_usable_cuda_devices(num_gpus: int = -1): ...Precision plugins for mixed precision training, quantization, and various floating-point formats to optimize memory usage and training speed while maintaining model quality.
class MixedPrecision:
def __init__(self, precision='16-mixed', **kwargs): ...
class HalfPrecision:
def __init__(self): ...
class DoublePrecision:
def __init__(self): ...
class BitsandbytesPrecision:
def __init__(self, mode='int8', **kwargs): ...Integration with popular experiment tracking platforms and comprehensive logging capabilities for monitoring training progress, metrics, hyperparameters, and model artifacts.
class TensorBoardLogger:
def __init__(self, save_dir, **kwargs): ...
class WandbLogger:
def __init__(self, project=None, **kwargs): ...
class MLFlowLogger:
def __init__(self, experiment_name=None, **kwargs): ...
class CSVLogger:
def __init__(self, save_dir, **kwargs): ...Profiling tools for analyzing training performance, identifying bottlenecks, and optimizing model training efficiency across different hardware configurations.
class PyTorchProfiler:
def __init__(self, **kwargs): ...
class AdvancedProfiler:
def __init__(self, **kwargs): ...
class SimpleProfiler:
def __init__(self): ...Data handling utilities including streaming datasets, combined data loaders, and data processing functions for efficient data pipeline management in large-scale training.
class StreamingDataset:
def __init__(self, **kwargs): ...
class CombinedStreamingDataset:
def __init__(self, datasets, **kwargs): ...
def optimize(data_dir, **kwargs): ...
def map(function, inputs, **kwargs): ...Common utilities for training control and configuration.
def seed_everything(seed: int, workers: bool = False) -> int: ...
def disable_possible_user_warnings() -> None: ...from typing import Any, Dict, List, Optional, Union
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
# Core types
STEP_OUTPUT = Union[Tensor, Dict[str, Any]]
TRAIN_DATALOADERS = Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]
EVAL_DATALOADERS = Union[DataLoader, List[DataLoader]]
_EVALUATE_OUTPUT = List[Dict[str, float]]
_PREDICT_OUTPUT = List[Any]
# LR Scheduler configuration
class LRSchedulerConfig:
scheduler: Any
interval: str = "epoch"
frequency: int = 1
monitor: Optional[str] = None
strict: bool = True
name: Optional[str] = None
# Enums
class GradClipAlgorithmType:
NORM = "norm"
VALUE = "value"
class LightningEnum:
pass
# Constants
FLOAT16_EPSILON: float
FLOAT32_EPSILON: float
FLOAT64_EPSILON: float