CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

array-operations.mddocs/

Array Operations

High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.

Capabilities

Array Device Placement

Functions for placing arrays on devices with different sharding strategies and memory semantics.

def batched_device_put(
    aval: Any,
    sharding: Any,
    shards: Sequence[Any],
    devices: list[Device],
    committed: bool = False,
    force_copy: bool = False,
    host_buffer_semantics: Any = ...,
) -> ArrayImpl:
    """
    Place array shards on devices with specified sharding.
    
    Parameters:
    - aval: Array abstract value
    - sharding: Sharding specification
    - shards: Array shards to place
    - devices: Target devices
    - committed: Whether placement is committed
    - force_copy: Force copying data
    - host_buffer_semantics: Host buffer handling
    
    Returns:
    ArrayImpl distributed across devices
    """

def array_result_handler(
    aval: Any, sharding: Any, committed: bool, _skip_checks: bool = False
) -> Callable:
    """
    Create result handler for array operations.
    
    Parameters:
    - aval: Array abstract value
    - sharding: Sharding specification
    - committed: Whether result is committed
    - _skip_checks: Skip validation checks
    
    Returns:
    Result handler function
    """

Array Copying and Transfer

High-performance array copying operations with sharding awareness.

def batched_copy_array_to_devices_with_sharding(
    arrays: Sequence[ArrayImpl],
    devices: Sequence[DeviceList],
    sharding: Sequence[Any],
    array_copy_semantics: Sequence[ArrayCopySemantics],
) -> Sequence[ArrayImpl]:
    """
    Copy arrays to devices with specified sharding.
    
    Parameters:
    - arrays: Source arrays to copy
    - devices: Target device lists
    - sharding: Sharding specifications
    - array_copy_semantics: Copy semantics for each array
    
    Returns:
    Copied arrays on target devices
    """

def reorder_shards(
    x: ArrayImpl,
    dst_sharding: Any,
    array_copy_semantics: ArrayCopySemantics,
) -> ArrayImpl:
    """
    Reorder array shards according to destination sharding.
    
    Parameters:
    - x: Source array
    - dst_sharding: Destination sharding specification
    - array_copy_semantics: Copy semantics
    
    Returns:
    Array with reordered shards
    """

Synchronization

Operations for synchronizing array operations across devices.

def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None:
    """
    Block until all arrays in sequence are ready.
    
    Parameters:
    - x: Sequence of arrays to wait for
    """

Array Implementation

Core array implementation providing the foundation for JAX arrays.

# ArrayImpl is defined in C++ and accessed through _jax module
# Key methods available on ArrayImpl instances:

# def block_until_ready(self) -> ArrayImpl: ...
# def is_deleted(self) -> bool: ...
# def is_ready(self) -> bool: ...  
# def delete(self): ...
# def clone(self) -> ArrayImpl: ...
# def on_device_size_in_bytes(self) -> int: ...

# Properties:
# dtype: np.dtype
# shape: tuple[int, ...]
# _arrays: Any  # Underlying device arrays
# traceback: Traceback

Usage Examples

Basic Array Placement

from jaxlib import xla_client
import numpy as np

client = xla_client.make_cpu_client()
devices = client.local_devices()

# Create array data
data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

# Place on device
buffer = client.buffer_from_pyval(data, device=devices[0])

# Check array properties
print(f"Array shape: {buffer.shape}")
print(f"Array dtype: {buffer.dtype}")
print(f"On-device size: {buffer.on_device_size_in_bytes()} bytes")

# Wait for completion
buffer.block_until_ready()
print(f"Array is ready: {buffer.is_ready()}")

Batch Operations

from jaxlib import xla_client
import numpy as np

client = xla_client.make_cpu_client()
devices = client.local_devices()

# Create multiple arrays
arrays = [
    client.buffer_from_pyval(np.array([1.0, 2.0]), devices[0]),
    client.buffer_from_pyval(np.array([3.0, 4.0]), devices[0]),
    client.buffer_from_pyval(np.array([5.0, 6.0]), devices[0])
]

# Wait for all arrays to be ready
xla_client.batched_block_until_ready(arrays)

print("All arrays are ready")
for i, arr in enumerate(arrays):
    print(f"Array {i}: ready={arr.is_ready()}")

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json