Spark SQL provides powerful aggregation capabilities through both untyped DataFrame operations (RelationalGroupedDataset) and type-safe Dataset operations (KeyValueGroupedDataset). These enable complex analytical operations including grouping, pivoting, window functions, and custom aggregations.
class RelationalGroupedDataset {
def agg(expr: Column, exprs: Column*): DataFrame
def agg(exprs: Map[String, String]): DataFrame
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame
def count(): DataFrame
def sum(colNames: String*): DataFrame
def avg(colNames: String*): DataFrame
def mean(colNames: String*): DataFrame
def max(colNames: String*): DataFrame
def min(colNames: String*): DataFrame
def pivot(pivotColumn: String): RelationalGroupedDataset
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset
}Usage Examples:
import org.apache.spark.sql.functions._
// Group by single column
val salesByRegion = df
.groupBy("region")
.agg(
sum("amount").alias("total_sales"),
avg("amount").alias("avg_sale"),
count("*").alias("num_transactions"),
max("date").alias("latest_date")
)
// Group by multiple columns
val departmentStats = df
.groupBy("department", "location")
.agg(
count("employee_id").alias("employee_count"),
avg("salary").alias("avg_salary"),
sum("salary").alias("total_payroll")
)
// Using Map for aggregations
val aggregations = Map(
"sales" -> "sum",
"quantity" -> "avg",
"price" -> "max"
)
val mapAgg = df.groupBy("category").agg(aggregations)
// Using tuple syntax
val tupleAgg = df
.groupBy("region")
.agg(("sales", "sum"), ("quantity", "avg"), ("price", "max"))// Multiple aggregations with conditions
val conditionalAgg = df
.groupBy("department")
.agg(
sum("salary").alias("total_salary"),
sum(when(col("gender") === "F", col("salary")).otherwise(0)).alias("female_salary"),
sum(when(col("gender") === "M", col("salary")).otherwise(0)).alias("male_salary"),
countDistinct("employee_id").alias("unique_employees"),
collect_list("name").alias("employee_names"),
collect_set("role").alias("unique_roles")
)
// Statistical aggregations
val statistics = df
.groupBy("category")
.agg(
count("*").alias("count"),
sum("value").alias("sum"),
avg("value").alias("mean"),
stddev("value").alias("std_dev"),
variance("value").alias("variance"),
min("value").alias("min_value"),
max("value").alias("max_value"),
skewness("value").alias("skewness"),
kurtosis("value").alias("kurtosis")
)
// Percentiles using expr()
val percentiles = df
.groupBy("region")
.agg(
expr("percentile_approx(sales, 0.5)").alias("median_sales"),
expr("percentile_approx(sales, array(0.25, 0.75))").alias("quartiles")
)// Basic pivot
val pivoted = df
.groupBy("region")
.pivot("product_category")
.sum("sales")
// Pivot with specific values (more efficient)
val efficientPivot = df
.groupBy("year", "quarter")
.pivot("region", Seq("North", "South", "East", "West"))
.agg(sum("revenue"), avg("profit_margin"))
// Multiple aggregations with pivot
val multiAggPivot = df
.groupBy("department")
.pivot("quarter")
.agg(
sum("sales").alias("total_sales"),
count("*").alias("transaction_count")
)
// Pivot with dynamic values
val dynamicPivot = df
.groupBy("region")
.pivot("product_category")
.agg(
sum("sales"),
avg("price"),
countDistinct("customer_id")
)class KeyValueGroupedDataset[K, V] {
def agg[U1: Encoder](column: TypedColumn[V, U1]): Dataset[(K, U1)]
def agg[U1: Encoder, U2: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
def agg[U1: Encoder, U2: Encoder, U3: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]
def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]
def count(): Dataset[(K, Long)]
def keys: Dataset[K]
def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V]
def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W]
def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]
def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
def flatMapGroupsWithState[S: Encoder, U: Encoder](outputMode: OutputMode, timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]
}Usage Examples:
import org.apache.spark.sql.expressions.scalalang.typed._
case class Sale(region: String, amount: Double, quantity: Int)
case class Employee(department: String, salary: Double, age: Int)
val salesDS: Dataset[Sale] = df.as[Sale]
val employeeDS: Dataset[Employee] = df.as[Employee]
// Type-safe grouping and aggregation
val regionSales = salesDS
.groupByKey(_.region)
.agg(
typed.sum(_.amount).name("total_amount"),
typed.avg(_.amount).name("avg_amount"),
typed.count(_.quantity).name("transaction_count")
)
// Multiple typed aggregations
val deptStats = employeeDS
.groupByKey(_.department)
.agg(
typed.avg(_.salary).name("avg_salary"),
typed.sum(_.salary).name("total_salary"),
typed.count(_.age).name("employee_count"),
typed.max(_.age).name("max_age")
)
// Simple count
val counts = salesDS
.groupByKey(_.region)
.count()// Map groups to custom results
val customSummary = salesDS
.groupByKey(_.region)
.mapGroups { (region, sales) =>
val salesList = sales.toList
val total = salesList.map(_.amount).sum
val count = salesList.size
val avgAmount = if (count > 0) total / count else 0.0
(region, total, count, avgAmount)
}
// Reduce within groups
val topSaleByRegion = salesDS
.groupByKey(_.region)
.reduceGroups((sale1, sale2) => if (sale1.amount > sale2.amount) sale1 else sale2)
// FlatMap groups for multiple results per group
val quarterlySales = salesDS
.groupByKey(_.region)
.flatMapGroups { (region, sales) =>
val salesList = sales.toList
val quarters = salesList.groupBy(s => getQuarter(s.date))
quarters.map { case (quarter, qSales) =>
(region, quarter, qSales.map(_.amount).sum)
}
}
def getQuarter(date: java.sql.Date): Int = {
val cal = java.util.Calendar.getInstance()
cal.setTime(date)
(cal.get(java.util.Calendar.MONTH) / 3) + 1
}// Transform key type
val byFirstLetter = salesDS
.groupByKey(_.region)
.keyAs[String](_.substring(0, 1))
.count()
// Transform values while maintaining grouping
val withNormalizedAmounts = salesDS
.groupByKey(_.region)
.mapValues(sale => sale.copy(amount = sale.amount / 1000.0))
.agg(typed.sum(_.amount))import org.apache.spark.sql.expressions.Window
object Window {
def partitionBy(cols: Column*): WindowSpec
def partitionBy(colNames: String*): WindowSpec
def orderBy(cols: Column*): WindowSpec
def orderBy(colNames: String*): WindowSpec
def rowsBetween(start: Long, end: Long): WindowSpec
def rangeBetween(start: Long, end: Long): WindowSpec
val unboundedPreceding: Long
val unboundedFollowing: Long
val currentRow: Long
}
class WindowSpec {
def partitionBy(cols: Column*): WindowSpec
def orderBy(cols: Column*): WindowSpec
def rowsBetween(start: Long, end: Long): WindowSpec
def rangeBetween(start: Long, end: Long): WindowSpec
}Usage Examples:
import org.apache.spark.sql.expressions.Window
// Basic window specifications
val partitionWindow = Window.partitionBy("department")
val orderWindow = Window.partitionBy("department").orderBy(col("salary").desc)
val frameWindow = Window
.partitionBy("department")
.orderBy("hire_date")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
// Ranking functions
val withRanking = df
.withColumn("salary_rank", rank().over(orderWindow))
.withColumn("salary_dense_rank", dense_rank().over(orderWindow))
.withColumn("row_number", row_number().over(orderWindow))
.withColumn("percentile", percent_rank().over(orderWindow))
// Running aggregations
val runningTotals = df
.withColumn("running_total", sum("salary").over(frameWindow))
.withColumn("running_avg", avg("salary").over(frameWindow))
.withColumn("running_count", count("*").over(frameWindow))
// Lead and lag for comparisons
val withOffsets = df
.withColumn("prev_salary", lag("salary", 1).over(orderWindow))
.withColumn("next_salary", lead("salary", 1).over(orderWindow))
.withColumn("salary_change",
col("salary") - lag("salary", 1).over(orderWindow))
// Moving averages
val movingAvgWindow = Window
.partitionBy("stock_symbol")
.orderBy("date")
.rowsBetween(-6, 0) // 7-day moving average
val withMovingAvg = df
.withColumn("price_7day_avg", avg("close_price").over(movingAvgWindow))
.withColumn("volume_7day_avg", avg("volume").over(movingAvgWindow))
// Cumulative operations
val cumulativeWindow = Window
.partitionBy("customer_id")
.orderBy("order_date")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val withCumulative = df
.withColumn("cumulative_spend", sum("order_amount").over(cumulativeWindow))
.withColumn("order_sequence", row_number().over(cumulativeWindow))// N-tile for bucketing
val quartileWindow = Window.partitionBy("department").orderBy("performance_score")
val withQuartiles = df
.withColumn("performance_quartile", ntile(4).over(quartileWindow))
// Cumulative distribution
val withCumeDist = df
.withColumn("cum_dist", cume_dist().over(orderWindow))
// Range-based windows
val rangeWindow = Window
.partitionBy("category")
.orderBy("price")
.rangeBetween(-100, 100) // Within $100 of current price
val priceComparisons = df
.withColumn("similar_price_count", count("*").over(rangeWindow))
.withColumn("similar_price_avg", avg("rating").over(rangeWindow))
// Multiple window specs for complex analytics
val salesAnalytics = df
.withColumn("monthly_total", sum("sales").over(
Window.partitionBy("region", "year", "month")))
.withColumn("ytd_total", sum("sales").over(
Window.partitionBy("region", "year").orderBy("month")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)))
.withColumn("region_rank", rank().over(
Window.partitionBy("year", "month").orderBy(col("sales").desc)))// Multi-level pivot
val multiPivot = df
.groupBy("region", "year")
.pivot("quarter", Seq("Q1", "Q2", "Q3", "Q4"))
.agg(
sum("revenue").alias("revenue"),
sum("profit").alias("profit"),
count("transactions").alias("tx_count")
)
// Conditional pivot
val conditionalPivot = df
.groupBy("product")
.pivot("status")
.agg(
sum(when(col("type") === "sale", col("amount")).otherwise(0)).alias("sales"),
sum(when(col("type") === "return", col("amount")).otherwise(0)).alias("returns")
)
// Cross-tab analysis
val crossTab = df.crosstab("education_level", "income_bracket")// Manual unpivot using union
val unpivoted = df.select("id", "Q1", "Q2", "Q3", "Q4")
.select(
col("id"),
lit("Q1").alias("quarter"),
col("Q1").alias("value")
)
.union(
df.select(col("id"), lit("Q2").alias("quarter"), col("Q2").alias("value"))
)
.union(
df.select(col("id"), lit("Q3").alias("quarter"), col("Q3").alias("value"))
)
.union(
df.select(col("id"), lit("Q4").alias("quarter"), col("Q4").alias("value"))
)
// Using melt pattern with stack function
val melted = df.selectExpr(
"id",
"stack(4, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, value)"
)import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
// Example: Geometric Mean UDAF
class GeometricMean extends UserDefinedAggregateFunction {
// Input schema
def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)
// Buffer schema
def bufferSchema: StructType = StructType(Array(
StructField("logSum", DoubleType),
StructField("count", LongType)
))
// Output type
def dataType: DataType = DoubleType
// Deterministic
def deterministic: Boolean = true
// Initialize buffer
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0 // logSum
buffer(1) = 0L // count
}
// Update buffer with input
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val value = input.getDouble(0)
if (value > 0) {
buffer(0) = buffer.getDouble(0) + math.log(value)
buffer(1) = buffer.getLong(1) + 1
}
}
}
// Merge two buffers
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// Compute final result
def evaluate(buffer: Row): Any = {
val count = buffer.getLong(1)
if (count > 0) {
math.exp(buffer.getDouble(0) / count)
} else {
null
}
}
}
// Register and use UDAF
val geometricMean = new GeometricMean()
spark.udf.register("geometric_mean", geometricMean)
val result = df
.groupBy("category")
.agg(geometricMean(col("value")).alias("geom_mean"))// Pre-aggregate before expensive operations
val preAggregated = df
.filter(col("date") >= "2023-01-01") // Filter early
.groupBy("region", "product")
.agg(sum("sales").alias("total_sales"))
.cache() // Cache intermediate result
// Use broadcast for small dimension tables
val regionMapping = spark.table("region_mapping")
broadcast(regionMapping)
val enrichedAgg = df
.join(broadcast(regionMapping), "region_id")
.groupBy("region_name")
.agg(sum("sales"))
// Bucketing for frequent aggregations
df.write
.bucketBy(10, "category")
.sortBy("date")
.saveAsTable("bucketed_sales")
// The bucketed table will have better performance for category-based aggregations
val bucketedAgg = spark.table("bucketed_sales")
.groupBy("category")
.agg(sum("amount"))
// Partial aggregation hints
val partialAgg = df
.hint("BROADCAST", regionMapping)
.groupBy("region")
.agg(sum("sales"))