0
# XLA Client Operations
1
2
Core XLA client functionality providing the main interface for interacting with XLA backends, managing computational resources, and creating clients for different hardware platforms.
3
4
## Capabilities
5
6
### Client Creation
7
8
Factory functions for creating XLA clients targeting different hardware platforms with platform-specific configuration options.
9
10
```python { .api }
11
def make_cpu_client(
12
asynchronous: bool = True,
13
distributed_client: DistributedRuntimeClient | None = None,
14
node_id: int = 0,
15
num_nodes: int = 1,
16
collectives: CpuCollectives | None = None,
17
num_devices: int | None = None,
18
get_local_topology_timeout_minutes: int | None = None,
19
get_global_topology_timeout_minutes: int | None = None,
20
transfer_server_factory: TransferServerInterfaceFactory | None = None,
21
) -> Client:
22
"""
23
Create a CPU client for XLA computations.
24
25
Parameters:
26
- asynchronous: Whether to use asynchronous execution
27
- distributed_client: Client for distributed computing
28
- node_id: Node identifier in distributed setup
29
- num_nodes: Total number of nodes
30
- collectives: CPU collective operations interface
31
- num_devices: Number of CPU devices to use
32
- get_local_topology_timeout_minutes: Timeout for local topology
33
- get_global_topology_timeout_minutes: Timeout for global topology
34
- transfer_server_factory: Factory for transfer servers
35
36
Returns:
37
XLA Client configured for CPU execution
38
"""
39
40
def make_gpu_client(
41
distributed_client: DistributedRuntimeClient | None = None,
42
node_id: int = 0,
43
num_nodes: int = 1,
44
platform_name: str | None = None,
45
allowed_devices: set[int] | None = None,
46
mock: bool | None = None,
47
mock_gpu_topology: str | None = None,
48
) -> Client:
49
"""
50
Create a GPU client for XLA computations.
51
52
Parameters:
53
- distributed_client: Client for distributed computing
54
- node_id: Node identifier in distributed setup
55
- num_nodes: Total number of nodes
56
- platform_name: GPU platform name ('cuda' or 'rocm')
57
- allowed_devices: Set of allowed GPU device IDs
58
- mock: Whether to use mock GPU for testing
59
- mock_gpu_topology: Mock topology specification
60
61
Returns:
62
XLA Client configured for GPU execution
63
"""
64
65
def make_c_api_client(
66
plugin_name: str,
67
options: dict[str, str | int | list[int] | float | bool] | None = None,
68
distributed_client: DistributedRuntimeClient | None = None,
69
transfer_server_factory: TransferServerInterfaceFactory | None = None,
70
) -> Client:
71
"""
72
Create a client using the PJRT C API for plugins.
73
74
Parameters:
75
- plugin_name: Name of the PJRT plugin
76
- options: Platform-specific options dictionary
77
- distributed_client: Client for distributed computing
78
- transfer_server_factory: Factory for transfer servers
79
80
Returns:
81
XLA Client using the specified plugin
82
"""
83
```
84
85
### Client Interface
86
87
The main Client class providing access to devices, compilation, and execution capabilities.
88
89
```python { .api }
90
class Client:
91
"""XLA client for managing devices and executing computations."""
92
93
platform: str
94
platform_version: str
95
runtime_type: str
96
97
def device_count(self) -> int:
98
"""Get total number of devices."""
99
100
def local_device_count(self) -> int:
101
"""Get number of local devices."""
102
103
def devices(self) -> list[Device]:
104
"""Get all available devices."""
105
106
def local_devices(self) -> list[Device]:
107
"""Get locally available devices."""
108
109
def host_id(self) -> int:
110
"""Get host identifier."""
111
112
def process_index(self) -> int:
113
"""Get process index in distributed setup."""
114
115
def buffer_from_pyval(
116
self,
117
argument: Any,
118
device: Device | None = None,
119
force_copy: bool = False,
120
host_buffer_semantics: HostBufferSemantics = ...,
121
) -> ArrayImpl:
122
"""
123
Create a buffer from Python value.
124
125
Parameters:
126
- argument: Python value to convert
127
- device: Target device (None for default)
128
- force_copy: Force copying even if not necessary
129
- host_buffer_semantics: How to handle host buffer
130
131
Returns:
132
Array buffer on the specified device
133
"""
134
135
def live_buffers(self) -> list[Any]:
136
"""Get list of live buffers."""
137
138
def live_arrays(self) -> list[ArrayImpl]:
139
"""Get list of live arrays."""
140
141
def live_executables(self) -> list[LoadedExecutable]:
142
"""Get list of live executables."""
143
144
def heap_profile(self) -> bytes:
145
"""Get heap profile for memory debugging."""
146
```
147
148
### Execution Utilities
149
150
Thread-level execution control for managing computation streams.
151
152
```python { .api }
153
def execution_stream_id(new_id: int):
154
"""
155
Context manager that overwrites and restores the current thread's execution_stream_id.
156
157
Parameters:
158
- new_id: New execution stream ID to set for the current thread
159
160
Returns:
161
Context manager that restores the original execution stream ID on exit
162
163
Usage:
164
with execution_stream_id(42):
165
# Code executed with stream ID 42
166
pass
167
# Original stream ID restored
168
"""
169
```
170
171
### GPU Plugin Options
172
173
Utilities for configuring GPU-specific options and plugin parameters.
174
175
```python { .api }
176
def generate_pjrt_gpu_plugin_options() -> dict[str, str | int | list[int] | float | bool]:
177
"""
178
Generate PjRt GPU plugin options from environment variables.
179
180
Reads configuration from environment variables:
181
- XLA_PYTHON_CLIENT_ALLOCATOR: Memory allocator type
182
- XLA_CLIENT_MEM_FRACTION: GPU memory fraction to use
183
- XLA_PYTHON_CLIENT_PREALLOCATE: Whether to preallocate memory
184
- XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB: Collective memory size
185
186
Returns:
187
Dictionary of plugin options
188
"""
189
```
190
191
### Topology Management
192
193
Functions for managing device topology and creating topology descriptions for different platforms.
194
195
```python { .api }
196
def make_tfrt_tpu_c_api_device_topology(
197
topology_name: str | None = None, **kwargs
198
) -> DeviceTopology:
199
"""
200
Create TPU device topology using TFRT C API.
201
202
Parameters:
203
- topology_name: Name of the topology
204
- **kwargs: Additional topology options
205
206
Returns:
207
DeviceTopology for TPU devices
208
"""
209
210
def make_c_api_device_topology(
211
c_api: Any, topology_name: str = '', **kwargs
212
) -> DeviceTopology:
213
"""
214
Create device topology using C API.
215
216
Parameters:
217
- c_api: C API interface
218
- topology_name: Name of the topology
219
- **kwargs: Additional topology options
220
221
Returns:
222
DeviceTopology for the specified platform
223
"""
224
225
def get_topology_for_devices(devices: list[Device]) -> DeviceTopology:
226
"""
227
Get topology description for a list of devices.
228
229
Parameters:
230
- devices: List of devices
231
232
Returns:
233
DeviceTopology describing the device layout
234
"""
235
```
236
237
### Distributed Runtime
238
239
Classes and functions for managing distributed computing across multiple nodes and processes.
240
241
```python { .api }
242
class DistributedRuntimeClient:
243
"""Client for distributed runtime coordination."""
244
245
def connect(self) -> Any:
246
"""Connect to distributed runtime service."""
247
248
def shutdown(self) -> Any:
249
"""Shutdown the distributed runtime client."""
250
251
def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> Any:
252
"""Blocking get operation for key-value store."""
253
254
def key_value_set(
255
self, key: str, value: str, allow_overwrite: bool = False
256
) -> Any:
257
"""Set operation for key-value store."""
258
259
def wait_at_barrier(
260
self,
261
barrier_id: str,
262
timeout_in_ms: int,
263
process_ids: list[int] | None = None,
264
) -> Any:
265
"""Wait at a named barrier for synchronization."""
266
267
class DistributedRuntimeService:
268
"""Service for distributed runtime coordination."""
269
270
def shutdown(self) -> None:
271
"""Shutdown the distributed runtime service."""
272
273
def get_distributed_runtime_service(
274
address: str,
275
num_nodes: int,
276
heartbeat_timeout: int | None = None,
277
cluster_register_timeout: int | None = None,
278
shutdown_timeout: int | None = None,
279
) -> DistributedRuntimeService:
280
"""
281
Create a distributed runtime service.
282
283
Parameters:
284
- address: Service address
285
- num_nodes: Number of nodes in cluster
286
- heartbeat_timeout: Heartbeat timeout in milliseconds
287
- cluster_register_timeout: Cluster registration timeout
288
- shutdown_timeout: Shutdown timeout
289
290
Returns:
291
DistributedRuntimeService instance
292
"""
293
294
def get_distributed_runtime_client(
295
address: str,
296
node_id: int,
297
rpc_timeout: int | None = None,
298
init_timeout: int | None = None,
299
shutdown_timeout: int | None = None,
300
heartbeat_timeout: int | None = None,
301
missed_heartbeat_callback: Any | None = None,
302
shutdown_on_destruction: bool | None = None,
303
use_compression: bool | None = None,
304
recoverable: bool | None = None,
305
) -> DistributedRuntimeClient:
306
"""
307
Create a distributed runtime client.
308
309
Parameters:
310
- address: Service address to connect to
311
- node_id: Unique node identifier
312
- rpc_timeout: RPC timeout in milliseconds
313
- init_timeout: Initialization timeout
314
- shutdown_timeout: Shutdown timeout
315
- heartbeat_timeout: Heartbeat timeout
316
- missed_heartbeat_callback: Callback for missed heartbeats
317
- shutdown_on_destruction: Whether to shutdown on destruction
318
- use_compression: Whether to use compression
319
- recoverable: Whether the client is recoverable
320
321
Returns:
322
DistributedRuntimeClient instance
323
"""
324
```
325
326
## Usage Examples
327
328
### Basic Client Setup
329
330
```python
331
from jaxlib import xla_client
332
333
# Create a CPU client
334
cpu_client = xla_client.make_cpu_client(asynchronous=True)
335
print(f"CPU devices: {cpu_client.local_devices()}")
336
337
# Create a GPU client (if available)
338
try:
339
gpu_client = xla_client.make_gpu_client(platform_name='cuda')
340
print(f"GPU devices: {gpu_client.local_devices()}")
341
except Exception as e:
342
print(f"GPU not available: {e}")
343
```
344
345
### Distributed Setup
346
347
```python
348
from jaxlib import xla_client
349
350
# Start distributed runtime service on coordinator
351
service = xla_client.get_distributed_runtime_service(
352
address="localhost:1234",
353
num_nodes=2,
354
heartbeat_timeout=60000
355
)
356
357
# Connect distributed client on each node
358
dist_client = xla_client.get_distributed_runtime_client(
359
address="localhost:1234",
360
node_id=0, # Different for each node
361
init_timeout=30000
362
)
363
364
# Create client with distributed support
365
client = xla_client.make_cpu_client(
366
distributed_client=dist_client,
367
node_id=0,
368
num_nodes=2
369
)
370
```