or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-execution.mdfem.mdframework-integration.mdindex.mdkernel-programming.mdoptimization.mdrendering.mdtypes-arrays.mdutilities.md

framework-integration.mddocs/

0

# Framework Interoperability

1

2

Warp provides seamless data exchange and integration with popular machine learning and scientific computing frameworks. This enables easy incorporation of Warp kernels into existing ML pipelines and scientific workflows.

3

4

## Capabilities

5

6

### PyTorch Integration

7

8

Convert between Warp arrays and PyTorch tensors with automatic device management and gradient support.

9

10

```python { .api }

11

def from_torch(tensor, dtype: type = None, requires_grad: bool = None) -> array:

12

"""

13

Create Warp array from PyTorch tensor.

14

15

Args:

16

tensor: PyTorch tensor

17

dtype: Target Warp type (inferred if None)

18

requires_grad: Enable gradient tracking (inherits if None)

19

20

Returns:

21

Warp array sharing memory with tensor

22

"""

23

24

def to_torch(arr: array, requires_grad: bool = False):

25

"""

26

Create PyTorch tensor from Warp array.

27

28

Args:

29

arr: Warp array

30

requires_grad: Enable gradient tracking

31

32

Returns:

33

PyTorch tensor sharing memory with array

34

"""

35

36

def dtype_from_torch(torch_dtype) -> type:

37

"""Convert PyTorch dtype to Warp type."""

38

39

def dtype_to_torch(wp_dtype: type):

40

"""Convert Warp type to PyTorch dtype."""

41

42

def device_from_torch(torch_device) -> Device:

43

"""Convert PyTorch device to Warp device."""

44

45

def device_to_torch(wp_device: Device):

46

"""Convert Warp device to PyTorch device."""

47

48

def stream_from_torch(torch_stream) -> Stream:

49

"""Create Warp stream from PyTorch CUDA stream."""

50

51

def stream_to_torch(wp_stream: Stream):

52

"""Convert Warp stream to PyTorch CUDA stream."""

53

```

54

55

### JAX Integration

56

57

Interoperability with JAX for functional programming and automatic differentiation.

58

59

```python { .api }

60

def from_jax(jax_array, dtype: type = None) -> array:

61

"""

62

Create Warp array from JAX array.

63

64

Args:

65

jax_array: JAX DeviceArray

66

dtype: Target Warp type (inferred if None)

67

68

Returns:

69

Warp array with data copied from JAX array

70

"""

71

72

def to_jax(arr: array):

73

"""

74

Create JAX array from Warp array.

75

76

Args:

77

arr: Warp array

78

79

Returns:

80

JAX DeviceArray with data copied from Warp array

81

"""

82

83

def dtype_from_jax(jax_dtype) -> type:

84

"""Convert JAX dtype to Warp type."""

85

86

def dtype_to_jax(wp_dtype: type):

87

"""Convert Warp type to JAX dtype."""

88

89

def device_from_jax(jax_device) -> Device:

90

"""Convert JAX device to Warp device."""

91

92

def device_to_jax(wp_device: Device):

93

"""Convert Warp device to JAX device."""

94

```

95

96

### JAX Experimental

97

98

Advanced JAX integration with XLA FFI support for high-performance custom operations.

99

100

```python { .api }

101

# Available in warp.jax_experimental module

102

def register_custom_call(name: str, kernel: Kernel) -> None:

103

"""Register Warp kernel as JAX custom call."""

104

105

def xla_ffi_kernel(kernel: Kernel):

106

"""Decorator to create XLA FFI-compatible kernel."""

107

```

108

109

### Paddle Integration

110

111

Integration with PaddlePaddle for deep learning workflows in Chinese ecosystem.

112

113

