or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

catalog.mddata-sources.mddata-types.mddataframe-dataset.mdindex.mdsession-management.mdsql-functions.mdstreaming.mdudfs.md
tile.json

dataframe-dataset.mddocs/

Apache Spark SQL - Dataset and DataFrame Operations

Capabilities

Dataset and DataFrame Core Operations

  • Create strongly typed Datasets and untyped DataFrames from various data sources
  • Perform functional transformations using map, filter, flatMap with type safety
  • Execute relational operations like select, where, groupBy, join with SQL-like syntax
  • Handle schema inference, validation, and evolution with compile-time type checking

Data Transformations and Actions

  • Apply lazy transformations that build execution plans without immediate computation
  • Execute actions that trigger computation and return results to driver or external systems
  • Chain transformations efficiently using Catalyst query optimizer
  • Handle distributed computation with automatic parallelization and fault tolerance

Aggregations and Grouping

  • Perform complex aggregations with built-in and custom aggregation functions
  • Group data by single or multiple columns with typed and untyped grouping operations
  • Execute window functions for analytics and ranking operations
  • Handle pivot and unpivot operations for data reshaping and analysis

Joins and Set Operations

  • Join datasets using various join types (inner, outer, left, right, semi, anti)
  • Optimize joins with broadcast hints and bucketing strategies
  • Perform set operations like union, intersect, except with schema alignment
  • Handle complex join conditions with multiple predicates and expressions

API Reference

Dataset[T] Class

abstract class Dataset[T] extends Serializable {
  // Schema and metadata
  def schema: StructType
  def dtypes: Array[(String, String)]
  def columns: Array[String]
  def printSchema(): Unit
  def explain(): Unit
  def explain(extended: Boolean): Unit
  def explain(mode: String): Unit
  
  // Basic transformations
  def select(cols: Column*): DataFrame
  def select[U1](c1: TypedColumn[T, U1]): Dataset[U1]
  def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)]
  def selectExpr(exprs: String*): DataFrame
  def filter(condition: Column): Dataset[T]
  def filter(conditionExpr: String): Dataset[T]
  def where(condition: Column): Dataset[T]
  def where(conditionExpr: String): Dataset[T]
  
  // Functional transformations
  def map[U : Encoder](func: T => U): Dataset[U]
  def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U]
  def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U]
  def foreach(f: T => Unit): Unit
  def foreachPartition(f: Iterator[T] => Unit): Unit
  
  // Sorting and limiting
  def sort(sortExprs: Column*): Dataset[T]
  def sort(sortCol: String, sortCols: String*): Dataset[T]
  def orderBy(sortExprs: Column*): Dataset[T]
  def orderBy(sortCol: String, sortCols: String*): Dataset[T]
  def limit(n: Int): Dataset[T]
  
  // Sampling and partitioning
  def sample(fraction: Double): Dataset[T]
  def sample(fraction: Double, seed: Long): Dataset[T]
  def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T]
  def randomSplit(weights: Array[Double]): Array[Dataset[T]]
  def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
  def repartition(numPartitions: Int): Dataset[T]
  def repartition(partitionExprs: Column*): Dataset[T]
  def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T]
  def coalesce(numPartitions: Int): Dataset[T]
  
  // Set operations
  def union(other: Dataset[T]): Dataset[T]
  def unionAll(other: Dataset[T]): Dataset[T]
  def unionByName(other: Dataset[T]): Dataset[T]
  def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T]
  def intersect(other: Dataset[T]): Dataset[T]
  def intersectAll(other: Dataset[T]): Dataset[T]
  def except(other: Dataset[T]): Dataset[T]
  def exceptAll(other: Dataset[T]): Dataset[T]
  
  // Joins
  def join(right: Dataset[_]): DataFrame
  def join(right: Dataset[_], joinExprs: Column): DataFrame  
  def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame
  def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame
  def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame
  def crossJoin(right: Dataset[_]): DataFrame
  
  // Grouping and aggregation
  def groupBy(cols: Column*): RelationalGroupedDataset
  def groupBy(col1: String, cols: String*): RelationalGroupedDataset  
  def groupByKey[K : Encoder](func: T => K): KeyValueGroupedDataset[K, T]
  def rollup(cols: Column*): RelationalGroupedDataset
  def rollup(col1: String, cols: String*): RelationalGroupedDataset
  def cube(cols: Column*): RelationalGroupedDataset
  def cube(col1: String, cols: String*): RelationalGroupedDataset
  def agg(expr: Column, exprs: Column*): DataFrame
  def agg(exprs: Map[String, String]): DataFrame
  
  // Actions
  def show(): Unit
  def show(numRows: Int): Unit
  def show(numRows: Int, truncate: Boolean): Unit
  def show(numRows: Int, truncate: Int): Unit
  def show(numRows: Int, truncate: Int, vertical: Boolean): Unit
  def collect(): Array[T]
  def collectAsList(): java.util.List[T]
  def count(): Long
  def reduce(func: (T, T) => T): T
  def fold(zeroValue: T)(op: (T, T) => T): T
  def aggregate[U : Encoder](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
  def first(): T
  def head(): T
  def head(n: Int): Array[T]
  def take(n: Int): Array[T]
  def takeAsList(n: Int): java.util.List[T]
  def tail(n: Int): Array[T]
  def isEmpty: Boolean
  def localCheckpoint(): Dataset[T]
  def localCheckpoint(eager: Boolean): Dataset[T]
  def checkpoint(): Dataset[T]
  def checkpoint(eager: Boolean): Dataset[T]
  
  // Column operations  
  def apply(colName: String): Column
  def col(colName: String): Column
  def withColumn(colName: String, col: Column): DataFrame
  def withColumnRenamed(existingName: String, newName: String): DataFrame
  def withColumns(colsMap: Map[String, Column]): DataFrame
  def withColumnsRenamed(colsMap: Map[String, String]): DataFrame
  def drop(colName: String): DataFrame
  def drop(colNames: String*): DataFrame  
  def drop(col: Column): DataFrame
  def dropDuplicates(): Dataset[T]
  def dropDuplicates(colNames: Array[String]): Dataset[T]
  def dropDuplicates(colNames: Seq[String]): Dataset[T]
  def dropDuplicates(col1: String, cols: String*): Dataset[T]
  def distinct(): Dataset[T]
  
  // Type conversions
  def as[U : Encoder]: Dataset[U]
  def to(schema: StructType): DataFrame
  def toDF(): DataFrame
  def toDF(colNames: String*): DataFrame
  def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U]
  
  // Persistence and caching
  def cache(): Dataset[T]
  def persist(): Dataset[T] 
  def persist(newLevel: StorageLevel): Dataset[T]
  def unpersist(): Dataset[T]
  def unpersist(blocking: Boolean): Dataset[T]
  def storageLevel: StorageLevel
  
  // I/O operations
  def write: DataFrameWriter[T]
  def writeStream: DataStreamWriter[T]
  
  // Statistics and analysis
  def describe(cols: String*): DataFrame
  def summary(statistics: String*): DataFrame
  def stat: DataFrameStatFunctions
  def na: DataFrameNaFunctions
}

DataFrame Type Alias

type DataFrame = Dataset[Row]

RelationalGroupedDataset Class

class RelationalGroupedDataset protected[sql](
    df: DataFrame,
    groupingExprs: Seq[Expression],
    groupType: RelationalGroupedDataset.GroupType) {
  
  // Basic aggregations
  def agg(expr: Column, exprs: Column*): DataFrame
  def agg(exprs: Map[String, String]): DataFrame
  def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame
  def count(): DataFrame
  def sum(colNames: String*): DataFrame
  def avg(colNames: String*): DataFrame
  def max(colNames: String*): DataFrame
  def min(colNames: String*): DataFrame
  def mean(colNames: String*): DataFrame
  
  // Pivot operations
  def pivot(pivotColumn: String): RelationalGroupedDataset
  def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
  def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset
}

