CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-warp-lang

A Python framework for high-performance simulation and graphics programming that JIT compiles Python functions to efficient GPU/CPU kernel code.

Overview
Eval results
Files

framework-integration.mddocs/

Framework Interoperability

Warp provides seamless data exchange and integration with popular machine learning and scientific computing frameworks. This enables easy incorporation of Warp kernels into existing ML pipelines and scientific workflows.

Capabilities

PyTorch Integration

Convert between Warp arrays and PyTorch tensors with automatic device management and gradient support.

def from_torch(tensor, dtype: type = None, requires_grad: bool = None) -> array:
    """
    Create Warp array from PyTorch tensor.
    
    Args:
        tensor: PyTorch tensor
        dtype: Target Warp type (inferred if None)
        requires_grad: Enable gradient tracking (inherits if None)
        
    Returns:
        Warp array sharing memory with tensor
    """

def to_torch(arr: array, requires_grad: bool = False):
    """
    Create PyTorch tensor from Warp array.
    
    Args:
        arr: Warp array
        requires_grad: Enable gradient tracking
        
    Returns:
        PyTorch tensor sharing memory with array
    """

def dtype_from_torch(torch_dtype) -> type:
    """Convert PyTorch dtype to Warp type."""

def dtype_to_torch(wp_dtype: type):
    """Convert Warp type to PyTorch dtype."""

def device_from_torch(torch_device) -> Device:
    """Convert PyTorch device to Warp device."""

def device_to_torch(wp_device: Device):
    """Convert Warp device to PyTorch device."""

def stream_from_torch(torch_stream) -> Stream:
    """Create Warp stream from PyTorch CUDA stream."""

def stream_to_torch(wp_stream: Stream):
    """Convert Warp stream to PyTorch CUDA stream."""

JAX Integration

Interoperability with JAX for functional programming and automatic differentiation.

def from_jax(jax_array, dtype: type = None) -> array:
    """
    Create Warp array from JAX array.
    
    Args:
        jax_array: JAX DeviceArray
        dtype: Target Warp type (inferred if None)
        
    Returns:
        Warp array with data copied from JAX array
    """

def to_jax(arr: array):
    """
    Create JAX array from Warp array.
    
    Args:
        arr: Warp array
        
    Returns:
        JAX DeviceArray with data copied from Warp array
    """

def dtype_from_jax(jax_dtype) -> type:
    """Convert JAX dtype to Warp type."""

def dtype_to_jax(wp_dtype: type):
    """Convert Warp type to JAX dtype."""

def device_from_jax(jax_device) -> Device:
    """Convert JAX device to Warp device."""

def device_to_jax(wp_device: Device):
    """Convert Warp device to JAX device."""

JAX Experimental

Advanced JAX integration with XLA FFI support for high-performance custom operations.

# Available in warp.jax_experimental module
def register_custom_call(name: str, kernel: Kernel) -> None:
    """Register Warp kernel as JAX custom call."""

def xla_ffi_kernel(kernel: Kernel):
    """Decorator to create XLA FFI-compatible kernel."""

Paddle Integration

Integration with PaddlePaddle for deep learning workflows in Chinese ecosystem.

def from_paddle(paddle_tensor, dtype: type = None) -> array:
    """
    Create Warp array from Paddle tensor.
    
    Args:
        paddle_tensor: Paddle tensor
        dtype: Target Warp type (inferred if None)
        
    Returns:
        Warp array sharing memory with tensor
    """

def to_paddle(arr: array):
    """
    Create Paddle tensor from Warp array.
    
    Args:
        arr: Warp array
        
    Returns:
        Paddle tensor sharing memory with array
    """

def dtype_from_paddle(paddle_dtype) -> type:
    """Convert Paddle dtype to Warp type."""

def dtype_to_paddle(wp_dtype: type):
    """Convert Warp type to Paddle dtype."""

def device_from_paddle(paddle_device) -> Device:
    """Convert Paddle device to Warp device."""

def device_to_paddle(wp_device: Device):
    """Convert Warp device to Paddle device."""

def stream_from_paddle(paddle_stream) -> Stream:
    """Create Warp stream from Paddle CUDA stream."""

DLPack Integration

Universal tensor exchange format for interoperability across frameworks.

def from_dlpack(dlpack_tensor) -> array:
    """
    Create Warp array from DLPack tensor.
    
    Args:
        dlpack_tensor: DLPack tensor capsule
        
    Returns:
        Warp array sharing memory with DLPack tensor
    """

def to_dlpack(arr: array):
    """
    Create DLPack tensor from Warp array.
    
    Args:
        arr: Warp array
        
    Returns:
        DLPack tensor capsule sharing memory
    """

NumPy Integration

Direct conversion between Warp arrays and NumPy arrays.

def from_numpy(np_array: np.ndarray, 
              dtype: type = None, 
              device: Device = None) -> array:
    """
    Create Warp array from NumPy array.
    
    Args:
        np_array: NumPy array
        dtype: Target Warp type (inferred if None)
        device: Target device (CPU if None)
        
    Returns:
        Warp array with data copied from NumPy array
    """

# Note: array.numpy() method provides reverse conversion

Usage Examples

PyTorch-Warp Pipeline

import torch
import warp as wp

# Create PyTorch tensors
x_torch = torch.randn(1000, 3, device='cuda', requires_grad=True)
y_torch = torch.zeros(1000, 3, device='cuda')

