Registration and usage of custom user-defined functions (UDFs) and user-defined aggregate functions (UDAFs). Enables extending Spark SQL with custom business logic and domain-specific operations.
Interface for registering user-defined functions that can be used in SQL and DataFrame operations.
/**
* Functions for registering user-defined functions
*/
class UDFRegistration {
/** Register UDF with 0 arguments */
def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction
/** Register UDF with 1 argument */
def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction
/** Register UDF with 2 arguments */
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](
name: String, func: (A1, A2) => RT): UserDefinedFunction
/** Register UDF with 3 arguments */
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
name: String, func: (A1, A2, A3) => RT): UserDefinedFunction
/** Register UDF with 4 arguments */
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction
/** Register UDF with 5 arguments */
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
// ... continues up to 22 arguments
/** Register user-defined aggregate function */
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction
/** Register UDF from UserDefinedFunction */
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction
}Wrapper for user-defined functions that can be applied to columns.
/**
* User-defined function that can be called in DataFrame operations
*/
class UserDefinedFunction {
/** Apply UDF to columns */
def apply(exprs: Column*): Column
/** Mark UDF as non-deterministic */
def asNondeterministic(): UserDefinedFunction
/** Mark UDF as non-nullable */
def asNonNullable(): UserDefinedFunction
/** UDF name (if registered) */
def name: Option[String]
/** Check if UDF is deterministic */
def deterministic: Boolean
/** Check if UDF is nullable */
def nullable: Boolean
}
/**
* Factory methods for creating UDFs
*/
object functions {
/** Create UDF with 1 argument */
def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction
/** Create UDF with 2 arguments */
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction
/** Create UDF with 3 arguments */
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
f: (A1, A2, A3) => RT): UserDefinedFunction
/** Create UDF with 4 arguments */
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
f: (A1, A2, A3, A4) => RT): UserDefinedFunction
/** Create UDF with 5 arguments */
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
// ... continues up to 22 arguments
}Usage Examples:
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
// Simple string transformation UDF
val upperCase = udf((s: String) => s.toUpperCase)
// Register for SQL usage
spark.udf.register("upper_case", upperCase)
// Use in DataFrame operations
val df = spark.table("employees")
val result = df.withColumn("upper_name", upperCase(col("name")))
// Use in SQL
val sqlResult = spark.sql("SELECT name, upper_case(name) as upper_name FROM employees")
// Complex business logic UDF
val calculateBonus = udf((salary: Double, performance: String, years: Int) => {
val baseBonus = salary * 0.1
val performanceMultiplier = performance match {
case "excellent" => 2.0
case "good" => 1.5
case "average" => 1.0
case _ => 0.5
}
val tenureBonus = if (years > 5) 1.2 else 1.0
baseBonus * performanceMultiplier * tenureBonus
})
spark.udf.register("calculate_bonus", calculateBonus)
val bonusData = df.withColumn("bonus",
calculateBonus(col("salary"), col("performance"), col("years_of_service"))
)
// Non-deterministic UDF (e.g., random values)
val randomId = udf(() => java.util.UUID.randomUUID().toString)
.asNondeterministic()
val withIds = df.withColumn("random_id", randomId())Abstract class for custom aggregate functions that work on groups of rows.
/**
* Base class for user-defined aggregate functions
*/
abstract class UserDefinedAggregateFunction extends Serializable {
/** Input schema for the aggregation function */
def inputSchema: StructType
/** Schema of the aggregation buffer */
def bufferSchema: StructType
/** Output data type */
def dataType: DataType
/** Whether function is deterministic */
def deterministic: Boolean
/** Initialize the aggregation buffer */
def initialize(buffer: MutableAggregationBuffer): Unit
/** Update buffer with new input row */
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/** Merge two aggregation buffers */
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/** Calculate final result from buffer */
def evaluate(buffer: Row): Any
}
/**
* Mutable aggregation buffer for UDAF implementations
*/
trait MutableAggregationBuffer extends Row {
/** Update value at index */
def update(i: Int, value: Any): Unit
}Usage Example:
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
// Custom aggregation function to calculate geometric mean
class GeometricMean extends UserDefinedAggregateFunction {
// Input: one double value
def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)
// Buffer: sum of logs and count
def bufferSchema: StructType = StructType(
StructField("logSum", DoubleType) ::
StructField("count", LongType) :: Nil
)
// Output: double
def dataType: DataType = DoubleType
def deterministic: Boolean = true
// Initialize buffer
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0 // logSum
buffer(1) = 0L // count
}
// Update buffer with new value
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val value = input.getDouble(0)
if (value > 0) {
buffer(0) = buffer.getDouble(0) + math.log(value)
buffer(1) = buffer.getLong(1) + 1
}
}
}
// Merge two buffers
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// Calculate final result
def evaluate(buffer: Row): Any = {
val count = buffer.getLong(1)
if (count == 0) {
null
} else {
math.exp(buffer.getDouble(0) / count)
}
}
}
// Register and use the UDAF
val geometricMean = new GeometricMean
spark.udf.register("geom_mean", geometricMean)
// Use in DataFrame operations
val result = spark.table("sales")
.groupBy("region")
.agg(geometricMean(col("amount")).alias("geom_avg_amount"))
// Use in SQL
val sqlResult = spark.sql("""
SELECT region, geom_mean(amount) as geom_avg_amount
FROM sales
GROUP BY region
""")Working with complex types:
// UDF that processes arrays
val arraySum = udf((arr: Seq[Int]) => if (arr != null) arr.sum else 0)
// UDF that processes structs
val extractField = udf((row: Row) => if (row != null) row.getString(0) else null)
// UDF that returns complex types
val createStruct = udf((name: String, age: Int) => Row(name, age, s"$name-$age"))
// Register with explicit return type
spark.udf.register("create_person", createStruct,
StructType(Seq(
StructField("name", StringType),
StructField("age", IntegerType),
StructField("id", StringType)
))
)Error handling in UDFs:
val safeDiv = udf((a: Double, b: Double) => {
try {
if (b == 0.0) None else Some(a / b)
} catch {
case _: Exception => None
}
})
// UDF with detailed error information
val validateEmail = udf((email: String) => {
if (email == null || email.isEmpty) {
Map("valid" -> false, "error" -> "Email is empty")
} else if (!email.contains("@")) {
Map("valid" -> false, "error" -> "Missing @ symbol")
} else {
Map("valid" -> true, "error" -> "")
}
})Performance optimization:
// Broadcast variables in UDFs
val lookupMap = spark.sparkContext.broadcast(Map(
"A" -> "Alpha",
"B" -> "Beta",
"C" -> "Gamma"
))
val lookupUdf = udf((code: String) => {
lookupMap.value.getOrElse(code, "Unknown")
})
// Use closure-free UDFs for better serialization
object UDFUtils {
def multiply(factor: Double): UserDefinedFunction = {
udf((value: Double) => value * factor)
}
def formatCurrency(locale: String): UserDefinedFunction = {
udf((amount: Double) => {
val formatter = java.text.NumberFormat.getCurrencyInstance(
java.util.Locale.forLanguageTag(locale)
)
formatter.format(amount)
})
}
}
val doubleValue = UDFUtils.multiply(2.0)
val formatUSD = UDFUtils.formatCurrency("en-US")Typed UDFs with case classes:
case class Person(name: String, age: Int)
case class PersonInfo(person: Person, category: String)
// Typed UDF working with case classes
val categorize = udf((person: Person) => {
val category = person.age match {
case age if age < 18 => "Minor"
case age if age < 65 => "Adult"
case _ => "Senior"
}
PersonInfo(person, category)
})
// Use with proper encoders
import spark.implicits._
val peopleDS = Seq(
Person("Alice", 25),
Person("Bob", 17),
Person("Carol", 70)
).toDS()
val categorizedDS = peopleDS.select(categorize(col("value")).alias("info"))Null handling:
val nullSafeUdf = udf((value: String) => {
Option(value).map(_.trim.toUpperCase).orNull
})
// Or mark as non-nullable if appropriate
val nonNullUdf = udf((value: String) => value.trim.toUpperCase)
.asNonNullable()Testing UDFs:
// Unit test UDFs outside Spark context
val testUdf = (s: String) => s.toUpperCase
assert(testUdf("hello") == "HELLO")
// Integration testing with Spark
val df = Seq("hello", "world").toDF("text")
val result = df.select(upperCase(col("text"))).collect()
assert(result.map(_.getString(0)) sameElements Array("HELLO", "WORLD"))Documentation and type safety:
/**
* Calculates compound interest
* @param principal Initial amount
* @param rate Annual interest rate (as decimal, e.g., 0.05 for 5%)
* @param years Number of years
* @param compoundingPeriods Number of times interest is compounded per year
* @return Final amount after compound interest
*/
val compoundInterest = udf((principal: Double, rate: Double, years: Int, compoundingPeriods: Int) => {
require(principal >= 0, "Principal must be non-negative")
require(rate >= 0, "Interest rate must be non-negative")
require(years >= 0, "Years must be non-negative")
require(compoundingPeriods > 0, "Compounding periods must be positive")
principal * math.pow(1 + rate / compoundingPeriods, compoundingPeriods * years)
})
spark.udf.register("compound_interest", compoundInterest)