or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

index.mdsources-sinks.mdsql-integration.mdtable-environment.mdtable-operations.mdtype-system.mduser-defined-functions.mdwindow-operations.md

user-defined-functions.mddocs/

0

# User-Defined Functions

1

2

The Flink Table API supports user-defined functions (UDFs) to extend functionality with custom logic. Three main types are supported: scalar functions (1-to-1), table functions (1-to-N), and aggregate functions (N-to-1).

3

4

## Capabilities

5

6

### Base Function Framework

7

8

All user-defined functions extend from the base UserDefinedFunction class.

9

10

```scala { .api }

11

/**

12

* Base class for all user-defined functions

13

*/

14

abstract class UserDefinedFunction {

15

/**

16

* Initialization method called when function is opened

17

* @param context Function context providing runtime information

18

*/

19

def open(context: FunctionContext): Unit = {}

20

21

/**

22

* Cleanup method called when function is closed

23

*/

24

def close(): Unit = {}

25

26

/**

27

* Indicates whether the function is deterministic

28

* @returns True if function always produces same output for same input

29

*/

30

def isDeterministic: Boolean = true

31

}

32

```

33

34

### Scalar Functions

35

36

Functions that take one or more input values and return a single value.

37

38

```scala { .api }

39

/**

40

* Base class for scalar functions (1-to-1 mapping)

41

*/

42

abstract class ScalarFunction extends UserDefinedFunction {

43

/**

44

* Creates a function call expression for this scalar function

45

* @param params Parameters for the function call

46

* @returns Expression representing the function call

47

*/

48

def apply(params: Expression*): Expression

49

50

/**

51

* Override to specify custom result type

52

* @param signature Array of parameter classes

53

* @returns Type information for the result

54

*/

55

def getResultType(signature: Array[Class[_]]): TypeInformation[_] = null

56

57

/**

58

* Override to specify custom parameter types

59

* @param signature Array of parameter classes

60

* @returns Array of parameter type information

61

*/

62

def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = null

63

}

64

```

65

66

**Usage Examples:**

67

68

```scala

69

// Simple scalar function

70

class AddOne extends ScalarFunction {

71

def eval(x: Int): Int = x + 1

72

}

73

74

// Scalar function with multiple parameters

75

class StringConcat extends ScalarFunction {

76

def eval(a: String, b: String): String = a + b

77

}

78

79

// Scalar function with variable arguments

80

class ConcatWs extends ScalarFunction {

81

def eval(separator: String, strings: String*): String = {

82

strings.mkString(separator)

83

}

84

}

85

86

// Complex scalar function with type override

87

class ParseJson extends ScalarFunction {

88

def eval(jsonStr: String): Row = {

89

// Parse JSON and return Row

90

val parsed = JSON.parse(jsonStr)

91

Row.of(parsed.get("id"), parsed.get("name"))

92

}

93

94

override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {

95

Types.ROW(Array("id", "name"), Array(Types.LONG, Types.STRING))

96

}

97

}

98

99

// Register and use scalar functions

100

tEnv.registerFunction("addOne", new AddOne())

101

tEnv.registerFunction("concat", new StringConcat())

102

103

val result = table.select('id, addOne('age), concat('firstName, 'lastName))

104

val sqlResult = tEnv.sqlQuery("SELECT id, addOne(age), concat(firstName, lastName) FROM Users")

105

```

106

107

### Table Functions

108

109

Functions that take one or more input values and return multiple rows (1-to-N mapping).

110

111

```scala { .api }

112

/**

113

* Base class for table functions (1-to-N mapping)

114

* @tparam T Type of output rows

115

*/

116

abstract class TableFunction[T] extends UserDefinedFunction {

117

/**

118

* Collects an output row

119

* @param result Output row to emit

120

*/

121

protected def collect(result: T): Unit

122

123

/**

124

* Override to specify custom result type

125

* @param signature Array of parameter classes

126

* @returns Type information for the result

127

*/

128

def getResultType(signature: Array[Class[_]]): TypeInformation[T] = null

129

130

/**

131

* Override to specify custom parameter types

132

* @param signature Array of parameter classes

133

* @returns Array of parameter type information

134

*/

135

def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = null

136

}

137

```

138

139

**Usage Examples:**

140

141

