or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

index.mddocs/

0

# Lightning Fabric

1

2

Lightning Fabric is a lightweight PyTorch scaling library that provides expert-level control over PyTorch training loops and scaling strategies. It enables developers to scale complex models including foundation models, LLMs, diffusion models, transformers, and reinforcement learning across any device or scale without boilerplate code.

3

4

## Package Information

5

6

- **Package Name**: lightning-fabric

7

- **Language**: Python

8

- **Installation**: `pip install lightning-fabric`

9

10

## Core Imports

11

12

```python

13

from lightning.fabric import Fabric, seed_everything, is_wrapped

14

```

15

16

Additional commonly used utilities:

17

18

```python

19

from lightning.fabric.utilities import (

20

move_data_to_device,

21

suggested_max_num_workers,

22

rank_zero_only,

23

rank_zero_warn,

24

rank_zero_info

25

)

26

```

27

28

## Basic Usage

29

30

```python

31

from lightning.fabric import Fabric

32

import torch

33

import torch.nn as nn

34

from torch.utils.data import DataLoader

35

36

# Initialize Fabric with desired configuration

37

fabric = Fabric(accelerator="auto", devices="auto", strategy="auto")

38

39

# Define model and optimizer

40

model = nn.Linear(10, 1)

41

optimizer = torch.optim.AdamW(model.parameters())

42

43

# Setup with Fabric (handles device placement and distributed wrapping)

44

model, optimizer = fabric.setup(model, optimizer)

45

46

# Setup dataloader

47

dataloader = fabric.setup_dataloaders(DataLoader(...))

48

49

# Training loop

50

model.train()

51

for batch in dataloader:

52

x, y = batch

53

optimizer.zero_grad()

54

55

# Forward pass

56

y_pred = model(x)

57

loss = nn.functional.mse_loss(y_pred, y)

58

59

# Backward pass with automatic scaling and gradient handling

60

fabric.backward(loss)

61

optimizer.step()

62

63

# Save checkpoint

64

state = {"model": model, "optimizer": optimizer}

65

fabric.save("checkpoint.ckpt", state)

66

```

67

68

## Architecture

69

70

Lightning Fabric uses a plugin-based architecture that enables flexible scaling and customization:

71

72

- **Fabric**: Main orchestrator class that coordinates all components

73

- **Accelerators**: Hardware abstraction (CPU, GPU, TPU, MPS)

74

- **Strategies**: Distribution patterns (single device, data parallel, model parallel, FSDP, DeepSpeed)

75

- **Precision Plugins**: Mixed precision and quantization support

76

- **Environment Plugins**: Cluster environment detection and configuration

77

- **Loggers**: Experiment tracking and metric logging

78

- **Wrappers**: Transparent wrapping of PyTorch objects for distributed training

79

80

This plugin system allows Fabric to work across any hardware configuration and distributed training setup while maintaining the same simple API.

81

82

## Capabilities

83

84

### Core Training Orchestration

85

86

Main Fabric class that handles distributed training setup, model and optimizer configuration, checkpoint management, and training utilities.

87

88

```python { .api }

89

class Fabric:

90

def __init__(

91

self,

92

accelerator: Union[str, Accelerator] = "auto",

93

strategy: Union[str, Strategy] = "auto",

94

devices: Union[list[int], str, int] = "auto",

95

num_nodes: int = 1,

96

precision: Optional[Union[str, int]] = None,

97

plugins: Optional[Union[Any, list[Any]]] = None,

98

callbacks: Optional[Union[list[Any], Any]] = None,

99

loggers: Optional[Union[Logger, list[Logger]]] = None

100

): ...

101

102

def setup(self, module, *optimizers, move_to_device=True): ...

103

def setup_module(self, module, move_to_device=True): ...

104

def setup_optimizers(self, *optimizers): ...

105

def setup_dataloaders(self, *dataloaders, use_distributed_sampler=True): ...

106

def backward(self, tensor, *args, model=None, **kwargs): ...

107

def save(self, path, state, filter=None): ...

108

def load(self, path, state=None, strict=True): ...

109

```

