0
# User-Defined Functions
1
2
The Flink Table API supports user-defined functions (UDFs) to extend functionality with custom logic. Three main types are supported: scalar functions (1-to-1), table functions (1-to-N), and aggregate functions (N-to-1).
3
4
## Capabilities
5
6
### Base Function Framework
7
8
All user-defined functions extend from the base UserDefinedFunction class.
9
10
```scala { .api }
11
/**
12
* Base class for all user-defined functions
13
*/
14
abstract class UserDefinedFunction {
15
/**
16
* Initialization method called when function is opened
17
* @param context Function context providing runtime information
18
*/
19
def open(context: FunctionContext): Unit = {}
20
21
/**
22
* Cleanup method called when function is closed
23
*/
24
def close(): Unit = {}
25
26
/**
27
* Indicates whether the function is deterministic
28
* @returns True if function always produces same output for same input
29
*/
30
def isDeterministic: Boolean = true
31
}
32
```
33
34
### Scalar Functions
35
36
Functions that take one or more input values and return a single value.
37
38
```scala { .api }
39
/**
40
* Base class for scalar functions (1-to-1 mapping)
41
*/
42
abstract class ScalarFunction extends UserDefinedFunction {
43
/**
44
* Creates a function call expression for this scalar function
45
* @param params Parameters for the function call
46
* @returns Expression representing the function call
47
*/
48
def apply(params: Expression*): Expression
49
50
/**
51
* Override to specify custom result type
52
* @param signature Array of parameter classes
53
* @returns Type information for the result
54
*/
55
def getResultType(signature: Array[Class[_]]): TypeInformation[_] = null
56
57
/**
58
* Override to specify custom parameter types
59
* @param signature Array of parameter classes
60
* @returns Array of parameter type information
61
*/
62
def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = null
63
}
64
```
65
66
**Usage Examples:**
67
68
```scala
69
// Simple scalar function
70
class AddOne extends ScalarFunction {
71
def eval(x: Int): Int = x + 1
72
}
73
74
// Scalar function with multiple parameters
75
class StringConcat extends ScalarFunction {
76
def eval(a: String, b: String): String = a + b
77
}
78
79
// Scalar function with variable arguments
80
class ConcatWs extends ScalarFunction {
81
def eval(separator: String, strings: String*): String = {
82
strings.mkString(separator)
83
}
84
}
85
86
// Complex scalar function with type override
87
class ParseJson extends ScalarFunction {
88
def eval(jsonStr: String): Row = {
89
// Parse JSON and return Row
90
val parsed = JSON.parse(jsonStr)
91
Row.of(parsed.get("id"), parsed.get("name"))
92
}
93
94
override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
95
Types.ROW(Array("id", "name"), Array(Types.LONG, Types.STRING))
96
}
97
}
98
99
// Register and use scalar functions
100
tEnv.registerFunction("addOne", new AddOne())
101
tEnv.registerFunction("concat", new StringConcat())
102
103
val result = table.select('id, addOne('age), concat('firstName, 'lastName))
104
val sqlResult = tEnv.sqlQuery("SELECT id, addOne(age), concat(firstName, lastName) FROM Users")
105
```
106
107
### Table Functions
108
109
Functions that take one or more input values and return multiple rows (1-to-N mapping).
110
111
```scala { .api }
112
/**
113
* Base class for table functions (1-to-N mapping)
114
* @tparam T Type of output rows
115
*/
116
abstract class TableFunction[T] extends UserDefinedFunction {
117
/**
118
* Collects an output row
119
* @param result Output row to emit
120
*/
121
protected def collect(result: T): Unit
122
123
/**
124
* Override to specify custom result type
125
* @param signature Array of parameter classes
126
* @returns Type information for the result
127
*/
128
def getResultType(signature: Array[Class[_]]): TypeInformation[T] = null
129
130
/**
131
* Override to specify custom parameter types
132
* @param signature Array of parameter classes
133
* @returns Array of parameter type information
134
*/
135
def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = null
136
}
137
```
138
139
**Usage Examples:**
140
141
```scala
142
// Split string into multiple rows
143
class SplitFunction extends TableFunction[String] {
144
def eval(str: String, separator: String): Unit = {
145
str.split(separator).foreach(collect)
146
}
147
}
148
149
// Parse CSV row into structured data
150
class ParseCsv extends TableFunction[Row] {
151
def eval(csvRow: String): Unit = {
152
val fields = csvRow.split(",")
153
collect(Row.of(fields(0), fields(1).toInt, fields(2).toDouble))
154
}
155
156
override def getResultType(signature: Array[Class[_]]): TypeInformation[Row] = {
157
Types.ROW(Array("name", "age", "salary"), Array(Types.STRING, Types.INT, Types.DOUBLE))
158
}
159
}
160
161
// Generate number sequence
162
class Range extends TableFunction[Long] {
163
def eval(start: Long, end: Long): Unit = {
164
(start until end).foreach(collect)
165
}
166
}
167
168
// Register and use table functions
169
tEnv.registerFunction("split", new SplitFunction())
170
tEnv.registerFunction("range", new Range())
171
172
// Use in join (LATERAL TABLE)
173
val result = table
174
.join(new SplitFunction() as ('word))
175
.select('id, 'text, 'word)
176
177
val sqlResult = tEnv.sqlQuery("""
178
SELECT u.id, u.tags, t.word
179
FROM Users u, LATERAL TABLE(split(u.tags, ',')) as t(word)
180
""")
181
```
182
183
### Aggregate Functions
184
185
Functions that take multiple input values and return a single aggregated result (N-to-1 mapping).
186
187
```scala { .api }
188
/**
189
* Base class for aggregate functions (N-to-1 mapping)
190
* @tparam T Type of the final result
191
* @tparam ACC Type of the accumulator
192
*/
193
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
194
/**
195
* Creates and initializes the accumulator
196
* @returns New accumulator instance
197
*/
198
def createAccumulator(): ACC
199
200
/**
201
* Processes an input value and updates the accumulator
202
* @param accumulator Current accumulator
203
* @param input Input values (method should be overloaded for different arities)
204
*/
205
def accumulate(accumulator: ACC, input: Any*): Unit
206
207
/**
208
* Extracts the final result from the accumulator
209
* @param accumulator Final accumulator state
210
* @returns Aggregated result
211
*/
212
def getValue(accumulator: ACC): T
213
214
/**
215
* Retracts an input value from the accumulator (for streaming)
216
* @param accumulator Current accumulator
217
* @param input Input values to retract
218
*/
219
def retract(accumulator: ACC, input: Any*): Unit = {
220
throw new UnsupportedOperationException("retract method is not implemented")
221
}
222
223
/**
224
* Merges multiple accumulators into one (for distributed aggregation)
225
* @param accumulator Target accumulator
226
* @param otherAccumulators Other accumulators to merge
227
*/
228
def merge(accumulator: ACC, otherAccumulators: java.lang.Iterable[ACC]): Unit = {
229
throw new UnsupportedOperationException("merge method is not implemented")
230
}
231
232
/**
233
* Override to specify custom result type
234
* @param signature Array of parameter classes
235
* @returns Type information for the result
236
*/
237
def getResultType(signature: Array[Class[_]]): TypeInformation[T] = null
238
239
/**
240
* Override to specify custom accumulator type
241
* @returns Type information for the accumulator
242
*/
243
def getAccumulatorType: TypeInformation[ACC] = null
244
}
245
```
246
247
**Usage Examples:**
248
249
```scala
250
// Simple sum aggregate function
251
class SumFunction extends AggregateFunction[Long, Long] {
252
override def createAccumulator(): Long = 0L
253
254
def accumulate(acc: Long, value: Long): Unit = {
255
acc + value
256
}
257
258
override def getValue(accumulator: Long): Long = accumulator
259
}
260
261
// Weighted average with complex accumulator
262
case class WeightedAvgAccumulator(sum: Double, count: Long)
263
264
class WeightedAvg extends AggregateFunction[Double, WeightedAvgAccumulator] {
265
override def createAccumulator(): WeightedAvgAccumulator = {
266
WeightedAvgAccumulator(0.0, 0L)
267
}
268
269
def accumulate(acc: WeightedAvgAccumulator, value: Double, weight: Long): Unit = {
270
acc.sum += value * weight
271
acc.count += weight
272
}
273
274
override def getValue(accumulator: WeightedAvgAccumulator): Double = {
275
if (accumulator.count == 0) 0.0 else accumulator.sum / accumulator.count
276
}
277
278
def retract(acc: WeightedAvgAccumulator, value: Double, weight: Long): Unit = {
279
acc.sum -= value * weight
280
acc.count -= weight
281
}
282
283
def merge(acc: WeightedAvgAccumulator, otherAccs: java.lang.Iterable[WeightedAvgAccumulator]): Unit = {
284
otherAccs.forEach { other =>
285
acc.sum += other.sum
286
acc.count += other.count
287
}
288
}
289
}
290
291
// String concatenation aggregate
292
class StringAgg extends AggregateFunction[String, StringBuilder] {
293
override def createAccumulator(): StringBuilder = new StringBuilder()
294
295
def accumulate(acc: StringBuilder, value: String, separator: String): Unit = {
296
if (acc.nonEmpty) acc.append(separator)
297
acc.append(value)
298
}
299
300
override def getValue(accumulator: StringBuilder): String = accumulator.toString
301
}
302
303
// Register and use aggregate functions
304
tEnv.registerFunction("weightedAvg", new WeightedAvg())
305
tEnv.registerFunction("stringAgg", new StringAgg())
306
307
val result = table
308
.groupBy('department)
309
.select('department, weightedAvg('salary, 'years), stringAgg('name, ", "))
310
311
val sqlResult = tEnv.sqlQuery("""
312
SELECT department,
313
weightedAvg(salary, years) as avgSalary,
314
stringAgg(name, ', ') as employees
315
FROM Employees
316
GROUP BY department
317
""")
318
```
319
320
### Function Context
321
322
Runtime context providing access to metrics, cached files, and other runtime information.
323
324
```scala { .api }
325
/**
326
* Context interface providing runtime information to functions
327
*/
328
trait FunctionContext {
329
/**
330
* Gets the metric group for registering custom metrics
331
* @returns Metric group for the function
332
*/
333
def getMetricGroup: MetricGroup
334
335
/**
336
* Gets a cached file by name (distributed cache)
337
* @param name Name of the cached file
338
* @returns File handle to the cached file
339
*/
340
def getCachedFile(name: String): java.io.File
341
342
/**
343
* Gets the job parameter value
344
* @param key Parameter key
345
* @param defaultValue Default value if key not found
346
* @returns Parameter value
347
*/
348
def getJobParameter(key: String, defaultValue: String): String
349
}
350
```
351
352
**Usage Examples:**
353
354
```scala
355
class MetricsFunction extends ScalarFunction {
356
private var counter: Counter = _
357
358
override def open(context: FunctionContext): Unit = {
359
super.open(context)
360
// Register custom metrics
361
counter = context.getMetricGroup.counter("custom_function_calls")
362
}
363
364
def eval(value: String): String = {
365
counter.inc()
366
value.toUpperCase
367
}
368
}
369
370
class ConfigurableFunction extends ScalarFunction {
371
private var multiplier: Double = _
372
373
override def open(context: FunctionContext): Unit = {
374
super.open(context)
375
// Read configuration from job parameters
376
multiplier = context.getJobParameter("multiplier", "1.0").toDouble
377
}
378
379
def eval(value: Double): Double = value * multiplier
380
}
381
```
382
383
### Advanced Function Features
384
385
Additional features for complex function implementations.
386
387
```scala { .api }
388
// Generic function interface for type-safe implementations
389
trait TypeInference {
390
def inferTypes(callContext: CallContext): TypeInference.Result
391
}
392
393
// Function that can be used in both scalar and table contexts
394
abstract class PolymorphicFunction extends UserDefinedFunction {
395
def eval(args: Any*): Any
396
def evalTable(args: Any*): java.lang.Iterable[_]
397
}
398
```
399
400
**Usage Examples:**
401
402
```scala
403
// Function with custom type inference
404
class FlexibleParseFunction extends ScalarFunction with TypeInference {
405
def eval(input: String, format: String): Any = {
406
format match {
407
case "int" => input.toInt
408
case "double" => input.toDouble
409
case "boolean" => input.toBoolean
410
case _ => input
411
}
412
}
413
414
override def inferTypes(callContext: CallContext): TypeInference.Result = {
415
// Custom type inference logic based on parameters
416
val formatLiteral = callContext.getArgumentValue(1, classOf[String])
417
val resultType = formatLiteral.orElse("string") match {
418
case "int" => Types.INT
419
case "double" => Types.DOUBLE
420
case "boolean" => Types.BOOLEAN
421
case _ => Types.STRING
422
}
423
TypeInference.Result.success(resultType)
424
}
425
}
426
```
427
428
## Function Registration and Usage Patterns
429
430
```scala { .api }
431
// Registration methods on TableEnvironment
432
def registerFunction(name: String, function: ScalarFunction): Unit
433
def registerFunction(name: String, function: TableFunction[_]): Unit
434
def registerFunction(name: String, function: AggregateFunction[_, _]): Unit
435
436
// Usage in Table API (Scala)
437
table.select('field, myScalarFunction('input))
438
table.join(myTableFunction('field) as ('output))
439
table.groupBy('key).select('key, myAggregateFunction('value))
440
441
// Usage in SQL
442
tEnv.sqlQuery("SELECT field, myScalarFunction(input) FROM MyTable")
443
tEnv.sqlQuery("SELECT * FROM MyTable, LATERAL TABLE(myTableFunction(field)) as t(output)")
444
tEnv.sqlQuery("SELECT key, myAggregateFunction(value) FROM MyTable GROUP BY key")
445
```
446
447
## Types
448
449
```scala { .api }
450
abstract class UserDefinedFunction
451
abstract class ScalarFunction extends UserDefinedFunction
452
abstract class TableFunction[T] extends UserDefinedFunction
453
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction
454
455
trait FunctionContext
456
trait MetricGroup
457
case class Counter()
458
459
// Exception types
460
class ValidationException(message: String) extends RuntimeException(message)
461
```