KeyValueGroupedDataset[K, V] Class

class KeyValueGroupedDataset[K, V] private[sql](
    kEncoder: Encoder[K],
    vEncoder: Encoder[V]) {
  
  // Typed aggregations
  def agg[U1: Encoder](col1: TypedColumn[V, U1]): Dataset[(K, U1)]
  def agg[U1: Encoder, U2: Encoder](
    col1: TypedColumn[V, U1], 
    col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
  def agg[U1: Encoder, U2: Encoder, U3: Encoder](
    col1: TypedColumn[V, U1], 
    col2: TypedColumn[V, U2], 
    col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)]
  def agg[U1: Encoder, U2: Encoder, U3: Encoder, U4: Encoder](
    col1: TypedColumn[V, U1], 
    col2: TypedColumn[V, U2], 
    col3: TypedColumn[V, U3], 
    col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)]
  
  // Group transformations
  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U]
  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
  def mapGroupsWithState[S: Encoder, U: Encoder](
    stateful: GroupStateTimeout)(
    func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
  def flatMapGroupsWithState[S: Encoder, U: Encoder](
    outputMode: OutputMode, 
    timeoutConf: GroupStateTimeout)(
    func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
  
  // Group reductions
  def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
  def cogroupApply[U, R: Encoder](
    other: KeyValueGroupedDataset[K, U])(
    f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R]
}

Column Class

class Column(val expr: Expression) extends Logging {
  // Arithmetic operations
  def +(other: Any): Column
  def -(other: Any): Column  
  def *(other: Any): Column
  def /(other: Any): Column
  def %(other: Any): Column
  
  // Comparison operations
  def ===(other: Any): Column
  def =!=(other: Any): Column
  def >(other: Any): Column
  def <(other: Any): Column
  def >=(other: Any): Column
  def <=(other: Any): Column
  def <->(other: Any): Column
  
  // Boolean operations
  def &&(other: Column): Column
  def ||(other: Column): Column
  def unary_!: Column
  
  // Null operations
  def isNull: Column
  def isNotNull: Column
  def isNaN: Column
  
  // String operations
  def contains(other: Any): Column
  def startsWith(other: Column): Column
  def startsWith(literal: String): Column  
  def endsWith(other: Column): Column
  def endsWith(literal: String): Column
  def like(literal: String): Column
  def rlike(literal: String): Column
  def substr(startPos: Column, len: Column): Column
  def substr(startPos: Int, len: Int): Column
  
  // Type operations
  def cast(to: DataType): Column
  def cast(to: String): Column
  def as(alias: String): Column
  def as(alias: Symbol): Column
  def as(aliases: Array[String]): Column
  def name(alias: String): Column
  
  // Sorting
  def asc: Column
  def asc_nulls_first: Column
  def asc_nulls_last: Column
  def desc: Column
  def desc_nulls_first: Column
  def desc_nulls_last: Column
  
  // Collection operations
  def getItem(key: Any): Column
  def getField(fieldName: String): Column
  
  // Window operations
  def over(): Column
  def over(window: WindowSpec): Column
}

Row Trait

trait Row extends Serializable {
  // Size and schema
  def size: Int
  def length: Int
  def schema: StructType
  
  // Generic access
  def get(i: Int): Any
  def getAs[T](i: Int): T
  def getAs[T](fieldName: String): T
  def fieldIndex(name: String): Int
  def getValuesMap[T](fieldNames: Seq[String]): Map[String, T]
  
  // Type-specific getters
  def isNullAt(i: Int): Boolean
  def getString(i: Int): String
  def getBoolean(i: Int): Boolean
  def getByte(i: Int): Byte  
  def getShort(i: Int): Short
  def getInt(i: Int): Int
  def getLong(i: Int): Long
  def getFloat(i: Int): Float
  def getDouble(i: Int): Double
  def getDecimal(i: Int): java.math.BigDecimal
  def getDate(i: Int): java.sql.Date
  def getTimestamp(i: Int): java.sql.Timestamp
  def getInstant(i: Int): java.time.Instant
  def getLocalDate(i: Int): java.time.LocalDate
  def getSeq[T](i: Int): Seq[T]
  def getList[T](i: Int): java.util.List[T]
  def getMap[K, V](i: Int): Map[K, V]
  def getJavaMap[K, V](i: Int): java.util.Map[K, V]
  def getStruct(i: Int): Row
  
  // Conversion operations
  def copy(): Row
  def toSeq: Seq[Any]
  def mkString: String
  def mkString(sep: String): String
  def mkString(start: String, sep: String, end: String): String
}

Usage Examples

Creating Datasets and DataFrames

import org.apache.spark.sql.{SparkSession, Dataset, DataFrame}
import org.apache.spark.sql.types._

// Create SparkSession
val spark = SparkSession.builder().appName("Dataset Examples").getOrCreate()
import spark.implicits._

// Create Dataset from case class
case class Person(name: String, age: Int, salary: Double)
val people = Seq(
  Person("Alice", 25, 50000.0),
  Person("Bob", 30, 60000.0),
  Person("Charlie", 35, 70000.0)
)
val ds: Dataset[Person] = spark.createDataset(people)

// Create DataFrame from Row data
val schema = StructType(Array(
  StructField("name", StringType, nullable = false),
  StructField("age", IntegerType, nullable = false),
  StructField("salary", DoubleType, nullable = false)
))
val rows = Seq(
  Row("Alice", 25, 50000.0),
  Row("Bob", 30, 60000.0),
  Row("Charlie", 35, 70000.0)
)
val df: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)

