CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-warp-lang

A Python framework for high-performance simulation and graphics programming that JIT compiles Python functions to efficient GPU/CPU kernel code.

Overview
Eval results
Files

kernel-programming.mddocs/

Kernel Programming and Built-in Functions

Warp kernels are Python functions decorated with @wp.kernel that get JIT compiled to efficient GPU/CPU code. This module covers kernel decorators, built-in mathematical functions, and programming constructs for writing high-performance parallel code.

Capabilities

Kernel and Function Decorators

Decorators that transform Python functions into compiled GPU/CPU code.

def kernel(func: Callable) -> Kernel:
    """
    Decorator to compile Python function as a parallel kernel.
    Kernels run on multiple threads simultaneously.
    
    Example:
        @wp.kernel
        def add_arrays(a: wp.array(dtype=float), 
                      b: wp.array(dtype=float),
                      c: wp.array(dtype=float)):
            i = wp.tid()
            c[i] = a[i] + b[i]
    """

def func(func: Callable) -> Function:
    """
    Decorator to compile Python function callable from kernels.
    Functions are sequential and can be called from within kernels.
    
    Example:
        @wp.func
        def compute_distance(a: wp.vec3, b: wp.vec3) -> float:
            diff = a - b
            return wp.length(diff)
    """

def func_grad(func: Function) -> Function:
    """Get gradient (derivative) function for automatic differentiation."""

def func_replay(func: Function) -> Function:
    """Get replay function for gradient computation."""

def func_native(func: Callable) -> Function:
    """Mark function to run natively (not compiled)."""

Kernel Programming Constructs

Built-in functions and constructs available within kernel code.

def tid() -> int:
    """Get current thread index (0 to dim-1)."""

def static(condition: bool):
    """
    Static conditional compilation.
    Branch is resolved at compile time.
    """

def print(value):
    """Print value from kernel (for debugging)."""

def expect(condition: bool):
    """Assert condition is true (removed in release builds)."""

def struct(cls):
    """
    Decorator to create Warp struct type.
    
    Example:
        @wp.struct
        class Particle:
            pos: wp.vec3
            vel: wp.vec3
            mass: float
    """

def overload(func: Callable):
    """Decorator to create function overloads for different argument types."""

Mathematical Built-in Functions

Core mathematical operations available in kernels.

# Basic arithmetic
def min(a: Scalar, b: Scalar) -> Scalar:
    """Return minimum of two values."""

def max(a: Scalar, b: Scalar) -> Scalar:
    """Return maximum of two values."""

def abs(x: Scalar) -> Scalar:
    """Return absolute value."""

def sign(x: Scalar) -> Scalar:
    """Return sign of value (-1, 0, or 1)."""

def step(edge: Float, x: Float) -> Float:
    """Step function: 0 if x < edge, else 1."""

def smoothstep(edge0: Float, edge1: Float, x: Float) -> Float:
    """Smooth step interpolation between two edges."""

# Power and exponential functions
def pow(base: Float, exp: Float) -> Float:
    """Raise base to power exp."""

def sqrt(x: Float) -> Float:
    """Square root."""

def cbrt(x: Float) -> Float:
    """Cube root."""

def exp(x: Float) -> Float:
    """Exponential function (e^x)."""

def exp2(x: Float) -> Float:
    """Base-2 exponential (2^x)."""

def log(x: Float) -> Float:
    """Natural logarithm."""

def log2(x: Float) -> Float:
    """Base-2 logarithm."""

def log10(x: Float) -> Float:
    """Base-10 logarithm."""

# Trigonometric functions  
def sin(x: Float) -> Float:
    """Sine function."""

def cos(x: Float) -> Float:
    """Cosine function."""

def tan(x: Float) -> Float:
    """Tangent function."""

def asin(x: Float) -> Float:
    """Arcsine function."""

def acos(x: Float) -> Float:
    """Arccosine function."""

def atan(x: Float) -> Float:
    """Arctangent function."""

def atan2(y: Float, x: Float) -> Float:
    """Two-argument arctangent."""

def sinh(x: Float) -> Float:
    """Hyperbolic sine."""

def cosh(x: Float) -> Float:
    """Hyperbolic cosine."""

def tanh(x: Float) -> Float:
    """Hyperbolic tangent."""

# Rounding and utility
def floor(x: Float) -> Float:
    """Largest integer <= x."""

