or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

device-management.mddocs/

0

# Device and Memory Management

1

2

Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance across CPUs, GPUs, and TPUs.

3

4

## Capabilities

5

6

### Device Interface

7

8

Core device representation providing access to device properties, memory spaces, and hardware-specific information.

9

10

```python { .api }

11

class Device:

12

"""Represents a computational device (CPU, GPU, TPU)."""

13

14

id: int

15

host_id: int

16

process_index: int

17

platform: str

18

device_kind: str

19

client: Client

20

local_hardware_id: int | None

21

22

def memory(self, kind: str) -> Memory:

23

"""

24

Get memory space of specified kind.

25

26

Parameters:

27

- kind: Memory kind string (e.g., 'default', 'pinned')

28

29

Returns:

30

Memory object for the specified kind

31

"""

32

33

def default_memory(self) -> Memory:

34

"""Get the default memory space for this device."""

35

36

def addressable_memories(self) -> list[Memory]:

37

"""Get all memory spaces addressable by this device."""

38

39

def live_buffers(self) -> list[Any]:

40

"""Get list of live buffers on this device."""

41

42

def memory_stats(self) -> dict[str, int] | None:

43

"""

44

Get memory usage statistics.

45

46

Returns:

47

Dictionary with memory statistics or None if not available

48

"""

49

50

def get_stream_for_external_ready_events(self) -> int:

51

"""Get stream handle for external ready events."""

52

```

53

54

### Memory Management

55

56

Memory space representation and management for different types of device memory.

57

58

```python { .api }

59

class Memory:

60

"""Represents a memory space on a device."""

61

62

process_index: int

63

platform: str

64

kind: str

65

66

def addressable_by_devices(self) -> list[Device]:

67

"""Get devices that can address this memory space."""

68

69

def check_and_canonicalize_memory_kind(

70

memory_kind: str | None, device_list: DeviceList

71

) -> str | None:

72

"""

73

Check and canonicalize memory kind specification.

74

75

Parameters:

76

- memory_kind: Memory kind string or None

77

- device_list: List of target devices

78

79

Returns:

80

Canonicalized memory kind or None

81

"""

82

```

83

84

### Device Lists

85

86

Container for managing collections of devices with utilities for addressing and memory management.

87

88

```python { .api }

89

class DeviceList:

90

"""Container for a list of devices with metadata."""

91

92

def __init__(self, device_assignment: tuple[Device, ...]): ...

93

94

def __len__(self) -> int:

95

"""Get number of devices in the list."""

96

97

def __getitem__(self, index: Any) -> Any:

98

"""Get device at specified index."""

99

100

def __iter__(self) -> Iterator[Device]:

101

"""Iterate over devices in the list."""

102

103

@property

104

def is_fully_addressable(self) -> bool:

105

"""Check if all devices are fully addressable."""

106

107

@property

108

def addressable_device_list(self) -> DeviceList:

109

"""Get list of addressable devices."""

110

111

@property

112

def process_indices(self) -> set[int]:

113

"""Get set of process indices for devices."""

114

115

@property

116

def default_memory_kind(self) -> str | None:

117

"""Get default memory kind for devices."""

118

119

@property

120

def memory_kinds(self) -> tuple[str, ...]:

121

"""Get tuple of available memory kinds."""

122

123

@property

124

def device_kind(self) -> str:

125

"""Get device kind for all devices."""

126

```

127

128

### Device Topology

129

130

Topology information for understanding device layout and connectivity in multi-device and multi-node systems.

131

132

```python { .api }

133

class DeviceTopology:

134

"""Represents the topology of devices in a system."""

135

136

platform: str

137

platform_version: str

138

139

def _make_compile_only_devices(self) -> list[Device]:

140

"""Create compile-only devices from topology."""

141

142

def serialize(self) -> bytes:

143

"""Serialize topology to bytes."""

144

```

145

146

### Device Assignment

147

148

Utilities for assigning devices to computations in distributed and multi-device scenarios.

149

150

```python { .api }

151

class DeviceAssignment:

152

"""Represents assignment of devices to computation replicas."""

153

154

@staticmethod

155

def create(array: np.ndarray) -> DeviceAssignment:

156

"""

157

Create device assignment from array.

158

159

Parameters:

160

- array: 2D numpy array of device ordinals indexed by [replica][computation]

161

162

Returns:

163

DeviceAssignment object

164

"""

165

166

def replica_count(self) -> int:

167

"""Get number of replicas."""

168

169

def computation_count(self) -> int:

170

"""Get number of computations per replica."""

171

172

def serialize(self) -> bytes:

173

"""Serialize device assignment to bytes."""

174

```

