or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

utilities.mddocs/

0

# Utilities

1

2

Helper functions and utilities for seeding, data movement, distributed operations, and performance monitoring.

3

4

## Capabilities

5

6

### Seeding Utilities

7

8

Functions for controlling random number generation and ensuring reproducibility.

9

10

```python { .api }

11

def seed_everything(

12

seed: Optional[int] = None,

13

workers: bool = False,

14

verbose: bool = True

15

) -> int:

16

"""

17

Set global random seeds for reproducible results.

18

19

Sets seeds for Python random, NumPy, PyTorch, and CUDA random number

20

generators to ensure reproducible training runs.

21

22

Args:

23

seed: Random seed value. If None, generates random seed

24

workers: Whether to seed DataLoader workers

25

verbose: Whether to print seed information

26

27

Returns:

28

The seed value used

29

30

Examples:

31

# Set specific seed

32

seed_everything(42)

33

34

# Generate random seed

35

used_seed = seed_everything()

36

37

# Seed DataLoader workers for complete reproducibility

38

seed_everything(42, workers=True)

39

"""

40

41

def reset_seed() -> None:

42

"""

43

Reset random seed to previous state.

44

45

Restores the random number generator state to what it was

46

before the last seed_everything() call.

47

"""

48

49

def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:

50

"""

51

Initialize random seeds for DataLoader workers.

52

53

Used internally by Fabric to ensure DataLoader workers have

54

different random seeds for proper data shuffling.

55

56

Args:

57

worker_id: DataLoader worker ID

58

rank: Process rank for distributed training

59

"""

60

```

61

62

### Data Movement Utilities

63

64

Functions for moving data between devices and handling device placement.

65

66

```python { .api }

67

def move_data_to_device(obj: Any, device: torch.device) -> Any:

68

"""

69

Move tensors and nested data structures to target device.

70

71

Recursively moves tensors in lists, tuples, dictionaries, and

72

custom objects to the specified device.

73

74

Args:

75

obj: Object containing tensors to move

76

device: Target device

77

78

Returns:

79

Object with tensors moved to target device

80

81

Examples:

82

# Move single tensor

83

tensor = torch.randn(10, 10)

84

tensor_gpu = move_data_to_device(tensor, torch.device("cuda"))

85

86

# Move nested data structure

87

data = {

88

"input": torch.randn(32, 784),

89

"target": torch.randint(0, 10, (32,)),

90

"metadata": {"batch_size": 32}

91

}

92

data_gpu = move_data_to_device(data, torch.device("cuda"))

93

"""

94

95

def suggested_max_num_workers(num_cpus: Optional[int] = None) -> int:

96

"""

97

Suggest optimal number of DataLoader workers.

98

99

Calculates recommended number of DataLoader workers based on

100

available CPU cores and system configuration.

101

102

Args:

103

num_cpus: Number of available CPUs (auto-detected if None)

104

105

Returns:

106

Recommended number of DataLoader workers

107

108

Examples:

109

# Auto-detect optimal workers

110

num_workers = suggested_max_num_workers()

111

dataloader = DataLoader(dataset, num_workers=num_workers)

112

113

# Use specific CPU count

114

num_workers = suggested_max_num_workers(num_cpus=8)

115

"""

116

```

117

118

### Object Wrapping Utilities

119

120

Functions for checking and managing Fabric-wrapped objects.

121

122

```python { .api }

123

def is_wrapped(obj: Any) -> bool:

124

"""

125

Check if object is wrapped by Fabric.

126

127

Determines whether a model, optimizer, or dataloader has been

128

wrapped by Fabric for distributed training.

129

130

Args:

131

obj: Object to check

132

133

Returns:

134

True if object is Fabric-wrapped, False otherwise

135

136

Examples:

137

model = nn.Linear(10, 1)

138

print(is_wrapped(model)) # False

139

140

fabric = Fabric()

141

wrapped_model = fabric.setup_module(model)

142

print(is_wrapped(wrapped_model)) # True

143

"""

144

145

def _unwrap_objects(collection: Any) -> Any:

146

"""

147

Unwrap Fabric-wrapped objects in nested collections.

148

149

Recursively unwraps Fabric objects in lists, tuples, dicts,

150

returning the underlying PyTorch objects.

151

152

Args:

153

collection: Collection potentially containing wrapped objects

154

155

Returns:

156

Collection with unwrapped objects

157

"""

158

```