```scala

142

// Split string into multiple rows

143

class SplitFunction extends TableFunction[String] {

144

def eval(str: String, separator: String): Unit = {

145

str.split(separator).foreach(collect)

146

}

147

}

148

149

// Parse CSV row into structured data

150

class ParseCsv extends TableFunction[Row] {

151

def eval(csvRow: String): Unit = {

152

val fields = csvRow.split(",")

153

collect(Row.of(fields(0), fields(1).toInt, fields(2).toDouble))

154

}

155

156

override def getResultType(signature: Array[Class[_]]): TypeInformation[Row] = {

157

Types.ROW(Array("name", "age", "salary"), Array(Types.STRING, Types.INT, Types.DOUBLE))

158

}

159

}

160

161

// Generate number sequence

162

class Range extends TableFunction[Long] {

163

def eval(start: Long, end: Long): Unit = {

164

(start until end).foreach(collect)

165

}

166

}

167

168

// Register and use table functions

169

tEnv.registerFunction("split", new SplitFunction())

170

tEnv.registerFunction("range", new Range())

171

172

// Use in join (LATERAL TABLE)

173

val result = table

174

.join(new SplitFunction() as ('word))

175

.select('id, 'text, 'word)

176

177

val sqlResult = tEnv.sqlQuery("""

178

SELECT u.id, u.tags, t.word

179

FROM Users u, LATERAL TABLE(split(u.tags, ',')) as t(word)

180

""")

181

```

182

183

### Aggregate Functions

184

185

Functions that take multiple input values and return a single aggregated result (N-to-1 mapping).

186

187

```scala { .api }

188

/**

189

* Base class for aggregate functions (N-to-1 mapping)

190

* @tparam T Type of the final result

191

* @tparam ACC Type of the accumulator

192

*/

193

abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {

194

/**

195

* Creates and initializes the accumulator

196

* @returns New accumulator instance

197

*/

198

def createAccumulator(): ACC

199

200

/**

201

* Processes an input value and updates the accumulator

202

* @param accumulator Current accumulator

203

* @param input Input values (method should be overloaded for different arities)

204

*/

205

def accumulate(accumulator: ACC, input: Any*): Unit

206

207

/**

208

* Extracts the final result from the accumulator

209

* @param accumulator Final accumulator state

210

* @returns Aggregated result

211

*/

212

def getValue(accumulator: ACC): T

213

214

/**

215

* Retracts an input value from the accumulator (for streaming)

216

* @param accumulator Current accumulator

217

* @param input Input values to retract

218

*/

219

def retract(accumulator: ACC, input: Any*): Unit = {

220

throw new UnsupportedOperationException("retract method is not implemented")

221

}

222

223

/**

224

* Merges multiple accumulators into one (for distributed aggregation)

225

* @param accumulator Target accumulator

226

* @param otherAccumulators Other accumulators to merge

227

*/

228

def merge(accumulator: ACC, otherAccumulators: java.lang.Iterable[ACC]): Unit = {

229

throw new UnsupportedOperationException("merge method is not implemented")

230

}

231

232

/**

233

* Override to specify custom result type

234

* @param signature Array of parameter classes

235

* @returns Type information for the result

236

*/

237

def getResultType(signature: Array[Class[_]]): TypeInformation[T] = null

238

239

/**

240

* Override to specify custom accumulator type

241

* @returns Type information for the accumulator

242

*/

243

def getAccumulatorType: TypeInformation[ACC] = null

244

}

245

```

246

247

**Usage Examples:**

248

249

```scala

250

// Simple sum aggregate function

251

class SumFunction extends AggregateFunction[Long, Long] {

252

override def createAccumulator(): Long = 0L

253

254

def accumulate(acc: Long, value: Long): Unit = {

255

acc + value

256

}

257

258

override def getValue(accumulator: Long): Long = accumulator

259

}

260

261

// Weighted average with complex accumulator

262

case class WeightedAvgAccumulator(sum: Double, count: Long)

263

264

class WeightedAvg extends AggregateFunction[Double, WeightedAvgAccumulator] {

265

override def createAccumulator(): WeightedAvgAccumulator = {

266

WeightedAvgAccumulator(0.0, 0L)

267

}

268

269

def accumulate(acc: WeightedAvgAccumulator, value: Double, weight: Long): Unit = {

270

acc.sum += value * weight

271

acc.count += weight

272

}

273

274

override def getValue(accumulator: WeightedAvgAccumulator): Double = {

275

if (accumulator.count == 0) 0.0 else accumulator.sum / accumulator.count

276

}

277

278

def retract(acc: WeightedAvgAccumulator, value: Double, weight: Long): Unit = {

279

acc.sum -= value * weight

280

acc.count -= weight

281

}

282

283

def merge(acc: WeightedAvgAccumulator, otherAccs: java.lang.Iterable[WeightedAvgAccumulator]): Unit = {

284

otherAccs.forEach { other =>

285

acc.sum += other.sum

286

acc.count += other.count

287

}

288

}

289

}

290

291

// String concatenation aggregate

292

class StringAgg extends AggregateFunction[String, StringBuilder] {

293

override def createAccumulator(): StringBuilder = new StringBuilder()

294

295

def accumulate(acc: StringBuilder, value: String, separator: String): Unit = {

296

if (acc.nonEmpty) acc.append(separator)

297

acc.append(value)

298

}

299

300

override def getValue(accumulator: StringBuilder): String = accumulator.toString

301

}

302

303

// Register and use aggregate functions

304

tEnv.registerFunction("weightedAvg", new WeightedAvg())

305

tEnv.registerFunction("stringAgg", new StringAgg())

306

307

val result = table

308

.groupBy('department)

309

.select('department, weightedAvg('salary, 'years), stringAgg('name, ", "))

310

311

val sqlResult = tEnv.sqlQuery("""

312

SELECT department,

313

weightedAvg(salary, years) as avgSalary,

314

stringAgg(name, ', ') as employees

315

FROM Employees

316

GROUP BY department

317

""")

318

```

