or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

compilation-execution.mddocs/

0

# Compilation and Execution

1

2

XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes. Provides the core functionality for transforming high-level computations into optimized executable code.

3

4

## Capabilities

5

6

### Compilation Options

7

8

Configuration options for controlling XLA compilation behavior and optimizations.

9

10

```python { .api }

11

class CompileOptions:

12

"""Options for XLA compilation."""

13

14

@staticmethod

15

def ParseFromString(s: bytes) -> CompileOptions:

16

"""Parse compilation options from serialized bytes."""

17

18

def __init__(self) -> None: ...

19

20

def SerializeAsString(self) -> bytes:

21

"""Serialize compilation options to bytes."""

22

23

argument_layouts: list[Shape] | None

24

parameter_is_tupled_arguments: bool

25

executable_build_options: ExecutableBuildOptions

26

tuple_arguments: bool

27

num_replicas: int

28

num_partitions: int

29

profile_version: int

30

device_assignment: DeviceAssignment | None

31

compile_portable_executable: bool

32

env_option_overrides: list[tuple[str, str]]

33

34

class ExecutableBuildOptions:

35

"""Options for building executables."""

36

37

def __init__(self) -> None: ...

38

39

result_layout: Shape | None

40

fdo_profile: bytes | None

41

num_replicas: int

42

num_partitions: int

43

debug_options: DebugOptions

44

device_assignment: DeviceAssignment | None

45

use_spmd_partitioning: bool

46

use_auto_spmd_partitioning: bool

47

auto_spmd_partitioning_mesh_shape: list[int]

48

auto_spmd_partitioning_mesh_ids: list[int]

49

use_shardy_partitioner: bool

50

51

def compilation_environments_from_serialized_proto(

52

self, serialized_proto: bytes

53

) -> None:

54

"""Set compilation environments from serialized proto."""

55

56

class DebugOptions:

57

"""Debug and optimization options for XLA."""

58

59

xla_cpu_enable_fast_math: bool

60

xla_gpu_enable_fast_min_max: bool

61

xla_backend_optimization_level: int

62

xla_cpu_enable_xprof_traceme: bool

63

xla_force_host_platform_device_count: int

64

xla_dump_to: str

65

xla_dump_hlo_module_re: str

66

xla_dump_hlo_as_text: bool

67

xla_dump_hlo_as_proto: bool

68

xla_detailed_logging: bool

69

xla_enable_dumping: bool

70

```

71

72

### Compilation Interface

73

74

Client methods for compiling XLA computations into executable forms.

75

76

```python { .api }

77

class Client:

78

"""XLA client compilation interface."""

79

80

def compile(

81

self,

82

computation: str | bytes,

83

executable_devices: DeviceList | Sequence[Device],

84

compile_options: CompileOptions = ...,

85

) -> Executable:

86

"""

87

Compile XLA computation to executable.

88

89

Parameters:

90

- computation: HLO module as string or serialized bytes

91

- executable_devices: Target devices for execution

92

- compile_options: Compilation configuration options

93

94

Returns:

95

Compiled Executable object

96

"""

97

98

def compile_and_load(

99

self,

100

computation: str | bytes,

101

executable_devices: DeviceList | Sequence[Device],

102

compile_options: CompileOptions = ...,

103

host_callbacks: Sequence[Any] = ...,

104

) -> LoadedExecutable:

105

"""

106

Compile and load XLA computation for execution.

107

108

Parameters:

109

- computation: HLO module as string or serialized bytes

110

- executable_devices: Target devices for execution

111

- compile_options: Compilation configuration options

112

- host_callbacks: Host callback functions

113

114

Returns:

115

LoadedExecutable ready for execution

116

"""

117

118

def serialize_executable(self, executable: LoadedExecutable) -> bytes:

119

"""

120

Serialize loaded executable to bytes.

121

122

Parameters:

123

- executable: LoadedExecutable to serialize

124

125

Returns:

126

Serialized executable as bytes

127

"""

128

129

def deserialize_executable(

130

self,

131

serialized: bytes,

132

executable_devices: DeviceList | Sequence[Device],

133

options: CompileOptions | None,

134

host_callbacks: Sequence[Any] = ...,

135

) -> LoadedExecutable:

136

"""

137

Deserialize executable from bytes.

138

139

Parameters:

140

- serialized: Serialized executable bytes

141

- executable_devices: Target devices for execution

142

- options: Compilation options

143

- host_callbacks: Host callback functions

144

145

Returns:

146

LoadedExecutable ready for execution

147

"""

148

```

149

150

### Executable Interface

151

152

Compiled executable representation with metadata and analysis capabilities.

153

154

