or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

broadcast-accumulators.mdcontext-configuration.mdindex.mdjava-api.mdkey-value-operations.mdrdd-operations.mdstatus-monitoring.mdstorage-persistence.mdtask-context.md

task-context.mddocs/

0

# Task Context

1

2

TaskContext provides runtime information and control for tasks executing on cluster nodes. It offers access to task metadata, partition information, and mechanisms for registering cleanup callbacks.

3

4

## TaskContext Class

5

6

```scala { .api }

7

abstract class TaskContext {

8

// Task Information

9

def isCompleted(): Boolean

10

def isInterrupted(): Boolean

11

def isRunningLocally(): Boolean

12

def stageId(): Int

13

def stageAttemptNumber(): Int

14

def partitionId(): Int

15

def attemptNumber(): Int

16

def taskAttemptId(): Long

17

def executorId(): String

18

def resourcesAllocated(): Map[String, ResourceInformation]

19

20

// Task Properties

21

def getLocalProperty(key: String): String

22

def getLocalProperties(): Properties

23

def setTaskProperty(key: String, value: String): Unit

24

def getTaskProperty(key: String): String

25

def getTaskProperties(): Properties

26

27

// Listeners and Callbacks

28

def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext

29

def addTaskCompletionListener(f: TaskContext => Unit): TaskContext

30

def addTaskFailureListener(listener: TaskFailureListener): TaskContext

31

def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext

32

33

// Metrics and Monitoring

34

def taskMetrics(): TaskMetrics

35

def getMetricsSources(sourceName: String): Seq[Source]

36

}

37

```

38

39

### TaskContext Object

40

41

```scala { .api }

42

object TaskContext {

43

def get(): TaskContext

44

def getPartitionId(): Int

45

def setTaskContext(tc: TaskContext): Unit

46

def unset(): Unit

47

}

48

```

49

50

## Task Information Access

51

52

### Basic Task Information

53

54

```scala

55

import org.apache.spark.{SparkContext, SparkConf, TaskContext}

56

57

val sc = new SparkContext(new SparkConf().setAppName("TaskContext Example").setMaster("local[*]"))

58

59

val data = sc.parallelize(1 to 100, numSlices = 4)

60

61

// Access task information during transformation

62

val taskInfo = data.mapPartitionsWithIndex { (partitionIndex, iterator) =>

63

val context = TaskContext.get()

64

65

iterator.map { value =>

66

val info = s"Value: $value, " +

67

s"Partition: ${context.partitionId()}, " +

68

s"Stage: ${context.stageId()}, " +

69

s"Attempt: ${context.attemptNumber()}, " +

70

s"TaskID: ${context.taskAttemptId()}, " +

71

s"Executor: ${context.executorId()}"

72

73

(value, info)

74

}

75

}

76

77

// Collect and print results

78

taskInfo.take(10).foreach { case (value, info) =>

79

println(info)

80

}

81

```

82

83

### Task State Monitoring

84

85

```scala

86

val monitoredData = data.mapPartitions { iterator =>

87

val context = TaskContext.get()

88

89

iterator.map { value =>

90

// Check if task has been interrupted (e.g., due to cancellation)

91

if (context.isInterrupted()) {

92

throw new InterruptedException("Task was interrupted")

93

}

94

95

// Check execution environment

96

val executionInfo = Map(

97

"isCompleted" -> context.isCompleted(),

98

"isRunningLocally" -> context.isRunningLocally(),

99

"partitionId" -> context.partitionId(),

100

"stageAttempt" -> context.stageAttemptNumber()

101

)

102

103

(value, executionInfo)

104

}

105

}

106

```

107

108

## Local Properties

109

110

TaskContext provides access to thread-local properties set at the driver level.

111

112

```scala

113

// Set properties at driver level

114

sc.setLocalProperty("job.group.id", "data-processing")

115

sc.setLocalProperty("job.description", "Processing customer data")

116

sc.setLocalProperty("custom.property", "custom-value")

117

118

val processedData = data.mapPartitions { iterator =>

119

val context = TaskContext.get()

120

121

// Access properties in task

122

val jobGroup = context.getLocalProperty("job.group.id")

123

val jobDescription = context.getLocalProperty("job.description")

124

val customProperty = context.getLocalProperty("custom.property")

125

126

iterator.map { value =>

127

// Use properties in processing logic

128

val metadata = Map(

129

"jobGroup" -> jobGroup,

130

"description" -> jobDescription,

131

"custom" -> customProperty

132

)

133

134

(value, metadata)

135

}

136

}

137

138

// Clear properties when done

139

sc.setLocalProperty("job.group.id", null)

140

sc.setLocalProperty("job.description", null)

141

```

142

143

## Task Completion and Failure Listeners

144

145

TaskContext allows registering callbacks for task lifecycle events.

146

147

### Completion Listeners

148

149

```scala

150

import java.io.Closeable

151

152

val dataWithResources = data.mapPartitions { iterator =>

153

val context = TaskContext.get()

154

155

// Create resources that need cleanup

156

val tempFile = createTempFile()

157

val databaseConnection = createDatabaseConnection()

158

val httpClient = createHttpClient()

159

160

// Register cleanup listeners

161

context.addTaskCompletionListener { taskContext =>

162

println(s"Task ${taskContext.taskAttemptId()} completed on partition ${taskContext.partitionId()}")

163

164

// Cleanup resources

165

try {

166

tempFile.delete()

167

databaseConnection.close()

168

httpClient.close()

169

println("Resources cleaned up successfully")

170

} catch {

171

case e: Exception =>

172

println(s"Error during cleanup: ${e.getMessage}")

173

}

174

}

175

176

// Process data using resources

177

iterator.map { value =>

178

// Use database connection, http client, etc.

179

processWithResources(value, databaseConnection, httpClient)

180

}

181

}

182

```

183

184

### Failure Listeners

185

186

```scala

187

val robustProcessing = data.mapPartitions { iterator =>

188

val context = TaskContext.get()

189

val metrics = collection.mutable.Map[String, Long]()

190

191

// Register failure listener for error reporting

192

context.addTaskFailureListener { (taskContext, exception) =>

193

val errorReport = s"""

194

|Task Failure Report:

195

| Task ID: ${taskContext.taskAttemptId()}

196

| Partition: ${taskContext.partitionId()}

197

| Stage: ${taskContext.stageId()}

198

| Attempt: ${taskContext.attemptNumber()}

199

| Exception: ${exception.getClass.getSimpleName}

200

| Message: ${exception.getMessage}

201

| Records Processed: ${metrics.getOrElse("processed", 0L)}

202

| Records Failed: ${metrics.getOrElse("failed", 0L)}

203

""".stripMargin

204

205

println(errorReport)

206

logErrorToMonitoringSystem(errorReport)

207

}

208

209

iterator.map { value =>

210

try {

211

val result = riskyProcessing(value)

212

metrics("processed") = metrics.getOrElse("processed", 0L) + 1

213

result

214

} catch {

215

case e: Exception =>

216

metrics("failed") = metrics.getOrElse("failed", 0L) + 1

217

throw e

218

}

219

}

220

}

221

```

222

223

### Combined Resource Management

224

225

```scala

226

class ResourceManager extends Serializable {

227

def withResources[T](context: TaskContext)(block: (DatabaseConnection, FileWriter) => Iterator[T]): Iterator[T] = {

228

var connection: DatabaseConnection = null

229

var writer: FileWriter = null

230

231

try {

232

// Initialize resources

233

connection = new DatabaseConnection()

234

writer = new FileWriter(s"output-partition-${context.partitionId()}.log")

235

236

// Register cleanup on both success and failure

237

context.addTaskCompletionListener { _ =>

238

cleanupResources(connection, writer, success = true)

239

}

240

241

context.addTaskFailureListener { (_, exception) =>

242

writer.write(s"Task failed with exception: ${exception.getMessage}\n")

243

cleanupResources(connection, writer, success = false)

244

}

245

246

// Execute the processing block

247

block(connection, writer)

248

249

} catch {

250

case e: Exception =>

251

cleanupResources(connection, writer, success = false)

252

throw e

253

}

254

}

255

256

private def cleanupResources(connection: DatabaseConnection, writer: FileWriter, success: Boolean): Unit = {

257

try {

258

if (writer != null) {

259

writer.write(s"Task completed with success: $success\n")

260

writer.close()

261

}

262

if (connection != null) {

263

if (success) connection.commit() else connection.rollback()

264

connection.close()

265

}

266

} catch {

267

case e: Exception =>

268

println(s"Error during resource cleanup: ${e.getMessage}")

269

}

270

}

271

}

272

273

// Usage

274

val resourceManager = new ResourceManager()

275

276

val processedData = data.mapPartitions { iterator =>

277

val context = TaskContext.get()

278

279

resourceManager.withResources(context) { (connection, writer) =>

280

iterator.map { value =>

281

writer.write(s"Processing value: $value\n")

282

val result = processWithDatabase(value, connection)

283

writer.write(s"Result: $result\n")

284

result

285

}

286

}

287

}

288

```

289

290

## Task Metrics Access

291

292

