0
# Optimization
1
2
This section covers the query optimization engine with rule-based and cost-based optimization techniques in Spark Catalyst. The optimizer transforms logical plans into more efficient equivalent plans.
3
4
## Core Imports
5
6
```scala
7
import org.apache.spark.sql.catalyst.optimizer._
8
import org.apache.spark.sql.catalyst.rules._
9
import org.apache.spark.sql.catalyst.plans.logical._
10
import org.apache.spark.sql.catalyst.expressions._
11
```
12
13
## Optimizer
14
15
The main optimization engine that applies rule-based optimizations to logical plans.
16
17
```scala { .api }
18
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
19
def batches: Seq[Batch]
20
}
21
22
case class Batch(name: String, strategy: Strategy, rules: Rule[LogicalPlan]*)
23
24
abstract class Strategy {
25
def maxIterations: Int
26
}
27
28
case object Once extends Strategy {
29
def maxIterations: Int = 1
30
}
31
32
case class FixedPoint(maxIterations: Int) extends Strategy
33
```
34
35
### Usage Example
36
37
```scala
38
import org.apache.spark.sql.catalyst.optimizer._
39
import org.apache.spark.sql.catalyst.plans.logical._
40
41
// Create optimizer instance
42
val optimizer = new SimpleTestOptimizer()
43
44
// Optimize a logical plan
45
val logicalPlan = Project(Seq(col("name")), Filter(Literal(true), relation))
46
val optimizedPlan = optimizer.execute(logicalPlan)
47
```
48
49
## Core Optimization Rules
50
51
### Predicate Pushdown
52
53
```scala { .api }
54
object PushDownPredicate extends Rule[LogicalPlan] {
55
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
56
case Filter(condition, child) => pushDownPredicate(condition, child)
57
}
58
}
59
60
object PushPredicateThroughJoin extends Rule[LogicalPlan] {
61
def apply(plan: LogicalPlan): LogicalPlan
62
}
63
```
64
65
### Projection Pushdown
66
67
```scala { .api }
68
object ColumnPruning extends Rule[LogicalPlan] {
69
def apply(plan: LogicalPlan): LogicalPlan
70
}
71
72
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
73
def apply(plan: LogicalPlan): LogicalPlan
74
}
75
```
76
77
### Join Optimization
78
79
```scala { .api }
80
object ReorderJoin extends Rule[LogicalPlan] {
81
def apply(plan: LogicalPlan): LogicalPlan
82
}
83
84
object EliminateOuterJoin extends Rule[LogicalPlan] {
85
def apply(plan: LogicalPlan): LogicalPlan
86
}
87
```
88
89
### Constant Folding and Simplification
90
91
```scala { .api }
92
object ConstantFolding extends Rule[LogicalPlan] {
93
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
94
case expr if expr.foldable => Literal.create(expr.eval(EmptyRow), expr.dataType)
95
}
96
}
97
98
object SimplifyConditionals extends Rule[LogicalPlan] {
99
def apply(plan: LogicalPlan): LogicalPlan
100
}
101
102
object SimplifyBinaryComparison extends Rule[LogicalPlan] {
103
def apply(plan: LogicalPlan): LogicalPlan
104
}
105
```
106
107
## Rule Executor Framework
108
109
### RuleExecutor
110
111
```scala { .api }
112
abstract class RuleExecutor[TreeType <: TreeNode[TreeType]] {
113
def batches: Seq[Batch]
114
115
def execute(plan: TreeType): TreeType = {
116
var curPlan = plan
117
batches.foreach { batch =>
118
val batchStartPlan = curPlan
119
var iteration = 0
120
var lastPlan = curPlan
121
var continue = true
122
123
while (continue && iteration < batch.strategy.maxIterations) {
124
curPlan = batch.rules.foldLeft(curPlan) { (plan, rule) =>
125
rule(plan)
126
}
127
iteration += 1
128
continue = iteration < batch.strategy.maxIterations && !curPlan.fastEquals(lastPlan)
129
lastPlan = curPlan
130
}
131
}
132
curPlan
133
}
134
}
135
```
136
137
## Optimization Techniques
138
139
### Expression Optimization
140
141
```scala { .api }
142
// Constant propagation
143
case class ConstantPropagation() extends Rule[LogicalPlan]
144
145
// Expression simplification
146
case class SimplifyExtractValueOps() extends Rule[LogicalPlan]
147
148
// Boolean expression simplification
149
case class BooleanSimplification() extends Rule[LogicalPlan]
150
151
// Null propagation
152
case class NullPropagation() extends Rule[LogicalPlan]
153
```
154
155
### Predicate Optimization
156
157
```scala { .api }
158
// Combine filters
159
case class CombineFilters() extends Rule[LogicalPlan] {
160
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
161
case Filter(condition1, Filter(condition2, child)) =>
162
Filter(And(condition1, condition2), child)
163
}
164
}
165
166
// Remove redundant predicates
167
case class PruneFilters() extends Rule[LogicalPlan]
168
169
// Convert IN predicates to more efficient forms
170
case class OptimizeIn() extends Rule[LogicalPlan]
171
```
172
173
### Join Optimization
174
175
```scala { .api }
176
// Eliminate cartesian products
177
case class EliminateCartesianProduct() extends Rule[LogicalPlan]
178
179
// Join reordering based on statistics
180
case class CostBasedJoinReorder() extends Rule[LogicalPlan]
181
182
// Convert joins to broadcasts when appropriate
183
case class JoinSelection() extends Rule[LogicalPlan]
184
```
185
186
## Cost-Based Optimization
187
188
### Statistics
189
190
```scala { .api }
191
case class CostBasedOptimizer extends Optimizer {
192
override def batches: Seq[Batch] = Seq(
193
Batch("Statistics", Once,
194
ComputeCurrentTime,
195
InferFiltersFromConstraints,
196
ReorderJoin,
197
PruneFilters
198
)
199
)
200
}
201
202
// Statistics estimation
203
case class EstimateStatistics() extends Rule[LogicalPlan]
204
205
// Join cost calculation
206
case class JoinCostCalculation() extends Rule[LogicalPlan]
207
```
208
209
### Usage Example
210
211
```scala
212
import org.apache.spark.sql.catalyst.plans.logical._
213
import org.apache.spark.sql.catalyst.expressions._
214
215
// Original inefficient plan
216
val relation1 = UnresolvedRelation(TableIdentifier("orders"))
217
val relation2 = UnresolvedRelation(TableIdentifier("customers"))
218
219
val cartesianJoin = Join(relation1, relation2, Cross, None)
220
val filter = Filter(
221
EqualTo(
222
UnresolvedAttribute("orders.customer_id"),
223
UnresolvedAttribute("customers.id")
224
),
225
cartesianJoin
226
)
227
228
// Apply optimization
229
val optimizedPlan = optimizer.execute(filter)
230
// Result: Join with proper join condition instead of cartesian product + filter
231
```
232
233
## Custom Optimization Rules
234
235
### Creating Custom Rules
236
237
```scala
238
import org.apache.spark.sql.catalyst.rules.Rule
239
import org.apache.spark.sql.catalyst.plans.logical._
240
241
// Custom rule to remove unnecessary DISTINCT operations
242
object RemoveRedundantDistinct extends Rule[LogicalPlan] {
243
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
244
case Distinct(child) if child.isInstanceOf[Distinct] => child
245
case Distinct(child) if child.output.forall(_.metadata.contains("unique")) => child
246
}
247
}
248
249
// Custom optimizer with additional rules
250
class CustomOptimizer extends Optimizer {
251
override def batches: Seq[Batch] = super.batches :+
252
Batch("Custom Optimizations", FixedPoint(100), RemoveRedundantDistinct)
253
}
254
```
255
256
## Optimization Verification
257
258
### Plan Comparison
259
260
```scala
261
// Verify optimization correctness
262
def verifyOptimization(original: LogicalPlan, optimized: LogicalPlan): Boolean = {
263
// Check that output schema is preserved
264
original.output.map(_.dataType) == optimized.output.map(_.dataType) &&
265
original.output.map(_.name) == optimized.output.map(_.name)
266
}
267
268
// Measure optimization benefit
269
def measureOptimizationBenefit(original: LogicalPlan, optimized: LogicalPlan): Double = {
270
val originalCost = estimateCost(original)
271
val optimizedCost = estimateCost(optimized)
272
(originalCost - optimizedCost) / originalCost
273
}
274
```
275
276
## Common Optimization Patterns
277
278
### Filter Pushdown Pattern
279
280
```scala
281
// Before optimization: Filter on top of Join
282
val inefficient = Filter(
283
EqualTo(UnresolvedAttribute("age"), Literal(25)),
284
Join(usersTable, ordersTable, Inner, joinCondition)
285
)
286
287
// After optimization: Filter pushed down to relation
288
val efficient = Join(
289
Filter(EqualTo(UnresolvedAttribute("age"), Literal(25)), usersTable),
290
ordersTable,
291
Inner,
292
joinCondition
293
)
294
```
295
296
### Projection Elimination Pattern
297
298
```scala
299
// Before optimization: Unnecessary projection
300
val redundantProject = Project(
301
Seq(col("name"), col("age")),
302
Project(Seq(col("name"), col("age"), col("id")), relation)
303
)
304
305
// After optimization: Single projection
306
val efficientProject = Project(
307
Seq(col("name"), col("age")),
308
relation
309
)
310
```
311
312
The optimization framework provides a flexible rule-based system for transforming logical plans into more efficient forms while preserving semantic correctness.