or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

aggregations.mdcatalog.mddata-io.mddataframe-dataset.mdfunctions-expressions.mdindex.mdsession-management.mdstreaming.md

aggregations.mddocs/

0

# Aggregations and Grouping

1

2

Spark SQL provides powerful aggregation capabilities through both untyped DataFrame operations (RelationalGroupedDataset) and type-safe Dataset operations (KeyValueGroupedDataset). These enable complex analytical operations including grouping, pivoting, window functions, and custom aggregations.

3

4

## RelationalGroupedDataset (Untyped Aggregations)

5

6

```scala { .api }

7

class RelationalGroupedDataset {

8

def agg(expr: Column, exprs: Column*): DataFrame

9

def agg(exprs: Map[String, String]): DataFrame

10

def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame

11

12

def count(): DataFrame

13

def sum(colNames: String*): DataFrame

14

def avg(colNames: String*): DataFrame

15

def mean(colNames: String*): DataFrame

16

def max(colNames: String*): DataFrame

17

def min(colNames: String*): DataFrame

18

19

def pivot(pivotColumn: String): RelationalGroupedDataset

20

def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset

21

def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset

22

}

23

```

24

25

### Basic Grouping and Aggregation

26

27

**Usage Examples:**

28

29

```scala

30

import org.apache.spark.sql.functions._

31

32

// Group by single column

33

val salesByRegion = df

34

.groupBy("region")

35

.agg(

36

sum("amount").alias("total_sales"),

37

avg("amount").alias("avg_sale"),

38

count("*").alias("num_transactions"),

39

max("date").alias("latest_date")

40

)

41

42

// Group by multiple columns

43

val departmentStats = df

44

.groupBy("department", "location")

45

.agg(

46

count("employee_id").alias("employee_count"),

47

avg("salary").alias("avg_salary"),

48

sum("salary").alias("total_payroll")

49

)

50

51

// Using Map for aggregations

52

val aggregations = Map(

53

"sales" -> "sum",

54

"quantity" -> "avg",

55

"price" -> "max"

56

)

57

val mapAgg = df.groupBy("category").agg(aggregations)

58

59

// Using tuple syntax

60

val tupleAgg = df

61

.groupBy("region")

62

.agg(("sales", "sum"), ("quantity", "avg"), ("price", "max"))

63

```

64

65

### Advanced Aggregations

66

67

```scala

68

// Multiple aggregations with conditions

69

val conditionalAgg = df

70

.groupBy("department")

71

.agg(

72

sum("salary").alias("total_salary"),

73

sum(when(col("gender") === "F", col("salary")).otherwise(0)).alias("female_salary"),

74

sum(when(col("gender") === "M", col("salary")).otherwise(0)).alias("male_salary"),

75

countDistinct("employee_id").alias("unique_employees"),

76

collect_list("name").alias("employee_names"),

77

collect_set("role").alias("unique_roles")

78

)

79

80

// Statistical aggregations

81

val statistics = df

82

.groupBy("category")

83

.agg(

84

count("*").alias("count"),

85

sum("value").alias("sum"),

86

avg("value").alias("mean"),

87

stddev("value").alias("std_dev"),

88

variance("value").alias("variance"),

89

min("value").alias("min_value"),

90

max("value").alias("max_value"),

91

skewness("value").alias("skewness"),

92

kurtosis("value").alias("kurtosis")

93

)

94

95

// Percentiles using expr()

96

val percentiles = df

97

.groupBy("region")

98

.agg(

99

expr("percentile_approx(sales, 0.5)").alias("median_sales"),

100

expr("percentile_approx(sales, array(0.25, 0.75))").alias("quartiles")

101

)

102

```

103

104

### Pivot Operations

105

106

```scala

107

// Basic pivot

108

val pivoted = df

109

.groupBy("region")

110

.pivot("product_category")

111

.sum("sales")

112

113

// Pivot with specific values (more efficient)

114

val efficientPivot = df

115

.groupBy("year", "quarter")

116

.pivot("region", Seq("North", "South", "East", "West"))

117

.agg(sum("revenue"), avg("profit_margin"))

118

119

// Multiple aggregations with pivot

120

val multiAggPivot = df

121

.groupBy("department")

122

.pivot("quarter")

123

.agg(

124

sum("sales").alias("total_sales"),

125

count("*").alias("transaction_count")

126

)

127

128

// Pivot with dynamic values

129

val dynamicPivot = df

130

.groupBy("region")

131

.pivot("product_category")

132

.agg(

133

sum("sales"),

134

avg("price"),

135

countDistinct("customer_id")

136

)

137

```

