Rule-based optimization system with built-in optimizations for query plan improvement and performance enhancement.
The main optimizer containing collections of optimization rules.
/**
* Collection of optimization rules for logical plans
*/
object Optimizer extends RuleExecutor[LogicalPlan] {
/** Sequence of optimization rule batches */
def batches: Seq[Batch]
/** Execute all optimization rules on a plan */
def execute(plan: LogicalPlan): LogicalPlan
}
/**
* Base class for transformation rules
*/
abstract class Rule[TreeType <: TreeNode[TreeType]] extends Logging {
/** Rule name for debugging and logging */
def ruleName: String = this.getClass.getSimpleName
/** Apply rule to a tree node */
def apply(plan: TreeType): TreeType
}
/**
* Executes batches of rules with different strategies
*/
abstract class RuleExecutor[TreeType <: TreeNode[TreeType]] extends Logging {
/** Execute all rule batches on input */
def execute(plan: TreeType): TreeType
/** Sequence of rule batches to execute */
def batches: Seq[Batch]
/** Maximum number of iterations per batch */
protected def maxIterations: Int = 100
}
/**
* Group of rules executed together with execution strategy
*/
case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) {
/** Rules in this batch */
def rulesIterator: Iterator[Rule[TreeType]] = rules.iterator
}
/** Execution strategies for rule batches */
sealed abstract class Strategy
case object Once extends Strategy
case class FixedPoint(maxIterations: Int) extends StrategyUsage Examples:
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.rules._
// Apply optimizer to a logical plan
val plan = Project(
Seq(Add(Literal(1), Literal(2)).as("sum")), // 1 + 2 (can be folded)
LocalRelation(AttributeReference("dummy", IntegerType, false)())
)
val optimizedPlan = Optimizer.execute(plan)
// Result: Literal(3) replaces Add(Literal(1), Literal(2))
// Create custom optimization rule
object MyCustomRule extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, child) if projectList.isEmpty =>
child // Remove empty projections
}
}
// Create custom optimizer
object MyOptimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = Seq(
Batch("Custom Optimizations", Once, MyCustomRule),
Batch("Constant Folding", FixedPoint(100), ConstantFolding)
)
}
val customOptimized = MyOptimizer.execute(plan)Rules for optimizing expressions within query plans.
/**
* Fold constant expressions into literals
*/
object ConstantFolding extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformAllExpressions {
case expr if expr.foldable => Literal.create(expr.eval(), expr.dataType)
}
}
}
/**
* Simplify boolean expressions
*/
object BooleanSimplification extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformAllExpressions {
case And(TrueLiteral, right) => right
case And(left, TrueLiteral) => left
case And(FalseLiteral, _) => FalseLiteral
case And(_, FalseLiteral) => FalseLiteral
case Or(TrueLiteral, _) => TrueLiteral
case Or(_, TrueLiteral) => TrueLiteral
case Or(FalseLiteral, right) => right
case Or(left, FalseLiteral) => left
case Not(TrueLiteral) => FalseLiteral
case Not(FalseLiteral) => TrueLiteral
case Not(Not(expr)) => expr
}
}
}
/**
* Simplify LIKE expressions to more efficient forms
*/
object LikeSimplification extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformAllExpressions {
case Like(input, Literal(pattern: UTF8String, StringType)) =>
if (pattern.toString == "%") {
// LIKE '%' is always true for non-null strings
IsNotNull(input)
} else if (!pattern.toString.contains("%") && !pattern.toString.contains("_")) {
// No wildcards - convert to equality
EqualTo(input, Literal(pattern, StringType))
} else {
// Keep original LIKE
Like(input, Literal(pattern, StringType))
}
}
}
}Usage Examples:
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
// Constant folding example
val exprPlan = Project(
Seq(
Add(Literal(10), Literal(5)).as("sum"), // 10 + 5 -> 15
Multiply(Literal(3), Literal(4)).as("product"), // 3 * 4 -> 12
Subtract(Literal(20), Literal(8)).as("diff") // 20 - 8 -> 12
),
LocalRelation(AttributeReference("dummy", IntegerType, false)())
)
val foldedPlan = ConstantFolding.apply(exprPlan)
// All arithmetic expressions become literals
// Boolean simplification example
val boolPlan = Filter(
And(
And(Literal(true), GreaterThan(col("age"), Literal(18))), // TRUE AND (age > 18) -> (age > 18)
Or(Literal(false), EqualTo(col("active"), Literal(true))) // FALSE OR (active = TRUE) -> (active = TRUE)
),
someRelation
)
val simplifiedPlan = BooleanSimplification.apply(boolPlan)
// Results in: Filter(And(GreaterThan(col("age"), Literal(18)), EqualTo(col("active"), Literal(true))), someRelation)
// LIKE simplification example
val likePlan = Filter(
And(
Like(col("name"), Literal("%")), // Always true for non-null -> IsNotNull(name)
Like(col("code"), Literal("ABC")) // No wildcards -> EqualTo(code, "ABC")
),
someRelation
)
val likeSimplified = LikeSimplification.apply(likePlan)Rules for optimizing the structure of query plans.
/**
* Remove unused columns from query plans
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case Project(projectList, child) =>
val usedColumns = projectList.flatMap(_.references).toSet
pruneChild(child, usedColumns) match {
case prunedChild if prunedChild.output != child.output =>
Project(projectList, prunedChild)
case _ => Project(projectList, child)
}
}
}
private def pruneChild(plan: LogicalPlan, requiredColumns: Set[Attribute]): LogicalPlan = {
// Implementation to remove unused columns
}
}
/**
* Push filter predicates down to data sources
*/
object FilterPushdown extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case Filter(condition, Join(left, right, joinType, joinCondition)) =>
// Split condition into parts that can be pushed to left/right sides
val (leftFilters, rightFilters, remainingFilters) = splitConjunctivePredicates(condition)
val newLeft = if (leftFilters.nonEmpty) Filter(leftFilters.reduce(And), left) else left
val newRight = if (rightFilters.nonEmpty) Filter(rightFilters.reduce(And), right) else right
val newJoin = Join(newLeft, newRight, joinType, joinCondition)
if (remainingFilters.nonEmpty) {
Filter(remainingFilters.reduce(And), newJoin)
} else {
newJoin
}
}
}
}
/**
* Collapse adjacent projections
*/
object ProjectCollapsing extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case Project(projectList1, Project(projectList2, child)) =>
// Substitute expressions from inner projection into outer projection
val substituted = projectList1.map(_.transform {
case a: AttributeReference =>
projectList2.find(_.exprId == a.exprId).map(_.child).getOrElse(a)
})
Project(substituted, child)
}
}
}
/**
* Combine adjacent limit operations
*/
object CombineLimits extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case Limit(expr1, Limit(expr2, child)) =>
// Take minimum of the two limits
Limit(
If(LessThan(expr1, expr2), expr1, expr2),
child
)
}
}
}
/**
* Convert small relations to local relations
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case relation if isSmallRelation(relation) =>
// Convert to LocalRelation with data materialized in memory
materializeAsLocalRelation(relation)
}
}
private def isSmallRelation(plan: LogicalPlan): Boolean = {
// Check if relation is small enough to materialize locally
}
private def materializeAsLocalRelation(plan: LogicalPlan): LocalRelation = {
// Convert to LocalRelation
}
}Usage Examples:
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.expressions._
// Column pruning example
val relation = LocalRelation(
AttributeReference("id", IntegerType, false)(),
AttributeReference("name", StringType, true)(),
AttributeReference("age", IntegerType, true)(),
AttributeReference("unused", StringType, true)()
)
val projectPlan = Project(
Seq(relation.output(0), relation.output(1)), // Only id and name used
relation
)
val prunedPlan = ColumnPruning.apply(projectPlan)
// Unused columns (age, unused) are removed from the scan
// Filter pushdown example
val leftRelation = LocalRelation(
AttributeReference("user_id", IntegerType, false)(),
AttributeReference("name", StringType, true)()
)
val rightRelation = LocalRelation(
AttributeReference("user_id", IntegerType, false)(),
AttributeReference("order_total", DoubleType, true)()
)
val joinPlan = Join(leftRelation, rightRelation, Inner,
Some(EqualTo(leftRelation.output(0), rightRelation.output(0))))
val filterAfterJoin = Filter(
And(
GreaterThan(leftRelation.output(0), Literal(100)), // Can push to left
GreaterThan(rightRelation.output(1), Literal(50.0)) // Can push to right
),
joinPlan
)
val pushedDown = FilterPushdown.apply(filterAfterJoin)
// Filters are pushed down before the join
// Project collapsing example
val innerProject = Project(
Seq(
relation.output(0).as("user_id"),
relation.output(1).as("user_name")
),
relation
)
val outerProject = Project(
Seq(innerProject.output(1)), // Only user_name
innerProject
)
val collapsed = ProjectCollapsing.apply(outerProject)
// Results in single projection: Project(Seq(relation.output(1)), relation)
// Limit combining example
val innerLimit = Limit(Literal(100), relation)
val outerLimit = Limit(Literal(50), innerLimit)
val combinedLimit = CombineLimits.apply(outerLimit)
// Results in: Limit(Literal(50), relation) - takes minimumComplex optimization patterns and aggregate optimizations.
/**
* Optimize aggregate operations
*/
object AggregateOptimize extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case Aggregate(groupExpr, aggExpr, child) =>
// Optimize aggregations - remove redundant grouping, etc.
optimizeAggregate(groupExpr, aggExpr, child)
}
}
private def optimizeAggregate(
groupExpr: Seq[Expression],
aggExpr: Seq[NamedExpression],
child: LogicalPlan): LogicalPlan = {
// Implementation for aggregate optimizations
}
}
/**
* Optimize set operations (UNION, INTERSECT, EXCEPT)
*/
object SetOperationPushDown extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case Filter(condition, Union(children)) =>
// Push filter into all union children
Union(children.map(Filter(condition, _)))
case Project(projectList, Union(children)) =>
// Push projection into all union children
Union(children.map(Project(projectList, _)))
}
}
}
/**
* Eliminate common sub-expressions
*/
object EliminateSubexpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case p =>
val commonExprs = findCommonSubexpressions(p.expressions)
if (commonExprs.nonEmpty) {
// Replace common subexpressions with references
eliminateCommonSubexpressions(p, commonExprs)
} else {
p
}
}
}
}Usage Examples:
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
// Aggregate optimization example
val relation = LocalRelation(
AttributeReference("category", StringType, false)(),
AttributeReference("amount", DoubleType, false)(),
AttributeReference("quantity", IntegerType, false)()
)
val aggregate = Aggregate(
Seq(relation.output(0)), // GROUP BY category
Seq(
Sum(relation.output(1)).as("total_amount"),
Count(Literal(1)).as("count"),
Average(relation.output(1)).as("avg_amount")
),
relation
)
val optimizedAgg = AggregateOptimize.apply(aggregate)
// Set operation pushdown example
val relation1 = LocalRelation(
AttributeReference("id", IntegerType, false)(),
AttributeReference("value", StringType, true)()
)
val relation2 = LocalRelation(relation1.output) // Same schema
val unionPlan = Union(Seq(relation1, relation2))
val filteredUnion = Filter(
GreaterThan(relation1.output(0), Literal(10)),
unionPlan
)
val pushedFilter = SetOperationPushDown.apply(filteredUnion)
// Results in: Union(Seq(Filter(..., relation1), Filter(..., relation2)))
// Common subexpression elimination example
val complexExpr1 = Add(relation.output(1), relation.output(2)) // amount + quantity
val complexExpr2 = Multiply(complexExpr1, Literal(0.1)) // (amount + quantity) * 0.1
val planWithDuplicates = Project(
Seq(
complexExpr1.as("sum"),
complexExpr2.as("discounted"),
Add(complexExpr1, Literal(5)).as("sum_plus_five") // Reuses complexExpr1
),
relation
)
val optimizedPlan = EliminateSubexpressions.apply(planWithDuplicates)
// Common subexpression (amount + quantity) is computed once and reusedComplete optimization pipeline with configurable batches.
/**
* Complete optimization pipeline
*/
object DefaultOptimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = Seq(
// Finish Analysis
Batch("Finish Analysis", Once,
EliminateSubqueryAliases,
ReplaceExpressions,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog)),
// Substitution
Batch("Substitution", fixedPoint,
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
new SubstituteUnresolvedOrdinals(conf)),
// Constant Folding and Strength Reduction
Batch("Constant Folding", fixedPoint,
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
LikeSimplification),
// Operator Optimizations
Batch("Operator Optimizations", fixedPoint,
SetOperationPushDown,
SamplePushDown,
PushDownPredicate,
PushDownLeftSemiAntiJoin,
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
CollapseRepartition,
CollapseProject,
CombineFilters,
CombineLimits,
CombineUnions,
NullPropagation,
ConstantFolding,
BooleanSimplification,
RemoveRedundantProject,
SimplifyCreateStructOps,
SimplifyCreateArrayOps,
SimplifyCreateMapOps),
// Join Reorder
Batch("Join Reorder", Once,
CostBasedJoinReorder),
// Local Relation Optimization
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation)
)
val fixedPoint = FixedPoint(100)
}Usage Examples:
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
// Use complete optimization pipeline
val complexPlan = Project(
Seq(
Add(Literal(1), Literal(2)).as("const_sum"), // Constant folding
col("name").as("name") // Column pruning opportunity
),
Filter(
And(Literal(true), GreaterThan(col("age"), Literal(18))), // Boolean simplification
Project(
Seq(col("id"), col("name"), col("age"), col("unused")), // Column pruning
someBaseRelation
)
)
)
val fullyOptimized = DefaultOptimizer.execute(complexPlan)
// Applies all optimization rules in proper order:
// 1. Boolean simplification: TRUE AND (age > 18) -> (age > 18)
// 2. Constant folding: 1 + 2 -> 3
// 3. Column pruning: removes "unused" column and "id" (not referenced above)
// 4. Project collapsing: merges adjacent projections
// 5. Other applicable optimizations
// Custom optimization pipeline
object MinimalOptimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = Seq(
Batch("Basic", FixedPoint(50),
ConstantFolding,
BooleanSimplification,
ColumnPruning
)
)
}
val minimalOptimized = MinimalOptimizer.execute(complexPlan)
// Applies only basic optimizations