diff options
-rw-r--r-- | networkx/convert.py | 30 | ||||
-rw-r--r-- | networkx/tests/test_convert.py | 22 |
2 files changed, 37 insertions, 15 deletions
diff --git a/networkx/convert.py b/networkx/convert.py index fee8cf4a..eb7b524a 100644 --- a/networkx/convert.py +++ b/networkx/convert.py @@ -256,7 +256,6 @@ def from_dict_of_lists(d,create_using=None): """ G=_prep_create_using(create_using) G.add_nodes_from(d) - if G.is_multigraph() and not G.is_directed(): # a dict_of_lists can't show multiedges. BUT for undirected graphs, # each edge shows up twice in the dict_of_lists. @@ -340,7 +339,6 @@ def from_dict_of_dicts(d,create_using=None,multigraph_input=False): """ G=_prep_create_using(create_using) G.add_nodes_from(d) - # is dict a MultiGraph or MultiDiGraph? if multigraph_input: # make a copy of the list of edge data (but not the edge data) @@ -362,33 +360,35 @@ def from_dict_of_dicts(d,create_using=None,multigraph_input=False): seen=set() # don't add both directions of undirected graph for u,nbrs in d.items(): for v,datadict in nbrs.items(): - if v not in seen: + if (u,v) not in seen: G.add_edges_from( (u,v,key,data) for key,data in datadict.items() - ) - seen.add(u) + ) + seen.add((v,u)) else: seen=set() # don't add both directions of undirected graph for u,nbrs in d.items(): for v,datadict in nbrs.items(): - if v not in seen: + if (u,v) not in seen: G.add_edges_from( (u,v,data) for key,data in datadict.items() ) - seen.add(u) + seen.add((v,u)) else: # not a multigraph to multigraph transfer - if G.is_directed(): - G.add_edges_from( ( (u,v,data) - for u,nbrs in d.items() - for v,data in nbrs.items()) ) - # need this if G is multigraph and slightly faster if not multigraph - else: + if G.is_multigraph() and not G.is_directed(): + # d can have both representations u-v, v-u in dict. Only add one. + # We don't need this check for digraphs since we add both directions, + # or for Graph() since it is done implicitly (parallel edges not allowed) seen=set() for u,nbrs in d.items(): for v,data in nbrs.items(): - if v not in seen: + if (u,v) not in seen: G.add_edge(u,v,attr_dict=data) - seen.add(u) + seen.add((v,u)) + else: + G.add_edges_from( ( (u,v,data) + for u,nbrs in d.items() + for v,data in nbrs.items()) ) return G def to_edgelist(G,nodelist=None): diff --git a/networkx/tests/test_convert.py b/networkx/tests/test_convert.py index 16d516da..968fce47 100644 --- a/networkx/tests/test_convert.py +++ b/networkx/tests/test_convert.py @@ -11,6 +11,11 @@ from networkx.algorithms.operators import * from networkx.generators.classic import barbell_graph,cycle_graph class TestConvert(): + def edgelists_equal(self,e1,e2): + return sorted(sorted(e) for e in e1)==sorted(sorted(e) for e in e2) + + + def test_simple_graphs(self): for dest, source in [(to_dict_of_dicts, from_dict_of_dicts), (to_dict_of_lists, from_dict_of_lists)]: @@ -164,7 +169,9 @@ class TestConvert(): GI=MultiGraph(XGM) assert_equal(sorted(XGM.nodes()), sorted(GI.nodes())) assert_equal(sorted(XGM.edges()), sorted(GI.edges())) + print G.edges() GM=MultiGraph(G) + print GM.edges() assert_equal(sorted(GM.nodes()), sorted(G.nodes())) assert_equal(sorted(GM.edges()), sorted(G.edges())) @@ -202,3 +209,18 @@ class TestConvert(): assert_equal(sorted(G.edges()), sorted(P.edges())) assert_equal(sorted(G.edges(data=True)), sorted(P.edges(data=True))) + def test_directed_to_undirected(self): + edges1 = [(0, 1), (1, 2), (2, 0)] + edges2 = [(0, 1), (1, 2), (0, 2)] + assert_true(self.edgelists_equal(nx.Graph(nx.DiGraph(edges1)).edges(),edges1)) + assert_true(self.edgelists_equal(nx.Graph(nx.DiGraph(edges2)).edges(),edges1)) + assert_true(self.edgelists_equal(nx.MultiGraph(nx.DiGraph(edges1)).edges(),edges1)) + assert_true(self.edgelists_equal(nx.MultiGraph(nx.DiGraph(edges2)).edges(),edges1)) + + assert_true(self.edgelists_equal(nx.MultiGraph(nx.MultiDiGraph(edges1)).edges(), + edges1)) + assert_true(self.edgelists_equal(nx.MultiGraph(nx.MultiDiGraph(edges2)).edges(), + edges1)) + + assert_true(self.edgelists_equal(nx.Graph(nx.MultiDiGraph(edges1)).edges(),edges1)) + assert_true(self.edgelists_equal(nx.Graph(nx.MultiDiGraph(edges2)).edges(),edges1)) |