or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

index.mdio-operations.mdkey-value-operations.mdpartitioning-shuffling.mdrdd-operations.mdshared-variables.mdspark-context.mdstorage-persistence.md
tile.json

partitioning-shuffling.mddocs/

Partitioning and Shuffling

Partitioning strategies and shuffle operations for controlling data distribution and optimizing performance across cluster nodes in Spark applications.

Capabilities

Partitioner

Abstract base class defining how elements are partitioned by key across cluster nodes.

/**
 * An object that defines how the elements in a key-value pair RDD are partitioned by key
 */
abstract class Partitioner extends Serializable {
  /** Return the number of partitions in this partitioner */
  def numPartitions: Int
  
  /** Return the partition id for the given key */
  def getPartition(key: Any): Int
  
  /** Test whether this partitioner is equal to another object */
  override def equals(other: Any): Boolean
  
  /** Return a hash code for this partitioner */
  override def hashCode: Int
}

// Built-in partitioners
class HashPartitioner(partitions: Int) extends Partitioner {
  require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
  
  def numPartitions: Int = partitions
  
  def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }
  
  override def equals(other: Any): Boolean = other match {
    case h: HashPartitioner => h.numPartitions == numPartitions
    case _ => false
  }
  
  override def hashCode: Int = numPartitions
}

class RangePartitioner[K : Ordering : ClassTag, V](
  partitions: Int,
  rdd: RDD[(K, V)],
  ascending: Boolean = true) extends Partitioner {
  
  def numPartitions: Int = partitions
  
  def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    // Binary search to find the partition
    // Implementation details...
  }
}

Partitioning Methods (RDD)

Methods for controlling RDD partitioning and data distribution.

// Partitioning methods available on RDDs
def partitions: Array[Partition]
def getNumPartitions: Int
def partitioner: Option[Partitioner]

// Repartitioning methods
def repartition(numPartitions: Int): RDD[T]
def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T]

// Key-value partitioning (available on pair RDDs)
def partitionBy(partitioner: Partitioner): RDD[(K, V)]
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)]

// Partition-level operations
def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]
def mapPartitionsWithIndex[U: ClassTag](f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]
def foreachPartition(f: Iterator[T] => Unit): Unit
def glom(): RDD[Array[T]]

// Partition sampling and inspection
def sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.random.nextLong): RDD[T]
def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T]

Usage Examples:

import org.apache.spark.{SparkContext, SparkConf, HashPartitioner, RangePartitioner}

val sc = new SparkContext(new SparkConf().setAppName("Partitioning Example").setMaster("local[*]"))

// Create sample data
val data = sc.parallelize(1 to 1000, numSlices = 4) // 4 initial partitions
val keyValueData = data.map(x => (x % 10, x)) // (key, value) pairs

// Check current partitioning
println(s"Number of partitions: ${keyValueData.getNumPartitions}")
println(s"Current partitioner: ${keyValueData.partitioner}")

// Hash partitioning
val hashPartitioned = keyValueData.partitionBy(new HashPartitioner(8))
println(s"Hash partitioned: ${hashPartitioned.getNumPartitions} partitions")

// Range partitioning
val rangePartitioned = keyValueData.partitionBy(new RangePartitioner(6, keyValueData))
println(s"Range partitioned: ${rangePartitioned.getNumPartitions} partitions")

// Repartitioning vs coalescing
val repartitioned = data.repartition(8) // Triggers shuffle
val coalesced = data.coalesce(2) // May avoid shuffle

// Inspect partition contents
val partitionContents = keyValueData.glom().collect()
partitionContents.zipWithIndex.foreach { case (partition, index) =>
  println(s"Partition $index: ${partition.take(5).mkString(", ")}...")
}

// Process partitions individually
val partitionSums = data.mapPartitions { iterator =>
  val sum = iterator.sum
  Iterator(sum)
}

Custom Partitioners

Creating custom partitioning strategies for specific use cases.

// Example custom partitioners
class DomainPartitioner(numPartitions: Int) extends Partitioner {
  override def numPartitions: Int = numPartitions
  
  override def getPartition(key: Any): Int = {
    val domain = key.toString
    domain match {
      case d if d.endsWith(".com") => 0
      case d if d.endsWith(".org") => 1
      case d if d.endsWith(".edu") => 2
      case _ => 3
    }
  }
}

class GeographicPartitioner(regions: Map[String, Int]) extends Partitioner {
  override def numPartitions: Int = regions.size
  
  override def getPartition(key: Any): Int = {
    val location = key.toString
    regions.getOrElse(location.substring(0, 2), 0) // First 2 chars as region
  }
}

