diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-31 15:12:29 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-31 15:12:29 -0400 |
commit | 03fecba81969015afce99ebb8209a5e8a7988b34 (patch) | |
tree | 7c447ae94802fd4a30e9dfca6c15c1a62bd874cd | |
parent | 97ed8d47951d7777f2dd72a7f960d46bf833c0d3 (diff) | |
download | sqlalchemy-03fecba81969015afce99ebb8209a5e8a7988b34.tar.gz |
really got topological going. now that we aren't putting fricking mapped objects into
it all that id() stuff can go
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/topological.py | 123 | ||||
-rw-r--r-- | test/base/test_dependency.py | 136 |
3 files changed, 173 insertions, 93 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d5575e0e7..4c59f50d5 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -15,10 +15,13 @@ def sort_tables(tables): parent_table = fkey.column.table if parent_table in tables: child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) + if parent_table is not child_table: + tuples.append((parent_table, child_table)) for table in tables: - visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) + visitors.traverse(table, + {'schema_visitor':True}, + {'foreign_key':visit_foreign_key}) return topological.sort(tuples, tables) def find_join_source(clauses, join_to): diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index 324995889..15739ae82 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -21,23 +21,8 @@ conditions. from sqlalchemy.exc import CircularDependencyError from sqlalchemy import util -__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree'] +__all__ = ['sort'] -# TODO: obviate the need for a _Node class. -# a straight tuple should be used. -class _Node(tuple): - """Represent each item in the sort.""" - - def __new__(cls, item): - children = [] - t = tuple.__new__(cls, [item, children]) - t.item = item - t.children = children - return t - - def __hash__(self): - return id(self) - class _EdgeCollection(object): """A collection of directed edges.""" @@ -52,20 +37,6 @@ class _EdgeCollection(object): self.parent_to_children[parentnode].add(childnode) self.child_to_parents[childnode].add(parentnode) - def remove(self, edge): - """Remove an edge from this collection. - - Return the childnode if it has no other parents. - """ - - (parentnode, childnode) = edge - self.parent_to_children[parentnode].remove(childnode) - self.child_to_parents[childnode].remove(parentnode) - if not self.child_to_parents[childnode]: - return childnode - else: - return None - def has_parents(self, node): return node in self.child_to_parents and bool(self.child_to_parents[node]) @@ -74,7 +45,12 @@ class _EdgeCollection(object): return [(node, child) for child in self.parent_to_children[node]] else: return [] - + + def outgoing(self, node): + """an iterable returning all nodes reached via node's outgoing edges""" + + return self.parent_to_children[node] + def get_parents(self): return self.parent_to_children.keys() @@ -92,9 +68,6 @@ class _EdgeCollection(object): if not self.child_to_parents[child]: yield child - def __len__(self): - return sum(len(x) for x in self.parent_to_children.values()) - def __iter__(self): for parent, children in self.parent_to_children.iteritems(): for child in children: @@ -108,69 +81,57 @@ def sort(tuples, allitems): 'tuples' is a list of tuples representing a partial ordering. """ - nodes = {} - edges = _EdgeCollection() - for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: - item_id = id(item) - if item_id not in nodes: - nodes[item_id] = _Node(item) + edges = _EdgeCollection() + nodes = set(allitems) for t in tuples: - id0, id1 = id(t[0]), id(t[1]) - if t[0] is t[1]: - continue - childnode = nodes[id1] - parentnode = nodes[id0] - edges.add((parentnode, childnode)) + nodes.update(t) + edges.add(t) queue = [] - for n in nodes.values(): + for n in nodes: if not edges.has_parents(n): queue.append(n) output = [] while nodes: if not queue: - raise CircularDependencyError("Circular dependency detected " + - repr(edges) + repr(queue)) + raise CircularDependencyError("Circular dependency detected: %r" % edges) node = queue.pop() - output.append(node.item) - del nodes[id(node.item)] + output.append(node) + nodes.remove(node) for childnode in edges.pop_node(node): queue.append(childnode) return output +def find_cycles(tuples, allitems): + # straight from gvr with some mods + todo = set(allitems) + edges = _EdgeCollection() -def _find_cycles(edges): - cycles = {} - - def traverse(node, cycle, goal): - for (n, key) in edges.edges_by_parent(node): - if key in cycle: - continue - cycle.add(key) - if key is goal: - cycset = set(cycle) - for x in cycle: - if x in cycles: - existing_set = cycles[x] - existing_set.update(cycset) - for y in existing_set: - cycles[y] = existing_set - cycset = existing_set - else: - cycles[x] = cycset + for t in tuples: + todo.update(t) + edges.add(t) + + output = set() + + while todo: + node = todo.pop() + stack = [node] + while stack: + top = stack[-1] + for node in edges.outgoing(top): + if node in stack: + cyc = stack[stack.index(node):] + todo.difference_update(cyc) + output.update(cyc) + + if node in todo: + stack.append(node) + todo.remove(node) + break else: - traverse(key, cycle, goal) - cycle.pop() - - for parent in edges.get_parents(): - traverse(parent, set(), parent) - - unique_cycles = set(tuple(s) for s in cycles.values()) + node = stack.pop() + return output - for cycle in unique_cycles: - edgecollection = [edge for edge in edges - if edge[0] in cycle and edge[1] in cycle] - yield edgecollection diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 7dc55ea99..8c38a98b0 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -1,6 +1,6 @@ import sqlalchemy.topological as topological from sqlalchemy.test import TestBase -from sqlalchemy.test.testing import assert_raises +from sqlalchemy.test.testing import assert_raises, eq_ from sqlalchemy import exc import collections @@ -16,7 +16,7 @@ class DependencySortTest(TestBase): for n in result[i:]: assert node not in deps[n] - def testsort(self): + def test_sort_one(self): rootnode = 'root' node2 = 'node2' node3 = 'node3' @@ -38,7 +38,7 @@ class DependencySortTest(TestBase): ] self.assert_sort(tuples, topological.sort(tuples, [])) - def testsort2(self): + def test_sort_two(self): node1 = 'node1' node2 = 'node2' node3 = 'node3' @@ -55,7 +55,7 @@ class DependencySortTest(TestBase): ] self.assert_sort(tuples, topological.sort(tuples, [node7])) - def testsort4(self): + def test_sort_three(self): node1 = 'keywords' node2 = 'itemkeyowrds' node3 = 'items' @@ -68,7 +68,7 @@ class DependencySortTest(TestBase): ] self.assert_sort(tuples, topological.sort(tuples, [])) - def testcircular(self): + def test_raise_on_cycle_one(self): node1 = 'node1' node2 = 'node2' node3 = 'node3' @@ -87,7 +87,7 @@ class DependencySortTest(TestBase): # TODO: test find_cycles - def testcircular2(self): + def test_raise_on_cycle_two(self): # this condition was arising from ticket:362 # and was not treated properly by topological sort node1 = 'node1' @@ -105,7 +105,7 @@ class DependencySortTest(TestBase): # TODO: test find_cycles - def testcircular3(self): + def test_raise_on_cycle_three(self): question, issue, providerservice, answer, provider = "Question", "Issue", "ProviderService", "Answer", "Provider" tuples = [(question, issue), (providerservice, issue), (provider, question), @@ -116,15 +116,14 @@ class DependencySortTest(TestBase): # TODO: test find_cycles - def testbigsort(self): + def test_large_sort(self): tuples = [(i, i + 1) for i in range(0, 1500, 2)] self.assert_sort( tuples, topological.sort(tuples, []) ) - - def testids(self): + def test_ticket_1380(self): # ticket:1380 regression: would raise a KeyError tuples = [(id(i), i) for i in range(3)] self.assert_sort( @@ -132,5 +131,122 @@ class DependencySortTest(TestBase): topological.sort(tuples, []) ) + def test_find_cycle_one(self): + node1 = 'node1' + node2 = 'node2' + node3 = 'node3' + node4 = 'node4' + tuples = [ + (node1, node2), + (node3, node1), + (node2, node4), + (node3, node2), + (node2, node3) + ] + + eq_( + topological.find_cycles(tuples), + set([node1, node2, node3]) + ) + + def test_find_multiple_cycles_one(self): + node1 = 'node1' + node2 = 'node2' + node3 = 'node3' + node4 = 'node4' + node5 = 'node5' + node6 = 'node6' + node7 = 'node7' + node8 = 'node8' + node9 = 'node9' + tuples = [ + # cycle 1 + (node1, node2), + (node2, node4), + (node4, node1), + + # cycle 2 + (node9, node9), + + # cycle 3 + (node7, node5), + (node5, node7), + + # cycle 4, but only if cycle 1 nodes are present + (node1, node6), + (node6, node8), + (node8, node4), + + (node3, node1), + (node3, node2), + ] + + allnodes = set([node1, node2, node3, node4, node5, node6, node7, node8, node9]) + eq_( + topological.find_cycles(tuples, allnodes), + set(['node8', 'node1', 'node2', 'node5', 'node4', 'node7', 'node6', 'node9']) + ) + + def test_find_multiple_cycles_two(self): + node1 = 'node1' + node2 = 'node2' + node3 = 'node3' + node4 = 'node4' + node5 = 'node5' + node6 = 'node6' + tuples = [ + # cycle 1 + (node1, node2), + (node2, node4), + (node4, node1), + + # cycle 2 + (node1, node6), + (node6, node2), + (node2, node4), + (node4, node1), + ] + + allnodes = set([node1, node2, node3, node4, node5, node6]) + eq_( + topological.find_cycles(tuples, allnodes), + set(['node1', 'node2', 'node4']) + ) + + def test_find_multiple_cycles_three(self): + node1 = 'node1' + node2 = 'node2' + node3 = 'node3' + node4 = 'node4' + node5 = 'node5' + node6 = 'node6' + tuples = [ + + # cycle 1 + (node1, node2), + (node2, node1), + + # cycle 2 + (node2, node3), + (node3, node2), + + # cycle3 + (node2, node4), + (node4, node2), + + # cycle4 + (node2, node5), + (node5, node6), + (node6, node2) + ] + + allnodes = set([node1, node2, node3, node4, node5, node6]) + eq_( + topological.find_cycles(tuples, allnodes), + allnodes + ) + + + |