Catalyst query optimization framework and expression evaluation engine for Apache Spark SQL
—
This section covers the framework for generating efficient Java code for expression evaluation and query execution in Spark Catalyst. Code generation enables high-performance query processing.
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._The context object that manages code generation state and utilities.
class CodegenContext {
def freshName(name: String): String
def addReferenceObj(objName: String, obj: Any, className: String = null): String
def addMutableState(javaType: String, variableName: String, initFunc: String = ""): String
def addNewFunction(funcName: String, funcCode: String, inlineToOuterClass: Boolean = false): String
def declareMutableStates(): String
def initMutableStates(): String
def declareAddedFunctions(): String
// Variable management
def freshVariable(name: String, dt: DataType): String
def INPUT_ROW: String
def currentVars: Seq[ExprCode]
def setCurrentVars(currentVars: Seq[ExprCode]): Unit
// Utility methods
def javaType(dt: DataType): String
def defaultValue(dt: DataType): String
def boxedType(dt: DataType): String
def getValue(input: String, dt: DataType, ordinal: String): String
def setValue(input: String, dt: DataType, ordinal: String, value: String): String
def isNullVar(input: String, ordinal: String): String
def setNullVar(input: String, ordinal: String, isNull: String): String
}import org.apache.spark.sql.catalyst.expressions.codegen._
// Create code generation context
val ctx = new CodegenContext()
// Generate fresh variable names
val varName = ctx.freshName("value")
val nullVar = ctx.freshName("isNull")
// Add mutable state for caching
val cacheVar = ctx.addMutableState("java.util.Map", "cache",
"cache = new java.util.HashMap();")
// Generate Java type names
val javaType = ctx.javaType(IntegerType) // "int"
val boxedType = ctx.boxedType(IntegerType) // "java.lang.Integer"Represents generated code for expression evaluation.
case class ExprCode(code: String, isNull: String, value: String) {
def copyWithCode(newCode: String): ExprCode = copy(code = newCode)
}
object ExprCode {
def forNullValue(dataType: DataType): ExprCode = ExprCode("", "true", ctx.defaultValue(dataType))
def forNonNullValue(value: String): ExprCode = ExprCode("", "false", value)
}// Generate code for literal expression
val literalCode = ExprCode(
code = "",
isNull = "false",
value = "42"
)
// Generate code for column access
val columnCode = ExprCode(
code = s"$javaType $varName = $INPUT_ROW.getInt($ordinal);",
isNull = s"$INPUT_ROW.isNullAt($ordinal)",
value = varName
)Base trait for code generation functionality.
trait CodeGenerator[InType <: AnyRef, OutType <: AnyRef] {
def generate(expressions: InType): OutType
def create(references: Array[Any]): OutType
def newCodeGenContext(): CodegenContext
def canonicalize(in: InType): InType
}Trait for expressions that support code generation.
trait CodegenSupport extends Expression {
def genCode(ctx: CodegenContext): ExprCode
def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode
// Null handling
def genCodeWithNull(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = doGenCode(ctx, ev)
if (nullable) {
ev.copy(code = eval.code, isNull = eval.isNull)
} else {
ev.copy(code = eval.code, isNull = "false")
}
}
}// Literal expression code generation
case class Literal(value: Any, dataType: DataType) extends LeafExpression with CodegenSupport {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
if (value == null) {
ExprCode("", "true", ctx.defaultValue(dataType))
} else {
val literalValue = value match {
case s: String => s""""$s""""
case _ => value.toString
}
ExprCode("", "false", literalValue)
}
}
}
// Binary arithmetic expression code generation
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with CodegenSupport {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val resultType = ctx.javaType(dataType)
val code = s"""
${leftGen.code}
${rightGen.code}
$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
boolean ${ev.isNull} = ${leftGen.isNull} || ${rightGen.isNull};
if (!${ev.isNull}) {
${ev.value} = ${leftGen.value} + ${rightGen.value};
}
"""
ev.copy(code = code)
}
}Generates efficient projection code for transforming rows.
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
def generate(expressions: Seq[Expression]): UnsafeProjection
def generate(expressions: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection
}
abstract class UnsafeProjection extends Projection {
def apply(row: InternalRow): UnsafeRow
}import org.apache.spark.sql.catalyst.expressions.codegen._
// Generate projection for expressions
val expressions = Seq(
UnresolvedAttribute("name"),
Add(UnresolvedAttribute("age"), Literal(1))
)
val projection = GenerateUnsafeProjection.generate(expressions)
// Apply projection to row
val inputRow = InternalRow("Alice", 25)
val outputRow = projection.apply(inputRow)Generates efficient predicate evaluation code.
object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
def generate(expression: Expression): Predicate
def generate(expression: Expression, inputSchema: Seq[Attribute]): Predicate
}
abstract class Predicate {
def eval(row: InternalRow): Boolean
}// Generate predicate for filter condition
val condition = GreaterThan(UnresolvedAttribute("age"), Literal(18))
val predicate = GeneratePredicate.generate(condition)
// Evaluate predicate
val row = InternalRow("Alice", 25)
val result = predicate.eval(row) // trueGenerates efficient row comparison code for sorting.
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] {
def generate(ordering: Seq[SortOrder]): Ordering[InternalRow]
def generate(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow]
}// Generate ordering for sort operation
val sortOrders = Seq(
SortOrder(UnresolvedAttribute("name"), Ascending),
SortOrder(UnresolvedAttribute("age"), Descending)
)
val ordering = GenerateOrdering.generate(sortOrders)
// Use for sorting
val rows = Seq(InternalRow("Bob", 30), InternalRow("Alice", 25))
val sortedRows = rows.sorted(ordering)// Generated hash aggregation code structure
trait HashAggregateExec {
def genCode(): String = {
s"""
public class GeneratedHashAggregation extends org.apache.spark.sql.execution.BufferedRowIterator {
private boolean agg_initAgg;
private java.util.HashMap agg_hashMap;
protected void processNext() throws java.io.IOException {
while (inputIterator.hasNext()) {
InternalRow agg_row = (InternalRow) inputIterator.next();
// Generated aggregation logic
${generateAggregationCode()}
}
}
}
"""
}
}object CodeGenUtils {
def genGetValue(input: String, dataType: DataType, ordinal: String): String = {
dataType match {
case IntegerType => s"$input.getInt($ordinal)"
case StringType => s"$input.getUTF8String($ordinal)"
case BooleanType => s"$input.getBoolean($ordinal)"
// ... other types
}
}
def genSetValue(input: String, dataType: DataType, ordinal: String, value: String): String = {
dataType match {
case IntegerType => s"$input.setInt($ordinal, $value)"
case StringType => s"$input.update($ordinal, $value)"
// ... other types
}
}
}// Generate optimized loops
def generateLoop(ctx: CodegenContext, loopVar: String, body: String): String = {
s"""
for (int $loopVar = 0; $loopVar < numRows; $loopVar++) {
$body
}
"""
}
// Generate branch-free code for better performance
def generateBranchFreeCode(condition: String, trueValue: String, falseValue: String): String = {
s"($condition) ? $trueValue : $falseValue"
}def generateTryCatch(ctx: CodegenContext, tryCode: String, exceptionClass: String): String = {
s"""
try {
$tryCode
} catch ($exceptionClass e) {
throw new RuntimeException("Error in generated code", e);
}
"""
}import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
// Define a custom expression with code generation
case class MultiplyByTwo(child: Expression) extends UnaryExpression with CodegenSupport {
override def dataType: DataType = child.dataType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = ctx.javaType(dataType)
val code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${childGen.value} * 2;
}
"""
ev.copy(code = code)
}
override def nullSafeEval(input: Any): Any = {
val numeric = input.asInstanceOf[Number]
numeric.intValue() * 2
}
}
// Use the expression with code generation
val expr = MultiplyByTwo(UnresolvedAttribute("value"))
val projection = GenerateUnsafeProjection.generate(Seq(expr))The code generation framework enables Catalyst to produce highly optimized Java code that rivals hand-written implementations, providing significant performance improvements for query execution.
Install with Tessl CLI
npx tessl i tessl/maven-org-apache-spark--spark-catalyst