or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

catalog.mddata-sources.mddata-types.mddataframe-dataset.mdindex.mdsession-management.mdsql-functions.mdstreaming.mdudfs.md

udfs.mddocs/

0

# Apache Spark SQL - User-Defined Functions (UDFs)

1

2

## Capabilities

3

4

### User-Defined Function Registration and Management

5

- Register scalar user-defined functions (UDFs) with support for 0-22 parameters and type-safe operation

6

- Create and manage user-defined aggregate functions (UDAFs) for custom aggregation logic

7

- Support for both temporary session-scoped and persistent catalog-scoped function registration

8

- Handle function overloading and parameter type checking with comprehensive error reporting

9

10

### Type-Safe Function Development

11

- Define functions with explicit input and output types using Spark's type system and encoders

12

- Support for complex data types including arrays, maps, structs, and user-defined types in function signatures

13

- Enable null handling and error propagation with configurable behavior for edge cases

14

- Provide compile-time type checking for function parameters and return values

15

16

### Function Optimization and Performance

17

- Control function determinism flags for query optimization and caching behavior

18

- Support for code generation and vectorized execution for high-performance function evaluation

19

- Handle broadcast variables and accumulators within UDF context for distributed computation

20

- Enable function pushdown and predicate optimization where supported by data sources

21

22

### Advanced Function Features

23

- Create column-oriented functions that operate on entire columns for vectorized processing

24

- Support for higher-order functions and functional programming patterns within SQL expressions

25

- Handle state management and context passing for complex stateful function implementations

26

- Enable function composition and chaining for building complex transformation pipelines

27

28

## API Reference

29

30

### UDFRegistration Class

31

```scala { .api }

32

abstract class UDFRegistration {

33

// Scalar UDF registration (0-22 parameters)

34

def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction

35

def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction

36

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: (A1, A2) => RT): UserDefinedFunction

37

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: (A1, A2, A3) => RT): UserDefinedFunction

38

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction

39

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction

40

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction

41

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction

42

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction

43

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction

44

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction

45

46

// Register with custom UserDefinedFunction

47

def register(name: String, udf: UserDefinedFunction): UserDefinedFunction

48

49

// UDAF registration

50

def register[BT, RT: TypeTag](name: String, udaf: expressions.UserDefinedAggregateFunction): Unit

51

}

52

```

53

54

### UserDefinedFunction Class

55

```scala { .api }

56

case class UserDefinedFunction protected[sql] (

57

f: AnyRef,

58

dataType: DataType,

59

inputTypes: Seq[DataType],

60

name: Option[String] = None,

61

nullable: Boolean = true,

62

deterministic: Boolean = true) {

63

64

// Function application

65

def apply(exprs: Column*): Column

66

67

// Configuration methods

68

def withName(name: String): UserDefinedFunction

69

def asNonNullable(): UserDefinedFunction

70

def asNondeterministic(): UserDefinedFunction

71

72

// Type information

73

def isNullable: Boolean

74

def isDeterministic: Boolean

75

}

76

```

77

78

### functions Object UDF Creation

79

```scala { .api }

80

object functions {

81

// UDF creation functions (0-22 parameters)

82

def udf[RT: TypeTag](f: () => RT): UserDefinedFunction

83

def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction

84

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction

85

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: (A1, A2, A3) => RT): UserDefinedFunction

86

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: (A1, A2, A3, A4) => RT): UserDefinedFunction

87

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction

88

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction

89

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction

90

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction

91

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction

92

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction

93

94

// Call registered UDF

95

def callUDF(udfName: String, cols: Column*): Column

96

}

97

```

98

99

### User-Defined Aggregate Functions (UDAFs)

100

