CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-cupy-cuda11x

CuPy: NumPy & SciPy for GPU - CUDA 11.x optimized distribution providing GPU-accelerated computing with Python

Pending
Overview
Eval results
Files

jit-compilation.mddocs/

JIT Compilation

CuPy provides just-in-time (JIT) compilation capabilities through the cupyx.jit module, enabling the compilation of Python functions directly to GPU kernels. This allows developers to write GPU code in Python syntax while achieving near-native CUDA performance through automatic kernel generation and optimization.

Capabilities

JIT Function Decoration

Transform Python functions into GPU kernels using decorators for automatic compilation and execution.

def rawkernel(device=False):
    """
    Decorator to compile a Python function into a raw CUDA kernel.
    
    The decorated function is compiled to CUDA C++ and can be launched
    with explicit grid and block dimensions.
    
    Parameters:
        device: bool, optional - If True, compile as device function
    """

def kernel(grid=None, block=None, shared_mem=0):
    """
    Decorator to compile and launch a Python function as a CUDA kernel.
    
    Automatically handles kernel launch parameters and provides a more
    convenient interface for simple kernels.
    
    Parameters:
        grid: tuple, optional - Grid dimensions (blocks per grid)
        block: tuple, optional - Block dimensions (threads per block)  
        shared_mem: int, optional - Shared memory size in bytes
    """

def elementwise(signature):
    """
    Decorator to create element-wise kernels from Python functions.
    
    The decorated function is applied element-wise across input arrays
    with automatic broadcasting and type handling.
    
    Parameters:
        signature: str - Function signature describing input/output types
    """

def reduction(signature, identity=None):
    """
    Decorator to create reduction kernels from Python functions.
    
    The decorated function performs reduction operations across array
    dimensions with automatic handling of reduction strategies.
    
    Parameters:
        signature: str - Function signature for reduction operation
        identity: scalar, optional - Identity value for the reduction
    """

Thread and Block Primitives

Access CUDA thread and block indexing primitives within JIT-compiled functions.

def threadIdx():
    """Get the current thread index within a block."""

def blockIdx():
    """Get the current block index within the grid."""

def blockDim():
    """Get the dimensions of the current block."""

def gridDim():
    """Get the dimensions of the current grid."""

def thread_id():
    """Get the global thread ID."""

def warp_id():
    """Get the current warp ID within a block."""

def lane_id():
    """Get the lane ID within the current warp."""

Synchronization Primitives

Synchronization functions for coordinating between threads and blocks.

def syncthreads():
    """Synchronize all threads within a block."""

def syncwarp():
    """Synchronize threads within a warp."""

def __syncthreads():
    """CUDA __syncthreads() primitive."""

def __syncwarp(mask=0xffffffff):
    """CUDA __syncwarp() primitive with optional mask."""

Memory Operations

Memory access patterns and shared memory management within JIT kernels.

def shared_memory(shape, dtype):
    """
    Allocate shared memory within a kernel.
    
    Parameters:
        shape: tuple - Shape of the shared memory array
        dtype: data-type - Data type of elements
    """

def local_memory(shape, dtype):
    """
    Allocate local (register) memory within a kernel.
    
    Parameters:
        shape: tuple - Shape of the local memory array
        dtype: data-type - Data type of elements
    """

def atomic_add(array, index, value):
    """
    Atomic addition operation.
    
    Parameters:
        array: array_like - Target array
        index: int - Index to update
        value: scalar - Value to add
    """

def atomic_sub(array, index, value):
    """
    Atomic subtraction operation.
    
    Parameters:
        array: array_like - Target array
        index: int - Index to update  
        value: scalar - Value to subtract
    """

def atomic_max(array, index, value):
    """
    Atomic maximum operation.
    
    Parameters:
        array: array_like - Target array
        index: int - Index to update
        value: scalar - Value to compare
    """

def atomic_min(array, index, value):
    """
    Atomic minimum operation.
    
    Parameters:
        array: array_like - Target array
        index: int - Index to update
        value: scalar - Value to compare
    """

def atomic_cas(array, index, compare, value):
    """
    Atomic compare-and-swap operation.
    
    Parameters:
        array: array_like - Target array
        index: int - Index to update
        compare: scalar - Expected value
        value: scalar - New value if comparison succeeds
    """

Control Flow and Utilities

Control flow constructs and utility functions for JIT compilation.