159

160

### Distributed Utilities

161

162

Helper functions for distributed training operations.

163

164

```python { .api }

165

class DistributedSamplerWrapper:

166

"""

167

Wrapper for PyTorch samplers to work with distributed training.

168

169

Automatically handles epoch setting and distributed sampling

170

for custom samplers in distributed environments.

171

"""

172

173

def __init__(self, sampler: Sampler, **kwargs):

174

"""

175

Initialize distributed sampler wrapper.

176

177

Args:

178

sampler: Base sampler to wrap

179

**kwargs: Additional arguments for DistributedSampler

180

"""

181

182

def set_epoch(self, epoch: int) -> None:

183

"""Set epoch for proper shuffling in distributed training."""

184

185

class _InfiniteBarrier:

186

"""

187

Barrier implementation that works across different process groups.

188

Used internally for synchronizing processes in complex distributed setups.

189

"""

190

191

def __call__(self) -> None:

192

"""Execute barrier synchronization."""

193

```

194

195

### Rank-Zero Utilities

196

197

Functions that only execute on the rank-0 process in distributed training.

198

199

```python { .api }

200

def rank_zero_only(fn: callable) -> callable:

201

"""

202

Decorator to execute function only on rank 0.

203

204

Args:

205

fn: Function to wrap

206

207

Returns:

208

Decorated function that only executes on rank 0

209

210

Examples:

211

@rank_zero_only

212

def save_model(model, path):

213

torch.save(model.state_dict(), path)

214

215

# Only rank 0 will save the model

216

save_model(model, "model.pth")

217

"""

218

219

def rank_zero_warn(message: str, category: Warning = UserWarning, stacklevel: int = 1) -> None:

220

"""

221

Issue warning only from rank 0 process.

222

223

Args:

224

message: Warning message

225

category: Warning category

226

stacklevel: Stack level for warning location

227

228

Examples:

229

rank_zero_warn("This is a warning from rank 0 only")

230

"""

231

232

def rank_zero_info(message: str) -> None:

233

"""

234

Log info message only from rank 0 process.

235

236

Args:

237

message: Info message to log

238

239

Examples:

240

rank_zero_info("Training started")

241

"""

242

243

def rank_zero_deprecation(message: str) -> None:

244

"""

245

Issue deprecation warning only from rank 0 process.

246

247

Args:

248

message: Deprecation message

249

250

Examples:

251

rank_zero_deprecation("This function is deprecated, use new_function() instead")

252

"""

253

```

254

255

### Performance Monitoring

256

257

Classes and functions for monitoring training performance and throughput.

258

259

```python { .api }

260

class Throughput:

261

"""

262

Throughput measurement utility.

263

264

Measures processing throughput (samples/second) during training

265

with automatic timing and averaging.

266

"""

267

268

def __init__(self, window_size: int = 100):

269

"""

270

Initialize throughput monitor.

271

272

Args:

273

window_size: Number of measurements to average over

274

"""

275

276

def update(self, batch_size: int) -> None:

277

"""

278

Update throughput measurement with new batch.

279

280

Args:

281

batch_size: Size of processed batch

282

"""

283

284

def compute(self) -> float:

285

"""

286

Compute current throughput.

287

288

Returns:

289

Throughput in samples per second

290

"""

291

292

def reset(self) -> None:

293

"""Reset throughput measurements."""

294

295

class ThroughputMonitor:

296

"""

297

Advanced throughput monitoring with multiple metrics.

298

299

Tracks various performance metrics including samples/second,

300

batches/second, and GPU utilization over time.

301

"""

302

303

def __init__(

304

self,

305

window_size: int = 100,

306

log_interval: int = 50

307

):

308

"""

309

Initialize throughput monitor.

310

311

Args:

312

window_size: Measurement window size

313

log_interval: Logging interval in steps

314

"""

315

316

def on_batch_end(

317

self,

318

batch_size: int,

319

num_samples: int,

320

step: int

321

) -> None:

322

"""Called at the end of each training batch."""

323

324

def get_metrics(self) -> dict[str, float]:

325

"""Get current performance metrics."""

326

327

def measure_flops(

328

model: nn.Module,

329

input_shape: tuple[int, ...],

330

device: Optional[torch.device] = None

331

) -> dict[str, Union[int, float]]:

332

"""

333

Measure FLOPs (floating point operations) for model inference.

334

335

Estimates computational complexity by measuring FLOPs required

336

for a forward pass with given input shape.

337

338

Args:

339

model: PyTorch model to analyze

340

input_shape: Shape of input tensor (excluding batch dimension)

341

device: Device to run measurement on

342

343

Returns:

344

Dictionary with FLOP measurements and model statistics

345

346

Examples:

347

# Measure FLOPs for image classification model

348

flops = measure_flops(model, (3, 224, 224))

349

print(f"Model requires {flops['flops']:,} FLOPs")

350

351

# Measure FLOPs for text model

352

flops = measure_flops(model, (512,)) # sequence length 512

353

"""

354

```

355

356

### General Utilities

357

358

Miscellaneous utility classes and functions.

359

360

```python { .api }

361

class AttributeDict(dict):

362

"""

363

Dictionary that allows attribute-style access to keys.

364

365

Enables accessing dictionary values using dot notation

366

in addition to standard dictionary access.

367

368

Examples:

369

config = AttributeDict({"learning_rate": 0.001, "batch_size": 32})

370

print(config.learning_rate) # 0.001

371

print(config["batch_size"]) # 32

372

373

config.epochs = 100

374

print(config["epochs"]) # 100

375

"""

376

377

def __getattr__(self, key: str) -> Any:

378

"""Get attribute using dot notation."""

379

380

def __setattr__(self, key: str, value: Any) -> None:

381

"""Set attribute using dot notation."""

382

383

def __delattr__(self, key: str) -> None:

384

"""Delete attribute using dot notation."""

385

386

def is_shared_filesystem(path: Union[str, Path]) -> bool:

387

"""

388

Check if path is on a shared filesystem across nodes.

389

390

Determines whether a path is accessible from all nodes in

391

a distributed training setup (e.g., NFS, shared storage).

392

393

Args:

394

path: Path to check

395

396

Returns:

397

True if filesystem is shared across nodes

398

399

Examples:

400

if is_shared_filesystem("/shared/checkpoints"):

401

# Can save checkpoint from any node

402

fabric.save("/shared/checkpoints/model.ckpt", state)

403

else:

404

# Save checkpoint only from rank 0

405

if fabric.is_global_zero:

406

fabric.save("local_model.ckpt", state)

407

"""

408

409

class LightningEnum(Enum):

410

"""

411

Base enumeration class with additional utility methods.

412

413

Extended enum class that provides helper methods for

414

string conversion and validation.

415

"""

416

417

@classmethod

418

def from_str(cls, value: str) -> "LightningEnum":

419

"""Create enum from string value."""

420

421

def __str__(self) -> str:

422

"""String representation of enum value."""

423

424

def disable_possible_user_warnings() -> None:

425

"""

426

Disable possible user warnings from Lightning.

427

428

Suppresses warnings that may be triggered by user code

429

but are not critical for operation.

430

431

Examples:

432

# Disable warnings in production

433

disable_possible_user_warnings()

434

"""

435

```

436

437

## Usage Examples

438

439

### Reproducible Training Setup

440

441

```python

442

from lightning.fabric import Fabric, seed_everything

443

444

# Set seed for reproducibility

445

seed_everything(42, workers=True)

446

447

fabric = Fabric(accelerator="gpu", devices=2)

448

449

# DataLoader will automatically use seeded workers

450

dataloader = fabric.setup_dataloaders(

451

DataLoader(dataset, num_workers=4, shuffle=True)

452

)

453

```

454

455

### Optimal DataLoader Configuration

456

457

