Lightning-fast unified analytics engine for large-scale data processing with high-level APIs in Scala, Java, Python, and R
—
GraphX is Apache Spark's API for graphs and graph-parallel computation. It extends the Spark RDD abstraction with a Resilient Distributed Property Graph: a directed multigraph with properties attached to each vertex and edge.
The central abstraction in GraphX:
abstract class Graph[VD: ClassTag, ED: ClassTag] extends Serializable {
// Core properties
def vertices: VertexRDD[VD]
def edges: EdgeRDD[ED]
def triplets: RDD[EdgeTriplet[VD, ED]]
// Structural operations
def reverse: Graph[VD, ED]
def subgraph(epred: EdgeTriplet[VD, ED] => Boolean = x => true, vpred: (VertexId, VD) => Boolean = (a, b) => true): Graph[VD, ED]
def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED]
// Transformation operations
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
def mapEdges[ED2: ClassTag](map: Edge[ED] => ED2): Graph[VD, ED2]
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2]
}type VertexId = LongRepresents a directed edge in the graph:
case class Edge[ED](srcId: VertexId, dstId: VertexId, attr: ED) extends SerializableJoins vertex and edge data:
class EdgeTriplet[VD, ED] extends Edge[ED] {
def srcId: VertexId // Source vertex ID
def dstId: VertexId // Destination vertex ID
def attr: ED // Edge attribute
def srcAttr: VD // Source vertex attribute
def dstAttr: VD // Destination vertex attribute
}Graph.apply: Create graph from vertices and edges
object Graph {
def apply[VD: ClassTag, ED: ClassTag](vertices: RDD[(VertexId, VD)], edges: RDD[Edge[ED]], defaultVertexAttr: VD = null): Graph[VD, ED]
}import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
// Create vertices RDD
val vertices: RDD[(VertexId, String)] = sc.parallelize(Array(
(1L, "Alice"),
(2L, "Bob"),
(3L, "Charlie"),
(4L, "David")
))
// Create edges RDD
val edges: RDD[Edge[String]] = sc.parallelize(Array(
Edge(1L, 2L, "friend"),
Edge(2L, 3L, "friend"),
Edge(3L, 4L, "colleague"),
Edge(1L, 4L, "colleague")
))
// Create graph
val graph = Graph(vertices, edges, defaultVertexAttr = "Unknown")Graph.fromEdges: Create graph from edges only
def fromEdges[VD: ClassTag, ED: ClassTag](edges: RDD[Edge[ED]], defaultValue: VD): Graph[VD, ED]// Create graph from edges (vertices inferred)
val relationships: RDD[Edge[String]] = sc.parallelize(Array(
Edge(1L, 2L, "follows"),
Edge(2L, 3L, "follows"),
Edge(3L, 1L, "follows")
))
val socialGraph = Graph.fromEdges(relationships, defaultValue = "user")Graph.fromEdgeTuples: Create unweighted graph from tuples
def fromEdgeTuples[VD: ClassTag](rawEdges: RDD[(VertexId, VertexId)], defaultValue: VD, uniqueEdges: Option[PartitionStrategy] = None): Graph[VD, Int]// Simple edge list as tuples
val edgeTuples: RDD[(VertexId, VertexId)] = sc.parallelize(Array(
(1L, 2L), (2L, 3L), (3L, 1L), (1L, 3L)
))
val simpleGraph = Graph.fromEdgeTuples(edgeTuples, defaultValue = 1)val numVertices = graph.vertices.count()
val numEdges = graph.edges.count()
println(s"Graph has $numVertices vertices and $numEdges edges")
// Access vertices and edges
graph.vertices.collect().foreach { case (id, attr) =>
println(s"Vertex $id: $attr")
}
graph.edges.collect().foreach { edge =>
println(s"Edge ${edge.srcId} -> ${edge.dstId}: ${edge.attr}")
}
// Access triplets (vertex-edge-vertex)
graph.triplets.collect().foreach { triplet =>
println(s"${triplet.srcAttr} -${triplet.attr}-> ${triplet.dstAttr}")
}mapVertices: Transform vertex attributes
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]// Add vertex degrees to attributes
val graphWithDegrees = graph.mapVertices { (id, attr) =>
(attr, graph.degrees.lookup(id).headOption.getOrElse(0))
}
// Convert to upper case
val upperCaseGraph = graph.mapVertices { (id, name) =>
name.toUpperCase
}mapEdges: Transform edge attributes
def mapEdges[ED2: ClassTag](map: Edge[ED] => ED2): Graph[VD, ED2]
def mapEdges[ED2: ClassTag](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2]// Add edge weights
val weightedGraph = graph.mapEdges { edge =>
edge.attr match {
case "friend" => 1.0
case "colleague" => 0.5
case _ => 0.1
}
}
// Transform edge attributes using triplet info
val enhancedGraph = graph.mapTriplets { triplet =>
s"${triplet.srcAttr}-${triplet.attr}-${triplet.dstAttr}"
}reverse: Reverse edge directions
def reverse: Graph[VD, ED]subgraph: Extract subgraph based on predicates
def subgraph(
epred: EdgeTriplet[VD, ED] => Boolean = (x => true),
vpred: (VertexId, VD) => Boolean = ((v, d) => true)
): Graph[VD, ED]// Reverse all edges
val reversedGraph = graph.reverse
// Extract subgraph with only "friend" edges
val friendGraph = graph.subgraph(epred = _.attr == "friend")
// Extract subgraph with specific vertices
val aliceBobGraph = graph.subgraph(
vpred = (id, attr) => attr == "Alice" || attr == "Bob"
)
// Extract subgraph based on both vertices and edges
val specificSubgraph = graph.subgraph(
epred = triplet => triplet.srcAttr != "Charlie",
vpred = (id, attr) => attr.length > 3
)groupEdges: Merge parallel edges
def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED]// Create graph with parallel edges
val parallelEdges: RDD[Edge[Int]] = sc.parallelize(Array(
Edge(1L, 2L, 1),
Edge(1L, 2L, 2), // Parallel edge
Edge(2L, 3L, 3)
))
val parallelGraph = Graph.fromEdges(parallelEdges, "user")
// Merge parallel edges by summing weights
val mergedGraph = parallelGraph.groupEdges(_ + _)Specialized RDD for vertices with efficient joins:
abstract class VertexRDD[VD] extends RDD[(VertexId, VD)] {
def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2]
def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2]
def leftJoin[VD2: ClassTag, VD3: ClassTag](other: RDD[(VertexId, VD2)])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3]
def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])(f: (VertexId, VD, U) => VD2): VertexRDD[VD2]
def aggregateUsingIndex[VD2: ClassTag](messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2]
}val degrees = graph.degrees
// Transform vertex values
val transformedVertices = graph.vertices.mapValues(_.toUpperCase)
// Join with additional data
val ages: RDD[(VertexId, Int)] = sc.parallelize(Array(
(1L, 25), (2L, 30), (3L, 35), (4L, 28)
))
val verticesWithAges = graph.vertices.leftJoin(ages) { (id, name, ageOpt) =>
(name, ageOpt.getOrElse(0))
}Specialized RDD for edges:
abstract class EdgeRDD[ED] extends RDD[Edge[ED]] {
def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2]
def reverse: EdgeRDD[ED]
def innerJoin[ED2: ClassTag, ED3: ClassTag](other: RDD[(VertexId, ED2)])(f: (VertexId, ED, ED2) => ED3): EdgeRDD[ED3]
}GraphOps provides additional graph algorithms and utilities through implicit conversion:
class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
// Degree operations
def degrees: VertexRDD[Int]
def inDegrees: VertexRDD[Int]
def outDegrees: VertexRDD[Int]
// Neighborhood operations
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]]
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]]
// Messaging operations
def aggregateMessages[A: ClassTag](sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields = TripletFields.All): VertexRDD[A]
// Pregel API
def pregel[A: ClassTag](initialMsg: A, maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)(vprog: (VertexId, VD, A) => VD, sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A): Graph[VD, ED]
}import org.apache.spark.graphx.GraphOps
// Compute vertex degrees
val degrees = graph.degrees
val inDegrees = graph.inDegrees
val outDegrees = graph.outDegrees
// Find vertices with highest in-degree
val maxInDegree = inDegrees.reduce { (a, b) =>
if (a._2 > b._2) a else b
}
println(s"Vertex ${maxInDegree._1} has highest in-degree: ${maxInDegree._2}")
// Join degrees with vertex attributes
val verticesWithDegrees = graph.vertices.leftJoin(degrees) { (id, attr, deg) =>
(attr, deg.getOrElse(0))
}import org.apache.spark.graphx.EdgeDirection
// Collect neighbors
val neighbors = graph.collectNeighbors(EdgeDirection.Out)
neighbors.collect().foreach { case (id, neighborArray) =>
println(s"Vertex $id has neighbors: ${neighborArray.mkString(", ")}")
}
// Collect neighbor IDs only
val neighborIds = graph.collectNeighborIds(EdgeDirection.In)def aggregateMessages[A: ClassTag](
sendMsg: EdgeContext[VD, ED, A] => Unit,
mergeMsg: (A, A) => A,
tripletFields: TripletFields = TripletFields.All
): VertexRDD[A]import org.apache.spark.graphx.{EdgeContext, TripletFields}
// Count neighbors
val neighborCount = graph.aggregateMessages[Int](
// Send message to each vertex
sendMsg = { edgeContext =>
edgeContext.sendToSrc(1)
edgeContext.sendToDst(1)
},
// Merge messages
mergeMsg = _ + _
)
// Compute average neighbor age (assuming vertices have age attribute)
val ageGraph: Graph[Int, String] = Graph.fromEdges(edges, defaultValue = 25)
val avgNeighborAge = ageGraph.aggregateMessages[Double](
sendMsg = { ctx =>
ctx.sendToSrc(ctx.dstAttr.toDouble)
ctx.sendToDst(ctx.srcAttr.toDouble)
},
mergeMsg = _ + _,
tripletFields = TripletFields.All
).mapValues { (id, totalAge) =>
val degree = ageGraph.degrees.lookup(id).head
totalAge / degree
}The Pregel API enables iterative graph computations:
def pregel[A: ClassTag](
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either
)(
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A
): Graph[VD, ED]import org.apache.spark.graphx.EdgeDirection
// Single Source Shortest Path using Pregel
def shortestPath(graph: Graph[Double, Double], sourceId: VertexId): Graph[Double, Double] = {
// Initialize distances (source = 0.0, others = Double.PositiveInfinity)
val initialGraph = graph.mapVertices { (id, _) =>
if (id == sourceId) 0.0 else Double.PositiveInfinity
}
initialGraph.pregel(
initialMsg = Double.PositiveInfinity,
maxIterations = Int.MaxValue,
activeDirection = EdgeDirection.Out
)(
// Vertex program: update distance if received shorter path
vprog = { (id, dist, newDist) => math.min(dist, newDist) },
// Send message: if vertex distance changed, notify neighbors
sendMsg = { triplet =>
if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
} else {
Iterator.empty
}
},
// Merge messages: take minimum distance
mergeMsg = (a, b) => math.min(a, b)
)
}
// Usage
val sourceVertex = 1L
val distances = shortestPath(weightedGraph, sourceVertex)GraphX includes implementations of common graph algorithms:
object PageRank {
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double]
def runUntilConvergence[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double]
}import org.apache.spark.graphx.lib.PageRank
// Run PageRank for fixed iterations
val pageRanks = PageRank.run(graph, numIter = 10)
// Run PageRank until convergence
val convergedRanks = PageRank.runUntilConvergence(graph, tol = 0.0001)
// Get vertices with highest PageRank
val topVertices = pageRanks.vertices.top(3)(Ordering.by(_._2))
topVertices.foreach { case (id, rank) =>
println(s"Vertex $id: PageRank = $rank")
}object ConnectedComponents {
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED]
}import org.apache.spark.graphx.lib.ConnectedComponents
val ccGraph = ConnectedComponents.run(graph)
// Group vertices by connected component
val componentSizes = ccGraph.vertices
.map(_._2) // Extract component ID
.countByValue() // Count vertices per component
componentSizes.foreach { case (componentId, size) =>
println(s"Component $componentId has $size vertices")
}object TriangleCount {
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED]
}import org.apache.spark.graphx.lib.TriangleCount
// Count triangles (graph must be canonical - lower vertex ID as source)
val canonicalGraph = graph.convertToCanonicalEdges()
val triangleCounts = TriangleCount.run(canonicalGraph)
// Find vertices involved in most triangles
val maxTriangles = triangleCounts.vertices.reduce { (a, b) =>
if (a._2 > b._2) a else b
}
println(s"Vertex ${maxTriangles._1} is in ${maxTriangles._2} triangles")object StronglyConnectedComponents {
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexId, ED]
}import org.apache.spark.graphx.lib.StronglyConnectedComponents
val sccGraph = StronglyConnectedComponents.run(graph, numIter = 10)
// Find strongly connected components
val sccSizes = sccGraph.vertices
.map(_._2)
.countByValue()
println(s"Found ${sccSizes.size} strongly connected components")Control how graphs are partitioned across the cluster:
object PartitionStrategy {
val EdgePartition1D: PartitionStrategy
val EdgePartition2D: PartitionStrategy
val RandomVertexCut: PartitionStrategy
val CanonicalRandomVertexCut: PartitionStrategy
}import org.apache.spark.graphx.{PartitionStrategy, Graph}
// Create graph with specific partitioning strategy
val partitionedGraph = Graph(vertices, edges)
.partitionBy(PartitionStrategy.EdgePartition2D, 4)
// Repartition existing graph
val repartitionedGraph = graph.partitionBy(PartitionStrategy.RandomVertexCut, 8)// Cache graph for iterative algorithms
val cachedGraph = graph.cache()
// Unpersist when done
cachedGraph.unpersist()// For large graphs, construct more efficiently
val efficientGraph = Graph.fromEdges(edges, defaultVertexAttr = "default")
.partitionBy(PartitionStrategy.EdgePartition2D, numPartitions = 4)
.cache()
// Materialize the graph
efficientGraph.vertices.count()
efficientGraph.edges.count()This comprehensive guide covers the complete GraphX API for building scalable graph processing applications in Apache Spark.
Install with Tessl CLI
npx tessl i tessl/maven-apache-spark