0
# Extension Points
1
2
Pluggable extension system for custom scheduler services and functionality in YARN deployments. This module provides well-defined interfaces for extending Spark's YARN integration with custom behavior and services.
3
4
## Capabilities
5
6
### SchedulerExtensionService
7
8
Main extension point for implementing custom scheduler services that integrate with YARN scheduler backends.
9
10
```scala { .api }
11
trait SchedulerExtensionService {
12
def start(binding: SchedulerExtensionServiceBinding): Unit
13
def stop(): Unit
14
}
15
```
16
17
**Core Methods:**
18
19
**`start(binding: SchedulerExtensionServiceBinding): Unit`**
20
- Called when scheduler backend starts
21
- Provides access to scheduler components through binding
22
- Should initialize extension-specific resources and services
23
24
**`stop(): Unit`**
25
- Called when scheduler backend stops
26
- Should cleanup extension resources and stop background services
27
- Must be idempotent and handle repeated calls gracefully
28
29
### SchedulerExtensionServiceBinding
30
31
Provides access to core scheduler components for extension implementations.
32
33
```scala { .api }
34
trait SchedulerExtensionServiceBinding {
35
def sparkContext(): SparkContext
36
def applicationAttemptId(): ApplicationAttemptId
37
def applicationId(): ApplicationId
38
}
39
```
40
41
**Component Access:**
42
- `sparkContext()`: Access to the SparkContext instance
43
- `applicationAttemptId()`: YARN application attempt identifier
44
- `applicationId()`: YARN application identifier
45
46
**Implementation Example:**
47
48
```scala
49
import org.apache.spark.scheduler.cluster.{SchedulerExtensionService, SchedulerExtensionServiceBinding}
50
import org.apache.spark.internal.Logging
51
52
class MetricsCollectionExtension extends SchedulerExtensionService with Logging {
53
private var metricsCollector: Option[YarnMetricsCollector] = None
54
private var binding: SchedulerExtensionServiceBinding = _
55
56
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
57
this.binding = binding
58
logInfo("Starting metrics collection extension")
59
60
val sc = binding.sparkContext()
61
val appId = binding.applicationId().toString
62
63
// Initialize custom metrics collection
64
val collector = new YarnMetricsCollector(sc.conf, appId)
65
collector.start()
66
67
metricsCollector = Some(collector)
68
logInfo(s"Metrics collection started for application: $appId")
69
}
70
71
override def stop(): Unit = {
72
logInfo("Stopping metrics collection extension")
73
74
metricsCollector.foreach { collector =>
75
try {
76
collector.stop()
77
logInfo("Metrics collection stopped successfully")
78
} catch {
79
case e: Exception => logError("Error stopping metrics collector", e)
80
}
81
}
82
83
metricsCollector = None
84
}
85
}
86
```
87
88
### Service Registration
89
90
Extensions are registered through Spark's configuration system using the `spark.yarn.services` property.
91
92
```scala
93
// Register extension service
94
val sparkConf = new SparkConf()
95
.set("spark.yarn.services", "com.example.MetricsCollectionExtension")
96
97
// Multiple extensions (comma-separated)
98
val sparkConf = new SparkConf()
99
.set("spark.yarn.services",
100
"com.example.MetricsCollectionExtension,com.example.LoggingExtension")
101
```
102
103
**Service Discovery:**
104
```scala
105
// Extensions are loaded using reflection during scheduler backend initialization
106
val serviceClasses = sparkConf.get("spark.yarn.services", "").split(",").filter(_.nonEmpty)
107
108
serviceClasses.foreach { className =>
109
try {
110
val serviceClass = Class.forName(className.trim)
111
val service = serviceClass.newInstance().asInstanceOf[SchedulerExtensionService]
112
service.start(binding)
113
registeredServices += service
114
} catch {
115
case e: Exception => logError(s"Failed to load extension service: $className", e)
116
}
117
}
118
```
119
120
## Common Extension Patterns
121
122
### Monitoring and Metrics Extension
123
124
```scala
125
class YarnMetricsExtension extends SchedulerExtensionService with Logging {
126
private var metricsReporter: Option[MetricsReporter] = None
127
128
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
129
val sc = binding.sparkContext()
130
val appId = binding.applicationId()
131
132
// Create custom metrics reporter for YARN-specific metrics
133
val reporter = new YarnMetricsReporter(
134
sc.conf,
135
appId.toString,
136
binding.applicationAttemptId().toString
137
)
138
139
// Register with Spark's metric system
140
sc.env.metricsSystem.registerSource(reporter)
141
142
// Start periodic reporting
143
reporter.startReporting()
144
metricsReporter = Some(reporter)
145
146
logInfo(s"YARN metrics extension started for app: $appId")
147
}
148
149
override def stop(): Unit = {
150
metricsReporter.foreach { reporter =>
151
reporter.stopReporting()
152
logInfo("YARN metrics extension stopped")
153
}
154
metricsReporter = None
155
}
156
}
157
158
class YarnMetricsReporter(conf: SparkConf, appId: String, attemptId: String)
159
extends Source with Logging {
160
161
override val sourceName: String = "yarn-metrics"
162
override val metricRegistry: MetricRegistry = new MetricRegistry()
163
164
// Register custom YARN metrics
165
metricRegistry.register("yarn.containers.requested", new Gauge[Int] {
166
override def getValue: Int = getRequestedContainers
167
})
168
169
metricRegistry.register("yarn.containers.allocated", new Gauge[Int] {
170
override def getValue: Int = getAllocatedContainers
171
})
172
173
def startReporting(): Unit = {
174
// Initialize periodic metrics collection
175
logInfo("Started YARN metrics reporting")
176
}
177
178
def stopReporting(): Unit = {
179
// Cleanup metrics collection
180
logInfo("Stopped YARN metrics reporting")
181
}
182
}
183
```
184
185
### Resource Management Extension
186
187
```scala
188
class DynamicResourceExtension extends SchedulerExtensionService with Logging {
189
private var resourceMonitor: Option[ResourceMonitor] = None
190
191
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
192
val sc = binding.sparkContext()
193
194
// Custom resource monitoring and adjustment
195
val monitor = new ResourceMonitor(sc, binding.applicationId())
196
monitor.start()
197
198
resourceMonitor = Some(monitor)
199
logInfo("Dynamic resource extension started")
200
}
201
202
override def stop(): Unit = {
203
resourceMonitor.foreach(_.stop())
204
resourceMonitor = None
205
logInfo("Dynamic resource extension stopped")
206
}
207
}
208
209
class ResourceMonitor(sc: SparkContext, appId: ApplicationId) extends Logging {
210
private val monitorThread = new Thread("resource-monitor") {
211
override def run(): Unit = {
212
while (!Thread.currentThread().isInterrupted) {
213
try {
214
monitorAndAdjustResources()
215
Thread.sleep(30000) // Monitor every 30 seconds
216
} catch {
217
case _: InterruptedException =>
218
Thread.currentThread().interrupt()
219
case e: Exception =>
220
logError("Error in resource monitoring", e)
221
}
222
}
223
}
224
}
225
226
def start(): Unit = {
227
monitorThread.setDaemon(true)
228
monitorThread.start()
229
logInfo("Resource monitor thread started")
230
}
231
232
def stop(): Unit = {
233
monitorThread.interrupt()
234
try {
235
monitorThread.join(5000)
236
} catch {
237
case _: InterruptedException => Thread.currentThread().interrupt()
238
}
239
logInfo("Resource monitor thread stopped")
240
}
241
242
private def monitorAndAdjustResources(): Unit = {
243
// Implement custom resource monitoring logic
244
val statusTracker = sc.statusTracker
245
val executorInfos = statusTracker.getExecutorInfos
246
247
// Analyze executor utilization and queue state
248
val utilizationMetrics = analyzeUtilization(executorInfos)
249
250
// Make resource adjustment recommendations
251
if (utilizationMetrics.shouldScale) {
252
logInfo(s"Recommending resource scaling: ${utilizationMetrics.recommendation}")
253
// Could integrate with dynamic allocation or external systems
254
}
255
}
256
257
private def analyzeUtilization(executors: Array[SparkExecutorInfo]): UtilizationMetrics = {
258
// Custom utilization analysis logic
259
UtilizationMetrics(shouldScale = false, "No scaling needed")
260
}
261
}
262
263
case class UtilizationMetrics(shouldScale: Boolean, recommendation: String)
264
```
265
266
### Logging and Debug Extension
267
268
```scala
269
class DiagnosticExtension extends SchedulerExtensionService with Logging {
270
private var diagnosticCollector: Option[DiagnosticCollector] = None
271
272
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
273
val sc = binding.sparkContext()
274
val appId = binding.applicationId()
275
276
val collector = new DiagnosticCollector(sc, appId.toString)
277
collector.start()
278
279
diagnosticCollector = Some(collector)
280
logInfo(s"Diagnostic extension started for application: $appId")
281
}
282
283
override def stop(): Unit = {
284
diagnosticCollector.foreach { collector =>
285
collector.generateReport()
286
collector.stop()
287
}
288
diagnosticCollector = None
289
logInfo("Diagnostic extension stopped")
290
}
291
}
292
293
class DiagnosticCollector(sc: SparkContext, appId: String) extends Logging {
294
private val diagnostics = mutable.ListBuffer[DiagnosticEvent]()
295
296
def start(): Unit = {
297
// Start collecting diagnostic information
298
sc.addSparkListener(new DiagnosticSparkListener(this))
299
logInfo("Diagnostic collection started")
300
}
301
302
def stop(): Unit = {
303
logInfo("Diagnostic collection stopped")
304
}
305
306
def addEvent(event: DiagnosticEvent): Unit = {
307
diagnostics.synchronized {
308
diagnostics += event
309
}
310
}
311
312
def generateReport(): Unit = {
313
val reportPath = s"/tmp/spark-diagnostics-$appId.json"
314
val report = DiagnosticReport(appId, System.currentTimeMillis(), diagnostics.toList)
315
316
// Write diagnostic report
317
import scala.util.Using
318
Using(new FileWriter(reportPath)) { writer =>
319
writer.write(report.toJson)
320
}
321
322
logInfo(s"Diagnostic report written to: $reportPath")
323
}
324
}
325
326
case class DiagnosticEvent(timestamp: Long, eventType: String, details: Map[String, String])
327
case class DiagnosticReport(appId: String, timestamp: Long, events: List[DiagnosticEvent]) {
328
def toJson: String = {
329
// Convert to JSON format
330
s"""{"appId":"$appId","timestamp":$timestamp,"events":[${events.map(_.toJson).mkString(",")}]}"""
331
}
332
}
333
```
334
335
### Integration with External Systems
336
337
```scala
338
class ExternalSystemIntegration extends SchedulerExtensionService with Logging {
339
private var integration: Option[SystemConnector] = None
340
341
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
342
val sc = binding.sparkContext()
343
val appId = binding.applicationId()
344
345
// Connect to external monitoring/management system
346
val connector = new SystemConnector(sc.conf, appId.toString)
347
connector.connect()
348
349
// Register application with external system
350
connector.registerApplication(appId.toString, sc.appName)
351
352
integration = Some(connector)
353
logInfo(s"External system integration started for: $appId")
354
}
355
356
override def stop(): Unit = {
357
integration.foreach { connector =>
358
connector.unregisterApplication()
359
connector.disconnect()
360
}
361
integration = None
362
logInfo("External system integration stopped")
363
}
364
}
365
366
class SystemConnector(conf: SparkConf, appId: String) extends Logging {
367
private val systemUrl = conf.get("spark.yarn.external.system.url", "")
368
private val apiKey = conf.get("spark.yarn.external.system.apikey", "")
369
370
def connect(): Unit = {
371
// Establish connection to external system
372
logInfo(s"Connecting to external system: $systemUrl")
373
}
374
375
def disconnect(): Unit = {
376
// Close connection to external system
377
logInfo("Disconnected from external system")
378
}
379
380
def registerApplication(appId: String, appName: String): Unit = {
381
// Register Spark application with external system
382
logInfo(s"Registered application $appName ($appId) with external system")
383
}
384
385
def unregisterApplication(): Unit = {
386
// Unregister application from external system
387
logInfo(s"Unregistered application $appId from external system")
388
}
389
}
390
```
391
392
## Extension Configuration
393
394
### Configuration Properties
395
396
```scala
397
// Extension-specific configuration
398
val sparkConf = new SparkConf()
399
.set("spark.yarn.services", "com.example.MyExtension")
400
.set("spark.yarn.extension.myext.enabled", "true")
401
.set("spark.yarn.extension.myext.interval", "30s")
402
.set("spark.yarn.extension.myext.endpoint", "http://monitoring-service:8080")
403
```
404
405
### Configuration Access in Extensions
406
407
```scala
408
class ConfigurableExtension extends SchedulerExtensionService {
409
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
410
val sc = binding.sparkContext()
411
val conf = sc.conf
412
413
// Read extension-specific configuration
414
val enabled = conf.getBoolean("spark.yarn.extension.myext.enabled", false)
415
val interval = conf.getTimeAsMs("spark.yarn.extension.myext.interval", "60s")
416
val endpoint = conf.get("spark.yarn.extension.myext.endpoint", "")
417
418
if (enabled) {
419
initializeExtension(interval, endpoint)
420
}
421
}
422
423
override def stop(): Unit = {
424
// Cleanup extension
425
}
426
427
private def initializeExtension(interval: Long, endpoint: String): Unit = {
428
// Initialize with configuration parameters
429
}
430
}
431
```
432
433
## Error Handling and Best Practices
434
435
### Exception Handling
436
437
```scala
438
class RobustExtension extends SchedulerExtensionService with Logging {
439
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
440
try {
441
// Extension initialization
442
initializeExtension(binding)
443
logInfo("Extension started successfully")
444
} catch {
445
case e: Exception =>
446
logError("Failed to start extension", e)
447
// Don't propagate exceptions that would break scheduler startup
448
// Instead, log error and continue with degraded functionality
449
}
450
}
451
452
override def stop(): Unit = {
453
try {
454
// Cleanup resources
455
cleanupExtension()
456
logInfo("Extension stopped successfully")
457
} catch {
458
case e: Exception =>
459
logError("Error during extension cleanup", e)
460
// Log but don't propagate cleanup errors
461
}
462
}
463
464
private def initializeExtension(binding: SchedulerExtensionServiceBinding): Unit = {
465
// Initialize extension with proper error handling
466
}
467
468
private def cleanupExtension(): Unit = {
469
// Cleanup resources with proper error handling
470
}
471
}
472
```
473
474
### Resource Management Best Practices
475
476
```scala
477
class WellManagedExtension extends SchedulerExtensionService with Logging {
478
private var executorService: Option[ExecutorService] = None
479
private var httpClient: Option[CloseableHttpClient] = None
480
481
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
482
// Create thread pool for extension work
483
val executor = Executors.newFixedThreadPool(2,
484
new ThreadFactoryBuilder()
485
.setNameFormat("yarn-extension-%d")
486
.setDaemon(true)
487
.build())
488
executorService = Some(executor)
489
490
// Create HTTP client for external communication
491
val client = HttpClients.createDefault()
492
httpClient = Some(client)
493
494
logInfo("Extension resources initialized")
495
}
496
497
override def stop(): Unit = {
498
// Shutdown thread pool
499
executorService.foreach { executor =>
500
executor.shutdown()
501
try {
502
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
503
executor.shutdownNow()
504
}
505
} catch {
506
case _: InterruptedException =>
507
executor.shutdownNow()
508
Thread.currentThread().interrupt()
509
}
510
}
511
512
// Close HTTP client
513
httpClient.foreach { client =>
514
try {
515
client.close()
516
} catch {
517
case e: IOException => logWarning("Error closing HTTP client", e)
518
}
519
}
520
521
logInfo("Extension resources cleaned up")
522
}
523
}
524
```
525
526
## Testing Extensions
527
528
### Unit Testing
529
530
```scala
531
class ExtensionTest extends AnyFunSuite with MockitoSugar {
532
test("extension should start and stop cleanly") {
533
val mockBinding = mock[SchedulerExtensionServiceBinding]
534
val mockSparkContext = mock[SparkContext]
535
val mockAppId = mock[ApplicationId]
536
537
when(mockBinding.sparkContext()).thenReturn(mockSparkContext)
538
when(mockBinding.applicationId()).thenReturn(mockAppId)
539
when(mockSparkContext.conf).thenReturn(new SparkConf())
540
when(mockAppId.toString).thenReturn("app-123")
541
542
val extension = new MyExtension()
543
544
// Test start
545
extension.start(mockBinding)
546
547
// Test stop
548
extension.stop()
549
550
// Verify no exceptions thrown
551
}
552
}
553
```
554
555
### Integration Testing
556
557
```scala
558
class ExtensionIntegrationTest extends SparkFunSuite with LocalSparkContext {
559
test("extension integrates with real Spark context") {
560
val conf = new SparkConf()
561
.setMaster("local[2]")
562
.setAppName("ExtensionTest")
563
.set("spark.yarn.services", "com.example.TestExtension")
564
565
sc = new SparkContext(conf)
566
567
// Verify extension was loaded and started
568
// This would require access to extension registry or other verification method
569
}
570
}
571
```