CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning-fabric

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate

Pending
Overview
Eval results
Files

utilities.mddocs/

Utilities

Helper functions and utilities for seeding, data movement, distributed operations, and performance monitoring.

Capabilities

Seeding Utilities

Functions for controlling random number generation and ensuring reproducibility.

def seed_everything(
    seed: Optional[int] = None, 
    workers: bool = False, 
    verbose: bool = True
) -> int:
    """
    Set global random seeds for reproducible results.
    
    Sets seeds for Python random, NumPy, PyTorch, and CUDA random number
    generators to ensure reproducible training runs.
    
    Args:
        seed: Random seed value. If None, generates random seed
        workers: Whether to seed DataLoader workers
        verbose: Whether to print seed information
        
    Returns:
        The seed value used
        
    Examples:
        # Set specific seed
        seed_everything(42)
        
        # Generate random seed
        used_seed = seed_everything()
        
        # Seed DataLoader workers for complete reproducibility
        seed_everything(42, workers=True)
    """

def reset_seed() -> None:
    """
    Reset random seed to previous state.
    
    Restores the random number generator state to what it was
    before the last seed_everything() call.
    """

def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
    """
    Initialize random seeds for DataLoader workers.
    
    Used internally by Fabric to ensure DataLoader workers have
    different random seeds for proper data shuffling.
    
    Args:
        worker_id: DataLoader worker ID
        rank: Process rank for distributed training
    """

Data Movement Utilities

Functions for moving data between devices and handling device placement.

def move_data_to_device(obj: Any, device: torch.device) -> Any:
    """
    Move tensors and nested data structures to target device.
    
    Recursively moves tensors in lists, tuples, dictionaries, and
    custom objects to the specified device.
    
    Args:
        obj: Object containing tensors to move
        device: Target device
        
    Returns:
        Object with tensors moved to target device
        
    Examples:
        # Move single tensor
        tensor = torch.randn(10, 10)
        tensor_gpu = move_data_to_device(tensor, torch.device("cuda"))
        
        # Move nested data structure
        data = {
            "input": torch.randn(32, 784),
            "target": torch.randint(0, 10, (32,)),
            "metadata": {"batch_size": 32}
        }
        data_gpu = move_data_to_device(data, torch.device("cuda"))
    """

def suggested_max_num_workers(num_cpus: Optional[int] = None) -> int:
    """
    Suggest optimal number of DataLoader workers.
    
    Calculates recommended number of DataLoader workers based on
    available CPU cores and system configuration.
    
    Args:
        num_cpus: Number of available CPUs (auto-detected if None)
        
    Returns:
        Recommended number of DataLoader workers
        
    Examples:
        # Auto-detect optimal workers
        num_workers = suggested_max_num_workers()
        dataloader = DataLoader(dataset, num_workers=num_workers)
        
        # Use specific CPU count
        num_workers = suggested_max_num_workers(num_cpus=8)
    """

Object Wrapping Utilities

Functions for checking and managing Fabric-wrapped objects.

def is_wrapped(obj: Any) -> bool:
    """
    Check if object is wrapped by Fabric.
    
    Determines whether a model, optimizer, or dataloader has been
    wrapped by Fabric for distributed training.
    
    Args:
        obj: Object to check
        
    Returns:
        True if object is Fabric-wrapped, False otherwise
        
    Examples:
        model = nn.Linear(10, 1)
        print(is_wrapped(model))  # False
        
        fabric = Fabric()
        wrapped_model = fabric.setup_module(model)
        print(is_wrapped(wrapped_model))  # True
    """

def _unwrap_objects(collection: Any) -> Any:
    """
    Unwrap Fabric-wrapped objects in nested collections.
    
    Recursively unwraps Fabric objects in lists, tuples, dicts,
    returning the underlying PyTorch objects.
    
    Args:
        collection: Collection potentially containing wrapped objects
        
    Returns:
        Collection with unwrapped objects
    """

Distributed Utilities

Helper functions for distributed training operations.

class DistributedSamplerWrapper:
    """
    Wrapper for PyTorch samplers to work with distributed training.
    
    Automatically handles epoch setting and distributed sampling
    for custom samplers in distributed environments.
    """
    
    def __init__(self, sampler: Sampler, **kwargs):
        """
        Initialize distributed sampler wrapper.
        
        Args:
            sampler: Base sampler to wrap
            **kwargs: Additional arguments for DistributedSampler
        """
    
    def set_epoch(self, epoch: int) -> None:
        """Set epoch for proper shuffling in distributed training."""

class _InfiniteBarrier:
    """
    Barrier implementation that works across different process groups.
    Used internally for synchronizing processes in complex distributed setups.
    """
    
    def __call__(self) -> None:
        """Execute barrier synchronization."""