TaskContext provides access to detailed task execution metrics.

293

294

```scala

295

val metricsCollector = data.mapPartitions { iterator =>

296

val context = TaskContext.get()

297

298

iterator.map { value =>

299

// Access task metrics during execution

300

val metrics = context.taskMetrics()

301

302

val currentMetrics = Map(

303

"executorDeserializeTime" -> metrics.executorDeserializeTime,

304

"executorRunTime" -> metrics.executorRunTime,

305

"jvmGCTime" -> metrics.jvmGCTime,

306

"inputMetrics" -> Option(metrics.inputMetrics).map(_.toString).getOrElse("None"),

307

"outputMetrics" -> Option(metrics.outputMetrics).map(_.toString).getOrElse("None"),

308

"shuffleReadMetrics" -> Option(metrics.shuffleReadMetrics).map(_.toString).getOrElse("None"),

309

"shuffleWriteMetrics" -> Option(metrics.shuffleWriteMetrics).map(_.toString).getOrElse("None")

310

)

311

312

(value, currentMetrics)

313

}

314

}

315

```

316

317

## Advanced Patterns

318

319

### Partition-wise Resource Pooling

320

321

```scala

322

import java.util.concurrent.ConcurrentHashMap

323

import scala.collection.JavaConverters._

324

325

object PartitionResourcePool {

326

private val connectionPools = new ConcurrentHashMap[Int, DatabaseConnectionPool]()

327

328

def getConnection(partitionId: Int): DatabaseConnection = {

329

connectionPools.computeIfAbsent(partitionId, _ => new DatabaseConnectionPool()).getConnection()

330

}

331

332

def releaseConnection(partitionId: Int, connection: DatabaseConnection): Unit = {

333

Option(connectionPools.get(partitionId)).foreach(_.releaseConnection(connection))

334

}

335

336

def closePool(partitionId: Int): Unit = {

337

Option(connectionPools.remove(partitionId)).foreach(_.close())

338

}

339

}

340

341

val pooledProcessing = data.mapPartitions { iterator =>

342

val context = TaskContext.get()

343

val partitionId = context.partitionId()

344

345

// Register cleanup for the entire partition

346

context.addTaskCompletionListener { _ =>

347

PartitionResourcePool.closePool(partitionId)

348

}

349

350

iterator.map { value =>

351

val connection = PartitionResourcePool.getConnection(partitionId)

352

try {

353

processWithConnection(value, connection)

354

} finally {

355

PartitionResourcePool.releaseConnection(partitionId, connection)

356

}

357

}

358

}

359

```

360

361

### Progress Tracking

362

363

```scala

364

class ProgressTracker extends Serializable {

365

def trackProgress[T](iterator: Iterator[T], context: TaskContext): Iterator[T] = {

366

val totalEstimate = 1000000 // Estimated records per partition

367

var processed = 0L

368

val startTime = System.currentTimeMillis()

369

370

iterator.map { item =>

371

processed += 1

372

373

// Report progress every 10,000 records

374

if (processed % 10000 == 0) {

375

val elapsed = System.currentTimeMillis() - startTime

376

val rate = processed.toDouble / (elapsed / 1000.0)

377

val progress = (processed.toDouble / totalEstimate) * 100

378

379

val progressInfo = s"Partition ${context.partitionId()}: " +

380

s"${processed} records processed (${progress.formatted("%.1f")}%), " +

381

s"Rate: ${rate.formatted("%.0f")} records/sec"

382

383

println(progressInfo)

384

385

// Optional: Send to monitoring system

386

sendProgressUpdate(context.stageId(), context.partitionId(), processed, progress, rate)

387

}

388

389

item

390

}

391

}

392

393

private def sendProgressUpdate(stageId: Int, partitionId: Int, processed: Long, progress: Double, rate: Double): Unit = {

394

// Implementation to send metrics to monitoring system

395

}

396

}

397

398

// Usage

399

val tracker = new ProgressTracker()

400

401

val trackedProcessing = data.mapPartitions { iterator =>

402

val context = TaskContext.get()

403

tracker.trackProgress(iterator, context).map(processItem)

404

}

405

```

406

407

### Conditional Processing Based on Context

408

409

```scala

410

val conditionalProcessing = data.mapPartitions { iterator =>

411

val context = TaskContext.get()

412

413

// Different processing based on execution context

414

val processingStrategy = if (context.isRunningLocally()) {

415

// Local execution - use simpler, debugging-friendly processing

416

(value: Int) => {

417

println(s"Local processing: $value")

418

value * 2

419

}

420

} else {

421

// Cluster execution - use optimized processing

422

val batchSize = 1000

423

val batch = collection.mutable.ArrayBuffer[Int]()

424

425

(value: Int) => {

426

batch += value

427

if (batch.size >= batchSize) {

428

val results = batchProcess(batch.toArray)

429

batch.clear()

430

results.last // Return last result for this example

431

} else {

432

value * 2 // Simple processing for incomplete batch

433

}

434

}

435

}

436

437

iterator.map(processingStrategy)

438

}

439

```

