or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

context-config.mdindex.mdjava-api.mdrdd-operations.mdresource-management.mdserialization.mdshared-variables.mdstorage-caching.mdtask-context.md
tile.json

task-context.mddocs/

Task Context

TaskContext provides runtime information and utilities available to tasks during execution, including partition information, memory management, and lifecycle hooks.

TaskContext

The main context object available to running tasks.

abstract class TaskContext {
  // Task Identification
  def partitionId(): Int
  def stageId(): Int  
  def stageAttemptNumber(): Int
  def taskAttemptId(): Long
  def attemptNumber(): Int
  def taskMetrics(): TaskMetrics
  
  // Memory Management
  def taskMemoryManager(): TaskMemoryManager
  
  // Lifecycle Hooks
  def registerTaskCompletionListener(listener: TaskCompletionListener): TaskContext
  def registerTaskFailureListener(listener: TaskFailureListener): TaskContext
  def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
  def addTaskFailureListener(listener: TaskFailureListener): TaskContext
  
  // Status Information
  def isCompleted(): Boolean
  def isRunningLocally(): Boolean
  def isInterrupted(): Boolean
  def getKillReason(): Option[String]
  def killTaskIfInterrupted(): Unit
  
  // Resources and Metrics
  def resources(): Map[String, ResourceInformation]
  def cpus(): Int
  def getMetricsSources(sourceName: String): Seq[Source]
  
  // Local Properties
  def getLocalProperty(key: String): String
  def setTaskThread(thread: Thread): Unit
  def markTaskCompleted(error: Option[Throwable]): Unit
  def markTaskFailed(error: Throwable): Unit
}

object TaskContext {
  def get(): TaskContext
  def getPartitionId(): Int
}

Task Listeners

Interfaces for handling task lifecycle events.

trait TaskCompletionListener {
  def onTaskCompletion(context: TaskContext): Unit
}

trait TaskFailureListener {
  def onTaskFailure(context: TaskContext, error: Throwable): Unit
}

TaskMetrics

Metrics and statistics collected during task execution.

class TaskMetrics {
  // Execution Metrics
  def executorDeserializeTime: Long
  def executorDeserializeCpuTime: Long
  def executorRunTime: Long
  def executorCpuTime: Long
  def resultSize: Long
  def jvmGCTime: Long
  def resultSerializationTime: Long
  def memoryBytesSpilled: Long
  def diskBytesSpilled: Long
  def peakExecutionMemory: Long
  
  // I/O Metrics
  def inputMetrics: InputMetrics
  def outputMetrics: OutputMetrics
  def shuffleReadMetrics: ShuffleReadMetrics
  def shuffleWriteMetrics: ShuffleWriteMetrics
  
  // Update Methods
  def incExecutorDeserializeTime(v: Long): Unit
  def incExecutorRunTime(v: Long): Unit
  def incResultSize(v: Long): Unit
  def incJvmGCTime(v: Long): Unit
  def incResultSerializationTime(v: Long): Unit
  def incMemoryBytesSpilled(v: Long): Unit
  def incDiskBytesSpilled(v: Long): Unit
  def setPeakExecutionMemory(v: Long): Unit
}

Input/Output/Shuffle Metrics

Detailed metrics for different types of I/O operations.

class InputMetrics {
  def bytesRead: Long
  def recordsRead: Long
  def incBytesRead(v: Long): Unit
  def incRecordsRead(v: Long): Unit
  def setBytesRead(v: Long): Unit
  def setRecordsRead(v: Long): Unit
}

class OutputMetrics {
  def bytesWritten: Long
  def recordsWritten: Long
  def incBytesWritten(v: Long): Unit
  def incRecordsWritten(v: Long): Unit
  def setBytesWritten(v: Long): Unit
  def setRecordsWritten(v: Long): Unit
}

class ShuffleReadMetrics {
  def remoteBlocksFetched: Long
  def localBlocksFetched: Long
  def fetchWaitTime: Long
  def remoteBytesRead: Long
  def remoteBytesReadToDisk: Long
  def localBytesRead: Long
  def recordsRead: Long
  