def if_then_else(condition, if_true, if_false):
    """
    Conditional expression for JIT compilation.
    
    Parameters:
        condition: bool expression - Condition to evaluate
        if_true: expression - Value/expression if condition is True
        if_false: expression - Value/expression if condition is False
    """

def while_loop(condition, body):
    """
    While loop construct for JIT compilation.
    
    Parameters:
        condition: callable - Function returning loop condition
        body: callable - Function containing loop body
    """

def for_loop(start, stop, step, body):
    """
    For loop construct for JIT compilation.
    
    Parameters:
        start: int - Loop start value
        stop: int - Loop end value (exclusive)
        step: int - Loop increment
        body: callable - Function containing loop body
    """

def unroll(n):
    """
    Decorator to unroll loops for performance optimization.
    
    Parameters:
        n: int - Number of iterations to unroll
    """

Mathematical Functions

Mathematical functions optimized for JIT compilation and GPU execution.

def sqrt(x):
    """Square root function for JIT kernels."""

def exp(x):
    """Exponential function for JIT kernels."""

def log(x):
    """Natural logarithm function for JIT kernels."""

def sin(x):
    """Sine function for JIT kernels."""

def cos(x):
    """Cosine function for JIT kernels."""

def tan(x):
    """Tangent function for JIT kernels."""

def pow(x, y):
    """Power function for JIT kernels."""

def abs(x):
    """Absolute value function for JIT kernels."""

def min(x, y):
    """Minimum function for JIT kernels."""

def max(x, y):
    """Maximum function for JIT kernels."""

Type System

Type specification and casting functions for JIT compilation.

def cast(value, dtype):
    """
    Cast value to specified data type.
    
    Parameters:
        value: scalar or array - Value to cast
        dtype: data-type - Target data type
    """

class float32:
    """32-bit floating point type for JIT."""

class float64:
    """64-bit floating point type for JIT."""

class int32:
    """32-bit signed integer type for JIT."""

class int64:
    """64-bit signed integer type for JIT."""

class uint32:
    """32-bit unsigned integer type for JIT."""

class uint64:
    """64-bit unsigned integer type for JIT."""

class bool:
    """Boolean type for JIT."""

Usage Examples

Basic JIT Kernel

import cupy as cp
from cupyx import jit

# Simple element-wise kernel using JIT
@jit.rawkernel()
def add_kernel(x, y, z, n):
    """Add two arrays element-wise."""
    tid = jit.thread_id()
    if tid < n:
        z[tid] = x[tid] + y[tid]

# Create input arrays
n = 1000000
x = cp.random.rand(n, dtype=cp.float32)
y = cp.random.rand(n, dtype=cp.float32)
z = cp.zeros(n, dtype=cp.float32)

# Launch kernel
threads_per_block = 256
blocks_per_grid = (n + threads_per_block - 1) // threads_per_block
add_kernel[blocks_per_grid, threads_per_block](x, y, z, n)

print("Result:", z[:10])

Element-wise JIT Function

# Element-wise function with automatic broadcasting
@jit.elementwise('T x, T y -> T')
def fused_operation(x, y):
    """Fused mathematical operation."""
    temp = x * x + y * y
    return jit.sqrt(temp) + jit.sin(x) * jit.cos(y)

# Use like a regular CuPy function
a = cp.linspace(0, 2*cp.pi, 1000000)
b = cp.linspace(0, cp.pi, 1000000)
result = fused_operation(a, b)

print("Element-wise result shape:", result.shape)
print("Sample values:", result[:5])

Reduction JIT Kernel

# Custom reduction operation
@jit.reduction('T x -> T', identity=0)
def sum_of_squares(x):
    """Compute sum of squares reduction."""
    return x * x

# Apply reduction
data = cp.array([1, 2, 3, 4, 5], dtype=cp.float32)
result = sum_of_squares(data)
print("Sum of squares:", result)

# Complex reduction with multiple operations
@jit.reduction('T x, T y -> T', identity=0)
def weighted_sum(x, y):
    """Compute weighted sum."""
    return x * y

weights = cp.array([0.1, 0.2, 0.3, 0.4, 0.5])
values = cp.array([10, 20, 30, 40, 50])
weighted_result = weighted_sum(values, weights)
print("Weighted sum:", weighted_result)

Shared Memory Example

