0
# Task Context
1
2
TaskContext provides runtime information and control for tasks executing on cluster nodes. It offers access to task metadata, partition information, and mechanisms for registering cleanup callbacks.
3
4
## TaskContext Class
5
6
```scala { .api }
7
abstract class TaskContext {
8
// Task Information
9
def isCompleted(): Boolean
10
def isInterrupted(): Boolean
11
def isRunningLocally(): Boolean
12
def stageId(): Int
13
def stageAttemptNumber(): Int
14
def partitionId(): Int
15
def attemptNumber(): Int
16
def taskAttemptId(): Long
17
def executorId(): String
18
def resourcesAllocated(): Map[String, ResourceInformation]
19
20
// Task Properties
21
def getLocalProperty(key: String): String
22
def getLocalProperties(): Properties
23
def setTaskProperty(key: String, value: String): Unit
24
def getTaskProperty(key: String): String
25
def getTaskProperties(): Properties
26
27
// Listeners and Callbacks
28
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
29
def addTaskCompletionListener(f: TaskContext => Unit): TaskContext
30
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
31
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
32
33
// Metrics and Monitoring
34
def taskMetrics(): TaskMetrics
35
def getMetricsSources(sourceName: String): Seq[Source]
36
}
37
```
38
39
### TaskContext Object
40
41
```scala { .api }
42
object TaskContext {
43
def get(): TaskContext
44
def getPartitionId(): Int
45
def setTaskContext(tc: TaskContext): Unit
46
def unset(): Unit
47
}
48
```
49
50
## Task Information Access
51
52
### Basic Task Information
53
54
```scala
55
import org.apache.spark.{SparkContext, SparkConf, TaskContext}
56
57
val sc = new SparkContext(new SparkConf().setAppName("TaskContext Example").setMaster("local[*]"))
58
59
val data = sc.parallelize(1 to 100, numSlices = 4)
60
61
// Access task information during transformation
62
val taskInfo = data.mapPartitionsWithIndex { (partitionIndex, iterator) =>
63
val context = TaskContext.get()
64
65
iterator.map { value =>
66
val info = s"Value: $value, " +
67
s"Partition: ${context.partitionId()}, " +
68
s"Stage: ${context.stageId()}, " +
69
s"Attempt: ${context.attemptNumber()}, " +
70
s"TaskID: ${context.taskAttemptId()}, " +
71
s"Executor: ${context.executorId()}"
72
73
(value, info)
74
}
75
}
76
77
// Collect and print results
78
taskInfo.take(10).foreach { case (value, info) =>
79
println(info)
80
}
81
```
82
83
### Task State Monitoring
84
85
```scala
86
val monitoredData = data.mapPartitions { iterator =>
87
val context = TaskContext.get()
88
89
iterator.map { value =>
90
// Check if task has been interrupted (e.g., due to cancellation)
91
if (context.isInterrupted()) {
92
throw new InterruptedException("Task was interrupted")
93
}
94
95
// Check execution environment
96
val executionInfo = Map(
97
"isCompleted" -> context.isCompleted(),
98
"isRunningLocally" -> context.isRunningLocally(),
99
"partitionId" -> context.partitionId(),
100
"stageAttempt" -> context.stageAttemptNumber()
101
)
102
103
(value, executionInfo)
104
}
105
}
106
```
107
108
## Local Properties
109
110
TaskContext provides access to thread-local properties set at the driver level.
111
112
```scala
113
// Set properties at driver level
114
sc.setLocalProperty("job.group.id", "data-processing")
115
sc.setLocalProperty("job.description", "Processing customer data")
116
sc.setLocalProperty("custom.property", "custom-value")
117
118
val processedData = data.mapPartitions { iterator =>
119
val context = TaskContext.get()
120
121
// Access properties in task
122
val jobGroup = context.getLocalProperty("job.group.id")
123
val jobDescription = context.getLocalProperty("job.description")
124
val customProperty = context.getLocalProperty("custom.property")
125
126
iterator.map { value =>
127
// Use properties in processing logic
128
val metadata = Map(
129
"jobGroup" -> jobGroup,
130
"description" -> jobDescription,
131
"custom" -> customProperty
132
)
133
134
(value, metadata)
135
}
136
}
137
138
// Clear properties when done
139
sc.setLocalProperty("job.group.id", null)
140
sc.setLocalProperty("job.description", null)
141
```
142
143
## Task Completion and Failure Listeners
144
145
TaskContext allows registering callbacks for task lifecycle events.
146
147
### Completion Listeners
148
149
```scala
150
import java.io.Closeable
151
152
val dataWithResources = data.mapPartitions { iterator =>
153
val context = TaskContext.get()
154
155
// Create resources that need cleanup
156
val tempFile = createTempFile()
157
val databaseConnection = createDatabaseConnection()
158
val httpClient = createHttpClient()
159
160
// Register cleanup listeners
161
context.addTaskCompletionListener { taskContext =>
162
println(s"Task ${taskContext.taskAttemptId()} completed on partition ${taskContext.partitionId()}")
163
164
// Cleanup resources
165
try {
166
tempFile.delete()
167
databaseConnection.close()
168
httpClient.close()
169
println("Resources cleaned up successfully")
170
} catch {
171
case e: Exception =>
172
println(s"Error during cleanup: ${e.getMessage}")
173
}
174
}
175
176
// Process data using resources
177
iterator.map { value =>
178
// Use database connection, http client, etc.
179
processWithResources(value, databaseConnection, httpClient)
180
}
181
}
182
```
183
184
### Failure Listeners
185
186
```scala
187
val robustProcessing = data.mapPartitions { iterator =>
188
val context = TaskContext.get()
189
val metrics = collection.mutable.Map[String, Long]()
190
191
// Register failure listener for error reporting
192
context.addTaskFailureListener { (taskContext, exception) =>
193
val errorReport = s"""
194
|Task Failure Report:
195
| Task ID: ${taskContext.taskAttemptId()}
196
| Partition: ${taskContext.partitionId()}
197
| Stage: ${taskContext.stageId()}
198
| Attempt: ${taskContext.attemptNumber()}
199
| Exception: ${exception.getClass.getSimpleName}
200
| Message: ${exception.getMessage}
201
| Records Processed: ${metrics.getOrElse("processed", 0L)}
202
| Records Failed: ${metrics.getOrElse("failed", 0L)}
203
""".stripMargin
204
205
println(errorReport)
206
logErrorToMonitoringSystem(errorReport)
207
}
208
209
iterator.map { value =>
210
try {
211
val result = riskyProcessing(value)
212
metrics("processed") = metrics.getOrElse("processed", 0L) + 1
213
result
214
} catch {
215
case e: Exception =>
216
metrics("failed") = metrics.getOrElse("failed", 0L) + 1
217
throw e
218
}
219
}
220
}
221
```
222
223
### Combined Resource Management
224
225
```scala
226
class ResourceManager extends Serializable {
227
def withResources[T](context: TaskContext)(block: (DatabaseConnection, FileWriter) => Iterator[T]): Iterator[T] = {
228
var connection: DatabaseConnection = null
229
var writer: FileWriter = null
230
231
try {
232
// Initialize resources
233
connection = new DatabaseConnection()
234
writer = new FileWriter(s"output-partition-${context.partitionId()}.log")
235
236
// Register cleanup on both success and failure
237
context.addTaskCompletionListener { _ =>
238
cleanupResources(connection, writer, success = true)
239
}
240
241
context.addTaskFailureListener { (_, exception) =>
242
writer.write(s"Task failed with exception: ${exception.getMessage}\n")
243
cleanupResources(connection, writer, success = false)
244
}
245
246
// Execute the processing block
247
block(connection, writer)
248
249
} catch {
250
case e: Exception =>
251
cleanupResources(connection, writer, success = false)
252
throw e
253
}
254
}
255
256
private def cleanupResources(connection: DatabaseConnection, writer: FileWriter, success: Boolean): Unit = {
257
try {
258
if (writer != null) {
259
writer.write(s"Task completed with success: $success\n")
260
writer.close()
261
}
262
if (connection != null) {
263
if (success) connection.commit() else connection.rollback()
264
connection.close()
265
}
266
} catch {
267
case e: Exception =>
268
println(s"Error during resource cleanup: ${e.getMessage}")
269
}
270
}
271
}
272
273
// Usage
274
val resourceManager = new ResourceManager()
275
276
val processedData = data.mapPartitions { iterator =>
277
val context = TaskContext.get()
278
279
resourceManager.withResources(context) { (connection, writer) =>
280
iterator.map { value =>
281
writer.write(s"Processing value: $value\n")
282
val result = processWithDatabase(value, connection)
283
writer.write(s"Result: $result\n")
284
result
285
}
286
}
287
}
288
```
289
290
## Task Metrics Access
291
292
TaskContext provides access to detailed task execution metrics.
293
294
```scala
295
val metricsCollector = data.mapPartitions { iterator =>
296
val context = TaskContext.get()
297
298
iterator.map { value =>
299
// Access task metrics during execution
300
val metrics = context.taskMetrics()
301
302
val currentMetrics = Map(
303
"executorDeserializeTime" -> metrics.executorDeserializeTime,
304
"executorRunTime" -> metrics.executorRunTime,
305
"jvmGCTime" -> metrics.jvmGCTime,
306
"inputMetrics" -> Option(metrics.inputMetrics).map(_.toString).getOrElse("None"),
307
"outputMetrics" -> Option(metrics.outputMetrics).map(_.toString).getOrElse("None"),
308
"shuffleReadMetrics" -> Option(metrics.shuffleReadMetrics).map(_.toString).getOrElse("None"),
309
"shuffleWriteMetrics" -> Option(metrics.shuffleWriteMetrics).map(_.toString).getOrElse("None")
310
)
311
312
(value, currentMetrics)
313
}
314
}
315
```
316
317
## Advanced Patterns
318
319
### Partition-wise Resource Pooling
320
321
```scala
322
import java.util.concurrent.ConcurrentHashMap
323
import scala.collection.JavaConverters._
324
325
object PartitionResourcePool {
326
private val connectionPools = new ConcurrentHashMap[Int, DatabaseConnectionPool]()
327
328
def getConnection(partitionId: Int): DatabaseConnection = {
329
connectionPools.computeIfAbsent(partitionId, _ => new DatabaseConnectionPool()).getConnection()
330
}
331
332
def releaseConnection(partitionId: Int, connection: DatabaseConnection): Unit = {
333
Option(connectionPools.get(partitionId)).foreach(_.releaseConnection(connection))
334
}
335
336
def closePool(partitionId: Int): Unit = {
337
Option(connectionPools.remove(partitionId)).foreach(_.close())
338
}
339
}
340
341
val pooledProcessing = data.mapPartitions { iterator =>
342
val context = TaskContext.get()
343
val partitionId = context.partitionId()
344
345
// Register cleanup for the entire partition
346
context.addTaskCompletionListener { _ =>
347
PartitionResourcePool.closePool(partitionId)
348
}
349
350
iterator.map { value =>
351
val connection = PartitionResourcePool.getConnection(partitionId)
352
try {
353
processWithConnection(value, connection)
354
} finally {
355
PartitionResourcePool.releaseConnection(partitionId, connection)
356
}
357
}
358
}
359
```
360
361
### Progress Tracking
362
363
```scala
364
class ProgressTracker extends Serializable {
365
def trackProgress[T](iterator: Iterator[T], context: TaskContext): Iterator[T] = {
366
val totalEstimate = 1000000 // Estimated records per partition
367
var processed = 0L
368
val startTime = System.currentTimeMillis()
369
370
iterator.map { item =>
371
processed += 1
372
373
// Report progress every 10,000 records
374
if (processed % 10000 == 0) {
375
val elapsed = System.currentTimeMillis() - startTime
376
val rate = processed.toDouble / (elapsed / 1000.0)
377
val progress = (processed.toDouble / totalEstimate) * 100
378
379
val progressInfo = s"Partition ${context.partitionId()}: " +
380
s"${processed} records processed (${progress.formatted("%.1f")}%), " +
381
s"Rate: ${rate.formatted("%.0f")} records/sec"
382
383
println(progressInfo)
384
385
// Optional: Send to monitoring system
386
sendProgressUpdate(context.stageId(), context.partitionId(), processed, progress, rate)
387
}
388
389
item
390
}
391
}
392
393
private def sendProgressUpdate(stageId: Int, partitionId: Int, processed: Long, progress: Double, rate: Double): Unit = {
394
// Implementation to send metrics to monitoring system
395
}
396
}
397
398
// Usage
399
val tracker = new ProgressTracker()
400
401
val trackedProcessing = data.mapPartitions { iterator =>
402
val context = TaskContext.get()
403
tracker.trackProgress(iterator, context).map(processItem)
404
}
405
```
406
407
### Conditional Processing Based on Context
408
409
```scala
410
val conditionalProcessing = data.mapPartitions { iterator =>
411
val context = TaskContext.get()
412
413
// Different processing based on execution context
414
val processingStrategy = if (context.isRunningLocally()) {
415
// Local execution - use simpler, debugging-friendly processing
416
(value: Int) => {
417
println(s"Local processing: $value")
418
value * 2
419
}
420
} else {
421
// Cluster execution - use optimized processing
422
val batchSize = 1000
423
val batch = collection.mutable.ArrayBuffer[Int]()
424
425
(value: Int) => {
426
batch += value
427
if (batch.size >= batchSize) {
428
val results = batchProcess(batch.toArray)
429
batch.clear()
430
results.last // Return last result for this example
431
} else {
432
value * 2 // Simple processing for incomplete batch
433
}
434
}
435
}
436
437
iterator.map(processingStrategy)
438
}
439
```
440
441
### Dynamic Resource Allocation
442
443
```scala
444
val adaptiveProcessing = data.mapPartitions { iterator =>
445
val context = TaskContext.get()
446
val partitionId = context.partitionId()
447
448
// Get available resources for this task
449
val resources = context.resourcesAllocated()
450
val availableCores = resources.get("cpus").map(_.amount.toInt).getOrElse(1)
451
val availableMemory = resources.get("memory").map(_.amount.toLong).getOrElse(1024L * 1024 * 1024) // 1GB default
452
453
// Adjust processing parameters based on available resources
454
val (batchSize, parallelism) = if (availableCores > 4 && availableMemory > 2L * 1024 * 1024 * 1024) {
455
(10000, 4) // High resource configuration
456
} else if (availableCores > 2 && availableMemory > 1024L * 1024 * 1024) {
457
(5000, 2) // Medium resource configuration
458
} else {
459
(1000, 1) // Low resource configuration
460
}
461
462
println(s"Partition $partitionId: Using batch size $batchSize with parallelism $parallelism")
463
464
// Process with adaptive parameters
465
iterator.grouped(batchSize).flatMap { batch =>
466
processInParallel(batch, parallelism)
467
}
468
}
469
```
470
471
## BarrierTaskContext
472
473
For barrier execution mode, Spark provides BarrierTaskContext with additional synchronization capabilities.
474
475
```scala { .api }
476
class BarrierTaskContext extends TaskContext {
477
def barrier(): Unit
478
def getTaskInfos(): Array[BarrierTaskInfo]
479
def getAllGather[T](message: T): Array[T]
480
}
481
482
case class BarrierTaskInfo(address: String)
483
```
484
485
### Barrier Execution Example
486
487
```scala
488
// Note: This requires barrier mode to be enabled
489
val barrierRDD = sc.parallelize(1 to 100, numSlices = 4)
490
491
val synchronizedProcessing = barrierRDD.barrier().mapPartitions { iterator =>
492
val context = TaskContext.get().asInstanceOf[BarrierTaskContext]
493
494
// Phase 1: Local processing
495
val localResults = iterator.map(process).toArray
496
497
// Synchronization point - all tasks wait here
498
context.barrier()
499
500
// Phase 2: Coordinate across partitions
501
val taskInfos = context.getTaskInfos()
502
println(s"Task addresses: ${taskInfos.map(_.address).mkString(", ")}")
503
504
// Share local statistics across all tasks
505
val localSum = localResults.sum
506
val allSums = context.getAllGather(localSum)
507
val globalSum = allSums.sum
508
509
// Phase 3: Process using global information
510
localResults.map(_ + globalSum).iterator
511
}
512
```
513
514
## Best Practices
515
516
### Resource Management
517
- Always register cleanup listeners for external resources
518
- Use try-finally blocks for critical resource cleanup
519
- Consider using connection pools for database connections
520
- Monitor resource usage through task metrics
521
522
### Error Handling
523
- Register failure listeners for comprehensive error reporting
524
- Log sufficient context information for debugging
525
- Implement graceful degradation when possible
526
- Use task metrics to identify performance bottlenecks
527
528
### Performance Optimization
529
- Access TaskContext only when needed (slight overhead)
530
- Cache TaskContext reference if used multiple times in same task
531
- Use local properties for configuration instead of closures when possible
532
- Monitor task execution time and optimize accordingly
533
534
### Monitoring and Debugging
535
```scala
536
// Comprehensive task monitoring
537
def monitoredMapPartitions[T, U](rdd: RDD[T])(f: Iterator[T] => Iterator[U]): RDD[U] = {
538
rdd.mapPartitions { iterator =>
539
val context = TaskContext.get()
540
val startTime = System.currentTimeMillis()
541
var recordCount = 0L
542
543
context.addTaskCompletionListener { taskContext =>
544
val duration = System.currentTimeMillis() - startTime
545
println(s"Task ${taskContext.taskAttemptId()} on partition ${taskContext.partitionId()} " +
546
s"processed $recordCount records in ${duration}ms")
547
}
548
549
f(iterator).map { result =>
550
recordCount += 1
551
result
552
}
553
}
554
}
555
```