or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

hardware-operations.mddocs/

0

# Hardware-Specific Operations

1

2

Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.

3

4

## Capabilities

5

6

### LAPACK Operations

7

8

Linear algebra operations using LAPACK for CPU computations.

9

10

```python { .api }

11

# From jaxlib.lapack module

12

13

class EigComputationMode(enum.Enum):

14

"""Eigenvalue computation modes."""

15

16

class SchurComputationMode(enum.Enum):

17

"""Schur decomposition computation modes."""

18

19

class SchurSort(enum.Enum):

20

"""Schur sorting options."""

21

22

LAPACK_DTYPE_PREFIX: dict[type, str]

23

24

def registrations() -> dict[str, list[tuple[str, Any, int]]]:

25

"""

26

Get LAPACK operation registrations.

27

28

Returns:

29

Dictionary mapping platform to list of (name, capsule, api_version) tuples

30

"""

31

32

def batch_partitionable_targets() -> list[str]:

33

"""

34

Get list of batch-partitionable LAPACK targets.

35

36

Returns:

37

List of target names that support batch partitioning

38

"""

39

40

def prepare_lapack_call(fn_base: str, dtype: Any) -> str:

41

"""

42

Initialize LAPACK and return target name.

43

44

Parameters:

45

- fn_base: Base function name

46

- dtype: Data type

47

48

Returns:

49

LAPACK target name for the function and dtype

50

"""

51

52

def build_lapack_fn_target(fn_base: str, dtype: Any) -> str:

53

"""

54

Build LAPACK function target name.

55

56

Parameters:

57

- fn_base: Base function name (e.g., 'getrf')

58

- dtype: NumPy dtype

59

60

Returns:

61

Full LAPACK target name (e.g., 'lapack_sgetrf')

62

"""

63

```

64

65

### GPU Linear Algebra

66

67

GPU-accelerated linear algebra operations using cuBLAS/cuSOLVER or ROCm equivalents.

68

69

```python { .api }

70

# From jaxlib.gpu_linalg module

71

72

def registrations() -> dict[str, list[tuple[str, Any, int]]]:

73

"""

74

Get GPU linear algebra registrations.

75

76

Returns:

77

Dictionary with 'CUDA' and 'ROCM' platform registrations

78

"""

79

80

def batch_partitionable_targets() -> list[str]:

81

"""

82

Get batch-partitionable GPU linalg targets.

83

84

Returns:

85

List of GPU targets supporting batch partitioning

86

"""

87

```

88

89

### GPU Sparse Operations

90

91

Sparse matrix operations optimized for GPU execution.

92

93

```python { .api }

94

# From jaxlib.gpu_sparse module

95

96

def registrations() -> dict[str, list[tuple[str, Any, int]]]:

97

"""Get GPU sparse operation registrations."""

98

```

99

100

### CPU Sparse Operations

101

102

Sparse matrix operations for CPU execution.

103

104

```python { .api }

105

# From jaxlib.cpu_sparse module

106

107

def registrations() -> dict[str, list[tuple[str, Any, int]]]:

108

"""

109

Get CPU sparse operation registrations.

110

111

Returns:

112

Dictionary with CPU sparse operation registrations

113

"""

114

```

115

116

### GPU Utilities

117

118

Common utilities and error handling for GPU operations.

119

120

```python { .api }

121

# From jaxlib.gpu_common_utils module

122

123

class GpuLibNotLinkedError(Exception):

124

"""

125

Exception raised when GPU library is not linked.

126

127

Used when GPU-specific functionality is called but

128

JAX was not built with GPU support.

129

"""

130

131

error_msg: str = (

132

'JAX was not built with GPU support. Please use a GPU-enabled JAX to use'

133

' this function.'

134

)

135

136

def __init__(self): ...

137

```

138

139

### Hardware-Specific Modules

140

141

Additional GPU-specific modules for specialized operations.

142

143

```python { .api }

144

# jaxlib.gpu_prng - GPU pseudo-random number generation

145

# jaxlib.gpu_rnn - GPU recurrent neural network operations

146

# jaxlib.gpu_solver - GPU linear equation solving

147

# jaxlib.gpu_triton - Triton kernel integration

148

```

149

150

## Usage Examples

151

152

### LAPACK Operations

153

154

```python

155

from jaxlib import lapack

156

import numpy as np

157

158

# Check available LAPACK operations

159

lapack_ops = lapack.registrations()

160

print(f"LAPACK operations: {len(lapack_ops['cpu'])}")

161

162

# Prepare LAPACK call for LU factorization

163

dtype = np.float32

164

target_name = lapack.prepare_lapack_call("getrf", dtype)

165

print(f"LAPACK target: {target_name}")

166

167

# Build target name manually

168

manual_target = lapack.build_lapack_fn_target("getrf", dtype)

169

print(f"Manual target: {manual_target}")

170

171

# Check batch-partitionable targets

172

batch_targets = lapack.batch_partitionable_targets()

173

print(f"Batch targets: {batch_targets[:5]}") # Show first 5

174

```

175

176

### GPU Operations

177

178

```python

179

from jaxlib import gpu_linalg, gpu_sparse, gpu_common_utils

180

181

try:

182

# Check GPU linear algebra availability

183

gpu_linalg_ops = gpu_linalg.registrations()

184

print(f"CUDA linalg ops: {len(gpu_linalg_ops.get('CUDA', []))}")

185

print(f"ROCM linalg ops: {len(gpu_linalg_ops.get('ROCM', []))}")

186

187

# Check GPU sparse operations

188

gpu_sparse_ops = gpu_sparse.registrations()

189

print(f"GPU sparse ops available: {len(gpu_sparse_ops)}")

190

191

# Get batch-partitionable GPU targets

192

gpu_batch_targets = gpu_linalg.batch_partitionable_targets()

193

print(f"GPU batch targets: {gpu_batch_targets}")

194

195

except gpu_common_utils.GpuLibNotLinkedError as e:

196

print(f"GPU not available: {e}")

197

```

198

199

### CPU Sparse Operations

200

201

```python

202

from jaxlib import cpu_sparse

203

204

# Get CPU sparse operation registrations

205

cpu_sparse_ops = cpu_sparse.registrations()

206

print(f"CPU sparse operations: {len(cpu_sparse_ops['cpu'])}")

207

208

# Show some operation names

209

if cpu_sparse_ops['cpu']:

210

print("Some CPU sparse operations:")

211

for name, _, api_version in cpu_sparse_ops['cpu'][:3]:

212

print(f" {name} (API v{api_version})")

213

```

214

215

### Checking Hardware Support

216

217

```python

218

from jaxlib import xla_client, gpu_common_utils

219

220

# Create clients to check hardware availability

221

try:

222

cpu_client = xla_client.make_cpu_client()

223

print(f"CPU devices: {len(cpu_client.local_devices())}")

224

except Exception as e:

225

print(f"CPU client error: {e}")

226

227

try:

228

gpu_client = xla_client.make_gpu_client()

229

print(f"GPU devices: {len(gpu_client.local_devices())}")

230

print(f"GPU platform: {gpu_client.platform}")

231

except Exception as e:

232

print(f"GPU not available: {e}")

233

234

# Check if specific GPU functionality is available

235

try:

236

from jaxlib import gpu_linalg

237

gpu_ops = gpu_linalg.registrations()

238

if any(gpu_ops.values()):

239

print("GPU linear algebra operations available")

240

else:

241

print("No GPU linear algebra operations found")

242

except gpu_common_utils.GpuLibNotLinkedError:

243

print("GPU library not linked")

244

```