110

111

[Core Training](./core-training.md)

112

113

### Distributed Operations

114

115

Collective communication operations for synchronizing data and gradients across processes in distributed training.

116

117

```python { .api }

118

def barrier(self, name=None) -> None: ...

119

def broadcast(self, obj, src=0): ...

120

def all_gather(self, data, group=None, sync_grads=False): ...

121

def all_reduce(self, data, group=None, reduce_op="mean"): ...

122

```

123

124

[Distributed Operations](./distributed.md)

125

126

### Accelerators

127

128

Hardware acceleration plugins for different compute devices including CPU, CUDA GPUs, Apple MPS, and TPUs.

129

130

```python { .api }

131

class Accelerator: ... # Abstract base

132

class CPUAccelerator(Accelerator): ...

133

class CUDAAccelerator(Accelerator): ...

134

class MPSAccelerator(Accelerator): ...

135

class XLAAccelerator(Accelerator): ...

136

```

137

138

[Accelerators](./accelerators.md)

139

140

### Strategies

141

142

Distributed training strategies for scaling models across devices and nodes.

143

144

```python { .api }

145

class Strategy: ... # Abstract base

146

class SingleDeviceStrategy(Strategy): ...

147

class DataParallelStrategy(Strategy): ...

148

class DDPStrategy(Strategy): ...

149

class DeepSpeedStrategy(Strategy): ...

150

class FSDPStrategy(Strategy): ...

151

class XLAStrategy(Strategy): ...

152

```

153

154

[Strategies](./strategies.md)

155

156

### Precision and Quantization

157

158

Precision plugins for mixed precision training, quantization, and memory optimization.

159

160

```python { .api }

161

class Precision: ... # Abstract base

162

class DoublePrecision(Precision): ...

163

class HalfPrecision(Precision): ...

164

class MixedPrecision(Precision): ...

165

class BitsandbytesPrecision(Precision): ...

166

class DeepSpeedPrecision(Precision): ...

167

class FSDPPrecision(Precision): ...

168

```

169

170

[Precision](./precision.md)

171

172

### Utilities

173

174

Helper functions for seeding, data movement, distributed utilities, and performance monitoring.

175

176

```python { .api }

177

def seed_everything(seed=None, workers=False, verbose=True) -> int: ...

178

def is_wrapped(obj) -> bool: ...

179

def move_data_to_device(obj, device): ...

180

def suggested_max_num_workers(num_cpus): ...

181

```

182

183

[Utilities](./utilities.md)

184

185

## Types

186

187

```python { .api }

188

# Common type aliases used throughout the API

189

_PATH = Union[str, Path]

190

_DEVICE = Union[torch.device, str, int]

191

_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable, dict[_DEVICE, _DEVICE]]]

192

_PARAMETERS = Iterator[torch.nn.Parameter]

193

ReduceOp = torch.distributed.ReduceOp

194

RedOpType = ReduceOp.RedOpType

195

196

# Protocols for type checking

197

@runtime_checkable

198

class _Stateful(Protocol[_DictKey]):

199

def state_dict(self) -> dict[_DictKey, Any]: ...

200

def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ...

201

202

@runtime_checkable

203

class Steppable(Protocol):

204

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: ...

205

206

@runtime_checkable

207

class Optimizable(Steppable, Protocol):

208

param_groups: list[dict[Any, Any]]

209

defaults: dict[Any, Any]

210

state: defaultdict[Tensor, Any]

211

212

def state_dict(self) -> dict[str, dict[Any, Any]]: ...

213

def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ...

214

215

@runtime_checkable

216

class CollectibleGroup(Protocol):

217

def size(self) -> int: ...

218

def rank(self) -> int: ...

219

```