Pluggable extension system for custom scheduler services and functionality in YARN deployments. This module provides well-defined interfaces for extending Spark's YARN integration with custom behavior and services.
Main extension point for implementing custom scheduler services that integrate with YARN scheduler backends.
trait SchedulerExtensionService {
def start(binding: SchedulerExtensionServiceBinding): Unit
def stop(): Unit
}Core Methods:
start(binding: SchedulerExtensionServiceBinding): Unit
stop(): Unit
Provides access to core scheduler components for extension implementations.
trait SchedulerExtensionServiceBinding {
def sparkContext(): SparkContext
def applicationAttemptId(): ApplicationAttemptId
def applicationId(): ApplicationId
}Component Access:
sparkContext(): Access to the SparkContext instanceapplicationAttemptId(): YARN application attempt identifierapplicationId(): YARN application identifierImplementation Example:
import org.apache.spark.scheduler.cluster.{SchedulerExtensionService, SchedulerExtensionServiceBinding}
import org.apache.spark.internal.Logging
class MetricsCollectionExtension extends SchedulerExtensionService with Logging {
private var metricsCollector: Option[YarnMetricsCollector] = None
private var binding: SchedulerExtensionServiceBinding = _
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
this.binding = binding
logInfo("Starting metrics collection extension")
val sc = binding.sparkContext()
val appId = binding.applicationId().toString
// Initialize custom metrics collection
val collector = new YarnMetricsCollector(sc.conf, appId)
collector.start()
metricsCollector = Some(collector)
logInfo(s"Metrics collection started for application: $appId")
}
override def stop(): Unit = {
logInfo("Stopping metrics collection extension")
metricsCollector.foreach { collector =>
try {
collector.stop()
logInfo("Metrics collection stopped successfully")
} catch {
case e: Exception => logError("Error stopping metrics collector", e)
}
}
metricsCollector = None
}
}Extensions are registered through Spark's configuration system using the spark.yarn.services property.
// Register extension service
val sparkConf = new SparkConf()
.set("spark.yarn.services", "com.example.MetricsCollectionExtension")
// Multiple extensions (comma-separated)
val sparkConf = new SparkConf()
.set("spark.yarn.services",
"com.example.MetricsCollectionExtension,com.example.LoggingExtension")Service Discovery:
// Extensions are loaded using reflection during scheduler backend initialization
val serviceClasses = sparkConf.get("spark.yarn.services", "").split(",").filter(_.nonEmpty)
serviceClasses.foreach { className =>
try {
val serviceClass = Class.forName(className.trim)
val service = serviceClass.newInstance().asInstanceOf[SchedulerExtensionService]
service.start(binding)
registeredServices += service
} catch {
case e: Exception => logError(s"Failed to load extension service: $className", e)
}
}class YarnMetricsExtension extends SchedulerExtensionService with Logging {
private var metricsReporter: Option[MetricsReporter] = None
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
val sc = binding.sparkContext()
val appId = binding.applicationId()
// Create custom metrics reporter for YARN-specific metrics
val reporter = new YarnMetricsReporter(
sc.conf,
appId.toString,
binding.applicationAttemptId().toString
)
// Register with Spark's metric system
sc.env.metricsSystem.registerSource(reporter)
// Start periodic reporting
reporter.startReporting()
metricsReporter = Some(reporter)
logInfo(s"YARN metrics extension started for app: $appId")
}
override def stop(): Unit = {
metricsReporter.foreach { reporter =>
reporter.stopReporting()
logInfo("YARN metrics extension stopped")
}
metricsReporter = None
}
}
class YarnMetricsReporter(conf: SparkConf, appId: String, attemptId: String)
extends Source with Logging {
override val sourceName: String = "yarn-metrics"
override val metricRegistry: MetricRegistry = new MetricRegistry()
// Register custom YARN metrics
metricRegistry.register("yarn.containers.requested", new Gauge[Int] {
override def getValue: Int = getRequestedContainers
})
metricRegistry.register("yarn.containers.allocated", new Gauge[Int] {
override def getValue: Int = getAllocatedContainers
})
def startReporting(): Unit = {
// Initialize periodic metrics collection
logInfo("Started YARN metrics reporting")
}
def stopReporting(): Unit = {
// Cleanup metrics collection
logInfo("Stopped YARN metrics reporting")
}
}class DynamicResourceExtension extends SchedulerExtensionService with Logging {
private var resourceMonitor: Option[ResourceMonitor] = None
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
val sc = binding.sparkContext()
// Custom resource monitoring and adjustment
val monitor = new ResourceMonitor(sc, binding.applicationId())
monitor.start()
resourceMonitor = Some(monitor)
logInfo("Dynamic resource extension started")
}
override def stop(): Unit = {
resourceMonitor.foreach(_.stop())
resourceMonitor = None
logInfo("Dynamic resource extension stopped")
}
}
class ResourceMonitor(sc: SparkContext, appId: ApplicationId) extends Logging {
private val monitorThread = new Thread("resource-monitor") {
override def run(): Unit = {
while (!Thread.currentThread().isInterrupted) {
try {
monitorAndAdjustResources()
Thread.sleep(30000) // Monitor every 30 seconds
} catch {
case _: InterruptedException =>
Thread.currentThread().interrupt()
case e: Exception =>
logError("Error in resource monitoring", e)
}
}
}
}
def start(): Unit = {
monitorThread.setDaemon(true)
monitorThread.start()
logInfo("Resource monitor thread started")
}
def stop(): Unit = {
monitorThread.interrupt()
try {
monitorThread.join(5000)
} catch {
case _: InterruptedException => Thread.currentThread().interrupt()
}
logInfo("Resource monitor thread stopped")
}
private def monitorAndAdjustResources(): Unit = {
// Implement custom resource monitoring logic
val statusTracker = sc.statusTracker
val executorInfos = statusTracker.getExecutorInfos
// Analyze executor utilization and queue state
val utilizationMetrics = analyzeUtilization(executorInfos)
// Make resource adjustment recommendations
if (utilizationMetrics.shouldScale) {
logInfo(s"Recommending resource scaling: ${utilizationMetrics.recommendation}")
// Could integrate with dynamic allocation or external systems
}
}
private def analyzeUtilization(executors: Array[SparkExecutorInfo]): UtilizationMetrics = {
// Custom utilization analysis logic
UtilizationMetrics(shouldScale = false, "No scaling needed")
}
}
case class UtilizationMetrics(shouldScale: Boolean, recommendation: String)class DiagnosticExtension extends SchedulerExtensionService with Logging {
private var diagnosticCollector: Option[DiagnosticCollector] = None
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
val sc = binding.sparkContext()
val appId = binding.applicationId()
val collector = new DiagnosticCollector(sc, appId.toString)
collector.start()
diagnosticCollector = Some(collector)
logInfo(s"Diagnostic extension started for application: $appId")
}
override def stop(): Unit = {
diagnosticCollector.foreach { collector =>
collector.generateReport()
collector.stop()
}
diagnosticCollector = None
logInfo("Diagnostic extension stopped")
}
}
class DiagnosticCollector(sc: SparkContext, appId: String) extends Logging {
private val diagnostics = mutable.ListBuffer[DiagnosticEvent]()
def start(): Unit = {
// Start collecting diagnostic information
sc.addSparkListener(new DiagnosticSparkListener(this))
logInfo("Diagnostic collection started")
}
def stop(): Unit = {
logInfo("Diagnostic collection stopped")
}
def addEvent(event: DiagnosticEvent): Unit = {
diagnostics.synchronized {
diagnostics += event
}
}
def generateReport(): Unit = {
val reportPath = s"/tmp/spark-diagnostics-$appId.json"
val report = DiagnosticReport(appId, System.currentTimeMillis(), diagnostics.toList)
// Write diagnostic report
import scala.util.Using
Using(new FileWriter(reportPath)) { writer =>
writer.write(report.toJson)
}
logInfo(s"Diagnostic report written to: $reportPath")
}
}
case class DiagnosticEvent(timestamp: Long, eventType: String, details: Map[String, String])
case class DiagnosticReport(appId: String, timestamp: Long, events: List[DiagnosticEvent]) {
def toJson: String = {
// Convert to JSON format
s"""{"appId":"$appId","timestamp":$timestamp,"events":[${events.map(_.toJson).mkString(",")}]}"""
}
}class ExternalSystemIntegration extends SchedulerExtensionService with Logging {
private var integration: Option[SystemConnector] = None
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
val sc = binding.sparkContext()
val appId = binding.applicationId()
// Connect to external monitoring/management system
val connector = new SystemConnector(sc.conf, appId.toString)
connector.connect()
// Register application with external system
connector.registerApplication(appId.toString, sc.appName)
integration = Some(connector)
logInfo(s"External system integration started for: $appId")
}
override def stop(): Unit = {
integration.foreach { connector =>
connector.unregisterApplication()
connector.disconnect()
}
integration = None
logInfo("External system integration stopped")
}
}
class SystemConnector(conf: SparkConf, appId: String) extends Logging {
private val systemUrl = conf.get("spark.yarn.external.system.url", "")
private val apiKey = conf.get("spark.yarn.external.system.apikey", "")
def connect(): Unit = {
// Establish connection to external system
logInfo(s"Connecting to external system: $systemUrl")
}
def disconnect(): Unit = {
// Close connection to external system
logInfo("Disconnected from external system")
}
def registerApplication(appId: String, appName: String): Unit = {
// Register Spark application with external system
logInfo(s"Registered application $appName ($appId) with external system")
}
def unregisterApplication(): Unit = {
// Unregister application from external system
logInfo(s"Unregistered application $appId from external system")
}
}// Extension-specific configuration
val sparkConf = new SparkConf()
.set("spark.yarn.services", "com.example.MyExtension")
.set("spark.yarn.extension.myext.enabled", "true")
.set("spark.yarn.extension.myext.interval", "30s")
.set("spark.yarn.extension.myext.endpoint", "http://monitoring-service:8080")class ConfigurableExtension extends SchedulerExtensionService {
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
val sc = binding.sparkContext()
val conf = sc.conf
// Read extension-specific configuration
val enabled = conf.getBoolean("spark.yarn.extension.myext.enabled", false)
val interval = conf.getTimeAsMs("spark.yarn.extension.myext.interval", "60s")
val endpoint = conf.get("spark.yarn.extension.myext.endpoint", "")
if (enabled) {
initializeExtension(interval, endpoint)
}
}
override def stop(): Unit = {
// Cleanup extension
}
private def initializeExtension(interval: Long, endpoint: String): Unit = {
// Initialize with configuration parameters
}
}class RobustExtension extends SchedulerExtensionService with Logging {
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
try {
// Extension initialization
initializeExtension(binding)
logInfo("Extension started successfully")
} catch {
case e: Exception =>
logError("Failed to start extension", e)
// Don't propagate exceptions that would break scheduler startup
// Instead, log error and continue with degraded functionality
}
}
override def stop(): Unit = {
try {
// Cleanup resources
cleanupExtension()
logInfo("Extension stopped successfully")
} catch {
case e: Exception =>
logError("Error during extension cleanup", e)
// Log but don't propagate cleanup errors
}
}
private def initializeExtension(binding: SchedulerExtensionServiceBinding): Unit = {
// Initialize extension with proper error handling
}
private def cleanupExtension(): Unit = {
// Cleanup resources with proper error handling
}
}class WellManagedExtension extends SchedulerExtensionService with Logging {
private var executorService: Option[ExecutorService] = None
private var httpClient: Option[CloseableHttpClient] = None
override def start(binding: SchedulerExtensionServiceBinding): Unit = {
// Create thread pool for extension work
val executor = Executors.newFixedThreadPool(2,
new ThreadFactoryBuilder()
.setNameFormat("yarn-extension-%d")
.setDaemon(true)
.build())
executorService = Some(executor)
// Create HTTP client for external communication
val client = HttpClients.createDefault()
httpClient = Some(client)
logInfo("Extension resources initialized")
}
override def stop(): Unit = {
// Shutdown thread pool
executorService.foreach { executor =>
executor.shutdown()
try {
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
executor.shutdownNow()
}
} catch {
case _: InterruptedException =>
executor.shutdownNow()
Thread.currentThread().interrupt()
}
}
// Close HTTP client
httpClient.foreach { client =>
try {
client.close()
} catch {
case e: IOException => logWarning("Error closing HTTP client", e)
}
}
logInfo("Extension resources cleaned up")
}
}class ExtensionTest extends AnyFunSuite with MockitoSugar {
test("extension should start and stop cleanly") {
val mockBinding = mock[SchedulerExtensionServiceBinding]
val mockSparkContext = mock[SparkContext]
val mockAppId = mock[ApplicationId]
when(mockBinding.sparkContext()).thenReturn(mockSparkContext)
when(mockBinding.applicationId()).thenReturn(mockAppId)
when(mockSparkContext.conf).thenReturn(new SparkConf())
when(mockAppId.toString).thenReturn("app-123")
val extension = new MyExtension()
// Test start
extension.start(mockBinding)
// Test stop
extension.stop()
// Verify no exceptions thrown
}
}class ExtensionIntegrationTest extends SparkFunSuite with LocalSparkContext {
test("extension integrates with real Spark context") {
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("ExtensionTest")
.set("spark.yarn.services", "com.example.TestExtension")
sc = new SparkContext(conf)
// Verify extension was loaded and started
// This would require access to extension registry or other verification method
}
}