or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

array-operations.mdcuda-integration.mdfft.mdindex.mdindexing-selection.mdinput-output.mdjit-kernels.mdlinear-algebra.mdlogic-operations.mdmathematical-functions.mdrandom-generation.mdscipy-extensions.mdstatistics.mdtesting.md

jit-kernels.mddocs/

0

# JIT Compilation and Custom Kernels

1

2

Just-in-time compilation capabilities and custom CUDA kernel creation for performance-critical applications requiring low-level GPU programming. CuPy provides comprehensive JIT compilation through kernel templates, raw CUDA kernels, and the `cupyx.jit` module for advanced GPU programming.

3

4

## Capabilities

5

6

### Kernel Templates

7

8

High-level kernel creation for common GPU computation patterns.

9

10

```python { .api }

11

class ElementwiseKernel:

12

"""Create custom element-wise operation kernel.

13

14

Args:

15

in_params: Input parameter specifications

16

out_params: Output parameter specifications

17

operation: CUDA C++ code for element operation

18

name: Kernel name for caching

19

reduce_dims: Whether to reduce dimensions

20

preamble: Additional declarations

21

loop_prep: Code before loop

22

after_loop: Code after loop

23

24

Example:

25

kernel = ElementwiseKernel(

26

'float32 x, float32 y',

27

'float32 z',

28

'z = x * x + y * y',

29

'squared_sum'

30

)

31

"""

32

33

def __init__(self, in_params, out_params, operation, name='kernel',

34

reduce_dims=True, preamble='', loop_prep='', after_loop='', **kwargs):

35

pass

36

37

def __call__(self, *args, **kwargs):

38

"""Execute kernel with given arguments."""

39

pass

40

41

class ReductionKernel:

42

"""Create custom reduction operation kernel.

43

44

Args:

45

in_params: Input parameter specifications

46

out_params: Output parameter specifications

47

map_expr: Expression for mapping phase

48

reduce_expr: Expression for reduction phase

49

post_map_expr: Expression after mapping

50

identity: Identity value for reduction

51

name: Kernel name

52

reduce_type: Type for reduction variable

53

reduce_dims: Whether to reduce dimensions

54

preamble: Additional declarations

55

56

Example:

57

kernel = ReductionKernel(

58

'float32 x',

59

'float32 y',

60

'x',

61

'a + b',

62

'y = a',

63

'0'

64

)

65

"""

66

67

def __init__(self, in_params, out_params, map_expr, reduce_expr,

68

post_map_expr='', identity=None, name='kernel', reduce_type=None,

69

reduce_dims=True, preamble='', **kwargs):

70

pass

71

72

def __call__(self, *args, **kwargs):

73

"""Execute reduction kernel."""

74

pass

75

76

class RawKernel:

77

"""Create kernel from raw CUDA C++ code.

78

79

Args:

80

code: Complete CUDA kernel source code

81

name: Kernel function name

82

options: Compiler options

83

backend: Compilation backend ('nvcc', 'nvrtc')

84

translate_cucomplex: Translate cuComplex types

85

86

Example:

87

code = '''

88

extern "C" __global__ void my_kernel(float* x, float* y, int n) {

89

int i = blockIdx.x * blockDim.x + threadIdx.x;

90

if (i < n) {

91

y[i] = x[i] * x[i];

92

}

93

}

94

'''

95

kernel = RawKernel(code, 'my_kernel')

96

"""

97

98

def __init__(self, code, name, options=(), backend='nvcc',

99

translate_cucomplex=True, **kwargs):

100

pass

101

102

def __call__(self, grid, block, args, **kwargs):

103

"""Launch kernel with grid/block configuration."""

104

pass

105

106

class RawModule:

107

"""Create module from raw CUDA C++ code.

108

109

Args:

110

code: CUDA source code with multiple functions

111

options: Compiler options

112

backend: Compilation backend

113

name_expressions: Named expressions for kernel names

114

log_stream: Compilation log output stream

115

116

Example:

117

code = '''

118

extern "C" {

119

__global__ void kernel1(float* data) { ... }

120

__global__ void kernel2(float* data) { ... }

121

}

122

'''

123

module = RawModule(code)

124

kernel1 = module.get_function('kernel1')

125

"""

126

127

def __init__(self, code, options=(), backend='nvcc', name_expressions=None,

128

log_stream=None, **kwargs):

129

pass

130

131

def get_function(self, name):

132

"""Get kernel function by name."""

133

pass

134

```