  def incRemoteBlocksFetched(v: Long): Unit
  def incLocalBlocksFetched(v: Long): Unit
  def incFetchWaitTime(v: Long): Unit
  def incRemoteBytesRead(v: Long): Unit
  def incRemoteBytesReadToDisk(v: Long): Unit
  def incLocalBytesRead(v: Long): Unit
  def incRecordsRead(v: Long): Unit
}

class ShuffleWriteMetrics {
  def bytesWritten: Long
  def recordsWritten: Long
  def writeTime: Long
  
  def incBytesWritten(v: Long): Unit
  def incRecordsWritten(v: Long): Unit
  def incWriteTime(v: Long): Unit
  def setBytesWritten(v: Long): Unit
  def setRecordsWritten(v: Long): Unit
  def setWriteTime(v: Long): Unit
}

TaskMemoryManager

Memory management interface for tasks.

class TaskMemoryManager(memoryManager: MemoryManager, taskAttemptId: Long) {
  // Memory Acquisition
  def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long
  def releaseExecutionMemory(size: Long, consumer: MemoryConsumer): Unit
  def releaseAllExecutionMemoryForConsumer(consumer: MemoryConsumer): Long
  
  // Memory Information
  def executionMemoryUsed: Long
  def getTungstenMemoryMode: MemoryMode
  
  // Page Management (for advanced users)
  def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock
  def freePage(page: MemoryBlock, consumer: MemoryConsumer): Unit
  def getPage(pagePlusOffsetAddress: Long): MemoryBlock
  def encodePageNumberAndOffset(page: MemoryBlock, offsetInPage: Long): Long
  def encodePageNumberAndOffset(pageNumber: Int, offsetInPage: Long): Long
  def decodePageNumber(pagePlusOffsetAddress: Long): Int
  def decodeOffset(pagePlusOffsetAddress: Long): Long
  
  // Cleanup
  def cleanUpAllAllocatedMemory(): Long
}

Usage Examples

Basic Task Context Usage

import org.apache.spark.TaskContext

val data = sc.parallelize(1 to 100, 4)

val processedData = data.mapPartitions { iter =>
  val context = TaskContext.get()
  
  // Get partition information
  val partitionId = context.partitionId()
  val stageId = context.stageId()
  val attemptId = context.taskAttemptId()
  
  println(s"Processing partition $partitionId in stage $stageId (attempt $attemptId)")
  
  // Process data with context information
  iter.map { value =>
    (partitionId, value, value * value)
  }
}

val result = processedData.collect()

Task Completion Listeners

import org.apache.spark.{TaskContext, TaskCompletionListener}

val data = sc.parallelize(1 to 1000, 10)

val processedData = data.mapPartitions { iter =>
  val context = TaskContext.get()
  
  // Register cleanup listener
  context.addTaskCompletionListener(new TaskCompletionListener {
    override def onTaskCompletion(context: TaskContext): Unit = {
      val metrics = context.taskMetrics()
      println(s"Task ${context.taskAttemptId()} completed in ${metrics.executorRunTime} ms")
      println(s"Memory spilled: ${metrics.memoryBytesSpilled} bytes")
      
      // Cleanup resources
      cleanupResources()
    }
  })
  
  // Process data
  iter.map(expensiveProcessing)
}

processedData.count()

Task Failure Handling

import org.apache.spark.{TaskContext, TaskFailureListener}

val unreliableData = sc.parallelize(1 to 100)

val processedData = unreliableData.mapPartitions { iter =>
  val context = TaskContext.get()
  
  // Register failure listener for cleanup
  context.addTaskFailureListener(new TaskFailureListener {
    override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
      println(s"Task ${context.taskAttemptId()} failed: ${error.getMessage}")
      
      // Cleanup external resources
      cleanupExternalConnections()
      savePartialResults(context.partitionId())
    }
  })
  
  // Potentially failing operation
  iter.map { value =>
    if (value % 50 == 0) {
      throw new RuntimeException(s"Simulated failure on value $value")
    }
    processValue(value)
  }
}

processedData.count()

Memory Management

import org.apache.spark.memory.MemoryConsumer

val data = sc.parallelize(1 to 1000000)

