0
# User-Defined Functions
1
2
Registration and usage of custom user-defined functions (UDFs) and user-defined aggregate functions (UDAFs). Enables extending Spark SQL with custom business logic and domain-specific operations.
3
4
## Capabilities
5
6
### UDF Registration
7
8
Interface for registering user-defined functions that can be used in SQL and DataFrame operations.
9
10
```scala { .api }
11
/**
12
* Functions for registering user-defined functions
13
*/
14
class UDFRegistration {
15
/** Register UDF with 0 arguments */
16
def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction
17
18
/** Register UDF with 1 argument */
19
def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction
20
21
/** Register UDF with 2 arguments */
22
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](
23
name: String, func: (A1, A2) => RT): UserDefinedFunction
24
25
/** Register UDF with 3 arguments */
26
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
27
name: String, func: (A1, A2, A3) => RT): UserDefinedFunction
28
29
/** Register UDF with 4 arguments */
30
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
31
name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction
32
33
/** Register UDF with 5 arguments */
34
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
35
name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
36
37
// ... continues up to 22 arguments
38
39
/** Register user-defined aggregate function */
40
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction
41
42
/** Register UDF from UserDefinedFunction */
43
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction
44
}
45
```
46
47
### UserDefinedFunction
48
49
Wrapper for user-defined functions that can be applied to columns.
50
51
```scala { .api }
52
/**
53
* User-defined function that can be called in DataFrame operations
54
*/
55
class UserDefinedFunction {
56
/** Apply UDF to columns */
57
def apply(exprs: Column*): Column
58
59
/** Mark UDF as non-deterministic */
60
def asNondeterministic(): UserDefinedFunction
61
62
/** Mark UDF as non-nullable */
63
def asNonNullable(): UserDefinedFunction
64
65
/** UDF name (if registered) */
66
def name: Option[String]
67
68
/** Check if UDF is deterministic */
69
def deterministic: Boolean
70
71
/** Check if UDF is nullable */
72
def nullable: Boolean
73
}
74
75
/**
76
* Factory methods for creating UDFs
77
*/
78
object functions {
79
/** Create UDF with 1 argument */
80
def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction
81
82
/** Create UDF with 2 arguments */
83
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction
84
85
/** Create UDF with 3 arguments */
86
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
87
f: (A1, A2, A3) => RT): UserDefinedFunction
88
89
/** Create UDF with 4 arguments */
90
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
91
f: (A1, A2, A3, A4) => RT): UserDefinedFunction
92
93
/** Create UDF with 5 arguments */
94
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
95
f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
96
97
// ... continues up to 22 arguments
98
}
99
```
100
101
**Usage Examples:**
102
103
```scala
104
import org.apache.spark.sql.functions.udf
105
import org.apache.spark.sql.types._
106
107
// Simple string transformation UDF
108
val upperCase = udf((s: String) => s.toUpperCase)
109
110
// Register for SQL usage
111
spark.udf.register("upper_case", upperCase)
112
113
// Use in DataFrame operations
114
val df = spark.table("employees")
115
val result = df.withColumn("upper_name", upperCase(col("name")))
116
117
// Use in SQL
118
val sqlResult = spark.sql("SELECT name, upper_case(name) as upper_name FROM employees")
119
120
// Complex business logic UDF
121
val calculateBonus = udf((salary: Double, performance: String, years: Int) => {
122
val baseBonus = salary * 0.1
123
val performanceMultiplier = performance match {
124
case "excellent" => 2.0
125
case "good" => 1.5
126
case "average" => 1.0
127
case _ => 0.5
128
}
129
val tenureBonus = if (years > 5) 1.2 else 1.0
130
baseBonus * performanceMultiplier * tenureBonus
131
})
132
133
spark.udf.register("calculate_bonus", calculateBonus)
134
135
val bonusData = df.withColumn("bonus",
136
calculateBonus(col("salary"), col("performance"), col("years_of_service"))
137
)
138
139
// Non-deterministic UDF (e.g., random values)
140
val randomId = udf(() => java.util.UUID.randomUUID().toString)
141
.asNondeterministic()
142
143
val withIds = df.withColumn("random_id", randomId())
144
```
145
146
### UserDefinedAggregateFunction
147
148
Abstract class for custom aggregate functions that work on groups of rows.
149
150
```scala { .api }
151
/**
152
* Base class for user-defined aggregate functions
153
*/
154
abstract class UserDefinedAggregateFunction extends Serializable {
155
/** Input schema for the aggregation function */
156
def inputSchema: StructType
157
158
/** Schema of the aggregation buffer */
159
def bufferSchema: StructType
160
161
/** Output data type */
162
def dataType: DataType
163
164
/** Whether function is deterministic */
165
def deterministic: Boolean
166
167
/** Initialize the aggregation buffer */
168
def initialize(buffer: MutableAggregationBuffer): Unit
169
170
/** Update buffer with new input row */
171
def update(buffer: MutableAggregationBuffer, input: Row): Unit
172
173
/** Merge two aggregation buffers */
174
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
175
176
/** Calculate final result from buffer */
177
def evaluate(buffer: Row): Any
178
}
179
180
/**
181
* Mutable aggregation buffer for UDAF implementations
182
*/
183
trait MutableAggregationBuffer extends Row {
184
/** Update value at index */
185
def update(i: Int, value: Any): Unit
186
}
187
```
188
189
**Usage Example:**
190
191
```scala
192
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
193
import org.apache.spark.sql.types._
194
import org.apache.spark.sql.Row
195
196
// Custom aggregation function to calculate geometric mean
197
class GeometricMean extends UserDefinedAggregateFunction {
198
// Input: one double value
199
def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)
200
201
// Buffer: sum of logs and count
202
def bufferSchema: StructType = StructType(
203
StructField("logSum", DoubleType) ::
204
StructField("count", LongType) :: Nil
205
)
206
207
// Output: double
208
def dataType: DataType = DoubleType
209
def deterministic: Boolean = true
210
211
// Initialize buffer
212
def initialize(buffer: MutableAggregationBuffer): Unit = {
213
buffer(0) = 0.0 // logSum
214
buffer(1) = 0L // count
215
}
216
217
// Update buffer with new value
218
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
219
if (!input.isNullAt(0)) {
220
val value = input.getDouble(0)
221
if (value > 0) {
222
buffer(0) = buffer.getDouble(0) + math.log(value)
223
buffer(1) = buffer.getLong(1) + 1
224
}
225
}
226
}
227
228
// Merge two buffers
229
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
230
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
231
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
232
}
233
234
// Calculate final result
235
def evaluate(buffer: Row): Any = {
236
val count = buffer.getLong(1)
237
if (count == 0) {
238
null
239
} else {
240
math.exp(buffer.getDouble(0) / count)
241
}
242
}
243
}
244
245
// Register and use the UDAF
246
val geometricMean = new GeometricMean
247
spark.udf.register("geom_mean", geometricMean)
248
249
// Use in DataFrame operations
250
val result = spark.table("sales")
251
.groupBy("region")
252
.agg(geometricMean(col("amount")).alias("geom_avg_amount"))
253
254
// Use in SQL
255
val sqlResult = spark.sql("""
256
SELECT region, geom_mean(amount) as geom_avg_amount
257
FROM sales
258
GROUP BY region
259
""")
260
```
261
262
### Advanced UDF Patterns
263
264
**Working with complex types:**
265
266
```scala
267
// UDF that processes arrays
268
val arraySum = udf((arr: Seq[Int]) => if (arr != null) arr.sum else 0)
269
270
// UDF that processes structs
271
val extractField = udf((row: Row) => if (row != null) row.getString(0) else null)
272
273
// UDF that returns complex types
274
val createStruct = udf((name: String, age: Int) => Row(name, age, s"$name-$age"))
275
276
// Register with explicit return type
277
spark.udf.register("create_person", createStruct,
278
StructType(Seq(
279
StructField("name", StringType),
280
StructField("age", IntegerType),
281
StructField("id", StringType)
282
))
283
)
284
```
285
286
**Error handling in UDFs:**
287
288
```scala
289
val safeDiv = udf((a: Double, b: Double) => {
290
try {
291
if (b == 0.0) None else Some(a / b)
292
} catch {
293
case _: Exception => None
294
}
295
})
296
297
// UDF with detailed error information
298
val validateEmail = udf((email: String) => {
299
if (email == null || email.isEmpty) {
300
Map("valid" -> false, "error" -> "Email is empty")
301
} else if (!email.contains("@")) {
302
Map("valid" -> false, "error" -> "Missing @ symbol")
303
} else {
304
Map("valid" -> true, "error" -> "")
305
}
306
})
307
```
308
309
**Performance optimization:**
310
311
```scala
312
// Broadcast variables in UDFs
313
val lookupMap = spark.sparkContext.broadcast(Map(
314
"A" -> "Alpha",
315
"B" -> "Beta",
316
"C" -> "Gamma"
317
))
318
319
val lookupUdf = udf((code: String) => {
320
lookupMap.value.getOrElse(code, "Unknown")
321
})
322
323
// Use closure-free UDFs for better serialization
324
object UDFUtils {
325
def multiply(factor: Double): UserDefinedFunction = {
326
udf((value: Double) => value * factor)
327
}
328
329
def formatCurrency(locale: String): UserDefinedFunction = {
330
udf((amount: Double) => {
331
val formatter = java.text.NumberFormat.getCurrencyInstance(
332
java.util.Locale.forLanguageTag(locale)
333
)
334
formatter.format(amount)
335
})
336
}
337
}
338
339
val doubleValue = UDFUtils.multiply(2.0)
340
val formatUSD = UDFUtils.formatCurrency("en-US")
341
```
342
343
**Typed UDFs with case classes:**
344
345
```scala
346
case class Person(name: String, age: Int)
347
case class PersonInfo(person: Person, category: String)
348
349
// Typed UDF working with case classes
350
val categorize = udf((person: Person) => {
351
val category = person.age match {
352
case age if age < 18 => "Minor"
353
case age if age < 65 => "Adult"
354
case _ => "Senior"
355
}
356
PersonInfo(person, category)
357
})
358
359
// Use with proper encoders
360
import spark.implicits._
361
val peopleDS = Seq(
362
Person("Alice", 25),
363
Person("Bob", 17),
364
Person("Carol", 70)
365
).toDS()
366
367
val categorizedDS = peopleDS.select(categorize(col("value")).alias("info"))
368
```
369
370
### UDF Best Practices
371
372
**Null handling:**
373
374
```scala
375
val nullSafeUdf = udf((value: String) => {
376
Option(value).map(_.trim.toUpperCase).orNull
377
})
378
379
// Or mark as non-nullable if appropriate
380
val nonNullUdf = udf((value: String) => value.trim.toUpperCase)
381
.asNonNullable()
382
```
383
384
**Testing UDFs:**
385
386
```scala
387
// Unit test UDFs outside Spark context
388
val testUdf = (s: String) => s.toUpperCase
389
assert(testUdf("hello") == "HELLO")
390
391
// Integration testing with Spark
392
val df = Seq("hello", "world").toDF("text")
393
val result = df.select(upperCase(col("text"))).collect()
394
assert(result.map(_.getString(0)) sameElements Array("HELLO", "WORLD"))
395
```
396
397
**Documentation and type safety:**
398
399
```scala
400
/**
401
* Calculates compound interest
402
* @param principal Initial amount
403
* @param rate Annual interest rate (as decimal, e.g., 0.05 for 5%)
404
* @param years Number of years
405
* @param compoundingPeriods Number of times interest is compounded per year
406
* @return Final amount after compound interest
407
*/
408
val compoundInterest = udf((principal: Double, rate: Double, years: Int, compoundingPeriods: Int) => {
409
require(principal >= 0, "Principal must be non-negative")
410
require(rate >= 0, "Interest rate must be non-negative")
411
require(years >= 0, "Years must be non-negative")
412
require(compoundingPeriods > 0, "Compounding periods must be positive")
413
414
principal * math.pow(1 + rate / compoundingPeriods, compoundingPeriods * years)
415
})
416
417
spark.udf.register("compound_interest", compoundInterest)
418
```