summaryrefslogtreecommitdiff
path: root/networkx/drawing
diff options
context:
space:
mode:
authorRoss Barnowski <rossbar@berkeley.edu>2022-06-10 00:03:01 +0300
committerGitHub <noreply@github.com>2022-06-09 17:03:01 -0400
commit23fc7568a5179c54dccd0c6f68760877f88aa4b6 (patch)
tree6ea7dc56bdbec0182e60e2ce5b994f6af00c2216 /networkx/drawing
parent4e92d77ae1973e892712daac85392efbef9b3591 (diff)
downloadnetworkx-23fc7568a5179c54dccd0c6f68760877f88aa4b6.tar.gz
Recover order of layers in multipartite_layout when layers are sortable (#5705)
* Add test for sorted layers. * Implement layer sorting when possible. Co-authored-by: Dan Schult <dschult@colgate.edu> * Add test case for non-sortable layers. Co-authored-by: Dan Schult <dschult@colgate.edu>
Diffstat (limited to 'networkx/drawing')
-rw-r--r--networkx/drawing/layout.py9
-rw-r--r--networkx/drawing/tests/test_layout.py21
2 files changed, 28 insertions, 2 deletions
diff --git a/networkx/drawing/layout.py b/networkx/drawing/layout.py
index 23381a24..b6d2afe5 100644
--- a/networkx/drawing/layout.py
+++ b/networkx/drawing/layout.py
@@ -1083,11 +1083,16 @@ def multipartite_layout(G, subset_key="subset", align="vertical", scale=1, cente
raise ValueError(msg)
layers[layer] = [v] + layers.get(layer, [])
+ # Sort by layer, if possible
+ try:
+ layers = sorted(layers.items())
+ except TypeError:
+ layers = list(layers.items())
+
pos = None
nodes = []
-
width = len(layers)
- for i, layer in enumerate(layers.values()):
+ for i, (_, layer) in enumerate(layers):
height = len(layer)
xs = np.repeat(i, height)
ys = np.arange(0, height, dtype=float)
diff --git a/networkx/drawing/tests/test_layout.py b/networkx/drawing/tests/test_layout.py
index f2fd90e7..f24d0038 100644
--- a/networkx/drawing/tests/test_layout.py
+++ b/networkx/drawing/tests/test_layout.py
@@ -422,3 +422,24 @@ def test_multipartite_layout_nonnumeric_partition_labels():
G.add_edges_from([(0, 2), (0, 3), (1, 2)])
pos = nx.multipartite_layout(G)
assert len(pos) == len(G)
+
+
+def test_multipartite_layout_layer_order():
+ """Return the layers in sorted order if the layers of the multipartite
+ graph are sortable. See gh-5691"""
+ G = nx.Graph()
+ for node, layer in zip(("a", "b", "c", "d", "e"), (2, 3, 1, 2, 4)):
+ G.add_node(node, subset=layer)
+
+ # Horizontal alignment, therefore y-coord determines layers
+ pos = nx.multipartite_layout(G, align="horizontal")
+
+ # Nodes "a" and "d" are in the same layer
+ assert pos["a"][-1] == pos["d"][-1]
+ # positions should be sorted according to layer
+ assert pos["c"][-1] < pos["a"][-1] < pos["b"][-1] < pos["e"][-1]
+
+ # Make sure that multipartite_layout still works when layers are not sortable
+ G.nodes["a"]["subset"] = "layer_0" # Can't sort mixed strs/ints
+ pos_nosort = nx.multipartite_layout(G) # smoke test: this should not raise
+ assert pos_nosort.keys() == pos.keys()