summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/topological.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-05-31 21:27:56 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-05-31 21:27:56 +0000
commitc84dd331df47a55bd0650686644da050311135f3 (patch)
treed8b3ceadad1c875d7fa73c631c5fcd810e0afb3d /lib/sqlalchemy/topological.py
parent13d4004774b8ea14e8ef1614bea7105122878748 (diff)
downloadsqlalchemy-c84dd331df47a55bd0650686644da050311135f3.tar.gz
slight cleanup i want in 0.5/0.6
Diffstat (limited to 'lib/sqlalchemy/topological.py')
-rw-r--r--lib/sqlalchemy/topological.py15
1 files changed, 6 insertions, 9 deletions
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index 306cdfc90..f9b9ad7b3 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -19,6 +19,7 @@ conditions.
"""
from sqlalchemy.exc import CircularDependencyError
+from sqlalchemy import util
__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree']
@@ -93,18 +94,14 @@ class _EdgeCollection(object):
"""A collection of directed edges."""
def __init__(self):
- self.parent_to_children = {}
- self.child_to_parents = {}
+ self.parent_to_children = util.defaultdict(set)
+ self.child_to_parents = util.defaultdict(set)
def add(self, edge):
"""Add an edge to this collection."""
- (parentnode, childnode) = edge
- if parentnode not in self.parent_to_children:
- self.parent_to_children[parentnode] = set()
+ parentnode, childnode = edge
self.parent_to_children[parentnode].add(childnode)
- if childnode not in self.child_to_parents:
- self.child_to_parents[childnode] = set()
self.child_to_parents[childnode].add(parentnode)
parentnode.dependencies.add(childnode)
@@ -117,13 +114,13 @@ class _EdgeCollection(object):
(parentnode, childnode) = edge
self.parent_to_children[parentnode].remove(childnode)
self.child_to_parents[childnode].remove(parentnode)
- if len(self.child_to_parents[childnode]) == 0:
+ if not self.child_to_parents[childnode]:
return childnode
else:
return None
def has_parents(self, node):
- return node in self.child_to_parents and len(self.child_to_parents[node]) > 0
+ return node in self.child_to_parents and bool(self.child_to_parents[node])
def edges_by_parent(self, node):
if node in self.parent_to_children: