0
# JaxLib
1
2
JaxLib is the XLA library for JAX, serving as the foundational support library that provides low-level binary components including Python bindings to XLA, the PJRT runtime, and handwritten kernels. It enables JAX's high-performance numerical computing capabilities on various hardware accelerators including CPUs, GPUs, and TPUs, supporting automatic differentiation, just-in-time compilation, vectorization, and distributed computing.
3
4
## Package Information
5
6
- **Package Name**: jaxlib
7
- **Language**: Python
8
- **Installation**: `pip install jaxlib`
9
- **Dependencies**: `scipy>=1.12`, `numpy>=1.26`, `ml_dtypes>=0.5.0`
10
- **Hardware Support**: CPU, GPU (CUDA/ROCm), TPU
11
12
## Core Imports
13
14
```python
15
import jaxlib
16
```
17
18
For XLA client operations:
19
20
```python
21
from jaxlib import xla_client
22
```
23
24
## Basic Usage
25
26
```python
27
from jaxlib import xla_client
28
import numpy as np
29
30
# Create a CPU client
31
client = xla_client.make_cpu_client()
32
33
# Create a simple computation
34
def simple_add(a, b):
35
return a + b
36
37
# Convert data to buffers
38
data_a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
39
data_b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
40
41
buffer_a = client.buffer_from_pyval(data_a)
42
buffer_b = client.buffer_from_pyval(data_b)
43
44
print("JaxLib version:", jaxlib.__version__)
45
print("Available devices:", client.devices())
46
print("Platform:", client.platform)
47
```
48
49
## Architecture
50
51
JaxLib implements a layered architecture with clear separation of concerns:
52
53
- **XLA Client Layer**: High-level Python API for XLA operations and compilation
54
- **PJRT Runtime**: Platform-specific runtime for executing compiled programs
55
- **Device Backends**: Hardware-specific implementations (CPU, GPU, TPU)
56
- **Custom Operations**: Extensible system for user-defined operations
57
- **Distributed Computing**: Multi-node execution and communication primitives
58
59
The design enables JAX to transform and scale numerical programs efficiently across different computing platforms through a consistent interface while allowing low-level optimization and hardware-specific acceleration.
60
61
## Capabilities
62
63
### XLA Client Operations
64
65
Core XLA client functionality including client creation, device management, compilation, and execution. Provides the main interface for interacting with XLA backends and managing computational resources.
66
67
```python { .api }
68
def make_cpu_client(
69
asynchronous: bool = True,
70
distributed_client: DistributedRuntimeClient | None = None,
71
node_id: int = 0,
72
num_nodes: int = 1,
73
collectives: CpuCollectives | None = None,
74
num_devices: int | None = None,
75
get_local_topology_timeout_minutes: int | None = None,
76
get_global_topology_timeout_minutes: int | None = None,
77
transfer_server_factory: TransferServerInterfaceFactory | None = None,
78
) -> Client: ...
79
80
def make_gpu_client(
81
distributed_client: DistributedRuntimeClient | None = None,
82
node_id: int = 0,
83
num_nodes: int = 1,
84
platform_name: str | None = None,
85
allowed_devices: set[int] | None = None,
86
mock: bool | None = None,
87
mock_gpu_topology: str | None = None,
88
) -> Client: ...
89
90
def make_c_api_client(
91
plugin_name: str,
92
options: dict[str, str | int | list[int] | float | bool] | None = None,
93
distributed_client: DistributedRuntimeClient | None = None,
94
transfer_server_factory: TransferServerInterfaceFactory | None = None,
95
) -> Client: ...
96
```
97
98
[XLA Client](./xla-client.md)
99
100
### Device and Memory Management
101
102
Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance.
103
104
```python { .api }
105
class Device:
106
id: int
107
host_id: int
108
process_index: int
109
platform: str
110
device_kind: str
111
client: Client
112
local_hardware_id: int | None
113
114
def memory(self, kind: str) -> Memory: ...
115
def default_memory(self) -> Memory: ...
116
def addressable_memories(self) -> list[Memory]: ...
117
def memory_stats(self) -> dict[str, int] | None: ...
118
119
class DeviceList:
120
def __init__(self, device_assignment: tuple[Device, ...]): ...
121
def __len__(self) -> int: ...
122
def __getitem__(self, index: Any) -> Any: ...
123
def __iter__(self) -> Iterator[Device]: ...
124
125
@property
126
def is_fully_addressable(self) -> bool: ...
127
@property
128
def addressable_device_list(self) -> DeviceList: ...
129
@property
130
def process_indices(self) -> set[int]: ...
131
@property
132
def default_memory_kind(self) -> str | None: ...
133
@property
134
def memory_kinds(self) -> tuple[str, ...]: ...
135
@property
136
def device_kind(self) -> str: ...
137
```
138
139
[Device Management](./device-management.md)
140
141
### Compilation and Execution
142
143
XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes.
144
145
```python { .api }
146
class Client:
147
platform: str
148
platform_version: str
149
runtime_type: str
150
151
def compile(
152
self,
153
computation: str | bytes,
154
executable_devices: DeviceList | Sequence[Device],
155
compile_options: CompileOptions = ...,
156
) -> Executable: ...
157
158
def compile_and_load(
159
self,
160
computation: str | bytes,
161
executable_devices: DeviceList | Sequence[Device],
162
compile_options: CompileOptions = ...,
163
host_callbacks: Sequence[Any] = ...,
164
) -> LoadedExecutable: ...
165
166
class LoadedExecutable:
167
client: Client
168
169
def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ...
170
def execute_sharded(
171
self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ...
172
) -> ExecuteResults: ...
173
def hlo_modules(self) -> list[HloModule]: ...
174
def get_output_memory_kinds(self) -> list[list[str]]: ...
175
def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...
176
```
177
178
[Compilation and Execution](./compilation-execution.md)
179
180
### Array and Buffer Operations
181
182
High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.
183
184
```python { .api }
185
def batched_device_put(
186
aval: Any,
187
sharding: Any,
188
shards: Sequence[Any],
189
devices: list[Device],
190
committed: bool = ...,
191
force_copy: bool = ...,
192
host_buffer_semantics: Any = ...,
193
) -> ArrayImpl: ...
194
195
def batched_copy_array_to_devices_with_sharding(
196
arrays: Sequence[ArrayImpl],
197
devices: Sequence[DeviceList],
198
sharding: Sequence[Any],
199
array_copy_semantics: Sequence[ArrayCopySemantics],
200
) -> Sequence[ArrayImpl]: ...
201
202
def reorder_shards(
203
x: ArrayImpl,
204
dst_sharding: Any,
205
array_copy_semantics: ArrayCopySemantics,
206
) -> ArrayImpl: ...
207
208
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...
209
```
210
211
[Array Operations](./array-operations.md)
212
213
### Sharding and Distribution
214
215
Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.
216
217
```python { .api }
218
class Sharding: ...
219
220
class NamedSharding(Sharding):
221
def __init__(
222
self,
223
mesh: Any,
224
spec: Any,
225
*,
226
memory_kind: str | None = None,
227
_logical_device_ids: tuple[int, ...] | None = None,
228
): ...
229
mesh: Any
230
spec: Any
231
232
class SingleDeviceSharding(Sharding):
233
def __init__(self, device: Device, *, memory_kind: str | None = None): ...
234
235
class GSPMDSharding(Sharding):
236
def __init__(
237
self,
238
devices: Sequence[Device],
239
op_sharding: OpSharding | HloSharding,
240
*,
241
memory_kind: str | None = None,
242
_device_list: DeviceList | None = None,
243
): ...
244
245
class HloSharding:
246
@staticmethod
247
def from_proto(proto: OpSharding) -> HloSharding: ...
248
@staticmethod
249
def replicate() -> HloSharding: ...
250
@staticmethod
251
def manual() -> HloSharding: ...
252
253
def is_replicated(self) -> bool: ...
254
def is_tiled(self) -> bool: ...
255
def num_devices(self) -> int: ...
256
```
257
258
[Sharding](./sharding.md)
259
260
### Custom Operations
261
262
Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.
263
264
```python { .api }
265
class CustomCallTargetTraits(enum.IntFlag):
266
DEFAULT = 0
267
COMMAND_BUFFER_COMPATIBLE = 1
268
269
def register_custom_call_target(
270
name: str,
271
fn: Any,
272
platform: str = 'cpu',
273
api_version: int = 0,
274
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
275
) -> None: ...
276
277
def register_custom_call_handler(
278
platform: str, handler: CustomCallHandler
279
) -> None: ...
280
281
def register_custom_call_partitioner(
282
name: str,
283
prop_user_sharding: Callable,
284
partition: Callable,
285
infer_sharding_from_operands: Callable,
286
can_side_effecting_have_replicated_sharding: bool = ...,
287
c_api: Any | None = ...,
288
) -> None: ...
289
290
def custom_call_targets(platform: str) -> dict[str, Any]: ...
291
```
292
293
[Custom Operations](./custom-operations.md)
294
295
### Hardware-Specific Operations
296
297
Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.
298
299
```python { .api }
300
# LAPACK operations
301
def registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
302
def prepare_lapack_call(fn_base: str, dtype: Any) -> str: ...
303
304
# GPU operations
305
def gpu_linalg.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
306
def gpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
307
308
# CPU sparse operations
309
def cpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
310
```
311
312
[Hardware-Specific Operations](./hardware-operations.md)
313
314
### Plugin System
315
316
Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.
317
318
```python { .api }
319
def import_from_plugin(
320
plugin_name: str,
321
submodule_name: str,
322
*,
323
check_version: bool = True
324
) -> ModuleType | None: ...
325
326
def check_plugin_version(
327
plugin_name: str,
328
jaxlib_version: str,
329
plugin_version: str
330
) -> bool: ...
331
332
def pjrt_plugin_loaded(plugin_name: str) -> bool: ...
333
334
def load_pjrt_plugin_dynamically(
335
plugin_name: str, library_path: str
336
) -> Any: ...
337
338
def initialize_pjrt_plugin(plugin_name: str) -> None: ...
339
```
340
341
[Plugin System](./plugin-system.md)
342
343
## Types
344
345
```python { .api }
346
# Core types
347
class Shape:
348
def __init__(self, s: str): ...
349
@staticmethod
350
def array_shape(
351
type: np.dtype | PrimitiveType,
352
dims_seq: Any = ...,
353
layout_seq: Any = ...,
354
dynamic_dimensions: list[bool] | None = ...,
355
) -> Shape: ...
356
357
def dimensions(self) -> tuple[int, ...]: ...
358
def rank(self) -> int: ...
359
def is_array(self) -> bool: ...
360
def is_tuple(self) -> bool: ...
361
362
class PrimitiveType(enum.IntEnum):
363
PRED = ...
364
S8 = ...
365
S16 = ...
366
S32 = ...
367
S64 = ...
368
U8 = ...
369
U16 = ...
370
U32 = ...
371
U64 = ...
372
F16 = ...
373
F32 = ...
374
F64 = ...
375
BF16 = ...
376
C64 = ...
377
C128 = ...
378
379
class ArrayCopySemantics(enum.IntEnum):
380
ALWAYS_COPY = ...
381
REUSE_INPUT = ...
382
DONATE_INPUT = ...
383
384
class HostBufferSemantics(enum.IntEnum):
385
IMMUTABLE_ONLY_DURING_CALL = ...
386
IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ...
387
ZERO_COPY = ...
388
389
# Exception types
390
class XlaRuntimeError(RuntimeError): ...
391
392
class GpuLibNotLinkedError(Exception):
393
"""Raised when the GPU library is not linked."""
394
```