```scala { .api }

101

// Base class for typed UDAFs

102

abstract class Aggregator[-IN, BUF, OUT] extends Serializable {

103

// Buffer operations

104

def zero: BUF

105

def reduce(buffer: BUF, input: IN): BUF

106

def merge(buffer1: BUF, buffer2: BUF): BUF

107

def finish(buffer: BUF): OUT

108

109

// Encoders

110

def bufferEncoder: Encoder[BUF]

111

def outputEncoder: Encoder[OUT]

112

113

// Convert to UDAF

114

def toColumn: TypedColumn[IN, OUT]

115

}

116

117

// Untyped UDAF base class

118

abstract class UserDefinedAggregateFunction extends Expression with ImplicitCastInputTypes {

119

// Data types

120

def inputTypes: Seq[AbstractDataType]

121

def bufferSchema: StructType

122

def dataType: DataType

123

def deterministic: Boolean

124

125

// Aggregation operations

126

def initialize(buffer: InternalRow): Unit

127

def update(buffer: InternalRow, input: InternalRow): Unit

128

def merge(buffer1: InternalRow, buffer2: InternalRow): Unit

129

def evaluate(buffer: InternalRow): Any

130

}

131

```

132

133

### Column Functions for UDFs

134

```scala { .api }

135

// Column operations that support UDFs

136

class Column {

137

// Transform operations

138

def transform(f: Column => Column): Column

139

140

// Higher-order function support

141

def filter(f: Column => Column): Column

142

def exists(f: Column => Column): Column

143

def forall(f: Column => Column): Column

144

def aggregate(initialValue: Column, merge: (Column, Column) => Column): Column

145

def aggregate(initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column

146

}

147

148

// Higher-order functions

149

object functions {

150

def transform(column: Column, function: Column => Column): Column

151

def filter(column: Column, function: Column => Column): Column

152

def exists(column: Column, function: Column => Column): Column

153

def forall(column: Column, function: Column => Column): Column

154

def aggregate(column: Column, initialValue: Column, merge: (Column, Column) => Column): Column

155

def aggregate(column: Column, initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column

156

def array_sort(e: Column, comparator: (Column, Column) => Column): Column

157

}

158

```

159

160

## Usage Examples

161

162

### Basic UDF Creation and Registration

163

```scala

164

import org.apache.spark.sql.{SparkSession, functions => F}

165

import org.apache.spark.sql.functions.udf

166

import org.apache.spark.sql.types._

167

168

val spark = SparkSession.builder()

169

.appName("UDF Examples")

170

.getOrCreate()

171

172

import spark.implicits._

173

174

// Sample data

175

val employeeData = Seq(

176

("Alice Johnson", "alice.johnson@company.com", 25, 75000.0),

177

("Bob Smith", "bob.smith@company.com", 30, 85000.0),

178

("Charlie Brown", "charlie.brown@company.com", 35, 95000.0)

179

).toDF("name", "email", "age", "salary")

180

181

// Simple UDF examples

182

val extractFirstName = udf((fullName: String) => {

183

if (fullName != null && fullName.contains(" ")) {

184

fullName.split(" ")(0)

185

} else {

186

fullName

187

}

188

})

189

190

val calculateBonus = udf((salary: Double, age: Int) => {

191

val ageMultiplier = if (age > 30) 0.15 else 0.10

192

salary * ageMultiplier

193

})

194

195

val isValidEmail = udf((email: String) => {

196

email != null && email.contains("@") && email.contains(".")

197

})

198

199

// Apply UDFs to DataFrame

200

val enrichedData = employeeData.select(

201

$"name",

202

extractFirstName($"name").as("first_name"),

203

$"email",

204

isValidEmail($"email").as("valid_email"),

205

$"age",

206

$"salary",

207

calculateBonus($"salary", $"age").as("bonus")

208

)

209

210

enrichedData.show()

211

212

// Register UDFs for SQL usage

213

spark.udf.register("extract_first_name", extractFirstName)

214

spark.udf.register("calculate_bonus", calculateBonus)

215

spark.udf.register("is_valid_email", isValidEmail)

216

217

// Use registered UDFs in SQL

218

employeeData.createOrReplaceTempView("employees")

219

220

val sqlResult = spark.sql("""

221

SELECT

222

name,

223

extract_first_name(name) as first_name,

224

email,

225

is_valid_email(email) as valid_email,

226

age,

227

salary,

228

calculate_bonus(salary, age) as bonus

229

FROM employees

230

""")

231

232

sqlResult.show()

233

```

