or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

aggregations.mdcatalog.mddata-io.mddataframe-dataset.mdfunctions-expressions.mdindex.mdsession-management.mdstreaming.md
tile.json

aggregations.mddocs/

Aggregations and Grouping

Spark SQL provides powerful aggregation capabilities through both untyped DataFrame operations (RelationalGroupedDataset) and type-safe Dataset operations (KeyValueGroupedDataset). These enable complex analytical operations including grouping, pivoting, window functions, and custom aggregations.

RelationalGroupedDataset (Untyped Aggregations)

class RelationalGroupedDataset {
  def agg(expr: Column, exprs: Column*): DataFrame
  def agg(exprs: Map[String, String]): DataFrame
  def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame
  
  def count(): DataFrame
  def sum(colNames: String*): DataFrame
  def avg(colNames: String*): DataFrame
  def mean(colNames: String*): DataFrame
  def max(colNames: String*): DataFrame
  def min(colNames: String*): DataFrame
  
  def pivot(pivotColumn: String): RelationalGroupedDataset
  def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
  def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset
}

Basic Grouping and Aggregation

Usage Examples:

import org.apache.spark.sql.functions._

// Group by single column
val salesByRegion = df
  .groupBy("region")
  .agg(
    sum("amount").alias("total_sales"),
    avg("amount").alias("avg_sale"),
    count("*").alias("num_transactions"),
    max("date").alias("latest_date")
  )

// Group by multiple columns
val departmentStats = df
  .groupBy("department", "location")
  .agg(
    count("employee_id").alias("employee_count"),
    avg("salary").alias("avg_salary"),
    sum("salary").alias("total_payroll")
  )

// Using Map for aggregations
val aggregations = Map(
  "sales" -> "sum",
  "quantity" -> "avg",
  "price" -> "max"
)
val mapAgg = df.groupBy("category").agg(aggregations)

// Using tuple syntax
val tupleAgg = df
  .groupBy("region")
  .agg(("sales", "sum"), ("quantity", "avg"), ("price", "max"))

Advanced Aggregations

// Multiple aggregations with conditions
val conditionalAgg = df
  .groupBy("department")
  .agg(
    sum("salary").alias("total_salary"),
    sum(when(col("gender") === "F", col("salary")).otherwise(0)).alias("female_salary"),
    sum(when(col("gender") === "M", col("salary")).otherwise(0)).alias("male_salary"),
    countDistinct("employee_id").alias("unique_employees"),
    collect_list("name").alias("employee_names"),
    collect_set("role").alias("unique_roles")
  )

// Statistical aggregations
val statistics = df
  .groupBy("category")
  .agg(
    count("*").alias("count"),
    sum("value").alias("sum"),
    avg("value").alias("mean"),
    stddev("value").alias("std_dev"),
    variance("value").alias("variance"),
    min("value").alias("min_value"),
    max("value").alias("max_value"),
    skewness("value").alias("skewness"),
    kurtosis("value").alias("kurtosis")
  )

// Percentiles using expr()
val percentiles = df
  .groupBy("region")
  .agg(
    expr("percentile_approx(sales, 0.5)").alias("median_sales"),
    expr("percentile_approx(sales, array(0.25, 0.75))").alias("quartiles")
  )

Pivot Operations

// Basic pivot
val pivoted = df
  .groupBy("region")
  .pivot("product_category")
  .sum("sales")

// Pivot with specific values (more efficient)
val efficientPivot = df
  .groupBy("year", "quarter")
  .pivot("region", Seq("North", "South", "East", "West"))
  .agg(sum("revenue"), avg("profit_margin"))

// Multiple aggregations with pivot
val multiAggPivot = df
  .groupBy("department")
  .pivot("quarter")
  .agg(
    sum("sales").alias("total_sales"),
    count("*").alias("transaction_count")
  )

// Pivot with dynamic values
val dynamicPivot = df
  .groupBy("region")
  .pivot("product_category")
  .agg(
    sum("sales"),
    avg("price"),
    countDistinct("customer_id")
  )

KeyValueGroupedDataset (Type-Safe Aggregations)

class KeyValueGroupedDataset[K, V] {
  def agg[U1: Encoder](column: TypedColumn[V, U1]): Dataset[(K, U1)]
  def agg[U1: Encoder, U2: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
  def agg[U1: Encoder, U2: Encoder, U3: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]
  def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]
  
  def count(): Dataset[(K, Long)]
  def keys: Dataset[K]
  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V]
  def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W]
  
  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]
  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
  def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
  
  def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
  def flatMapGroupsWithState[S: Encoder, U: Encoder](outputMode: OutputMode, timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
  
  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]
}

Type-Safe Aggregations

Usage Examples:

import org.apache.spark.sql.expressions.scalalang.typed._

case class Sale(region: String, amount: Double, quantity: Int)
case class Employee(department: String, salary: Double, age: Int)

val salesDS: Dataset[Sale] = df.as[Sale]
val employeeDS: Dataset[Employee] = df.as[Employee]

// Type-safe grouping and aggregation
val regionSales = salesDS
  .groupByKey(_.region)
  .agg(
    typed.sum(_.amount).name("total_amount"),
    typed.avg(_.amount).name("avg_amount"),
    typed.count(_.quantity).name("transaction_count")
  )

// Multiple typed aggregations
val deptStats = employeeDS
  .groupByKey(_.department)
  .agg(
    typed.avg(_.salary).name("avg_salary"),
    typed.sum(_.salary).name("total_salary"),
    typed.count(_.age).name("employee_count"),
    typed.max(_.age).name("max_age")
  )

// Simple count
val counts = salesDS
  .groupByKey(_.region)
  .count()

Custom Group Operations

// Map groups to custom results
val customSummary = salesDS
  .groupByKey(_.region)
  .mapGroups { (region, sales) =>
    val salesList = sales.toList
    val total = salesList.map(_.amount).sum
    val count = salesList.size
    val avgAmount = if (count > 0) total / count else 0.0
    (region, total, count, avgAmount)
  }

// Reduce within groups
val topSaleByRegion = salesDS
  .groupByKey(_.region)
  .reduceGroups((sale1, sale2) => if (sale1.amount > sale2.amount) sale1 else sale2)

// FlatMap groups for multiple results per group
val quarterlySales = salesDS
  .groupByKey(_.region)
  .flatMapGroups { (region, sales) =>
    val salesList = sales.toList
    val quarters = salesList.groupBy(s => getQuarter(s.date))
    quarters.map { case (quarter, qSales) =>
      (region, quarter, qSales.map(_.amount).sum)
    }
  }

def getQuarter(date: java.sql.Date): Int = {
  val cal = java.util.Calendar.getInstance()
  cal.setTime(date)
  (cal.get(java.util.Calendar.MONTH) / 3) + 1
}

Key Transformations

// Transform key type
val byFirstLetter = salesDS
  .groupByKey(_.region)
  .keyAs[String](_.substring(0, 1))
  .count()

// Transform values while maintaining grouping
val withNormalizedAmounts = salesDS
  .groupByKey(_.region)
  .mapValues(sale => sale.copy(amount = sale.amount / 1000.0))
  .agg(typed.sum(_.amount))

Window Functions and Analytics

import org.apache.spark.sql.expressions.Window

object Window {
  def partitionBy(cols: Column*): WindowSpec
  def partitionBy(colNames: String*): WindowSpec
  def orderBy(cols: Column*): WindowSpec
  def orderBy(colNames: String*): WindowSpec
  def rowsBetween(start: Long, end: Long): WindowSpec
  def rangeBetween(start: Long, end: Long): WindowSpec
  
  val unboundedPreceding: Long
  val unboundedFollowing: Long
  val currentRow: Long
}

class WindowSpec {
  def partitionBy(cols: Column*): WindowSpec
  def orderBy(cols: Column*): WindowSpec
  def rowsBetween(start: Long, end: Long): WindowSpec
  def rangeBetween(start: Long, end: Long): WindowSpec
}

Window Aggregations

Usage Examples:

import org.apache.spark.sql.expressions.Window

// Basic window specifications
val partitionWindow = Window.partitionBy("department")
val orderWindow = Window.partitionBy("department").orderBy(col("salary").desc)
val frameWindow = Window
  .partitionBy("department")
  .orderBy("hire_date")
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

// Ranking functions
val withRanking = df
  .withColumn("salary_rank", rank().over(orderWindow))
  .withColumn("salary_dense_rank", dense_rank().over(orderWindow))
  .withColumn("row_number", row_number().over(orderWindow))
  .withColumn("percentile", percent_rank().over(orderWindow))

// Running aggregations
val runningTotals = df
  .withColumn("running_total", sum("salary").over(frameWindow))
  .withColumn("running_avg", avg("salary").over(frameWindow))
  .withColumn("running_count", count("*").over(frameWindow))

// Lead and lag for comparisons
val withOffsets = df
  .withColumn("prev_salary", lag("salary", 1).over(orderWindow))
  .withColumn("next_salary", lead("salary", 1).over(orderWindow))
  .withColumn("salary_change", 
    col("salary") - lag("salary", 1).over(orderWindow))

// Moving averages
val movingAvgWindow = Window
  .partitionBy("stock_symbol")
  .orderBy("date")
  .rowsBetween(-6, 0)  // 7-day moving average

val withMovingAvg = df
  .withColumn("price_7day_avg", avg("close_price").over(movingAvgWindow))
  .withColumn("volume_7day_avg", avg("volume").over(movingAvgWindow))

// Cumulative operations
val cumulativeWindow = Window
  .partitionBy("customer_id")
  .orderBy("order_date")
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

val withCumulative = df
  .withColumn("cumulative_spend", sum("order_amount").over(cumulativeWindow))
  .withColumn("order_sequence", row_number().over(cumulativeWindow))

Advanced Window Functions