// Convert between Dataset and DataFrame
val dataFrame = ds.toDF()
val dataset = df.as[Person]

Basic Transformations

import org.apache.spark.sql.functions._

// Select operations
val names = ds.select($"name")
val nameAge = ds.select($"name", $"age")
val computed = ds.select($"name", ($"salary" * 1.1).as("adjusted_salary"))

// Filter operations
val youngPeople = ds.filter($"age" < 30)
val highEarners = ds.where($"salary" > 55000)
val complexFilter = ds.filter($"age" > 25 && $"salary" < 65000)

// Functional transformations
val upperNames = ds.map(person => person.copy(name = person.name.toUpperCase))
val salaryCategories = ds.map { person =>
  val category = if (person.salary > 55000) "High" else "Low"
  (person.name, category)
}

// Column operations
val withBonus = ds.withColumn("bonus", $"salary" * 0.1)
val renamed = ds.withColumnRenamed("salary", "annual_salary")
val dropped = ds.drop("salary")

Aggregations and Grouping

import org.apache.spark.sql.functions._

// Simple aggregations
val totalCount = ds.count()
val avgSalary = ds.agg(avg($"salary")).collect()(0)(0)
val salaryStats = ds.agg(
  min($"salary").as("min_salary"),
  max($"salary").as("max_salary"),
  avg($"salary").as("avg_salary"),
  stddev($"salary").as("stddev_salary")
)

// Grouping operations
val ageGroups = ds.groupBy($"age").agg(
  count($"name").as("count"),
  avg($"salary").as("avg_salary")
)

// Typed grouping
val salaryByAge = ds.groupByKey(_.age).agg(
  typed.avg[Person](_.salary),
  typed.count[Person](_)
)

// Window functions
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy($"age").orderBy($"salary".desc)
val rankedSalaries = ds.withColumn("rank", row_number().over(windowSpec))

// Pivot operations
val pivotData = ds.groupBy($"age").pivot("name").agg(sum($"salary"))

Joins and Set Operations

// Sample datasets for joins
case class Department(id: Int, name: String)
val departments = Seq(
  Department(1, "Engineering"),
  Department(2, "Sales"),
  Department(3, "Marketing")
)
val deptDS = spark.createDataset(departments)