135

136

### JIT Decorators and Compilation

137

138

Advanced JIT compilation with Python decorators and runtime code generation.

139

140

```python { .api }

141

def rawkernel(mode='CUDA'):

142

"""Decorator for raw CUDA kernel functions.

143

144

Args:

145

mode: Compilation mode ('CUDA' or 'HIP')

146

147

Example:

148

@rawkernel()

149

def my_kernel(x, y, size):

150

tid = threadIdx.x + blockIdx.x * blockDim.x

151

if tid < size:

152

y[tid] = x[tid] * x[tid]

153

154

# Launch kernel

155

my_kernel((grid_size,), (block_size,), (x_gpu, y_gpu, n))

156

"""

157

158

def jit(signature=None, device=False, inline=False, cache=True):

159

"""JIT compile Python functions for GPU execution.

160

161

Args:

162

signature: Function signature specification

163

device: Compile for device execution

164

inline: Allow inlining

165

cache: Enable compilation caching

166

167

Returns:

168

Compiled function object

169

"""

170

171

def compile_with_cache(source, name, options=(), arch=None, cachdir=None,

172

prepend_cupy_headers=True, backend='nvcc',

173

translate_cucomplex=True, enable_cooperative_groups=False,

174

name_expressions=None, log_stream=None,

175

cache_in_memory=False, jitify=False):

176

"""Compile CUDA source with caching.

177

178

Args:

179

source: CUDA C++ source code

180

name: Function name to extract

181

options: Compiler options tuple

182

arch: Target architecture

183

cachdir: Cache directory path

184

prepend_cupy_headers: Include CuPy headers

185

backend: Compilation backend

186

translate_cucomplex: Handle cuComplex types

187

enable_cooperative_groups: Enable cooperative groups

188

name_expressions: Named expressions for kernels

189

log_stream: Compilation log stream

190

cache_in_memory: Use in-memory caching

191

jitify: Use Jitify for compilation

192

193

Returns:

194

Function: Compiled CUDA function

195

"""

196

```

197

198

### CUDA Execution Context

199

200

Low-level CUDA execution primitives and thread management.

201

202

```python { .api }

203

# Thread and Block Indexing

204

threadIdx = ThreadIndex() # Thread index within block

205

blockIdx = BlockIndex() # Block index within grid

206

blockDim = BlockDimension() # Block dimensions

207

gridDim = GridDimension() # Grid dimensions

208

209

class ThreadIndex:

210

"""Thread index within block."""

211

x: int # X dimension thread index

212

y: int # Y dimension thread index

213

z: int # Z dimension thread index

214

215

class BlockIndex:

216

"""Block index within grid."""

217

x: int # X dimension block index

218

y: int # Y dimension block index

219

z: int # Z dimension block index

220

221

class BlockDimension:

222

"""Block dimensions."""

223

x: int # X dimension block size

224

y: int # Y dimension block size

225

z: int # Z dimension block size

226

227

class GridDimension:

228

"""Grid dimensions."""

229

x: int # X dimension grid size

230

y: int # Y dimension grid size

231

z: int # Z dimension grid size

232

233

warpsize: int = 32 # Warp size constant

234

235

def laneid():

236

"""Get lane ID within warp (0-31).

237

238

Returns:

239

int: Lane ID within current warp

240

"""

241

242

def grid(ndim):

243

"""Get linearized grid index.

244

245

Args:

246

ndim: Number of dimensions (1, 2, or 3)

247

248

Returns:

249

int or tuple: Grid index

250

"""

251

252

def gridsize(ndim):

253

"""Get total grid size.

254

255

Args:

256

ndim: Number of dimensions

257

258

Returns:

259

int or tuple: Grid size

260

"""

261

```

262

263

### Synchronization Primitives

264

265

Thread and warp synchronization functions for coordinated GPU execution.

266

267

