0
# Apache Spark SQL - User-Defined Functions (UDFs)
1
2
## Capabilities
3
4
### User-Defined Function Registration and Management
5
- Register scalar user-defined functions (UDFs) with support for 0-22 parameters and type-safe operation
6
- Create and manage user-defined aggregate functions (UDAFs) for custom aggregation logic
7
- Support for both temporary session-scoped and persistent catalog-scoped function registration
8
- Handle function overloading and parameter type checking with comprehensive error reporting
9
10
### Type-Safe Function Development
11
- Define functions with explicit input and output types using Spark's type system and encoders
12
- Support for complex data types including arrays, maps, structs, and user-defined types in function signatures
13
- Enable null handling and error propagation with configurable behavior for edge cases
14
- Provide compile-time type checking for function parameters and return values
15
16
### Function Optimization and Performance
17
- Control function determinism flags for query optimization and caching behavior
18
- Support for code generation and vectorized execution for high-performance function evaluation
19
- Handle broadcast variables and accumulators within UDF context for distributed computation
20
- Enable function pushdown and predicate optimization where supported by data sources
21
22
### Advanced Function Features
23
- Create column-oriented functions that operate on entire columns for vectorized processing
24
- Support for higher-order functions and functional programming patterns within SQL expressions
25
- Handle state management and context passing for complex stateful function implementations
26
- Enable function composition and chaining for building complex transformation pipelines
27
28
## API Reference
29
30
### UDFRegistration Class
31
```scala { .api }
32
abstract class UDFRegistration {
33
// Scalar UDF registration (0-22 parameters)
34
def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction
35
def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction
36
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: (A1, A2) => RT): UserDefinedFunction
37
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: (A1, A2, A3) => RT): UserDefinedFunction
38
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction
39
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
40
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction
41
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction
42
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction
43
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction
44
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction
45
46
// Register with custom UserDefinedFunction
47
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction
48
49
// UDAF registration
50
def register[BT, RT: TypeTag](name: String, udaf: expressions.UserDefinedAggregateFunction): Unit
51
}
52
```
53
54
### UserDefinedFunction Class
55
```scala { .api }
56
case class UserDefinedFunction protected[sql] (
57
f: AnyRef,
58
dataType: DataType,
59
inputTypes: Seq[DataType],
60
name: Option[String] = None,
61
nullable: Boolean = true,
62
deterministic: Boolean = true) {
63
64
// Function application
65
def apply(exprs: Column*): Column
66
67
// Configuration methods
68
def withName(name: String): UserDefinedFunction
69
def asNonNullable(): UserDefinedFunction
70
def asNondeterministic(): UserDefinedFunction
71
72
// Type information
73
def isNullable: Boolean
74
def isDeterministic: Boolean
75
}
76
```
77
78
### functions Object UDF Creation
79
```scala { .api }
80
object functions {
81
// UDF creation functions (0-22 parameters)
82
def udf[RT: TypeTag](f: () => RT): UserDefinedFunction
83
def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction
84
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction
85
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: (A1, A2, A3) => RT): UserDefinedFunction
86
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: (A1, A2, A3, A4) => RT): UserDefinedFunction
87
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction
88
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction
89
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction
90
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction
91
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction
92
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction
93
94
// Call registered UDF
95
def callUDF(udfName: String, cols: Column*): Column
96
}
97
```
98
99
### User-Defined Aggregate Functions (UDAFs)
100
```scala { .api }
101
// Base class for typed UDAFs
102
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
103
// Buffer operations
104
def zero: BUF
105
def reduce(buffer: BUF, input: IN): BUF
106
def merge(buffer1: BUF, buffer2: BUF): BUF
107
def finish(buffer: BUF): OUT
108
109
// Encoders
110
def bufferEncoder: Encoder[BUF]
111
def outputEncoder: Encoder[OUT]
112
113
// Convert to UDAF
114
def toColumn: TypedColumn[IN, OUT]
115
}
116
117
// Untyped UDAF base class
118
abstract class UserDefinedAggregateFunction extends Expression with ImplicitCastInputTypes {
119
// Data types
120
def inputTypes: Seq[AbstractDataType]
121
def bufferSchema: StructType
122
def dataType: DataType
123
def deterministic: Boolean
124
125
// Aggregation operations
126
def initialize(buffer: InternalRow): Unit
127
def update(buffer: InternalRow, input: InternalRow): Unit
128
def merge(buffer1: InternalRow, buffer2: InternalRow): Unit
129
def evaluate(buffer: InternalRow): Any
130
}
131
```
132
133
### Column Functions for UDFs
134
```scala { .api }
135
// Column operations that support UDFs
136
class Column {
137
// Transform operations
138
def transform(f: Column => Column): Column
139
140
// Higher-order function support
141
def filter(f: Column => Column): Column
142
def exists(f: Column => Column): Column
143
def forall(f: Column => Column): Column
144
def aggregate(initialValue: Column, merge: (Column, Column) => Column): Column
145
def aggregate(initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column
146
}
147
148
// Higher-order functions
149
object functions {
150
def transform(column: Column, function: Column => Column): Column
151
def filter(column: Column, function: Column => Column): Column
152
def exists(column: Column, function: Column => Column): Column
153
def forall(column: Column, function: Column => Column): Column
154
def aggregate(column: Column, initialValue: Column, merge: (Column, Column) => Column): Column
155
def aggregate(column: Column, initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column
156
def array_sort(e: Column, comparator: (Column, Column) => Column): Column
157
}
158
```
159
160
## Usage Examples
161
162
### Basic UDF Creation and Registration
163
```scala
164
import org.apache.spark.sql.{SparkSession, functions => F}
165
import org.apache.spark.sql.functions.udf
166
import org.apache.spark.sql.types._
167
168
val spark = SparkSession.builder()
169
.appName("UDF Examples")
170
.getOrCreate()
171
172
import spark.implicits._
173
174
// Sample data
175
val employeeData = Seq(
176
("Alice Johnson", "alice.johnson@company.com", 25, 75000.0),
177
("Bob Smith", "bob.smith@company.com", 30, 85000.0),
178
("Charlie Brown", "charlie.brown@company.com", 35, 95000.0)
179
).toDF("name", "email", "age", "salary")
180
181
// Simple UDF examples
182
val extractFirstName = udf((fullName: String) => {
183
if (fullName != null && fullName.contains(" ")) {
184
fullName.split(" ")(0)
185
} else {
186
fullName
187
}
188
})
189
190
val calculateBonus = udf((salary: Double, age: Int) => {
191
val ageMultiplier = if (age > 30) 0.15 else 0.10
192
salary * ageMultiplier
193
})
194
195
val isValidEmail = udf((email: String) => {
196
email != null && email.contains("@") && email.contains(".")
197
})
198
199
// Apply UDFs to DataFrame
200
val enrichedData = employeeData.select(
201
$"name",
202
extractFirstName($"name").as("first_name"),
203
$"email",
204
isValidEmail($"email").as("valid_email"),
205
$"age",
206
$"salary",
207
calculateBonus($"salary", $"age").as("bonus")
208
)
209
210
enrichedData.show()
211
212
// Register UDFs for SQL usage
213
spark.udf.register("extract_first_name", extractFirstName)
214
spark.udf.register("calculate_bonus", calculateBonus)
215
spark.udf.register("is_valid_email", isValidEmail)
216
217
// Use registered UDFs in SQL
218
employeeData.createOrReplaceTempView("employees")
219
220
val sqlResult = spark.sql("""
221
SELECT
222
name,
223
extract_first_name(name) as first_name,
224
email,
225
is_valid_email(email) as valid_email,
226
age,
227
salary,
228
calculate_bonus(salary, age) as bonus
229
FROM employees
230
""")
231
232
sqlResult.show()
233
```
234
235
### Advanced UDF Features
236
```scala
237
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
238
239
// UDF with complex return type
240
case class PersonInfo(firstName: String, domain: String, salaryGrade: String)
241
242
val extractPersonInfo = udf((name: String, email: String, salary: Double) => {
243
val firstName = if (name != null && name.contains(" ")) {
244
name.split(" ")(0)
245
} else {
246
name
247
}
248
249
val domain = if (email != null && email.contains("@")) {
250
email.split("@")(1)
251
} else {
252
"unknown"
253
}
254
255
val salaryGrade = salary match {
256
case s if s < 50000 => "Entry"
257
case s if s < 80000 => "Mid"
258
case s if s < 100000 => "Senior"
259
case _ => "Executive"
260
}
261
262
PersonInfo(firstName, domain, salaryGrade)
263
})
264
265
// Apply complex UDF
266
val personInfoData = employeeData.select(
267
$"name",
268
$"email",
269
$"salary",
270
extractPersonInfo($"name", $"email", $"salary").as("person_info")
271
).select(
272
$"name",
273
$"email",
274
$"salary",
275
$"person_info.firstName",
276
$"person_info.domain",
277
$"person_info.salaryGrade"
278
)
279
280
personInfoData.show()
281
282
// UDF with array input and output
283
val extractNameParts = udf((fullName: String) => {
284
if (fullName != null) {
285
fullName.split(" ").toSeq
286
} else {
287
Seq.empty[String]
288
}
289
})
290
291
val joinNameParts = udf((parts: Seq[String]) => {
292
if (parts != null && parts.nonEmpty) {
293
parts.mkString(" ")
294
} else {
295
null
296
}
297
})
298
299
// UDF with Map input/output
300
val parseEmailInfo = udf((email: String) => {
301
if (email != null && email.contains("@")) {
302
val parts = email.split("@")
303
Map(
304
"username" -> parts(0),
305
"domain" -> parts(1),
306
"is_company_email" -> parts(1).contains("company")
307
)
308
} else {
309
Map.empty[String, Any]
310
}
311
})
312
313
val arrayMapData = employeeData.select(
314
$"name",
315
$"email",
316
extractNameParts($"name").as("name_parts"),
317
parseEmailInfo($"email").as("email_info")
318
)
319
320
arrayMapData.show(truncate = false)
321
```
322
323
### Null Handling and Error Management
324
```scala
325
// UDF with null handling
326
val safeDivide = udf((numerator: Double, denominator: Double) => {
327
if (denominator != 0.0) {
328
Some(numerator / denominator)
329
} else {
330
None
331
}
332
})
333
334
// UDF with error handling
335
val parseAge = udf((ageStr: String) => {
336
try {
337
if (ageStr != null && ageStr.nonEmpty) {
338
Some(ageStr.toInt)
339
} else {
340
None
341
}
342
} catch {
343
case _: NumberFormatException => None
344
}
345
})
346
347
// UDF with validation
348
val validateSalary = udf((salary: Double) => {
349
salary match {
350
case s if s < 0 => ("invalid", "Salary cannot be negative")
351
case s if s > 1000000 => ("warning", "Salary seems unusually high")
352
case s if s < 20000 => ("warning", "Salary seems unusually low")
353
case _ => ("valid", "Salary is within normal range")
354
}
355
})
356
357
// Test data with nulls and errors
358
val testData = Seq(
359
("John", "25", 50000.0),
360
("Jane", "abc", 75000.0), // Invalid age
361
("Bob", null, -1000.0), // Null age, negative salary
362
("Alice", "30", 2000000.0) // High salary
363
).toDF("name", "age_str", "salary")
364
365
val validatedData = testData.select(
366
$"name",
367
$"age_str",
368
parseAge($"age_str").as("parsed_age"),
369
$"salary",
370
validateSalary($"salary").as("salary_validation"),
371
safeDivide($"salary", lit(12.0)).as("monthly_salary")
372
)
373
374
validatedData.show(truncate = false)
375
376
// Non-nullable and deterministic UDFs
377
val strictUpperCase = udf((input: String) => {
378
input.toUpperCase
379
}).asNonNullable()
380
381
val deterministicHash = udf((input: String) => {
382
input.hashCode
383
}).asNondeterministic() // Mark as non-deterministic if using random elements
384
385
val nonDeterministicId = udf(() => {
386
java.util.UUID.randomUUID().toString
387
}).asNondeterministic()
388
```
389
390
### User-Defined Aggregate Functions (UDAFs)
391
```scala
392
import org.apache.spark.sql.{Encoder, Encoders}
393
import org.apache.spark.sql.expressions.Aggregator
394
395
// Simple UDAF: Calculate geometric mean
396
class GeometricMean extends Aggregator[Double, (Double, Long), Double] {
397
def zero: (Double, Long) = (1.0, 0L)
398
399
def reduce(buffer: (Double, Long), input: Double): (Double, Long) = {
400
if (input > 0) {
401
(buffer._1 * input, buffer._2 + 1)
402
} else {
403
buffer
404
}
405
}
406
407
def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
408
(b1._1 * b2._1, b1._2 + b2._2)
409
}
410
411
def finish(buffer: (Double, Long)): Double = {
412
if (buffer._2 > 0) {
413
math.pow(buffer._1, 1.0 / buffer._2)
414
} else {
415
0.0
416
}
417
}
418
419
def bufferEncoder: Encoder[(Double, Long)] = Encoders.product
420
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
421
}
422
423
// Complex UDAF: Calculate summary statistics
424
case class StatsSummary(
425
count: Long,
426
sum: Double,
427
sumSquares: Double,
428
min: Double,
429
max: Double
430
) {
431
def mean: Double = if (count > 0) sum / count else 0.0
432
def variance: Double = if (count > 1) (sumSquares - sum * sum / count) / (count - 1) else 0.0
433
def stddev: Double = math.sqrt(variance)
434
}
435
436
class SummaryStats extends Aggregator[Double, StatsSummary, StatsSummary] {
437
def zero: StatsSummary = StatsSummary(0L, 0.0, 0.0, Double.MaxValue, Double.MinValue)
438
439
def reduce(buffer: StatsSummary, input: Double): StatsSummary = {
440
StatsSummary(
441
count = buffer.count + 1,
442
sum = buffer.sum + input,
443
sumSquares = buffer.sumSquares + input * input,
444
min = math.min(buffer.min, input),
445
max = math.max(buffer.max, input)
446
)
447
}
448
449
def merge(b1: StatsSummary, b2: StatsSummary): StatsSummary = {
450
if (b1.count == 0) b2
451
else if (b2.count == 0) b1
452
else StatsSummary(
453
count = b1.count + b2.count,
454
sum = b1.sum + b2.sum,
455
sumSquares = b1.sumSquares + b2.sumSquares,
456
min = math.min(b1.min, b2.min),
457
max = math.max(b1.max, b2.max)
458
)
459
}
460
461
def finish(buffer: StatsSummary): StatsSummary = {
462
if (buffer.count == 0) {
463
buffer.copy(min = 0.0, max = 0.0)
464
} else {
465
buffer
466
}
467
}
468
469
def bufferEncoder: Encoder[StatsSummary] = Encoders.product
470
def outputEncoder: Encoder[StatsSummary] = Encoders.product
471
}
472
473
// Register and use UDAFs
474
val geometricMean = new GeometricMean().toColumn
475
val summaryStats = new SummaryStats().toColumn
476
477
// Apply UDAFs
478
val aggregatedData = employeeData
479
.agg(
480
geometricMean.name("geometric_mean_salary"),
481
summaryStats.name("salary_stats")
482
)
483
484
aggregatedData.show()
485
486
// Extract fields from complex UDAF result
487
val detailedStats = employeeData
488
.agg(summaryStats.name("stats"))
489
.select(
490
$"stats.count",
491
$"stats.sum",
492
$"stats.mean",
493
$"stats.variance",
494
$"stats.stddev",
495
$"stats.min",
496
$"stats.max"
497
)
498
499
detailedStats.show()
500
```
501
502
### Higher-Order Functions and Functional UDFs
503
```scala
504
import org.apache.spark.sql.functions._
505
506
// Data with arrays for higher-order function examples
507
val arrayData = Seq(
508
("Alice", Array(85, 90, 88, 92)),
509
("Bob", Array(78, 85, 80, 88)),
510
("Charlie", Array(95, 88, 90, 94))
511
).toDF("name", "test_scores")
512
513
// UDF for array transformation
514
val scaleScore = udf((score: Int, scaleFactor: Double) => (score * scaleFactor).toInt)
515
val calculateGrade = udf((score: Int) => {
516
score match {
517
case s if s >= 90 => "A"
518
case s if s >= 80 => "B"
519
case s if s >= 70 => "C"
520
case s if s >= 60 => "D"
521
case _ => "F"
522
}
523
})
524
525
// Higher-order function usage with UDFs
526
val transformedScores = arrayData.select(
527
$"name",
528
$"test_scores",
529
// Transform each score
530
transform($"test_scores", score => scaleScore(score, lit(1.05))).as("scaled_scores"),
531
// Filter scores above threshold
532
filter($"test_scores", score => score > 85).as("high_scores"),
533
// Check if all scores are passing
534
forall($"test_scores", score => score >= 60).as("all_passing"),
535
// Check if any score is excellent
536
exists($"test_scores", score => score >= 95).as("has_excellent")
537
)
538
539
transformedScores.show(truncate = false)
540
541
// Complex aggregation with UDFs
542
val scoreAnalysis = arrayData.select(
543
$"name",
544
$"test_scores",
545
// Calculate weighted average (more recent tests have higher weight)
546
aggregate($"test_scores", lit(0.0),
547
(acc, score) => acc + score * lit(1.0)).as("sum_scores"),
548
// Find maximum score
549
aggregate($"test_scores", lit(0),
550
(acc, score) => greatest(acc, score)).as("max_score"),
551
// Calculate grade distribution
552
transform($"test_scores", score => calculateGrade(score)).as("letter_grades")
553
)
554
555
scoreAnalysis.show(truncate = false)
556
```
557
558
### Performance Optimization and Best Practices
559
```scala
560
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
561
import org.apache.spark.util.AccumulatorV2
562
563
// Optimized UDF with broadcast variables
564
val taxRates = Map("US" -> 0.25, "UK" -> 0.20, "CA" -> 0.15, "DE" -> 0.30)
565
val broadcastTaxRates = spark.sparkContext.broadcast(taxRates)
566
567
val calculateNetSalary = udf((grossSalary: Double, country: String) => {
568
val taxRate = broadcastTaxRates.value.getOrElse(country, 0.0)
569
grossSalary * (1.0 - taxRate)
570
})
571
572
// Sample international employee data
573
val intlEmployees = Seq(
574
("Alice", 75000.0, "US"),
575
("Bob", 65000.0, "UK"),
576
("Charlie", 70000.0, "CA"),
577
("Diana", 80000.0, "DE")
578
).toDF("name", "gross_salary", "country")
579
580
val netSalaries = intlEmployees.select(
581
$"name",
582
$"gross_salary",
583
$"country",
584
calculateNetSalary($"gross_salary", $"country").as("net_salary")
585
)
586
587
netSalaries.show()
588
589
// UDF with accumulator for monitoring
590
class StringAccumulator extends AccumulatorV2[String, String] {
591
private var _value = ""
592
593
def isZero: Boolean = _value.isEmpty
594
def copy(): StringAccumulator = {
595
val acc = new StringAccumulator
596
acc._value = _value
597
acc
598
}
599
def reset(): Unit = _value = ""
600
def add(v: String): Unit = _value += v + "\n"
601
def merge(other: AccumulatorV2[String, String]): Unit = _value += other.value
602
def value: String = _value
603
}
604
605
val errorAccumulator = new StringAccumulator
606
spark.sparkContext.register(errorAccumulator, "UDF Errors")
607
608
val robustProcessing = udf((data: String) => {
609
try {
610
// Simulate complex processing
611
if (data == null || data.isEmpty) {
612
errorAccumulator.add(s"Empty data encountered")
613
"EMPTY"
614
} else if (data.length < 3) {
615
errorAccumulator.add(s"Short data: $data")
616
"SHORT"
617
} else {
618
data.toUpperCase
619
}
620
} catch {
621
case e: Exception =>
622
errorAccumulator.add(s"Error processing '$data': ${e.getMessage}")
623
"ERROR"
624
}
625
})
626
627
// Vectorized UDF for better performance (when possible)
628
val vectorizedStringLength = udf((strings: Seq[String]) => {
629
strings.map(s => if (s != null) s.length else 0)
630
})
631
632
// Batch processing example
633
val testStrings = Seq(
634
Array("hello", "world", null, "spark", "sql"),
635
Array("scala", "java", "", "python"),
636
Array(null, "data", "engineering")
637
).toDF("string_array")
638
639
val lengthResults = testStrings.select(
640
$"string_array",
641
vectorizedStringLength($"string_array").as("lengths")
642
)
643
644
lengthResults.show(truncate = false)
645
646
// UDF lifecycle management
647
object UDFRegistry {
648
private val registeredUDFs = scala.collection.mutable.Set[String]()
649
650
def registerUDF[T](name: String, udf: UserDefinedFunction)(implicit spark: SparkSession): Unit = {
651
spark.udf.register(name, udf)
652
registeredUDFs += name
653
println(s"Registered UDF: $name")
654
}
655
656
def listRegisteredUDFs(): Set[String] = registeredUDFs.toSet
657
658
def cleanup()(implicit spark: SparkSession): Unit = {
659
registeredUDFs.foreach { name =>
660
try {
661
// Note: Spark doesn't have built-in UDF unregistration,
662
// so we track them for documentation purposes
663
println(s"UDF registered in session: $name")
664
} catch {
665
case e: Exception =>
666
println(s"Note: UDF $name may need manual cleanup: ${e.getMessage}")
667
}
668
}
669
}
670
}
671
672
// Register UDFs through registry
673
UDFRegistry.registerUDF("calculate_net_salary", calculateNetSalary)
674
UDFRegistry.registerUDF("robust_processing", robustProcessing)
675
676
// Use registered UDFs
677
intlEmployees.createOrReplaceTempView("international_employees")
678
679
val sqlWithUDF = spark.sql("""
680
SELECT
681
name,
682
gross_salary,
683
country,
684
calculate_net_salary(gross_salary, country) as net_salary
685
FROM international_employees
686
""")
687
688
sqlWithUDF.show()
689
690
// Display UDF registry status
691
println("Registered UDFs: " + UDFRegistry.listRegisteredUDFs().mkString(", "))
692
println("Error accumulator value:")
693
println(errorAccumulator.value)
694
695
// Cleanup
696
UDFRegistry.cleanup()
697
```
698
699
### UDF Testing and Validation
700
```scala
701
import org.apache.spark.sql.test.SharedSparkSession // In test environment
702
703
// UDF testing framework
704
class UDFTester(spark: SparkSession) {
705
import spark.implicits._
706
707
def testUDF[T, R](udf: UserDefinedFunction, testCases: Seq[(T, R)]): Boolean = {
708
val testData = testCases.map(_._1).toDF("input")
709
val expectedResults = testCases.map(_._2)
710
711
val results = testData.select(udf($"input").as("result")).collect().map(_.getAs[R]("result"))
712
713
val passed = results.zip(expectedResults).forall { case (actual, expected) =>
714
actual == expected
715
}
716
717
if (passed) {
718
println(s"✅ All ${testCases.length} test cases passed")
719
} else {
720
println(s"❌ Some test cases failed")
721
results.zip(expectedResults).zipWithIndex.foreach { case ((actual, expected), idx) =>
722
if (actual != expected) {
723
println(s" Test case $idx: expected $expected, got $actual")
724
}
725
}
726
}
727
728
passed
729
}
730
731
def benchmarkUDF[T](udf: UserDefinedFunction, data: DataFrame, column: String, iterations: Int = 5): Unit = {
732
val times = (1 to iterations).map { _ =>
733
val start = System.nanoTime()
734
data.select(udf($"$column")).collect()
735
val end = System.nanoTime()
736
(end - start) / 1000000.0 // Convert to milliseconds
737
}
738
739
val avgTime = times.sum / times.length
740
val minTime = times.min
741
val maxTime = times.max
742
743
println(s"UDF Performance (${iterations} iterations):")
744
println(s" Average: ${avgTime}ms")
745
println(s" Min: ${minTime}ms")
746
println(s" Max: ${maxTime}ms")
747
}
748
}
749
750
// Test the UDFs
751
val tester = new UDFTester(spark)
752
753
// Test extract first name UDF
754
val firstNameTests = Seq(
755
("John Doe", "John"),
756
("Alice", "Alice"),
757
("Bob Smith Jr", "Bob"),
758
(null, null),
759
("", "")
760
)
761
762
tester.testUDF(extractFirstName, firstNameTests)
763
764
// Test calculate bonus UDF
765
val bonusTests = Seq(
766
((50000.0, 25), 5000.0),
767
((80000.0, 35), 12000.0),
768
((60000.0, 30), 6000.0)
769
)
770
771
val bonusUDF = udf((salary: Double, age: Int) => {
772
val ageMultiplier = if (age > 30) 0.15 else 0.10
773
salary * ageMultiplier
774
})
775
776
tester.testUDF(bonusUDF, bonusTests)
777
778
// Performance benchmarking
779
val largeDataset = (1 to 10000).map(i => s"Name $i").toDF("name")
780
tester.benchmarkUDF(extractFirstName, largeDataset, "name")
781
782
// UDF documentation generator
783
def generateUDFDocumentation(udfs: Map[String, (UserDefinedFunction, String)]): String = {
784
val docs = udfs.map { case (name, (udf, description)) =>
785
s"""
786
|## $name
787
|
788
|**Description**: $description
789
|
790
|**Input Types**: ${udf.inputTypes.mkString(", ")}
791
|**Output Type**: ${udf.dataType}
792
|**Nullable**: ${udf.nullable}
793
|**Deterministic**: ${udf.deterministic}
794
|
795
""".stripMargin
796
}.mkString("\n")
797
798
s"""
799
|# UDF Documentation
800
|
801
|$docs
802
""".stripMargin
803
}
804
805
val udfDocs = Map(
806
"extract_first_name" -> (extractFirstName, "Extracts the first name from a full name string"),
807
"calculate_bonus" -> (bonusUDF, "Calculates employee bonus based on salary and age"),
808
"is_valid_email" -> (isValidEmail, "Validates email format")
809
)
810
811
println(generateUDFDocumentation(udfDocs))
812
```