XLA library for JAX providing low-level bindings and hardware acceleration support
—
Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.
Linear algebra operations using LAPACK for CPU computations.
# From jaxlib.lapack module
class EigComputationMode(enum.Enum):
"""Eigenvalue computation modes."""
class SchurComputationMode(enum.Enum):
"""Schur decomposition computation modes."""
class SchurSort(enum.Enum):
"""Schur sorting options."""
LAPACK_DTYPE_PREFIX: dict[type, str]
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
"""
Get LAPACK operation registrations.
Returns:
Dictionary mapping platform to list of (name, capsule, api_version) tuples
"""
def batch_partitionable_targets() -> list[str]:
"""
Get list of batch-partitionable LAPACK targets.
Returns:
List of target names that support batch partitioning
"""
def prepare_lapack_call(fn_base: str, dtype: Any) -> str:
"""
Initialize LAPACK and return target name.
Parameters:
- fn_base: Base function name
- dtype: Data type
Returns:
LAPACK target name for the function and dtype
"""
def build_lapack_fn_target(fn_base: str, dtype: Any) -> str:
"""
Build LAPACK function target name.
Parameters:
- fn_base: Base function name (e.g., 'getrf')
- dtype: NumPy dtype
Returns:
Full LAPACK target name (e.g., 'lapack_sgetrf')
"""GPU-accelerated linear algebra operations using cuBLAS/cuSOLVER or ROCm equivalents.
# From jaxlib.gpu_linalg module
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
"""
Get GPU linear algebra registrations.
Returns:
Dictionary with 'CUDA' and 'ROCM' platform registrations
"""
def batch_partitionable_targets() -> list[str]:
"""
Get batch-partitionable GPU linalg targets.
Returns:
List of GPU targets supporting batch partitioning
"""Sparse matrix operations optimized for GPU execution.
# From jaxlib.gpu_sparse module
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
"""Get GPU sparse operation registrations."""Sparse matrix operations for CPU execution.
# From jaxlib.cpu_sparse module
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
"""
Get CPU sparse operation registrations.
Returns:
Dictionary with CPU sparse operation registrations
"""Common utilities and error handling for GPU operations.
# From jaxlib.gpu_common_utils module
class GpuLibNotLinkedError(Exception):
"""
Exception raised when GPU library is not linked.
Used when GPU-specific functionality is called but
JAX was not built with GPU support.
"""
error_msg: str = (
'JAX was not built with GPU support. Please use a GPU-enabled JAX to use'
' this function.'
)
def __init__(self): ...Additional GPU-specific modules for specialized operations.
# jaxlib.gpu_prng - GPU pseudo-random number generation
# jaxlib.gpu_rnn - GPU recurrent neural network operations
# jaxlib.gpu_solver - GPU linear equation solving
# jaxlib.gpu_triton - Triton kernel integrationfrom jaxlib import lapack
import numpy as np
# Check available LAPACK operations
lapack_ops = lapack.registrations()
print(f"LAPACK operations: {len(lapack_ops['cpu'])}")
# Prepare LAPACK call for LU factorization
dtype = np.float32
target_name = lapack.prepare_lapack_call("getrf", dtype)
print(f"LAPACK target: {target_name}")
# Build target name manually
manual_target = lapack.build_lapack_fn_target("getrf", dtype)
print(f"Manual target: {manual_target}")
# Check batch-partitionable targets
batch_targets = lapack.batch_partitionable_targets()
print(f"Batch targets: {batch_targets[:5]}") # Show first 5from jaxlib import gpu_linalg, gpu_sparse, gpu_common_utils
try:
# Check GPU linear algebra availability
gpu_linalg_ops = gpu_linalg.registrations()
print(f"CUDA linalg ops: {len(gpu_linalg_ops.get('CUDA', []))}")
print(f"ROCM linalg ops: {len(gpu_linalg_ops.get('ROCM', []))}")
# Check GPU sparse operations
gpu_sparse_ops = gpu_sparse.registrations()
print(f"GPU sparse ops available: {len(gpu_sparse_ops)}")
# Get batch-partitionable GPU targets
gpu_batch_targets = gpu_linalg.batch_partitionable_targets()
print(f"GPU batch targets: {gpu_batch_targets}")
except gpu_common_utils.GpuLibNotLinkedError as e:
print(f"GPU not available: {e}")from jaxlib import cpu_sparse
# Get CPU sparse operation registrations
cpu_sparse_ops = cpu_sparse.registrations()
print(f"CPU sparse operations: {len(cpu_sparse_ops['cpu'])}")
# Show some operation names
if cpu_sparse_ops['cpu']:
print("Some CPU sparse operations:")
for name, _, api_version in cpu_sparse_ops['cpu'][:3]:
print(f" {name} (API v{api_version})")from jaxlib import xla_client, gpu_common_utils
# Create clients to check hardware availability
try:
cpu_client = xla_client.make_cpu_client()
print(f"CPU devices: {len(cpu_client.local_devices())}")
except Exception as e:
print(f"CPU client error: {e}")
try:
gpu_client = xla_client.make_gpu_client()
print(f"GPU devices: {len(gpu_client.local_devices())}")
print(f"GPU platform: {gpu_client.platform}")
except Exception as e:
print(f"GPU not available: {e}")
# Check if specific GPU functionality is available
try:
from jaxlib import gpu_linalg
gpu_ops = gpu_linalg.registrations()
if any(gpu_ops.values()):
print("GPU linear algebra operations available")
else:
print("No GPU linear algebra operations found")
except gpu_common_utils.GpuLibNotLinkedError:
print("GPU library not linked")Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib