CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

compilation-execution.mddocs/

Compilation and Execution

XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes. Provides the core functionality for transforming high-level computations into optimized executable code.

Capabilities

Compilation Options

Configuration options for controlling XLA compilation behavior and optimizations.

class CompileOptions:
    """Options for XLA compilation."""
    
    @staticmethod
    def ParseFromString(s: bytes) -> CompileOptions:
        """Parse compilation options from serialized bytes."""
    
    def __init__(self) -> None: ...
    
    def SerializeAsString(self) -> bytes:
        """Serialize compilation options to bytes."""
    
    argument_layouts: list[Shape] | None
    parameter_is_tupled_arguments: bool
    executable_build_options: ExecutableBuildOptions
    tuple_arguments: bool
    num_replicas: int
    num_partitions: int
    profile_version: int
    device_assignment: DeviceAssignment | None
    compile_portable_executable: bool
    env_option_overrides: list[tuple[str, str]]

class ExecutableBuildOptions:
    """Options for building executables."""
    
    def __init__(self) -> None: ...
    
    result_layout: Shape | None
    fdo_profile: bytes | None
    num_replicas: int
    num_partitions: int
    debug_options: DebugOptions
    device_assignment: DeviceAssignment | None
    use_spmd_partitioning: bool
    use_auto_spmd_partitioning: bool
    auto_spmd_partitioning_mesh_shape: list[int]
    auto_spmd_partitioning_mesh_ids: list[int]
    use_shardy_partitioner: bool
    
    def compilation_environments_from_serialized_proto(
        self, serialized_proto: bytes
    ) -> None:
        """Set compilation environments from serialized proto."""

class DebugOptions:
    """Debug and optimization options for XLA."""
    
    xla_cpu_enable_fast_math: bool
    xla_gpu_enable_fast_min_max: bool
    xla_backend_optimization_level: int
    xla_cpu_enable_xprof_traceme: bool
    xla_force_host_platform_device_count: int
    xla_dump_to: str
    xla_dump_hlo_module_re: str
    xla_dump_hlo_as_text: bool
    xla_dump_hlo_as_proto: bool
    xla_detailed_logging: bool
    xla_enable_dumping: bool

Compilation Interface

Client methods for compiling XLA computations into executable forms.

class Client:
    """XLA client compilation interface."""
    
    def compile(
        self,
        computation: str | bytes,
        executable_devices: DeviceList | Sequence[Device],
        compile_options: CompileOptions = ...,
    ) -> Executable:
        """
        Compile XLA computation to executable.
        
        Parameters:
        - computation: HLO module as string or serialized bytes
        - executable_devices: Target devices for execution
        - compile_options: Compilation configuration options
        
        Returns:
        Compiled Executable object
        """
    
    def compile_and_load(
        self,
        computation: str | bytes,
        executable_devices: DeviceList | Sequence[Device],
        compile_options: CompileOptions = ...,
        host_callbacks: Sequence[Any] = ...,
    ) -> LoadedExecutable:
        """
        Compile and load XLA computation for execution.
        
        Parameters:
        - computation: HLO module as string or serialized bytes
        - executable_devices: Target devices for execution
        - compile_options: Compilation configuration options
        - host_callbacks: Host callback functions
        
        Returns:
        LoadedExecutable ready for execution
        """
    
    def serialize_executable(self, executable: LoadedExecutable) -> bytes:
        """
        Serialize loaded executable to bytes.
        
        Parameters:
        - executable: LoadedExecutable to serialize
        
        Returns:
        Serialized executable as bytes
        """
    
    def deserialize_executable(
        self,
        serialized: bytes,
        executable_devices: DeviceList | Sequence[Device],
        options: CompileOptions | None,
        host_callbacks: Sequence[Any] = ...,
    ) -> LoadedExecutable:
        """
        Deserialize executable from bytes.
        
        Parameters:
        - serialized: Serialized executable bytes
        - executable_devices: Target devices for execution
        - options: Compilation options
        - host_callbacks: Host callback functions
        
        Returns:
        LoadedExecutable ready for execution
        """

Executable Interface

Compiled executable representation with metadata and analysis capabilities.