class CustomRangePartitioner[K](
  ranges: Array[K],
  numPartitions: Int)(implicit ord: Ordering[K]) extends Partitioner {
  
  override def numPartitions: Int = numPartitions
  
  override def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    val partition = ranges.indexWhere(ord.gteq(k, _))
    if (partition < 0) ranges.length - 1 else partition
  }
}

Custom Partitioner Examples:

// Domain-based partitioning for web logs
val webLogs = sc.textFile("web-logs.txt")
  .map(parseLogEntry) // Returns (domain, logEntry)
  .partitionBy(new DomainPartitioner(4))

// Geographic partitioning for location data
val regionMap = Map("US" -> 0, "EU" -> 1, "AS" -> 2, "OT" -> 3)
val locationData = sc.textFile("locations.txt")
  .map(parseLocation) // Returns (country, locationInfo)
  .partitionBy(new GeographicPartitioner(regionMap))

// Custom range partitioning for time series data
val timeRanges = Array(
  "2023-01-01", "2023-04-01", "2023-07-01", "2023-10-01"
)
val timeSeriesData = sc.textFile("timeseries.txt")
  .map(parseTimeEntry) // Returns (date, data)
  .partitionBy(new CustomRangePartitioner(timeRanges, 4))

Shuffle Operations

Operations that trigger data shuffling across the cluster.

// Operations that typically trigger shuffles:

// Repartitioning operations
def repartition(numPartitions: Int): RDD[T] // Always shuffles
def partitionBy(partitioner: Partitioner): RDD[(K, V)] // Shuffles if different partitioner

// Aggregation operations (on key-value RDDs)
def groupByKey(): RDD[(K, Iterable[V])] // Shuffles
def reduceByKey(func: (V, V) => V): RDD[(K, V)] // Shuffles but with pre-aggregation
def aggregateByKey[U](zeroValue: U)(seqOp: (U, V) => U, combOp: (U, U) => U): RDD[(K, U)] // Shuffles
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C): RDD[(K, C)] // Shuffles

// Join operations
def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] // Shuffles both RDDs
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] // Shuffles
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] // Shuffles
def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] // Shuffles
def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] // Shuffles

// Sorting operations
def sortByKey(ascending: Boolean = true): RDD[(K, V)] // Shuffles
def sortBy[K](f: T => K, ascending: Boolean = true): RDD[T] // Shuffles

// Set operations
def distinct(): RDD[T] // Shuffles
def intersection(other: RDD[T]): RDD[T] // Shuffles
def subtract(other: RDD[T]): RDD[T] // Shuffles

Performance Optimization Strategies

Minimizing Shuffles

// INEFFICIENT: Multiple shuffles
val inefficientPipeline = data
  .groupByKey() // Shuffle 1
  .mapValues(_.sum) // No shuffle
  .filter(_._2 > 100) // No shuffle
  .sortByKey() // Shuffle 2

// EFFICIENT: Fewer shuffles with better operations
val efficientPipeline = data
  .reduceByKey(_ + _) // Shuffle 1 with pre-aggregation
  .filter(_._2 > 100) // No shuffle
  .sortByKey() // Shuffle 2

// MOST EFFICIENT: Single shuffle with repartitionAndSortWithinPartitions
val mostEfficientPipeline = data
  .reduceByKey(_ + _) // Shuffle 1 with pre-aggregation
  .filter(_._2 > 100) // No shuffle
  .repartitionAndSortWithinPartitions(new HashPartitioner(8)) // Combined repartition + sort

Partition-Aware Operations

// Leverage partitioning for efficient joins
val users = sc.textFile("users.txt")
  .map(parseUser) // (userId, userInfo)
  .partitionBy(new HashPartitioner(8))
  .cache() // Cache partitioned data

val orders = sc.textFile("orders.txt")
  .map(parseOrder) // (userId, orderInfo)
  .partitionBy(new HashPartitioner(8)) // Same partitioner as users

// Efficient join - no shuffle needed since both RDDs have same partitioner
val userOrders = users.join(orders)

// Efficient grouped operations
val userStats = orders
  .mapValues(order => (1, order.amount)) // (userId, (count, amount))
  .reduceByKey((a, b) => (a._1 + b._1, a._2 + b._2)) // Pre-aggregated
  .mapValues(stats => (stats._1, stats._2 / stats._1)) // (count, avgAmount)

Advanced Partitioning Patterns