Rank-Zero Utilities

Functions that only execute on the rank-0 process in distributed training.

def rank_zero_only(fn: callable) -> callable:
    """
    Decorator to execute function only on rank 0.
    
    Args:
        fn: Function to wrap
        
    Returns:
        Decorated function that only executes on rank 0
        
    Examples:
        @rank_zero_only
        def save_model(model, path):
            torch.save(model.state_dict(), path)
        
        # Only rank 0 will save the model
        save_model(model, "model.pth")
    """

def rank_zero_warn(message: str, category: Warning = UserWarning, stacklevel: int = 1) -> None:
    """
    Issue warning only from rank 0 process.
    
    Args:
        message: Warning message
        category: Warning category
        stacklevel: Stack level for warning location
        
    Examples:
        rank_zero_warn("This is a warning from rank 0 only")
    """

def rank_zero_info(message: str) -> None:
    """
    Log info message only from rank 0 process.
    
    Args:
        message: Info message to log
        
    Examples:
        rank_zero_info("Training started")
    """

def rank_zero_deprecation(message: str) -> None:
    """
    Issue deprecation warning only from rank 0 process.
    
    Args:
        message: Deprecation message
        
    Examples:
        rank_zero_deprecation("This function is deprecated, use new_function() instead")
    """

Performance Monitoring

Classes and functions for monitoring training performance and throughput.

class Throughput:
    """
    Throughput measurement utility.
    
    Measures processing throughput (samples/second) during training
    with automatic timing and averaging.
    """
    
    def __init__(self, window_size: int = 100):
        """
        Initialize throughput monitor.
        
        Args:
            window_size: Number of measurements to average over
        """
    
    def update(self, batch_size: int) -> None:
        """
        Update throughput measurement with new batch.
        
        Args:
            batch_size: Size of processed batch
        """
    
    def compute(self) -> float:
        """
        Compute current throughput.
        
        Returns:
            Throughput in samples per second
        """
    
    def reset(self) -> None:
        """Reset throughput measurements."""

class ThroughputMonitor:
    """
    Advanced throughput monitoring with multiple metrics.
    
    Tracks various performance metrics including samples/second,
    batches/second, and GPU utilization over time.
    """
    
    def __init__(
        self,
        window_size: int = 100,
        log_interval: int = 50
    ):
        """
        Initialize throughput monitor.
        
        Args:
            window_size: Measurement window size
            log_interval: Logging interval in steps
        """
    
    def on_batch_end(
        self,
        batch_size: int,
        num_samples: int,
        step: int
    ) -> None:
        """Called at the end of each training batch."""
    
    def get_metrics(self) -> dict[str, float]:
        """Get current performance metrics."""

def measure_flops(
    model: nn.Module,
    input_shape: tuple[int, ...],
    device: Optional[torch.device] = None
) -> dict[str, Union[int, float]]:
    """
    Measure FLOPs (floating point operations) for model inference.
    
    Estimates computational complexity by measuring FLOPs required
    for a forward pass with given input shape.
    
    Args:
        model: PyTorch model to analyze
        input_shape: Shape of input tensor (excluding batch dimension)
        device: Device to run measurement on
        
    Returns:
        Dictionary with FLOP measurements and model statistics
        
    Examples:
        # Measure FLOPs for image classification model
        flops = measure_flops(model, (3, 224, 224))
        print(f"Model requires {flops['flops']:,} FLOPs")
        
        # Measure FLOPs for text model
        flops = measure_flops(model, (512,))  # sequence length 512
    """

General Utilities

Miscellaneous utility classes and functions.

