or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

accelerators.mddocs/

0

# Accelerators

1

2

Hardware acceleration plugins that provide device abstraction and optimizations for different compute platforms.

3

4

## Capabilities

5

6

### Base Accelerator

7

8

Abstract base class defining the accelerator interface.

9

10

```python { .api }

11

class Accelerator:

12

"""

13

Abstract base class for hardware accelerators.

14

15

Accelerators handle device detection, setup, and hardware-specific

16

optimizations for training and inference.

17

"""

18

19

def setup_device(self, device: torch.device) -> None:

20

"""Setup the device for training."""

21

22

def teardown(self) -> None:

23

"""Clean up accelerator resources."""

24

25

def parse_devices(self, devices: Any) -> Any:

26

"""Parse device specification into concrete device list."""

27

28

def get_parallel_devices(self, devices: Any) -> list[torch.device]:

29

"""Get list of devices for parallel training."""

30

31

def auto_device_count(self) -> int:

32

"""Get number of available devices."""

33

34

def is_available(self) -> bool:

35

"""Check if accelerator is available on current system."""

36

37

@staticmethod

38

def register_accelerators() -> None:

39

"""Register accelerator in global registry."""

40

```

41

42

### CPU Accelerator

43

44

CPU-based training acceleration with optimizations for CPU hardware.

45

46

```python { .api }

47

class CPUAccelerator(Accelerator):

48

"""

49

CPU accelerator for training on CPU hardware.

50

51

Provides CPU-specific optimizations and multi-threading support.

52

"""

53

54

def setup_device(self, device: torch.device) -> None:

55

"""Setup CPU device with optimal threading configuration."""

56

57

def is_available(self) -> bool:

58

"""CPU is always available."""

59

60

def auto_device_count(self) -> int:

61

"""Returns 1 for CPU (single logical device)."""

62

63

def parse_devices(self, devices: Any) -> int:

64

"""Parse CPU device specification."""

65

66

def get_parallel_devices(self, devices: Any) -> list[torch.device]:

67

"""Get CPU device list for parallel training."""

68

```

69

70

### CUDA Accelerator

71

72

NVIDIA GPU acceleration with CUDA support and GPU-specific optimizations.

73

74

```python { .api }

75

class CUDAAccelerator(Accelerator):

76

"""

77

CUDA accelerator for NVIDIA GPU training.

78

79

Provides GPU memory management, multi-GPU support, and CUDA optimizations.

80

"""

81

82

def setup_device(self, device: torch.device) -> None:

83

"""Setup CUDA device with memory and compute optimizations."""

84

85

def is_available(self) -> bool:

86

"""Check if CUDA is available and GPUs are present."""

87

88

def auto_device_count(self) -> int:

89

"""Get number of available CUDA devices."""

90

91

def parse_devices(self, devices: Any) -> Union[int, list[int]]:

92

"""Parse GPU device specification (IDs, count, etc.)."""

93

94

def get_parallel_devices(self, devices: Any) -> list[torch.device]:

95

"""Get list of CUDA devices for parallel training."""

96

97

def get_device_stats(self, device: torch.device) -> dict[str, Any]:

98

"""Get GPU memory and utilization statistics."""

99

100

def empty_cache(self) -> None:

101

"""Clear GPU memory cache."""

102

103

def set_cuda_device(self, device: torch.device) -> None:

104

"""Set current CUDA device."""

105

```

106

107

### MPS Accelerator

108

109

Apple Silicon GPU acceleration using Metal Performance Shaders.

110

111

```python { .api }

112

class MPSAccelerator(Accelerator):

113

"""

114

MPS (Metal Performance Shaders) accelerator for Apple Silicon.

115

116

Provides GPU acceleration on Apple M1/M2/M3 chips using Metal framework.

117

"""

118

119

def setup_device(self, device: torch.device) -> None:

120

"""Setup MPS device for Apple Silicon GPU training."""

121

122

def is_available(self) -> bool:

123

"""Check if MPS backend is available on current system."""

124

125

def auto_device_count(self) -> int:

126

"""Returns 1 for MPS (single logical GPU device)."""

127

128

def parse_devices(self, devices: Any) -> int:

129

"""Parse MPS device specification."""

130

131

def get_parallel_devices(self, devices: Any) -> list[torch.device]:

132

"""Get MPS device for training (single device)."""

133

```

134

135

### XLA Accelerator

136

137

TPU acceleration using XLA (Accelerated Linear Algebra) compiler.

138

139

```python { .api }

140

class XLAAccelerator(Accelerator):

141

"""

142

XLA accelerator for TPU training and XLA-compiled execution.

143

144

Provides TPU support and XLA compilation optimizations for

145

high-performance training on Google Cloud TPUs.

146

"""

147

148

def setup_device(self, device: torch.device) -> None:

149

"""Setup XLA device for TPU training."""

150

151

def is_available(self) -> bool:

152

"""Check if XLA/TPU runtime is available."""

153

154

def auto_device_count(self) -> int:

155

"""Get number of available TPU cores."""

156

157

def parse_devices(self, devices: Any) -> Union[int, list[int]]:

158

"""Parse TPU device specification."""

159

160

def get_parallel_devices(self, devices: Any) -> list[torch.device]:

161

"""Get list of TPU devices for parallel training."""

162

163

def all_gather_object(self, obj: Any) -> list[Any]:

164

"""TPU-specific all-gather implementation."""

165

166

def broadcast_object(self, obj: Any, src: int = 0) -> Any:

167

"""TPU-specific broadcast implementation."""

168

```

