Data partitioning strategies for controlling how RDD elements are distributed across cluster nodes to optimize performance and minimize data shuffling in distributed operations.
Abstract base class defining how keys are distributed across partitions in key-value RDDs.
/**
* Object that defines how keys are distributed across partitions
*/
abstract class Partitioner extends Serializable {
/** Number of partitions */
def numPartitions: Int
/** Get partition index for given key */
def getPartition(key: Any): Int
/** Whether this partitioner guarantees same partition for equal keys */
def equals(other: Any): Boolean
/** Hash code for this partitioner */
def hashCode(): Int
}Default partitioner using hash function to distribute keys across partitions.
/**
* Partitioner that partitions using Java object hashCode
* @param partitions number of partitions
*/
class HashPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner => h.numPartitions == numPartitions
case _ => false
}
override def hashCode: Int = numPartitions
}Partitioner that distributes keys roughly evenly across partitions based on key ranges.
/**
* Partitioner that partitions sortable records by range into roughly equal ranges
* @param partitions number of partitions
* @param rdd RDD to sample for determining ranges
* @param ascending whether to sort keys in ascending order
*/
class RangePartitioner[K: Ordering: ClassTag, V](
partitions: Int,
rdd: RDD[_ <: Product2[K, V]],
private var ascending: Boolean = true,
val samplePointsPerPartitionHint: Int = 20
) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
var partition = 0
if (rangeBounds.length <= 128) {
// If we have less than 128 partitions naive search is faster
while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
partition += 1
}
} else {
// Use binary search for larger partition counts
partition = binarySearch(rangeBounds, k)
if (partition < 0) {
partition = -partition - 1
}
if (partition > rangeBounds.length) {
partition = rangeBounds.length
}
}
if (ascending) {
partition
} else {
rangeBounds.length - partition
}
}
override def equals(other: Any): Boolean = other match {
case r: RangePartitioner[_, _] =>
r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
case _ =>
false
}
override def hashCode(): Int = {
val prime = 31
var result = 1
var i = 0
while (i < rangeBounds.length) {
result = prime * result + rangeBounds(i).hashCode
i += 1
}
result = prime * result + ascending.hashCode
result
}
/** Range bounds array */
def rangeBounds: Array[K] = ???
}Create domain-specific partitioners for specialized data distribution patterns.
/**
* Example: Partitioner for geographic data
*/
class GeographicPartitioner(regions: Array[String]) extends Partitioner {
private val regionToIndex = regions.zipWithIndex.toMap
override def numPartitions: Int = regions.length
override def getPartition(key: Any): Int = key match {
case location: String =>
// Extract region from location string
val region = extractRegion(location)
regionToIndex.getOrElse(region, 0)
case _ => 0
}
private def extractRegion(location: String): String = {
// Custom logic to determine region from location
if (location.contains("US")) "North America"
else if (location.contains("EU")) "Europe"
else if (location.contains("AS")) "Asia"
else "Other"
}
}
/**
* Example: Partitioner for time-series data
*/
class TimePartitioner(timeRanges: Array[(Long, Long)]) extends Partitioner {
override def numPartitions: Int = timeRanges.length
override def getPartition(key: Any): Int = key match {
case timestamp: Long =>
timeRanges.zipWithIndex.find { case ((start, end), _) =>
timestamp >= start && timestamp < end
}.map(_._2).getOrElse(0)
case _ => 0
}
}
/**
* Example: Partitioner based on key prefix
*/
class PrefixPartitioner(prefixes: Array[String]) extends Partitioner {
override def numPartitions: Int = prefixes.length
override def getPartition(key: Any): Int = key match {
case str: String =>
prefixes.zipWithIndex
.find { case (prefix, _) => str.startsWith(prefix) }
.map(_._2)
.getOrElse(0)
case _ => 0
}
}Methods for controlling and querying RDD partitioning.
abstract class RDD[T: ClassTag] {
/** Get the partitioner if this RDD has one */
def partitioner: Option[Partitioner]
/** Get array of partitions */
def partitions: Array[Partition]
/** Repartition RDD using hash partitioner */
def repartition(numPartitions: Int): RDD[T]
/** Coalesce to fewer partitions without shuffling */
def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T]
/** Get preferred locations for each partition */
def preferredLocations(split: Partition): Seq[String]
}
class PairRDDFunctions[K, V](self: RDD[(K, V)]) {
/** Partition RDD using specified partitioner */
def partitionBy(partitioner: Partitioner): RDD[(K, V)]
/** Group by key with custom partitioner */
def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])]
/** Reduce by key with custom partitioner */
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)]
/** Join with custom partitioner */
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))]
/** Sort by key with custom partitioner */
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
(implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[(K, V)]
}Classes providing metadata about RDD partitions.
/**
* Identifier for a partition in an RDD
*/
trait Partition extends Serializable {
/** Partition index within its parent RDD */
def index: Int
/** Hash code based on index */
override def hashCode(): Int = index
/** Equality based on index */
override def equals(other: Any): Boolean = other match {
case that: Partition => this.index == that.index
case _ => false
}
}
/**
* Partition for HadoopRDD
*/
class HadoopPartition(rddId: Int, override val index: Int, inputSplit: InputSplit)
extends Partition {
def inputSplit: InputSplit = ???
override def hashCode(): Int = 41 * (41 + rddId) + index
}
/**
* Partition created from a range
*/
case class ParallelCollectionPartition[T: ClassTag](
override val index: Int,
start: Int,
end: Int,
values: Seq[T]
) extends PartitionUsage Examples:
import org.apache.spark.{SparkContext, SparkConf, HashPartitioner, RangePartitioner}
val sc = new SparkContext(new SparkConf().setAppName("Partitioning Example"))
// Create pair RDD
val data = sc.parallelize(Array(
("apple", 1), ("banana", 2), ("cherry", 3), ("apple", 4),
("banana", 5), ("date", 6), ("elderberry", 7), ("apple", 8)
))
println(s"Default partitions: ${data.partitions.length}")
println(s"Default partitioner: ${data.partitioner}")
// Hash partitioning
val hashPartitioned = data.partitionBy(new HashPartitioner(4))
println(s"Hash partitioner: ${hashPartitioned.partitioner}")
// Verify partitioning is preserved through transformations
val grouped = hashPartitioned.groupByKey() // No shuffle needed!
println(s"Grouped partitioner: ${grouped.partitioner}")
// Range partitioning for sortable keys
val numbers = sc.parallelize(Array(
(1, "one"), (5, "five"), (3, "three"), (9, "nine"),
(2, "two"), (7, "seven"), (4, "four"), (8, "eight")
))
val rangePartitioned = numbers.partitionBy(new RangePartitioner(3, numbers))
println(s"Range partitioner: ${rangePartitioned.partitioner}")
// Custom partitioner example
class EvenOddPartitioner extends Partitioner {
override def numPartitions: Int = 2
override def getPartition(key: Any): Int = key match {
case i: Int => if (i % 2 == 0) 0 else 1
case _ => 0
}
}
val evenOddPartitioned = numbers.partitionBy(new EvenOddPartitioner())
println("Even/Odd partitioning:")
evenOddPartitioned.glom().collect().zipWithIndex.foreach { case (partition, index) =>
println(s"Partition $index: ${partition.mkString(", ")}")
}
// Coalescing partitions
val manyPartitions = sc.parallelize(1 to 100, 20)
println(s"Many partitions: ${manyPartitions.partitions.length}")
val coalesced = manyPartitions.coalesce(5)
println(s"Coalesced partitions: ${coalesced.partitions.length}")
// Repartitioning
val repartitioned = manyPartitions.repartition(8)
println(s"Repartitioned: ${repartitioned.partitions.length}")
// Partition-aware operations
val partitionedData = sc.parallelize(Array(
("user1", "data1"), ("user2", "data2"), ("user1", "data3"), ("user3", "data4")
), 2).partitionBy(new HashPartitioner(2))
// This join will not cause a shuffle since both RDDs use same partitioner
val otherData = sc.parallelize(Array(
("user1", "profile1"), ("user2", "profile2"), ("user3", "profile3")
)).partitionBy(new HashPartitioner(2))
val joined = partitionedData.join(otherData) // No shuffle!
println("Joined data:")
joined.collect().foreach(println)
sc.stop()Java Examples:
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import scala.Tuple2;
import java.util.Arrays;
import java.util.List;
JavaSparkContext sc = new JavaSparkContext(
new SparkConf().setAppName("Java Partitioning Example")
);
// Create pair RDD
List<Tuple2<String, Integer>> data = Arrays.asList(
new Tuple2<>("apple", 1),
new Tuple2<>("banana", 2),
new Tuple2<>("apple", 3)
);
JavaPairRDD<String, Integer> pairRDD = sc.parallelizePairs(data);
// Hash partitioning
JavaPairRDD<String, Integer> partitioned = pairRDD.partitionBy(new HashPartitioner(2));
// Verify partitioning
System.out.println("Partitioner: " + partitioned.partitioner());
// Custom partitioner in Java
class CustomPartitioner extends Partitioner {
@Override
public int numPartitions() {
return 2;
}
@Override
public int getPartition(Object key) {
return key.toString().length() % 2;
}
}
JavaPairRDD<String, Integer> customPartitioned = pairRDD.partitionBy(new CustomPartitioner());
sc.close();// Without partitioning - causes shuffle
val rdd1 = sc.parallelize(data1)
val rdd2 = sc.parallelize(data2)
val joined = rdd1.join(rdd2) // Shuffle occurs
// With partitioning - no shuffle
val partitioned1 = rdd1.partitionBy(new HashPartitioner(4))
val partitioned2 = rdd2.partitionBy(new HashPartitioner(4))
val joined = partitioned1.join(partitioned2) // No shuffle!// Bad: Losing partitioning
val partitioned = rdd.partitionBy(new HashPartitioner(4))
val mapped = partitioned.map(x => (x._1.toUpperCase, x._2)) // Loses partitioning!
// Good: Preserving partitioning
val mapped = partitioned.mapValues(_.toUpperCase) // Preserves partitioning
// Bad: Uneven partitions
val skewed = rdd.partitionBy(new BadPartitioner()) // Some partitions much larger
// Good: Balanced partitions
val balanced = rdd.partitionBy(new HashPartitioner(numPartitions))Effective partitioning is crucial for Spark performance, reducing network overhead and enabling efficient distributed operations across your cluster.