or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

catalog.mdcolumns-functions.mddata-io.mddataset-dataframe.mdindex.mdsession-management.mdstreaming.mdtypes-encoders.mdudfs.md
tile.json

udfs.mddocs/

User-Defined Functions

Registration and usage of custom user-defined functions (UDFs) and user-defined aggregate functions (UDAFs). Enables extending Spark SQL with custom business logic and domain-specific operations.

Capabilities

UDF Registration

Interface for registering user-defined functions that can be used in SQL and DataFrame operations.

/**
 * Functions for registering user-defined functions
 */
class UDFRegistration {
  /** Register UDF with 0 arguments */
  def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction
  
  /** Register UDF with 1 argument */
  def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction
  
  /** Register UDF with 2 arguments */
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](
    name: String, func: (A1, A2) => RT): UserDefinedFunction
  
  /** Register UDF with 3 arguments */
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
    name: String, func: (A1, A2, A3) => RT): UserDefinedFunction
  
  /** Register UDF with 4 arguments */
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
    name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction
  
  /** Register UDF with 5 arguments */
  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
    name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
  
  // ... continues up to 22 arguments
  
  /** Register user-defined aggregate function */
  def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction
  
  /** Register UDF from UserDefinedFunction */
  def register(name: String, udf: UserDefinedFunction): UserDefinedFunction
}

UserDefinedFunction

Wrapper for user-defined functions that can be applied to columns.

/**
 * User-defined function that can be called in DataFrame operations
 */
class UserDefinedFunction {
  /** Apply UDF to columns */
  def apply(exprs: Column*): Column
  
  /** Mark UDF as non-deterministic */
  def asNondeterministic(): UserDefinedFunction
  
  /** Mark UDF as non-nullable */
  def asNonNullable(): UserDefinedFunction
  
  /** UDF name (if registered) */
  def name: Option[String]
  
  /** Check if UDF is deterministic */
  def deterministic: Boolean
  
  /** Check if UDF is nullable */
  def nullable: Boolean
}

/**
 * Factory methods for creating UDFs
 */
object functions {
  /** Create UDF with 1 argument */
  def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction
  
  /** Create UDF with 2 arguments */
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction
  
  /** Create UDF with 3 arguments */
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
    f: (A1, A2, A3) => RT): UserDefinedFunction
  
  /** Create UDF with 4 arguments */
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
    f: (A1, A2, A3, A4) => RT): UserDefinedFunction
  
  /** Create UDF with 5 arguments */
  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
    f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
  
  // ... continues up to 22 arguments
}

Usage Examples:

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._

// Simple string transformation UDF
val upperCase = udf((s: String) => s.toUpperCase)

// Register for SQL usage
spark.udf.register("upper_case", upperCase)

// Use in DataFrame operations
val df = spark.table("employees")
val result = df.withColumn("upper_name", upperCase(col("name")))

// Use in SQL
val sqlResult = spark.sql("SELECT name, upper_case(name) as upper_name FROM employees")

// Complex business logic UDF
val calculateBonus = udf((salary: Double, performance: String, years: Int) => {
  val baseBonus = salary * 0.1
  val performanceMultiplier = performance match {
    case "excellent" => 2.0
    case "good" => 1.5
    case "average" => 1.0
    case _ => 0.5
  }
  val tenureBonus = if (years > 5) 1.2 else 1.0
  baseBonus * performanceMultiplier * tenureBonus
})

spark.udf.register("calculate_bonus", calculateBonus)

val bonusData = df.withColumn("bonus", 
  calculateBonus(col("salary"), col("performance"), col("years_of_service"))
)

// Non-deterministic UDF (e.g., random values)
val randomId = udf(() => java.util.UUID.randomUUID().toString)
  .asNondeterministic()

val withIds = df.withColumn("random_id", randomId())

UserDefinedAggregateFunction

Abstract class for custom aggregate functions that work on groups of rows.

/**
 * Base class for user-defined aggregate functions
 */
abstract class UserDefinedAggregateFunction extends Serializable {
  /** Input schema for the aggregation function */
  def inputSchema: StructType
  
  /** Schema of the aggregation buffer */
  def bufferSchema: StructType
  
  /** Output data type */
  def dataType: DataType
  
  /** Whether function is deterministic */
  def deterministic: Boolean
  
  /** Initialize the aggregation buffer */
  def initialize(buffer: MutableAggregationBuffer): Unit
  
  /** Update buffer with new input row */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit
  
  /** Merge two aggregation buffers */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
  
  /** Calculate final result from buffer */
  def evaluate(buffer: Row): Any
}

/**
 * Mutable aggregation buffer for UDAF implementations
 */
trait MutableAggregationBuffer extends Row {
  /** Update value at index */
  def update(i: Int, value: Any): Unit
}

Usage Example:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

