0
# Device and Memory Management
1
2
Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance across CPUs, GPUs, and TPUs.
3
4
## Capabilities
5
6
### Device Interface
7
8
Core device representation providing access to device properties, memory spaces, and hardware-specific information.
9
10
```python { .api }
11
class Device:
12
"""Represents a computational device (CPU, GPU, TPU)."""
13
14
id: int
15
host_id: int
16
process_index: int
17
platform: str
18
device_kind: str
19
client: Client
20
local_hardware_id: int | None
21
22
def memory(self, kind: str) -> Memory:
23
"""
24
Get memory space of specified kind.
25
26
Parameters:
27
- kind: Memory kind string (e.g., 'default', 'pinned')
28
29
Returns:
30
Memory object for the specified kind
31
"""
32
33
def default_memory(self) -> Memory:
34
"""Get the default memory space for this device."""
35
36
def addressable_memories(self) -> list[Memory]:
37
"""Get all memory spaces addressable by this device."""
38
39
def live_buffers(self) -> list[Any]:
40
"""Get list of live buffers on this device."""
41
42
def memory_stats(self) -> dict[str, int] | None:
43
"""
44
Get memory usage statistics.
45
46
Returns:
47
Dictionary with memory statistics or None if not available
48
"""
49
50
def get_stream_for_external_ready_events(self) -> int:
51
"""Get stream handle for external ready events."""
52
```
53
54
### Memory Management
55
56
Memory space representation and management for different types of device memory.
57
58
```python { .api }
59
class Memory:
60
"""Represents a memory space on a device."""
61
62
process_index: int
63
platform: str
64
kind: str
65
66
def addressable_by_devices(self) -> list[Device]:
67
"""Get devices that can address this memory space."""
68
69
def check_and_canonicalize_memory_kind(
70
memory_kind: str | None, device_list: DeviceList
71
) -> str | None:
72
"""
73
Check and canonicalize memory kind specification.
74
75
Parameters:
76
- memory_kind: Memory kind string or None
77
- device_list: List of target devices
78
79
Returns:
80
Canonicalized memory kind or None
81
"""
82
```
83
84
### Device Lists
85
86
Container for managing collections of devices with utilities for addressing and memory management.
87
88
```python { .api }
89
class DeviceList:
90
"""Container for a list of devices with metadata."""
91
92
def __init__(self, device_assignment: tuple[Device, ...]): ...
93
94
def __len__(self) -> int:
95
"""Get number of devices in the list."""
96
97
def __getitem__(self, index: Any) -> Any:
98
"""Get device at specified index."""
99
100
def __iter__(self) -> Iterator[Device]:
101
"""Iterate over devices in the list."""
102
103
@property
104
def is_fully_addressable(self) -> bool:
105
"""Check if all devices are fully addressable."""
106
107
@property
108
def addressable_device_list(self) -> DeviceList:
109
"""Get list of addressable devices."""
110
111
@property
112
def process_indices(self) -> set[int]:
113
"""Get set of process indices for devices."""
114
115
@property
116
def default_memory_kind(self) -> str | None:
117
"""Get default memory kind for devices."""
118
119
@property
120
def memory_kinds(self) -> tuple[str, ...]:
121
"""Get tuple of available memory kinds."""
122
123
@property
124
def device_kind(self) -> str:
125
"""Get device kind for all devices."""
126
```
127
128
### Device Topology
129
130
Topology information for understanding device layout and connectivity in multi-device and multi-node systems.
131
132
```python { .api }
133
class DeviceTopology:
134
"""Represents the topology of devices in a system."""
135
136
platform: str
137
platform_version: str
138
139
def _make_compile_only_devices(self) -> list[Device]:
140
"""Create compile-only devices from topology."""
141
142
def serialize(self) -> bytes:
143
"""Serialize topology to bytes."""
144
```
145
146
### Device Assignment
147
148
Utilities for assigning devices to computations in distributed and multi-device scenarios.
149
150
```python { .api }
151
class DeviceAssignment:
152
"""Represents assignment of devices to computation replicas."""
153
154
@staticmethod
155
def create(array: np.ndarray) -> DeviceAssignment:
156
"""
157
Create device assignment from array.
158
159
Parameters:
160
- array: 2D numpy array of device ordinals indexed by [replica][computation]
161
162
Returns:
163
DeviceAssignment object
164
"""
165
166
def replica_count(self) -> int:
167
"""Get number of replicas."""
168
169
def computation_count(self) -> int:
170
"""Get number of computations per replica."""
171
172
def serialize(self) -> bytes:
173
"""Serialize device assignment to bytes."""
174
```
175
176
### Layout Management
177
178
Data layout specification and management for optimal memory access patterns on different hardware.
179
180
```python { .api }
181
class Layout:
182
"""Represents data layout in memory."""
183
184
def __init__(self, minor_to_major: tuple[int, ...]): ...
185
186
def minor_to_major(self) -> tuple[int, ...]:
187
"""Get minor-to-major dimension ordering."""
188
189
def tiling(self) -> Sequence[tuple[int, ...]]:
190
"""Get tiling specification."""
191
192
def element_size_in_bits(self) -> int:
193
"""Get element size in bits."""
194
195
def to_string(self) -> str:
196
"""Get string representation of layout."""
197
198
class PjRtLayout:
199
"""PJRT-specific layout representation."""
200
201
def _xla_layout(self) -> Layout:
202
"""Get underlying XLA layout."""
203
```
204
205
### GPU Configuration
206
207
GPU-specific configuration and memory management options.
208
209
```python { .api }
210
class GpuAllocatorConfig:
211
"""Configuration for GPU memory allocator."""
212
213
class Kind(enum.IntEnum):
214
DEFAULT = ...
215
PLATFORM = ...
216
BFC = ...
217
CUDA_ASYNC = ...
218
219
def __init__(
220
self,
221
kind: Kind = ...,
222
memory_fraction: float = ...,
223
preallocate: bool = ...,
224
collective_memory_size: int = ...,
225
) -> None: ...
226
```
227
228
## Usage Examples
229
230
### Device Discovery and Selection
231
232
```python
233
from jaxlib import xla_client
234
235
# Create client and discover devices
236
client = xla_client.make_cpu_client()
237
devices = client.devices()
238
239
print(f"Available devices: {len(devices)}")
240
for device in devices:
241
print(f"Device {device.id}: {device.platform} ({device.device_kind})")
242
print(f" Host ID: {device.host_id}")
243
print(f" Process: {device.process_index}")
244
245
# Check memory information
246
default_mem = device.default_memory()
247
print(f" Default memory: {default_mem.kind}")
248
249
addressable_mems = device.addressable_memories()
250
print(f" Addressable memories: {[m.kind for m in addressable_mems]}")
251
252
# Get memory stats if available
253
stats = device.memory_stats()
254
if stats:
255
print(f" Memory stats: {stats}")
256
```
257
258
### Memory Management
259
260
```python
261
from jaxlib import xla_client
262
import numpy as np
263
264
client = xla_client.make_cpu_client()
265
device = client.local_devices()[0]
266
267
# Create data and put on device
268
data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
269
buffer = client.buffer_from_pyval(data, device=device)
270
271
print(f"Buffer on device: {buffer}")
272
print(f"Live buffers on device: {len(device.live_buffers())}")
273
274
# Check memory usage
275
stats = device.memory_stats()
276
if stats:
277
print(f"Memory usage: {stats}")
278
```
279
280
### Device Assignment for Multi-Device
281
282
```python
283
from jaxlib import xla_client
284
import numpy as np
285
286
client = xla_client.make_cpu_client()
287
devices = client.local_devices()
288
289
if len(devices) >= 2:
290
# Create device assignment for 2 replicas on 2 devices
291
assignment_array = np.array([[0], [1]], dtype=np.int32)
292
device_assignment = xla_client.DeviceAssignment.create(assignment_array)
293
294
print(f"Replica count: {device_assignment.replica_count()}")
295
print(f"Computation count: {device_assignment.computation_count()}")
296
```
297
298
### Device Topology
299
300
```python
301
from jaxlib import xla_client
302
303
client = xla_client.make_cpu_client()
304
devices = client.local_devices()
305
306
# Get topology for available devices
307
topology = xla_client.get_topology_for_devices(devices)
308
print(f"Topology platform: {topology.platform}")
309
print(f"Platform version: {topology.platform_version}")
310
311
# Serialize topology for transfer
312
topology_bytes = topology.serialize()
313
print(f"Serialized topology size: {len(topology_bytes)} bytes")
314
```