or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

index.mdio-operations.mdkey-value-operations.mdpartitioning-shuffling.mdrdd-operations.mdshared-variables.mdspark-context.mdstorage-persistence.md

partitioning-shuffling.mddocs/

0

# Partitioning and Shuffling

1

2

Partitioning strategies and shuffle operations for controlling data distribution and optimizing performance across cluster nodes in Spark applications.

3

4

## Capabilities

5

6

### Partitioner

7

8

Abstract base class defining how elements are partitioned by key across cluster nodes.

9

10

```scala { .api }

11

/**

12

* An object that defines how the elements in a key-value pair RDD are partitioned by key

13

*/

14

abstract class Partitioner extends Serializable {

15

/** Return the number of partitions in this partitioner */

16

def numPartitions: Int

17

18

/** Return the partition id for the given key */

19

def getPartition(key: Any): Int

20

21

/** Test whether this partitioner is equal to another object */

22

override def equals(other: Any): Boolean

23

24

/** Return a hash code for this partitioner */

25

override def hashCode: Int

26

}

27

28

// Built-in partitioners

29

class HashPartitioner(partitions: Int) extends Partitioner {

30

require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

31

32

def numPartitions: Int = partitions

33

34

def getPartition(key: Any): Int = key match {

35

case null => 0

36

case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)

37

}

38

39

override def equals(other: Any): Boolean = other match {

40

case h: HashPartitioner => h.numPartitions == numPartitions

41

case _ => false

42

}

43

44

override def hashCode: Int = numPartitions

45

}

46

47

class RangePartitioner[K : Ordering : ClassTag, V](

48

partitions: Int,

49

rdd: RDD[(K, V)],

50

ascending: Boolean = true) extends Partitioner {

51

52

def numPartitions: Int = partitions

53

54

def getPartition(key: Any): Int = {

55

val k = key.asInstanceOf[K]

56

// Binary search to find the partition

57

// Implementation details...

58

}

59

}

60

```

61

62

### Partitioning Methods (RDD)

63

64

Methods for controlling RDD partitioning and data distribution.

65

66

```scala { .api }

67

// Partitioning methods available on RDDs

68

def partitions: Array[Partition]

69

def getNumPartitions: Int

70

def partitioner: Option[Partitioner]

71

72

// Repartitioning methods

73

def repartition(numPartitions: Int): RDD[T]

74

def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T]

75

76

// Key-value partitioning (available on pair RDDs)

77

def partitionBy(partitioner: Partitioner): RDD[(K, V)]

78

def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)]

79

80

// Partition-level operations

81

def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]

82

def mapPartitionsWithIndex[U: ClassTag](f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]

83

def foreachPartition(f: Iterator[T] => Unit): Unit

84

def glom(): RDD[Array[T]]

85

86

// Partition sampling and inspection

87

def sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.random.nextLong): RDD[T]

88

def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T]

89

```

90

91

**Usage Examples:**

92

93

```scala

94

import org.apache.spark.{SparkContext, SparkConf, HashPartitioner, RangePartitioner}

95

96

val sc = new SparkContext(new SparkConf().setAppName("Partitioning Example").setMaster("local[*]"))

97

98

// Create sample data

99

val data = sc.parallelize(1 to 1000, numSlices = 4) // 4 initial partitions

100

val keyValueData = data.map(x => (x % 10, x)) // (key, value) pairs

101

102

// Check current partitioning

103

println(s"Number of partitions: ${keyValueData.getNumPartitions}")

104

println(s"Current partitioner: ${keyValueData.partitioner}")

105

106

// Hash partitioning

107

val hashPartitioned = keyValueData.partitionBy(new HashPartitioner(8))

108

println(s"Hash partitioned: ${hashPartitioned.getNumPartitions} partitions")

109

110

// Range partitioning

111

val rangePartitioned = keyValueData.partitionBy(new RangePartitioner(6, keyValueData))

112

println(s"Range partitioned: ${rangePartitioned.getNumPartitions} partitions")

113

114

// Repartitioning vs coalescing

115

val repartitioned = data.repartition(8) // Triggers shuffle

116

val coalesced = data.coalesce(2) // May avoid shuffle

117

118

// Inspect partition contents

119

val partitionContents = keyValueData.glom().collect()

120

partitionContents.zipWithIndex.foreach { case (partition, index) =>

121

println(s"Partition $index: ${partition.take(5).mkString(", ")}...")

122

}

123

124

// Process partitions individually

125

val partitionSums = data.mapPartitions { iterator =>

126

val sum = iterator.sum

127

Iterator(sum)

128

}

129

```

