XLA library for JAX providing low-level bindings and hardware acceleration support
—
Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.
Functions for registering custom operations that can be called from XLA computations.
class CustomCallTargetTraits(enum.IntFlag):
"""Traits for custom call targets."""
DEFAULT = 0
COMMAND_BUFFER_COMPATIBLE = 1
def register_custom_call_target(
name: str,
fn: Any,
platform: str = 'cpu',
api_version: int = 0,
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None:
"""
Register a custom call target function.
Parameters:
- name: Name of the custom call
- fn: PyCapsule containing function pointer
- platform: Target platform ('cpu', 'gpu', etc.)
- api_version: XLA FFI version (0 for untyped, 1 for typed)
- traits: Custom call traits
"""
def register_custom_call_handler(
platform: str, handler: CustomCallHandler
) -> None:
"""
Register a custom call handler for a platform.
Parameters:
- platform: Target platform
- handler: Handler function for registering custom calls
"""
def custom_call_targets(platform: str) -> dict[str, Any]:
"""
Get registered custom call targets for a platform.
Parameters:
- platform: Platform name
Returns:
Dictionary of registered custom call targets
"""Advanced functionality for custom operations that support sharding and partitioning.
def register_custom_call_partitioner(
name: str,
prop_user_sharding: Callable,
partition: Callable,
infer_sharding_from_operands: Callable,
can_side_effecting_have_replicated_sharding: bool = False,
c_api: Any | None = None,
) -> None:
"""
Register partitioner for custom call.
Parameters:
- name: Custom call name
- prop_user_sharding: Function to propagate user sharding
- partition: Function to partition the operation
- infer_sharding_from_operands: Function to infer output sharding
- can_side_effecting_have_replicated_sharding: Whether side-effecting ops can be replicated
- c_api: C API interface (optional)
"""
def register_custom_call_as_batch_partitionable(
target_name: str,
c_api: Any | None = None,
) -> None:
"""
Register custom call as batch partitionable.
Parameters:
- target_name: Name of the custom call target
- c_api: C API interface (optional)
"""
def encode_inspect_sharding_callback(handler: Any) -> bytes:
"""
Encode sharding inspection callback.
Parameters:
- handler: Callback handler function
Returns:
Encoded callback as bytes
"""Support for registering custom types for use with the FFI system.
def register_custom_type_id(
type_name: str,
type_id: Any,
platform: str = 'cpu',
) -> None:
"""
Register custom type ID for FFI.
Parameters:
- type_name: Unique name for the type
- type_id: PyCapsule containing pointer to ffi::TypeId
- platform: Target platform
"""
def register_custom_type_id_handler(
platform: str, handler: CustomTypeIdHandler
) -> None:
"""
Register handler for custom type IDs.
Parameters:
- platform: Target platform
- handler: Handler function for registering type IDs
"""from jaxlib import xla_client
import ctypes
# Example: Register a simple custom function
# First, you would compile a C/C++ function and get a pointer
# Hypothetical custom function (in practice, this would be from a compiled library)
def create_custom_add_capsule():
# This is a simplified example - in practice you'd load from a shared library
# and create a PyCapsule with the function pointer
pass
# Register the custom call
xla_client.register_custom_call_target(
name="custom_add",
fn=create_custom_add_capsule(), # PyCapsule with function pointer
platform="cpu",
api_version=1, # Use typed FFI
traits=xla_client.CustomCallTargetTraits.DEFAULT
)
# Check if registered
cpu_targets = xla_client.custom_call_targets("cpu")
print(f"Custom targets: {list(cpu_targets.keys())}")from jaxlib import xla_client
def prop_user_sharding_fn(op_sharding, operand_shardings):
"""Propagate user-specified sharding."""
# Implementation would handle sharding propagation
return op_sharding
def partition_fn(operands, partition_id, total_partitions):
"""Partition the custom operation."""
# Implementation would partition operands appropriately
return operands
def infer_sharding_fn(operand_shardings):
"""Infer output sharding from operand shardings."""
# Implementation would infer appropriate output sharding
return operand_shardings[0] if operand_shardings else None
# Register partitioner for custom operation
xla_client.register_custom_call_partitioner(
name="custom_matrix_multiply",
prop_user_sharding=prop_user_sharding_fn,
partition=partition_fn,
infer_sharding_from_operands=infer_sharding_fn,
can_side_effecting_have_replicated_sharding=False
)from jaxlib import xla_client
# Register custom type (hypothetical example)
def create_custom_type_capsule():
# In practice, this would create a PyCapsule containing
# a pointer to an ffi::TypeId for your custom type
pass
xla_client.register_custom_type_id(
type_name="MyCustomType",
type_id=create_custom_type_capsule(),
platform="cpu"
)
print("Registered custom type: MyCustomType")Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib