XLA library for JAX providing low-level bindings and hardware acceleration support
—
Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.
Functions for loading and managing PJRT plugins dynamically.
def pjrt_plugin_loaded(plugin_name: str) -> bool:
"""
Check if a PJRT plugin is loaded.
Parameters:
- plugin_name: Name of the plugin to check
Returns:
True if plugin is loaded, False otherwise
"""
def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
"""
Load a PJRT plugin from a shared library.
Parameters:
- plugin_name: Name to assign to the plugin
- library_path: Path to the plugin shared library
Returns:
Plugin handle or status
"""
def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
"""
Load a PJRT plugin using existing C API.
Parameters:
- plugin_name: Name to assign to the plugin
- c_api: Existing C API interface
"""
def pjrt_plugin_initialized(plugin_name: str) -> bool:
"""
Check if a plugin is initialized.
Parameters:
- plugin_name: Name of the plugin
Returns:
True if initialized, False otherwise
"""
def initialize_pjrt_plugin(plugin_name: str) -> None:
"""
Initialize a loaded PJRT plugin.
The plugin must be loaded first before calling this.
Parameters:
- plugin_name: Name of the plugin to initialize
"""High-level interface for importing functionality from known plugins with version checking.
# From jaxlib.plugin_support module
def import_from_plugin(
plugin_name: str,
submodule_name: str,
*,
check_version: bool = True
) -> ModuleType | None:
"""
Import a submodule from a known plugin with version checking.
Parameters:
- plugin_name: Plugin name ('cuda' or 'rocm')
- submodule_name: Submodule name (e.g., '_triton', '_linalg')
- check_version: Whether to check version compatibility
Returns:
Imported module or None if not available/incompatible
"""
def check_plugin_version(
plugin_name: str,
jaxlib_version: str,
plugin_version: str
) -> bool:
"""
Check if plugin version is compatible with jaxlib version.
Parameters:
- plugin_name: Name of the plugin
- jaxlib_version: Version of jaxlib
- plugin_version: Version of the plugin
Returns:
True if versions are compatible, False otherwise
"""
def maybe_import_plugin_submodule(
plugin_module_names: Sequence[str],
submodule_name: str,
*,
check_version: bool = True,
) -> ModuleType | None:
"""
Try to import plugin submodule from multiple candidates.
Parameters:
- plugin_module_names: List of plugin module names to try
- submodule_name: Submodule to import
- check_version: Whether to check version compatibility
Returns:
First successfully imported module or None
"""from jaxlib import xla_client
# Check if a plugin is already loaded
cuda_loaded = xla_client.pjrt_plugin_loaded("cuda")
print(f"CUDA plugin loaded: {cuda_loaded}")
# Load a plugin dynamically (example path)
if not cuda_loaded:
try:
plugin_path = "/path/to/cuda_plugin.so" # Hypothetical path
result = xla_client.load_pjrt_plugin_dynamically("cuda", plugin_path)
print(f"Plugin load result: {result}")
# Initialize the plugin
if xla_client.pjrt_plugin_loaded("cuda"):
xla_client.initialize_pjrt_plugin("cuda")
print("CUDA plugin initialized")
except Exception as e:
print(f"Failed to load CUDA plugin: {e}")
# Check initialization status
initialized = xla_client.pjrt_plugin_initialized("cuda")
print(f"CUDA plugin initialized: {initialized}")from jaxlib import plugin_support
# Try to import CUDA-specific functionality
cuda_linalg = plugin_support.import_from_plugin("cuda", "_linalg")
if cuda_linalg:
print("CUDA linear algebra module available")
# Use cuda_linalg.registrations(), etc.
else:
print("CUDA linear algebra not available")
# Try to import ROCm functionality
rocm_linalg = plugin_support.import_from_plugin("rocm", "_linalg")
if rocm_linalg:
print("ROCm linear algebra module available")
else:
print("ROCm linear algebra not available")
# Import with version checking disabled
triton_module = plugin_support.import_from_plugin(
"cuda", "_triton", check_version=False
)
if triton_module:
print("Triton module imported (version check skipped)")from jaxlib import plugin_support
import jaxlib
jaxlib_version = jaxlib.__version__
# Check version compatibility manually
plugin_version = "0.7.1" # Example plugin version
compatible = plugin_support.check_plugin_version(
"cuda", jaxlib_version, plugin_version
)
print(f"CUDA plugin v{plugin_version} compatible with jaxlib v{jaxlib_version}: {compatible}")
# Try multiple plugin candidates
plugin_candidates = [".cuda", "jax_cuda13_plugin", "jax_cuda12_plugin"]
cuda_module = plugin_support.maybe_import_plugin_submodule(
plugin_candidates, "_linalg", check_version=True
)
if cuda_module:
print("Successfully imported CUDA module from one of the candidates")
else:
print("No compatible CUDA module found")from jaxlib import xla_client
# Check available plugins and create appropriate clients
plugins_to_try = ["cuda", "rocm", "tpu"]
for plugin_name in plugins_to_try:
if xla_client.pjrt_plugin_loaded(plugin_name):
try:
# Generate default options for the plugin
if plugin_name == "cuda":
options = xla_client.generate_pjrt_gpu_plugin_options()
else:
options = {} # Use default options
# Create client using the plugin
client = xla_client.make_c_api_client(
plugin_name=plugin_name,
options=options
)
print(f"Created {plugin_name} client with {len(client.devices())} devices")
break
except Exception as e:
print(f"Failed to create {plugin_name} client: {e}")
continue
else:
# Fall back to CPU
print("No GPU/TPU plugins available, using CPU")
client = xla_client.make_cpu_client()from jaxlib import xla_client
# List common plugin names to check
common_plugins = ["cpu", "cuda", "rocm", "tpu"]
print("Plugin Status:")
print("-" * 40)
for plugin in common_plugins:
loaded = xla_client.pjrt_plugin_loaded(plugin)
if loaded:
initialized = xla_client.pjrt_plugin_initialized(plugin)
print(f"{plugin:8}: Loaded={loaded}, Initialized={initialized}")
else:
print(f"{plugin:8}: Not loaded")
# Check available custom call targets per platform
print("\nCustom Call Targets:")
print("-" * 40)
for platform in ["cpu", "CUDA", "ROCM"]:
try:
targets = xla_client.custom_call_targets(platform)
print(f"{platform}: {len(targets)} targets")
if targets:
# Show a few example targets
example_targets = list(targets.keys())[:3]
print(f" Examples: {example_targets}")
except Exception as e:
print(f"{platform}: Error - {e}")Install with Tessl CLI
npx tessl i tessl/pypi-jaxlib