0
# Framework Interoperability
1
2
Warp provides seamless data exchange and integration with popular machine learning and scientific computing frameworks. This enables easy incorporation of Warp kernels into existing ML pipelines and scientific workflows.
3
4
## Capabilities
5
6
### PyTorch Integration
7
8
Convert between Warp arrays and PyTorch tensors with automatic device management and gradient support.
9
10
```python { .api }
11
def from_torch(tensor, dtype: type = None, requires_grad: bool = None) -> array:
12
"""
13
Create Warp array from PyTorch tensor.
14
15
Args:
16
tensor: PyTorch tensor
17
dtype: Target Warp type (inferred if None)
18
requires_grad: Enable gradient tracking (inherits if None)
19
20
Returns:
21
Warp array sharing memory with tensor
22
"""
23
24
def to_torch(arr: array, requires_grad: bool = False):
25
"""
26
Create PyTorch tensor from Warp array.
27
28
Args:
29
arr: Warp array
30
requires_grad: Enable gradient tracking
31
32
Returns:
33
PyTorch tensor sharing memory with array
34
"""
35
36
def dtype_from_torch(torch_dtype) -> type:
37
"""Convert PyTorch dtype to Warp type."""
38
39
def dtype_to_torch(wp_dtype: type):
40
"""Convert Warp type to PyTorch dtype."""
41
42
def device_from_torch(torch_device) -> Device:
43
"""Convert PyTorch device to Warp device."""
44
45
def device_to_torch(wp_device: Device):
46
"""Convert Warp device to PyTorch device."""
47
48
def stream_from_torch(torch_stream) -> Stream:
49
"""Create Warp stream from PyTorch CUDA stream."""
50
51
def stream_to_torch(wp_stream: Stream):
52
"""Convert Warp stream to PyTorch CUDA stream."""
53
```
54
55
### JAX Integration
56
57
Interoperability with JAX for functional programming and automatic differentiation.
58
59
```python { .api }
60
def from_jax(jax_array, dtype: type = None) -> array:
61
"""
62
Create Warp array from JAX array.
63
64
Args:
65
jax_array: JAX DeviceArray
66
dtype: Target Warp type (inferred if None)
67
68
Returns:
69
Warp array with data copied from JAX array
70
"""
71
72
def to_jax(arr: array):
73
"""
74
Create JAX array from Warp array.
75
76
Args:
77
arr: Warp array
78
79
Returns:
80
JAX DeviceArray with data copied from Warp array
81
"""
82
83
def dtype_from_jax(jax_dtype) -> type:
84
"""Convert JAX dtype to Warp type."""
85
86
def dtype_to_jax(wp_dtype: type):
87
"""Convert Warp type to JAX dtype."""
88
89
def device_from_jax(jax_device) -> Device:
90
"""Convert JAX device to Warp device."""
91
92
def device_to_jax(wp_device: Device):
93
"""Convert Warp device to JAX device."""
94
```
95
96
### JAX Experimental
97
98
Advanced JAX integration with XLA FFI support for high-performance custom operations.
99
100
```python { .api }
101
# Available in warp.jax_experimental module
102
def register_custom_call(name: str, kernel: Kernel) -> None:
103
"""Register Warp kernel as JAX custom call."""
104
105
def xla_ffi_kernel(kernel: Kernel):
106
"""Decorator to create XLA FFI-compatible kernel."""
107
```
108
109
### Paddle Integration
110
111
Integration with PaddlePaddle for deep learning workflows in Chinese ecosystem.
112
113
```python { .api }
114
def from_paddle(paddle_tensor, dtype: type = None) -> array:
115
"""
116
Create Warp array from Paddle tensor.
117
118
Args:
119
paddle_tensor: Paddle tensor
120
dtype: Target Warp type (inferred if None)
121
122
Returns:
123
Warp array sharing memory with tensor
124
"""
125
126
def to_paddle(arr: array):
127
"""
128
Create Paddle tensor from Warp array.
129
130
Args:
131
arr: Warp array
132
133
Returns:
134
Paddle tensor sharing memory with array
135
"""
136
137
def dtype_from_paddle(paddle_dtype) -> type:
138
"""Convert Paddle dtype to Warp type."""
139
140
def dtype_to_paddle(wp_dtype: type):
141
"""Convert Warp type to Paddle dtype."""
142
143
def device_from_paddle(paddle_device) -> Device:
144
"""Convert Paddle device to Warp device."""
145
146
def device_to_paddle(wp_device: Device):
147
"""Convert Warp device to Paddle device."""
148
149
def stream_from_paddle(paddle_stream) -> Stream:
150
"""Create Warp stream from Paddle CUDA stream."""
151
```
152
153
### DLPack Integration
154
155
Universal tensor exchange format for interoperability across frameworks.
156
157
```python { .api }
158
def from_dlpack(dlpack_tensor) -> array:
159
"""
160
Create Warp array from DLPack tensor.
161
162
Args:
163
dlpack_tensor: DLPack tensor capsule
164
165
Returns:
166
Warp array sharing memory with DLPack tensor
167
"""
168
169
def to_dlpack(arr: array):
170
"""
171
Create DLPack tensor from Warp array.
172
173
Args:
174
arr: Warp array
175
176
Returns:
177
DLPack tensor capsule sharing memory
178
"""
179
```
180
181
### NumPy Integration
182
183
Direct conversion between Warp arrays and NumPy arrays.
184
185
```python { .api }
186
def from_numpy(np_array: np.ndarray,
187
dtype: type = None,
188
device: Device = None) -> array:
189
"""
190
Create Warp array from NumPy array.
191
192
Args:
193
np_array: NumPy array
194
dtype: Target Warp type (inferred if None)
195
device: Target device (CPU if None)
196
197
Returns:
198
Warp array with data copied from NumPy array
199
"""
200
201
# Note: array.numpy() method provides reverse conversion
202
```
203
204
## Usage Examples
205
206
### PyTorch-Warp Pipeline
207
```python
208
import torch
209
import warp as wp
210
211
# Create PyTorch tensors
212
x_torch = torch.randn(1000, 3, device='cuda', requires_grad=True)
213
y_torch = torch.zeros(1000, 3, device='cuda')
214
215
# Convert to Warp arrays (shares memory, preserves gradients)
216
x_warp = wp.from_torch(x_torch)
217
y_warp = wp.from_torch(y_torch, requires_grad=True)
218
219
# Define Warp kernel
220
@wp.kernel
221
def process_data(x: wp.array(dtype=wp.vec3),
222
y: wp.array(dtype=wp.vec3)):
223
i = wp.tid()
224
# Some computation
225
y[i] = x[i] * 2.0 + wp.vec3(1.0, 0.0, -1.0)
226
227
# Launch kernel
228
wp.launch(process_data, dim=1000, inputs=[x_warp, y_warp])
229
230
# Convert result back to PyTorch (shares memory)
231
result_torch = wp.to_torch(y_warp)
232
233
# Use in PyTorch pipeline
234
loss = torch.mean(result_torch)
235
loss.backward() # Gradients flow back through Warp computation
236
```
237
238
### JAX Integration Example
239
```python
240
import jax
241
import jax.numpy as jnp
242
import warp as wp
243
244
# JAX array
245
x_jax = jnp.array([[1.0, 2.0], [3.0, 4.0]])
246
247
# Convert to Warp
248
x_warp = wp.from_jax(x_jax)
249
250
# Process with Warp kernel
251
@wp.kernel
252
def double_values(x: wp.array(dtype=float),
253
y: wp.array(dtype=float)):
254
i, j = wp.tid()
255
y[i, j] = x[i, j] * 2.0
256
257
y_warp = wp.zeros_like(x_warp)
258
wp.launch(double_values, dim=x_warp.shape, inputs=[x_warp, y_warp])
259
260
# Convert back to JAX
261
y_jax = wp.to_jax(y_warp)
262
263
# Continue JAX computation
264
result = jnp.sum(y_jax)
265
```
266
267
### Multi-Framework Workflow
268
```python
269
import numpy as np
270
import torch
271
import warp as wp
272
273
# Start with NumPy data
274
np_data = np.random.rand(1000, 3).astype(np.float32)
275
276
# Convert to Warp
277
warp_array = wp.from_numpy(np_data, device='cuda')
278
279
# Process with Warp kernel
280
@wp.kernel
281
def normalize_vectors(vectors: wp.array(dtype=wp.vec3)):
282
i = wp.tid()
283
v = vectors[i]
284
length = wp.length(v)
285
if length > 0.0:
286
vectors[i] = v / length
287
288
wp.launch(normalize_vectors, dim=1000, inputs=[warp_array])
289
290
# Convert to PyTorch for ML pipeline
291
torch_tensor = wp.to_torch(warp_array)
292
293
# Use in neural network
294
model = torch.nn.Linear(3, 1).cuda()
295
output = model(torch_tensor)
296
297
# Convert back for final processing
298
final_warp = wp.from_torch(output)
299
final_np = final_warp.numpy()
300
```
301
302
### Stream Synchronization
303
```python
304
import torch
305
import warp as wp
306
307
# Create PyTorch CUDA stream
308
torch_stream = torch.cuda.Stream()
309
310
# Convert to Warp stream
311
warp_stream = wp.stream_from_torch(torch_stream)
312
313
# Launch Warp kernel on stream
314
with torch.cuda.stream(torch_stream):
315
wp.launch(my_kernel, dim=1000, inputs=[x, y], stream=warp_stream)
316
317
# PyTorch operations on same stream
318
result = torch.matmul(tensor_a, tensor_b)
319
320
# Synchronization happens automatically
321
torch.cuda.synchronize()
322
```
323
324
### Gradient Flow Example
325
```python
326
import torch
327
import warp as wp
328
329
# Enable gradient tracking
330
torch.autograd.set_grad_enabled(True)
331
332
# PyTorch tensor with gradients
333
x = torch.randn(100, requires_grad=True, device='cuda')
334
335
# Custom Warp function with gradient support
336
@wp.func
337
def custom_activation(x: float) -> float:
338
return wp.sin(x) * wp.exp(-x * x)
339
340
@wp.kernel
341
def apply_activation(input: wp.array(dtype=float),
342
output: wp.array(dtype=float)):
343
i = wp.tid()
344
output[i] = custom_activation(input[i])
345
346
# Convert to Warp with gradient tracking
347
x_warp = wp.from_torch(x, requires_grad=True)
348
y_warp = wp.zeros_like(x_warp)
349
350
# Launch kernel
351
wp.launch(apply_activation, dim=100, inputs=[x_warp, y_warp])
352
353
# Convert back with gradient preservation
354
y = wp.to_torch(y_warp, requires_grad=True)
355
356
# Compute loss and backpropagate
357
loss = torch.sum(y)
358
loss.backward()
359
360
# Gradients available in original tensor
361
print(x.grad) # Contains gradients from Warp computation
362
```
363
364
## Device Management Across Frameworks
365
366
### Cross-Framework Device Consistency
367
```python
368
import torch
369
import warp as wp
370
371
# Ensure consistent device usage
372
if torch.cuda.is_available():
373
torch_device = torch.device('cuda:0')
374
warp_device = wp.device_from_torch(torch_device)
375
else:
376
torch_device = torch.device('cpu')
377
warp_device = wp.get_device('cpu')
378
379
# Set devices
380
torch.cuda.set_device(torch_device)
381
wp.set_device(warp_device)
382
383
# Create tensors/arrays on consistent devices
384
x_torch = torch.randn(1000, device=torch_device)
385
x_warp = wp.from_torch(x_torch)
386
387
assert x_warp.device == warp_device
388
```
389
390
## Types
391
392
```python { .api }
393
# Framework tensor types (external)
394
TorchTensor = torch.Tensor # PyTorch tensor
395
JaxArray = jax.Array # JAX array
396
PaddleTensor = paddle.Tensor # Paddle tensor
397
DLPackTensor = object # DLPack capsule
398
399
# Device conversion types
400
TorchDevice = torch.device
401
JaxDevice = jax.Device
402
PaddleDevice = paddle.device.CUDAPlace
403
404
# Stream types
405
TorchStream = torch.cuda.Stream
406
```