Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
—
Helper functions and utilities for seeding, data movement, distributed operations, and performance monitoring.
Functions for controlling random number generation and ensuring reproducibility.
def seed_everything(
seed: Optional[int] = None,
workers: bool = False,
verbose: bool = True
) -> int:
"""
Set global random seeds for reproducible results.
Sets seeds for Python random, NumPy, PyTorch, and CUDA random number
generators to ensure reproducible training runs.
Args:
seed: Random seed value. If None, generates random seed
workers: Whether to seed DataLoader workers
verbose: Whether to print seed information
Returns:
The seed value used
Examples:
# Set specific seed
seed_everything(42)
# Generate random seed
used_seed = seed_everything()
# Seed DataLoader workers for complete reproducibility
seed_everything(42, workers=True)
"""
def reset_seed() -> None:
"""
Reset random seed to previous state.
Restores the random number generator state to what it was
before the last seed_everything() call.
"""
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
"""
Initialize random seeds for DataLoader workers.
Used internally by Fabric to ensure DataLoader workers have
different random seeds for proper data shuffling.
Args:
worker_id: DataLoader worker ID
rank: Process rank for distributed training
"""Functions for moving data between devices and handling device placement.
def move_data_to_device(obj: Any, device: torch.device) -> Any:
"""
Move tensors and nested data structures to target device.
Recursively moves tensors in lists, tuples, dictionaries, and
custom objects to the specified device.
Args:
obj: Object containing tensors to move
device: Target device
Returns:
Object with tensors moved to target device
Examples:
# Move single tensor
tensor = torch.randn(10, 10)
tensor_gpu = move_data_to_device(tensor, torch.device("cuda"))
# Move nested data structure
data = {
"input": torch.randn(32, 784),
"target": torch.randint(0, 10, (32,)),
"metadata": {"batch_size": 32}
}
data_gpu = move_data_to_device(data, torch.device("cuda"))
"""
def suggested_max_num_workers(num_cpus: Optional[int] = None) -> int:
"""
Suggest optimal number of DataLoader workers.
Calculates recommended number of DataLoader workers based on
available CPU cores and system configuration.
Args:
num_cpus: Number of available CPUs (auto-detected if None)
Returns:
Recommended number of DataLoader workers
Examples:
# Auto-detect optimal workers
num_workers = suggested_max_num_workers()
dataloader = DataLoader(dataset, num_workers=num_workers)
# Use specific CPU count
num_workers = suggested_max_num_workers(num_cpus=8)
"""Functions for checking and managing Fabric-wrapped objects.
def is_wrapped(obj: Any) -> bool:
"""
Check if object is wrapped by Fabric.
Determines whether a model, optimizer, or dataloader has been
wrapped by Fabric for distributed training.
Args:
obj: Object to check
Returns:
True if object is Fabric-wrapped, False otherwise
Examples:
model = nn.Linear(10, 1)
print(is_wrapped(model)) # False
fabric = Fabric()
wrapped_model = fabric.setup_module(model)
print(is_wrapped(wrapped_model)) # True
"""
def _unwrap_objects(collection: Any) -> Any:
"""
Unwrap Fabric-wrapped objects in nested collections.
Recursively unwraps Fabric objects in lists, tuples, dicts,
returning the underlying PyTorch objects.
Args:
collection: Collection potentially containing wrapped objects
Returns:
Collection with unwrapped objects
"""Helper functions for distributed training operations.
class DistributedSamplerWrapper:
"""
Wrapper for PyTorch samplers to work with distributed training.
Automatically handles epoch setting and distributed sampling
for custom samplers in distributed environments.
"""
def __init__(self, sampler: Sampler, **kwargs):
"""
Initialize distributed sampler wrapper.
Args:
sampler: Base sampler to wrap
**kwargs: Additional arguments for DistributedSampler
"""
def set_epoch(self, epoch: int) -> None:
"""Set epoch for proper shuffling in distributed training."""
class _InfiniteBarrier:
"""
Barrier implementation that works across different process groups.
Used internally for synchronizing processes in complex distributed setups.
"""
def __call__(self) -> None:
"""Execute barrier synchronization."""Functions that only execute on the rank-0 process in distributed training.
def rank_zero_only(fn: callable) -> callable:
"""
Decorator to execute function only on rank 0.
Args:
fn: Function to wrap
Returns:
Decorated function that only executes on rank 0
Examples:
@rank_zero_only
def save_model(model, path):
torch.save(model.state_dict(), path)
# Only rank 0 will save the model
save_model(model, "model.pth")
"""
def rank_zero_warn(message: str, category: Warning = UserWarning, stacklevel: int = 1) -> None:
"""
Issue warning only from rank 0 process.
Args:
message: Warning message
category: Warning category
stacklevel: Stack level for warning location
Examples:
rank_zero_warn("This is a warning from rank 0 only")
"""
def rank_zero_info(message: str) -> None:
"""
Log info message only from rank 0 process.
Args:
message: Info message to log
Examples:
rank_zero_info("Training started")
"""
def rank_zero_deprecation(message: str) -> None:
"""
Issue deprecation warning only from rank 0 process.
Args:
message: Deprecation message
Examples:
rank_zero_deprecation("This function is deprecated, use new_function() instead")
"""Classes and functions for monitoring training performance and throughput.
class Throughput:
"""
Throughput measurement utility.
Measures processing throughput (samples/second) during training
with automatic timing and averaging.
"""
def __init__(self, window_size: int = 100):
"""
Initialize throughput monitor.
Args:
window_size: Number of measurements to average over
"""
def update(self, batch_size: int) -> None:
"""
Update throughput measurement with new batch.
Args:
batch_size: Size of processed batch
"""
def compute(self) -> float:
"""
Compute current throughput.
Returns:
Throughput in samples per second
"""
def reset(self) -> None:
"""Reset throughput measurements."""
class ThroughputMonitor:
"""
Advanced throughput monitoring with multiple metrics.
Tracks various performance metrics including samples/second,
batches/second, and GPU utilization over time.
"""
def __init__(
self,
window_size: int = 100,
log_interval: int = 50
):
"""
Initialize throughput monitor.
Args:
window_size: Measurement window size
log_interval: Logging interval in steps
"""
def on_batch_end(
self,
batch_size: int,
num_samples: int,
step: int
) -> None:
"""Called at the end of each training batch."""
def get_metrics(self) -> dict[str, float]:
"""Get current performance metrics."""
def measure_flops(
model: nn.Module,
input_shape: tuple[int, ...],
device: Optional[torch.device] = None
) -> dict[str, Union[int, float]]:
"""
Measure FLOPs (floating point operations) for model inference.
Estimates computational complexity by measuring FLOPs required
for a forward pass with given input shape.
Args:
model: PyTorch model to analyze
input_shape: Shape of input tensor (excluding batch dimension)
device: Device to run measurement on
Returns:
Dictionary with FLOP measurements and model statistics
Examples:
# Measure FLOPs for image classification model
flops = measure_flops(model, (3, 224, 224))
print(f"Model requires {flops['flops']:,} FLOPs")
# Measure FLOPs for text model
flops = measure_flops(model, (512,)) # sequence length 512
"""Miscellaneous utility classes and functions.
class AttributeDict(dict):
"""
Dictionary that allows attribute-style access to keys.
Enables accessing dictionary values using dot notation
in addition to standard dictionary access.
Examples:
config = AttributeDict({"learning_rate": 0.001, "batch_size": 32})
print(config.learning_rate) # 0.001
print(config["batch_size"]) # 32
config.epochs = 100
print(config["epochs"]) # 100
"""
def __getattr__(self, key: str) -> Any:
"""Get attribute using dot notation."""
def __setattr__(self, key: str, value: Any) -> None:
"""Set attribute using dot notation."""
def __delattr__(self, key: str) -> None:
"""Delete attribute using dot notation."""
def is_shared_filesystem(path: Union[str, Path]) -> bool:
"""
Check if path is on a shared filesystem across nodes.
Determines whether a path is accessible from all nodes in
a distributed training setup (e.g., NFS, shared storage).
Args:
path: Path to check
Returns:
True if filesystem is shared across nodes
Examples:
if is_shared_filesystem("/shared/checkpoints"):
# Can save checkpoint from any node
fabric.save("/shared/checkpoints/model.ckpt", state)
else:
# Save checkpoint only from rank 0
if fabric.is_global_zero:
fabric.save("local_model.ckpt", state)
"""
class LightningEnum(Enum):
"""
Base enumeration class with additional utility methods.
Extended enum class that provides helper methods for
string conversion and validation.
"""
@classmethod
def from_str(cls, value: str) -> "LightningEnum":
"""Create enum from string value."""
def __str__(self) -> str:
"""String representation of enum value."""
def disable_possible_user_warnings() -> None:
"""
Disable possible user warnings from Lightning.
Suppresses warnings that may be triggered by user code
but are not critical for operation.
Examples:
# Disable warnings in production
disable_possible_user_warnings()
"""from lightning.fabric import Fabric, seed_everything
# Set seed for reproducibility
seed_everything(42, workers=True)
fabric = Fabric(accelerator="gpu", devices=2)
# DataLoader will automatically use seeded workers
dataloader = fabric.setup_dataloaders(
DataLoader(dataset, num_workers=4, shuffle=True)
)from lightning.fabric.utilities import suggested_max_num_workers
# Get optimal number of workers
num_workers = suggested_max_num_workers()
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=num_workers,
pin_memory=True
)from lightning.fabric.utilities import ThroughputMonitor
# Initialize performance monitor
throughput = ThroughputMonitor(window_size=100, log_interval=50)
# Training loop with monitoring
for step, batch in enumerate(dataloader):
batch_size = batch[0].size(0)
# Training step
loss = train_step(model, batch)
# Update throughput monitoring
throughput.on_batch_end(
batch_size=batch_size,
num_samples=batch_size,
step=step
)
if step % 50 == 0:
metrics = throughput.get_metrics()
fabric.print(f"Step {step}: {metrics['samples_per_sec']:.1f} samples/sec")from lightning.fabric.utilities import move_data_to_device
# Complex nested data structure
batch = {
"input": torch.randn(32, 784),
"target": torch.randint(0, 10, (32,)),
"metadata": {
"lengths": torch.randint(10, 100, (32,)),
"mask": torch.ones(32, 100, dtype=torch.bool)
}
}
# Move entire structure to device
device = fabric.device
batch = move_data_to_device(batch, device)from lightning.fabric.utilities import rank_zero_only, rank_zero_warn
@rank_zero_only
def save_artifacts(model, metrics, epoch):
"""Save model and log metrics only from rank 0."""
torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
with open("metrics.json", "w") as f:
json.dump(metrics, f)
# Training loop
for epoch in range(num_epochs):
train_metrics = train_epoch(model, dataloader)
# Only rank 0 saves artifacts
save_artifacts(model, train_metrics, epoch)
# Warning only from rank 0
if train_metrics["loss"] > previous_loss:
rank_zero_warn("Loss increased compared to previous epoch")from lightning.fabric.utilities import measure_flops
# Measure model complexity
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
flops_info = measure_flops(model, (784,))
fabric.print(f"Model FLOPs: {flops_info['flops']:,}")
fabric.print(f"Model parameters: {flops_info['params']:,}")from lightning.fabric.utilities import AttributeDict
# Configuration with attribute access
config = AttributeDict({
"model": {
"hidden_size": 256,
"num_layers": 3
},
"training": {
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 100
}
})
# Access using dot notation
model = create_model(
hidden_size=config.model.hidden_size,
num_layers=config.model.num_layers
)
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.training.learning_rate
)from lightning.fabric.utilities import is_shared_filesystem
checkpoint_path = "/shared/storage/checkpoints"
if is_shared_filesystem(checkpoint_path):
# All nodes can access this path
fabric.save(f"{checkpoint_path}/model.ckpt", state)
else:
# Use local storage with rank coordination
if fabric.is_global_zero:
fabric.save("model.ckpt", state)
# Wait for rank 0 to finish saving
fabric.barrier("checkpoint_save")Install with Tessl CLI
npx tessl i tessl/pypi-lightning-fabric