Spark provides two types of shared variables for efficient data sharing across distributed computations: broadcast variables for read-only data and accumulators for write-only aggregations.
Broadcast variables allow efficient sharing of large read-only datasets across all nodes in a cluster.
abstract class Broadcast[T](id: Long) {
def value: T
def unpersist(): Unit
def unpersist(blocking: Boolean): Unit
def destroy(): Unit
def id: Long
def toString: String
}// From SparkContext
def broadcast[T: ClassTag](value: T): Broadcast[T]Modern accumulator API providing type-safe, efficient aggregation across distributed computations.
abstract class AccumulatorV2[IN, OUT] {
// Core Operations
def isZero: Boolean
def copy(): AccumulatorV2[IN, OUT]
def reset(): Unit
def add(v: IN): Unit
def merge(other: AccumulatorV2[IN, OUT]): Unit
def value: OUT
// Metadata
def name: Option[String]
def id: Long
def isRegistered: Boolean
}class LongAccumulator extends AccumulatorV2[java.lang.Long, java.lang.Long] {
def add(v: Long): Unit
def add(v: java.lang.Long): Unit
def count: Long
def sum: Long
def avg: Double
def value: java.lang.Long
// AccumulatorV2 implementation
def isZero: Boolean
def copy(): LongAccumulator
def reset(): Unit
def merge(other: AccumulatorV2[java.lang.Long, java.lang.Long]): Unit
}class DoubleAccumulator extends AccumulatorV2[java.lang.Double, java.lang.Double] {
def add(v: Double): Unit
def add(v: java.lang.Double): Unit
def count: Long
def sum: Double
def avg: Double
def value: java.lang.Double
// AccumulatorV2 implementation
def isZero: Boolean
def copy(): DoubleAccumulator
def reset(): Unit
def merge(other: AccumulatorV2[java.lang.Double, java.lang.Double]): Unit
}class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
def add(v: T): Unit
def value: java.util.List[T]
// AccumulatorV2 implementation
def isZero: Boolean
def copy(): CollectionAccumulator[T]
def reset(): Unit
def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit
}// Long accumulators
def longAccumulator(): LongAccumulator
def longAccumulator(name: String): LongAccumulator
// Double accumulators
def doubleAccumulator(): DoubleAccumulator
def doubleAccumulator(name: String): DoubleAccumulator
// Collection accumulators
def collectionAccumulator[T](): CollectionAccumulator[T]
def collectionAccumulator[T](name: String): CollectionAccumulator[T]
// Custom accumulators
def register[T](acc: AccumulatorV2[T, T]): Unit
def register[T](acc: AccumulatorV2[T, T], name: String): UnitCreating custom accumulator types by extending AccumulatorV2.
// Example: Set accumulator for collecting unique values
class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] {
private val _set = mutable.Set.empty[T]
def isZero: Boolean = _set.isEmpty
def copy(): SetAccumulator[T] = {
val newAcc = new SetAccumulator[T]
newAcc._set ++= _set
newAcc
}
def reset(): Unit = _set.clear()
def add(v: T): Unit = _set += v
def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = {
other match {
case set: SetAccumulator[T] => _set ++= set._set
case _ => throw new UnsupportedOperationException("Cannot merge different accumulator types")
}
}
def value: java.util.Set[T] = _set.asJava
}import org.apache.spark.broadcast.Broadcast
// Large lookup table that will be used across many tasks
val lookupTable = Map(
"user1" -> "John Doe",
"user2" -> "Jane Smith",
// ... thousands more entries
)
// Broadcast the lookup table
val broadcastLookup: Broadcast[Map[String, String]] = sc.broadcast(lookupTable)
// Use in transformations
val userIds = sc.parallelize(Array("user1", "user2", "user1", "user3"))
val enrichedData = userIds.map { userId =>
val lookup = broadcastLookup.value // Access broadcast value
val userName = lookup.getOrElse(userId, "Unknown")
(userId, userName)
}
val result = enrichedData.collect()
// Result: Array((user1,John Doe), (user2,Jane Smith), (user1,John Doe), (user3,Unknown))
// Clean up when done
broadcastLookup.unpersist()val data = sc.parallelize(1 to 1000)
// Create accumulator for counting even numbers
val evenCount = sc.longAccumulator("Even Numbers")
// Use in transformation
val processed = data.map { num =>
if (num % 2 == 0) {
evenCount.add(1) // Accumulate even numbers
}
num * num
}
// Trigger action to execute transformations
val result = processed.collect()
// Access accumulator value
println(s"Found ${evenCount.value} even numbers")val textData = sc.parallelize(Array("error: failed", "info: success", "error: timeout", "debug: trace"))
// Accumulator to collect all error messages
val errorMessages = sc.collectionAccumulator[String]("Error Messages")
// Process data and collect errors
val processedData = textData.map { line =>
if (line.startsWith("error:")) {
errorMessages.add(line) // Collect error messages
}
line.toUpperCase
}
// Trigger action
processedData.count()
// Access collected errors
val errors = errorMessages.value
println(s"Found ${errors.size()} errors: ${errors}")// Register custom accumulator
val uniqueWords = new SetAccumulator[String]
sc.register(uniqueWords, "Unique Words")
val sentences = sc.parallelize(Array(
"hello world",
"world of spark",
"hello spark"
))
// Use custom accumulator
val wordCounts = sentences.flatMap(_.split(" ")).map { word =>
uniqueWords.add(word) // Collect unique words
(word, 1)
}.reduceByKey(_ + _)
// Trigger action
val counts = wordCounts.collect()
// Access unique words
val unique = uniqueWords.value
println(s"Found ${unique.size()} unique words: ${unique}")val malformedRecords = sc.longAccumulator("Malformed Records")
val validRecords = sc.longAccumulator("Valid Records")
val errorDetails = sc.collectionAccumulator[String]("Error Details")
val processedData = rawData.map { record =>
try {
val parsed = parseRecord(record)
validRecords.add(1)
parsed
} catch {
case e: Exception =>
malformedRecords.add(1)
errorDetails.add(s"Error parsing '$record': ${e.getMessage}")
null
}
}.filter(_ != null)
processedData.count()
println(s"Valid: ${validRecords.value}, Malformed: ${malformedRecords.value}")
errorDetails.value.foreach(println)val processingTime = sc.doubleAccumulator("Processing Time (ms)")
val recordsProcessed = sc.longAccumulator("Records Processed")
val result = data.map { record =>
val start = System.currentTimeMillis()
val processed = expensiveProcessing(record)
val elapsed = System.currentTimeMillis() - start
processingTime.add(elapsed.toDouble)
recordsProcessed.add(1)
processed
}
result.count()
println(f"Average processing time: ${processingTime.value / recordsProcessed.value}%.2f ms per record")unpersist() when no longer needed