# Convert to Warp arrays (shares memory, preserves gradients)
x_warp = wp.from_torch(x_torch)
y_warp = wp.from_torch(y_torch, requires_grad=True)

# Define Warp kernel
@wp.kernel
def process_data(x: wp.array(dtype=wp.vec3), 
                y: wp.array(dtype=wp.vec3)):
    i = wp.tid()
    # Some computation
    y[i] = x[i] * 2.0 + wp.vec3(1.0, 0.0, -1.0)

# Launch kernel
wp.launch(process_data, dim=1000, inputs=[x_warp, y_warp])

# Convert result back to PyTorch (shares memory)
result_torch = wp.to_torch(y_warp)

# Use in PyTorch pipeline
loss = torch.mean(result_torch)
loss.backward()  # Gradients flow back through Warp computation

JAX Integration Example

import jax
import jax.numpy as jnp
import warp as wp

# JAX array
x_jax = jnp.array([[1.0, 2.0], [3.0, 4.0]])

# Convert to Warp
x_warp = wp.from_jax(x_jax)

# Process with Warp kernel
@wp.kernel
def double_values(x: wp.array(dtype=float), 
                 y: wp.array(dtype=float)):
    i, j = wp.tid()
    y[i, j] = x[i, j] * 2.0

y_warp = wp.zeros_like(x_warp)
wp.launch(double_values, dim=x_warp.shape, inputs=[x_warp, y_warp])

# Convert back to JAX
y_jax = wp.to_jax(y_warp)

# Continue JAX computation
result = jnp.sum(y_jax)

Multi-Framework Workflow

import numpy as np
import torch
import warp as wp

# Start with NumPy data
np_data = np.random.rand(1000, 3).astype(np.float32)

# Convert to Warp
warp_array = wp.from_numpy(np_data, device='cuda')

# Process with Warp kernel
@wp.kernel
def normalize_vectors(vectors: wp.array(dtype=wp.vec3)):
    i = wp.tid()
    v = vectors[i] 
    length = wp.length(v)
    if length > 0.0:
        vectors[i] = v / length

wp.launch(normalize_vectors, dim=1000, inputs=[warp_array])

# Convert to PyTorch for ML pipeline
torch_tensor = wp.to_torch(warp_array)

# Use in neural network
model = torch.nn.Linear(3, 1).cuda()
output = model(torch_tensor)

# Convert back for final processing
final_warp = wp.from_torch(output)
final_np = final_warp.numpy()

Stream Synchronization

import torch
import warp as wp

# Create PyTorch CUDA stream
torch_stream = torch.cuda.Stream()

# Convert to Warp stream
warp_stream = wp.stream_from_torch(torch_stream)

# Launch Warp kernel on stream
with torch.cuda.stream(torch_stream):
    wp.launch(my_kernel, dim=1000, inputs=[x, y], stream=warp_stream)
    
    # PyTorch operations on same stream
    result = torch.matmul(tensor_a, tensor_b)

# Synchronization happens automatically
torch.cuda.synchronize()

Gradient Flow Example

import torch
import warp as wp

# Enable gradient tracking
torch.autograd.set_grad_enabled(True)

# PyTorch tensor with gradients
x = torch.randn(100, requires_grad=True, device='cuda')

# Custom Warp function with gradient support
@wp.func
def custom_activation(x: float) -> float:
    return wp.sin(x) * wp.exp(-x * x)

@wp.kernel  
def apply_activation(input: wp.array(dtype=float),
                    output: wp.array(dtype=float)):
    i = wp.tid()
    output[i] = custom_activation(input[i])

# Convert to Warp with gradient tracking
x_warp = wp.from_torch(x, requires_grad=True)
y_warp = wp.zeros_like(x_warp)

# Launch kernel
wp.launch(apply_activation, dim=100, inputs=[x_warp, y_warp])

# Convert back with gradient preservation
y = wp.to_torch(y_warp, requires_grad=True)

# Compute loss and backpropagate
loss = torch.sum(y)
loss.backward()

# Gradients available in original tensor
print(x.grad)  # Contains gradients from Warp computation

Device Management Across Frameworks

Cross-Framework Device Consistency

import torch
import warp as wp

# Ensure consistent device usage
if torch.cuda.is_available():
    torch_device = torch.device('cuda:0')
    warp_device = wp.device_from_torch(torch_device)
else:
    torch_device = torch.device('cpu')
    warp_device = wp.get_device('cpu')

# Set devices
torch.cuda.set_device(torch_device)
wp.set_device(warp_device)

# Create tensors/arrays on consistent devices
x_torch = torch.randn(1000, device=torch_device)
x_warp = wp.from_torch(x_torch)

assert x_warp.device == warp_device

Types

# Framework tensor types (external)
TorchTensor = torch.Tensor  # PyTorch tensor
JaxArray = jax.Array        # JAX array  
PaddleTensor = paddle.Tensor # Paddle tensor
DLPackTensor = object       # DLPack capsule

# Device conversion types
TorchDevice = torch.device
JaxDevice = jax.Device
PaddleDevice = paddle.device.CUDAPlace

# Stream types
TorchStream = torch.cuda.Stream

Install with Tessl CLI

npx tessl i tessl/pypi-warp-lang

docs

core-execution.md

fem.md

framework-integration.md

index.md

kernel-programming.md

optimization.md

rendering.md

types-arrays.md

utilities.md

tile.json