234

235

### Advanced UDF Features

236

```scala

237

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

238

239

// UDF with complex return type

240

case class PersonInfo(firstName: String, domain: String, salaryGrade: String)

241

242

val extractPersonInfo = udf((name: String, email: String, salary: Double) => {

243

val firstName = if (name != null && name.contains(" ")) {

244

name.split(" ")(0)

245

} else {

246

name

247

}

248

249

val domain = if (email != null && email.contains("@")) {

250

email.split("@")(1)

251

} else {

252

"unknown"

253

}

254

255

val salaryGrade = salary match {

256

case s if s < 50000 => "Entry"

257

case s if s < 80000 => "Mid"

258

case s if s < 100000 => "Senior"

259

case _ => "Executive"

260

}

261

262

PersonInfo(firstName, domain, salaryGrade)

263

})

264

265

// Apply complex UDF

266

val personInfoData = employeeData.select(

267

$"name",

268

$"email",

269

$"salary",

270

extractPersonInfo($"name", $"email", $"salary").as("person_info")

271

).select(

272

$"name",

273

$"email",

274

$"salary",

275

$"person_info.firstName",

276

$"person_info.domain",

277

$"person_info.salaryGrade"

278

)

279

280

personInfoData.show()

281

282

// UDF with array input and output

283

val extractNameParts = udf((fullName: String) => {

284

if (fullName != null) {

285

fullName.split(" ").toSeq

286

} else {

287

Seq.empty[String]

288

}

289

})

290

291

val joinNameParts = udf((parts: Seq[String]) => {

292

if (parts != null && parts.nonEmpty) {

293

parts.mkString(" ")

294

} else {

295

null

296

}

297

})

298

299

// UDF with Map input/output

300

val parseEmailInfo = udf((email: String) => {

301

if (email != null && email.contains("@")) {

302

val parts = email.split("@")

303

Map(

304

"username" -> parts(0),

305

"domain" -> parts(1),

306

"is_company_email" -> parts(1).contains("company")

307

)

308

} else {

309

Map.empty[String, Any]

310

}

311

})

312

313

val arrayMapData = employeeData.select(

314

$"name",

315

$"email",

316

extractNameParts($"name").as("name_parts"),

317

parseEmailInfo($"email").as("email_info")

318

)

319

320

arrayMapData.show(truncate = false)

321

```

322

323

### Null Handling and Error Management

324

```scala

325

// UDF with null handling

326

val safeDivide = udf((numerator: Double, denominator: Double) => {

327

if (denominator != 0.0) {

328

Some(numerator / denominator)

329

} else {

330

None

331

}

332

})

333

334

// UDF with error handling

335

val parseAge = udf((ageStr: String) => {

336

try {

337

if (ageStr != null && ageStr.nonEmpty) {

338

Some(ageStr.toInt)

339

} else {

340

None

341

}

342

} catch {

343

case _: NumberFormatException => None

344

}

345

})

346

347

// UDF with validation

348

val validateSalary = udf((salary: Double) => {

349

salary match {

350

case s if s < 0 => ("invalid", "Salary cannot be negative")

351

case s if s > 1000000 => ("warning", "Salary seems unusually high")

352

case s if s < 20000 => ("warning", "Salary seems unusually low")

353

case _ => ("valid", "Salary is within normal range")

354

}

355

})

356

357

// Test data with nulls and errors

358

val testData = Seq(

359

("John", "25", 50000.0),

360

("Jane", "abc", 75000.0), // Invalid age

361

("Bob", null, -1000.0), // Null age, negative salary

362

("Alice", "30", 2000000.0) // High salary

363

).toDF("name", "age_str", "salary")

364

365

val validatedData = testData.select(

366

$"name",

367

$"age_str",

368

parseAge($"age_str").as("parsed_age"),

369

$"salary",

370

validateSalary($"salary").as("salary_validation"),

371

safeDivide($"salary", lit(12.0)).as("monthly_salary")

372

)

373

374

validatedData.show(truncate = false)

375

376

// Non-nullable and deterministic UDFs

377

val strictUpperCase = udf((input: String) => {

378

input.toUpperCase

379

}).asNonNullable()

380

381

val deterministicHash = udf((input: String) => {

382

input.hashCode

383

}).asNondeterministic() // Mark as non-deterministic if using random elements

384

385

val nonDeterministicId = udf(() => {

386

java.util.UUID.randomUUID().toString

387

}).asNondeterministic()

388

```