```python { .api }

155

class Executable:

156

"""Compiled XLA executable."""

157

158

def hlo_modules(self) -> list[HloModule]:

159

"""Get HLO modules comprising this executable."""

160

161

def get_output_memory_kinds(self) -> list[list[str]]:

162

"""Get memory kinds for outputs."""

163

164

def get_output_shardings(self) -> list[OpSharding] | None:

165

"""Get output sharding specifications."""

166

167

def get_parameter_shardings(self) -> list[OpSharding] | None:

168

"""Get parameter sharding specifications."""

169

170

def get_parameter_layouts(self) -> list[Layout]:

171

"""Get parameter data layouts."""

172

173

def get_output_layouts(self) -> list[Layout]:

174

"""Get output data layouts."""

175

176

def get_compiled_memory_stats(self) -> CompiledMemoryStats:

177

"""Get compiled memory usage statistics."""

178

179

def serialize(self) -> str:

180

"""Serialize executable to string."""

181

182

def cost_analysis(self) -> dict[str, Any]:

183

"""Get cost analysis information."""

184

```

185

186

### Execution Interface

187

188

Loaded executable with execution capabilities and resource management.

189

190

```python { .api }

191

class LoadedExecutable:

192

"""Loaded executable ready for execution."""

193

194

client: Client

195

traceback: Traceback

196

fingerprint: bytes | None

197

198

def local_devices(self) -> list[Device]:

199

"""Get local devices for this executable."""

200

201

def size_of_generated_code_in_bytes(self) -> int:

202

"""Get generated code size in bytes."""

203

204

def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]:

205

"""

206

Execute on single replica with array arguments.

207

208

Parameters:

209

- arguments: Input arrays for computation

210

211

Returns:

212

List of output arrays

213

"""

214

215

def execute_with_token(

216

self, arguments: Sequence[ArrayImpl]

217

) -> tuple[list[ArrayImpl], Token]:

218

"""

219

Execute with token for ordering.

220

221

Parameters:

222

- arguments: Input arrays for computation

223

224

Returns:

225

Tuple of (output arrays, execution token)

226

"""

227

228

def execute_sharded(

229

self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = False

230

) -> ExecuteResults:

231

"""

232

Execute on multiple replicas with sharded arguments.

233

234

Parameters:

235

- arguments: Sharded input arrays per replica

236

- with_tokens: Whether to return execution tokens

237

238

Returns:

239

ExecuteResults containing sharded outputs

240

"""

241

242

def hlo_modules(self) -> list[HloModule]:

243

"""Get HLO modules comprising this executable."""

244

245

def get_output_memory_kinds(self) -> list[list[str]]:

246

"""Get memory kinds for outputs."""

247

248

def get_compiled_memory_stats(self) -> CompiledMemoryStats:

249

"""Get compiled memory usage statistics."""

250

251

def get_output_shardings(self) -> list[OpSharding] | None:

252

"""Get output sharding specifications."""

253

254

def get_parameter_shardings(self) -> list[OpSharding] | None:

255

"""Get parameter sharding specifications."""

256

257

def get_parameter_layouts(self) -> list[Layout]:

258

"""Get parameter data layouts."""

259

260

def get_output_layouts(self) -> list[Layout]:

261

"""Get output data layouts."""

262

263

def keep_alive(self) -> None:

264

"""Keep executable alive in memory."""

265

266

def cost_analysis(self) -> dict[str, Any]:

267

"""Get cost analysis information."""

268

```

269

270

### Execution Results

271

272

Container for managing execution results from sharded computations.

273

274

```python { .api }

275

class ExecuteResults:

276

"""Results container for sharded execution."""

277

278

def __len__(self) -> int:

279

"""Get number of result sets."""

280

281

def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]:

282

"""

283

Disassemble results into single-device arrays.

284

285

Returns:

286

List of array lists, one per device

287

"""

288

289

def disassemble_prefix_into_single_device_arrays(

290

self, n: int

291

) -> list[list[ArrayImpl]]:

292

"""

293

Disassemble first n results into single-device arrays.

294

295

Parameters:

296

- n: Number of results to disassemble

297

298

Returns:

299

List of array lists for first n results

300

"""

301

302

def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]:

303

"""

304

Consume results with custom handlers.

305

306

Parameters:

307

- handlers: List of handler functions

308

309

Returns:

310

List of handler results

311

"""

312

313

def consume_token(self) -> ShardedToken:

314

"""Consume execution token from results."""

315

```

316

317

### Execution Tokens

318

319

Token system for managing execution ordering and synchronization.

320

321

```python { .api }

322

class Token:

323

"""Execution token for single-device operations."""

324

325

def block_until_ready(self):

326

"""Block until token is ready."""

327

328

class ShardedToken:

329

"""Execution token for sharded operations."""

330

331

def block_until_ready(self):

332

"""Block until all shards are ready."""

333

334

def get_token(self, device_id: int):

335

"""Get token for specific device."""

336

```

337

338

### Memory Statistics

339

340

Detailed memory usage information for compiled executables.

341

342

