or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

analysis.mdexpressions.mdindex.mdoptimization.mdquery-planning.mdrow-operations.mdtree-operations.mdtypes.md
tile.json

tree-operations.mddocs/

Tree Operations

Foundation framework for all Catalyst data structures, providing uniform tree traversal, transformation, and manipulation capabilities for query plans and expressions.

Capabilities

TreeNode Base Class

The TreeNode abstract class provides the foundation for all tree-based data structures in Catalyst.

/**
 * Base class for all tree node types in Catalyst.
 * Provides tree traversal and transformation capabilities.
 */
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
  self: BaseType =>
  
  /** The origin information for this tree node */
  val origin: Origin
  
  /**
   * Returns a Seq of the children of this node.
   * Children should not change. Immutability required for containsChild optimization
   */
  def children: Seq[BaseType]
  
  /** Set of children for efficient containment checks */
  lazy val containsChild: Set[TreeNode[_]]
  
  /**
   * Faster version of equality which short-circuits when two treeNodes are the same instance.
   * We don't override Object.equals as doing so prevents scala compiler from generating case class equals methods
   */
  def fastEquals(other: TreeNode[_]): Boolean
}

Usage Examples:

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.types._

// Example with expressions (which extend TreeNode)
val left = Literal(1, IntegerType)
val right = Literal(2, IntegerType)
val add = Add(left, right)

// Access children
val children = add.children  // Seq(left, right)

// Fast equality check
val same = add.fastEquals(add)  // true
val different = add.fastEquals(left)  // false

// Check if contains child
val contains = add.containsChild.contains(left)  // true

Origin Tracking

Origin provides source location information for tree nodes, useful for error reporting and debugging.

/**
 * Tracks source location information for tree nodes
 */
case class Origin(
  line: Option[Int] = None,
  startPosition: Option[Int] = None
)

/**
 * Thread-local context for tracking current parsing origin
 */
object CurrentOrigin {
  /** Get current origin */
  def get: Origin
  
  /** Set current origin */
  def set(o: Origin): Unit
  
  /** Reset to default origin */
  def reset(): Unit
  
  /** Set line and position information */
  def setPosition(line: Int, start: Int): Unit
  
  /**
   * Execute function with temporary origin context
   */
  def withOrigin[A](o: Origin)(f: => A): A
}

Usage Examples:

import org.apache.spark.sql.catalyst.trees._

// Create origin with location information
val origin = Origin(line = Some(42), startPosition = Some(10))

// Use origin context
val result = CurrentOrigin.withOrigin(origin) {
  // Code executed with this origin context
  // Any tree nodes created here will have this origin
  someComplexOperation()
}

// Set global origin for subsequent operations
CurrentOrigin.setPosition(100, 5)
val currentOrigin = CurrentOrigin.get
// Origin(Some(100), Some(5))

// Reset to default
CurrentOrigin.reset()

Tree Traversal Methods

TreeNode provides various methods for traversing and searching tree structures.

/**
 * Find the first TreeNode that satisfies the condition specified by `f`.
 * The condition is recursively applied to this node and all of its children (pre-order).
 */
def find(f: BaseType => Boolean): Option[BaseType]

/**
 * Runs the given function on this node and then recursively on children (pre-order).
 * @param f the function to be applied to each node in the tree.
 */
def foreach(f: BaseType => Unit): Unit

/**
 * Runs the given function recursively on children then on this node (post-order).
 * @param f the function to be applied to each node in the tree.
 */
def foreachUp(f: BaseType => Unit): Unit

/**
 * Returns a Seq containing the result of applying the given function to each node
 * in this tree in a preorder traversal.
 * @param f the function to be applied.
 */
def map[A](f: BaseType => A): Seq[A]

/**
 * Returns a Seq by applying a function to all nodes in this tree and using the elements of the
 * resulting collections.
 */
def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A]

/**
 * Returns a Seq containing all of the trees that satisfy the given predicate.
 */
def collect[B](pf: PartialFunction[BaseType, B]): Seq[B]

/**
 * Finds and returns the first TreeNode of type `T`.
 */
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B]

Usage Examples:

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

// Create a complex expression tree
val a = AttributeReference("a", IntegerType, false)()
val b = AttributeReference("b", IntegerType, false)()
val const = Literal(10, IntegerType)
val expr = Add(Multiply(a, b), const)

// Find first literal in the tree
val firstLiteral = expr.find {
  case _: Literal => true
  case _ => false
}
// Some(Literal(10, IntegerType))

// Collect all attribute references
val attributes = expr.collect {
  case attr: AttributeReference => attr.name
}
// Seq("a", "b")

