0
# Tree Operations
1
2
Foundation framework for all Catalyst data structures, providing uniform tree traversal, transformation, and manipulation capabilities for query plans and expressions.
3
4
## Capabilities
5
6
### TreeNode Base Class
7
8
The TreeNode abstract class provides the foundation for all tree-based data structures in Catalyst.
9
10
```scala { .api }
11
/**
12
* Base class for all tree node types in Catalyst.
13
* Provides tree traversal and transformation capabilities.
14
*/
15
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
16
self: BaseType =>
17
18
/** The origin information for this tree node */
19
val origin: Origin
20
21
/**
22
* Returns a Seq of the children of this node.
23
* Children should not change. Immutability required for containsChild optimization
24
*/
25
def children: Seq[BaseType]
26
27
/** Set of children for efficient containment checks */
28
lazy val containsChild: Set[TreeNode[_]]
29
30
/**
31
* Faster version of equality which short-circuits when two treeNodes are the same instance.
32
* We don't override Object.equals as doing so prevents scala compiler from generating case class equals methods
33
*/
34
def fastEquals(other: TreeNode[_]): Boolean
35
}
36
```
37
38
**Usage Examples:**
39
40
```scala
41
import org.apache.spark.sql.catalyst.expressions._
42
import org.apache.spark.sql.catalyst.trees._
43
import org.apache.spark.sql.types._
44
45
// Example with expressions (which extend TreeNode)
46
val left = Literal(1, IntegerType)
47
val right = Literal(2, IntegerType)
48
val add = Add(left, right)
49
50
// Access children
51
val children = add.children // Seq(left, right)
52
53
// Fast equality check
54
val same = add.fastEquals(add) // true
55
val different = add.fastEquals(left) // false
56
57
// Check if contains child
58
val contains = add.containsChild.contains(left) // true
59
```
60
61
### Origin Tracking
62
63
Origin provides source location information for tree nodes, useful for error reporting and debugging.
64
65
```scala { .api }
66
/**
67
* Tracks source location information for tree nodes
68
*/
69
case class Origin(
70
line: Option[Int] = None,
71
startPosition: Option[Int] = None
72
)
73
74
/**
75
* Thread-local context for tracking current parsing origin
76
*/
77
object CurrentOrigin {
78
/** Get current origin */
79
def get: Origin
80
81
/** Set current origin */
82
def set(o: Origin): Unit
83
84
/** Reset to default origin */
85
def reset(): Unit
86
87
/** Set line and position information */
88
def setPosition(line: Int, start: Int): Unit
89
90
/**
91
* Execute function with temporary origin context
92
*/
93
def withOrigin[A](o: Origin)(f: => A): A
94
}
95
```
96
97
**Usage Examples:**
98
99
```scala
100
import org.apache.spark.sql.catalyst.trees._
101
102
// Create origin with location information
103
val origin = Origin(line = Some(42), startPosition = Some(10))
104
105
// Use origin context
106
val result = CurrentOrigin.withOrigin(origin) {
107
// Code executed with this origin context
108
// Any tree nodes created here will have this origin
109
someComplexOperation()
110
}
111
112
// Set global origin for subsequent operations
113
CurrentOrigin.setPosition(100, 5)
114
val currentOrigin = CurrentOrigin.get
115
// Origin(Some(100), Some(5))
116
117
// Reset to default
118
CurrentOrigin.reset()
119
```
120
121
### Tree Traversal Methods
122
123
TreeNode provides various methods for traversing and searching tree structures.
124
125
```scala { .api }
126
/**
127
* Find the first TreeNode that satisfies the condition specified by `f`.
128
* The condition is recursively applied to this node and all of its children (pre-order).
129
*/
130
def find(f: BaseType => Boolean): Option[BaseType]
131
132
/**
133
* Runs the given function on this node and then recursively on children (pre-order).
134
* @param f the function to be applied to each node in the tree.
135
*/
136
def foreach(f: BaseType => Unit): Unit
137
138
/**
139
* Runs the given function recursively on children then on this node (post-order).
140
* @param f the function to be applied to each node in the tree.
141
*/
142
def foreachUp(f: BaseType => Unit): Unit
143
144
/**
145
* Returns a Seq containing the result of applying the given function to each node
146
* in this tree in a preorder traversal.
147
* @param f the function to be applied.
148
*/
149
def map[A](f: BaseType => A): Seq[A]
150
151
/**
152
* Returns a Seq by applying a function to all nodes in this tree and using the elements of the
153
* resulting collections.
154
*/
155
def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A]
156
157
/**
158
* Returns a Seq containing all of the trees that satisfy the given predicate.
159
*/
160
def collect[B](pf: PartialFunction[BaseType, B]): Seq[B]
161
162
/**
163
* Finds and returns the first TreeNode of type `T`.
164
*/
165
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B]
166
```
167
168
**Usage Examples:**
169
170
```scala
171
import org.apache.spark.sql.catalyst.expressions._
172
import org.apache.spark.sql.types._
173
174
// Create a complex expression tree
175
val a = AttributeReference("a", IntegerType, false)()
176
val b = AttributeReference("b", IntegerType, false)()
177
val const = Literal(10, IntegerType)
178
val expr = Add(Multiply(a, b), const)
179
180
// Find first literal in the tree
181
val firstLiteral = expr.find {
182
case _: Literal => true
183
case _ => false
184
}
185
// Some(Literal(10, IntegerType))
186
187
// Collect all attribute references
188
val attributes = expr.collect {
189
case attr: AttributeReference => attr.name
190
}
191
// Seq("a", "b")
192
193
// Apply function to each node
194
expr.foreach { node =>
195
println(s"Node: ${node.getClass.getSimpleName}")
196
}
197
198
// Transform tree structure
199
val doubled = expr.map {
200
case lit: Literal => Literal(lit.value.asInstanceOf[Int] * 2, lit.dataType)
201
case other => other
202
}
203
```
204
205
### Tree Transformation Methods
206
207
Powerful transformation methods for modifying tree structures while preserving type safety.
208
209
```scala { .api }
210
/**
211
* Returns a copy of this node where `rule` has been recursively applied to the tree.
212
* When `rule` does not apply to a given node it is left unchanged.
213
* Users should not expect a specific directionality. If a specific directionality is needed,
214
* transformDown or transformUp should be used.
215
*/
216
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType
217
218
/**
219
* Returns a copy of this node where `rule` has been recursively applied first to all of its
220
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
221
* unchanged.
222
*/
223
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType
224
225
/**
226
* Returns a copy of this node where `rule` has been recursively applied first to itself and then
227
* to all of its children (pre-order). When `rule` does not apply to a given node, it is left
228
* unchanged.
229
*/
230
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType
231
```
232
233
**Usage Examples:**
234
235
```scala
236
import org.apache.spark.sql.catalyst.expressions._
237
import org.apache.spark.sql.types._
238
239
val a = AttributeReference("a", IntegerType, false)()
240
val b = AttributeReference("b", IntegerType, false)()
241
val expr = Add(a, Multiply(b, Literal(2, IntegerType)))
242
243
// Transform all literals by doubling their values
244
val transformed = expr.transformDown {
245
case Literal(value: Int, dataType) =>
246
Literal(value * 2, dataType)
247
}
248
249
// Transform all attribute references to uppercase
250
val upperCased = expr.transformUp {
251
case AttributeReference(name, dataType, nullable, metadata) =>
252
AttributeReference(name.toUpperCase, dataType, nullable, metadata)
253
}
254
255
// Conditional transformation
256
val optimized = expr.transform {
257
case Multiply(child, Literal(1, _)) => child // x * 1 => x
258
case Multiply(Literal(1, _), child) => child // 1 * x => x
259
case Add(child, Literal(0, _)) => child // x + 0 => x
260
case Add(Literal(0, _), child) => child // 0 + x => x
261
}
262
```
263
264
### Advanced Tree Operations
265
266
Additional methods for advanced tree manipulation and analysis.
267
268
```scala { .api }
269
/**
270
* Args to the constructor that should be used to construct copies of this object.
271
* Subclasses should override this method to return the args that constructed them.
272
*/
273
def productIterator: Iterator[Any]
274
275
/**
276
* Returns a string representing the arguments to this node, minus any children
277
*/
278
def argString: String
279
280
/**
281
* ONE line description of this node.
282
*/
283
def simpleString: String
284
285
/**
286
* ALL the nodes that should be shown as a result of printing this node.
287
* All nodes in this seq will be shown in the tree format.
288
*/
289
def innerChildren: Seq[TreeNode[_]]
290
291
/**
292
* Appends the string representation of this node and its children to the given StringBuilder.
293
*/
294
def generateTreeString(
295
depth: Int,
296
lastChildren: Seq[Boolean],
297
builder: StringBuilder,
298
verbose: Boolean,
299
prefix: String = "",
300
addSuffix: Boolean = false): StringBuilder
301
```
302
303
**Usage Examples:**
304
305
```scala
306
import org.apache.spark.sql.catalyst.expressions._
307
import org.apache.spark.sql.types._
308
309
val expr = Add(
310
AttributeReference("x", IntegerType, false)(),
311
Literal(42, IntegerType)
312
)
313
314
// Get simple string representation
315
val simple = expr.simpleString
316
// "add(x#0, 42)"
317
318
// Get argument string (without children)
319
val args = expr.argString
320
// ""
321
322
// Generate tree string for visualization
323
val treeString = expr.generateTreeString(0, Seq.empty, new StringBuilder(), false)
324
println(treeString.toString())
325
// Output shows tree structure with proper indentation
326
327
// Access product iterator for reflection
328
val constructorArgs = expr.productIterator.toSeq
329
// Seq(AttributeReference("x", IntegerType, false), Literal(42, IntegerType))
330
```
331
332
### Tree Node Pattern Matching
333
334
TreeNode supports pattern matching for elegant tree processing.
335
336
```scala
337
import org.apache.spark.sql.catalyst.expressions._
338
339
// Pattern match on tree structure
340
def optimizeExpression(expr: Expression): Expression = expr match {
341
case Add(left, Literal(0, _)) => left // x + 0 => x
342
case Add(Literal(0, _), right) => right // 0 + x => x
343
case Multiply(_, Literal(0, _)) => Literal(0, expr.dataType) // x * 0 => 0
344
case Multiply(Literal(0, _), _) => Literal(0, expr.dataType) // 0 * x => 0
345
case Multiply(left, Literal(1, _)) => left // x * 1 => x
346
case Multiply(Literal(1, _), right) => right // 1 * x => x
347
case other => other
348
}
349
350
// Recursive pattern matching with transformation
351
def constantFold(expr: Expression): Expression = expr.transformDown {
352
case Add(Literal(a: Int, _), Literal(b: Int, _)) =>
353
Literal(a + b, IntegerType)
354
case Multiply(Literal(a: Int, _), Literal(b: Int, _)) =>
355
Literal(a * b, IntegerType)
356
}
357
```