```python { .api }

268

def syncthreads():

269

"""Synchronize all threads in a block.

270

271

Blocks until all threads in the current thread block have reached

272

this point and all memory accesses are visible to all threads.

273

"""

274

275

def syncwarp(mask=0xffffffff):

276

"""Synchronize threads in a warp.

277

278

Args:

279

mask: Thread mask specifying which threads to synchronize

280

281

Note:

282

Only threads with corresponding bit set in mask participate.

283

"""

284

285

def barrier(scope='block'):

286

"""Memory barrier with specified scope.

287

288

Args:

289

scope: Barrier scope ('block', 'grid', 'device', 'system')

290

"""

291

292

def memfence_block():

293

"""Block-level memory fence."""

294

295

def memfence_grid():

296

"""Grid-level memory fence."""

297

298

def memfence_system():

299

"""System-level memory fence."""

300

```

301

302

### Shared Memory Management

303

304

Shared memory allocation and access for high-performance block-local storage.

305

306

```python { .api }

307

def shared_memory(dtype, shape):

308

"""Allocate shared memory array.

309

310

Args:

311

dtype: Data type for array elements

312

shape: Array shape (int or tuple)

313

314

Returns:

315

SharedArray: Shared memory array object

316

317

Example:

318

# Allocate 256 float32 values in shared memory

319

shared_data = shared_memory(cp.float32, 256)

320

321

# 2D shared memory array

322

shared_matrix = shared_memory(cp.float32, (16, 16))

323

"""

324

325

def dynamic_shared_memory(dtype):

326

"""Access dynamically allocated shared memory.

327

328

Args:

329

dtype: Data type for interpreting shared memory

330

331

Returns:

332

SharedArray: Dynamic shared memory view

333

334

Note:

335

Size determined by kernel launch parameters.

336

"""

337

338

class SharedArray:

339

"""Shared memory array interface."""

340

341

def __getitem__(self, key):

342

"""Access shared memory elements."""

343

pass

344

345

def __setitem__(self, key, value):

346

"""Set shared memory elements."""

347

pass

348

```

349

350

### Atomic Operations

351

352

Thread-safe atomic operations for lock-free algorithms and reductions.

353

354

```python { .api }

355

def atomic_add(array, index, value):

356

"""Atomic addition.

357

358

Args:

359

array: Target array

360

index: Array index

361

value: Value to add

362

363

Returns:

364

Previous value at index

365

"""

366

367

def atomic_sub(array, index, value):

368

"""Atomic subtraction."""

369

370

def atomic_exch(array, index, value):

371

"""Atomic exchange.

372

373

Args:

374

array: Target array

375

index: Array index

376

value: New value

377

378

Returns:

379

Previous value at index

380

"""

381

382

def atomic_min(array, index, value):

383

"""Atomic minimum operation."""

384

385

def atomic_max(array, index, value):

386

"""Atomic maximum operation."""

387

388

def atomic_inc(array, index):

389

"""Atomic increment.

390

391

Args:

392

array: Target array

393

index: Array index

394

395

Returns:

396

Previous value at index

397

"""

398

399

def atomic_dec(array, index):

400

"""Atomic decrement."""

401

402

def atomic_cas(array, index, compare, value):

403

"""Atomic compare-and-swap.

404

405

Args:

406

array: Target array

407

index: Array index

408

compare: Compare value

409

value: New value if comparison succeeds

410

411

Returns:

412

Previous value at index

413

"""

414

415

def atomic_and(array, index, value):

416

"""Atomic bitwise AND."""

417

418

def atomic_or(array, index, value):

419

"""Atomic bitwise OR."""

420

421

def atomic_xor(array, index, value):

422

"""Atomic bitwise XOR."""

423

```

424

425

### Warp-Level Operations

426

427

Efficient warp-level collective operations for high-performance algorithms.

428

429

