summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/mapping/topological.py9
-rw-r--r--lib/sqlalchemy/mapping/unitofwork.py1
-rw-r--r--lib/sqlalchemy/sql.py3
-rw-r--r--test/dependency.py62
-rw-r--r--test/relationships.py99
6 files changed, 153 insertions, 24 deletions
diff --git a/CHANGES b/CHANGES
index 60b81cf78..99eb0bc10 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,6 @@
+next
+- some fixes to topological sort algorithm
+
0.1.6
- support for MS-SQL added courtesy Rick Morrison, Runar Petursson
- the latest SQLSoup from J. Ellis
diff --git a/lib/sqlalchemy/mapping/topological.py b/lib/sqlalchemy/mapping/topological.py
index 95807bf5f..779faab2d 100644
--- a/lib/sqlalchemy/mapping/topological.py
+++ b/lib/sqlalchemy/mapping/topological.py
@@ -141,15 +141,16 @@ class QueueDependencySorter(object):
#print repr(output)
head = None
node = None
+ # put the sorted list into a "tree". this is not much of a
+ # "tree" at the moment as its more of a linked list. it would be nice
+ # to group non-dependent nodes into sibling nodes, which allows better batching
+ # of SQL statements, but this algorithm has proved tricky
for o in output:
if head is None:
head = o
- node = o
else:
- for x in node.children:
- if x.dependencies.has_key(o):
- node = x
node.children.append(o)
+ node = o
return head
def _add_edge(self, edges, edge):
diff --git a/lib/sqlalchemy/mapping/unitofwork.py b/lib/sqlalchemy/mapping/unitofwork.py
index 3ef1d96ae..873bed548 100644
--- a/lib/sqlalchemy/mapping/unitofwork.py
+++ b/lib/sqlalchemy/mapping/unitofwork.py
@@ -422,7 +422,6 @@ class UOWTransaction(object):
mappers = util.HashSet()
for task in self.tasks.values():
mappers.append(task.mapper)
-
head = DependencySorter(self.dependencies, mappers).sort(allow_all_cycles=True)
#print str(head)
task = sort_hier(head)
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 7129781a7..b18b0916e 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -917,7 +917,6 @@ class Join(FromClause):
def __init__(self, left, right, onclause=None, isouter = False):
self.left = left
self.right = right
-
# TODO: if no onclause, do NATURAL JOIN
if onclause is None:
self.onclause = self._match_primaries(left, right)
@@ -925,6 +924,8 @@ class Join(FromClause):
self.onclause = onclause
self.isouter = isouter
+ name = property(lambda self: "Join on %s, %s" % (self.left.name, self.right.name))
+
def _locate_oid_column(self):
return self.left.oid_column
diff --git a/test/dependency.py b/test/dependency.py
index 5fd3df2fd..0aede4c7e 100644
--- a/test/dependency.py
+++ b/test/dependency.py
@@ -17,6 +17,26 @@ class thingy(object):
return repr(self)
class DependencySortTest(PersistTest):
+
+ def _assert_sort(self, tuples, allnodes, **kwargs):
+
+ head = DependencySorter(tuples, allnodes).sort(**kwargs)
+
+ print "\n" + str(head)
+ def findnode(t, n, parent=False):
+ if n.item is t[0]:
+ parent=True
+ elif n.item is t[1]:
+ if not parent and t[0] not in [c.item for c in n.cycles]:
+ self.assert_(False, "Node " + str(t[1]) + " not a child of " +str(t[0]))
+ else:
+ return
+ for c in n.children:
+ findnode(t, c, parent)
+
+ for t in tuples:
+ findnode(t, head)
+
def testsort(self):
rootnode = thingy('root')
node2 = thingy('node2')
@@ -27,6 +47,7 @@ class DependencySortTest(PersistTest):
subnode3 = thingy('subnode3')
subnode4 = thingy('subnode4')
subsubnode1 = thingy('subsubnode1')
+ allnodes = [rootnode, node2,node3,node4,subnode1,subnode2,subnode3,subnode4,subsubnode1]
tuples = [
(subnode3, subsubnode1),
(node2, subnode1),
@@ -37,8 +58,8 @@ class DependencySortTest(PersistTest):
(node4, subnode3),
(node4, subnode4)
]
- head = DependencySorter(tuples, []).sort()
- print "\n" + str(head)
+
+ self._assert_sort(tuples, allnodes)
def testsort2(self):
node1 = thingy('node1')
@@ -55,8 +76,7 @@ class DependencySortTest(PersistTest):
(node5, node6),
(node6, node2)
]
- head = DependencySorter(tuples, [node7]).sort()
- print "\n" + str(head)
+ self._assert_sort(tuples, [node1,node2,node3,node4,node5,node6,node7])
def testsort3(self):
['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords']
@@ -68,15 +88,10 @@ class DependencySortTest(PersistTest):
(node3, node2),
(node1,node3)
]
- head1 = DependencySorter(tuples, [node1, node2, node3]).sort()
- head2 = DependencySorter(tuples, [node3, node1, node2]).sort()
- head3 = DependencySorter(tuples, [node3, node2, node1]).sort()
+ self._assert_sort(tuples, [node1, node2, node3])
+ self._assert_sort(tuples, [node3, node1, node2])
+ self._assert_sort(tuples, [node3, node2, node1])
- # TODO: figure out a "node == node2" function
- #self.assert_(str(head1) == str(head2) == str(head3))
- print "\n" + str(head1)
- print "\n" + str(head2)
- print "\n" + str(head3)
def testsort4(self):
node1 = thingy('keywords')
@@ -89,8 +104,7 @@ class DependencySortTest(PersistTest):
(node1, node3),
(node3, node2)
]
- head = DependencySorter(tuples, []).sort()
- print "\n" + str(head)
+ self._assert_sort(tuples, [node1,node2,node3,node4])
def testsort5(self):
# this one, depenending on the weather,
@@ -117,8 +131,21 @@ class DependencySortTest(PersistTest):
node3,
node4
]
- head = DependencySorter(tuples, allitems).sort()
- print "\n" + str(head)
+ self._assert_sort(tuples, allitems)
+
+ def testsort6(self):
+ #('tbl_c', 'tbl_d'), ('tbl_a', 'tbl_c'), ('tbl_b', 'tbl_d')
+ nodea = thingy('tbl_a')
+ nodeb = thingy('tbl_b')
+ nodec = thingy('tbl_c')
+ noded = thingy('tbl_d')
+ tuples = [
+ (nodec, noded),
+ (nodea, nodec),
+ (nodeb, noded)
+ ]
+ allitems = [nodea,nodeb,nodec,noded]
+ self._assert_sort(tuples, allitems)
def testcircular(self):
node1 = thingy('node1')
@@ -134,8 +161,7 @@ class DependencySortTest(PersistTest):
(node3, node1),
(node4, node1)
]
- head = DependencySorter(tuples, []).sort(allow_all_cycles=True)
- print "\n" + str(head)
+ self._assert_sort(tuples, [node1,node2,node3,node4,node5], allow_all_cycles=True)
if __name__ == "__main__":
diff --git a/test/relationships.py b/test/relationships.py
new file mode 100644
index 000000000..36f5fe3d7
--- /dev/null
+++ b/test/relationships.py
@@ -0,0 +1,99 @@
+"""Test complex relationships"""
+
+import testbase
+import unittest, sys, datetime
+
+db = testbase.db
+#db.echo_uow=True
+
+from sqlalchemy import *
+
+
+class RelationTest(testbase.PersistTest):
+ """this is essentially an extension of the "dependency.py" topological sort test. this exposes
+ a particular issue that doesnt always occur with the straight dependency tests, due to the nature
+ of the sort being different based on random conditions"""
+ def setUpAll(self):
+ testbase.db.tables.clear()
+ global tbl_a
+ global tbl_b
+ global tbl_c
+ global tbl_d
+ tbl_a = Table("tbl_a", db,
+ Column("id", Integer, primary_key=True),
+ Column("name", String),
+ )
+ tbl_b = Table("tbl_b", db,
+ Column("id", Integer, primary_key=True),
+ Column("name", String),
+ )
+ tbl_c = Table("tbl_c", db,
+ Column("id", Integer, primary_key=True),
+ Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False),
+ Column("name", String),
+ )
+ tbl_d = Table("tbl_d", db,
+ Column("id", Integer, primary_key=True),
+ Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False),
+ Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
+ Column("name", String),
+ )
+ def setUp(self):
+ tbl_a.create()
+ tbl_b.create()
+ tbl_c.create()
+ tbl_d.create()
+
+ objectstore.clear()
+ clear_mappers()
+
+ class A(object):
+ pass
+ class B(object):
+ pass
+ class C(object):
+ pass
+ class D(object):
+ pass
+
+ D.mapper = mapper(D, tbl_d)
+ C.mapper = mapper(C, tbl_c, properties=dict(
+ d_rows=relation(D, private=True, backref="c_row"),
+ ))
+ B.mapper = mapper(B, tbl_b)
+ A.mapper = mapper(A, tbl_a, properties=dict(
+ c_rows=relation(C, private=True, backref="a_row"),
+ ))
+ D.mapper.add_property("b_row", relation(B))
+
+ global a
+ global c
+ a = A(); a.name = "a1"
+ b = B(); b.name = "b1"
+ c = C(); c.name = "c1"; c.a_row = a
+ # we must have more than one d row or it won't fail
+ d = D(); d.name = "d1"; d.b_row = b; d.c_row = c
+ d = D(); d.name = "d2"; d.b_row = b; d.c_row = c
+ d = D(); d.name = "d3"; d.b_row = b; d.c_row = c
+
+ def tearDown(self):
+ tbl_d.drop()
+ tbl_c.drop()
+ tbl_b.drop()
+ tbl_a.drop()
+
+ def testDeleteRootTable(self):
+ session = objectstore.get_session()
+ session.commit()
+ session.delete(a) # works as expected
+ session.commit()
+
+ def testDeleteMiddleTable(self):
+ session = objectstore.get_session()
+ session.commit()
+ session.delete(c) # fails
+ session.commit()
+
+
+if __name__ == "__main__":
+ testbase.main()