// Multi-level partitioning for hierarchical data
class HierarchicalPartitioner(
  primaryPartitions: Int,
  secondaryPartitions: Int) extends Partitioner {
  
  override def numPartitions: Int = primaryPartitions * secondaryPartitions
  
  override def getPartition(key: Any): Int = {
    val (primary, secondary) = key.asInstanceOf[(String, String)]
    val primaryPartition = math.abs(primary.hashCode) % primaryPartitions
    val secondaryPartition = math.abs(secondary.hashCode) % secondaryPartitions
    primaryPartition * secondaryPartitions + secondaryPartition
  }
}

// Usage for multi-dimensional data
val hierarchicalData = sc.textFile("hierarchical.txt")
  .map(parseHierarchical) // Returns ((region, category), data)
  .partitionBy(new HierarchicalPartitioner(4, 4)) // 16 total partitions

// Skew-aware partitioning
class SkewAwarePartitioner[K](
  rdd: RDD[(K, _)],
  numPartitions: Int,
  sampleFraction: Double = 0.1) extends Partitioner {
  
  // Sample data to identify skewed keys
  private val keyFrequencies = rdd.sample(false, sampleFraction)
    .map(_._1)
    .countByValue()
  
  private val heavyKeys = keyFrequencies
    .filter(_._2 > keyFrequencies.values.sum / numPartitions * 2)
    .keySet
  
  override def numPartitions: Int = numPartitions
  
  override def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    if (heavyKeys.contains(k)) {
      // Distribute heavy keys across multiple partitions
      (k.hashCode() & Integer.MAX_VALUE) % (numPartitions / 2)
    } else {
      // Regular partitioning for normal keys
      ((k.hashCode() & Integer.MAX_VALUE) % (numPartitions / 2)) + (numPartitions / 2)
    }
  }
}

Monitoring Shuffle Performance

// Monitor shuffle operations
def analyzeShuffleMetrics(sc: SparkContext): Unit = {
  val statusTracker = sc.statusTracker
  
  statusTracker.getActiveStageIds().foreach { stageId =>
    statusTracker.getStageInfo(stageId) match {
      case Some(stageInfo) =>
        println(s"Stage $stageId:")
        println(s"  Shuffle Read Bytes: ${stageInfo.shuffleReadBytes}")
        println(s"  Shuffle Write Bytes: ${stageInfo.shuffleWriteBytes}")
        println(s"  Shuffle Read Records: ${stageInfo.shuffleReadRecords}")
        println(s"  Shuffle Write Records: ${stageInfo.shuffleWriteRecords}")
      case None =>
        println(s"No info available for stage $stageId")
    }
  }
}

// Custom shuffle metrics collection
def trackShuffleOperations[T](rdd: RDD[T], operationName: String): RDD[T] = {
  val startTime = System.currentTimeMillis()
  val result = rdd.cache() // Force evaluation
  result.count() // Trigger action
  val endTime = System.currentTimeMillis()
  
  println(s"Operation '$operationName' took ${endTime - startTime}ms")
  println(s"Result partitions: ${result.getNumPartitions}")
  
  result
}

// Usage
val shuffledData = trackShuffleOperations(
  originalData.groupByKey(),
  "groupByKey operation"
)

Advanced Optimization Techniques

Broadcast Hash Joins

// Optimize small table joins using broadcast
def broadcastHashJoin[K, V, W](
  largeRDD: RDD[(K, V)],
  smallRDD: RDD[(K, W)])(implicit sc: SparkContext): RDD[(K, (V, W))] = {
  
  // Collect small RDD to driver and broadcast
  val smallMap = smallRDD.collectAsMap()
  val broadcastSmallMap = sc.broadcast(smallMap)
  
  // Perform map-side join
  largeRDD.flatMap { case (key, value) =>
    broadcastSmallMap.value.get(key) match {
      case Some(smallValue) => Some((key, (value, smallValue)))
      case None => None
    }
  }
}

// Usage
val largeTx = sc.textFile("large-transactions.txt").map(parseTx)
val smallLookup = sc.textFile("small-lookup.txt").map(parseLookup)

val joined = broadcastHashJoin(largeTx, smallLookup)

Dynamic Partition Pruning

// Prune partitions based on predicates
def prunePartitions[T](
  rdd: RDD[T],
  partitionPredicate: Int => Boolean): RDD[T] = {
  
  val filteredPartitions = rdd.partitions.zipWithIndex
    .filter { case (_, index) => partitionPredicate(index) }
    .map(_._1)
  
  new PartitionPrunedRDD(rdd, filteredPartitions)
}