169

170

### Device Utilities

171

172

Helper functions for device detection and management.

173

174

```python { .api }

175

def find_usable_cuda_devices(num_devices: int = -1) -> list[int]:

176

"""

177

Find CUDA devices that are available and usable.

178

179

Args:

180

num_devices: Number of devices to find (-1 for all available)

181

182

Returns:

183

List of CUDA device IDs that can be used for training

184

185

Examples:

186

# Find all available GPUs

187

devices = find_usable_cuda_devices()

188

189

# Find 2 available GPUs

190

devices = find_usable_cuda_devices(2)

191

"""

192

193

def get_nvidia_gpu_stats(device: torch.device) -> dict[str, Union[int, float]]:

194

"""

195

Get NVIDIA GPU statistics and memory usage.

196

197

Args:

198

device: CUDA device to query

199

200

Returns:

201

Dictionary with GPU statistics including memory usage,

202

utilization, temperature, and power consumption

203

"""

204

```

205

206

### Accelerator Registry

207

208

Global registry system for discovering and instantiating accelerators.

209

210

```python { .api }

211

class AcceleratorRegistry:

212

"""Registry for accelerator plugins."""

213

214

def register(

215

self,

216

name: str,

217

accelerator_class: type[Accelerator],

218

description: Optional[str] = None

219

) -> None:

220

"""Register an accelerator class."""

221

222

def get(self, name: str) -> type[Accelerator]:

223

"""Get accelerator class by name."""

224

225

def available_accelerators(self) -> list[str]:

226

"""Get list of available accelerator names."""

227

228

def remove(self, name: str) -> None:

229

"""Remove accelerator from registry."""

230

231

# Global registry instance

232

ACCELERATOR_REGISTRY: AcceleratorRegistry

233

```

234

235

## Usage Examples

236

237

### Automatic Accelerator Selection

238

239

```python

240

from lightning.fabric import Fabric

241

242

# Auto-detect best available accelerator

243

fabric = Fabric(accelerator="auto")

244

print(f"Using accelerator: {fabric.accelerator.__class__.__name__}")

245

```

246

247

### Specific Accelerator Configuration

248

249

```python

250

# Use specific accelerator types

251

fabric_gpu = Fabric(accelerator="cuda", devices=2)

252

fabric_cpu = Fabric(accelerator="cpu")

253

fabric_mps = Fabric(accelerator="mps") # Apple Silicon

254

fabric_tpu = Fabric(accelerator="tpu", devices=8) # TPU v3/v4

255

```

256

257

### Multi-GPU Setup

258

259

```python

260

# Use specific GPU devices

261

fabric = Fabric(accelerator="cuda", devices=[0, 2, 3])

262

263

# Use all available GPUs

264

fabric = Fabric(accelerator="gpu", devices="auto")

265

266

# Use specific number of GPUs

267

fabric = Fabric(accelerator="gpu", devices=4)

268

```

269

270

### Custom Accelerator

271

272

```python

273

from lightning.fabric.accelerators import Accelerator, ACCELERATOR_REGISTRY

274

275

class CustomAccelerator(Accelerator):

276

def setup_device(self, device):

277

# Custom device setup logic

278

pass

279

280

def is_available(self):

281

# Custom availability check

282

return True

283

284

# Register custom accelerator

285

ACCELERATOR_REGISTRY.register("custom", CustomAccelerator)

286

287

# Use custom accelerator

288

fabric = Fabric(accelerator="custom")

289

```

290

291

### Device Management

292

293

```python

294

from lightning.fabric.accelerators.cuda import find_usable_cuda_devices

295

296

# Find available GPUs

297

available_gpus = find_usable_cuda_devices()

298

print(f"Available GPUs: {available_gpus}")

299

300

# Use subset of available GPUs

301

if len(available_gpus) >= 2:

302

fabric = Fabric(accelerator="cuda", devices=available_gpus[:2])

303

```

304

305

### TPU Configuration

306

307

```python

308

# TPU training setup

309

fabric = Fabric(

310

accelerator="tpu",

311

devices=8, # TPU v3/v4 pod

312

precision="bf16-mixed" # BFloat16 for TPUs

313

)

314

315

# Access TPU-specific methods

316

if hasattr(fabric.accelerator, 'all_gather_object'):

317

result = fabric.accelerator.all_gather_object(local_data)

318

```

319

320

### Device Statistics

321

322

```python

323

# Monitor GPU usage during training

324

if fabric.accelerator.__class__.__name__ == 'CUDAAccelerator':

325

stats = fabric.accelerator.get_device_stats(fabric.device)

326

fabric.print(f"GPU Memory: {stats['memory_used']}/{stats['memory_total']} MB")

327

328

# Clear cache if needed

329

if stats['memory_used'] / stats['memory_total'] > 0.9:

330

fabric.accelerator.empty_cache()

331

```