// Apply function to each node
expr.foreach { node =>
  println(s"Node: ${node.getClass.getSimpleName}")
}

// Transform tree structure
val doubled = expr.map {
  case lit: Literal => Literal(lit.value.asInstanceOf[Int] * 2, lit.dataType)
  case other => other
}

Tree Transformation Methods

Powerful transformation methods for modifying tree structures while preserving type safety.

/**
 * Returns a copy of this node where `rule` has been recursively applied to the tree.
 * When `rule` does not apply to a given node it is left unchanged.
 * Users should not expect a specific directionality. If a specific directionality is needed,
 * transformDown or transformUp should be used.
 */
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType

/**
 * Returns a copy of this node where `rule` has been recursively applied first to all of its
 * children and then itself (post-order). When `rule` does not apply to a given node, it is left
 * unchanged.
 */
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType

/**
 * Returns a copy of this node where `rule` has been recursively applied first to itself and then
 * to all of its children (pre-order). When `rule` does not apply to a given node, it is left
 * unchanged.
 */
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType

Usage Examples:

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

val a = AttributeReference("a", IntegerType, false)()
val b = AttributeReference("b", IntegerType, false)()
val expr = Add(a, Multiply(b, Literal(2, IntegerType)))

// Transform all literals by doubling their values
val transformed = expr.transformDown {
  case Literal(value: Int, dataType) => 
    Literal(value * 2, dataType)
}

// Transform all attribute references to uppercase
val upperCased = expr.transformUp {
  case AttributeReference(name, dataType, nullable, metadata) =>
    AttributeReference(name.toUpperCase, dataType, nullable, metadata)
}

// Conditional transformation
val optimized = expr.transform {
  case Multiply(child, Literal(1, _)) => child  // x * 1 => x
  case Multiply(Literal(1, _), child) => child  // 1 * x => x
  case Add(child, Literal(0, _)) => child       // x + 0 => x
  case Add(Literal(0, _), child) => child       // 0 + x => x
}

Advanced Tree Operations

Additional methods for advanced tree manipulation and analysis.

/**
 * Args to the constructor that should be used to construct copies of this object.
 * Subclasses should override this method to return the args that constructed them.
 */
def productIterator: Iterator[Any]

/**
 * Returns a string representing the arguments to this node, minus any children
 */
def argString: String

/**
 * ONE line description of this node.
 */
def simpleString: String

/**
 * ALL the nodes that should be shown as a result of printing this node.
 * All nodes in this seq will be shown in the tree format.
 */
def innerChildren: Seq[TreeNode[_]]

/**
 * Appends the string representation of this node and its children to the given StringBuilder.
 */
def generateTreeString(
    depth: Int,
    lastChildren: Seq[Boolean],
    builder: StringBuilder,
    verbose: Boolean,
    prefix: String = "",
    addSuffix: Boolean = false): StringBuilder

Usage Examples:

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

val expr = Add(
  AttributeReference("x", IntegerType, false)(),
  Literal(42, IntegerType)
)

// Get simple string representation
val simple = expr.simpleString
// "add(x#0, 42)"

// Get argument string (without children)
val args = expr.argString
// ""

// Generate tree string for visualization
val treeString = expr.generateTreeString(0, Seq.empty, new StringBuilder(), false)
println(treeString.toString())
// Output shows tree structure with proper indentation

// Access product iterator for reflection
val constructorArgs = expr.productIterator.toSeq
// Seq(AttributeReference("x", IntegerType, false), Literal(42, IntegerType))

Tree Node Pattern Matching

TreeNode supports pattern matching for elegant tree processing.

import org.apache.spark.sql.catalyst.expressions._

// Pattern match on tree structure
def optimizeExpression(expr: Expression): Expression = expr match {
  case Add(left, Literal(0, _)) => left  // x + 0 => x
  case Add(Literal(0, _), right) => right // 0 + x => x
  case Multiply(_, Literal(0, _)) => Literal(0, expr.dataType) // x * 0 => 0
  case Multiply(Literal(0, _), _) => Literal(0, expr.dataType) // 0 * x => 0
  case Multiply(left, Literal(1, _)) => left // x * 1 => x
  case Multiply(Literal(1, _), right) => right // 1 * x => x
  case other => other
}

// Recursive pattern matching with transformation
def constantFold(expr: Expression): Expression = expr.transformDown {
  case Add(Literal(a: Int, _), Literal(b: Int, _)) => 
    Literal(a + b, IntegerType)
  case Multiply(Literal(a: Int, _), Literal(b: Int, _)) => 
    Literal(a * b, IntegerType)
}