130

131

### Custom Partitioners

132

133

Creating custom partitioning strategies for specific use cases.

134

135

```scala { .api }

136

// Example custom partitioners

137

class DomainPartitioner(numPartitions: Int) extends Partitioner {

138

override def numPartitions: Int = numPartitions

139

140

override def getPartition(key: Any): Int = {

141

val domain = key.toString

142

domain match {

143

case d if d.endsWith(".com") => 0

144

case d if d.endsWith(".org") => 1

145

case d if d.endsWith(".edu") => 2

146

case _ => 3

147

}

148

}

149

}

150

151

class GeographicPartitioner(regions: Map[String, Int]) extends Partitioner {

152

override def numPartitions: Int = regions.size

153

154

override def getPartition(key: Any): Int = {

155

val location = key.toString

156

regions.getOrElse(location.substring(0, 2), 0) // First 2 chars as region

157

}

158

}

159

160

class CustomRangePartitioner[K](

161

ranges: Array[K],

162

numPartitions: Int)(implicit ord: Ordering[K]) extends Partitioner {

163

164

override def numPartitions: Int = numPartitions

165

166

override def getPartition(key: Any): Int = {

167

val k = key.asInstanceOf[K]

168

val partition = ranges.indexWhere(ord.gteq(k, _))

169

if (partition < 0) ranges.length - 1 else partition

170

}

171

}

172

```

173

174

**Custom Partitioner Examples:**

175

176

```scala

177

// Domain-based partitioning for web logs

178

val webLogs = sc.textFile("web-logs.txt")

179

.map(parseLogEntry) // Returns (domain, logEntry)

180

.partitionBy(new DomainPartitioner(4))

181

182

// Geographic partitioning for location data

183

val regionMap = Map("US" -> 0, "EU" -> 1, "AS" -> 2, "OT" -> 3)

184

val locationData = sc.textFile("locations.txt")

185

.map(parseLocation) // Returns (country, locationInfo)

186

.partitionBy(new GeographicPartitioner(regionMap))

187

188

// Custom range partitioning for time series data

189

val timeRanges = Array(

190

"2023-01-01", "2023-04-01", "2023-07-01", "2023-10-01"

191

)

192

val timeSeriesData = sc.textFile("timeseries.txt")

193

.map(parseTimeEntry) // Returns (date, data)

194

.partitionBy(new CustomRangePartitioner(timeRanges, 4))

195

```

196

197

### Shuffle Operations

198

199

Operations that trigger data shuffling across the cluster.

200

201

```scala { .api }

202

// Operations that typically trigger shuffles:

203

204

// Repartitioning operations

205

def repartition(numPartitions: Int): RDD[T] // Always shuffles

206

def partitionBy(partitioner: Partitioner): RDD[(K, V)] // Shuffles if different partitioner

207

208

// Aggregation operations (on key-value RDDs)

209

def groupByKey(): RDD[(K, Iterable[V])] // Shuffles

210

def reduceByKey(func: (V, V) => V): RDD[(K, V)] // Shuffles but with pre-aggregation

211

def aggregateByKey[U](zeroValue: U)(seqOp: (U, V) => U, combOp: (U, U) => U): RDD[(K, U)] // Shuffles

212

def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C): RDD[(K, C)] // Shuffles

213

214

// Join operations

215

def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] // Shuffles both RDDs

216

def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] // Shuffles

217

def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] // Shuffles

218

def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] // Shuffles

219

def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] // Shuffles

220

221

// Sorting operations

222

def sortByKey(ascending: Boolean = true): RDD[(K, V)] // Shuffles

223

def sortBy[K](f: T => K, ascending: Boolean = true): RDD[T] // Shuffles

224

225

// Set operations

226

def distinct(): RDD[T] // Shuffles

227

def intersection(other: RDD[T]): RDD[T] // Shuffles

228

def subtract(other: RDD[T]): RDD[T] // Shuffles

229

```

