0
# Plugin System
1
2
Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.
3
4
## Capabilities
5
6
### Plugin Management
7
8
Functions for loading and managing PJRT plugins dynamically.
9
10
```python { .api }
11
def pjrt_plugin_loaded(plugin_name: str) -> bool:
12
"""
13
Check if a PJRT plugin is loaded.
14
15
Parameters:
16
- plugin_name: Name of the plugin to check
17
18
Returns:
19
True if plugin is loaded, False otherwise
20
"""
21
22
def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
23
"""
24
Load a PJRT plugin from a shared library.
25
26
Parameters:
27
- plugin_name: Name to assign to the plugin
28
- library_path: Path to the plugin shared library
29
30
Returns:
31
Plugin handle or status
32
"""
33
34
def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
35
"""
36
Load a PJRT plugin using existing C API.
37
38
Parameters:
39
- plugin_name: Name to assign to the plugin
40
- c_api: Existing C API interface
41
"""
42
43
def pjrt_plugin_initialized(plugin_name: str) -> bool:
44
"""
45
Check if a plugin is initialized.
46
47
Parameters:
48
- plugin_name: Name of the plugin
49
50
Returns:
51
True if initialized, False otherwise
52
"""
53
54
def initialize_pjrt_plugin(plugin_name: str) -> None:
55
"""
56
Initialize a loaded PJRT plugin.
57
58
The plugin must be loaded first before calling this.
59
60
Parameters:
61
- plugin_name: Name of the plugin to initialize
62
"""
63
```
64
65
### Plugin Import System
66
67
High-level interface for importing functionality from known plugins with version checking.
68
69
```python { .api }
70
# From jaxlib.plugin_support module
71
72
def import_from_plugin(
73
plugin_name: str,
74
submodule_name: str,
75
*,
76
check_version: bool = True
77
) -> ModuleType | None:
78
"""
79
Import a submodule from a known plugin with version checking.
80
81
Parameters:
82
- plugin_name: Plugin name ('cuda' or 'rocm')
83
- submodule_name: Submodule name (e.g., '_triton', '_linalg')
84
- check_version: Whether to check version compatibility
85
86
Returns:
87
Imported module or None if not available/incompatible
88
"""
89
90
def check_plugin_version(
91
plugin_name: str,
92
jaxlib_version: str,
93
plugin_version: str
94
) -> bool:
95
"""
96
Check if plugin version is compatible with jaxlib version.
97
98
Parameters:
99
- plugin_name: Name of the plugin
100
- jaxlib_version: Version of jaxlib
101
- plugin_version: Version of the plugin
102
103
Returns:
104
True if versions are compatible, False otherwise
105
"""
106
107
def maybe_import_plugin_submodule(
108
plugin_module_names: Sequence[str],
109
submodule_name: str,
110
*,
111
check_version: bool = True,
112
) -> ModuleType | None:
113
"""
114
Try to import plugin submodule from multiple candidates.
115
116
Parameters:
117
- plugin_module_names: List of plugin module names to try
118
- submodule_name: Submodule to import
119
- check_version: Whether to check version compatibility
120
121
Returns:
122
First successfully imported module or None
123
"""
124
```
125
126
## Usage Examples
127
128
### Loading Plugins
129
130
```python
131
from jaxlib import xla_client
132
133
# Check if a plugin is already loaded
134
cuda_loaded = xla_client.pjrt_plugin_loaded("cuda")
135
print(f"CUDA plugin loaded: {cuda_loaded}")
136
137
# Load a plugin dynamically (example path)
138
if not cuda_loaded:
139
try:
140
plugin_path = "/path/to/cuda_plugin.so" # Hypothetical path
141
result = xla_client.load_pjrt_plugin_dynamically("cuda", plugin_path)
142
print(f"Plugin load result: {result}")
143
144
# Initialize the plugin
145
if xla_client.pjrt_plugin_loaded("cuda"):
146
xla_client.initialize_pjrt_plugin("cuda")
147
print("CUDA plugin initialized")
148
except Exception as e:
149
print(f"Failed to load CUDA plugin: {e}")
150
151
# Check initialization status
152
initialized = xla_client.pjrt_plugin_initialized("cuda")
153
print(f"CUDA plugin initialized: {initialized}")
154
```
155
156
### Using Plugin Import System
157
158
```python
159
from jaxlib import plugin_support
160
161
# Try to import CUDA-specific functionality
162
cuda_linalg = plugin_support.import_from_plugin("cuda", "_linalg")
163
if cuda_linalg:
164
print("CUDA linear algebra module available")
165
# Use cuda_linalg.registrations(), etc.
166
else:
167
print("CUDA linear algebra not available")
168
169
# Try to import ROCm functionality
170
rocm_linalg = plugin_support.import_from_plugin("rocm", "_linalg")
171
if rocm_linalg:
172
print("ROCm linear algebra module available")
173
else:
174
print("ROCm linear algebra not available")
175
176
# Import with version checking disabled
177
triton_module = plugin_support.import_from_plugin(
178
"cuda", "_triton", check_version=False
179
)
180
if triton_module:
181
print("Triton module imported (version check skipped)")
182
```
183
184
### Version Compatibility
185
186
```python
187
from jaxlib import plugin_support
188
import jaxlib
189
190
jaxlib_version = jaxlib.__version__
191
192
# Check version compatibility manually
193
plugin_version = "0.7.1" # Example plugin version
194
compatible = plugin_support.check_plugin_version(
195
"cuda", jaxlib_version, plugin_version
196
)
197
print(f"CUDA plugin v{plugin_version} compatible with jaxlib v{jaxlib_version}: {compatible}")
198
199
# Try multiple plugin candidates
200
plugin_candidates = [".cuda", "jax_cuda13_plugin", "jax_cuda12_plugin"]
201
cuda_module = plugin_support.maybe_import_plugin_submodule(
202
plugin_candidates, "_linalg", check_version=True
203
)
204
205
if cuda_module:
206
print("Successfully imported CUDA module from one of the candidates")
207
else:
208
print("No compatible CUDA module found")
209
```
210
211
### Creating Plugin-Based Clients
212
213
```python
214
from jaxlib import xla_client
215
216
# Check available plugins and create appropriate clients
217
plugins_to_try = ["cuda", "rocm", "tpu"]
218
219
for plugin_name in plugins_to_try:
220
if xla_client.pjrt_plugin_loaded(plugin_name):
221
try:
222
# Generate default options for the plugin
223
if plugin_name == "cuda":
224
options = xla_client.generate_pjrt_gpu_plugin_options()
225
else:
226
options = {} # Use default options
227
228
# Create client using the plugin
229
client = xla_client.make_c_api_client(
230
plugin_name=plugin_name,
231
options=options
232
)
233
234
print(f"Created {plugin_name} client with {len(client.devices())} devices")
235
break
236
237
except Exception as e:
238
print(f"Failed to create {plugin_name} client: {e}")
239
continue
240
else:
241
# Fall back to CPU
242
print("No GPU/TPU plugins available, using CPU")
243
client = xla_client.make_cpu_client()
244
```
245
246
### Plugin Information
247
248
```python
249
from jaxlib import xla_client
250
251
# List common plugin names to check
252
common_plugins = ["cpu", "cuda", "rocm", "tpu"]
253
254
print("Plugin Status:")
255
print("-" * 40)
256
for plugin in common_plugins:
257
loaded = xla_client.pjrt_plugin_loaded(plugin)
258
if loaded:
259
initialized = xla_client.pjrt_plugin_initialized(plugin)
260
print(f"{plugin:8}: Loaded={loaded}, Initialized={initialized}")
261
else:
262
print(f"{plugin:8}: Not loaded")
263
264
# Check available custom call targets per platform
265
print("\nCustom Call Targets:")
266
print("-" * 40)
267
for platform in ["cpu", "CUDA", "ROCM"]:
268
try:
269
targets = xla_client.custom_call_targets(platform)
270
print(f"{platform}: {len(targets)} targets")
271
if targets:
272
# Show a few example targets
273
example_targets = list(targets.keys())[:3]
274
print(f" Examples: {example_targets}")
275
except Exception as e:
276
print(f"{platform}: Error - {e}")
277
```