0
# Hardware-Specific Operations
1
2
Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.
3
4
## Capabilities
5
6
### LAPACK Operations
7
8
Linear algebra operations using LAPACK for CPU computations.
9
10
```python { .api }
11
# From jaxlib.lapack module
12
13
class EigComputationMode(enum.Enum):
14
"""Eigenvalue computation modes."""
15
16
class SchurComputationMode(enum.Enum):
17
"""Schur decomposition computation modes."""
18
19
class SchurSort(enum.Enum):
20
"""Schur sorting options."""
21
22
LAPACK_DTYPE_PREFIX: dict[type, str]
23
24
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
25
"""
26
Get LAPACK operation registrations.
27
28
Returns:
29
Dictionary mapping platform to list of (name, capsule, api_version) tuples
30
"""
31
32
def batch_partitionable_targets() -> list[str]:
33
"""
34
Get list of batch-partitionable LAPACK targets.
35
36
Returns:
37
List of target names that support batch partitioning
38
"""
39
40
def prepare_lapack_call(fn_base: str, dtype: Any) -> str:
41
"""
42
Initialize LAPACK and return target name.
43
44
Parameters:
45
- fn_base: Base function name
46
- dtype: Data type
47
48
Returns:
49
LAPACK target name for the function and dtype
50
"""
51
52
def build_lapack_fn_target(fn_base: str, dtype: Any) -> str:
53
"""
54
Build LAPACK function target name.
55
56
Parameters:
57
- fn_base: Base function name (e.g., 'getrf')
58
- dtype: NumPy dtype
59
60
Returns:
61
Full LAPACK target name (e.g., 'lapack_sgetrf')
62
"""
63
```
64
65
### GPU Linear Algebra
66
67
GPU-accelerated linear algebra operations using cuBLAS/cuSOLVER or ROCm equivalents.
68
69
```python { .api }
70
# From jaxlib.gpu_linalg module
71
72
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
73
"""
74
Get GPU linear algebra registrations.
75
76
Returns:
77
Dictionary with 'CUDA' and 'ROCM' platform registrations
78
"""
79
80
def batch_partitionable_targets() -> list[str]:
81
"""
82
Get batch-partitionable GPU linalg targets.
83
84
Returns:
85
List of GPU targets supporting batch partitioning
86
"""
87
```
88
89
### GPU Sparse Operations
90
91
Sparse matrix operations optimized for GPU execution.
92
93
```python { .api }
94
# From jaxlib.gpu_sparse module
95
96
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
97
"""Get GPU sparse operation registrations."""
98
```
99
100
### CPU Sparse Operations
101
102
Sparse matrix operations for CPU execution.
103
104
```python { .api }
105
# From jaxlib.cpu_sparse module
106
107
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
108
"""
109
Get CPU sparse operation registrations.
110
111
Returns:
112
Dictionary with CPU sparse operation registrations
113
"""
114
```
115
116
### GPU Utilities
117
118
Common utilities and error handling for GPU operations.
119
120
```python { .api }
121
# From jaxlib.gpu_common_utils module
122
123
class GpuLibNotLinkedError(Exception):
124
"""
125
Exception raised when GPU library is not linked.
126
127
Used when GPU-specific functionality is called but
128
JAX was not built with GPU support.
129
"""
130
131
error_msg: str = (
132
'JAX was not built with GPU support. Please use a GPU-enabled JAX to use'
133
' this function.'
134
)
135
136
def __init__(self): ...
137
```
138
139
### Hardware-Specific Modules
140
141
Additional GPU-specific modules for specialized operations.
142
143
```python { .api }
144
# jaxlib.gpu_prng - GPU pseudo-random number generation
145
# jaxlib.gpu_rnn - GPU recurrent neural network operations
146
# jaxlib.gpu_solver - GPU linear equation solving
147
# jaxlib.gpu_triton - Triton kernel integration
148
```
149
150
## Usage Examples
151
152
### LAPACK Operations
153
154
```python
155
from jaxlib import lapack
156
import numpy as np
157
158
# Check available LAPACK operations
159
lapack_ops = lapack.registrations()
160
print(f"LAPACK operations: {len(lapack_ops['cpu'])}")
161
162
# Prepare LAPACK call for LU factorization
163
dtype = np.float32
164
target_name = lapack.prepare_lapack_call("getrf", dtype)
165
print(f"LAPACK target: {target_name}")
166
167
# Build target name manually
168
manual_target = lapack.build_lapack_fn_target("getrf", dtype)
169
print(f"Manual target: {manual_target}")
170
171
# Check batch-partitionable targets
172
batch_targets = lapack.batch_partitionable_targets()
173
print(f"Batch targets: {batch_targets[:5]}") # Show first 5
174
```
175
176
### GPU Operations
177
178
```python
179
from jaxlib import gpu_linalg, gpu_sparse, gpu_common_utils
180
181
try:
182
# Check GPU linear algebra availability
183
gpu_linalg_ops = gpu_linalg.registrations()
184
print(f"CUDA linalg ops: {len(gpu_linalg_ops.get('CUDA', []))}")
185
print(f"ROCM linalg ops: {len(gpu_linalg_ops.get('ROCM', []))}")
186
187
# Check GPU sparse operations
188
gpu_sparse_ops = gpu_sparse.registrations()
189
print(f"GPU sparse ops available: {len(gpu_sparse_ops)}")
190
191
# Get batch-partitionable GPU targets
192
gpu_batch_targets = gpu_linalg.batch_partitionable_targets()
193
print(f"GPU batch targets: {gpu_batch_targets}")
194
195
except gpu_common_utils.GpuLibNotLinkedError as e:
196
print(f"GPU not available: {e}")
197
```
198
199
### CPU Sparse Operations
200
201
```python
202
from jaxlib import cpu_sparse
203
204
# Get CPU sparse operation registrations
205
cpu_sparse_ops = cpu_sparse.registrations()
206
print(f"CPU sparse operations: {len(cpu_sparse_ops['cpu'])}")
207
208
# Show some operation names
209
if cpu_sparse_ops['cpu']:
210
print("Some CPU sparse operations:")
211
for name, _, api_version in cpu_sparse_ops['cpu'][:3]:
212
print(f" {name} (API v{api_version})")
213
```
214
215
### Checking Hardware Support
216
217
```python
218
from jaxlib import xla_client, gpu_common_utils
219
220
# Create clients to check hardware availability
221
try:
222
cpu_client = xla_client.make_cpu_client()
223
print(f"CPU devices: {len(cpu_client.local_devices())}")
224
except Exception as e:
225
print(f"CPU client error: {e}")
226
227
try:
228
gpu_client = xla_client.make_gpu_client()
229
print(f"GPU devices: {len(gpu_client.local_devices())}")
230
print(f"GPU platform: {gpu_client.platform}")
231
except Exception as e:
232
print(f"GPU not available: {e}")
233
234
# Check if specific GPU functionality is available
235
try:
236
from jaxlib import gpu_linalg
237
gpu_ops = gpu_linalg.registrations()
238
if any(gpu_ops.values()):
239
print("GPU linear algebra operations available")
240
else:
241
print("No GPU linear algebra operations found")
242
except gpu_common_utils.GpuLibNotLinkedError:
243
print("GPU library not linked")
244
```