XLA library for JAX providing low-level bindings and hardware acceleration support
—
Core XLA client functionality providing the main interface for interacting with XLA backends, managing computational resources, and creating clients for different hardware platforms.
Factory functions for creating XLA clients targeting different hardware platforms with platform-specific configuration options.
def make_cpu_client(
asynchronous: bool = True,
distributed_client: DistributedRuntimeClient | None = None,
node_id: int = 0,
num_nodes: int = 1,
collectives: CpuCollectives | None = None,
num_devices: int | None = None,
get_local_topology_timeout_minutes: int | None = None,
get_global_topology_timeout_minutes: int | None = None,
transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client:
"""
Create a CPU client for XLA computations.
Parameters:
- asynchronous: Whether to use asynchronous execution
- distributed_client: Client for distributed computing
- node_id: Node identifier in distributed setup
- num_nodes: Total number of nodes
- collectives: CPU collective operations interface
- num_devices: Number of CPU devices to use
- get_local_topology_timeout_minutes: Timeout for local topology
- get_global_topology_timeout_minutes: Timeout for global topology
- transfer_server_factory: Factory for transfer servers
Returns:
XLA Client configured for CPU execution
"""
def make_gpu_client(
distributed_client: DistributedRuntimeClient | None = None,
node_id: int = 0,
num_nodes: int = 1,
platform_name: str | None = None,
allowed_devices: set[int] | None = None,
mock: bool | None = None,
mock_gpu_topology: str | None = None,
) -> Client:
"""
Create a GPU client for XLA computations.
Parameters:
- distributed_client: Client for distributed computing
- node_id: Node identifier in distributed setup
- num_nodes: Total number of nodes
- platform_name: GPU platform name ('cuda' or 'rocm')
- allowed_devices: Set of allowed GPU device IDs
- mock: Whether to use mock GPU for testing
- mock_gpu_topology: Mock topology specification
Returns:
XLA Client configured for GPU execution
"""
def make_c_api_client(
plugin_name: str,
options: dict[str, str | int | list[int] | float | bool] | None = None,
distributed_client: DistributedRuntimeClient | None = None,
transfer_server_factory: TransferServerInterfaceFactory | None = None,
) -> Client:
"""
Create a client using the PJRT C API for plugins.
Parameters:
- plugin_name: Name of the PJRT plugin
- options: Platform-specific options dictionary
- distributed_client: Client for distributed computing
- transfer_server_factory: Factory for transfer servers
Returns:
XLA Client using the specified plugin
"""The main Client class providing access to devices, compilation, and execution capabilities.
class Client:
"""XLA client for managing devices and executing computations."""
platform: str
platform_version: str
runtime_type: str
def device_count(self) -> int:
"""Get total number of devices."""
def local_device_count(self) -> int:
"""Get number of local devices."""
def devices(self) -> list[Device]:
"""Get all available devices."""
def local_devices(self) -> list[Device]:
"""Get locally available devices."""
def host_id(self) -> int:
"""Get host identifier."""
def process_index(self) -> int:
"""Get process index in distributed setup."""
def buffer_from_pyval(
self,
argument: Any,
device: Device | None = None,
force_copy: bool = False,
host_buffer_semantics: HostBufferSemantics = ...,
) -> ArrayImpl:
"""
Create a buffer from Python value.
Parameters:
- argument: Python value to convert
- device: Target device (None for default)
- force_copy: Force copying even if not necessary
- host_buffer_semantics: How to handle host buffer
Returns:
Array buffer on the specified device
"""
def live_buffers(self) -> list[Any]:
"""Get list of live buffers."""
def live_arrays(self) -> list[ArrayImpl]:
"""Get list of live arrays."""
def live_executables(self) -> list[LoadedExecutable]:
"""Get list of live executables."""
def heap_profile(self) -> bytes:
"""Get heap profile for memory debugging."""Thread-level execution control for managing computation streams.
def execution_stream_id(new_id: int):
"""
Context manager that overwrites and restores the current thread's execution_stream_id.
Parameters:
- new_id: New execution stream ID to set for the current thread
Returns:
Context manager that restores the original execution stream ID on exit
Usage:
with execution_stream_id(42):
# Code executed with stream ID 42
pass
# Original stream ID restored
"""Utilities for configuring GPU-specific options and plugin parameters.
def generate_pjrt_gpu_plugin_options() -> dict[str, str | int | list[int] | float | bool]:
"""
Generate PjRt GPU plugin options from environment variables.
Reads configuration from environment variables:
- XLA_PYTHON_CLIENT_ALLOCATOR: Memory allocator type
- XLA_CLIENT_MEM_FRACTION: GPU memory fraction to use
- XLA_PYTHON_CLIENT_PREALLOCATE: Whether to preallocate memory
- XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB: Collective memory size
Returns:
Dictionary of plugin options
"""Functions for managing device topology and creating topology descriptions for different platforms.
def make_tfrt_tpu_c_api_device_topology(
topology_name: str | None = None, **kwargs
) -> DeviceTopology:
"""
Create TPU device topology using TFRT C API.
Parameters:
- topology_name: Name of the topology
- **kwargs: Additional topology options
Returns:
DeviceTopology for TPU devices
"""
def make_c_api_device_topology(
c_api: Any, topology_name: str = '', **kwargs
) -> DeviceTopology:
"""
Create device topology using C API.
Parameters:
- c_api: C API interface
- topology_name: Name of the topology
- **kwargs: Additional topology options
Returns:
DeviceTopology for the specified platform
"""
def get_topology_for_devices(devices: list[Device]) -> DeviceTopology:
"""
Get topology description for a list of devices.
Parameters:
- devices: List of devices
Returns:
DeviceTopology describing the device layout
"""Classes and functions for managing distributed computing across multiple nodes and processes.
class DistributedRuntimeClient:
"""Client for distributed runtime coordination."""
def connect(self) -> Any:
"""Connect to distributed runtime service."""
def shutdown(self) -> Any:
"""Shutdown the distributed runtime client."""
def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> Any:
"""Blocking get operation for key-value store."""
def key_value_set(
self, key: str, value: str, allow_overwrite: bool = False
) -> Any:
"""Set operation for key-value store."""
def wait_at_barrier(
self,
barrier_id: str,
timeout_in_ms: int,
process_ids: list[int] | None = None,
) -> Any:
"""Wait at a named barrier for synchronization."""
class DistributedRuntimeService:
"""Service for distributed runtime coordination."""
def shutdown(self) -> None:
"""Shutdown the distributed runtime service."""
def get_distributed_runtime_service(
address: str,
num_nodes: int,
heartbeat_timeout: int | None = None,
cluster_register_timeout: int | None = None,
shutdown_timeout: int | None = None,
) -> DistributedRuntimeService:
"""
Create a distributed runtime service.
Parameters:
- address: Service address
- num_nodes: Number of nodes in cluster
- heartbeat_timeout: Heartbeat timeout in milliseconds
- cluster_register_timeout: Cluster registration timeout
- shutdown_timeout: Shutdown timeout
Returns:
DistributedRuntimeService instance
"""
def get_distributed_runtime_client(
address: str,
node_id: int,
rpc_timeout: int | None = None,
init_timeout: int | None = None,
shutdown_timeout: int | None = None,
heartbeat_timeout: int | None = None,
missed_heartbeat_callback: Any | None = None,
shutdown_on_destruction: bool | None = None,
use_compression: bool | None = None,
recoverable: bool | None = None,
) -> DistributedRuntimeClient:
"""
Create a distributed runtime client.
Parameters:
- address: Service address to connect to
- node_id: Unique node identifier
- rpc_timeout: RPC timeout in milliseconds
- init_timeout: Initialization timeout
- shutdown_timeout: Shutdown timeout
- heartbeat_timeout: Heartbeat timeout
- missed_heartbeat_callback: Callback for missed heartbeats
- shutdown_on_destruction: Whether to shutdown on destruction
- use_compression: Whether to use compression
- recoverable: Whether the client is recoverable
Returns:
DistributedRuntimeClient instance
"""from jaxlib import xla_client
# Create a CPU client
cpu_client = xla_client.make_cpu_client(asynchronous=True)
print(f"CPU devices: {cpu_client.local_devices()}")
# Create a GPU client (if available)
try:
gpu_client = xla_client.make_gpu_client(platform_name='cuda')
print(f"GPU devices: {gpu_client.local_devices()}")
except Exception as e:
print(f"GPU not available: {e}")from jaxlib import xla_client
# Start distributed runtime service on coordinator
service = xla_client.get_distributed_runtime_service(
address="localhost:1234",
num_nodes=2,
heartbeat_timeout=60000
)
# Connect distributed client on each node
dist_client = xla_client.get_distributed_runtime_client(
address="localhost:1234",
node_id=0, # Different for each node
init_timeout=30000
)
# Create client with distributed support
client = xla_client.make_cpu_client(
distributed_client=dist_client,
node_id=0,
num_nodes=2
)Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib