or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md
tile.json

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/jaxlib@0.7.x

To install, run

npx @tessl/cli install tessl/pypi-jaxlib@0.7.0

index.mddocs/

JaxLib

JaxLib is the XLA library for JAX, serving as the foundational support library that provides low-level binary components including Python bindings to XLA, the PJRT runtime, and handwritten kernels. It enables JAX's high-performance numerical computing capabilities on various hardware accelerators including CPUs, GPUs, and TPUs, supporting automatic differentiation, just-in-time compilation, vectorization, and distributed computing.

Package Information

  • Package Name: jaxlib
  • Language: Python
  • Installation: pip install jaxlib
  • Dependencies: scipy>=1.12, numpy>=1.26, ml_dtypes>=0.5.0
  • Hardware Support: CPU, GPU (CUDA/ROCm), TPU

Core Imports

import jaxlib

For XLA client operations:

from jaxlib import xla_client

Basic Usage

from jaxlib import xla_client
import numpy as np

# Create a CPU client
client = xla_client.make_cpu_client()

# Create a simple computation
def simple_add(a, b):
    return a + b

# Convert data to buffers
data_a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
data_b = np.array([4.0, 5.0, 6.0], dtype=np.float32)

buffer_a = client.buffer_from_pyval(data_a)
buffer_b = client.buffer_from_pyval(data_b)

print("JaxLib version:", jaxlib.__version__)
print("Available devices:", client.devices())
print("Platform:", client.platform)

Architecture

JaxLib implements a layered architecture with clear separation of concerns:

  • XLA Client Layer: High-level Python API for XLA operations and compilation
  • PJRT Runtime: Platform-specific runtime for executing compiled programs
  • Device Backends: Hardware-specific implementations (CPU, GPU, TPU)
  • Custom Operations: Extensible system for user-defined operations
  • Distributed Computing: Multi-node execution and communication primitives

The design enables JAX to transform and scale numerical programs efficiently across different computing platforms through a consistent interface while allowing low-level optimization and hardware-specific acceleration.

Capabilities

XLA Client Operations

Core XLA client functionality including client creation, device management, compilation, and execution. Provides the main interface for interacting with XLA backends and managing computational resources.

