CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/maven-org-apache-spark--spark-core-2-11

Apache Spark Core - The foundational component of Apache Spark providing distributed computing capabilities including RDDs, transformations, actions, and cluster management.

Pending
Overview
Eval results
Files

task-context.mddocs/

Task Context

TaskContext provides runtime information and control for tasks executing on cluster nodes. It offers access to task metadata, partition information, and mechanisms for registering cleanup callbacks.

TaskContext Class

abstract class TaskContext {
  // Task Information
  def isCompleted(): Boolean
  def isInterrupted(): Boolean
  def isRunningLocally(): Boolean
  def stageId(): Int
  def stageAttemptNumber(): Int
  def partitionId(): Int
  def attemptNumber(): Int
  def taskAttemptId(): Long
  def executorId(): String
  def resourcesAllocated(): Map[String, ResourceInformation]
  
  // Task Properties
  def getLocalProperty(key: String): String
  def getLocalProperties(): Properties
  def setTaskProperty(key: String, value: String): Unit
  def getTaskProperty(key: String): String
  def getTaskProperties(): Properties
  
  // Listeners and Callbacks
  def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
  def addTaskCompletionListener(f: TaskContext => Unit): TaskContext
  def addTaskFailureListener(listener: TaskFailureListener): TaskContext
  def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
  
  // Metrics and Monitoring
  def taskMetrics(): TaskMetrics
  def getMetricsSources(sourceName: String): Seq[Source]
}

TaskContext Object

object TaskContext {
  def get(): TaskContext
  def getPartitionId(): Int
  def setTaskContext(tc: TaskContext): Unit
  def unset(): Unit
}

Task Information Access

Basic Task Information

import org.apache.spark.{SparkContext, SparkConf, TaskContext}

val sc = new SparkContext(new SparkConf().setAppName("TaskContext Example").setMaster("local[*]"))

val data = sc.parallelize(1 to 100, numSlices = 4)

// Access task information during transformation
val taskInfo = data.mapPartitionsWithIndex { (partitionIndex, iterator) =>
  val context = TaskContext.get()
  
  iterator.map { value =>
    val info = s"Value: $value, " +
      s"Partition: ${context.partitionId()}, " +
      s"Stage: ${context.stageId()}, " +
      s"Attempt: ${context.attemptNumber()}, " +
      s"TaskID: ${context.taskAttemptId()}, " +
      s"Executor: ${context.executorId()}"
    
    (value, info)
  }
}

// Collect and print results
taskInfo.take(10).foreach { case (value, info) =>
  println(info)
}

Task State Monitoring

val monitoredData = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  
  iterator.map { value =>
    // Check if task has been interrupted (e.g., due to cancellation)
    if (context.isInterrupted()) {
      throw new InterruptedException("Task was interrupted")
    }
    
    // Check execution environment
    val executionInfo = Map(
      "isCompleted" -> context.isCompleted(),
      "isRunningLocally" -> context.isRunningLocally(),
      "partitionId" -> context.partitionId(),
      "stageAttempt" -> context.stageAttemptNumber()
    )
    
    (value, executionInfo)
  }
}

Local Properties

TaskContext provides access to thread-local properties set at the driver level.

// Set properties at driver level
sc.setLocalProperty("job.group.id", "data-processing")
sc.setLocalProperty("job.description", "Processing customer data")
sc.setLocalProperty("custom.property", "custom-value")

val processedData = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  
  // Access properties in task
  val jobGroup = context.getLocalProperty("job.group.id")
  val jobDescription = context.getLocalProperty("job.description")
  val customProperty = context.getLocalProperty("custom.property")
  
  iterator.map { value =>
    // Use properties in processing logic
    val metadata = Map(
      "jobGroup" -> jobGroup,
      "description" -> jobDescription,
      "custom" -> customProperty
    )
    
    (value, metadata)
  }
}

// Clear properties when done
sc.setLocalProperty("job.group.id", null)
sc.setLocalProperty("job.description", null)

Task Completion and Failure Listeners