```python { .api }

430

def shfl_sync(mask, var, srcLane, width=32):

431

"""Warp shuffle operation.

432

433

Args:

434

mask: Thread participation mask

435

var: Variable to shuffle

436

srcLane: Source lane index

437

width: Warp width (power of 2, ≤32)

438

439

Returns:

440

Value from source lane

441

"""

442

443

def shfl_up_sync(mask, var, delta, width=32):

444

"""Warp shuffle up operation.

445

446

Args:

447

mask: Thread participation mask

448

var: Variable to shuffle

449

delta: Offset to source lane

450

width: Warp width

451

452

Returns:

453

Value from lane (current - delta)

454

"""

455

456

def shfl_down_sync(mask, var, delta, width=32):

457

"""Warp shuffle down operation.

458

459

Args:

460

mask: Thread participation mask

461

var: Variable to shuffle

462

delta: Offset to source lane

463

width: Warp width

464

465

Returns:

466

Value from lane (current + delta)

467

"""

468

469

def shfl_xor_sync(mask, var, laneMask, width=32):

470

"""Warp shuffle XOR operation.

471

472

Args:

473

mask: Thread participation mask

474

var: Variable to shuffle

475

laneMask: XOR mask for lane selection

476

width: Warp width

477

478

Returns:

479

Value from lane (current ^ laneMask)

480

"""

481

482

def vote_all_sync(mask, predicate):

483

"""Test if predicate is true for all threads in mask.

484

485

Args:

486

mask: Thread participation mask

487

predicate: Boolean expression to test

488

489

Returns:

490

bool: True if all threads have true predicate

491

"""

492

493

def vote_any_sync(mask, predicate):

494

"""Test if predicate is true for any thread in mask."""

495

496

def vote_uni_sync(mask, predicate):

497

"""Test if predicate has same value for all threads."""

498

499

def ballot_sync(mask, predicate):

500

"""Get ballot of predicate results across warp.

501

502

Args:

503

mask: Thread participation mask

504

predicate: Boolean expression

505

506

Returns:

507

int: Bitmask of predicate results

508

"""

509

510

def activemask():

511

"""Get mask of currently active threads in warp.

512

513

Returns:

514

int: Bitmask of active threads

515

"""

516

```

517

518

### Mathematical Functions

519

520

GPU-optimized mathematical functions for kernel development.

521

522

```python { .api }

523

def fma(x, y, z):

524

"""Fused multiply-add: x * y + z with single rounding."""

525

526

def rsqrt(x):

527

"""Fast reciprocal square root: 1/sqrt(x)."""

528

529

def rcp(x):

530

"""Fast reciprocal: 1/x."""

531

532

def sin_pi(x):

533

"""Compute sin(π * x) accurately."""

534

535

def cos_pi(x):

536

"""Compute cos(π * x) accurately."""

537

538

def sincos(x):

539

"""Compute sin and cos simultaneously.

540

541

Returns:

542

tuple: (sin(x), cos(x))

543

"""

544

545

def exp2(x):

546

"""Base-2 exponential: 2^x."""

547

548

def log2(x):

549

"""Base-2 logarithm."""

550

551

def pow(x, y):

552

"""Power function: x^y."""

553

554

def sqrt(x):

555

"""Square root."""

556

557

def cbrt(x):

558

"""Cube root."""

559

560

def hypot(x, y):

561

"""Euclidean distance: sqrt(x^2 + y^2)."""

562

563

def remainder(x, y):

564

"""Floating point remainder."""

565

566

def fmod(x, y):

567

"""Floating point modulo."""

568

569

def copysign(x, y):

570

"""Copy sign of y to magnitude of x."""

571

572

def ldexp(x, exp):

573

"""Compute x * 2^exp."""

574

575

def frexp(x):

576

"""Extract mantissa and exponent.

577

578

Returns:

579

tuple: (mantissa, exponent)

580

"""

581

```

582

583

## Usage Examples

584

585

### Element-wise Kernels

586

587

```python

588

import cupy as cp

589

from cupy import ElementwiseKernel

590

591

# Simple arithmetic kernel

592

add_kernel = ElementwiseKernel(

593

'float32 x, float32 y',

594

'float32 z',

595

'z = x + y',

596

'elementwise_add'

597

)

598

599

# Complex expression kernel

600

norm_kernel = ElementwiseKernel(

601

'float32 x, float32 y',

602

'float32 norm',

603

'norm = sqrt(x * x + y * y)',

604

'vector_norm'

605

)

606

607

# Multi-output kernel

608

polar_kernel = ElementwiseKernel(

609

'float32 x, float32 y',

610

'float32 r, float32 theta',

611

'''

612

r = sqrt(x * x + y * y);

613

theta = atan2(y, x);

614

''',

615

'cartesian_to_polar'

616

)

617

618

# Usage

619

x = cp.random.randn(1000000).astype(cp.float32)

620

y = cp.random.randn(1000000).astype(cp.float32)

621

622

result = add_kernel(x, y)

623

norms = norm_kernel(x, y)

624

r, theta = polar_kernel(x, y)

625

```

