0
# Aggregations and Grouping
1
2
Spark SQL provides powerful aggregation capabilities through both untyped DataFrame operations (RelationalGroupedDataset) and type-safe Dataset operations (KeyValueGroupedDataset). These enable complex analytical operations including grouping, pivoting, window functions, and custom aggregations.
3
4
## RelationalGroupedDataset (Untyped Aggregations)
5
6
```scala { .api }
7
class RelationalGroupedDataset {
8
def agg(expr: Column, exprs: Column*): DataFrame
9
def agg(exprs: Map[String, String]): DataFrame
10
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame
11
12
def count(): DataFrame
13
def sum(colNames: String*): DataFrame
14
def avg(colNames: String*): DataFrame
15
def mean(colNames: String*): DataFrame
16
def max(colNames: String*): DataFrame
17
def min(colNames: String*): DataFrame
18
19
def pivot(pivotColumn: String): RelationalGroupedDataset
20
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
21
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset
22
}
23
```
24
25
### Basic Grouping and Aggregation
26
27
**Usage Examples:**
28
29
```scala
30
import org.apache.spark.sql.functions._
31
32
// Group by single column
33
val salesByRegion = df
34
.groupBy("region")
35
.agg(
36
sum("amount").alias("total_sales"),
37
avg("amount").alias("avg_sale"),
38
count("*").alias("num_transactions"),
39
max("date").alias("latest_date")
40
)
41
42
// Group by multiple columns
43
val departmentStats = df
44
.groupBy("department", "location")
45
.agg(
46
count("employee_id").alias("employee_count"),
47
avg("salary").alias("avg_salary"),
48
sum("salary").alias("total_payroll")
49
)
50
51
// Using Map for aggregations
52
val aggregations = Map(
53
"sales" -> "sum",
54
"quantity" -> "avg",
55
"price" -> "max"
56
)
57
val mapAgg = df.groupBy("category").agg(aggregations)
58
59
// Using tuple syntax
60
val tupleAgg = df
61
.groupBy("region")
62
.agg(("sales", "sum"), ("quantity", "avg"), ("price", "max"))
63
```
64
65
### Advanced Aggregations
66
67
```scala
68
// Multiple aggregations with conditions
69
val conditionalAgg = df
70
.groupBy("department")
71
.agg(
72
sum("salary").alias("total_salary"),
73
sum(when(col("gender") === "F", col("salary")).otherwise(0)).alias("female_salary"),
74
sum(when(col("gender") === "M", col("salary")).otherwise(0)).alias("male_salary"),
75
countDistinct("employee_id").alias("unique_employees"),
76
collect_list("name").alias("employee_names"),
77
collect_set("role").alias("unique_roles")
78
)
79
80
// Statistical aggregations
81
val statistics = df
82
.groupBy("category")
83
.agg(
84
count("*").alias("count"),
85
sum("value").alias("sum"),
86
avg("value").alias("mean"),
87
stddev("value").alias("std_dev"),
88
variance("value").alias("variance"),
89
min("value").alias("min_value"),
90
max("value").alias("max_value"),
91
skewness("value").alias("skewness"),
92
kurtosis("value").alias("kurtosis")
93
)
94
95
// Percentiles using expr()
96
val percentiles = df
97
.groupBy("region")
98
.agg(
99
expr("percentile_approx(sales, 0.5)").alias("median_sales"),
100
expr("percentile_approx(sales, array(0.25, 0.75))").alias("quartiles")
101
)
102
```
103
104
### Pivot Operations
105
106
```scala
107
// Basic pivot
108
val pivoted = df
109
.groupBy("region")
110
.pivot("product_category")
111
.sum("sales")
112
113
// Pivot with specific values (more efficient)
114
val efficientPivot = df
115
.groupBy("year", "quarter")
116
.pivot("region", Seq("North", "South", "East", "West"))
117
.agg(sum("revenue"), avg("profit_margin"))
118
119
// Multiple aggregations with pivot
120
val multiAggPivot = df
121
.groupBy("department")
122
.pivot("quarter")
123
.agg(
124
sum("sales").alias("total_sales"),
125
count("*").alias("transaction_count")
126
)
127
128
// Pivot with dynamic values
129
val dynamicPivot = df
130
.groupBy("region")
131
.pivot("product_category")
132
.agg(
133
sum("sales"),
134
avg("price"),
135
countDistinct("customer_id")
136
)
137
```
138
139
## KeyValueGroupedDataset (Type-Safe Aggregations)
140
141
```scala { .api }
142
class KeyValueGroupedDataset[K, V] {
143
def agg[U1: Encoder](column: TypedColumn[V, U1]): Dataset[(K, U1)]
144
def agg[U1: Encoder, U2: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
145
def agg[U1: Encoder, U2: Encoder, U3: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]
146
def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]
147
148
def count(): Dataset[(K, Long)]
149
def keys: Dataset[K]
150
def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V]
151
def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W]
152
153
def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]
154
def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
155
def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
156
157
def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
158
def flatMapGroupsWithState[S: Encoder, U: Encoder](outputMode: OutputMode, timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
159
160
def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]
161
}
162
```
163
164
### Type-Safe Aggregations
165
166
**Usage Examples:**
167
168
```scala
169
import org.apache.spark.sql.expressions.scalalang.typed._
170
171
case class Sale(region: String, amount: Double, quantity: Int)
172
case class Employee(department: String, salary: Double, age: Int)
173
174
val salesDS: Dataset[Sale] = df.as[Sale]
175
val employeeDS: Dataset[Employee] = df.as[Employee]
176
177
// Type-safe grouping and aggregation
178
val regionSales = salesDS
179
.groupByKey(_.region)
180
.agg(
181
typed.sum(_.amount).name("total_amount"),
182
typed.avg(_.amount).name("avg_amount"),
183
typed.count(_.quantity).name("transaction_count")
184
)
185
186
// Multiple typed aggregations
187
val deptStats = employeeDS
188
.groupByKey(_.department)
189
.agg(
190
typed.avg(_.salary).name("avg_salary"),
191
typed.sum(_.salary).name("total_salary"),
192
typed.count(_.age).name("employee_count"),
193
typed.max(_.age).name("max_age")
194
)
195
196
// Simple count
197
val counts = salesDS
198
.groupByKey(_.region)
199
.count()
200
```
201
202
### Custom Group Operations
203
204
```scala
205
// Map groups to custom results
206
val customSummary = salesDS
207
.groupByKey(_.region)
208
.mapGroups { (region, sales) =>
209
val salesList = sales.toList
210
val total = salesList.map(_.amount).sum
211
val count = salesList.size
212
val avgAmount = if (count > 0) total / count else 0.0
213
(region, total, count, avgAmount)
214
}
215
216
// Reduce within groups
217
val topSaleByRegion = salesDS
218
.groupByKey(_.region)
219
.reduceGroups((sale1, sale2) => if (sale1.amount > sale2.amount) sale1 else sale2)
220
221
// FlatMap groups for multiple results per group
222
val quarterlySales = salesDS
223
.groupByKey(_.region)
224
.flatMapGroups { (region, sales) =>
225
val salesList = sales.toList
226
val quarters = salesList.groupBy(s => getQuarter(s.date))
227
quarters.map { case (quarter, qSales) =>
228
(region, quarter, qSales.map(_.amount).sum)
229
}
230
}
231
232
def getQuarter(date: java.sql.Date): Int = {
233
val cal = java.util.Calendar.getInstance()
234
cal.setTime(date)
235
(cal.get(java.util.Calendar.MONTH) / 3) + 1
236
}
237
```
238
239
### Key Transformations
240
241
```scala
242
// Transform key type
243
val byFirstLetter = salesDS
244
.groupByKey(_.region)
245
.keyAs[String](_.substring(0, 1))
246
.count()
247
248
// Transform values while maintaining grouping
249
val withNormalizedAmounts = salesDS
250
.groupByKey(_.region)
251
.mapValues(sale => sale.copy(amount = sale.amount / 1000.0))
252
.agg(typed.sum(_.amount))
253
```
254
255
## Window Functions and Analytics
256
257
```scala { .api }
258
import org.apache.spark.sql.expressions.Window
259
260
object Window {
261
def partitionBy(cols: Column*): WindowSpec
262
def partitionBy(colNames: String*): WindowSpec
263
def orderBy(cols: Column*): WindowSpec
264
def orderBy(colNames: String*): WindowSpec
265
def rowsBetween(start: Long, end: Long): WindowSpec
266
def rangeBetween(start: Long, end: Long): WindowSpec
267
268
val unboundedPreceding: Long
269
val unboundedFollowing: Long
270
val currentRow: Long
271
}
272
273
class WindowSpec {
274
def partitionBy(cols: Column*): WindowSpec
275
def orderBy(cols: Column*): WindowSpec
276
def rowsBetween(start: Long, end: Long): WindowSpec
277
def rangeBetween(start: Long, end: Long): WindowSpec
278
}
279
```
280
281
### Window Aggregations
282
283
**Usage Examples:**
284
285
```scala
286
import org.apache.spark.sql.expressions.Window
287
288
// Basic window specifications
289
val partitionWindow = Window.partitionBy("department")
290
val orderWindow = Window.partitionBy("department").orderBy(col("salary").desc)
291
val frameWindow = Window
292
.partitionBy("department")
293
.orderBy("hire_date")
294
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
295
296
// Ranking functions
297
val withRanking = df
298
.withColumn("salary_rank", rank().over(orderWindow))
299
.withColumn("salary_dense_rank", dense_rank().over(orderWindow))
300
.withColumn("row_number", row_number().over(orderWindow))
301
.withColumn("percentile", percent_rank().over(orderWindow))
302
303
// Running aggregations
304
val runningTotals = df
305
.withColumn("running_total", sum("salary").over(frameWindow))
306
.withColumn("running_avg", avg("salary").over(frameWindow))
307
.withColumn("running_count", count("*").over(frameWindow))
308
309
// Lead and lag for comparisons
310
val withOffsets = df
311
.withColumn("prev_salary", lag("salary", 1).over(orderWindow))
312
.withColumn("next_salary", lead("salary", 1).over(orderWindow))
313
.withColumn("salary_change",
314
col("salary") - lag("salary", 1).over(orderWindow))
315
316
// Moving averages
317
val movingAvgWindow = Window
318
.partitionBy("stock_symbol")
319
.orderBy("date")
320
.rowsBetween(-6, 0) // 7-day moving average
321
322
val withMovingAvg = df
323
.withColumn("price_7day_avg", avg("close_price").over(movingAvgWindow))
324
.withColumn("volume_7day_avg", avg("volume").over(movingAvgWindow))
325
326
// Cumulative operations
327
val cumulativeWindow = Window
328
.partitionBy("customer_id")
329
.orderBy("order_date")
330
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
331
332
val withCumulative = df
333
.withColumn("cumulative_spend", sum("order_amount").over(cumulativeWindow))
334
.withColumn("order_sequence", row_number().over(cumulativeWindow))
335
```
336
337
### Advanced Window Functions
338
339
```scala
340
// N-tile for bucketing
341
val quartileWindow = Window.partitionBy("department").orderBy("performance_score")
342
val withQuartiles = df
343
.withColumn("performance_quartile", ntile(4).over(quartileWindow))
344
345
// Cumulative distribution
346
val withCumeDist = df
347
.withColumn("cum_dist", cume_dist().over(orderWindow))
348
349
// Range-based windows
350
val rangeWindow = Window
351
.partitionBy("category")
352
.orderBy("price")
353
.rangeBetween(-100, 100) // Within $100 of current price
354
355
val priceComparisons = df
356
.withColumn("similar_price_count", count("*").over(rangeWindow))
357
.withColumn("similar_price_avg", avg("rating").over(rangeWindow))
358
359
// Multiple window specs for complex analytics
360
val salesAnalytics = df
361
.withColumn("monthly_total", sum("sales").over(
362
Window.partitionBy("region", "year", "month")))
363
.withColumn("ytd_total", sum("sales").over(
364
Window.partitionBy("region", "year").orderBy("month")
365
.rowsBetween(Window.unboundedPreceding, Window.currentRow)))
366
.withColumn("region_rank", rank().over(
367
Window.partitionBy("year", "month").orderBy(col("sales").desc)))
368
```
369
370
## Pivot and Unpivot Operations
371
372
### Complex Pivot Examples
373
374
```scala
375
// Multi-level pivot
376
val multiPivot = df
377
.groupBy("region", "year")
378
.pivot("quarter", Seq("Q1", "Q2", "Q3", "Q4"))
379
.agg(
380
sum("revenue").alias("revenue"),
381
sum("profit").alias("profit"),
382
count("transactions").alias("tx_count")
383
)
384
385
// Conditional pivot
386
val conditionalPivot = df
387
.groupBy("product")
388
.pivot("status")
389
.agg(
390
sum(when(col("type") === "sale", col("amount")).otherwise(0)).alias("sales"),
391
sum(when(col("type") === "return", col("amount")).otherwise(0)).alias("returns")
392
)
393
394
// Cross-tab analysis
395
val crossTab = df.crosstab("education_level", "income_bracket")
396
```
397
398
### Unpivot Operations (Manual)
399
400
```scala
401
// Manual unpivot using union
402
val unpivoted = df.select("id", "Q1", "Q2", "Q3", "Q4")
403
.select(
404
col("id"),
405
lit("Q1").alias("quarter"),
406
col("Q1").alias("value")
407
)
408
.union(
409
df.select(col("id"), lit("Q2").alias("quarter"), col("Q2").alias("value"))
410
)
411
.union(
412
df.select(col("id"), lit("Q3").alias("quarter"), col("Q3").alias("value"))
413
)
414
.union(
415
df.select(col("id"), lit("Q4").alias("quarter"), col("Q4").alias("value"))
416
)
417
418
// Using melt pattern with stack function
419
val melted = df.selectExpr(
420
"id",
421
"stack(4, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, value)"
422
)
423
```
424
425
## Custom Aggregation Functions
426
427
### User-Defined Aggregate Functions (UDAF)
428
429
```scala
430
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
431
import org.apache.spark.sql.types._
432
import org.apache.spark.sql.Row
433
434
// Example: Geometric Mean UDAF
435
class GeometricMean extends UserDefinedAggregateFunction {
436
// Input schema
437
def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)
438
439
// Buffer schema
440
def bufferSchema: StructType = StructType(Array(
441
StructField("logSum", DoubleType),
442
StructField("count", LongType)
443
))
444
445
// Output type
446
def dataType: DataType = DoubleType
447
448
// Deterministic
449
def deterministic: Boolean = true
450
451
// Initialize buffer
452
def initialize(buffer: MutableAggregationBuffer): Unit = {
453
buffer(0) = 0.0 // logSum
454
buffer(1) = 0L // count
455
}
456
457
// Update buffer with input
458
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
459
if (!input.isNullAt(0)) {
460
val value = input.getDouble(0)
461
if (value > 0) {
462
buffer(0) = buffer.getDouble(0) + math.log(value)
463
buffer(1) = buffer.getLong(1) + 1
464
}
465
}
466
}
467
468
// Merge two buffers
469
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
470
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
471
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
472
}
473
474
// Compute final result
475
def evaluate(buffer: Row): Any = {
476
val count = buffer.getLong(1)
477
if (count > 0) {
478
math.exp(buffer.getDouble(0) / count)
479
} else {
480
null
481
}
482
}
483
}
484
485
// Register and use UDAF
486
val geometricMean = new GeometricMean()
487
spark.udf.register("geometric_mean", geometricMean)
488
489
val result = df
490
.groupBy("category")
491
.agg(geometricMean(col("value")).alias("geom_mean"))
492
```
493
494
## Performance Optimization for Aggregations
495
496
### Optimization Techniques
497
498
```scala
499
// Pre-aggregate before expensive operations
500
val preAggregated = df
501
.filter(col("date") >= "2023-01-01") // Filter early
502
.groupBy("region", "product")
503
.agg(sum("sales").alias("total_sales"))
504
.cache() // Cache intermediate result
505
506
// Use broadcast for small dimension tables
507
val regionMapping = spark.table("region_mapping")
508
broadcast(regionMapping)
509
510
val enrichedAgg = df
511
.join(broadcast(regionMapping), "region_id")
512
.groupBy("region_name")
513
.agg(sum("sales"))
514
515
// Bucketing for frequent aggregations
516
df.write
517
.bucketBy(10, "category")
518
.sortBy("date")
519
.saveAsTable("bucketed_sales")
520
521
// The bucketed table will have better performance for category-based aggregations
522
val bucketedAgg = spark.table("bucketed_sales")
523
.groupBy("category")
524
.agg(sum("amount"))
525
526
// Partial aggregation hints
527
val partialAgg = df
528
.hint("BROADCAST", regionMapping)
529
.groupBy("region")
530
.agg(sum("sales"))
531
```