or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcompilation-execution.mdcustom-operations.mddevice-management.mdhardware-operations.mdindex.mdplugin-system.mdsharding.mdxla-client.md

plugin-system.mddocs/

0

# Plugin System

1

2

Dynamic plugin loading and version management for hardware-specific extensions and third-party integrations.

3

4

## Capabilities

5

6

### Plugin Management

7

8

Functions for loading and managing PJRT plugins dynamically.

9

10

```python { .api }

11

def pjrt_plugin_loaded(plugin_name: str) -> bool:

12

"""

13

Check if a PJRT plugin is loaded.

14

15

Parameters:

16

- plugin_name: Name of the plugin to check

17

18

Returns:

19

True if plugin is loaded, False otherwise

20

"""

21

22

def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:

23

"""

24

Load a PJRT plugin from a shared library.

25

26

Parameters:

27

- plugin_name: Name to assign to the plugin

28

- library_path: Path to the plugin shared library

29

30

Returns:

31

Plugin handle or status

32

"""

33

34

def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:

35

"""

36

Load a PJRT plugin using existing C API.

37

38

Parameters:

39

- plugin_name: Name to assign to the plugin

40

- c_api: Existing C API interface

41

"""

42

43

def pjrt_plugin_initialized(plugin_name: str) -> bool:

44

"""

45

Check if a plugin is initialized.

46

47

Parameters:

48

- plugin_name: Name of the plugin

49

50

Returns:

51

True if initialized, False otherwise

52

"""

53

54

def initialize_pjrt_plugin(plugin_name: str) -> None:

55

"""

56

Initialize a loaded PJRT plugin.

57

58

The plugin must be loaded first before calling this.

59

60

Parameters:

61

- plugin_name: Name of the plugin to initialize

62

"""

63

```

64

65

### Plugin Import System

66

67

High-level interface for importing functionality from known plugins with version checking.

68

69

```python { .api }

70

# From jaxlib.plugin_support module

71

72

def import_from_plugin(

73

plugin_name: str,

74

submodule_name: str,

75

*,

76

check_version: bool = True

77

) -> ModuleType | None:

78

"""

79

Import a submodule from a known plugin with version checking.

80

81

Parameters:

82

- plugin_name: Plugin name ('cuda' or 'rocm')

83

- submodule_name: Submodule name (e.g., '_triton', '_linalg')

84

- check_version: Whether to check version compatibility

85

86

Returns:

87

Imported module or None if not available/incompatible

88

"""

89

90

def check_plugin_version(

91

plugin_name: str,

92

jaxlib_version: str,

93

plugin_version: str

94

) -> bool:

95

"""

96

Check if plugin version is compatible with jaxlib version.

97

98

Parameters:

99

- plugin_name: Name of the plugin

100

- jaxlib_version: Version of jaxlib

101

- plugin_version: Version of the plugin

102

103

Returns:

104

True if versions are compatible, False otherwise

105

"""

106

107

def maybe_import_plugin_submodule(

108

plugin_module_names: Sequence[str],

109

submodule_name: str,

110

*,

111

check_version: bool = True,

112

) -> ModuleType | None:

113

"""

114

Try to import plugin submodule from multiple candidates.

115

116

Parameters:

117

- plugin_module_names: List of plugin module names to try

118

- submodule_name: Submodule to import

119

- check_version: Whether to check version compatibility

120

121

Returns:

122

First successfully imported module or None

123

"""

124

```

125

126

## Usage Examples

127

128

### Loading Plugins

129

130

```python

131

from jaxlib import xla_client

132

133

# Check if a plugin is already loaded

134

cuda_loaded = xla_client.pjrt_plugin_loaded("cuda")

135

print(f"CUDA plugin loaded: {cuda_loaded}")

136

137

# Load a plugin dynamically (example path)

138

if not cuda_loaded:

139

try:

140

plugin_path = "/path/to/cuda_plugin.so" # Hypothetical path

141

result = xla_client.load_pjrt_plugin_dynamically("cuda", plugin_path)

142

print(f"Plugin load result: {result}")

143

144

# Initialize the plugin

145

if xla_client.pjrt_plugin_loaded("cuda"):

146

xla_client.initialize_pjrt_plugin("cuda")

147

print("CUDA plugin initialized")

148

except Exception as e:

149

print(f"Failed to load CUDA plugin: {e}")

150

151

# Check initialization status

152

initialized = xla_client.pjrt_plugin_initialized("cuda")

153

print(f"CUDA plugin initialized: {initialized}")

154

```

