XLA library for JAX providing low-level bindings and hardware acceleration support
—
High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.
Functions for placing arrays on devices with different sharding strategies and memory semantics.
def batched_device_put(
aval: Any,
sharding: Any,
shards: Sequence[Any],
devices: list[Device],
committed: bool = False,
force_copy: bool = False,
host_buffer_semantics: Any = ...,
) -> ArrayImpl:
"""
Place array shards on devices with specified sharding.
Parameters:
- aval: Array abstract value
- sharding: Sharding specification
- shards: Array shards to place
- devices: Target devices
- committed: Whether placement is committed
- force_copy: Force copying data
- host_buffer_semantics: Host buffer handling
Returns:
ArrayImpl distributed across devices
"""
def array_result_handler(
aval: Any, sharding: Any, committed: bool, _skip_checks: bool = False
) -> Callable:
"""
Create result handler for array operations.
Parameters:
- aval: Array abstract value
- sharding: Sharding specification
- committed: Whether result is committed
- _skip_checks: Skip validation checks
Returns:
Result handler function
"""High-performance array copying operations with sharding awareness.
def batched_copy_array_to_devices_with_sharding(
arrays: Sequence[ArrayImpl],
devices: Sequence[DeviceList],
sharding: Sequence[Any],
array_copy_semantics: Sequence[ArrayCopySemantics],
) -> Sequence[ArrayImpl]:
"""
Copy arrays to devices with specified sharding.
Parameters:
- arrays: Source arrays to copy
- devices: Target device lists
- sharding: Sharding specifications
- array_copy_semantics: Copy semantics for each array
Returns:
Copied arrays on target devices
"""
def reorder_shards(
x: ArrayImpl,
dst_sharding: Any,
array_copy_semantics: ArrayCopySemantics,
) -> ArrayImpl:
"""
Reorder array shards according to destination sharding.
Parameters:
- x: Source array
- dst_sharding: Destination sharding specification
- array_copy_semantics: Copy semantics
Returns:
Array with reordered shards
"""Operations for synchronizing array operations across devices.
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None:
"""
Block until all arrays in sequence are ready.
Parameters:
- x: Sequence of arrays to wait for
"""Core array implementation providing the foundation for JAX arrays.
# ArrayImpl is defined in C++ and accessed through _jax module
# Key methods available on ArrayImpl instances:
# def block_until_ready(self) -> ArrayImpl: ...
# def is_deleted(self) -> bool: ...
# def is_ready(self) -> bool: ...
# def delete(self): ...
# def clone(self) -> ArrayImpl: ...
# def on_device_size_in_bytes(self) -> int: ...
# Properties:
# dtype: np.dtype
# shape: tuple[int, ...]
# _arrays: Any # Underlying device arrays
# traceback: Tracebackfrom jaxlib import xla_client
import numpy as np
client = xla_client.make_cpu_client()
devices = client.local_devices()
# Create array data
data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
# Place on device
buffer = client.buffer_from_pyval(data, device=devices[0])
# Check array properties
print(f"Array shape: {buffer.shape}")
print(f"Array dtype: {buffer.dtype}")
print(f"On-device size: {buffer.on_device_size_in_bytes()} bytes")
# Wait for completion
buffer.block_until_ready()
print(f"Array is ready: {buffer.is_ready()}")from jaxlib import xla_client
import numpy as np
client = xla_client.make_cpu_client()
devices = client.local_devices()
# Create multiple arrays
arrays = [
client.buffer_from_pyval(np.array([1.0, 2.0]), devices[0]),
client.buffer_from_pyval(np.array([3.0, 4.0]), devices[0]),
client.buffer_from_pyval(np.array([5.0, 6.0]), devices[0])
]
# Wait for all arrays to be ready
xla_client.batched_block_until_ready(arrays)
print("All arrays are ready")
for i, arr in enumerate(arrays):
print(f"Array {i}: ready={arr.is_ready()}")Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib