diff options
Diffstat (limited to 'networkx/algorithms/lowest_common_ancestors.py')
-rw-r--r-- | networkx/algorithms/lowest_common_ancestors.py | 369 |
1 files changed, 47 insertions, 322 deletions
diff --git a/networkx/algorithms/lowest_common_ancestors.py b/networkx/algorithms/lowest_common_ancestors.py index ec515a90..68aaf7d3 100644 --- a/networkx/algorithms/lowest_common_ancestors.py +++ b/networkx/algorithms/lowest_common_ancestors.py @@ -1,7 +1,7 @@ """Algorithms for finding the lowest common ancestor of trees and DAGs.""" from collections import defaultdict from collections.abc import Mapping, Set -from itertools import chain, combinations_with_replacement, count +from itertools import combinations_with_replacement import networkx as nx from networkx.utils import UnionFind, arbitrary_element, not_implemented_for @@ -10,14 +10,12 @@ __all__ = [ "all_pairs_lowest_common_ancestor", "tree_all_pairs_lowest_common_ancestor", "lowest_common_ancestor", - "naive_lowest_common_ancestor", - "naive_all_pairs_lowest_common_ancestor", ] @not_implemented_for("undirected") @not_implemented_for("multigraph") -def naive_all_pairs_lowest_common_ancestor(G, pairs=None): +def all_pairs_lowest_common_ancestor(G, pairs=None): """Return the lowest common ancestor of all pairs or the provided pairs Parameters @@ -48,13 +46,13 @@ def naive_all_pairs_lowest_common_ancestor(G, pairs=None): possible combinations of nodes in `G`, including self-pairings: >>> G = nx.DiGraph([(0, 1), (0, 3), (1, 2)]) - >>> dict(nx.naive_all_pairs_lowest_common_ancestor(G)) + >>> dict(nx.all_pairs_lowest_common_ancestor(G)) {(0, 0): 0, (0, 1): 0, (0, 3): 0, (0, 2): 0, (1, 1): 1, (1, 3): 0, (1, 2): 1, (3, 3): 3, (3, 2): 0, (2, 2): 2} The pairs argument can be used to limit the output to only the specified node pairings: - >>> dict(nx.naive_all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)])) + >>> dict(nx.all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)])) {(1, 2): 1, (2, 3): 0} Notes @@ -63,46 +61,59 @@ def naive_all_pairs_lowest_common_ancestor(G, pairs=None): See Also -------- - naive_lowest_common_ancestor + lowest_common_ancestor """ if not nx.is_directed_acyclic_graph(G): raise nx.NetworkXError("LCA only defined on directed acyclic graphs.") if len(G) == 0: raise nx.NetworkXPointlessConcept("LCA meaningless on null graphs.") - ancestor_cache = {} - if pairs is None: - pairs = combinations_with_replacement(G, 2) - - for v, w in pairs: - if v not in ancestor_cache: - ancestor_cache[v] = nx.ancestors(G, v) - ancestor_cache[v].add(v) - if w not in ancestor_cache: - ancestor_cache[w] = nx.ancestors(G, w) - ancestor_cache[w].add(w) - - common_ancestors = ancestor_cache[v] & ancestor_cache[w] - - if common_ancestors: - common_ancestor = next(iter(common_ancestors)) - while True: - successor = None - for lower_ancestor in G.successors(common_ancestor): - if lower_ancestor in common_ancestors: - successor = lower_ancestor + else: + # Convert iterator to iterable, if necessary. Trim duplicates. + pairs = dict.fromkeys(pairs) + # Verify that each of the nodes in the provided pairs is in G + nodeset = set(G) + for pair in pairs: + if set(pair) - nodeset: + raise nx.NodeNotFound( + f"Node(s) {set(pair) - nodeset} from pair {pair} not in G." + ) + + # Once input validation is done, construct the generator + def generate_lca_from_pairs(G, pairs): + ancestor_cache = {} + + for v, w in pairs: + if v not in ancestor_cache: + ancestor_cache[v] = nx.ancestors(G, v) + ancestor_cache[v].add(v) + if w not in ancestor_cache: + ancestor_cache[w] = nx.ancestors(G, w) + ancestor_cache[w].add(w) + + common_ancestors = ancestor_cache[v] & ancestor_cache[w] + + if common_ancestors: + common_ancestor = next(iter(common_ancestors)) + while True: + successor = None + for lower_ancestor in G.successors(common_ancestor): + if lower_ancestor in common_ancestors: + successor = lower_ancestor + break + if successor is None: break - if successor is None: - break - common_ancestor = successor - yield ((v, w), common_ancestor) + common_ancestor = successor + yield ((v, w), common_ancestor) + + return generate_lca_from_pairs(G, pairs) @not_implemented_for("undirected") @not_implemented_for("multigraph") -def naive_lowest_common_ancestor(G, node1, node2, default=None): +def lowest_common_ancestor(G, node1, node2, default=None): """Compute the lowest common ancestor of the given pair of nodes. Parameters @@ -124,14 +135,14 @@ def naive_lowest_common_ancestor(G, node1, node2, default=None): >>> G = nx.DiGraph() >>> nx.add_path(G, (0, 1, 2, 3)) >>> nx.add_path(G, (0, 4, 3)) - >>> nx.naive_lowest_common_ancestor(G, 2, 4) + >>> nx.lowest_common_ancestor(G, 2, 4) 0 See Also -------- - naive_all_pairs_lowest_common_ancestor""" + all_pairs_lowest_common_ancestor""" - ans = list(naive_all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)])) + ans = list(all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)])) if ans: assert len(ans) == 1 return ans[0][1] @@ -254,289 +265,3 @@ def tree_all_pairs_lowest_common_ancestor(G, root=None, pairs=None): parent = arbitrary_element(G.pred[node]) uf.union(parent, node) ancestors[uf[parent]] = parent - - -@not_implemented_for("undirected") -@not_implemented_for("multigraph") -def lowest_common_ancestor(G, node1, node2, default=None): - """Compute the lowest common ancestor of the given pair of nodes. - - Parameters - ---------- - G : NetworkX directed graph - - node1, node2 : nodes in the graph. - - default : object - Returned if no common ancestor between `node1` and `node2` - - Returns - ------- - The lowest common ancestor of node1 and node2, - or default if they have no common ancestors. - - Examples - -------- - >>> G = nx.DiGraph([(0, 1), (0, 2), (2, 3), (2, 4), (1, 6), (4, 5)]) - >>> nx.lowest_common_ancestor(G, 3, 5) - 2 - - We can also set `default` argument as below. The value of default is returned - if there are no common ancestors of given two nodes. - - >>> G = nx.DiGraph([(4, 5), (12, 13)]) - >>> nx.lowest_common_ancestor(G, 12, 5, default="No common ancestors!") - 'No common ancestors!' - - Notes - ----- - Only defined on non-null directed acyclic graphs. - Takes n log(n) time in the size of the graph. - See `all_pairs_lowest_common_ancestor` when you have - more than one pair of nodes of interest. - - See Also - -------- - tree_all_pairs_lowest_common_ancestor - all_pairs_lowest_common_ancestor - """ - ans = list(all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)])) - if ans: - assert len(ans) == 1 - return ans[0][1] - return default - - -@not_implemented_for("undirected") -@not_implemented_for("multigraph") -def all_pairs_lowest_common_ancestor(G, pairs=None): - """Compute the lowest common ancestor for pairs of nodes. - - Parameters - ---------- - G : NetworkX directed graph - - pairs : iterable of pairs of nodes, optional (default: all pairs) - The pairs of nodes of interest. - If None, will find the LCA of all pairs of nodes. - - Yields - ------ - ((node1, node2), lca) : 2-tuple - Where lca is least common ancestor of node1 and node2. - Note that for the default case, the order of the node pair is not considered, - e.g. you will not get both ``(a, b)`` and ``(b, a)`` - - Raises - ------ - NetworkXPointlessConcept - If `G` is null. - NetworkXError - If `G` is not a DAG. - - Examples - -------- - The default behavior is to yield the lowest common ancestor for all - possible combinations of nodes in `G`, including self-pairings: - - >>> G = nx.DiGraph([(0, 1), (0, 3), (1, 2)]) - >>> dict(nx.all_pairs_lowest_common_ancestor(G)) - {(2, 2): 2, (1, 1): 1, (2, 1): 1, (1, 3): 0, (2, 3): 0, (3, 3): 3, (0, 0): 0, (1, 0): 0, (2, 0): 0, (3, 0): 0} - - The `pairs` argument can be used to limit the output to only the - specified node pairings: - - >>> dict(nx.all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)])) - {(2, 3): 0, (1, 2): 1} - - Notes - ----- - Only defined on non-null directed acyclic graphs. - - Uses the $O(n^3)$ ancestor-list algorithm from: - M. A. Bender, M. Farach-Colton, G. Pemmasani, S. Skiena, P. Sumazin. - "Lowest common ancestors in trees and directed acyclic graphs." - Journal of Algorithms, 57(2): 75-94, 2005. - - See Also - -------- - tree_all_pairs_lowest_common_ancestor - lowest_common_ancestor - """ - if not nx.is_directed_acyclic_graph(G): - raise nx.NetworkXError("LCA only defined on directed acyclic graphs.") - if len(G) == 0: - raise nx.NetworkXPointlessConcept("LCA meaningless on null graphs.") - - # The copy isn't ideal, neither is the switch-on-type, but without it users - # passing an iterable will encounter confusing errors, and itertools.tee - # does not appear to handle builtin types efficiently (IE, it materializes - # another buffer rather than just creating listoperators at the same - # offset). The Python documentation notes use of tee is unadvised when one - # is consumed before the other. - # - # This will always produce correct results and avoid unnecessary - # copies in many common cases. - # - if not isinstance(pairs, (Mapping, Set)) and pairs is not None: - pairs = set(pairs) - - # Convert G into a dag with a single root by adding a node with edges to - # all sources iff necessary. - sources = [n for n, deg in G.in_degree if deg == 0] - if len(sources) == 1: - root = sources[0] - super_root = None - else: - G = G.copy() - # find unused node - root = -1 - while root in G: - root -= 1 - # use that as the super_root below all sources - super_root = root - for source in sources: - G.add_edge(root, source) - - # Start by computing a spanning tree, and the DAG of all edges not in it. - # We will then use the tree lca algorithm on the spanning tree, and use - # the DAG to figure out the set of tree queries necessary. - spanning_tree = nx.dfs_tree(G, root) - dag = nx.DiGraph( - (u, v) - for u, v in G.edges - if u not in spanning_tree or v not in spanning_tree[u] - ) - - # Ensure that both the dag and the spanning tree contains all nodes in G, - # even nodes that are disconnected in the dag. - spanning_tree.add_nodes_from(G) - dag.add_nodes_from(G) - - counter = count() - - # Necessary to handle graphs consisting of a single node and no edges. - root_distance = {root: next(counter)} - - for edge in nx.bfs_edges(spanning_tree, root): - for node in edge: - if node not in root_distance: - root_distance[node] = next(counter) - - # Index the position of all nodes in the Euler tour so we can efficiently - # sort lists and merge in tour order. - euler_tour_pos = {} - for node in nx.depth_first_search.dfs_preorder_nodes(G, root): - if node not in euler_tour_pos: - euler_tour_pos[node] = next(counter) - - # Generate the set of all nodes of interest in the pairs. - pairset = set() - if pairs is not None: - pairset = set(chain.from_iterable(pairs)) - - for n in pairset: - if n not in G: - msg = f"The node {str(n)} is not in the digraph." - raise nx.NodeNotFound(msg) - - # Generate the transitive closure over the dag (not G) of all nodes, and - # sort each node's closure set by order of first appearance in the Euler - # tour. - ancestors = {} - for v in dag: - if pairs is None or v in pairset: - my_ancestors = nx.ancestors(G, v) - my_ancestors.add(v) - ancestors[v] = sorted(my_ancestors, key=euler_tour_pos.get) - - def _compute_dag_lca_from_tree_values(tree_lca, dry_run): - """Iterate through the in-order merge for each pair of interest. - - We do this to answer the user's query, but it is also used to - avoid generating unnecessary tree entries when the user only - needs some pairs. - """ - for (node1, node2) in pairs if pairs is not None else tree_lca: - best_root_distance = None - best = None - - indices = [0, 0] - ancestors_by_index = [ancestors[node1], ancestors[node2]] - - def get_next_in_merged_lists(indices): - """Returns index of the list containing the next item - - Next order refers to the merged order. - Index can be 0 or 1 (or None if exhausted). - """ - index1, index2 = indices - if index1 >= len(ancestors[node1]) and index2 >= len(ancestors[node2]): - return None - elif index1 >= len(ancestors[node1]): - return 1 - elif index2 >= len(ancestors[node2]): - return 0 - elif ( - euler_tour_pos[ancestors[node1][index1]] - < euler_tour_pos[ancestors[node2][index2]] - ): - return 0 - else: - return 1 - - # Find the LCA by iterating through the in-order merge of the two - # nodes of interests' ancestor sets. In principle, we need to - # consider all pairs in the Cartesian product of the ancestor sets, - # but by the restricted min range query reduction we are guaranteed - # that one of the pairs of interest is adjacent in the merged list - # iff one came from each list. - i = get_next_in_merged_lists(indices) - cur = ancestors_by_index[i][indices[i]], i - while i is not None: - prev = cur - indices[i] += 1 - i = get_next_in_merged_lists(indices) - if i is not None: - cur = ancestors_by_index[i][indices[i]], i - - # Two adjacent entries must not be from the same list - # in order for their tree LCA to be considered. - if cur[1] != prev[1]: - tree_node1, tree_node2 = prev[0], cur[0] - if (tree_node1, tree_node2) in tree_lca: - ans = tree_lca[tree_node1, tree_node2] - else: - ans = tree_lca[tree_node2, tree_node1] - if not dry_run and ( - best is None or root_distance[ans] > best_root_distance - ): - best_root_distance = root_distance[ans] - best = ans - - # If the LCA is super_root, there is no LCA in the user's graph. - if not dry_run and (super_root is None or best != super_root): - yield (node1, node2), best - - # Generate the spanning tree lca for all pairs. This doesn't make sense to - # do incrementally since we are using a linear time offline algorithm for - # tree lca. - if pairs is None: - # We want all pairs so we'll need the entire tree. - tree_lca = dict(tree_all_pairs_lowest_common_ancestor(spanning_tree, root)) - else: - # We only need the merged adjacent pairs by seeing which queries the - # algorithm needs then generating them in a single pass. - tree_lca = defaultdict(int) - for _ in _compute_dag_lca_from_tree_values(tree_lca, True): - pass - - # Replace the bogus default tree values with the real ones. - for (pair, lca) in tree_all_pairs_lowest_common_ancestor( - spanning_tree, root, tree_lca - ): - tree_lca[pair] = lca - - # All precomputations complete. Now we just need to give the user the pairs - # they asked for, or all pairs if they want them all. - return _compute_dag_lca_from_tree_values(tree_lca, False) |