diff options
author | ldelille <63287336+ldelille@users.noreply.github.com> | 2021-02-02 23:48:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-02 14:48:25 -0800 |
commit | 29e3f3b94d36f0fcf95fa8054318686b19301439 (patch) | |
tree | ad229ea83abedb5a0c08c0c1c59f9762d53ac717 | |
parent | 91197ef9c48fd977eb290b28854d21957ea6fd03 (diff) | |
download | networkx-29e3f3b94d36f0fcf95fa8054318686b19301439.tar.gz |
Improve intersection function (#4588)
* enhanced intersection function for graphs with different node sets + associated tests
* updated documentation according to enhancement of intersection function
* enhanced documentation of intersection function
* Modified tests for intersection of multiple graphs
* corrected typo in documentation of intersection function
* corrected test_intersection_multigraph_attributes (made sure the two graphs had different node sets during test)
* added intersection behavior if graph types are different
-rw-r--r-- | networkx/algorithms/operators/all.py | 4 | ||||
-rw-r--r-- | networkx/algorithms/operators/binary.py | 22 | ||||
-rw-r--r-- | networkx/algorithms/operators/tests/test_all.py | 55 | ||||
-rw-r--r-- | networkx/algorithms/operators/tests/test_binary.py | 56 |
4 files changed, 118 insertions, 19 deletions
diff --git a/networkx/algorithms/operators/all.py b/networkx/algorithms/operators/all.py index a08f634d..c3706687 100644 --- a/networkx/algorithms/operators/all.py +++ b/networkx/algorithms/operators/all.py @@ -130,11 +130,9 @@ def compose_all(graphs): def intersection_all(graphs): - """Returns a new graph that contains only the edges that exist in + """Returns a new graph that contains only the nodes and the edges that exist in all graphs. - All supplied graphs must have the same node set. - Parameters ---------- graphs : list diff --git a/networkx/algorithms/operators/binary.py b/networkx/algorithms/operators/binary.py index e3a52114..545bdc8b 100644 --- a/networkx/algorithms/operators/binary.py +++ b/networkx/algorithms/operators/binary.py @@ -137,15 +137,18 @@ def disjoint_union(G, H): def intersection(G, H): - """Returns a new graph that contains only the edges that exist in + """Returns a new graph that contains only the nodes and the edges that exist in both G and H. - The node sets of H and G must be the same. - Parameters ---------- G,H : graph - A NetworkX graph. G and H must have the same node sets. + A NetworkX graph. G and H can have different node sets but must be both graphs or both multigraphs. + + Raises + ------ + NetworkXError + If one is a MultiGraph and the other one is a graph. Returns ------- @@ -162,14 +165,17 @@ def intersection(G, H): >>> H = nx.path_graph(5) >>> R = G.copy() >>> R.remove_nodes_from(n for n in G if n not in H) + >>> R.remove_edges_from(e for e in G.edges if e not in H.edges) """ - # create new graph - R = nx.create_empty_copy(G) - if not G.is_multigraph() == H.is_multigraph(): raise nx.NetworkXError("G and H must both be graphs or multigraphs.") + + # create new graph if set(G) != set(H): - raise nx.NetworkXError("Node sets of graphs are not equal") + R = G.__class__() + R.add_nodes_from(set(G.nodes).intersection(set(H.nodes))) + else: + R = nx.create_empty_copy(G) if G.number_of_edges() <= H.number_of_edges(): if G.is_multigraph(): diff --git a/networkx/algorithms/operators/tests/test_all.py b/networkx/algorithms/operators/tests/test_all.py index 8c96b081..9502bbf9 100644 --- a/networkx/algorithms/operators/tests/test_all.py +++ b/networkx/algorithms/operators/tests/test_all.py @@ -48,6 +48,26 @@ def test_intersection_all(): assert sorted(I.edges()) == [(2, 3)] +def test_intersection_all_different_node_sets(): + G = nx.Graph() + H = nx.Graph() + R = nx.Graph() + G.add_nodes_from([1, 2, 3, 4, 6, 7]) + G.add_edge(1, 2) + G.add_edge(2, 3) + G.add_edge(6, 7) + H.add_nodes_from([1, 2, 3, 4]) + H.add_edge(2, 3) + H.add_edge(3, 4) + R.add_nodes_from([1, 2, 3, 4, 8, 9]) + R.add_edge(2, 3) + R.add_edge(4, 1) + R.add_edge(8, 9) + I = nx.intersection_all([G, H, R]) + assert set(I.nodes()) == {1, 2, 3, 4} + assert sorted(I.edges()) == [(2, 3)] + + def test_intersection_all_attributes(): g = nx.Graph() g.add_node(0, x=4) @@ -65,8 +85,23 @@ def test_intersection_all_attributes(): assert set(gh.nodes()) == set(h.nodes()) assert sorted(gh.edges()) == sorted(g.edges()) - h.remove_node(0) - pytest.raises(nx.NetworkXError, nx.intersection, g, h) + +def test_intersection_all_attributes_different_node_sets(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + g.add_node(2) + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + + gh = nx.intersection_all([g, h]) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == sorted(g.edges()) def test_intersection_all_multigraph_attributes(): @@ -84,6 +119,22 @@ def test_intersection_all_multigraph_attributes(): assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] +def test_intersection_all_multigraph_attributes_different_node_sets(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + g.add_edge(1, 2, key=1) + g.add_edge(1, 2, key=2) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.intersection_all([g, h]) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] + + def test_union_all_and_compose_all(): K3 = nx.complete_graph(3) P3 = nx.path_graph(3) diff --git a/networkx/algorithms/operators/tests/test_binary.py b/networkx/algorithms/operators/tests/test_binary.py index bf885130..c6b4adf9 100644 --- a/networkx/algorithms/operators/tests/test_binary.py +++ b/networkx/algorithms/operators/tests/test_binary.py @@ -39,6 +39,21 @@ def test_intersection(): assert sorted(I.edges()) == [(2, 3)] +def test_intersection_node_sets_different(): + G = nx.Graph() + H = nx.Graph() + G.add_nodes_from([1, 2, 3, 4, 7]) + G.add_edge(1, 2) + G.add_edge(2, 3) + H.add_nodes_from([1, 2, 3, 4, 5, 6]) + H.add_edge(2, 3) + H.add_edge(3, 4) + H.add_edge(5, 6) + I = nx.intersection(G, H) + assert set(I.nodes()) == {1, 2, 3, 4} + assert sorted(I.edges()) == [(2, 3)] + + def test_intersection_attributes(): g = nx.Graph() g.add_node(0, x=4) @@ -50,14 +65,30 @@ def test_intersection_attributes(): h.graph["name"] = "h" h.graph["attr"] = "attr" h.nodes[0]["x"] = 7 - gh = nx.intersection(g, h) + assert set(gh.nodes()) == set(g.nodes()) assert set(gh.nodes()) == set(h.nodes()) assert sorted(gh.edges()) == sorted(g.edges()) - h.remove_node(0) - pytest.raises(nx.NetworkXError, nx.intersection, g, h) + +def test_intersection_attributes_node_sets_different(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_node(2, x=3) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + h.remove_node(2) + + gh = nx.intersection(g, h) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == sorted(g.edges()) def test_intersection_multigraph_attributes(): @@ -75,6 +106,22 @@ def test_intersection_multigraph_attributes(): assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] +def test_intersection_multigraph_attributes_node_set_different(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + g.add_edge(0, 2, key=2) + g.add_edge(0, 2, key=1) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.intersection(g, h) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] + + def test_difference(): G = nx.Graph() H = nx.Graph() @@ -132,9 +179,6 @@ def test_difference_attributes(): assert set(gh.nodes()) == set(h.nodes()) assert sorted(gh.edges()) == [] - h.remove_node(0) - pytest.raises(nx.NetworkXError, nx.intersection, g, h) - def test_difference_multigraph_attributes(): g = nx.MultiGraph() |