XLA library for JAX providing low-level bindings and hardware acceleration support
—
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Pending
The risk profile of this skill
JaxLib is the XLA library for JAX, serving as the foundational support library that provides low-level binary components including Python bindings to XLA, the PJRT runtime, and handwritten kernels. It enables JAX's high-performance numerical computing capabilities on various hardware accelerators including CPUs, GPUs, and TPUs, supporting automatic differentiation, just-in-time compilation, vectorization, and distributed computing.
pip install jaxlibscipy>=1.12, numpy>=1.26, ml_dtypes>=0.5.0import jaxlibFor XLA client operations:
from jaxlib import xla_clientfrom jaxlib import xla_client
import numpy as np
# Create a CPU client
client = xla_client.make_cpu_client()
# Create a simple computation
def simple_add(a, b):
return a + b
# Convert data to buffers
data_a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
data_b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
buffer_a = client.buffer_from_pyval(data_a)
buffer_b = client.buffer_from_pyval(data_b)
print("JaxLib version:", jaxlib.__version__)
print("Available devices:", client.devices())
print("Platform:", client.platform)JaxLib implements a layered architecture with clear separation of concerns:
The design enables JAX to transform and scale numerical programs efficiently across different computing platforms through a consistent interface while allowing low-level optimization and hardware-specific acceleration.
Core XLA client functionality including client creation, device management, compilation, and execution. Provides the main interface for interacting with XLA backends and managing computational resources.
def make_cpu_client(
asynchronous: bool = True,
distributed_client: DistributedRuntimeClient | None = None,
node_id: int = 0,
num_nodes: int = 1,
collectives: CpuCollectives | None = None,
num_devices: int | None = None,
get_local_topology_timeout_minutes: int | None = None,
get_global_topology_timeout_minutes: int | None = None,
transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client: ...
def make_gpu_client(
distributed_client: DistributedRuntimeClient | None = None,
node_id: int = 0,
num_nodes: int = 1,
platform_name: str | None = None,
allowed_devices: set[int] | None = None,
mock: bool | None = None,
mock_gpu_topology: str | None = None,
) -> Client: ...
def make_c_api_client(
plugin_name: str,
options: dict[str, str | int | list[int] | float | bool] | None = None,
distributed_client: DistributedRuntimeClient | None = None,
transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client: ...Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance.
class Device:
id: int
host_id: int
process_index: int
platform: str
device_kind: str
client: Client
local_hardware_id: int | None
def memory(self, kind: str) -> Memory: ...
def default_memory(self) -> Memory: ...
def addressable_memories(self) -> list[Memory]: ...
def memory_stats(self) -> dict[str, int] | None: ...
class DeviceList:
def __init__(self, device_assignment: tuple[Device, ...]): ...
def __len__(self) -> int: ...
def __getitem__(self, index: Any) -> Any: ...
def __iter__(self) -> Iterator[Device]: ...
@property
def is_fully_addressable(self) -> bool: ...
@property
def addressable_device_list(self) -> DeviceList: ...
@property
def process_indices(self) -> set[int]: ...
@property
def default_memory_kind(self) -> str | None: ...
@property
def memory_kinds(self) -> tuple[str, ...]: ...
@property
def device_kind(self) -> str: ...XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes.
class Client:
platform: str
platform_version: str
runtime_type: str
def compile(
self,
computation: str | bytes,
executable_devices: DeviceList | Sequence[Device],
compile_options: CompileOptions = ...,
) -> Executable: ...
def compile_and_load(
self,
computation: str | bytes,
executable_devices: DeviceList | Sequence[Device],
compile_options: CompileOptions = ...,
host_callbacks: Sequence[Any] = ...,
) -> LoadedExecutable: ...
class LoadedExecutable:
client: Client
def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ...
def execute_sharded(
self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ...
) -> ExecuteResults: ...
def hlo_modules(self) -> list[HloModule]: ...
def get_output_memory_kinds(self) -> list[list[str]]: ...
def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.
def batched_device_put(
aval: Any,
sharding: Any,
shards: Sequence[Any],
devices: list[Device],
committed: bool = ...,
force_copy: bool = ...,
host_buffer_semantics: Any = ...,
) -> ArrayImpl: ...
def batched_copy_array_to_devices_with_sharding(
arrays: Sequence[ArrayImpl],
devices: Sequence[DeviceList],
sharding: Sequence[Any],
array_copy_semantics: Sequence[ArrayCopySemantics],
) -> Sequence[ArrayImpl]: ...
def reorder_shards(
x: ArrayImpl,
dst_sharding: Any,
array_copy_semantics: ArrayCopySemantics,
) -> ArrayImpl: ...
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.
class Sharding: ...
class NamedSharding(Sharding):
def __init__(
self,
mesh: Any,
spec: Any,
*,
memory_kind: str | None = None,
_logical_device_ids: tuple[int, ...] | None = None,
): ...
mesh: Any
spec: Any
class SingleDeviceSharding(Sharding):
def __init__(self, device: Device, *, memory_kind: str | None = None): ...
class GSPMDSharding(Sharding):
def __init__(
self,
devices: Sequence[Device],
op_sharding: OpSharding | HloSharding,
*,
memory_kind: str | None = None,
_device_list: DeviceList | None = None,
): ...
class HloSharding:
@staticmethod
def from_proto(proto: OpSharding) -> HloSharding: ...
@staticmethod
def replicate() -> HloSharding: ...
@staticmethod
def manual() -> HloSharding: ...
def is_replicated(self) -> bool: ...
def is_tiled(self) -> bool: ...
def num_devices(self) -> int: ...Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.
class CustomCallTargetTraits(enum.IntFlag):
DEFAULT = 0
COMMAND_BUFFER_COMPATIBLE = 1
def register_custom_call_target(
name: str,
fn: Any,
platform: str = 'cpu',
api_version: int = 0,
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None: ...
def register_custom_call_handler(
platform: str, handler: CustomCallHandler
) -> None: ...
def register_custom_call_partitioner(
name: str,
prop_user_sharding: Callable,
partition: Callable,
infer_sharding_from_operands: Callable,
can_side_effecting_have_replicated_sharding: bool = ...,
c_api: Any | None = ...,
) -> None: ...
def custom_call_targets(platform: str) -> dict[str, Any]: ...Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.
# LAPACK operations
def registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
def prepare_lapack_call(fn_base: str, dtype: Any) -> str: ...
# GPU operations
def gpu_linalg.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
def gpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
# CPU sparse operations
def cpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.
def import_from_plugin(
plugin_name: str,
submodule_name: str,
*,
check_version: bool = True
) -> ModuleType | None: ...
def check_plugin_version(
plugin_name: str,
jaxlib_version: str,
plugin_version: str
) -> bool: ...
def pjrt_plugin_loaded(plugin_name: str) -> bool: ...
def load_pjrt_plugin_dynamically(
plugin_name: str, library_path: str
) -> Any: ...
def initialize_pjrt_plugin(plugin_name: str) -> None: ...# Core types
class Shape:
def __init__(self, s: str): ...
@staticmethod
def array_shape(
type: np.dtype | PrimitiveType,
dims_seq: Any = ...,
layout_seq: Any = ...,
dynamic_dimensions: list[bool] | None = ...,
) -> Shape: ...
def dimensions(self) -> tuple[int, ...]: ...
def rank(self) -> int: ...
def is_array(self) -> bool: ...
def is_tuple(self) -> bool: ...
class PrimitiveType(enum.IntEnum):
PRED = ...
S8 = ...
S16 = ...
S32 = ...
S64 = ...
U8 = ...
U16 = ...
U32 = ...
U64 = ...
F16 = ...
F32 = ...
F64 = ...
BF16 = ...
C64 = ...
C128 = ...
class ArrayCopySemantics(enum.IntEnum):
ALWAYS_COPY = ...
REUSE_INPUT = ...
DONATE_INPUT = ...
class HostBufferSemantics(enum.IntEnum):
IMMUTABLE_ONLY_DURING_CALL = ...
IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ...
ZERO_COPY = ...
# Exception types
class XlaRuntimeError(RuntimeError): ...
class GpuLibNotLinkedError(Exception):
"""Raised when the GPU library is not linked."""