or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

catalog.mddata-sources.mddata-types.mddataframe-dataset.mdindex.mdsession-management.mdsql-functions.mdstreaming.mdudfs.md
tile.json

udfs.mddocs/

Apache Spark SQL - User-Defined Functions (UDFs)

Capabilities

User-Defined Function Registration and Management

  • Register scalar user-defined functions (UDFs) with support for 0-22 parameters and type-safe operation
  • Create and manage user-defined aggregate functions (UDAFs) for custom aggregation logic
  • Support for both temporary session-scoped and persistent catalog-scoped function registration
  • Handle function overloading and parameter type checking with comprehensive error reporting

Type-Safe Function Development

  • Define functions with explicit input and output types using Spark's type system and encoders
  • Support for complex data types including arrays, maps, structs, and user-defined types in function signatures
  • Enable null handling and error propagation with configurable behavior for edge cases
  • Provide compile-time type checking for function parameters and return values

Function Optimization and Performance

  • Control function determinism flags for query optimization and caching behavior
  • Support for code generation and vectorized execution for high-performance function evaluation
  • Handle broadcast variables and accumulators within UDF context for distributed computation
  • Enable function pushdown and predicate optimization where supported by data sources

Advanced Function Features

  • Create column-oriented functions that operate on entire columns for vectorized processing
  • Support for higher-order functions and functional programming patterns within SQL expressions
  • Handle state management and context passing for complex stateful function implementations
  • Enable function composition and chaining for building complex transformation pipelines

API Reference

UDFRegistration Class

abstract class UDFRegistration {
  // Scalar UDF registration (0-22 parameters)
  def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: (A1, A2) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: (A1, A2, A3) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction
  
  // Register with custom UserDefinedFunction
  def register(name: String, udf: UserDefinedFunction): UserDefinedFunction
  
  // UDAF registration
  def register[BT, RT: TypeTag](name: String, udaf: expressions.UserDefinedAggregateFunction): Unit
}

UserDefinedFunction Class

case class UserDefinedFunction protected[sql] (
    f: AnyRef,
    dataType: DataType,
    inputTypes: Seq[DataType],
    name: Option[String] = None,
    nullable: Boolean = true,
    deterministic: Boolean = true) {
  
  // Function application
  def apply(exprs: Column*): Column
  
  // Configuration methods
  def withName(name: String): UserDefinedFunction
  def asNonNullable(): UserDefinedFunction
  def asNondeterministic(): UserDefinedFunction
  
  // Type information
  def isNullable: Boolean
  def isDeterministic: Boolean
}

functions Object UDF Creation

object functions {
  // UDF creation functions (0-22 parameters)
  def udf[RT: TypeTag](f: () => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: (A1, A2, A3) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: (A1, A2, A3, A4) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction
  
  // Call registered UDF
  def callUDF(udfName: String, cols: Column*): Column
}

User-Defined Aggregate Functions (UDAFs)

// Base class for typed UDAFs
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
  // Buffer operations
  def zero: BUF
  def reduce(buffer: BUF, input: IN): BUF
  def merge(buffer1: BUF, buffer2: BUF): BUF
  def finish(buffer: BUF): OUT
  
  // Encoders
  def bufferEncoder: Encoder[BUF]
  def outputEncoder: Encoder[OUT]
  
  // Convert to UDAF
  def toColumn: TypedColumn[IN, OUT]
}

// Untyped UDAF base class
abstract class UserDefinedAggregateFunction extends Expression with ImplicitCastInputTypes {
  // Data types
  def inputTypes: Seq[AbstractDataType]
  def bufferSchema: StructType
  def dataType: DataType
  def deterministic: Boolean
  
  // Aggregation operations
  def initialize(buffer: InternalRow): Unit
  def update(buffer: InternalRow, input: InternalRow): Unit
  def merge(buffer1: InternalRow, buffer2: InternalRow): Unit
  def evaluate(buffer: InternalRow): Any
}