389

390

### User-Defined Aggregate Functions (UDAFs)

391

```scala

392

import org.apache.spark.sql.{Encoder, Encoders}

393

import org.apache.spark.sql.expressions.Aggregator

394

395

// Simple UDAF: Calculate geometric mean

396

class GeometricMean extends Aggregator[Double, (Double, Long), Double] {

397

def zero: (Double, Long) = (1.0, 0L)

398

399

def reduce(buffer: (Double, Long), input: Double): (Double, Long) = {

400

if (input > 0) {

401

(buffer._1 * input, buffer._2 + 1)

402

} else {

403

buffer

404

}

405

}

406

407

def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {

408

(b1._1 * b2._1, b1._2 + b2._2)

409

}

410

411

def finish(buffer: (Double, Long)): Double = {

412

if (buffer._2 > 0) {

413

math.pow(buffer._1, 1.0 / buffer._2)

414

} else {

415

0.0

416

}

417

}

418

419

def bufferEncoder: Encoder[(Double, Long)] = Encoders.product

420

def outputEncoder: Encoder[Double] = Encoders.scalaDouble

421

}

422

423

// Complex UDAF: Calculate summary statistics

424

case class StatsSummary(

425

count: Long,

426

sum: Double,

427

sumSquares: Double,

428

min: Double,

429

max: Double

430

) {

431

def mean: Double = if (count > 0) sum / count else 0.0

432

def variance: Double = if (count > 1) (sumSquares - sum * sum / count) / (count - 1) else 0.0

433

def stddev: Double = math.sqrt(variance)

434

}

435

436

class SummaryStats extends Aggregator[Double, StatsSummary, StatsSummary] {

437

def zero: StatsSummary = StatsSummary(0L, 0.0, 0.0, Double.MaxValue, Double.MinValue)

438

439

def reduce(buffer: StatsSummary, input: Double): StatsSummary = {

440

StatsSummary(

441

count = buffer.count + 1,

442

sum = buffer.sum + input,

443

sumSquares = buffer.sumSquares + input * input,

444

min = math.min(buffer.min, input),

445

max = math.max(buffer.max, input)

446

)

447

}

448

449

def merge(b1: StatsSummary, b2: StatsSummary): StatsSummary = {

450

if (b1.count == 0) b2

451

else if (b2.count == 0) b1

452

else StatsSummary(

453

count = b1.count + b2.count,

454

sum = b1.sum + b2.sum,

455

sumSquares = b1.sumSquares + b2.sumSquares,

456

min = math.min(b1.min, b2.min),

457

max = math.max(b1.max, b2.max)

458

)

459

}

460

461

def finish(buffer: StatsSummary): StatsSummary = {

462

if (buffer.count == 0) {

463

buffer.copy(min = 0.0, max = 0.0)

464

} else {

465

buffer

466

}

467

}

468

469

def bufferEncoder: Encoder[StatsSummary] = Encoders.product

470

def outputEncoder: Encoder[StatsSummary] = Encoders.product

471

}

472

473

// Register and use UDAFs

474

val geometricMean = new GeometricMean().toColumn

475

val summaryStats = new SummaryStats().toColumn

476

477

// Apply UDAFs

478

val aggregatedData = employeeData

479

.agg(

480

geometricMean.name("geometric_mean_salary"),

481

summaryStats.name("salary_stats")

482

)

483

484

aggregatedData.show()

485

486

// Extract fields from complex UDAF result

487

val detailedStats = employeeData

488

.agg(summaryStats.name("stats"))

489

.select(

490

$"stats.count",

491

$"stats.sum",

492

$"stats.mean",

493

$"stats.variance",

494

$"stats.stddev",

495

$"stats.min",

496

$"stats.max"

497

)

498

499

detailedStats.show()

500

```

