or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcallbacks.mdcore-training.mddata.mdfabric.mdindex.mdloggers.mdprecision.mdprofilers.mdstrategies.md

fabric.mddocs/

0

# Lightning Fabric

1

2

Lightweight training acceleration framework providing expert-level control over training loops, device management, and distributed strategies without high-level abstractions. Fabric gives you the flexibility of raw PyTorch with the power of Lightning's optimizations.

3

4

## Capabilities

5

6

### Fabric Core

7

8

Main Fabric class that accelerates PyTorch training with distributed training, mixed precision, and device management while maintaining full control over the training loop.

9

10

```python { .api }

11

class Fabric:

12

def __init__(

13

self,

14

accelerator: str = "auto",

15

strategy: str = "auto",

16

devices: Union[List[int], str, int] = "auto",

17

num_nodes: int = 1,

18

precision: Union[str, int] = "32-true",

19

plugins: Optional[Union[Plugin, List[Plugin]]] = None,

20

callbacks: Optional[Union[Callback, List[Callback]]] = None,

21

loggers: Optional[Union[Logger, List[Logger]]] = None,

22

**kwargs

23

):

24

"""

25

Initialize Fabric for training acceleration.

26

27

Args:

28

accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')

29

strategy: Distributed strategy ('ddp', 'fsdp', 'deepspeed', etc.)

30

devices: Which devices to use

31

num_nodes: Number of nodes for distributed training

32

precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)

33

plugins: Additional plugins for customization

34

callbacks: Callbacks for training lifecycle hooks

35

loggers: Loggers for experiment tracking

36

"""

37

38

def setup(

39

self,

40

model: nn.Module,

41

*optimizers: Optimizer

42

) -> Union[nn.Module, Tuple[nn.Module, ...]]:

43

"""

44

Set up model and optimizers for training.

45

46

Args:

47

model: PyTorch model to accelerate

48

*optimizers: Optimizers to set up

49

50

Returns:

51

Wrapped model and optimizers ready for training

52

"""

53

54

def setup_dataloaders(

55

self,

56

*dataloaders: DataLoader

57

) -> Union[DataLoader, List[DataLoader]]:

58

"""

59

Set up data loaders for distributed training.

60

61

Args:

62

*dataloaders: Data loaders to set up

63

64

Returns:

65

Wrapped data loaders ready for distributed training

66

"""

67

68

def backward(self, tensor: Tensor) -> None:

69

"""

70

Perform backward pass with proper scaling and synchronization.

71

72

Args:

73

tensor: Loss tensor to compute gradients from

74

"""

75

76

def clip_gradients(

77

self,

78

model: nn.Module,

79

optimizer: Optimizer,

80

max_norm: Union[float, int],

81

norm_type: Union[float, int] = 2.0,

82

error_if_nonfinite: bool = True

83

) -> Tensor:

84

"""

85

Clip gradients by norm.

86

87

Args:

88

model: Model whose gradients to clip

89

optimizer: Optimizer being used

90

max_norm: Maximum norm for gradients

91

norm_type: Type of norm to use

92

error_if_nonfinite: Raise error if gradients are non-finite

93

94

Returns:

95

Total norm of the gradients

96

"""

97

98

def all_gather(

99

self,

100

tensor: Tensor,

101

group: Optional[Any] = None,

102

sync_grads: bool = False

103

) -> Tensor:

104

"""

105

Gather tensors from all processes.

106

107

Args:

108

tensor: Tensor to gather

109

group: Process group

110

sync_grads: Synchronize gradients

111

112

Returns:

113

Gathered tensor from all processes

114

"""

115

116

def all_reduce(

117

self,

118

tensor: Tensor,

119

group: Optional[Any] = None,

120

reduce_op: str = "mean"

121

) -> Tensor:

122

"""

123

Reduce tensor across all processes.

124

125

Args:

126

tensor: Tensor to reduce

127

group: Process group

128

reduce_op: Reduction operation ('mean', 'sum')

129

130

Returns:

131

Reduced tensor

132

"""

133

134

def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:

135

"""

136

Broadcast tensor from source process to all processes.

137

138

Args:

139

tensor: Tensor to broadcast

140

src: Source rank

141

142

Returns:

143

Broadcasted tensor

144

"""

145

146

def barrier(self, name: Optional[str] = None) -> None:

147

"""

148

Synchronize all processes.

149

150

Args:

151

name: Optional barrier name for debugging

152

"""

153

154

def is_global_zero(self) -> bool:

155

"""

156

Check if current process is global rank 0.

157

158

Returns:

159

True if global rank 0

160

"""

161

162

def print(self, *args, **kwargs) -> None:

163

"""

164

Print only on rank 0.

165

166

Args:

167

*args: Arguments to print

168

**kwargs: Keyword arguments for print

169

"""

170

171

def log(self, name: str, value: Any, step: Optional[int] = None) -> None:

172

"""

173

Log a metric.

174

175

Args:

176

name: Metric name

177

value: Metric value

178

step: Optional step number

179

"""

180

181

def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:

182

"""

183

Log a dictionary of metrics.

184

185

Args:

186

metrics: Dictionary of metrics

187

step: Optional step number

188

"""

189

190

def save(self, path: str, state: Dict[str, Any]) -> None:

191

"""

192

Save checkpoint.

193

194

Args:

195

path: Path to save checkpoint

196

state: State dictionary to save

197

"""

198

199

def load(self, path: str) -> Dict[str, Any]:

200

"""

201

Load checkpoint.

202

203

Args:

204

path: Path to load checkpoint from

205

206

Returns:

207

Loaded state dictionary

208

"""

209

210

@property

211

def device(self) -> torch.device:

212

"""Get the current device."""

213

214

@property

215

def global_rank(self) -> int:

216

"""Get global rank of current process."""

217

218

@property

219

def local_rank(self) -> int:

220

"""Get local rank of current process."""

221

222

@property

223

def node_rank(self) -> int:

224

"""Get node rank of current process."""

225

226

@property

227

def world_size(self) -> int:

228

"""Get total number of processes."""

229

230

def to_device(self, obj: Any) -> Any:

231

"""

232

Move object to device.

233

234

Args:

235

obj: Object to move to device

236

237

Returns:

238

Object on the device

239

"""

240

```