Column Functions for UDFs

// Column operations that support UDFs
class Column {
  // Transform operations
  def transform(f: Column => Column): Column
  
  // Higher-order function support
  def filter(f: Column => Column): Column
  def exists(f: Column => Column): Column
  def forall(f: Column => Column): Column
  def aggregate(initialValue: Column, merge: (Column, Column) => Column): Column
  def aggregate(initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column
}

// Higher-order functions
object functions {
  def transform(column: Column, function: Column => Column): Column
  def filter(column: Column, function: Column => Column): Column  
  def exists(column: Column, function: Column => Column): Column
  def forall(column: Column, function: Column => Column): Column
  def aggregate(column: Column, initialValue: Column, merge: (Column, Column) => Column): Column
  def aggregate(column: Column, initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column
  def array_sort(e: Column, comparator: (Column, Column) => Column): Column
}

Usage Examples

Basic UDF Creation and Registration

import org.apache.spark.sql.{SparkSession, functions => F}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._

val spark = SparkSession.builder()
  .appName("UDF Examples")
  .getOrCreate()

import spark.implicits._

// Sample data
val employeeData = Seq(
  ("Alice Johnson", "alice.johnson@company.com", 25, 75000.0),
  ("Bob Smith", "bob.smith@company.com", 30, 85000.0),
  ("Charlie Brown", "charlie.brown@company.com", 35, 95000.0)
).toDF("name", "email", "age", "salary")

// Simple UDF examples
val extractFirstName = udf((fullName: String) => {
  if (fullName != null && fullName.contains(" ")) {
    fullName.split(" ")(0)
  } else {
    fullName
  }
})

val calculateBonus = udf((salary: Double, age: Int) => {
  val ageMultiplier = if (age > 30) 0.15 else 0.10
  salary * ageMultiplier
})

val isValidEmail = udf((email: String) => {
  email != null && email.contains("@") && email.contains(".")
})

// Apply UDFs to DataFrame
val enrichedData = employeeData.select(
  $"name",
  extractFirstName($"name").as("first_name"),
  $"email",
  isValidEmail($"email").as("valid_email"),
  $"age",
  $"salary", 
  calculateBonus($"salary", $"age").as("bonus")
)

enrichedData.show()

// Register UDFs for SQL usage
spark.udf.register("extract_first_name", extractFirstName)
spark.udf.register("calculate_bonus", calculateBonus)
spark.udf.register("is_valid_email", isValidEmail)

// Use registered UDFs in SQL
employeeData.createOrReplaceTempView("employees")

val sqlResult = spark.sql("""
  SELECT 
    name,
    extract_first_name(name) as first_name,
    email,
    is_valid_email(email) as valid_email,
    age,
    salary,
    calculate_bonus(salary, age) as bonus
  FROM employees
""")

sqlResult.show()

Advanced UDF Features

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

// UDF with complex return type
case class PersonInfo(firstName: String, domain: String, salaryGrade: String)

val extractPersonInfo = udf((name: String, email: String, salary: Double) => {
  val firstName = if (name != null && name.contains(" ")) {
    name.split(" ")(0)
  } else {
    name
  }
  
  val domain = if (email != null && email.contains("@")) {
    email.split("@")(1)
  } else {
    "unknown"
  }
  
  val salaryGrade = salary match {
    case s if s < 50000 => "Entry"
    case s if s < 80000 => "Mid"
    case s if s < 100000 => "Senior"
    case _ => "Executive"
  }
  
  PersonInfo(firstName, domain, salaryGrade)
})

// Apply complex UDF
val personInfoData = employeeData.select(
  $"name",
  $"email", 
  $"salary",
  extractPersonInfo($"name", $"email", $"salary").as("person_info")
).select(
  $"name",
  $"email",
  $"salary",
  $"person_info.firstName",
  $"person_info.domain", 
  $"person_info.salaryGrade"
)

personInfoData.show()

// UDF with array input and output
val extractNameParts = udf((fullName: String) => {
  if (fullName != null) {
    fullName.split(" ").toSeq
  } else {
    Seq.empty[String]
  }
})

val joinNameParts = udf((parts: Seq[String]) => {
  if (parts != null && parts.nonEmpty) {
    parts.mkString(" ")
  } else {
    null
  }
})

// UDF with Map input/output
val parseEmailInfo = udf((email: String) => {
  if (email != null && email.contains("@")) {
    val parts = email.split("@")
    Map(
      "username" -> parts(0),
      "domain" -> parts(1),
      "is_company_email" -> parts(1).contains("company")
    )
  } else {
    Map.empty[String, Any]
  }
})

val arrayMapData = employeeData.select(
  $"name",
  $"email",
  extractNameParts($"name").as("name_parts"),
  parseEmailInfo($"email").as("email_info")
)

arrayMapData.show(truncate = false)

Null Handling and Error Management

// UDF with null handling
val safeDivide = udf((numerator: Double, denominator: Double) => {
  if (denominator != 0.0) {
    Some(numerator / denominator)
  } else {
    None
  }
})

// UDF with error handling
val parseAge = udf((ageStr: String) => {
  try {
    if (ageStr != null && ageStr.nonEmpty) {
      Some(ageStr.toInt)
    } else {
      None
    }
  } catch {
    case _: NumberFormatException => None
  }
})

// UDF with validation
val validateSalary = udf((salary: Double) => {
  salary match {
    case s if s < 0 => ("invalid", "Salary cannot be negative")
    case s if s > 1000000 => ("warning", "Salary seems unusually high")
    case s if s < 20000 => ("warning", "Salary seems unusually low")
    case _ => ("valid", "Salary is within normal range")
  }
})

// Test data with nulls and errors
val testData = Seq(
  ("John", "25", 50000.0),
  ("Jane", "abc", 75000.0), // Invalid age
  ("Bob", null, -1000.0), // Null age, negative salary
  ("Alice", "30", 2000000.0) // High salary
).toDF("name", "age_str", "salary")

val validatedData = testData.select(
  $"name",
  $"age_str",
  parseAge($"age_str").as("parsed_age"),
  $"salary",
  validateSalary($"salary").as("salary_validation"),
  safeDivide($"salary", lit(12.0)).as("monthly_salary")
)

validatedData.show(truncate = false)

// Non-nullable and deterministic UDFs
val strictUpperCase = udf((input: String) => {
  input.toUpperCase
}).asNonNullable()

val deterministicHash = udf((input: String) => {
  input.hashCode
}).asNondeterministic() // Mark as non-deterministic if using random elements

val nonDeterministicId = udf(() => {
  java.util.UUID.randomUUID().toString
}).asNondeterministic()

User-Defined Aggregate Functions (UDAFs)

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

// Simple UDAF: Calculate geometric mean
class GeometricMean extends Aggregator[Double, (Double, Long), Double] {
  def zero: (Double, Long) = (1.0, 0L)
  
  def reduce(buffer: (Double, Long), input: Double): (Double, Long) = {
    if (input > 0) {
      (buffer._1 * input, buffer._2 + 1)
    } else {
      buffer
    }
  }
  
  def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
    (b1._1 * b2._1, b1._2 + b2._2)
  }
  
  def finish(buffer: (Double, Long)): Double = {
    if (buffer._2 > 0) {
      math.pow(buffer._1, 1.0 / buffer._2)
    } else {
      0.0
    }
  }
  
  def bufferEncoder: Encoder[(Double, Long)] = Encoders.product
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

// Complex UDAF: Calculate summary statistics
case class StatsSummary(
  count: Long,
  sum: Double, 
  sumSquares: Double,
  min: Double,
  max: Double
) {
  def mean: Double = if (count > 0) sum / count else 0.0
  def variance: Double = if (count > 1) (sumSquares - sum * sum / count) / (count - 1) else 0.0
  def stddev: Double = math.sqrt(variance)
}

class SummaryStats extends Aggregator[Double, StatsSummary, StatsSummary] {
  def zero: StatsSummary = StatsSummary(0L, 0.0, 0.0, Double.MaxValue, Double.MinValue)
  
  def reduce(buffer: StatsSummary, input: Double): StatsSummary = {
    StatsSummary(
      count = buffer.count + 1,
      sum = buffer.sum + input,
      sumSquares = buffer.sumSquares + input * input,
      min = math.min(buffer.min, input),
      max = math.max(buffer.max, input)
    )
  }
  
  def merge(b1: StatsSummary, b2: StatsSummary): StatsSummary = {
    if (b1.count == 0) b2
    else if (b2.count == 0) b1
    else StatsSummary(
      count = b1.count + b2.count,
      sum = b1.sum + b2.sum,
      sumSquares = b1.sumSquares + b2.sumSquares,
      min = math.min(b1.min, b2.min),
      max = math.max(b1.max, b2.max)
    )
  }
  
  def finish(buffer: StatsSummary): StatsSummary = {
    if (buffer.count == 0) {
      buffer.copy(min = 0.0, max = 0.0)
    } else {
      buffer
    }
  }
  
  def bufferEncoder: Encoder[StatsSummary] = Encoders.product
  def outputEncoder: Encoder[StatsSummary] = Encoders.product
}

// Register and use UDAFs
val geometricMean = new GeometricMean().toColumn
val summaryStats = new SummaryStats().toColumn

// Apply UDAFs
val aggregatedData = employeeData
  .agg(
    geometricMean.name("geometric_mean_salary"),
    summaryStats.name("salary_stats")
  )

aggregatedData.show()

// Extract fields from complex UDAF result
val detailedStats = employeeData
  .agg(summaryStats.name("stats"))
  .select(
    $"stats.count",
    $"stats.sum", 
    $"stats.mean",
    $"stats.variance",
    $"stats.stddev",
    $"stats.min",
    $"stats.max"
  )

detailedStats.show()

Higher-Order Functions and Functional UDFs

import org.apache.spark.sql.functions._

// Data with arrays for higher-order function examples
val arrayData = Seq(
  ("Alice", Array(85, 90, 88, 92)),
  ("Bob", Array(78, 85, 80, 88)),
  ("Charlie", Array(95, 88, 90, 94))
).toDF("name", "test_scores")

// UDF for array transformation
val scaleScore = udf((score: Int, scaleFactor: Double) => (score * scaleFactor).toInt)
val calculateGrade = udf((score: Int) => {
  score match {
    case s if s >= 90 => "A"
    case s if s >= 80 => "B" 
    case s if s >= 70 => "C"
    case s if s >= 60 => "D"
    case _ => "F"
  }
})

// Higher-order function usage with UDFs
val transformedScores = arrayData.select(
  $"name",
  $"test_scores",
  // Transform each score
  transform($"test_scores", score => scaleScore(score, lit(1.05))).as("scaled_scores"),
  // Filter scores above threshold
  filter($"test_scores", score => score > 85).as("high_scores"),
  // Check if all scores are passing
  forall($"test_scores", score => score >= 60).as("all_passing"),
  // Check if any score is excellent
  exists($"test_scores", score => score >= 95).as("has_excellent")
)

transformedScores.show(truncate = false)

// Complex aggregation with UDFs
val scoreAnalysis = arrayData.select(
  $"name",
  $"test_scores",
  // Calculate weighted average (more recent tests have higher weight)
  aggregate($"test_scores", lit(0.0), 
    (acc, score) => acc + score * lit(1.0)).as("sum_scores"),
  // Find maximum score
  aggregate($"test_scores", lit(0),
    (acc, score) => greatest(acc, score)).as("max_score"),
  // Calculate grade distribution
  transform($"test_scores", score => calculateGrade(score)).as("letter_grades")
)

scoreAnalysis.show(truncate = false)

Performance Optimization and Best Practices

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.util.AccumulatorV2

// Optimized UDF with broadcast variables
val taxRates = Map("US" -> 0.25, "UK" -> 0.20, "CA" -> 0.15, "DE" -> 0.30)
val broadcastTaxRates = spark.sparkContext.broadcast(taxRates)

val calculateNetSalary = udf((grossSalary: Double, country: String) => {
  val taxRate = broadcastTaxRates.value.getOrElse(country, 0.0)
  grossSalary * (1.0 - taxRate)
})

// Sample international employee data
val intlEmployees = Seq(
  ("Alice", 75000.0, "US"),
  ("Bob", 65000.0, "UK"),
  ("Charlie", 70000.0, "CA"),
  ("Diana", 80000.0, "DE")
).toDF("name", "gross_salary", "country")

val netSalaries = intlEmployees.select(
  $"name",
  $"gross_salary",
  $"country",
  calculateNetSalary($"gross_salary", $"country").as("net_salary")
)

netSalaries.show()

// UDF with accumulator for monitoring
class StringAccumulator extends AccumulatorV2[String, String] {
  private var _value = ""
  
  def isZero: Boolean = _value.isEmpty
  def copy(): StringAccumulator = {
    val acc = new StringAccumulator
    acc._value = _value
    acc
  }
  def reset(): Unit = _value = ""
  def add(v: String): Unit = _value += v + "\n"
  def merge(other: AccumulatorV2[String, String]): Unit = _value += other.value
  def value: String = _value
}

val errorAccumulator = new StringAccumulator
spark.sparkContext.register(errorAccumulator, "UDF Errors")

val robustProcessing = udf((data: String) => {
  try {
    // Simulate complex processing
    if (data == null || data.isEmpty) {
      errorAccumulator.add(s"Empty data encountered")
      "EMPTY"
    } else if (data.length < 3) {
      errorAccumulator.add(s"Short data: $data")
      "SHORT"
    } else {
      data.toUpperCase
    }
  } catch {
    case e: Exception =>
      errorAccumulator.add(s"Error processing '$data': ${e.getMessage}")
      "ERROR"
  }
})

// Vectorized UDF for better performance (when possible)
val vectorizedStringLength = udf((strings: Seq[String]) => {
  strings.map(s => if (s != null) s.length else 0)
})

// Batch processing example
val testStrings = Seq(
  Array("hello", "world", null, "spark", "sql"),
  Array("scala", "java", "", "python"),
  Array(null, "data", "engineering")
).toDF("string_array")

val lengthResults = testStrings.select(
  $"string_array",
  vectorizedStringLength($"string_array").as("lengths")
)

lengthResults.show(truncate = false)

// UDF lifecycle management
object UDFRegistry {
  private val registeredUDFs = scala.collection.mutable.Set[String]()
  
  def registerUDF[T](name: String, udf: UserDefinedFunction)(implicit spark: SparkSession): Unit = {
    spark.udf.register(name, udf)
    registeredUDFs += name
    println(s"Registered UDF: $name")
  }
  
  def listRegisteredUDFs(): Set[String] = registeredUDFs.toSet
  
  def cleanup()(implicit spark: SparkSession): Unit = {
    registeredUDFs.foreach { name =>
      try {
        // Note: Spark doesn't have built-in UDF unregistration, 
        // so we track them for documentation purposes
        println(s"UDF registered in session: $name")
      } catch {
        case e: Exception => 
          println(s"Note: UDF $name may need manual cleanup: ${e.getMessage}")
      }
    }
  }
}

// Register UDFs through registry
UDFRegistry.registerUDF("calculate_net_salary", calculateNetSalary)
UDFRegistry.registerUDF("robust_processing", robustProcessing)

// Use registered UDFs
intlEmployees.createOrReplaceTempView("international_employees")

val sqlWithUDF = spark.sql("""
  SELECT 
    name,
    gross_salary,
    country,
    calculate_net_salary(gross_salary, country) as net_salary
  FROM international_employees
""")

sqlWithUDF.show()

// Display UDF registry status
println("Registered UDFs: " + UDFRegistry.listRegisteredUDFs().mkString(", "))
println("Error accumulator value:")
println(errorAccumulator.value)

// Cleanup
UDFRegistry.cleanup()

UDF Testing and Validation

import org.apache.spark.sql.test.SharedSparkSession // In test environment

// UDF testing framework
class UDFTester(spark: SparkSession) {
  import spark.implicits._
  
  def testUDF[T, R](udf: UserDefinedFunction, testCases: Seq[(T, R)]): Boolean = {
    val testData = testCases.map(_._1).toDF("input")
    val expectedResults = testCases.map(_._2)
    
    val results = testData.select(udf($"input").as("result")).collect().map(_.getAs[R]("result"))
    
    val passed = results.zip(expectedResults).forall { case (actual, expected) =>
      actual == expected
    }
    
    if (passed) {
      println(s"✅ All ${testCases.length} test cases passed")
    } else {
      println(s"❌ Some test cases failed")
      results.zip(expectedResults).zipWithIndex.foreach { case ((actual, expected), idx) =>
        if (actual != expected) {
          println(s"  Test case $idx: expected $expected, got $actual")
        }
      }
    }
    
    passed
  }
  
  def benchmarkUDF[T](udf: UserDefinedFunction, data: DataFrame, column: String, iterations: Int = 5): Unit = {
    val times = (1 to iterations).map { _ =>
      val start = System.nanoTime()
      data.select(udf($"$column")).collect()
      val end = System.nanoTime()
      (end - start) / 1000000.0 // Convert to milliseconds
    }
    
    val avgTime = times.sum / times.length
    val minTime = times.min
    val maxTime = times.max
    
    println(s"UDF Performance (${iterations} iterations):")
    println(s"  Average: ${avgTime}ms")
    println(s"  Min: ${minTime}ms") 
    println(s"  Max: ${maxTime}ms")
  }
}

// Test the UDFs
val tester = new UDFTester(spark)

// Test extract first name UDF
val firstNameTests = Seq(
  ("John Doe", "John"),
  ("Alice", "Alice"),
  ("Bob Smith Jr", "Bob"),
  (null, null),
  ("", "")
)

tester.testUDF(extractFirstName, firstNameTests)

// Test calculate bonus UDF  
val bonusTests = Seq(
  ((50000.0, 25), 5000.0),
  ((80000.0, 35), 12000.0),
  ((60000.0, 30), 6000.0)
)

val bonusUDF = udf((salary: Double, age: Int) => {
  val ageMultiplier = if (age > 30) 0.15 else 0.10
  salary * ageMultiplier
})

tester.testUDF(bonusUDF, bonusTests)

// Performance benchmarking
val largeDataset = (1 to 10000).map(i => s"Name $i").toDF("name")
tester.benchmarkUDF(extractFirstName, largeDataset, "name")

// UDF documentation generator
def generateUDFDocumentation(udfs: Map[String, (UserDefinedFunction, String)]): String = {
  val docs = udfs.map { case (name, (udf, description)) =>
    s"""
    |## $name
    |
    |**Description**: $description
    |
    |**Input Types**: ${udf.inputTypes.mkString(", ")}
    |**Output Type**: ${udf.dataType}
    |**Nullable**: ${udf.nullable}
    |**Deterministic**: ${udf.deterministic}
    |
    """.stripMargin
  }.mkString("\n")
  
  s"""
  |# UDF Documentation
  |
  |$docs
  """.stripMargin
}

val udfDocs = Map(
  "extract_first_name" -> (extractFirstName, "Extracts the first name from a full name string"),
  "calculate_bonus" -> (bonusUDF, "Calculates employee bonus based on salary and age"),
  "is_valid_email" -> (isValidEmail, "Validates email format")
)

println(generateUDFDocumentation(udfDocs))