501

502

### Higher-Order Functions and Functional UDFs

503

```scala

504

import org.apache.spark.sql.functions._

505

506

// Data with arrays for higher-order function examples

507

val arrayData = Seq(

508

("Alice", Array(85, 90, 88, 92)),

509

("Bob", Array(78, 85, 80, 88)),

510

("Charlie", Array(95, 88, 90, 94))

511

).toDF("name", "test_scores")

512

513

// UDF for array transformation

514

val scaleScore = udf((score: Int, scaleFactor: Double) => (score * scaleFactor).toInt)

515

val calculateGrade = udf((score: Int) => {

516

score match {

517

case s if s >= 90 => "A"

518

case s if s >= 80 => "B"

519

case s if s >= 70 => "C"

520

case s if s >= 60 => "D"

521

case _ => "F"

522

}

523

})

524

525

// Higher-order function usage with UDFs

526

val transformedScores = arrayData.select(

527

$"name",

528

$"test_scores",

529

// Transform each score

530

transform($"test_scores", score => scaleScore(score, lit(1.05))).as("scaled_scores"),

531

// Filter scores above threshold

532

filter($"test_scores", score => score > 85).as("high_scores"),

533

// Check if all scores are passing

534

forall($"test_scores", score => score >= 60).as("all_passing"),

535

// Check if any score is excellent

536

exists($"test_scores", score => score >= 95).as("has_excellent")

537

)

538

539

transformedScores.show(truncate = false)

540

541

// Complex aggregation with UDFs

542

val scoreAnalysis = arrayData.select(

543

$"name",

544

$"test_scores",

545

// Calculate weighted average (more recent tests have higher weight)

546

aggregate($"test_scores", lit(0.0),

547

(acc, score) => acc + score * lit(1.0)).as("sum_scores"),

548

// Find maximum score

549

aggregate($"test_scores", lit(0),

550

(acc, score) => greatest(acc, score)).as("max_score"),

551

// Calculate grade distribution

552

transform($"test_scores", score => calculateGrade(score)).as("letter_grades")

553

)

554

555

scoreAnalysis.show(truncate = false)

556

```

557

558

### Performance Optimization and Best Practices

559

