summaryrefslogtreecommitdiff
path: root/examples/algorithms/plot_snap.py
blob: 85ea71de5a86b0fe93fa4c0f418b64727f2fc52b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
==================
SNAP Graph Summary
==================
An example of summarizing a graph based on node attributes and edge attributes
using the Summarization by Grouping Nodes on Attributes and Pairwise
edges (SNAP) algorithm (not to be confused with the Stanford Network
Analysis Project).  The algorithm groups nodes by their unique
combinations of node attribute values and edge types with other groups
of nodes to produce a summary graph.  The summary graph can then be used to
infer how nodes with different attributes values relate to other nodes in the
graph.
"""
import networkx as nx
import matplotlib.pyplot as plt


nodes = {
    "A": dict(color="Red"),
    "B": dict(color="Red"),
    "C": dict(color="Red"),
    "D": dict(color="Red"),
    "E": dict(color="Blue"),
    "F": dict(color="Blue"),
    "G": dict(color="Blue"),
    "H": dict(color="Blue"),
    "I": dict(color="Yellow"),
    "J": dict(color="Yellow"),
    "K": dict(color="Yellow"),
    "L": dict(color="Yellow"),
}
edges = [
    ("A", "B", "Strong"),
    ("A", "C", "Weak"),
    ("A", "E", "Strong"),
    ("A", "I", "Weak"),
    ("B", "D", "Weak"),
    ("B", "J", "Weak"),
    ("B", "F", "Strong"),
    ("C", "G", "Weak"),
    ("D", "H", "Weak"),
    ("I", "J", "Strong"),
    ("J", "K", "Strong"),
    ("I", "L", "Strong"),
]
original_graph = nx.Graph()
original_graph.add_nodes_from(n for n in nodes.items())
original_graph.add_edges_from((u, v, {"type": label}) for u, v, label in edges)


plt.suptitle("SNAP Summarization")

base_options = dict(with_labels=True, edgecolors="black", node_size=500)

ax1 = plt.subplot(1, 2, 1)
plt.title(
    "Original (%s nodes, %s edges)"
    % (original_graph.number_of_nodes(), original_graph.number_of_edges())
)
pos = nx.spring_layout(original_graph, seed=7482934)
node_colors = [d["color"] for _, d in original_graph.nodes(data=True)]

edge_type_visual_weight_lookup = {"Weak": 1.0, "Strong": 3.0}
edge_weights = [
    edge_type_visual_weight_lookup[d["type"]]
    for _, _, d in original_graph.edges(data=True)
]

nx.draw_networkx(
    original_graph, pos=pos, node_color=node_colors, width=edge_weights, **base_options
)

node_attributes = ("color",)
edge_attributes = ("type",)
summary_graph = nx.snap_aggregation(
    original_graph, node_attributes, edge_attributes, prefix="S-"
)

plt.subplot(1, 2, 2)

plt.title(
    "SNAP Aggregation (%s nodes, %s edges)"
    % (summary_graph.number_of_nodes(), summary_graph.number_of_edges())
)
summary_pos = nx.spring_layout(summary_graph, seed=8375428)
node_colors = []
for node in summary_graph:
    color = summary_graph.nodes[node]["color"]
    node_colors.append(color)

edge_weights = []
for edge in summary_graph.edges():
    edge_types = summary_graph.get_edge_data(*edge)["types"]
    edge_weight = 0.0
    for edge_type in edge_types:
        edge_weight += edge_type_visual_weight_lookup[edge_type["type"]]
    edge_weights.append(edge_weight)

nx.draw_networkx(
    summary_graph,
    pos=summary_pos,
    node_color=node_colors,
    width=edge_weights,
    **base_options
)

plt.tight_layout()
plt.show()