230

231

## Performance Optimization Strategies

232

233

### Minimizing Shuffles

234

235

```scala

236

// INEFFICIENT: Multiple shuffles

237

val inefficientPipeline = data

238

.groupByKey() // Shuffle 1

239

.mapValues(_.sum) // No shuffle

240

.filter(_._2 > 100) // No shuffle

241

.sortByKey() // Shuffle 2

242

243

// EFFICIENT: Fewer shuffles with better operations

244

val efficientPipeline = data

245

.reduceByKey(_ + _) // Shuffle 1 with pre-aggregation

246

.filter(_._2 > 100) // No shuffle

247

.sortByKey() // Shuffle 2

248

249

// MOST EFFICIENT: Single shuffle with repartitionAndSortWithinPartitions

250

val mostEfficientPipeline = data

251

.reduceByKey(_ + _) // Shuffle 1 with pre-aggregation

252

.filter(_._2 > 100) // No shuffle

253

.repartitionAndSortWithinPartitions(new HashPartitioner(8)) // Combined repartition + sort

254

```

255

256

### Partition-Aware Operations

257

258

```scala

259

// Leverage partitioning for efficient joins

260

val users = sc.textFile("users.txt")

261

.map(parseUser) // (userId, userInfo)

262

.partitionBy(new HashPartitioner(8))

263

.cache() // Cache partitioned data

264

265

val orders = sc.textFile("orders.txt")

266

.map(parseOrder) // (userId, orderInfo)

267

.partitionBy(new HashPartitioner(8)) // Same partitioner as users

268

269

// Efficient join - no shuffle needed since both RDDs have same partitioner

270

val userOrders = users.join(orders)

271

272

// Efficient grouped operations

273

val userStats = orders

274

.mapValues(order => (1, order.amount)) // (userId, (count, amount))

275

.reduceByKey((a, b) => (a._1 + b._1, a._2 + b._2)) // Pre-aggregated

276

.mapValues(stats => (stats._1, stats._2 / stats._1)) // (count, avgAmount)

277

```

278

279

### Advanced Partitioning Patterns

280

281

```scala

282

// Multi-level partitioning for hierarchical data

283

class HierarchicalPartitioner(

284

primaryPartitions: Int,

285

secondaryPartitions: Int) extends Partitioner {

286

287

override def numPartitions: Int = primaryPartitions * secondaryPartitions

288

289

override def getPartition(key: Any): Int = {

290

val (primary, secondary) = key.asInstanceOf[(String, String)]

291

val primaryPartition = math.abs(primary.hashCode) % primaryPartitions

292

val secondaryPartition = math.abs(secondary.hashCode) % secondaryPartitions

293

primaryPartition * secondaryPartitions + secondaryPartition

294

}

295

}

296

297

// Usage for multi-dimensional data

298

val hierarchicalData = sc.textFile("hierarchical.txt")

299

.map(parseHierarchical) // Returns ((region, category), data)

300

.partitionBy(new HierarchicalPartitioner(4, 4)) // 16 total partitions

301

302

// Skew-aware partitioning

303

class SkewAwarePartitioner[K](

304

rdd: RDD[(K, _)],

305

numPartitions: Int,

306

sampleFraction: Double = 0.1) extends Partitioner {

307

308

// Sample data to identify skewed keys

309

private val keyFrequencies = rdd.sample(false, sampleFraction)

310

.map(_._1)

311

.countByValue()

312

313

private val heavyKeys = keyFrequencies

314

.filter(_._2 > keyFrequencies.values.sum / numPartitions * 2)

315

.keySet

316

317

override def numPartitions: Int = numPartitions

318

319

override def getPartition(key: Any): Int = {

320

val k = key.asInstanceOf[K]

321

if (heavyKeys.contains(k)) {

322

// Distribute heavy keys across multiple partitions

323

(k.hashCode() & Integer.MAX_VALUE) % (numPartitions / 2)

324

} else {

325

// Regular partitioning for normal keys

326

((k.hashCode() & Integer.MAX_VALUE) % (numPartitions / 2)) + (numPartitions / 2)

327

}

328

}

329

}

330

```

331

332

### Monitoring Shuffle Performance

333

334

```scala

335

// Monitor shuffle operations

336

def analyzeShuffleMetrics(sc: SparkContext): Unit = {

337

val statusTracker = sc.statusTracker

338

339

statusTracker.getActiveStageIds().foreach { stageId =>

340

statusTracker.getStageInfo(stageId) match {

341

case Some(stageInfo) =>

342

println(s"Stage $stageId:")

343

println(s" Shuffle Read Bytes: ${stageInfo.shuffleReadBytes}")

344

println(s" Shuffle Write Bytes: ${stageInfo.shuffleWriteBytes}")

345

println(s" Shuffle Read Records: ${stageInfo.shuffleReadRecords}")

346

println(s" Shuffle Write Records: ${stageInfo.shuffleWriteRecords}")

347

case None =>

348

println(s"No info available for stage $stageId")

349

}

350

}

351

}

352

353

// Custom shuffle metrics collection

354

def trackShuffleOperations[T](rdd: RDD[T], operationName: String): RDD[T] = {

355

val startTime = System.currentTimeMillis()

356

val result = rdd.cache() // Force evaluation

357

result.count() // Trigger action

358

val endTime = System.currentTimeMillis()

359

360

println(s"Operation '$operationName' took ${endTime - startTime}ms")

361

println(s"Result partitions: ${result.getNumPartitions}")

362

363

result

364

}

365

366

// Usage

367

val shuffledData = trackShuffleOperations(

368

originalData.groupByKey(),

369

"groupByKey operation"

370

)

371

```

372

373

## Advanced Optimization Techniques

374

375

### Broadcast Hash Joins

376

377

```scala

378

// Optimize small table joins using broadcast

379

def broadcastHashJoin[K, V, W](

380

largeRDD: RDD[(K, V)],

381

smallRDD: RDD[(K, W)])(implicit sc: SparkContext): RDD[(K, (V, W))] = {

382

383

// Collect small RDD to driver and broadcast

384

val smallMap = smallRDD.collectAsMap()

385

val broadcastSmallMap = sc.broadcast(smallMap)

386

387

// Perform map-side join

388

largeRDD.flatMap { case (key, value) =>

389

broadcastSmallMap.value.get(key) match {

390

case Some(smallValue) => Some((key, (value, smallValue)))

391

case None => None

392

}

393

}

394

}

395

396

// Usage

397

val largeTx = sc.textFile("large-transactions.txt").map(parseTx)

398

val smallLookup = sc.textFile("small-lookup.txt").map(parseLookup)

399

400

val joined = broadcastHashJoin(largeTx, smallLookup)

401

```

402

403

### Dynamic Partition Pruning

404

405

```scala

406

// Prune partitions based on predicates

407

def prunePartitions[T](

408

rdd: RDD[T],

409

partitionPredicate: Int => Boolean): RDD[T] = {

410

411

val filteredPartitions = rdd.partitions.zipWithIndex

412

.filter { case (_, index) => partitionPredicate(index) }

413

.map(_._1)

414

415

new PartitionPrunedRDD(rdd, filteredPartitions)

416

}

417

418

// Custom RDD that only processes selected partitions

419

class PartitionPrunedRDD[T](

420

parent: RDD[T],

421

selectedPartitions: Array[Partition]) extends RDD[T](parent) {

422

423

override def compute(split: Partition, context: TaskContext): Iterator[T] = {

424

parent.compute(split, context)

425

}

426

427

override protected def getPartitions: Array[Partition] = selectedPartitions

428

}

429

```

430

431

### Partition-Level Caching Strategy

432

433

```scala

434

// Smart caching based on partition access patterns

435

class SmartCacheManager[T](rdd: RDD[T]) {

436

private var accessCounts = Array.fill(rdd.getNumPartitions)(0L)

437

private var cachedPartitions = Set.empty[Int]

438

439

def accessPartition(partitionId: Int): Unit = {

440

accessCounts(partitionId) += 1

441

442

// Cache frequently accessed partitions

443

if (accessCounts(partitionId) > 10 && !cachedPartitions.contains(partitionId)) {

444

cachePartition(partitionId)

445

}

446

}

447

448

private def cachePartition(partitionId: Int): Unit = {

449

// Implementation would use custom caching logic

450

cachedPartitions += partitionId

451

println(s"Cached partition $partitionId")

452

}

453

454

def getAccessStats: Array[Long] = accessCounts.clone()

455

}

456

```

