Apache Spark Core provides distributed computing capabilities including RDD abstractions, task scheduling, memory management, and fault recovery.
TaskContext provides runtime information and utilities available to tasks during execution, including partition information, memory management, and lifecycle hooks.
The main context object available to running tasks.
abstract class TaskContext {
// Task Identification
def partitionId(): Int
def stageId(): Int
def stageAttemptNumber(): Int
def taskAttemptId(): Long
def attemptNumber(): Int
def taskMetrics(): TaskMetrics
// Memory Management
def taskMemoryManager(): TaskMemoryManager
// Lifecycle Hooks
def registerTaskCompletionListener(listener: TaskCompletionListener): TaskContext
def registerTaskFailureListener(listener: TaskFailureListener): TaskContext
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
// Status Information
def isCompleted(): Boolean
def isRunningLocally(): Boolean
def isInterrupted(): Boolean
def getKillReason(): Option[String]
def killTaskIfInterrupted(): Unit
// Resources and Metrics
def resources(): Map[String, ResourceInformation]
def cpus(): Int
def getMetricsSources(sourceName: String): Seq[Source]
// Local Properties
def getLocalProperty(key: String): String
def setTaskThread(thread: Thread): Unit
def markTaskCompleted(error: Option[Throwable]): Unit
def markTaskFailed(error: Throwable): Unit
}
object TaskContext {
def get(): TaskContext
def getPartitionId(): Int
}Interfaces for handling task lifecycle events.
trait TaskCompletionListener {
def onTaskCompletion(context: TaskContext): Unit
}
trait TaskFailureListener {
def onTaskFailure(context: TaskContext, error: Throwable): Unit
}Metrics and statistics collected during task execution.
class TaskMetrics {
// Execution Metrics
def executorDeserializeTime: Long
def executorDeserializeCpuTime: Long
def executorRunTime: Long
def executorCpuTime: Long
def resultSize: Long
def jvmGCTime: Long
def resultSerializationTime: Long
def memoryBytesSpilled: Long
def diskBytesSpilled: Long
def peakExecutionMemory: Long
// I/O Metrics
def inputMetrics: InputMetrics
def outputMetrics: OutputMetrics
def shuffleReadMetrics: ShuffleReadMetrics
def shuffleWriteMetrics: ShuffleWriteMetrics
// Update Methods
def incExecutorDeserializeTime(v: Long): Unit
def incExecutorRunTime(v: Long): Unit
def incResultSize(v: Long): Unit
def incJvmGCTime(v: Long): Unit
def incResultSerializationTime(v: Long): Unit
def incMemoryBytesSpilled(v: Long): Unit
def incDiskBytesSpilled(v: Long): Unit
def setPeakExecutionMemory(v: Long): Unit
}Detailed metrics for different types of I/O operations.
class InputMetrics {
def bytesRead: Long
def recordsRead: Long
def incBytesRead(v: Long): Unit
def incRecordsRead(v: Long): Unit
def setBytesRead(v: Long): Unit
def setRecordsRead(v: Long): Unit
}
class OutputMetrics {
def bytesWritten: Long
def recordsWritten: Long
def incBytesWritten(v: Long): Unit
def incRecordsWritten(v: Long): Unit
def setBytesWritten(v: Long): Unit
def setRecordsWritten(v: Long): Unit
}
class ShuffleReadMetrics {
def remoteBlocksFetched: Long
def localBlocksFetched: Long
def fetchWaitTime: Long
def remoteBytesRead: Long
def remoteBytesReadToDisk: Long
def localBytesRead: Long
def recordsRead: Long
def incRemoteBlocksFetched(v: Long): Unit
def incLocalBlocksFetched(v: Long): Unit
def incFetchWaitTime(v: Long): Unit
def incRemoteBytesRead(v: Long): Unit
def incRemoteBytesReadToDisk(v: Long): Unit
def incLocalBytesRead(v: Long): Unit
def incRecordsRead(v: Long): Unit
}
class ShuffleWriteMetrics {
def bytesWritten: Long
def recordsWritten: Long
def writeTime: Long
def incBytesWritten(v: Long): Unit
def incRecordsWritten(v: Long): Unit
def incWriteTime(v: Long): Unit
def setBytesWritten(v: Long): Unit
def setRecordsWritten(v: Long): Unit
def setWriteTime(v: Long): Unit
}Memory management interface for tasks.
class TaskMemoryManager(memoryManager: MemoryManager, taskAttemptId: Long) {
// Memory Acquisition
def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long
def releaseExecutionMemory(size: Long, consumer: MemoryConsumer): Unit
def releaseAllExecutionMemoryForConsumer(consumer: MemoryConsumer): Long
// Memory Information
def executionMemoryUsed: Long
def getTungstenMemoryMode: MemoryMode
// Page Management (for advanced users)
def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock
def freePage(page: MemoryBlock, consumer: MemoryConsumer): Unit
def getPage(pagePlusOffsetAddress: Long): MemoryBlock
def encodePageNumberAndOffset(page: MemoryBlock, offsetInPage: Long): Long
def encodePageNumberAndOffset(pageNumber: Int, offsetInPage: Long): Long
def decodePageNumber(pagePlusOffsetAddress: Long): Int
def decodeOffset(pagePlusOffsetAddress: Long): Long
// Cleanup
def cleanUpAllAllocatedMemory(): Long
}import org.apache.spark.TaskContext
val data = sc.parallelize(1 to 100, 4)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
// Get partition information
val partitionId = context.partitionId()
val stageId = context.stageId()
val attemptId = context.taskAttemptId()
println(s"Processing partition $partitionId in stage $stageId (attempt $attemptId)")
// Process data with context information
iter.map { value =>
(partitionId, value, value * value)
}
}
val result = processedData.collect()import org.apache.spark.{TaskContext, TaskCompletionListener}
val data = sc.parallelize(1 to 1000, 10)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
// Register cleanup listener
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
val metrics = context.taskMetrics()
println(s"Task ${context.taskAttemptId()} completed in ${metrics.executorRunTime} ms")
println(s"Memory spilled: ${metrics.memoryBytesSpilled} bytes")
// Cleanup resources
cleanupResources()
}
})
// Process data
iter.map(expensiveProcessing)
}
processedData.count()import org.apache.spark.{TaskContext, TaskFailureListener}
val unreliableData = sc.parallelize(1 to 100)
val processedData = unreliableData.mapPartitions { iter =>
val context = TaskContext.get()
// Register failure listener for cleanup
context.addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
println(s"Task ${context.taskAttemptId()} failed: ${error.getMessage}")
// Cleanup external resources
cleanupExternalConnections()
savePartialResults(context.partitionId())
}
})
// Potentially failing operation
iter.map { value =>
if (value % 50 == 0) {
throw new RuntimeException(s"Simulated failure on value $value")
}
processValue(value)
}
}
processedData.count()import org.apache.spark.memory.MemoryConsumer
val data = sc.parallelize(1 to 1000000)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
val memoryManager = context.taskMemoryManager()
// Custom memory consumer for large operations
val consumer = new MemoryConsumer(memoryManager) {
override def spill(size: Long, trigger: MemoryConsumer): Long = {
// Custom spill logic
spillToTemp(size)
}
}
try {
// Acquire memory for processing
val requiredMemory = 1024 * 1024 // 1MB
val acquiredMemory = memoryManager.acquireExecutionMemory(requiredMemory, consumer)
if (acquiredMemory < requiredMemory) {
println(s"Only acquired $acquiredMemory bytes out of $requiredMemory requested")
}
// Process data with allocated memory
val buffer = new Array[Int](acquiredMemory.toInt / 4) // 4 bytes per int
processWithBuffer(iter, buffer)
} finally {
// Release memory
memoryManager.releaseAllExecutionMemoryForConsumer(consumer)
}
}
processedData.count()import org.apache.spark.util.AccumulatorV2
// Custom accumulator for tracking progress
class ProgressAccumulator extends AccumulatorV2[Long, Long] {
private var _value = 0L
def isZero: Boolean = _value == 0
def copy(): ProgressAccumulator = {
val acc = new ProgressAccumulator
acc._value = _value
acc
}
def reset(): Unit = _value = 0
def add(v: Long): Unit = _value += v
def merge(other: AccumulatorV2[Long, Long]): Unit = _value += other.value
def value: Long = _value
}
val progressTracker = new ProgressAccumulator
sc.register(progressTracker, "Progress Tracker")
val data = sc.parallelize(1 to 10000, 10)
val processedData = data.mapPartitions { iter =>
val context = TaskContext.get()
val partitionId = context.partitionId()
var processed = 0L
// Register completion listener to report final progress
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
println(s"Partition $partitionId processed $processed records")
}
})
iter.map { value =>
val result = processValue(value)
processed += 1
progressTracker.add(1) // Update global progress
// Report progress periodically
if (processed % 100 == 0) {
println(s"Partition $partitionId: $processed records processed")
}
result
}
}
processedData.count()
println(s"Total records processed: ${progressTracker.value}")val data = sc.parallelize(1 to 100)
val resourceInfo = data.mapPartitions { iter =>
val context = TaskContext.get()
// Get available resources
val resources = context.resources()
val cpus = context.cpus()
println(s"Task has $cpus CPUs available")
resources.foreach { case (resourceName, resourceInfo) =>
println(s"Resource $resourceName: ${resourceInfo.addresses.mkString(", ")}")
}
// Adapt processing based on available resources
val batchSize = if (cpus > 4) 1000 else 100
iter.grouped(batchSize).map(processBatch)
}.collect()TaskContext.get() and handle null caseisInterrupted() in long-running operationsInstall with Tessl CLI
npx tessl i tessl/maven-org-apache-spark--spark-core-2-12