or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

broadcast-accumulators.mdcontext-configuration.mdindex.mdjava-api.mdkey-value-operations.mdrdd-operations.mdstatus-monitoring.mdstorage-persistence.mdtask-context.md

broadcast-accumulators.mddocs/

0

# Broadcast Variables and Accumulators

1

2

Spark provides two types of shared variables for efficient distributed computation: broadcast variables for read-only data distribution and accumulators for aggregating information across tasks.

3

4

## Broadcast Variables

5

6

Broadcast variables allow keeping a read-only variable cached on each machine rather than shipping a copy with each task.

7

8

### Broadcast Class

9

10

```scala { .api }

11

abstract class Broadcast[T](val id: Long) {

12

def value: T

13

def unpersist(): Unit

14

def unpersist(blocking: Boolean): Unit

15

def destroy(): Unit

16

def toString: String

17

}

18

```

19

20

### Creating and Using Broadcast Variables

21

22

```scala

23

import org.apache.spark.{SparkContext, SparkConf}

24

25

val sc = new SparkContext(new SparkConf().setAppName("Broadcast Example").setMaster("local[*]"))

26

27

// Create a lookup table that will be used across many tasks

28

val lookupTable = Map(

29

"US" -> "United States",

30

"UK" -> "United Kingdom",

31

"DE" -> "Germany",

32

"FR" -> "France",

33

"JP" -> "Japan"

34

)

35

36

// Broadcast the lookup table

37

val broadcastLookup = sc.broadcast(lookupTable)

38

39

// Use the broadcast variable in transformations

40

val countryCodes = sc.parallelize(Seq("US", "UK", "DE", "UNKNOWN"))

41

val countryNames = countryCodes.map { code =>

42

val lookup = broadcastLookup.value // Access broadcast value

43

lookup.getOrElse(code, "Unknown Country")

44

}

45

46

// Collect results

47

val results = countryNames.collect()

48

// Results: Array("United States", "United Kingdom", "Germany", "Unknown Country")

49

50

// Clean up broadcast variable when done

51

broadcastLookup.unpersist() // Remove from executors' memory

52

broadcastLookup.destroy() // Remove all data and metadata

53

```

54

55

### Large Dataset Join Optimization

56

57

```scala

58

// Traditional join (can be expensive for large datasets)

59

val largeRDD = sc.textFile("large_dataset.txt").map(parseRecord)

60

val smallRDD = sc.textFile("small_lookup.txt").map(parseLookup)

61

val traditionalJoin = largeRDD.join(smallRDD) // Expensive shuffle

62

63

// Broadcast join optimization (when one dataset is small)

64

val smallDataset: Map[String, LookupInfo] = smallRDD.collectAsMap()

65

val broadcastSmall = sc.broadcast(smallDataset)

66

67

val broadcastJoin = largeRDD.map { case (key, value) =>

68

val lookupInfo = broadcastSmall.value.get(key)

69

(key, (value, lookupInfo))

70

}

71

```

72

73

### Configuration Broadcast Pattern

74

75

```scala

76

case class AppConfig(

77

apiEndpoint: String,

78

timeout: Int,

79

retries: Int,

80

features: Map[String, Boolean]

81

)

82

83

val config = AppConfig(

84

apiEndpoint = "https://api.example.com",

85

timeout = 30000,

86

retries = 3,

87

features = Map("feature_a" -> true, "feature_b" -> false)

88

)

89

90

val broadcastConfig = sc.broadcast(config)

91

92

// Use configuration in transformations

93

val processedData = inputRDD.mapPartitions { partition =>

94

val cfg = broadcastConfig.value

95

val apiClient = new ApiClient(cfg.apiEndpoint, cfg.timeout, cfg.retries)

96

97

partition.map { record =>

98

if (cfg.features("feature_a")) {

99

processWithFeatureA(record, apiClient)

100

} else {

101

processStandard(record, apiClient)

102

}

103

}

104

}

105

```

106

107

## Accumulators

108

109

Accumulators are variables that can only be "added" to through associative and commutative operations, making them suitable for implementing counters and sums.

110

111

### Legacy Accumulator (Deprecated)

112

113

```scala { .api }

114

class Accumulator[T](initialValue: T, param: AccumulatorParam[T]) {

115

def +=(term: T): Unit

116

def add(term: T): Unit

117

def value: T // Only valid on driver

118

def setValue(newValue: T): Unit

119

}

120

121

class Accumulable[R, T](initialValue: R, param: AccumulableParam[R, T]) {

122

def +=(term: T): Unit

123

def add(term: T): Unit

124

def value: R

125

def setValue(newValue: R): Unit

126

}

127

```

128

129

### AccumulatorV2 (Current API)

130

131

```scala { .api }

132

abstract class AccumulatorV2[IN, OUT] {

133

def isZero: Boolean

134

def copy(): AccumulatorV2[IN, OUT]

135

def reset(): Unit

136

def add(v: IN): Unit

137

def merge(other: AccumulatorV2[IN, OUT]): Unit

138

def value: OUT

139

def name: Option[String]

140

def id: Long

141

}

142

```

143

144

### Built-in Accumulator Types

145

146

```scala { .api }

147

class LongAccumulator extends AccumulatorV2[java.lang.Long, java.lang.Long] {

148

def add(v: Long): Unit

149

def add(v: java.lang.Long): Unit

150

def sum: Long

151

def count: Long

152

def avg: Double

153

}

154

155

class DoubleAccumulator extends AccumulatorV2[java.lang.Double, java.lang.Double] {

156

def add(v: Double): Unit

157

def add(v: java.lang.Double): Unit

158

def sum: Double

159

def count: Long

160

def avg: Double

161

}

162

163

class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {

164

def add(v: T): Unit

165

def value: java.util.List[T]

166

}

167

```

168

169

### Using Built-in Accumulators

170

171

```scala

172

val sc = new SparkContext(new SparkConf().setAppName("Accumulator Example").setMaster("local[*]"))

173

174

// Create accumulators

175

val errorCount = sc.longAccumulator("Error Count")

176

val processingTime = sc.doubleAccumulator("Total Processing Time")

177

val errorMessages = sc.collectionAccumulator[String]("Error Messages")

178

179

val data = sc.parallelize(1 to 1000)

180

181

// Use accumulators in transformations

182

val results = data.map { number =>

183

val startTime = System.currentTimeMillis()

184

185

try {

186

if (number % 100 == 0) {

187

throw new RuntimeException(s"Simulated error for number $number")

188

}

189

190

val result = complexProcessing(number)

191

val elapsed = System.currentTimeMillis() - startTime

192

processingTime.add(elapsed.toDouble)

193

194

result

195

} catch {

196

case e: Exception =>

197

errorCount.add(1)

198

errorMessages.add(s"Error processing $number: ${e.getMessage}")

199

-1 // Error sentinel value

200

}

201

}

202

203

// Trigger computation

204

val processedResults = results.filter(_ != -1).collect()

205

206

// Access accumulator values (only on driver)

207

println(s"Successfully processed: ${processedResults.length}")

208

println(s"Errors encountered: ${errorCount.value}")

209

println(s"Average processing time: ${processingTime.value / processedResults.length}ms")

210

println(s"Error messages: ${errorMessages.value.asScala.take(5)}")

211

```

212

213

### Custom Accumulators

214

215

```scala

216

import scala.collection.mutable

217

218

// Custom accumulator for collecting statistics

219

class StatsAccumulator extends AccumulatorV2[Double, (Long, Double, Double, Double, Double)] {

220

private var _count: Long = 0

221

private var _sum: Double = 0.0

222

private var _sumSquares: Double = 0.0

223

private var _min: Double = Double.MaxValue

224

private var _max: Double = Double.MinValue

225

226

def isZero: Boolean = _count == 0

227

228

def copy(): StatsAccumulator = {

229

val newAcc = new StatsAccumulator

230

newAcc._count = _count

231

newAcc._sum = _sum

232

newAcc._sumSquares = _sumSquares

233

newAcc._min = _min

234

newAcc._max = _max

235

newAcc

236

}

237

238

def reset(): Unit = {

239

_count = 0

240

_sum = 0.0

241

_sumSquares = 0.0

242

_min = Double.MaxValue

243

_max = Double.MinValue

244

}

245

246

def add(v: Double): Unit = {

247

_count += 1

248

_sum += v

249

_sumSquares += v * v

250

_min = math.min(_min, v)

251

_max = math.max(_max, v)

252

}

253

254

def merge(other: AccumulatorV2[Double, (Long, Double, Double, Double, Double)]): Unit = {

255

other match {

256

case o: StatsAccumulator =>

257

_count += o._count

258

_sum += o._sum

259

_sumSquares += o._sumSquares

260

_min = math.min(_min, o._min)

261

_max = math.max(_max, o._max)

262

}

263

}

264

265

def value: (Long, Double, Double, Double, Double) = {

266

if (_count == 0) (0, 0.0, 0.0, 0.0, 0.0)

267

else {

268

val mean = _sum / _count

269

val variance = (_sumSquares / _count) - (mean * mean)

270

(count, _sum, mean, variance, _min, _max)

271

}

272

}

273

}

274

275

// Register and use custom accumulator

276

val statsAcc = new StatsAccumulator

277

sc.register(statsAcc, "Statistics")

278

279

val numbers = sc.parallelize(Seq.fill(1000)(scala.util.Random.nextGaussian() * 100 + 50))

280

numbers.foreach(statsAcc.add)

281

282

val (count, sum, mean, variance, min, max) = statsAcc.value

283

println(f"Count: $count, Sum: $sum%.2f, Mean: $mean%.2f, Variance: $variance%.2f")

284

println(f"Min: $min%.2f, Max: $max%.2f")

285

```

286

287

### Histogram Accumulator

288

289

```scala

290

class HistogramAccumulator(val buckets: Array[Double]) extends AccumulatorV2[Double, Array[Long]] {

291

require(buckets.sorted.sameElements(buckets), "Buckets must be sorted")

292

293

private val _counts = Array.fill(buckets.length + 1)(0L)

294

295

def isZero: Boolean = _counts.forall(_ == 0)

296

297

def copy(): HistogramAccumulator = {

298

val newAcc = new HistogramAccumulator(buckets)

299

System.arraycopy(_counts, 0, newAcc._counts, 0, _counts.length)

300

newAcc

301

}

302

303

def reset(): Unit = {

304

java.util.Arrays.fill(_counts, 0L)

305

}

306

307

def add(v: Double): Unit = {

308

val bucketIndex = java.util.Arrays.binarySearch(buckets, v)

309

val index = if (bucketIndex >= 0) bucketIndex else -bucketIndex - 1

310

_counts(index) += 1

311

}

312

313

def merge(other: AccumulatorV2[Double, Array[Long]]): Unit = {

314

other match {

315

case o: HistogramAccumulator =>

316

for (i <- _counts.indices) {

317

_counts(i) += o._counts(i)

318

}

319

}

320

}

321

322

def value: Array[Long] = _counts.clone()

323

}

324

325

// Usage

326

val histogramBuckets = Array(0.0, 25.0, 50.0, 75.0, 100.0)

327

val histogramAcc = new HistogramAccumulator(histogramBuckets)

328

sc.register(histogramAcc, "Value Histogram")

329

330

val values = sc.parallelize(Seq.fill(10000)(scala.util.Random.nextDouble() * 100))

331

values.foreach(histogramAcc.add)

332

333

val histogram = histogramAcc.value

334

println("Histogram buckets and counts:")

335

println(s"< ${histogramBuckets(0)}: ${histogram(0)}")

336

for (i <- histogramBuckets.indices.init) {

337

println(s"${histogramBuckets(i)} - ${histogramBuckets(i+1)}: ${histogram(i+1)}")

338

}

339

println(s">= ${histogramBuckets.last}: ${histogram.last}")

340

```

341

342

## Advanced Patterns

343

344

### Distributed Cache with Broadcast

345

346

```scala

347

import scala.collection.mutable

348

349

class DistributedCache[K, V](loadFunction: K => V) extends Serializable {

350

@transient private lazy val cache = mutable.Map[K, V]()

351

352

def get(key: K): V = {

353

cache.getOrElseUpdate(key, loadFunction(key))

354

}

355

}

356

357

// Broadcast the cache instance

358

val databaseCache = new DistributedCache[String, DatabaseRecord] { key =>

359

// This will be called once per executor per key

360

loadFromDatabase(key)

361

}

362

val broadcastCache = sc.broadcast(databaseCache)

363

364

// Use across multiple operations

365

val enrichedData1 = rdd1.map { record =>

366

val enrichment = broadcastCache.value.get(record.key)

367

record.copy(metadata = enrichment)

368

}

369

370

val enrichedData2 = rdd2.map { record =>

371

val enrichment = broadcastCache.value.get(record.foreignKey)

372

record.copy(additionalInfo = enrichment)

373

}

374

```

375

376

### Multi-level Metrics Collection

377

378