def ceil(x: Float) -> Float:
    """Smallest integer >= x."""

def round(x: Float) -> Float:
    """Round to nearest integer."""

def trunc(x: Float) -> Float:
    """Truncate to integer (towards zero)."""

def frac(x: Float) -> Float:
    """Fractional part (x - floor(x))."""

def fmod(x: Float, y: Float) -> Float:
    """Floating point remainder."""

def clamp(x: Scalar, a: Scalar, b: Scalar) -> Scalar:
    """Clamp x between a and b."""

def lerp(a: Float, b: Float, t: Float) -> Float:
    """Linear interpolation: a + t*(b-a)."""

Vector Mathematics

Built-in functions for vector operations.

def length(v: Vector) -> Float:
    """Vector magnitude."""

def length_sq(v: Vector) -> Float:
    """Squared vector magnitude (faster than length)."""

def normalize(v: Vector) -> Vector:
    """Return unit vector in same direction."""

def dot(a: Vector, b: Vector) -> Float:
    """Vector dot product."""

def cross(a: vec3, b: vec3) -> vec3:
    """Vector cross product (3D only)."""

def distance(a: Vector, b: Vector) -> Float:
    """Distance between two points."""

def distance_sq(a: Vector, b: Vector) -> Float:
    """Squared distance (faster than distance)."""

def reflect(v: Vector, n: Vector) -> Vector:
    """Reflect vector v across normal n."""

def refract(v: Vector, n: Vector, eta: Float) -> Vector:
    """Refract vector through surface with refractive index."""

Matrix Operations

Built-in functions for matrix mathematics.

def mul(a: Matrix, b: Matrix) -> Matrix:
    """Matrix multiplication."""

def mul(m: Matrix, v: Vector) -> Vector:
    """Matrix-vector multiplication."""

def transpose(m: Matrix) -> Matrix:
    """Matrix transpose."""

def inverse(m: Matrix) -> Matrix:
    """Matrix inverse (if invertible)."""

def determinant(m: Matrix) -> Float:
    """Matrix determinant."""

def trace(m: Matrix) -> Float:
    """Matrix trace (sum of diagonal elements)."""

Quaternion Operations

Built-in functions for quaternion mathematics.

def quat_identity() -> quat:
    """Identity quaternion (no rotation)."""

def quat_from_axis_angle(axis: vec3, angle: Float) -> quat:
    """Create quaternion from rotation axis and angle."""

def quat_to_axis_angle(q: quat) -> tuple[vec3, Float]:
    """Extract axis and angle from quaternion."""

def quat_mul(a: quat, b: quat) -> quat:
    """Quaternion multiplication (composition)."""

def quat_rotate(q: quat, v: vec3) -> vec3:
    """Rotate vector by quaternion."""

def quat_inverse(q: quat) -> quat:
    """Quaternion inverse."""

def quat_conjugate(q: quat) -> quat:
    """Quaternion conjugate."""

def quat_normalize(q: quat) -> quat:
    """Normalize quaternion to unit length."""

def quat_slerp(a: quat, b: quat, t: Float) -> quat:
    """Spherical linear interpolation between quaternions."""

Atomic Operations

Thread-safe operations for concurrent access to shared memory.

def atomic_add(arr: array, index: int, value: Scalar) -> Scalar:
    """Atomically add value to array element, return old value."""

def atomic_sub(arr: array, index: int, value: Scalar) -> Scalar:
    """Atomically subtract value from array element."""

def atomic_min(arr: array, index: int, value: Scalar) -> Scalar:
    """Atomically set minimum value."""

def atomic_max(arr: array, index: int, value: Scalar) -> Scalar:
    """Atomically set maximum value."""

def atomic_cas(arr: array, index: int, compare: Scalar, value: Scalar) -> Scalar:
    """Atomic compare-and-swap operation."""

Geometric Built-ins

Functions for spatial queries and geometric computations.

def mesh_query_point(mesh: Mesh, point: vec3, max_dist: Float) -> MeshQueryPoint:
    """Find closest point on mesh to query point."""

def mesh_query_ray(mesh: Mesh, start: vec3, dir: vec3) -> MeshQueryRay:
    """Cast ray against mesh and find intersection."""

def mesh_query_aabb(mesh: Mesh, lower: vec3, upper: vec3) -> MeshQueryAABB:
    """Query mesh triangles overlapping axis-aligned box."""

