0
# Apache Spark SQL - Dataset and DataFrame Operations
1
2
## Capabilities
3
4
### Dataset and DataFrame Core Operations
5
- Create strongly typed Datasets and untyped DataFrames from various data sources
6
- Perform functional transformations using map, filter, flatMap with type safety
7
- Execute relational operations like select, where, groupBy, join with SQL-like syntax
8
- Handle schema inference, validation, and evolution with compile-time type checking
9
10
### Data Transformations and Actions
11
- Apply lazy transformations that build execution plans without immediate computation
12
- Execute actions that trigger computation and return results to driver or external systems
13
- Chain transformations efficiently using Catalyst query optimizer
14
- Handle distributed computation with automatic parallelization and fault tolerance
15
16
### Aggregations and Grouping
17
- Perform complex aggregations with built-in and custom aggregation functions
18
- Group data by single or multiple columns with typed and untyped grouping operations
19
- Execute window functions for analytics and ranking operations
20
- Handle pivot and unpivot operations for data reshaping and analysis
21
22
### Joins and Set Operations
23
- Join datasets using various join types (inner, outer, left, right, semi, anti)
24
- Optimize joins with broadcast hints and bucketing strategies
25
- Perform set operations like union, intersect, except with schema alignment
26
- Handle complex join conditions with multiple predicates and expressions
27
28
## API Reference
29
30
### Dataset[T] Class
31
```scala { .api }
32
abstract class Dataset[T] extends Serializable {
33
// Schema and metadata
34
def schema: StructType
35
def dtypes: Array[(String, String)]
36
def columns: Array[String]
37
def printSchema(): Unit
38
def explain(): Unit
39
def explain(extended: Boolean): Unit
40
def explain(mode: String): Unit
41
42
// Basic transformations
43
def select(cols: Column*): DataFrame
44
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1]
45
def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)]
46
def selectExpr(exprs: String*): DataFrame
47
def filter(condition: Column): Dataset[T]
48
def filter(conditionExpr: String): Dataset[T]
49
def where(condition: Column): Dataset[T]
50
def where(conditionExpr: String): Dataset[T]
51
52
// Functional transformations
53
def map[U : Encoder](func: T => U): Dataset[U]
54
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U]
55
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U]
56
def foreach(f: T => Unit): Unit
57
def foreachPartition(f: Iterator[T] => Unit): Unit
58
59
// Sorting and limiting
60
def sort(sortExprs: Column*): Dataset[T]
61
def sort(sortCol: String, sortCols: String*): Dataset[T]
62
def orderBy(sortExprs: Column*): Dataset[T]
63
def orderBy(sortCol: String, sortCols: String*): Dataset[T]
64
def limit(n: Int): Dataset[T]
65
66
// Sampling and partitioning
67
def sample(fraction: Double): Dataset[T]
68
def sample(fraction: Double, seed: Long): Dataset[T]
69
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T]
70
def randomSplit(weights: Array[Double]): Array[Dataset[T]]
71
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
72
def repartition(numPartitions: Int): Dataset[T]
73
def repartition(partitionExprs: Column*): Dataset[T]
74
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T]
75
def coalesce(numPartitions: Int): Dataset[T]
76
77
// Set operations
78
def union(other: Dataset[T]): Dataset[T]
79
def unionAll(other: Dataset[T]): Dataset[T]
80
def unionByName(other: Dataset[T]): Dataset[T]
81
def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T]
82
def intersect(other: Dataset[T]): Dataset[T]
83
def intersectAll(other: Dataset[T]): Dataset[T]
84
def except(other: Dataset[T]): Dataset[T]
85
def exceptAll(other: Dataset[T]): Dataset[T]
86
87
// Joins
88
def join(right: Dataset[_]): DataFrame
89
def join(right: Dataset[_], joinExprs: Column): DataFrame
90
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame
91
def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame
92
def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame
93
def crossJoin(right: Dataset[_]): DataFrame
94
95
// Grouping and aggregation
96
def groupBy(cols: Column*): RelationalGroupedDataset
97
def groupBy(col1: String, cols: String*): RelationalGroupedDataset
98
def groupByKey[K : Encoder](func: T => K): KeyValueGroupedDataset[K, T]
99
def rollup(cols: Column*): RelationalGroupedDataset
100
def rollup(col1: String, cols: String*): RelationalGroupedDataset
101
def cube(cols: Column*): RelationalGroupedDataset
102
def cube(col1: String, cols: String*): RelationalGroupedDataset
103
def agg(expr: Column, exprs: Column*): DataFrame
104
def agg(exprs: Map[String, String]): DataFrame
105
106
// Actions
107
def show(): Unit
108
def show(numRows: Int): Unit
109
def show(numRows: Int, truncate: Boolean): Unit
110
def show(numRows: Int, truncate: Int): Unit
111
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit
112
def collect(): Array[T]
113
def collectAsList(): java.util.List[T]
114
def count(): Long
115
def reduce(func: (T, T) => T): T
116
def fold(zeroValue: T)(op: (T, T) => T): T
117
def aggregate[U : Encoder](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
118
def first(): T
119
def head(): T
120
def head(n: Int): Array[T]
121
def take(n: Int): Array[T]
122
def takeAsList(n: Int): java.util.List[T]
123
def tail(n: Int): Array[T]
124
def isEmpty: Boolean
125
def localCheckpoint(): Dataset[T]
126
def localCheckpoint(eager: Boolean): Dataset[T]
127
def checkpoint(): Dataset[T]
128
def checkpoint(eager: Boolean): Dataset[T]
129
130
// Column operations
131
def apply(colName: String): Column
132
def col(colName: String): Column
133
def withColumn(colName: String, col: Column): DataFrame
134
def withColumnRenamed(existingName: String, newName: String): DataFrame
135
def withColumns(colsMap: Map[String, Column]): DataFrame
136
def withColumnsRenamed(colsMap: Map[String, String]): DataFrame
137
def drop(colName: String): DataFrame
138
def drop(colNames: String*): DataFrame
139
def drop(col: Column): DataFrame
140
def dropDuplicates(): Dataset[T]
141
def dropDuplicates(colNames: Array[String]): Dataset[T]
142
def dropDuplicates(colNames: Seq[String]): Dataset[T]
143
def dropDuplicates(col1: String, cols: String*): Dataset[T]
144
def distinct(): Dataset[T]
145
146
// Type conversions
147
def as[U : Encoder]: Dataset[U]
148
def to(schema: StructType): DataFrame
149
def toDF(): DataFrame
150
def toDF(colNames: String*): DataFrame
151
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U]
152
153
// Persistence and caching
154
def cache(): Dataset[T]
155
def persist(): Dataset[T]
156
def persist(newLevel: StorageLevel): Dataset[T]
157
def unpersist(): Dataset[T]
158
def unpersist(blocking: Boolean): Dataset[T]
159
def storageLevel: StorageLevel
160
161
// I/O operations
162
def write: DataFrameWriter[T]
163
def writeStream: DataStreamWriter[T]
164
165
// Statistics and analysis
166
def describe(cols: String*): DataFrame
167
def summary(statistics: String*): DataFrame
168
def stat: DataFrameStatFunctions
169
def na: DataFrameNaFunctions
170
}
171
```
172
173
### DataFrame Type Alias
174
```scala { .api }
175
type DataFrame = Dataset[Row]
176
```
177
178
### RelationalGroupedDataset Class
179
```scala { .api }
180
class RelationalGroupedDataset protected[sql](
181
df: DataFrame,
182
groupingExprs: Seq[Expression],
183
groupType: RelationalGroupedDataset.GroupType) {
184
185
// Basic aggregations
186
def agg(expr: Column, exprs: Column*): DataFrame
187
def agg(exprs: Map[String, String]): DataFrame
188
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame
189
def count(): DataFrame
190
def sum(colNames: String*): DataFrame
191
def avg(colNames: String*): DataFrame
192
def max(colNames: String*): DataFrame
193
def min(colNames: String*): DataFrame
194
def mean(colNames: String*): DataFrame
195
196
// Pivot operations
197
def pivot(pivotColumn: String): RelationalGroupedDataset
198
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
199
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset
200
}
201
```
202
203
### KeyValueGroupedDataset[K, V] Class
204
```scala { .api }
205
class KeyValueGroupedDataset[K, V] private[sql](
206
kEncoder: Encoder[K],
207
vEncoder: Encoder[V]) {
208
209
// Typed aggregations
210
def agg[U1: Encoder](col1: TypedColumn[V, U1]): Dataset[(K, U1)]
211
def agg[U1: Encoder, U2: Encoder](
212
col1: TypedColumn[V, U1],
213
col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
214
def agg[U1: Encoder, U2: Encoder, U3: Encoder](
215
col1: TypedColumn[V, U1],
216
col2: TypedColumn[V, U2],
217
col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]
218
def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](
219
col1: TypedColumn[V, U1],
220
col2: TypedColumn[V, U2],
221
col3: TypedColumn[V, U3],
222
col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]
223
224
// Group transformations
225
def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]
226
def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
227
def mapGroupsWithState[S: Encoder, U: Encoder](
228
stateful: GroupStateTimeout)(
229
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
230
def flatMapGroupsWithState[S: Encoder, U: Encoder](
231
outputMode: OutputMode,
232
timeoutConf: GroupStateTimeout)(
233
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
234
235
// Group reductions
236
def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
237
def cogroupApply[U, R: Encoder](
238
other: KeyValueGroupedDataset[K, U])(
239
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]
240
}
241
```
242
243
### Column Class
244
```scala { .api }
245
class Column(val expr: Expression) extends Logging {
246
// Arithmetic operations
247
def +(other: Any): Column
248
def -(other: Any): Column
249
def *(other: Any): Column
250
def /(other: Any): Column
251
def %(other: Any): Column
252
253
// Comparison operations
254
def ===(other: Any): Column
255
def =!=(other: Any): Column
256
def >(other: Any): Column
257
def <(other: Any): Column
258
def >=(other: Any): Column
259
def <=(other: Any): Column
260
def <->(other: Any): Column
261
262
// Boolean operations
263
def &&(other: Column): Column
264
def ||(other: Column): Column
265
def unary_!: Column
266
267
// Null operations
268
def isNull: Column
269
def isNotNull: Column
270
def isNaN: Column
271
272
// String operations
273
def contains(other: Any): Column
274
def startsWith(other: Column): Column
275
def startsWith(literal: String): Column
276
def endsWith(other: Column): Column
277
def endsWith(literal: String): Column
278
def like(literal: String): Column
279
def rlike(literal: String): Column
280
def substr(startPos: Column, len: Column): Column
281
def substr(startPos: Int, len: Int): Column
282
283
// Type operations
284
def cast(to: DataType): Column
285
def cast(to: String): Column
286
def as(alias: String): Column
287
def as(alias: Symbol): Column
288
def as(aliases: Array[String]): Column
289
def name(alias: String): Column
290
291
// Sorting
292
def asc: Column
293
def asc_nulls_first: Column
294
def asc_nulls_last: Column
295
def desc: Column
296
def desc_nulls_first: Column
297
def desc_nulls_last: Column
298
299
// Collection operations
300
def getItem(key: Any): Column
301
def getField(fieldName: String): Column
302
303
// Window operations
304
def over(): Column
305
def over(window: WindowSpec): Column
306
}
307
```
308
309
### Row Trait
310
```scala { .api }
311
trait Row extends Serializable {
312
// Size and schema
313
def size: Int
314
def length: Int
315
def schema: StructType
316
317
// Generic access
318
def get(i: Int): Any
319
def getAs[T](i: Int): T
320
def getAs[T](fieldName: String): T
321
def fieldIndex(name: String): Int
322
def getValuesMap[T](fieldNames: Seq[String]): Map[String, T]
323
324
// Type-specific getters
325
def isNullAt(i: Int): Boolean
326
def getString(i: Int): String
327
def getBoolean(i: Int): Boolean
328
def getByte(i: Int): Byte
329
def getShort(i: Int): Short
330
def getInt(i: Int): Int
331
def getLong(i: Int): Long
332
def getFloat(i: Int): Float
333
def getDouble(i: Int): Double
334
def getDecimal(i: Int): java.math.BigDecimal
335
def getDate(i: Int): java.sql.Date
336
def getTimestamp(i: Int): java.sql.Timestamp
337
def getInstant(i: Int): java.time.Instant
338
def getLocalDate(i: Int): java.time.LocalDate
339
def getSeq[T](i: Int): Seq[T]
340
def getList[T](i: Int): java.util.List[T]
341
def getMap[K, V](i: Int): Map[K, V]
342
def getJavaMap[K, V](i: Int): java.util.Map[K, V]
343
def getStruct(i: Int): Row
344
345
// Conversion operations
346
def copy(): Row
347
def toSeq: Seq[Any]
348
def mkString: String
349
def mkString(sep: String): String
350
def mkString(start: String, sep: String, end: String): String
351
}
352
```
353
354
## Usage Examples
355
356
### Creating Datasets and DataFrames
357
```scala
358
import org.apache.spark.sql.{SparkSession, Dataset, DataFrame}
359
import org.apache.spark.sql.types._
360
361
// Create SparkSession
362
val spark = SparkSession.builder().appName("Dataset Examples").getOrCreate()
363
import spark.implicits._
364
365
// Create Dataset from case class
366
case class Person(name: String, age: Int, salary: Double)
367
val people = Seq(
368
Person("Alice", 25, 50000.0),
369
Person("Bob", 30, 60000.0),
370
Person("Charlie", 35, 70000.0)
371
)
372
val ds: Dataset[Person] = spark.createDataset(people)
373
374
// Create DataFrame from Row data
375
val schema = StructType(Array(
376
StructField("name", StringType, nullable = false),
377
StructField("age", IntegerType, nullable = false),
378
StructField("salary", DoubleType, nullable = false)
379
))
380
val rows = Seq(
381
Row("Alice", 25, 50000.0),
382
Row("Bob", 30, 60000.0),
383
Row("Charlie", 35, 70000.0)
384
)
385
val df: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
386
387
// Convert between Dataset and DataFrame
388
val dataFrame = ds.toDF()
389
val dataset = df.as[Person]
390
```
391
392
### Basic Transformations
393
```scala
394
import org.apache.spark.sql.functions._
395
396
// Select operations
397
val names = ds.select($"name")
398
val nameAge = ds.select($"name", $"age")
399
val computed = ds.select($"name", ($"salary" * 1.1).as("adjusted_salary"))
400
401
// Filter operations
402
val youngPeople = ds.filter($"age" < 30)
403
val highEarners = ds.where($"salary" > 55000)
404
val complexFilter = ds.filter($"age" > 25 && $"salary" < 65000)
405
406
// Functional transformations
407
val upperNames = ds.map(person => person.copy(name = person.name.toUpperCase))
408
val salaryCategories = ds.map { person =>
409
val category = if (person.salary > 55000) "High" else "Low"
410
(person.name, category)
411
}
412
413
// Column operations
414
val withBonus = ds.withColumn("bonus", $"salary" * 0.1)
415
val renamed = ds.withColumnRenamed("salary", "annual_salary")
416
val dropped = ds.drop("salary")
417
```
418
419
### Aggregations and Grouping
420
```scala
421
import org.apache.spark.sql.functions._
422
423
// Simple aggregations
424
val totalCount = ds.count()
425
val avgSalary = ds.agg(avg($"salary")).collect()(0)(0)
426
val salaryStats = ds.agg(
427
min($"salary").as("min_salary"),
428
max($"salary").as("max_salary"),
429
avg($"salary").as("avg_salary"),
430
stddev($"salary").as("stddev_salary")
431
)
432
433
// Grouping operations
434
val ageGroups = ds.groupBy($"age").agg(
435
count($"name").as("count"),
436
avg($"salary").as("avg_salary")
437
)
438
439
// Typed grouping
440
val salaryByAge = ds.groupByKey(_.age).agg(
441
typed.avg[Person](_.salary),
442
typed.count[Person](_)
443
)
444
445
// Window functions
446
import org.apache.spark.sql.expressions.Window
447
val windowSpec = Window.partitionBy($"age").orderBy($"salary".desc)
448
val rankedSalaries = ds.withColumn("rank", row_number().over(windowSpec))
449
450
// Pivot operations
451
val pivotData = ds.groupBy($"age").pivot("name").agg(sum($"salary"))
452
```
453
454
### Joins and Set Operations
455
```scala
456
// Sample datasets for joins
457
case class Department(id: Int, name: String)
458
val departments = Seq(
459
Department(1, "Engineering"),
460
Department(2, "Sales"),
461
Department(3, "Marketing")
462
)
463
val deptDS = spark.createDataset(departments)
464
465
case class Employee(name: String, age: Int, deptId: Int)
466
val employees = Seq(
467
Employee("Alice", 25, 1),
468
Employee("Bob", 30, 2),
469
Employee("Charlie", 35, 1)
470
)
471
val empDS = spark.createDataset(employees)
472
473
// Join operations
474
val innerJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"))
475
val leftJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "left")
476
val rightJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "right")
477
val outerJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "outer")
478
479
// Set operations
480
val moreEmployees = Seq(Employee("David", 28, 3), Employee("Eve", 32, 2))
481
val moreEmpDS = spark.createDataset(moreEmployees)
482
483
val combined = empDS.union(moreEmpDS)
484
val duplicatesRemoved = combined.distinct()
485
val intersection = empDS.intersect(moreEmpDS)
486
val difference = empDS.except(moreEmpDS)
487
```
488
489
### Advanced Operations
490
```scala
491
import org.apache.spark.sql.functions._
492
import org.apache.spark.sql.expressions.Window
493
494
// Complex transformations
495
val enrichedData = ds.transform { df =>
496
df.withColumn("age_group",
497
when($"age" < 30, "Young")
498
.when($"age" < 40, "Middle")
499
.otherwise("Senior"))
500
.withColumn("salary_percentile",
501
percent_rank().over(Window.orderBy($"salary")))
502
}
503
504
// Sampling and partitioning
505
val sample = ds.sample(0.5) // 50% sample
506
val stratified = ds.stat.sampleBy("age_group", Map("Young" -> 0.3, "Middle" -> 0.5, "Senior" -> 0.7), 123L)
507
val repartitioned = ds.repartition(4, $"age")
508
val coalesced = ds.coalesce(2)
509
510
// Null handling
511
val withNulls = ds.na.drop() // Drop rows with any nulls
512
val filled = ds.na.fill(Map("salary" -> 0.0, "name" -> "Unknown"))
513
val replaced = ds.na.replace("name", Map("Alice" -> "Alicia"))
514
515
// Statistical operations
516
val correlation = ds.stat.corr("age", "salary")
517
val covariance = ds.stat.cov("age", "salary")
518
val crosstab = ds.stat.crosstab("age", "salary")
519
val freqItems = ds.stat.freqItems(Seq("name"), 0.4)
520
521
// Caching and persistence
522
import org.apache.spark.storage.StorageLevel
523
val cached = ds.cache()
524
val persisted = ds.persist(StorageLevel.MEMORY_AND_DISK_SER)
525
526
// Explain execution plan
527
ds.explain(true)
528
ds.explain("cost")
529
```
530
531
### Working with Complex Data Types
532
```scala
533
import org.apache.spark.sql.functions._
534
535
// Array operations
536
case class PersonWithHobbies(name: String, age: Int, hobbies: Seq[String])
537
val peopleWithHobbies = Seq(
538
PersonWithHobbies("Alice", 25, Seq("reading", "swimming")),
539
PersonWithHobbies("Bob", 30, Seq("cooking", "hiking", "reading"))
540
)
541
val hobbiesDS = spark.createDataset(peopleWithHobbies)
542
543
val exploded = hobbiesDS.select($"name", explode($"hobbies").as("hobby"))
544
val arraySize = hobbiesDS.select($"name", size($"hobbies").as("hobby_count"))
545
val contains = hobbiesDS.filter(array_contains($"hobbies", "reading"))
546
547
// Struct operations
548
case class Address(street: String, city: String, zipCode: String)
549
case class PersonWithAddress(name: String, age: Int, address: Address)
550
551
val peopleWithAddr = Seq(
552
PersonWithAddress("Alice", 25, Address("123 Main St", "Seattle", "98101")),
553
PersonWithAddress("Bob", 30, Address("456 Oak Ave", "Portland", "97201"))
554
)
555
val addrDS = spark.createDataset(peopleWithAddr)
556
557
val cityOnly = addrDS.select($"name", $"address.city")
558
val flatStruct = addrDS.select($"name", $"age", $"address.*")
559
560
// Map operations
561
val mapData = spark.range(3).select(
562
map(lit("key1"), $"id", lit("key2"), $"id" * 2).as("map_col")
563
)
564
val keys = mapData.select(map_keys($"map_col"))
565
val values = mapData.select(map_values($"map_col"))
566
```