241

242

### Utility Functions

243

244

Core utility functions for reproducibility, object inspection, and common operations in Fabric workflows.

245

246

```python { .api }

247

def seed_everything(seed: int, workers: bool = False) -> int:

248

"""

249

Set random seeds for reproducibility.

250

251

Args:

252

seed: Random seed to set

253

workers: Also set seed for data loader workers

254

255

Returns:

256

The seed that was set

257

"""

258

259

def is_wrapped(obj: Any) -> bool:

260

"""

261

Check if an object has been wrapped by Fabric.

262

263

Args:

264

obj: Object to check

265

266

Returns:

267

True if object is wrapped by Fabric

268

"""

269

```

270

271

## Basic Usage Example

272

273

```python

274

import torch

275

import torch.nn as nn

276

from torch.utils.data import DataLoader, TensorDataset

277

from lightning import Fabric

278

279

# Initialize Fabric

280

fabric = Fabric(accelerator="gpu", devices=2, strategy="ddp")

281

282

# Define model and optimizer

283

model = nn.Linear(10, 1)

284

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

285

286

# Setup model and optimizer with Fabric

287

model, optimizer = fabric.setup(model, optimizer)

288

289

# Create sample data and dataloader

290

data = torch.randn(1000, 10)

291

targets = torch.randn(1000, 1)

292

dataset = TensorDataset(data, targets)

293

dataloader = DataLoader(dataset, batch_size=32)

294

295

# Setup dataloader

296

dataloader = fabric.setup_dataloaders(dataloader)

297

298

# Training loop with full control

299

for epoch in range(10):

300

for batch_idx, (x, y) in enumerate(dataloader):

301

optimizer.zero_grad()

302

303

# Forward pass

304

y_pred = model(x)

305

loss = nn.functional.mse_loss(y_pred, y)

306

307

# Backward pass - Fabric handles scaling and synchronization

308

fabric.backward(loss)

309

310

optimizer.step()

311

312

# Log metrics

313

if batch_idx % 10 == 0:

314

fabric.log("train_loss", loss.item())

315

fabric.print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

316

317

# Save checkpoint

318

state = {

319

"model": model.state_dict(),

320

"optimizer": optimizer.state_dict(),

321

"epoch": epoch

322

}

323

fabric.save(f"checkpoint_epoch_{epoch}.ckpt", state)

324

```

325

326

## Advanced Usage Example

327

328

```python

329

import torch

330

import torch.nn as nn

331

from torch.utils.data import DataLoader

332

from lightning import Fabric

333

334

# Initialize Fabric with advanced configuration

335

fabric = Fabric(

336

accelerator="gpu",

337

devices=4,

338

strategy="fsdp",

339

precision="16-mixed",

340

plugins=None

341

)

342

343

class MyModel(nn.Module):

344

def __init__(self):

345

super().__init__()

346

self.layers = nn.Sequential(

347

nn.Linear(784, 256),

348

nn.ReLU(),

349

nn.Dropout(0.2),

350

nn.Linear(256, 128),

351

nn.ReLU(),

352

nn.Dropout(0.2),

353

nn.Linear(128, 10)

354

)

355

356

def forward(self, x):

357

return self.layers(x)

358

359

# Model and optimizers

360

model = MyModel()

361

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

362

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

363

364

# Setup with Fabric

365

model, optimizer = fabric.setup(model, optimizer)

366

367

# Training loop with advanced features

368

for epoch in range(100):

369

model.train()

370

371

for batch_idx, (data, target) in enumerate(train_dataloader):

372

optimizer.zero_grad()

373

374

output = model(data)

375

loss = nn.functional.cross_entropy(output, target)

376

377

# Backward with automatic mixed precision

378

fabric.backward(loss)

379

380

# Gradient clipping

381

fabric.clip_gradients(model, optimizer, max_norm=1.0)

382

383

optimizer.step()

384

385

# Metrics logging

386

if batch_idx % 100 == 0:

387

accuracy = (output.argmax(dim=1) == target).float().mean()

388

389

# Log metrics - automatically handles distributed averaging

390

fabric.log_dict({

391

"train_loss": loss.item(),

392

"train_acc": accuracy.item(),

393

"lr": scheduler.get_last_lr()[0]

394

})

395

396

# Print only on rank 0

397

fabric.print(f"Epoch {epoch}/{100}, Batch {batch_idx}, "

398

f"Loss: {loss.item():.4f}, Acc: {accuracy.item():.4f}")

399

400

scheduler.step()

401

402

# Synchronization barrier

403

fabric.barrier()

404

405

# Save checkpoint (only on rank 0)

406

if fabric.is_global_zero():

407

checkpoint = {

408

"model": model.state_dict(),

409

"optimizer": optimizer.state_dict(),

410

"scheduler": scheduler.state_dict(),

411

"epoch": epoch,

412

}

413

fabric.save(f"model_epoch_{epoch}.ckpt", checkpoint)

414

```