```python { .api }

343

class CompiledMemoryStats:

344

"""Memory usage statistics for compiled executable."""

345

346

generated_code_size_in_bytes: int

347

argument_size_in_bytes: int

348

output_size_in_bytes: int

349

alias_size_in_bytes: int

350

temp_size_in_bytes: int

351

host_generated_code_size_in_bytes: int

352

host_argument_size_in_bytes: int

353

host_output_size_in_bytes: int

354

host_alias_size_in_bytes: int

355

host_temp_size_in_bytes: int

356

serialized_buffer_assignment_proto: bytes

357

358

def __str__(self) -> str:

359

"""Get string representation of memory stats."""

360

```

361

362

## Usage Examples

363

364

### Basic Compilation and Execution

365

366

```python

367

from jaxlib import xla_client

368

import numpy as np

369

370

# Create client and get device

371

client = xla_client.make_cpu_client()

372

device = client.local_devices()[0]

373

374

# Simple HLO computation (add two arrays)

375

hlo_text = """

376

HloModule add_module

377

378

ENTRY add_computation {

379

x = f32[3] parameter(0)

380

y = f32[3] parameter(1)

381

ROOT add = f32[3] add(x, y)

382

}

383

"""

384

385

# Compile the computation

386

executable = client.compile_and_load(

387

hlo_text,

388

executable_devices=[device]

389

)

390

391

# Prepare input data

392

a = np.array([1.0, 2.0, 3.0], dtype=np.float32)

393

b = np.array([4.0, 5.0, 6.0], dtype=np.float32)

394

395

# Create device buffers

396

buffer_a = client.buffer_from_pyval(a, device=device)

397

buffer_b = client.buffer_from_pyval(b, device=device)

398

399

# Execute the computation

400

result_buffers = executable.execute([buffer_a, buffer_b])

401

result = np.array(result_buffers[0])

402

403

print(f"Result: {result}") # [5.0, 7.0, 9.0]

404

```

405

406

### Compilation with Options

407

408

```python

409

from jaxlib import xla_client

410

411

client = xla_client.make_cpu_client()

412

devices = client.local_devices()

413

414

# Create compilation options

415

compile_options = xla_client.CompileOptions()

416

compile_options.num_replicas = 1

417

compile_options.num_partitions = 1

418

419

# Build options with debug settings

420

build_options = xla_client.ExecutableBuildOptions()

421

build_options.debug_options.xla_backend_optimization_level = 2

422

build_options.debug_options.xla_dump_hlo_as_text = True

423

compile_options.executable_build_options = build_options

424

425

# Compile with options

426

executable = client.compile_and_load(

427

hlo_text,

428

executable_devices=devices[:1],

429

compile_options=compile_options

430

)

431

432

# Get compilation info

433

stats = executable.get_compiled_memory_stats()

434

print(f"Generated code size: {stats.generated_code_size_in_bytes} bytes")

435

print(f"Argument size: {stats.argument_size_in_bytes} bytes")

436

```

437

438

### Sharded Execution

439

440

```python

441

from jaxlib import xla_client

442

import numpy as np

443

444

client = xla_client.make_cpu_client()

445

devices = client.local_devices()

446

447

if len(devices) >= 2:

448

# HLO for element-wise operation across devices

449

hlo_sharded = """

450

HloModule sharded_add

451

452

ENTRY computation {

453

x = f32[2] parameter(0)

454

y = f32[2] parameter(1)

455

ROOT add = f32[2] add(x, y)

456

}

457

"""

458

459

# Compile for multiple devices

460

executable = client.compile_and_load(

461

hlo_sharded,

462

executable_devices=devices[:2]

463

)

464

465

# Prepare sharded inputs (one shard per device)

466

shard1_a = client.buffer_from_pyval(np.array([1.0, 2.0], dtype=np.float32), devices[0])

467

shard1_b = client.buffer_from_pyval(np.array([3.0, 4.0], dtype=np.float32), devices[0])

468

469

shard2_a = client.buffer_from_pyval(np.array([5.0, 6.0], dtype=np.float32), devices[1])

470

shard2_b = client.buffer_from_pyval(np.array([7.0, 8.0], dtype=np.float32), devices[1])

471

472

# Execute with sharded inputs

473

sharded_args = [[shard1_a, shard1_b], [shard2_a, shard2_b]]

474

results = executable.execute_sharded(sharded_args)

475

476

# Get results from each device

477

output_arrays = results.disassemble_into_single_device_arrays()

478

for i, device_output in enumerate(output_arrays):

479

result = np.array(device_output[0])

480

print(f"Device {i} result: {result}")

481

```

482

483

### Executable Serialization

484

485

```python

486

from jaxlib import xla_client

487

488

client = xla_client.make_cpu_client()

489

device = client.local_devices()[0]

490

491

# Compile executable

492

executable = client.compile_and_load(hlo_text, [device])

493

494

# Serialize for storage/transfer

495

serialized = client.serialize_executable(executable)

496

print(f"Serialized size: {len(serialized)} bytes")

497

498

# Deserialize executable

499

restored_executable = client.deserialize_executable(

500

serialized,

501

executable_devices=[device],

502

options=None

503

)

504

505

# Use restored executable

506

result = restored_executable.execute([buffer_a, buffer_b])

507

```