or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

xla-client.mddocs/

0

# XLA Client Operations

1

2

Core XLA client functionality providing the main interface for interacting with XLA backends, managing computational resources, and creating clients for different hardware platforms.

3

4

## Capabilities

5

6

### Client Creation

7

8

Factory functions for creating XLA clients targeting different hardware platforms with platform-specific configuration options.

9

10

```python { .api }

11

def make_cpu_client(

12

asynchronous: bool = True,

13

distributed_client: DistributedRuntimeClient | None = None,

14

node_id: int = 0,

15

num_nodes: int = 1,

16

collectives: CpuCollectives | None = None,

17

num_devices: int | None = None,

18

get_local_topology_timeout_minutes: int | None = None,

19

get_global_topology_timeout_minutes: int | None = None,

20

transfer_server_factory: TransferServerInterfaceFactory | None = None,

21

) -> Client:

22

"""

23

Create a CPU client for XLA computations.

24

25

Parameters:

26

- asynchronous: Whether to use asynchronous execution

27

- distributed_client: Client for distributed computing

28

- node_id: Node identifier in distributed setup

29

- num_nodes: Total number of nodes

30

- collectives: CPU collective operations interface

31

- num_devices: Number of CPU devices to use

32

- get_local_topology_timeout_minutes: Timeout for local topology

33

- get_global_topology_timeout_minutes: Timeout for global topology

34

- transfer_server_factory: Factory for transfer servers

35

36

Returns:

37

XLA Client configured for CPU execution

38

"""

39

40

def make_gpu_client(

41

distributed_client: DistributedRuntimeClient | None = None,

42

node_id: int = 0,

43

num_nodes: int = 1,

44

platform_name: str | None = None,

45

allowed_devices: set[int] | None = None,

46

mock: bool | None = None,

47

mock_gpu_topology: str | None = None,

48

) -> Client:

49

"""

50

Create a GPU client for XLA computations.

51

52

Parameters:

53

- distributed_client: Client for distributed computing

54

- node_id: Node identifier in distributed setup

55

- num_nodes: Total number of nodes

56

- platform_name: GPU platform name ('cuda' or 'rocm')

57

- allowed_devices: Set of allowed GPU device IDs

58

- mock: Whether to use mock GPU for testing

59

- mock_gpu_topology: Mock topology specification

60

61

Returns:

62

XLA Client configured for GPU execution

63

"""

64

65

def make_c_api_client(

66

plugin_name: str,

67

options: dict[str, str | int | list[int] | float | bool] | None = None,

68

distributed_client: DistributedRuntimeClient | None = None,

69

transfer_server_factory: TransferServerInterfaceFactory | None = None,

70

) -> Client:

71

"""

72

Create a client using the PJRT C API for plugins.

73

74

Parameters:

75

- plugin_name: Name of the PJRT plugin

76

- options: Platform-specific options dictionary

77

- distributed_client: Client for distributed computing

78

- transfer_server_factory: Factory for transfer servers

79

80

Returns:

81

XLA Client using the specified plugin

82

"""

83

```

84

85

### Client Interface

86

87

The main Client class providing access to devices, compilation, and execution capabilities.

88

89

```python { .api }

90

class Client:

91

"""XLA client for managing devices and executing computations."""

92

93

platform: str

94

platform_version: str

95

runtime_type: str

96

97

def device_count(self) -> int:

98

"""Get total number of devices."""

99

100

def local_device_count(self) -> int:

101

"""Get number of local devices."""

102

103

def devices(self) -> list[Device]:

104

"""Get all available devices."""

105

106

def local_devices(self) -> list[Device]:

107

"""Get locally available devices."""

108

109

def host_id(self) -> int:

110

"""Get host identifier."""

111

112

def process_index(self) -> int:

113

"""Get process index in distributed setup."""

114

115

def buffer_from_pyval(

116

self,

117

argument: Any,

118

device: Device | None = None,

119

force_copy: bool = False,

120

host_buffer_semantics: HostBufferSemantics = ...,

121

) -> ArrayImpl:

122

"""

123

Create a buffer from Python value.

124

125

Parameters:

126

- argument: Python value to convert

127

- device: Target device (None for default)

128

- force_copy: Force copying even if not necessary

129

- host_buffer_semantics: How to handle host buffer

130

131

Returns:

132

Array buffer on the specified device

133

"""

134

135

def live_buffers(self) -> list[Any]:

136

"""Get list of live buffers."""

137

138

def live_arrays(self) -> list[ArrayImpl]:

139

"""Get list of live arrays."""

140

141

def live_executables(self) -> list[LoadedExecutable]:

142

"""Get list of live executables."""

143

144

def heap_profile(self) -> bytes:

145

"""Get heap profile for memory debugging."""

146

```

147

148

### Execution Utilities

149

150

Thread-level execution control for managing computation streams.

151

152

```python { .api }

153

def execution_stream_id(new_id: int):

154

"""

155

Context manager that overwrites and restores the current thread's execution_stream_id.

156

157

Parameters:

158

- new_id: New execution stream ID to set for the current thread

159

160

Returns:

161

Context manager that restores the original execution stream ID on exit

162

163

Usage:

164

with execution_stream_id(42):

165

# Code executed with stream ID 42

166

pass

167

# Original stream ID restored

168

"""

169

```

170

171

### GPU Plugin Options

172

173

Utilities for configuring GPU-specific options and plugin parameters.

174

175

