Apache Flink Table API for SQL-like operations on streaming and batch data
—
The Flink Table API supports user-defined functions (UDFs) to extend functionality with custom logic. Three main types are supported: scalar functions (1-to-1), table functions (1-to-N), and aggregate functions (N-to-1).
All user-defined functions extend from the base UserDefinedFunction class.
/**
* Base class for all user-defined functions
*/
abstract class UserDefinedFunction {
/**
* Initialization method called when function is opened
* @param context Function context providing runtime information
*/
def open(context: FunctionContext): Unit = {}
/**
* Cleanup method called when function is closed
*/
def close(): Unit = {}
/**
* Indicates whether the function is deterministic
* @returns True if function always produces same output for same input
*/
def isDeterministic: Boolean = true
}Functions that take one or more input values and return a single value.
/**
* Base class for scalar functions (1-to-1 mapping)
*/
abstract class ScalarFunction extends UserDefinedFunction {
/**
* Creates a function call expression for this scalar function
* @param params Parameters for the function call
* @returns Expression representing the function call
*/
def apply(params: Expression*): Expression
/**
* Override to specify custom result type
* @param signature Array of parameter classes
* @returns Type information for the result
*/
def getResultType(signature: Array[Class[_]]): TypeInformation[_] = null
/**
* Override to specify custom parameter types
* @param signature Array of parameter classes
* @returns Array of parameter type information
*/
def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = null
}Usage Examples:
// Simple scalar function
class AddOne extends ScalarFunction {
def eval(x: Int): Int = x + 1
}
// Scalar function with multiple parameters
class StringConcat extends ScalarFunction {
def eval(a: String, b: String): String = a + b
}
// Scalar function with variable arguments
class ConcatWs extends ScalarFunction {
def eval(separator: String, strings: String*): String = {
strings.mkString(separator)
}
}
// Complex scalar function with type override
class ParseJson extends ScalarFunction {
def eval(jsonStr: String): Row = {
// Parse JSON and return Row
val parsed = JSON.parse(jsonStr)
Row.of(parsed.get("id"), parsed.get("name"))
}
override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
Types.ROW(Array("id", "name"), Array(Types.LONG, Types.STRING))
}
}
// Register and use scalar functions
tEnv.registerFunction("addOne", new AddOne())
tEnv.registerFunction("concat", new StringConcat())
val result = table.select('id, addOne('age), concat('firstName, 'lastName))
val sqlResult = tEnv.sqlQuery("SELECT id, addOne(age), concat(firstName, lastName) FROM Users")Functions that take one or more input values and return multiple rows (1-to-N mapping).
/**
* Base class for table functions (1-to-N mapping)
* @tparam T Type of output rows
*/
abstract class TableFunction[T] extends UserDefinedFunction {
/**
* Collects an output row
* @param result Output row to emit
*/
protected def collect(result: T): Unit
/**
* Override to specify custom result type
* @param signature Array of parameter classes
* @returns Type information for the result
*/
def getResultType(signature: Array[Class[_]]): TypeInformation[T] = null
/**
* Override to specify custom parameter types
* @param signature Array of parameter classes
* @returns Array of parameter type information
*/
def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = null
}Usage Examples:
// Split string into multiple rows
class SplitFunction extends TableFunction[String] {
def eval(str: String, separator: String): Unit = {
str.split(separator).foreach(collect)
}
}
// Parse CSV row into structured data
class ParseCsv extends TableFunction[Row] {
def eval(csvRow: String): Unit = {
val fields = csvRow.split(",")
collect(Row.of(fields(0), fields(1).toInt, fields(2).toDouble))
}
override def getResultType(signature: Array[Class[_]]): TypeInformation[Row] = {
Types.ROW(Array("name", "age", "salary"), Array(Types.STRING, Types.INT, Types.DOUBLE))
}
}
// Generate number sequence
class Range extends TableFunction[Long] {
def eval(start: Long, end: Long): Unit = {
(start until end).foreach(collect)
}
}
// Register and use table functions
tEnv.registerFunction("split", new SplitFunction())
tEnv.registerFunction("range", new Range())
// Use in join (LATERAL TABLE)
val result = table
.join(new SplitFunction() as ('word))
.select('id, 'text, 'word)
val sqlResult = tEnv.sqlQuery("""
SELECT u.id, u.tags, t.word
FROM Users u, LATERAL TABLE(split(u.tags, ',')) as t(word)
""")Functions that take multiple input values and return a single aggregated result (N-to-1 mapping).
/**
* Base class for aggregate functions (N-to-1 mapping)
* @tparam T Type of the final result
* @tparam ACC Type of the accumulator
*/
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
/**
* Creates and initializes the accumulator
* @returns New accumulator instance
*/
def createAccumulator(): ACC
/**
* Processes an input value and updates the accumulator
* @param accumulator Current accumulator
* @param input Input values (method should be overloaded for different arities)
*/
def accumulate(accumulator: ACC, input: Any*): Unit
/**
* Extracts the final result from the accumulator
* @param accumulator Final accumulator state
* @returns Aggregated result
*/
def getValue(accumulator: ACC): T
/**
* Retracts an input value from the accumulator (for streaming)
* @param accumulator Current accumulator
* @param input Input values to retract
*/
def retract(accumulator: ACC, input: Any*): Unit = {
throw new UnsupportedOperationException("retract method is not implemented")
}
/**
* Merges multiple accumulators into one (for distributed aggregation)
* @param accumulator Target accumulator
* @param otherAccumulators Other accumulators to merge
*/
def merge(accumulator: ACC, otherAccumulators: java.lang.Iterable[ACC]): Unit = {
throw new UnsupportedOperationException("merge method is not implemented")
}
/**
* Override to specify custom result type
* @param signature Array of parameter classes
* @returns Type information for the result
*/
def getResultType(signature: Array[Class[_]]): TypeInformation[T] = null
/**
* Override to specify custom accumulator type
* @returns Type information for the accumulator
*/
def getAccumulatorType: TypeInformation[ACC] = null
}Usage Examples:
// Simple sum aggregate function
class SumFunction extends AggregateFunction[Long, Long] {
override def createAccumulator(): Long = 0L
def accumulate(acc: Long, value: Long): Unit = {
acc + value
}
override def getValue(accumulator: Long): Long = accumulator
}
// Weighted average with complex accumulator
case class WeightedAvgAccumulator(sum: Double, count: Long)
class WeightedAvg extends AggregateFunction[Double, WeightedAvgAccumulator] {
override def createAccumulator(): WeightedAvgAccumulator = {
WeightedAvgAccumulator(0.0, 0L)
}
def accumulate(acc: WeightedAvgAccumulator, value: Double, weight: Long): Unit = {
acc.sum += value * weight
acc.count += weight
}
override def getValue(accumulator: WeightedAvgAccumulator): Double = {
if (accumulator.count == 0) 0.0 else accumulator.sum / accumulator.count
}
def retract(acc: WeightedAvgAccumulator, value: Double, weight: Long): Unit = {
acc.sum -= value * weight
acc.count -= weight
}
def merge(acc: WeightedAvgAccumulator, otherAccs: java.lang.Iterable[WeightedAvgAccumulator]): Unit = {
otherAccs.forEach { other =>
acc.sum += other.sum
acc.count += other.count
}
}
}
// String concatenation aggregate
class StringAgg extends AggregateFunction[String, StringBuilder] {
override def createAccumulator(): StringBuilder = new StringBuilder()
def accumulate(acc: StringBuilder, value: String, separator: String): Unit = {
if (acc.nonEmpty) acc.append(separator)
acc.append(value)
}
override def getValue(accumulator: StringBuilder): String = accumulator.toString
}
// Register and use aggregate functions
tEnv.registerFunction("weightedAvg", new WeightedAvg())
tEnv.registerFunction("stringAgg", new StringAgg())
val result = table
.groupBy('department)
.select('department, weightedAvg('salary, 'years), stringAgg('name, ", "))
val sqlResult = tEnv.sqlQuery("""
SELECT department,
weightedAvg(salary, years) as avgSalary,
stringAgg(name, ', ') as employees
FROM Employees
GROUP BY department
""")Runtime context providing access to metrics, cached files, and other runtime information.
/**
* Context interface providing runtime information to functions
*/
trait FunctionContext {
/**
* Gets the metric group for registering custom metrics
* @returns Metric group for the function
*/
def getMetricGroup: MetricGroup
/**
* Gets a cached file by name (distributed cache)
* @param name Name of the cached file
* @returns File handle to the cached file
*/
def getCachedFile(name: String): java.io.File
/**
* Gets the job parameter value
* @param key Parameter key
* @param defaultValue Default value if key not found
* @returns Parameter value
*/
def getJobParameter(key: String, defaultValue: String): String
}Usage Examples:
class MetricsFunction extends ScalarFunction {
private var counter: Counter = _
override def open(context: FunctionContext): Unit = {
super.open(context)
// Register custom metrics
counter = context.getMetricGroup.counter("custom_function_calls")
}
def eval(value: String): String = {
counter.inc()
value.toUpperCase
}
}
class ConfigurableFunction extends ScalarFunction {
private var multiplier: Double = _
override def open(context: FunctionContext): Unit = {
super.open(context)
// Read configuration from job parameters
multiplier = context.getJobParameter("multiplier", "1.0").toDouble
}
def eval(value: Double): Double = value * multiplier
}Additional features for complex function implementations.
// Generic function interface for type-safe implementations
trait TypeInference {
def inferTypes(callContext: CallContext): TypeInference.Result
}
// Function that can be used in both scalar and table contexts
abstract class PolymorphicFunction extends UserDefinedFunction {
def eval(args: Any*): Any
def evalTable(args: Any*): java.lang.Iterable[_]
}Usage Examples:
// Function with custom type inference
class FlexibleParseFunction extends ScalarFunction with TypeInference {
def eval(input: String, format: String): Any = {
format match {
case "int" => input.toInt
case "double" => input.toDouble
case "boolean" => input.toBoolean
case _ => input
}
}
override def inferTypes(callContext: CallContext): TypeInference.Result = {
// Custom type inference logic based on parameters
val formatLiteral = callContext.getArgumentValue(1, classOf[String])
val resultType = formatLiteral.orElse("string") match {
case "int" => Types.INT
case "double" => Types.DOUBLE
case "boolean" => Types.BOOLEAN
case _ => Types.STRING
}
TypeInference.Result.success(resultType)
}
}// Registration methods on TableEnvironment
def registerFunction(name: String, function: ScalarFunction): Unit
def registerFunction(name: String, function: TableFunction[_]): Unit
def registerFunction(name: String, function: AggregateFunction[_, _]): Unit
// Usage in Table API (Scala)
table.select('field, myScalarFunction('input))
table.join(myTableFunction('field) as ('output))
table.groupBy('key).select('key, myAggregateFunction('value))
// Usage in SQL
tEnv.sqlQuery("SELECT field, myScalarFunction(input) FROM MyTable")
tEnv.sqlQuery("SELECT * FROM MyTable, LATERAL TABLE(myTableFunction(field)) as t(output)")
tEnv.sqlQuery("SELECT key, myAggregateFunction(value) FROM MyTable GROUP BY key")abstract class UserDefinedFunction
abstract class ScalarFunction extends UserDefinedFunction
abstract class TableFunction[T] extends UserDefinedFunction
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction
trait FunctionContext
trait MetricGroup
case class Counter()
// Exception types
class ValidationException(message: String) extends RuntimeException(message)Install with Tessl CLI
npx tessl i tessl/maven-org-apache-flink--flink-table-2-11