Stream partitioning controls how data is distributed across parallel operators in Flink streaming applications. Proper partitioning strategies are crucial for load balancing, performance, and ensuring correct results in parallel processing.
class DataStream[T] {
def keyBy[K: TypeInformation](fun: T => K): KeyedStream[T, K]
def keyBy(fields: Int*): KeyedStream[T, _]
def keyBy(firstField: String, otherFields: String*): KeyedStream[T, _]
}Partition elements by key for stateful processing:
import org.apache.flink.streaming.api.scala._
val env = StreamExecutionEnvironment.getExecutionEnvironment
// Key by function
case class User(id: String, name: String, age: Int)
val users = env.fromElements(
User("1", "Alice", 25),
User("2", "Bob", 30),
User("1", "Alice", 26)
)
val keyedByUserId = users.keyBy(_.id)
// Key by field position (for tuples)
val salesData = env.fromElements(("ProductA", 100), ("ProductB", 200), ("ProductA", 150))
val keyedBySalesPosition = salesData.keyBy(0) // Key by product name
// Key by field name
val keyedByUserField = users.keyBy("id")
// Key by multiple fields
val keyedByMultipleFields = users.keyBy("id", "age")class DataStream[T] {
def broadcast: DataStream[T]
}Send each element to all downstream operators:
val env = StreamExecutionEnvironment.getExecutionEnvironment
val configData = env.fromElements("config1", "config2", "config3")
// Broadcast configuration to all parallel instances
val broadcastedConfig = configData.broadcast
broadcastedConfig
.map(config => s"Processing with config: $config")
.setParallelism(4) // All 4 parallel instances receive all elements
.print()class DataStream[T] {
def global: DataStream[T]
}Send all elements to the first downstream operator instance:
val env = StreamExecutionEnvironment.getExecutionEnvironment
val stream = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
// Send all elements to first instance (parallelism effectively becomes 1)
val globalStream = stream.global
.map(x => s"Processed on single instance: $x")
.print()class DataStream[T] {
def shuffle: DataStream[T]
}Randomly distribute elements across downstream operators:
val env = StreamExecutionEnvironment.getExecutionEnvironment
val stream = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
// Random distribution for load balancing
val shuffledStream = stream.shuffle
.map(x => s"Randomly assigned: $x")
.setParallelism(4)
.print()class DataStream[T] {
def rebalance: DataStream[T]
}Distribute elements evenly across downstream operators using round-robin:
val env = StreamExecutionEnvironment.getExecutionEnvironment
val stream = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
// Even distribution using round-robin
val rebalancedStream = stream.rebalance
.map(x => s"Evenly distributed: $x")
.setParallelism(4)
.print()class DataStream[T] {
def rescale: DataStream[T]
}Distribute elements to a subset of downstream operators:
val env = StreamExecutionEnvironment.getExecutionEnvironment
val stream = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
.setParallelism(2)
// Rescale to subset of downstream operators
val rescaledStream = stream.rescale
.map(x => s"Rescaled: $x")
.setParallelism(6) // Each upstream instance sends to 3 downstream instances
.print()class DataStream[T] {
def forward: DataStream[T]
}Forward elements to local downstream operators (same machine):
val env = StreamExecutionEnvironment.getExecutionEnvironment
val stream = env.fromElements(1, 2, 3, 4, 5)
// Forward to collocated downstream operators
val forwardedStream = stream.forward
.map(x => s"Locally forwarded: $x")
.print()class DataStream[T] {
def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], field: Int): DataStream[T]
def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], field: String): DataStream[T]
def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], fun: T => K): DataStream[T]
}Implement custom partitioning logic:
import org.apache.flink.api.common.functions.Partitioner
// Custom partitioner for even/odd distribution
class EvenOddPartitioner extends Partitioner[Int] {
override def partition(key: Int, numPartitions: Int): Int = {
if (key % 2 == 0) 0 else 1 // Even numbers to partition 0, odd to partition 1
}
}
// Range-based partitioner
class RangePartitioner extends Partitioner[Int] {
override def partition(key: Int, numPartitions: Int): Int = {
val range = 100 / numPartitions
Math.min(key / range, numPartitions - 1)
}
}
val env = StreamExecutionEnvironment.getExecutionEnvironment
val numbers = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 75, 99)
// Custom partition by function
val evenOddPartitioned = numbers
.partitionCustom(new EvenOddPartitioner, identity)
.map(x => s"EvenOdd partitioned: $x")
.print()
// Custom partition by field (for tuples/case classes)
case class ValueWithCategory(value: Int, category: String)
val categorizedData = env.fromElements(
ValueWithCategory(10, "A"),
ValueWithCategory(20, "B"),
ValueWithCategory(30, "A")
)
class CategoryPartitioner extends Partitioner[String] {
override def partition(key: String, numPartitions: Int): Int = {
key.hashCode % numPartitions
}
}
val categoryPartitioned = categorizedData
.partitionCustom(new CategoryPartitioner, "category")
.print()val env = StreamExecutionEnvironment.getExecutionEnvironment
// Avoid skewed data distribution
case class Event(userId: String, data: String)
val events = env.fromElements(
Event("popular_user", "data1"), // This user might cause skew
Event("user2", "data2"),
Event("popular_user", "data3")
)
// Instead of simple keyBy which might cause skew:
// val skewed = events.keyBy(_.userId)
// Use custom partitioning for better distribution:
class SkewAwarePartitioner extends Partitioner[String] {
override def partition(key: String, numPartitions: Int): Int = {
if (key == "popular_user") {
// Distribute popular user across multiple partitions
(key + System.currentTimeMillis()).hashCode % numPartitions
} else {
key.hashCode % numPartitions
}
}
}
val balancedEvents = events
.partitionCustom(new SkewAwarePartitioner, _.userId)// Use forward() for operators that should stay on same machine
val localProcessing = stream
.map(_.toUpperCase) // CPU-intensive operation
.forward // Keep on same machine
.filter(_.length > 5) // Another local operation
// Use rebalance() after operations that might cause imbalance
val afterFilter = stream
.filter(_.contains("important")) // Might reduce volume significantly
.rebalance // Redistribute remaining elements evenly
.map(complexProcessing)// Partitioning affects state access
val keyedStream = events.keyBy(_.userId) // State is partitioned by userId
// Changing partitioning loses access to keyed state
val repartitioned = keyedStream
.map(processEvent) // Can access keyed state here
.rebalance // Repartitioning - no more keyed state access
.map(postProcess) // No keyed state access hereimport org.apache.flink.streaming.api.scala._
import org.apache.flink.api.common.functions.Partitioner
case class LogEntry(
timestamp: Long,
level: String,
service: String,
message: String,
userId: Option[String]
)
case class ProcessingStats(
service: String,
errorCount: Int,
warningCount: Int,
totalCount: Int
)
object MultiStageProcessingPipeline {
// Custom partitioner for log levels
class LogLevelPartitioner extends Partitioner[String] {
override def partition(key: String, numPartitions: Int): Int = {
key match {
case "ERROR" => 0
case "WARN" => 1
case "INFO" => 2
case _ => 3
}
}
}
// Load-balancing partitioner for services
class ServicePartitioner extends Partitioner[String] {
override def partition(key: String, numPartitions: Int): Int = {
// Use hash for even distribution
Math.abs(key.hashCode) % numPartitions
}
}
def main(args: Array[String]): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(8)
// Sample log data
val logs = env.fromElements(
LogEntry(1000, "ERROR", "service-a", "Database connection failed", Some("user1")),
LogEntry(1001, "INFO", "service-b", "Request processed", Some("user2")),
LogEntry(1002, "WARN", "service-a", "High memory usage", None),
LogEntry(1003, "ERROR", "service-c", "Authentication failed", Some("user3"))
)
// Stage 1: Parse and clean data (broadcast configuration)
val config = env.fromElements("config-item-1", "config-item-2")
.broadcast // Configuration needed by all instances
// Stage 2: Partition by log level for specialized processing
val logsByLevel = logs
.partitionCustom(new LogLevelPartitioner, _.level)
.map { log =>
// Process based on log level
val priority = log.level match {
case "ERROR" => 1
case "WARN" => 2
case "INFO" => 3
case _ => 4
}
(log, priority)
}
// Stage 3: Key by service for stateful aggregation
val serviceStats = logsByLevel
.map(_._1) // Extract log entry
.keyBy(_.service) // Key by service for stateful processing
.map { log =>
// Simulate stateful processing (count by service)
log.service match {
case service => ProcessingStats(service,
if (log.level == "ERROR") 1 else 0,
if (log.level == "WARN") 1 else 0,
1)
}
}
// Stage 4: Rebalance for final processing
val finalResults = serviceStats
.rebalance // Even distribution for final processing
.map(stats => s"Service ${stats.service}: ${stats.totalCount} total, ${stats.errorCount} errors")
// Stage 5: Global aggregation (all data to single instance)
val globalSummary = logs
.global // Send all to single instance for global stats
.map(_ => 1)
.reduce(_ + _)
.map(count => s"Total log entries processed: $count")
// Stage 6: User-specific processing with custom partitioning
val userLogs = logs
.filter(_.userId.isDefined)
.partitionCustom(new ServicePartitioner, log => log.userId.get)
.keyBy(_.userId.get)
.map(log => (log.userId.get, 1))
.reduce((a, b) => (a._1, a._2 + b._2))
// Print results
finalResults.print("Service Stats")
globalSummary.print("Global Summary")
userLogs.print("User Activity")
env.execute("Multi-Stage Processing Pipeline")
}
}