TaskContext provides runtime information and utilities available to tasks during execution, including partition information, memory management, and lifecycle hooks.
The main context object available to running tasks.
abstract class TaskContext {
// Task Identification
def partitionId(): Int
def stageId(): Int
def stageAttemptNumber(): Int
def taskAttemptId(): Long
def attemptNumber(): Int
def taskMetrics(): TaskMetrics
// Memory Management
def taskMemoryManager(): TaskMemoryManager
// Lifecycle Hooks
def registerTaskCompletionListener(listener: TaskCompletionListener): TaskContext
def registerTaskFailureListener(listener: TaskFailureListener): TaskContext
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
// Status Information
def isCompleted(): Boolean
def isRunningLocally(): Boolean
def isInterrupted(): Boolean
def getKillReason(): Option[String]
def killTaskIfInterrupted(): Unit
// Resources and Metrics
def resources(): Map[String, ResourceInformation]
def cpus(): Int
def getMetricsSources(sourceName: String): Seq[Source]
// Local Properties
def getLocalProperty(key: String): String
def setTaskThread(thread: Thread): Unit
def markTaskCompleted(error: Option[Throwable]): Unit
def markTaskFailed(error: Throwable): Unit
}
object TaskContext {
def get(): TaskContext
def getPartitionId(): Int
}Interfaces for handling task lifecycle events.
trait TaskCompletionListener {
def onTaskCompletion(context: TaskContext): Unit
}
trait TaskFailureListener {
def onTaskFailure(context: TaskContext, error: Throwable): Unit
}Metrics and statistics collected during task execution.
class TaskMetrics {
// Execution Metrics
def executorDeserializeTime: Long
def executorDeserializeCpuTime: Long
def executorRunTime: Long
def executorCpuTime: Long
def resultSize: Long
def jvmGCTime: Long
def resultSerializationTime: Long
def memoryBytesSpilled: Long
def diskBytesSpilled: Long
def peakExecutionMemory: Long
// I/O Metrics
def inputMetrics: InputMetrics
def outputMetrics: OutputMetrics
def shuffleReadMetrics: ShuffleReadMetrics
def shuffleWriteMetrics: ShuffleWriteMetrics
// Update Methods
def incExecutorDeserializeTime(v: Long): Unit
def incExecutorRunTime(v: Long): Unit
def incResultSize(v: Long): Unit
def incJvmGCTime(v: Long): Unit
def incResultSerializationTime(v: Long): Unit
def incMemoryBytesSpilled(v: Long): Unit
def incDiskBytesSpilled(v: Long): Unit
def setPeakExecutionMemory(v: Long): Unit
}Detailed metrics for different types of I/O operations.
class InputMetrics {
def bytesRead: Long
def recordsRead: Long
def incBytesRead(v: Long): Unit
def incRecordsRead(v: Long): Unit
def setBytesRead(v: Long): Unit
def setRecordsRead(v: Long): Unit
}
class OutputMetrics {
def bytesWritten: Long
def recordsWritten: Long
def incBytesWritten(v: Long): Unit
def incRecordsWritten(v: Long): Unit
def setBytesWritten(v: Long): Unit
def setRecordsWritten(v: Long): Unit
}
class ShuffleReadMetrics {
def remoteBlocksFetched: Long
def localBlocksFetched: Long
def fetchWaitTime: Long
def remoteBytesRead: Long
def remoteBytesReadToDisk: Long
def localBytesRead: Long
def recordsRead: Long
def incRemoteBlocksFetched(v: Long): Unit
def incLocalBlocksFetched(v: Long): Unit
def incFetchWaitTime(v: Long): Unit
def incRemoteBytesRead(v: Long): Unit
def incRemoteBytesReadToDisk(v: Long): Unit
def incLocalBytesRead(v: Long): Unit
def incRecordsRead(v: Long): Unit
}
class ShuffleWriteMetrics {
def bytesWritten: Long
def recordsWritten: Long
def writeTime: Long
def incBytesWritten(v: Long): Unit
def incRecordsWritten(v: Long): Unit
def incWriteTime(v: Long): Unit
def setBytesWritten(v: Long): Unit
def setRecordsWritten(v: Long): Unit
def setWriteTime(v: Long): Unit
}Memory management interface for tasks.
class TaskMemoryManager(memoryManager: MemoryManager, taskAttemptId: Long) {
// Memory Acquisition
def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long
def releaseExecutionMemory(size: Long, consumer: MemoryConsumer): Unit
def releaseAllExecutionMemoryForConsumer(consumer: MemoryConsumer): Long
// Memory Information
def executionMemoryUsed: Long
def getTungstenMemoryMode: MemoryMode
// Page Management (for advanced users)
def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock
def freePage(page: MemoryBlock, consumer: MemoryConsumer): Unit
def getPage(pagePlusOffsetAddress: Long): MemoryBlock
def encodePageNumberAndOffset(page: MemoryBlock, offsetInPage: Long): Long
def encodePageNumberAndOffset(pageNumber: Int, offsetInPage: Long): Long
def decodePageNumber(pagePlusOffsetAddress: Long): Int
def decodeOffset(pagePlusOffsetAddress: Long): Long
// Cleanup
def cleanUpAllAllocatedMemory(): Long
}import org.apache.spark.TaskContext
val data = sc.parallelize(1 to 100, 4)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
// Get partition information
val partitionId = context.partitionId()
val stageId = context.stageId()
val attemptId = context.taskAttemptId()
println(s"Processing partition $partitionId in stage $stageId (attempt $attemptId)")
// Process data with context information
iter.map { value =>
(partitionId, value, value * value)
}
}
val result = processedData.collect()import org.apache.spark.{TaskContext, TaskCompletionListener}
val data = sc.parallelize(1 to 1000, 10)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
// Register cleanup listener
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
val metrics = context.taskMetrics()
println(s"Task ${context.taskAttemptId()} completed in ${metrics.executorRunTime} ms")
println(s"Memory spilled: ${metrics.memoryBytesSpilled} bytes")
// Cleanup resources
cleanupResources()
}
})
// Process data
iter.map(expensiveProcessing)
}
processedData.count()import org.apache.spark.{TaskContext, TaskFailureListener}
val unreliableData = sc.parallelize(1 to 100)
val processedData = unreliableData.mapPartitions { iter =>
val context = TaskContext.get()
// Register failure listener for cleanup
context.addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
println(s"Task ${context.taskAttemptId()} failed: ${error.getMessage}")
// Cleanup external resources
cleanupExternalConnections()
savePartialResults(context.partitionId())
}
})
// Potentially failing operation
iter.map { value =>
if (value % 50 == 0) {
throw new RuntimeException(s"Simulated failure on value $value")
}
processValue(value)
}
}
processedData.count()import org.apache.spark.memory.MemoryConsumer
val data = sc.parallelize(1 to 1000000)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
val memoryManager = context.taskMemoryManager()
// Custom memory consumer for large operations
val consumer = new MemoryConsumer(memoryManager) {
override def spill(size: Long, trigger: MemoryConsumer): Long = {
// Custom spill logic
spillToTemp(size)
}
}
try {
// Acquire memory for processing
val requiredMemory = 1024 * 1024 // 1MB
val acquiredMemory = memoryManager.acquireExecutionMemory(requiredMemory, consumer)
if (acquiredMemory < requiredMemory) {
println(s"Only acquired $acquiredMemory bytes out of $requiredMemory requested")
}
// Process data with allocated memory
val buffer = new Array[Int](acquiredMemory.toInt / 4) // 4 bytes per int
processWithBuffer(iter, buffer)
} finally {
// Release memory
memoryManager.releaseAllExecutionMemoryForConsumer(consumer)
}
}
processedData.count()import org.apache.spark.util.AccumulatorV2
// Custom accumulator for tracking progress
class ProgressAccumulator extends AccumulatorV2[Long, Long] {
private var _value = 0L
def isZero: Boolean = _value == 0
def copy(): ProgressAccumulator = {
val acc = new ProgressAccumulator
acc._value = _value
acc
}
def reset(): Unit = _value = 0
def add(v: Long): Unit = _value += v
def merge(other: AccumulatorV2[Long, Long]): Unit = _value += other.value
def value: Long = _value
}
val progressTracker = new ProgressAccumulator
sc.register(progressTracker, "Progress Tracker")
val data = sc.parallelize(1 to 10000, 10)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
val partitionId = context.partitionId()
var processed = 0L
// Register completion listener to report final progress
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
println(s"Partition $partitionId processed $processed records")
}
})
iter.map { value =>
val result = processValue(value)
processed += 1
progressTracker.add(1) // Update global progress
// Report progress periodically
if (processed % 100 == 0) {
println(s"Partition $partitionId: $processed records processed")
}
result
}
}
processedData.count()
println(s"Total records processed: ${progressTracker.value}")val data = sc.parallelize(1 to 100)
val resourceInfo = data.mapPartitions { iter =>
val context = TaskContext.get()
// Get available resources
val resources = context.resources()
val cpus = context.cpus()
println(s"Task has $cpus CPUs available")
resources.foreach { case (resourceName, resourceInfo) =>
println(s"Resource $resourceName: ${resourceInfo.addresses.mkString(", ")}")
}
// Adapt processing based on available resources
val batchSize = if (cpus > 4) 1000 else 100
iter.grouped(batchSize).map(processBatch)
}.collect()TaskContext.get() and handle null caseisInterrupted() in long-running operations