A Python framework for high-performance simulation and graphics programming that JIT compiles Python functions to efficient GPU/CPU kernel code.
Warp provides seamless data exchange and integration with popular machine learning and scientific computing frameworks. This enables easy incorporation of Warp kernels into existing ML pipelines and scientific workflows.
Convert between Warp arrays and PyTorch tensors with automatic device management and gradient support.
def from_torch(tensor, dtype: type = None, requires_grad: bool = None) -> array:
"""
Create Warp array from PyTorch tensor.
Args:
tensor: PyTorch tensor
dtype: Target Warp type (inferred if None)
requires_grad: Enable gradient tracking (inherits if None)
Returns:
Warp array sharing memory with tensor
"""
def to_torch(arr: array, requires_grad: bool = False):
"""
Create PyTorch tensor from Warp array.
Args:
arr: Warp array
requires_grad: Enable gradient tracking
Returns:
PyTorch tensor sharing memory with array
"""
def dtype_from_torch(torch_dtype) -> type:
"""Convert PyTorch dtype to Warp type."""
def dtype_to_torch(wp_dtype: type):
"""Convert Warp type to PyTorch dtype."""
def device_from_torch(torch_device) -> Device:
"""Convert PyTorch device to Warp device."""
def device_to_torch(wp_device: Device):
"""Convert Warp device to PyTorch device."""
def stream_from_torch(torch_stream) -> Stream:
"""Create Warp stream from PyTorch CUDA stream."""
def stream_to_torch(wp_stream: Stream):
"""Convert Warp stream to PyTorch CUDA stream."""Interoperability with JAX for functional programming and automatic differentiation.
def from_jax(jax_array, dtype: type = None) -> array:
"""
Create Warp array from JAX array.
Args:
jax_array: JAX DeviceArray
dtype: Target Warp type (inferred if None)
Returns:
Warp array with data copied from JAX array
"""
def to_jax(arr: array):
"""
Create JAX array from Warp array.
Args:
arr: Warp array
Returns:
JAX DeviceArray with data copied from Warp array
"""
def dtype_from_jax(jax_dtype) -> type:
"""Convert JAX dtype to Warp type."""
def dtype_to_jax(wp_dtype: type):
"""Convert Warp type to JAX dtype."""
def device_from_jax(jax_device) -> Device:
"""Convert JAX device to Warp device."""
def device_to_jax(wp_device: Device):
"""Convert Warp device to JAX device."""Advanced JAX integration with XLA FFI support for high-performance custom operations.
# Available in warp.jax_experimental module
def register_custom_call(name: str, kernel: Kernel) -> None:
"""Register Warp kernel as JAX custom call."""
def xla_ffi_kernel(kernel: Kernel):
"""Decorator to create XLA FFI-compatible kernel."""Integration with PaddlePaddle for deep learning workflows in Chinese ecosystem.
def from_paddle(paddle_tensor, dtype: type = None) -> array:
"""
Create Warp array from Paddle tensor.
Args:
paddle_tensor: Paddle tensor
dtype: Target Warp type (inferred if None)
Returns:
Warp array sharing memory with tensor
"""
def to_paddle(arr: array):
"""
Create Paddle tensor from Warp array.
Args:
arr: Warp array
Returns:
Paddle tensor sharing memory with array
"""
def dtype_from_paddle(paddle_dtype) -> type:
"""Convert Paddle dtype to Warp type."""
def dtype_to_paddle(wp_dtype: type):
"""Convert Warp type to Paddle dtype."""
def device_from_paddle(paddle_device) -> Device:
"""Convert Paddle device to Warp device."""
def device_to_paddle(wp_device: Device):
"""Convert Warp device to Paddle device."""
def stream_from_paddle(paddle_stream) -> Stream:
"""Create Warp stream from Paddle CUDA stream."""Universal tensor exchange format for interoperability across frameworks.
def from_dlpack(dlpack_tensor) -> array:
"""
Create Warp array from DLPack tensor.
Args:
dlpack_tensor: DLPack tensor capsule
Returns:
Warp array sharing memory with DLPack tensor
"""
def to_dlpack(arr: array):
"""
Create DLPack tensor from Warp array.
Args:
arr: Warp array
Returns:
DLPack tensor capsule sharing memory
"""Direct conversion between Warp arrays and NumPy arrays.
def from_numpy(np_array: np.ndarray,
dtype: type = None,
device: Device = None) -> array:
"""
Create Warp array from NumPy array.
Args:
np_array: NumPy array
dtype: Target Warp type (inferred if None)
device: Target device (CPU if None)
Returns:
Warp array with data copied from NumPy array
"""
# Note: array.numpy() method provides reverse conversionimport torch
import warp as wp
# Create PyTorch tensors
x_torch = torch.randn(1000, 3, device='cuda', requires_grad=True)
y_torch = torch.zeros(1000, 3, device='cuda')
# Convert to Warp arrays (shares memory, preserves gradients)
x_warp = wp.from_torch(x_torch)
y_warp = wp.from_torch(y_torch, requires_grad=True)
# Define Warp kernel
@wp.kernel
def process_data(x: wp.array(dtype=wp.vec3),
y: wp.array(dtype=wp.vec3)):
i = wp.tid()
# Some computation
y[i] = x[i] * 2.0 + wp.vec3(1.0, 0.0, -1.0)
# Launch kernel
wp.launch(process_data, dim=1000, inputs=[x_warp, y_warp])
# Convert result back to PyTorch (shares memory)
result_torch = wp.to_torch(y_warp)
# Use in PyTorch pipeline
loss = torch.mean(result_torch)
loss.backward() # Gradients flow back through Warp computationimport jax
import jax.numpy as jnp
import warp as wp
# JAX array
x_jax = jnp.array([[1.0, 2.0], [3.0, 4.0]])
# Convert to Warp
x_warp = wp.from_jax(x_jax)
# Process with Warp kernel
@wp.kernel
def double_values(x: wp.array(dtype=float),
y: wp.array(dtype=float)):
i, j = wp.tid()
y[i, j] = x[i, j] * 2.0
y_warp = wp.zeros_like(x_warp)
wp.launch(double_values, dim=x_warp.shape, inputs=[x_warp, y_warp])
# Convert back to JAX
y_jax = wp.to_jax(y_warp)
# Continue JAX computation
result = jnp.sum(y_jax)import numpy as np
import torch
import warp as wp
# Start with NumPy data
np_data = np.random.rand(1000, 3).astype(np.float32)
# Convert to Warp
warp_array = wp.from_numpy(np_data, device='cuda')
# Process with Warp kernel
@wp.kernel
def normalize_vectors(vectors: wp.array(dtype=wp.vec3)):
i = wp.tid()
v = vectors[i]
length = wp.length(v)
if length > 0.0:
vectors[i] = v / length
wp.launch(normalize_vectors, dim=1000, inputs=[warp_array])
# Convert to PyTorch for ML pipeline
torch_tensor = wp.to_torch(warp_array)
# Use in neural network
model = torch.nn.Linear(3, 1).cuda()
output = model(torch_tensor)
# Convert back for final processing
final_warp = wp.from_torch(output)
final_np = final_warp.numpy()import torch
import warp as wp
# Create PyTorch CUDA stream
torch_stream = torch.cuda.Stream()
# Convert to Warp stream
warp_stream = wp.stream_from_torch(torch_stream)
# Launch Warp kernel on stream
with torch.cuda.stream(torch_stream):
wp.launch(my_kernel, dim=1000, inputs=[x, y], stream=warp_stream)
# PyTorch operations on same stream
result = torch.matmul(tensor_a, tensor_b)
# Synchronization happens automatically
torch.cuda.synchronize()import torch
import warp as wp
# Enable gradient tracking
torch.autograd.set_grad_enabled(True)
# PyTorch tensor with gradients
x = torch.randn(100, requires_grad=True, device='cuda')
# Custom Warp function with gradient support
@wp.func
def custom_activation(x: float) -> float:
return wp.sin(x) * wp.exp(-x * x)
@wp.kernel
def apply_activation(input: wp.array(dtype=float),
output: wp.array(dtype=float)):
i = wp.tid()
output[i] = custom_activation(input[i])
# Convert to Warp with gradient tracking
x_warp = wp.from_torch(x, requires_grad=True)
y_warp = wp.zeros_like(x_warp)
# Launch kernel
wp.launch(apply_activation, dim=100, inputs=[x_warp, y_warp])
# Convert back with gradient preservation
y = wp.to_torch(y_warp, requires_grad=True)
# Compute loss and backpropagate
loss = torch.sum(y)
loss.backward()
# Gradients available in original tensor
print(x.grad) # Contains gradients from Warp computationimport torch
import warp as wp
# Ensure consistent device usage
if torch.cuda.is_available():
torch_device = torch.device('cuda:0')
warp_device = wp.device_from_torch(torch_device)
else:
torch_device = torch.device('cpu')
warp_device = wp.get_device('cpu')
# Set devices
torch.cuda.set_device(torch_device)
wp.set_device(warp_device)
# Create tensors/arrays on consistent devices
x_torch = torch.randn(1000, device=torch_device)
x_warp = wp.from_torch(x_torch)
assert x_warp.device == warp_device# Framework tensor types (external)
TorchTensor = torch.Tensor # PyTorch tensor
JaxArray = jax.Array # JAX array
PaddleTensor = paddle.Tensor # Paddle tensor
DLPackTensor = object # DLPack capsule
# Device conversion types
TorchDevice = torch.device
JaxDevice = jax.Device
PaddleDevice = paddle.device.CUDAPlace
# Stream types
TorchStream = torch.cuda.StreamInstall with Tessl CLI
npx tessl i tessl/pypi-warp-lang