or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

catalog.mdcolumns-functions.mddata-io.mddataset-dataframe.mdindex.mdsession-management.mdstreaming.mdtypes-encoders.mdudfs.md

udfs.mddocs/

0

# User-Defined Functions

1

2

Registration and usage of custom user-defined functions (UDFs) and user-defined aggregate functions (UDAFs). Enables extending Spark SQL with custom business logic and domain-specific operations.

3

4

## Capabilities

5

6

### UDF Registration

7

8

Interface for registering user-defined functions that can be used in SQL and DataFrame operations.

9

10

```scala { .api }

11

/**

12

* Functions for registering user-defined functions

13

*/

14

class UDFRegistration {

15

/** Register UDF with 0 arguments */

16

def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction

17

18

/** Register UDF with 1 argument */

19

def register[RT: TypeTag, A1: TypeTag](name: String, func: A1 => RT): UserDefinedFunction

20

21

/** Register UDF with 2 arguments */

22

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](

23

name: String, func: (A1, A2) => RT): UserDefinedFunction

24

25

/** Register UDF with 3 arguments */

26

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](

27

name: String, func: (A1, A2, A3) => RT): UserDefinedFunction

28

29

/** Register UDF with 4 arguments */

30

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](

31

name: String, func: (A1, A2, A3, A4) => RT): UserDefinedFunction

32

33

/** Register UDF with 5 arguments */

34

def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](

35

name: String, func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction

36

37

// ... continues up to 22 arguments

38

39

/** Register user-defined aggregate function */

40

def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction

41

42

/** Register UDF from UserDefinedFunction */

43

def register(name: String, udf: UserDefinedFunction): UserDefinedFunction

44

}

45

```

46

47

### UserDefinedFunction

48

49

Wrapper for user-defined functions that can be applied to columns.

50

51

```scala { .api }

52

/**

53

* User-defined function that can be called in DataFrame operations

54

*/

55

class UserDefinedFunction {

56

/** Apply UDF to columns */

57

def apply(exprs: Column*): Column

58

59

/** Mark UDF as non-deterministic */

60

def asNondeterministic(): UserDefinedFunction

61

62

/** Mark UDF as non-nullable */

63

def asNonNullable(): UserDefinedFunction

64

65

/** UDF name (if registered) */

66

def name: Option[String]

67

68

/** Check if UDF is deterministic */

69

def deterministic: Boolean

70

71

/** Check if UDF is nullable */

72

def nullable: Boolean

73

}

74

75

/**

76

* Factory methods for creating UDFs

77

*/

78

object functions {

79

/** Create UDF with 1 argument */

80

def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction

81

82

/** Create UDF with 2 arguments */

83

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction

84

85

/** Create UDF with 3 arguments */

86

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](

87

f: (A1, A2, A3) => RT): UserDefinedFunction

88

89

/** Create UDF with 4 arguments */

90

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](

91

f: (A1, A2, A3, A4) => RT): UserDefinedFunction

92

93

/** Create UDF with 5 arguments */

94

def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](

95

f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction

96

97

// ... continues up to 22 arguments

98

}

99

```

100

101

**Usage Examples:**

102

103

```scala

104

import org.apache.spark.sql.functions.udf

105

import org.apache.spark.sql.types._

106

107

// Simple string transformation UDF

108

val upperCase = udf((s: String) => s.toUpperCase)

109

110

// Register for SQL usage

111

spark.udf.register("upper_case", upperCase)

112

113

// Use in DataFrame operations

114

val df = spark.table("employees")

115

val result = df.withColumn("upper_name", upperCase(col("name")))

116

117

// Use in SQL

118

val sqlResult = spark.sql("SELECT name, upper_case(name) as upper_name FROM employees")

119

120

// Complex business logic UDF

121

val calculateBonus = udf((salary: Double, performance: String, years: Int) => {

122

val baseBonus = salary * 0.1

123

val performanceMultiplier = performance match {

124

case "excellent" => 2.0

125

case "good" => 1.5

126

case "average" => 1.0

127

case _ => 0.5

128

}

129

val tenureBonus = if (years > 5) 1.2 else 1.0

130

baseBonus * performanceMultiplier * tenureBonus

131

})

132

133

spark.udf.register("calculate_bonus", calculateBonus)

134

135

val bonusData = df.withColumn("bonus",

136

calculateBonus(col("salary"), col("performance"), col("years_of_service"))

137

)

138

139

// Non-deterministic UDF (e.g., random values)

140

val randomId = udf(() => java.util.UUID.randomUUID().toString)

141

.asNondeterministic()

142

143

val withIds = df.withColumn("random_id", randomId())

144

```