@jit.rawkernel()
def matrix_transpose_shared(input_matrix, output_matrix, width, height):
    """Matrix transpose using shared memory."""
    # Allocate shared memory tile
    TILE_SIZE = 32
    tile = jit.shared_memory((TILE_SIZE, TILE_SIZE), cp.float32)
    
    # Calculate thread coordinates
    x = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().x
    y = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().y
    
    # Load data into shared memory
    if x < width and y < height:
        tile[jit.threadIdx().y, jit.threadIdx().x] = input_matrix[y, x]
    
    # Synchronize threads
    jit.syncthreads()
    
    # Calculate transposed coordinates
    tx = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().x
    ty = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().y
    
    # Write to output (transposed)
    if tx < height and ty < width:
        output_matrix[ty, tx] = tile[jit.threadIdx().x, jit.threadIdx().y]

# Test matrix transpose
input_mat = cp.random.rand(1024, 1024, dtype=cp.float32)
output_mat = cp.zeros((1024, 1024), dtype=cp.float32)

# Launch with 2D grid
block_size = (32, 32)
grid_size = (
    (input_mat.shape[1] + block_size[0] - 1) // block_size[0],
    (input_mat.shape[0] + block_size[1] - 1) // block_size[1]
)

matrix_transpose_shared[grid_size, block_size](
    input_mat, output_mat, input_mat.shape[1], input_mat.shape[0]
)

# Verify correctness
expected = input_mat.T
print("Transpose correct:", cp.allclose(output_mat, expected))

Atomic Operations Example

@jit.rawkernel()
def histogram_kernel(data, histogram, n_bins, data_size):
    """Compute histogram using atomic operations."""
    tid = jit.thread_id()
    
    if tid < data_size:
        # Calculate bin index
        bin_idx = int(data[tid] * n_bins)
        bin_idx = jit.min(bin_idx, n_bins - 1)  # Clamp to valid range
        
        # Atomic increment
        jit.atomic_add(histogram, bin_idx, 1)

# Generate random data
data = cp.random.rand(1000000, dtype=cp.float32)
histogram = cp.zeros(100, dtype=cp.int32)

# Launch histogram kernel
threads = 256
blocks = (data.size + threads - 1) // threads
histogram_kernel[blocks, threads](data, histogram, 100, data.size)

print("Histogram bins:", histogram[:10])
print("Total count:", cp.sum(histogram))

Advanced Control Flow

@jit.rawkernel()
def complex_algorithm(input_data, output_data, threshold, size):
    """Complex algorithm with control flow."""
    tid = jit.thread_id()
    
    if tid >= size:
        return
    
    value = input_data[tid]
    result = 0.0
    
    # Complex conditional logic
    if value > threshold:
        # Iterative computation
        for i in range(10):
            result += jit.sin(value * i) * jit.exp(-i * 0.1)
    else:
        # Alternative computation
        temp = jit.sqrt(jit.abs(value))
        result = temp * jit.cos(temp)
    
    output_data[tid] = result

# Test complex algorithm
input_arr = cp.random.randn(100000).astype(cp.float32)
output_arr = cp.zeros_like(input_arr)
threshold = 0.5

threads = 256
blocks = (input_arr.size + threads - 1) // threads
complex_algorithm[blocks, threads](input_arr, output_arr, threshold, input_arr.size)

print("Complex algorithm results:", output_arr[:10])

Performance Optimization Examples

# Loop unrolling for performance
@jit.rawkernel()
def optimized_convolution(input_data, kernel, output_data, width, height, kernel_size):
    """Optimized 2D convolution with loop unrolling."""
    x = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().x
    y = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().y
    
    if x >= width or y >= height:
        return
    
    result = 0.0
    half_kernel = kernel_size // 2
    
    # Manual loop unrolling for small kernels
    if kernel_size == 3:
        for dy in range(-1, 2):
            for dx in range(-1, 2):
                px = jit.max(0, jit.min(width - 1, x + dx))
                py = jit.max(0, jit.min(height - 1, y + dy))
                kernel_idx = (dy + 1) * 3 + (dx + 1)
                result += input_data[py, px] * kernel[kernel_idx]
    else:
        # General case
        for dy in range(-half_kernel, half_kernel + 1):
            for dx in range(-half_kernel, half_kernel + 1):
                px = jit.max(0, jit.min(width - 1, x + dx))
                py = jit.max(0, jit.min(height - 1, y + dy))
                kernel_idx = (dy + half_kernel) * kernel_size + (dx + half_kernel)
                result += input_data[py, px] * kernel[kernel_idx]
    
    output_data[y, x] = result

