Apache Spark Core - The foundational component of Apache Spark providing distributed computing capabilities including RDDs, transformations, actions, and cluster management.
—
TaskContext provides runtime information and control for tasks executing on cluster nodes. It offers access to task metadata, partition information, and mechanisms for registering cleanup callbacks.
abstract class TaskContext {
// Task Information
def isCompleted(): Boolean
def isInterrupted(): Boolean
def isRunningLocally(): Boolean
def stageId(): Int
def stageAttemptNumber(): Int
def partitionId(): Int
def attemptNumber(): Int
def taskAttemptId(): Long
def executorId(): String
def resourcesAllocated(): Map[String, ResourceInformation]
// Task Properties
def getLocalProperty(key: String): String
def getLocalProperties(): Properties
def setTaskProperty(key: String, value: String): Unit
def getTaskProperty(key: String): String
def getTaskProperties(): Properties
// Listeners and Callbacks
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
def addTaskCompletionListener(f: TaskContext => Unit): TaskContext
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
// Metrics and Monitoring
def taskMetrics(): TaskMetrics
def getMetricsSources(sourceName: String): Seq[Source]
}object TaskContext {
def get(): TaskContext
def getPartitionId(): Int
def setTaskContext(tc: TaskContext): Unit
def unset(): Unit
}import org.apache.spark.{SparkContext, SparkConf, TaskContext}
val sc = new SparkContext(new SparkConf().setAppName("TaskContext Example").setMaster("local[*]"))
val data = sc.parallelize(1 to 100, numSlices = 4)
// Access task information during transformation
val taskInfo = data.mapPartitionsWithIndex { (partitionIndex, iterator) =>
val context = TaskContext.get()
iterator.map { value =>
val info = s"Value: $value, " +
s"Partition: ${context.partitionId()}, " +
s"Stage: ${context.stageId()}, " +
s"Attempt: ${context.attemptNumber()}, " +
s"TaskID: ${context.taskAttemptId()}, " +
s"Executor: ${context.executorId()}"
(value, info)
}
}
// Collect and print results
taskInfo.take(10).foreach { case (value, info) =>
println(info)
}val monitoredData = data.mapPartitions { iterator =>
val context = TaskContext.get()
iterator.map { value =>
// Check if task has been interrupted (e.g., due to cancellation)
if (context.isInterrupted()) {
throw new InterruptedException("Task was interrupted")
}
// Check execution environment
val executionInfo = Map(
"isCompleted" -> context.isCompleted(),
"isRunningLocally" -> context.isRunningLocally(),
"partitionId" -> context.partitionId(),
"stageAttempt" -> context.stageAttemptNumber()
)
(value, executionInfo)
}
}TaskContext provides access to thread-local properties set at the driver level.
// Set properties at driver level
sc.setLocalProperty("job.group.id", "data-processing")
sc.setLocalProperty("job.description", "Processing customer data")
sc.setLocalProperty("custom.property", "custom-value")
val processedData = data.mapPartitions { iterator =>
val context = TaskContext.get()
// Access properties in task
val jobGroup = context.getLocalProperty("job.group.id")
val jobDescription = context.getLocalProperty("job.description")
val customProperty = context.getLocalProperty("custom.property")
iterator.map { value =>
// Use properties in processing logic
val metadata = Map(
"jobGroup" -> jobGroup,
"description" -> jobDescription,
"custom" -> customProperty
)
(value, metadata)
}
}
// Clear properties when done
sc.setLocalProperty("job.group.id", null)
sc.setLocalProperty("job.description", null)TaskContext allows registering callbacks for task lifecycle events.
import java.io.Closeable
val dataWithResources = data.mapPartitions { iterator =>
val context = TaskContext.get()
// Create resources that need cleanup
val tempFile = createTempFile()
val databaseConnection = createDatabaseConnection()
val httpClient = createHttpClient()
// Register cleanup listeners
context.addTaskCompletionListener { taskContext =>
println(s"Task ${taskContext.taskAttemptId()} completed on partition ${taskContext.partitionId()}")
// Cleanup resources
try {
tempFile.delete()
databaseConnection.close()
httpClient.close()
println("Resources cleaned up successfully")
} catch {
case e: Exception =>
println(s"Error during cleanup: ${e.getMessage}")
}
}
// Process data using resources
iterator.map { value =>
// Use database connection, http client, etc.
processWithResources(value, databaseConnection, httpClient)
}
}val robustProcessing = data.mapPartitions { iterator =>
val context = TaskContext.get()
val metrics = collection.mutable.Map[String, Long]()
// Register failure listener for error reporting
context.addTaskFailureListener { (taskContext, exception) =>
val errorReport = s"""
|Task Failure Report:
| Task ID: ${taskContext.taskAttemptId()}
| Partition: ${taskContext.partitionId()}
| Stage: ${taskContext.stageId()}
| Attempt: ${taskContext.attemptNumber()}
| Exception: ${exception.getClass.getSimpleName}
| Message: ${exception.getMessage}
| Records Processed: ${metrics.getOrElse("processed", 0L)}
| Records Failed: ${metrics.getOrElse("failed", 0L)}
""".stripMargin
println(errorReport)
logErrorToMonitoringSystem(errorReport)
}
iterator.map { value =>
try {
val result = riskyProcessing(value)
metrics("processed") = metrics.getOrElse("processed", 0L) + 1
result
} catch {
case e: Exception =>
metrics("failed") = metrics.getOrElse("failed", 0L) + 1
throw e
}
}
}class ResourceManager extends Serializable {
def withResources[T](context: TaskContext)(block: (DatabaseConnection, FileWriter) => Iterator[T]): Iterator[T] = {
var connection: DatabaseConnection = null
var writer: FileWriter = null
try {
// Initialize resources
connection = new DatabaseConnection()
writer = new FileWriter(s"output-partition-${context.partitionId()}.log")
// Register cleanup on both success and failure
context.addTaskCompletionListener { _ =>
cleanupResources(connection, writer, success = true)
}
context.addTaskFailureListener { (_, exception) =>
writer.write(s"Task failed with exception: ${exception.getMessage}\n")
cleanupResources(connection, writer, success = false)
}
// Execute the processing block
block(connection, writer)
} catch {
case e: Exception =>
cleanupResources(connection, writer, success = false)
throw e
}
}
private def cleanupResources(connection: DatabaseConnection, writer: FileWriter, success: Boolean): Unit = {
try {
if (writer != null) {
writer.write(s"Task completed with success: $success\n")
writer.close()
}
if (connection != null) {
if (success) connection.commit() else connection.rollback()
connection.close()
}
} catch {
case e: Exception =>
println(s"Error during resource cleanup: ${e.getMessage}")
}
}
}
// Usage
val resourceManager = new ResourceManager()
val processedData = data.mapPartitions { iterator =>
val context = TaskContext.get()
resourceManager.withResources(context) { (connection, writer) =>
iterator.map { value =>
writer.write(s"Processing value: $value\n")
val result = processWithDatabase(value, connection)
writer.write(s"Result: $result\n")
result
}
}
}TaskContext provides access to detailed task execution metrics.
val metricsCollector = data.mapPartitions { iterator =>
val context = TaskContext.get()
iterator.map { value =>
// Access task metrics during execution
val metrics = context.taskMetrics()
val currentMetrics = Map(
"executorDeserializeTime" -> metrics.executorDeserializeTime,
"executorRunTime" -> metrics.executorRunTime,
"jvmGCTime" -> metrics.jvmGCTime,
"inputMetrics" -> Option(metrics.inputMetrics).map(_.toString).getOrElse("None"),
"outputMetrics" -> Option(metrics.outputMetrics).map(_.toString).getOrElse("None"),
"shuffleReadMetrics" -> Option(metrics.shuffleReadMetrics).map(_.toString).getOrElse("None"),
"shuffleWriteMetrics" -> Option(metrics.shuffleWriteMetrics).map(_.toString).getOrElse("None")
)
(value, currentMetrics)
}
}import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
object PartitionResourcePool {
private val connectionPools = new ConcurrentHashMap[Int, DatabaseConnectionPool]()
def getConnection(partitionId: Int): DatabaseConnection = {
connectionPools.computeIfAbsent(partitionId, _ => new DatabaseConnectionPool()).getConnection()
}
def releaseConnection(partitionId: Int, connection: DatabaseConnection): Unit = {
Option(connectionPools.get(partitionId)).foreach(_.releaseConnection(connection))
}
def closePool(partitionId: Int): Unit = {
Option(connectionPools.remove(partitionId)).foreach(_.close())
}
}
val pooledProcessing = data.mapPartitions { iterator =>
val context = TaskContext.get()
val partitionId = context.partitionId()
// Register cleanup for the entire partition
context.addTaskCompletionListener { _ =>
PartitionResourcePool.closePool(partitionId)
}
iterator.map { value =>
val connection = PartitionResourcePool.getConnection(partitionId)
try {
processWithConnection(value, connection)
} finally {
PartitionResourcePool.releaseConnection(partitionId, connection)
}
}
}class ProgressTracker extends Serializable {
def trackProgress[T](iterator: Iterator[T], context: TaskContext): Iterator[T] = {
val totalEstimate = 1000000 // Estimated records per partition
var processed = 0L
val startTime = System.currentTimeMillis()
iterator.map { item =>
processed += 1
// Report progress every 10,000 records
if (processed % 10000 == 0) {
val elapsed = System.currentTimeMillis() - startTime
val rate = processed.toDouble / (elapsed / 1000.0)
val progress = (processed.toDouble / totalEstimate) * 100
val progressInfo = s"Partition ${context.partitionId()}: " +
s"${processed} records processed (${progress.formatted("%.1f")}%), " +
s"Rate: ${rate.formatted("%.0f")} records/sec"
println(progressInfo)
// Optional: Send to monitoring system
sendProgressUpdate(context.stageId(), context.partitionId(), processed, progress, rate)
}
item
}
}
private def sendProgressUpdate(stageId: Int, partitionId: Int, processed: Long, progress: Double, rate: Double): Unit = {
// Implementation to send metrics to monitoring system
}
}
// Usage
val tracker = new ProgressTracker()
val trackedProcessing = data.mapPartitions { iterator =>
val context = TaskContext.get()
tracker.trackProgress(iterator, context).map(processItem)
}val conditionalProcessing = data.mapPartitions { iterator =>
val context = TaskContext.get()
// Different processing based on execution context
val processingStrategy = if (context.isRunningLocally()) {
// Local execution - use simpler, debugging-friendly processing
(value: Int) => {
println(s"Local processing: $value")
value * 2
}
} else {
// Cluster execution - use optimized processing
val batchSize = 1000
val batch = collection.mutable.ArrayBuffer[Int]()
(value: Int) => {
batch += value
if (batch.size >= batchSize) {
val results = batchProcess(batch.toArray)
batch.clear()
results.last // Return last result for this example
} else {
value * 2 // Simple processing for incomplete batch
}
}
}
iterator.map(processingStrategy)
}val adaptiveProcessing = data.mapPartitions { iterator =>
val context = TaskContext.get()
val partitionId = context.partitionId()
// Get available resources for this task
val resources = context.resourcesAllocated()
val availableCores = resources.get("cpus").map(_.amount.toInt).getOrElse(1)
val availableMemory = resources.get("memory").map(_.amount.toLong).getOrElse(1024L * 1024 * 1024) // 1GB default
// Adjust processing parameters based on available resources
val (batchSize, parallelism) = if (availableCores > 4 && availableMemory > 2L * 1024 * 1024 * 1024) {
(10000, 4) // High resource configuration
} else if (availableCores > 2 && availableMemory > 1024L * 1024 * 1024) {
(5000, 2) // Medium resource configuration
} else {
(1000, 1) // Low resource configuration
}
println(s"Partition $partitionId: Using batch size $batchSize with parallelism $parallelism")
// Process with adaptive parameters
iterator.grouped(batchSize).flatMap { batch =>
processInParallel(batch, parallelism)
}
}For barrier execution mode, Spark provides BarrierTaskContext with additional synchronization capabilities.
class BarrierTaskContext extends TaskContext {
def barrier(): Unit
def getTaskInfos(): Array[BarrierTaskInfo]
def getAllGather[T](message: T): Array[T]
}
case class BarrierTaskInfo(address: String)// Note: This requires barrier mode to be enabled
val barrierRDD = sc.parallelize(1 to 100, numSlices = 4)
val synchronizedProcessing = barrierRDD.barrier().mapPartitions { iterator =>
val context = TaskContext.get().asInstanceOf[BarrierTaskContext]
// Phase 1: Local processing
val localResults = iterator.map(process).toArray
// Synchronization point - all tasks wait here
context.barrier()
// Phase 2: Coordinate across partitions
val taskInfos = context.getTaskInfos()
println(s"Task addresses: ${taskInfos.map(_.address).mkString(", ")}")
// Share local statistics across all tasks
val localSum = localResults.sum
val allSums = context.getAllGather(localSum)
val globalSum = allSums.sum
// Phase 3: Process using global information
localResults.map(_ + globalSum).iterator
}// Comprehensive task monitoring
def monitoredMapPartitions[T, U](rdd: RDD[T])(f: Iterator[T] => Iterator[U]): RDD[U] = {
rdd.mapPartitions { iterator =>
val context = TaskContext.get()
val startTime = System.currentTimeMillis()
var recordCount = 0L
context.addTaskCompletionListener { taskContext =>
val duration = System.currentTimeMillis() - startTime
println(s"Task ${taskContext.taskAttemptId()} on partition ${taskContext.partitionId()} " +
s"processed $recordCount records in ${duration}ms")
}
f(iterator).map { result =>
recordCount += 1
result
}
}
}Install with Tessl CLI
npx tessl i tessl/maven-org-apache-spark--spark-core-2-11