case class Employee(name: String, age: Int, deptId: Int)
val employees = Seq(
  Employee("Alice", 25, 1),
  Employee("Bob", 30, 2), 
  Employee("Charlie", 35, 1)
)
val empDS = spark.createDataset(employees)

// Join operations
val innerJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"))
val leftJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "left")
val rightJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "right")
val outerJoin = empDS.join(deptDS, empDS("deptId") === deptDS("id"), "outer")

// Set operations
val moreEmployees = Seq(Employee("David", 28, 3), Employee("Eve", 32, 2))
val moreEmpDS = spark.createDataset(moreEmployees)

val combined = empDS.union(moreEmpDS)
val duplicatesRemoved = combined.distinct()
val intersection = empDS.intersect(moreEmpDS)
val difference = empDS.except(moreEmpDS)

Advanced Operations

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

// Complex transformations
val enrichedData = ds.transform { df =>
  df.withColumn("age_group", 
    when($"age" < 30, "Young")
    .when($"age" < 40, "Middle")
    .otherwise("Senior"))
    .withColumn("salary_percentile", 
      percent_rank().over(Window.orderBy($"salary")))
}

// Sampling and partitioning
val sample = ds.sample(0.5) // 50% sample
val stratified = ds.stat.sampleBy("age_group", Map("Young" -> 0.3, "Middle" -> 0.5, "Senior" -> 0.7), 123L)
val repartitioned = ds.repartition(4, $"age")
val coalesced = ds.coalesce(2)

// Null handling
val withNulls = ds.na.drop() // Drop rows with any nulls
val filled = ds.na.fill(Map("salary" -> 0.0, "name" -> "Unknown"))
val replaced = ds.na.replace("name", Map("Alice" -> "Alicia"))

// Statistical operations
val correlation = ds.stat.corr("age", "salary")
val covariance = ds.stat.cov("age", "salary") 
val crosstab = ds.stat.crosstab("age", "salary")
val freqItems = ds.stat.freqItems(Seq("name"), 0.4)

// Caching and persistence
import org.apache.spark.storage.StorageLevel
val cached = ds.cache()
val persisted = ds.persist(StorageLevel.MEMORY_AND_DISK_SER)

// Explain execution plan
ds.explain(true)
ds.explain("cost")

Working with Complex Data Types

import org.apache.spark.sql.functions._

// Array operations
case class PersonWithHobbies(name: String, age: Int, hobbies: Seq[String])
val peopleWithHobbies = Seq(
  PersonWithHobbies("Alice", 25, Seq("reading", "swimming")),
  PersonWithHobbies("Bob", 30, Seq("cooking", "hiking", "reading"))
)
val hobbiesDS = spark.createDataset(peopleWithHobbies)

val exploded = hobbiesDS.select($"name", explode($"hobbies").as("hobby"))
val arraySize = hobbiesDS.select($"name", size($"hobbies").as("hobby_count"))
val contains = hobbiesDS.filter(array_contains($"hobbies", "reading"))

// Struct operations
case class Address(street: String, city: String, zipCode: String)
case class PersonWithAddress(name: String, age: Int, address: Address)

val peopleWithAddr = Seq(
  PersonWithAddress("Alice", 25, Address("123 Main St", "Seattle", "98101")),
  PersonWithAddress("Bob", 30, Address("456 Oak Ave", "Portland", "97201"))
)
val addrDS = spark.createDataset(peopleWithAddr)

val cityOnly = addrDS.select($"name", $"address.city")
val flatStruct = addrDS.select($"name", $"age", $"address.*")

// Map operations  
val mapData = spark.range(3).select(
  map(lit("key1"), $"id", lit("key2"), $"id" * 2).as("map_col")
)
val keys = mapData.select(map_keys($"map_col"))
val values = mapData.select(map_values($"map_col"))