CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

xla-client.mddocs/

XLA Client Operations

Core XLA client functionality providing the main interface for interacting with XLA backends, managing computational resources, and creating clients for different hardware platforms.

Capabilities

Client Creation

Factory functions for creating XLA clients targeting different hardware platforms with platform-specific configuration options.

def make_cpu_client(
    asynchronous: bool = True,
    distributed_client: DistributedRuntimeClient | None = None,
    node_id: int = 0,
    num_nodes: int = 1,
    collectives: CpuCollectives | None = None,
    num_devices: int | None = None,
    get_local_topology_timeout_minutes: int | None = None,
    get_global_topology_timeout_minutes: int | None = None,
    transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client:
    """
    Create a CPU client for XLA computations.
    
    Parameters:
    - asynchronous: Whether to use asynchronous execution
    - distributed_client: Client for distributed computing
    - node_id: Node identifier in distributed setup
    - num_nodes: Total number of nodes
    - collectives: CPU collective operations interface
    - num_devices: Number of CPU devices to use
    - get_local_topology_timeout_minutes: Timeout for local topology
    - get_global_topology_timeout_minutes: Timeout for global topology
    - transfer_server_factory: Factory for transfer servers
    
    Returns:
    XLA Client configured for CPU execution
    """

def make_gpu_client(
    distributed_client: DistributedRuntimeClient | None = None,
    node_id: int = 0,
    num_nodes: int = 1,
    platform_name: str | None = None,
    allowed_devices: set[int] | None = None,
    mock: bool | None = None,
    mock_gpu_topology: str | None = None,
) -> Client:
    """
    Create a GPU client for XLA computations.
    
    Parameters:
    - distributed_client: Client for distributed computing
    - node_id: Node identifier in distributed setup
    - num_nodes: Total number of nodes
    - platform_name: GPU platform name ('cuda' or 'rocm')
    - allowed_devices: Set of allowed GPU device IDs
    - mock: Whether to use mock GPU for testing
    - mock_gpu_topology: Mock topology specification
    
    Returns:
    XLA Client configured for GPU execution
    """

def make_c_api_client(
    plugin_name: str,
    options: dict[str, str | int | list[int] | float | bool] | None = None,
    distributed_client: DistributedRuntimeClient | None = None,
    transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client:
    """
    Create a client using the PJRT C API for plugins.
    
    Parameters:
    - plugin_name: Name of the PJRT plugin
    - options: Platform-specific options dictionary
    - distributed_client: Client for distributed computing
    - transfer_server_factory: Factory for transfer servers
    
    Returns:
    XLA Client using the specified plugin
    """

Client Interface

The main Client class providing access to devices, compilation, and execution capabilities.

class Client:
    """XLA client for managing devices and executing computations."""
    
    platform: str
    platform_version: str
    runtime_type: str
    
    def device_count(self) -> int:
        """Get total number of devices."""
    
    def local_device_count(self) -> int:
        """Get number of local devices."""
    
    def devices(self) -> list[Device]:
        """Get all available devices."""
    
    def local_devices(self) -> list[Device]:
        """Get locally available devices."""
    
    def host_id(self) -> int:
        """Get host identifier."""
    
    def process_index(self) -> int:
        """Get process index in distributed setup."""
    
    def buffer_from_pyval(
        self,
        argument: Any,
        device: Device | None = None,
        force_copy: bool = False,
        host_buffer_semantics: HostBufferSemantics = ...,
    ) -> ArrayImpl:
        """
        Create a buffer from Python value.
        
        Parameters:
        - argument: Python value to convert
        - device: Target device (None for default)
        - force_copy: Force copying even if not necessary
        - host_buffer_semantics: How to handle host buffer
        
        Returns:
        Array buffer on the specified device
        """
    
    def live_buffers(self) -> list[Any]:
        """Get list of live buffers."""
    
    def live_arrays(self) -> list[ArrayImpl]:
        """Get list of live arrays."""
    
    def live_executables(self) -> list[LoadedExecutable]:
        """Get list of live executables."""
    
    def heap_profile(self) -> bytes:
        """Get heap profile for memory debugging."""

Execution Utilities

Thread-level execution control for managing computation streams.

def execution_stream_id(new_id: int):
    """
    Context manager that overwrites and restores the current thread's execution_stream_id.
    
    Parameters:
    - new_id: New execution stream ID to set for the current thread
    
    Returns:
    Context manager that restores the original execution stream ID on exit
    
    Usage:
    with execution_stream_id(42):
        # Code executed with stream ID 42
        pass
    # Original stream ID restored
    """

GPU Plugin Options

Utilities for configuring GPU-specific options and plugin parameters.

def generate_pjrt_gpu_plugin_options() -> dict[str, str | int | list[int] | float | bool]:
    """
    Generate PjRt GPU plugin options from environment variables.
    
    Reads configuration from environment variables:
    - XLA_PYTHON_CLIENT_ALLOCATOR: Memory allocator type
    - XLA_CLIENT_MEM_FRACTION: GPU memory fraction to use
    - XLA_PYTHON_CLIENT_PREALLOCATE: Whether to preallocate memory
    - XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB: Collective memory size
    
    Returns:
    Dictionary of plugin options
    """

Topology Management

Functions for managing device topology and creating topology descriptions for different platforms.

def make_tfrt_tpu_c_api_device_topology(
    topology_name: str | None = None, **kwargs
) -> DeviceTopology:
    """
    Create TPU device topology using TFRT C API.
    
    Parameters:
    - topology_name: Name of the topology
    - **kwargs: Additional topology options
    
    Returns:
    DeviceTopology for TPU devices
    """

def make_c_api_device_topology(
    c_api: Any, topology_name: str = '', **kwargs
) -> DeviceTopology:
    """
    Create device topology using C API.
    
    Parameters:
    - c_api: C API interface
    - topology_name: Name of the topology
    - **kwargs: Additional topology options
    
    Returns:
    DeviceTopology for the specified platform
    """

def get_topology_for_devices(devices: list[Device]) -> DeviceTopology:
    """
    Get topology description for a list of devices.
    
    Parameters:
    - devices: List of devices
    
    Returns:
    DeviceTopology describing the device layout
    """

Distributed Runtime

Classes and functions for managing distributed computing across multiple nodes and processes.

class DistributedRuntimeClient:
    """Client for distributed runtime coordination."""
    
    def connect(self) -> Any:
        """Connect to distributed runtime service."""
    
    def shutdown(self) -> Any:
        """Shutdown the distributed runtime client."""
    
    def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> Any:
        """Blocking get operation for key-value store."""
    
    def key_value_set(
        self, key: str, value: str, allow_overwrite: bool = False
    ) -> Any:
        """Set operation for key-value store."""
    
    def wait_at_barrier(
        self,
        barrier_id: str,
        timeout_in_ms: int,
        process_ids: list[int] | None = None,
    ) -> Any:
        """Wait at a named barrier for synchronization."""

class DistributedRuntimeService:
    """Service for distributed runtime coordination."""
    
    def shutdown(self) -> None:
        """Shutdown the distributed runtime service."""

def get_distributed_runtime_service(
    address: str,
    num_nodes: int,
    heartbeat_timeout: int | None = None,
    cluster_register_timeout: int | None = None,
    shutdown_timeout: int | None = None,
) -> DistributedRuntimeService:
    """
    Create a distributed runtime service.
    
    Parameters:
    - address: Service address
    - num_nodes: Number of nodes in cluster
    - heartbeat_timeout: Heartbeat timeout in milliseconds
    - cluster_register_timeout: Cluster registration timeout
    - shutdown_timeout: Shutdown timeout
    
    Returns:
    DistributedRuntimeService instance
    """

def get_distributed_runtime_client(
    address: str,
    node_id: int,
    rpc_timeout: int | None = None,
    init_timeout: int | None = None,
    shutdown_timeout: int | None = None,
    heartbeat_timeout: int | None = None,
    missed_heartbeat_callback: Any | None = None,
    shutdown_on_destruction: bool | None = None,
    use_compression: bool | None = None,
    recoverable: bool | None = None,
) -> DistributedRuntimeClient:
    """
    Create a distributed runtime client.
    
    Parameters:
    - address: Service address to connect to
    - node_id: Unique node identifier
    - rpc_timeout: RPC timeout in milliseconds
    - init_timeout: Initialization timeout
    - shutdown_timeout: Shutdown timeout  
    - heartbeat_timeout: Heartbeat timeout
    - missed_heartbeat_callback: Callback for missed heartbeats
    - shutdown_on_destruction: Whether to shutdown on destruction
    - use_compression: Whether to use compression
    - recoverable: Whether the client is recoverable
    
    Returns:
    DistributedRuntimeClient instance
    """

Usage Examples

Basic Client Setup

from jaxlib import xla_client

# Create a CPU client
cpu_client = xla_client.make_cpu_client(asynchronous=True)
print(f"CPU devices: {cpu_client.local_devices()}")

# Create a GPU client (if available)
try:
    gpu_client = xla_client.make_gpu_client(platform_name='cuda')
    print(f"GPU devices: {gpu_client.local_devices()}")
except Exception as e:
    print(f"GPU not available: {e}")

Distributed Setup

from jaxlib import xla_client

# Start distributed runtime service on coordinator
service = xla_client.get_distributed_runtime_service(
    address="localhost:1234",
    num_nodes=2,
    heartbeat_timeout=60000
)

# Connect distributed client on each node
dist_client = xla_client.get_distributed_runtime_client(
    address="localhost:1234",
    node_id=0,  # Different for each node
    init_timeout=30000
)

# Create client with distributed support
client = xla_client.make_cpu_client(
    distributed_client=dist_client,
    node_id=0,
    num_nodes=2
)

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json