457

458

## Best Practices

459

460

### Choosing the Right Partitioner

461

462

```scala

463

// Guidelines for partitioner selection

464

def choosePartitioner[K, V](

465

rdd: RDD[(K, V)],

466

operationType: String,

467

dataCharacteristics: Map[String, Any]): Partitioner = {

468

469

val numPartitions = dataCharacteristics.getOrElse("partitions", 200).asInstanceOf[Int]

470

val isSkewed = dataCharacteristics.getOrElse("skewed", false).asInstanceOf[Boolean]

471

val isSorted = dataCharacteristics.getOrElse("sorted", false).asInstanceOf[Boolean]

472

473

(operationType, isSkewed, isSorted) match {

474

case ("join", false, _) => new HashPartitioner(numPartitions)

475

case ("join", true, _) => new SkewAwarePartitioner(rdd, numPartitions)

476

case ("sort", _, false) => new RangePartitioner(numPartitions, rdd)

477

case ("sort", _, true) => new HashPartitioner(numPartitions) // Already sorted

478

case ("groupBy", false, _) => new HashPartitioner(numPartitions)

479

case ("groupBy", true, _) => new SkewAwarePartitioner(rdd, numPartitions)

480

case _ => new HashPartitioner(numPartitions) // Default

481

}

482

}

483

```

484

485

### Partition Size Guidelines

486

487

```scala

488

// Calculate optimal partition count

489

def calculateOptimalPartitions(

490

dataSize: Long,

491

targetPartitionSize: Long = 128 * 1024 * 1024, // 128MB

492

maxPartitions: Int = 2000): Int = {

493

494

val calculatedPartitions = (dataSize / targetPartitionSize).toInt

495

val cores = Runtime.getRuntime.availableProcessors()

496

val minPartitions = cores * 2 // At least 2 partitions per core

497

498

math.min(maxPartitions, math.max(minPartitions, calculatedPartitions))

499

}

500

501

// Monitor partition sizes

502

def analyzePartitionSizes[T](rdd: RDD[T]): Unit = {

503

val partitionSizes = rdd.mapPartitions { iter =>

504

Iterator(iter.size)

505

}.collect()

506

507

val avgSize = partitionSizes.sum.toDouble / partitionSizes.length

508

val maxSize = partitionSizes.max

509

val minSize = partitionSizes.min

510

val skewRatio = maxSize.toDouble / avgSize

511

512

println(s"Partition Analysis:")

513

println(s" Count: ${partitionSizes.length}")

514

println(s" Average size: $avgSize")

515

println(s" Max size: $maxSize")

516

println(s" Min size: $minSize")

517

println(s" Skew ratio: $skewRatio")

518

519

if (skewRatio > 2.0) {

520

println(" WARNING: High partition skew detected!")

521

}

522

}

523

```

524

525

### Memory-Efficient Partitioning

526

527

```scala

528

// Memory-aware partition processing

529

def processPartitionsWithMemoryControl[T, U](

530

rdd: RDD[T],

531

processFunc: Iterator[T] => Iterator[U],

532

maxMemoryPerPartition: Long = 512 * 1024 * 1024): RDD[U] = { // 512MB

533

534

rdd.mapPartitions { partition =>

535

val runtime = Runtime.getRuntime

536

val initialMemory = runtime.totalMemory() - runtime.freeMemory()

537

538

val bufferedPartition = partition.grouped(1000) // Process in batches

539

540

bufferedPartition.flatMap { batch =>

541

val currentMemory = runtime.totalMemory() - runtime.freeMemory()

542

543

if (currentMemory - initialMemory > maxMemoryPerPartition) {

544

System.gc() // Suggest garbage collection

545

Thread.sleep(100) // Allow GC to run

546

}

547

548

processFunc(batch.iterator)

549

}

550

}

551

}

552

```