or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

context-config.mdindex.mdjava-api.mdrdd-operations.mdresource-management.mdserialization.mdshared-variables.mdstorage-caching.mdtask-context.md

task-context.mddocs/

0

# Task Context

1

2

TaskContext provides runtime information and utilities available to tasks during execution, including partition information, memory management, and lifecycle hooks.

3

4

## TaskContext

5

6

The main context object available to running tasks.

7

8

```scala { .api }

9

abstract class TaskContext {

10

// Task Identification

11

def partitionId(): Int

12

def stageId(): Int

13

def stageAttemptNumber(): Int

14

def taskAttemptId(): Long

15

def attemptNumber(): Int

16

def taskMetrics(): TaskMetrics

17

18

// Memory Management

19

def taskMemoryManager(): TaskMemoryManager

20

21

// Lifecycle Hooks

22

def registerTaskCompletionListener(listener: TaskCompletionListener): TaskContext

23

def registerTaskFailureListener(listener: TaskFailureListener): TaskContext

24

def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext

25

def addTaskFailureListener(listener: TaskFailureListener): TaskContext

26

27

// Status Information

28

def isCompleted(): Boolean

29

def isRunningLocally(): Boolean

30

def isInterrupted(): Boolean

31

def getKillReason(): Option[String]

32

def killTaskIfInterrupted(): Unit

33

34

// Resources and Metrics

35

def resources(): Map[String, ResourceInformation]

36

def cpus(): Int

37

def getMetricsSources(sourceName: String): Seq[Source]

38

39

// Local Properties

40

def getLocalProperty(key: String): String

41

def setTaskThread(thread: Thread): Unit

42

def markTaskCompleted(error: Option[Throwable]): Unit

43

def markTaskFailed(error: Throwable): Unit

44

}

45

46

object TaskContext {

47

def get(): TaskContext

48

def getPartitionId(): Int

49

}

50

```

51

52

## Task Listeners

53

54

Interfaces for handling task lifecycle events.

55

56

```scala { .api }

57

trait TaskCompletionListener {

58

def onTaskCompletion(context: TaskContext): Unit

59

}

60

61

trait TaskFailureListener {

62

def onTaskFailure(context: TaskContext, error: Throwable): Unit

63

}

64

```

65

66

## TaskMetrics

67

68

Metrics and statistics collected during task execution.

69

70

```scala { .api }

71

class TaskMetrics {

72

// Execution Metrics

73

def executorDeserializeTime: Long

74

def executorDeserializeCpuTime: Long

75

def executorRunTime: Long

76

def executorCpuTime: Long

77

def resultSize: Long

78

def jvmGCTime: Long

79

def resultSerializationTime: Long

80

def memoryBytesSpilled: Long

81

def diskBytesSpilled: Long

82

def peakExecutionMemory: Long

83

84

// I/O Metrics

85

def inputMetrics: InputMetrics

86

def outputMetrics: OutputMetrics

87

def shuffleReadMetrics: ShuffleReadMetrics

88

def shuffleWriteMetrics: ShuffleWriteMetrics

89

90

// Update Methods

91

def incExecutorDeserializeTime(v: Long): Unit

92

def incExecutorRunTime(v: Long): Unit

93

def incResultSize(v: Long): Unit

94

def incJvmGCTime(v: Long): Unit

95

def incResultSerializationTime(v: Long): Unit

96

def incMemoryBytesSpilled(v: Long): Unit

97

def incDiskBytesSpilled(v: Long): Unit

98

def setPeakExecutionMemory(v: Long): Unit

99

}

100

```

101

102

## Input/Output/Shuffle Metrics

103

104

Detailed metrics for different types of I/O operations.

105

106

```scala { .api }

107

class InputMetrics {

108

def bytesRead: Long

109

def recordsRead: Long

110

def incBytesRead(v: Long): Unit

111

def incRecordsRead(v: Long): Unit

112

def setBytesRead(v: Long): Unit

113

def setRecordsRead(v: Long): Unit

114

}

115

116

class OutputMetrics {

117

def bytesWritten: Long

118

def recordsWritten: Long

119

def incBytesWritten(v: Long): Unit

120

def incRecordsWritten(v: Long): Unit

121

def setBytesWritten(v: Long): Unit

122

def setRecordsWritten(v: Long): Unit

123

}

124

125

class ShuffleReadMetrics {

126

def remoteBlocksFetched: Long

127

def localBlocksFetched: Long

128

def fetchWaitTime: Long

129

def remoteBytesRead: Long

130

def remoteBytesReadToDisk: Long

131

def localBytesRead: Long

132

def recordsRead: Long

133

134

def incRemoteBlocksFetched(v: Long): Unit

135

def incLocalBlocksFetched(v: Long): Unit

136

def incFetchWaitTime(v: Long): Unit

137

def incRemoteBytesRead(v: Long): Unit

138

def incRemoteBytesReadToDisk(v: Long): Unit

139

def incLocalBytesRead(v: Long): Unit

140

def incRecordsRead(v: Long): Unit

141

}

142

143

class ShuffleWriteMetrics {

144

def bytesWritten: Long

145

def recordsWritten: Long

146

def writeTime: Long

147

148

def incBytesWritten(v: Long): Unit

149

def incRecordsWritten(v: Long): Unit

150

def incWriteTime(v: Long): Unit

151

def setBytesWritten(v: Long): Unit

152

def setRecordsWritten(v: Long): Unit

153

def setWriteTime(v: Long): Unit

154

}

155

```

