CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jaxlib

XLA library for JAX providing low-level bindings and hardware acceleration support

Pending
Overview
Eval results
Files

plugin-system.mddocs/

Plugin System

Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.

Capabilities

Plugin Management

Functions for loading and managing PJRT plugins dynamically.

def pjrt_plugin_loaded(plugin_name: str) -> bool:
    """
    Check if a PJRT plugin is loaded.
    
    Parameters:
    - plugin_name: Name of the plugin to check
    
    Returns:
    True if plugin is loaded, False otherwise
    """

def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
    """
    Load a PJRT plugin from a shared library.
    
    Parameters:
    - plugin_name: Name to assign to the plugin
    - library_path: Path to the plugin shared library
    
    Returns:
    Plugin handle or status
    """

def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
    """
    Load a PJRT plugin using existing C API.
    
    Parameters:
    - plugin_name: Name to assign to the plugin
    - c_api: Existing C API interface
    """

def pjrt_plugin_initialized(plugin_name: str) -> bool:
    """
    Check if a plugin is initialized.
    
    Parameters:
    - plugin_name: Name of the plugin
    
    Returns:
    True if initialized, False otherwise
    """

def initialize_pjrt_plugin(plugin_name: str) -> None:
    """
    Initialize a loaded PJRT plugin.
    
    The plugin must be loaded first before calling this.
    
    Parameters:
    - plugin_name: Name of the plugin to initialize
    """

Plugin Import System

High-level interface for importing functionality from known plugins with version checking.

# From jaxlib.plugin_support module

def import_from_plugin(
    plugin_name: str, 
    submodule_name: str, 
    *, 
    check_version: bool = True
) -> ModuleType | None:
    """
    Import a submodule from a known plugin with version checking.
    
    Parameters:
    - plugin_name: Plugin name ('cuda' or 'rocm')
    - submodule_name: Submodule name (e.g., '_triton', '_linalg')
    - check_version: Whether to check version compatibility
    
    Returns:
    Imported module or None if not available/incompatible
    """

def check_plugin_version(
    plugin_name: str, 
    jaxlib_version: str, 
    plugin_version: str
) -> bool:
    """
    Check if plugin version is compatible with jaxlib version.
    
    Parameters:
    - plugin_name: Name of the plugin
    - jaxlib_version: Version of jaxlib
    - plugin_version: Version of the plugin
    
    Returns:
    True if versions are compatible, False otherwise
    """

def maybe_import_plugin_submodule(
    plugin_module_names: Sequence[str],
    submodule_name: str,
    *,
    check_version: bool = True,
) -> ModuleType | None:
    """
    Try to import plugin submodule from multiple candidates.
    
    Parameters:
    - plugin_module_names: List of plugin module names to try
    - submodule_name: Submodule to import
    - check_version: Whether to check version compatibility
    
    Returns:
    First successfully imported module or None
    """

Usage Examples

Loading Plugins

from jaxlib import xla_client

# Check if a plugin is already loaded
cuda_loaded = xla_client.pjrt_plugin_loaded("cuda")
print(f"CUDA plugin loaded: {cuda_loaded}")

# Load a plugin dynamically (example path)
if not cuda_loaded:
    try:
        plugin_path = "/path/to/cuda_plugin.so"  # Hypothetical path
        result = xla_client.load_pjrt_plugin_dynamically("cuda", plugin_path)
        print(f"Plugin load result: {result}")
        
        # Initialize the plugin
        if xla_client.pjrt_plugin_loaded("cuda"):
            xla_client.initialize_pjrt_plugin("cuda")
            print("CUDA plugin initialized")
    except Exception as e:
        print(f"Failed to load CUDA plugin: {e}")

# Check initialization status
initialized = xla_client.pjrt_plugin_initialized("cuda")
print(f"CUDA plugin initialized: {initialized}")

Using Plugin Import System

from jaxlib import plugin_support

# Try to import CUDA-specific functionality
cuda_linalg = plugin_support.import_from_plugin("cuda", "_linalg")
if cuda_linalg:
    print("CUDA linear algebra module available")
    # Use cuda_linalg.registrations(), etc.
else:
    print("CUDA linear algebra not available")

# Try to import ROCm functionality
rocm_linalg = plugin_support.import_from_plugin("rocm", "_linalg")
if rocm_linalg:
    print("ROCm linear algebra module available")
else:
    print("ROCm linear algebra not available")

# Import with version checking disabled
triton_module = plugin_support.import_from_plugin(
    "cuda", "_triton", check_version=False
)
if triton_module:
    print("Triton module imported (version check skipped)")

Version Compatibility

from jaxlib import plugin_support
import jaxlib

jaxlib_version = jaxlib.__version__

# Check version compatibility manually
plugin_version = "0.7.1"  # Example plugin version
compatible = plugin_support.check_plugin_version(
    "cuda", jaxlib_version, plugin_version
)
print(f"CUDA plugin v{plugin_version} compatible with jaxlib v{jaxlib_version}: {compatible}")

# Try multiple plugin candidates
plugin_candidates = [".cuda", "jax_cuda13_plugin", "jax_cuda12_plugin"]
cuda_module = plugin_support.maybe_import_plugin_submodule(
    plugin_candidates, "_linalg", check_version=True
)

if cuda_module:
    print("Successfully imported CUDA module from one of the candidates")
else:
    print("No compatible CUDA module found")

Creating Plugin-Based Clients

from jaxlib import xla_client

# Check available plugins and create appropriate clients
plugins_to_try = ["cuda", "rocm", "tpu"]

for plugin_name in plugins_to_try:
    if xla_client.pjrt_plugin_loaded(plugin_name):
        try:
            # Generate default options for the plugin
            if plugin_name == "cuda":
                options = xla_client.generate_pjrt_gpu_plugin_options()
            else:
                options = {}  # Use default options
            
            # Create client using the plugin
            client = xla_client.make_c_api_client(
                plugin_name=plugin_name,
                options=options
            )
            
            print(f"Created {plugin_name} client with {len(client.devices())} devices")
            break
            
        except Exception as e:
            print(f"Failed to create {plugin_name} client: {e}")
            continue
else:
    # Fall back to CPU
    print("No GPU/TPU plugins available, using CPU")
    client = xla_client.make_cpu_client()

Plugin Information

from jaxlib import xla_client

# List common plugin names to check
common_plugins = ["cpu", "cuda", "rocm", "tpu"]

print("Plugin Status:")
print("-" * 40)
for plugin in common_plugins:
    loaded = xla_client.pjrt_plugin_loaded(plugin)
    if loaded:
        initialized = xla_client.pjrt_plugin_initialized(plugin)
        print(f"{plugin:8}: Loaded={loaded}, Initialized={initialized}")
    else:
        print(f"{plugin:8}: Not loaded")

# Check available custom call targets per platform
print("\nCustom Call Targets:")
print("-" * 40)
for platform in ["cpu", "CUDA", "ROCM"]:
    try:
        targets = xla_client.custom_call_targets(platform)
        print(f"{platform}: {len(targets)} targets")
        if targets:
            # Show a few example targets
            example_targets = list(targets.keys())[:3]
            print(f"  Examples: {example_targets}")
    except Exception as e:
        print(f"{platform}: Error - {e}")

Install with Tessl CLI

npx tessl i tessl/pypi-jaxlib

docs

array-operations.md

compilation-execution.md

custom-operations.md

device-management.md

hardware-operations.md

index.md

plugin-system.md

sharding.md

xla-client.md

tile.json