XLA library for JAX providing low-level bindings and hardware acceleration support
—
XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes. Provides the core functionality for transforming high-level computations into optimized executable code.
Configuration options for controlling XLA compilation behavior and optimizations.
class CompileOptions:
"""Options for XLA compilation."""
@staticmethod
def ParseFromString(s: bytes) -> CompileOptions:
"""Parse compilation options from serialized bytes."""
def __init__(self) -> None: ...
def SerializeAsString(self) -> bytes:
"""Serialize compilation options to bytes."""
argument_layouts: list[Shape] | None
parameter_is_tupled_arguments: bool
executable_build_options: ExecutableBuildOptions
tuple_arguments: bool
num_replicas: int
num_partitions: int
profile_version: int
device_assignment: DeviceAssignment | None
compile_portable_executable: bool
env_option_overrides: list[tuple[str, str]]
class ExecutableBuildOptions:
"""Options for building executables."""
def __init__(self) -> None: ...
result_layout: Shape | None
fdo_profile: bytes | None
num_replicas: int
num_partitions: int
debug_options: DebugOptions
device_assignment: DeviceAssignment | None
use_spmd_partitioning: bool
use_auto_spmd_partitioning: bool
auto_spmd_partitioning_mesh_shape: list[int]
auto_spmd_partitioning_mesh_ids: list[int]
use_shardy_partitioner: bool
def compilation_environments_from_serialized_proto(
self, serialized_proto: bytes
) -> None:
"""Set compilation environments from serialized proto."""
class DebugOptions:
"""Debug and optimization options for XLA."""
xla_cpu_enable_fast_math: bool
xla_gpu_enable_fast_min_max: bool
xla_backend_optimization_level: int
xla_cpu_enable_xprof_traceme: bool
xla_force_host_platform_device_count: int
xla_dump_to: str
xla_dump_hlo_module_re: str
xla_dump_hlo_as_text: bool
xla_dump_hlo_as_proto: bool
xla_detailed_logging: bool
xla_enable_dumping: boolClient methods for compiling XLA computations into executable forms.
class Client:
"""XLA client compilation interface."""
def compile(
self,
computation: str | bytes,
executable_devices: DeviceList | Sequence[Device],
compile_options: CompileOptions = ...,
) -> Executable:
"""
Compile XLA computation to executable.
Parameters:
- computation: HLO module as string or serialized bytes
- executable_devices: Target devices for execution
- compile_options: Compilation configuration options
Returns:
Compiled Executable object
"""
def compile_and_load(
self,
computation: str | bytes,
executable_devices: DeviceList | Sequence[Device],
compile_options: CompileOptions = ...,
host_callbacks: Sequence[Any] = ...,
) -> LoadedExecutable:
"""
Compile and load XLA computation for execution.
Parameters:
- computation: HLO module as string or serialized bytes
- executable_devices: Target devices for execution
- compile_options: Compilation configuration options
- host_callbacks: Host callback functions
Returns:
LoadedExecutable ready for execution
"""
def serialize_executable(self, executable: LoadedExecutable) -> bytes:
"""
Serialize loaded executable to bytes.
Parameters:
- executable: LoadedExecutable to serialize
Returns:
Serialized executable as bytes
"""
def deserialize_executable(
self,
serialized: bytes,
executable_devices: DeviceList | Sequence[Device],
options: CompileOptions | None,
host_callbacks: Sequence[Any] = ...,
) -> LoadedExecutable:
"""
Deserialize executable from bytes.
Parameters:
- serialized: Serialized executable bytes
- executable_devices: Target devices for execution
- options: Compilation options
- host_callbacks: Host callback functions
Returns:
LoadedExecutable ready for execution
"""Compiled executable representation with metadata and analysis capabilities.
class Executable:
"""Compiled XLA executable."""
def hlo_modules(self) -> list[HloModule]:
"""Get HLO modules comprising this executable."""
def get_output_memory_kinds(self) -> list[list[str]]:
"""Get memory kinds for outputs."""
def get_output_shardings(self) -> list[OpSharding] | None:
"""Get output sharding specifications."""
def get_parameter_shardings(self) -> list[OpSharding] | None:
"""Get parameter sharding specifications."""
def get_parameter_layouts(self) -> list[Layout]:
"""Get parameter data layouts."""
def get_output_layouts(self) -> list[Layout]:
"""Get output data layouts."""
def get_compiled_memory_stats(self) -> CompiledMemoryStats:
"""Get compiled memory usage statistics."""
def serialize(self) -> str:
"""Serialize executable to string."""
def cost_analysis(self) -> dict[str, Any]:
"""Get cost analysis information."""Loaded executable with execution capabilities and resource management.
class LoadedExecutable:
"""Loaded executable ready for execution."""
client: Client
traceback: Traceback
fingerprint: bytes | None
def local_devices(self) -> list[Device]:
"""Get local devices for this executable."""
def size_of_generated_code_in_bytes(self) -> int:
"""Get generated code size in bytes."""
def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]:
"""
Execute on single replica with array arguments.
Parameters:
- arguments: Input arrays for computation
Returns:
List of output arrays
"""
def execute_with_token(
self, arguments: Sequence[ArrayImpl]
) -> tuple[list[ArrayImpl], Token]:
"""
Execute with token for ordering.
Parameters:
- arguments: Input arrays for computation
Returns:
Tuple of (output arrays, execution token)
"""
def execute_sharded(
self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = False
) -> ExecuteResults:
"""
Execute on multiple replicas with sharded arguments.
Parameters:
- arguments: Sharded input arrays per replica
- with_tokens: Whether to return execution tokens
Returns:
ExecuteResults containing sharded outputs
"""
def hlo_modules(self) -> list[HloModule]:
"""Get HLO modules comprising this executable."""
def get_output_memory_kinds(self) -> list[list[str]]:
"""Get memory kinds for outputs."""
def get_compiled_memory_stats(self) -> CompiledMemoryStats:
"""Get compiled memory usage statistics."""
def get_output_shardings(self) -> list[OpSharding] | None:
"""Get output sharding specifications."""
def get_parameter_shardings(self) -> list[OpSharding] | None:
"""Get parameter sharding specifications."""
def get_parameter_layouts(self) -> list[Layout]:
"""Get parameter data layouts."""
def get_output_layouts(self) -> list[Layout]:
"""Get output data layouts."""
def keep_alive(self) -> None:
"""Keep executable alive in memory."""
def cost_analysis(self) -> dict[str, Any]:
"""Get cost analysis information."""Container for managing execution results from sharded computations.
class ExecuteResults:
"""Results container for sharded execution."""
def __len__(self) -> int:
"""Get number of result sets."""
def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]:
"""
Disassemble results into single-device arrays.
Returns:
List of array lists, one per device
"""
def disassemble_prefix_into_single_device_arrays(
self, n: int
) -> list[list[ArrayImpl]]:
"""
Disassemble first n results into single-device arrays.
Parameters:
- n: Number of results to disassemble
Returns:
List of array lists for first n results
"""
def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]:
"""
Consume results with custom handlers.
Parameters:
- handlers: List of handler functions
Returns:
List of handler results
"""
def consume_token(self) -> ShardedToken:
"""Consume execution token from results."""Token system for managing execution ordering and synchronization.
class Token:
"""Execution token for single-device operations."""
def block_until_ready(self):
"""Block until token is ready."""
class ShardedToken:
"""Execution token for sharded operations."""
def block_until_ready(self):
"""Block until all shards are ready."""
def get_token(self, device_id: int):
"""Get token for specific device."""Detailed memory usage information for compiled executables.
class CompiledMemoryStats:
"""Memory usage statistics for compiled executable."""
generated_code_size_in_bytes: int
argument_size_in_bytes: int
output_size_in_bytes: int
alias_size_in_bytes: int
temp_size_in_bytes: int
host_generated_code_size_in_bytes: int
host_argument_size_in_bytes: int
host_output_size_in_bytes: int
host_alias_size_in_bytes: int
host_temp_size_in_bytes: int
serialized_buffer_assignment_proto: bytes
def __str__(self) -> str:
"""Get string representation of memory stats."""from jaxlib import xla_client
import numpy as np
# Create client and get device
client = xla_client.make_cpu_client()
device = client.local_devices()[0]
# Simple HLO computation (add two arrays)
hlo_text = """
HloModule add_module
ENTRY add_computation {
x = f32[3] parameter(0)
y = f32[3] parameter(1)
ROOT add = f32[3] add(x, y)
}
"""
# Compile the computation
executable = client.compile_and_load(
hlo_text,
executable_devices=[device]
)
# Prepare input data
a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
# Create device buffers
buffer_a = client.buffer_from_pyval(a, device=device)
buffer_b = client.buffer_from_pyval(b, device=device)
# Execute the computation
result_buffers = executable.execute([buffer_a, buffer_b])
result = np.array(result_buffers[0])
print(f"Result: {result}") # [5.0, 7.0, 9.0]from jaxlib import xla_client
client = xla_client.make_cpu_client()
devices = client.local_devices()
# Create compilation options
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = 1
compile_options.num_partitions = 1
# Build options with debug settings
build_options = xla_client.ExecutableBuildOptions()
build_options.debug_options.xla_backend_optimization_level = 2
build_options.debug_options.xla_dump_hlo_as_text = True
compile_options.executable_build_options = build_options
# Compile with options
executable = client.compile_and_load(
hlo_text,
executable_devices=devices[:1],
compile_options=compile_options
)
# Get compilation info
stats = executable.get_compiled_memory_stats()
print(f"Generated code size: {stats.generated_code_size_in_bytes} bytes")
print(f"Argument size: {stats.argument_size_in_bytes} bytes")from jaxlib import xla_client
import numpy as np
client = xla_client.make_cpu_client()
devices = client.local_devices()
if len(devices) >= 2:
# HLO for element-wise operation across devices
hlo_sharded = """
HloModule sharded_add
ENTRY computation {
x = f32[2] parameter(0)
y = f32[2] parameter(1)
ROOT add = f32[2] add(x, y)
}
"""
# Compile for multiple devices
executable = client.compile_and_load(
hlo_sharded,
executable_devices=devices[:2]
)
# Prepare sharded inputs (one shard per device)
shard1_a = client.buffer_from_pyval(np.array([1.0, 2.0], dtype=np.float32), devices[0])
shard1_b = client.buffer_from_pyval(np.array([3.0, 4.0], dtype=np.float32), devices[0])
shard2_a = client.buffer_from_pyval(np.array([5.0, 6.0], dtype=np.float32), devices[1])
shard2_b = client.buffer_from_pyval(np.array([7.0, 8.0], dtype=np.float32), devices[1])
# Execute with sharded inputs
sharded_args = [[shard1_a, shard1_b], [shard2_a, shard2_b]]
results = executable.execute_sharded(sharded_args)
# Get results from each device
output_arrays = results.disassemble_into_single_device_arrays()
for i, device_output in enumerate(output_arrays):
result = np.array(device_output[0])
print(f"Device {i} result: {result}")from jaxlib import xla_client
client = xla_client.make_cpu_client()
device = client.local_devices()[0]
# Compile executable
executable = client.compile_and_load(hlo_text, [device])
# Serialize for storage/transfer
serialized = client.serialize_executable(executable)
print(f"Serialized size: {len(serialized)} bytes")
# Deserialize executable
restored_executable = client.deserialize_executable(
serialized,
executable_devices=[device],
options=None
)
# Use restored executable
result = restored_executable.execute([buffer_a, buffer_b])Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib