0
# JIT Compilation and Custom Kernels
1
2
Just-in-time compilation capabilities and custom CUDA kernel creation for performance-critical applications requiring low-level GPU programming. CuPy provides comprehensive JIT compilation through kernel templates, raw CUDA kernels, and the `cupyx.jit` module for advanced GPU programming.
3
4
## Capabilities
5
6
### Kernel Templates
7
8
High-level kernel creation for common GPU computation patterns.
9
10
```python { .api }
11
class ElementwiseKernel:
12
"""Create custom element-wise operation kernel.
13
14
Args:
15
in_params: Input parameter specifications
16
out_params: Output parameter specifications
17
operation: CUDA C++ code for element operation
18
name: Kernel name for caching
19
reduce_dims: Whether to reduce dimensions
20
preamble: Additional declarations
21
loop_prep: Code before loop
22
after_loop: Code after loop
23
24
Example:
25
kernel = ElementwiseKernel(
26
'float32 x, float32 y',
27
'float32 z',
28
'z = x * x + y * y',
29
'squared_sum'
30
)
31
"""
32
33
def __init__(self, in_params, out_params, operation, name='kernel',
34
reduce_dims=True, preamble='', loop_prep='', after_loop='', **kwargs):
35
pass
36
37
def __call__(self, *args, **kwargs):
38
"""Execute kernel with given arguments."""
39
pass
40
41
class ReductionKernel:
42
"""Create custom reduction operation kernel.
43
44
Args:
45
in_params: Input parameter specifications
46
out_params: Output parameter specifications
47
map_expr: Expression for mapping phase
48
reduce_expr: Expression for reduction phase
49
post_map_expr: Expression after mapping
50
identity: Identity value for reduction
51
name: Kernel name
52
reduce_type: Type for reduction variable
53
reduce_dims: Whether to reduce dimensions
54
preamble: Additional declarations
55
56
Example:
57
kernel = ReductionKernel(
58
'float32 x',
59
'float32 y',
60
'x',
61
'a + b',
62
'y = a',
63
'0'
64
)
65
"""
66
67
def __init__(self, in_params, out_params, map_expr, reduce_expr,
68
post_map_expr='', identity=None, name='kernel', reduce_type=None,
69
reduce_dims=True, preamble='', **kwargs):
70
pass
71
72
def __call__(self, *args, **kwargs):
73
"""Execute reduction kernel."""
74
pass
75
76
class RawKernel:
77
"""Create kernel from raw CUDA C++ code.
78
79
Args:
80
code: Complete CUDA kernel source code
81
name: Kernel function name
82
options: Compiler options
83
backend: Compilation backend ('nvcc', 'nvrtc')
84
translate_cucomplex: Translate cuComplex types
85
86
Example:
87
code = '''
88
extern "C" __global__ void my_kernel(float* x, float* y, int n) {
89
int i = blockIdx.x * blockDim.x + threadIdx.x;
90
if (i < n) {
91
y[i] = x[i] * x[i];
92
}
93
}
94
'''
95
kernel = RawKernel(code, 'my_kernel')
96
"""
97
98
def __init__(self, code, name, options=(), backend='nvcc',
99
translate_cucomplex=True, **kwargs):
100
pass
101
102
def __call__(self, grid, block, args, **kwargs):
103
"""Launch kernel with grid/block configuration."""
104
pass
105
106
class RawModule:
107
"""Create module from raw CUDA C++ code.
108
109
Args:
110
code: CUDA source code with multiple functions
111
options: Compiler options
112
backend: Compilation backend
113
name_expressions: Named expressions for kernel names
114
log_stream: Compilation log output stream
115
116
Example:
117
code = '''
118
extern "C" {
119
__global__ void kernel1(float* data) { ... }
120
__global__ void kernel2(float* data) { ... }
121
}
122
'''
123
module = RawModule(code)
124
kernel1 = module.get_function('kernel1')
125
"""
126
127
def __init__(self, code, options=(), backend='nvcc', name_expressions=None,
128
log_stream=None, **kwargs):
129
pass
130
131
def get_function(self, name):
132
"""Get kernel function by name."""
133
pass
134
```
135
136
### JIT Decorators and Compilation
137
138
Advanced JIT compilation with Python decorators and runtime code generation.
139
140
```python { .api }
141
def rawkernel(mode='CUDA'):
142
"""Decorator for raw CUDA kernel functions.
143
144
Args:
145
mode: Compilation mode ('CUDA' or 'HIP')
146
147
Example:
148
@rawkernel()
149
def my_kernel(x, y, size):
150
tid = threadIdx.x + blockIdx.x * blockDim.x
151
if tid < size:
152
y[tid] = x[tid] * x[tid]
153
154
# Launch kernel
155
my_kernel((grid_size,), (block_size,), (x_gpu, y_gpu, n))
156
"""
157
158
def jit(signature=None, device=False, inline=False, cache=True):
159
"""JIT compile Python functions for GPU execution.
160
161
Args:
162
signature: Function signature specification
163
device: Compile for device execution
164
inline: Allow inlining
165
cache: Enable compilation caching
166
167
Returns:
168
Compiled function object
169
"""
170
171
def compile_with_cache(source, name, options=(), arch=None, cachdir=None,
172
prepend_cupy_headers=True, backend='nvcc',
173
translate_cucomplex=True, enable_cooperative_groups=False,
174
name_expressions=None, log_stream=None,
175
cache_in_memory=False, jitify=False):
176
"""Compile CUDA source with caching.
177
178
Args:
179
source: CUDA C++ source code
180
name: Function name to extract
181
options: Compiler options tuple
182
arch: Target architecture
183
cachdir: Cache directory path
184
prepend_cupy_headers: Include CuPy headers
185
backend: Compilation backend
186
translate_cucomplex: Handle cuComplex types
187
enable_cooperative_groups: Enable cooperative groups
188
name_expressions: Named expressions for kernels
189
log_stream: Compilation log stream
190
cache_in_memory: Use in-memory caching
191
jitify: Use Jitify for compilation
192
193
Returns:
194
Function: Compiled CUDA function
195
"""
196
```
197
198
### CUDA Execution Context
199
200
Low-level CUDA execution primitives and thread management.
201
202
```python { .api }
203
# Thread and Block Indexing
204
threadIdx = ThreadIndex() # Thread index within block
205
blockIdx = BlockIndex() # Block index within grid
206
blockDim = BlockDimension() # Block dimensions
207
gridDim = GridDimension() # Grid dimensions
208
209
class ThreadIndex:
210
"""Thread index within block."""
211
x: int # X dimension thread index
212
y: int # Y dimension thread index
213
z: int # Z dimension thread index
214
215
class BlockIndex:
216
"""Block index within grid."""
217
x: int # X dimension block index
218
y: int # Y dimension block index
219
z: int # Z dimension block index
220
221
class BlockDimension:
222
"""Block dimensions."""
223
x: int # X dimension block size
224
y: int # Y dimension block size
225
z: int # Z dimension block size
226
227
class GridDimension:
228
"""Grid dimensions."""
229
x: int # X dimension grid size
230
y: int # Y dimension grid size
231
z: int # Z dimension grid size
232
233
warpsize: int = 32 # Warp size constant
234
235
def laneid():
236
"""Get lane ID within warp (0-31).
237
238
Returns:
239
int: Lane ID within current warp
240
"""
241
242
def grid(ndim):
243
"""Get linearized grid index.
244
245
Args:
246
ndim: Number of dimensions (1, 2, or 3)
247
248
Returns:
249
int or tuple: Grid index
250
"""
251
252
def gridsize(ndim):
253
"""Get total grid size.
254
255
Args:
256
ndim: Number of dimensions
257
258
Returns:
259
int or tuple: Grid size
260
"""
261
```
262
263
### Synchronization Primitives
264
265
Thread and warp synchronization functions for coordinated GPU execution.
266
267
```python { .api }
268
def syncthreads():
269
"""Synchronize all threads in a block.
270
271
Blocks until all threads in the current thread block have reached
272
this point and all memory accesses are visible to all threads.
273
"""
274
275
def syncwarp(mask=0xffffffff):
276
"""Synchronize threads in a warp.
277
278
Args:
279
mask: Thread mask specifying which threads to synchronize
280
281
Note:
282
Only threads with corresponding bit set in mask participate.
283
"""
284
285
def barrier(scope='block'):
286
"""Memory barrier with specified scope.
287
288
Args:
289
scope: Barrier scope ('block', 'grid', 'device', 'system')
290
"""
291
292
def memfence_block():
293
"""Block-level memory fence."""
294
295
def memfence_grid():
296
"""Grid-level memory fence."""
297
298
def memfence_system():
299
"""System-level memory fence."""
300
```
301
302
### Shared Memory Management
303
304
Shared memory allocation and access for high-performance block-local storage.
305
306
```python { .api }
307
def shared_memory(dtype, shape):
308
"""Allocate shared memory array.
309
310
Args:
311
dtype: Data type for array elements
312
shape: Array shape (int or tuple)
313
314
Returns:
315
SharedArray: Shared memory array object
316
317
Example:
318
# Allocate 256 float32 values in shared memory
319
shared_data = shared_memory(cp.float32, 256)
320
321
# 2D shared memory array
322
shared_matrix = shared_memory(cp.float32, (16, 16))
323
"""
324
325
def dynamic_shared_memory(dtype):
326
"""Access dynamically allocated shared memory.
327
328
Args:
329
dtype: Data type for interpreting shared memory
330
331
Returns:
332
SharedArray: Dynamic shared memory view
333
334
Note:
335
Size determined by kernel launch parameters.
336
"""
337
338
class SharedArray:
339
"""Shared memory array interface."""
340
341
def __getitem__(self, key):
342
"""Access shared memory elements."""
343
pass
344
345
def __setitem__(self, key, value):
346
"""Set shared memory elements."""
347
pass
348
```
349
350
### Atomic Operations
351
352
Thread-safe atomic operations for lock-free algorithms and reductions.
353
354
```python { .api }
355
def atomic_add(array, index, value):
356
"""Atomic addition.
357
358
Args:
359
array: Target array
360
index: Array index
361
value: Value to add
362
363
Returns:
364
Previous value at index
365
"""
366
367
def atomic_sub(array, index, value):
368
"""Atomic subtraction."""
369
370
def atomic_exch(array, index, value):
371
"""Atomic exchange.
372
373
Args:
374
array: Target array
375
index: Array index
376
value: New value
377
378
Returns:
379
Previous value at index
380
"""
381
382
def atomic_min(array, index, value):
383
"""Atomic minimum operation."""
384
385
def atomic_max(array, index, value):
386
"""Atomic maximum operation."""
387
388
def atomic_inc(array, index):
389
"""Atomic increment.
390
391
Args:
392
array: Target array
393
index: Array index
394
395
Returns:
396
Previous value at index
397
"""
398
399
def atomic_dec(array, index):
400
"""Atomic decrement."""
401
402
def atomic_cas(array, index, compare, value):
403
"""Atomic compare-and-swap.
404
405
Args:
406
array: Target array
407
index: Array index
408
compare: Compare value
409
value: New value if comparison succeeds
410
411
Returns:
412
Previous value at index
413
"""
414
415
def atomic_and(array, index, value):
416
"""Atomic bitwise AND."""
417
418
def atomic_or(array, index, value):
419
"""Atomic bitwise OR."""
420
421
def atomic_xor(array, index, value):
422
"""Atomic bitwise XOR."""
423
```
424
425
### Warp-Level Operations
426
427
Efficient warp-level collective operations for high-performance algorithms.
428
429
```python { .api }
430
def shfl_sync(mask, var, srcLane, width=32):
431
"""Warp shuffle operation.
432
433
Args:
434
mask: Thread participation mask
435
var: Variable to shuffle
436
srcLane: Source lane index
437
width: Warp width (power of 2, ≤32)
438
439
Returns:
440
Value from source lane
441
"""
442
443
def shfl_up_sync(mask, var, delta, width=32):
444
"""Warp shuffle up operation.
445
446
Args:
447
mask: Thread participation mask
448
var: Variable to shuffle
449
delta: Offset to source lane
450
width: Warp width
451
452
Returns:
453
Value from lane (current - delta)
454
"""
455
456
def shfl_down_sync(mask, var, delta, width=32):
457
"""Warp shuffle down operation.
458
459
Args:
460
mask: Thread participation mask
461
var: Variable to shuffle
462
delta: Offset to source lane
463
width: Warp width
464
465
Returns:
466
Value from lane (current + delta)
467
"""
468
469
def shfl_xor_sync(mask, var, laneMask, width=32):
470
"""Warp shuffle XOR operation.
471
472
Args:
473
mask: Thread participation mask
474
var: Variable to shuffle
475
laneMask: XOR mask for lane selection
476
width: Warp width
477
478
Returns:
479
Value from lane (current ^ laneMask)
480
"""
481
482
def vote_all_sync(mask, predicate):
483
"""Test if predicate is true for all threads in mask.
484
485
Args:
486
mask: Thread participation mask
487
predicate: Boolean expression to test
488
489
Returns:
490
bool: True if all threads have true predicate
491
"""
492
493
def vote_any_sync(mask, predicate):
494
"""Test if predicate is true for any thread in mask."""
495
496
def vote_uni_sync(mask, predicate):
497
"""Test if predicate has same value for all threads."""
498
499
def ballot_sync(mask, predicate):
500
"""Get ballot of predicate results across warp.
501
502
Args:
503
mask: Thread participation mask
504
predicate: Boolean expression
505
506
Returns:
507
int: Bitmask of predicate results
508
"""
509
510
def activemask():
511
"""Get mask of currently active threads in warp.
512
513
Returns:
514
int: Bitmask of active threads
515
"""
516
```
517
518
### Mathematical Functions
519
520
GPU-optimized mathematical functions for kernel development.
521
522
```python { .api }
523
def fma(x, y, z):
524
"""Fused multiply-add: x * y + z with single rounding."""
525
526
def rsqrt(x):
527
"""Fast reciprocal square root: 1/sqrt(x)."""
528
529
def rcp(x):
530
"""Fast reciprocal: 1/x."""
531
532
def sin_pi(x):
533
"""Compute sin(π * x) accurately."""
534
535
def cos_pi(x):
536
"""Compute cos(π * x) accurately."""
537
538
def sincos(x):
539
"""Compute sin and cos simultaneously.
540
541
Returns:
542
tuple: (sin(x), cos(x))
543
"""
544
545
def exp2(x):
546
"""Base-2 exponential: 2^x."""
547
548
def log2(x):
549
"""Base-2 logarithm."""
550
551
def pow(x, y):
552
"""Power function: x^y."""
553
554
def sqrt(x):
555
"""Square root."""
556
557
def cbrt(x):
558
"""Cube root."""
559
560
def hypot(x, y):
561
"""Euclidean distance: sqrt(x^2 + y^2)."""
562
563
def remainder(x, y):
564
"""Floating point remainder."""
565
566
def fmod(x, y):
567
"""Floating point modulo."""
568
569
def copysign(x, y):
570
"""Copy sign of y to magnitude of x."""
571
572
def ldexp(x, exp):
573
"""Compute x * 2^exp."""
574
575
def frexp(x):
576
"""Extract mantissa and exponent.
577
578
Returns:
579
tuple: (mantissa, exponent)
580
"""
581
```
582
583
## Usage Examples
584
585
### Element-wise Kernels
586
587
```python
588
import cupy as cp
589
from cupy import ElementwiseKernel
590
591
# Simple arithmetic kernel
592
add_kernel = ElementwiseKernel(
593
'float32 x, float32 y',
594
'float32 z',
595
'z = x + y',
596
'elementwise_add'
597
)
598
599
# Complex expression kernel
600
norm_kernel = ElementwiseKernel(
601
'float32 x, float32 y',
602
'float32 norm',
603
'norm = sqrt(x * x + y * y)',
604
'vector_norm'
605
)
606
607
# Multi-output kernel
608
polar_kernel = ElementwiseKernel(
609
'float32 x, float32 y',
610
'float32 r, float32 theta',
611
'''
612
r = sqrt(x * x + y * y);
613
theta = atan2(y, x);
614
''',
615
'cartesian_to_polar'
616
)
617
618
# Usage
619
x = cp.random.randn(1000000).astype(cp.float32)
620
y = cp.random.randn(1000000).astype(cp.float32)
621
622
result = add_kernel(x, y)
623
norms = norm_kernel(x, y)
624
r, theta = polar_kernel(x, y)
625
```
626
627
### Reduction Kernels
628
629
```python
630
import cupy as cp
631
from cupy import ReductionKernel
632
633
# Sum reduction
634
sum_kernel = ReductionKernel(
635
'float32 x',
636
'float32 y',
637
'x', # map expression
638
'a + b', # reduce expression
639
'y = a', # post-map expression
640
'0', # identity value
641
'sum_reduction'
642
)
643
644
# Maximum reduction
645
max_kernel = ReductionKernel(
646
'float32 x',
647
'float32 y',
648
'x',
649
'max(a, b)',
650
'y = a',
651
'-INFINITY',
652
'max_reduction'
653
)
654
655
# Variance reduction
656
variance_kernel = ReductionKernel(
657
'float32 x, float32 mean',
658
'float32 var',
659
'(x - mean) * (x - mean)',
660
'a + b',
661
'var = a / (_in_ind.size() - 1)',
662
'0',
663
'variance_reduction'
664
)
665
666
# Usage
667
data = cp.random.randn(1000000).astype(cp.float32)
668
total = sum_kernel(data)
669
maximum = max_kernel(data)
670
mean_val = cp.mean(data)
671
var_val = variance_kernel(data, mean_val)
672
```
673
674
### Raw CUDA Kernels
675
676
```python
677
import cupy as cp
678
from cupy import RawKernel
679
680
# Matrix multiplication kernel
681
matmul_code = '''
682
extern "C" __global__ void matmul(
683
const float* A, const float* B, float* C,
684
int M, int N, int K
685
) {
686
int row = blockIdx.y * blockDim.y + threadIdx.y;
687
int col = blockIdx.x * blockDim.x + threadIdx.x;
688
689
if (row < M && col < N) {
690
float sum = 0.0f;
691
for (int k = 0; k < K; k++) {
692
sum += A[row * K + k] * B[k * N + col];
693
}
694
C[row * N + col] = sum;
695
}
696
}
697
'''
698
699
matmul_kernel = RawKernel(matmul_code, 'matmul')
700
701
# Optimized reduction kernel
702
reduction_code = '''
703
extern "C" __global__ void block_reduce_sum(
704
const float* input, float* output, int n
705
) {
706
extern __shared__ float shared_data[];
707
708
int tid = threadIdx.x;
709
int i = blockIdx.x * blockDim.x + threadIdx.x;
710
711
// Load data into shared memory
712
shared_data[tid] = (i < n) ? input[i] : 0.0f;
713
__syncthreads();
714
715
// Reduction in shared memory
716
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
717
if (tid < s) {
718
shared_data[tid] += shared_data[tid + s];
719
}
720
__syncthreads();
721
}
722
723
// Write result
724
if (tid == 0) {
725
output[blockIdx.x] = shared_data[0];
726
}
727
}
728
'''
729
730
reduce_kernel = RawKernel(reduction_code, 'block_reduce_sum')
731
732
# Usage
733
A = cp.random.randn(512, 256).astype(cp.float32)
734
B = cp.random.randn(256, 128).astype(cp.float32)
735
C = cp.zeros((512, 128), dtype=cp.float32)
736
737
# Launch matrix multiplication
738
block_size = (16, 16)
739
grid_size = ((128 + block_size[0] - 1) // block_size[0],
740
(512 + block_size[1] - 1) // block_size[1])
741
742
matmul_kernel(grid_size, block_size, (A, B, C, 512, 128, 256))
743
```
744
745
### JIT Compilation with Decorators
746
747
```python
748
import cupy as cp
749
from cupyx.jit import rawkernel
750
751
@rawkernel()
752
def saxpy_kernel(a, x, y, n):
753
"""SAXPY: y = a*x + y"""
754
tid = cp.cuda.grid(1)
755
if tid < n:
756
y[tid] = a * x[tid] + y[tid]
757
758
@rawkernel()
759
def transpose_kernel(input_mat, output_mat, width, height):
760
"""Matrix transpose kernel"""
761
col = cp.cuda.blockIdx.x * cp.cuda.blockDim.x + cp.cuda.threadIdx.x
762
row = cp.cuda.blockIdx.y * cp.cuda.blockDim.y + cp.cuda.threadIdx.y
763
764
if col < width and row < height:
765
output_mat[col * height + row] = input_mat[row * width + col]
766
767
@rawkernel()
768
def stencil_kernel(input_arr, output_arr, width, height):
769
"""5-point stencil computation"""
770
col = cp.cuda.blockIdx.x * cp.cuda.blockDim.x + cp.cuda.threadIdx.x
771
row = cp.cuda.blockIdx.y * cp.cuda.blockDim.y + cp.cuda.threadIdx.y
772
773
if 1 <= col < width-1 and 1 <= row < height-1:
774
idx = row * width + col
775
result = (input_arr[idx] +
776
input_arr[idx-1] + input_arr[idx+1] +
777
input_arr[idx-width] + input_arr[idx+width]) * 0.2
778
output_arr[idx] = result
779
780
# Usage
781
n = 1000000
782
a = 2.5
783
x = cp.random.randn(n).astype(cp.float32)
784
y = cp.random.randn(n).astype(cp.float32)
785
786
# Launch SAXPY kernel
787
block_size = 256
788
grid_size = (n + block_size - 1) // block_size
789
saxpy_kernel((grid_size,), (block_size,), (a, x, y, n))
790
```
791
792
### Shared Memory and Atomics
793
794
```python
795
import cupy as cp
796
from cupyx.jit import rawkernel
797
798
@rawkernel()
799
def histogram_kernel(data, bins, hist, n, num_bins):
800
"""Compute histogram using shared memory and atomics"""
801
# Shared memory for local histogram
802
shared_hist = cp.cuda.shared_memory(cp.int32, 256)
803
804
tid = cp.cuda.threadIdx.x
805
bid = cp.cuda.blockIdx.x
806
807
# Initialize shared memory
808
if tid < num_bins:
809
shared_hist[tid] = 0
810
cp.cuda.syncthreads()
811
812
# Process data elements
813
idx = bid * cp.cuda.blockDim.x + tid
814
while idx < n:
815
bin_idx = int(data[idx] * num_bins)
816
if 0 <= bin_idx < num_bins:
817
cp.cuda.atomic_add(shared_hist, bin_idx, 1)
818
idx += cp.cuda.gridDim.x * cp.cuda.blockDim.x
819
820
cp.cuda.syncthreads()
821
822
# Reduce to global histogram
823
if tid < num_bins:
824
cp.cuda.atomic_add(hist, tid, shared_hist[tid])
825
826
@rawkernel()
827
def prefix_sum_kernel(data, result, n):
828
"""Parallel prefix sum using shared memory"""
829
shared_data = cp.cuda.shared_memory(cp.float32, 512)
830
831
tid = cp.cuda.threadIdx.x
832
bid = cp.cuda.blockIdx.x
833
block_size = cp.cuda.blockDim.x
834
835
# Load data
836
idx = bid * block_size + tid
837
shared_data[tid] = data[idx] if idx < n else 0.0
838
cp.cuda.syncthreads()
839
840
# Up-sweep phase
841
offset = 1
842
while offset < block_size:
843
if (tid + 1) % (2 * offset) == 0:
844
shared_data[tid] += shared_data[tid - offset]
845
offset *= 2
846
cp.cuda.syncthreads()
847
848
# Down-sweep phase
849
if tid == block_size - 1:
850
shared_data[tid] = 0.0
851
852
offset = block_size // 2
853
while offset > 0:
854
cp.cuda.syncthreads()
855
if (tid + 1) % (2 * offset) == 0:
856
temp = shared_data[tid - offset]
857
shared_data[tid - offset] = shared_data[tid]
858
shared_data[tid] += temp
859
offset //= 2
860
861
cp.cuda.syncthreads()
862
863
# Store result
864
if idx < n:
865
result[idx] = shared_data[tid]
866
867
# Usage examples
868
data = cp.random.rand(1000000).astype(cp.float32)
869
hist = cp.zeros(256, dtype=cp.int32)
870
871
# Compute histogram
872
block_size = 256
873
grid_size = 128
874
histogram_kernel((grid_size,), (block_size,), (data, None, hist, len(data), 256))
875
876
# Prefix sum
877
prefix_result = cp.zeros_like(data)
878
prefix_sum_kernel((grid_size,), (block_size,), (data, prefix_result, len(data)))
879
```
880
881
### Performance Optimization Techniques
882
883
```python
884
import cupy as cp
885
from cupyx.jit import rawkernel
886
887
@rawkernel()
888
def optimized_gemm_kernel(A, B, C, M, N, K, tile_size=16):
889
"""Optimized matrix multiplication with tiling"""
890
# Shared memory tiles
891
tile_A = cp.cuda.shared_memory(cp.float32, (16, 16))
892
tile_B = cp.cuda.shared_memory(cp.float32, (16, 16))
893
894
# Thread and block indices
895
tx, ty = cp.cuda.threadIdx.x, cp.cuda.threadIdx.y
896
bx, by = cp.cuda.blockIdx.x, cp.cuda.blockIdx.y
897
898
# Calculate output position
899
row = by * tile_size + ty
900
col = bx * tile_size + tx
901
902
result = 0.0
903
904
# Tile across K dimension
905
for tile in range((K + tile_size - 1) // tile_size):
906
# Load tile into shared memory
907
if row < M and tile * tile_size + tx < K:
908
tile_A[ty, tx] = A[row * K + tile * tile_size + tx]
909
else:
910
tile_A[ty, tx] = 0.0
911
912
if col < N and tile * tile_size + ty < K:
913
tile_B[ty, tx] = B[(tile * tile_size + ty) * N + col]
914
else:
915
tile_B[ty, tx] = 0.0
916
917
cp.cuda.syncthreads()
918
919
# Compute partial result
920
for k in range(tile_size):
921
result += tile_A[ty, k] * tile_B[k, tx]
922
923
cp.cuda.syncthreads()
924
925
# Store result
926
if row < M and col < N:
927
C[row * N + col] = result
928
929
# Memory coalescing example
930
@rawkernel()
931
def coalesced_transpose(input_mat, output_mat, width, height, tile_size=32):
932
"""Memory-coalesced matrix transpose"""
933
tile = cp.cuda.shared_memory(cp.float32, (32, 33)) # +1 to avoid bank conflicts
934
935
x = cp.cuda.blockIdx.x * tile_size + cp.cuda.threadIdx.x
936
y = cp.cuda.blockIdx.y * tile_size + cp.cuda.threadIdx.y
937
938
# Load tile with coalesced access
939
if x < width and y < height:
940
tile[cp.cuda.threadIdx.y, cp.cuda.threadIdx.x] = input_mat[y * width + x]
941
942
cp.cuda.syncthreads()
943
944
# Transpose coordinates for output
945
x = cp.cuda.blockIdx.y * tile_size + cp.cuda.threadIdx.x
946
y = cp.cuda.blockIdx.x * tile_size + cp.cuda.threadIdx.y
947
948
# Store with coalesced access
949
if x < height and y < width:
950
output_mat[y * height + x] = tile[cp.cuda.threadIdx.x, cp.cuda.threadIdx.y]
951
952
# Usage with performance considerations
953
M, N, K = 2048, 2048, 2048
954
A = cp.random.randn(M, K).astype(cp.float32)
955
B = cp.random.randn(K, N).astype(cp.float32)
956
C = cp.zeros((M, N), dtype=cp.float32)
957
958
# Optimized GEMM launch
959
tile_size = 16
960
grid_dim = ((N + tile_size - 1) // tile_size, (M + tile_size - 1) // tile_size)
961
block_dim = (tile_size, tile_size)
962
963
optimized_gemm_kernel(grid_dim, block_dim, (A, B, C, M, N, K, tile_size))
964
```