class AttributeDict(dict):
    """
    Dictionary that allows attribute-style access to keys.
    
    Enables accessing dictionary values using dot notation
    in addition to standard dictionary access.
    
    Examples:
        config = AttributeDict({"learning_rate": 0.001, "batch_size": 32})
        print(config.learning_rate)  # 0.001
        print(config["batch_size"])  # 32
        
        config.epochs = 100
        print(config["epochs"])  # 100
    """
    
    def __getattr__(self, key: str) -> Any:
        """Get attribute using dot notation."""
    
    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute using dot notation."""
    
    def __delattr__(self, key: str) -> None:
        """Delete attribute using dot notation."""

def is_shared_filesystem(path: Union[str, Path]) -> bool:
    """
    Check if path is on a shared filesystem across nodes.
    
    Determines whether a path is accessible from all nodes in
    a distributed training setup (e.g., NFS, shared storage).
    
    Args:
        path: Path to check
        
    Returns:
        True if filesystem is shared across nodes
        
    Examples:
        if is_shared_filesystem("/shared/checkpoints"):
            # Can save checkpoint from any node
            fabric.save("/shared/checkpoints/model.ckpt", state)
        else:
            # Save checkpoint only from rank 0
            if fabric.is_global_zero:
                fabric.save("local_model.ckpt", state)
    """

class LightningEnum(Enum):
    """
    Base enumeration class with additional utility methods.
    
    Extended enum class that provides helper methods for
    string conversion and validation.
    """
    
    @classmethod
    def from_str(cls, value: str) -> "LightningEnum":
        """Create enum from string value."""
    
    def __str__(self) -> str:
        """String representation of enum value."""

def disable_possible_user_warnings() -> None:
    """
    Disable possible user warnings from Lightning.
    
    Suppresses warnings that may be triggered by user code
    but are not critical for operation.
    
    Examples:
        # Disable warnings in production
        disable_possible_user_warnings()
    """

Usage Examples

Reproducible Training Setup

from lightning.fabric import Fabric, seed_everything

# Set seed for reproducibility
seed_everything(42, workers=True)

fabric = Fabric(accelerator="gpu", devices=2)

# DataLoader will automatically use seeded workers
dataloader = fabric.setup_dataloaders(
    DataLoader(dataset, num_workers=4, shuffle=True)
)

Optimal DataLoader Configuration

from lightning.fabric.utilities import suggested_max_num_workers

# Get optimal number of workers
num_workers = suggested_max_num_workers()

dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=num_workers,
    pin_memory=True
)

Performance Monitoring

from lightning.fabric.utilities import ThroughputMonitor

# Initialize performance monitor
throughput = ThroughputMonitor(window_size=100, log_interval=50)

# Training loop with monitoring
for step, batch in enumerate(dataloader):
    batch_size = batch[0].size(0)
    
    # Training step
    loss = train_step(model, batch)
    
    # Update throughput monitoring
    throughput.on_batch_end(
        batch_size=batch_size,
        num_samples=batch_size,
        step=step
    )
    
    if step % 50 == 0:
        metrics = throughput.get_metrics()
        fabric.print(f"Step {step}: {metrics['samples_per_sec']:.1f} samples/sec")

Device-Agnostic Data Movement

from lightning.fabric.utilities import move_data_to_device

# Complex nested data structure
batch = {
    "input": torch.randn(32, 784),
    "target": torch.randint(0, 10, (32,)),
    "metadata": {
        "lengths": torch.randint(10, 100, (32,)),
        "mask": torch.ones(32, 100, dtype=torch.bool)
    }
}

# Move entire structure to device
device = fabric.device
batch = move_data_to_device(batch, device)

Rank-Zero Operations

from lightning.fabric.utilities import rank_zero_only, rank_zero_warn

@rank_zero_only
def save_artifacts(model, metrics, epoch):
    """Save model and log metrics only from rank 0."""
    torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
    with open("metrics.json", "w") as f:
        json.dump(metrics, f)

# Training loop
for epoch in range(num_epochs):
    train_metrics = train_epoch(model, dataloader)
    
    # Only rank 0 saves artifacts
    save_artifacts(model, train_metrics, epoch)
    
    # Warning only from rank 0
    if train_metrics["loss"] > previous_loss:
        rank_zero_warn("Loss increased compared to previous epoch")

FLOP Measurement

from lightning.fabric.utilities import measure_flops

# Measure model complexity
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

flops_info = measure_flops(model, (784,))
fabric.print(f"Model FLOPs: {flops_info['flops']:,}")
fabric.print(f"Model parameters: {flops_info['params']:,}")

Configuration Management

from lightning.fabric.utilities import AttributeDict

# Configuration with attribute access
config = AttributeDict({
    "model": {
        "hidden_size": 256,
        "num_layers": 3
    },
    "training": {
        "learning_rate": 0.001,
        "batch_size": 32,
        "epochs": 100
    }
})

# Access using dot notation
model = create_model(
    hidden_size=config.model.hidden_size,
    num_layers=config.model.num_layers
)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config.training.learning_rate
)

Filesystem Utilities

from lightning.fabric.utilities import is_shared_filesystem

checkpoint_path = "/shared/storage/checkpoints"

if is_shared_filesystem(checkpoint_path):
    # All nodes can access this path
    fabric.save(f"{checkpoint_path}/model.ckpt", state)
else:
    # Use local storage with rank coordination
    if fabric.is_global_zero:
        fabric.save("model.ckpt", state)
    
    # Wait for rank 0 to finish saving
    fabric.barrier("checkpoint_save")

Install with Tessl CLI

npx tessl i tessl/pypi-lightning-fabric

docs

accelerators.md

core-training.md

distributed.md

index.md

precision.md

strategies.md

utilities.md

tile.json