TaskContext allows registering callbacks for task lifecycle events.

Completion Listeners

import java.io.Closeable

val dataWithResources = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  
  // Create resources that need cleanup
  val tempFile = createTempFile()
  val databaseConnection = createDatabaseConnection()
  val httpClient = createHttpClient()
  
  // Register cleanup listeners
  context.addTaskCompletionListener { taskContext =>
    println(s"Task ${taskContext.taskAttemptId()} completed on partition ${taskContext.partitionId()}")
    
    // Cleanup resources
    try {
      tempFile.delete()
      databaseConnection.close()
      httpClient.close()
      println("Resources cleaned up successfully")
    } catch {
      case e: Exception =>
        println(s"Error during cleanup: ${e.getMessage}")
    }
  }
  
  // Process data using resources
  iterator.map { value =>
    // Use database connection, http client, etc.
    processWithResources(value, databaseConnection, httpClient)
  }
}

Failure Listeners

val robustProcessing = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  val metrics = collection.mutable.Map[String, Long]()
  
  // Register failure listener for error reporting
  context.addTaskFailureListener { (taskContext, exception) =>
    val errorReport = s"""
      |Task Failure Report:
      |  Task ID: ${taskContext.taskAttemptId()}
      |  Partition: ${taskContext.partitionId()}
      |  Stage: ${taskContext.stageId()}
      |  Attempt: ${taskContext.attemptNumber()}
      |  Exception: ${exception.getClass.getSimpleName}
      |  Message: ${exception.getMessage}
      |  Records Processed: ${metrics.getOrElse("processed", 0L)}
      |  Records Failed: ${metrics.getOrElse("failed", 0L)}
    """.stripMargin
    
    println(errorReport)
    logErrorToMonitoringSystem(errorReport)
  }
  
  iterator.map { value =>
    try {
      val result = riskyProcessing(value)
      metrics("processed") = metrics.getOrElse("processed", 0L) + 1
      result
    } catch {
      case e: Exception =>
        metrics("failed") = metrics.getOrElse("failed", 0L) + 1
        throw e
    }
  }
}

Combined Resource Management

class ResourceManager extends Serializable {
  def withResources[T](context: TaskContext)(block: (DatabaseConnection, FileWriter) => Iterator[T]): Iterator[T] = {
    var connection: DatabaseConnection = null
    var writer: FileWriter = null
    
    try {
      // Initialize resources
      connection = new DatabaseConnection()
      writer = new FileWriter(s"output-partition-${context.partitionId()}.log")
      
      // Register cleanup on both success and failure
      context.addTaskCompletionListener { _ =>
        cleanupResources(connection, writer, success = true)
      }
      
      context.addTaskFailureListener { (_, exception) =>
        writer.write(s"Task failed with exception: ${exception.getMessage}\n")
        cleanupResources(connection, writer, success = false)
      }
      
      // Execute the processing block
      block(connection, writer)
      
    } catch {
      case e: Exception =>
        cleanupResources(connection, writer, success = false)
        throw e
    }
  }
  
  private def cleanupResources(connection: DatabaseConnection, writer: FileWriter, success: Boolean): Unit = {
    try {
      if (writer != null) {
        writer.write(s"Task completed with success: $success\n")
        writer.close()
      }
      if (connection != null) {
        if (success) connection.commit() else connection.rollback()
        connection.close()
      }
    } catch {
      case e: Exception =>
        println(s"Error during resource cleanup: ${e.getMessage}")
    }
  }
}

// Usage
val resourceManager = new ResourceManager()

val processedData = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  
  resourceManager.withResources(context) { (connection, writer) =>
    iterator.map { value =>
      writer.write(s"Processing value: $value\n")
      val result = processWithDatabase(value, connection)
      writer.write(s"Result: $result\n")
      result
    }
  }
}

Task Metrics Access

TaskContext provides access to detailed task execution metrics.