val processedData = data.mapPartitions { iter =>
  val context = TaskContext.get()
  val memoryManager = context.taskMemoryManager()
  
  // Custom memory consumer for large operations
  val consumer = new MemoryConsumer(memoryManager) {
    override def spill(size: Long, trigger: MemoryConsumer): Long = {
      // Custom spill logic
      spillToTemp(size)
    }
  }
  
  try {
    // Acquire memory for processing
    val requiredMemory = 1024 * 1024 // 1MB
    val acquiredMemory = memoryManager.acquireExecutionMemory(requiredMemory, consumer)
    
    if (acquiredMemory < requiredMemory) {
      println(s"Only acquired $acquiredMemory bytes out of $requiredMemory requested")
    }
    
    // Process data with allocated memory
    val buffer = new Array[Int](acquiredMemory.toInt / 4) // 4 bytes per int
    processWithBuffer(iter, buffer)
    
  } finally {
    // Release memory
    memoryManager.releaseAllExecutionMemoryForConsumer(consumer)
  }
}

processedData.count()

Monitoring Task Progress

import org.apache.spark.util.AccumulatorV2

// Custom accumulator for tracking progress
class ProgressAccumulator extends AccumulatorV2[Long, Long] {
  private var _value = 0L
  
  def isZero: Boolean = _value == 0
  def copy(): ProgressAccumulator = {
    val acc = new ProgressAccumulator
    acc._value = _value
    acc
  }
  def reset(): Unit = _value = 0
  def add(v: Long): Unit = _value += v
  def merge(other: AccumulatorV2[Long, Long]): Unit = _value += other.value
  def value: Long = _value
}

val progressTracker = new ProgressAccumulator
sc.register(progressTracker, "Progress Tracker")

val data = sc.parallelize(1 to 10000, 10)

val processedData = data.mapPartitions { iter =>
  val context = TaskContext.get()
  val partitionId = context.partitionId()
  var processed = 0L
  
  // Register completion listener to report final progress
  context.addTaskCompletionListener(new TaskCompletionListener {
    override def onTaskCompletion(context: TaskContext): Unit = {
      println(s"Partition $partitionId processed $processed records")
    }
  })
  
  iter.map { value =>
    val result = processValue(value)
    processed += 1
    progressTracker.add(1) // Update global progress
    
    // Report progress periodically
    if (processed % 100 == 0) {
      println(s"Partition $partitionId: $processed records processed")
    }
    
    result
  }
}

processedData.count()
println(s"Total records processed: ${progressTracker.value}")

Resource Information

val data = sc.parallelize(1 to 100)

val resourceInfo = data.mapPartitions { iter =>
  val context = TaskContext.get()
  
  // Get available resources
  val resources = context.resources()
  val cpus = context.cpus()
  
  println(s"Task has $cpus CPUs available")
  resources.foreach { case (resourceName, resourceInfo) =>
    println(s"Resource $resourceName: ${resourceInfo.addresses.mkString(", ")}")
  }
  
  // Adapt processing based on available resources
  val batchSize = if (cpus > 4) 1000 else 100
  
  iter.grouped(batchSize).map(processBatch)
}.collect()

Best Practices

Task Context Usage

  1. Always check availability: Use TaskContext.get() and handle null case
  2. Register listeners early: Add listeners at the beginning of task execution
  3. Clean up resources: Use completion listeners for guaranteed cleanup
  4. Handle interruption: Check isInterrupted() in long-running operations
  5. Partition-aware processing: Use partition ID for distributed coordination

Memory Management

  1. Acquire before use: Always acquire memory before large allocations
  2. Release promptly: Release memory as soon as processing is complete
  3. Handle insufficient memory: Gracefully handle cases where less memory is available
  4. Implement spilling: Provide spill logic for memory consumers
  5. Monitor usage: Track memory usage through task metrics

Error Handling

  1. Register failure listeners: Always register cleanup for external resources
  2. Graceful degradation: Handle partial failures appropriately
  3. Resource cleanup: Ensure cleanup happens even on task failure
  4. Logging: Provide detailed error information for debugging
  5. Retry logic: Consider task-level retry strategies for transient failures