626

627

### Reduction Kernels

628

629

```python

630

import cupy as cp

631

from cupy import ReductionKernel

632

633

# Sum reduction

634

sum_kernel = ReductionKernel(

635

'float32 x',

636

'float32 y',

637

'x', # map expression

638

'a + b', # reduce expression

639

'y = a', # post-map expression

640

'0', # identity value

641

'sum_reduction'

642

)

643

644

# Maximum reduction

645

max_kernel = ReductionKernel(

646

'float32 x',

647

'float32 y',

648

'x',

649

'max(a, b)',

650

'y = a',

651

'-INFINITY',

652

'max_reduction'

653

)

654

655

# Variance reduction

656

variance_kernel = ReductionKernel(

657

'float32 x, float32 mean',

658

'float32 var',

659

'(x - mean) * (x - mean)',

660

'a + b',

661

'var = a / (_in_ind.size() - 1)',

662

'0',

663

'variance_reduction'

664

)

665

666

# Usage

667

data = cp.random.randn(1000000).astype(cp.float32)

668

total = sum_kernel(data)

669

maximum = max_kernel(data)

670

mean_val = cp.mean(data)

671

var_val = variance_kernel(data, mean_val)

672

```

673

674

### Raw CUDA Kernels

675

676

```python

677

import cupy as cp

678

from cupy import RawKernel

679

680

# Matrix multiplication kernel

681

matmul_code = '''

682

extern "C" __global__ void matmul(

683

const float* A, const float* B, float* C,

684

int M, int N, int K

685

) {

686

int row = blockIdx.y * blockDim.y + threadIdx.y;

687

int col = blockIdx.x * blockDim.x + threadIdx.x;

688

689

if (row < M && col < N) {

690

float sum = 0.0f;

691

for (int k = 0; k < K; k++) {

692

sum += A[row * K + k] * B[k * N + col];

693

}

694

C[row * N + col] = sum;

695

}

696

}

697

'''

698

699

matmul_kernel = RawKernel(matmul_code, 'matmul')

700

701

# Optimized reduction kernel

702

reduction_code = '''

703

extern "C" __global__ void block_reduce_sum(

704

const float* input, float* output, int n

705

) {

706

extern __shared__ float shared_data[];

707

708

int tid = threadIdx.x;

709

int i = blockIdx.x * blockDim.x + threadIdx.x;

710

711

// Load data into shared memory

712

shared_data[tid] = (i < n) ? input[i] : 0.0f;

713

__syncthreads();

714

715

// Reduction in shared memory

716

for (int s = blockDim.x / 2; s > 0; s >>= 1) {

717

if (tid < s) {

718

shared_data[tid] += shared_data[tid + s];

719

}

720

__syncthreads();

721

}

722

723

// Write result

724

if (tid == 0) {

725

output[blockIdx.x] = shared_data[0];

726

}

727

}

728

'''

729

730

reduce_kernel = RawKernel(reduction_code, 'block_reduce_sum')

731

732

# Usage

733

A = cp.random.randn(512, 256).astype(cp.float32)

734

B = cp.random.randn(256, 128).astype(cp.float32)

735

C = cp.zeros((512, 128), dtype=cp.float32)

736

737

# Launch matrix multiplication

738

block_size = (16, 16)

739

grid_size = ((128 + block_size[0] - 1) // block_size[0],

740

(512 + block_size[1] - 1) // block_size[1])

741

742

matmul_kernel(grid_size, block_size, (A, B, C, 512, 128, 256))

743

```

744

745

### JIT Compilation with Decorators

746

747

