The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
—
Lightweight training acceleration framework providing expert-level control over training loops, device management, and distributed strategies without high-level abstractions. Fabric gives you the flexibility of raw PyTorch with the power of Lightning's optimizations.
Main Fabric class that accelerates PyTorch training with distributed training, mixed precision, and device management while maintaining full control over the training loop.
class Fabric:
def __init__(
self,
accelerator: str = "auto",
strategy: str = "auto",
devices: Union[List[int], str, int] = "auto",
num_nodes: int = 1,
precision: Union[str, int] = "32-true",
plugins: Optional[Union[Plugin, List[Plugin]]] = None,
callbacks: Optional[Union[Callback, List[Callback]]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
**kwargs
):
"""
Initialize Fabric for training acceleration.
Args:
accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
strategy: Distributed strategy ('ddp', 'fsdp', 'deepspeed', etc.)
devices: Which devices to use
num_nodes: Number of nodes for distributed training
precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)
plugins: Additional plugins for customization
callbacks: Callbacks for training lifecycle hooks
loggers: Loggers for experiment tracking
"""
def setup(
self,
model: nn.Module,
*optimizers: Optimizer
) -> Union[nn.Module, Tuple[nn.Module, ...]]:
"""
Set up model and optimizers for training.
Args:
model: PyTorch model to accelerate
*optimizers: Optimizers to set up
Returns:
Wrapped model and optimizers ready for training
"""
def setup_dataloaders(
self,
*dataloaders: DataLoader
) -> Union[DataLoader, List[DataLoader]]:
"""
Set up data loaders for distributed training.
Args:
*dataloaders: Data loaders to set up
Returns:
Wrapped data loaders ready for distributed training
"""
def backward(self, tensor: Tensor) -> None:
"""
Perform backward pass with proper scaling and synchronization.
Args:
tensor: Loss tensor to compute gradients from
"""
def clip_gradients(
self,
model: nn.Module,
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True
) -> Tensor:
"""
Clip gradients by norm.
Args:
model: Model whose gradients to clip
optimizer: Optimizer being used
max_norm: Maximum norm for gradients
norm_type: Type of norm to use
error_if_nonfinite: Raise error if gradients are non-finite
Returns:
Total norm of the gradients
"""
def all_gather(
self,
tensor: Tensor,
group: Optional[Any] = None,
sync_grads: bool = False
) -> Tensor:
"""
Gather tensors from all processes.
Args:
tensor: Tensor to gather
group: Process group
sync_grads: Synchronize gradients
Returns:
Gathered tensor from all processes
"""
def all_reduce(
self,
tensor: Tensor,
group: Optional[Any] = None,
reduce_op: str = "mean"
) -> Tensor:
"""
Reduce tensor across all processes.
Args:
tensor: Tensor to reduce
group: Process group
reduce_op: Reduction operation ('mean', 'sum')
Returns:
Reduced tensor
"""
def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
"""
Broadcast tensor from source process to all processes.
Args:
tensor: Tensor to broadcast
src: Source rank
Returns:
Broadcasted tensor
"""
def barrier(self, name: Optional[str] = None) -> None:
"""
Synchronize all processes.
Args:
name: Optional barrier name for debugging
"""
def is_global_zero(self) -> bool:
"""
Check if current process is global rank 0.
Returns:
True if global rank 0
"""
def print(self, *args, **kwargs) -> None:
"""
Print only on rank 0.
Args:
*args: Arguments to print
**kwargs: Keyword arguments for print
"""
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
"""
Log a metric.
Args:
name: Metric name
value: Metric value
step: Optional step number
"""
def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
"""
Log a dictionary of metrics.
Args:
metrics: Dictionary of metrics
step: Optional step number
"""
def save(self, path: str, state: Dict[str, Any]) -> None:
"""
Save checkpoint.
Args:
path: Path to save checkpoint
state: State dictionary to save
"""
def load(self, path: str) -> Dict[str, Any]:
"""
Load checkpoint.
Args:
path: Path to load checkpoint from
Returns:
Loaded state dictionary
"""
@property
def device(self) -> torch.device:
"""Get the current device."""
@property
def global_rank(self) -> int:
"""Get global rank of current process."""
@property
def local_rank(self) -> int:
"""Get local rank of current process."""
@property
def node_rank(self) -> int:
"""Get node rank of current process."""
@property
def world_size(self) -> int:
"""Get total number of processes."""
def to_device(self, obj: Any) -> Any:
"""
Move object to device.
Args:
obj: Object to move to device
Returns:
Object on the device
"""Core utility functions for reproducibility, object inspection, and common operations in Fabric workflows.
def seed_everything(seed: int, workers: bool = False) -> int:
"""
Set random seeds for reproducibility.
Args:
seed: Random seed to set
workers: Also set seed for data loader workers
Returns:
The seed that was set
"""
def is_wrapped(obj: Any) -> bool:
"""
Check if an object has been wrapped by Fabric.
Args:
obj: Object to check
Returns:
True if object is wrapped by Fabric
"""import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from lightning import Fabric
# Initialize Fabric
fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")
# Define model and optimizer
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Setup model and optimizer with Fabric
model, optimizer = fabric.setup(model, optimizer)
# Create sample data and dataloader
data = torch.randn(1000, 10)
targets = torch.randn(1000, 1)
dataset = TensorDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32)
# Setup dataloader
dataloader = fabric.setup_dataloaders(dataloader)
# Training loop with full control
for epoch in range(10):
for batch_idx, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
# Forward pass
y_pred = model(x)
loss = nn.functional.mse_loss(y_pred, y)
# Backward pass - Fabric handles scaling and synchronization
fabric.backward(loss)
optimizer.step()
# Log metrics
if batch_idx % 10 == 0:
fabric.log("train_loss", loss.item())
fabric.print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")
# Save checkpoint
state = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}
fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lightning import Fabric
# Initialize Fabric with advanced configuration
fabric = Fabric(
accelerator="gpu",
devices=4,
strategy="fsdp",
precision="16-mixed",
plugins=None
)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 10)
)
def forward(self, x):
return self.layers(x)
# Model and optimizers
model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# Setup with Fabric
model, optimizer = fabric.setup(model, optimizer)
# Training loop with advanced features
for epoch in range(100):
model.train()
for batch_idx, (data, target) in enumerate(train_dataloader):
optimizer.zero_grad()
output = model(data)
loss = nn.functional.cross_entropy(output, target)
# Backward with automatic mixed precision
fabric.backward(loss)
# Gradient clipping
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()
# Metrics logging
if batch_idx % 100 == 0:
accuracy = (output.argmax(dim=1) == target).float().mean()
# Log metrics - automatically handles distributed averaging
fabric.log_dict({
"train_loss": loss.item(),
"train_acc": accuracy.item(),
"lr": scheduler.get_last_lr()[0]
})
# Print only on rank 0
fabric.print(f"Epoch {epoch}/{100}, Batch {batch_idx}, "
f"Loss: {loss.item():.4f}, Acc: {accuracy.item():.4f}")
scheduler.step()
# Synchronization barrier
fabric.barrier()
# Save checkpoint (only on rank 0)
if fabric.is_global_zero():
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch,
}
fabric.save(f"model_epoch_{epoch}.ckpt", checkpoint)Install with Tessl CLI
npx tessl i tessl/pypi-lightning