val metricsCollector = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  
  iterator.map { value =>
    // Access task metrics during execution
    val metrics = context.taskMetrics()
    
    val currentMetrics = Map(
      "executorDeserializeTime" -> metrics.executorDeserializeTime,
      "executorRunTime" -> metrics.executorRunTime,
      "jvmGCTime" -> metrics.jvmGCTime,
      "inputMetrics" -> Option(metrics.inputMetrics).map(_.toString).getOrElse("None"),
      "outputMetrics" -> Option(metrics.outputMetrics).map(_.toString).getOrElse("None"),
      "shuffleReadMetrics" -> Option(metrics.shuffleReadMetrics).map(_.toString).getOrElse("None"),
      "shuffleWriteMetrics" -> Option(metrics.shuffleWriteMetrics).map(_.toString).getOrElse("None")
    )
    
    (value, currentMetrics)
  }
}

Advanced Patterns

Partition-wise Resource Pooling

import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._

object PartitionResourcePool {
  private val connectionPools = new ConcurrentHashMap[Int, DatabaseConnectionPool]()
  
  def getConnection(partitionId: Int): DatabaseConnection = {
    connectionPools.computeIfAbsent(partitionId, _ => new DatabaseConnectionPool()).getConnection()
  }
  
  def releaseConnection(partitionId: Int, connection: DatabaseConnection): Unit = {
    Option(connectionPools.get(partitionId)).foreach(_.releaseConnection(connection))
  }
  
  def closePool(partitionId: Int): Unit = {
    Option(connectionPools.remove(partitionId)).foreach(_.close())
  }
}

val pooledProcessing = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  val partitionId = context.partitionId()
  
  // Register cleanup for the entire partition
  context.addTaskCompletionListener { _ =>
    PartitionResourcePool.closePool(partitionId)
  }
  
  iterator.map { value =>
    val connection = PartitionResourcePool.getConnection(partitionId)
    try {
      processWithConnection(value, connection)
    } finally {
      PartitionResourcePool.releaseConnection(partitionId, connection)
    }
  }
}

Progress Tracking

class ProgressTracker extends Serializable {
  def trackProgress[T](iterator: Iterator[T], context: TaskContext): Iterator[T] = {
    val totalEstimate = 1000000 // Estimated records per partition
    var processed = 0L
    val startTime = System.currentTimeMillis()
    
    iterator.map { item =>
      processed += 1
      
      // Report progress every 10,000 records
      if (processed % 10000 == 0) {
        val elapsed = System.currentTimeMillis() - startTime
        val rate = processed.toDouble / (elapsed / 1000.0)
        val progress = (processed.toDouble / totalEstimate) * 100
        
        val progressInfo = s"Partition ${context.partitionId()}: " +
          s"${processed} records processed (${progress.formatted("%.1f")}%), " +
          s"Rate: ${rate.formatted("%.0f")} records/sec"
        
        println(progressInfo)
        
        // Optional: Send to monitoring system
        sendProgressUpdate(context.stageId(), context.partitionId(), processed, progress, rate)
      }
      
      item
    }
  }
  
  private def sendProgressUpdate(stageId: Int, partitionId: Int, processed: Long, progress: Double, rate: Double): Unit = {
    // Implementation to send metrics to monitoring system
  }
}

// Usage
val tracker = new ProgressTracker()

val trackedProcessing = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  tracker.trackProgress(iterator, context).map(processItem)
}

Conditional Processing Based on Context

val conditionalProcessing = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  
  // Different processing based on execution context
  val processingStrategy = if (context.isRunningLocally()) {
    // Local execution - use simpler, debugging-friendly processing
    (value: Int) => {
      println(s"Local processing: $value")
      value * 2
    }
  } else {
    // Cluster execution - use optimized processing
    val batchSize = 1000
    val batch = collection.mutable.ArrayBuffer[Int]()
    
    (value: Int) => {
      batch += value
      if (batch.size >= batchSize) {
        val results = batchProcess(batch.toArray)
        batch.clear()
        results.last // Return last result for this example
      } else {
        value * 2 // Simple processing for incomplete batch
      }
    }
  }
  
  iterator.map(processingStrategy)
}

