0
# Compilation and Execution
1
2
XLA computation compilation, loading, and execution with support for distributed computing, sharding, and various execution modes. Provides the core functionality for transforming high-level computations into optimized executable code.
3
4
## Capabilities
5
6
### Compilation Options
7
8
Configuration options for controlling XLA compilation behavior and optimizations.
9
10
```python { .api }
11
class CompileOptions:
12
"""Options for XLA compilation."""
13
14
@staticmethod
15
def ParseFromString(s: bytes) -> CompileOptions:
16
"""Parse compilation options from serialized bytes."""
17
18
def __init__(self) -> None: ...
19
20
def SerializeAsString(self) -> bytes:
21
"""Serialize compilation options to bytes."""
22
23
argument_layouts: list[Shape] | None
24
parameter_is_tupled_arguments: bool
25
executable_build_options: ExecutableBuildOptions
26
tuple_arguments: bool
27
num_replicas: int
28
num_partitions: int
29
profile_version: int
30
device_assignment: DeviceAssignment | None
31
compile_portable_executable: bool
32
env_option_overrides: list[tuple[str, str]]
33
34
class ExecutableBuildOptions:
35
"""Options for building executables."""
36
37
def __init__(self) -> None: ...
38
39
result_layout: Shape | None
40
fdo_profile: bytes | None
41
num_replicas: int
42
num_partitions: int
43
debug_options: DebugOptions
44
device_assignment: DeviceAssignment | None
45
use_spmd_partitioning: bool
46
use_auto_spmd_partitioning: bool
47
auto_spmd_partitioning_mesh_shape: list[int]
48
auto_spmd_partitioning_mesh_ids: list[int]
49
use_shardy_partitioner: bool
50
51
def compilation_environments_from_serialized_proto(
52
self, serialized_proto: bytes
53
) -> None:
54
"""Set compilation environments from serialized proto."""
55
56
class DebugOptions:
57
"""Debug and optimization options for XLA."""
58
59
xla_cpu_enable_fast_math: bool
60
xla_gpu_enable_fast_min_max: bool
61
xla_backend_optimization_level: int
62
xla_cpu_enable_xprof_traceme: bool
63
xla_force_host_platform_device_count: int
64
xla_dump_to: str
65
xla_dump_hlo_module_re: str
66
xla_dump_hlo_as_text: bool
67
xla_dump_hlo_as_proto: bool
68
xla_detailed_logging: bool
69
xla_enable_dumping: bool
70
```
71
72
### Compilation Interface
73
74
Client methods for compiling XLA computations into executable forms.
75
76
```python { .api }
77
class Client:
78
"""XLA client compilation interface."""
79
80
def compile(
81
self,
82
computation: str | bytes,
83
executable_devices: DeviceList | Sequence[Device],
84
compile_options: CompileOptions = ...,
85
) -> Executable:
86
"""
87
Compile XLA computation to executable.
88
89
Parameters:
90
- computation: HLO module as string or serialized bytes
91
- executable_devices: Target devices for execution
92
- compile_options: Compilation configuration options
93
94
Returns:
95
Compiled Executable object
96
"""
97
98
def compile_and_load(
99
self,
100
computation: str | bytes,
101
executable_devices: DeviceList | Sequence[Device],
102
compile_options: CompileOptions = ...,
103
host_callbacks: Sequence[Any] = ...,
104
) -> LoadedExecutable:
105
"""
106
Compile and load XLA computation for execution.
107
108
Parameters:
109
- computation: HLO module as string or serialized bytes
110
- executable_devices: Target devices for execution
111
- compile_options: Compilation configuration options
112
- host_callbacks: Host callback functions
113
114
Returns:
115
LoadedExecutable ready for execution
116
"""
117
118
def serialize_executable(self, executable: LoadedExecutable) -> bytes:
119
"""
120
Serialize loaded executable to bytes.
121
122
Parameters:
123
- executable: LoadedExecutable to serialize
124
125
Returns:
126
Serialized executable as bytes
127
"""
128
129
def deserialize_executable(
130
self,
131
serialized: bytes,
132
executable_devices: DeviceList | Sequence[Device],
133
options: CompileOptions | None,
134
host_callbacks: Sequence[Any] = ...,
135
) -> LoadedExecutable:
136
"""
137
Deserialize executable from bytes.
138
139
Parameters:
140
- serialized: Serialized executable bytes
141
- executable_devices: Target devices for execution
142
- options: Compilation options
143
- host_callbacks: Host callback functions
144
145
Returns:
146
LoadedExecutable ready for execution
147
"""
148
```
149
150
### Executable Interface
151
152
Compiled executable representation with metadata and analysis capabilities.
153
154
```python { .api }
155
class Executable:
156
"""Compiled XLA executable."""
157
158
def hlo_modules(self) -> list[HloModule]:
159
"""Get HLO modules comprising this executable."""
160
161
def get_output_memory_kinds(self) -> list[list[str]]:
162
"""Get memory kinds for outputs."""
163
164
def get_output_shardings(self) -> list[OpSharding] | None:
165
"""Get output sharding specifications."""
166
167
def get_parameter_shardings(self) -> list[OpSharding] | None:
168
"""Get parameter sharding specifications."""
169
170
def get_parameter_layouts(self) -> list[Layout]:
171
"""Get parameter data layouts."""
172
173
def get_output_layouts(self) -> list[Layout]:
174
"""Get output data layouts."""
175
176
def get_compiled_memory_stats(self) -> CompiledMemoryStats:
177
"""Get compiled memory usage statistics."""
178
179
def serialize(self) -> str:
180
"""Serialize executable to string."""
181
182
def cost_analysis(self) -> dict[str, Any]:
183
"""Get cost analysis information."""
184
```
185
186
### Execution Interface
187
188
Loaded executable with execution capabilities and resource management.
189
190
```python { .api }
191
class LoadedExecutable:
192
"""Loaded executable ready for execution."""
193
194
client: Client
195
traceback: Traceback
196
fingerprint: bytes | None
197
198
def local_devices(self) -> list[Device]:
199
"""Get local devices for this executable."""
200
201
def size_of_generated_code_in_bytes(self) -> int:
202
"""Get generated code size in bytes."""
203
204
def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]:
205
"""
206
Execute on single replica with array arguments.
207
208
Parameters:
209
- arguments: Input arrays for computation
210
211
Returns:
212
List of output arrays
213
"""
214
215
def execute_with_token(
216
self, arguments: Sequence[ArrayImpl]
217
) -> tuple[list[ArrayImpl], Token]:
218
"""
219
Execute with token for ordering.
220
221
Parameters:
222
- arguments: Input arrays for computation
223
224
Returns:
225
Tuple of (output arrays, execution token)
226
"""
227
228
def execute_sharded(
229
self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = False
230
) -> ExecuteResults:
231
"""
232
Execute on multiple replicas with sharded arguments.
233
234
Parameters:
235
- arguments: Sharded input arrays per replica
236
- with_tokens: Whether to return execution tokens
237
238
Returns:
239
ExecuteResults containing sharded outputs
240
"""
241
242
def hlo_modules(self) -> list[HloModule]:
243
"""Get HLO modules comprising this executable."""
244
245
def get_output_memory_kinds(self) -> list[list[str]]:
246
"""Get memory kinds for outputs."""
247
248
def get_compiled_memory_stats(self) -> CompiledMemoryStats:
249
"""Get compiled memory usage statistics."""
250
251
def get_output_shardings(self) -> list[OpSharding] | None:
252
"""Get output sharding specifications."""
253
254
def get_parameter_shardings(self) -> list[OpSharding] | None:
255
"""Get parameter sharding specifications."""
256
257
def get_parameter_layouts(self) -> list[Layout]:
258
"""Get parameter data layouts."""
259
260
def get_output_layouts(self) -> list[Layout]:
261
"""Get output data layouts."""
262
263
def keep_alive(self) -> None:
264
"""Keep executable alive in memory."""
265
266
def cost_analysis(self) -> dict[str, Any]:
267
"""Get cost analysis information."""
268
```
269
270
### Execution Results
271
272
Container for managing execution results from sharded computations.
273
274
```python { .api }
275
class ExecuteResults:
276
"""Results container for sharded execution."""
277
278
def __len__(self) -> int:
279
"""Get number of result sets."""
280
281
def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]:
282
"""
283
Disassemble results into single-device arrays.
284
285
Returns:
286
List of array lists, one per device
287
"""
288
289
def disassemble_prefix_into_single_device_arrays(
290
self, n: int
291
) -> list[list[ArrayImpl]]:
292
"""
293
Disassemble first n results into single-device arrays.
294
295
Parameters:
296
- n: Number of results to disassemble
297
298
Returns:
299
List of array lists for first n results
300
"""
301
302
def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]:
303
"""
304
Consume results with custom handlers.
305
306
Parameters:
307
- handlers: List of handler functions
308
309
Returns:
310
List of handler results
311
"""
312
313
def consume_token(self) -> ShardedToken:
314
"""Consume execution token from results."""
315
```
316
317
### Execution Tokens
318
319
Token system for managing execution ordering and synchronization.
320
321
```python { .api }
322
class Token:
323
"""Execution token for single-device operations."""
324
325
def block_until_ready(self):
326
"""Block until token is ready."""
327
328
class ShardedToken:
329
"""Execution token for sharded operations."""
330
331
def block_until_ready(self):
332
"""Block until all shards are ready."""
333
334
def get_token(self, device_id: int):
335
"""Get token for specific device."""
336
```
337
338
### Memory Statistics
339
340
Detailed memory usage information for compiled executables.
341
342
```python { .api }
343
class CompiledMemoryStats:
344
"""Memory usage statistics for compiled executable."""
345
346
generated_code_size_in_bytes: int
347
argument_size_in_bytes: int
348
output_size_in_bytes: int
349
alias_size_in_bytes: int
350
temp_size_in_bytes: int
351
host_generated_code_size_in_bytes: int
352
host_argument_size_in_bytes: int
353
host_output_size_in_bytes: int
354
host_alias_size_in_bytes: int
355
host_temp_size_in_bytes: int
356
serialized_buffer_assignment_proto: bytes
357
358
def __str__(self) -> str:
359
"""Get string representation of memory stats."""
360
```
361
362
## Usage Examples
363
364
### Basic Compilation and Execution
365
366
```python
367
from jaxlib import xla_client
368
import numpy as np
369
370
# Create client and get device
371
client = xla_client.make_cpu_client()
372
device = client.local_devices()[0]
373
374
# Simple HLO computation (add two arrays)
375
hlo_text = """
376
HloModule add_module
377
378
ENTRY add_computation {
379
x = f32[3] parameter(0)
380
y = f32[3] parameter(1)
381
ROOT add = f32[3] add(x, y)
382
}
383
"""
384
385
# Compile the computation
386
executable = client.compile_and_load(
387
hlo_text,
388
executable_devices=[device]
389
)
390
391
# Prepare input data
392
a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
393
b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
394
395
# Create device buffers
396
buffer_a = client.buffer_from_pyval(a, device=device)
397
buffer_b = client.buffer_from_pyval(b, device=device)
398
399
# Execute the computation
400
result_buffers = executable.execute([buffer_a, buffer_b])
401
result = np.array(result_buffers[0])
402
403
print(f"Result: {result}") # [5.0, 7.0, 9.0]
404
```
405
406
### Compilation with Options
407
408
```python
409
from jaxlib import xla_client
410
411
client = xla_client.make_cpu_client()
412
devices = client.local_devices()
413
414
# Create compilation options
415
compile_options = xla_client.CompileOptions()
416
compile_options.num_replicas = 1
417
compile_options.num_partitions = 1
418
419
# Build options with debug settings
420
build_options = xla_client.ExecutableBuildOptions()
421
build_options.debug_options.xla_backend_optimization_level = 2
422
build_options.debug_options.xla_dump_hlo_as_text = True
423
compile_options.executable_build_options = build_options
424
425
# Compile with options
426
executable = client.compile_and_load(
427
hlo_text,
428
executable_devices=devices[:1],
429
compile_options=compile_options
430
)
431
432
# Get compilation info
433
stats = executable.get_compiled_memory_stats()
434
print(f"Generated code size: {stats.generated_code_size_in_bytes} bytes")
435
print(f"Argument size: {stats.argument_size_in_bytes} bytes")
436
```
437
438
### Sharded Execution
439
440
```python
441
from jaxlib import xla_client
442
import numpy as np
443
444
client = xla_client.make_cpu_client()
445
devices = client.local_devices()
446
447
if len(devices) >= 2:
448
# HLO for element-wise operation across devices
449
hlo_sharded = """
450
HloModule sharded_add
451
452
ENTRY computation {
453
x = f32[2] parameter(0)
454
y = f32[2] parameter(1)
455
ROOT add = f32[2] add(x, y)
456
}
457
"""
458
459
# Compile for multiple devices
460
executable = client.compile_and_load(
461
hlo_sharded,
462
executable_devices=devices[:2]
463
)
464
465
# Prepare sharded inputs (one shard per device)
466
shard1_a = client.buffer_from_pyval(np.array([1.0, 2.0], dtype=np.float32), devices[0])
467
shard1_b = client.buffer_from_pyval(np.array([3.0, 4.0], dtype=np.float32), devices[0])
468
469
shard2_a = client.buffer_from_pyval(np.array([5.0, 6.0], dtype=np.float32), devices[1])
470
shard2_b = client.buffer_from_pyval(np.array([7.0, 8.0], dtype=np.float32), devices[1])
471
472
# Execute with sharded inputs
473
sharded_args = [[shard1_a, shard1_b], [shard2_a, shard2_b]]
474
results = executable.execute_sharded(sharded_args)
475
476
# Get results from each device
477
output_arrays = results.disassemble_into_single_device_arrays()
478
for i, device_output in enumerate(output_arrays):
479
result = np.array(device_output[0])
480
print(f"Device {i} result: {result}")
481
```
482
483
### Executable Serialization
484
485
```python
486
from jaxlib import xla_client
487
488
client = xla_client.make_cpu_client()
489
device = client.local_devices()[0]
490
491
# Compile executable
492
executable = client.compile_and_load(hlo_text, [device])
493
494
# Serialize for storage/transfer
495
serialized = client.serialize_executable(executable)
496
print(f"Serialized size: {len(serialized)} bytes")
497
498
# Deserialize executable
499
restored_executable = client.deserialize_executable(
500
serialized,
501
executable_devices=[device],
502
options=None
503
)
504
505
# Use restored executable
506
result = restored_executable.execute([buffer_a, buffer_b])
507
```