```python

748

import cupy as cp

749

from cupyx.jit import rawkernel

750

751

@rawkernel()

752

def saxpy_kernel(a, x, y, n):

753

"""SAXPY: y = a*x + y"""

754

tid = cp.cuda.grid(1)

755

if tid < n:

756

y[tid] = a * x[tid] + y[tid]

757

758

@rawkernel()

759

def transpose_kernel(input_mat, output_mat, width, height):

760

"""Matrix transpose kernel"""

761

col = cp.cuda.blockIdx.x * cp.cuda.blockDim.x + cp.cuda.threadIdx.x

762

row = cp.cuda.blockIdx.y * cp.cuda.blockDim.y + cp.cuda.threadIdx.y

763

764

if col < width and row < height:

765

output_mat[col * height + row] = input_mat[row * width + col]

766

767

@rawkernel()

768

def stencil_kernel(input_arr, output_arr, width, height):

769

"""5-point stencil computation"""

770

col = cp.cuda.blockIdx.x * cp.cuda.blockDim.x + cp.cuda.threadIdx.x

771

row = cp.cuda.blockIdx.y * cp.cuda.blockDim.y + cp.cuda.threadIdx.y

772

773

if 1 <= col < width-1 and 1 <= row < height-1:

774

idx = row * width + col

775

result = (input_arr[idx] +

776

input_arr[idx-1] + input_arr[idx+1] +

777

input_arr[idx-width] + input_arr[idx+width]) * 0.2

778

output_arr[idx] = result

779

780

# Usage

781

n = 1000000

782

a = 2.5

783

x = cp.random.randn(n).astype(cp.float32)

784

y = cp.random.randn(n).astype(cp.float32)

785

786

# Launch SAXPY kernel

787

block_size = 256

788

grid_size = (n + block_size - 1) // block_size

789

saxpy_kernel((grid_size,), (block_size,), (a, x, y, n))

790

```

791

792

### Shared Memory and Atomics

793

794

```python

795

import cupy as cp

796

from cupyx.jit import rawkernel

797

798

@rawkernel()

799

def histogram_kernel(data, bins, hist, n, num_bins):

800

"""Compute histogram using shared memory and atomics"""

801

# Shared memory for local histogram

802

shared_hist = cp.cuda.shared_memory(cp.int32, 256)

803

804

tid = cp.cuda.threadIdx.x

805

bid = cp.cuda.blockIdx.x

806

807

# Initialize shared memory

808

if tid < num_bins:

809

shared_hist[tid] = 0

810

cp.cuda.syncthreads()

811

812

# Process data elements

813

idx = bid * cp.cuda.blockDim.x + tid

814

while idx < n:

815

bin_idx = int(data[idx] * num_bins)

816

if 0 <= bin_idx < num_bins:

817

cp.cuda.atomic_add(shared_hist, bin_idx, 1)

818

idx += cp.cuda.gridDim.x * cp.cuda.blockDim.x

819

820

cp.cuda.syncthreads()

821

822

# Reduce to global histogram

823

if tid < num_bins:

824

cp.cuda.atomic_add(hist, tid, shared_hist[tid])

825

826

@rawkernel()

827

def prefix_sum_kernel(data, result, n):

828

"""Parallel prefix sum using shared memory"""

829

shared_data = cp.cuda.shared_memory(cp.float32, 512)

830

831

tid = cp.cuda.threadIdx.x

832

bid = cp.cuda.blockIdx.x

833

block_size = cp.cuda.blockDim.x

834

835

# Load data

836

idx = bid * block_size + tid

837

shared_data[tid] = data[idx] if idx < n else 0.0

838

cp.cuda.syncthreads()

839

840

# Up-sweep phase

841

offset = 1

842

while offset < block_size:

843

if (tid + 1) % (2 * offset) == 0:

844

shared_data[tid] += shared_data[tid - offset]

845

offset *= 2

846

cp.cuda.syncthreads()

847

848

# Down-sweep phase

849

if tid == block_size - 1:

850

shared_data[tid] = 0.0

851

852

offset = block_size // 2

853

while offset > 0:

854

cp.cuda.syncthreads()

855

if (tid + 1) % (2 * offset) == 0:

856

temp = shared_data[tid - offset]

857

shared_data[tid - offset] = shared_data[tid]

858

shared_data[tid] += temp

859

offset //= 2

860

861

cp.cuda.syncthreads()

862

863

# Store result

864

if idx < n:

865

result[idx] = shared_data[tid]

866

867

# Usage examples

868

data = cp.random.rand(1000000).astype(cp.float32)

869

hist = cp.zeros(256, dtype=cp.int32)

870

871

# Compute histogram

872

block_size = 256

873

grid_size = 128

874

histogram_kernel((grid_size,), (block_size,), (data, None, hist, len(data), 256))

875

876

# Prefix sum

877

prefix_result = cp.zeros_like(data)

878

prefix_sum_kernel((grid_size,), (block_size,), (data, prefix_result, len(data)))

879

```