```python { .api }

114

def from_paddle(paddle_tensor, dtype: type = None) -> array:

115

"""

116

Create Warp array from Paddle tensor.

117

118

Args:

119

paddle_tensor: Paddle tensor

120

dtype: Target Warp type (inferred if None)

121

122

Returns:

123

Warp array sharing memory with tensor

124

"""

125

126

def to_paddle(arr: array):

127

"""

128

Create Paddle tensor from Warp array.

129

130

Args:

131

arr: Warp array

132

133

Returns:

134

Paddle tensor sharing memory with array

135

"""

136

137

def dtype_from_paddle(paddle_dtype) -> type:

138

"""Convert Paddle dtype to Warp type."""

139

140

def dtype_to_paddle(wp_dtype: type):

141

"""Convert Warp type to Paddle dtype."""

142

143

def device_from_paddle(paddle_device) -> Device:

144

"""Convert Paddle device to Warp device."""

145

146

def device_to_paddle(wp_device: Device):

147

"""Convert Warp device to Paddle device."""

148

149

def stream_from_paddle(paddle_stream) -> Stream:

150

"""Create Warp stream from Paddle CUDA stream."""

151

```

152

153

### DLPack Integration

154

155

Universal tensor exchange format for interoperability across frameworks.

156

157

```python { .api }

158

def from_dlpack(dlpack_tensor) -> array:

159

"""

160

Create Warp array from DLPack tensor.

161

162

Args:

163

dlpack_tensor: DLPack tensor capsule

164

165

Returns:

166

Warp array sharing memory with DLPack tensor

167

"""

168

169

def to_dlpack(arr: array):

170

"""

171

Create DLPack tensor from Warp array.

172

173

Args:

174

arr: Warp array

175

176

Returns:

177

DLPack tensor capsule sharing memory

178

"""

179

```

180

181

### NumPy Integration

182

183

Direct conversion between Warp arrays and NumPy arrays.

184

185

```python { .api }

186

def from_numpy(np_array: np.ndarray,

187

dtype: type = None,

188

device: Device = None) -> array:

189

"""

190

Create Warp array from NumPy array.

191

192

Args:

193

np_array: NumPy array

194

dtype: Target Warp type (inferred if None)

195

device: Target device (CPU if None)

196

197

Returns:

198

Warp array with data copied from NumPy array

199

"""

200

201

# Note: array.numpy() method provides reverse conversion

202

```

203

204

## Usage Examples

205

206

### PyTorch-Warp Pipeline

207

```python

208

import torch

209

import warp as wp

210

211

# Create PyTorch tensors

212

x_torch = torch.randn(1000, 3, device='cuda', requires_grad=True)

213

y_torch = torch.zeros(1000, 3, device='cuda')

214

215

# Convert to Warp arrays (shares memory, preserves gradients)

216

x_warp = wp.from_torch(x_torch)

217

y_warp = wp.from_torch(y_torch, requires_grad=True)

218

219

# Define Warp kernel

220

@wp.kernel

221

def process_data(x: wp.array(dtype=wp.vec3),

222

y: wp.array(dtype=wp.vec3)):

223

i = wp.tid()

224

# Some computation

225

y[i] = x[i] * 2.0 + wp.vec3(1.0, 0.0, -1.0)

226

227

# Launch kernel

228

wp.launch(process_data, dim=1000, inputs=[x_warp, y_warp])

229

230

# Convert result back to PyTorch (shares memory)

231

result_torch = wp.to_torch(y_warp)

232

233

# Use in PyTorch pipeline

234

loss = torch.mean(result_torch)

235

loss.backward() # Gradients flow back through Warp computation

236

```

237

238

### JAX Integration Example

239

```python

240

import jax

241

import jax.numpy as jnp

242

import warp as wp

243

244

# JAX array

245

x_jax = jnp.array([[1.0, 2.0], [3.0, 4.0]])

246

247

# Convert to Warp

248

x_warp = wp.from_jax(x_jax)

249

250

# Process with Warp kernel

251

@wp.kernel

252

def double_values(x: wp.array(dtype=float),

253

y: wp.array(dtype=float)):

254

i, j = wp.tid()

255

y[i, j] = x[i, j] * 2.0

256

257

y_warp = wp.zeros_like(x_warp)

258

wp.launch(double_values, dim=x_warp.shape, inputs=[x_warp, y_warp])

259

260

# Convert back to JAX

261

y_jax = wp.to_jax(y_warp)

262

263

# Continue JAX computation

264

result = jnp.sum(y_jax)

265

```

266

267

### Multi-Framework Workflow

268

