0
# Task Context
1
2
TaskContext provides runtime information and utilities available to tasks during execution, including partition information, memory management, and lifecycle hooks.
3
4
## TaskContext
5
6
The main context object available to running tasks.
7
8
```scala { .api }
9
abstract class TaskContext {
10
// Task Identification
11
def partitionId(): Int
12
def stageId(): Int
13
def stageAttemptNumber(): Int
14
def taskAttemptId(): Long
15
def attemptNumber(): Int
16
def taskMetrics(): TaskMetrics
17
18
// Memory Management
19
def taskMemoryManager(): TaskMemoryManager
20
21
// Lifecycle Hooks
22
def registerTaskCompletionListener(listener: TaskCompletionListener): TaskContext
23
def registerTaskFailureListener(listener: TaskFailureListener): TaskContext
24
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
25
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
26
27
// Status Information
28
def isCompleted(): Boolean
29
def isRunningLocally(): Boolean
30
def isInterrupted(): Boolean
31
def getKillReason(): Option[String]
32
def killTaskIfInterrupted(): Unit
33
34
// Resources and Metrics
35
def resources(): Map[String, ResourceInformation]
36
def cpus(): Int
37
def getMetricsSources(sourceName: String): Seq[Source]
38
39
// Local Properties
40
def getLocalProperty(key: String): String
41
def setTaskThread(thread: Thread): Unit
42
def markTaskCompleted(error: Option[Throwable]): Unit
43
def markTaskFailed(error: Throwable): Unit
44
}
45
46
object TaskContext {
47
def get(): TaskContext
48
def getPartitionId(): Int
49
}
50
```
51
52
## Task Listeners
53
54
Interfaces for handling task lifecycle events.
55
56
```scala { .api }
57
trait TaskCompletionListener {
58
def onTaskCompletion(context: TaskContext): Unit
59
}
60
61
trait TaskFailureListener {
62
def onTaskFailure(context: TaskContext, error: Throwable): Unit
63
}
64
```
65
66
## TaskMetrics
67
68
Metrics and statistics collected during task execution.
69
70
```scala { .api }
71
class TaskMetrics {
72
// Execution Metrics
73
def executorDeserializeTime: Long
74
def executorDeserializeCpuTime: Long
75
def executorRunTime: Long
76
def executorCpuTime: Long
77
def resultSize: Long
78
def jvmGCTime: Long
79
def resultSerializationTime: Long
80
def memoryBytesSpilled: Long
81
def diskBytesSpilled: Long
82
def peakExecutionMemory: Long
83
84
// I/O Metrics
85
def inputMetrics: InputMetrics
86
def outputMetrics: OutputMetrics
87
def shuffleReadMetrics: ShuffleReadMetrics
88
def shuffleWriteMetrics: ShuffleWriteMetrics
89
90
// Update Methods
91
def incExecutorDeserializeTime(v: Long): Unit
92
def incExecutorRunTime(v: Long): Unit
93
def incResultSize(v: Long): Unit
94
def incJvmGCTime(v: Long): Unit
95
def incResultSerializationTime(v: Long): Unit
96
def incMemoryBytesSpilled(v: Long): Unit
97
def incDiskBytesSpilled(v: Long): Unit
98
def setPeakExecutionMemory(v: Long): Unit
99
}
100
```
101
102
## Input/Output/Shuffle Metrics
103
104
Detailed metrics for different types of I/O operations.
105
106
```scala { .api }
107
class InputMetrics {
108
def bytesRead: Long
109
def recordsRead: Long
110
def incBytesRead(v: Long): Unit
111
def incRecordsRead(v: Long): Unit
112
def setBytesRead(v: Long): Unit
113
def setRecordsRead(v: Long): Unit
114
}
115
116
class OutputMetrics {
117
def bytesWritten: Long
118
def recordsWritten: Long
119
def incBytesWritten(v: Long): Unit
120
def incRecordsWritten(v: Long): Unit
121
def setBytesWritten(v: Long): Unit
122
def setRecordsWritten(v: Long): Unit
123
}
124
125
class ShuffleReadMetrics {
126
def remoteBlocksFetched: Long
127
def localBlocksFetched: Long
128
def fetchWaitTime: Long
129
def remoteBytesRead: Long
130
def remoteBytesReadToDisk: Long
131
def localBytesRead: Long
132
def recordsRead: Long
133
134
def incRemoteBlocksFetched(v: Long): Unit
135
def incLocalBlocksFetched(v: Long): Unit
136
def incFetchWaitTime(v: Long): Unit
137
def incRemoteBytesRead(v: Long): Unit
138
def incRemoteBytesReadToDisk(v: Long): Unit
139
def incLocalBytesRead(v: Long): Unit
140
def incRecordsRead(v: Long): Unit
141
}
142
143
class ShuffleWriteMetrics {
144
def bytesWritten: Long
145
def recordsWritten: Long
146
def writeTime: Long
147
148
def incBytesWritten(v: Long): Unit
149
def incRecordsWritten(v: Long): Unit
150
def incWriteTime(v: Long): Unit
151
def setBytesWritten(v: Long): Unit
152
def setRecordsWritten(v: Long): Unit
153
def setWriteTime(v: Long): Unit
154
}
155
```
156
157
## TaskMemoryManager
158
159
Memory management interface for tasks.
160
161
```scala { .api }
162
class TaskMemoryManager(memoryManager: MemoryManager, taskAttemptId: Long) {
163
// Memory Acquisition
164
def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long
165
def releaseExecutionMemory(size: Long, consumer: MemoryConsumer): Unit
166
def releaseAllExecutionMemoryForConsumer(consumer: MemoryConsumer): Long
167
168
// Memory Information
169
def executionMemoryUsed: Long
170
def getTungstenMemoryMode: MemoryMode
171
172
// Page Management (for advanced users)
173
def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock
174
def freePage(page: MemoryBlock, consumer: MemoryConsumer): Unit
175
def getPage(pagePlusOffsetAddress: Long): MemoryBlock
176
def encodePageNumberAndOffset(page: MemoryBlock, offsetInPage: Long): Long
177
def encodePageNumberAndOffset(pageNumber: Int, offsetInPage: Long): Long
178
def decodePageNumber(pagePlusOffsetAddress: Long): Int
179
def decodeOffset(pagePlusOffsetAddress: Long): Long
180
181
// Cleanup
182
def cleanUpAllAllocatedMemory(): Long
183
}
184
```
185
186
## Usage Examples
187
188
### Basic Task Context Usage
189
```scala
190
import org.apache.spark.TaskContext
191
192
val data = sc.parallelize(1 to 100, 4)
193
194
val processedData = data.mapPartitions { iter =>
195
val context = TaskContext.get()
196
197
// Get partition information
198
val partitionId = context.partitionId()
199
val stageId = context.stageId()
200
val attemptId = context.taskAttemptId()
201
202
println(s"Processing partition $partitionId in stage $stageId (attempt $attemptId)")
203
204
// Process data with context information
205
iter.map { value =>
206
(partitionId, value, value * value)
207
}
208
}
209
210
val result = processedData.collect()
211
```
212
213
### Task Completion Listeners
214
```scala
215
import org.apache.spark.{TaskContext, TaskCompletionListener}
216
217
val data = sc.parallelize(1 to 1000, 10)
218
219
val processedData = data.mapPartitions { iter =>
220
val context = TaskContext.get()
221
222
// Register cleanup listener
223
context.addTaskCompletionListener(new TaskCompletionListener {
224
override def onTaskCompletion(context: TaskContext): Unit = {
225
val metrics = context.taskMetrics()
226
println(s"Task ${context.taskAttemptId()} completed in ${metrics.executorRunTime} ms")
227
println(s"Memory spilled: ${metrics.memoryBytesSpilled} bytes")
228
229
// Cleanup resources
230
cleanupResources()
231
}
232
})
233
234
// Process data
235
iter.map(expensiveProcessing)
236
}
237
238
processedData.count()
239
```
240
241
### Task Failure Handling
242
```scala
243
import org.apache.spark.{TaskContext, TaskFailureListener}
244
245
val unreliableData = sc.parallelize(1 to 100)
246
247
val processedData = unreliableData.mapPartitions { iter =>
248
val context = TaskContext.get()
249
250
// Register failure listener for cleanup
251
context.addTaskFailureListener(new TaskFailureListener {
252
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
253
println(s"Task ${context.taskAttemptId()} failed: ${error.getMessage}")
254
255
// Cleanup external resources
256
cleanupExternalConnections()
257
savePartialResults(context.partitionId())
258
}
259
})
260
261
// Potentially failing operation
262
iter.map { value =>
263
if (value % 50 == 0) {
264
throw new RuntimeException(s"Simulated failure on value $value")
265
}
266
processValue(value)
267
}
268
}
269
270
processedData.count()
271
```
272
273
### Memory Management
274
```scala
275
import org.apache.spark.memory.MemoryConsumer
276
277
val data = sc.parallelize(1 to 1000000)
278
279
val processedData = data.mapPartitions { iter =>
280
val context = TaskContext.get()
281
val memoryManager = context.taskMemoryManager()
282
283
// Custom memory consumer for large operations
284
val consumer = new MemoryConsumer(memoryManager) {
285
override def spill(size: Long, trigger: MemoryConsumer): Long = {
286
// Custom spill logic
287
spillToTemp(size)
288
}
289
}
290
291
try {
292
// Acquire memory for processing
293
val requiredMemory = 1024 * 1024 // 1MB
294
val acquiredMemory = memoryManager.acquireExecutionMemory(requiredMemory, consumer)
295
296
if (acquiredMemory < requiredMemory) {
297
println(s"Only acquired $acquiredMemory bytes out of $requiredMemory requested")
298
}
299
300
// Process data with allocated memory
301
val buffer = new Array[Int](acquiredMemory.toInt / 4) // 4 bytes per int
302
processWithBuffer(iter, buffer)
303
304
} finally {
305
// Release memory
306
memoryManager.releaseAllExecutionMemoryForConsumer(consumer)
307
}
308
}
309
310
processedData.count()
311
```
312
313
### Monitoring Task Progress
314
```scala
315
import org.apache.spark.util.AccumulatorV2
316
317
// Custom accumulator for tracking progress
318
class ProgressAccumulator extends AccumulatorV2[Long, Long] {
319
private var _value = 0L
320
321
def isZero: Boolean = _value == 0
322
def copy(): ProgressAccumulator = {
323
val acc = new ProgressAccumulator
324
acc._value = _value
325
acc
326
}
327
def reset(): Unit = _value = 0
328
def add(v: Long): Unit = _value += v
329
def merge(other: AccumulatorV2[Long, Long]): Unit = _value += other.value
330
def value: Long = _value
331
}
332
333
val progressTracker = new ProgressAccumulator
334
sc.register(progressTracker, "Progress Tracker")
335
336
val data = sc.parallelize(1 to 10000, 10)
337
338
val processedData = data.mapPartitions { iter =>
339
val context = TaskContext.get()
340
val partitionId = context.partitionId()
341
var processed = 0L
342
343
// Register completion listener to report final progress
344
context.addTaskCompletionListener(new TaskCompletionListener {
345
override def onTaskCompletion(context: TaskContext): Unit = {
346
println(s"Partition $partitionId processed $processed records")
347
}
348
})
349
350
iter.map { value =>
351
val result = processValue(value)
352
processed += 1
353
progressTracker.add(1) // Update global progress
354
355
// Report progress periodically
356
if (processed % 100 == 0) {
357
println(s"Partition $partitionId: $processed records processed")
358
}
359
360
result
361
}
362
}
363
364
processedData.count()
365
println(s"Total records processed: ${progressTracker.value}")
366
```
367
368
### Resource Information
369
```scala
370
val data = sc.parallelize(1 to 100)
371
372
val resourceInfo = data.mapPartitions { iter =>
373
val context = TaskContext.get()
374
375
// Get available resources
376
val resources = context.resources()
377
val cpus = context.cpus()
378
379
println(s"Task has $cpus CPUs available")
380
resources.foreach { case (resourceName, resourceInfo) =>
381
println(s"Resource $resourceName: ${resourceInfo.addresses.mkString(", ")}")
382
}
383
384
// Adapt processing based on available resources
385
val batchSize = if (cpus > 4) 1000 else 100
386
387
iter.grouped(batchSize).map(processBatch)
388
}.collect()
389
```
390
391
## Best Practices
392
393
### Task Context Usage
394
1. **Always check availability**: Use `TaskContext.get()` and handle null case
395
2. **Register listeners early**: Add listeners at the beginning of task execution
396
3. **Clean up resources**: Use completion listeners for guaranteed cleanup
397
4. **Handle interruption**: Check `isInterrupted()` in long-running operations
398
5. **Partition-aware processing**: Use partition ID for distributed coordination
399
400
### Memory Management
401
1. **Acquire before use**: Always acquire memory before large allocations
402
2. **Release promptly**: Release memory as soon as processing is complete
403
3. **Handle insufficient memory**: Gracefully handle cases where less memory is available
404
4. **Implement spilling**: Provide spill logic for memory consumers
405
5. **Monitor usage**: Track memory usage through task metrics
406
407
### Error Handling
408
1. **Register failure listeners**: Always register cleanup for external resources
409
2. **Graceful degradation**: Handle partial failures appropriately
410
3. **Resource cleanup**: Ensure cleanup happens even on task failure
411
4. **Logging**: Provide detailed error information for debugging
412
5. **Retry logic**: Consider task-level retry strategies for transient failures