138

139

## KeyValueGroupedDataset (Type-Safe Aggregations)

140

141

```scala { .api }

142

class KeyValueGroupedDataset[K, V] {

143

def agg[U1: Encoder](column: TypedColumn[V, U1]): Dataset[(K, U1)]

144

def agg[U1: Encoder, U2: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]

145

def agg[U1: Encoder, U2: Encoder, U3: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]

146

def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]

147

148

def count(): Dataset[(K, Long)]

149

def keys: Dataset[K]

150

def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V]

151

def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W]

152

153

def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]

154

def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]

155

def reduceGroups(f: (V, V) => V): Dataset[(K, V)]

156

157

def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]

158

def flatMapGroupsWithState[S: Encoder, U: Encoder](outputMode: OutputMode, timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]

159

160

def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]

161

}

162

```

163

164

### Type-Safe Aggregations

165

166

**Usage Examples:**

167

168

```scala

169

import org.apache.spark.sql.expressions.scalalang.typed._

170

171

case class Sale(region: String, amount: Double, quantity: Int)

172

case class Employee(department: String, salary: Double, age: Int)

173

174

val salesDS: Dataset[Sale] = df.as[Sale]

175

val employeeDS: Dataset[Employee] = df.as[Employee]

176

177

// Type-safe grouping and aggregation

178

val regionSales = salesDS

179

.groupByKey(_.region)

180

.agg(

181

typed.sum(_.amount).name("total_amount"),

182

typed.avg(_.amount).name("avg_amount"),

183

typed.count(_.quantity).name("transaction_count")

184

)

185

186

// Multiple typed aggregations

187

val deptStats = employeeDS

188

.groupByKey(_.department)

189

.agg(

190

typed.avg(_.salary).name("avg_salary"),

191

typed.sum(_.salary).name("total_salary"),

192

typed.count(_.age).name("employee_count"),

193

typed.max(_.age).name("max_age")

194

)

195

196

// Simple count

197

val counts = salesDS

198

.groupByKey(_.region)

199

.count()

200

```

201

202

### Custom Group Operations

203

204

```scala

205

// Map groups to custom results

206

val customSummary = salesDS

207

.groupByKey(_.region)

208

.mapGroups { (region, sales) =>

209

val salesList = sales.toList

210

val total = salesList.map(_.amount).sum

211

val count = salesList.size

212

val avgAmount = if (count > 0) total / count else 0.0

213

(region, total, count, avgAmount)

214

}

215

216

// Reduce within groups

217

val topSaleByRegion = salesDS

218

.groupByKey(_.region)

219

.reduceGroups((sale1, sale2) => if (sale1.amount > sale2.amount) sale1 else sale2)

220

221

// FlatMap groups for multiple results per group

222

val quarterlySales = salesDS

223

.groupByKey(_.region)

224

.flatMapGroups { (region, sales) =>

225

val salesList = sales.toList

226

val quarters = salesList.groupBy(s => getQuarter(s.date))

227

quarters.map { case (quarter, qSales) =>

228

(region, quarter, qSales.map(_.amount).sum)

229

}

230

}

231

232

def getQuarter(date: java.sql.Date): Int = {

233

val cal = java.util.Calendar.getInstance()

234

cal.setTime(date)

235

(cal.get(java.util.Calendar.MONTH) / 3) + 1

236

}

237

```

238

239

### Key Transformations

240

241

```scala

242

// Transform key type

243

val byFirstLetter = salesDS

244

.groupByKey(_.region)

245

.keyAs[String](_.substring(0, 1))

246

.count()

247

248

// Transform values while maintaining grouping

249

val withNormalizedAmounts = salesDS

250

.groupByKey(_.region)

251

.mapValues(sale => sale.copy(amount = sale.amount / 1000.0))

252

.agg(typed.sum(_.amount))

253

```

254

255

## Window Functions and Analytics

256

257

```scala { .api }

258

import org.apache.spark.sql.expressions.Window

259

260

object Window {

261

def partitionBy(cols: Column*): WindowSpec

262

def partitionBy(colNames: String*): WindowSpec

263

def orderBy(cols: Column*): WindowSpec

264

def orderBy(colNames: String*): WindowSpec

265

def rowsBetween(start: Long, end: Long): WindowSpec

266

def rangeBetween(start: Long, end: Long): WindowSpec

267

268

val unboundedPreceding: Long

269

val unboundedFollowing: Long

270

val currentRow: Long

271

}

272

273

class WindowSpec {

274

def partitionBy(cols: Column*): WindowSpec

275

def orderBy(cols: Column*): WindowSpec

276

def rowsBetween(start: Long, end: Long): WindowSpec

277

def rangeBetween(start: Long, end: Long): WindowSpec

278

}

279

```

280

281

### Window Aggregations

282

283

**Usage Examples:**

284

285

```scala

286

import org.apache.spark.sql.expressions.Window

287

288

// Basic window specifications

289

val partitionWindow = Window.partitionBy("department")

290

val orderWindow = Window.partitionBy("department").orderBy(col("salary").desc)

291

val frameWindow = Window

292

.partitionBy("department")

293

.orderBy("hire_date")

294

.rowsBetween(Window.unboundedPreceding, Window.currentRow)

295

296

// Ranking functions

297

val withRanking = df

298

.withColumn("salary_rank", rank().over(orderWindow))

299

.withColumn("salary_dense_rank", dense_rank().over(orderWindow))

300

.withColumn("row_number", row_number().over(orderWindow))

301

.withColumn("percentile", percent_rank().over(orderWindow))

302

303

// Running aggregations

304

val runningTotals = df

305

.withColumn("running_total", sum("salary").over(frameWindow))

306

.withColumn("running_avg", avg("salary").over(frameWindow))

307

.withColumn("running_count", count("*").over(frameWindow))

308

309

// Lead and lag for comparisons

310

val withOffsets = df

311

.withColumn("prev_salary", lag("salary", 1).over(orderWindow))

312

.withColumn("next_salary", lead("salary", 1).over(orderWindow))

313

.withColumn("salary_change",

314

col("salary") - lag("salary", 1).over(orderWindow))

315

316

// Moving averages

317

val movingAvgWindow = Window

318

.partitionBy("stock_symbol")

319

.orderBy("date")

320

.rowsBetween(-6, 0) // 7-day moving average

321

322

val withMovingAvg = df

323

.withColumn("price_7day_avg", avg("close_price").over(movingAvgWindow))

324

.withColumn("volume_7day_avg", avg("volume").over(movingAvgWindow))

325

326

// Cumulative operations

327

val cumulativeWindow = Window

328

.partitionBy("customer_id")

329

.orderBy("order_date")

330

.rowsBetween(Window.unboundedPreceding, Window.currentRow)

331

332

val withCumulative = df

333

.withColumn("cumulative_spend", sum("order_amount").over(cumulativeWindow))

334

.withColumn("order_sequence", row_number().over(cumulativeWindow))

335

```

336

337

### Advanced Window Functions

338

339

```scala

340

// N-tile for bucketing

341

val quartileWindow = Window.partitionBy("department").orderBy("performance_score")

342

val withQuartiles = df

343

.withColumn("performance_quartile", ntile(4).over(quartileWindow))

344

345

// Cumulative distribution

346

val withCumeDist = df

347

.withColumn("cum_dist", cume_dist().over(orderWindow))

348

349

// Range-based windows

350

val rangeWindow = Window

351

.partitionBy("category")

352

.orderBy("price")

353

.rangeBetween(-100, 100) // Within $100 of current price

354

355

val priceComparisons = df

356

.withColumn("similar_price_count", count("*").over(rangeWindow))

357

.withColumn("similar_price_avg", avg("rating").over(rangeWindow))

358

359

// Multiple window specs for complex analytics

360

val salesAnalytics = df

361

.withColumn("monthly_total", sum("sales").over(

362

Window.partitionBy("region", "year", "month")))

363

.withColumn("ytd_total", sum("sales").over(

364

Window.partitionBy("region", "year").orderBy("month")

365

.rowsBetween(Window.unboundedPreceding, Window.currentRow)))

366

.withColumn("region_rank", rank().over(

367

Window.partitionBy("year", "month").orderBy(col("sales").desc)))

368

```

369

370

## Pivot and Unpivot Operations

371

372

### Complex Pivot Examples

373

374