class Executable:
    """Compiled XLA executable."""
    
    def hlo_modules(self) -> list[HloModule]:
        """Get HLO modules comprising this executable."""
    
    def get_output_memory_kinds(self) -> list[list[str]]:
        """Get memory kinds for outputs."""
    
    def get_output_shardings(self) -> list[OpSharding] | None:
        """Get output sharding specifications."""
    
    def get_parameter_shardings(self) -> list[OpSharding] | None:
        """Get parameter sharding specifications."""
    
    def get_parameter_layouts(self) -> list[Layout]:
        """Get parameter data layouts."""
    
    def get_output_layouts(self) -> list[Layout]:
        """Get output data layouts."""
    
    def get_compiled_memory_stats(self) -> CompiledMemoryStats:
        """Get compiled memory usage statistics."""
    
    def serialize(self) -> str:
        """Serialize executable to string."""
    
    def cost_analysis(self) -> dict[str, Any]:
        """Get cost analysis information."""

Execution Interface

Loaded executable with execution capabilities and resource management.

class LoadedExecutable:
    """Loaded executable ready for execution."""
    
    client: Client
    traceback: Traceback
    fingerprint: bytes | None
    
    def local_devices(self) -> list[Device]:
        """Get local devices for this executable."""
    
    def size_of_generated_code_in_bytes(self) -> int:
        """Get generated code size in bytes."""
    
    def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]:
        """
        Execute on single replica with array arguments.
        
        Parameters:
        - arguments: Input arrays for computation
        
        Returns:
        List of output arrays
        """
    
    def execute_with_token(
        self, arguments: Sequence[ArrayImpl]
    ) -> tuple[list[ArrayImpl], Token]:
        """
        Execute with token for ordering.
        
        Parameters:
        - arguments: Input arrays for computation
        
        Returns:
        Tuple of (output arrays, execution token)
        """
    
    def execute_sharded(
        self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = False
    ) -> ExecuteResults:
        """
        Execute on multiple replicas with sharded arguments.
        
        Parameters:
        - arguments: Sharded input arrays per replica
        - with_tokens: Whether to return execution tokens
        
        Returns:
        ExecuteResults containing sharded outputs
        """
    
    def hlo_modules(self) -> list[HloModule]:
        """Get HLO modules comprising this executable."""
    
    def get_output_memory_kinds(self) -> list[list[str]]:
        """Get memory kinds for outputs."""
    
    def get_compiled_memory_stats(self) -> CompiledMemoryStats:
        """Get compiled memory usage statistics."""
    
    def get_output_shardings(self) -> list[OpSharding] | None:
        """Get output sharding specifications."""
    
    def get_parameter_shardings(self) -> list[OpSharding] | None:
        """Get parameter sharding specifications."""
    
    def get_parameter_layouts(self) -> list[Layout]:
        """Get parameter data layouts."""
    
    def get_output_layouts(self) -> list[Layout]:
        """Get output data layouts."""
    
    def keep_alive(self) -> None:
        """Keep executable alive in memory."""
    
    def cost_analysis(self) -> dict[str, Any]:
        """Get cost analysis information."""

Execution Results

Container for managing execution results from sharded computations.

class ExecuteResults:
    """Results container for sharded execution."""
    
    def __len__(self) -> int:
        """Get number of result sets."""
    
    def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]:
        """
        Disassemble results into single-device arrays.
        
        Returns:
        List of array lists, one per device
        """
    
    def disassemble_prefix_into_single_device_arrays(
        self, n: int
    ) -> list[list[ArrayImpl]]:
        """
        Disassemble first n results into single-device arrays.
        
        Parameters:
        - n: Number of results to disassemble
        
        Returns:
        List of array lists for first n results
        """
    
    def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]:
        """
        Consume results with custom handlers.
        
        Parameters:
        - handlers: List of handler functions
        
        Returns:
        List of handler results
        """
    
    def consume_token(self) -> ShardedToken:
        """Consume execution token from results."""

Execution Tokens

Token system for managing execution ordering and synchronization.

class Token:
    """Execution token for single-device operations."""
    
    def block_until_ready(self):
        """Block until token is ready."""

class ShardedToken:
    """Execution token for sharded operations."""
    
    def block_until_ready(self):
        """Block until all shards are ready."""
    
    def get_token(self, device_id: int):
        """Get token for specific device."""

Memory Statistics

Detailed memory usage information for compiled executables.

