0
# Broadcast Variables and Accumulators
1
2
Spark provides two types of shared variables for efficient distributed computation: broadcast variables for read-only data distribution and accumulators for aggregating information across tasks.
3
4
## Broadcast Variables
5
6
Broadcast variables allow keeping a read-only variable cached on each machine rather than shipping a copy with each task.
7
8
### Broadcast Class
9
10
```scala { .api }
11
abstract class Broadcast[T](val id: Long) {
12
def value: T
13
def unpersist(): Unit
14
def unpersist(blocking: Boolean): Unit
15
def destroy(): Unit
16
def toString: String
17
}
18
```
19
20
### Creating and Using Broadcast Variables
21
22
```scala
23
import org.apache.spark.{SparkContext, SparkConf}
24
25
val sc = new SparkContext(new SparkConf().setAppName("Broadcast Example").setMaster("local[*]"))
26
27
// Create a lookup table that will be used across many tasks
28
val lookupTable = Map(
29
"US" -> "United States",
30
"UK" -> "United Kingdom",
31
"DE" -> "Germany",
32
"FR" -> "France",
33
"JP" -> "Japan"
34
)
35
36
// Broadcast the lookup table
37
val broadcastLookup = sc.broadcast(lookupTable)
38
39
// Use the broadcast variable in transformations
40
val countryCodes = sc.parallelize(Seq("US", "UK", "DE", "UNKNOWN"))
41
val countryNames = countryCodes.map { code =>
42
val lookup = broadcastLookup.value // Access broadcast value
43
lookup.getOrElse(code, "Unknown Country")
44
}
45
46
// Collect results
47
val results = countryNames.collect()
48
// Results: Array("United States", "United Kingdom", "Germany", "Unknown Country")
49
50
// Clean up broadcast variable when done
51
broadcastLookup.unpersist() // Remove from executors' memory
52
broadcastLookup.destroy() // Remove all data and metadata
53
```
54
55
### Large Dataset Join Optimization
56
57
```scala
58
// Traditional join (can be expensive for large datasets)
59
val largeRDD = sc.textFile("large_dataset.txt").map(parseRecord)
60
val smallRDD = sc.textFile("small_lookup.txt").map(parseLookup)
61
val traditionalJoin = largeRDD.join(smallRDD) // Expensive shuffle
62
63
// Broadcast join optimization (when one dataset is small)
64
val smallDataset: Map[String, LookupInfo] = smallRDD.collectAsMap()
65
val broadcastSmall = sc.broadcast(smallDataset)
66
67
val broadcastJoin = largeRDD.map { case (key, value) =>
68
val lookupInfo = broadcastSmall.value.get(key)
69
(key, (value, lookupInfo))
70
}
71
```
72
73
### Configuration Broadcast Pattern
74
75
```scala
76
case class AppConfig(
77
apiEndpoint: String,
78
timeout: Int,
79
retries: Int,
80
features: Map[String, Boolean]
81
)
82
83
val config = AppConfig(
84
apiEndpoint = "https://api.example.com",
85
timeout = 30000,
86
retries = 3,
87
features = Map("feature_a" -> true, "feature_b" -> false)
88
)
89
90
val broadcastConfig = sc.broadcast(config)
91
92
// Use configuration in transformations
93
val processedData = inputRDD.mapPartitions { partition =>
94
val cfg = broadcastConfig.value
95
val apiClient = new ApiClient(cfg.apiEndpoint, cfg.timeout, cfg.retries)
96
97
partition.map { record =>
98
if (cfg.features("feature_a")) {
99
processWithFeatureA(record, apiClient)
100
} else {
101
processStandard(record, apiClient)
102
}
103
}
104
}
105
```
106
107
## Accumulators
108
109
Accumulators are variables that can only be "added" to through associative and commutative operations, making them suitable for implementing counters and sums.
110
111
### Legacy Accumulator (Deprecated)
112
113
```scala { .api }
114
class Accumulator[T](initialValue: T, param: AccumulatorParam[T]) {
115
def +=(term: T): Unit
116
def add(term: T): Unit
117
def value: T // Only valid on driver
118
def setValue(newValue: T): Unit
119
}
120
121
class Accumulable[R, T](initialValue: R, param: AccumulableParam[R, T]) {
122
def +=(term: T): Unit
123
def add(term: T): Unit
124
def value: R
125
def setValue(newValue: R): Unit
126
}
127
```
128
129
### AccumulatorV2 (Current API)
130
131
```scala { .api }
132
abstract class AccumulatorV2[IN, OUT] {
133
def isZero: Boolean
134
def copy(): AccumulatorV2[IN, OUT]
135
def reset(): Unit
136
def add(v: IN): Unit
137
def merge(other: AccumulatorV2[IN, OUT]): Unit
138
def value: OUT
139
def name: Option[String]
140
def id: Long
141
}
142
```
143
144
### Built-in Accumulator Types
145
146
```scala { .api }
147
class LongAccumulator extends AccumulatorV2[java.lang.Long, java.lang.Long] {
148
def add(v: Long): Unit
149
def add(v: java.lang.Long): Unit
150
def sum: Long
151
def count: Long
152
def avg: Double
153
}
154
155
class DoubleAccumulator extends AccumulatorV2[java.lang.Double, java.lang.Double] {
156
def add(v: Double): Unit
157
def add(v: java.lang.Double): Unit
158
def sum: Double
159
def count: Long
160
def avg: Double
161
}
162
163
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
164
def add(v: T): Unit
165
def value: java.util.List[T]
166
}
167
```
168
169
### Using Built-in Accumulators
170
171
```scala
172
val sc = new SparkContext(new SparkConf().setAppName("Accumulator Example").setMaster("local[*]"))
173
174
// Create accumulators
175
val errorCount = sc.longAccumulator("Error Count")
176
val processingTime = sc.doubleAccumulator("Total Processing Time")
177
val errorMessages = sc.collectionAccumulator[String]("Error Messages")
178
179
val data = sc.parallelize(1 to 1000)
180
181
// Use accumulators in transformations
182
val results = data.map { number =>
183
val startTime = System.currentTimeMillis()
184
185
try {
186
if (number % 100 == 0) {
187
throw new RuntimeException(s"Simulated error for number $number")
188
}
189
190
val result = complexProcessing(number)
191
val elapsed = System.currentTimeMillis() - startTime
192
processingTime.add(elapsed.toDouble)
193
194
result
195
} catch {
196
case e: Exception =>
197
errorCount.add(1)
198
errorMessages.add(s"Error processing $number: ${e.getMessage}")
199
-1 // Error sentinel value
200
}
201
}
202
203
// Trigger computation
204
val processedResults = results.filter(_ != -1).collect()
205
206
// Access accumulator values (only on driver)
207
println(s"Successfully processed: ${processedResults.length}")
208
println(s"Errors encountered: ${errorCount.value}")
209
println(s"Average processing time: ${processingTime.value / processedResults.length}ms")
210
println(s"Error messages: ${errorMessages.value.asScala.take(5)}")
211
```
212
213
### Custom Accumulators
214
215
```scala
216
import scala.collection.mutable
217
218
// Custom accumulator for collecting statistics
219
class StatsAccumulator extends AccumulatorV2[Double, (Long, Double, Double, Double, Double)] {
220
private var _count: Long = 0
221
private var _sum: Double = 0.0
222
private var _sumSquares: Double = 0.0
223
private var _min: Double = Double.MaxValue
224
private var _max: Double = Double.MinValue
225
226
def isZero: Boolean = _count == 0
227
228
def copy(): StatsAccumulator = {
229
val newAcc = new StatsAccumulator
230
newAcc._count = _count
231
newAcc._sum = _sum
232
newAcc._sumSquares = _sumSquares
233
newAcc._min = _min
234
newAcc._max = _max
235
newAcc
236
}
237
238
def reset(): Unit = {
239
_count = 0
240
_sum = 0.0
241
_sumSquares = 0.0
242
_min = Double.MaxValue
243
_max = Double.MinValue
244
}
245
246
def add(v: Double): Unit = {
247
_count += 1
248
_sum += v
249
_sumSquares += v * v
250
_min = math.min(_min, v)
251
_max = math.max(_max, v)
252
}
253
254
def merge(other: AccumulatorV2[Double, (Long, Double, Double, Double, Double)]): Unit = {
255
other match {
256
case o: StatsAccumulator =>
257
_count += o._count
258
_sum += o._sum
259
_sumSquares += o._sumSquares
260
_min = math.min(_min, o._min)
261
_max = math.max(_max, o._max)
262
}
263
}
264
265
def value: (Long, Double, Double, Double, Double) = {
266
if (_count == 0) (0, 0.0, 0.0, 0.0, 0.0)
267
else {
268
val mean = _sum / _count
269
val variance = (_sumSquares / _count) - (mean * mean)
270
(count, _sum, mean, variance, _min, _max)
271
}
272
}
273
}
274
275
// Register and use custom accumulator
276
val statsAcc = new StatsAccumulator
277
sc.register(statsAcc, "Statistics")
278
279
val numbers = sc.parallelize(Seq.fill(1000)(scala.util.Random.nextGaussian() * 100 + 50))
280
numbers.foreach(statsAcc.add)
281
282
val (count, sum, mean, variance, min, max) = statsAcc.value
283
println(f"Count: $count, Sum: $sum%.2f, Mean: $mean%.2f, Variance: $variance%.2f")
284
println(f"Min: $min%.2f, Max: $max%.2f")
285
```
286
287
### Histogram Accumulator
288
289
```scala
290
class HistogramAccumulator(val buckets: Array[Double]) extends AccumulatorV2[Double, Array[Long]] {
291
require(buckets.sorted.sameElements(buckets), "Buckets must be sorted")
292
293
private val _counts = Array.fill(buckets.length + 1)(0L)
294
295
def isZero: Boolean = _counts.forall(_ == 0)
296
297
def copy(): HistogramAccumulator = {
298
val newAcc = new HistogramAccumulator(buckets)
299
System.arraycopy(_counts, 0, newAcc._counts, 0, _counts.length)
300
newAcc
301
}
302
303
def reset(): Unit = {
304
java.util.Arrays.fill(_counts, 0L)
305
}
306
307
def add(v: Double): Unit = {
308
val bucketIndex = java.util.Arrays.binarySearch(buckets, v)
309
val index = if (bucketIndex >= 0) bucketIndex else -bucketIndex - 1
310
_counts(index) += 1
311
}
312
313
def merge(other: AccumulatorV2[Double, Array[Long]]): Unit = {
314
other match {
315
case o: HistogramAccumulator =>
316
for (i <- _counts.indices) {
317
_counts(i) += o._counts(i)
318
}
319
}
320
}
321
322
def value: Array[Long] = _counts.clone()
323
}
324
325
// Usage
326
val histogramBuckets = Array(0.0, 25.0, 50.0, 75.0, 100.0)
327
val histogramAcc = new HistogramAccumulator(histogramBuckets)
328
sc.register(histogramAcc, "Value Histogram")
329
330
val values = sc.parallelize(Seq.fill(10000)(scala.util.Random.nextDouble() * 100))
331
values.foreach(histogramAcc.add)
332
333
val histogram = histogramAcc.value
334
println("Histogram buckets and counts:")
335
println(s"< ${histogramBuckets(0)}: ${histogram(0)}")
336
for (i <- histogramBuckets.indices.init) {
337
println(s"${histogramBuckets(i)} - ${histogramBuckets(i+1)}: ${histogram(i+1)}")
338
}
339
println(s">= ${histogramBuckets.last}: ${histogram.last}")
340
```
341
342
## Advanced Patterns
343
344
### Distributed Cache with Broadcast
345
346
```scala
347
import scala.collection.mutable
348
349
class DistributedCache[K, V](loadFunction: K => V) extends Serializable {
350
@transient private lazy val cache = mutable.Map[K, V]()
351
352
def get(key: K): V = {
353
cache.getOrElseUpdate(key, loadFunction(key))
354
}
355
}
356
357
// Broadcast the cache instance
358
val databaseCache = new DistributedCache[String, DatabaseRecord] { key =>
359
// This will be called once per executor per key
360
loadFromDatabase(key)
361
}
362
val broadcastCache = sc.broadcast(databaseCache)
363
364
// Use across multiple operations
365
val enrichedData1 = rdd1.map { record =>
366
val enrichment = broadcastCache.value.get(record.key)
367
record.copy(metadata = enrichment)
368
}
369
370
val enrichedData2 = rdd2.map { record =>
371
val enrichment = broadcastCache.value.get(record.foreignKey)
372
record.copy(additionalInfo = enrichment)
373
}
374
```
375
376
### Multi-level Metrics Collection
377
378
```scala
379
case class TaskMetrics(
380
processedRecords: Long = 0,
381
errorRecords: Long = 0,
382
processingTimeMs: Long = 0,
383
cacheHits: Long = 0,
384
cacheMisses: Long = 0
385
) {
386
def +(other: TaskMetrics): TaskMetrics = TaskMetrics(
387
processedRecords + other.processedRecords,
388
errorRecords + other.errorRecords,
389
processingTimeMs + other.processingTimeMs,
390
cacheHits + other.cacheHits,
391
cacheMisses + other.cacheMisses
392
)
393
}
394
395
class TaskMetricsAccumulator extends AccumulatorV2[TaskMetrics, TaskMetrics] {
396
private var _metrics = TaskMetrics()
397
398
def isZero: Boolean = _metrics == TaskMetrics()
399
def copy(): TaskMetricsAccumulator = {
400
val newAcc = new TaskMetricsAccumulator
401
newAcc._metrics = _metrics
402
newAcc
403
}
404
def reset(): Unit = _metrics = TaskMetrics()
405
def add(v: TaskMetrics): Unit = _metrics = _metrics + v
406
def merge(other: AccumulatorV2[TaskMetrics, TaskMetrics]): Unit = {
407
other match {
408
case o: TaskMetricsAccumulator => _metrics = _metrics + o._metrics
409
}
410
}
411
def value: TaskMetrics = _metrics
412
}
413
414
// Usage in a processing pipeline
415
val metricsAcc = new TaskMetricsAccumulator
416
sc.register(metricsAcc, "Pipeline Metrics")
417
418
val results = inputRDD.mapPartitions { partition =>
419
var partitionMetrics = TaskMetrics()
420
val cache = mutable.Map[String, Any]()
421
422
val processedPartition = partition.map { record =>
423
val startTime = System.currentTimeMillis()
424
425
try {
426
// Simulate cache lookup
427
val cacheKey = record.getCacheKey
428
val cachedValue = cache.get(cacheKey)
429
430
if (cachedValue.isDefined) {
431
partitionMetrics = partitionMetrics.copy(cacheHits = partitionMetrics.cacheHits + 1)
432
} else {
433
partitionMetrics = partitionMetrics.copy(cacheMisses = partitionMetrics.cacheMisses + 1)
434
cache(cacheKey) = computeValue(record)
435
}
436
437
val processed = processRecord(record)
438
val elapsed = System.currentTimeMillis() - startTime
439
440
partitionMetrics = partitionMetrics.copy(
441
processedRecords = partitionMetrics.processedRecords + 1,
442
processingTimeMs = partitionMetrics.processingTimeMs + elapsed
443
)
444
445
processed
446
} catch {
447
case _: Exception =>
448
partitionMetrics = partitionMetrics.copy(errorRecords = partitionMetrics.errorRecords + 1)
449
null
450
}
451
}.filter(_ != null)
452
453
// Add partition metrics to global accumulator
454
metricsAcc.add(partitionMetrics)
455
processedPartition
456
}
457
458
// Trigger computation and get metrics
459
val finalResults = results.collect()
460
val finalMetrics = metricsAcc.value
461
462
println(s"Processing Summary:")
463
println(s" Processed: ${finalMetrics.processedRecords}")
464
println(s" Errors: ${finalMetrics.errorRecords}")
465
println(s" Total time: ${finalMetrics.processingTimeMs}ms")
466
println(s" Avg time per record: ${finalMetrics.processingTimeMs.toDouble / finalMetrics.processedRecords}ms")
467
println(s" Cache hit rate: ${finalMetrics.cacheHits.toDouble / (finalMetrics.cacheHits + finalMetrics.cacheMisses) * 100}%")
468
```
469
470
### Best Practices
471
472
#### Broadcast Variables
473
- Only broadcast read-only data
474
- Broadcast small to medium-sized datasets (typically < 2GB)
475
- Clean up broadcast variables when no longer needed
476
- Use broadcast joins for small lookup tables
477
- Consider using broadcast for configuration objects
478
479
#### Accumulators
480
- Only use accumulators for metrics and debugging information
481
- Don't rely on accumulator values for program logic (values may be inconsistent)
482
- Accumulators are only guaranteed to be updated once per task for actions
483
- Register accumulators with meaningful names for monitoring
484
- Consider using custom accumulators for complex aggregations
485
486
#### Memory Management
487
```scala
488
// Clean up resources properly
489
try {
490
val broadcastData = sc.broadcast(largeData)
491
val metrics = sc.longAccumulator("Processing Count")
492
493
// Use broadcast and accumulator
494
val results = processWithBroadcastAndAccumulator(inputRDD, broadcastData, metrics)
495
496
// Collect results
497
results.collect()
498
499
println(s"Processed ${metrics.value} records")
500
501
} finally {
502
// Always clean up
503
broadcastData.unpersist()
504
broadcastData.destroy()
505
}
506
```