or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

index.mddocs/

0

# JaxLib

1

2

JaxLib is the XLA library for JAX, serving as the foundational support library that provides low-level binary components including Python bindings to XLA, the PJRT runtime, and handwritten kernels. It enables JAX's high-performance numerical computing capabilities on various hardware accelerators including CPUs, GPUs, and TPUs, supporting automatic differentiation, just-in-time compilation, vectorization, and distributed computing.

3

4

## Package Information

5

6

- **Package Name**: jaxlib

7

- **Language**: Python

8

- **Installation**: `pip install jaxlib`

9

- **Dependencies**: `scipy>=1.12`, `numpy>=1.26`, `ml_dtypes>=0.5.0`

10

- **Hardware Support**: CPU, GPU (CUDA/ROCm), TPU

11

12

## Core Imports

13

14

```python

15

import jaxlib

16

```

17

18

For XLA client operations:

19

20

```python

21

from jaxlib import xla_client

22

```

23

24

## Basic Usage

25

26

```python

27

from jaxlib import xla_client

28

import numpy as np

29

30

# Create a CPU client

31

client = xla_client.make_cpu_client()

32

33

# Create a simple computation

34

def simple_add(a, b):

35

return a + b

36

37

# Convert data to buffers

38

data_a = np.array([1.0, 2.0, 3.0], dtype=np.float32)

39

data_b = np.array([4.0, 5.0, 6.0], dtype=np.float32)

40

41

buffer_a = client.buffer_from_pyval(data_a)

42

buffer_b = client.buffer_from_pyval(data_b)

43

44

print("JaxLib version:", jaxlib.__version__)

45

print("Available devices:", client.devices())

46

print("Platform:", client.platform)

47

```

48

49

## Architecture

50

51

JaxLib implements a layered architecture with clear separation of concerns:

52

53

- **XLA Client Layer**: High-level Python API for XLA operations and compilation

54

- **PJRT Runtime**: Platform-specific runtime for executing compiled programs

55

- **Device Backends**: Hardware-specific implementations (CPU, GPU, TPU)

56

- **Custom Operations**: Extensible system for user-defined operations

57

- **Distributed Computing**: Multi-node execution and communication primitives

58

59

The design enables JAX to transform and scale numerical programs efficiently across different computing platforms through a consistent interface while allowing low-level optimization and hardware-specific acceleration.

60

61

## Capabilities

62

63

### XLA Client Operations

64

65

Core XLA client functionality including client creation, device management, compilation, and execution. Provides the main interface for interacting with XLA backends and managing computational resources.

66

67

```python { .api }

68

def make_cpu_client(

69

asynchronous: bool = True,

70

distributed_client: DistributedRuntimeClient | None = None,

71

node_id: int = 0,

72

num_nodes: int = 1,

73

collectives: CpuCollectives | None = None,

74

num_devices: int | None = None,

75

get_local_topology_timeout_minutes: int | None = None,

76

get_global_topology_timeout_minutes: int | None = None,

77

transfer_server_factory: TransferServerInterfaceFactory | None = None,

78

) -> Client: ...

79

80

def make_gpu_client(

81

distributed_client: DistributedRuntimeClient | None = None,

82

node_id: int = 0,

83

num_nodes: int = 1,

84

platform_name: str | None = None,

85

allowed_devices: set[int] | None = None,

86

mock: bool | None = None,

87

mock_gpu_topology: str | None = None,

88

) -> Client: ...

89

90

def make_c_api_client(

91

plugin_name: str,

92

options: dict[str, str | int | list[int] | float | bool] | None = None,

93

distributed_client: DistributedRuntimeClient | None = None,

94

transfer_server_factory: TransferServerInterfaceFactory | None = None,

95

) -> Client: ...

96

```

97

98

[XLA Client](./xla-client.md)

99

100

### Device and Memory Management

101

102

Device discovery, selection, and memory management across different hardware platforms. Handles device topology, memory spaces, and resource allocation for optimal performance.

103

104

