XLA library for JAX providing low-level bindings and hardware acceleration support
npx @tessl/cli install tessl/pypi-jaxlib@0.7.0JaxLib is the XLA library for JAX, serving as the foundational support library that provides low-level binary components including Python bindings to XLA, the PJRT runtime, and handwritten kernels. It enables JAX's high-performance numerical computing capabilities on various hardware accelerators including CPUs, GPUs, and TPUs, supporting automatic differentiation, just-in-time compilation, vectorization, and distributed computing.
pip install jaxlibscipy>=1.12, numpy>=1.26, ml_dtypes>=0.5.0import jaxlibFor XLA client operations:
from jaxlib import xla_clientfrom jaxlib import xla_client
import numpy as np
# Create a CPU client
client = xla_client.make_cpu_client()
# Create a simple computation
def simple_add(a, b):
return a + b
# Convert data to buffers
data_a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
data_b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
buffer_a = client.buffer_from_pyval(data_a)
buffer_b = client.buffer_from_pyval(data_b)
print("JaxLib version:", jaxlib.__version__)
print("Available devices:", client.devices())
print("Platform:", client.platform)JaxLib implements a layered architecture with clear separation of concerns:
The design enables JAX to transform and scale numerical programs efficiently across different computing platforms through a consistent interface while allowing low-level optimization and hardware-specific acceleration.
Core XLA client functionality including client creation, device management, compilation, and execution. Provides the main interface for interacting with XLA backends and managing computational resources.
def make_cpu_client(
asynchronous: bool = True,
distributed_client: DistributedRuntimeClient | None = None,
node_id: int = 0,
num_nodes: int = 1,
collectives: CpuCollectives | None = None,
num_devices: int | None = None,
get_local_topology_timeout_minutes: int | None = None,
get_global_topology_timeout_minutes: int | None = None,
transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client: ...
def make_gpu_client(
distributed_client: DistributedRuntimeClient | None = None,
node_id: int = 0,
num_nodes: int = 1,
platform_name: str | None = None,
allowed_devices: set[int] | None = None,
mock: bool | None = None,
mock_gpu_topology: str | None = None,
) -> Client: ...
def make_c_api_client(
plugin_name: str,
options: dict[str, str | int | list[int] | float | bool] | None = None,
distributed_client: DistributedRuntimeClient | None = None,
transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client: ...Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance.
class Device:
id: int
host_id: int
process_index: int
platform: str
device_kind: str
client: Client
local_hardware_id: int | None
def memory(self, kind: str) -> Memory: ...
def default_memory(self) -> Memory: ...
def addressable_memories(self) -> list[Memory]: ...
def memory_stats(self) -> dict[str, int] | None: ...
class DeviceList:
def __init__(self, device_assignment: tuple[Device, ...]): ...
def __len__(self) -> int: ...
def __getitem__(self, index: Any) -> Any: ...
def __iter__(self) -> Iterator[Device]: ...
@property
def is_fully_addressable(self) -> bool: ...
@property
def addressable_device_list(self) -> DeviceList: ...
@property
def process_indices(self) -> set[int]: ...
@property
def default_memory_kind(self) -> str | None: ...
@property
def memory_kinds(self) -> tuple[str, ...]: ...
@property
def device_kind(self) -> str: ...XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes.
class Client:
platform: str
platform_version: str
runtime_type: str
def compile(
self,
computation: str | bytes,
executable_devices: DeviceList | Sequence[Device],
compile_options: CompileOptions = ...,
) -> Executable: ...
def compile_and_load(
self,
computation: str | bytes,
executable_devices: DeviceList | Sequence[Device],
compile_options: CompileOptions = ...,
host_callbacks: Sequence[Any] = ...,
) -> LoadedExecutable: ...
class LoadedExecutable:
client: Client
def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ...
def execute_sharded(
self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ...
) -> ExecuteResults: ...
def hlo_modules(self) -> list[HloModule]: ...
def get_output_memory_kinds(self) -> list[list[str]]: ...
def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.
def batched_device_put(
aval: Any,
sharding: Any,
shards: Sequence[Any],
devices: list[Device],
committed: bool = ...,
force_copy: bool = ...,
host_buffer_semantics: Any = ...,
) -> ArrayImpl: ...
def batched_copy_array_to_devices_with_sharding(
arrays: Sequence[ArrayImpl],
devices: Sequence[DeviceList],
sharding: Sequence[Any],
array_copy_semantics: Sequence[ArrayCopySemantics],
) -> Sequence[ArrayImpl]: ...
def reorder_shards(
x: ArrayImpl,
dst_sharding: Any,
array_copy_semantics: ArrayCopySemantics,
) -> ArrayImpl: ...
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.
class Sharding: ...
class NamedSharding(Sharding):
def __init__(
self,
mesh: Any,
spec: Any,
*,
memory_kind: str | None = None,
_logical_device_ids: tuple[int, ...] | None = None,
): ...
mesh: Any
spec: Any
class SingleDeviceSharding(Sharding):
def __init__(self, device: Device, *, memory_kind: str | None = None): ...
class GSPMDSharding(Sharding):
def __init__(
self,
devices: Sequence[Device],
op_sharding: OpSharding | HloSharding,
*,
memory_kind: str | None = None,
_device_list: DeviceList | None = None,
): ...
class HloSharding:
@staticmethod
def from_proto(proto: OpSharding) -> HloSharding: ...
@staticmethod
def replicate() -> HloSharding: ...
@staticmethod
def manual() -> HloSharding: ...
def is_replicated(self) -> bool: ...
def is_tiled(self) -> bool: ...
def num_devices(self) -> int: ...Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.
class CustomCallTargetTraits(enum.IntFlag):
DEFAULT = 0
COMMAND_BUFFER_COMPATIBLE = 1
def register_custom_call_target(
name: str,
fn: Any,
platform: str = 'cpu',
api_version: int = 0,
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None: ...
def register_custom_call_handler(
platform: str, handler: CustomCallHandler
) -> None: ...
def register_custom_call_partitioner(
name: str,
prop_user_sharding: Callable,
partition: Callable,
infer_sharding_from_operands: Callable,
can_side_effecting_have_replicated_sharding: bool = ...,
c_api: Any | None = ...,
) -> None: ...
def custom_call_targets(platform: str) -> dict[str, Any]: ...Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.
# LAPACK operations
def registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
def prepare_lapack_call(fn_base: str, dtype: Any) -> str: ...
# GPU operations
def gpu_linalg.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
def gpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
# CPU sparse operations
def cpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.
def import_from_plugin(
plugin_name: str,
submodule_name: str,
*,
check_version: bool = True
) -> ModuleType | None: ...
def check_plugin_version(
plugin_name: str,
jaxlib_version: str,
plugin_version: str
) -> bool: ...
def pjrt_plugin_loaded(plugin_name: str) -> bool: ...
def load_pjrt_plugin_dynamically(
plugin_name: str, library_path: str
) -> Any: ...
def initialize_pjrt_plugin(plugin_name: str) -> None: ...# Core types
class Shape:
def __init__(self, s: str): ...
@staticmethod
def array_shape(
type: np.dtype | PrimitiveType,
dims_seq: Any = ...,
layout_seq: Any = ...,
dynamic_dimensions: list[bool] | None = ...,
) -> Shape: ...
def dimensions(self) -> tuple[int, ...]: ...
def rank(self) -> int: ...
def is_array(self) -> bool: ...
def is_tuple(self) -> bool: ...
class PrimitiveType(enum.IntEnum):
PRED = ...
S8 = ...
S16 = ...
S32 = ...
S64 = ...
U8 = ...
U16 = ...
U32 = ...
U64 = ...
F16 = ...
F32 = ...
F64 = ...
BF16 = ...
C64 = ...
C128 = ...
class ArrayCopySemantics(enum.IntEnum):
ALWAYS_COPY = ...
REUSE_INPUT = ...
DONATE_INPUT = ...
class HostBufferSemantics(enum.IntEnum):
IMMUTABLE_ONLY_DURING_CALL = ...
IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ...
ZERO_COPY = ...
# Exception types
class XlaRuntimeError(RuntimeError): ...
class GpuLibNotLinkedError(Exception):
"""Raised when the GPU library is not linked."""