```scala

379

case class TaskMetrics(

380

processedRecords: Long = 0,

381

errorRecords: Long = 0,

382

processingTimeMs: Long = 0,

383

cacheHits: Long = 0,

384

cacheMisses: Long = 0

385

) {

386

def +(other: TaskMetrics): TaskMetrics = TaskMetrics(

387

processedRecords + other.processedRecords,

388

errorRecords + other.errorRecords,

389

processingTimeMs + other.processingTimeMs,

390

cacheHits + other.cacheHits,

391

cacheMisses + other.cacheMisses

392

)

393

}

394

395

class TaskMetricsAccumulator extends AccumulatorV2[TaskMetrics, TaskMetrics] {

396

private var _metrics = TaskMetrics()

397

398

def isZero: Boolean = _metrics == TaskMetrics()

399

def copy(): TaskMetricsAccumulator = {

400

val newAcc = new TaskMetricsAccumulator

401

newAcc._metrics = _metrics

402

newAcc

403

}

404

def reset(): Unit = _metrics = TaskMetrics()

405

def add(v: TaskMetrics): Unit = _metrics = _metrics + v

406

def merge(other: AccumulatorV2[TaskMetrics, TaskMetrics]): Unit = {

407

other match {

408

case o: TaskMetricsAccumulator => _metrics = _metrics + o._metrics

409

}

410

}

411

def value: TaskMetrics = _metrics

412

}

413

414

// Usage in a processing pipeline

415

val metricsAcc = new TaskMetricsAccumulator

416

sc.register(metricsAcc, "Pipeline Metrics")

417

418

val results = inputRDD.mapPartitions { partition =>

419

var partitionMetrics = TaskMetrics()

420

val cache = mutable.Map[String, Any]()

421

422

val processedPartition = partition.map { record =>

423

val startTime = System.currentTimeMillis()

424

425

try {

426

// Simulate cache lookup

427

val cacheKey = record.getCacheKey

428

val cachedValue = cache.get(cacheKey)

429

430

if (cachedValue.isDefined) {

431

partitionMetrics = partitionMetrics.copy(cacheHits = partitionMetrics.cacheHits + 1)

432

} else {

433

partitionMetrics = partitionMetrics.copy(cacheMisses = partitionMetrics.cacheMisses + 1)

434

cache(cacheKey) = computeValue(record)

435

}

436

437

val processed = processRecord(record)

438

val elapsed = System.currentTimeMillis() - startTime

439

440

partitionMetrics = partitionMetrics.copy(

441

processedRecords = partitionMetrics.processedRecords + 1,

442

processingTimeMs = partitionMetrics.processingTimeMs + elapsed

443

)

444

445

processed

446

} catch {

447

case _: Exception =>

448

partitionMetrics = partitionMetrics.copy(errorRecords = partitionMetrics.errorRecords + 1)

449

null

450

}

451

}.filter(_ != null)

452

453

// Add partition metrics to global accumulator

454

metricsAcc.add(partitionMetrics)

455

processedPartition

456

}

457

458

// Trigger computation and get metrics

459

val finalResults = results.collect()

460

val finalMetrics = metricsAcc.value

461

462

println(s"Processing Summary:")

463

println(s" Processed: ${finalMetrics.processedRecords}")

464

println(s" Errors: ${finalMetrics.errorRecords}")

465

println(s" Total time: ${finalMetrics.processingTimeMs}ms")

466

println(s" Avg time per record: ${finalMetrics.processingTimeMs.toDouble / finalMetrics.processedRecords}ms")

467

println(s" Cache hit rate: ${finalMetrics.cacheHits.toDouble / (finalMetrics.cacheHits + finalMetrics.cacheMisses) * 100}%")

468

```

469

470

### Best Practices

471

472

#### Broadcast Variables

473

- Only broadcast read-only data

474

- Broadcast small to medium-sized datasets (typically < 2GB)

475

- Clean up broadcast variables when no longer needed

476

- Use broadcast joins for small lookup tables

477

- Consider using broadcast for configuration objects

478

479

#### Accumulators

480

- Only use accumulators for metrics and debugging information

481

- Don't rely on accumulator values for program logic (values may be inconsistent)

482

- Accumulators are only guaranteed to be updated once per task for actions

483

- Register accumulators with meaningful names for monitoring

484

- Consider using custom accumulators for complex aggregations

485

486

#### Memory Management

487

```scala

488

// Clean up resources properly

489

try {

490

val broadcastData = sc.broadcast(largeData)

491

val metrics = sc.longAccumulator("Processing Count")

492

493

// Use broadcast and accumulator

494

val results = processWithBroadcastAndAccumulator(inputRDD, broadcastData, metrics)

495

496

// Collect results

497

results.collect()

498

499

println(s"Processed ${metrics.value} records")

500

501

} finally {

502

// Always clean up

503

broadcastData.unpersist()

504

broadcastData.destroy()

505

}

506

```