Catalyst query optimization framework and expression evaluation engine for Apache Spark SQL
—
Catalyst's expression system provides a comprehensive framework for representing computations, predicates, and data transformations. Expressions form the foundation of SQL operations and can be evaluated both interpretively and through code generation.
Core interface for all expressions in Catalyst with evaluation and code generation capabilities.
/**
* Base class for all expressions in Catalyst
* Expressions represent computations that can be evaluated on input data
*/
abstract class Expression extends TreeNode[Expression] {
/** Data type of the expression result */
def dataType: DataType
/** Whether expression can produce null values */
def nullable: Boolean
/** Whether expression can be constant folded */
def foldable: Boolean
/** Whether expression is deterministic (same input -> same output) */
def deterministic: Boolean = true
/** Set of attributes referenced by this expression */
def references: AttributeSet
/** Evaluate expression on input row (interpreted mode) */
def eval(input: InternalRow = null): Any
/** Generate Java code for expression (code generation mode) */
def genCode(ctx: CodegenContext): ExprCode
/** Validate input data types */
def checkInputDataTypes(): TypeCheckResult
/** Human-readable name for expression */
def prettyName: String
/** String representation for SQL */
def sql: String
/** Clone expression with new children */
def withNewChildren(newChildren: Seq[Expression]): Expression
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
// Create literal expression
val literal = Literal(42, IntegerType)
println(literal.dataType) // IntegerType
println(literal.nullable) // false
println(literal.foldable) // true
println(literal.deterministic) // true
// Evaluate expression
val result = literal.eval() // Returns 42
// Expression properties
println(literal.prettyName) // "42"
println(literal.sql) // "42"Specialized traits for different expression patterns:
/** Expression with no child expressions */
trait LeafExpression extends Expression {
def children: Seq[Expression] = Nil
}
/** Expression with one child expression */
trait UnaryExpression extends Expression {
def child: Expression
def children: Seq[Expression] = child :: Nil
}
/** Expression with two child expressions */
trait BinaryExpression extends Expression {
def left: Expression
def right: Expression
def children: Seq[Expression] = Seq(left, right)
}
/** Expression with three child expressions */
trait TernaryExpression extends Expression {
def first: Expression
def second: Expression
def third: Expression
def children: Seq[Expression] = Seq(first, second, third)
}
/** Special case of BinaryExpression requiring same data type for both children */
trait BinaryOperator extends BinaryExpression {
def dataType: DataType = left.dataType
}
/** Marker for expressions that are not deterministic */
trait Nondeterministic extends Expression {
final override def deterministic: Boolean = false
/** Initialize per-partition state */
def initializeInternal(partitionIndex: Int): Unit
/** Evaluate with nondeterministic behavior */
def evalInternal(input: InternalRow): Any
}
/** Expression that should not be evaluated (analysis-time only) */
trait Unevaluable extends Expression {
final override def eval(input: InternalRow = null): Any =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
}
/** Expression that falls back to interpreted mode (no code generation) */
trait CodegenFallback extends Expression
/** Expression that cannot be expressed in SQL */
trait NonSQLExpression extends Expression
/** Expression replaced by another expression at runtime */
trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
def replacement: Expression
}Expressions that have names and can be referenced in query plans:
/**
* Expression that has a name and can be referenced
*/
trait NamedExpression extends Expression {
/** Expression name */
def name: String
/** Unique expression ID */
def exprId: ExprId
/** Qualification path (e.g., table name) */
def qualifiers: Seq[String]
/** Convert to attribute reference */
def toAttribute: Attribute
/** Create new instance with same name but different ID */
def newInstance(): NamedExpression
/** Create copy with new qualifiers */
def withQualifier(newQualifier: Seq[String]): NamedExpression
/** Metadata associated with this expression */
def metadata: Metadata
}
/**
* Reference to an attribute/column
*/
abstract class Attribute extends LeafExpression with NamedExpression with Unevaluable {
/** Create copy with new name */
def withName(newName: String): Attribute
/** Create copy with new qualifier */
def withQualifier(newQualifier: Seq[String]): Attribute
/** Create copy with new expression ID */
def withExprId(newExprId: ExprId): Attribute
/** Create copy with new data type */
def withDataType(newType: DataType): Attribute
/** Create copy with new nullability */
def withNullability(newNullability: Boolean): Attribute
/** Create copy with new metadata */
def withMetadata(newMetadata: Metadata): Attribute
}
/**
* Reference to an attribute with name, data type, and nullability
*/
case class AttributeReference(
name: String,
dataType: DataType,
nullable: Boolean = true,
metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Seq[String] = Seq.empty[String]) extends Attribute {
def qualifiers: Seq[String] = qualifier
// Implementation of withX methods for creating copies
def withName(newName: String): AttributeReference = copy(name = newName)(exprId, qualifier)
def withQualifier(newQualifier: Seq[String]): AttributeReference = copy()(exprId, newQualifier)
def withExprId(newExprId: ExprId): AttributeReference = copy()(newExprId, qualifier)
def withDataType(newType: DataType): AttributeReference = copy(dataType = newType)(exprId, qualifier)
def withNullability(newNullability: Boolean): AttributeReference = copy(nullable = newNullability)(exprId, qualifier)
def withMetadata(newMetadata: Metadata): AttributeReference = copy(metadata = newMetadata)(exprId, qualifier)
def newInstance(): AttributeReference = copy()(NamedExpression.newExprId, qualifier)
def toAttribute: AttributeReference = this
}
/**
* Expression with an alias name
*/
case class Alias(
child: Expression,
name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Seq[String] = Seq.empty[String],
val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression {
def qualifiers: Seq[String] = qualifier
def dataType: DataType = child.dataType
def nullable: Boolean = child.nullable
def metadata: Metadata = explicitMetadata.getOrElse(Metadata.empty)
def newInstance(): NamedExpression = copy()(NamedExpression.newExprId, qualifier, explicitMetadata)
def toAttribute: Attribute = AttributeReference(name, dataType, nullable, metadata)(exprId, qualifier)
}Usage Examples:
// Attribute reference creation
val nameAttr = AttributeReference("name", StringType, nullable = false)()
val ageAttr = AttributeReference("age", IntegerType, nullable = false)()
// Qualified attribute
val qualifiedAttr = AttributeReference("id", LongType)(
exprId = NamedExpression.newExprId,
qualifier = Seq("users")
)
// Alias creation
val ageAlias = Alias(ageAttr, "user_age")()
val computedAlias = Alias(Add(ageAttr, Literal(1)), "age_plus_one")()
// Attribute properties
println(nameAttr.name) // "name"
println(nameAttr.dataType) // StringType
println(nameAttr.nullable) // false
println(nameAttr.exprId) // Unique ExprId
// Create copies with modifications
val nullableNameAttr = nameAttr.withNullability(true)
val renamedAttr = nameAttr.withName("full_name")Constant value expressions for representing literal data:
/**
* Constant literal value expression
*/
case class Literal(value: Any, dataType: DataType) extends LeafExpression {
def nullable: Boolean = value == null
def foldable: Boolean = true
def eval(input: InternalRow = null): Any = value
override def toString: String = if (value == null) "null" else value.toString
}
object Literal {
/** Create literal with inferred type */
def apply(v: Any): Literal = v match {
case null => Literal(null, NullType)
case b: Boolean => Literal(b, BooleanType)
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
case i: Int => Literal(i, IntegerType)
case l: Long => Literal(l, LongType)
case f: Float => Literal(f, FloatType)
case d: Double => Literal(d, DoubleType)
case s: String => Literal(s, StringType)
case d: java.math.BigDecimal => Literal(d, DecimalType.fromBigDecimal(d))
case d: java.sql.Date => Literal(d, DateType)
case t: java.sql.Timestamp => Literal(t, TimestampType)
case a: Array[Byte] => Literal(a, BinaryType)
case _ => throw new RuntimeException(s"Unsupported literal type: ${v.getClass}")
}
/** Create literal with explicit type */
def create(v: Any, dataType: DataType): Literal = Literal(v, dataType)
/** Create from Scala object with proper internal representation */
def fromObject(obj: Any): Literal = apply(obj)
/** Create default value for data type */
def default(dataType: DataType): Literal = dataType match {
case BooleanType => Literal(false)
case ByteType => Literal(0.toByte)
case ShortType => Literal(0.toShort)
case IntegerType => Literal(0)
case LongType => Literal(0L)
case FloatType => Literal(0.0f)
case DoubleType => Literal(0.0)
case StringType => Literal("")
case _ => Literal(null, dataType)
}
}Usage Examples:
// Literal creation with type inference
val intLit = Literal(42) // Literal(42, IntegerType)
val stringLit = Literal("hello") // Literal("hello", StringType)
val nullLit = Literal(null) // Literal(null, NullType)
// Explicit type creation
val decimalLit = Literal.create(new java.math.BigDecimal("123.45"), DecimalType(5, 2))
// Default values
val defaultInt = Literal.default(IntegerType) // Literal(0, IntegerType)
val defaultString = Literal.default(StringType) // Literal("", StringType)
// Evaluation
val result = intLit.eval() // Returns 42
println(intLit.foldable) // true (can be constant folded)Boolean expressions for filtering and conditional logic:
/**
* Base class for boolean expressions/predicates
*/
abstract class Predicate extends Expression {
def dataType: DataType = BooleanType
}
/**
* Base class for binary comparison expressions
*/
abstract class Comparison extends BinaryExpression with Predicate
/** Equality comparison (=) */
case class EqualTo(left: Expression, right: Expression) extends BinaryOperator with Comparison {
def symbol: String = "="
def nullSafeEval(input1: Any, input2: Any): Any = input1 == input2
}
/** Null-safe equality comparison (<=>) */
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryOperator with Comparison {
def symbol: String = "<=>"
override def nullable: Boolean = false
def nullSafeEval(input1: Any, input2: Any): Any = {
if (input1 == null && input2 == null) true
else if (input1 == null || input2 == null) false
else input1 == input2
}
}
/** Greater than comparison (>) */
case class GreaterThan(left: Expression, right: Expression) extends BinaryOperator with Comparison {
def symbol: String = ">"
def nullSafeEval(input1: Any, input2: Any): Any = {
RowOrdering.compare(input1, input2, left.dataType) > 0
}
}
/** Greater than or equal comparison (>=) */
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryOperator with Comparison {
def symbol: String = ">="
def nullSafeEval(input1: Any, input2: Any): Any = {
RowOrdering.compare(input1, input2, left.dataType) >= 0
}
}
/** Less than comparison (<) */
case class LessThan(left: Expression, right: Expression) extends BinaryOperator with Comparison {
def symbol: String = "<"
def nullSafeEval(input1: Any, input2: Any): Any = {
RowOrdering.compare(input1, input2, left.dataType) < 0
}
}
/** Less than or equal comparison (<=) */
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryOperator with Comparison {
def symbol: String = "<="
def nullSafeEval(input1: Any, input2: Any): Any = {
RowOrdering.compare(input1, input2, left.dataType) <= 0
}
}
/** Logical AND operation */
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
def symbol: String = "&&"
def eval(input: InternalRow): Any = {
val leftResult = left.eval(input)
if (leftResult == false) {
false
} else {
val rightResult = right.eval(input)
if (rightResult == false) false
else if (leftResult == null || rightResult == null) null
else true
}
}
}
/** Logical OR operation */
case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate {
def symbol: String = "||"
def eval(input: InternalRow): Any = {
val leftResult = left.eval(input)
if (leftResult == true) {
true
} else {
val rightResult = right.eval(input)
if (rightResult == true) true
else if (leftResult == null || rightResult == null) null
else false
}
}
}
/** Logical NOT operation */
case class Not(child: Expression) extends UnaryExpression with Predicate {
def dataType: DataType = BooleanType
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else !childResult.asInstanceOf[Boolean]
}
}
/** IS NULL predicate */
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
def eval(input: InternalRow): Any = child.eval(input) == null
}
/** IS NOT NULL predicate */
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
def eval(input: InternalRow): Any = child.eval(input) != null
}
/** IN predicate (value IN (list)) */
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
def children: Seq[Expression] = value +: list
def nullable: Boolean = children.exists(_.nullable)
def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
if (evaluatedValue == null) {
null
} else {
list.exists { item =>
val itemValue = item.eval(input)
itemValue != null && evaluatedValue == itemValue
}
}
}
}Usage Examples:
// Comparison expressions
val ageAttr = AttributeReference("age", IntegerType)()
val nameAttr = AttributeReference("name", StringType)()
val ageFilter = GreaterThan(ageAttr, Literal(25))
val nameFilter = EqualTo(nameAttr, Literal("Alice"))
// Logical operations
val combinedFilter = And(ageFilter, nameFilter)
val eitherFilter = Or(ageFilter, nameFilter)
val notFilter = Not(ageFilter)
// Null checks
val nullCheck = IsNull(nameAttr)
val notNullCheck = IsNotNull(nameAttr)
// IN predicate
val statusIn = In(
AttributeReference("status", StringType)(),
Seq(Literal("active"), Literal("pending"), Literal("verified"))
)
// Expression evaluation with sample data
val sampleRow = InternalRow(30, "Alice", "active")
val result1 = ageFilter.eval(sampleRow) // true (30 > 25)
val result2 = nameFilter.eval(sampleRow) // true ("Alice" == "Alice")
val result3 = combinedFilter.eval(sampleRow) // true (both conditions true)Mathematical operations and numeric computations:
/**
* Base trait for binary arithmetic expressions
*/
trait BinaryArithmetic extends BinaryOperator {
/** Symbol for the arithmetic operation */
def symbol: String
/** Null-safe evaluation of the arithmetic operation */
def nullSafeEval(input1: Any, input2: Any): Any
def eval(input: InternalRow): Any = {
val leftResult = left.eval(input)
if (leftResult == null) {
null
} else {
val rightResult = right.eval(input)
if (rightResult == null) null
else nullSafeEval(leftResult, rightResult)
}
}
}
/** Addition operation (+) */
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol: String = "+"
def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case ByteType => input1.asInstanceOf[Byte] + input2.asInstanceOf[Byte]
case ShortType => input1.asInstanceOf[Short] + input2.asInstanceOf[Short]
case IntegerType => input1.asInstanceOf[Int] + input2.asInstanceOf[Int]
case LongType => input1.asInstanceOf[Long] + input2.asInstanceOf[Long]
case FloatType => input1.asInstanceOf[Float] + input2.asInstanceOf[Float]
case DoubleType => input1.asInstanceOf[Double] + input2.asInstanceOf[Double]
case DecimalType() => input1.asInstanceOf[Decimal] + input2.asInstanceOf[Decimal]
}
}
/** Subtraction operation (-) */
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol: String = "-"
def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case ByteType => input1.asInstanceOf[Byte] - input2.asInstanceOf[Byte]
case ShortType => input1.asInstanceOf[Short] - input2.asInstanceOf[Short]
case IntegerType => input1.asInstanceOf[Int] - input2.asInstanceOf[Int]
case LongType => input1.asInstanceOf[Long] - input2.asInstanceOf[Long]
case FloatType => input1.asInstanceOf[Float] - input2.asInstanceOf[Float]
case DoubleType => input1.asInstanceOf[Double] - input2.asInstanceOf[Double]
case DecimalType() => input1.asInstanceOf[Decimal] - input2.asInstanceOf[Decimal]
}
}
/** Multiplication operation (*) */
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol: String = "*"
def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case ByteType => input1.asInstanceOf[Byte] * input2.asInstanceOf[Byte]
case ShortType => input1.asInstanceOf[Short] * input2.asInstanceOf[Short]
case IntegerType => input1.asInstanceOf[Int] * input2.asInstanceOf[Int]
case LongType => input1.asInstanceOf[Long] * input2.asInstanceOf[Long]
case FloatType => input1.asInstanceOf[Float] * input2.asInstanceOf[Float]
case DoubleType => input1.asInstanceOf[Double] * input2.asInstanceOf[Double]
case DecimalType() => input1.asInstanceOf[Decimal] * input2.asInstanceOf[Decimal]
}
}
/** Division operation (/) */
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol: String = "/"
def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case DoubleType => input1.asInstanceOf[Double] / input2.asInstanceOf[Double]
case DecimalType() => input1.asInstanceOf[Decimal] / input2.asInstanceOf[Decimal]
case _ => sys.error(s"Type $dataType not supported.")
}
}
/** Modulo/remainder operation (%) */
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol: String = "%"
def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case ByteType => input1.asInstanceOf[Byte] % input2.asInstanceOf[Byte]
case ShortType => input1.asInstanceOf[Short] % input2.asInstanceOf[Short]
case IntegerType => input1.asInstanceOf[Int] % input2.asInstanceOf[Int]
case LongType => input1.asInstanceOf[Long] % input2.asInstanceOf[Long]
case FloatType => input1.asInstanceOf[Float] % input2.asInstanceOf[Float]
case DoubleType => input1.asInstanceOf[Double] % input2.asInstanceOf[Double]
case DecimalType() => input1.asInstanceOf[Decimal] % input2.asInstanceOf[Decimal]
}
}
/** Unary minus operation (-) */
case class UnaryMinus(child: Expression) extends UnaryExpression {
def dataType: DataType = child.dataType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) {
null
} else {
dataType match {
case ByteType => -childResult.asInstanceOf[Byte]
case ShortType => -childResult.asInstanceOf[Short]
case IntegerType => -childResult.asInstanceOf[Int]
case LongType => -childResult.asInstanceOf[Long]
case FloatType => -childResult.asInstanceOf[Float]
case DoubleType => -childResult.asInstanceOf[Double]
case DecimalType() => -childResult.asInstanceOf[Decimal]
}
}
}
}
/** Unary plus operation (+) */
case class UnaryPositive(child: Expression) extends UnaryExpression {
def dataType: DataType = child.dataType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = child.eval(input)
}
/** Absolute value function */
case class Abs(child: Expression) extends UnaryExpression {
def dataType: DataType = child.dataType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) {
null
} else {
dataType match {
case ByteType => math.abs(childResult.asInstanceOf[Byte])
case ShortType => math.abs(childResult.asInstanceOf[Short])
case IntegerType => math.abs(childResult.asInstanceOf[Int])
case LongType => math.abs(childResult.asInstanceOf[Long])
case FloatType => math.abs(childResult.asInstanceOf[Float])
case DoubleType => math.abs(childResult.asInstanceOf[Double])
case DecimalType() => childResult.asInstanceOf[Decimal].abs
}
}
}
}Usage Examples:
// Arithmetic expression creation
val ageAttr = AttributeReference("age", IntegerType)()
val salaryAttr = AttributeReference("salary", DoubleType)()
val agePlus10 = Add(ageAttr, Literal(10))
val salaryMinus1000 = Subtract(salaryAttr, Literal(1000.0))
val doubleSalary = Multiply(salaryAttr, Literal(2.0))
val halfSalary = Divide(salaryAttr, Literal(2.0))
// Unary operations
val negativeAge = UnaryMinus(ageAttr)
val absoluteValue = Abs(UnaryMinus(salaryAttr))
// Evaluation with sample data
val row = InternalRow(25, 50000.0)
val result1 = agePlus10.eval(row) // 35
val result2 = doubleSalary.eval(row) // 100000.0
val result3 = negativeAge.eval(row) // -25String manipulation and text processing operations:
/** String substring operation */
case class Substring(str: Expression, pos: Expression, len: Expression) extends TernaryExpression {
def dataType: DataType = StringType
def nullable: Boolean = str.nullable || pos.nullable || len.nullable
def first: Expression = str
def second: Expression = pos
def third: Expression = len
def eval(input: InternalRow): Any = {
val string = str.eval(input)
if (string == null) return null
val position = pos.eval(input)
if (position == null) return null
val length = len.eval(input)
if (length == null) return null
val s = string.asInstanceOf[UTF8String]
val start = position.asInstanceOf[Int] - 1 // 1-based to 0-based
val len = length.asInstanceOf[Int]
if (start < 0 || len < 0) UTF8String.EMPTY_UTF8
else s.substring(start, start + len)
}
}
/** String length function */
case class Length(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else childResult.asInstanceOf[UTF8String].numChars()
}
}
/** Convert string to uppercase */
case class Upper(child: Expression) extends UnaryExpression {
def dataType: DataType = StringType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else childResult.asInstanceOf[UTF8String].toUpperCase
}
}
/** Convert string to lowercase */
case class Lower(child: Expression) extends UnaryExpression {
def dataType: DataType = StringType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else childResult.asInstanceOf[UTF8String].toLowerCase
}
}
/** String concatenation */
case class Concat(children: Seq[Expression]) extends Expression {
def dataType: DataType = StringType
def nullable: Boolean = children.exists(_.nullable)
def eval(input: InternalRow): Any = {
val inputs = children.map(_.eval(input))
if (inputs.contains(null)) return null
val strings = inputs.map(_.asInstanceOf[UTF8String])
UTF8String.concat(strings: _*)
}
}
/** String trimming (remove leading/trailing whitespace) */
case class StringTrim(child: Expression) extends UnaryExpression {
def dataType: DataType = StringType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else childResult.asInstanceOf[UTF8String].trim()
}
}
/** Left trim (remove leading whitespace) */
case class StringLTrim(child: Expression) extends UnaryExpression {
def dataType: DataType = StringType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else childResult.asInstanceOf[UTF8String].trimLeft()
}
}
/** Right trim (remove trailing whitespace) */
case class StringRTrim(child: Expression) extends UnaryExpression {
def dataType: DataType = StringType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) null
else childResult.asInstanceOf[UTF8String].trimRight()
}
}Usage Examples:
// String expression creation
val nameAttr = AttributeReference("name", StringType)()
val messageAttr = AttributeReference("message", StringType)()
val upperName = Upper(nameAttr)
val lowerName = Lower(nameAttr)
val nameLength = Length(nameAttr)
val substring = Substring(nameAttr, Literal(1), Literal(3)) // First 3 characters
// String concatenation
val greeting = Concat(Seq(Literal("Hello "), nameAttr, Literal("!")))
// String trimming
val trimmed = StringTrim(messageAttr)
// Evaluation with sample data
val row = InternalRow(UTF8String.fromString(" Alice "), UTF8String.fromString(" Hello World "))
val result1 = upperName.eval(row) // " ALICE "
val result2 = nameLength.eval(row) // 8 (including spaces)
val result3 = trimmed.eval(row) // "Hello World" (spaces removed)Mathematical functions and operations:
/**
* Base trait for unary mathematical expressions
*/
trait UnaryMathExpression extends UnaryExpression {
def dataType: DataType = DoubleType
def nullable: Boolean = child.nullable
/** The mathematical function to apply */
def mathFunction(input: Double): Double
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) {
null
} else {
val doubleValue = childResult match {
case d: Double => d
case f: Float => f.toDouble
case i: Int => i.toDouble
case l: Long => l.toDouble
case _ => sys.error(s"Unsupported data type: ${child.dataType}")
}
mathFunction(doubleValue)
}
}
}
/** Sine function */
case class Sin(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.sin(input)
override def prettyName: String = "sin"
}
/** Cosine function */
case class Cos(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.cos(input)
override def prettyName: String = "cos"
}
/** Tangent function */
case class Tan(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.tan(input)
override def prettyName: String = "tan"
}
/** Square root function */
case class Sqrt(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.sqrt(input)
override def prettyName: String = "sqrt"
}
/** Natural logarithm function */
case class Log(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.log(input)
override def prettyName: String = "ln"
}
/** Exponential function (e^x) */
case class Exp(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.exp(input)
override def prettyName: String = "exp"
}
/** Floor function */
case class Floor(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.floor(input)
override def prettyName: String = "floor"
}
/** Ceiling function */
case class Ceil(child: Expression) extends UnaryMathExpression {
def mathFunction(input: Double): Double = math.ceil(input)
override def prettyName: String = "ceil"
}
/**
* Base trait for binary mathematical expressions
*/
trait BinaryMathExpression extends BinaryExpression {
def dataType: DataType = DoubleType
def nullable: Boolean = left.nullable || right.nullable
/** The mathematical function to apply */
def mathFunction(left: Double, right: Double): Double
def eval(input: InternalRow): Any = {
val leftResult = left.eval(input)
if (leftResult == null) return null
val rightResult = right.eval(input)
if (rightResult == null) return null
val leftDouble = leftResult.asInstanceOf[Double]
val rightDouble = rightResult.asInstanceOf[Double]
mathFunction(leftDouble, rightDouble)
}
}
/** Round function with precision */
case class Round(child: Expression, scale: Expression) extends BinaryExpression {
def left: Expression = child
def right: Expression = scale
def dataType: DataType = child.dataType
def nullable: Boolean = child.nullable || scale.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
val scaleResult = scale.eval(input)
if (scaleResult == null) return null
val scaleValue = scaleResult.asInstanceOf[Int]
// Implementation depends on child data type
childResult match {
case d: Double => BigDecimal(d).setScale(scaleValue, BigDecimal.RoundingMode.HALF_UP).doubleValue
case f: Float => BigDecimal(f.toDouble).setScale(scaleValue, BigDecimal.RoundingMode.HALF_UP).floatValue
case _ => childResult
}
}
}
/** Power function (base^exponent) */
case class Pow(left: Expression, right: Expression) extends BinaryMathExpression {
def mathFunction(left: Double, right: Double): Double = math.pow(left, right)
override def prettyName: String = "power"
}Usage Examples:
// Mathematical expression creation
val angleAttr = AttributeReference("angle", DoubleType)()
val valueAttr = AttributeReference("value", DoubleType)()
val sineValue = Sin(angleAttr)
val sqrtValue = Sqrt(valueAttr)
val logValue = Log(valueAttr)
val powerValue = Pow(valueAttr, Literal(2.0)) // Square
// Rounding
val rounded = Round(valueAttr, Literal(2)) // Round to 2 decimal places
// Evaluation with sample data
val row = InternalRow(math.Pi / 2, 16.0) // 90 degrees, value 16
val result1 = sineValue.eval(row) // ~1.0 (sin(π/2))
val result2 = sqrtValue.eval(row) // 4.0 (√16)
val result3 = powerValue.eval(row) // 256.0 (16²)Operations for working with arrays, maps, and structs:
/** Extract item from array by index */
case class GetArrayItem(left: Expression, right: Expression) extends BinaryExpression {
def dataType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
def nullable: Boolean = true // Array access can return null
def eval(input: InternalRow): Any = {
val arrayResult = left.eval(input)
if (arrayResult == null) return null
val indexResult = right.eval(input)
if (indexResult == null) return null
val array = arrayResult.asInstanceOf[ArrayData]
val index = indexResult.asInstanceOf[Int]
if (index < 0 || index >= array.numElements() || array.isNullAt(index)) {
null
} else {
array.get(index, dataType)
}
}
}
/** Extract value from map by key */
case class GetMapValue(left: Expression, right: Expression) extends BinaryExpression {
def dataType: DataType = left.dataType.asInstanceOf[MapType].valueType
def nullable: Boolean = true
def eval(input: InternalRow): Any = {
val mapResult = left.eval(input)
if (mapResult == null) return null
val keyResult = right.eval(input)
if (keyResult == null) return null
val map = mapResult.asInstanceOf[MapData]
val key = keyResult
// Find key in map and return corresponding value
val keys = map.keyArray()
val values = map.valueArray()
val keyType = left.dataType.asInstanceOf[MapType].keyType
for (i <- 0 until keys.numElements()) {
if (!keys.isNullAt(i) && keys.get(i, keyType) == key) {
if (values.isNullAt(i)) return null
else return values.get(i, dataType)
}
}
null // Key not found
}
}
/** Extract field from struct by index */
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression {
def dataType: DataType = child.dataType.asInstanceOf[StructType].fields(ordinal).dataType
def nullable: Boolean = child.nullable || child.dataType.asInstanceOf[StructType].fields(ordinal).nullable
def eval(input: InternalRow): Any = {
val structResult = child.eval(input)
if (structResult == null) return null
val struct = structResult.asInstanceOf[InternalRow]
if (struct.isNullAt(ordinal)) null
else struct.get(ordinal, dataType)
}
}
/** Create array from expressions */
case class CreateArray(children: Seq[Expression]) extends Expression {
def dataType: DataType = {
val elementType = if (children.nonEmpty) children.head.dataType else NullType
ArrayType(elementType, containsNull = children.exists(_.nullable))
}
def nullable: Boolean = false // Array itself is never null (but elements can be)
def eval(input: InternalRow): Any = {
val values = children.map(_.eval(input))
new GenericArrayData(values)
}
}
/** Create map from key-value expressions (alternating keys and values) */
case class CreateMap(children: Seq[Expression]) extends Expression {
require(children.size % 2 == 0, "CreateMap should have an even number of arguments")
def dataType: DataType = {
val keyType = if (children.nonEmpty) children.head.dataType else StringType
val valueType = if (children.size > 1) children(1).dataType else StringType
val valueContainsNull = children.indices.filter(_ % 2 == 1).exists(i => children(i).nullable)
MapType(keyType, valueType, valueContainsNull)
}
def nullable: Boolean = false
def eval(input: InternalRow): Any = {
val evaluatedChildren = children.map(_.eval(input))
val keys = evaluatedChildren.indices.filter(_ % 2 == 0).map(evaluatedChildren(_))
val values = evaluatedChildren.indices.filter(_ % 2 == 1).map(evaluatedChildren(_))
ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
}
}
/** Create struct from field expressions */
case class CreateStruct(children: Seq[Expression]) extends Expression {
def dataType: DataType = {
val fields = children.zipWithIndex.map { case (child, index) =>
StructField(s"col${index + 1}", child.dataType, child.nullable)
}
StructType(fields)
}
def nullable: Boolean = false
def eval(input: InternalRow): Any = {
val values = children.map(_.eval(input))
InternalRow.fromSeq(values)
}
}Usage Examples:
// Array operations
val arrayAttr = AttributeReference("items", ArrayType(StringType))()
val arrayAccess = GetArrayItem(arrayAttr, Literal(0)) // First element
val newArray = CreateArray(Seq(
Literal("apple"),
Literal("banana"),
Literal("orange")
))
// Map operations
val mapAttr = AttributeReference("properties", MapType(StringType, IntegerType))()
val mapAccess = GetMapValue(mapAttr, Literal("age"))
val newMap = CreateMap(Seq(
Literal("name"), Literal("Alice"),
Literal("age"), Literal(25)
))
// Struct operations
val structAttr = AttributeReference("person", StructType(Seq(
StructField("name", StringType),
StructField("age", IntegerType)
)))()
val nameField = GetStructField(structAttr, ordinal = 0, name = Some("name"))
val ageField = GetStructField(structAttr, ordinal = 1, name = Some("age"))
val newStruct = CreateStruct(Seq(
Literal("Bob"),
Literal(30)
))Type conversion operations for data transformation:
/**
* Type casting expression
*/
case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends UnaryExpression {
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
// Cast implementation depends on source and target types
cast(childResult, child.dataType, dataType)
}
private def cast(value: Any, from: DataType, to: DataType): Any = {
(from, to) match {
case (StringType, IntegerType) =>
try { value.asInstanceOf[UTF8String].toString.toInt }
catch { case _: NumberFormatException => null }
case (IntegerType, StringType) =>
UTF8String.fromString(value.asInstanceOf[Int].toString)
case (DoubleType, IntegerType) =>
value.asInstanceOf[Double].toInt
case (IntegerType, DoubleType) =>
value.asInstanceOf[Int].toDouble
// Many more casting rules...
case _ if from == to => value // No cast needed
case _ => null // Unsupported cast
}
}
}
/** Safe cast that returns null on conversion failure */
case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends UnaryExpression {
def nullable: Boolean = true // Always nullable since cast can fail
def eval(input: InternalRow): Any = {
try {
Cast(child, dataType, timeZoneId).eval(input)
} catch {
case _: Exception => null // Return null on any cast failure
}
}
}Usage Examples:
// Type casting
val stringAttr = AttributeReference("stringValue", StringType)()
val numberAttr = AttributeReference("numberValue", DoubleType)()
val stringToInt = Cast(stringAttr, IntegerType)
val doubleToInt = Cast(numberAttr, IntegerType)
val intToString = Cast(Literal(42), StringType)
// Safe casting (returns null on failure)
val safeCast = TryCast(stringAttr, IntegerType)
// Evaluation
val row1 = InternalRow(UTF8String.fromString("123"), 45.67)
val result1 = stringToInt.eval(row1) // 123
val result2 = doubleToInt.eval(row1) // 45
val result3 = intToString.eval(row1) // "42"
val row2 = InternalRow(UTF8String.fromString("invalid"), 45.67)
val result4 = stringToInt.eval(row2) // null (cast failure)
val result5 = safeCast.eval(row2) // null (safe cast)Date and time operations for temporal data processing:
/** Extract year from date/timestamp */
case class Year(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
child.dataType match {
case DateType =>
val days = childResult.asInstanceOf[Int]
DateTimeUtils.getYear(days)
case TimestampType =>
val microseconds = childResult.asInstanceOf[Long]
DateTimeUtils.getYear(microseconds)
}
}
}
/** Extract month from date/timestamp */
case class Month(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
child.dataType match {
case DateType =>
val days = childResult.asInstanceOf[Int]
DateTimeUtils.getMonth(days)
case TimestampType =>
val microseconds = childResult.asInstanceOf[Long]
DateTimeUtils.getMonth(microseconds)
}
}
}
/** Extract day of month from date/timestamp */
case class DayOfMonth(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
child.dataType match {
case DateType =>
val days = childResult.asInstanceOf[Int]
DateTimeUtils.getDayOfMonth(days)
case TimestampType =>
val microseconds = childResult.asInstanceOf[Long]
DateTimeUtils.getDayOfMonth(microseconds)
}
}
}
/** Extract hour from timestamp */
case class Hour(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
val microseconds = childResult.asInstanceOf[Long]
DateTimeUtils.getHour(microseconds)
}
}
/** Extract minute from timestamp */
case class Minute(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
val microseconds = childResult.asInstanceOf[Long]
DateTimeUtils.getMinute(microseconds)
}
}
/** Extract second from timestamp */
case class Second(child: Expression) extends UnaryExpression {
def dataType: DataType = IntegerType
def nullable: Boolean = child.nullable
def eval(input: InternalRow): Any = {
val childResult = child.eval(input)
if (childResult == null) return null
val microseconds = childResult.asInstanceOf[Long]
DateTimeUtils.getSecond(microseconds)
}
}
/** Current timestamp function */
case class CurrentTimestamp() extends LeafExpression with Nondeterministic {
def dataType: DataType = TimestampType
def nullable: Boolean = false
def initializeInternal(partitionIndex: Int): Unit = {
// Initialize with current time at task start
}
def evalInternal(input: InternalRow): Any = {
System.currentTimeMillis() * 1000L // Convert to microseconds
}
}
/** Current date function */
case class CurrentDate(timeZoneId: Option[String] = None) extends LeafExpression with Nondeterministic {
def dataType: DataType = DateType
def nullable: Boolean = false
def initializeInternal(partitionIndex: Int): Unit = {
// Initialize with current date at task start
}
def evalInternal(input: InternalRow): Any = {
val currentTimeMillis = System.currentTimeMillis()
DateTimeUtils.millisToDays(currentTimeMillis)
}
}
/** Add interval to timestamp */
case class DateAdd(startDate: Expression, days: Expression) extends BinaryExpression {
def left: Expression = startDate
def right: Expression = days
def dataType: DataType = DateType
def nullable: Boolean = startDate.nullable || days.nullable
def eval(input: InternalRow): Any = {
val startResult = startDate.eval(input)
if (startResult == null) return null
val daysResult = days.eval(input)
if (daysResult == null) return null
val startDays = startResult.asInstanceOf[Int]
val additionalDays = daysResult.asInstanceOf[Int]
startDays + additionalDays
}
}
/** Subtract interval from timestamp */
case class DateSub(startDate: Expression, days: Expression) extends BinaryExpression {
def left: Expression = startDate
def right: Expression = days
def dataType: DataType = DateType
def nullable: Boolean = startDate.nullable || days.nullable
def eval(input: InternalRow): Any = {
val startResult = startDate.eval(input)
if (startResult == null) return null
val daysResult = days.eval(input)
if (daysResult == null) return null
val startDays = startResult.asInstanceOf[Int]
val subtractDays = daysResult.asInstanceOf[Int]
startDays - subtractDays
}
}Usage Examples:
// Date/time expression creation
val timestampAttr = AttributeReference("created_at", TimestampType)()
val dateAttr = AttributeReference("birth_date", DateType)()
val yearExtracted = Year(timestampAttr)
val monthExtracted = Month(dateAttr)
val hourExtracted = Hour(timestampAttr)
// Current date/time
val now = CurrentTimestamp()
val today = CurrentDate()
// Date arithmetic
val futureDate = DateAdd(dateAttr, Literal(30)) // Add 30 days
val pastDate = DateSub(dateAttr, Literal(7)) // Subtract 7 days
// Evaluation with sample data
val currentTime = System.currentTimeMillis() * 1000L
val sampleRow = InternalRow(currentTime, DateTimeUtils.millisToDays(System.currentTimeMillis()))
val yearResult = yearExtracted.eval(sampleRow) // Current year
val monthResult = monthExtracted.eval(sampleRow) // Current month
val hourResult = hourExtracted.eval(sampleRow) // Current hourConditional logic and branching expressions:
/** Conditional if-then-else expression */
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) extends TernaryExpression {
def first: Expression = predicate
def second: Expression = trueValue
def third: Expression = falseValue
def dataType: DataType = trueValue.dataType
def nullable: Boolean = trueValue.nullable || falseValue.nullable
def eval(input: InternalRow): Any = {
val predicateResult = predicate.eval(input)
if (predicateResult == null) {
null
} else if (predicateResult.asInstanceOf[Boolean]) {
trueValue.eval(input)
} else {
falseValue.eval(input)
}
}
}
/** Case-when conditional expression */
case class CaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression] = None) extends Expression {
def children: Seq[Expression] = branches.flatMap { case (condition, value) =>
Seq(condition, value)
} ++ elseValue.toSeq
def dataType: DataType = branches.head._2.dataType
def nullable: Boolean = branches.exists(_._2.nullable) || elseValue.exists(_.nullable)
def eval(input: InternalRow): Any = {
for ((condition, value) <- branches) {
val conditionResult = condition.eval(input)
if (conditionResult != null && conditionResult.asInstanceOf[Boolean]) {
return value.eval(input)
}
}
elseValue match {
case Some(expr) => expr.eval(input)
case None => null
}
}
}
/** Return first non-null expression */
case class Coalesce(children: Seq[Expression]) extends Expression {
require(children.nonEmpty, "Coalesce must have at least one child")
def dataType: DataType = children.head.dataType
def nullable: Boolean = children.forall(_.nullable)
def eval(input: InternalRow): Any = {
for (child <- children) {
val result = child.eval(input)
if (result != null) {
return result
}
}
null
}
}
/** Null if two expressions are equal, otherwise returns first expression */
case class NullIf(left: Expression, right: Expression) extends BinaryExpression {
def dataType: DataType = left.dataType
def nullable: Boolean = true // Always nullable since it can return null
def eval(input: InternalRow): Any = {
val leftResult = left.eval(input)
val rightResult = right.eval(input)
if (leftResult == rightResult) {
null
} else {
leftResult
}
}
}
/** If first expression is null, return second expression, otherwise return first */
case class Nvl(left: Expression, right: Expression) extends BinaryExpression {
def dataType: DataType = left.dataType
def nullable: Boolean = right.nullable
def eval(input: InternalRow): Any = {
val leftResult = left.eval(input)
if (leftResult == null) {
right.eval(input)
} else {
leftResult
}
}
}
/** Three-argument null handling: if first is null return third, otherwise return second */
case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression) extends TernaryExpression {
def first: Expression = expr1
def second: Expression = expr2
def third: Expression = expr3
def dataType: DataType = expr2.dataType
def nullable: Boolean = expr2.nullable || expr3.nullable
def eval(input: InternalRow): Any = {
val firstResult = expr1.eval(input)
if (firstResult == null) {
expr3.eval(input)
} else {
expr2.eval(input)
}
}
}Usage Examples:
// Conditional expressions
val ageAttr = AttributeReference("age", IntegerType)()
val nameAttr = AttributeReference("name", StringType)()
val statusAttr = AttributeReference("status", StringType)()
// If-then-else
val ageCategory = If(
GreaterThan(ageAttr, Literal(18)),
Literal("Adult"),
Literal("Minor")
)
// Case-when with multiple branches
val riskLevel = CaseWhen(
branches = Seq(
(GreaterThan(ageAttr, Literal(65)), Literal("High Risk")),
(GreaterThan(ageAttr, Literal(40)), Literal("Medium Risk")),
(GreaterThan(ageAttr, Literal(18)), Literal("Low Risk"))
),
elseValue = Some(Literal("No Risk"))
)
// Coalesce - return first non-null
val preferredName = Coalesce(Seq(
AttributeReference("nickname", StringType)(),
AttributeReference("first_name", StringType)(),
Literal("Unknown")
))
// Null handling
val statusOrDefault = Nvl(statusAttr, Literal("Active"))
// Evaluation with sample data
val sampleRow = InternalRow(25, UTF8String.fromString("Alice"), null)
val result1 = ageCategory.eval(sampleRow) // "Adult" (25 > 18)
val result2 = riskLevel.eval(sampleRow) // "Low Risk" (25 > 18 but not > 40)
val result3 = statusOrDefault.eval(sampleRow) // "Active" (status is null)Aggregation operations for computing statistics across multiple rows:
/**
* Base class for all aggregate expressions
*/
abstract class AggregateExpression extends Expression {
/** Whether this aggregate is distinct */
def isDistinct: Boolean
/** Aggregate function implementation */
def aggregateFunction: AggregateFunction
def dataType: DataType = aggregateFunction.dataType
def nullable: Boolean = aggregateFunction.nullable
def children: Seq[Expression] = aggregateFunction.children
}
/**
* Base trait for aggregate functions
*/
abstract class AggregateFunction extends Expression {
/** Attributes representing the aggregation buffer */
def aggBufferAttributes: Seq[AttributeReference]
/** Initial values for the buffer */
def initialValues: Seq[Expression]
/** Update buffer with new input row */
def updateExpressions: Seq[Expression]
/** Merge two buffers */
def mergeExpressions: Seq[Expression]
/** Extract final result from buffer */
def evaluateExpression: Expression
}
/**
* Declarative aggregate functions using expressions
*/
abstract class DeclarativeAggregate extends AggregateFunction
/** Count aggregate function */
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
def this(child: Expression) = this(Seq(child))
def dataType: DataType = LongType
def nullable: Boolean = false
// Buffer has single long value for count
lazy val count = AttributeReference("count", LongType, nullable = false)()
def aggBufferAttributes: Seq[AttributeReference] = Seq(count)
// Initialize count to 0
def initialValues: Seq[Expression] = Seq(Literal(0L))
// Increment count for non-null values (or all rows if no children)
def updateExpressions: Seq[Expression] = {
if (children.isEmpty) {
// COUNT(*) - count all rows
Seq(Add(count, Literal(1L)))
} else {
// COUNT(expr) - count non-null values
val condition = children.map(IsNotNull).reduce(And)
Seq(If(condition, Add(count, Literal(1L)), count))
}
}
// Merge counts by adding them
def mergeExpressions: Seq[Expression] = Seq(Add(count.left, count.right))
// Final result is the count value
def evaluateExpression: Expression = count
}
/** Sum aggregate function */
case class Sum(child: Expression) extends DeclarativeAggregate {
def children: Seq[Expression] = Seq(child)
def dataType: DataType = child.dataType
def nullable: Boolean = true
// Buffer has sum and isEmpty flag
lazy val sum = AttributeReference("sum", child.dataType, nullable = true)()
lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()
def aggBufferAttributes: Seq[AttributeReference] = Seq(sum, isEmpty)
// Initialize sum to null, isEmpty to true
def initialValues: Seq[Expression] = Seq(Literal(null, child.dataType), Literal(true))
// Add non-null values to sum
def updateExpressions: Seq[Expression] = Seq(
If(IsNull(child), sum,
If(isEmpty, child, Add(sum, child))),
If(IsNull(child), isEmpty, Literal(false))
)
// Merge sums
def mergeExpressions: Seq[Expression] = Seq(
If(sum.left.isEmpty, sum.right,
If(sum.right.isEmpty, sum.left, Add(sum.left, sum.right))),
And(sum.left.isEmpty, sum.right.isEmpty)
)
// Final result is sum (null if no non-null values)
def evaluateExpression: Expression = If(isEmpty, Literal(null, dataType), sum)
}
/** Maximum aggregate function */
case class Max(child: Expression) extends DeclarativeAggregate {
def children: Seq[Expression] = Seq(child)
def dataType: DataType = child.dataType
def nullable: Boolean = true
lazy val max = AttributeReference("max", child.dataType, nullable = true)()
def aggBufferAttributes: Seq[AttributeReference] = Seq(max)
def initialValues: Seq[Expression] = Seq(Literal(null, child.dataType))
def updateExpressions: Seq[Expression] = Seq(
If(IsNull(child), max,
If(IsNull(max), child,
If(GreaterThan(child, max), child, max)))
)
def mergeExpressions: Seq[Expression] = Seq(
If(IsNull(max.left), max.right,
If(IsNull(max.right), max.left,
If(GreaterThan(max.left, max.right), max.left, max.right)))
)
def evaluateExpression: Expression = max
}
/** Minimum aggregate function */
case class Min(child: Expression) extends DeclarativeAggregate {
def children: Seq[Expression] = Seq(child)
def dataType: DataType = child.dataType
def nullable: Boolean = true
lazy val min = AttributeReference("min", child.dataType, nullable = true)()
def aggBufferAttributes: Seq[AttributeReference] = Seq(min)
def initialValues: Seq[Expression] = Seq(Literal(null, child.dataType))
def updateExpressions: Seq[Expression] = Seq(
If(IsNull(child), min,
If(IsNull(min), child,
If(LessThan(child, min), child, min)))
)
def mergeExpressions: Seq[Expression] = Seq(
If(IsNull(min.left), min.right,
If(IsNull(min.right), min.left,
If(LessThan(min.left, min.right), min.left, min.right)))
)
def evaluateExpression: Expression = min
}
/** Average aggregate function */
case class Average(child: Expression) extends DeclarativeAggregate {
def children: Seq[Expression] = Seq(child)
def dataType: DataType = DoubleType
def nullable: Boolean = true
lazy val sum = AttributeReference("sum", DoubleType, nullable = true)()
lazy val count = AttributeReference("count", LongType, nullable = false)()
def aggBufferAttributes: Seq[AttributeReference] = Seq(sum, count)
def initialValues: Seq[Expression] = Seq(Literal(null, DoubleType), Literal(0L))
def updateExpressions: Seq[Expression] = Seq(
If(IsNull(child), sum,
If(IsNull(sum), Cast(child, DoubleType), Add(sum, Cast(child, DoubleType)))),
If(IsNull(child), count, Add(count, Literal(1L)))
)
def mergeExpressions: Seq[Expression] = Seq(
If(IsNull(sum.left), sum.right,
If(IsNull(sum.right), sum.left, Add(sum.left, sum.right))),
Add(count.left, count.right)
)
def evaluateExpression: Expression =
If(EqualTo(count, Literal(0L)), Literal(null, DoubleType), Divide(sum, Cast(count, DoubleType)))
}Usage Examples:
// Aggregate expression creation
val ageAttr = AttributeReference("age", IntegerType)()
val salaryAttr = AttributeReference("salary", DoubleType)()
val totalCount = Count(Seq.empty) // COUNT(*)
val ageCount = Count(Seq(ageAttr)) // COUNT(age)
val totalSalary = Sum(salaryAttr)
val maxAge = Max(ageAttr)
val minSalary = Min(salaryAttr)
val avgSalary = Average(salaryAttr)
// Usage in aggregation context (conceptual - normally used in query plans)
val aggregates = Seq(
Alias(totalCount, "total_employees")(),
Alias(avgSalary, "avg_salary")(),
Alias(maxAge, "max_age")()
)Catalyst's code generation system for high-performance expression evaluation:
/**
* Context for Java code generation
*/
class CodegenContext {
/** Add object reference and return variable name */
def addReferenceObj(obj: AnyRef): String = {
val objName = s"obj$${references.size}"
references += obj
objName
}
/** Add mutable state variable */
def addMutableState(
javaType: String,
variableName: String,
initCode: String = ""): String = {
val name = if (variableName.nonEmpty) variableName else freshName("mutableState")
mutableStateInitCode += s"$initCode"
mutableStateVars += s"private $javaType $name;"
name
}
/** Generate fresh variable name */
def freshName(name: String): String = {
val count = freshNameIds.getOrElse(name, 0)
freshNameIds(name) = count + 1
s"${name}_$count"
}
/** Add new function to generated class */
def addNewFunction(funcName: String, funcCode: String): String = {
addedFunctions += funcCode
funcName
}
/** Split large expressions into multiple functions */
def splitExpressionsByRows(expressions: Seq[Expression], functionName: String): String = {
// Split expressions to avoid Java method size limits
val splitSize = 1000
val splitExprs = expressions.grouped(splitSize).zipWithIndex.map { case (exprs, index) =>
val subFuncName = s"${functionName}_$index"
val codes = exprs.map(_.genCode(this))
generateSubFunction(subFuncName, codes)
}
splitExprs.mkString("\\n")
}
// Internal state
private val references = mutable.ArrayBuffer[AnyRef]()
private val mutableStateInitCode = mutable.ArrayBuffer[String]()
private val mutableStateVars = mutable.ArrayBuffer[String]()
private val freshNameIds = mutable.HashMap[String, Int]()
private val addedFunctions = mutable.ArrayBuffer[String]()
}
/**
* Generated code for an expression
*/
case class ExprCode(code: String, isNull: String, value: String) {
/** Copy with new null check code */
def copy(code: String = this.code,
isNull: String = this.isNull,
value: String = this.value): ExprCode = {
ExprCode(code, isNull, value)
}
}
/**
* Utility functions for code generation
*/
object CodeGenerator {
/** Generate Java code for expression evaluation */
def generateCode(ctx: CodegenContext, expressions: Seq[Expression]): (String, String, String) = {
val codes = expressions.map(_.genCode(ctx))
val evalCodes = codes.map(_.code).mkString("\\n")
val nullChecks = codes.map(_.isNull).mkString(", ")
val values = codes.map(_.value).mkString(", ")
(evalCodes, nullChecks, values)
}
/** Java type mapping from Catalyst DataType */
def javaType(dataType: DataType): String = dataType match {
case BooleanType => "boolean"
case ByteType => "byte"
case ShortType => "short"
case IntegerType => "int"
case LongType => "long"
case FloatType => "float"
case DoubleType => "double"
case StringType => "UTF8String"
case BinaryType => "byte[]"
case DateType => "int"
case TimestampType => "long"
case _: DecimalType => "Decimal"
case _: ArrayType => "ArrayData"
case _: MapType => "MapData"
case _: StructType => "InternalRow"
case _ => "Object"
}
/** Default value for Java type */
def defaultValue(dataType: DataType): String = dataType match {
case BooleanType => "false"
case ByteType | ShortType | IntegerType => "0"
case LongType => "0L"
case FloatType => "0.0f"
case DoubleType => "0.0"
case _ => "null"
}
}
/**
* Code generation support for expressions
*/
trait CodegenSupport extends Expression {
/** Generate Java code for this expression */
def genCode(ctx: CodegenContext): ExprCode = {
val nullVar = ctx.freshName("isNull")
val valueVar = ctx.freshName("value")
val javaType = CodeGenerator.javaType(dataType)
val code = s"""
|boolean $nullVar = false;
|$javaType $valueVar = ${CodeGenerator.defaultValue(dataType)};
|${doGenCode(ctx, nullVar, valueVar)}
""".stripMargin
ExprCode(code, nullVar, valueVar)
}
/** Subclasses implement this to generate expression-specific code */
protected def doGenCode(ctx: CodegenContext, nullVar: String, valueVar: String): String
}
/**
* Example: Add expression with code generation
*/
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with CodegenSupport {
def symbol: String = "+"
protected def doGenCode(ctx: CodegenContext, nullVar: String, valueVar: String): String = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
s"""
|${leftGen.code}
|${rightGen.code}
|
|if (${leftGen.isNull} || ${rightGen.isNull}) {
| $nullVar = true;
|} else {
| $valueVar = ${leftGen.value} + ${rightGen.value};
|}
""".stripMargin
}
// Fallback interpreted evaluation
def nullSafeEval(input1: Any, input2: Any): Any = {
dataType match {
case IntegerType => input1.asInstanceOf[Int] + input2.asInstanceOf[Int]
case LongType => input1.asInstanceOf[Long] + input2.asInstanceOf[Long]
case DoubleType => input1.asInstanceOf[Double] + input2.asInstanceOf[Double]
// ... other numeric types
}
}
}Usage Examples:
// Code generation context
val ctx = new CodegenContext()
// Expression for code generation
val ageAttr = AttributeReference("age", IntegerType)()
val expr = Add(ageAttr, Literal(10))
// Generate code
val generated = expr.genCode(ctx)
println(generated.code)
// Output: Java code for evaluating age + 10
// Generated code structure:
/*
boolean isNull_0 = false;
int value_0 = 0;
// Load age from input row
boolean isNull_1 = input.isNullAt(0);
int value_1 = isNull_1 ? 0 : input.getInt(0);
// Literal 10
boolean isNull_2 = false;
int value_2 = 10;
// Add operation
if (isNull_1 || isNull_2) {
isNull_0 = true;
} else {
value_0 = value_1 + value_2;
}
*/
// Performance: Generated code is much faster than interpretation
// - No virtual method calls
// - No boxing/unboxing
// - Direct memory access
// - JIT-friendly code patternsThe Catalyst expression system provides a powerful and extensible framework for representing and evaluating computations in Spark SQL. Key capabilities include:
Expression Hierarchy: Base Expression class with specialized traits for different patterns (Unary, Binary, Ternary, etc.)
Type Safety: Strong typing with DataType system integration and comprehensive type checking
Evaluation Modes: Both interpreted evaluation (eval) and high-performance code generation (genCode)
Rich Function Library: Comprehensive set of built-in operations for arithmetic, comparisons, string manipulation, date/time operations, complex types, aggregation, and more
Named Expressions: Support for aliasing and attribute references enabling query planning and optimization
Conditional Logic: Full support for if-then-else, case-when, and null handling operations
Code Generation: Advanced code generation framework that produces optimized Java code for high-performance execution
Extensibility: Clean interfaces for adding custom expressions and functions
This expression system forms the foundation for all SQL operations in Spark, enabling both correctness through strong typing and performance through code generation optimization.
Install with Tessl CLI
npx tessl i tessl/maven-org-apache-spark--spark-catalyst