or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

array-operations.mddocs/

0

# Array Operations

1

2

High-performance array operations including device placement, sharding, copying, and memory management optimized for different hardware backends.

3

4

## Capabilities

5

6

### Array Device Placement

7

8

Functions for placing arrays on devices with different sharding strategies and memory semantics.

9

10

```python { .api }

11

def batched_device_put(

12

aval: Any,

13

sharding: Any,

14

shards: Sequence[Any],

15

devices: list[Device],

16

committed: bool = False,

17

force_copy: bool = False,

18

host_buffer_semantics: Any = ...,

19

) -> ArrayImpl:

20

"""

21

Place array shards on devices with specified sharding.

22

23

Parameters:

24

- aval: Array abstract value

25

- sharding: Sharding specification

26

- shards: Array shards to place

27

- devices: Target devices

28

- committed: Whether placement is committed

29

- force_copy: Force copying data

30

- host_buffer_semantics: Host buffer handling

31

32

Returns:

33

ArrayImpl distributed across devices

34

"""

35

36

def array_result_handler(

37

aval: Any, sharding: Any, committed: bool, _skip_checks: bool = False

38

) -> Callable:

39

"""

40

Create result handler for array operations.

41

42

Parameters:

43

- aval: Array abstract value

44

- sharding: Sharding specification

45

- committed: Whether result is committed

46

- _skip_checks: Skip validation checks

47

48

Returns:

49

Result handler function

50

"""

51

```

52

53

### Array Copying and Transfer

54

55

High-performance array copying operations with sharding awareness.

56

57

```python { .api }

58

def batched_copy_array_to_devices_with_sharding(

59

arrays: Sequence[ArrayImpl],

60

devices: Sequence[DeviceList],

61

sharding: Sequence[Any],

62

array_copy_semantics: Sequence[ArrayCopySemantics],

63

) -> Sequence[ArrayImpl]:

64

"""

65

Copy arrays to devices with specified sharding.

66

67

Parameters:

68

- arrays: Source arrays to copy

69

- devices: Target device lists

70

- sharding: Sharding specifications

71

- array_copy_semantics: Copy semantics for each array

72

73

Returns:

74

Copied arrays on target devices

75

"""

76

77

def reorder_shards(

78

x: ArrayImpl,

79

dst_sharding: Any,

80

array_copy_semantics: ArrayCopySemantics,

81

) -> ArrayImpl:

82

"""

83

Reorder array shards according to destination sharding.

84

85

Parameters:

86

- x: Source array

87

- dst_sharding: Destination sharding specification

88

- array_copy_semantics: Copy semantics

89

90

Returns:

91

Array with reordered shards

92

"""

93

```

94

95

### Synchronization

96

97

Operations for synchronizing array operations across devices.

98

99

```python { .api }

100

def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None:

101

"""

102

Block until all arrays in sequence are ready.

103

104

Parameters:

105

- x: Sequence of arrays to wait for

106

"""

107

```

108

109

### Array Implementation

110

111

Core array implementation providing the foundation for JAX arrays.

112

113

```python { .api }

114

# ArrayImpl is defined in C++ and accessed through _jax module

115

# Key methods available on ArrayImpl instances:

116

117

# def block_until_ready(self) -> ArrayImpl: ...

118

# def is_deleted(self) -> bool: ...

119

# def is_ready(self) -> bool: ...

120

# def delete(self): ...

121

# def clone(self) -> ArrayImpl: ...

122

# def on_device_size_in_bytes(self) -> int: ...

123

124

# Properties:

125

# dtype: np.dtype

126

# shape: tuple[int, ...]

127

# _arrays: Any # Underlying device arrays

128

# traceback: Traceback

129

```

130

131

## Usage Examples

132

133

### Basic Array Placement

134

135

```python

136

from jaxlib import xla_client

137

import numpy as np

138

139

client = xla_client.make_cpu_client()

140

devices = client.local_devices()

141

142

# Create array data

143

data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

144

145

# Place on device

146

buffer = client.buffer_from_pyval(data, device=devices[0])

147

148

# Check array properties

149

print(f"Array shape: {buffer.shape}")

150

print(f"Array dtype: {buffer.dtype}")

151

print(f"On-device size: {buffer.on_device_size_in_bytes()} bytes")

152

153

# Wait for completion

154

buffer.block_until_ready()

155

print(f"Array is ready: {buffer.is_ready()}")

156

```

157

158

### Batch Operations

159

160

```python

161

from jaxlib import xla_client

162

import numpy as np

163

164

client = xla_client.make_cpu_client()

165

devices = client.local_devices()

166

167

# Create multiple arrays

168

arrays = [

169

client.buffer_from_pyval(np.array([1.0, 2.0]), devices[0]),

170

client.buffer_from_pyval(np.array([3.0, 4.0]), devices[0]),

171

client.buffer_from_pyval(np.array([5.0, 6.0]), devices[0])

172

]

173

174

# Wait for all arrays to be ready

175

xla_client.batched_block_until_ready(arrays)

176

177

print("All arrays are ready")

178

for i, arr in enumerate(arrays):

179

print(f"Array {i}: ready={arr.is_ready()}")

180

```