// Custom aggregation function to calculate geometric mean
class GeometricMean extends UserDefinedAggregateFunction {
  // Input: one double value
  def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)
  
  // Buffer: sum of logs and count
  def bufferSchema: StructType = StructType(
    StructField("logSum", DoubleType) :: 
    StructField("count", LongType) :: Nil
  )
  
  // Output: double
  def dataType: DataType = DoubleType
  def deterministic: Boolean = true
  
  // Initialize buffer
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0  // logSum
    buffer(1) = 0L   // count
  }
  
  // Update buffer with new value
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      val value = input.getDouble(0)
      if (value > 0) {
        buffer(0) = buffer.getDouble(0) + math.log(value)
        buffer(1) = buffer.getLong(1) + 1
      }
    }
  }
  
  // Merge two buffers
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  
  // Calculate final result
  def evaluate(buffer: Row): Any = {
    val count = buffer.getLong(1)
    if (count == 0) {
      null
    } else {
      math.exp(buffer.getDouble(0) / count)
    }
  }
}

// Register and use the UDAF
val geometricMean = new GeometricMean
spark.udf.register("geom_mean", geometricMean)

// Use in DataFrame operations
val result = spark.table("sales")
  .groupBy("region")
  .agg(geometricMean(col("amount")).alias("geom_avg_amount"))

// Use in SQL
val sqlResult = spark.sql("""
  SELECT region, geom_mean(amount) as geom_avg_amount 
  FROM sales 
  GROUP BY region
""")

Advanced UDF Patterns

Working with complex types:

// UDF that processes arrays
val arraySum = udf((arr: Seq[Int]) => if (arr != null) arr.sum else 0)

// UDF that processes structs
val extractField = udf((row: Row) => if (row != null) row.getString(0) else null)

// UDF that returns complex types
val createStruct = udf((name: String, age: Int) => Row(name, age, s"$name-$age"))

// Register with explicit return type
spark.udf.register("create_person", createStruct, 
  StructType(Seq(
    StructField("name", StringType),
    StructField("age", IntegerType),
    StructField("id", StringType)
  ))
)

Error handling in UDFs:

val safeDiv = udf((a: Double, b: Double) => {
  try {
    if (b == 0.0) None else Some(a / b)
  } catch {
    case _: Exception => None
  }
})

// UDF with detailed error information
val validateEmail = udf((email: String) => {
  if (email == null || email.isEmpty) {
    Map("valid" -> false, "error" -> "Email is empty")
  } else if (!email.contains("@")) {
    Map("valid" -> false, "error" -> "Missing @ symbol")
  } else {
    Map("valid" -> true, "error" -> "")
  }
})

Performance optimization:

// Broadcast variables in UDFs
val lookupMap = spark.sparkContext.broadcast(Map(
  "A" -> "Alpha",
  "B" -> "Beta", 
  "C" -> "Gamma"
))

val lookupUdf = udf((code: String) => {
  lookupMap.value.getOrElse(code, "Unknown")
})

// Use closure-free UDFs for better serialization
object UDFUtils {
  def multiply(factor: Double): UserDefinedFunction = {
    udf((value: Double) => value * factor)
  }
  
  def formatCurrency(locale: String): UserDefinedFunction = {
    udf((amount: Double) => {
      val formatter = java.text.NumberFormat.getCurrencyInstance(
        java.util.Locale.forLanguageTag(locale)
      )
      formatter.format(amount)
    })
  }
}

val doubleValue = UDFUtils.multiply(2.0)
val formatUSD = UDFUtils.formatCurrency("en-US")

Typed UDFs with case classes:

case class Person(name: String, age: Int)
case class PersonInfo(person: Person, category: String)

// Typed UDF working with case classes
val categorize = udf((person: Person) => {
  val category = person.age match {
    case age if age < 18 => "Minor"
    case age if age < 65 => "Adult"
    case _ => "Senior"
  }
  PersonInfo(person, category)
})

// Use with proper encoders
import spark.implicits._
val peopleDS = Seq(
  Person("Alice", 25),
  Person("Bob", 17),
  Person("Carol", 70)
).toDS()

val categorizedDS = peopleDS.select(categorize(col("value")).alias("info"))

UDF Best Practices

Null handling:

val nullSafeUdf = udf((value: String) => {
  Option(value).map(_.trim.toUpperCase).orNull
})

// Or mark as non-nullable if appropriate
val nonNullUdf = udf((value: String) => value.trim.toUpperCase)
  .asNonNullable()

Testing UDFs:

// Unit test UDFs outside Spark context
val testUdf = (s: String) => s.toUpperCase
assert(testUdf("hello") == "HELLO")

// Integration testing with Spark
val df = Seq("hello", "world").toDF("text")
val result = df.select(upperCase(col("text"))).collect()
assert(result.map(_.getString(0)) sameElements Array("HELLO", "WORLD"))

Documentation and type safety:

/**
 * Calculates compound interest
 * @param principal Initial amount
 * @param rate Annual interest rate (as decimal, e.g., 0.05 for 5%)
 * @param years Number of years
 * @param compoundingPeriods Number of times interest is compounded per year
 * @return Final amount after compound interest
 */
val compoundInterest = udf((principal: Double, rate: Double, years: Int, compoundingPeriods: Int) => {
  require(principal >= 0, "Principal must be non-negative")
  require(rate >= 0, "Interest rate must be non-negative") 
  require(years >= 0, "Years must be non-negative")
  require(compoundingPeriods > 0, "Compounding periods must be positive")
  
  principal * math.pow(1 + rate / compoundingPeriods, compoundingPeriods * years)
})

spark.udf.register("compound_interest", compoundInterest)