Extensible expression evaluation framework supporting complex expression trees, type checking, code generation, and high-performance evaluation for SQL operations.
The abstract base class for all expressions in Catalyst.
/**
* An expression in Catalyst.
* If an expression wants to be exposed in the function registry, the concrete implementation
* must be a case class whose constructor arguments are all Expressions types.
*/
abstract class Expression extends TreeNode[Expression] {
/**
* Returns true when an expression is a candidate for static evaluation before the query is executed.
* The following conditions are used to determine suitability for constant folding:
* - A Coalesce is foldable if all of its children are foldable
* - A BinaryExpression is foldable if its both left and right child are foldable
* - A Not, IsNull, or IsNotNull is foldable if its child is foldable
* - A Literal is foldable
* - A Cast or UnaryMinus is foldable if its child is foldable
*/
def foldable: Boolean = false
/**
* Returns true when the current expression always return the same result for fixed inputs from children.
* Note that this means that an expression should be considered as non-deterministic if:
* - if it relies on some mutable internal state, or
* - if it relies on some implicit input that is not part of the children expression list.
* - if it has non-deterministic child or children.
*/
def deterministic: Boolean = children.forall(_.deterministic)
/** Whether this expression can return null */
def nullable: Boolean
/** Set of attributes referenced by this expression */
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: InternalRow = null): Any
/** Data type of the result of evaluating this expression */
def dataType: DataType
/**
* Returns Java source code that can be used to generate the result of evaluating the expression.
* @param ctx a CodeGenContext
* @return GeneratedExpressionCode
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode
/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression.
* @param ctx a CodeGenContext
* @param ev an GeneratedExpressionCode with unique terms
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String
/**
* Returns true if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and false if it still contains any unresolved placeholders.
*/
lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
/**
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
def childrenResolved: Boolean = children.forall(_.resolved)
/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
*/
def semanticEquals(other: Expression): Boolean
/**
* Returns the hash for this expression. Expressions that compute the same result, even if
* they differ cosmetically should return the same hash.
*/
def semanticHash(): Int
/**
* Returns a user-facing string representation of this expression's name.
* This should usually match the name of the function in SQL.
*/
def prettyName: String = getClass.getSimpleName.toLowerCase
/**
* Returns a user-facing string representation of this expression, i.e. does not have developer
* centric debugging information like the expression id.
*/
def prettyString: String
/** Validate input data types and return type check result */
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
// Create expressions
val literal = Literal(42, IntegerType)
val attr = AttributeReference("x", IntegerType, nullable = false)()
val add = Add(attr, literal)
// Check expression properties
val foldable = literal.foldable // true (literals are foldable)
val addFoldable = add.foldable // false (contains non-literal)
val deterministic = add.deterministic // true (always same result for same input)
val nullable = add.nullable // false (int + int is never null)
val dataType = add.dataType // IntegerType
// Evaluate expression
val row = InternalRow(10) // x = 10
val result = add.eval(row) // 52 (10 + 42)
// Get referenced attributes
val refs = add.references // AttributeSet containing "x" attributeBase traits defining the structure of expression trees.
/**
* An expression that has no child expressions.
*/
trait LeafExpression extends Expression {
def children: Seq[Expression] = Nil
}
/**
* An expression that has one child expression.
*/
trait UnaryExpression extends Expression {
def child: Expression
def children: Seq[Expression] = child :: Nil
}
/**
* An expression that has two child expressions.
*/
trait BinaryExpression extends Expression {
def left: Expression
def right: Expression
def children: Seq[Expression] = Seq(left, right)
}
/**
* A special case of BinaryExpression that requires two children to have the same output data type.
*/
trait BinaryOperator extends BinaryExpression {
/** Expected input type from both left and right child expressions */
def inputType: AbstractDataType
override def checkInputDataTypes(): TypeCheckResult = {
// Validates that both children have compatible types
}
}
/**
* An expression that has three child expressions.
*/
abstract class TernaryExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
/**
* Called by default eval implementation. If subclass of TernaryExpression keep the default
* nullability, they can override this method to save null-check code.
*/
protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any =
sys.error(s"TernaryExpressions must override either eval or nullSafeEval")
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
// Leaf expression (no children)
val constant = Literal(100, IntegerType)
val children1 = constant.children // Nil
// Unary expression (one child)
val negation = UnaryMinus(constant)
val children2 = negation.children // Seq(constant)
val child = negation.child // constant
// Binary expression (two children)
val attr = AttributeReference("y", IntegerType, nullable = false)()
val addition = Add(attr, constant)
val children3 = addition.children // Seq(attr, constant)
val left = addition.left // attr
val right = addition.right // constant
// Binary operator with type constraints
val comparison = EqualTo(attr, constant) // Both sides must be same type
val typeCheck = comparison.checkInputDataTypes() // Validates type compatibilitySpecial traits that modify expression behavior and evaluation.
/**
* An expression that is not deterministic - returns different results for the same input.
*/
trait Nondeterministic extends Expression {
final override def deterministic: Boolean = false
final override def foldable: Boolean = false
/**
* Sets the initial values for this nondeterministic expression, called before evaluation.
*/
final def setInitialValues(): Unit = {
initInternal()
initialized = true
}
/** Initialize any internal state before evaluation */
protected def initInternal(): Unit
/** Internal evaluation after initialization */
protected def evalInternal(input: InternalRow): Any
private[this] var initialized = false
final override def eval(input: InternalRow = null): Any = {
require(initialized, "nondeterministic expression should be initialized before evaluate")
evalInternal(input)
}
}
/**
* An expression that is not supposed to be evaluated.
* Used for expressions that are only meaningful during analysis.
*/
trait Unevaluable extends Expression {
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
}
/**
* An expression that does not have code gen implemented and falls back to interpreted mode.
*/
trait CodegenFallback extends Expression {
// Falls back to interpreted evaluation when code generation is not available
}
/**
* Expressions that expect specific input types and can provide helpful error messages.
*/
trait ExpectsInputTypes extends Expression {
/** Expected input types for child expressions */
def inputTypes: Seq[AbstractDataType]
override def checkInputDataTypes(): TypeCheckResult = {
// Validates children match expected input types
}
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.util.random.RandomSeed
// Nondeterministic expression
case class RandomExpression(seed: Long) extends LeafExpression with Nondeterministic {
private var random: scala.util.Random = _
def initialize(partitionIndex: Int): Unit = {
random = new scala.util.Random(seed + partitionIndex)
}
def eval(input: InternalRow): Any = random.nextDouble()
def dataType: DataType = DoubleType
def nullable: Boolean = false
}
// Unevaluable expression (used during analysis only)
case class UnresolvedAttribute(name: String) extends LeafExpression with Unevaluable {
def dataType: DataType = throw new UnresolvedException(this, "dataType")
def nullable: Boolean = throw new UnresolvedException(this, "nullable")
}
// Expression with expected input types
case class Substring(str: Expression, pos: Expression, len: Expression)
extends Expression with ExpectsInputTypes {
def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
def children: Seq[Expression] = Seq(str, pos, len)
def dataType: DataType = StringType
def nullable: Boolean = children.exists(_.nullable)
def eval(input: InternalRow): Any = {
val string = str.eval(input).asInstanceOf[UTF8String]
val position = pos.eval(input).asInstanceOf[Int]
val length = len.eval(input).asInstanceOf[Int]
string.substring(position - 1, position - 1 + length)
}
}Fundamental expression implementations for common operations.
/**
* Represents a constant value.
*/
case class Literal(value: Any, dataType: DataType) extends LeafExpression {
override def foldable: Boolean = true
override def nullable: Boolean = value == null
override def eval(input: InternalRow): Any = value
}
object Literal {
/** Create literal from Scala value with automatic type inference */
def apply(v: Any): Literal
/** Create null literal of specified type */
def create(v: Any, dataType: DataType): Literal
/** Default literals for primitive types */
val TrueLiteral: Literal = Literal(true, BooleanType)
val FalseLiteral: Literal = Literal(false, BooleanType)
}
/**
* Reference to an attribute/column in a relation.
*/
case class AttributeReference(
name: String,
dataType: DataType,
nullable: Boolean,
metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil) extends Attribute {
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("AttributeReference.eval() not supported")
/** Create a new copy with different nullability */
def withNullability(newNullability: Boolean): AttributeReference
/** Create a new copy with different name */
def withName(newName: String): AttributeReference
/** Create a new copy with different data type */
def withDataType(newType: DataType): AttributeReference
}
/**
* Reference bound to a specific input position.
*/
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression {
override def eval(input: InternalRow): Any = input.get(ordinal, dataType)
}
/**
* Type conversion expression.
*/
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def nullable: Boolean = child.nullable
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) null else cast(value)
}
private def cast(value: Any): Any = {
// Type conversion logic based on source and target types
}
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
// Create literals
val intLit = Literal(42, IntegerType)
val stringLit = Literal("hello", StringType)
val nullLit = Literal(null, StringType)
val autoLit = Literal("world") // Type inferred as StringType
// Evaluate literals
val value1 = intLit.eval() // 42
val value2 = stringLit.eval() // "hello"
val value3 = nullLit.eval() // null
// Create attribute references
val userId = AttributeReference("user_id", IntegerType, nullable = false)()
val userName = AttributeReference("user_name", StringType, nullable = true)()
// Create bound references (for compiled expressions)
val boundId = BoundReference(0, IntegerType, nullable = false) // First column
val boundName = BoundReference(1, StringType, nullable = true) // Second column
// Cast expressions
val stringToInt = Cast(stringLit, IntegerType)
val intToString = Cast(intLit, StringType)
// Build complex expressions
val userIdPlusOne = Add(userId, Literal(1, IntegerType))
val userGreeting = Concat(Seq(Literal("Hello "), userName))Mathematical and comparison operations.
// Arithmetic operations
case class Add(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = NumericType
override def dataType: DataType = left.dataType
override def eval(input: InternalRow): Any = {
val leftValue = left.eval(input)
val rightValue = right.eval(input)
if (leftValue == null || rightValue == null) null
else numeric.plus(leftValue, rightValue)
}
}
case class Subtract(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = NumericType
override def dataType: DataType = left.dataType
}
case class Multiply(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = NumericType
override def dataType: DataType = left.dataType
}
case class Divide(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = NumericType
override def dataType: DataType = left.dataType
}
case class UnaryMinus(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
}
// Comparison operations
case class EqualTo(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = AnyDataType
override def dataType: DataType = BooleanType
}
case class LessThan(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = TypeCollection.Ordered
override def dataType: DataType = BooleanType
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = TypeCollection.Ordered
override def dataType: DataType = BooleanType
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = TypeCollection.Ordered
override def dataType: DataType = BooleanType
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = TypeCollection.Ordered
override def dataType: DataType = BooleanType
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
// Arithmetic expressions
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = false)()
val sum = Add(a, b)
val diff = Subtract(a, b)
val product = Multiply(a, b)
val quotient = Divide(a, b)
val negation = UnaryMinus(a)
// Evaluate arithmetic
val row = InternalRow(10, 3)
val sumResult = sum.eval(row) // 13
val diffResult = diff.eval(row) // 7
val productResult = product.eval(row) // 30
val quotientResult = quotient.eval(row) // 3.33...
val negResult = negation.eval(row) // -10
// Comparison expressions
val equal = EqualTo(a, b)
val less = LessThan(a, b)
val greater = GreaterThan(a, b)
val lessEqual = LessThanOrEqual(a, b)
val greaterEqual = GreaterThanOrEqual(a, b)
// Evaluate comparisons
val equalResult = equal.eval(row) // false (10 != 3)
val lessResult = less.eval(row) // false (10 > 3)
val greaterResult = greater.eval(row) // true (10 > 3)
// Complex arithmetic expressions
val formula = Add(Multiply(a, Literal(2)), Subtract(b, Literal(1)))
// Equivalent to: (a * 2) + (b - 1) = (10 * 2) + (3 - 1) = 20 + 2 = 22
val formulaResult = formula.eval(row) // 22Boolean logic and string manipulation operations.
// Logical operations
case class And(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = BooleanType
override def dataType: DataType = BooleanType
override def eval(input: InternalRow): Any = {
val leftValue = left.eval(input)
if (leftValue == false) false
else {
val rightValue = right.eval(input)
if (rightValue == false) false
else if (leftValue == null || rightValue == null) null
else true
}
}
}
case class Or(left: Expression, right: Expression) extends BinaryOperator {
def inputType: AbstractDataType = BooleanType
override def dataType: DataType = BooleanType
}
case class Not(child: Expression) extends UnaryExpression {
override def dataType: DataType = BooleanType
override def nullable: Boolean = child.nullable
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) null else !value.asInstanceOf[Boolean]
}
}
// Null checking
case class IsNull(child: Expression) extends UnaryExpression {
override def dataType: DataType = BooleanType
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = child.eval(input) == null
}
case class IsNotNull(child: Expression) extends UnaryExpression {
override def dataType: DataType = BooleanType
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = child.eval(input) != null
}
// String operations
case class Concat(children: Seq[Expression]) extends Expression {
override def dataType: DataType = StringType
override def nullable: Boolean = children.exists(_.nullable)
override def eval(input: InternalRow): Any = {
val values = children.map(_.eval(input))
if (values.contains(null)) null
else UTF8String.concat(values.map(_.asInstanceOf[UTF8String]): _*)
}
}
case class Length(child: Expression) extends UnaryExpression {
override def dataType: DataType = IntegerType
override def nullable: Boolean = child.nullable
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) null else value.asInstanceOf[UTF8String].numChars()
}
}
case class Upper(child: Expression) extends UnaryExpression {
override def dataType: DataType = StringType
override def nullable: Boolean = child.nullable
}
case class Lower(child: Expression) extends UnaryExpression {
override def dataType: DataType = StringType
override def nullable: Boolean = child.nullable
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
// Logical expressions
val active = AttributeReference("active", BooleanType, nullable = false)()
val verified = AttributeReference("verified", BooleanType, nullable = false)()
val bothTrue = And(active, verified)
val eitherTrue = Or(active, verified)
val notActive = Not(active)
// Evaluate logical expressions
val row = InternalRow(true, false)
val andResult = bothTrue.eval(row) // false (true AND false)
val orResult = eitherTrue.eval(row) // true (true OR false)
val notResult = notActive.eval(row) // false (NOT true)
// Null checking
val name = AttributeReference("name", StringType, nullable = true)()
val nameIsNull = IsNull(name)
val nameIsNotNull = IsNotNull(name)
val nullRow = InternalRow(UTF8String.fromString(null))
val nullCheck = nameIsNull.eval(nullRow) // true
val notNullCheck = nameIsNotNull.eval(nullRow) // false
// String operations
val firstName = AttributeReference("first_name", StringType, nullable = false)()
val lastName = AttributeReference("last_name", StringType, nullable = false)()
val fullName = Concat(Seq(firstName, Literal(" "), lastName))
val nameLength = Length(firstName)
val upperName = Upper(firstName)
val lowerName = Lower(firstName)
val nameRow = InternalRow(UTF8String.fromString("John"), UTF8String.fromString("Doe"))
val concatResult = fullName.eval(nameRow) // "John Doe"
val lengthResult = nameLength.eval(nameRow) // 4
val upperResult = upperName.eval(nameRow) // "JOHN"
val lowerResult = lowerName.eval(nameRow) // "john"Complex expression capabilities including user-defined functions and sort ordering.
/**
* User-defined function wrapper.
*/
case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None) extends Expression {
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
val evaluatedChildren = children.map(_.eval(input))
// Invoke user function with evaluated arguments
}
}
/**
* Sort ordering specification for ORDER BY and sorting operations.
*/
case class SortOrder(
child: Expression,
direction: SortDirection,
nullOrdering: NullOrdering = NullOrdering(direction)) extends Expression with Unevaluable {
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
override val children: Seq[Expression] = Seq(child)
}
sealed abstract class SortDirection
case object Ascending extends SortDirection
case object Descending extends SortDirection
case class NullOrdering(direction: SortDirection) {
val nullsFirst: Boolean = direction == Descending
val nullsLast: Boolean = direction == Ascending
}
/**
* Attribute set for efficient attribute operations.
*/
class AttributeSet private (private val baseSet: Set[ExprId]) {
/** Check if an attribute is contained in this set */
def contains(a: Attribute): Boolean
/** Add an attribute to this set */
def +(a: Attribute): AttributeSet
/** Remove an attribute from this set */
def -(a: Attribute): AttributeSet
/** Union with another AttributeSet */
def ++(other: AttributeSet): AttributeSet
/** Difference with another AttributeSet */
def --(other: AttributeSet): AttributeSet
/** Intersection with another AttributeSet */
def intersect(other: AttributeSet): AttributeSet
/** Check if this set is a subset of another */
def subsetOf(other: AttributeSet): Boolean
}
object AttributeSet {
/** Create AttributeSet from sequence of attributes */
def apply(attrs: Seq[Attribute]): AttributeSet
/** Create AttributeSet from individual attributes */
def apply(attrs: Attribute*): AttributeSet
/** Empty AttributeSet */
val empty: AttributeSet
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
// User-defined function
val upperFunc = (s: String) => if (s == null) null else s.toUpperCase
val name = AttributeReference("name", StringType, nullable = true)()
val upperUDF = ScalaUDF(upperFunc, StringType, Seq(name), Seq(StringType), Some("upper"))
// Sort ordering
val id = AttributeReference("id", IntegerType, nullable = false)()
val score = AttributeReference("score", DoubleType, nullable = true)()
val idAsc = SortOrder(id, Ascending) // id ASC (nulls last)
val scoreDesc = SortOrder(score, Descending) // score DESC (nulls first)
// Custom null ordering
val scoreDescNullsLast = SortOrder(score, Descending, NullOrdering(Ascending))
// Attribute sets
val attr1 = AttributeReference("a", IntegerType, false)()
val attr2 = AttributeReference("b", StringType, true)()
val attr3 = AttributeReference("c", DoubleType, false)()
val set1 = AttributeSet(attr1, attr2)
val set2 = AttributeSet(attr2, attr3)
val contains = set1.contains(attr1) // true
val union = set1 ++ set2 // {attr1, attr2, attr3}
val intersection = set1.intersect(set2) // {attr2}
val difference = set1 -- set2 // {attr1}
val subset = set1.subsetOf(union) // true
// Complex expression with multiple operations
val complexExpr = And(
GreaterThan(score, Literal(80.0)),
IsNotNull(name)
)
// Expression referencing multiple attributes
val referencedAttrs = complexExpr.references // AttributeSet containing score and name