```python { .api }

176

def generate_pjrt_gpu_plugin_options() -> dict[str, str | int | list[int] | float | bool]:

177

"""

178

Generate PjRt GPU plugin options from environment variables.

179

180

Reads configuration from environment variables:

181

- XLA_PYTHON_CLIENT_ALLOCATOR: Memory allocator type

182

- XLA_CLIENT_MEM_FRACTION: GPU memory fraction to use

183

- XLA_PYTHON_CLIENT_PREALLOCATE: Whether to preallocate memory

184

- XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB: Collective memory size

185

186

Returns:

187

Dictionary of plugin options

188

"""

189

```

190

191

### Topology Management

192

193

Functions for managing device topology and creating topology descriptions for different platforms.

194

195

```python { .api }

196

def make_tfrt_tpu_c_api_device_topology(

197

topology_name: str | None = None, **kwargs

198

) -> DeviceTopology:

199

"""

200

Create TPU device topology using TFRT C API.

201

202

Parameters:

203

- topology_name: Name of the topology

204

- **kwargs: Additional topology options

205

206

Returns:

207

DeviceTopology for TPU devices

208

"""

209

210

def make_c_api_device_topology(

211

c_api: Any, topology_name: str = '', **kwargs

212

) -> DeviceTopology:

213

"""

214

Create device topology using C API.

215

216

Parameters:

217

- c_api: C API interface

218

- topology_name: Name of the topology

219

- **kwargs: Additional topology options

220

221

Returns:

222

DeviceTopology for the specified platform

223

"""

224

225

def get_topology_for_devices(devices: list[Device]) -> DeviceTopology:

226

"""

227

Get topology description for a list of devices.

228

229

Parameters:

230

- devices: List of devices

231

232

Returns:

233

DeviceTopology describing the device layout

234

"""

235

```

236

237

### Distributed Runtime

238

239

Classes and functions for managing distributed computing across multiple nodes and processes.

240

241

```python { .api }

242

class DistributedRuntimeClient:

243

"""Client for distributed runtime coordination."""

244

245

def connect(self) -> Any:

246

"""Connect to distributed runtime service."""

247

248

def shutdown(self) -> Any:

249

"""Shutdown the distributed runtime client."""

250

251

def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> Any:

252

"""Blocking get operation for key-value store."""

253

254

def key_value_set(

255

self, key: str, value: str, allow_overwrite: bool = False

256

) -> Any:

257

"""Set operation for key-value store."""

258

259

def wait_at_barrier(

260

self,

261

barrier_id: str,

262

timeout_in_ms: int,

263

process_ids: list[int] | None = None,

264

) -> Any:

265

"""Wait at a named barrier for synchronization."""

266

267

class DistributedRuntimeService:

268

"""Service for distributed runtime coordination."""

269

270

def shutdown(self) -> None:

271

"""Shutdown the distributed runtime service."""

272

273

def get_distributed_runtime_service(

274

address: str,

275

num_nodes: int,

276

heartbeat_timeout: int | None = None,

277

cluster_register_timeout: int | None = None,

278

shutdown_timeout: int | None = None,

279

) -> DistributedRuntimeService:

280

"""

281

Create a distributed runtime service.

282

283

Parameters:

284

- address: Service address

285

- num_nodes: Number of nodes in cluster

286

- heartbeat_timeout: Heartbeat timeout in milliseconds

287

- cluster_register_timeout: Cluster registration timeout

288

- shutdown_timeout: Shutdown timeout

289

290

Returns:

291

DistributedRuntimeService instance

292

"""

293

294

def get_distributed_runtime_client(

295

address: str,

296

node_id: int,

297

rpc_timeout: int | None = None,

298

init_timeout: int | None = None,

299

shutdown_timeout: int | None = None,

300

heartbeat_timeout: int | None = None,

301

missed_heartbeat_callback: Any | None = None,

302

shutdown_on_destruction: bool | None = None,

303

use_compression: bool | None = None,

304

recoverable: bool | None = None,

305

) -> DistributedRuntimeClient:

306

"""

307

Create a distributed runtime client.

308

309

Parameters:

310

- address: Service address to connect to

311

- node_id: Unique node identifier

312

- rpc_timeout: RPC timeout in milliseconds

313

- init_timeout: Initialization timeout

314

- shutdown_timeout: Shutdown timeout

315

- heartbeat_timeout: Heartbeat timeout

316

- missed_heartbeat_callback: Callback for missed heartbeats

317

- shutdown_on_destruction: Whether to shutdown on destruction

318

- use_compression: Whether to use compression

319

- recoverable: Whether the client is recoverable

320

321

Returns:

322

DistributedRuntimeClient instance

323

"""

324

```

325

326

## Usage Examples

327

328

### Basic Client Setup

329

330

```python

331

from jaxlib import xla_client

332

333

# Create a CPU client

334

cpu_client = xla_client.make_cpu_client(asynchronous=True)

335

print(f"CPU devices: {cpu_client.local_devices()}")

336

337

# Create a GPU client (if available)

338

try:

339

gpu_client = xla_client.make_gpu_client(platform_name='cuda')

340

print(f"GPU devices: {gpu_client.local_devices()}")

341

except Exception as e:

342

print(f"GPU not available: {e}")

343

```

344

345

### Distributed Setup

346

347

```python

348

from jaxlib import xla_client

349

350

# Start distributed runtime service on coordinator

351

service = xla_client.get_distributed_runtime_service(

352

address="localhost:1234",

353

num_nodes=2,

354

heartbeat_timeout=60000

355

)

356

357

# Connect distributed client on each node

358

dist_client = xla_client.get_distributed_runtime_client(

359

address="localhost:1234",

360

node_id=0, # Different for each node

361

init_timeout=30000

362

)

363

364

# Create client with distributed support

365

client = xla_client.make_cpu_client(

366

distributed_client=dist_client,

367

node_id=0,

368

num_nodes=2

369

)

370

```