```python

458

from lightning.fabric.utilities import suggested_max_num_workers

459

460

# Get optimal number of workers

461

num_workers = suggested_max_num_workers()

462

463

dataloader = DataLoader(

464

dataset,

465

batch_size=32,

466

num_workers=num_workers,

467

pin_memory=True

468

)

469

```

470

471

### Performance Monitoring

472

473

```python

474

from lightning.fabric.utilities import ThroughputMonitor

475

476

# Initialize performance monitor

477

throughput = ThroughputMonitor(window_size=100, log_interval=50)

478

479

# Training loop with monitoring

480

for step, batch in enumerate(dataloader):

481

batch_size = batch[0].size(0)

482

483

# Training step

484

loss = train_step(model, batch)

485

486

# Update throughput monitoring

487

throughput.on_batch_end(

488

batch_size=batch_size,

489

num_samples=batch_size,

490

step=step

491

)

492

493

if step % 50 == 0:

494

metrics = throughput.get_metrics()

495

fabric.print(f"Step {step}: {metrics['samples_per_sec']:.1f} samples/sec")

496

```

497

498

### Device-Agnostic Data Movement

499

500

```python

501

from lightning.fabric.utilities import move_data_to_device

502

503

# Complex nested data structure

504

batch = {

505

"input": torch.randn(32, 784),

506

"target": torch.randint(0, 10, (32,)),

507

"metadata": {

508

"lengths": torch.randint(10, 100, (32,)),

509

"mask": torch.ones(32, 100, dtype=torch.bool)

510

}

511

}

512

513

# Move entire structure to device

514

device = fabric.device

515

batch = move_data_to_device(batch, device)

516

```

517

518

### Rank-Zero Operations

519

520

```python

521

from lightning.fabric.utilities import rank_zero_only, rank_zero_warn

522

523

@rank_zero_only

524

def save_artifacts(model, metrics, epoch):

525

"""Save model and log metrics only from rank 0."""

526

torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")

527

with open("metrics.json", "w") as f:

528

json.dump(metrics, f)

529

530

# Training loop

531

for epoch in range(num_epochs):

532

train_metrics = train_epoch(model, dataloader)

533

534

# Only rank 0 saves artifacts

535

save_artifacts(model, train_metrics, epoch)

536

537

# Warning only from rank 0

538

if train_metrics["loss"] > previous_loss:

539

rank_zero_warn("Loss increased compared to previous epoch")

540

```

541

542

### FLOP Measurement

543

544

```python

545

from lightning.fabric.utilities import measure_flops

546

547

# Measure model complexity

548

model = nn.Sequential(

549

nn.Linear(784, 256),

550

nn.ReLU(),

551

nn.Linear(256, 10)

552

)

553

554

flops_info = measure_flops(model, (784,))

555

fabric.print(f"Model FLOPs: {flops_info['flops']:,}")

556

fabric.print(f"Model parameters: {flops_info['params']:,}")

557

```

558

559

### Configuration Management

560

561

```python

562

from lightning.fabric.utilities import AttributeDict

563

564

# Configuration with attribute access

565

config = AttributeDict({

566

"model": {

567

"hidden_size": 256,

568

"num_layers": 3

569

},

570

"training": {

571

"learning_rate": 0.001,

572

"batch_size": 32,

573

"epochs": 100

574

}

575

})

576

577

# Access using dot notation

578

model = create_model(

579

hidden_size=config.model.hidden_size,

580

num_layers=config.model.num_layers

581

)

582

583

optimizer = torch.optim.Adam(

584

model.parameters(),

585

lr=config.training.learning_rate

586

)

587

```

588

589

### Filesystem Utilities

590

591

```python

592

from lightning.fabric.utilities import is_shared_filesystem

593

594

checkpoint_path = "/shared/storage/checkpoints"

595

596

if is_shared_filesystem(checkpoint_path):

597

# All nodes can access this path

598

fabric.save(f"{checkpoint_path}/model.ckpt", state)

599

else:

600

# Use local storage with rank coordination

601

if fabric.is_global_zero:

602

fabric.save("model.ckpt", state)

603

604

# Wait for rank 0 to finish saving

605

fabric.barrier("checkpoint_save")

606

```