Dynamic Resource Allocation

val adaptiveProcessing = data.mapPartitions { iterator =>
  val context = TaskContext.get()
  val partitionId = context.partitionId()
  
  // Get available resources for this task
  val resources = context.resourcesAllocated()
  val availableCores = resources.get("cpus").map(_.amount.toInt).getOrElse(1)
  val availableMemory = resources.get("memory").map(_.amount.toLong).getOrElse(1024L * 1024 * 1024) // 1GB default
  
  // Adjust processing parameters based on available resources
  val (batchSize, parallelism) = if (availableCores > 4 && availableMemory > 2L * 1024 * 1024 * 1024) {
    (10000, 4) // High resource configuration
  } else if (availableCores > 2 && availableMemory > 1024L * 1024 * 1024) {
    (5000, 2)  // Medium resource configuration
  } else {
    (1000, 1)  // Low resource configuration
  }
  
  println(s"Partition $partitionId: Using batch size $batchSize with parallelism $parallelism")
  
  // Process with adaptive parameters
  iterator.grouped(batchSize).flatMap { batch =>
    processInParallel(batch, parallelism)
  }
}

BarrierTaskContext

For barrier execution mode, Spark provides BarrierTaskContext with additional synchronization capabilities.

class BarrierTaskContext extends TaskContext {
  def barrier(): Unit
  def getTaskInfos(): Array[BarrierTaskInfo]
  def getAllGather[T](message: T): Array[T]
}

case class BarrierTaskInfo(address: String)

Barrier Execution Example

// Note: This requires barrier mode to be enabled
val barrierRDD = sc.parallelize(1 to 100, numSlices = 4)

val synchronizedProcessing = barrierRDD.barrier().mapPartitions { iterator =>
  val context = TaskContext.get().asInstanceOf[BarrierTaskContext]
  
  // Phase 1: Local processing
  val localResults = iterator.map(process).toArray
  
  // Synchronization point - all tasks wait here
  context.barrier()
  
  // Phase 2: Coordinate across partitions
  val taskInfos = context.getTaskInfos()
  println(s"Task addresses: ${taskInfos.map(_.address).mkString(", ")}")
  
  // Share local statistics across all tasks
  val localSum = localResults.sum
  val allSums = context.getAllGather(localSum)
  val globalSum = allSums.sum
  
  // Phase 3: Process using global information
  localResults.map(_ + globalSum).iterator
}

Best Practices

Resource Management

  • Always register cleanup listeners for external resources
  • Use try-finally blocks for critical resource cleanup
  • Consider using connection pools for database connections
  • Monitor resource usage through task metrics

Error Handling

  • Register failure listeners for comprehensive error reporting
  • Log sufficient context information for debugging
  • Implement graceful degradation when possible
  • Use task metrics to identify performance bottlenecks

Performance Optimization

  • Access TaskContext only when needed (slight overhead)
  • Cache TaskContext reference if used multiple times in same task
  • Use local properties for configuration instead of closures when possible
  • Monitor task execution time and optimize accordingly

Monitoring and Debugging

// Comprehensive task monitoring
def monitoredMapPartitions[T, U](rdd: RDD[T])(f: Iterator[T] => Iterator[U]): RDD[U] = {
  rdd.mapPartitions { iterator =>
    val context = TaskContext.get()
    val startTime = System.currentTimeMillis()
    var recordCount = 0L
    
    context.addTaskCompletionListener { taskContext =>
      val duration = System.currentTimeMillis() - startTime
      println(s"Task ${taskContext.taskAttemptId()} on partition ${taskContext.partitionId()} " +
        s"processed $recordCount records in ${duration}ms")
    }
    
    f(iterator).map { result =>
      recordCount += 1
      result
    }
  }
}

Install with Tessl CLI

npx tessl i tessl/maven-org-apache-spark--spark-core-2-11

docs

broadcast-accumulators.md

context-configuration.md

index.md

java-api.md

key-value-operations.md

rdd-operations.md

status-monitoring.md

storage-persistence.md

task-context.md

tile.json