diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2022-06-10 00:03:01 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-09 17:03:01 -0400 |
commit | 23fc7568a5179c54dccd0c6f68760877f88aa4b6 (patch) | |
tree | 6ea7dc56bdbec0182e60e2ce5b994f6af00c2216 /networkx/drawing | |
parent | 4e92d77ae1973e892712daac85392efbef9b3591 (diff) | |
download | networkx-23fc7568a5179c54dccd0c6f68760877f88aa4b6.tar.gz |
Recover order of layers in multipartite_layout when layers are sortable (#5705)
* Add test for sorted layers.
* Implement layer sorting when possible.
Co-authored-by: Dan Schult <dschult@colgate.edu>
* Add test case for non-sortable layers.
Co-authored-by: Dan Schult <dschult@colgate.edu>
Diffstat (limited to 'networkx/drawing')
-rw-r--r-- | networkx/drawing/layout.py | 9 | ||||
-rw-r--r-- | networkx/drawing/tests/test_layout.py | 21 |
2 files changed, 28 insertions, 2 deletions
diff --git a/networkx/drawing/layout.py b/networkx/drawing/layout.py index 23381a24..b6d2afe5 100644 --- a/networkx/drawing/layout.py +++ b/networkx/drawing/layout.py @@ -1083,11 +1083,16 @@ def multipartite_layout(G, subset_key="subset", align="vertical", scale=1, cente raise ValueError(msg) layers[layer] = [v] + layers.get(layer, []) + # Sort by layer, if possible + try: + layers = sorted(layers.items()) + except TypeError: + layers = list(layers.items()) + pos = None nodes = [] - width = len(layers) - for i, layer in enumerate(layers.values()): + for i, (_, layer) in enumerate(layers): height = len(layer) xs = np.repeat(i, height) ys = np.arange(0, height, dtype=float) diff --git a/networkx/drawing/tests/test_layout.py b/networkx/drawing/tests/test_layout.py index f2fd90e7..f24d0038 100644 --- a/networkx/drawing/tests/test_layout.py +++ b/networkx/drawing/tests/test_layout.py @@ -422,3 +422,24 @@ def test_multipartite_layout_nonnumeric_partition_labels(): G.add_edges_from([(0, 2), (0, 3), (1, 2)]) pos = nx.multipartite_layout(G) assert len(pos) == len(G) + + +def test_multipartite_layout_layer_order(): + """Return the layers in sorted order if the layers of the multipartite + graph are sortable. See gh-5691""" + G = nx.Graph() + for node, layer in zip(("a", "b", "c", "d", "e"), (2, 3, 1, 2, 4)): + G.add_node(node, subset=layer) + + # Horizontal alignment, therefore y-coord determines layers + pos = nx.multipartite_layout(G, align="horizontal") + + # Nodes "a" and "d" are in the same layer + assert pos["a"][-1] == pos["d"][-1] + # positions should be sorted according to layer + assert pos["c"][-1] < pos["a"][-1] < pos["b"][-1] < pos["e"][-1] + + # Make sure that multipartite_layout still works when layers are not sortable + G.nodes["a"]["subset"] = "layer_0" # Can't sort mixed strs/ints + pos_nosort = nx.multipartite_layout(G) # smoke test: this should not raise + assert pos_nosort.keys() == pos.keys() |