Broadcast variables and accumulators for efficient data sharing and distributed counting across cluster nodes in Spark applications.
Read-only variables cached on each machine rather than shipping with tasks, providing efficient sharing of large datasets across all nodes.
/**
* A broadcast variable created with SparkContext.broadcast()
*/
abstract class Broadcast[T](val id: Long) extends Serializable {
/** Get the broadcasted value */
def value: T
/** Asynchronously delete cached copies of this broadcast on the executors */
def unpersist(): Unit
/** Asynchronously delete cached copies of this broadcast on the executors */
def unpersist(blocking: Boolean): Unit
/** Destroy all data and metadata related to this broadcast variable */
def destroy(): Unit
/** Whether this broadcast is valid */
def isValid: Boolean
override def toString: String = "Broadcast(" + id + ")"
}
// SparkContext methods for creating broadcast variables
def broadcast[T: ClassTag](value: T): Broadcast[T]Usage Examples:
import org.apache.spark.{SparkContext, SparkConf}
val sc = new SparkContext(new SparkConf().setAppName("Broadcast Example"))
// Create a large lookup table
val lookupTable = Map(
"US" -> "United States",
"UK" -> "United Kingdom",
"DE" -> "Germany",
"FR" -> "France"
// ... thousands more entries
)
// Broadcast the lookup table to all nodes
val broadcastLookup = sc.broadcast(lookupTable)
// Use broadcast variable in transformations
val countryData = sc.textFile("hdfs://country_codes.txt")
val countryNames = countryData.map { code =>
val lookup = broadcastLookup.value // Access broadcast value
lookup.getOrElse(code, "Unknown")
}
// Clean up when done
broadcastLookup.unpersist()
// broadcastLookup.destroy() // Only if completely done with variable
// Example with configuration
val config = Map(
"apiUrl" -> "https://api.example.com",
"timeout" -> "30000",
"retries" -> "3"
)
val broadcastConfig = sc.broadcast(config)
val processedData = inputRDD.mapPartitions { partition =>
val conf = broadcastConfig.value
// Use configuration for processing each partition
partition.map(processRecord(_, conf))
}Shared variables that can be "added" to through associative operations, providing efficient distributed counters and collectors.
/**
* A shared variable that can be accumulated (i.e., has an associative and commutative "add" operation)
*/
class Accumulator[T](initialValue: T, param: AccumulatorParam[T], name: Option[String] = None) extends Serializable {
/** Get the current value of this accumulator from within a task */
def value: T
/** Set the accumulator's value; only the driver can call this */
def setValue(newValue: T): Unit
/** Add a value to this accumulator */
def add(term: T): Unit
/** The += operator; can be used to add to the accumulator */
def +=(term: T): Unit = add(term)
/** Merge two accumulators together */
def ++(other: Accumulator[T]): Accumulator[T]
/** Access the accumulator's current value; only the driver should call this */
def localValue: T = value
override def toString: String = if (name.isDefined) name.get else localValue.toString
}
/**
* A more general version of Accumulator where the result type differs from the element type
*/
class Accumulable[R, T](initialValue: R, param: AccumulableParam[R, T], name: Option[String] = None) extends Serializable {
/** Get the current value */
def value: R
/** Set the value; only the driver can call this */
def setValue(newValue: R): Unit
/** Add a term to this accumulable */
def add(term: T): Unit
/** The += operator */
def +=(term: T): Unit = add(term)
/** Add to the accumulator (alternative to +=) */
def ++=(term: T): Unit = add(term)
/** Merge with another Accumulable */
def ++(other: Accumulable[R, T]): Accumulable[R, T]
/** Access the current value; only the driver should call this */
def localValue: R = value
override def toString: String = if (name.isDefined) name.get else localValue.toString
}
// SparkContext methods for creating accumulators
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T]
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]): Accumulator[T]
def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]): Accumulable[T, R]
def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]): Accumulable[T, R]Usage Examples:
import org.apache.spark.{SparkContext, SparkConf}
val sc = new SparkContext(new SparkConf().setAppName("Accumulator Example"))
// Basic numeric accumulator
val errorCount = sc.accumulator(0, "Error Count")
val sumAccum = sc.accumulator(0.0, "Sum Accumulator")
val data = sc.parallelize(1 to 1000)
// Use accumulators in transformations
val processedData = data.map { value =>
try {
if (value % 100 == 0) {
throw new RuntimeException(s"Error processing $value")
}
sumAccum += value.toDouble
value * 2
} catch {
case e: Exception =>
errorCount += 1
-1 // Error marker
}
}
// Trigger action to execute transformations
val results = processedData.filter(_ != -1).collect()
// Read accumulator values (only on driver)
println(s"Processed ${results.length} items")
println(s"Encountered ${errorCount.value} errors")
println(s"Sum of successful values: ${sumAccum.value}")
// Collection accumulator example
val uniqueWords = sc.accumulable(Set.empty[String])
val text = sc.textFile("hdfs://input.txt")
text.flatMap(_.split(" ")).foreach { word =>
uniqueWords += word
}
println(s"Unique words found: ${uniqueWords.value.size}")Interfaces defining how accumulator values are combined.
/**
* A trait that defines how to accumulate values of type T
*/
trait AccumulatorParam[T] extends Serializable {
/** Add two values together and return a new value */
def addInPlace(r1: T, r2: T): T
/** Return the "zero" value for this type */
def zero(initialValue: T): T
}
/**
* A trait that defines how to accumulate values of type T into type R
*/
trait AccumulableParam[R, T] extends Serializable {
/** Add a T value to an R accumulator and return a new R */
def addAccumulator(r: R, t: T): R
/** Add two R values together and return a new R */
def addInPlace(r1: R, r2: R): R
/** Return the "zero" value for type R */
def zero(initialValue: R): R
}
// Built-in accumulator parameters
object AccumulatorParam {
implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def addInPlace(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int): Int = 0
}
implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
def addInPlace(t1: Long, t2: Long): Long = t1 + t2
def zero(initialValue: Long): Long = 0L
}
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double): Double = 0.0
}
implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
def addInPlace(t1: Float, t2: Float): Float = t1 + t2
def zero(initialValue: Float): Float = 0.0f
}
}Custom Accumulator Examples:
// Custom accumulator for collecting statistics
case class Stats(count: Long, sum: Double, min: Double, max: Double)
implicit object StatsAccumulatorParam extends AccumulatorParam[Stats] {
def zero(initialValue: Stats): Stats = Stats(0, 0.0, Double.MaxValue, Double.MinValue)
def addInPlace(s1: Stats, s2: Stats): Stats = {
if (s1.count == 0) s2
else if (s2.count == 0) s1
else Stats(
s1.count + s2.count,
s1.sum + s2.sum,
math.min(s1.min, s2.min),
math.max(s1.max, s2.max)
)
}
}
// Custom accumulable for collecting unique items
implicit object SetAccumulableParam extends AccumulableParam[Set[String], String] {
def zero(initialValue: Set[String]): Set[String] = Set.empty
def addAccumulator(set: Set[String], item: String): Set[String] = set + item
def addInPlace(set1: Set[String], set2: Set[String]): Set[String] = set1 ++ set2
}
// Usage
val stats = sc.accumulator(Stats(0, 0.0, Double.MaxValue, Double.MinValue))
val uniqueItems = sc.accumulable(Set.empty[String])
val data = sc.parallelize(Array(1.5, 2.7, 3.1, 4.9, 2.7))
data.foreach { value =>
stats += Stats(1, value, value, value)
uniqueItems += value.toString
}
val finalStats = stats.value
println(s"Count: ${finalStats.count}, Avg: ${finalStats.sum / finalStats.count}")
println(s"Unique values: ${uniqueItems.value}")// Configuration and monitoring pattern
val config = Map("threshold" -> 100, "maxRetries" -> 3)
val broadcastConfig = sc.broadcast(config)
val successCount = sc.accumulator(0, "Successful Operations")
val failureCount = sc.accumulator(0, "Failed Operations")
val retryCount = sc.accumulator(0, "Retry Count")
val results = inputData.mapPartitions { partition =>
val conf = broadcastConfig.value
val threshold = conf("threshold")
val maxRetries = conf("maxRetries")
partition.map { record =>
var attempts = 0
var success = false
var result: Option[String] = None
while (attempts < maxRetries && !success) {
try {
if (record.value > threshold) {
result = Some(processRecord(record))
success = true
successCount += 1
} else {
failureCount += 1
success = true // Don't retry for threshold failures
}
} catch {
case _: Exception =>
attempts += 1
retryCount += 1
if (attempts >= maxRetries) {
failureCount += 1
}
}
}
result
}.filter(_.isDefined).map(_.get)
}
// Trigger execution and report metrics
val finalResults = results.collect()
println(s"Processed: ${finalResults.length}")
println(s"Successful: ${successCount.value}")
println(s"Failed: ${failureCount.value}")
println(s"Retries: ${retryCount.value}")// Performance monitoring accumulators
val processingTimeAccum = sc.accumulator(0L, "Total Processing Time")
val recordsProcessedAccum = sc.accumulator(0L, "Records Processed")
val partitionStatsAccum = sc.accumulable(Map.empty[Int, (Long, Long)])
val monitoredData = inputRDD.mapPartitionsWithIndex { (partitionId, partition) =>
val startTime = System.currentTimeMillis()
var recordCount = 0L
val processedPartition = partition.map { record =>
recordCount += 1
recordsProcessedAccum += 1
// Process record
processRecord(record)
}.toList
val endTime = System.currentTimeMillis()
val processingTime = endTime - startTime
processingTimeAccum += processingTime
partitionStatsAccum += Map(partitionId -> (recordCount, processingTime))
processedPartition.iterator
}
// Trigger execution
val results = monitoredData.collect()
// Analyze performance
val totalTime = processingTimeAccum.value
val totalRecords = recordsProcessedAccum.value
val avgTimePerRecord = totalTime.toDouble / totalRecords
println(s"Total processing time: ${totalTime}ms")
println(s"Average time per record: ${avgTimePerRecord}ms")
println(s"Partition stats: ${partitionStatsAccum.value}")case class ProcessingError(partitionId: Int, recordId: String, error: String, timestamp: Long)
// Custom accumulator for collecting errors
implicit object ErrorAccumulableParam extends AccumulableParam[List[ProcessingError], ProcessingError] {
def zero(initialValue: List[ProcessingError]): List[ProcessingError] = List.empty
def addAccumulator(list: List[ProcessingError], error: ProcessingError): List[ProcessingError] = error :: list
def addInPlace(list1: List[ProcessingError], list2: List[ProcessingError]): List[ProcessingError] = list1 ++ list2
}
val errorCollector = sc.accumulable(List.empty[ProcessingError])
val processedData = inputRDD.mapPartitionsWithIndex { (partitionId, partition) =>
partition.map { record =>
try {
processRecord(record)
} catch {
case e: Exception =>
val error = ProcessingError(
partitionId = partitionId,
recordId = record.id,
error = e.getMessage,
timestamp = System.currentTimeMillis()
)
errorCollector += error
None
}
}.filter(_.isDefined).map(_.get)
}
// Execute and collect errors
val results = processedData.collect()
val errors = errorCollector.value
println(s"Successfully processed: ${results.length}")
println(s"Errors encountered: ${errors.length}")
errors.foreach(error => println(s"Error in partition ${error.partitionId}: ${error.error}"))// 1. Broadcast large read-only data structures
val largeLookupTable = loadLookupTable() // Assume this is large
val broadcastTable = sc.broadcast(largeLookupTable)
// 2. Reuse broadcast variables across multiple operations
val enrichedData1 = data1.map(enrichWithLookup(_, broadcastTable.value))
val enrichedData2 = data2.map(enrichWithLookup(_, broadcastTable.value))
// 3. Clean up when done
broadcastTable.unpersist() // Remove from memory
broadcastTable.destroy() // Complete cleanup// 1. Only update accumulators inside actions or transformations
val counter = sc.accumulator(0)
// CORRECT: Update in transformation that leads to action
val results = data.map { value =>
if (someCondition(value)) counter += 1
processValue(value)
}.collect() // Action triggers execution
// 2. Be aware of lazy evaluation
val lazyRDD = data.map { value =>
counter += 1 // This will be called multiple times if RDD is reused
value * 2
}
lazyRDD.cache() // Cache to avoid recomputation
val result1 = lazyRDD.count()
val result2 = lazyRDD.sum() // Counter won't be incremented again
// 3. Use meaningful names for debugging
val errorCounter = sc.accumulator(0, "Processing Errors")
val warningCounter = sc.accumulator(0, "Processing Warnings")// Monitor memory usage with accumulators
val memoryUsageAccum = sc.accumulator(0L, "Memory Usage")
val processedData = largeDataRDD.mapPartitions { partition =>
val runtime = Runtime.getRuntime
val initialMemory = runtime.totalMemory() - runtime.freeMemory()
val results = partition.map(processLargeRecord).toList
val finalMemory = runtime.totalMemory() - runtime.freeMemory()
memoryUsageAccum += (finalMemory - initialMemory)
results.iterator
}
processedData.count() // Trigger execution
println(s"Total memory used: ${memoryUsageAccum.value / (1024 * 1024)} MB")