0
# User-Defined Functions
1
2
This document covers support for custom scalar, table, and aggregate functions in Apache Flink Table Uber Blink.
3
4
## Function Registration
5
6
```java { .api }
7
interface TableEnvironment {
8
void createFunction(String path, UserDefinedFunction function);
9
void createFunction(String path, Class<? extends UserDefinedFunction> functionClass);
10
void createTemporaryFunction(String path, UserDefinedFunction function);
11
void createTemporaryFunction(String path, Class<? extends UserDefinedFunction> functionClass);
12
void createTemporarySystemFunction(String name, UserDefinedFunction function);
13
void createTemporarySystemFunction(String name, Class<? extends UserDefinedFunction> functionClass);
14
boolean dropFunction(String path);
15
boolean dropTemporaryFunction(String path);
16
boolean dropTemporarySystemFunction(String name);
17
}
18
```
19
20
## Scalar Functions
21
22
```java { .api }
23
abstract class ScalarFunction extends UserDefinedFunction {
24
// User implements eval methods
25
// public ReturnType eval(InputType1 input1, InputType2 input2, ...);
26
}
27
```
28
29
**Example:**
30
31
```java
32
public static class AddFunction extends ScalarFunction {
33
public Integer eval(Integer a, Integer b) {
34
return a != null && b != null ? a + b : null;
35
}
36
37
public Double eval(Double a, Double b) {
38
return a != null && b != null ? a + b : null;
39
}
40
}
41
42
// Register and use
43
tEnv.createFunction("add_func", AddFunction.class);
44
Table result = tEnv.sqlQuery("SELECT add_func(col1, col2) FROM table1");
45
```
46
47
## Table Functions
48
49
```java { .api }
50
abstract class TableFunction<T> extends UserDefinedFunction {
51
protected void collect(T row);
52
// User implements eval methods
53
// public void eval(InputType1 input1, InputType2 input2, ...);
54
}
55
```
56
57
**Example:**
58
59
```java
60
public static class SplitFunction extends TableFunction<String> {
61
public void eval(String str) {
62
if (str != null) {
63
for (String s : str.split(",")) {
64
collect(s.trim());
65
}
66
}
67
}
68
}
69
70
// Register and use
71
tEnv.createFunction("split_func", SplitFunction.class);
72
Table result = tEnv.sqlQuery(
73
"SELECT name, word FROM users CROSS JOIN LATERAL TABLE(split_func(tags)) AS t(word)"
74
);
75
```
76
77
## Aggregate Functions
78
79
```java { .api }
80
abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
81
public abstract ACC createAccumulator();
82
public abstract T getValue(ACC accumulator);
83
84
// Required methods
85
public abstract void accumulate(ACC accumulator, InputType1 input1, ...);
86
87
// Optional methods
88
public void retract(ACC accumulator, InputType1 input1, ...) {}
89
public void merge(ACC accumulator, Iterable<ACC> iterable) {}
90
public void resetAccumulator(ACC accumulator) {}
91
}
92
```
93
94
**Example:**
95
96
```java
97
public static class WeightedAvgAccum {
98
public double sum = 0;
99
public int count = 0;
100
}
101
102
public static class WeightedAvgFunction extends AggregateFunction<Double, WeightedAvgAccum> {
103
@Override
104
public WeightedAvgAccum createAccumulator() {
105
return new WeightedAvgAccum();
106
}
107
108
@Override
109
public Double getValue(WeightedAvgAccum acc) {
110
return acc.count == 0 ? null : acc.sum / acc.count;
111
}
112
113
public void accumulate(WeightedAvgAccum acc, Double value, Integer weight) {
114
if (value != null && weight != null) {
115
acc.sum += value * weight;
116
acc.count += weight;
117
}
118
}
119
120
public void retract(WeightedAvgAccum acc, Double value, Integer weight) {
121
if (value != null && weight != null) {
122
acc.sum -= value * weight;
123
acc.count -= weight;
124
}
125
}
126
127
public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
128
for (WeightedAvgAccum a : it) {
129
acc.sum += a.sum;
130
acc.count += a.count;
131
}
132
}
133
}
134
135
// Register and use
136
tEnv.createFunction("weighted_avg", WeightedAvgFunction.class);
137
Table result = tEnv.sqlQuery("SELECT weighted_avg(score, weight) FROM tests GROUP BY student_id");
138
```
139
140
## Table Aggregate Functions
141
142
```java { .api }
143
abstract class TableAggregateFunction<T, ACC> extends UserDefinedFunction {
144
public abstract ACC createAccumulator();
145
public abstract void accumulate(ACC accumulator, InputType1 input1, ...);
146
public abstract void emitValue(ACC accumulator, Collector<T> out);
147
148
// Optional methods
149
public void retract(ACC accumulator, InputType1 input1, ...) {}
150
public void emitUpdateWithRetract(ACC accumulator, RetractableCollector<T> out) {}
151
}
152
```
153
154
## Function Annotations
155
156
```java { .api }
157
@FunctionHint(
158
input = @DataTypeHint("STRING"),
159
output = @DataTypeHint("INT")
160
)
161
public class MyFunction extends ScalarFunction {
162
public Integer eval(String input) {
163
return input != null ? input.length() : null;
164
}
165
}
166
167
@FunctionHint(
168
input = {@DataTypeHint("DECIMAL(10, 2)"), @DataTypeHint("INT")},
169
output = @DataTypeHint("DECIMAL(10, 2)")
170
)
171
public class MultiplyFunction extends ScalarFunction {
172
public BigDecimal eval(BigDecimal value, Integer multiplier) {
173
return value != null && multiplier != null ?
174
value.multiply(BigDecimal.valueOf(multiplier)) : null;
175
}
176
}
177
```
178
179
## Function Context
180
181
```java { .api }
182
abstract class UserDefinedFunction {
183
public void open(FunctionContext context) throws Exception {}
184
public void close() throws Exception {}
185
}
186
187
interface FunctionContext {
188
MetricGroup getMetricGroup();
189
CachedFile getCachedFile(String name);
190
File getCachedFile(String name);
191
DistributedCache.DistributedCacheEntry getCachedFileEntry(String name);
192
int getIndexOfThisSubtask();
193
int getNumberOfParallelSubtasks();
194
ExecutionConfig getExecutionConfig();
195
ClassLoader getUserCodeClassLoader();
196
}
197
```
198
199
**Example with Context:**
200
201
```java
202
public static class LookupFunction extends ScalarFunction {
203
private transient Map<String, String> lookupData;
204
205
@Override
206
public void open(FunctionContext context) throws Exception {
207
// Initialize lookup data from cached file
208
File file = context.getCachedFile("lookup-data");
209
lookupData = loadLookupData(file);
210
}
211
212
public String eval(String key) {
213
return lookupData.get(key);
214
}
215
}
216
```
217
218
## Python Functions
219
220
```java { .api }
221
interface TableEnvironment {
222
void createTemporarySystemFunction(String name, String fullyQualifiedName);
223
}
224
```
225
226
Register Python UDF:
227
228
```java
229
// Register Python function
230
tEnv.createTemporarySystemFunction("py_upper", "my_module.upper_func");
231
232
// Use in SQL
233
Table result = tEnv.sqlQuery("SELECT py_upper(name) FROM users");
234
```
235
236
## Built-in Function Override
237
238
```java
239
// Override built-in function
240
public static class CustomSubstringFunction extends ScalarFunction {
241
public String eval(String str, Integer start, Integer length) {
242
// Custom substring logic
243
return customSubstring(str, start, length);
244
}
245
}
246
247
tEnv.createTemporarySystemFunction("SUBSTRING", CustomSubstringFunction.class);
248
```
249
250
## Types
251
252
```java { .api }
253
abstract class UserDefinedFunction implements Serializable {
254
FunctionKind getKind();
255
TypeInference getTypeInference(DataTypeFactory typeFactory);
256
Set<FunctionRequirement> getRequirements();
257
boolean isDeterministic();
258
}
259
260
enum FunctionKind {
261
SCALAR,
262
TABLE,
263
AGGREGATE,
264
TABLE_AGGREGATE,
265
ASYNC_TABLE,
266
OTHER
267
}
268
269
@Retention(RetentionPolicy.RUNTIME)
270
@Target({ElementType.TYPE, ElementType.METHOD})
271
@interface FunctionHint {
272
DataTypeHint[] input() default {};
273
DataTypeHint output() default @DataTypeHint(value = DataTypeHint.NULL);
274
boolean isVarArgs() default false;
275
}
276
277
@Retention(RetentionPolicy.RUNTIME)
278
@Target({ElementType.TYPE, ElementType.METHOD, ElementType.PARAMETER})
279
@interface DataTypeHint {
280
String value() default NULL;
281
Class<?> bridgedTo() default Object.class;
282
int defaultDecimalPrecision() default -1;
283
int defaultDecimalScale() default -1;
284
boolean allowRawGlobally() default true;
285
String NULL = "__NULL__";
286
}
287
```