0
# Thread Utilities and Helper Classes
1
2
Thread-safe utilities with lexical scoping support for managing thread-local state across Spark operations, plus additional utility classes for log querying and event handling.
3
4
## LexicalThreadLocal
5
6
Helper trait for thread locals with lexical scoping that provides controlled propagation of values across thread boundaries.
7
8
```scala { .api }
9
import org.apache.spark.util.LexicalThreadLocal
10
11
trait LexicalThreadLocal[T] {
12
// Get current thread local value
13
def get(): Option[T]
14
15
// Create handle for controlled value propagation
16
def createHandle(opt: Option[T]): Handle
17
18
// Inner class for scoped execution
19
final class Handle private[LexicalThreadLocal](private val opt: Option[T]) {
20
def runWith[R](f: => R): R
21
}
22
}
23
```
24
25
### Basic Usage
26
27
```scala { .api }
28
// Define a thread local context
29
object MyExecutionContext extends LexicalThreadLocal[String] {
30
def create(value: String): Handle = createHandle(Some(value))
31
def clear(): Handle = createHandle(None)
32
}
33
34
// Usage pattern
35
def demonstrateThreadLocal(): Unit = {
36
// Initially no value
37
assert(MyExecutionContext.get() == None)
38
39
val handle = MyExecutionContext.create("my-context-value")
40
41
handle.runWith {
42
// Value is available within this scope
43
assert(MyExecutionContext.get() == Some("my-context-value"))
44
45
// Value propagates to nested calls
46
nestedFunction()
47
48
// Value remains available
49
assert(MyExecutionContext.get() == Some("my-context-value"))
50
}
51
52
// Value is cleared outside scope
53
assert(MyExecutionContext.get() == None)
54
}
55
56
def nestedFunction(): Unit = {
57
// Context value is inherited
58
val currentValue = MyExecutionContext.get()
59
println(s"Nested function sees: $currentValue")
60
}
61
```
62
63
### Advanced Usage Patterns
64
65
```scala { .api }
66
import org.apache.spark.util.LexicalThreadLocal
67
68
// Execution context with user information
69
object UserContext extends LexicalThreadLocal[User] {
70
def withUser[T](user: User)(f: => T): T = {
71
createHandle(Some(user)).runWith(f)
72
}
73
74
def getCurrentUser(): Option[User] = get()
75
}
76
77
case class User(id: String, name: String, roles: Set[String])
78
79
// Request tracking context
80
object RequestContext extends LexicalThreadLocal[RequestInfo] {
81
def withRequest[T](requestId: String, correlationId: String)(f: => T): T = {
82
val requestInfo = RequestInfo(requestId, correlationId, System.currentTimeMillis())
83
createHandle(Some(requestInfo)).runWith(f)
84
}
85
86
def getRequestId(): Option[String] = get().map(_.requestId)
87
def getCorrelationId(): Option[String] = get().map(_.correlationId)
88
}
89
90
case class RequestInfo(requestId: String, correlationId: String, startTime: Long)
91
92
// Usage in application code
93
class UserService {
94
def processUserData(user: User, data: String): String = {
95
UserContext.withUser(user) {
96
RequestContext.withRequest("req-123", "corr-456") {
97
98
// Both contexts are available throughout the call chain
99
validateAccess()
100
processData(data)
101
auditOperation()
102
103
"Processing completed"
104
}
105
}
106
}
107
108
private def validateAccess(): Unit = {
109
val user = UserContext.getCurrentUser().get
110
val requestId = RequestContext.getRequestId().get
111
112
println(s"Validating access for user ${user.id} in request $requestId")
113
114
if (!user.roles.contains("data_processor")) {
115
throw new SecurityException(s"User ${user.id} lacks required role")
116
}
117
}
118
119
private def processData(data: String): Unit = {
120
val correlationId = RequestContext.getCorrelationId().get
121
println(s"Processing data with correlation ID: $correlationId")
122
123
// Processing logic that may spawn additional threads
124
processInParallel(data)
125
}
126
127
private def processInParallel(data: String): Unit = {
128
// Thread local values don't automatically propagate to new threads
129
val userHandle = UserContext.createHandle(UserContext.get())
130
val requestHandle = RequestContext.createHandle(RequestContext.get())
131
132
val future = scala.concurrent.Future {
133
userHandle.runWith {
134
requestHandle.runWith {
135
// Context is available in the new thread
136
heavyProcessing(data)
137
}
138
}
139
}(scala.concurrent.ExecutionContext.global)
140
}
141
142
private def auditOperation(): Unit = {
143
val user = UserContext.getCurrentUser().get
144
val requestInfo = RequestContext.get().get
145
val duration = System.currentTimeMillis() - requestInfo.startTime
146
147
println(s"Audit: User ${user.id} completed operation in ${duration}ms")
148
}
149
}
150
```
151
152
## Log Utilities
153
154
Developer API for querying Spark logs with Spark SQL integration.
155
156
```scala { .api }
157
import org.apache.spark.util.LogUtils
158
159
object LogUtils {
160
// Schema definition for structured Spark logs
161
val SPARK_LOG_SCHEMA: String
162
}
163
```
164
165
### Schema and Log Analysis
166
167
```scala { .api }
168
import org.apache.spark.sql.SparkSession
169
import org.apache.spark.util.LogUtils
170
171
class LogAnalyzer {
172
val spark = SparkSession.builder()
173
.appName("LogAnalysis")
174
.getOrCreate()
175
176
def analyzeSparkLogs(logPath: String): Unit = {
177
// Read logs using the predefined schema
178
val logsDF = spark.read
179
.schema(LogUtils.SPARK_LOG_SCHEMA)
180
.json(logPath)
181
182
logsDF.createOrReplaceTempView("spark_logs")
183
184
// Analyze error patterns
185
analyzeErrorPatterns()
186
187
// Analyze performance metrics
188
analyzePerformanceMetrics()
189
190
// Analyze resource usage
191
analyzeResourceUsage()
192
}
193
194
private def analyzeErrorPatterns(): Unit = {
195
val errorSummary = spark.sql("""
196
SELECT
197
level,
198
logger,
199
COUNT(*) as error_count,
200
COLLECT_SET(mdc.error_class) as error_classes
201
FROM spark_logs
202
WHERE level IN ('ERROR', 'WARN')
203
GROUP BY level, logger
204
ORDER BY error_count DESC
205
""")
206
207
println("Error Pattern Analysis:")
208
errorSummary.show(20, truncate = false)
209
}
210
211
private def analyzePerformanceMetrics(): Unit = {
212
val performanceMetrics = spark.sql("""
213
SELECT
214
mdc.job_id,
215
mdc.stage_id,
216
AVG(CAST(mdc.duration AS LONG)) as avg_duration_ms,
217
MAX(CAST(mdc.duration AS LONG)) as max_duration_ms,
218
COUNT(*) as task_count
219
FROM spark_logs
220
WHERE mdc.duration IS NOT NULL
221
AND level = 'INFO'
222
GROUP BY mdc.job_id, mdc.stage_id
223
ORDER BY avg_duration_ms DESC
224
""")
225
226
println("Performance Metrics:")
227
performanceMetrics.show(20)
228
}
229
230
private def analyzeResourceUsage(): Unit = {
231
val resourceMetrics = spark.sql("""
232
SELECT
233
timestamp,
234
mdc.executor_id,
235
mdc.memory_used,
236
mdc.disk_used,
237
mdc.cpu_usage
238
FROM spark_logs
239
WHERE mdc.executor_id IS NOT NULL
240
AND (mdc.memory_used IS NOT NULL OR mdc.disk_used IS NOT NULL)
241
ORDER BY timestamp DESC
242
""")
243
244
println("Resource Usage Trends:")
245
resourceMetrics.show(20)
246
}
247
}
248
```
249
250
### Custom Log Processing
251
252
```scala { .api }
253
class CustomLogProcessor {
254
255
def processApplicationLogs(appId: String, logPath: String): Unit = {
256
val spark = SparkSession.builder().getOrCreate()
257
258
// Filter logs for specific application
259
val appLogsDF = spark.read
260
.schema(LogUtils.SPARK_LOG_SCHEMA)
261
.json(logPath)
262
.filter($"mdc.app_id" === appId)
263
264
// Create timeline of events
265
createEventTimeline(appLogsDF)
266
267
// Identify bottlenecks
268
identifyBottlenecks(appLogsDF)
269
270
// Generate recommendations
271
generateOptimizationRecommendations(appLogsDF)
272
}
273
274
private def createEventTimeline(logsDF: DataFrame): Unit = {
275
val timeline = logsDF
276
.select("timestamp", "level", "message", "mdc.job_id", "mdc.stage_id")
277
.orderBy("timestamp")
278
279
println("Application Timeline:")
280
timeline.show(50, truncate = false)
281
}
282
283
private def identifyBottlenecks(logsDF: DataFrame): Unit = {
284
// Find stages with high duration variance
285
val bottlenecks = logsDF
286
.filter($"mdc.duration".isNotNull)
287
.groupBy($"mdc.stage_id")
288
.agg(
289
avg($"mdc.duration").alias("avg_duration"),
290
stddev($"mdc.duration").alias("duration_variance"),
291
count("*").alias("task_count")
292
)
293
.filter($"duration_variance" > $"avg_duration" * 0.5) // High variance stages
294
.orderBy(desc("duration_variance"))
295
296
println("Potential Bottlenecks:")
297
bottlenecks.show()
298
}
299
}
300
```
301
302
## Event System Support
303
304
### SparkListenerEvent
305
306
Base trait for Spark listener events in the developer API.
307
308
```scala { .api }
309
import org.apache.spark.scheduler.SparkListenerEvent
310
311
trait SparkListenerEvent {
312
// Whether to log this event (protected[spark])
313
protected[spark] def logEvent: Boolean
314
}
315
```
316
317
### Custom Event Handling
318
319
```scala { .api }
320
import org.apache.spark.scheduler.{SparkListenerEvent, SparkListener}
321
import org.apache.spark.util.LexicalThreadLocal
322
323
// Custom event for application-specific tracking
324
case class CustomApplicationEvent(
325
eventType: String,
326
applicationId: String,
327
timestamp: Long,
328
metadata: Map[String, String]
329
) extends SparkListenerEvent {
330
protected[spark] def logEvent: Boolean = true
331
}
332
333
// Event context for correlation
334
object EventContext extends LexicalThreadLocal[String] {
335
def withEventId[T](eventId: String)(f: => T): T = {
336
createHandle(Some(eventId)).runWith(f)
337
}
338
}
339
340
// Custom listener with context tracking
341
class ContextAwareSparkListener extends SparkListener {
342
343
override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = {
344
EventContext.withEventId(s"app-start-${applicationStart.appId.getOrElse("unknown")}") {
345
logApplicationEvent("APPLICATION_STARTED", applicationStart.appId.getOrElse("unknown"))
346
347
// Additional context processing
348
processApplicationMetadata(applicationStart)
349
}
350
}
351
352
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
353
val eventId = s"job-${jobStart.jobId}-start"
354
EventContext.withEventId(eventId) {
355
logJobEvent("JOB_STARTED", jobStart.jobId, jobStart.stageInfos.length)
356
}
357
}
358
359
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
360
val stageInfo = stageCompleted.stageInfo
361
val eventId = s"stage-${stageInfo.stageId}-completed"
362
363
EventContext.withEventId(eventId) {
364
logStageEvent("STAGE_COMPLETED", stageInfo.stageId, stageInfo.submissionTime, stageInfo.completionTime)
365
}
366
}
367
368
private def logApplicationEvent(eventType: String, appId: String): Unit = {
369
val correlationId = EventContext.get().getOrElse("no-correlation")
370
println(s"[$correlationId] Application Event: $eventType for app $appId")
371
}
372
373
private def logJobEvent(eventType: String, jobId: Int, stageCount: Int): Unit = {
374
val correlationId = EventContext.get().getOrElse("no-correlation")
375
println(s"[$correlationId] Job Event: $eventType for job $jobId with $stageCount stages")
376
}
377
378
private def logStageEvent(eventType: String, stageId: Int, startTime: Option[Long], endTime: Option[Long]): Unit = {
379
val correlationId = EventContext.get().getOrElse("no-correlation")
380
val duration = (startTime, endTime) match {
381
case (Some(start), Some(end)) => s"${end - start}ms"
382
case _ => "unknown duration"
383
}
384
println(s"[$correlationId] Stage Event: $eventType for stage $stageId ($duration)")
385
}
386
}
387
```
388
389
## Integration Patterns
390
391
### Cross-Component Context Propagation
392
393
```scala { .api }
394
import org.apache.spark.internal.{Logging, LogKeys, MDC}
395
import org.apache.spark.util.LexicalThreadLocal
396
397
// Unified context for cross-component tracking
398
case class SparkExecutionContext(
399
applicationId: String,
400
jobId: Option[Int] = None,
401
stageId: Option[Int] = None,
402
taskId: Option[Long] = None,
403
userId: Option[String] = None,
404
correlationId: String
405
)
406
407
object SparkExecutionContext extends LexicalThreadLocal[SparkExecutionContext] {
408
409
def withContext[T](context: SparkExecutionContext)(f: => T): T = {
410
createHandle(Some(context)).runWith(f)
411
}
412
413
def withJob[T](jobId: Int)(f: => T): T = {
414
val currentContext = get().getOrElse(
415
throw new IllegalStateException("No execution context available")
416
)
417
val jobContext = currentContext.copy(jobId = Some(jobId))
418
createHandle(Some(jobContext)).runWith(f)
419
}
420
421
def withStage[T](stageId: Int)(f: => T): T = {
422
val currentContext = get().getOrElse(
423
throw new IllegalStateException("No execution context available")
424
)
425
val stageContext = currentContext.copy(stageId = Some(stageId))
426
createHandle(Some(stageContext)).runWith(f)
427
}
428
}
429
430
// Enhanced logging with automatic context injection
431
trait ContextualLogging extends Logging {
432
433
protected def logInfoWithContext(message: String): Unit = {
434
SparkExecutionContext.get() match {
435
case Some(context) =>
436
logInfo(
437
log"$message " +
438
log"app=${MDC(LogKeys.APP_ID, context.applicationId)} " +
439
log"correlation=${MDC(LogKeys.CORRELATION_ID, context.correlationId)}" +
440
context.jobId.map(id => log" job=${MDC(LogKeys.JOB_ID, id)}").getOrElse(log"") +
441
context.stageId.map(id => log" stage=${MDC(LogKeys.STAGE_ID, id)}").getOrElse(log"") +
442
context.taskId.map(id => log" task=${MDC(LogKeys.TASK_ID, id)}").getOrElse(log"")
443
)
444
case None =>
445
logInfo(message)
446
}
447
}
448
449
protected def logErrorWithContext(message: String, throwable: Throwable): Unit = {
450
SparkExecutionContext.get() match {
451
case Some(context) =>
452
logError(
453
log"$message " +
454
log"app=${MDC(LogKeys.APP_ID, context.applicationId)} " +
455
log"correlation=${MDC(LogKeys.CORRELATION_ID, context.correlationId)} " +
456
log"error=${MDC(LogKeys.ERROR_CLASS, throwable.getClass.getSimpleName)}",
457
throwable
458
)
459
case None =>
460
logError(message, throwable)
461
}
462
}
463
}
464
```
465
466
### Thread Pool Integration
467
468
```scala { .api }
469
import java.util.concurrent.{ExecutorService, Executors, Callable}
470
import scala.concurrent.{Future, ExecutionContext}
471
472
class ContextPropagatingExecutor(underlying: ExecutorService) extends ExecutorService {
473
474
override def submit[T](task: Callable[T]): java.util.concurrent.Future[T] = {
475
// Capture current thread local state
476
val contextHandle = SparkExecutionContext.createHandle(SparkExecutionContext.get())
477
478
// Wrap task to restore context in execution thread
479
val wrappedTask = new Callable[T] {
480
override def call(): T = {
481
contextHandle.runWith {
482
task.call()
483
}
484
}
485
}
486
487
underlying.submit(wrappedTask)
488
}
489
490
override def submit(task: Runnable): java.util.concurrent.Future[_] = {
491
val contextHandle = SparkExecutionContext.createHandle(SparkExecutionContext.get())
492
493
val wrappedTask = new Runnable {
494
override def run(): Unit = {
495
contextHandle.runWith {
496
task.run()
497
}
498
}
499
}
500
501
underlying.submit(wrappedTask)
502
}
503
504
// Delegate other ExecutorService methods
505
override def shutdown(): Unit = underlying.shutdown()
506
override def shutdownNow(): java.util.List[Runnable] = underlying.shutdownNow()
507
override def isShutdown: Boolean = underlying.isShutdown
508
override def isTerminated: Boolean = underlying.isTerminated
509
override def awaitTermination(timeout: Long, unit: java.util.concurrent.TimeUnit): Boolean =
510
underlying.awaitTermination(timeout, unit)
511
}
512
513
// Usage with Scala Futures
514
implicit val contextPropagatingEC: ExecutionContext =
515
ExecutionContext.fromExecutor(new ContextPropagatingExecutor(Executors.newFixedThreadPool(4)))
516
517
class AsyncProcessor extends ContextualLogging {
518
519
def processAsync(data: String): Future[String] = {
520
logInfoWithContext(s"Starting async processing of data: ${data.take(50)}...")
521
522
Future {
523
// Context is automatically available here
524
logInfoWithContext("Processing in background thread")
525
526
// Simulate processing
527
Thread.sleep(1000)
528
529
val result = data.toUpperCase
530
logInfoWithContext(s"Async processing completed, result length: ${result.length}")
531
532
result
533
}
534
}
535
}
536
```
537
538
## Best Practices
539
540
### Thread Local Management
541
542
```scala { .api }
543
class ThreadLocalBestPractices {
544
545
// 1. Always clear thread locals when no longer needed
546
def properCleanup(): Unit = {
547
val handle = MyExecutionContext.create("temporary-value")
548
549
try {
550
handle.runWith {
551
// Use the context
552
performOperation()
553
}
554
} finally {
555
// Thread local is automatically cleared when runWith completes
556
// No manual cleanup needed with LexicalThreadLocal
557
}
558
}
559
560
// 2. Propagate context explicitly across thread boundaries
561
def crossThreadPropagation(): Unit = {
562
val contextHandle = MyExecutionContext.createHandle(MyExecutionContext.get())
563
564
val future = Future {
565
contextHandle.runWith {
566
// Context is available in the new thread
567
performAsyncOperation()
568
}
569
}(ExecutionContext.global)
570
}
571
572
// 3. Use factory methods for common patterns
573
def withTimeout[T](timeoutMs: Long)(f: => T): T = {
574
val timeoutContext = TimeoutContext(System.currentTimeMillis() + timeoutMs)
575
TimeoutThreadLocal.withTimeout(timeoutContext)(f)
576
}
577
578
// 4. Validate context availability before use
579
def safeContextAccess(): Unit = {
580
MyExecutionContext.get() match {
581
case Some(value) =>
582
println(s"Context available: $value")
583
case None =>
584
throw new IllegalStateException("Required execution context not available")
585
}
586
}
587
}
588
589
// Example timeout context implementation
590
case class TimeoutContext(deadlineMs: Long) {
591
def isExpired: Boolean = System.currentTimeMillis() > deadlineMs
592
def remainingMs: Long = math.max(0, deadlineMs - System.currentTimeMillis())
593
}
594
595
object TimeoutThreadLocal extends LexicalThreadLocal[TimeoutContext] {
596
def withTimeout[T](timeout: TimeoutContext)(f: => T): T = {
597
createHandle(Some(timeout)).runWith(f)
598
}
599
600
def checkTimeout(): Unit = {
601
get() match {
602
case Some(context) if context.isExpired =>
603
throw new RuntimeException("Operation timed out")
604
case _ => // Continue
605
}
606
}
607
}
608
```
609
610
### Error Handling with Context
611
612
```scala { .api }
613
import org.apache.spark.SparkException
614
615
class ContextAwareErrorHandling extends ContextualLogging {
616
617
def processWithErrorContext(data: String): String = {
618
try {
619
logInfoWithContext("Starting data processing")
620
621
// Processing that might fail
622
if (data.isEmpty) {
623
throw new SparkException(
624
"EMPTY_INPUT_DATA",
625
Map("operation" -> "data_processing"),
626
null
627
)
628
}
629
630
val result = data.reverse
631
logInfoWithContext("Data processing completed successfully")
632
result
633
634
} catch {
635
case ex: SparkException =>
636
// Enhanced error logging with context
637
logErrorWithContext("Data processing failed with Spark exception", ex)
638
639
// Add context to exception if needed
640
val contextInfo = SparkExecutionContext.get().map(_.correlationId).getOrElse("unknown")
641
throw new SparkException(
642
ex.getCondition(),
643
ex.getMessageParameters().asScala.toMap + ("correlationId" -> contextInfo),
644
ex.getCause
645
)
646
647
case ex: Exception =>
648
logErrorWithContext("Data processing failed with unexpected exception", ex)
649
throw new SparkException("UNEXPECTED_PROCESSING_ERROR", Map.empty, ex)
650
}
651
}
652
}
653
```