A Python framework for high-performance simulation and graphics programming that JIT compiles Python functions to efficient GPU/CPU kernel code.
Warp kernels are Python functions decorated with @wp.kernel that get JIT compiled to efficient GPU/CPU code. This module covers kernel decorators, built-in mathematical functions, and programming constructs for writing high-performance parallel code.
Decorators that transform Python functions into compiled GPU/CPU code.
def kernel(func: Callable) -> Kernel:
"""
Decorator to compile Python function as a parallel kernel.
Kernels run on multiple threads simultaneously.
Example:
@wp.kernel
def add_arrays(a: wp.array(dtype=float),
b: wp.array(dtype=float),
c: wp.array(dtype=float)):
i = wp.tid()
c[i] = a[i] + b[i]
"""
def func(func: Callable) -> Function:
"""
Decorator to compile Python function callable from kernels.
Functions are sequential and can be called from within kernels.
Example:
@wp.func
def compute_distance(a: wp.vec3, b: wp.vec3) -> float:
diff = a - b
return wp.length(diff)
"""
def func_grad(func: Function) -> Function:
"""Get gradient (derivative) function for automatic differentiation."""
def func_replay(func: Function) -> Function:
"""Get replay function for gradient computation."""
def func_native(func: Callable) -> Function:
"""Mark function to run natively (not compiled)."""Built-in functions and constructs available within kernel code.
def tid() -> int:
"""Get current thread index (0 to dim-1)."""
def static(condition: bool):
"""
Static conditional compilation.
Branch is resolved at compile time.
"""
def print(value):
"""Print value from kernel (for debugging)."""
def expect(condition: bool):
"""Assert condition is true (removed in release builds)."""
def struct(cls):
"""
Decorator to create Warp struct type.
Example:
@wp.struct
class Particle:
pos: wp.vec3
vel: wp.vec3
mass: float
"""
def overload(func: Callable):
"""Decorator to create function overloads for different argument types."""Core mathematical operations available in kernels.
# Basic arithmetic
def min(a: Scalar, b: Scalar) -> Scalar:
"""Return minimum of two values."""
def max(a: Scalar, b: Scalar) -> Scalar:
"""Return maximum of two values."""
def abs(x: Scalar) -> Scalar:
"""Return absolute value."""
def sign(x: Scalar) -> Scalar:
"""Return sign of value (-1, 0, or 1)."""
def step(edge: Float, x: Float) -> Float:
"""Step function: 0 if x < edge, else 1."""
def smoothstep(edge0: Float, edge1: Float, x: Float) -> Float:
"""Smooth step interpolation between two edges."""
# Power and exponential functions
def pow(base: Float, exp: Float) -> Float:
"""Raise base to power exp."""
def sqrt(x: Float) -> Float:
"""Square root."""
def cbrt(x: Float) -> Float:
"""Cube root."""
def exp(x: Float) -> Float:
"""Exponential function (e^x)."""
def exp2(x: Float) -> Float:
"""Base-2 exponential (2^x)."""
def log(x: Float) -> Float:
"""Natural logarithm."""
def log2(x: Float) -> Float:
"""Base-2 logarithm."""
def log10(x: Float) -> Float:
"""Base-10 logarithm."""
# Trigonometric functions
def sin(x: Float) -> Float:
"""Sine function."""
def cos(x: Float) -> Float:
"""Cosine function."""
def tan(x: Float) -> Float:
"""Tangent function."""
def asin(x: Float) -> Float:
"""Arcsine function."""
def acos(x: Float) -> Float:
"""Arccosine function."""
def atan(x: Float) -> Float:
"""Arctangent function."""
def atan2(y: Float, x: Float) -> Float:
"""Two-argument arctangent."""
def sinh(x: Float) -> Float:
"""Hyperbolic sine."""
def cosh(x: Float) -> Float:
"""Hyperbolic cosine."""
def tanh(x: Float) -> Float:
"""Hyperbolic tangent."""
# Rounding and utility
def floor(x: Float) -> Float:
"""Largest integer <= x."""
def ceil(x: Float) -> Float:
"""Smallest integer >= x."""
def round(x: Float) -> Float:
"""Round to nearest integer."""
def trunc(x: Float) -> Float:
"""Truncate to integer (towards zero)."""
def frac(x: Float) -> Float:
"""Fractional part (x - floor(x))."""
def fmod(x: Float, y: Float) -> Float:
"""Floating point remainder."""
def clamp(x: Scalar, a: Scalar, b: Scalar) -> Scalar:
"""Clamp x between a and b."""
def lerp(a: Float, b: Float, t: Float) -> Float:
"""Linear interpolation: a + t*(b-a)."""Built-in functions for vector operations.
def length(v: Vector) -> Float:
"""Vector magnitude."""
def length_sq(v: Vector) -> Float:
"""Squared vector magnitude (faster than length)."""
def normalize(v: Vector) -> Vector:
"""Return unit vector in same direction."""
def dot(a: Vector, b: Vector) -> Float:
"""Vector dot product."""
def cross(a: vec3, b: vec3) -> vec3:
"""Vector cross product (3D only)."""
def distance(a: Vector, b: Vector) -> Float:
"""Distance between two points."""
def distance_sq(a: Vector, b: Vector) -> Float:
"""Squared distance (faster than distance)."""
def reflect(v: Vector, n: Vector) -> Vector:
"""Reflect vector v across normal n."""
def refract(v: Vector, n: Vector, eta: Float) -> Vector:
"""Refract vector through surface with refractive index."""Built-in functions for matrix mathematics.
def mul(a: Matrix, b: Matrix) -> Matrix:
"""Matrix multiplication."""
def mul(m: Matrix, v: Vector) -> Vector:
"""Matrix-vector multiplication."""
def transpose(m: Matrix) -> Matrix:
"""Matrix transpose."""
def inverse(m: Matrix) -> Matrix:
"""Matrix inverse (if invertible)."""
def determinant(m: Matrix) -> Float:
"""Matrix determinant."""
def trace(m: Matrix) -> Float:
"""Matrix trace (sum of diagonal elements)."""Built-in functions for quaternion mathematics.
def quat_identity() -> quat:
"""Identity quaternion (no rotation)."""
def quat_from_axis_angle(axis: vec3, angle: Float) -> quat:
"""Create quaternion from rotation axis and angle."""
def quat_to_axis_angle(q: quat) -> tuple[vec3, Float]:
"""Extract axis and angle from quaternion."""
def quat_mul(a: quat, b: quat) -> quat:
"""Quaternion multiplication (composition)."""
def quat_rotate(q: quat, v: vec3) -> vec3:
"""Rotate vector by quaternion."""
def quat_inverse(q: quat) -> quat:
"""Quaternion inverse."""
def quat_conjugate(q: quat) -> quat:
"""Quaternion conjugate."""
def quat_normalize(q: quat) -> quat:
"""Normalize quaternion to unit length."""
def quat_slerp(a: quat, b: quat, t: Float) -> quat:
"""Spherical linear interpolation between quaternions."""Thread-safe operations for concurrent access to shared memory.
def atomic_add(arr: array, index: int, value: Scalar) -> Scalar:
"""Atomically add value to array element, return old value."""
def atomic_sub(arr: array, index: int, value: Scalar) -> Scalar:
"""Atomically subtract value from array element."""
def atomic_min(arr: array, index: int, value: Scalar) -> Scalar:
"""Atomically set minimum value."""
def atomic_max(arr: array, index: int, value: Scalar) -> Scalar:
"""Atomically set maximum value."""
def atomic_cas(arr: array, index: int, compare: Scalar, value: Scalar) -> Scalar:
"""Atomic compare-and-swap operation."""Functions for spatial queries and geometric computations.
def mesh_query_point(mesh: Mesh, point: vec3, max_dist: Float) -> MeshQueryPoint:
"""Find closest point on mesh to query point."""
def mesh_query_ray(mesh: Mesh, start: vec3, dir: vec3) -> MeshQueryRay:
"""Cast ray against mesh and find intersection."""
def mesh_query_aabb(mesh: Mesh, lower: vec3, upper: vec3) -> MeshQueryAABB:
"""Query mesh triangles overlapping axis-aligned box."""
def bvh_query_ray(bvh: Bvh, start: vec3, dir: vec3) -> BvhQuery:
"""Ray query against bounding volume hierarchy."""
def hash_grid_query(grid: HashGrid, point: vec3, radius: Float) -> HashGridQuery:
"""Find neighbors within radius using spatial hash grid."""
def volume_sample_f(volume: Volume, uvw: vec3, sampling_mode: int) -> Float:
"""Sample scalar value from 3D volume at normalized coordinates."""
def volume_sample_v(volume: Volume, uvw: vec3, sampling_mode: int) -> vec3:
"""Sample vector value from 3D volume."""import warp as wp
@wp.kernel
def compute_forces(positions: wp.array(dtype=wp.vec3),
velocities: wp.array(dtype=wp.vec3),
forces: wp.array(dtype=wp.vec3),
mass: float,
dt: float):
# Get thread ID
i = wp.tid()
# Each thread processes one particle
pos = positions[i]
vel = velocities[i]
# Compute force (simple damping)
force = -0.1 * vel * mass
forces[i] = force
# Update position using current velocity
positions[i] = pos + vel * dt@wp.func
def spring_force(pos_a: wp.vec3, pos_b: wp.vec3,
rest_length: float, stiffness: float) -> wp.vec3:
"""Compute spring force between two points."""
diff = pos_a - pos_b
dist = wp.length(diff)
if dist > 0.0:
force_mag = stiffness * (dist - rest_length)
force_dir = diff / dist
return -force_mag * force_dir
else:
return wp.vec3(0.0, 0.0, 0.0)
@wp.kernel
def update_springs(positions: wp.array(dtype=wp.vec3),
forces: wp.array(dtype=wp.vec3),
spring_indices: wp.array(dtype=wp.int32),
rest_lengths: wp.array(dtype=float),
stiffness: float):
i = wp.tid()
# Get spring endpoints
idx_a = spring_indices[i * 2]
idx_b = spring_indices[i * 2 + 1]
pos_a = positions[idx_a]
pos_b = positions[idx_b]
rest_len = rest_lengths[i]
# Compute spring force using helper function
force = spring_force(pos_a, pos_b, rest_len, stiffness)
# Apply forces (use atomic operations for thread safety)
wp.atomic_add(forces, idx_a, force)
wp.atomic_add(forces, idx_b, -force)@wp.struct
class Particle:
position: wp.vec3
velocity: wp.vec3
mass: float
radius: float
@wp.kernel
def update_particles(particles: wp.array(dtype=Particle), dt: float):
i = wp.tid()
p = particles[i]
# Update position
p.position = p.position + p.velocity * dt
# Apply gravity
gravity = wp.vec3(0.0, -9.81, 0.0)
p.velocity = p.velocity + gravity * dt
# Write back modified particle
particles[i] = pclass Kernel:
"""Compiled kernel function."""
def __call__(self, *args, **kwargs):
"""Launch kernel (equivalent to wp.launch)."""
class Function:
"""Compiled function callable from kernels."""
def __call__(self, *args, **kwargs):
"""Call function."""
# Query result types
class MeshQueryPoint:
face: int # Triangle index (-1 if no hit)
u: float # Barycentric coordinate
v: float # Barycentric coordinate
sign: float # Inside/outside indicator
class MeshQueryRay:
face: int # Hit triangle index (-1 if no hit)
u: float # Barycentric coordinate
v: float # Barycentric coordinate
t: float # Ray parameter at intersection
class MeshQueryAABB:
face: int # Next overlapping triangle (-1 when done)
class BvhQuery:
"""BVH ray query state."""
class HashGridQuery:
"""Hash grid neighbor query state."""Install with Tessl CLI
npx tessl i tessl/pypi-warp-lang