def bvh_query_ray(bvh: Bvh, start: vec3, dir: vec3) -> BvhQuery:
    """Ray query against bounding volume hierarchy."""

def hash_grid_query(grid: HashGrid, point: vec3, radius: Float) -> HashGridQuery:
    """Find neighbors within radius using spatial hash grid."""

def volume_sample_f(volume: Volume, uvw: vec3, sampling_mode: int) -> Float:
    """Sample scalar value from 3D volume at normalized coordinates."""

def volume_sample_v(volume: Volume, uvw: vec3, sampling_mode: int) -> vec3:
    """Sample vector value from 3D volume."""

Usage Examples

Basic Kernel Structure

import warp as wp

@wp.kernel
def compute_forces(positions: wp.array(dtype=wp.vec3),
                  velocities: wp.array(dtype=wp.vec3), 
                  forces: wp.array(dtype=wp.vec3),
                  mass: float,
                  dt: float):
    # Get thread ID
    i = wp.tid()
    
    # Each thread processes one particle
    pos = positions[i]
    vel = velocities[i]
    
    # Compute force (simple damping)
    force = -0.1 * vel * mass
    forces[i] = force
    
    # Update position using current velocity
    positions[i] = pos + vel * dt

Function Calls Within Kernels

@wp.func
def spring_force(pos_a: wp.vec3, pos_b: wp.vec3, 
                rest_length: float, stiffness: float) -> wp.vec3:
    """Compute spring force between two points."""
    diff = pos_a - pos_b
    dist = wp.length(diff)
    
    if dist > 0.0:
        force_mag = stiffness * (dist - rest_length)
        force_dir = diff / dist
        return -force_mag * force_dir
    else:
        return wp.vec3(0.0, 0.0, 0.0)

@wp.kernel  
def update_springs(positions: wp.array(dtype=wp.vec3),
                  forces: wp.array(dtype=wp.vec3),
                  spring_indices: wp.array(dtype=wp.int32),
                  rest_lengths: wp.array(dtype=float),
                  stiffness: float):
    i = wp.tid()
    
    # Get spring endpoints
    idx_a = spring_indices[i * 2]
    idx_b = spring_indices[i * 2 + 1]
    
    pos_a = positions[idx_a] 
    pos_b = positions[idx_b]
    rest_len = rest_lengths[i]
    
    # Compute spring force using helper function
    force = spring_force(pos_a, pos_b, rest_len, stiffness)
    
    # Apply forces (use atomic operations for thread safety)
    wp.atomic_add(forces, idx_a, force)
    wp.atomic_add(forces, idx_b, -force)

Struct Types

@wp.struct
class Particle:
    position: wp.vec3
    velocity: wp.vec3
    mass: float
    radius: float

@wp.kernel
def update_particles(particles: wp.array(dtype=Particle), dt: float):
    i = wp.tid()
    p = particles[i]
    
    # Update position
    p.position = p.position + p.velocity * dt
    
    # Apply gravity
    gravity = wp.vec3(0.0, -9.81, 0.0)
    p.velocity = p.velocity + gravity * dt
    
    # Write back modified particle
    particles[i] = p

Types

class Kernel:
    """Compiled kernel function."""
    
    def __call__(self, *args, **kwargs):
        """Launch kernel (equivalent to wp.launch)."""

class Function:
    """Compiled function callable from kernels."""
    
    def __call__(self, *args, **kwargs):
        """Call function."""

# Query result types
class MeshQueryPoint:
    face: int        # Triangle index (-1 if no hit)
    u: float         # Barycentric coordinate
    v: float         # Barycentric coordinate  
    sign: float      # Inside/outside indicator
    
class MeshQueryRay:
    face: int        # Hit triangle index (-1 if no hit)
    u: float         # Barycentric coordinate
    v: float         # Barycentric coordinate
    t: float         # Ray parameter at intersection

class MeshQueryAABB:
    face: int        # Next overlapping triangle (-1 when done)

class BvhQuery:
    """BVH ray query state."""

class HashGridQuery:
    """Hash grid neighbor query state."""

Install with Tessl CLI

npx tessl i tessl/pypi-warp-lang

docs

core-execution.md

fem.md

framework-integration.md

index.md

kernel-programming.md

optimization.md

rendering.md

types-arrays.md

utilities.md

tile.json