Function interfaces for implementing custom transformation logic. These abstract base classes define the contracts for various processing patterns, enabling users to implement custom business logic while leveraging Flink's distributed execution capabilities.
Abstract base class for all user-defined functions providing common infrastructure.
class Function:
"""
Abstract base class for all user-defined functions.
Provides common functionality for function execution including
configuration, lifecycle management, and error handling.
"""
def _run(self):
"""Abstract method implemented by subclasses for function execution."""
def _configure(self, input_file, output_file, mmap_size, port, env, info, subtask_index):
"""Sets up function execution context with runtime parameters."""
def _close(self):
"""Cleanup method called after function execution."""Transforms each input element to exactly one output element.
class MapFunction(Function):
def map(self, value):
"""
Transforms single input element to single output element.
Parameters:
value: Input element of any type
Returns:
Transformed element (can be different type)
"""
def collect(self, value):
"""
Internal method for collecting transformed values.
Parameters:
value: Input value to transform and collect
"""Transforms each input element to zero or more output elements.
class FlatMapFunction(Function):
def flat_map(self, value, collector):
"""
Transforms single input to zero or more outputs.
Parameters:
value: Input element
collector: Output collector - call collector.collect(output) for each result
"""
def collect(self, value):
"""
Internal method for collecting values using the flat_map transformation.
Parameters:
value: Input value to transform via flat_map
"""Determines whether elements should be included in the result.
class FilterFunction(Function):
def filter(self, value):
"""
Predicate function to include/exclude elements.
Parameters:
value: Input element to test
Returns:
bool: True to include element, False to exclude
"""Processes entire partitions of data rather than individual elements.
class MapPartitionFunction(Function):
def map_partition(self, iterator, collector):
"""
Processes entire partition of elements.
Allows for more efficient processing when setup/cleanup costs are high
or when processing requires access to multiple elements.
Parameters:
iterator: Iterator over all elements in the partition
collector: Output collector - call collector.collect(output) for each result
"""Combines two elements into one of the same type.
class ReduceFunction(Function):
def reduce(self, value1, value2):
"""
Combines two elements into one of the same type.
This function is applied associatively to reduce a set of elements
down to a single element.
Parameters:
value1: First element to combine
value2: Second element to combine
Returns:
Combined element of the same type as inputs
"""
def combine(self, value1, value2):
"""
Optional combiner function for partial aggregation.
Used for optimization - should have same semantics as reduce().
Parameters:
value1: First element to combine
value2: Second element to combine
Returns:
Combined element of the same type as inputs
"""Processes groups of elements with the same key.
class GroupReduceFunction(Function):
def reduce(self, iterator, collector):
"""
Processes group of elements, emitting zero or more results.
Called once per group (or once for entire DataSet if not grouped).
Can iterate over all elements in the group and emit any number of results.
Parameters:
iterator: Iterator over all elements in the group
collector: Output collector - call collector.collect(output) for each result
"""
def combine(self, iterator, collector):
"""
Optional combiner for partial aggregation within partitions.
Used for optimization - should produce partial results that can be
further reduced by the reduce() method.
Parameters:
iterator: Iterator over elements in the partition
collector: Output collector for partial results
"""Combines matching elements from two DataSets.
class JoinFunction(Function):
def join(self, value1, value2):
"""
Combines matching elements from two DataSets.
Called for each pair of elements with matching keys.
Parameters:
value1: Element from first DataSet
value2: Element from second DataSet
Returns:
Combined element (can be any type)
"""Processes groups from two DataSets with matching keys.
class CoGroupFunction(Function):
def co_group(self, iterator1, iterator2, collector):
"""
Processes groups from two DataSets with the same key.
Called once per key, even if one or both groups are empty.
Useful for implementing outer joins and complex multi-input operations.
Parameters:
iterator1: Iterator over elements from first DataSet with this key
iterator2: Iterator over elements from second DataSet with this key
collector: Output collector for results
"""Combines elements in a cross product (Cartesian product).
class CrossFunction(Function):
def cross(self, value1, value2):
"""
Combines elements in cross product.
Called for every combination of elements from two DataSets.
Parameters:
value1: Element from first DataSet
value2: Element from second DataSet
Returns:
Combined element (can be any type)
"""Extracts keys from elements for grouping and joining operations.
class KeySelectorFunction(Function):
def get_key(self, value):
"""
Extracts key from element for grouping/joining.
Parameters:
value: Input element
Returns:
Key value used for grouping/joining
"""Provides runtime information and services to functions.
class RuntimeContext:
def get_broadcast_variable(self, name):
"""
Accesses broadcast variable by name.
Parameters:
name (str): Name of the broadcast variable
Returns:
Broadcast variable value
"""
def get_index_of_this_subtask(self):
"""
Gets index of current parallel subtask.
Returns:
int: Zero-based subtask index
"""class Sum:
"""Aggregation for summing numeric values."""
class Min:
"""Aggregation for finding minimum values."""
class Max:
"""Aggregation for finding maximum values."""class AggregationFunction:
"""Combines multiple aggregations on different fields."""
def add_aggregation(self, aggregation, field):
"""
Adds additional aggregation to different field.
Parameters:
aggregation (Aggregation): Aggregation type (Sum, Min, Max)
field (int): Field index to aggregate
Returns:
AggregationFunction: Self for method chaining
"""from flink.functions.MapFunction import MapFunction
class DoubleValue(MapFunction):
def map(self, value):
return value * 2
# Usage
data = env.from_elements(1, 2, 3, 4, 5)
doubled = data.map(DoubleValue())
# Or using lambda
doubled = data.map(lambda x: x * 2)from flink.functions.FlatMapFunction import FlatMapFunction
class Tokenizer(FlatMapFunction):
def flat_map(self, line, collector):
words = line.lower().split()
for word in words:
collector.collect(word)
# Usage
text = env.from_elements("hello world", "flink python", "data processing")
words = text.flat_map(Tokenizer())from flink.functions.FilterFunction import FilterFunction
class EvenNumberFilter(FilterFunction):
def filter(self, value):
return value % 2 == 0
# Usage
numbers = env.from_elements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
evens = numbers.filter(EvenNumberFilter())from flink.functions.GroupReduceFunction import GroupReduceFunction
class WordCounter(GroupReduceFunction):
def reduce(self, iterator, collector):
word = None
count = 0
for element in iterator:
if word is None:
word = element
count += 1
collector.collect((word, count))
# Usage
words = text.flat_map(Tokenizer())
word_counts = words.group_by(0).reduce_group(WordCounter())from flink.functions.ReduceFunction import ReduceFunction
class SumReduce(ReduceFunction):
def reduce(self, value1, value2):
return value1 + value2
def combine(self, value1, value2):
# Same implementation for this simple case
return value1 + value2
# Usage
numbers = env.from_elements(1, 2, 3, 4, 5)
total = numbers.reduce(SumReduce())from flink.functions.MapPartitionFunction import MapPartitionFunction
class BatchProcessor(MapPartitionFunction):
def map_partition(self, iterator, collector):
# Setup expensive resources once per partition
processor = ExpensiveProcessor()
batch = []
for element in iterator:
batch.append(element)
# Process in batches of 100
if len(batch) >= 100:
results = processor.process_batch(batch)
for result in results:
collector.collect(result)
batch = []
# Process remaining elements
if batch:
results = processor.process_batch(batch)
for result in results:
collector.collect(result)
# Cleanup
processor.close()
# Usage
large_dataset = env.read_csv("large_file.csv", [str, int, float])
processed = large_dataset.map_partition(BatchProcessor())from flink.functions.JoinFunction import JoinFunction
class CustomerOrderJoin(JoinFunction):
def join(self, customer, order):
return {
'customer_id': customer[0],
'customer_name': customer[1],
'order_id': order[0],
'order_amount': order[2]
}
# Usage
customers = env.read_csv("customers.csv", [str, str])
orders = env.read_csv("orders.csv", [int, str, float])
result = customers.join(orders) \
.where(0) \
.equal_to(1) \
.using(CustomerOrderJoin())from flink.functions.CoGroupFunction import CoGroupFunction
class LeftOuterJoin(CoGroupFunction):
def co_group(self, iterator1, iterator2, collector):
left_items = list(iterator1)
right_items = list(iterator2)
if not left_items:
return # No items in left dataset for this key
if not right_items:
# Left outer join - emit left items with null right side
for left_item in left_items:
collector.collect((left_item, None))
else:
# Inner join - emit all combinations
for left_item in left_items:
for right_item in right_items:
collector.collect((left_item, right_item))
# Usage
result = dataset1.co_group(dataset2) \
.where(0) \
.equal_to(0) \
.using(LeftOuterJoin())