summaryrefslogtreecommitdiff
path: root/test/sql/generative.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/sql/generative.py')
-rw-r--r--test/sql/generative.py275
1 files changed, 275 insertions, 0 deletions
diff --git a/test/sql/generative.py b/test/sql/generative.py
new file mode 100644
index 000000000..357a66fcd
--- /dev/null
+++ b/test/sql/generative.py
@@ -0,0 +1,275 @@
+import testbase
+from sql import select as selecttests
+from sqlalchemy import *
+from testlib import *
+
+class TraversalTest(AssertMixin):
+ """test ClauseVisitor's traversal, particularly its ability to copy and modify
+ a ClauseElement in place."""
+
+ def setUpAll(self):
+ global A, B
+
+ # establish two ficticious ClauseElements.
+ # define deep equality semantics as well as deep identity semantics.
+ class A(ClauseElement):
+ def __init__(self, expr):
+ self.expr = expr
+
+ def is_other(self, other):
+ return other is self
+
+ def __eq__(self, other):
+ return other.expr == self.expr
+
+ def __ne__(self, other):
+ return other.expr != self.expr
+
+ def __str__(self):
+ return "A(%s)" % repr(self.expr)
+
+ class B(ClauseElement):
+ def __init__(self, *items):
+ self.items = items
+
+ def is_other(self, other):
+ if other is not self:
+ return False
+ for i1, i2 in zip(self.items, other.items):
+ if i1 is not i2:
+ return False
+ return True
+
+ def __eq__(self, other):
+ for i1, i2 in zip(self.items, other.items):
+ if i1 != i2:
+ return False
+ return True
+
+ def __ne__(self, other):
+ for i1, i2 in zip(self.items, other.items):
+ if i1 != i2:
+ return True
+ return False
+
+ def _copy_internals(self):
+ self.items = [i._clone() for i in self.items]
+
+ def get_children(self, **kwargs):
+ return self.items
+
+ def __str__(self):
+ return "B(%s)" % repr([str(i) for i in self.items])
+
+ def test_test_classes(self):
+ a1 = A("expr1")
+ struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct3 = B(a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+ assert a1.is_other(a1)
+ assert struct.is_other(struct)
+ assert struct == struct2
+ assert struct != struct3
+ assert not struct.is_other(struct2)
+ assert not struct.is_other(struct3)
+
+ def test_clone(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ pass
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=True)
+ assert struct == s2
+ assert not struct.is_other(s2)
+
+ def test_no_clone(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ pass
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=False)
+ assert struct == s2
+ assert struct.is_other(s2)
+
+ def test_change_in_place(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ if a.expr == "expr2":
+ a.expr = "expr2modified"
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=True)
+ assert struct != s2
+ assert not struct.is_other(s2)
+ assert struct2 == s2
+
+ class Vis2(ClauseVisitor):
+ def visit_a(self, a):
+ if a.expr == "expr2b":
+ a.expr = "expr2bmodified"
+ def visit_b(self, b):
+ pass
+
+ vis2 = Vis2()
+ s3 = vis2.traverse(struct, clone=True)
+ assert struct != s3
+ assert struct3 == s3
+
+class ClauseTest(selecttests.SQLTest):
+ """test copy-in-place behavior of various ClauseElements."""
+
+ def setUpAll(self):
+ global t1, t2
+ t1 = table("table1",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+ t2 = table("table2",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+
+ def test_binary(self):
+ clause = t1.c.col2 == t2.c.col2
+ assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+
+ def test_join(self):
+ clause = t1.join(t2, t1.c.col2==t2.c.col2)
+ c1 = str(clause)
+ assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+
+ class Vis(ClauseVisitor):
+ def visit_binary(self, binary):
+ binary.right = t2.c.col3
+
+ clause2 = Vis().traverse(clause, clone=True)
+ assert c1 == str(clause)
+ assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
+
+ def test_select(self):
+ s = t1.select()
+ s2 = select([s])
+ s2_assert = str(s2)
+ s3_assert = str(select([t1.select()], t1.c.col2==7))
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
+ s3 = Vis().traverse(s2, clone=True)
+ assert str(s3) == s3_assert
+ assert str(s2) == s2_assert
+ print str(s2)
+ print str(s3)
+ Vis().traverse(s2)
+ assert str(s2) == s3_assert
+
+ print "------------------"
+
+ s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9)))
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col3==9)
+ s4 = Vis().traverse(s3, clone=True)
+ print str(s3)
+ print str(s4)
+ assert str(s4) == s4_assert
+ assert str(s3) == s3_assert
+
+ print "------------------"
+ s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9)))
+ class Vis(ClauseVisitor):
+ def visit_binary(self, binary):
+ if binary.left is t1.c.col3:
+ binary.left = t1.c.col1
+ binary.right = bindparam("table1_col1")
+ s5 = Vis().traverse(s4, clone=True)
+ print str(s4)
+ print str(s5)
+ assert str(s5) == s5_assert
+ assert str(s4) == s4_assert
+
+ def test_correlated_select(self):
+ s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
+
+ self.runtest(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
+
+ def test_clause_adapter(self):
+ from sqlalchemy import sql_util
+
+ t1alias = t1.alias('t1alias')
+
+ vis = sql_util.ClauseAdapter(t1alias)
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
+ assert ff._get_from_objects() == [t1alias]
+
+ self.runtest(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
+
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
+ self.runtest(ff, "count(t1alias.col1) AS foo")
+ assert ff._get_from_objects() == [t1alias]
+
+# TODO:
+# self.runtest(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
+
+ t2alias = t2.alias('t2alias')
+ vis.chain(sql_util.ClauseAdapter(t2alias))
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+
+
+
+class SelectTest(selecttests.SQLTest):
+ """tests the generative capability of Select"""
+
+ def setUpAll(self):
+ global t1, t2
+ t1 = table("table1",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+ t2 = table("table2",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+
+ def test_select(self):
+ self.runtest(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :table1_col1 ORDER BY table1.col3")
+
+ self.runtest(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3")
+
+ s = select([t2], t2.c.col1==t1.c.col1, correlate=False)
+ s = s.correlate(t1).order_by(t2.c.col3)
+ self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3")
+
+
+if __name__ == '__main__':
+ testbase.main()