"""Unit tests for the :mod:`networkx.algorithms.tree.mst` module.""" import pytest import networkx as nx from networkx.utils import edges_equal, nodes_equal def test_unknown_algorithm(): with pytest.raises(ValueError): nx.minimum_spanning_tree(nx.Graph(), algorithm="random") class MinimumSpanningTreeTestBase: """Base class for test classes for minimum spanning tree algorithms. This class contains some common tests that will be inherited by subclasses. Each subclass must have a class attribute :data:`algorithm` that is a string representing the algorithm to run, as described under the ``algorithm`` keyword argument for the :func:`networkx.minimum_spanning_edges` function. Subclasses can then implement any algorithm-specific tests. """ def setup_method(self, method): """Creates an example graph and stores the expected minimum and maximum spanning tree edges. """ # This stores the class attribute `algorithm` in an instance attribute. self.algo = self.algorithm # This example graph comes from Wikipedia: # https://en.wikipedia.org/wiki/Kruskal's_algorithm edges = [ (0, 1, 7), (0, 3, 5), (1, 2, 8), (1, 3, 9), (1, 4, 7), (2, 4, 5), (3, 4, 15), (3, 5, 6), (4, 5, 8), (4, 6, 9), (5, 6, 11), ] self.G = nx.Graph() self.G.add_weighted_edges_from(edges) self.minimum_spanning_edgelist = [ (0, 1, {"weight": 7}), (0, 3, {"weight": 5}), (1, 4, {"weight": 7}), (2, 4, {"weight": 5}), (3, 5, {"weight": 6}), (4, 6, {"weight": 9}), ] self.maximum_spanning_edgelist = [ (0, 1, {"weight": 7}), (1, 2, {"weight": 8}), (1, 3, {"weight": 9}), (3, 4, {"weight": 15}), (4, 6, {"weight": 9}), (5, 6, {"weight": 11}), ] def test_minimum_edges(self): edges = nx.minimum_spanning_edges(self.G, algorithm=self.algo) # Edges from the spanning edges functions don't come in sorted # orientation, so we need to sort each edge individually. actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges) assert edges_equal(actual, self.minimum_spanning_edgelist) def test_maximum_edges(self): edges = nx.maximum_spanning_edges(self.G, algorithm=self.algo) # Edges from the spanning edges functions don't come in sorted # orientation, so we need to sort each edge individually. actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges) assert edges_equal(actual, self.maximum_spanning_edgelist) def test_without_data(self): edges = nx.minimum_spanning_edges(self.G, algorithm=self.algo, data=False) # Edges from the spanning edges functions don't come in sorted # orientation, so we need to sort each edge individually. actual = sorted((min(u, v), max(u, v)) for u, v in edges) expected = [(u, v) for u, v, d in self.minimum_spanning_edgelist] assert edges_equal(actual, expected) def test_nan_weights(self): # Edge weights NaN never appear in the spanning tree. see #2164 G = self.G G.add_edge(0, 12, weight=float("nan")) edges = nx.minimum_spanning_edges( G, algorithm=self.algo, data=False, ignore_nan=True ) actual = sorted((min(u, v), max(u, v)) for u, v in edges) expected = [(u, v) for u, v, d in self.minimum_spanning_edgelist] assert edges_equal(actual, expected) # Now test for raising exception edges = nx.minimum_spanning_edges( G, algorithm=self.algo, data=False, ignore_nan=False ) with pytest.raises(ValueError): list(edges) # test default for ignore_nan as False edges = nx.minimum_spanning_edges(G, algorithm=self.algo, data=False) with pytest.raises(ValueError): list(edges) def test_nan_weights_order(self): # now try again with a nan edge at the beginning of G.nodes edges = [ (0, 1, 7), (0, 3, 5), (1, 2, 8), (1, 3, 9), (1, 4, 7), (2, 4, 5), (3, 4, 15), (3, 5, 6), (4, 5, 8), (4, 6, 9), (5, 6, 11), ] G = nx.Graph() G.add_weighted_edges_from([(u + 1, v + 1, wt) for u, v, wt in edges]) G.add_edge(0, 7, weight=float("nan")) edges = nx.minimum_spanning_edges( G, algorithm=self.algo, data=False, ignore_nan=True ) actual = sorted((min(u, v), max(u, v)) for u, v in edges) shift = [(u + 1, v + 1) for u, v, d in self.minimum_spanning_edgelist] assert edges_equal(actual, shift) def test_isolated_node(self): # now try again with an isolated node edges = [ (0, 1, 7), (0, 3, 5), (1, 2, 8), (1, 3, 9), (1, 4, 7), (2, 4, 5), (3, 4, 15), (3, 5, 6), (4, 5, 8), (4, 6, 9), (5, 6, 11), ] G = nx.Graph() G.add_weighted_edges_from([(u + 1, v + 1, wt) for u, v, wt in edges]) G.add_node(0) edges = nx.minimum_spanning_edges( G, algorithm=self.algo, data=False, ignore_nan=True ) actual = sorted((min(u, v), max(u, v)) for u, v in edges) shift = [(u + 1, v + 1) for u, v, d in self.minimum_spanning_edgelist] assert edges_equal(actual, shift) def test_minimum_tree(self): T = nx.minimum_spanning_tree(self.G, algorithm=self.algo) actual = sorted(T.edges(data=True)) assert edges_equal(actual, self.minimum_spanning_edgelist) def test_maximum_tree(self): T = nx.maximum_spanning_tree(self.G, algorithm=self.algo) actual = sorted(T.edges(data=True)) assert edges_equal(actual, self.maximum_spanning_edgelist) def test_disconnected(self): G = nx.Graph([(0, 1, {"weight": 1}), (2, 3, {"weight": 2})]) T = nx.minimum_spanning_tree(G, algorithm=self.algo) assert nodes_equal(list(T), list(range(4))) assert edges_equal(list(T.edges()), [(0, 1), (2, 3)]) def test_empty_graph(self): G = nx.empty_graph(3) T = nx.minimum_spanning_tree(G, algorithm=self.algo) assert nodes_equal(sorted(T), list(range(3))) assert T.number_of_edges() == 0 def test_attributes(self): G = nx.Graph() G.add_edge(1, 2, weight=1, color="red", distance=7) G.add_edge(2, 3, weight=1, color="green", distance=2) G.add_edge(1, 3, weight=10, color="blue", distance=1) G.graph["foo"] = "bar" T = nx.minimum_spanning_tree(G, algorithm=self.algo) assert T.graph == G.graph assert nodes_equal(T, G) for u, v in T.edges(): assert T.adj[u][v] == G.adj[u][v] def test_weight_attribute(self): G = nx.Graph() G.add_edge(0, 1, weight=1, distance=7) G.add_edge(0, 2, weight=30, distance=1) G.add_edge(1, 2, weight=1, distance=1) G.add_node(3) T = nx.minimum_spanning_tree(G, algorithm=self.algo, weight="distance") assert nodes_equal(sorted(T), list(range(4))) assert edges_equal(sorted(T.edges()), [(0, 2), (1, 2)]) T = nx.maximum_spanning_tree(G, algorithm=self.algo, weight="distance") assert nodes_equal(sorted(T), list(range(4))) assert edges_equal(sorted(T.edges()), [(0, 1), (0, 2)]) class TestBoruvka(MinimumSpanningTreeTestBase): """Unit tests for computing a minimum (or maximum) spanning tree using Borůvka's algorithm. """ algorithm = "boruvka" def test_unicode_name(self): """Tests that using a Unicode string can correctly indicate Borůvka's algorithm. """ edges = nx.minimum_spanning_edges(self.G, algorithm="borůvka") # Edges from the spanning edges functions don't come in sorted # orientation, so we need to sort each edge individually. actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges) assert edges_equal(actual, self.minimum_spanning_edgelist) class MultigraphMSTTestBase(MinimumSpanningTreeTestBase): # Abstract class def test_multigraph_keys_min(self): """Tests that the minimum spanning edges of a multigraph preserves edge keys. """ G = nx.MultiGraph() G.add_edge(0, 1, key="a", weight=2) G.add_edge(0, 1, key="b", weight=1) min_edges = nx.minimum_spanning_edges mst_edges = min_edges(G, algorithm=self.algo, data=False) assert edges_equal([(0, 1, "b")], list(mst_edges)) def test_multigraph_keys_max(self): """Tests that the maximum spanning edges of a multigraph preserves edge keys. """ G = nx.MultiGraph() G.add_edge(0, 1, key="a", weight=2) G.add_edge(0, 1, key="b", weight=1) max_edges = nx.maximum_spanning_edges mst_edges = max_edges(G, algorithm=self.algo, data=False) assert edges_equal([(0, 1, "a")], list(mst_edges)) class TestKruskal(MultigraphMSTTestBase): """Unit tests for computing a minimum (or maximum) spanning tree using Kruskal's algorithm. """ algorithm = "kruskal" def test_key_data_bool(self): """Tests that the keys and data values are included in MST edges based on whether keys and data parameters are true or false""" G = nx.MultiGraph() G.add_edge(1, 2, key=1, weight=2) G.add_edge(1, 2, key=2, weight=3) G.add_edge(3, 2, key=1, weight=2) G.add_edge(3, 1, key=1, weight=4) # keys are included and data is not included mst_edges = nx.minimum_spanning_edges( G, algorithm=self.algo, keys=True, data=False ) assert edges_equal([(1, 2, 1), (2, 3, 1)], list(mst_edges)) # keys are not included and data is included mst_edges = nx.minimum_spanning_edges( G, algorithm=self.algo, keys=False, data=True ) assert edges_equal( [(1, 2, {"weight": 2}), (2, 3, {"weight": 2})], list(mst_edges) ) # both keys and data are not included mst_edges = nx.minimum_spanning_edges( G, algorithm=self.algo, keys=False, data=False ) assert edges_equal([(1, 2), (2, 3)], list(mst_edges)) class TestPrim(MultigraphMSTTestBase): """Unit tests for computing a minimum (or maximum) spanning tree using Prim's algorithm. """ algorithm = "prim" def test_ignore_nan(self): """Tests that the edges with NaN weights are ignored or raise an Error based on ignore_nan is true or false""" H = nx.MultiGraph() H.add_edge(1, 2, key=1, weight=float("nan")) H.add_edge(1, 2, key=2, weight=3) H.add_edge(3, 2, key=1, weight=2) H.add_edge(3, 1, key=1, weight=4) # NaN weight edges are ignored when ignore_nan=True mst_edges = nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=True) assert edges_equal( [(1, 2, 2, {"weight": 3}), (2, 3, 1, {"weight": 2})], list(mst_edges) ) # NaN weight edges raise Error when ignore_nan=False with pytest.raises(ValueError): list(nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=False)) def test_multigraph_keys_tree(self): G = nx.MultiGraph() G.add_edge(0, 1, key="a", weight=2) G.add_edge(0, 1, key="b", weight=1) T = nx.minimum_spanning_tree(G, algorithm=self.algo) assert edges_equal([(0, 1, 1)], list(T.edges(data="weight"))) def test_multigraph_keys_tree_max(self): G = nx.MultiGraph() G.add_edge(0, 1, key="a", weight=2) G.add_edge(0, 1, key="b", weight=1) T = nx.maximum_spanning_tree(G, algorithm=self.algo) assert edges_equal([(0, 1, 2)], list(T.edges(data="weight"))) class TestSpanningTreeIterator: """ Tests the spanning tree iterator on the example graph in the 2005 Sörensen and Janssens paper An Algorithm to Generate all Spanning Trees of a Graph in Order of Increasing Cost """ def setup_method(self): # Original Graph edges = [(0, 1, 5), (1, 2, 4), (1, 4, 6), (2, 3, 5), (2, 4, 7), (3, 4, 3)] self.G = nx.Graph() self.G.add_weighted_edges_from(edges) # List of lists of spanning trees in increasing order self.spanning_trees = [ # 1, MST, cost = 17 [ (0, 1, {"weight": 5}), (1, 2, {"weight": 4}), (2, 3, {"weight": 5}), (3, 4, {"weight": 3}), ], # 2, cost = 18 [ (0, 1, {"weight": 5}), (1, 2, {"weight": 4}), (1, 4, {"weight": 6}), (3, 4, {"weight": 3}), ], # 3, cost = 19 [ (0, 1, {"weight": 5}), (1, 4, {"weight": 6}), (2, 3, {"weight": 5}), (3, 4, {"weight": 3}), ], # 4, cost = 19 [ (0, 1, {"weight": 5}), (1, 2, {"weight": 4}), (2, 4, {"weight": 7}), (3, 4, {"weight": 3}), ], # 5, cost = 20 [ (0, 1, {"weight": 5}), (1, 2, {"weight": 4}), (1, 4, {"weight": 6}), (2, 3, {"weight": 5}), ], # 6, cost = 21 [ (0, 1, {"weight": 5}), (1, 4, {"weight": 6}), (2, 4, {"weight": 7}), (3, 4, {"weight": 3}), ], # 7, cost = 21 [ (0, 1, {"weight": 5}), (1, 2, {"weight": 4}), (2, 3, {"weight": 5}), (2, 4, {"weight": 7}), ], # 8, cost = 23 [ (0, 1, {"weight": 5}), (1, 4, {"weight": 6}), (2, 3, {"weight": 5}), (2, 4, {"weight": 7}), ], ] def test_minimum_spanning_tree_iterator(self): """ Tests that the spanning trees are correctly returned in increasing order """ tree_index = 0 for tree in nx.SpanningTreeIterator(self.G): actual = sorted(tree.edges(data=True)) assert edges_equal(actual, self.spanning_trees[tree_index]) tree_index += 1 def test_maximum_spanning_tree_iterator(self): """ Tests that the spanning trees are correctly returned in decreasing order """ tree_index = 7 for tree in nx.SpanningTreeIterator(self.G, minimum=False): actual = sorted(tree.edges(data=True)) assert edges_equal(actual, self.spanning_trees[tree_index]) tree_index -= 1 def test_random_spanning_tree_multiplicative_small(): """ Using a fixed seed, sample one tree for repeatability. """ from math import exp pytest.importorskip("scipy") gamma = { (0, 1): -0.6383, (0, 2): -0.6827, (0, 5): 0, (1, 2): -1.0781, (1, 4): 0, (2, 3): 0, (5, 3): -0.2820, (5, 4): -0.3327, (4, 3): -0.9927, } # The undirected support of gamma G = nx.Graph() for u, v in gamma: G.add_edge(u, v, lambda_key=exp(gamma[(u, v)])) solution_edges = [(2, 3), (3, 4), (0, 5), (5, 4), (4, 1)] solution = nx.Graph() solution.add_edges_from(solution_edges) sampled_tree = nx.random_spanning_tree(G, "lambda_key", seed=42) assert nx.utils.edges_equal(solution.edges, sampled_tree.edges) @pytest.mark.slow def test_random_spanning_tree_multiplicative_large(): """ Sample many trees from the distribution created in the last test """ from math import exp from random import Random pytest.importorskip("numpy") stats = pytest.importorskip("scipy.stats") gamma = { (0, 1): -0.6383, (0, 2): -0.6827, (0, 5): 0, (1, 2): -1.0781, (1, 4): 0, (2, 3): 0, (5, 3): -0.2820, (5, 4): -0.3327, (4, 3): -0.9927, } # The undirected support of gamma G = nx.Graph() for u, v in gamma: G.add_edge(u, v, lambda_key=exp(gamma[(u, v)])) # Find the multiplicative weight for each tree. total_weight = 0 tree_expected = {} for t in nx.SpanningTreeIterator(G): # Find the multiplicative weight of the spanning tree weight = 1 for u, v, d in t.edges(data="lambda_key"): weight *= d tree_expected[t] = weight total_weight += weight # Assert that every tree has an entry in the expected distribution assert len(tree_expected) == 75 # Set the sample size and then calculate the expected number of times we # expect to see each tree. This test uses a near minimum sample size where # the most unlikely tree has an expected frequency of 5.15. # (Minimum required is 5) # # Here we also initialize the tree_actual dict so that we know the keys # match between the two. We will later take advantage of the fact that since # python 3.7 dict order is guaranteed so the expected and actual data will # have the same order. sample_size = 1200 tree_actual = {} for t in tree_expected: tree_expected[t] = (tree_expected[t] / total_weight) * sample_size tree_actual[t] = 0 # Sample the spanning trees # # Assert that they are actually trees and record which of the 75 trees we # have sampled. # # For repeatability, we want to take advantage of the decorators in NetworkX # to randomly sample the same sample each time. However, if we pass in a # constant seed to sample_spanning_tree we will get the same tree each time. # Instead, we can create our own random number generator with a fixed seed # and pass those into sample_spanning_tree. rng = Random(37) for _ in range(sample_size): sampled_tree = nx.random_spanning_tree(G, "lambda_key", seed=rng) assert nx.is_tree(sampled_tree) for t in tree_expected: if nx.utils.edges_equal(t.edges, sampled_tree.edges): tree_actual[t] += 1 break # Conduct a Chi squared test to see if the actual distribution matches the # expected one at an alpha = 0.05 significance level. # # H_0: The distribution of trees in tree_actual matches the normalized product # of the edge weights in the tree. # # H_a: The distribution of trees in tree_actual follows some other # distribution of spanning trees. _, p = stats.chisquare(list(tree_actual.values()), list(tree_expected.values())) # Assert that p is greater than the significance level so that we do not # reject the null hypothesis assert not p < 0.05 def test_random_spanning_tree_additive_small(): """ Sample a single spanning tree from the additive method. """ pytest.importorskip("scipy") edges = { (0, 1): 1, (0, 2): 1, (0, 5): 3, (1, 2): 2, (1, 4): 3, (2, 3): 3, (5, 3): 4, (5, 4): 5, (4, 3): 4, } # Build the graph G = nx.Graph() for u, v in edges: G.add_edge(u, v, weight=edges[(u, v)]) solution_edges = [(0, 2), (1, 2), (2, 3), (3, 4), (3, 5)] solution = nx.Graph() solution.add_edges_from(solution_edges) sampled_tree = nx.random_spanning_tree( G, weight="weight", multiplicative=False, seed=37 ) assert nx.utils.edges_equal(solution.edges, sampled_tree.edges) @pytest.mark.slow def test_random_spanning_tree_additive_large(): """ Sample many spanning trees from the additive method. """ from random import Random pytest.importorskip("numpy") stats = pytest.importorskip("scipy.stats") edges = { (0, 1): 1, (0, 2): 1, (0, 5): 3, (1, 2): 2, (1, 4): 3, (2, 3): 3, (5, 3): 4, (5, 4): 5, (4, 3): 4, } # Build the graph G = nx.Graph() for u, v in edges: G.add_edge(u, v, weight=edges[(u, v)]) # Find the additive weight for each tree. total_weight = 0 tree_expected = {} for t in nx.SpanningTreeIterator(G): # Find the multiplicative weight of the spanning tree weight = 0 for u, v, d in t.edges(data="weight"): weight += d tree_expected[t] = weight total_weight += weight # Assert that every tree has an entry in the expected distribution assert len(tree_expected) == 75 # Set the sample size and then calculate the expected number of times we # expect to see each tree. This test uses a near minimum sample size where # the most unlikely tree has an expected frequency of 5.07. # (Minimum required is 5) # # Here we also initialize the tree_actual dict so that we know the keys # match between the two. We will later take advantage of the fact that since # python 3.7 dict order is guaranteed so the expected and actual data will # have the same order. sample_size = 500 tree_actual = {} for t in tree_expected: tree_expected[t] = (tree_expected[t] / total_weight) * sample_size tree_actual[t] = 0 # Sample the spanning trees # # Assert that they are actually trees and record which of the 75 trees we # have sampled. # # For repeatability, we want to take advantage of the decorators in NetworkX # to randomly sample the same sample each time. However, if we pass in a # constant seed to sample_spanning_tree we will get the same tree each time. # Instead, we can create our own random number generator with a fixed seed # and pass those into sample_spanning_tree. rng = Random(37) for _ in range(sample_size): sampled_tree = nx.random_spanning_tree( G, "weight", multiplicative=False, seed=rng ) assert nx.is_tree(sampled_tree) for t in tree_expected: if nx.utils.edges_equal(t.edges, sampled_tree.edges): tree_actual[t] += 1 break # Conduct a Chi squared test to see if the actual distribution matches the # expected one at an alpha = 0.05 significance level. # # H_0: The distribution of trees in tree_actual matches the normalized product # of the edge weights in the tree. # # H_a: The distribution of trees in tree_actual follows some other # distribution of spanning trees. _, p = stats.chisquare(list(tree_actual.values()), list(tree_expected.values())) # Assert that p is greater than the significance level so that we do not # reject the null hypothesis assert not p < 0.05