# Vectorized operations for better performance
@jit.elementwise('T x, T y, T z, T w -> T')
def vectorized_operation(x, y, z, w):
    """Vectorized computation using multiple inputs."""
    temp1 = x * y + z * w
    temp2 = jit.sqrt(temp1 * temp1 + 1.0)
    return jit.sin(temp2) * jit.exp(-temp2 * 0.1)

# Test vectorized operation
a = cp.random.rand(1000000)
b = cp.random.rand(1000000)
c = cp.random.rand(1000000)
d = cp.random.rand(1000000)

result = vectorized_operation(a, b, c, d)
print("Vectorized result sample:", result[:5])

Multi-dimensional Indexing

@jit.rawkernel()
def multi_dim_kernel(input_3d, output_3d, depth, height, width):
    """3D array processing with multi-dimensional indexing."""
    # 3D thread indexing
    x = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().x
    y = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().y
    z = jit.blockIdx().z * jit.blockDim().z + jit.threadIdx().z
    
    if x >= width or y >= height or z >= depth:
        return
    
    # Access neighboring elements in 3D
    result = 0.0
    count = 0
    
    for dz in range(-1, 2):
        for dy in range(-1, 2):
            for dx in range(-1, 2):
                nz = jit.max(0, jit.min(depth - 1, z + dz))
                ny = jit.max(0, jit.min(height - 1, y + dy))
                nx = jit.max(0, jit.min(width - 1, x + dx))
                
                result += input_3d[nz, ny, nx]
                count += 1
    
    # Average of neighborhood
    output_3d[z, y, x] = result / count

# Test 3D processing
input_3d = cp.random.rand(64, 256, 256, dtype=cp.float32)
output_3d = cp.zeros_like(input_3d)

# 3D grid launch
block_3d = (8, 16, 16)
grid_3d = (
    (input_3d.shape[2] + block_3d[0] - 1) // block_3d[0],
    (input_3d.shape[1] + block_3d[1] - 1) // block_3d[1],
    (input_3d.shape[0] + block_3d[2] - 1) // block_3d[2]
)

multi_dim_kernel[grid_3d, block_3d](
    input_3d, output_3d, 
    input_3d.shape[0], input_3d.shape[1], input_3d.shape[2]
)

print("3D processing completed, output shape:", output_3d.shape)

Error Handling and Debugging

# Debugging with conditional compilation
@jit.rawkernel()
def debug_kernel(data, output, size, debug_flag):
    """Kernel with debugging capabilities."""
    tid = jit.thread_id()
    
    if tid >= size:
        return
    
    value = data[tid]
    
    # Bounds checking
    if tid >= size:
        if debug_flag:
            # In debug mode, set error flag
            output[tid] = -999.0
        return
    
    # NaN/Inf checking
    if jit.isnan(value) or jit.isinf(value):
        if debug_flag:
            output[tid] = -888.0
        else:
            output[tid] = 0.0
        return
    
    # Normal computation
    result = jit.sqrt(jit.abs(value)) + jit.sin(value)
    output[tid] = result

# Function composition and modularity
@jit.rawkernel()
def modular_computation():
    """Example of modular JIT kernel design."""
    
    def compute_step1(x, y):
        return x * y + jit.sin(x)
    
    def compute_step2(intermediate):
        return jit.sqrt(jit.abs(intermediate))
    
    def compute_step3(x, step2_result):
        return step2_result * jit.exp(-x * 0.1)
    
    # Main kernel logic using helper functions
    tid = jit.thread_id()
    # ... kernel implementation using helper functions

JIT compilation in CuPy provides a powerful bridge between Python productivity and GPU performance, enabling developers to write complex GPU algorithms in familiar Python syntax while achieving near-native CUDA performance through automatic optimization and compilation.

Install with Tessl CLI

npx tessl i tessl/pypi-cupy-cuda11x

docs

array-operations.md

cuda-integration.md

custom-kernels.md

fft.md

index.md

io-operations.md

jit-compilation.md

linear-algebra.md

mathematical-functions.md

performance-profiling.md

polynomial-operations.md

random.md

scipy-extensions.md

tile.json