156

157

## TaskMemoryManager

158

159

Memory management interface for tasks.

160

161

```scala { .api }

162

class TaskMemoryManager(memoryManager: MemoryManager, taskAttemptId: Long) {

163

// Memory Acquisition

164

def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long

165

def releaseExecutionMemory(size: Long, consumer: MemoryConsumer): Unit

166

def releaseAllExecutionMemoryForConsumer(consumer: MemoryConsumer): Long

167

168

// Memory Information

169

def executionMemoryUsed: Long

170

def getTungstenMemoryMode: MemoryMode

171

172

// Page Management (for advanced users)

173

def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock

174

def freePage(page: MemoryBlock, consumer: MemoryConsumer): Unit

175

def getPage(pagePlusOffsetAddress: Long): MemoryBlock

176

def encodePageNumberAndOffset(page: MemoryBlock, offsetInPage: Long): Long

177

def encodePageNumberAndOffset(pageNumber: Int, offsetInPage: Long): Long

178

def decodePageNumber(pagePlusOffsetAddress: Long): Int

179

def decodeOffset(pagePlusOffsetAddress: Long): Long

180

181

// Cleanup

182

def cleanUpAllAllocatedMemory(): Long

183

}

184

```

185

186

## Usage Examples

187

188

### Basic Task Context Usage

189

```scala

190

import org.apache.spark.TaskContext

191

192

val data = sc.parallelize(1 to 100, 4)

193

194

val processedData = data.mapPartitions { iter =>

195

val context = TaskContext.get()

196

197

// Get partition information

198

val partitionId = context.partitionId()

199

val stageId = context.stageId()

200

val attemptId = context.taskAttemptId()

201

202

println(s"Processing partition $partitionId in stage $stageId (attempt $attemptId)")

203

204

// Process data with context information

205

iter.map { value =>

206

(partitionId, value, value * value)

207

}

208

}

209

210

val result = processedData.collect()

211

```

212

213

### Task Completion Listeners

214

```scala

215

import org.apache.spark.{TaskContext, TaskCompletionListener}

216

217

val data = sc.parallelize(1 to 1000, 10)

218

219

val processedData = data.mapPartitions { iter =>

220

val context = TaskContext.get()

221

222

// Register cleanup listener

223

context.addTaskCompletionListener(new TaskCompletionListener {

224

override def onTaskCompletion(context: TaskContext): Unit = {

225

val metrics = context.taskMetrics()

226

println(s"Task ${context.taskAttemptId()} completed in ${metrics.executorRunTime} ms")

227

println(s"Memory spilled: ${metrics.memoryBytesSpilled} bytes")

228

229

// Cleanup resources

230

cleanupResources()

231

}

232

})

233

234

// Process data

235

iter.map(expensiveProcessing)

236

}

237

238

processedData.count()

239

```

240

241

### Task Failure Handling

242

```scala

243

import org.apache.spark.{TaskContext, TaskFailureListener}

244

245

val unreliableData = sc.parallelize(1 to 100)

246

247

val processedData = unreliableData.mapPartitions { iter =>

248

val context = TaskContext.get()

249

250

// Register failure listener for cleanup

251

context.addTaskFailureListener(new TaskFailureListener {

252

override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {

253

println(s"Task ${context.taskAttemptId()} failed: ${error.getMessage}")

254

255

// Cleanup external resources

256

cleanupExternalConnections()

257

savePartialResults(context.partitionId())

258

}

259

})

260

261

// Potentially failing operation

262

iter.map { value =>

263

if (value % 50 == 0) {

264

throw new RuntimeException(s"Simulated failure on value $value")

265

}

266

processValue(value)

267

}

268

}

269

270

processedData.count()

271

```

272

273

### Memory Management

274