880

881

### Performance Optimization Techniques

882

883

```python

884

import cupy as cp

885

from cupyx.jit import rawkernel

886

887

@rawkernel()

888

def optimized_gemm_kernel(A, B, C, M, N, K, tile_size=16):

889

"""Optimized matrix multiplication with tiling"""

890

# Shared memory tiles

891

tile_A = cp.cuda.shared_memory(cp.float32, (16, 16))

892

tile_B = cp.cuda.shared_memory(cp.float32, (16, 16))

893

894

# Thread and block indices

895

tx, ty = cp.cuda.threadIdx.x, cp.cuda.threadIdx.y

896

bx, by = cp.cuda.blockIdx.x, cp.cuda.blockIdx.y

897

898

# Calculate output position

899

row = by * tile_size + ty

900

col = bx * tile_size + tx

901

902

result = 0.0

903

904

# Tile across K dimension

905

for tile in range((K + tile_size - 1) // tile_size):

906

# Load tile into shared memory

907

if row < M and tile * tile_size + tx < K:

908

tile_A[ty, tx] = A[row * K + tile * tile_size + tx]

909

else:

910

tile_A[ty, tx] = 0.0

911

912

if col < N and tile * tile_size + ty < K:

913

tile_B[ty, tx] = B[(tile * tile_size + ty) * N + col]

914

else:

915

tile_B[ty, tx] = 0.0

916

917

cp.cuda.syncthreads()

918

919

# Compute partial result

920

for k in range(tile_size):

921

result += tile_A[ty, k] * tile_B[k, tx]

922

923

cp.cuda.syncthreads()

924

925

# Store result

926

if row < M and col < N:

927

C[row * N + col] = result

928

929

# Memory coalescing example

930

@rawkernel()

931

def coalesced_transpose(input_mat, output_mat, width, height, tile_size=32):

932

"""Memory-coalesced matrix transpose"""

933

tile = cp.cuda.shared_memory(cp.float32, (32, 33)) # +1 to avoid bank conflicts

934

935

x = cp.cuda.blockIdx.x * tile_size + cp.cuda.threadIdx.x

936

y = cp.cuda.blockIdx.y * tile_size + cp.cuda.threadIdx.y

937

938

# Load tile with coalesced access

939

if x < width and y < height:

940

tile[cp.cuda.threadIdx.y, cp.cuda.threadIdx.x] = input_mat[y * width + x]

941

942

cp.cuda.syncthreads()

943

944

# Transpose coordinates for output

945

x = cp.cuda.blockIdx.y * tile_size + cp.cuda.threadIdx.x

946

y = cp.cuda.blockIdx.x * tile_size + cp.cuda.threadIdx.y

947

948

# Store with coalesced access

949

if x < height and y < width:

950

output_mat[y * height + x] = tile[cp.cuda.threadIdx.x, cp.cuda.threadIdx.y]

951

952

# Usage with performance considerations

953

M, N, K = 2048, 2048, 2048

954

A = cp.random.randn(M, K).astype(cp.float32)

955

B = cp.random.randn(K, N).astype(cp.float32)

956

C = cp.zeros((M, N), dtype=cp.float32)

957

958

# Optimized GEMM launch

959

tile_size = 16

960

grid_dim = ((N + tile_size - 1) // tile_size, (M + tile_size - 1) // tile_size)

961

block_dim = (tile_size, tile_size)

962

963

optimized_gemm_kernel(grid_dim, block_dim, (A, B, C, M, N, K, tile_size))

964

```