PySpark Streaming module that enables scalable, high-throughput, fault-tolerant stream processing of live data streams in Python
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Advanced operations for maintaining state across streaming batches, including updateStateByKey and mapWithState for building stateful streaming applications.
Maintain state across batches using update function:
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S]
): DStream[(K, S)] // On DStream[(K, V)]With custom partitioning:
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
numPartitions: Int
): DStream[(K, S)]
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
partitioner: Partitioner,
initialRDD: RDD[(K, S)]
): DStream[(K, S)]Example word count with state:
val lines = ssc.socketTextStream("localhost", 9999)
val words = lines.flatMap(_.split("\\s+")).map((_, 1))
// Running word count across all batches
val runningCounts = words.updateStateByKey[Int] { (values, state) =>
val currentCount = values.sum
val newCount = state.getOrElse(0) + currentCount
Some(newCount)
}
runningCounts.print()User session tracking:
case class SessionInfo(loginTime: Long, lastActivity: Long, pageViews: Int)
val userEvents = ssc.socketTextStream("localhost", 9999)
.map(parseUserEvent) // Returns (userId, event)
val userSessions = userEvents.updateStateByKey[SessionInfo] { (events, sessionOpt) =>
val currentTime = System.currentTimeMillis()
val session = sessionOpt.getOrElse(SessionInfo(currentTime, currentTime, 0))
val updatedSession = events.foldLeft(session) { (sess, event) =>
event match {
case "login" => sess.copy(loginTime = currentTime, lastActivity = currentTime)
case "pageview" => sess.copy(lastActivity = currentTime, pageViews = sess.pageViews + 1)
case "logout" => return None // Remove session
case _ => sess.copy(lastActivity = currentTime)
}
}
// Expire sessions after 30 minutes of inactivity
if (currentTime - updatedSession.lastActivity > 30 * 60 * 1000) {
None
} else {
Some(updatedSession)
}
}Real-time analytics with state:
case class Analytics(count: Long, sum: Double, min: Double, max: Double) {
def avg: Double = if (count > 0) sum / count else 0.0
}
val metrics = ssc.socketTextStream("localhost", 9999)
.map(line => (line.split(",")(0), line.split(",")(1).toDouble)) // (metric_name, value)
val runningAnalytics = metrics.updateStateByKey[Analytics] { (values, stateOpt) =>
val current = stateOpt.getOrElse(Analytics(0, 0.0, Double.MaxValue, Double.MinValue))
val updated = values.foldLeft(current) { (analytics, value) =>
Analytics(
count = analytics.count + 1,
sum = analytics.sum + value,
min = math.min(analytics.min, value),
max = math.max(analytics.max, value)
)
}
Some(updated)
}Create StateSpec with mapping function:
object StateSpec {
def function[K, V, S, T](
mappingFunction: (Time, K, Option[V], State[S]) => Option[T]
): StateSpec[K, V, S, T]
def function[K, V, S, T](
mappingFunction: (K, Option[V], State[S]) => T
): StateSpec[K, V, S, T]
}StateSpec configuration methods:
abstract class StateSpec[K, V, S, T] {
def initialState(rdd: RDD[(K, S)]): this.type
def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type
def numPartitions(numPartitions: Int): this.type
def partitioner(partitioner: Partitioner): this.type
def timeout(idleDuration: Duration): this.type
}Apply stateful mapping:
def mapWithState[S: ClassTag, T: ClassTag](
spec: StateSpec[K, V, S, T]
): MapWithStateDStream[K, V, S, T] // On DStream[(K, V)]State management in mapping functions:
abstract class State[S] {
def exists(): Boolean
def get(): S
def update(newState: S): Unit
def remove(): Unit
def isTimingOut(): Boolean
def getOption(): Option[S]
}Simple counter with mapWithState:
val words = ssc.socketTextStream("localhost", 9999)
.flatMap(_.split("\\s+"))
.map((_, 1))
val mappingFunction = (word: String, one: Option[Int], state: State[Int]) => {
val sum = one.getOrElse(0) + state.getOption().getOrElse(0)
val output = (word, sum)
state.update(sum)
output
}
val stateDStream = words.mapWithState(StateSpec.function(mappingFunction))
stateDStream.print()User behavior analysis:
case class UserBehavior(
totalSessions: Int,
totalPageViews: Int,
lastActivity: Long,
avgSessionDuration: Double
)
val userActions = ssc.socketTextStream("localhost", 9999)
.map(parseUserAction) // Returns (userId, action, timestamp)
val behaviorSpec = StateSpec.function(
(userId: String, action: Option[(String, Long)], state: State[UserBehavior]) => {
val currentTime = System.currentTimeMillis()
val behavior = state.getOption().getOrElse(
UserBehavior(0, 0, currentTime, 0.0)
)
action match {
case Some(("session_start", timestamp)) =>
val newBehavior = behavior.copy(
totalSessions = behavior.totalSessions + 1,
lastActivity = timestamp
)
state.update(newBehavior)
Some((userId, "session_started", newBehavior))
case Some(("page_view", timestamp)) =>
val newBehavior = behavior.copy(
totalPageViews = behavior.totalPageViews + 1,
lastActivity = timestamp
)
state.update(newBehavior)
Some((userId, "page_viewed", newBehavior))
case None if state.isTimingOut() =>
// State is timing out, emit final statistics
Some((userId, "user_summary", behavior))
case _ =>
None
}
}
).timeout(Minutes(30)) // Timeout inactive users after 30 minutes
val userBehaviorStream = userActions.mapWithState(behaviorSpec)Real-time anomaly detection:
case class MetricState(
values: Queue[Double],
sum: Double,
count: Int,
windowSize: Int = 100
) {
def mean: Double = if (count > 0) sum / count else 0.0
def stdDev: Double = {
if (count < 2) return 0.0
val meanVal = mean
val variance = values.map(v => math.pow(v - meanVal, 2)).sum / count
math.sqrt(variance)
}
}
val metrics = ssc.socketTextStream("localhost", 9999)
.map(line => (line.split(",")(0), line.split(",")(1).toDouble))
val anomalySpec = StateSpec.function(
(metric: String, value: Option[Double], state: State[MetricState]) => {
value match {
case Some(v) =>
val currentState = state.getOption().getOrElse(
MetricState(Queue.empty, 0.0, 0)
)
val newValues = if (currentState.values.size >= currentState.windowSize) {
val (removed, remaining) = currentState.values.dequeue
remaining.enqueue(v)
} else {
currentState.values.enqueue(v)
}
val newState = MetricState(
values = newValues,
sum = currentState.sum - (if (currentState.values.size >= currentState.windowSize)
currentState.values.head else 0.0) + v,
count = math.min(currentState.count + 1, currentState.windowSize)
)
state.update(newState)
// Detect anomaly (value is more than 3 standard deviations from mean)
if (newState.count > 10) {
val zScore = math.abs(v - newState.mean) / newState.stdDev
if (zScore > 3.0) {
Some((metric, s"ANOMALY: $v (z-score: $zScore)"))
} else {
Some((metric, s"NORMAL: $v"))
}
} else {
Some((metric, s"LEARNING: $v"))
}
case None => None
}
}
)
val anomalies = metrics.mapWithState(anomalySpec)
anomalies.print()Access state snapshots:
class MapWithStateDStream[K, V, S, T] extends DStream[T] {
def stateSnapshots(): DStream[(K, S)]
}Example state snapshots:
val stateDStream = words.mapWithState(StateSpec.function(mappingFunction))
// Get periodic snapshots of all state
val snapshots = stateDStream.stateSnapshots()
snapshots.foreachRDD { rdd =>
println(s"Current state count: ${rdd.count()}")
rdd.take(10).foreach { case (key, state) =>
println(s"$key -> $state")
}
}Provide initial state from external source:
// Load initial state from database or file
val initialState = ssc.sparkContext.parallelize(loadInitialStateFromDB())
val stateSpec = StateSpec.function(mappingFunction)
.initialState(initialState)
.numPartitions(10)
.timeout(Minutes(60))Control state memory usage:
val memoryEfficientSpec = StateSpec.function(
(key: String, value: Option[String], state: State[Map[String, Int]]) => {
val currentMap = state.getOption().getOrElse(Map.empty)
value match {
case Some(v) =>
val updated = currentMap + (v -> (currentMap.getOrElse(v, 0) + 1))
// Limit map size to prevent memory issues
val trimmed = if (updated.size > 1000) {
updated.toSeq.sortBy(_._2).takeRight(800).toMap
} else {
updated
}
state.update(trimmed)
Some((key, trimmed.size))
case None if state.isTimingOut() =>
// Clean up before timeout
Some((key, -1)) // Indicate removal
case _ => None
}
}
).timeout(Hours(1))Enable checkpointing for stateful operations:
val ssc = new StreamingContext(conf, Seconds(5))
ssc.checkpoint("hdfs://namenode:9000/checkpoint")
// Stateful operations require checkpointing
val statefulStream = inputStream.updateStateByKey(updateFunction)Optimize stateful operations:
// Use appropriate partitioning
val optimizedState = keyValueStream.updateStateByKey(
updateFunction,
new HashPartitioner(20) // Match your cluster size
)
// Consider using mapWithState for better performance
val efficientState = keyValueStream.mapWithState(
StateSpec.function(mappingFunction)
.numPartitions(20)
.timeout(Minutes(30))
)Handle errors in state updates:
val robustStateSpec = StateSpec.function(
(key: String, value: Option[String], state: State[Int]) => {
try {
val current = state.getOption().getOrElse(0)
val newValue = current + value.map(_.toInt).getOrElse(0)
state.update(newValue)
Some((key, newValue))
} catch {
case e: NumberFormatException =>
// Log error and maintain previous state
logError(s"Invalid number for key $key: ${value.getOrElse("None")}")
Some((key, state.getOption().getOrElse(0)))
case e: Exception =>
// Handle other errors
logError(s"Error updating state for $key", e)
None
}
}
)Install with Tessl CLI
npx tessl i tessl/pypi-pyspark-streaming