Python package for creating and manipulating graphs and networks
—
Graph layout algorithms and drawing functions with matplotlib integration for visualizing networks. NetworkX provides comprehensive visualization capabilities for exploring graph structure and properties.
Basic drawing functions using matplotlib for graph visualization.
def draw(G, pos=None, ax=None, **kwds):
"""
Draw graph with matplotlib.
Parameters:
- G: NetworkX graph
- pos: Dictionary of node positions (computed if None)
- ax: Matplotlib axis object
- **kwds: Keyword arguments for customization (node_color, edge_color, etc.)
Returns:
None (draws to current matplotlib axis)
"""
def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
"""Draw graph with node and edge labels."""
def draw_networkx_nodes(G, pos, nodelist=None, node_size=300, node_color='#1f78b4', node_shape='o', alpha=None, cmap=None, vmin=None, vmax=None, ax=None, linewidths=None, edgecolors=None, label=None, margins=None):
"""Draw graph nodes."""
def draw_networkx_edges(G, pos, edgelist=None, width=1.0, edge_color='k', style='solid', alpha=None, arrowstyle=None, arrowsize=10, edge_cmap=None, edge_vmin=None, edge_vmax=None, ax=None, arrows=None, label=None, node_size=300, nodelist=None, node_shape='o', connectionstyle='arc3', min_source_margin=0, min_target_margin=0):
"""Draw graph edges."""
def draw_networkx_labels(G, pos, labels=None, font_size=12, font_color='k', font_family='sans-serif', font_weight='normal', alpha=None, bbox=None, horizontalalignment='center', verticalalignment='center', ax=None, clip_on=True):
"""Draw node labels."""
def draw_networkx_edge_labels(G, pos, edge_labels=None, label_pos=0.5, font_size=10, font_color='k', font_family='sans-serif', font_weight='normal', alpha=None, bbox=None, horizontalalignment='center', verticalalignment='center', ax=None, rotate=True, clip_on=True):
"""Draw edge labels."""Algorithms for computing node positions for graph visualization.
def spring_layout(G, k=None, pos=None, fixed=None, iterations=50, threshold=1e-4, weight='weight', scale=1, center=None, dim=2, seed=None):
"""
Position nodes using Fruchterman-Reingold force-directed algorithm.
Parameters:
- G: NetworkX graph
- k: Optimal distance parameter
- pos: Initial positions dictionary
- fixed: Nodes to keep fixed at initial positions
- iterations: Maximum number of iterations
- threshold: Convergence threshold
- weight: Edge data key for forces
- scale: Scale factor for positions
- center: Coordinate pair for center of layout
- dim: Dimension of layout (2 or 3)
- seed: Random seed
Returns:
Dictionary of node positions
"""
def kamada_kawai_layout(G, dist=None, pos=None, weight='weight', scale=1, center=None, dim=2):
"""Position nodes using Kamada-Kawai path-length cost-function."""
def circular_layout(G, scale=1, center=None, dim=2):
"""Position nodes on a circle."""
def random_layout(G, center=None, dim=2, seed=None):
"""Position nodes uniformly at random."""
def shell_layout(G, nlist=None, rotate=None, scale=1, center=None, dim=2):
"""Position nodes in concentric circles."""
def spectral_layout(G, weight='weight', scale=1, center=None, dim=2):
"""Position nodes using eigenvectors of graph Laplacian."""
def planar_layout(G, scale=1, center=None, dim=2):
"""Position nodes for planar graphs using Schnyder's algorithm."""
def spiral_layout(G, scale=1, center=None, dim=2, resolution=0.35, equidistant=False):
"""Position nodes in a spiral layout."""
def multipartite_layout(G, subset_key='subset', align='vertical', scale=1, center=None):
"""Position nodes in layers for multipartite graphs."""
def bipartite_layout(G, nodes, align='horizontal', scale=1, center=None, aspect_ratio=4/3):
"""Position nodes for bipartite graphs in two layers."""Drawing functions for specific layout algorithms.
def draw_circular(G, **kwds):
"""Draw graph with circular layout."""
def draw_kamada_kawai(G, **kwds):
"""Draw graph with Kamada-Kawai layout."""
def draw_random(G, **kwds):
"""Draw graph with random layout."""
def draw_spectral(G, **kwds):
"""Draw graph with spectral layout."""
def draw_spring(G, **kwds):
"""Draw graph with spring layout."""
def draw_planar(G, **kwds):
"""Draw graph with planar layout."""
def draw_shell(G, **kwds):
"""Draw graph with shell layout."""Helper functions for working with graph layouts.
def rescale_layout(pos, scale=1):
"""
Rescale layout coordinates.
Parameters:
- pos: Dictionary of node positions
- scale: Scale factor
Returns:
Dictionary of rescaled positions
"""
def rescale_layout_dict(pos, scale=1):
"""Rescale layout with proper handling of dict format."""Interface with Graphviz layout engines (requires pygraphviz).
def graphviz_layout(G, prog='neato', root=None, args=''):
"""
Create layout using Graphviz layout programs.
Parameters:
- G: NetworkX graph
- prog: Graphviz layout program ('dot', 'neato', 'fdp', 'sfdp', 'twopi', 'circo')
- root: Root node for directed layouts
- args: Additional arguments to layout program
Returns:
Dictionary of node positions
"""
def pygraphviz_layout(G, prog='neato', root=None, args=''):
"""Alias for graphviz_layout."""
def to_agraph(N):
"""Convert NetworkX graph to PyGraphviz AGraph."""
def from_agraph(A, create_using=None):
"""Convert PyGraphviz AGraph to NetworkX graph."""
def write_dot(G, path):
"""Write graph in DOT format using PyGraphviz."""
def read_dot(path):
"""Read graph from DOT file using PyGraphviz."""
def view_pygraphviz(G, edgelabel=None, prog='dot', args='', suffix='', path=None):
"""View graph using PyGraphviz and system viewer."""Alternative interface with Graphviz using pydot library.
def pydot_layout(G, prog='neato', root=None, **kwds):
"""Create layout using pydot and Graphviz."""
def to_pydot(N):
"""Convert NetworkX graph to pydot graph."""
def from_pydot(P):
"""Convert pydot graph to NetworkX graph."""import networkx as nx
import matplotlib.pyplot as plt
# Create sample graphs
G1 = nx.complete_graph(6)
G2 = nx.cycle_graph(8)
G3 = nx.star_graph(7)
G4 = nx.wheel_graph(8)
# Create subplot for multiple graphs
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Draw graphs with different layouts
nx.draw(G1, ax=axes[0,0], with_labels=True, node_color='lightblue',
node_size=500, font_size=12, font_weight='bold')
axes[0,0].set_title("Complete Graph K6")
nx.draw_circular(G2, ax=axes[0,1], with_labels=True, node_color='lightgreen',
node_size=500, font_size=12, font_weight='bold')
axes[0,1].set_title("Cycle Graph C8")
nx.draw(G3, ax=axes[1,0], with_labels=True, node_color='lightcoral',
node_size=500, font_size=12, font_weight='bold')
axes[1,0].set_title("Star Graph S7")
nx.draw_circular(G4, ax=axes[1,1], with_labels=True, node_color='lightyellow',
node_size=500, font_size=12, font_weight='bold')
axes[1,1].set_title("Wheel Graph W8")
plt.tight_layout()
plt.show()import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
# Create graph with attributes
G = nx.karate_club_graph()
# Add attributes for visualization
for node in G.nodes():
G.nodes[node]['group'] = 'A' if G.degree(node) > 5 else 'B'
# Compute different layouts
layouts = {
'Spring': nx.spring_layout(G, seed=42),
'Kamada-Kawai': nx.kamada_kawai_layout(G),
'Circular': nx.circular_layout(G),
'Spectral': nx.spectral_layout(G)
}
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
axes = axes.flatten()
for i, (name, pos) in enumerate(layouts.items()):
# Node colors based on group
node_colors = ['red' if G.nodes[node]['group'] == 'A' else 'blue'
for node in G.nodes()]
# Node sizes based on degree
node_sizes = [G.degree(node) * 50 for node in G.nodes()]
# Draw graph
nx.draw(G, pos=pos, ax=axes[i],
node_color=node_colors,
node_size=node_sizes,
with_labels=True,
font_size=8,
font_weight='bold',
edge_color='gray',
alpha=0.7)
axes[i].set_title(f"{name} Layout")
plt.tight_layout()
plt.show()import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
# Create weighted graph
G = nx.Graph()
edges = [(1, 2, 0.8), (2, 3, 0.6), (3, 4, 0.9), (4, 1, 0.7),
(1, 3, 0.3), (2, 4, 0.4)]
G.add_weighted_edges_from(edges)
# Position nodes
pos = nx.spring_layout(G, seed=42)
# Get edge weights for visualization
edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
# Create figure
plt.figure(figsize=(12, 5))
# Basic visualization
plt.subplot(131)
nx.draw(G, pos, with_labels=True, node_color='lightblue',
node_size=500, font_size=12, font_weight='bold')
plt.title("Basic Graph")
# Edge width based on weight
plt.subplot(132)
nx.draw(G, pos, with_labels=True,
node_color='lightgreen',
node_size=500,
edge_color=edge_weights,
width=[w*5 for w in edge_weights],
edge_cmap=plt.cm.Blues,
font_size=12, font_weight='bold')
plt.title("Edge Width by Weight")
# With edge labels
plt.subplot(133)
nx.draw(G, pos, with_labels=True,
node_color='lightcoral',
node_size=500,
font_size=12, font_weight='bold')
# Add edge labels
edge_labels = {(u, v): f"{d['weight']:.1f}" for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=10)
plt.title("With Edge Labels")
plt.tight_layout()
plt.show()import networkx as nx
import matplotlib.pyplot as plt
# Create bipartite graph
B = nx.Graph()
# Add nodes with bipartite attribute
B.add_nodes_from([1, 2, 3, 4], bipartite=0) # Top nodes
B.add_nodes_from(['a', 'b', 'c'], bipartite=1) # Bottom nodes
# Add edges
B.add_edges_from([(1, 'a'), (2, 'a'), (2, 'b'), (3, 'b'), (3, 'c'), (4, 'c')])
# Create tree
T = nx.balanced_tree(2, 3)
# Create multipartite graph
M = nx.Graph()
M.add_nodes_from([1, 2, 3], subset=0)
M.add_nodes_from([4, 5, 6, 7], subset=1)
M.add_nodes_from([8, 9], subset=2)
M.add_edges_from([(1, 4), (2, 5), (3, 6), (4, 8), (5, 8), (6, 9), (7, 9)])
# Create layouts
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Bipartite layout
top_nodes = {n for n, d in B.nodes(data=True) if d['bipartite'] == 0}
pos_bipartite = nx.bipartite_layout(B, top_nodes)
nx.draw(B, pos=pos_bipartite, ax=axes[0], with_labels=True,
node_color=['lightblue' if n in top_nodes else 'lightcoral' for n in B.nodes()],
node_size=500, font_size=12, font_weight='bold')
axes[0].set_title("Bipartite Layout")
# Tree layout (hierarchical)
pos_tree = nx.nx_agraph.graphviz_layout(T, prog='dot') if hasattr(nx, 'nx_agraph') else nx.spring_layout(T)
nx.draw(T, pos=pos_tree, ax=axes[1], with_labels=True,
node_color='lightgreen', node_size=500, font_size=12, font_weight='bold')
axes[1].set_title("Hierarchical Tree")
# Multipartite layout
pos_multi = nx.multipartite_layout(M, subset_key='subset')
colors = ['red' if d['subset'] == 0 else 'blue' if d['subset'] == 1 else 'green'
for n, d in M.nodes(data=True)]
nx.draw(M, pos=pos_multi, ax=axes[2], with_labels=True,
node_color=colors, node_size=500, font_size=12, font_weight='bold')
axes[2].set_title("Multipartite Layout")
plt.tight_layout()
plt.show()import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
# Create graph that changes over time
def create_temporal_graph(t):
"""Create graph that evolves with parameter t."""
G = nx.Graph()
n = 10
# Add nodes
G.add_nodes_from(range(n))
# Add edges based on time parameter
for i in range(n):
for j in range(i+1, n):
# Probability of edge depends on time and distance
prob = 0.3 * (1 + np.sin(t + i*0.5 + j*0.3))
if np.random.random() < prob:
G.add_edge(i, j)
return G
# Set up the figure and axis
fig, ax = plt.subplots(figsize=(10, 8))
plt.subplots_adjust(bottom=0.15)
# Initial graph
t_init = 0
G = create_temporal_graph(t_init)
pos = nx.circular_layout(G) # Keep positions fixed
# Initial drawing
nx.draw(G, pos=pos, ax=ax, with_labels=True,
node_color='lightblue', node_size=500,
font_size=12, font_weight='bold')
ax.set_title(f"Temporal Graph (t = {t_init:.2f})")
# Add slider
ax_slider = plt.axes([0.2, 0.05, 0.6, 0.03])
slider = Slider(ax_slider, 'Time', 0, 2*np.pi, valinit=t_init)
def update_graph(val):
"""Update graph based on slider value."""
ax.clear()
t = slider.val
# Create new graph
np.random.seed(42) # For reproducibility
G = create_temporal_graph(t)
# Draw updated graph
nx.draw(G, pos=pos, ax=ax, with_labels=True,
node_color='lightblue', node_size=500,
font_size=12, font_weight='bold')
ax.set_title(f"Temporal Graph (t = {t:.2f})")
fig.canvas.draw()
slider.on_changed(update_graph)
plt.show()import networkx as nx
import matplotlib.pyplot as plt
def compare_graphs(graphs, titles, layout_func=nx.spring_layout):
"""Compare multiple graphs side by side."""
n_graphs = len(graphs)
fig, axes = plt.subplots(1, n_graphs, figsize=(5*n_graphs, 5))
if n_graphs == 1:
axes = [axes]
for i, (G, title) in enumerate(zip(graphs, titles)):
pos = layout_func(G, seed=42)
# Compute node colors based on centrality
centrality = nx.degree_centrality(G)
node_colors = [centrality[node] for node in G.nodes()]
nx.draw(G, pos=pos, ax=axes[i],
node_color=node_colors,
cmap=plt.cm.viridis,
node_size=500,
with_labels=True,
font_size=10,
font_weight='bold',
edge_color='gray',
alpha=0.8)
axes[i].set_title(f"{title}\n{G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
plt.tight_layout()
return fig
# Create different graph types for comparison
graphs = [
nx.erdos_renyi_graph(20, 0.15, seed=42),
nx.barabasi_albert_graph(20, 2, seed=42),
nx.watts_strogatz_graph(20, 4, 0.3, seed=42),
nx.complete_graph(8)
]
titles = ['Erdős-Rényi', 'Barabási-Albert', 'Watts-Strogatz', 'Complete']
fig = compare_graphs(graphs, titles)
plt.show()Install with Tessl CLI
npx tessl i tessl/pypi-networkx