145

146

### UserDefinedAggregateFunction

147

148

Abstract class for custom aggregate functions that work on groups of rows.

149

150

```scala { .api }

151

/**

152

* Base class for user-defined aggregate functions

153

*/

154

abstract class UserDefinedAggregateFunction extends Serializable {

155

/** Input schema for the aggregation function */

156

def inputSchema: StructType

157

158

/** Schema of the aggregation buffer */

159

def bufferSchema: StructType

160

161

/** Output data type */

162

def dataType: DataType

163

164

/** Whether function is deterministic */

165

def deterministic: Boolean

166

167

/** Initialize the aggregation buffer */

168

def initialize(buffer: MutableAggregationBuffer): Unit

169

170

/** Update buffer with new input row */

171

def update(buffer: MutableAggregationBuffer, input: Row): Unit

172

173

/** Merge two aggregation buffers */

174

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

175

176

/** Calculate final result from buffer */

177

def evaluate(buffer: Row): Any

178

}

179

180

/**

181

* Mutable aggregation buffer for UDAF implementations

182

*/

183

trait MutableAggregationBuffer extends Row {

184

/** Update value at index */

185

def update(i: Int, value: Any): Unit

186

}

187

```

188

189

**Usage Example:**

190

191

```scala

192

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}

193

import org.apache.spark.sql.types._

194

import org.apache.spark.sql.Row

195

196

// Custom aggregation function to calculate geometric mean

197

class GeometricMean extends UserDefinedAggregateFunction {

198

// Input: one double value

199

def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)

200

201

// Buffer: sum of logs and count

202

def bufferSchema: StructType = StructType(

203

StructField("logSum", DoubleType) ::

204

StructField("count", LongType) :: Nil

205

)

206

207

// Output: double

208

def dataType: DataType = DoubleType

209

def deterministic: Boolean = true

210

211

// Initialize buffer

212

def initialize(buffer: MutableAggregationBuffer): Unit = {

213

buffer(0) = 0.0 // logSum

214

buffer(1) = 0L // count

215

}

216

217

// Update buffer with new value

218

def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

219

if (!input.isNullAt(0)) {

220

val value = input.getDouble(0)

221

if (value > 0) {

222

buffer(0) = buffer.getDouble(0) + math.log(value)

223

buffer(1) = buffer.getLong(1) + 1

224

}

225

}

226

}

227

228

// Merge two buffers

229

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

230

buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)

231

buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)

232

}

233

234

// Calculate final result

235

def evaluate(buffer: Row): Any = {

236

val count = buffer.getLong(1)

237

if (count == 0) {

238

null

239

} else {

240

math.exp(buffer.getDouble(0) / count)

241

}

242

}

243

}

244

245

// Register and use the UDAF

246

val geometricMean = new GeometricMean

247

spark.udf.register("geom_mean", geometricMean)

248

249

// Use in DataFrame operations

250

val result = spark.table("sales")

251

.groupBy("region")

252

.agg(geometricMean(col("amount")).alias("geom_avg_amount"))

253

254

// Use in SQL

255

val sqlResult = spark.sql("""

256

SELECT region, geom_mean(amount) as geom_avg_amount

257

FROM sales

258

GROUP BY region

259

""")

260

```

261

262

### Advanced UDF Patterns

263

264

**Working with complex types:**

265

266

```scala

267

// UDF that processes arrays

268

val arraySum = udf((arr: Seq[Int]) => if (arr != null) arr.sum else 0)

269

270

// UDF that processes structs

271

val extractField = udf((row: Row) => if (row != null) row.getString(0) else null)

272

273

// UDF that returns complex types

274

val createStruct = udf((name: String, age: Int) => Row(name, age, s"$name-$age"))

275

276

// Register with explicit return type

277

spark.udf.register("create_person", createStruct,

278

StructType(Seq(

279

StructField("name", StringType),

280

StructField("age", IntegerType),

281

StructField("id", StringType)

282

))

283

)

284

```

285

286

**Error handling in UDFs:**

287

288

```scala

289

val safeDiv = udf((a: Double, b: Double) => {

290

try {

291

if (b == 0.0) None else Some(a / b)

292

} catch {

293

case _: Exception => None

294

}

295

})

296

297

// UDF with detailed error information

298

val validateEmail = udf((email: String) => {

299

if (email == null || email.isEmpty) {

300

Map("valid" -> false, "error" -> "Email is empty")

301

} else if (!email.contains("@")) {

302

Map("valid" -> false, "error" -> "Missing @ symbol")

303

} else {

304

Map("valid" -> true, "error" -> "")

305

}

306

})

307

```

308

309

**Performance optimization:**

310

311

```scala

312

// Broadcast variables in UDFs

313

val lookupMap = spark.sparkContext.broadcast(Map(

314

"A" -> "Alpha",

315

"B" -> "Beta",

316

"C" -> "Gamma"

317

))

318

319

val lookupUdf = udf((code: String) => {

320

lookupMap.value.getOrElse(code, "Unknown")

321

})

322

323

// Use closure-free UDFs for better serialization

324

object UDFUtils {

325

def multiply(factor: Double): UserDefinedFunction = {

326

udf((value: Double) => value * factor)

327

}

328

329

def formatCurrency(locale: String): UserDefinedFunction = {

330

udf((amount: Double) => {

331

val formatter = java.text.NumberFormat.getCurrencyInstance(

332

java.util.Locale.forLanguageTag(locale)

333

)

334

formatter.format(amount)

335

})

336

}

337

}

338

339

val doubleValue = UDFUtils.multiply(2.0)

340

val formatUSD = UDFUtils.formatCurrency("en-US")

341

```

342

343

**Typed UDFs with case classes:**

344

345

```scala

346

case class Person(name: String, age: Int)

347

case class PersonInfo(person: Person, category: String)

348

349

// Typed UDF working with case classes

350

val categorize = udf((person: Person) => {

351

val category = person.age match {

352

case age if age < 18 => "Minor"

353

case age if age < 65 => "Adult"

354

case _ => "Senior"

355

}

356

PersonInfo(person, category)

357

})

358

359

// Use with proper encoders

360

import spark.implicits._

361

val peopleDS = Seq(

362

Person("Alice", 25),

363

Person("Bob", 17),

364

Person("Carol", 70)

365

).toDS()

366

367

val categorizedDS = peopleDS.select(categorize(col("value")).alias("info"))

368

```

369

370

### UDF Best Practices

371

372

**Null handling:**

373

374

```scala

375

val nullSafeUdf = udf((value: String) => {

376

Option(value).map(_.trim.toUpperCase).orNull

377

})

378

379

// Or mark as non-nullable if appropriate

380

val nonNullUdf = udf((value: String) => value.trim.toUpperCase)

381

.asNonNullable()

382

```

383

384

**Testing UDFs:**

385

386

```scala

387

// Unit test UDFs outside Spark context

388

val testUdf = (s: String) => s.toUpperCase

389

assert(testUdf("hello") == "HELLO")

390

391

// Integration testing with Spark

392

val df = Seq("hello", "world").toDF("text")

393

val result = df.select(upperCase(col("text"))).collect()

394

assert(result.map(_.getString(0)) sameElements Array("HELLO", "WORLD"))

395

```

396

397

**Documentation and type safety:**

398

399

```scala

400

/**

401

* Calculates compound interest

402

* @param principal Initial amount

403

* @param rate Annual interest rate (as decimal, e.g., 0.05 for 5%)

404

* @param years Number of years

405

* @param compoundingPeriods Number of times interest is compounded per year

406

* @return Final amount after compound interest

407

*/

408

val compoundInterest = udf((principal: Double, rate: Double, years: Int, compoundingPeriods: Int) => {

409

require(principal >= 0, "Principal must be non-negative")

410

require(rate >= 0, "Interest rate must be non-negative")

411

require(years >= 0, "Years must be non-negative")

412

require(compoundingPeriods > 0, "Compounding periods must be positive")

413

414

principal * math.pow(1 + rate / compoundingPeriods, compoundingPeriods * years)

415

})

416

417

spark.udf.register("compound_interest", compoundInterest)

418

```