175

176

### Layout Management

177

178

Data layout specification and management for optimal memory access patterns on different hardware.

179

180

```python { .api }

181

class Layout:

182

"""Represents data layout in memory."""

183

184

def __init__(self, minor_to_major: tuple[int, ...]): ...

185

186

def minor_to_major(self) -> tuple[int, ...]:

187

"""Get minor-to-major dimension ordering."""

188

189

def tiling(self) -> Sequence[tuple[int, ...]]:

190

"""Get tiling specification."""

191

192

def element_size_in_bits(self) -> int:

193

"""Get element size in bits."""

194

195

def to_string(self) -> str:

196

"""Get string representation of layout."""

197

198

class PjRtLayout:

199

"""PJRT-specific layout representation."""

200

201

def _xla_layout(self) -> Layout:

202

"""Get underlying XLA layout."""

203

```

204

205

### GPU Configuration

206

207

GPU-specific configuration and memory management options.

208

209

```python { .api }

210

class GpuAllocatorConfig:

211

"""Configuration for GPU memory allocator."""

212

213

class Kind(enum.IntEnum):

214

DEFAULT = ...

215

PLATFORM = ...

216

BFC = ...

217

CUDA_ASYNC = ...

218

219

def __init__(

220

self,

221

kind: Kind = ...,

222

memory_fraction: float = ...,

223

preallocate: bool = ...,

224

collective_memory_size: int = ...,

225

) -> None: ...

226

```

227

228

## Usage Examples

229

230

### Device Discovery and Selection

231

232

```python

233

from jaxlib import xla_client

234

235

# Create client and discover devices

236

client = xla_client.make_cpu_client()

237

devices = client.devices()

238

239

print(f"Available devices: {len(devices)}")

240

for device in devices:

241

print(f"Device {device.id}: {device.platform} ({device.device_kind})")

242

print(f" Host ID: {device.host_id}")

243

print(f" Process: {device.process_index}")

244

245

# Check memory information

246

default_mem = device.default_memory()

247

print(f" Default memory: {default_mem.kind}")

248

249

addressable_mems = device.addressable_memories()

250

print(f" Addressable memories: {[m.kind for m in addressable_mems]}")

251

252

# Get memory stats if available

253

stats = device.memory_stats()

254

if stats:

255

print(f" Memory stats: {stats}")

256

```

257

258

### Memory Management

259

260

```python

261

from jaxlib import xla_client

262

import numpy as np

263

264

client = xla_client.make_cpu_client()

265

device = client.local_devices()[0]

266

267

# Create data and put on device

268

data = np.array([1.0, 2.0, 3.0], dtype=np.float32)

269

buffer = client.buffer_from_pyval(data, device=device)

270

271

print(f"Buffer on device: {buffer}")

272

print(f"Live buffers on device: {len(device.live_buffers())}")

273

274

# Check memory usage

275

stats = device.memory_stats()

276

if stats:

277

print(f"Memory usage: {stats}")

278

```

279

280

### Device Assignment for Multi-Device

281

282

```python

283

from jaxlib import xla_client

284

import numpy as np

285

286

client = xla_client.make_cpu_client()

287

devices = client.local_devices()

288

289

if len(devices) >= 2:

290

# Create device assignment for 2 replicas on 2 devices

291

assignment_array = np.array([[0], [1]], dtype=np.int32)

292

device_assignment = xla_client.DeviceAssignment.create(assignment_array)

293

294

print(f"Replica count: {device_assignment.replica_count()}")

295

print(f"Computation count: {device_assignment.computation_count()}")

296

```

297

298

### Device Topology

299

300

```python

301

from jaxlib import xla_client

302

303

client = xla_client.make_cpu_client()

304

devices = client.local_devices()

305

306

# Get topology for available devices

307

topology = xla_client.get_topology_for_devices(devices)

308

print(f"Topology platform: {topology.platform}")

309

print(f"Platform version: {topology.platform_version}")

310

311

# Serialize topology for transfer

312

topology_bytes = topology.serialize()

313

print(f"Serialized topology size: {len(topology_bytes)} bytes")

314

```