CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

hardware-operations.mddocs/

Hardware-Specific Operations

Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.

Capabilities

LAPACK Operations

Linear algebra operations using LAPACK for CPU computations.

# From jaxlib.lapack module

class EigComputationMode(enum.Enum):
    """Eigenvalue computation modes."""

class SchurComputationMode(enum.Enum):
    """Schur decomposition computation modes."""

class SchurSort(enum.Enum):
    """Schur sorting options."""

LAPACK_DTYPE_PREFIX: dict[type, str]

def registrations() -> dict[str, list[tuple[str, Any, int]]]:
    """
    Get LAPACK operation registrations.
    
    Returns:
    Dictionary mapping platform to list of (name, capsule, api_version) tuples
    """

def batch_partitionable_targets() -> list[str]:
    """
    Get list of batch-partitionable LAPACK targets.
    
    Returns:
    List of target names that support batch partitioning
    """

def prepare_lapack_call(fn_base: str, dtype: Any) -> str:
    """
    Initialize LAPACK and return target name.
    
    Parameters:
    - fn_base: Base function name
    - dtype: Data type
    
    Returns:
    LAPACK target name for the function and dtype
    """

def build_lapack_fn_target(fn_base: str, dtype: Any) -> str:
    """
    Build LAPACK function target name.
    
    Parameters:
    - fn_base: Base function name (e.g., 'getrf')
    - dtype: NumPy dtype
    
    Returns:
    Full LAPACK target name (e.g., 'lapack_sgetrf')
    """

GPU Linear Algebra

GPU-accelerated linear algebra operations using cuBLAS/cuSOLVER or ROCm equivalents.

# From jaxlib.gpu_linalg module

def registrations() -> dict[str, list[tuple[str, Any, int]]]:
    """
    Get GPU linear algebra registrations.
    
    Returns:
    Dictionary with 'CUDA' and 'ROCM' platform registrations
    """

def batch_partitionable_targets() -> list[str]:
    """
    Get batch-partitionable GPU linalg targets.
    
    Returns:
    List of GPU targets supporting batch partitioning
    """

GPU Sparse Operations

Sparse matrix operations optimized for GPU execution.

# From jaxlib.gpu_sparse module

def registrations() -> dict[str, list[tuple[str, Any, int]]]:
    """Get GPU sparse operation registrations."""

CPU Sparse Operations

Sparse matrix operations for CPU execution.

# From jaxlib.cpu_sparse module

def registrations() -> dict[str, list[tuple[str, Any, int]]]:
    """
    Get CPU sparse operation registrations.
    
    Returns:
    Dictionary with CPU sparse operation registrations
    """

GPU Utilities

Common utilities and error handling for GPU operations.

# From jaxlib.gpu_common_utils module

class GpuLibNotLinkedError(Exception):
    """
    Exception raised when GPU library is not linked.
    
    Used when GPU-specific functionality is called but
    JAX was not built with GPU support.
    """
    
    error_msg: str = (
        'JAX was not built with GPU support. Please use a GPU-enabled JAX to use'
        ' this function.'
    )
    
    def __init__(self): ...

Hardware-Specific Modules

Additional GPU-specific modules for specialized operations.

# jaxlib.gpu_prng - GPU pseudo-random number generation
# jaxlib.gpu_rnn - GPU recurrent neural network operations  
# jaxlib.gpu_solver - GPU linear equation solving
# jaxlib.gpu_triton - Triton kernel integration

Usage Examples

LAPACK Operations

from jaxlib import lapack
import numpy as np

# Check available LAPACK operations
lapack_ops = lapack.registrations()
print(f"LAPACK operations: {len(lapack_ops['cpu'])}")

# Prepare LAPACK call for LU factorization
dtype = np.float32
target_name = lapack.prepare_lapack_call("getrf", dtype)
print(f"LAPACK target: {target_name}")

# Build target name manually
manual_target = lapack.build_lapack_fn_target("getrf", dtype)
print(f"Manual target: {manual_target}")

# Check batch-partitionable targets
batch_targets = lapack.batch_partitionable_targets()
print(f"Batch targets: {batch_targets[:5]}")  # Show first 5

GPU Operations

from jaxlib import gpu_linalg, gpu_sparse, gpu_common_utils

try:
    # Check GPU linear algebra availability
    gpu_linalg_ops = gpu_linalg.registrations()
    print(f"CUDA linalg ops: {len(gpu_linalg_ops.get('CUDA', []))}")
    print(f"ROCM linalg ops: {len(gpu_linalg_ops.get('ROCM', []))}")
    
    # Check GPU sparse operations
    gpu_sparse_ops = gpu_sparse.registrations()
    print(f"GPU sparse ops available: {len(gpu_sparse_ops)}")
    
    # Get batch-partitionable GPU targets
    gpu_batch_targets = gpu_linalg.batch_partitionable_targets()
    print(f"GPU batch targets: {gpu_batch_targets}")
    
except gpu_common_utils.GpuLibNotLinkedError as e:
    print(f"GPU not available: {e}")

CPU Sparse Operations

from jaxlib import cpu_sparse

# Get CPU sparse operation registrations
cpu_sparse_ops = cpu_sparse.registrations()
print(f"CPU sparse operations: {len(cpu_sparse_ops['cpu'])}")

# Show some operation names
if cpu_sparse_ops['cpu']:
    print("Some CPU sparse operations:")
    for name, _, api_version in cpu_sparse_ops['cpu'][:3]:
        print(f"  {name} (API v{api_version})")

Checking Hardware Support

from jaxlib import xla_client, gpu_common_utils

# Create clients to check hardware availability
try:
    cpu_client = xla_client.make_cpu_client()
    print(f"CPU devices: {len(cpu_client.local_devices())}")
except Exception as e:
    print(f"CPU client error: {e}")

try:
    gpu_client = xla_client.make_gpu_client()
    print(f"GPU devices: {len(gpu_client.local_devices())}")
    print(f"GPU platform: {gpu_client.platform}")
except Exception as e:
    print(f"GPU not available: {e}")

# Check if specific GPU functionality is available
try:
    from jaxlib import gpu_linalg
    gpu_ops = gpu_linalg.registrations()
    if any(gpu_ops.values()):
        print("GPU linear algebra operations available")
    else:
        print("No GPU linear algebra operations found")
except gpu_common_utils.GpuLibNotLinkedError:
    print("GPU library not linked")

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json