or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

custom-operations.mddocs/

0

# Custom Operations

1

2

Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.

3

4

## Capabilities

5

6

### Custom Call Registration

7

8

Functions for registering custom operations that can be called from XLA computations.

9

10

```python { .api }

11

class CustomCallTargetTraits(enum.IntFlag):

12

"""Traits for custom call targets."""

13

DEFAULT = 0

14

COMMAND_BUFFER_COMPATIBLE = 1

15

16

def register_custom_call_target(

17

name: str,

18

fn: Any,

19

platform: str = 'cpu',

20

api_version: int = 0,

21

traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,

22

) -> None:

23

"""

24

Register a custom call target function.

25

26

Parameters:

27

- name: Name of the custom call

28

- fn: PyCapsule containing function pointer

29

- platform: Target platform ('cpu', 'gpu', etc.)

30

- api_version: XLA FFI version (0 for untyped, 1 for typed)

31

- traits: Custom call traits

32

"""

33

34

def register_custom_call_handler(

35

platform: str, handler: CustomCallHandler

36

) -> None:

37

"""

38

Register a custom call handler for a platform.

39

40

Parameters:

41

- platform: Target platform

42

- handler: Handler function for registering custom calls

43

"""

44

45

def custom_call_targets(platform: str) -> dict[str, Any]:

46

"""

47

Get registered custom call targets for a platform.

48

49

Parameters:

50

- platform: Platform name

51

52

Returns:

53

Dictionary of registered custom call targets

54

"""

55

```

56

57

### Custom Call Partitioning

58

59

Advanced functionality for custom operations that support sharding and partitioning.

60

61

```python { .api }

62

def register_custom_call_partitioner(

63

name: str,

64

prop_user_sharding: Callable,

65

partition: Callable,

66

infer_sharding_from_operands: Callable,

67

can_side_effecting_have_replicated_sharding: bool = False,

68

c_api: Any | None = None,

69

) -> None:

70

"""

71

Register partitioner for custom call.

72

73

Parameters:

74

- name: Custom call name

75

- prop_user_sharding: Function to propagate user sharding

76

- partition: Function to partition the operation

77

- infer_sharding_from_operands: Function to infer output sharding

78

- can_side_effecting_have_replicated_sharding: Whether side-effecting ops can be replicated

79

- c_api: C API interface (optional)

80

"""

81

82

def register_custom_call_as_batch_partitionable(

83

target_name: str,

84

c_api: Any | None = None,

85

) -> None:

86

"""

87

Register custom call as batch partitionable.

88

89

Parameters:

90

- target_name: Name of the custom call target

91

- c_api: C API interface (optional)

92

"""

93

94

def encode_inspect_sharding_callback(handler: Any) -> bytes:

95

"""

96

Encode sharding inspection callback.

97

98

Parameters:

99

- handler: Callback handler function

100

101

Returns:

102

Encoded callback as bytes

103

"""

104

```

105

106

### Custom Type System

107

108

Support for registering custom types for use with the FFI system.

109

110

```python { .api }

111

def register_custom_type_id(

112

type_name: str,

113

type_id: Any,

114

platform: str = 'cpu',

115

) -> None:

116

"""

117

Register custom type ID for FFI.

118

119

Parameters:

120

- type_name: Unique name for the type

121

- type_id: PyCapsule containing pointer to ffi::TypeId

122

- platform: Target platform

123

"""

124

125

def register_custom_type_id_handler(

126

platform: str, handler: CustomTypeIdHandler

127

) -> None:

128

"""

129

Register handler for custom type IDs.

130

131

Parameters:

132

- platform: Target platform

133

- handler: Handler function for registering type IDs

134

"""

135

```

136

137

## Usage Examples

138

139

### Basic Custom Call

140

141

```python

142

from jaxlib import xla_client

143

import ctypes

144

145

# Example: Register a simple custom function

146

# First, you would compile a C/C++ function and get a pointer

147

148

# Hypothetical custom function (in practice, this would be from a compiled library)

149

def create_custom_add_capsule():

150

# This is a simplified example - in practice you'd load from a shared library

151

# and create a PyCapsule with the function pointer

152

pass

153

154

# Register the custom call

155

xla_client.register_custom_call_target(

156

name="custom_add",

157

fn=create_custom_add_capsule(), # PyCapsule with function pointer

158

platform="cpu",

159

api_version=1, # Use typed FFI

160

traits=xla_client.CustomCallTargetTraits.DEFAULT

161

)

162

163

# Check if registered

164

cpu_targets = xla_client.custom_call_targets("cpu")

165

print(f"Custom targets: {list(cpu_targets.keys())}")

166

```

167

168

### Custom Call with Partitioning

169

170

```python

171

from jaxlib import xla_client

172

173

def prop_user_sharding_fn(op_sharding, operand_shardings):

174

"""Propagate user-specified sharding."""

175

# Implementation would handle sharding propagation

176

return op_sharding

177

178

def partition_fn(operands, partition_id, total_partitions):

179

"""Partition the custom operation."""

180

# Implementation would partition operands appropriately

181

return operands

182

183

def infer_sharding_fn(operand_shardings):

184

"""Infer output sharding from operand shardings."""

185

# Implementation would infer appropriate output sharding

186

return operand_shardings[0] if operand_shardings else None

187

188

# Register partitioner for custom operation

189

xla_client.register_custom_call_partitioner(

190

name="custom_matrix_multiply",

191

prop_user_sharding=prop_user_sharding_fn,

192

partition=partition_fn,

193

infer_sharding_from_operands=infer_sharding_fn,

194

can_side_effecting_have_replicated_sharding=False

195

)

196

```

197

198

### Custom Types

199

200

```python

201

from jaxlib import xla_client

202

203

# Register custom type (hypothetical example)

204

def create_custom_type_capsule():

205

# In practice, this would create a PyCapsule containing

206

# a pointer to an ffi::TypeId for your custom type

207

pass

208

209

xla_client.register_custom_type_id(

210

type_name="MyCustomType",

211

type_id=create_custom_type_capsule(),

212

platform="cpu"

213

)

214

215

print("Registered custom type: MyCustomType")

216

```