This document covers support for custom scalar, table, and aggregate functions in Apache Flink Table Uber Blink.
interface TableEnvironment {
void createFunction(String path, UserDefinedFunction function);
void createFunction(String path, Class<? extends UserDefinedFunction> functionClass);
void createTemporaryFunction(String path, UserDefinedFunction function);
void createTemporaryFunction(String path, Class<? extends UserDefinedFunction> functionClass);
void createTemporarySystemFunction(String name, UserDefinedFunction function);
void createTemporarySystemFunction(String name, Class<? extends UserDefinedFunction> functionClass);
boolean dropFunction(String path);
boolean dropTemporaryFunction(String path);
boolean dropTemporarySystemFunction(String name);
}abstract class ScalarFunction extends UserDefinedFunction {
// User implements eval methods
// public ReturnType eval(InputType1 input1, InputType2 input2, ...);
}Example:
public static class AddFunction extends ScalarFunction {
public Integer eval(Integer a, Integer b) {
return a != null && b != null ? a + b : null;
}
public Double eval(Double a, Double b) {
return a != null && b != null ? a + b : null;
}
}
// Register and use
tEnv.createFunction("add_func", AddFunction.class);
Table result = tEnv.sqlQuery("SELECT add_func(col1, col2) FROM table1");abstract class TableFunction<T> extends UserDefinedFunction {
protected void collect(T row);
// User implements eval methods
// public void eval(InputType1 input1, InputType2 input2, ...);
}Example:
public static class SplitFunction extends TableFunction<String> {
public void eval(String str) {
if (str != null) {
for (String s : str.split(",")) {
collect(s.trim());
}
}
}
}
// Register and use
tEnv.createFunction("split_func", SplitFunction.class);
Table result = tEnv.sqlQuery(
"SELECT name, word FROM users CROSS JOIN LATERAL TABLE(split_func(tags)) AS t(word)"
);abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
public abstract ACC createAccumulator();
public abstract T getValue(ACC accumulator);
// Required methods
public abstract void accumulate(ACC accumulator, InputType1 input1, ...);
// Optional methods
public void retract(ACC accumulator, InputType1 input1, ...) {}
public void merge(ACC accumulator, Iterable<ACC> iterable) {}
public void resetAccumulator(ACC accumulator) {}
}Example:
public static class WeightedAvgAccum {
public double sum = 0;
public int count = 0;
}
public static class WeightedAvgFunction extends AggregateFunction<Double, WeightedAvgAccum> {
@Override
public WeightedAvgAccum createAccumulator() {
return new WeightedAvgAccum();
}
@Override
public Double getValue(WeightedAvgAccum acc) {
return acc.count == 0 ? null : acc.sum / acc.count;
}
public void accumulate(WeightedAvgAccum acc, Double value, Integer weight) {
if (value != null && weight != null) {
acc.sum += value * weight;
acc.count += weight;
}
}
public void retract(WeightedAvgAccum acc, Double value, Integer weight) {
if (value != null && weight != null) {
acc.sum -= value * weight;
acc.count -= weight;
}
}
public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
for (WeightedAvgAccum a : it) {
acc.sum += a.sum;
acc.count += a.count;
}
}
}
// Register and use
tEnv.createFunction("weighted_avg", WeightedAvgFunction.class);
Table result = tEnv.sqlQuery("SELECT weighted_avg(score, weight) FROM tests GROUP BY student_id");abstract class TableAggregateFunction<T, ACC> extends UserDefinedFunction {
public abstract ACC createAccumulator();
public abstract void accumulate(ACC accumulator, InputType1 input1, ...);
public abstract void emitValue(ACC accumulator, Collector<T> out);
// Optional methods
public void retract(ACC accumulator, InputType1 input1, ...) {}
public void emitUpdateWithRetract(ACC accumulator, RetractableCollector<T> out) {}
}@FunctionHint(
input = @DataTypeHint("STRING"),
output = @DataTypeHint("INT")
)
public class MyFunction extends ScalarFunction {
public Integer eval(String input) {
return input != null ? input.length() : null;
}
}
@FunctionHint(
input = {@DataTypeHint("DECIMAL(10, 2)"), @DataTypeHint("INT")},
output = @DataTypeHint("DECIMAL(10, 2)")
)
public class MultiplyFunction extends ScalarFunction {
public BigDecimal eval(BigDecimal value, Integer multiplier) {
return value != null && multiplier != null ?
value.multiply(BigDecimal.valueOf(multiplier)) : null;
}
}abstract class UserDefinedFunction {
public void open(FunctionContext context) throws Exception {}
public void close() throws Exception {}
}
interface FunctionContext {
MetricGroup getMetricGroup();
CachedFile getCachedFile(String name);
File getCachedFile(String name);
DistributedCache.DistributedCacheEntry getCachedFileEntry(String name);
int getIndexOfThisSubtask();
int getNumberOfParallelSubtasks();
ExecutionConfig getExecutionConfig();
ClassLoader getUserCodeClassLoader();
}Example with Context:
public static class LookupFunction extends ScalarFunction {
private transient Map<String, String> lookupData;
@Override
public void open(FunctionContext context) throws Exception {
// Initialize lookup data from cached file
File file = context.getCachedFile("lookup-data");
lookupData = loadLookupData(file);
}
public String eval(String key) {
return lookupData.get(key);
}
}interface TableEnvironment {
void createTemporarySystemFunction(String name, String fullyQualifiedName);
}Register Python UDF:
// Register Python function
tEnv.createTemporarySystemFunction("py_upper", "my_module.upper_func");
// Use in SQL
Table result = tEnv.sqlQuery("SELECT py_upper(name) FROM users");// Override built-in function
public static class CustomSubstringFunction extends ScalarFunction {
public String eval(String str, Integer start, Integer length) {
// Custom substring logic
return customSubstring(str, start, length);
}
}
tEnv.createTemporarySystemFunction("SUBSTRING", CustomSubstringFunction.class);abstract class UserDefinedFunction implements Serializable {
FunctionKind getKind();
TypeInference getTypeInference(DataTypeFactory typeFactory);
Set<FunctionRequirement> getRequirements();
boolean isDeterministic();
}
enum FunctionKind {
SCALAR,
TABLE,
AGGREGATE,
TABLE_AGGREGATE,
ASYNC_TABLE,
OTHER
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
@interface FunctionHint {
DataTypeHint[] input() default {};
DataTypeHint output() default @DataTypeHint(value = DataTypeHint.NULL);
boolean isVarArgs() default false;
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD, ElementType.PARAMETER})
@interface DataTypeHint {
String value() default NULL;
Class<?> bridgedTo() default Object.class;
int defaultDecimalPrecision() default -1;
int defaultDecimalScale() default -1;
boolean allowRawGlobally() default true;
String NULL = "__NULL__";
}