Foundation framework for all Catalyst data structures, providing uniform tree traversal, transformation, and manipulation capabilities for query plans and expressions.
The TreeNode abstract class provides the foundation for all tree-based data structures in Catalyst.
/**
* Base class for all tree node types in Catalyst.
* Provides tree traversal and transformation capabilities.
*/
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
self: BaseType =>
/** The origin information for this tree node */
val origin: Origin
/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
*/
def children: Seq[BaseType]
/** Set of children for efficient containment checks */
lazy val containsChild: Set[TreeNode[_]]
/**
* Faster version of equality which short-circuits when two treeNodes are the same instance.
* We don't override Object.equals as doing so prevents scala compiler from generating case class equals methods
*/
def fastEquals(other: TreeNode[_]): Boolean
}Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.types._
// Example with expressions (which extend TreeNode)
val left = Literal(1, IntegerType)
val right = Literal(2, IntegerType)
val add = Add(left, right)
// Access children
val children = add.children // Seq(left, right)
// Fast equality check
val same = add.fastEquals(add) // true
val different = add.fastEquals(left) // false
// Check if contains child
val contains = add.containsChild.contains(left) // trueOrigin provides source location information for tree nodes, useful for error reporting and debugging.
/**
* Tracks source location information for tree nodes
*/
case class Origin(
line: Option[Int] = None,
startPosition: Option[Int] = None
)
/**
* Thread-local context for tracking current parsing origin
*/
object CurrentOrigin {
/** Get current origin */
def get: Origin
/** Set current origin */
def set(o: Origin): Unit
/** Reset to default origin */
def reset(): Unit
/** Set line and position information */
def setPosition(line: Int, start: Int): Unit
/**
* Execute function with temporary origin context
*/
def withOrigin[A](o: Origin)(f: => A): A
}Usage Examples:
import org.apache.spark.sql.catalyst.trees._
// Create origin with location information
val origin = Origin(line = Some(42), startPosition = Some(10))
// Use origin context
val result = CurrentOrigin.withOrigin(origin) {
// Code executed with this origin context
// Any tree nodes created here will have this origin
someComplexOperation()
}
// Set global origin for subsequent operations
CurrentOrigin.setPosition(100, 5)
val currentOrigin = CurrentOrigin.get
// Origin(Some(100), Some(5))
// Reset to default
CurrentOrigin.reset()TreeNode provides various methods for traversing and searching tree structures.
/**
* Find the first TreeNode that satisfies the condition specified by `f`.
* The condition is recursively applied to this node and all of its children (pre-order).
*/
def find(f: BaseType => Boolean): Option[BaseType]
/**
* Runs the given function on this node and then recursively on children (pre-order).
* @param f the function to be applied to each node in the tree.
*/
def foreach(f: BaseType => Unit): Unit
/**
* Runs the given function recursively on children then on this node (post-order).
* @param f the function to be applied to each node in the tree.
*/
def foreachUp(f: BaseType => Unit): Unit
/**
* Returns a Seq containing the result of applying the given function to each node
* in this tree in a preorder traversal.
* @param f the function to be applied.
*/
def map[A](f: BaseType => A): Seq[A]
/**
* Returns a Seq by applying a function to all nodes in this tree and using the elements of the
* resulting collections.
*/
def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A]
/**
* Returns a Seq containing all of the trees that satisfy the given predicate.
*/
def collect[B](pf: PartialFunction[BaseType, B]): Seq[B]
/**
* Finds and returns the first TreeNode of type `T`.
*/
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B]Usage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
// Create a complex expression tree
val a = AttributeReference("a", IntegerType, false)()
val b = AttributeReference("b", IntegerType, false)()
val const = Literal(10, IntegerType)
val expr = Add(Multiply(a, b), const)
// Find first literal in the tree
val firstLiteral = expr.find {
case _: Literal => true
case _ => false
}
// Some(Literal(10, IntegerType))
// Collect all attribute references
val attributes = expr.collect {
case attr: AttributeReference => attr.name
}
// Seq("a", "b")
// Apply function to each node
expr.foreach { node =>
println(s"Node: ${node.getClass.getSimpleName}")
}
// Transform tree structure
val doubled = expr.map {
case lit: Literal => Literal(lit.value.asInstanceOf[Int] * 2, lit.dataType)
case other => other
}Powerful transformation methods for modifying tree structures while preserving type safety.
/**
* Returns a copy of this node where `rule` has been recursively applied to the tree.
* When `rule` does not apply to a given node it is left unchanged.
* Users should not expect a specific directionality. If a specific directionality is needed,
* transformDown or transformUp should be used.
*/
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType
/**
* Returns a copy of this node where `rule` has been recursively applied first to all of its
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
* unchanged.
*/
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType
/**
* Returns a copy of this node where `rule` has been recursively applied first to itself and then
* to all of its children (pre-order). When `rule` does not apply to a given node, it is left
* unchanged.
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseTypeUsage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
val a = AttributeReference("a", IntegerType, false)()
val b = AttributeReference("b", IntegerType, false)()
val expr = Add(a, Multiply(b, Literal(2, IntegerType)))
// Transform all literals by doubling their values
val transformed = expr.transformDown {
case Literal(value: Int, dataType) =>
Literal(value * 2, dataType)
}
// Transform all attribute references to uppercase
val upperCased = expr.transformUp {
case AttributeReference(name, dataType, nullable, metadata) =>
AttributeReference(name.toUpperCase, dataType, nullable, metadata)
}
// Conditional transformation
val optimized = expr.transform {
case Multiply(child, Literal(1, _)) => child // x * 1 => x
case Multiply(Literal(1, _), child) => child // 1 * x => x
case Add(child, Literal(0, _)) => child // x + 0 => x
case Add(Literal(0, _), child) => child // 0 + x => x
}Additional methods for advanced tree manipulation and analysis.
/**
* Args to the constructor that should be used to construct copies of this object.
* Subclasses should override this method to return the args that constructed them.
*/
def productIterator: Iterator[Any]
/**
* Returns a string representing the arguments to this node, minus any children
*/
def argString: String
/**
* ONE line description of this node.
*/
def simpleString: String
/**
* ALL the nodes that should be shown as a result of printing this node.
* All nodes in this seq will be shown in the tree format.
*/
def innerChildren: Seq[TreeNode[_]]
/**
* Appends the string representation of this node and its children to the given StringBuilder.
*/
def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
builder: StringBuilder,
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false): StringBuilderUsage Examples:
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
val expr = Add(
AttributeReference("x", IntegerType, false)(),
Literal(42, IntegerType)
)
// Get simple string representation
val simple = expr.simpleString
// "add(x#0, 42)"
// Get argument string (without children)
val args = expr.argString
// ""
// Generate tree string for visualization
val treeString = expr.generateTreeString(0, Seq.empty, new StringBuilder(), false)
println(treeString.toString())
// Output shows tree structure with proper indentation
// Access product iterator for reflection
val constructorArgs = expr.productIterator.toSeq
// Seq(AttributeReference("x", IntegerType, false), Literal(42, IntegerType))TreeNode supports pattern matching for elegant tree processing.
import org.apache.spark.sql.catalyst.expressions._
// Pattern match on tree structure
def optimizeExpression(expr: Expression): Expression = expr match {
case Add(left, Literal(0, _)) => left // x + 0 => x
case Add(Literal(0, _), right) => right // 0 + x => x
case Multiply(_, Literal(0, _)) => Literal(0, expr.dataType) // x * 0 => 0
case Multiply(Literal(0, _), _) => Literal(0, expr.dataType) // 0 * x => 0
case Multiply(left, Literal(1, _)) => left // x * 1 => x
case Multiply(Literal(1, _), right) => right // 1 * x => x
case other => other
}
// Recursive pattern matching with transformation
def constantFold(expr: Expression): Expression = expr.transformDown {
case Add(Literal(a: Int, _), Literal(b: Int, _)) =>
Literal(a + b, IntegerType)
case Multiply(Literal(a: Int, _), Literal(b: Int, _)) =>
Literal(a * b, IntegerType)
}