CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

custom-operations.mddocs/

Custom Operations

Extensible custom call interface for integrating user-defined operations and hardware-specific kernels into XLA computations.

Capabilities

Custom Call Registration

Functions for registering custom operations that can be called from XLA computations.

class CustomCallTargetTraits(enum.IntFlag):
    """Traits for custom call targets."""
    DEFAULT = 0
    COMMAND_BUFFER_COMPATIBLE = 1

def register_custom_call_target(
    name: str,
    fn: Any,
    platform: str = 'cpu',
    api_version: int = 0,
    traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None:
    """
    Register a custom call target function.
    
    Parameters:
    - name: Name of the custom call
    - fn: PyCapsule containing function pointer
    - platform: Target platform ('cpu', 'gpu', etc.)
    - api_version: XLA FFI version (0 for untyped, 1 for typed)
    - traits: Custom call traits
    """

def register_custom_call_handler(
    platform: str, handler: CustomCallHandler
) -> None:
    """
    Register a custom call handler for a platform.
    
    Parameters:
    - platform: Target platform
    - handler: Handler function for registering custom calls
    """

def custom_call_targets(platform: str) -> dict[str, Any]:
    """
    Get registered custom call targets for a platform.
    
    Parameters:
    - platform: Platform name
    
    Returns:
    Dictionary of registered custom call targets
    """

Custom Call Partitioning

Advanced functionality for custom operations that support sharding and partitioning.

def register_custom_call_partitioner(
    name: str,
    prop_user_sharding: Callable,
    partition: Callable,
    infer_sharding_from_operands: Callable,
    can_side_effecting_have_replicated_sharding: bool = False,
    c_api: Any | None = None,
) -> None:
    """
    Register partitioner for custom call.
    
    Parameters:
    - name: Custom call name
    - prop_user_sharding: Function to propagate user sharding
    - partition: Function to partition the operation
    - infer_sharding_from_operands: Function to infer output sharding
    - can_side_effecting_have_replicated_sharding: Whether side-effecting ops can be replicated
    - c_api: C API interface (optional)
    """

def register_custom_call_as_batch_partitionable(
    target_name: str,
    c_api: Any | None = None,
) -> None:
    """
    Register custom call as batch partitionable.
    
    Parameters:
    - target_name: Name of the custom call target
    - c_api: C API interface (optional)
    """

def encode_inspect_sharding_callback(handler: Any) -> bytes:
    """
    Encode sharding inspection callback.
    
    Parameters:
    - handler: Callback handler function
    
    Returns:
    Encoded callback as bytes
    """

Custom Type System

Support for registering custom types for use with the FFI system.

def register_custom_type_id(
    type_name: str,
    type_id: Any,
    platform: str = 'cpu',
) -> None:
    """
    Register custom type ID for FFI.
    
    Parameters:
    - type_name: Unique name for the type
    - type_id: PyCapsule containing pointer to ffi::TypeId
    - platform: Target platform
    """

def register_custom_type_id_handler(
    platform: str, handler: CustomTypeIdHandler
) -> None:
    """
    Register handler for custom type IDs.
    
    Parameters:
    - platform: Target platform
    - handler: Handler function for registering type IDs
    """

Usage Examples

Basic Custom Call

from jaxlib import xla_client
import ctypes

# Example: Register a simple custom function
# First, you would compile a C/C++ function and get a pointer

# Hypothetical custom function (in practice, this would be from a compiled library)
def create_custom_add_capsule():
    # This is a simplified example - in practice you'd load from a shared library
    # and create a PyCapsule with the function pointer
    pass

# Register the custom call
xla_client.register_custom_call_target(
    name="custom_add",
    fn=create_custom_add_capsule(),  # PyCapsule with function pointer
    platform="cpu",
    api_version=1,  # Use typed FFI
    traits=xla_client.CustomCallTargetTraits.DEFAULT
)

# Check if registered
cpu_targets = xla_client.custom_call_targets("cpu")
print(f"Custom targets: {list(cpu_targets.keys())}")

Custom Call with Partitioning

from jaxlib import xla_client

def prop_user_sharding_fn(op_sharding, operand_shardings):
    """Propagate user-specified sharding."""
    # Implementation would handle sharding propagation
    return op_sharding

def partition_fn(operands, partition_id, total_partitions):
    """Partition the custom operation."""
    # Implementation would partition operands appropriately
    return operands

def infer_sharding_fn(operand_shardings):
    """Infer output sharding from operand shardings."""
    # Implementation would infer appropriate output sharding
    return operand_shardings[0] if operand_shardings else None

# Register partitioner for custom operation
xla_client.register_custom_call_partitioner(
    name="custom_matrix_multiply",
    prop_user_sharding=prop_user_sharding_fn,
    partition=partition_fn,
    infer_sharding_from_operands=infer_sharding_fn,
    can_side_effecting_have_replicated_sharding=False
)

Custom Types

from jaxlib import xla_client

# Register custom type (hypothetical example)
def create_custom_type_capsule():
    # In practice, this would create a PyCapsule containing
    # a pointer to an ffi::TypeId for your custom type
    pass

xla_client.register_custom_type_id(
    type_name="MyCustomType",
    type_id=create_custom_type_capsule(),
    platform="cpu"
)

print("Registered custom type: MyCustomType")

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json