0
# JIT Compilation
1
2
CuPy provides just-in-time (JIT) compilation capabilities through the `cupyx.jit` module, enabling the compilation of Python functions directly to GPU kernels. This allows developers to write GPU code in Python syntax while achieving near-native CUDA performance through automatic kernel generation and optimization.
3
4
## Capabilities
5
6
### JIT Function Decoration
7
8
Transform Python functions into GPU kernels using decorators for automatic compilation and execution.
9
10
```python { .api }
11
def rawkernel(device=False):
12
"""
13
Decorator to compile a Python function into a raw CUDA kernel.
14
15
The decorated function is compiled to CUDA C++ and can be launched
16
with explicit grid and block dimensions.
17
18
Parameters:
19
device: bool, optional - If True, compile as device function
20
"""
21
22
def kernel(grid=None, block=None, shared_mem=0):
23
"""
24
Decorator to compile and launch a Python function as a CUDA kernel.
25
26
Automatically handles kernel launch parameters and provides a more
27
convenient interface for simple kernels.
28
29
Parameters:
30
grid: tuple, optional - Grid dimensions (blocks per grid)
31
block: tuple, optional - Block dimensions (threads per block)
32
shared_mem: int, optional - Shared memory size in bytes
33
"""
34
35
def elementwise(signature):
36
"""
37
Decorator to create element-wise kernels from Python functions.
38
39
The decorated function is applied element-wise across input arrays
40
with automatic broadcasting and type handling.
41
42
Parameters:
43
signature: str - Function signature describing input/output types
44
"""
45
46
def reduction(signature, identity=None):
47
"""
48
Decorator to create reduction kernels from Python functions.
49
50
The decorated function performs reduction operations across array
51
dimensions with automatic handling of reduction strategies.
52
53
Parameters:
54
signature: str - Function signature for reduction operation
55
identity: scalar, optional - Identity value for the reduction
56
"""
57
```
58
59
### Thread and Block Primitives
60
61
Access CUDA thread and block indexing primitives within JIT-compiled functions.
62
63
```python { .api }
64
def threadIdx():
65
"""Get the current thread index within a block."""
66
67
def blockIdx():
68
"""Get the current block index within the grid."""
69
70
def blockDim():
71
"""Get the dimensions of the current block."""
72
73
def gridDim():
74
"""Get the dimensions of the current grid."""
75
76
def thread_id():
77
"""Get the global thread ID."""
78
79
def warp_id():
80
"""Get the current warp ID within a block."""
81
82
def lane_id():
83
"""Get the lane ID within the current warp."""
84
```
85
86
### Synchronization Primitives
87
88
Synchronization functions for coordinating between threads and blocks.
89
90
```python { .api }
91
def syncthreads():
92
"""Synchronize all threads within a block."""
93
94
def syncwarp():
95
"""Synchronize threads within a warp."""
96
97
def __syncthreads():
98
"""CUDA __syncthreads() primitive."""
99
100
def __syncwarp(mask=0xffffffff):
101
"""CUDA __syncwarp() primitive with optional mask."""
102
```
103
104
### Memory Operations
105
106
Memory access patterns and shared memory management within JIT kernels.
107
108
```python { .api }
109
def shared_memory(shape, dtype):
110
"""
111
Allocate shared memory within a kernel.
112
113
Parameters:
114
shape: tuple - Shape of the shared memory array
115
dtype: data-type - Data type of elements
116
"""
117
118
def local_memory(shape, dtype):
119
"""
120
Allocate local (register) memory within a kernel.
121
122
Parameters:
123
shape: tuple - Shape of the local memory array
124
dtype: data-type - Data type of elements
125
"""
126
127
def atomic_add(array, index, value):
128
"""
129
Atomic addition operation.
130
131
Parameters:
132
array: array_like - Target array
133
index: int - Index to update
134
value: scalar - Value to add
135
"""
136
137
def atomic_sub(array, index, value):
138
"""
139
Atomic subtraction operation.
140
141
Parameters:
142
array: array_like - Target array
143
index: int - Index to update
144
value: scalar - Value to subtract
145
"""
146
147
def atomic_max(array, index, value):
148
"""
149
Atomic maximum operation.
150
151
Parameters:
152
array: array_like - Target array
153
index: int - Index to update
154
value: scalar - Value to compare
155
"""
156
157
def atomic_min(array, index, value):
158
"""
159
Atomic minimum operation.
160
161
Parameters:
162
array: array_like - Target array
163
index: int - Index to update
164
value: scalar - Value to compare
165
"""
166
167
def atomic_cas(array, index, compare, value):
168
"""
169
Atomic compare-and-swap operation.
170
171
Parameters:
172
array: array_like - Target array
173
index: int - Index to update
174
compare: scalar - Expected value
175
value: scalar - New value if comparison succeeds
176
"""
177
```
178
179
### Control Flow and Utilities
180
181
Control flow constructs and utility functions for JIT compilation.
182
183
```python { .api }
184
def if_then_else(condition, if_true, if_false):
185
"""
186
Conditional expression for JIT compilation.
187
188
Parameters:
189
condition: bool expression - Condition to evaluate
190
if_true: expression - Value/expression if condition is True
191
if_false: expression - Value/expression if condition is False
192
"""
193
194
def while_loop(condition, body):
195
"""
196
While loop construct for JIT compilation.
197
198
Parameters:
199
condition: callable - Function returning loop condition
200
body: callable - Function containing loop body
201
"""
202
203
def for_loop(start, stop, step, body):
204
"""
205
For loop construct for JIT compilation.
206
207
Parameters:
208
start: int - Loop start value
209
stop: int - Loop end value (exclusive)
210
step: int - Loop increment
211
body: callable - Function containing loop body
212
"""
213
214
def unroll(n):
215
"""
216
Decorator to unroll loops for performance optimization.
217
218
Parameters:
219
n: int - Number of iterations to unroll
220
"""
221
```
222
223
### Mathematical Functions
224
225
Mathematical functions optimized for JIT compilation and GPU execution.
226
227
```python { .api }
228
def sqrt(x):
229
"""Square root function for JIT kernels."""
230
231
def exp(x):
232
"""Exponential function for JIT kernels."""
233
234
def log(x):
235
"""Natural logarithm function for JIT kernels."""
236
237
def sin(x):
238
"""Sine function for JIT kernels."""
239
240
def cos(x):
241
"""Cosine function for JIT kernels."""
242
243
def tan(x):
244
"""Tangent function for JIT kernels."""
245
246
def pow(x, y):
247
"""Power function for JIT kernels."""
248
249
def abs(x):
250
"""Absolute value function for JIT kernels."""
251
252
def min(x, y):
253
"""Minimum function for JIT kernels."""
254
255
def max(x, y):
256
"""Maximum function for JIT kernels."""
257
```
258
259
### Type System
260
261
Type specification and casting functions for JIT compilation.
262
263
```python { .api }
264
def cast(value, dtype):
265
"""
266
Cast value to specified data type.
267
268
Parameters:
269
value: scalar or array - Value to cast
270
dtype: data-type - Target data type
271
"""
272
273
class float32:
274
"""32-bit floating point type for JIT."""
275
276
class float64:
277
"""64-bit floating point type for JIT."""
278
279
class int32:
280
"""32-bit signed integer type for JIT."""
281
282
class int64:
283
"""64-bit signed integer type for JIT."""
284
285
class uint32:
286
"""32-bit unsigned integer type for JIT."""
287
288
class uint64:
289
"""64-bit unsigned integer type for JIT."""
290
291
class bool:
292
"""Boolean type for JIT."""
293
```
294
295
## Usage Examples
296
297
### Basic JIT Kernel
298
299
```python
300
import cupy as cp
301
from cupyx import jit
302
303
# Simple element-wise kernel using JIT
304
@jit.rawkernel()
305
def add_kernel(x, y, z, n):
306
"""Add two arrays element-wise."""
307
tid = jit.thread_id()
308
if tid < n:
309
z[tid] = x[tid] + y[tid]
310
311
# Create input arrays
312
n = 1000000
313
x = cp.random.rand(n, dtype=cp.float32)
314
y = cp.random.rand(n, dtype=cp.float32)
315
z = cp.zeros(n, dtype=cp.float32)
316
317
# Launch kernel
318
threads_per_block = 256
319
blocks_per_grid = (n + threads_per_block - 1) // threads_per_block
320
add_kernel[blocks_per_grid, threads_per_block](x, y, z, n)
321
322
print("Result:", z[:10])
323
```
324
325
### Element-wise JIT Function
326
327
```python
328
# Element-wise function with automatic broadcasting
329
@jit.elementwise('T x, T y -> T')
330
def fused_operation(x, y):
331
"""Fused mathematical operation."""
332
temp = x * x + y * y
333
return jit.sqrt(temp) + jit.sin(x) * jit.cos(y)
334
335
# Use like a regular CuPy function
336
a = cp.linspace(0, 2*cp.pi, 1000000)
337
b = cp.linspace(0, cp.pi, 1000000)
338
result = fused_operation(a, b)
339
340
print("Element-wise result shape:", result.shape)
341
print("Sample values:", result[:5])
342
```
343
344
### Reduction JIT Kernel
345
346
```python
347
# Custom reduction operation
348
@jit.reduction('T x -> T', identity=0)
349
def sum_of_squares(x):
350
"""Compute sum of squares reduction."""
351
return x * x
352
353
# Apply reduction
354
data = cp.array([1, 2, 3, 4, 5], dtype=cp.float32)
355
result = sum_of_squares(data)
356
print("Sum of squares:", result)
357
358
# Complex reduction with multiple operations
359
@jit.reduction('T x, T y -> T', identity=0)
360
def weighted_sum(x, y):
361
"""Compute weighted sum."""
362
return x * y
363
364
weights = cp.array([0.1, 0.2, 0.3, 0.4, 0.5])
365
values = cp.array([10, 20, 30, 40, 50])
366
weighted_result = weighted_sum(values, weights)
367
print("Weighted sum:", weighted_result)
368
```
369
370
### Shared Memory Example
371
372
```python
373
@jit.rawkernel()
374
def matrix_transpose_shared(input_matrix, output_matrix, width, height):
375
"""Matrix transpose using shared memory."""
376
# Allocate shared memory tile
377
TILE_SIZE = 32
378
tile = jit.shared_memory((TILE_SIZE, TILE_SIZE), cp.float32)
379
380
# Calculate thread coordinates
381
x = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().x
382
y = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().y
383
384
# Load data into shared memory
385
if x < width and y < height:
386
tile[jit.threadIdx().y, jit.threadIdx().x] = input_matrix[y, x]
387
388
# Synchronize threads
389
jit.syncthreads()
390
391
# Calculate transposed coordinates
392
tx = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().x
393
ty = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().y
394
395
# Write to output (transposed)
396
if tx < height and ty < width:
397
output_matrix[ty, tx] = tile[jit.threadIdx().x, jit.threadIdx().y]
398
399
# Test matrix transpose
400
input_mat = cp.random.rand(1024, 1024, dtype=cp.float32)
401
output_mat = cp.zeros((1024, 1024), dtype=cp.float32)
402
403
# Launch with 2D grid
404
block_size = (32, 32)
405
grid_size = (
406
(input_mat.shape[1] + block_size[0] - 1) // block_size[0],
407
(input_mat.shape[0] + block_size[1] - 1) // block_size[1]
408
)
409
410
matrix_transpose_shared[grid_size, block_size](
411
input_mat, output_mat, input_mat.shape[1], input_mat.shape[0]
412
)
413
414
# Verify correctness
415
expected = input_mat.T
416
print("Transpose correct:", cp.allclose(output_mat, expected))
417
```
418
419
### Atomic Operations Example
420
421
```python
422
@jit.rawkernel()
423
def histogram_kernel(data, histogram, n_bins, data_size):
424
"""Compute histogram using atomic operations."""
425
tid = jit.thread_id()
426
427
if tid < data_size:
428
# Calculate bin index
429
bin_idx = int(data[tid] * n_bins)
430
bin_idx = jit.min(bin_idx, n_bins - 1) # Clamp to valid range
431
432
# Atomic increment
433
jit.atomic_add(histogram, bin_idx, 1)
434
435
# Generate random data
436
data = cp.random.rand(1000000, dtype=cp.float32)
437
histogram = cp.zeros(100, dtype=cp.int32)
438
439
# Launch histogram kernel
440
threads = 256
441
blocks = (data.size + threads - 1) // threads
442
histogram_kernel[blocks, threads](data, histogram, 100, data.size)
443
444
print("Histogram bins:", histogram[:10])
445
print("Total count:", cp.sum(histogram))
446
```
447
448
### Advanced Control Flow
449
450
```python
451
@jit.rawkernel()
452
def complex_algorithm(input_data, output_data, threshold, size):
453
"""Complex algorithm with control flow."""
454
tid = jit.thread_id()
455
456
if tid >= size:
457
return
458
459
value = input_data[tid]
460
result = 0.0
461
462
# Complex conditional logic
463
if value > threshold:
464
# Iterative computation
465
for i in range(10):
466
result += jit.sin(value * i) * jit.exp(-i * 0.1)
467
else:
468
# Alternative computation
469
temp = jit.sqrt(jit.abs(value))
470
result = temp * jit.cos(temp)
471
472
output_data[tid] = result
473
474
# Test complex algorithm
475
input_arr = cp.random.randn(100000).astype(cp.float32)
476
output_arr = cp.zeros_like(input_arr)
477
threshold = 0.5
478
479
threads = 256
480
blocks = (input_arr.size + threads - 1) // threads
481
complex_algorithm[blocks, threads](input_arr, output_arr, threshold, input_arr.size)
482
483
print("Complex algorithm results:", output_arr[:10])
484
```
485
486
### Performance Optimization Examples
487
488
```python
489
# Loop unrolling for performance
490
@jit.rawkernel()
491
def optimized_convolution(input_data, kernel, output_data, width, height, kernel_size):
492
"""Optimized 2D convolution with loop unrolling."""
493
x = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().x
494
y = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().y
495
496
if x >= width or y >= height:
497
return
498
499
result = 0.0
500
half_kernel = kernel_size // 2
501
502
# Manual loop unrolling for small kernels
503
if kernel_size == 3:
504
for dy in range(-1, 2):
505
for dx in range(-1, 2):
506
px = jit.max(0, jit.min(width - 1, x + dx))
507
py = jit.max(0, jit.min(height - 1, y + dy))
508
kernel_idx = (dy + 1) * 3 + (dx + 1)
509
result += input_data[py, px] * kernel[kernel_idx]
510
else:
511
# General case
512
for dy in range(-half_kernel, half_kernel + 1):
513
for dx in range(-half_kernel, half_kernel + 1):
514
px = jit.max(0, jit.min(width - 1, x + dx))
515
py = jit.max(0, jit.min(height - 1, y + dy))
516
kernel_idx = (dy + half_kernel) * kernel_size + (dx + half_kernel)
517
result += input_data[py, px] * kernel[kernel_idx]
518
519
output_data[y, x] = result
520
521
# Vectorized operations for better performance
522
@jit.elementwise('T x, T y, T z, T w -> T')
523
def vectorized_operation(x, y, z, w):
524
"""Vectorized computation using multiple inputs."""
525
temp1 = x * y + z * w
526
temp2 = jit.sqrt(temp1 * temp1 + 1.0)
527
return jit.sin(temp2) * jit.exp(-temp2 * 0.1)
528
529
# Test vectorized operation
530
a = cp.random.rand(1000000)
531
b = cp.random.rand(1000000)
532
c = cp.random.rand(1000000)
533
d = cp.random.rand(1000000)
534
535
result = vectorized_operation(a, b, c, d)
536
print("Vectorized result sample:", result[:5])
537
```
538
539
### Multi-dimensional Indexing
540
541
```python
542
@jit.rawkernel()
543
def multi_dim_kernel(input_3d, output_3d, depth, height, width):
544
"""3D array processing with multi-dimensional indexing."""
545
# 3D thread indexing
546
x = jit.blockIdx().x * jit.blockDim().x + jit.threadIdx().x
547
y = jit.blockIdx().y * jit.blockDim().y + jit.threadIdx().y
548
z = jit.blockIdx().z * jit.blockDim().z + jit.threadIdx().z
549
550
if x >= width or y >= height or z >= depth:
551
return
552
553
# Access neighboring elements in 3D
554
result = 0.0
555
count = 0
556
557
for dz in range(-1, 2):
558
for dy in range(-1, 2):
559
for dx in range(-1, 2):
560
nz = jit.max(0, jit.min(depth - 1, z + dz))
561
ny = jit.max(0, jit.min(height - 1, y + dy))
562
nx = jit.max(0, jit.min(width - 1, x + dx))
563
564
result += input_3d[nz, ny, nx]
565
count += 1
566
567
# Average of neighborhood
568
output_3d[z, y, x] = result / count
569
570
# Test 3D processing
571
input_3d = cp.random.rand(64, 256, 256, dtype=cp.float32)
572
output_3d = cp.zeros_like(input_3d)
573
574
# 3D grid launch
575
block_3d = (8, 16, 16)
576
grid_3d = (
577
(input_3d.shape[2] + block_3d[0] - 1) // block_3d[0],
578
(input_3d.shape[1] + block_3d[1] - 1) // block_3d[1],
579
(input_3d.shape[0] + block_3d[2] - 1) // block_3d[2]
580
)
581
582
multi_dim_kernel[grid_3d, block_3d](
583
input_3d, output_3d,
584
input_3d.shape[0], input_3d.shape[1], input_3d.shape[2]
585
)
586
587
print("3D processing completed, output shape:", output_3d.shape)
588
```
589
590
### Error Handling and Debugging
591
592
```python
593
# Debugging with conditional compilation
594
@jit.rawkernel()
595
def debug_kernel(data, output, size, debug_flag):
596
"""Kernel with debugging capabilities."""
597
tid = jit.thread_id()
598
599
if tid >= size:
600
return
601
602
value = data[tid]
603
604
# Bounds checking
605
if tid >= size:
606
if debug_flag:
607
# In debug mode, set error flag
608
output[tid] = -999.0
609
return
610
611
# NaN/Inf checking
612
if jit.isnan(value) or jit.isinf(value):
613
if debug_flag:
614
output[tid] = -888.0
615
else:
616
output[tid] = 0.0
617
return
618
619
# Normal computation
620
result = jit.sqrt(jit.abs(value)) + jit.sin(value)
621
output[tid] = result
622
623
# Function composition and modularity
624
@jit.rawkernel()
625
def modular_computation():
626
"""Example of modular JIT kernel design."""
627
628
def compute_step1(x, y):
629
return x * y + jit.sin(x)
630
631
def compute_step2(intermediate):
632
return jit.sqrt(jit.abs(intermediate))
633
634
def compute_step3(x, step2_result):
635
return step2_result * jit.exp(-x * 0.1)
636
637
# Main kernel logic using helper functions
638
tid = jit.thread_id()
639
# ... kernel implementation using helper functions
640
```
641
642
JIT compilation in CuPy provides a powerful bridge between Python productivity and GPU performance, enabling developers to write complex GPU algorithms in familiar Python syntax while achieving near-native CUDA performance through automatic optimization and compilation.