Vertex-centric programming framework for implementing custom iterative graph algorithms using the Pregel computational model. The Pregel API enables distributed graph computation through a message-passing paradigm where vertices receive messages, update their state, and send messages to neighbors.
The main Pregel computation API for implementing vertex-centric iterative algorithms.
/**
* Execute a Pregel computation on the graph
* @param initialMsg Initial message sent to all vertices in first iteration
* @param maxIterations Maximum number of iterations (default: no limit)
* @param activeDirection Edge direction for active vertices (default: Either)
* @param vprog Vertex program: (VertexId, VertexData, Message) => NewVertexData
* @param sendMsg Send message function: EdgeTriplet => Iterator[(VertexId, Message)]
* @param mergeMsg Merge messages function: (Message, Message) => Message
* @returns New graph with updated vertex attributes
*/
def pregel[A: ClassTag](
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either
)(
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A
): Graph[VD, ED]
object Pregel {
/**
* Execute Pregel computation (standalone object version)
* @param graph Input graph
* @param initialMsg Initial message for all vertices
* @param maxIterations Maximum iterations
* @param activeDirection Edge direction for message passing
* @param vprog Vertex program function
* @param sendMsg Message sending function
* @param mergeMsg Message merging function
* @returns Updated graph
*/
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag](
graph: Graph[VD, ED],
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either
)(
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A
): Graph[VD, ED]
}Control which edges are used for message passing based on vertex activity.
object EdgeDirection {
/** Edges where the vertex is the destination */
val In: EdgeDirection
/** Edges where the vertex is the source */
val Out: EdgeDirection
/** Both incoming and outgoing edges */
val Either: EdgeDirection
/** Only edges where vertex is both source AND destination (self-loops) */
val Both: EdgeDirection
}
class EdgeDirection {
/** Reverse the direction */
def reverse: EdgeDirection
}The Pregel API builds on the lower-level aggregateMessages function for message passing.
/**
* Lower-level message aggregation (used internally by Pregel)
* @param sendMsg Function defining messages sent along edges
* @param mergeMsg Function combining multiple messages at same vertex
* @param tripletFields Optimization hint for which triplet fields are accessed
* @returns VertexRDD with aggregated messages
*/
def aggregateMessages[A: ClassTag](
sendMsg: EdgeContext[VD, ED, A] => Unit,
mergeMsg: (A, A) => A,
tripletFields: TripletFields = TripletFields.All
): VertexRDD[A]
/**
* Context for sending messages in aggregateMessages
*/
abstract class EdgeContext[VD, ED, A] {
val srcId: VertexId
val dstId: VertexId
val srcAttr: VD
val dstAttr: VD
val attr: ED
/** Send message to source vertex */
def sendToSrc(msg: A): Unit
/** Send message to destination vertex */
def sendToDst(msg: A): Unit
}Classic shortest path algorithm using Pregel message passing.
import org.apache.spark.graphx._
def shortestPaths(graph: Graph[Long, Double], sourceId: VertexId): Graph[Double, Double] = {
// Initialize distances: 0 for source, infinity for others
val initialGraph = graph.mapVertices((id, _) =>
if (id == sourceId) 0.0 else Double.PositiveInfinity
)
// Pregel computation
initialGraph.pregel(Double.PositiveInfinity)(
// Vertex program: update distance if received shorter path
vprog = (id, dist, newDist) => math.min(dist, newDist),
// Send message: if distance changed, notify neighbors
sendMsg = triplet => {
if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
} else {
Iterator.empty
}
},
// Merge messages: take minimum distance
mergeMsg = (a, b) => math.min(a, b)
)
}
// Usage
val sourceVertex = 1L
val distances = shortestPaths(graph, sourceVertex).vertices
distances.collect.foreach { case (id, dist) =>
println(s"Distance from $sourceVertex to $id: $dist")
}Find connected components using iterative label propagation.
def connectedComponents[ED: ClassTag](graph: Graph[Long, ED]): Graph[VertexId, ED] = {
// Initialize each vertex with its own ID as component label
val initialGraph = graph.mapVertices((id, _) => id)
initialGraph.pregel(Long.MaxValue)(
// Vertex program: adopt smaller component ID
vprog = (id, oldLabel, newLabel) => math.min(oldLabel, newLabel),
// Send message: propagate smallest seen component ID
sendMsg = triplet => {
val messages = mutable.ListBuffer[(VertexId, VertexId)]()
if (triplet.srcAttr < triplet.dstAttr) {
messages += ((triplet.dstId, triplet.srcAttr))
}
if (triplet.dstAttr < triplet.srcAttr) {
messages += ((triplet.srcId, triplet.dstAttr))
}
messages.toIterator
},
// Merge messages: take minimum component ID
mergeMsg = (a, b) => math.min(a, b)
)
}Implement PageRank algorithm using the Pregel framework.
def pageRank(graph: Graph[Double, Double], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = {
// Initialize all vertices with rank 1.0
val initialGraph = graph.mapVertices((_, _) => 1.0)
// Get out-degrees for each vertex
val outDegrees = graph.outDegrees
val graphWithDegrees = initialGraph.outerJoinVertices(outDegrees)((id, rank, degOpt) =>
(rank, degOpt.getOrElse(0))
)
// Run Pregel for fixed iterations
graphWithDegrees.pregel((0.0, 0), numIter)(
// Vertex program: update PageRank score
vprog = (id, attr, msgSum) => {
val (oldRank, outDegree) = attr
val newRank = resetProb + (1.0 - resetProb) * msgSum
(newRank, outDegree)
},
// Send message: send rank contribution to neighbors
sendMsg = triplet => {
val (srcRank, srcOutDegree) = triplet.srcAttr
if (srcOutDegree > 0) {
Iterator((triplet.dstId, srcRank / srcOutDegree))
} else {
Iterator.empty
}
},
// Merge messages: sum all incoming rank contributions
mergeMsg = (a, b) => a + b
).mapVertices((id, attr) => attr._1) // Extract just the rank
}Matrix factorization using alternating least squares implemented with Pregel.
case class Factor(features: Array[Double], bias: Double)
def alternatingLeastSquares(
graph: Graph[Double, Double], // ratings graph
rank: Int,
numIter: Int
): Graph[Factor, Double] = {
import scala.util.Random
val random = new Random(42)
// Initialize vertex features randomly
val initialGraph = graph.mapVertices { (id, _) =>
Factor(Array.fill(rank)(random.nextGaussian() * 0.1), 0.0)
}
// Alternate between updating user and item factors
var currentGraph = initialGraph
for (iter <- 0 until numIter) {
// Update user factors (vertices with ID < some threshold)
currentGraph = currentGraph.pregel(Factor(Array.empty, 0.0))(
vprog = (id, oldFactor, newFactor) => {
if (id < 1000000 && newFactor.features.nonEmpty) newFactor else oldFactor
},
sendMsg = triplet => {
val rating = triplet.attr
// Send item factors to users, user factors to items
if (triplet.srcId < 1000000) { // User vertex
Iterator((triplet.srcId, triplet.dstAttr)) // Send item factor to user
} else {
Iterator((triplet.dstId, triplet.srcAttr)) // Send user factor to item
}
},
mergeMsg = (f1, f2) => f1 // Simple merge (would need proper ALS update)
)
// Update item factors (similar pattern)
// ... item factor update iteration ...
}
currentGraph
}Some algorithms require multiple Pregel phases with different logic.
def twoPhaseAlgorithm[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VD, ED] = {
// Phase 1: Forward pass
val phase1Result = graph.pregel(initialMsg1)(vprog1, sendMsg1, mergeMsg1)
// Phase 2: Backward pass with different logic
val phase2Result = phase1Result.pregel(initialMsg2)(vprog2, sendMsg2, mergeMsg2)
phase2Result
}Implement custom convergence checking within Pregel iterations.
def convergedPregelAlgorithm[VD: ClassTag, ED: ClassTag](
graph: Graph[VD, ED],
tolerance: Double
): Graph[VD, ED] = {
var currentGraph = graph
var converged = false
var iteration = 0
while (!converged && iteration < 100) {
val previousGraph = currentGraph
currentGraph = currentGraph.pregel(initialMsg)(
vprog = (id, oldAttr, msg) => {
// Update logic that tracks changes
val newAttr = updateFunction(oldAttr, msg)
newAttr
},
sendMsg = sendFunction,
mergeMsg = mergeFunction
)
// Check convergence by comparing vertex attributes
val maxChange = previousGraph.vertices
.join(currentGraph.vertices)
.map { case (id, (oldAttr, newAttr)) =>
computeChange(oldAttr, newAttr)
}
.max()
converged = maxChange < tolerance
previousGraph.unpersist(blocking = false)
iteration += 1
}
currentGraph
}// Optimize Pregel with proper caching and partitioning
def optimizedPregelAlgorithm[VD: ClassTag, ED: ClassTag](
graph: Graph[VD, ED]
): Graph[VD, ED] = {
val optimizedGraph = graph
.partitionBy(PartitionStrategy.EdgePartition2D) // Better partitioning
.cache() // Cache for iterations
val result = optimizedGraph.pregel(
initialMsg = initialMessage,
maxIterations = 50 // Prevent infinite loops
)(
vprog = vertexProgram,
sendMsg = messageFunction,
mergeMsg = mergeFunction
)
// Clean up
optimizedGraph.unpersist(blocking = false)
result.cache() // Cache result if it will be reused
}
// Use TripletFields for better performance
def efficientMessagePassing[VD: ClassTag, ED: ClassTag](
graph: Graph[VD, ED]
): VertexRDD[Double] = {
graph.aggregateMessages[Double](
sendMsg = ctx => {
// Only access needed fields
ctx.sendToDst(ctx.srcAttr)
},
mergeMsg = (a, b) => a + b,
tripletFields = TripletFields.Src // Only need source attributes
)
}// Use aggregateMessages for single-pass message aggregation
val degrees = graph.aggregateMessages[Int](
sendMsg = ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) },
mergeMsg = (a, b) => a + b
)
// Use mapVertices/mapTriplets for simple transformations
val normalizedGraph = graph.mapVertices((id, attr) => attr / maxValue)
// Use GraphOps methods for common operations
val components = graph.connectedComponents() // More efficient than custom Pregel
val pageRanks = graph.pageRank(0.001) // Optimized implementation// Pregel execution phases in each iteration:
// 1. Vertex Program: Update vertex state based on received messages
// 2. Send Messages: Generate messages for next iteration
// 3. Message Aggregation: Combine multiple messages to same vertex
// 4. Check Convergence: Determine if more iterations needed
def pregelIteration[VD, ED, A](
graph: Graph[VD, ED],
messages: VertexRDD[A],
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A
): (Graph[VD, ED], VertexRDD[A]) = {
// Phase 1: Apply vertex program
val newVertices = graph.vertices.leftJoin(messages)(vprog)
val newGraph = Graph(newVertices, graph.edges)
// Phase 2: Send messages for next iteration
val newMessages = newGraph.aggregateMessages(
sendMsg = ctx => sendMsg(ctx.toEdgeTriplet).foreach {
case (vid, msg) => if (vid == ctx.srcId) ctx.sendToSrc(msg) else ctx.sendToDst(msg)
},
mergeMsg = mergeMsg
)
(newGraph, newMessages)
}