0
# Array Operations
1
2
High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.
3
4
## Capabilities
5
6
### Array Device Placement
7
8
Functions for placing arrays on devices with different sharding strategies and memory semantics.
9
10
```python { .api }
11
def batched_device_put(
12
aval: Any,
13
sharding: Any,
14
shards: Sequence[Any],
15
devices: list[Device],
16
committed: bool = False,
17
force_copy: bool = False,
18
host_buffer_semantics: Any = ...,
19
) -> ArrayImpl:
20
"""
21
Place array shards on devices with specified sharding.
22
23
Parameters:
24
- aval: Array abstract value
25
- sharding: Sharding specification
26
- shards: Array shards to place
27
- devices: Target devices
28
- committed: Whether placement is committed
29
- force_copy: Force copying data
30
- host_buffer_semantics: Host buffer handling
31
32
Returns:
33
ArrayImpl distributed across devices
34
"""
35
36
def array_result_handler(
37
aval: Any, sharding: Any, committed: bool, _skip_checks: bool = False
38
) -> Callable:
39
"""
40
Create result handler for array operations.
41
42
Parameters:
43
- aval: Array abstract value
44
- sharding: Sharding specification
45
- committed: Whether result is committed
46
- _skip_checks: Skip validation checks
47
48
Returns:
49
Result handler function
50
"""
51
```
52
53
### Array Copying and Transfer
54
55
High-performance array copying operations with sharding awareness.
56
57
```python { .api }
58
def batched_copy_array_to_devices_with_sharding(
59
arrays: Sequence[ArrayImpl],
60
devices: Sequence[DeviceList],
61
sharding: Sequence[Any],
62
array_copy_semantics: Sequence[ArrayCopySemantics],
63
) -> Sequence[ArrayImpl]:
64
"""
65
Copy arrays to devices with specified sharding.
66
67
Parameters:
68
- arrays: Source arrays to copy
69
- devices: Target device lists
70
- sharding: Sharding specifications
71
- array_copy_semantics: Copy semantics for each array
72
73
Returns:
74
Copied arrays on target devices
75
"""
76
77
def reorder_shards(
78
x: ArrayImpl,
79
dst_sharding: Any,
80
array_copy_semantics: ArrayCopySemantics,
81
) -> ArrayImpl:
82
"""
83
Reorder array shards according to destination sharding.
84
85
Parameters:
86
- x: Source array
87
- dst_sharding: Destination sharding specification
88
- array_copy_semantics: Copy semantics
89
90
Returns:
91
Array with reordered shards
92
"""
93
```
94
95
### Synchronization
96
97
Operations for synchronizing array operations across devices.
98
99
```python { .api }
100
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None:
101
"""
102
Block until all arrays in sequence are ready.
103
104
Parameters:
105
- x: Sequence of arrays to wait for
106
"""
107
```
108
109
### Array Implementation
110
111
Core array implementation providing the foundation for JAX arrays.
112
113
```python { .api }
114
# ArrayImpl is defined in C++ and accessed through _jax module
115
# Key methods available on ArrayImpl instances:
116
117
# def block_until_ready(self) -> ArrayImpl: ...
118
# def is_deleted(self) -> bool: ...
119
# def is_ready(self) -> bool: ...
120
# def delete(self): ...
121
# def clone(self) -> ArrayImpl: ...
122
# def on_device_size_in_bytes(self) -> int: ...
123
124
# Properties:
125
# dtype: np.dtype
126
# shape: tuple[int, ...]
127
# _arrays: Any # Underlying device arrays
128
# traceback: Traceback
129
```
130
131
## Usage Examples
132
133
### Basic Array Placement
134
135
```python
136
from jaxlib import xla_client
137
import numpy as np
138
139
client = xla_client.make_cpu_client()
140
devices = client.local_devices()
141
142
# Create array data
143
data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
144
145
# Place on device
146
buffer = client.buffer_from_pyval(data, device=devices[0])
147
148
# Check array properties
149
print(f"Array shape: {buffer.shape}")
150
print(f"Array dtype: {buffer.dtype}")
151
print(f"On-device size: {buffer.on_device_size_in_bytes()} bytes")
152
153
# Wait for completion
154
buffer.block_until_ready()
155
print(f"Array is ready: {buffer.is_ready()}")
156
```
157
158
### Batch Operations
159
160
```python
161
from jaxlib import xla_client
162
import numpy as np
163
164
client = xla_client.make_cpu_client()
165
devices = client.local_devices()
166
167
# Create multiple arrays
168
arrays = [
169
client.buffer_from_pyval(np.array([1.0, 2.0]), devices[0]),
170
client.buffer_from_pyval(np.array([3.0, 4.0]), devices[0]),
171
client.buffer_from_pyval(np.array([5.0, 6.0]), devices[0])
172
]
173
174
# Wait for all arrays to be ready
175
xla_client.batched_block_until_ready(arrays)
176
177
print("All arrays are ready")
178
for i, arr in enumerate(arrays):
179
print(f"Array {i}: ready={arr.is_ready()}")
180
```