def make_cpu_client(
    asynchronous: bool = True,
    distributed_client: DistributedRuntimeClient | None = None,
    node_id: int = 0,
    num_nodes: int = 1,
    collectives: CpuCollectives | None = None,
    num_devices: int | None = None,
    get_local_topology_timeout_minutes: int | None = None,
    get_global_topology_timeout_minutes: int | None = None,
    transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client: ...

def make_gpu_client(
    distributed_client: DistributedRuntimeClient | None = None,
    node_id: int = 0,
    num_nodes: int = 1,
    platform_name: str | None = None,
    allowed_devices: set[int] | None = None,
    mock: bool | None = None,
    mock_gpu_topology: str | None = None,
) -> Client: ...

def make_c_api_client(
    plugin_name: str,
    options: dict[str, str | int | list[int] | float | bool] | None = None,
    distributed_client: DistributedRuntimeClient | None = None,
    transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client: ...

XLA Client

Device and Memory Management

Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance.

class Device:
    id: int
    host_id: int
    process_index: int
    platform: str
    device_kind: str
    client: Client
    local_hardware_id: int | None
    
    def memory(self, kind: str) -> Memory: ...
    def default_memory(self) -> Memory: ...
    def addressable_memories(self) -> list[Memory]: ...
    def memory_stats(self) -> dict[str, int] | None: ...

class DeviceList:
    def __init__(self, device_assignment: tuple[Device, ...]): ...
    def __len__(self) -> int: ...
    def __getitem__(self, index: Any) -> Any: ...
    def __iter__(self) -> Iterator[Device]: ...
    
    @property
    def is_fully_addressable(self) -> bool: ...
    @property
    def addressable_device_list(self) -> DeviceList: ...
    @property
    def process_indices(self) -> set[int]: ...
    @property
    def default_memory_kind(self) -> str | None: ...
    @property
    def memory_kinds(self) -> tuple[str, ...]: ...
    @property
    def device_kind(self) -> str: ...

Device Management

Compilation and Execution

XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes.

class Client:
    platform: str
    platform_version: str
    runtime_type: str
    
    def compile(
        self,
        computation: str | bytes,
        executable_devices: DeviceList | Sequence[Device],
        compile_options: CompileOptions = ...,
    ) -> Executable: ...
    
    def compile_and_load(
        self,
        computation: str | bytes,
        executable_devices: DeviceList | Sequence[Device],
        compile_options: CompileOptions = ...,
        host_callbacks: Sequence[Any] = ...,
    ) -> LoadedExecutable: ...

class LoadedExecutable:
    client: Client
    
    def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ...
    def execute_sharded(
        self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ...
    ) -> ExecuteResults: ...
    def hlo_modules(self) -> list[HloModule]: ...
    def get_output_memory_kinds(self) -> list[list[str]]: ...
    def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...

Compilation and Execution

Array and Buffer Operations

High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.

def batched_device_put(
    aval: Any,
    sharding: Any,
    shards: Sequence[Any],
    devices: list[Device],
    committed: bool = ...,
    force_copy: bool = ...,
    host_buffer_semantics: Any = ...,
) -> ArrayImpl: ...

def batched_copy_array_to_devices_with_sharding(
    arrays: Sequence[ArrayImpl],
    devices: Sequence[DeviceList],
    sharding: Sequence[Any],
    array_copy_semantics: Sequence[ArrayCopySemantics],
) -> Sequence[ArrayImpl]: ...

def reorder_shards(
    x: ArrayImpl,
    dst_sharding: Any,
    array_copy_semantics: ArrayCopySemantics,
) -> ArrayImpl: ...

def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...

Array Operations

Sharding and Distribution

Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.

class Sharding: ...

class NamedSharding(Sharding):
    def __init__(
        self,
        mesh: Any,
        spec: Any,
        *,
        memory_kind: str | None = None,
        _logical_device_ids: tuple[int, ...] | None = None,
    ): ...
    mesh: Any
    spec: Any

class SingleDeviceSharding(Sharding):
    def __init__(self, device: Device, *, memory_kind: str | None = None): ...

class GSPMDSharding(Sharding):
    def __init__(
        self,
        devices: Sequence[Device],
        op_sharding: OpSharding | HloSharding,
        *,
        memory_kind: str | None = None,
        _device_list: DeviceList | None = None,
    ): ...

class HloSharding:
    @staticmethod
    def from_proto(proto: OpSharding) -> HloSharding: ...
    @staticmethod
    def replicate() -> HloSharding: ...
    @staticmethod
    def manual() -> HloSharding: ...
    
    def is_replicated(self) -> bool: ...
    def is_tiled(self) -> bool: ...
    def num_devices(self) -> int: ...

Sharding

Custom Operations

Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.

class CustomCallTargetTraits(enum.IntFlag):
    DEFAULT = 0
    COMMAND_BUFFER_COMPATIBLE = 1

def register_custom_call_target(
    name: str,
    fn: Any,
    platform: str = 'cpu',
    api_version: int = 0,
    traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None: ...

def register_custom_call_handler(
    platform: str, handler: CustomCallHandler
) -> None: ...

def register_custom_call_partitioner(
    name: str,
    prop_user_sharding: Callable,
    partition: Callable,
    infer_sharding_from_operands: Callable,
    can_side_effecting_have_replicated_sharding: bool = ...,
    c_api: Any | None = ...,
) -> None: ...

def custom_call_targets(platform: str) -> dict[str, Any]: ...

Custom Operations

Hardware-Specific Operations

Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.

# LAPACK operations
def registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
def prepare_lapack_call(fn_base: str, dtype: Any) -> str: ...

# GPU operations  
def gpu_linalg.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...
def gpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...

# CPU sparse operations
def cpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...

Hardware-Specific Operations

Plugin System

Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.

def import_from_plugin(
    plugin_name: str, 
    submodule_name: str, 
    *, 
    check_version: bool = True
) -> ModuleType | None: ...

def check_plugin_version(
    plugin_name: str, 
    jaxlib_version: str, 
    plugin_version: str
) -> bool: ...

def pjrt_plugin_loaded(plugin_name: str) -> bool: ...

def load_pjrt_plugin_dynamically(
    plugin_name: str, library_path: str
) -> Any: ...

def initialize_pjrt_plugin(plugin_name: str) -> None: ...

Plugin System

Types

# Core types
class Shape:
    def __init__(self, s: str): ...
    @staticmethod
    def array_shape(
        type: np.dtype | PrimitiveType,
        dims_seq: Any = ...,
        layout_seq: Any = ...,
        dynamic_dimensions: list[bool] | None = ...,
    ) -> Shape: ...
    
    def dimensions(self) -> tuple[int, ...]: ...
    def rank(self) -> int: ...
    def is_array(self) -> bool: ...
    def is_tuple(self) -> bool: ...

class PrimitiveType(enum.IntEnum):
    PRED = ...
    S8 = ...
    S16 = ...
    S32 = ...
    S64 = ...
    U8 = ...
    U16 = ...
    U32 = ...
    U64 = ...
    F16 = ...
    F32 = ...
    F64 = ...
    BF16 = ...
    C64 = ...
    C128 = ...

class ArrayCopySemantics(enum.IntEnum):
    ALWAYS_COPY = ...
    REUSE_INPUT = ...
    DONATE_INPUT = ...

class HostBufferSemantics(enum.IntEnum):
    IMMUTABLE_ONLY_DURING_CALL = ...
    IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ...  
    ZERO_COPY = ...

# Exception types
class XlaRuntimeError(RuntimeError): ...

class GpuLibNotLinkedError(Exception):
    """Raised when the GPU library is not linked."""