```scala

275

import org.apache.spark.memory.MemoryConsumer

276

277

val data = sc.parallelize(1 to 1000000)

278

279

val processedData = data.mapPartitions { iter =>

280

val context = TaskContext.get()

281

val memoryManager = context.taskMemoryManager()

282

283

// Custom memory consumer for large operations

284

val consumer = new MemoryConsumer(memoryManager) {

285

override def spill(size: Long, trigger: MemoryConsumer): Long = {

286

// Custom spill logic

287

spillToTemp(size)

288

}

289

}

290

291

try {

292

// Acquire memory for processing

293

val requiredMemory = 1024 * 1024 // 1MB

294

val acquiredMemory = memoryManager.acquireExecutionMemory(requiredMemory, consumer)

295

296

if (acquiredMemory < requiredMemory) {

297

println(s"Only acquired $acquiredMemory bytes out of $requiredMemory requested")

298

}

299

300

// Process data with allocated memory

301

val buffer = new Array[Int](acquiredMemory.toInt / 4) // 4 bytes per int

302

processWithBuffer(iter, buffer)

303

304

} finally {

305

// Release memory

306

memoryManager.releaseAllExecutionMemoryForConsumer(consumer)

307

}

308

}

309

310

processedData.count()

311

```

312

313

### Monitoring Task Progress

314

```scala

315

import org.apache.spark.util.AccumulatorV2

316

317

// Custom accumulator for tracking progress

318

class ProgressAccumulator extends AccumulatorV2[Long, Long] {

319

private var _value = 0L

320

321

def isZero: Boolean = _value == 0

322

def copy(): ProgressAccumulator = {

323

val acc = new ProgressAccumulator

324

acc._value = _value

325

acc

326

}

327

def reset(): Unit = _value = 0

328

def add(v: Long): Unit = _value += v

329

def merge(other: AccumulatorV2[Long, Long]): Unit = _value += other.value

330

def value: Long = _value

331

}

332

333

val progressTracker = new ProgressAccumulator

334

sc.register(progressTracker, "Progress Tracker")

335

336

val data = sc.parallelize(1 to 10000, 10)

337

338

val processedData = data.mapPartitions { iter =>

339

val context = TaskContext.get()

340

val partitionId = context.partitionId()

341

var processed = 0L

342

343

// Register completion listener to report final progress

344

context.addTaskCompletionListener(new TaskCompletionListener {

345

override def onTaskCompletion(context: TaskContext): Unit = {

346

println(s"Partition $partitionId processed $processed records")

347

}

348

})

349

350

iter.map { value =>

351

val result = processValue(value)

352

processed += 1

353

progressTracker.add(1) // Update global progress

354

355

// Report progress periodically

356

if (processed % 100 == 0) {

357

println(s"Partition $partitionId: $processed records processed")

358

}

359

360

result

361

}

362

}

363

364

processedData.count()

365

println(s"Total records processed: ${progressTracker.value}")

366

```

367

368

### Resource Information

369

```scala

370

val data = sc.parallelize(1 to 100)

371

372

val resourceInfo = data.mapPartitions { iter =>

373

val context = TaskContext.get()

374

375

// Get available resources

376

val resources = context.resources()

377

val cpus = context.cpus()

378

379

println(s"Task has $cpus CPUs available")

380

resources.foreach { case (resourceName, resourceInfo) =>

381

println(s"Resource $resourceName: ${resourceInfo.addresses.mkString(", ")}")

382

}

383

384

// Adapt processing based on available resources

385

val batchSize = if (cpus > 4) 1000 else 100

386

387

iter.grouped(batchSize).map(processBatch)

388

}.collect()

389

```

390

391

## Best Practices

392

393

### Task Context Usage

394

1. **Always check availability**: Use `TaskContext.get()` and handle null case

395

2. **Register listeners early**: Add listeners at the beginning of task execution

396

3. **Clean up resources**: Use completion listeners for guaranteed cleanup

397

4. **Handle interruption**: Check `isInterrupted()` in long-running operations

398

5. **Partition-aware processing**: Use partition ID for distributed coordination

399

400

### Memory Management

401

1. **Acquire before use**: Always acquire memory before large allocations

402

2. **Release promptly**: Release memory as soon as processing is complete

403

3. **Handle insufficient memory**: Gracefully handle cases where less memory is available

404

4. **Implement spilling**: Provide spill logic for memory consumers

405

5. **Monitor usage**: Track memory usage through task metrics

406

407

### Error Handling

408

1. **Register failure listeners**: Always register cleanup for external resources

409

2. **Graceful degradation**: Handle partial failures appropriately

410

3. **Resource cleanup**: Ensure cleanup happens even on task failure

411

4. **Logging**: Provide detailed error information for debugging

412

5. **Retry logic**: Consider task-level retry strategies for transient failures