```scala

375

// Multi-level pivot

376

val multiPivot = df

377

.groupBy("region", "year")

378

.pivot("quarter", Seq("Q1", "Q2", "Q3", "Q4"))

379

.agg(

380

sum("revenue").alias("revenue"),

381

sum("profit").alias("profit"),

382

count("transactions").alias("tx_count")

383

)

384

385

// Conditional pivot

386

val conditionalPivot = df

387

.groupBy("product")

388

.pivot("status")

389

.agg(

390

sum(when(col("type") === "sale", col("amount")).otherwise(0)).alias("sales"),

391

sum(when(col("type") === "return", col("amount")).otherwise(0)).alias("returns")

392

)

393

394

// Cross-tab analysis

395

val crossTab = df.crosstab("education_level", "income_bracket")

396

```

397

398

### Unpivot Operations (Manual)

399

400

```scala

401

// Manual unpivot using union

402

val unpivoted = df.select("id", "Q1", "Q2", "Q3", "Q4")

403

.select(

404

col("id"),

405

lit("Q1").alias("quarter"),

406

col("Q1").alias("value")

407

)

408

.union(

409

df.select(col("id"), lit("Q2").alias("quarter"), col("Q2").alias("value"))

410

)

411

.union(

412

df.select(col("id"), lit("Q3").alias("quarter"), col("Q3").alias("value"))

413

)

414

.union(

415

df.select(col("id"), lit("Q4").alias("quarter"), col("Q4").alias("value"))

416

)

417

418

// Using melt pattern with stack function

419

val melted = df.selectExpr(

420

"id",

421

"stack(4, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, value)"

422

)

423

```

424

425

## Custom Aggregation Functions

426

427

### User-Defined Aggregate Functions (UDAF)

428

429

```scala

430

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}

431

import org.apache.spark.sql.types._

432

import org.apache.spark.sql.Row

433

434

// Example: Geometric Mean UDAF

435

class GeometricMean extends UserDefinedAggregateFunction {

436

// Input schema

437

def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)

438

439

// Buffer schema

440

def bufferSchema: StructType = StructType(Array(

441

StructField("logSum", DoubleType),

442

StructField("count", LongType)

443

))

444

445

// Output type

446

def dataType: DataType = DoubleType

447

448

// Deterministic

449

def deterministic: Boolean = true

450

451

// Initialize buffer

452

def initialize(buffer: MutableAggregationBuffer): Unit = {

453

buffer(0) = 0.0 // logSum

454

buffer(1) = 0L // count

455

}

456

457

// Update buffer with input

458

def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

459

if (!input.isNullAt(0)) {

460

val value = input.getDouble(0)

461

if (value > 0) {

462

buffer(0) = buffer.getDouble(0) + math.log(value)

463

buffer(1) = buffer.getLong(1) + 1

464

}

465

}

466

}

467

468

// Merge two buffers

469

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

470

buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)

471

buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)

472

}

473

474

// Compute final result

475

def evaluate(buffer: Row): Any = {

476

val count = buffer.getLong(1)

477

if (count > 0) {

478

math.exp(buffer.getDouble(0) / count)

479

} else {

480

null

481

}

482

}

483

}

484

485

// Register and use UDAF

486

val geometricMean = new GeometricMean()

487

spark.udf.register("geometric_mean", geometricMean)

488

489

val result = df

490

.groupBy("category")

491

.agg(geometricMean(col("value")).alias("geom_mean"))

492

```

493

494

## Performance Optimization for Aggregations

495

496

### Optimization Techniques

497

498

```scala

499

// Pre-aggregate before expensive operations

500

val preAggregated = df

501

.filter(col("date") >= "2023-01-01") // Filter early

502

.groupBy("region", "product")

503

.agg(sum("sales").alias("total_sales"))

504

.cache() // Cache intermediate result

505

506

// Use broadcast for small dimension tables

507

val regionMapping = spark.table("region_mapping")

508

broadcast(regionMapping)

509

510

val enrichedAgg = df

511

.join(broadcast(regionMapping), "region_id")

512

.groupBy("region_name")

513

.agg(sum("sales"))

514

515

// Bucketing for frequent aggregations

516

df.write

517

.bucketBy(10, "category")

518

.sortBy("date")

519

.saveAsTable("bucketed_sales")

520

521

// The bucketed table will have better performance for category-based aggregations

522

val bucketedAgg = spark.table("bucketed_sales")

523

.groupBy("category")

524

.agg(sum("amount"))

525

526

// Partial aggregation hints

527

val partialAgg = df

528

.hint("BROADCAST", regionMapping)

529

.groupBy("region")

530

.agg(sum("sales"))

531

```