440

441

### Dynamic Resource Allocation

442

443

```scala

444

val adaptiveProcessing = data.mapPartitions { iterator =>

445

val context = TaskContext.get()

446

val partitionId = context.partitionId()

447

448

// Get available resources for this task

449

val resources = context.resourcesAllocated()

450

val availableCores = resources.get("cpus").map(_.amount.toInt).getOrElse(1)

451

val availableMemory = resources.get("memory").map(_.amount.toLong).getOrElse(1024L * 1024 * 1024) // 1GB default

452

453

// Adjust processing parameters based on available resources

454

val (batchSize, parallelism) = if (availableCores > 4 && availableMemory > 2L * 1024 * 1024 * 1024) {

455

(10000, 4) // High resource configuration

456

} else if (availableCores > 2 && availableMemory > 1024L * 1024 * 1024) {

457

(5000, 2) // Medium resource configuration

458

} else {

459

(1000, 1) // Low resource configuration

460

}

461

462

println(s"Partition $partitionId: Using batch size $batchSize with parallelism $parallelism")

463

464

// Process with adaptive parameters

465

iterator.grouped(batchSize).flatMap { batch =>

466

processInParallel(batch, parallelism)

467

}

468

}

469

```

470

471

## BarrierTaskContext

472

473

For barrier execution mode, Spark provides BarrierTaskContext with additional synchronization capabilities.

474

475

```scala { .api }

476

class BarrierTaskContext extends TaskContext {

477

def barrier(): Unit

478

def getTaskInfos(): Array[BarrierTaskInfo]

479

def getAllGather[T](message: T): Array[T]

480

}

481

482

case class BarrierTaskInfo(address: String)

483

```

484

485

### Barrier Execution Example

486

487

```scala

488

// Note: This requires barrier mode to be enabled

489

val barrierRDD = sc.parallelize(1 to 100, numSlices = 4)

490

491

val synchronizedProcessing = barrierRDD.barrier().mapPartitions { iterator =>

492

val context = TaskContext.get().asInstanceOf[BarrierTaskContext]

493

494

// Phase 1: Local processing

495

val localResults = iterator.map(process).toArray

496

497

// Synchronization point - all tasks wait here

498

context.barrier()

499

500

// Phase 2: Coordinate across partitions

501

val taskInfos = context.getTaskInfos()

502

println(s"Task addresses: ${taskInfos.map(_.address).mkString(", ")}")

503

504

// Share local statistics across all tasks

505

val localSum = localResults.sum

506

val allSums = context.getAllGather(localSum)

507

val globalSum = allSums.sum

508

509

// Phase 3: Process using global information

510

localResults.map(_ + globalSum).iterator

511

}

512

```

513

514

## Best Practices

515

516

### Resource Management

517

- Always register cleanup listeners for external resources

518

- Use try-finally blocks for critical resource cleanup

519

- Consider using connection pools for database connections

520

- Monitor resource usage through task metrics

521

522

### Error Handling

523

- Register failure listeners for comprehensive error reporting

524

- Log sufficient context information for debugging

525

- Implement graceful degradation when possible

526

- Use task metrics to identify performance bottlenecks

527

528

### Performance Optimization

529

- Access TaskContext only when needed (slight overhead)

530

- Cache TaskContext reference if used multiple times in same task

531

- Use local properties for configuration instead of closures when possible

532

- Monitor task execution time and optimize accordingly

533

534

### Monitoring and Debugging

535

```scala

536

// Comprehensive task monitoring

537

def monitoredMapPartitions[T, U](rdd: RDD[T])(f: Iterator[T] => Iterator[U]): RDD[U] = {

538

rdd.mapPartitions { iterator =>

539

val context = TaskContext.get()

540

val startTime = System.currentTimeMillis()

541

var recordCount = 0L

542

543

context.addTaskCompletionListener { taskContext =>

544

val duration = System.currentTimeMillis() - startTime

545

println(s"Task ${taskContext.taskAttemptId()} on partition ${taskContext.partitionId()} " +

546

s"processed $recordCount records in ${duration}ms")

547

}

548

549

f(iterator).map { result =>

550

recordCount += 1

551

result

552

}

553

}

554

}

555

```