0
# Code Generation
1
2
This section covers the framework for generating efficient Java code for expression evaluation and query execution in Spark Catalyst. Code generation enables high-performance query processing.
3
4
## Core Imports
5
6
```scala
7
import org.apache.spark.sql.catalyst.expressions.codegen._
8
import org.apache.spark.sql.catalyst.expressions._
9
import org.apache.spark.sql.types._
10
```
11
12
## CodegenContext
13
14
The context object that manages code generation state and utilities.
15
16
```scala { .api }
17
class CodegenContext {
18
def freshName(name: String): String
19
def addReferenceObj(objName: String, obj: Any, className: String = null): String
20
def addMutableState(javaType: String, variableName: String, initFunc: String = ""): String
21
def addNewFunction(funcName: String, funcCode: String, inlineToOuterClass: Boolean = false): String
22
def declareMutableStates(): String
23
def initMutableStates(): String
24
def declareAddedFunctions(): String
25
26
// Variable management
27
def freshVariable(name: String, dt: DataType): String
28
def INPUT_ROW: String
29
def currentVars: Seq[ExprCode]
30
def setCurrentVars(currentVars: Seq[ExprCode]): Unit
31
32
// Utility methods
33
def javaType(dt: DataType): String
34
def defaultValue(dt: DataType): String
35
def boxedType(dt: DataType): String
36
def getValue(input: String, dt: DataType, ordinal: String): String
37
def setValue(input: String, dt: DataType, ordinal: String, value: String): String
38
def isNullVar(input: String, ordinal: String): String
39
def setNullVar(input: String, ordinal: String, isNull: String): String
40
}
41
```
42
43
### Usage Example
44
45
```scala
46
import org.apache.spark.sql.catalyst.expressions.codegen._
47
48
// Create code generation context
49
val ctx = new CodegenContext()
50
51
// Generate fresh variable names
52
val varName = ctx.freshName("value")
53
val nullVar = ctx.freshName("isNull")
54
55
// Add mutable state for caching
56
val cacheVar = ctx.addMutableState("java.util.Map", "cache",
57
"cache = new java.util.HashMap();")
58
59
// Generate Java type names
60
val javaType = ctx.javaType(IntegerType) // "int"
61
val boxedType = ctx.boxedType(IntegerType) // "java.lang.Integer"
62
```
63
64
## ExprCode
65
66
Represents generated code for expression evaluation.
67
68
```scala { .api }
69
case class ExprCode(code: String, isNull: String, value: String) {
70
def copyWithCode(newCode: String): ExprCode = copy(code = newCode)
71
}
72
73
object ExprCode {
74
def forNullValue(dataType: DataType): ExprCode = ExprCode("", "true", ctx.defaultValue(dataType))
75
def forNonNullValue(value: String): ExprCode = ExprCode("", "false", value)
76
}
77
```
78
79
### Usage Example
80
81
```scala
82
// Generate code for literal expression
83
val literalCode = ExprCode(
84
code = "",
85
isNull = "false",
86
value = "42"
87
)
88
89
// Generate code for column access
90
val columnCode = ExprCode(
91
code = s"$javaType $varName = $INPUT_ROW.getInt($ordinal);",
92
isNull = s"$INPUT_ROW.isNullAt($ordinal)",
93
value = varName
94
)
95
```
96
97
## CodeGenerator
98
99
Base trait for code generation functionality.
100
101
```scala { .api }
102
trait CodeGenerator[InType <: AnyRef, OutType <: AnyRef] {
103
def generate(expressions: InType): OutType
104
def create(references: Array[Any]): OutType
105
def newCodeGenContext(): CodegenContext
106
def canonicalize(in: InType): InType
107
}
108
```
109
110
## Expression Code Generation
111
112
### CodegenSupport
113
114
Trait for expressions that support code generation.
115
116
```scala { .api }
117
trait CodegenSupport extends Expression {
118
def genCode(ctx: CodegenContext): ExprCode
119
def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode
120
121
// Null handling
122
def genCodeWithNull(ctx: CodegenContext, ev: ExprCode): ExprCode = {
123
val eval = doGenCode(ctx, ev)
124
if (nullable) {
125
ev.copy(code = eval.code, isNull = eval.isNull)
126
} else {
127
ev.copy(code = eval.code, isNull = "false")
128
}
129
}
130
}
131
```
132
133
### Expression Code Generation Examples
134
135
```scala
136
// Literal expression code generation
137
case class Literal(value: Any, dataType: DataType) extends LeafExpression with CodegenSupport {
138
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
139
val javaType = ctx.javaType(dataType)
140
if (value == null) {
141
ExprCode("", "true", ctx.defaultValue(dataType))
142
} else {
143
val literalValue = value match {
144
case s: String => s""""$s""""
145
case _ => value.toString
146
}
147
ExprCode("", "false", literalValue)
148
}
149
}
150
}
151
152
// Binary arithmetic expression code generation
153
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with CodegenSupport {
154
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
155
val leftGen = left.genCode(ctx)
156
val rightGen = right.genCode(ctx)
157
val resultType = ctx.javaType(dataType)
158
159
val code = s"""
160
${leftGen.code}
161
${rightGen.code}
162
$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
163
boolean ${ev.isNull} = ${leftGen.isNull} || ${rightGen.isNull};
164
if (!${ev.isNull}) {
165
${ev.value} = ${leftGen.value} + ${rightGen.value};
166
}
167
"""
168
ev.copy(code = code)
169
}
170
}
171
```
172
173
## Projection Code Generation
174
175
### GenerateUnsafeProjection
176
177
Generates efficient projection code for transforming rows.
178
179
```scala { .api }
180
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
181
def generate(expressions: Seq[Expression]): UnsafeProjection
182
def generate(expressions: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection
183
}
184
185
abstract class UnsafeProjection extends Projection {
186
def apply(row: InternalRow): UnsafeRow
187
}
188
```
189
190
### Usage Example
191
192
```scala
193
import org.apache.spark.sql.catalyst.expressions.codegen._
194
195
// Generate projection for expressions
196
val expressions = Seq(
197
UnresolvedAttribute("name"),
198
Add(UnresolvedAttribute("age"), Literal(1))
199
)
200
201
val projection = GenerateUnsafeProjection.generate(expressions)
202
203
// Apply projection to row
204
val inputRow = InternalRow("Alice", 25)
205
val outputRow = projection.apply(inputRow)
206
```
207
208
## Predicate Code Generation
209
210
### GeneratePredicate
211
212
Generates efficient predicate evaluation code.
213
214
```scala { .api }
215
object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
216
def generate(expression: Expression): Predicate
217
def generate(expression: Expression, inputSchema: Seq[Attribute]): Predicate
218
}
219
220
abstract class Predicate {
221
def eval(row: InternalRow): Boolean
222
}
223
```
224
225
### Usage Example
226
227
```scala
228
// Generate predicate for filter condition
229
val condition = GreaterThan(UnresolvedAttribute("age"), Literal(18))
230
val predicate = GeneratePredicate.generate(condition)
231
232
// Evaluate predicate
233
val row = InternalRow("Alice", 25)
234
val result = predicate.eval(row) // true
235
```
236
237
## Ordering Code Generation
238
239
### GenerateOrdering
240
241
Generates efficient row comparison code for sorting.
242
243
```scala { .api }
244
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] {
245
def generate(ordering: Seq[SortOrder]): Ordering[InternalRow]
246
def generate(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow]
247
}
248
```
249
250
### Usage Example
251
252
```scala
253
// Generate ordering for sort operation
254
val sortOrders = Seq(
255
SortOrder(UnresolvedAttribute("name"), Ascending),
256
SortOrder(UnresolvedAttribute("age"), Descending)
257
)
258
259
val ordering = GenerateOrdering.generate(sortOrders)
260
261
// Use for sorting
262
val rows = Seq(InternalRow("Bob", 30), InternalRow("Alice", 25))
263
val sortedRows = rows.sorted(ordering)
264
```
265
266
## Aggregation Code Generation
267
268
### HashAggregateExec Code Generation
269
270
```scala { .api }
271
// Generated hash aggregation code structure
272
trait HashAggregateExec {
273
def genCode(): String = {
274
s"""
275
public class GeneratedHashAggregation extends org.apache.spark.sql.execution.BufferedRowIterator {
276
private boolean agg_initAgg;
277
private java.util.HashMap agg_hashMap;
278
279
protected void processNext() throws java.io.IOException {
280
while (inputIterator.hasNext()) {
281
InternalRow agg_row = (InternalRow) inputIterator.next();
282
// Generated aggregation logic
283
${generateAggregationCode()}
284
}
285
}
286
}
287
"""
288
}
289
}
290
```
291
292
## Code Generation Utilities
293
294
### Java Code Templates
295
296
```scala { .api }
297
object CodeGenUtils {
298
def genGetValue(input: String, dataType: DataType, ordinal: String): String = {
299
dataType match {
300
case IntegerType => s"$input.getInt($ordinal)"
301
case StringType => s"$input.getUTF8String($ordinal)"
302
case BooleanType => s"$input.getBoolean($ordinal)"
303
// ... other types
304
}
305
}
306
307
def genSetValue(input: String, dataType: DataType, ordinal: String, value: String): String = {
308
dataType match {
309
case IntegerType => s"$input.setInt($ordinal, $value)"
310
case StringType => s"$input.update($ordinal, $value)"
311
// ... other types
312
}
313
}
314
}
315
```
316
317
### Performance Optimizations
318
319
```scala
320
// Generate optimized loops
321
def generateLoop(ctx: CodegenContext, loopVar: String, body: String): String = {
322
s"""
323
for (int $loopVar = 0; $loopVar < numRows; $loopVar++) {
324
$body
325
}
326
"""
327
}
328
329
// Generate branch-free code for better performance
330
def generateBranchFreeCode(condition: String, trueValue: String, falseValue: String): String = {
331
s"($condition) ? $trueValue : $falseValue"
332
}
333
```
334
335
## Error Handling in Generated Code
336
337
```scala
338
def generateTryCatch(ctx: CodegenContext, tryCode: String, exceptionClass: String): String = {
339
s"""
340
try {
341
$tryCode
342
} catch ($exceptionClass e) {
343
throw new RuntimeException("Error in generated code", e);
344
}
345
"""
346
}
347
```
348
349
## Complete Example
350
351
```scala
352
import org.apache.spark.sql.catalyst.expressions._
353
import org.apache.spark.sql.catalyst.expressions.codegen._
354
355
// Define a custom expression with code generation
356
case class MultiplyByTwo(child: Expression) extends UnaryExpression with CodegenSupport {
357
override def dataType: DataType = child.dataType
358
359
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
360
val childGen = child.genCode(ctx)
361
val javaType = ctx.javaType(dataType)
362
363
val code = s"""
364
${childGen.code}
365
boolean ${ev.isNull} = ${childGen.isNull};
366
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
367
if (!${ev.isNull}) {
368
${ev.value} = ${childGen.value} * 2;
369
}
370
"""
371
ev.copy(code = code)
372
}
373
374
override def nullSafeEval(input: Any): Any = {
375
val numeric = input.asInstanceOf[Number]
376
numeric.intValue() * 2
377
}
378
}
379
380
// Use the expression with code generation
381
val expr = MultiplyByTwo(UnresolvedAttribute("value"))
382
val projection = GenerateUnsafeProjection.generate(Seq(expr))
383
```
384
385
The code generation framework enables Catalyst to produce highly optimized Java code that rivals hand-written implementations, providing significant performance improvements for query execution.