319

320

### Function Context

321

322

Runtime context providing access to metrics, cached files, and other runtime information.

323

324

```scala { .api }

325

/**

326

* Context interface providing runtime information to functions

327

*/

328

trait FunctionContext {

329

/**

330

* Gets the metric group for registering custom metrics

331

* @returns Metric group for the function

332

*/

333

def getMetricGroup: MetricGroup

334

335

/**

336

* Gets a cached file by name (distributed cache)

337

* @param name Name of the cached file

338

* @returns File handle to the cached file

339

*/

340

def getCachedFile(name: String): java.io.File

341

342

/**

343

* Gets the job parameter value

344

* @param key Parameter key

345

* @param defaultValue Default value if key not found

346

* @returns Parameter value

347

*/

348

def getJobParameter(key: String, defaultValue: String): String

349

}

350

```

351

352

**Usage Examples:**

353

354

```scala

355

class MetricsFunction extends ScalarFunction {

356

private var counter: Counter = _

357

358

override def open(context: FunctionContext): Unit = {

359

super.open(context)

360

// Register custom metrics

361

counter = context.getMetricGroup.counter("custom_function_calls")

362

}

363

364

def eval(value: String): String = {

365

counter.inc()

366

value.toUpperCase

367

}

368

}

369

370

class ConfigurableFunction extends ScalarFunction {

371

private var multiplier: Double = _

372

373

override def open(context: FunctionContext): Unit = {

374

super.open(context)

375

// Read configuration from job parameters

376

multiplier = context.getJobParameter("multiplier", "1.0").toDouble

377

}

378

379

def eval(value: Double): Double = value * multiplier

380

}

381

```

382

383

### Advanced Function Features

384

385

Additional features for complex function implementations.

386

387

```scala { .api }

388

// Generic function interface for type-safe implementations

389

trait TypeInference {

390

def inferTypes(callContext: CallContext): TypeInference.Result

391

}

392

393

// Function that can be used in both scalar and table contexts

394

abstract class PolymorphicFunction extends UserDefinedFunction {

395

def eval(args: Any*): Any

396

def evalTable(args: Any*): java.lang.Iterable[_]

397

}

398

```

399

400

**Usage Examples:**

401

402

```scala

403

// Function with custom type inference

404

class FlexibleParseFunction extends ScalarFunction with TypeInference {

405

def eval(input: String, format: String): Any = {

406

format match {

407

case "int" => input.toInt

408

case "double" => input.toDouble

409

case "boolean" => input.toBoolean

410

case _ => input

411

}

412

}

413

414

override def inferTypes(callContext: CallContext): TypeInference.Result = {

415

// Custom type inference logic based on parameters

416

val formatLiteral = callContext.getArgumentValue(1, classOf[String])

417

val resultType = formatLiteral.orElse("string") match {

418

case "int" => Types.INT

419

case "double" => Types.DOUBLE

420

case "boolean" => Types.BOOLEAN

421

case _ => Types.STRING

422

}

423

TypeInference.Result.success(resultType)

424

}

425

}

426

```

427

428

## Function Registration and Usage Patterns

429

430

```scala { .api }

431

// Registration methods on TableEnvironment

432

def registerFunction(name: String, function: ScalarFunction): Unit

433

def registerFunction(name: String, function: TableFunction[_]): Unit

434

def registerFunction(name: String, function: AggregateFunction[_, _]): Unit

435

436

// Usage in Table API (Scala)

437

table.select('field, myScalarFunction('input))

438

table.join(myTableFunction('field) as ('output))

439

table.groupBy('key).select('key, myAggregateFunction('value))

440

441

// Usage in SQL

442

tEnv.sqlQuery("SELECT field, myScalarFunction(input) FROM MyTable")

443

tEnv.sqlQuery("SELECT * FROM MyTable, LATERAL TABLE(myTableFunction(field)) as t(output)")

444

tEnv.sqlQuery("SELECT key, myAggregateFunction(value) FROM MyTable GROUP BY key")

445

```

446

447

## Types

448

449

```scala { .api }

450

abstract class UserDefinedFunction

451

abstract class ScalarFunction extends UserDefinedFunction

452

abstract class TableFunction[T] extends UserDefinedFunction

453

abstract class AggregateFunction[T, ACC] extends UserDefinedFunction

454

455

trait FunctionContext

456

trait MetricGroup

457

case class Counter()

458

459

// Exception types

460

class ValidationException(message: String) extends RuntimeException(message)

461

```