abstract class Dataset[T] extends Serializable {
// Schema and metadata
def schema: StructType
def dtypes: Array[(String, String)]
def columns: Array[String]
def printSchema(): Unit
def explain(): Unit
def explain(extended: Boolean): Unit
def explain(mode: String): Unit
// Basic transformations
def select(cols: Column*): DataFrame
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1]
def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)]
def selectExpr(exprs: String*): DataFrame
def filter(condition: Column): Dataset[T]
def filter(conditionExpr: String): Dataset[T]
def where(condition: Column): Dataset[T]
def where(conditionExpr: String): Dataset[T]
// Functional transformations
def map[U : Encoder](func: T => U): Dataset[U]
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U]
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U]
def foreach(f: T => Unit): Unit
def foreachPartition(f: Iterator[T] => Unit): Unit
// Sorting and limiting
def sort(sortExprs: Column*): Dataset[T]
def sort(sortCol: String, sortCols: String*): Dataset[T]
def orderBy(sortExprs: Column*): Dataset[T]
def orderBy(sortCol: String, sortCols: String*): Dataset[T]
def limit(n: Int): Dataset[T]
// Sampling and partitioning
def sample(fraction: Double): Dataset[T]
def sample(fraction: Double, seed: Long): Dataset[T]
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T]
def randomSplit(weights: Array[Double]): Array[Dataset[T]]
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
def repartition(numPartitions: Int): Dataset[T]
def repartition(partitionExprs: Column*): Dataset[T]
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T]
def coalesce(numPartitions: Int): Dataset[T]
// Set operations
def union(other: Dataset[T]): Dataset[T]
def unionAll(other: Dataset[T]): Dataset[T]
def unionByName(other: Dataset[T]): Dataset[T]
def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T]
def intersect(other: Dataset[T]): Dataset[T]
def intersectAll(other: Dataset[T]): Dataset[T]
def except(other: Dataset[T]): Dataset[T]
def exceptAll(other: Dataset[T]): Dataset[T]
// Joins
def join(right: Dataset[_]): DataFrame
def join(right: Dataset[_], joinExprs: Column): DataFrame
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame
def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame
def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame
def crossJoin(right: Dataset[_]): DataFrame
// Grouping and aggregation
def groupBy(cols: Column*): RelationalGroupedDataset
def groupBy(col1: String, cols: String*): RelationalGroupedDataset
def groupByKey[K : Encoder](func: T => K): KeyValueGroupedDataset[K, T]
def rollup(cols: Column*): RelationalGroupedDataset
def rollup(col1: String, cols: String*): RelationalGroupedDataset
def cube(cols: Column*): RelationalGroupedDataset
def cube(col1: String, cols: String*): RelationalGroupedDataset
def agg(expr: Column, exprs: Column*): DataFrame
def agg(exprs: Map[String, String]): DataFrame
// Actions
def show(): Unit
def show(numRows: Int): Unit
def show(numRows: Int, truncate: Boolean): Unit
def show(numRows: Int, truncate: Int): Unit
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit
def collect(): Array[T]
def collectAsList(): java.util.List[T]
def count(): Long
def reduce(func: (T, T) => T): T
def fold(zeroValue: T)(op: (T, T) => T): T
def aggregate[U : Encoder](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
def first(): T
def head(): T
def head(n: Int): Array[T]
def take(n: Int): Array[T]
def takeAsList(n: Int): java.util.List[T]
def tail(n: Int): Array[T]
def isEmpty: Boolean
def localCheckpoint(): Dataset[T]
def localCheckpoint(eager: Boolean): Dataset[T]
def checkpoint(): Dataset[T]
def checkpoint(eager: Boolean): Dataset[T]
// Column operations
def apply(colName: String): Column
def col(colName: String): Column
def withColumn(colName: String, col: Column): DataFrame
def withColumnRenamed(existingName: String, newName: String): DataFrame
def withColumns(colsMap: Map[String, Column]): DataFrame
def withColumnsRenamed(colsMap: Map[String, String]): DataFrame
def drop(colName: String): DataFrame
def drop(colNames: String*): DataFrame
def drop(col: Column): DataFrame
def dropDuplicates(): Dataset[T]
def dropDuplicates(colNames: Array[String]): Dataset[T]
def dropDuplicates(colNames: Seq[String]): Dataset[T]
def dropDuplicates(col1: String, cols: String*): Dataset[T]
def distinct(): Dataset[T]
// Type conversions
def as[U : Encoder]: Dataset[U]
def to(schema: StructType): DataFrame
def toDF(): DataFrame
def toDF(colNames: String*): DataFrame
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U]
// Persistence and caching
def cache(): Dataset[T]
def persist(): Dataset[T]
def persist(newLevel: StorageLevel): Dataset[T]
def unpersist(): Dataset[T]
def unpersist(blocking: Boolean): Dataset[T]
def storageLevel: StorageLevel
// I/O operations
def write: DataFrameWriter[T]
def writeStream: DataStreamWriter[T]
// Statistics and analysis
def describe(cols: String*): DataFrame
def summary(statistics: String*): DataFrame
def stat: DataFrameStatFunctions
def na: DataFrameNaFunctions
}type DataFrame = Dataset[Row]class RelationalGroupedDataset protected[sql](
df: DataFrame,
groupingExprs: Seq[Expression],
groupType: RelationalGroupedDataset.GroupType) {
// Basic aggregations
def agg(expr: Column, exprs: Column*): DataFrame
def agg(exprs: Map[String, String]): DataFrame
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame
def count(): DataFrame
def sum(colNames: String*): DataFrame
def avg(colNames: String*): DataFrame
def max(colNames: String*): DataFrame
def min(colNames: String*): DataFrame
def mean(colNames: String*): DataFrame
// Pivot operations
def pivot(pivotColumn: String): RelationalGroupedDataset
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset
}class KeyValueGroupedDataset[K, V] private[sql](
kEncoder: Encoder[K],
vEncoder: Encoder[V]) {
// Typed aggregations
def agg[U1: Encoder](col1: TypedColumn[V, U1]): Dataset[(K, U1)]
def agg[U1: Encoder, U2: Encoder](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
def agg[U1: Encoder, U2: Encoder, U3: Encoder](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]
def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]
// Group transformations
def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]
def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
def mapGroupsWithState[S: Encoder, U: Encoder](
stateful: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
// Group reductions
def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
def cogroupApply[U, R: Encoder](
other: KeyValueGroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]
}class Column(val expr: Expression) extends Logging {
// Arithmetic operations
def +(other: Any): Column
def -(other: Any): Column
def *(other: Any): Column
def /(other: Any): Column
def %(other: Any): Column
// Comparison operations
def ===(other: Any): Column
def =!=(other: Any): Column
def >(other: Any): Column
def <(other: Any): Column
def >=(other: Any): Column
def <=(other: Any): Column
def <->(other: Any): Column
// Boolean operations
def &&(other: Column): Column
def ||(other: Column): Column
def unary_!: Column
// Null operations
def isNull: Column
def isNotNull: Column
def isNaN: Column
// String operations
def contains(other: Any): Column
def startsWith(other: Column): Column
def startsWith(literal: String): Column
def endsWith(other: Column): Column
def endsWith(literal: String): Column
def like(literal: String): Column
def rlike(literal: String): Column
def substr(startPos: Column, len: Column): Column
def substr(startPos: Int, len: Int): Column
// Type operations
def cast(to: DataType): Column
def cast(to: String): Column
def as(alias: String): Column
def as(alias: Symbol): Column
def as(aliases: Array[String]): Column
def name(alias: String): Column
// Sorting
def asc: Column
def asc_nulls_first: Column
def asc_nulls_last: Column
def desc: Column
def desc_nulls_first: Column
def desc_nulls_last: Column
// Collection operations
def getItem(key: Any): Column
def getField(fieldName: String): Column
// Window operations
def over(): Column
def over(window: WindowSpec): Column
}trait Row extends Serializable {
// Size and schema
def size: Int
def length: Int
def schema: StructType
// Generic access
def get(i: Int): Any
def getAs[T](i: Int): T
def getAs[T](fieldName: String): T
def fieldIndex(name: String): Int
def getValuesMap[T](fieldNames: Seq[String]): Map[String, T]
// Type-specific getters
def isNullAt(i: Int): Boolean
def getString(i: Int): String
def getBoolean(i: Int): Boolean
def getByte(i: Int): Byte
def getShort(i: Int): Short
def getInt(i: Int): Int
def getLong(i: Int): Long
def getFloat(i: Int): Float
def getDouble(i: Int): Double
def getDecimal(i: Int): java.math.BigDecimal
def getDate(i: Int): java.sql.Date
def getTimestamp(i: Int): java.sql.Timestamp
def getInstant(i: Int): java.time.Instant
def getLocalDate(i: Int): java.time.LocalDate
def getSeq[T](i: Int): Seq[T]
def getList[T](i: Int): java.util.List[T]
def getMap[K, V](i: Int): Map[K, V]
def getJavaMap[K, V](i: Int): java.util.Map[K, V]
def getStruct(i: Int): Row
// Conversion operations
def copy(): Row
def toSeq: Seq[Any]
def mkString: String
def mkString(sep: String): String
def mkString(start: String, sep: String, end: String): String
}import org.apache.spark.sql.{SparkSession, Dataset, DataFrame}
import org.apache.spark.sql.types._
// Create SparkSession
val spark = SparkSession.builder().appName("Dataset Examples").getOrCreate()
import spark.implicits._
// Create Dataset from case class
case class Person(name: String, age: Int, salary: Double)
val people = Seq(
Person("Alice", 25, 50000.0),
Person("Bob", 30, 60000.0),
Person("Charlie", 35, 70000.0)
)
val ds: Dataset[Person] = spark.createDataset(people)
// Create DataFrame from Row data
val schema = StructType(Array(
StructField("name", StringType, nullable = false),
StructField("age", IntegerType, nullable = false),
StructField("salary", DoubleType, nullable = false)
))
val rows = Seq(
Row("Alice", 25, 50000.0),
Row("Bob", 30, 60000.0),
Row("Charlie", 35, 70000.0)
)
val df: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
// Convert between Dataset and DataFrame
val dataFrame = ds.toDF()
val dataset = df.as[Person]import org.apache.spark.sql.functions._
// Select operations
val names = ds.select($"name")
val nameAge = ds.select($"name", $"age")
val computed = ds.select($"name", ($"salary" * 1.1).as("adjusted_salary"))
// Filter operations
val youngPeople = ds.filter($"age" < 30)
val highEarners = ds.where($"salary" > 55000)
val complexFilter = ds.filter($"age" > 25 && $"salary" < 65000)
// Functional transformations
val upperNames = ds.map(person => person.copy(name = person.name.toUpperCase))
val salaryCategories = ds.map { person =>
val category = if (person.salary > 55000) "High" else "Low"
(person.name, category)
}
// Column operations
val withBonus = ds.withColumn("bonus", $"salary" * 0.1)
val renamed = ds.withColumnRenamed("salary", "annual_salary")
val dropped = ds.drop("salary")import org.apache.spark.sql.functions._
// Simple aggregations
val totalCount = ds.count()
val avgSalary = ds.agg(avg($"salary")).collect()(0)(0)
val salaryStats = ds.agg(
min($"salary").as("min_salary"),
max($"salary").as("max_salary"),
avg($"salary").as("avg_salary"),
stddev($"salary").as("stddev_salary")
)
// Grouping operations
val ageGroups = ds.groupBy($"age").agg(
count($"name").as("count"),
avg($"salary").as("avg_salary")
)
// Typed grouping
val salaryByAge = ds.groupByKey(_.age).agg(
typed.avg[Person](_.salary),
typed.count[Person](_)
)
// Window functions
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy($"age").orderBy($"salary".desc)
val rankedSalaries = ds.withColumn("rank", row_number().over(windowSpec))
// Pivot operations
val pivotData = ds.groupBy($"age").pivot("name").agg(sum($"salary"))// Sample datasets for joins
case class Department(id: Int, name: String)
val departments = Seq(
Department(1, "Engineering"),
Department(2, "Sales"),
Department(3, "Marketing")
)
val deptDS = spark.createDataset(departments)
case class Employee(name: String, age: Int, deptId: Int)
val employees = Seq(
Employee("Alice", 25, 1),
Employee("Bob", 30, 2),
Employee("Charlie", 35, 1)
)
val empDS = spark.createDataset(employees)
// Join operations
val innerJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"))
val leftJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "left")
val rightJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "right")
val outerJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "outer")
// Set operations
val moreEmployees = Seq(Employee("David", 28, 3), Employee("Eve", 32, 2))
val moreEmpDS = spark.createDataset(moreEmployees)
val combined = empDS.union(moreEmpDS)
val duplicatesRemoved = combined.distinct()
val intersection = empDS.intersect(moreEmpDS)
val difference = empDS.except(moreEmpDS)import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
// Complex transformations
val enrichedData = ds.transform { df =>
df.withColumn("age_group",
when($"age" < 30, "Young")
.when($"age" < 40, "Middle")
.otherwise("Senior"))
.withColumn("salary_percentile",
percent_rank().over(Window.orderBy($"salary")))
}
// Sampling and partitioning
val sample = ds.sample(0.5) // 50% sample
val stratified = ds.stat.sampleBy("age_group", Map("Young" -> 0.3, "Middle" -> 0.5, "Senior" -> 0.7), 123L)
val repartitioned = ds.repartition(4, $"age")
val coalesced = ds.coalesce(2)
// Null handling
val withNulls = ds.na.drop() // Drop rows with any nulls
val filled = ds.na.fill(Map("salary" -> 0.0, "name" -> "Unknown"))
val replaced = ds.na.replace("name", Map("Alice" -> "Alicia"))
// Statistical operations
val correlation = ds.stat.corr("age", "salary")
val covariance = ds.stat.cov("age", "salary")
val crosstab = ds.stat.crosstab("age", "salary")
val freqItems = ds.stat.freqItems(Seq("name"), 0.4)
// Caching and persistence
import org.apache.spark.storage.StorageLevel
val cached = ds.cache()
val persisted = ds.persist(StorageLevel.MEMORY_AND_DISK_SER)
// Explain execution plan
ds.explain(true)
ds.explain("cost")import org.apache.spark.sql.functions._
// Array operations
case class PersonWithHobbies(name: String, age: Int, hobbies: Seq[String])
val peopleWithHobbies = Seq(
PersonWithHobbies("Alice", 25, Seq("reading", "swimming")),
PersonWithHobbies("Bob", 30, Seq("cooking", "hiking", "reading"))
)
val hobbiesDS = spark.createDataset(peopleWithHobbies)
val exploded = hobbiesDS.select($"name", explode($"hobbies").as("hobby"))
val arraySize = hobbiesDS.select($"name", size($"hobbies").as("hobby_count"))
val contains = hobbiesDS.filter(array_contains($"hobbies", "reading"))
// Struct operations
case class Address(street: String, city: String, zipCode: String)
case class PersonWithAddress(name: String, age: Int, address: Address)
val peopleWithAddr = Seq(
PersonWithAddress("Alice", 25, Address("123 Main St", "Seattle", "98101")),
PersonWithAddress("Bob", 30, Address("456 Oak Ave", "Portland", "97201"))
)
val addrDS = spark.createDataset(peopleWithAddr)
val cityOnly = addrDS.select($"name", $"address.city")
val flatStruct = addrDS.select($"name", $"age", $"address.*")
// Map operations
val mapData = spark.range(3).select(
map(lit("key1"), $"id", lit("key2"), $"id" * 2).as("map_col")
)
val keys = mapData.select(map_keys($"map_col"))
val values = mapData.select(map_values($"map_col"))