```python

269

import numpy as np

270

import torch

271

import warp as wp

272

273

# Start with NumPy data

274

np_data = np.random.rand(1000, 3).astype(np.float32)

275

276

# Convert to Warp

277

warp_array = wp.from_numpy(np_data, device='cuda')

278

279

# Process with Warp kernel

280

@wp.kernel

281

def normalize_vectors(vectors: wp.array(dtype=wp.vec3)):

282

i = wp.tid()

283

v = vectors[i]

284

length = wp.length(v)

285

if length > 0.0:

286

vectors[i] = v / length

287

288

wp.launch(normalize_vectors, dim=1000, inputs=[warp_array])

289

290

# Convert to PyTorch for ML pipeline

291

torch_tensor = wp.to_torch(warp_array)

292

293

# Use in neural network

294

model = torch.nn.Linear(3, 1).cuda()

295

output = model(torch_tensor)

296

297

# Convert back for final processing

298

final_warp = wp.from_torch(output)

299

final_np = final_warp.numpy()

300

```

301

302

### Stream Synchronization

303

```python

304

import torch

305

import warp as wp

306

307

# Create PyTorch CUDA stream

308

torch_stream = torch.cuda.Stream()

309

310

# Convert to Warp stream

311

warp_stream = wp.stream_from_torch(torch_stream)

312

313

# Launch Warp kernel on stream

314

with torch.cuda.stream(torch_stream):

315

wp.launch(my_kernel, dim=1000, inputs=[x, y], stream=warp_stream)

316

317

# PyTorch operations on same stream

318

result = torch.matmul(tensor_a, tensor_b)

319

320

# Synchronization happens automatically

321

torch.cuda.synchronize()

322

```

323

324

### Gradient Flow Example

325

```python

326

import torch

327

import warp as wp

328

329

# Enable gradient tracking

330

torch.autograd.set_grad_enabled(True)

331

332

# PyTorch tensor with gradients

333

x = torch.randn(100, requires_grad=True, device='cuda')

334

335

# Custom Warp function with gradient support

336

@wp.func

337

def custom_activation(x: float) -> float:

338

return wp.sin(x) * wp.exp(-x * x)

339

340

@wp.kernel

341

def apply_activation(input: wp.array(dtype=float),

342

output: wp.array(dtype=float)):

343

i = wp.tid()

344

output[i] = custom_activation(input[i])

345

346

# Convert to Warp with gradient tracking

347

x_warp = wp.from_torch(x, requires_grad=True)

348

y_warp = wp.zeros_like(x_warp)

349

350

# Launch kernel

351

wp.launch(apply_activation, dim=100, inputs=[x_warp, y_warp])

352

353

# Convert back with gradient preservation

354

y = wp.to_torch(y_warp, requires_grad=True)

355

356

# Compute loss and backpropagate

357

loss = torch.sum(y)

358

loss.backward()

359

360

# Gradients available in original tensor

361

print(x.grad) # Contains gradients from Warp computation

362

```

363

364

## Device Management Across Frameworks

365

366

### Cross-Framework Device Consistency

367

```python

368

import torch

369

import warp as wp

370

371

# Ensure consistent device usage

372

if torch.cuda.is_available():

373

torch_device = torch.device('cuda:0')

374

warp_device = wp.device_from_torch(torch_device)

375

else:

376

torch_device = torch.device('cpu')

377

warp_device = wp.get_device('cpu')

378

379

# Set devices

380

torch.cuda.set_device(torch_device)

381

wp.set_device(warp_device)

382

383

# Create tensors/arrays on consistent devices

384

x_torch = torch.randn(1000, device=torch_device)

385

x_warp = wp.from_torch(x_torch)

386

387

assert x_warp.device == warp_device

388

```

389

390

## Types

391

392

```python { .api }

393

# Framework tensor types (external)

394

TorchTensor = torch.Tensor # PyTorch tensor

395

JaxArray = jax.Array # JAX array

396

PaddleTensor = paddle.Tensor # Paddle tensor

397

DLPackTensor = object # DLPack capsule

398

399

# Device conversion types

400

TorchDevice = torch.device

401

JaxDevice = jax.Device

402

PaddleDevice = paddle.device.CUDAPlace

403

404

# Stream types

405

TorchStream = torch.cuda.Stream

406

```