155

156

### Using Plugin Import System

157

158

```python

159

from jaxlib import plugin_support

160

161

# Try to import CUDA-specific functionality

162

cuda_linalg = plugin_support.import_from_plugin("cuda", "_linalg")

163

if cuda_linalg:

164

print("CUDA linear algebra module available")

165

# Use cuda_linalg.registrations(), etc.

166

else:

167

print("CUDA linear algebra not available")

168

169

# Try to import ROCm functionality

170

rocm_linalg = plugin_support.import_from_plugin("rocm", "_linalg")

171

if rocm_linalg:

172

print("ROCm linear algebra module available")

173

else:

174

print("ROCm linear algebra not available")

175

176

# Import with version checking disabled

177

triton_module = plugin_support.import_from_plugin(

178

"cuda", "_triton", check_version=False

179

)

180

if triton_module:

181

print("Triton module imported (version check skipped)")

182

```

183

184

### Version Compatibility

185

186

```python

187

from jaxlib import plugin_support

188

import jaxlib

189

190

jaxlib_version = jaxlib.__version__

191

192

# Check version compatibility manually

193

plugin_version = "0.7.1" # Example plugin version

194

compatible = plugin_support.check_plugin_version(

195

"cuda", jaxlib_version, plugin_version

196

)

197

print(f"CUDA plugin v{plugin_version} compatible with jaxlib v{jaxlib_version}: {compatible}")

198

199

# Try multiple plugin candidates

200

plugin_candidates = [".cuda", "jax_cuda13_plugin", "jax_cuda12_plugin"]

201

cuda_module = plugin_support.maybe_import_plugin_submodule(

202

plugin_candidates, "_linalg", check_version=True

203

)

204

205

if cuda_module:

206

print("Successfully imported CUDA module from one of the candidates")

207

else:

208

print("No compatible CUDA module found")

209

```

210

211

### Creating Plugin-Based Clients

212

213

```python

214

from jaxlib import xla_client

215

216

# Check available plugins and create appropriate clients

217

plugins_to_try = ["cuda", "rocm", "tpu"]

218

219

for plugin_name in plugins_to_try:

220

if xla_client.pjrt_plugin_loaded(plugin_name):

221

try:

222

# Generate default options for the plugin

223

if plugin_name == "cuda":

224

options = xla_client.generate_pjrt_gpu_plugin_options()

225

else:

226

options = {} # Use default options

227

228

# Create client using the plugin

229

client = xla_client.make_c_api_client(

230

plugin_name=plugin_name,

231

options=options

232

)

233

234

print(f"Created {plugin_name} client with {len(client.devices())} devices")

235

break

236

237

except Exception as e:

238

print(f"Failed to create {plugin_name} client: {e}")

239

continue

240

else:

241

# Fall back to CPU

242

print("No GPU/TPU plugins available, using CPU")

243

client = xla_client.make_cpu_client()

244

```

245

246

### Plugin Information

247

248

```python

249

from jaxlib import xla_client

250

251

# List common plugin names to check

252

common_plugins = ["cpu", "cuda", "rocm", "tpu"]

253

254

print("Plugin Status:")

255

print("-" * 40)

256

for plugin in common_plugins:

257

loaded = xla_client.pjrt_plugin_loaded(plugin)

258

if loaded:

259

initialized = xla_client.pjrt_plugin_initialized(plugin)

260

print(f"{plugin:8}: Loaded={loaded}, Initialized={initialized}")

261

else:

262

print(f"{plugin:8}: Not loaded")

263

264

# Check available custom call targets per platform

265

print("\nCustom Call Targets:")

266

print("-" * 40)

267

for platform in ["cpu", "CUDA", "ROCM"]:

268

try:

269

targets = xla_client.custom_call_targets(platform)

270

print(f"{platform}: {len(targets)} targets")

271

if targets:

272

# Show a few example targets

273

example_targets = list(targets.keys())[:3]

274

print(f" Examples: {example_targets}")

275

except Exception as e:

276

print(f"{platform}: Error - {e}")

277

```