0
# Utilities
1
2
Helper functions and utilities for seeding, data movement, distributed operations, and performance monitoring.
3
4
## Capabilities
5
6
### Seeding Utilities
7
8
Functions for controlling random number generation and ensuring reproducibility.
9
10
```python { .api }
11
def seed_everything(
12
seed: Optional[int] = None,
13
workers: bool = False,
14
verbose: bool = True
15
) -> int:
16
"""
17
Set global random seeds for reproducible results.
18
19
Sets seeds for Python random, NumPy, PyTorch, and CUDA random number
20
generators to ensure reproducible training runs.
21
22
Args:
23
seed: Random seed value. If None, generates random seed
24
workers: Whether to seed DataLoader workers
25
verbose: Whether to print seed information
26
27
Returns:
28
The seed value used
29
30
Examples:
31
# Set specific seed
32
seed_everything(42)
33
34
# Generate random seed
35
used_seed = seed_everything()
36
37
# Seed DataLoader workers for complete reproducibility
38
seed_everything(42, workers=True)
39
"""
40
41
def reset_seed() -> None:
42
"""
43
Reset random seed to previous state.
44
45
Restores the random number generator state to what it was
46
before the last seed_everything() call.
47
"""
48
49
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
50
"""
51
Initialize random seeds for DataLoader workers.
52
53
Used internally by Fabric to ensure DataLoader workers have
54
different random seeds for proper data shuffling.
55
56
Args:
57
worker_id: DataLoader worker ID
58
rank: Process rank for distributed training
59
"""
60
```
61
62
### Data Movement Utilities
63
64
Functions for moving data between devices and handling device placement.
65
66
```python { .api }
67
def move_data_to_device(obj: Any, device: torch.device) -> Any:
68
"""
69
Move tensors and nested data structures to target device.
70
71
Recursively moves tensors in lists, tuples, dictionaries, and
72
custom objects to the specified device.
73
74
Args:
75
obj: Object containing tensors to move
76
device: Target device
77
78
Returns:
79
Object with tensors moved to target device
80
81
Examples:
82
# Move single tensor
83
tensor = torch.randn(10, 10)
84
tensor_gpu = move_data_to_device(tensor, torch.device("cuda"))
85
86
# Move nested data structure
87
data = {
88
"input": torch.randn(32, 784),
89
"target": torch.randint(0, 10, (32,)),
90
"metadata": {"batch_size": 32}
91
}
92
data_gpu = move_data_to_device(data, torch.device("cuda"))
93
"""
94
95
def suggested_max_num_workers(num_cpus: Optional[int] = None) -> int:
96
"""
97
Suggest optimal number of DataLoader workers.
98
99
Calculates recommended number of DataLoader workers based on
100
available CPU cores and system configuration.
101
102
Args:
103
num_cpus: Number of available CPUs (auto-detected if None)
104
105
Returns:
106
Recommended number of DataLoader workers
107
108
Examples:
109
# Auto-detect optimal workers
110
num_workers = suggested_max_num_workers()
111
dataloader = DataLoader(dataset, num_workers=num_workers)
112
113
# Use specific CPU count
114
num_workers = suggested_max_num_workers(num_cpus=8)
115
"""
116
```
117
118
### Object Wrapping Utilities
119
120
Functions for checking and managing Fabric-wrapped objects.
121
122
```python { .api }
123
def is_wrapped(obj: Any) -> bool:
124
"""
125
Check if object is wrapped by Fabric.
126
127
Determines whether a model, optimizer, or dataloader has been
128
wrapped by Fabric for distributed training.
129
130
Args:
131
obj: Object to check
132
133
Returns:
134
True if object is Fabric-wrapped, False otherwise
135
136
Examples:
137
model = nn.Linear(10, 1)
138
print(is_wrapped(model)) # False
139
140
fabric = Fabric()
141
wrapped_model = fabric.setup_module(model)
142
print(is_wrapped(wrapped_model)) # True
143
"""
144
145
def _unwrap_objects(collection: Any) -> Any:
146
"""
147
Unwrap Fabric-wrapped objects in nested collections.
148
149
Recursively unwraps Fabric objects in lists, tuples, dicts,
150
returning the underlying PyTorch objects.
151
152
Args:
153
collection: Collection potentially containing wrapped objects
154
155
Returns:
156
Collection with unwrapped objects
157
"""
158
```
159
160
### Distributed Utilities
161
162
Helper functions for distributed training operations.
163
164
```python { .api }
165
class DistributedSamplerWrapper:
166
"""
167
Wrapper for PyTorch samplers to work with distributed training.
168
169
Automatically handles epoch setting and distributed sampling
170
for custom samplers in distributed environments.
171
"""
172
173
def __init__(self, sampler: Sampler, **kwargs):
174
"""
175
Initialize distributed sampler wrapper.
176
177
Args:
178
sampler: Base sampler to wrap
179
**kwargs: Additional arguments for DistributedSampler
180
"""
181
182
def set_epoch(self, epoch: int) -> None:
183
"""Set epoch for proper shuffling in distributed training."""
184
185
class _InfiniteBarrier:
186
"""
187
Barrier implementation that works across different process groups.
188
Used internally for synchronizing processes in complex distributed setups.
189
"""
190
191
def __call__(self) -> None:
192
"""Execute barrier synchronization."""
193
```
194
195
### Rank-Zero Utilities
196
197
Functions that only execute on the rank-0 process in distributed training.
198
199
```python { .api }
200
def rank_zero_only(fn: callable) -> callable:
201
"""
202
Decorator to execute function only on rank 0.
203
204
Args:
205
fn: Function to wrap
206
207
Returns:
208
Decorated function that only executes on rank 0
209
210
Examples:
211
@rank_zero_only
212
def save_model(model, path):
213
torch.save(model.state_dict(), path)
214
215
# Only rank 0 will save the model
216
save_model(model, "model.pth")
217
"""
218
219
def rank_zero_warn(message: str, category: Warning = UserWarning, stacklevel: int = 1) -> None:
220
"""
221
Issue warning only from rank 0 process.
222
223
Args:
224
message: Warning message
225
category: Warning category
226
stacklevel: Stack level for warning location
227
228
Examples:
229
rank_zero_warn("This is a warning from rank 0 only")
230
"""
231
232
def rank_zero_info(message: str) -> None:
233
"""
234
Log info message only from rank 0 process.
235
236
Args:
237
message: Info message to log
238
239
Examples:
240
rank_zero_info("Training started")
241
"""
242
243
def rank_zero_deprecation(message: str) -> None:
244
"""
245
Issue deprecation warning only from rank 0 process.
246
247
Args:
248
message: Deprecation message
249
250
Examples:
251
rank_zero_deprecation("This function is deprecated, use new_function() instead")
252
"""
253
```
254
255
### Performance Monitoring
256
257
Classes and functions for monitoring training performance and throughput.
258
259
```python { .api }
260
class Throughput:
261
"""
262
Throughput measurement utility.
263
264
Measures processing throughput (samples/second) during training
265
with automatic timing and averaging.
266
"""
267
268
def __init__(self, window_size: int = 100):
269
"""
270
Initialize throughput monitor.
271
272
Args:
273
window_size: Number of measurements to average over
274
"""
275
276
def update(self, batch_size: int) -> None:
277
"""
278
Update throughput measurement with new batch.
279
280
Args:
281
batch_size: Size of processed batch
282
"""
283
284
def compute(self) -> float:
285
"""
286
Compute current throughput.
287
288
Returns:
289
Throughput in samples per second
290
"""
291
292
def reset(self) -> None:
293
"""Reset throughput measurements."""
294
295
class ThroughputMonitor:
296
"""
297
Advanced throughput monitoring with multiple metrics.
298
299
Tracks various performance metrics including samples/second,
300
batches/second, and GPU utilization over time.
301
"""
302
303
def __init__(
304
self,
305
window_size: int = 100,
306
log_interval: int = 50
307
):
308
"""
309
Initialize throughput monitor.
310
311
Args:
312
window_size: Measurement window size
313
log_interval: Logging interval in steps
314
"""
315
316
def on_batch_end(
317
self,
318
batch_size: int,
319
num_samples: int,
320
step: int
321
) -> None:
322
"""Called at the end of each training batch."""
323
324
def get_metrics(self) -> dict[str, float]:
325
"""Get current performance metrics."""
326
327
def measure_flops(
328
model: nn.Module,
329
input_shape: tuple[int, ...],
330
device: Optional[torch.device] = None
331
) -> dict[str, Union[int, float]]:
332
"""
333
Measure FLOPs (floating point operations) for model inference.
334
335
Estimates computational complexity by measuring FLOPs required
336
for a forward pass with given input shape.
337
338
Args:
339
model: PyTorch model to analyze
340
input_shape: Shape of input tensor (excluding batch dimension)
341
device: Device to run measurement on
342
343
Returns:
344
Dictionary with FLOP measurements and model statistics
345
346
Examples:
347
# Measure FLOPs for image classification model
348
flops = measure_flops(model, (3, 224, 224))
349
print(f"Model requires {flops['flops']:,} FLOPs")
350
351
# Measure FLOPs for text model
352
flops = measure_flops(model, (512,)) # sequence length 512
353
"""
354
```
355
356
### General Utilities
357
358
Miscellaneous utility classes and functions.
359
360
```python { .api }
361
class AttributeDict(dict):
362
"""
363
Dictionary that allows attribute-style access to keys.
364
365
Enables accessing dictionary values using dot notation
366
in addition to standard dictionary access.
367
368
Examples:
369
config = AttributeDict({"learning_rate": 0.001, "batch_size": 32})
370
print(config.learning_rate) # 0.001
371
print(config["batch_size"]) # 32
372
373
config.epochs = 100
374
print(config["epochs"]) # 100
375
"""
376
377
def __getattr__(self, key: str) -> Any:
378
"""Get attribute using dot notation."""
379
380
def __setattr__(self, key: str, value: Any) -> None:
381
"""Set attribute using dot notation."""
382
383
def __delattr__(self, key: str) -> None:
384
"""Delete attribute using dot notation."""
385
386
def is_shared_filesystem(path: Union[str, Path]) -> bool:
387
"""
388
Check if path is on a shared filesystem across nodes.
389
390
Determines whether a path is accessible from all nodes in
391
a distributed training setup (e.g., NFS, shared storage).
392
393
Args:
394
path: Path to check
395
396
Returns:
397
True if filesystem is shared across nodes
398
399
Examples:
400
if is_shared_filesystem("/shared/checkpoints"):
401
# Can save checkpoint from any node
402
fabric.save("/shared/checkpoints/model.ckpt", state)
403
else:
404
# Save checkpoint only from rank 0
405
if fabric.is_global_zero:
406
fabric.save("local_model.ckpt", state)
407
"""
408
409
class LightningEnum(Enum):
410
"""
411
Base enumeration class with additional utility methods.
412
413
Extended enum class that provides helper methods for
414
string conversion and validation.
415
"""
416
417
@classmethod
418
def from_str(cls, value: str) -> "LightningEnum":
419
"""Create enum from string value."""
420
421
def __str__(self) -> str:
422
"""String representation of enum value."""
423
424
def disable_possible_user_warnings() -> None:
425
"""
426
Disable possible user warnings from Lightning.
427
428
Suppresses warnings that may be triggered by user code
429
but are not critical for operation.
430
431
Examples:
432
# Disable warnings in production
433
disable_possible_user_warnings()
434
"""
435
```
436
437
## Usage Examples
438
439
### Reproducible Training Setup
440
441
```python
442
from lightning.fabric import Fabric, seed_everything
443
444
# Set seed for reproducibility
445
seed_everything(42, workers=True)
446
447
fabric = Fabric(accelerator="gpu", devices=2)
448
449
# DataLoader will automatically use seeded workers
450
dataloader = fabric.setup_dataloaders(
451
DataLoader(dataset, num_workers=4, shuffle=True)
452
)
453
```
454
455
### Optimal DataLoader Configuration
456
457
```python
458
from lightning.fabric.utilities import suggested_max_num_workers
459
460
# Get optimal number of workers
461
num_workers = suggested_max_num_workers()
462
463
dataloader = DataLoader(
464
dataset,
465
batch_size=32,
466
num_workers=num_workers,
467
pin_memory=True
468
)
469
```
470
471
### Performance Monitoring
472
473
```python
474
from lightning.fabric.utilities import ThroughputMonitor
475
476
# Initialize performance monitor
477
throughput = ThroughputMonitor(window_size=100, log_interval=50)
478
479
# Training loop with monitoring
480
for step, batch in enumerate(dataloader):
481
batch_size = batch[0].size(0)
482
483
# Training step
484
loss = train_step(model, batch)
485
486
# Update throughput monitoring
487
throughput.on_batch_end(
488
batch_size=batch_size,
489
num_samples=batch_size,
490
step=step
491
)
492
493
if step % 50 == 0:
494
metrics = throughput.get_metrics()
495
fabric.print(f"Step {step}: {metrics['samples_per_sec']:.1f} samples/sec")
496
```
497
498
### Device-Agnostic Data Movement
499
500
```python
501
from lightning.fabric.utilities import move_data_to_device
502
503
# Complex nested data structure
504
batch = {
505
"input": torch.randn(32, 784),
506
"target": torch.randint(0, 10, (32,)),
507
"metadata": {
508
"lengths": torch.randint(10, 100, (32,)),
509
"mask": torch.ones(32, 100, dtype=torch.bool)
510
}
511
}
512
513
# Move entire structure to device
514
device = fabric.device
515
batch = move_data_to_device(batch, device)
516
```
517
518
### Rank-Zero Operations
519
520
```python
521
from lightning.fabric.utilities import rank_zero_only, rank_zero_warn
522
523
@rank_zero_only
524
def save_artifacts(model, metrics, epoch):
525
"""Save model and log metrics only from rank 0."""
526
torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
527
with open("metrics.json", "w") as f:
528
json.dump(metrics, f)
529
530
# Training loop
531
for epoch in range(num_epochs):
532
train_metrics = train_epoch(model, dataloader)
533
534
# Only rank 0 saves artifacts
535
save_artifacts(model, train_metrics, epoch)
536
537
# Warning only from rank 0
538
if train_metrics["loss"] > previous_loss:
539
rank_zero_warn("Loss increased compared to previous epoch")
540
```
541
542
### FLOP Measurement
543
544
```python
545
from lightning.fabric.utilities import measure_flops
546
547
# Measure model complexity
548
model = nn.Sequential(
549
nn.Linear(784, 256),
550
nn.ReLU(),
551
nn.Linear(256, 10)
552
)
553
554
flops_info = measure_flops(model, (784,))
555
fabric.print(f"Model FLOPs: {flops_info['flops']:,}")
556
fabric.print(f"Model parameters: {flops_info['params']:,}")
557
```
558
559
### Configuration Management
560
561
```python
562
from lightning.fabric.utilities import AttributeDict
563
564
# Configuration with attribute access
565
config = AttributeDict({
566
"model": {
567
"hidden_size": 256,
568
"num_layers": 3
569
},
570
"training": {
571
"learning_rate": 0.001,
572
"batch_size": 32,
573
"epochs": 100
574
}
575
})
576
577
# Access using dot notation
578
model = create_model(
579
hidden_size=config.model.hidden_size,
580
num_layers=config.model.num_layers
581
)
582
583
optimizer = torch.optim.Adam(
584
model.parameters(),
585
lr=config.training.learning_rate
586
)
587
```
588
589
### Filesystem Utilities
590
591
```python
592
from lightning.fabric.utilities import is_shared_filesystem
593
594
checkpoint_path = "/shared/storage/checkpoints"
595
596
if is_shared_filesystem(checkpoint_path):
597
# All nodes can access this path
598
fabric.save(f"{checkpoint_path}/model.ckpt", state)
599
else:
600
# Use local storage with rank coordination
601
if fabric.is_global_zero:
602
fabric.save("model.ckpt", state)
603
604
# Wait for rank 0 to finish saving
605
fabric.barrier("checkpoint_save")
606
```