XLA library for JAX providing low-level bindings and hardware acceleration support
—
Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance across CPUs, GPUs, and TPUs.
Core device representation providing access to device properties, memory spaces, and hardware-specific information.
class Device:
"""Represents a computational device (CPU, GPU, TPU)."""
id: int
host_id: int
process_index: int
platform: str
device_kind: str
client: Client
local_hardware_id: int | None
def memory(self, kind: str) -> Memory:
"""
Get memory space of specified kind.
Parameters:
- kind: Memory kind string (e.g., 'default', 'pinned')
Returns:
Memory object for the specified kind
"""
def default_memory(self) -> Memory:
"""Get the default memory space for this device."""
def addressable_memories(self) -> list[Memory]:
"""Get all memory spaces addressable by this device."""
def live_buffers(self) -> list[Any]:
"""Get list of live buffers on this device."""
def memory_stats(self) -> dict[str, int] | None:
"""
Get memory usage statistics.
Returns:
Dictionary with memory statistics or None if not available
"""
def get_stream_for_external_ready_events(self) -> int:
"""Get stream handle for external ready events."""Memory space representation and management for different types of device memory.
class Memory:
"""Represents a memory space on a device."""
process_index: int
platform: str
kind: str
def addressable_by_devices(self) -> list[Device]:
"""Get devices that can address this memory space."""
def check_and_canonicalize_memory_kind(
memory_kind: str | None, device_list: DeviceList
) -> str | None:
"""
Check and canonicalize memory kind specification.
Parameters:
- memory_kind: Memory kind string or None
- device_list: List of target devices
Returns:
Canonicalized memory kind or None
"""Container for managing collections of devices with utilities for addressing and memory management.
class DeviceList:
"""Container for a list of devices with metadata."""
def __init__(self, device_assignment: tuple[Device, ...]): ...
def __len__(self) -> int:
"""Get number of devices in the list."""
def __getitem__(self, index: Any) -> Any:
"""Get device at specified index."""
def __iter__(self) -> Iterator[Device]:
"""Iterate over devices in the list."""
@property
def is_fully_addressable(self) -> bool:
"""Check if all devices are fully addressable."""
@property
def addressable_device_list(self) -> DeviceList:
"""Get list of addressable devices."""
@property
def process_indices(self) -> set[int]:
"""Get set of process indices for devices."""
@property
def default_memory_kind(self) -> str | None:
"""Get default memory kind for devices."""
@property
def memory_kinds(self) -> tuple[str, ...]:
"""Get tuple of available memory kinds."""
@property
def device_kind(self) -> str:
"""Get device kind for all devices."""Topology information for understanding device layout and connectivity in multi-device and multi-node systems.
class DeviceTopology:
"""Represents the topology of devices in a system."""
platform: str
platform_version: str
def _make_compile_only_devices(self) -> list[Device]:
"""Create compile-only devices from topology."""
def serialize(self) -> bytes:
"""Serialize topology to bytes."""Utilities for assigning devices to computations in distributed and multi-device scenarios.
class DeviceAssignment:
"""Represents assignment of devices to computation replicas."""
@staticmethod
def create(array: np.ndarray) -> DeviceAssignment:
"""
Create device assignment from array.
Parameters:
- array: 2D numpy array of device ordinals indexed by [replica][computation]
Returns:
DeviceAssignment object
"""
def replica_count(self) -> int:
"""Get number of replicas."""
def computation_count(self) -> int:
"""Get number of computations per replica."""
def serialize(self) -> bytes:
"""Serialize device assignment to bytes."""Data layout specification and management for optimal memory access patterns on different hardware.
class Layout:
"""Represents data layout in memory."""
def __init__(self, minor_to_major: tuple[int, ...]): ...
def minor_to_major(self) -> tuple[int, ...]:
"""Get minor-to-major dimension ordering."""
def tiling(self) -> Sequence[tuple[int, ...]]:
"""Get tiling specification."""
def element_size_in_bits(self) -> int:
"""Get element size in bits."""
def to_string(self) -> str:
"""Get string representation of layout."""
class PjRtLayout:
"""PJRT-specific layout representation."""
def _xla_layout(self) -> Layout:
"""Get underlying XLA layout."""GPU-specific configuration and memory management options.
class GpuAllocatorConfig:
"""Configuration for GPU memory allocator."""
class Kind(enum.IntEnum):
DEFAULT = ...
PLATFORM = ...
BFC = ...
CUDA_ASYNC = ...
def __init__(
self,
kind: Kind = ...,
memory_fraction: float = ...,
preallocate: bool = ...,
collective_memory_size: int = ...,
) -> None: ...from jaxlib import xla_client
# Create client and discover devices
client = xla_client.make_cpu_client()
devices = client.devices()
print(f"Available devices: {len(devices)}")
for device in devices:
print(f"Device {device.id}: {device.platform} ({device.device_kind})")
print(f" Host ID: {device.host_id}")
print(f" Process: {device.process_index}")
# Check memory information
default_mem = device.default_memory()
print(f" Default memory: {default_mem.kind}")
addressable_mems = device.addressable_memories()
print(f" Addressable memories: {[m.kind for m in addressable_mems]}")
# Get memory stats if available
stats = device.memory_stats()
if stats:
print(f" Memory stats: {stats}")from jaxlib import xla_client
import numpy as np
client = xla_client.make_cpu_client()
device = client.local_devices()[0]
# Create data and put on device
data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
buffer = client.buffer_from_pyval(data, device=device)
print(f"Buffer on device: {buffer}")
print(f"Live buffers on device: {len(device.live_buffers())}")
# Check memory usage
stats = device.memory_stats()
if stats:
print(f"Memory usage: {stats}")from jaxlib import xla_client
import numpy as np
client = xla_client.make_cpu_client()
devices = client.local_devices()
if len(devices) >= 2:
# Create device assignment for 2 replicas on 2 devices
assignment_array = np.array([[0], [1]], dtype=np.int32)
device_assignment = xla_client.DeviceAssignment.create(assignment_array)
print(f"Replica count: {device_assignment.replica_count()}")
print(f"Computation count: {device_assignment.computation_count()}")from jaxlib import xla_client
client = xla_client.make_cpu_client()
devices = client.local_devices()
# Get topology for available devices
topology = xla_client.get_topology_for_devices(devices)
print(f"Topology platform: {topology.platform}")
print(f"Platform version: {topology.platform_version}")
# Serialize topology for transfer
topology_bytes = topology.serialize()
print(f"Serialized topology size: {len(topology_bytes)} bytes")Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib