or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

catalog-management.mdcomplex-event-processing.mdcore-table-operations.mddatastream-integration.mdindex.mdsql-processing.mdtype-system.mduser-defined-functions.mdwindow-operations.md
tile.json

user-defined-functions.mddocs/

User-Defined Functions

This document covers support for custom scalar, table, and aggregate functions in Apache Flink Table Uber Blink.

Function Registration

interface TableEnvironment {
    void createFunction(String path, UserDefinedFunction function);
    void createFunction(String path, Class<? extends UserDefinedFunction> functionClass);
    void createTemporaryFunction(String path, UserDefinedFunction function);
    void createTemporaryFunction(String path, Class<? extends UserDefinedFunction> functionClass);
    void createTemporarySystemFunction(String name, UserDefinedFunction function);
    void createTemporarySystemFunction(String name, Class<? extends UserDefinedFunction> functionClass);
    boolean dropFunction(String path);
    boolean dropTemporaryFunction(String path);
    boolean dropTemporarySystemFunction(String name);
}

Scalar Functions

abstract class ScalarFunction extends UserDefinedFunction {
    // User implements eval methods
    // public ReturnType eval(InputType1 input1, InputType2 input2, ...);
}

Example:

public static class AddFunction extends ScalarFunction {
    public Integer eval(Integer a, Integer b) {
        return a != null && b != null ? a + b : null;
    }
    
    public Double eval(Double a, Double b) {
        return a != null && b != null ? a + b : null;
    }
}

// Register and use
tEnv.createFunction("add_func", AddFunction.class);
Table result = tEnv.sqlQuery("SELECT add_func(col1, col2) FROM table1");

Table Functions

abstract class TableFunction<T> extends UserDefinedFunction {
    protected void collect(T row);
    // User implements eval methods
    // public void eval(InputType1 input1, InputType2 input2, ...);
}

Example:

public static class SplitFunction extends TableFunction<String> {
    public void eval(String str) {
        if (str != null) {
            for (String s : str.split(",")) {
                collect(s.trim());
            }
        }
    }
}

// Register and use
tEnv.createFunction("split_func", SplitFunction.class);
Table result = tEnv.sqlQuery(
    "SELECT name, word FROM users CROSS JOIN LATERAL TABLE(split_func(tags)) AS t(word)"
);

Aggregate Functions

abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
    public abstract ACC createAccumulator();
    public abstract T getValue(ACC accumulator);
    
    // Required methods
    public abstract void accumulate(ACC accumulator, InputType1 input1, ...);
    
    // Optional methods
    public void retract(ACC accumulator, InputType1 input1, ...) {}
    public void merge(ACC accumulator, Iterable<ACC> iterable) {}
    public void resetAccumulator(ACC accumulator) {}
}

Example:

public static class WeightedAvgAccum {
    public double sum = 0;
    public int count = 0;
}

public static class WeightedAvgFunction extends AggregateFunction<Double, WeightedAvgAccum> {
    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }
    
    @Override
    public Double getValue(WeightedAvgAccum acc) {
        return acc.count == 0 ? null : acc.sum / acc.count;
    }
    
    public void accumulate(WeightedAvgAccum acc, Double value, Integer weight) {
        if (value != null && weight != null) {
            acc.sum += value * weight;
            acc.count += weight;
        }
    }
    
    public void retract(WeightedAvgAccum acc, Double value, Integer weight) {
        if (value != null && weight != null) {
            acc.sum -= value * weight;
            acc.count -= weight;
        }
    }
    
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        for (WeightedAvgAccum a : it) {
            acc.sum += a.sum;
            acc.count += a.count;
        }
    }
}

// Register and use
tEnv.createFunction("weighted_avg", WeightedAvgFunction.class);
Table result = tEnv.sqlQuery("SELECT weighted_avg(score, weight) FROM tests GROUP BY student_id");

Table Aggregate Functions

abstract class TableAggregateFunction<T, ACC> extends UserDefinedFunction {
    public abstract ACC createAccumulator();
    public abstract void accumulate(ACC accumulator, InputType1 input1, ...);
    public abstract void emitValue(ACC accumulator, Collector<T> out);
    
    // Optional methods
    public void retract(ACC accumulator, InputType1 input1, ...) {}
    public void emitUpdateWithRetract(ACC accumulator, RetractableCollector<T> out) {}
}

Function Annotations

@FunctionHint(
    input = @DataTypeHint("STRING"),
    output = @DataTypeHint("INT")
)
public class MyFunction extends ScalarFunction {
    public Integer eval(String input) {
        return input != null ? input.length() : null;
    }
}

@FunctionHint(
    input = {@DataTypeHint("DECIMAL(10, 2)"), @DataTypeHint("INT")},
    output = @DataTypeHint("DECIMAL(10, 2)")
)
public class MultiplyFunction extends ScalarFunction {
    public BigDecimal eval(BigDecimal value, Integer multiplier) {
        return value != null && multiplier != null ? 
            value.multiply(BigDecimal.valueOf(multiplier)) : null;
    }
}

Function Context

abstract class UserDefinedFunction {
    public void open(FunctionContext context) throws Exception {}
    public void close() throws Exception {}
}

interface FunctionContext {
    MetricGroup getMetricGroup();
    CachedFile getCachedFile(String name);
    File getCachedFile(String name);
    DistributedCache.DistributedCacheEntry getCachedFileEntry(String name);
    int getIndexOfThisSubtask();
    int getNumberOfParallelSubtasks();
    ExecutionConfig getExecutionConfig();
    ClassLoader getUserCodeClassLoader();
}

Example with Context:

public static class LookupFunction extends ScalarFunction {
    private transient Map<String, String> lookupData;
    
    @Override
    public void open(FunctionContext context) throws Exception {
        // Initialize lookup data from cached file
        File file = context.getCachedFile("lookup-data");
        lookupData = loadLookupData(file);
    }
    
    public String eval(String key) {
        return lookupData.get(key);
    }
}

Python Functions

interface TableEnvironment {
    void createTemporarySystemFunction(String name, String fullyQualifiedName);
}

Register Python UDF:

// Register Python function
tEnv.createTemporarySystemFunction("py_upper", "my_module.upper_func");

// Use in SQL  
Table result = tEnv.sqlQuery("SELECT py_upper(name) FROM users");

Built-in Function Override

// Override built-in function
public static class CustomSubstringFunction extends ScalarFunction {
    public String eval(String str, Integer start, Integer length) {
        // Custom substring logic
        return customSubstring(str, start, length);
    }
}

tEnv.createTemporarySystemFunction("SUBSTRING", CustomSubstringFunction.class);

Types

abstract class UserDefinedFunction implements Serializable {
    FunctionKind getKind();
    TypeInference getTypeInference(DataTypeFactory typeFactory);
    Set<FunctionRequirement> getRequirements();
    boolean isDeterministic();
}

enum FunctionKind {
    SCALAR,
    TABLE,
    AGGREGATE,
    TABLE_AGGREGATE,
    ASYNC_TABLE,
    OTHER
}

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
@interface FunctionHint {
    DataTypeHint[] input() default {};
    DataTypeHint output() default @DataTypeHint(value = DataTypeHint.NULL);
    boolean isVarArgs() default false;
}

@Retention(RetentionPolicy.RUNTIME) 
@Target({ElementType.TYPE, ElementType.METHOD, ElementType.PARAMETER})
@interface DataTypeHint {
    String value() default NULL;
    Class<?> bridgedTo() default Object.class;
    int defaultDecimalPrecision() default -1;
    int defaultDecimalScale() default -1;
    boolean allowRawGlobally() default true;
    String NULL = "__NULL__";
}