// Custom RDD that only processes selected partitions
class PartitionPrunedRDD[T](
  parent: RDD[T],
  selectedPartitions: Array[Partition]) extends RDD[T](parent) {
  
  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    parent.compute(split, context)
  }
  
  override protected def getPartitions: Array[Partition] = selectedPartitions
}

Partition-Level Caching Strategy

// Smart caching based on partition access patterns
class SmartCacheManager[T](rdd: RDD[T]) {
  private var accessCounts = Array.fill(rdd.getNumPartitions)(0L)
  private var cachedPartitions = Set.empty[Int]
  
  def accessPartition(partitionId: Int): Unit = {
    accessCounts(partitionId) += 1
    
    // Cache frequently accessed partitions
    if (accessCounts(partitionId) > 10 && !cachedPartitions.contains(partitionId)) {
      cachePartition(partitionId)
    }
  }
  
  private def cachePartition(partitionId: Int): Unit = {
    // Implementation would use custom caching logic
    cachedPartitions += partitionId
    println(s"Cached partition $partitionId")
  }
  
  def getAccessStats: Array[Long] = accessCounts.clone()
}

Best Practices

Choosing the Right Partitioner

// Guidelines for partitioner selection
def choosePartitioner[K, V](
  rdd: RDD[(K, V)],
  operationType: String,
  dataCharacteristics: Map[String, Any]): Partitioner = {
  
  val numPartitions = dataCharacteristics.getOrElse("partitions", 200).asInstanceOf[Int]
  val isSkewed = dataCharacteristics.getOrElse("skewed", false).asInstanceOf[Boolean]
  val isSorted = dataCharacteristics.getOrElse("sorted", false).asInstanceOf[Boolean]
  
  (operationType, isSkewed, isSorted) match {
    case ("join", false, _) => new HashPartitioner(numPartitions)
    case ("join", true, _) => new SkewAwarePartitioner(rdd, numPartitions)
    case ("sort", _, false) => new RangePartitioner(numPartitions, rdd)
    case ("sort", _, true) => new HashPartitioner(numPartitions) // Already sorted
    case ("groupBy", false, _) => new HashPartitioner(numPartitions)
    case ("groupBy", true, _) => new SkewAwarePartitioner(rdd, numPartitions)
    case _ => new HashPartitioner(numPartitions) // Default
  }
}

Partition Size Guidelines

// Calculate optimal partition count
def calculateOptimalPartitions(
  dataSize: Long,
  targetPartitionSize: Long = 128 * 1024 * 1024, // 128MB
  maxPartitions: Int = 2000): Int = {
  
  val calculatedPartitions = (dataSize / targetPartitionSize).toInt
  val cores = Runtime.getRuntime.availableProcessors()
  val minPartitions = cores * 2 // At least 2 partitions per core
  
  math.min(maxPartitions, math.max(minPartitions, calculatedPartitions))
}

// Monitor partition sizes
def analyzePartitionSizes[T](rdd: RDD[T]): Unit = {
  val partitionSizes = rdd.mapPartitions { iter =>
    Iterator(iter.size)
  }.collect()
  
  val avgSize = partitionSizes.sum.toDouble / partitionSizes.length
  val maxSize = partitionSizes.max
  val minSize = partitionSizes.min
  val skewRatio = maxSize.toDouble / avgSize
  
  println(s"Partition Analysis:")
  println(s"  Count: ${partitionSizes.length}")
  println(s"  Average size: $avgSize")
  println(s"  Max size: $maxSize")
  println(s"  Min size: $minSize")
  println(s"  Skew ratio: $skewRatio")
  
  if (skewRatio > 2.0) {
    println("  WARNING: High partition skew detected!")
  }
}

Memory-Efficient Partitioning

// Memory-aware partition processing
def processPartitionsWithMemoryControl[T, U](
  rdd: RDD[T],
  processFunc: Iterator[T] => Iterator[U],
  maxMemoryPerPartition: Long = 512 * 1024 * 1024): RDD[U] = { // 512MB
  
  rdd.mapPartitions { partition =>
    val runtime = Runtime.getRuntime
    val initialMemory = runtime.totalMemory() - runtime.freeMemory()
    
    val bufferedPartition = partition.grouped(1000) // Process in batches
    
    bufferedPartition.flatMap { batch =>
      val currentMemory = runtime.totalMemory() - runtime.freeMemory()
      
      if (currentMemory - initialMemory > maxMemoryPerPartition) {
        System.gc() // Suggest garbage collection
        Thread.sleep(100) // Allow GC to run
      }
      
      processFunc(batch.iterator)
    }
  }
}