KeyedStream represents a DataStream that has been partitioned by key, enabling stateful operations, aggregations, and key-specific processing with exactly-once consistency guarantees.
class DataStream[T] {
def keyBy[K: TypeInformation](fun: T => K): KeyedStream[T, K]
def keyBy(fields: Int*): KeyedStream[T, _]
def keyBy(firstField: String, otherFields: String*): KeyedStream[T, _]
}Partition streams by key:
import org.apache.flink.streaming.api.scala._
val env = StreamExecutionEnvironment.getExecutionEnvironment
// Key by function
case class Event(userId: String, eventType: String, timestamp: Long)
val events = env.fromElements(
Event("user1", "login", 1000),
Event("user2", "click", 1001),
Event("user1", "logout", 1002)
)
val keyedByUser = events.keyBy(_.userId)
// Key by field position (for tuples)
val tupleStream = env.fromElements(("user1", 100), ("user2", 200), ("user1", 150))
val keyedByPosition = tupleStream.keyBy(0) // Key by first element
// Key by field name (for case classes)
val keyedByField = events.keyBy("userId")
// Key by multiple fields
val multiFieldKey = events.keyBy("userId", "eventType")// Complex key selector
val keyedByComplex = events.keyBy(event => (event.userId, event.eventType))
// Key by computed value
val keyedByHour = events.keyBy(event => event.timestamp / 3600000) // Hour bucket
// Key by hash for load balancing
val keyedByHash = events.keyBy(_.userId.hashCode % 10)class KeyedStream[T, K] {
def mapWithState[R: TypeInformation, S: TypeInformation](
fun: (T, Option[S]) => (R, Option[S])
): DataStream[R]
}Stateful mapping operations:
case class Counter(count: Int)
val env = StreamExecutionEnvironment.getExecutionEnvironment
val keyedStream = env.fromElements(("a", 1), ("b", 2), ("a", 3), ("b", 4))
.keyBy(0)
// Running sum with state
val runningSums = keyedStream.mapWithState { (value, state) =>
val currentSum = state.getOrElse(0) + value._2
((value._1, currentSum), Some(currentSum))
}
// Stateful transformation with complex state
case class UserSession(sessionId: String, eventCount: Int, lastActivity: Long)
val sessionUpdates = events.keyBy(_.userId)
.mapWithState { (event, sessionState) =>
val currentTime = System.currentTimeMillis()
val session = sessionState match {
case Some(s) if currentTime - s.lastActivity < 300000 => // 5 minutes
s.copy(eventCount = s.eventCount + 1, lastActivity = currentTime)
case _ =>
UserSession(java.util.UUID.randomUUID().toString, 1, currentTime)
}
((event.userId, session), Some(session))
}class KeyedStream[T, K] {
def filterWithState[S: TypeInformation](
fun: (T, Option[S]) => (Boolean, Option[S])
): DataStream[T]
}Stateful filtering:
val env = StreamExecutionEnvironment.getExecutionEnvironment
val events = env.fromElements(1, 2, 3, 2, 4, 3, 5, 6, 1)
.keyBy(identity)
// Filter out duplicates using state
val deduplicatedEvents = events.filterWithState { (value, seen) =>
val alreadySeen = seen.contains(value)
(!alreadySeen, Some(value))
}
// Rate limiting with state
case class RateLimitState(count: Int, windowStart: Long)
val rateLimitedEvents = events.filterWithState { (value, state) =>
val currentTime = System.currentTimeMillis()
val windowSize = 60000 // 1 minute window
val currentState = state match {
case Some(s) if currentTime - s.windowStart < windowSize =>
s.copy(count = s.count + 1)
case _ =>
RateLimitState(1, currentTime)
}
val allowed = currentState.count <= 10 // Max 10 events per minute
(allowed, Some(currentState))
}class KeyedStream[T, K] {
def flatMapWithState[R: TypeInformation, S: TypeInformation](
fun: (T, Option[S]) => (TraversableOnce[R], Option[S])
): DataStream[R]
}Stateful one-to-many transformations:
case class BufferState(buffer: List[Int])
val env = StreamExecutionEnvironment.getExecutionEnvironment
val numbers = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
.keyBy(_ % 3) // Group by remainder
// Batch processing with state
val batched = numbers.flatMapWithState { (value, state) =>
val currentBuffer = state.map(_.buffer).getOrElse(List.empty)
val newBuffer = currentBuffer :+ value
if (newBuffer.size >= 3) {
(List(newBuffer.sum), Some(BufferState(List.empty)))
} else {
(List.empty, Some(BufferState(newBuffer)))
}
}class KeyedStream[T, K] {
def reduce(fun: (T, T) => T): DataStream[T]
def reduce(reducer: ReduceFunction[T]): DataStream[T]
}Incremental aggregations:
import org.apache.flink.api.common.functions.ReduceFunction
val env = StreamExecutionEnvironment.getExecutionEnvironment
// Reduce with lambda
val numbers = env.fromElements((1, 10), (2, 20), (1, 15), (2, 25))
.keyBy(0)
val sums = numbers.reduce((a, b) => (a._1, a._2 + b._2))
// Reduce with ReduceFunction
case class Metrics(userId: String, totalClicks: Int, totalRevenue: Double)
class MetricsReducer extends ReduceFunction[Metrics] {
override def reduce(m1: Metrics, m2: Metrics): Metrics = {
Metrics(m1.userId, m1.totalClicks + m2.totalClicks, m1.totalRevenue + m2.totalRevenue)
}
}
val userMetrics = env.fromElements(
Metrics("user1", 5, 12.50),
Metrics("user2", 3, 8.75),
Metrics("user1", 2, 5.25)
)
val aggregatedMetrics = userMetrics
.keyBy(_.userId)
.reduce(new MetricsReducer)class KeyedStream[T, K] {
def fold[R: TypeInformation](initialValue: R, folder: FoldFunction[T, R]): DataStream[R]
def fold[R: TypeInformation](initialValue: R)(fun: (R, T) => R): DataStream[R]
}Fold with initial value (deprecated, use process functions instead):
val env = StreamExecutionEnvironment.getExecutionEnvironment
val words = env.fromElements("hello", "world", "hello", "flink")
.keyBy(identity)
// Fold to count occurrences (deprecated pattern)
val wordCounts = words.fold(0)((count, _) => count + 1)class KeyedStream[T, K] {
def sum(position: Int): DataStream[T]
def sum(field: String): DataStream[T]
def min(position: Int): DataStream[T]
def min(field: String): DataStream[T]
def max(position: Int): DataStream[T]
def max(field: String): DataStream[T]
def minBy(position: Int): DataStream[T]
def minBy(field: String): DataStream[T]
def maxBy(position: Int): DataStream[T]
def maxBy(field: String): DataStream[T]
}Built-in aggregation functions:
val env = StreamExecutionEnvironment.getExecutionEnvironment
// Tuple aggregations
val sales = env.fromElements(
("product1", 100.0),
("product2", 150.0),
("product1", 75.0),
("product2", 200.0)
)
val salesByProduct = sales.keyBy(0)
val totalSales = salesByProduct.sum(1) // Sum by position
val maxSale = salesByProduct.max(1) // Maximum value
val bestSale = salesByProduct.maxBy(1) // Element with maximum value
// Case class aggregations
case class Sale(product: String, amount: Double, quantity: Int)
val salesData = env.fromElements(
Sale("A", 100.0, 2),
Sale("B", 150.0, 3),
Sale("A", 75.0, 1)
)
val salesByProductName = salesData.keyBy(_.product)
val totalsByProduct = salesByProductName.sum("amount") // Sum by field name
val maxQuantity = salesByProductName.max("quantity") // Max quantity
val topSale = salesByProductName.maxBy("amount") // Best sale by amountclass KeyedStream[T, K] {
def process[R: TypeInformation](processFunction: KeyedProcessFunction[K, T, R]): DataStream[R]
}Low-level stateful processing with timers:
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
import org.apache.flink.streaming.api.functions.KeyedProcessFunction.{Context, OnTimerContext}
import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
case class LoginEvent(userId: String, timestamp: Long)
case class LoginAlert(userId: String, count: Int, window: String)
class LoginMonitor extends KeyedProcessFunction[String, LoginEvent, LoginAlert] {
private var loginCount: ValueState[Int] = _
private var timer: ValueState[Long] = _
override def open(parameters: org.apache.flink.configuration.Configuration): Unit = {
loginCount = getRuntimeContext.getState(
new ValueStateDescriptor[Int]("login-count", classOf[Int])
)
timer = getRuntimeContext.getState(
new ValueStateDescriptor[Long]("timer", classOf[Long])
)
}
override def processElement(
event: LoginEvent,
ctx: Context,
out: Collector[LoginAlert]
): Unit = {
val currentCount = Option(loginCount.value()).getOrElse(0) + 1
loginCount.update(currentCount)
// Set timer for 1 minute window if not already set
if (timer.value() == 0) {
val timerTime = event.timestamp + 60000 // 1 minute
ctx.timerService().registerEventTimeTimer(timerTime)
timer.update(timerTime)
}
// Alert if too many logins
if (currentCount >= 5) {
out.collect(LoginAlert(event.userId, currentCount, "suspicious"))
}
}
override def onTimer(
timestamp: Long,
ctx: OnTimerContext,
out: Collector[LoginAlert]
): Unit = {
val count = Option(loginCount.value()).getOrElse(0)
if (count > 0) {
out.collect(LoginAlert(ctx.getCurrentKey, count, "window-summary"))
}
// Clear state
loginCount.clear()
timer.clear()
}
}
val env = StreamExecutionEnvironment.getExecutionEnvironment
val loginEvents = env.fromElements(
LoginEvent("user1", 1000),
LoginEvent("user1", 1010),
LoginEvent("user1", 1020)
)
val alerts = loginEvents
.keyBy(_.userId)
.process(new LoginMonitor)class KeyedStream[T, K] {
def asQueryableState(queryableStateName: String): DataStreamSink[T]
def asQueryableState(queryableStateName: String, stateDescriptor: ValueStateDescriptor[T]): DataStreamSink[T]
def asQueryableState[ACC](queryableStateName: String, stateDescriptor: ReducingStateDescriptor[T]): DataStreamSink[T]
}Expose keyed state for external queries:
import org.apache.flink.api.common.state.ValueStateDescriptor
val env = StreamExecutionEnvironment.getExecutionEnvironment
// Simple queryable state
val userCounts = env.fromElements(("user1", 1), ("user2", 1), ("user1", 1))
.keyBy(0)
.sum(1)
.asQueryableState("user-counts")
// Custom state descriptor
val stateDescriptor = new ValueStateDescriptor[(String, Int)](
"user-counts",
classOf[(String, Int)]
)
val queryableUserCounts = env.fromElements(("user1", 1), ("user2", 1))
.keyBy(0)
.sum(1)
.asQueryableState("custom-user-counts", stateDescriptor)import org.apache.flink.api.common.state.StateTtlConfig
import org.apache.flink.api.common.time.Time
// Configure state TTL for automatic cleanup
val ttlConfig = StateTtlConfig
.newBuilder(Time.hours(1)) // TTL of 1 hour
.setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite)
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
.build()
val stateDescriptor = new ValueStateDescriptor[String]("my-state", classOf[String])
stateDescriptor.enableTimeToLive(ttlConfig)// Use efficient data structures for state
case class CompactCounter(count: Int, lastUpdate: Long)
// Prefer primitive types when possible
val intStateDescriptor = new ValueStateDescriptor[Int]("counter", classOf[Int])
// Clear state when no longer needed
class StatefulProcessor extends KeyedProcessFunction[String, String, String] {
private var state: ValueState[String] = _
override def processElement(value: String, ctx: Context, out: Collector[String]): Unit = {
if (shouldClearState(value)) {
state.clear() // Explicitly clear state
}
}
private def shouldClearState(value: String): Boolean = {
value == "RESET"
}
}import org.apache.flink.streaming.api.scala._
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor}
import org.apache.flink.util.Collector
case class UserEvent(userId: String, action: String, timestamp: Long)
case class UserSession(
userId: String,
sessionId: String,
startTime: Long,
endTime: Long,
eventCount: Int,
actions: List[String]
)
class SessionTracker extends KeyedProcessFunction[String, UserEvent, UserSession] {
private var currentSession: ValueState[UserSession] = _
private var sessionTimer: ValueState[Long] = _
private val SESSION_TIMEOUT = 30 * 60 * 1000L // 30 minutes
override def open(parameters: org.apache.flink.configuration.Configuration): Unit = {
val sessionDescriptor = new ValueStateDescriptor[UserSession](
"current-session", classOf[UserSession]
)
currentSession = getRuntimeContext.getState(sessionDescriptor)
val timerDescriptor = new ValueStateDescriptor[Long](
"session-timer", classOf[Long]
)
sessionTimer = getRuntimeContext.getState(timerDescriptor)
}
override def processElement(
event: UserEvent,
ctx: Context,
out: Collector[UserSession]
): Unit = {
val session = Option(currentSession.value())
val currentTime = event.timestamp
session match {
case Some(s) if currentTime - s.endTime <= SESSION_TIMEOUT =>
// Continue existing session
val updatedSession = s.copy(
endTime = currentTime,
eventCount = s.eventCount + 1,
actions = s.actions :+ event.action
)
currentSession.update(updatedSession)
// Reset timer
val oldTimer = sessionTimer.value()
if (oldTimer != 0) {
ctx.timerService().deleteEventTimeTimer(oldTimer)
}
val newTimer = currentTime + SESSION_TIMEOUT
ctx.timerService().registerEventTimeTimer(newTimer)
sessionTimer.update(newTimer)
case _ =>
// Start new session
val newSession = UserSession(
userId = event.userId,
sessionId = java.util.UUID.randomUUID().toString,
startTime = currentTime,
endTime = currentTime,
eventCount = 1,
actions = List(event.action)
)
currentSession.update(newSession)
// Set session timeout timer
val timerTime = currentTime + SESSION_TIMEOUT
ctx.timerService().registerEventTimeTimer(timerTime)
sessionTimer.update(timerTime)
}
}
override def onTimer(
timestamp: Long,
ctx: OnTimerContext,
out: Collector[UserSession]
): Unit = {
// Session timeout - emit final session and clear state
val session = currentSession.value()
if (session != null) {
out.collect(session)
currentSession.clear()
}
sessionTimer.clear()
}
}
object UserSessionTracking {
def main(args: Array[String]): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val userEvents = env.fromElements(
UserEvent("user1", "login", 1000),
UserEvent("user1", "page_view", 1010),
UserEvent("user1", "click", 1020),
UserEvent("user2", "login", 1030),
UserEvent("user1", "logout", 1040)
)
val sessions = userEvents
.keyBy(_.userId)
.process(new SessionTracker)
sessions.print()
env.execute("User Session Tracking")
}
}