class CompiledMemoryStats:
    """Memory usage statistics for compiled executable."""
    
    generated_code_size_in_bytes: int
    argument_size_in_bytes: int
    output_size_in_bytes: int
    alias_size_in_bytes: int
    temp_size_in_bytes: int
    host_generated_code_size_in_bytes: int
    host_argument_size_in_bytes: int
    host_output_size_in_bytes: int
    host_alias_size_in_bytes: int
    host_temp_size_in_bytes: int
    serialized_buffer_assignment_proto: bytes
    
    def __str__(self) -> str:
        """Get string representation of memory stats."""

Usage Examples

Basic Compilation and Execution

from jaxlib import xla_client
import numpy as np

# Create client and get device
client = xla_client.make_cpu_client()
device = client.local_devices()[0]

# Simple HLO computation (add two arrays)
hlo_text = """
HloModule add_module

ENTRY add_computation {
  x = f32[3] parameter(0)
  y = f32[3] parameter(1)
  ROOT add = f32[3] add(x, y)
}
"""

# Compile the computation
executable = client.compile_and_load(
    hlo_text,
    executable_devices=[device]
)

# Prepare input data
a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
b = np.array([4.0, 5.0, 6.0], dtype=np.float32)

# Create device buffers
buffer_a = client.buffer_from_pyval(a, device=device)
buffer_b = client.buffer_from_pyval(b, device=device)

# Execute the computation
result_buffers = executable.execute([buffer_a, buffer_b])
result = np.array(result_buffers[0])

print(f"Result: {result}")  # [5.0, 7.0, 9.0]

Compilation with Options

from jaxlib import xla_client

client = xla_client.make_cpu_client()
devices = client.local_devices()

# Create compilation options
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = 1
compile_options.num_partitions = 1

# Build options with debug settings
build_options = xla_client.ExecutableBuildOptions()
build_options.debug_options.xla_backend_optimization_level = 2
build_options.debug_options.xla_dump_hlo_as_text = True
compile_options.executable_build_options = build_options

# Compile with options
executable = client.compile_and_load(
    hlo_text,
    executable_devices=devices[:1],
    compile_options=compile_options
)

# Get compilation info
stats = executable.get_compiled_memory_stats()
print(f"Generated code size: {stats.generated_code_size_in_bytes} bytes")
print(f"Argument size: {stats.argument_size_in_bytes} bytes")

Sharded Execution

from jaxlib import xla_client
import numpy as np

client = xla_client.make_cpu_client()
devices = client.local_devices()

if len(devices) >= 2:
    # HLO for element-wise operation across devices
    hlo_sharded = """
    HloModule sharded_add
    
    ENTRY computation {
      x = f32[2] parameter(0)
      y = f32[2] parameter(1)
      ROOT add = f32[2] add(x, y)
    }
    """
    
    # Compile for multiple devices
    executable = client.compile_and_load(
        hlo_sharded,
        executable_devices=devices[:2]
    )
    
    # Prepare sharded inputs (one shard per device)
    shard1_a = client.buffer_from_pyval(np.array([1.0, 2.0], dtype=np.float32), devices[0])
    shard1_b = client.buffer_from_pyval(np.array([3.0, 4.0], dtype=np.float32), devices[0])
    
    shard2_a = client.buffer_from_pyval(np.array([5.0, 6.0], dtype=np.float32), devices[1])
    shard2_b = client.buffer_from_pyval(np.array([7.0, 8.0], dtype=np.float32), devices[1])
    
    # Execute with sharded inputs
    sharded_args = [[shard1_a, shard1_b], [shard2_a, shard2_b]]
    results = executable.execute_sharded(sharded_args)
    
    # Get results from each device
    output_arrays = results.disassemble_into_single_device_arrays()
    for i, device_output in enumerate(output_arrays):
        result = np.array(device_output[0])
        print(f"Device {i} result: {result}")

Executable Serialization

from jaxlib import xla_client

client = xla_client.make_cpu_client()
device = client.local_devices()[0]

# Compile executable
executable = client.compile_and_load(hlo_text, [device])

# Serialize for storage/transfer
serialized = client.serialize_executable(executable)
print(f"Serialized size: {len(serialized)} bytes")

# Deserialize executable
restored_executable = client.deserialize_executable(
    serialized,
    executable_devices=[device],
    options=None
)

# Use restored executable
result = restored_executable.execute([buffer_a, buffer_b])

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json