```python { .api }

105

class Device:

106

id: int

107

host_id: int

108

process_index: int

109

platform: str

110

device_kind: str

111

client: Client

112

local_hardware_id: int | None

113

114

def memory(self, kind: str) -> Memory: ...

115

def default_memory(self) -> Memory: ...

116

def addressable_memories(self) -> list[Memory]: ...

117

def memory_stats(self) -> dict[str, int] | None: ...

118

119

class DeviceList:

120

def __init__(self, device_assignment: tuple[Device, ...]): ...

121

def __len__(self) -> int: ...

122

def __getitem__(self, index: Any) -> Any: ...

123

def __iter__(self) -> Iterator[Device]: ...

124

125

@property

126

def is_fully_addressable(self) -> bool: ...

127

@property

128

def addressable_device_list(self) -> DeviceList: ...

129

@property

130

def process_indices(self) -> set[int]: ...

131

@property

132

def default_memory_kind(self) -> str | None: ...

133

@property

134

def memory_kinds(self) -> tuple[str, ...]: ...

135

@property

136

def device_kind(self) -> str: ...

137

```

138

139

[Device Management](./device-management.md)

140

141

### Compilation and Execution

142

143

XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes.

144

145

```python { .api }

146

class Client:

147

platform: str

148

platform_version: str

149

runtime_type: str

150

151

def compile(

152

self,

153

computation: str | bytes,

154

executable_devices: DeviceList | Sequence[Device],

155

compile_options: CompileOptions = ...,

156

) -> Executable: ...

157

158

def compile_and_load(

159

self,

160

computation: str | bytes,

161

executable_devices: DeviceList | Sequence[Device],

162

compile_options: CompileOptions = ...,

163

host_callbacks: Sequence[Any] = ...,

164

) -> LoadedExecutable: ...

165

166

class LoadedExecutable:

167

client: Client

168

169

def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ...

170

def execute_sharded(

171

self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ...

172

) -> ExecuteResults: ...

173

def hlo_modules(self) -> list[HloModule]: ...

174

def get_output_memory_kinds(self) -> list[list[str]]: ...

175

def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...

176

```

177

178

[Compilation and Execution](./compilation-execution.md)

179

180

### Array and Buffer Operations

181

182

High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.

183

184

```python { .api }

185

def batched_device_put(

186

aval: Any,

187

sharding: Any,

188

shards: Sequence[Any],

189

devices: list[Device],

190

committed: bool = ...,

191

force_copy: bool = ...,

192

host_buffer_semantics: Any = ...,

193

) -> ArrayImpl: ...

194

195

def batched_copy_array_to_devices_with_sharding(

196

arrays: Sequence[ArrayImpl],

197

devices: Sequence[DeviceList],

198

sharding: Sequence[Any],

199

array_copy_semantics: Sequence[ArrayCopySemantics],

200

) -> Sequence[ArrayImpl]: ...

201

202

def reorder_shards(

203

x: ArrayImpl,

204

dst_sharding: Any,

205

array_copy_semantics: ArrayCopySemantics,

206

) -> ArrayImpl: ...

207

208

def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...

209

```

210

211

[Array Operations](./array-operations.md)

212

213

### Sharding and Distribution

214

215

Sharding strategies for distributing computations across multiple devices and nodes, including SPMD, GSPMD, and custom sharding patterns.

216

217

```python { .api }

218

class Sharding: ...

219

220

class NamedSharding(Sharding):

221

def __init__(

222

self,

223

mesh: Any,

224

spec: Any,

225

*,

226

memory_kind: str | None = None,

227

_logical_device_ids: tuple[int, ...] | None = None,

228

): ...

229

mesh: Any

230

spec: Any

231

232

class SingleDeviceSharding(Sharding):

233

def __init__(self, device: Device, *, memory_kind: str | None = None): ...

234

235

class GSPMDSharding(Sharding):

236

def __init__(

237

self,

238

devices: Sequence[Device],

239

op_sharding: OpSharding | HloSharding,

240

*,

241

memory_kind: str | None = None,

242

_device_list: DeviceList | None = None,

243

): ...

244

245

class HloSharding:

246

@staticmethod

247

def from_proto(proto: OpSharding) -> HloSharding: ...

248

@staticmethod

249

def replicate() -> HloSharding: ...

250

@staticmethod

251

def manual() -> HloSharding: ...

252

253

def is_replicated(self) -> bool: ...

254

def is_tiled(self) -> bool: ...

255

def num_devices(self) -> int: ...

256

```

257

258

[Sharding](./sharding.md)

259

260

### Custom Operations

261

262

Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.

263

264

```python { .api }

265

class CustomCallTargetTraits(enum.IntFlag):

266

DEFAULT = 0

267

COMMAND_BUFFER_COMPATIBLE = 1

268

269

def register_custom_call_target(

270

name: str,

271

fn: Any,

272

platform: str = 'cpu',

273

api_version: int = 0,

274

traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,

275

) -> None: ...

276

277

def register_custom_call_handler(

278

platform: str, handler: CustomCallHandler

279

) -> None: ...

280

281

def register_custom_call_partitioner(

282

name: str,

283

prop_user_sharding: Callable,

284

partition: Callable,

285

infer_sharding_from_operands: Callable,

286

can_side_effecting_have_replicated_sharding: bool = ...,

287

c_api: Any | None = ...,

288

) -> None: ...

289

290

def custom_call_targets(platform: str) -> dict[str, Any]: ...

291

```

292

293

[Custom Operations](./custom-operations.md)

294

295

### Hardware-Specific Operations

296

297

Specialized operations for different hardware platforms including CPU linear algebra, GPU kernels, and sparse matrix operations.

298

299

```python { .api }

300

# LAPACK operations

301

def registrations() -> dict[str, list[tuple[str, Any, int]]]: ...

302

def prepare_lapack_call(fn_base: str, dtype: Any) -> str: ...

303

304

# GPU operations

305

def gpu_linalg.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...

306

def gpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...

307

308

# CPU sparse operations

309

def cpu_sparse.registrations() -> dict[str, list[tuple[str, Any, int]]]: ...

310

```

311

312

[Hardware-Specific Operations](./hardware-operations.md)

313

314

### Plugin System

315

316

Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.

317

318

```python { .api }

319

def import_from_plugin(

320

plugin_name: str,

321

submodule_name: str,

322

*,

323

check_version: bool = True

324

) -> ModuleType | None: ...

325

326

def check_plugin_version(

327

plugin_name: str,

328

jaxlib_version: str,

329

plugin_version: str

330

) -> bool: ...

331

332

def pjrt_plugin_loaded(plugin_name: str) -> bool: ...

333

334

def load_pjrt_plugin_dynamically(

335

plugin_name: str, library_path: str

336

) -> Any: ...

337

338

def initialize_pjrt_plugin(plugin_name: str) -> None: ...

339

```

340

341

[Plugin System](./plugin-system.md)

342

343

## Types

344

345

```python { .api }

346

# Core types

347

class Shape:

348

def __init__(self, s: str): ...

349

@staticmethod

350

def array_shape(

351

type: np.dtype | PrimitiveType,

352

dims_seq: Any = ...,

353

layout_seq: Any = ...,

354

dynamic_dimensions: list[bool] | None = ...,

355

) -> Shape: ...

356

357

def dimensions(self) -> tuple[int, ...]: ...

358

def rank(self) -> int: ...

359

def is_array(self) -> bool: ...

360

def is_tuple(self) -> bool: ...

361

362

class PrimitiveType(enum.IntEnum):

363

PRED = ...

364

S8 = ...

365

S16 = ...

366

S32 = ...

367

S64 = ...

368

U8 = ...

369

U16 = ...

370

U32 = ...

371

U64 = ...

372

F16 = ...

373

F32 = ...

374

F64 = ...

375

BF16 = ...

376

C64 = ...

377

C128 = ...

378

379

class ArrayCopySemantics(enum.IntEnum):

380

ALWAYS_COPY = ...

381

REUSE_INPUT = ...

382

DONATE_INPUT = ...

383

384

class HostBufferSemantics(enum.IntEnum):

385

IMMUTABLE_ONLY_DURING_CALL = ...

386

IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ...

387

ZERO_COPY = ...

388

389

# Exception types

390

class XlaRuntimeError(RuntimeError): ...

391

392

class GpuLibNotLinkedError(Exception):

393

"""Raised when the GPU library is not linked."""

394

```