```scala

560

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

561

import org.apache.spark.util.AccumulatorV2

562

563

// Optimized UDF with broadcast variables

564

val taxRates = Map("US" -> 0.25, "UK" -> 0.20, "CA" -> 0.15, "DE" -> 0.30)

565

val broadcastTaxRates = spark.sparkContext.broadcast(taxRates)

566

567

val calculateNetSalary = udf((grossSalary: Double, country: String) => {

568

val taxRate = broadcastTaxRates.value.getOrElse(country, 0.0)

569

grossSalary * (1.0 - taxRate)

570

})

571

572

// Sample international employee data

573

val intlEmployees = Seq(

574

("Alice", 75000.0, "US"),

575

("Bob", 65000.0, "UK"),

576

("Charlie", 70000.0, "CA"),

577

("Diana", 80000.0, "DE")

578

).toDF("name", "gross_salary", "country")

579

580

val netSalaries = intlEmployees.select(

581

$"name",

582

$"gross_salary",

583

$"country",

584

calculateNetSalary($"gross_salary", $"country").as("net_salary")

585

)

586

587

netSalaries.show()

588

589

// UDF with accumulator for monitoring

590

class StringAccumulator extends AccumulatorV2[String, String] {

591

private var _value = ""

592

593

def isZero: Boolean = _value.isEmpty

594

def copy(): StringAccumulator = {

595

val acc = new StringAccumulator

596

acc._value = _value

597

acc

598

}

599

def reset(): Unit = _value = ""

600

def add(v: String): Unit = _value += v + "\n"

601

def merge(other: AccumulatorV2[String, String]): Unit = _value += other.value

602

def value: String = _value

603

}

604

605

val errorAccumulator = new StringAccumulator

606

spark.sparkContext.register(errorAccumulator, "UDF Errors")

607

608

val robustProcessing = udf((data: String) => {

609

try {

610

// Simulate complex processing

611

if (data == null || data.isEmpty) {

612

errorAccumulator.add(s"Empty data encountered")

613

"EMPTY"

614

} else if (data.length < 3) {

615

errorAccumulator.add(s"Short data: $data")

616

"SHORT"

617

} else {

618

data.toUpperCase

619

}

620

} catch {

621

case e: Exception =>

622

errorAccumulator.add(s"Error processing '$data': ${e.getMessage}")

623

"ERROR"

624

}

625

})

626

627

// Vectorized UDF for better performance (when possible)

628

val vectorizedStringLength = udf((strings: Seq[String]) => {

629

strings.map(s => if (s != null) s.length else 0)

630

})

631

632

// Batch processing example

633

val testStrings = Seq(

634

Array("hello", "world", null, "spark", "sql"),

635

Array("scala", "java", "", "python"),

636

Array(null, "data", "engineering")

637

).toDF("string_array")

638

639

val lengthResults = testStrings.select(

640

$"string_array",

641

vectorizedStringLength($"string_array").as("lengths")

642

)

643

644

lengthResults.show(truncate = false)

645

646

// UDF lifecycle management

647

object UDFRegistry {

648

private val registeredUDFs = scala.collection.mutable.Set[String]()

649

650

def registerUDF[T](name: String, udf: UserDefinedFunction)(implicit spark: SparkSession): Unit = {

651

spark.udf.register(name, udf)

652

registeredUDFs += name

653

println(s"Registered UDF: $name")

654

}

655

656

def listRegisteredUDFs(): Set[String] = registeredUDFs.toSet

657

658

def cleanup()(implicit spark: SparkSession): Unit = {

659

registeredUDFs.foreach { name =>

660

try {

661

// Note: Spark doesn't have built-in UDF unregistration,

662

// so we track them for documentation purposes

663

println(s"UDF registered in session: $name")

664

} catch {

665

case e: Exception =>

666

println(s"Note: UDF $name may need manual cleanup: ${e.getMessage}")

667

}

668

}

669

}

670

}

671

672

// Register UDFs through registry

673

UDFRegistry.registerUDF("calculate_net_salary", calculateNetSalary)

674

UDFRegistry.registerUDF("robust_processing", robustProcessing)

675

676

// Use registered UDFs

677

intlEmployees.createOrReplaceTempView("international_employees")

678

679

val sqlWithUDF = spark.sql("""

680

SELECT

681

name,

682

gross_salary,

683

country,

684

calculate_net_salary(gross_salary, country) as net_salary

685

FROM international_employees

686

""")

687

688

sqlWithUDF.show()

689

690

// Display UDF registry status

691

println("Registered UDFs: " + UDFRegistry.listRegisteredUDFs().mkString(", "))

692

println("Error accumulator value:")

693

println(errorAccumulator.value)

694

695

// Cleanup

696

UDFRegistry.cleanup()

697

```

698

699

### UDF Testing and Validation

700

```scala

701

import org.apache.spark.sql.test.SharedSparkSession // In test environment

702

703

// UDF testing framework

704

class UDFTester(spark: SparkSession) {

705

import spark.implicits._

706

707

def testUDF[T, R](udf: UserDefinedFunction, testCases: Seq[(T, R)]): Boolean = {

708

val testData = testCases.map(_._1).toDF("input")

709

val expectedResults = testCases.map(_._2)

710

711

val results = testData.select(udf($"input").as("result")).collect().map(_.getAs[R]("result"))

712

713

val passed = results.zip(expectedResults).forall { case (actual, expected) =>

714

actual == expected

715

}

716

717

if (passed) {

718

println(s"✅ All ${testCases.length} test cases passed")

719

} else {

720

println(s"❌ Some test cases failed")

721

results.zip(expectedResults).zipWithIndex.foreach { case ((actual, expected), idx) =>

722

if (actual != expected) {

723

println(s" Test case $idx: expected $expected, got $actual")

724

}

725

}

726

}

727

728

passed

729

}

730

731

def benchmarkUDF[T](udf: UserDefinedFunction, data: DataFrame, column: String, iterations: Int = 5): Unit = {

732

val times = (1 to iterations).map { _ =>

733

val start = System.nanoTime()

734

data.select(udf($"$column")).collect()

735

val end = System.nanoTime()

736

(end - start) / 1000000.0 // Convert to milliseconds

737

}

738

739

val avgTime = times.sum / times.length

740

val minTime = times.min

741

val maxTime = times.max

742

743

println(s"UDF Performance (${iterations} iterations):")

744

println(s" Average: ${avgTime}ms")

745

println(s" Min: ${minTime}ms")

746

println(s" Max: ${maxTime}ms")

747

}

748

}

749

750

// Test the UDFs

751

val tester = new UDFTester(spark)

752

753

// Test extract first name UDF

754

val firstNameTests = Seq(

755

("John Doe", "John"),

756

("Alice", "Alice"),

757

("Bob Smith Jr", "Bob"),

758

(null, null),

759

("", "")

760

)

761

762

tester.testUDF(extractFirstName, firstNameTests)

763

764

// Test calculate bonus UDF

765

val bonusTests = Seq(

766

((50000.0, 25), 5000.0),

767

((80000.0, 35), 12000.0),

768

((60000.0, 30), 6000.0)

769

)

770

771

val bonusUDF = udf((salary: Double, age: Int) => {

772

val ageMultiplier = if (age > 30) 0.15 else 0.10

773

salary * ageMultiplier

774

})

775

776

tester.testUDF(bonusUDF, bonusTests)

777

778

// Performance benchmarking

779

val largeDataset = (1 to 10000).map(i => s"Name $i").toDF("name")

780

tester.benchmarkUDF(extractFirstName, largeDataset, "name")

781

782

// UDF documentation generator

783

def generateUDFDocumentation(udfs: Map[String, (UserDefinedFunction, String)]): String = {

784

val docs = udfs.map { case (name, (udf, description)) =>

785

s"""

786

|## $name

787

|

788

|**Description**: $description

789

|

790

|**Input Types**: ${udf.inputTypes.mkString(", ")}

791

|**Output Type**: ${udf.dataType}

792

|**Nullable**: ${udf.nullable}

793

|**Deterministic**: ${udf.deterministic}

794

|

795

""".stripMargin

796

}.mkString("\n")

797

798

s"""

799

|# UDF Documentation

800

|

801

|$docs

802

""".stripMargin

803

}

804

805

val udfDocs = Map(

806

"extract_first_name" -> (extractFirstName, "Extracts the first name from a full name string"),

807

"calculate_bonus" -> (bonusUDF, "Calculates employee bonus based on salary and age"),

808

"is_valid_email" -> (isValidEmail, "Validates email format")

809

)

810

811

println(generateUDFDocumentation(udfDocs))

812

```