Catalyst query optimization framework and expression evaluation engine for Apache Spark SQL
—
This section covers the query optimization engine with rule-based and cost-based optimization techniques in Spark Catalyst. The optimizer transforms logical plans into more efficient equivalent plans.
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.expressions._The main optimization engine that applies rule-based optimizations to logical plans.
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch]
}
case class Batch(name: String, strategy: Strategy, rules: Rule[LogicalPlan]*)
abstract class Strategy {
def maxIterations: Int
}
case object Once extends Strategy {
def maxIterations: Int = 1
}
case class FixedPoint(maxIterations: Int) extends Strategyimport org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
// Create optimizer instance
val optimizer = new SimpleTestOptimizer()
// Optimize a logical plan
val logicalPlan = Project(Seq(col("name")), Filter(Literal(true), relation))
val optimizedPlan = optimizer.execute(logicalPlan)object PushDownPredicate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
case Filter(condition, child) => pushDownPredicate(condition, child)
}
}
object PushPredicateThroughJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}object ReorderJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}
object EliminateOuterJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}object ConstantFolding extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case expr if expr.foldable => Literal.create(expr.eval(EmptyRow), expr.dataType)
}
}
object SimplifyConditionals extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}
object SimplifyBinaryComparison extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan
}abstract class RuleExecutor[TreeType <: TreeNode[TreeType]] {
def batches: Seq[Batch]
def execute(plan: TreeType): TreeType = {
var curPlan = plan
batches.foreach { batch =>
val batchStartPlan = curPlan
var iteration = 0
var lastPlan = curPlan
var continue = true
while (continue && iteration < batch.strategy.maxIterations) {
curPlan = batch.rules.foldLeft(curPlan) { (plan, rule) =>
rule(plan)
}
iteration += 1
continue = iteration < batch.strategy.maxIterations && !curPlan.fastEquals(lastPlan)
lastPlan = curPlan
}
}
curPlan
}
}// Constant propagation
case class ConstantPropagation() extends Rule[LogicalPlan]
// Expression simplification
case class SimplifyExtractValueOps() extends Rule[LogicalPlan]
// Boolean expression simplification
case class BooleanSimplification() extends Rule[LogicalPlan]
// Null propagation
case class NullPropagation() extends Rule[LogicalPlan]// Combine filters
case class CombineFilters() extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
case Filter(condition1, Filter(condition2, child)) =>
Filter(And(condition1, condition2), child)
}
}
// Remove redundant predicates
case class PruneFilters() extends Rule[LogicalPlan]
// Convert IN predicates to more efficient forms
case class OptimizeIn() extends Rule[LogicalPlan]// Eliminate cartesian products
case class EliminateCartesianProduct() extends Rule[LogicalPlan]
// Join reordering based on statistics
case class CostBasedJoinReorder() extends Rule[LogicalPlan]
// Convert joins to broadcasts when appropriate
case class JoinSelection() extends Rule[LogicalPlan]case class CostBasedOptimizer extends Optimizer {
override def batches: Seq[Batch] = Seq(
Batch("Statistics", Once,
ComputeCurrentTime,
InferFiltersFromConstraints,
ReorderJoin,
PruneFilters
)
)
}
// Statistics estimation
case class EstimateStatistics() extends Rule[LogicalPlan]
// Join cost calculation
case class JoinCostCalculation() extends Rule[LogicalPlan]import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.expressions._
// Original inefficient plan
val relation1 = UnresolvedRelation(TableIdentifier("orders"))
val relation2 = UnresolvedRelation(TableIdentifier("customers"))
val cartesianJoin = Join(relation1, relation2, Cross, None)
val filter = Filter(
EqualTo(
UnresolvedAttribute("orders.customer_id"),
UnresolvedAttribute("customers.id")
),
cartesianJoin
)
// Apply optimization
val optimizedPlan = optimizer.execute(filter)
// Result: Join with proper join condition instead of cartesian product + filterimport org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.plans.logical._
// Custom rule to remove unnecessary DISTINCT operations
object RemoveRedundantDistinct extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
case Distinct(child) if child.isInstanceOf[Distinct] => child
case Distinct(child) if child.output.forall(_.metadata.contains("unique")) => child
}
}
// Custom optimizer with additional rules
class CustomOptimizer extends Optimizer {
override def batches: Seq[Batch] = super.batches :+
Batch("Custom Optimizations", FixedPoint(100), RemoveRedundantDistinct)
}// Verify optimization correctness
def verifyOptimization(original: LogicalPlan, optimized: LogicalPlan): Boolean = {
// Check that output schema is preserved
original.output.map(_.dataType) == optimized.output.map(_.dataType) &&
original.output.map(_.name) == optimized.output.map(_.name)
}
// Measure optimization benefit
def measureOptimizationBenefit(original: LogicalPlan, optimized: LogicalPlan): Double = {
val originalCost = estimateCost(original)
val optimizedCost = estimateCost(optimized)
(originalCost - optimizedCost) / originalCost
}// Before optimization: Filter on top of Join
val inefficient = Filter(
EqualTo(UnresolvedAttribute("age"), Literal(25)),
Join(usersTable, ordersTable, Inner, joinCondition)
)
// After optimization: Filter pushed down to relation
val efficient = Join(
Filter(EqualTo(UnresolvedAttribute("age"), Literal(25)), usersTable),
ordersTable,
Inner,
joinCondition
)// Before optimization: Unnecessary projection
val redundantProject = Project(
Seq(col("name"), col("age")),
Project(Seq(col("name"), col("age"), col("id")), relation)
)
// After optimization: Single projection
val efficientProject = Project(
Seq(col("name"), col("age")),
relation
)The optimization framework provides a flexible rule-based system for transforming logical plans into more efficient forms while preserving semantic correctness.
Install with Tessl CLI
npx tessl i tessl/maven-org-apache-spark--spark-catalyst