CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

device-management.mddocs/

Device and Memory Management

Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance across CPUs, GPUs, and TPUs.

Capabilities

Device Interface

Core device representation providing access to device properties, memory spaces, and hardware-specific information.

class Device:
    """Represents a computational device (CPU, GPU, TPU)."""
    
    id: int
    host_id: int
    process_index: int
    platform: str
    device_kind: str
    client: Client
    local_hardware_id: int | None
    
    def memory(self, kind: str) -> Memory:
        """
        Get memory space of specified kind.
        
        Parameters:
        - kind: Memory kind string (e.g., 'default', 'pinned')
        
        Returns:
        Memory object for the specified kind
        """
    
    def default_memory(self) -> Memory:
        """Get the default memory space for this device."""
    
    def addressable_memories(self) -> list[Memory]:
        """Get all memory spaces addressable by this device."""
    
    def live_buffers(self) -> list[Any]:
        """Get list of live buffers on this device."""
    
    def memory_stats(self) -> dict[str, int] | None:
        """
        Get memory usage statistics.
        
        Returns:
        Dictionary with memory statistics or None if not available
        """
    
    def get_stream_for_external_ready_events(self) -> int:
        """Get stream handle for external ready events."""

Memory Management

Memory space representation and management for different types of device memory.

class Memory:
    """Represents a memory space on a device."""
    
    process_index: int
    platform: str
    kind: str
    
    def addressable_by_devices(self) -> list[Device]:
        """Get devices that can address this memory space."""

def check_and_canonicalize_memory_kind(
    memory_kind: str | None, device_list: DeviceList
) -> str | None:
    """
    Check and canonicalize memory kind specification.
    
    Parameters:
    - memory_kind: Memory kind string or None
    - device_list: List of target devices
    
    Returns:
    Canonicalized memory kind or None
    """

Device Lists

Container for managing collections of devices with utilities for addressing and memory management.

class DeviceList:
    """Container for a list of devices with metadata."""
    
    def __init__(self, device_assignment: tuple[Device, ...]): ...
    
    def __len__(self) -> int:
        """Get number of devices in the list."""
    
    def __getitem__(self, index: Any) -> Any:
        """Get device at specified index."""
    
    def __iter__(self) -> Iterator[Device]:
        """Iterate over devices in the list."""
    
    @property
    def is_fully_addressable(self) -> bool:
        """Check if all devices are fully addressable."""
    
    @property
    def addressable_device_list(self) -> DeviceList:
        """Get list of addressable devices."""
    
    @property
    def process_indices(self) -> set[int]:
        """Get set of process indices for devices."""
    
    @property
    def default_memory_kind(self) -> str | None:
        """Get default memory kind for devices."""
    
    @property
    def memory_kinds(self) -> tuple[str, ...]:
        """Get tuple of available memory kinds."""
    
    @property
    def device_kind(self) -> str:
        """Get device kind for all devices."""

Device Topology

Topology information for understanding device layout and connectivity in multi-device and multi-node systems.

class DeviceTopology:
    """Represents the topology of devices in a system."""
    
    platform: str
    platform_version: str
    
    def _make_compile_only_devices(self) -> list[Device]:
        """Create compile-only devices from topology."""
    
    def serialize(self) -> bytes:
        """Serialize topology to bytes."""

Device Assignment

Utilities for assigning devices to computations in distributed and multi-device scenarios.

class DeviceAssignment:
    """Represents assignment of devices to computation replicas."""
    
    @staticmethod
    def create(array: np.ndarray) -> DeviceAssignment:
        """
        Create device assignment from array.
        
        Parameters:
        - array: 2D numpy array of device ordinals indexed by [replica][computation]
        
        Returns:
        DeviceAssignment object
        """
    
    def replica_count(self) -> int:
        """Get number of replicas."""
    
    def computation_count(self) -> int:
        """Get number of computations per replica."""
    
    def serialize(self) -> bytes:
        """Serialize device assignment to bytes."""

Layout Management

Data layout specification and management for optimal memory access patterns on different hardware.

class Layout:
    """Represents data layout in memory."""
    
    def __init__(self, minor_to_major: tuple[int, ...]): ...
    
    def minor_to_major(self) -> tuple[int, ...]:
        """Get minor-to-major dimension ordering."""
    
    def tiling(self) -> Sequence[tuple[int, ...]]:
        """Get tiling specification."""
    
    def element_size_in_bits(self) -> int:
        """Get element size in bits."""
    
    def to_string(self) -> str:
        """Get string representation of layout."""

class PjRtLayout:
    """PJRT-specific layout representation."""
    
    def _xla_layout(self) -> Layout:
        """Get underlying XLA layout."""

GPU Configuration

GPU-specific configuration and memory management options.

class GpuAllocatorConfig:
    """Configuration for GPU memory allocator."""
    
    class Kind(enum.IntEnum):
        DEFAULT = ...
        PLATFORM = ...
        BFC = ...
        CUDA_ASYNC = ...
    
    def __init__(
        self,
        kind: Kind = ...,
        memory_fraction: float = ...,
        preallocate: bool = ...,
        collective_memory_size: int = ...,
    ) -> None: ...

Usage Examples

Device Discovery and Selection

from jaxlib import xla_client

# Create client and discover devices
client = xla_client.make_cpu_client()
devices = client.devices()

print(f"Available devices: {len(devices)}")
for device in devices:
    print(f"Device {device.id}: {device.platform} ({device.device_kind})")
    print(f"  Host ID: {device.host_id}")
    print(f"  Process: {device.process_index}")
    
    # Check memory information
    default_mem = device.default_memory()
    print(f"  Default memory: {default_mem.kind}")
    
    addressable_mems = device.addressable_memories()
    print(f"  Addressable memories: {[m.kind for m in addressable_mems]}")
    
    # Get memory stats if available
    stats = device.memory_stats()
    if stats:
        print(f"  Memory stats: {stats}")

Memory Management

from jaxlib import xla_client
import numpy as np

client = xla_client.make_cpu_client()
device = client.local_devices()[0]

# Create data and put on device
data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
buffer = client.buffer_from_pyval(data, device=device)

print(f"Buffer on device: {buffer}")
print(f"Live buffers on device: {len(device.live_buffers())}")

# Check memory usage
stats = device.memory_stats()
if stats:
    print(f"Memory usage: {stats}")

Device Assignment for Multi-Device

from jaxlib import xla_client
import numpy as np

client = xla_client.make_cpu_client()
devices = client.local_devices()

if len(devices) >= 2:
    # Create device assignment for 2 replicas on 2 devices
    assignment_array = np.array([[0], [1]], dtype=np.int32)
    device_assignment = xla_client.DeviceAssignment.create(assignment_array)
    
    print(f"Replica count: {device_assignment.replica_count()}")
    print(f"Computation count: {device_assignment.computation_count()}")

Device Topology

from jaxlib import xla_client

client = xla_client.make_cpu_client()
devices = client.local_devices()

# Get topology for available devices
topology = xla_client.get_topology_for_devices(devices)
print(f"Topology platform: {topology.platform}")
print(f"Platform version: {topology.platform_version}")

# Serialize topology for transfer
topology_bytes = topology.serialize()
print(f"Serialized topology size: {len(topology_bytes)} bytes")

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json