// N-tile for bucketing
val quartileWindow = Window.partitionBy("department").orderBy("performance_score")
val withQuartiles = df
  .withColumn("performance_quartile", ntile(4).over(quartileWindow))

// Cumulative distribution
val withCumeDist = df
  .withColumn("cum_dist", cume_dist().over(orderWindow))

// Range-based windows
val rangeWindow = Window
  .partitionBy("category")
  .orderBy("price")
  .rangeBetween(-100, 100)  // Within $100 of current price

val priceComparisons = df
  .withColumn("similar_price_count", count("*").over(rangeWindow))
  .withColumn("similar_price_avg", avg("rating").over(rangeWindow))

// Multiple window specs for complex analytics
val salesAnalytics = df
  .withColumn("monthly_total", sum("sales").over(
    Window.partitionBy("region", "year", "month")))
  .withColumn("ytd_total", sum("sales").over(
    Window.partitionBy("region", "year").orderBy("month")
      .rowsBetween(Window.unboundedPreceding, Window.currentRow)))
  .withColumn("region_rank", rank().over(
    Window.partitionBy("year", "month").orderBy(col("sales").desc)))

Pivot and Unpivot Operations

Complex Pivot Examples

// Multi-level pivot
val multiPivot = df
  .groupBy("region", "year")
  .pivot("quarter", Seq("Q1", "Q2", "Q3", "Q4"))
  .agg(
    sum("revenue").alias("revenue"),
    sum("profit").alias("profit"),
    count("transactions").alias("tx_count")
  )

// Conditional pivot
val conditionalPivot = df
  .groupBy("product")
  .pivot("status")
  .agg(
    sum(when(col("type") === "sale", col("amount")).otherwise(0)).alias("sales"),
    sum(when(col("type") === "return", col("amount")).otherwise(0)).alias("returns")
  )

// Cross-tab analysis
val crossTab = df.crosstab("education_level", "income_bracket")

Unpivot Operations (Manual)

// Manual unpivot using union
val unpivoted = df.select("id", "Q1", "Q2", "Q3", "Q4")
  .select(
    col("id"),
    lit("Q1").alias("quarter"),
    col("Q1").alias("value")
  )
  .union(
    df.select(col("id"), lit("Q2").alias("quarter"), col("Q2").alias("value"))
  )
  .union(
    df.select(col("id"), lit("Q3").alias("quarter"), col("Q3").alias("value"))
  )
  .union(
    df.select(col("id"), lit("Q4").alias("quarter"), col("Q4").alias("value"))
  )

// Using melt pattern with stack function
val melted = df.selectExpr(
  "id",
  "stack(4, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, value)"
)

Custom Aggregation Functions

User-Defined Aggregate Functions (UDAF)

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

// Example: Geometric Mean UDAF
class GeometricMean extends UserDefinedAggregateFunction {
  // Input schema
  def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)
  
  // Buffer schema
  def bufferSchema: StructType = StructType(Array(
    StructField("logSum", DoubleType),
    StructField("count", LongType)
  ))
  
  // Output type
  def dataType: DataType = DoubleType
  
  // Deterministic
  def deterministic: Boolean = true
  
  // Initialize buffer
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0  // logSum
    buffer(1) = 0L   // count
  }
  
  // Update buffer with input
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      val value = input.getDouble(0)
      if (value > 0) {
        buffer(0) = buffer.getDouble(0) + math.log(value)
        buffer(1) = buffer.getLong(1) + 1
      }
    }
  }
  
  // Merge two buffers
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  
  // Compute final result
  def evaluate(buffer: Row): Any = {
    val count = buffer.getLong(1)
    if (count > 0) {
      math.exp(buffer.getDouble(0) / count)
    } else {
      null
    }
  }
}

// Register and use UDAF
val geometricMean = new GeometricMean()
spark.udf.register("geometric_mean", geometricMean)

val result = df
  .groupBy("category")
  .agg(geometricMean(col("value")).alias("geom_mean"))

Performance Optimization for Aggregations

Optimization Techniques

// Pre-aggregate before expensive operations
val preAggregated = df
  .filter(col("date") >= "2023-01-01")  // Filter early
  .groupBy("region", "product")
  .agg(sum("sales").alias("total_sales"))
  .cache()  // Cache intermediate result

// Use broadcast for small dimension tables
val regionMapping = spark.table("region_mapping")
broadcast(regionMapping)

val enrichedAgg = df
  .join(broadcast(regionMapping), "region_id")
  .groupBy("region_name")
  .agg(sum("sales"))

// Bucketing for frequent aggregations
df.write
  .bucketBy(10, "category")
  .sortBy("date")
  .saveAsTable("bucketed_sales")

// The bucketed table will have better performance for category-based aggregations
val bucketedAgg = spark.table("bucketed_sales")
  .groupBy("category")
  .agg(sum("amount"))

// Partial aggregation hints
val partialAgg = df
  .hint("BROADCAST", regionMapping)
  .groupBy("region")
  .agg(sum("sales"))