0
# Custom Operations
1
2
Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.
3
4
## Capabilities
5
6
### Custom Call Registration
7
8
Functions for registering custom operations that can be called from XLA computations.
9
10
```python { .api }
11
class CustomCallTargetTraits(enum.IntFlag):
12
"""Traits for custom call targets."""
13
DEFAULT = 0
14
COMMAND_BUFFER_COMPATIBLE = 1
15
16
def register_custom_call_target(
17
name: str,
18
fn: Any,
19
platform: str = 'cpu',
20
api_version: int = 0,
21
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
22
) -> None:
23
"""
24
Register a custom call target function.
25
26
Parameters:
27
- name: Name of the custom call
28
- fn: PyCapsule containing function pointer
29
- platform: Target platform ('cpu', 'gpu', etc.)
30
- api_version: XLA FFI version (0 for untyped, 1 for typed)
31
- traits: Custom call traits
32
"""
33
34
def register_custom_call_handler(
35
platform: str, handler: CustomCallHandler
36
) -> None:
37
"""
38
Register a custom call handler for a platform.
39
40
Parameters:
41
- platform: Target platform
42
- handler: Handler function for registering custom calls
43
"""
44
45
def custom_call_targets(platform: str) -> dict[str, Any]:
46
"""
47
Get registered custom call targets for a platform.
48
49
Parameters:
50
- platform: Platform name
51
52
Returns:
53
Dictionary of registered custom call targets
54
"""
55
```
56
57
### Custom Call Partitioning
58
59
Advanced functionality for custom operations that support sharding and partitioning.
60
61
```python { .api }
62
def register_custom_call_partitioner(
63
name: str,
64
prop_user_sharding: Callable,
65
partition: Callable,
66
infer_sharding_from_operands: Callable,
67
can_side_effecting_have_replicated_sharding: bool = False,
68
c_api: Any | None = None,
69
) -> None:
70
"""
71
Register partitioner for custom call.
72
73
Parameters:
74
- name: Custom call name
75
- prop_user_sharding: Function to propagate user sharding
76
- partition: Function to partition the operation
77
- infer_sharding_from_operands: Function to infer output sharding
78
- can_side_effecting_have_replicated_sharding: Whether side-effecting ops can be replicated
79
- c_api: C API interface (optional)
80
"""
81
82
def register_custom_call_as_batch_partitionable(
83
target_name: str,
84
c_api: Any | None = None,
85
) -> None:
86
"""
87
Register custom call as batch partitionable.
88
89
Parameters:
90
- target_name: Name of the custom call target
91
- c_api: C API interface (optional)
92
"""
93
94
def encode_inspect_sharding_callback(handler: Any) -> bytes:
95
"""
96
Encode sharding inspection callback.
97
98
Parameters:
99
- handler: Callback handler function
100
101
Returns:
102
Encoded callback as bytes
103
"""
104
```
105
106
### Custom Type System
107
108
Support for registering custom types for use with the FFI system.
109
110
```python { .api }
111
def register_custom_type_id(
112
type_name: str,
113
type_id: Any,
114
platform: str = 'cpu',
115
) -> None:
116
"""
117
Register custom type ID for FFI.
118
119
Parameters:
120
- type_name: Unique name for the type
121
- type_id: PyCapsule containing pointer to ffi::TypeId
122
- platform: Target platform
123
"""
124
125
def register_custom_type_id_handler(
126
platform: str, handler: CustomTypeIdHandler
127
) -> None:
128
"""
129
Register handler for custom type IDs.
130
131
Parameters:
132
- platform: Target platform
133
- handler: Handler function for registering type IDs
134
"""
135
```
136
137
## Usage Examples
138
139
### Basic Custom Call
140
141
```python
142
from jaxlib import xla_client
143
import ctypes
144
145
# Example: Register a simple custom function
146
# First, you would compile a C/C++ function and get a pointer
147
148
# Hypothetical custom function (in practice, this would be from a compiled library)
149
def create_custom_add_capsule():
150
# This is a simplified example - in practice you'd load from a shared library
151
# and create a PyCapsule with the function pointer
152
pass
153
154
# Register the custom call
155
xla_client.register_custom_call_target(
156
name="custom_add",
157
fn=create_custom_add_capsule(), # PyCapsule with function pointer
158
platform="cpu",
159
api_version=1, # Use typed FFI
160
traits=xla_client.CustomCallTargetTraits.DEFAULT
161
)
162
163
# Check if registered
164
cpu_targets = xla_client.custom_call_targets("cpu")
165
print(f"Custom targets: {list(cpu_targets.keys())}")
166
```
167
168
### Custom Call with Partitioning
169
170
```python
171
from jaxlib import xla_client
172
173
def prop_user_sharding_fn(op_sharding, operand_shardings):
174
"""Propagate user-specified sharding."""
175
# Implementation would handle sharding propagation
176
return op_sharding
177
178
def partition_fn(operands, partition_id, total_partitions):
179
"""Partition the custom operation."""
180
# Implementation would partition operands appropriately
181
return operands
182
183
def infer_sharding_fn(operand_shardings):
184
"""Infer output sharding from operand shardings."""
185
# Implementation would infer appropriate output sharding
186
return operand_shardings[0] if operand_shardings else None
187
188
# Register partitioner for custom operation
189
xla_client.register_custom_call_partitioner(
190
name="custom_matrix_multiply",
191
prop_user_sharding=prop_user_sharding_fn,
192
partition=partition_fn,
193
infer_sharding_from_operands=infer_sharding_fn,
194
can_side_effecting_have_replicated_sharding=False
195
)
196
```
197
198
### Custom Types
199
200
```python
201
from jaxlib import xla_client
202
203
# Register custom type (hypothetical example)
204
def create_custom_type_capsule():
205
# In practice, this would create a PyCapsule containing
206
# a pointer to an ffi::TypeId for your custom type
207
pass
208
209
xla_client.register_custom_type_id(
210
type_name="MyCustomType",
211
type_id=create_custom_type_capsule(),
212
platform="cpu"
213
)
214
215
print("Registered custom type: MyCustomType")
216
```