0
# GraphX - Graph Processing
1
2
GraphX is Apache Spark's API for graphs and graph-parallel computation. It extends the Spark RDD abstraction with a Resilient Distributed Property Graph: a directed multigraph with properties attached to each vertex and edge.
3
4
## Core Graph Abstractions
5
6
### Graph Class
7
8
The central abstraction in GraphX:
9
10
```scala { .api }
11
abstract class Graph[VD: ClassTag, ED: ClassTag] extends Serializable {
12
// Core properties
13
def vertices: VertexRDD[VD]
14
def edges: EdgeRDD[ED]
15
def triplets: RDD[EdgeTriplet[VD, ED]]
16
17
// Structural operations
18
def reverse: Graph[VD, ED]
19
def subgraph(epred: EdgeTriplet[VD, ED] => Boolean = x => true, vpred: (VertexId, VD) => Boolean = (a, b) => true): Graph[VD, ED]
20
def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED]
21
22
// Transformation operations
23
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
24
def mapEdges[ED2: ClassTag](map: Edge[ED] => ED2): Graph[VD, ED2]
25
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2]
26
}
27
```
28
29
### VertexId Type
30
31
```scala { .api }
32
type VertexId = Long
33
```
34
35
### Edge Class
36
37
Represents a directed edge in the graph:
38
39
```scala { .api }
40
case class Edge[ED](srcId: VertexId, dstId: VertexId, attr: ED) extends Serializable
41
```
42
43
### EdgeTriplet Class
44
45
Joins vertex and edge data:
46
47
```scala { .api }
48
class EdgeTriplet[VD, ED] extends Edge[ED] {
49
def srcId: VertexId // Source vertex ID
50
def dstId: VertexId // Destination vertex ID
51
def attr: ED // Edge attribute
52
def srcAttr: VD // Source vertex attribute
53
def dstAttr: VD // Destination vertex attribute
54
}
55
```
56
57
## Creating Graphs
58
59
### Graph Construction
60
61
**Graph.apply**: Create graph from vertices and edges
62
```scala { .api }
63
object Graph {
64
def apply[VD: ClassTag, ED: ClassTag](vertices: RDD[(VertexId, VD)], edges: RDD[Edge[ED]], defaultVertexAttr: VD = null): Graph[VD, ED]
65
}
66
```
67
68
```scala
69
import org.apache.spark.graphx._
70
import org.apache.spark.rdd.RDD
71
72
// Create vertices RDD
73
val vertices: RDD[(VertexId, String)] = sc.parallelize(Array(
74
(1L, "Alice"),
75
(2L, "Bob"),
76
(3L, "Charlie"),
77
(4L, "David")
78
))
79
80
// Create edges RDD
81
val edges: RDD[Edge[String]] = sc.parallelize(Array(
82
Edge(1L, 2L, "friend"),
83
Edge(2L, 3L, "friend"),
84
Edge(3L, 4L, "colleague"),
85
Edge(1L, 4L, "colleague")
86
))
87
88
// Create graph
89
val graph = Graph(vertices, edges, defaultVertexAttr = "Unknown")
90
```
91
92
**Graph.fromEdges**: Create graph from edges only
93
```scala { .api }
94
def fromEdges[VD: ClassTag, ED: ClassTag](edges: RDD[Edge[ED]], defaultValue: VD): Graph[VD, ED]
95
```
96
97
```scala
98
// Create graph from edges (vertices inferred)
99
val relationships: RDD[Edge[String]] = sc.parallelize(Array(
100
Edge(1L, 2L, "follows"),
101
Edge(2L, 3L, "follows"),
102
Edge(3L, 1L, "follows")
103
))
104
105
val socialGraph = Graph.fromEdges(relationships, defaultValue = "user")
106
```
107
108
**Graph.fromEdgeTuples**: Create unweighted graph from tuples
109
```scala { .api }
110
def fromEdgeTuples[VD: ClassTag](rawEdges: RDD[(VertexId, VertexId)], defaultValue: VD, uniqueEdges: Option[PartitionStrategy] = None): Graph[VD, Int]
111
```
112
113
```scala
114
// Simple edge list as tuples
115
val edgeTuples: RDD[(VertexId, VertexId)] = sc.parallelize(Array(
116
(1L, 2L), (2L, 3L), (3L, 1L), (1L, 3L)
117
))
118
119
val simpleGraph = Graph.fromEdgeTuples(edgeTuples, defaultValue = 1)
120
```
121
122
## Graph Properties and Operations
123
124
### Basic Properties
125
126
```scala
127
val numVertices = graph.vertices.count()
128
val numEdges = graph.edges.count()
129
130
println(s"Graph has $numVertices vertices and $numEdges edges")
131
132
// Access vertices and edges
133
graph.vertices.collect().foreach { case (id, attr) =>
134
println(s"Vertex $id: $attr")
135
}
136
137
graph.edges.collect().foreach { edge =>
138
println(s"Edge ${edge.srcId} -> ${edge.dstId}: ${edge.attr}")
139
}
140
141
// Access triplets (vertex-edge-vertex)
142
graph.triplets.collect().foreach { triplet =>
143
println(s"${triplet.srcAttr} -${triplet.attr}-> ${triplet.dstAttr}")
144
}
145
```
146
147
### Graph Transformations
148
149
**mapVertices**: Transform vertex attributes
150
```scala { .api }
151
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
152
```
153
154
```scala
155
// Add vertex degrees to attributes
156
val graphWithDegrees = graph.mapVertices { (id, attr) =>
157
(attr, graph.degrees.lookup(id).headOption.getOrElse(0))
158
}
159
160
// Convert to upper case
161
val upperCaseGraph = graph.mapVertices { (id, name) =>
162
name.toUpperCase
163
}
164
```
165
166
**mapEdges**: Transform edge attributes
167
```scala { .api }
168
def mapEdges[ED2: ClassTag](map: Edge[ED] => ED2): Graph[VD, ED2]
169
def mapEdges[ED2: ClassTag](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2]
170
```
171
172
```scala
173
// Add edge weights
174
val weightedGraph = graph.mapEdges { edge =>
175
edge.attr match {
176
case "friend" => 1.0
177
case "colleague" => 0.5
178
case _ => 0.1
179
}
180
}
181
182
// Transform edge attributes using triplet info
183
val enhancedGraph = graph.mapTriplets { triplet =>
184
s"${triplet.srcAttr}-${triplet.attr}-${triplet.dstAttr}"
185
}
186
```
187
188
### Structural Operations
189
190
**reverse**: Reverse edge directions
191
```scala { .api }
192
def reverse: Graph[VD, ED]
193
```
194
195
**subgraph**: Extract subgraph based on predicates
196
```scala { .api }
197
def subgraph(
198
epred: EdgeTriplet[VD, ED] => Boolean = (x => true),
199
vpred: (VertexId, VD) => Boolean = ((v, d) => true)
200
): Graph[VD, ED]
201
```
202
203
```scala
204
// Reverse all edges
205
val reversedGraph = graph.reverse
206
207
// Extract subgraph with only "friend" edges
208
val friendGraph = graph.subgraph(epred = _.attr == "friend")
209
210
// Extract subgraph with specific vertices
211
val aliceBobGraph = graph.subgraph(
212
vpred = (id, attr) => attr == "Alice" || attr == "Bob"
213
)
214
215
// Extract subgraph based on both vertices and edges
216
val specificSubgraph = graph.subgraph(
217
epred = triplet => triplet.srcAttr != "Charlie",
218
vpred = (id, attr) => attr.length > 3
219
)
220
```
221
222
**groupEdges**: Merge parallel edges
223
```scala { .api }
224
def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED]
225
```
226
227
```scala
228
// Create graph with parallel edges
229
val parallelEdges: RDD[Edge[Int]] = sc.parallelize(Array(
230
Edge(1L, 2L, 1),
231
Edge(1L, 2L, 2), // Parallel edge
232
Edge(2L, 3L, 3)
233
))
234
235
val parallelGraph = Graph.fromEdges(parallelEdges, "user")
236
237
// Merge parallel edges by summing weights
238
val mergedGraph = parallelGraph.groupEdges(_ + _)
239
```
240
241
## VertexRDD and EdgeRDD
242
243
### VertexRDD
244
245
Specialized RDD for vertices with efficient joins:
246
247
```scala { .api }
248
abstract class VertexRDD[VD] extends RDD[(VertexId, VD)] {
249
def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2]
250
def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2]
251
def leftJoin[VD2: ClassTag, VD3: ClassTag](other: RDD[(VertexId, VD2)])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3]
252
def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])(f: (VertexId, VD, U) => VD2): VertexRDD[VD2]
253
def aggregateUsingIndex[VD2: ClassTag](messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2]
254
}
255
```
256
257
```scala
258
val degrees = graph.degrees
259
260
// Transform vertex values
261
val transformedVertices = graph.vertices.mapValues(_.toUpperCase)
262
263
// Join with additional data
264
val ages: RDD[(VertexId, Int)] = sc.parallelize(Array(
265
(1L, 25), (2L, 30), (3L, 35), (4L, 28)
266
))
267
268
val verticesWithAges = graph.vertices.leftJoin(ages) { (id, name, ageOpt) =>
269
(name, ageOpt.getOrElse(0))
270
}
271
```
272
273
### EdgeRDD
274
275
Specialized RDD for edges:
276
277
```scala { .api }
278
abstract class EdgeRDD[ED] extends RDD[Edge[ED]] {
279
def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2]
280
def reverse: EdgeRDD[ED]
281
def innerJoin[ED2: ClassTag, ED3: ClassTag](other: RDD[(VertexId, ED2)])(f: (VertexId, ED, ED2) => ED3): EdgeRDD[ED3]
282
}
283
```
284
285
## GraphOps - Advanced Operations
286
287
GraphOps provides additional graph algorithms and utilities through implicit conversion:
288
289
```scala { .api }
290
class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
291
// Degree operations
292
def degrees: VertexRDD[Int]
293
def inDegrees: VertexRDD[Int]
294
def outDegrees: VertexRDD[Int]
295
296
// Neighborhood operations
297
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]]
298
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]]
299
300
// Messaging operations
301
def aggregateMessages[A: ClassTag](sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields = TripletFields.All): VertexRDD[A]
302
303
// Pregel API
304
def pregel[A: ClassTag](initialMsg: A, maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)(vprog: (VertexId, VD, A) => VD, sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A): Graph[VD, ED]
305
}
306
```
307
308
### Degree Operations
309
310
```scala
311
import org.apache.spark.graphx.GraphOps
312
313
// Compute vertex degrees
314
val degrees = graph.degrees
315
val inDegrees = graph.inDegrees
316
val outDegrees = graph.outDegrees
317
318
// Find vertices with highest in-degree
319
val maxInDegree = inDegrees.reduce { (a, b) =>
320
if (a._2 > b._2) a else b
321
}
322
println(s"Vertex ${maxInDegree._1} has highest in-degree: ${maxInDegree._2}")
323
324
// Join degrees with vertex attributes
325
val verticesWithDegrees = graph.vertices.leftJoin(degrees) { (id, attr, deg) =>
326
(attr, deg.getOrElse(0))
327
}
328
```
329
330
### Neighborhood Operations
331
332
```scala
333
import org.apache.spark.graphx.EdgeDirection
334
335
// Collect neighbors
336
val neighbors = graph.collectNeighbors(EdgeDirection.Out)
337
neighbors.collect().foreach { case (id, neighborArray) =>
338
println(s"Vertex $id has neighbors: ${neighborArray.mkString(", ")}")
339
}
340
341
// Collect neighbor IDs only
342
val neighborIds = graph.collectNeighborIds(EdgeDirection.In)
343
```
344
345
### Message Passing with aggregateMessages
346
347
```scala { .api }
348
def aggregateMessages[A: ClassTag](
349
sendMsg: EdgeContext[VD, ED, A] => Unit,
350
mergeMsg: (A, A) => A,
351
tripletFields: TripletFields = TripletFields.All
352
): VertexRDD[A]
353
```
354
355
```scala
356
import org.apache.spark.graphx.{EdgeContext, TripletFields}
357
358
// Count neighbors
359
val neighborCount = graph.aggregateMessages[Int](
360
// Send message to each vertex
361
sendMsg = { edgeContext =>
362
edgeContext.sendToSrc(1)
363
edgeContext.sendToDst(1)
364
},
365
// Merge messages
366
mergeMsg = _ + _
367
)
368
369
// Compute average neighbor age (assuming vertices have age attribute)
370
val ageGraph: Graph[Int, String] = Graph.fromEdges(edges, defaultValue = 25)
371
372
val avgNeighborAge = ageGraph.aggregateMessages[Double](
373
sendMsg = { ctx =>
374
ctx.sendToSrc(ctx.dstAttr.toDouble)
375
ctx.sendToDst(ctx.srcAttr.toDouble)
376
},
377
mergeMsg = _ + _,
378
tripletFields = TripletFields.All
379
).mapValues { (id, totalAge) =>
380
val degree = ageGraph.degrees.lookup(id).head
381
totalAge / degree
382
}
383
```
384
385
### Pregel API
386
387
The Pregel API enables iterative graph computations:
388
389
```scala { .api }
390
def pregel[A: ClassTag](
391
initialMsg: A,
392
maxIterations: Int = Int.MaxValue,
393
activeDirection: EdgeDirection = EdgeDirection.Either
394
)(
395
vprog: (VertexId, VD, A) => VD,
396
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
397
mergeMsg: (A, A) => A
398
): Graph[VD, ED]
399
```
400
401
```scala
402
import org.apache.spark.graphx.EdgeDirection
403
404
// Single Source Shortest Path using Pregel
405
def shortestPath(graph: Graph[Double, Double], sourceId: VertexId): Graph[Double, Double] = {
406
// Initialize distances (source = 0.0, others = Double.PositiveInfinity)
407
val initialGraph = graph.mapVertices { (id, _) =>
408
if (id == sourceId) 0.0 else Double.PositiveInfinity
409
}
410
411
initialGraph.pregel(
412
initialMsg = Double.PositiveInfinity,
413
maxIterations = Int.MaxValue,
414
activeDirection = EdgeDirection.Out
415
)(
416
// Vertex program: update distance if received shorter path
417
vprog = { (id, dist, newDist) => math.min(dist, newDist) },
418
419
// Send message: if vertex distance changed, notify neighbors
420
sendMsg = { triplet =>
421
if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
422
Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
423
} else {
424
Iterator.empty
425
}
426
},
427
428
// Merge messages: take minimum distance
429
mergeMsg = (a, b) => math.min(a, b)
430
)
431
}
432
433
// Usage
434
val sourceVertex = 1L
435
val distances = shortestPath(weightedGraph, sourceVertex)
436
```
437
438
## Built-in Graph Algorithms
439
440
GraphX includes implementations of common graph algorithms:
441
442
### PageRank
443
444
```scala { .api }
445
object PageRank {
446
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double]
447
def runUntilConvergence[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double]
448
}
449
```
450
451
```scala
452
import org.apache.spark.graphx.lib.PageRank
453
454
// Run PageRank for fixed iterations
455
val pageRanks = PageRank.run(graph, numIter = 10)
456
457
// Run PageRank until convergence
458
val convergedRanks = PageRank.runUntilConvergence(graph, tol = 0.0001)
459
460
// Get vertices with highest PageRank
461
val topVertices = pageRanks.vertices.top(3)(Ordering.by(_._2))
462
topVertices.foreach { case (id, rank) =>
463
println(s"Vertex $id: PageRank = $rank")
464
}
465
```
466
467
### Connected Components
468
469
```scala { .api }
470
object ConnectedComponents {
471
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED]
472
}
473
```
474
475
```scala
476
import org.apache.spark.graphx.lib.ConnectedComponents
477
478
val ccGraph = ConnectedComponents.run(graph)
479
480
// Group vertices by connected component
481
val componentSizes = ccGraph.vertices
482
.map(_._2) // Extract component ID
483
.countByValue() // Count vertices per component
484
485
componentSizes.foreach { case (componentId, size) =>
486
println(s"Component $componentId has $size vertices")
487
}
488
```
489
490
### Triangle Counting
491
492
```scala { .api }
493
object TriangleCount {
494
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED]
495
}
496
```
497
498
```scala
499
import org.apache.spark.graphx.lib.TriangleCount
500
501
// Count triangles (graph must be canonical - lower vertex ID as source)
502
val canonicalGraph = graph.convertToCanonicalEdges()
503
val triangleCounts = TriangleCount.run(canonicalGraph)
504
505
// Find vertices involved in most triangles
506
val maxTriangles = triangleCounts.vertices.reduce { (a, b) =>
507
if (a._2 > b._2) a else b
508
}
509
println(s"Vertex ${maxTriangles._1} is in ${maxTriangles._2} triangles")
510
```
511
512
### Strongly Connected Components
513
514
```scala { .api }
515
object StronglyConnectedComponents {
516
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexId, ED]
517
}
518
```
519
520
```scala
521
import org.apache.spark.graphx.lib.StronglyConnectedComponents
522
523
val sccGraph = StronglyConnectedComponents.run(graph, numIter = 10)
524
525
// Find strongly connected components
526
val sccSizes = sccGraph.vertices
527
.map(_._2)
528
.countByValue()
529
530
println(s"Found ${sccSizes.size} strongly connected components")
531
```
532
533
## Graph Partitioning
534
535
Control how graphs are partitioned across the cluster:
536
537
```scala { .api }
538
object PartitionStrategy {
539
val EdgePartition1D: PartitionStrategy
540
val EdgePartition2D: PartitionStrategy
541
val RandomVertexCut: PartitionStrategy
542
val CanonicalRandomVertexCut: PartitionStrategy
543
}
544
```
545
546
```scala
547
import org.apache.spark.graphx.{PartitionStrategy, Graph}
548
549
// Create graph with specific partitioning strategy
550
val partitionedGraph = Graph(vertices, edges)
551
.partitionBy(PartitionStrategy.EdgePartition2D, 4)
552
553
// Repartition existing graph
554
val repartitionedGraph = graph.partitionBy(PartitionStrategy.RandomVertexCut, 8)
555
```
556
557
## Performance Optimization
558
559
### Graph Caching
560
561
```scala
562
// Cache graph for iterative algorithms
563
val cachedGraph = graph.cache()
564
565
// Unpersist when done
566
cachedGraph.unpersist()
567
```
568
569
### Efficient Graph Construction
570
571
```scala
572
// For large graphs, construct more efficiently
573
val efficientGraph = Graph.fromEdges(edges, defaultVertexAttr = "default")
574
.partitionBy(PartitionStrategy.EdgePartition2D, numPartitions = 4)
575
.cache()
576
577
// Materialize the graph
578
efficientGraph.vertices.count()
579
efficientGraph.edges.count()
580
```
581
582
This comprehensive guide covers the complete GraphX API for building scalable graph processing applications in Apache Spark.