diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
commit | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch) | |
tree | c1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /test/sql | |
parent | 3a8e235af64e36b3b711df1f069d32359fe6c967 (diff) | |
download | sqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz |
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'test/sql')
-rw-r--r-- | test/sql/alltests.py | 4 | ||||
-rw-r--r-- | test/sql/case_statement.py | 16 | ||||
-rw-r--r-- | test/sql/constraints.py | 11 | ||||
-rw-r--r-- | test/sql/defaults.py | 129 | ||||
-rw-r--r-- | test/sql/generative.py | 275 | ||||
-rw-r--r-- | test/sql/labels.py | 18 | ||||
-rw-r--r-- | test/sql/query.py | 246 | ||||
-rw-r--r-- | test/sql/quote.py | 5 | ||||
-rw-r--r-- | test/sql/rowcount.py | 6 | ||||
-rw-r--r-- | test/sql/select.py | 188 | ||||
-rwxr-xr-x | test/sql/selectable.py | 32 | ||||
-rw-r--r-- | test/sql/testtypes.py | 131 | ||||
-rw-r--r-- | test/sql/unicode.py | 56 |
13 files changed, 792 insertions, 325 deletions
diff --git a/test/sql/alltests.py b/test/sql/alltests.py index 7be1a3ffb..a669a25f2 100644 --- a/test/sql/alltests.py +++ b/test/sql/alltests.py @@ -7,6 +7,8 @@ def suite(): 'sql.testtypes', 'sql.constraints', + 'sql.generative', + # SQL syntax 'sql.select', 'sql.selectable', @@ -30,7 +32,5 @@ def suite(): alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) return alltests - - if __name__ == '__main__': testbase.main(suite()) diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 946279b9d..493545b22 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -1,13 +1,15 @@ -import sys import testbase +import sys from sqlalchemy import * +from testlib import * -class CaseTest(testbase.PersistTest): +class CaseTest(PersistTest): def setUpAll(self): + metadata = MetaData(testbase.db) global info_table - info_table = Table('infos', testbase.db, + info_table = Table('infos', metadata, Column('pk', Integer, primary_key=True), Column('info', String(30))) @@ -26,9 +28,9 @@ class CaseTest(testbase.PersistTest): def testcase(self): inner = select([case([ [info_table.c.pk < 3, - literal('lessthan3', type=String)], + literal('lessthan3', type_=String)], [and_(info_table.c.pk >= 3, info_table.c.pk < 7), - literal('gt3', type=String)]]).label('x'), + literal('gt3', type_=String)]]).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -65,9 +67,9 @@ class CaseTest(testbase.PersistTest): w_else = select([case([ [info_table.c.pk < 3, - literal(3, type=Integer)], + literal(3, type_=Integer)], [and_(info_table.c.pk >= 3, info_table.c.pk < 6), - literal(6, type=Integer)]], + literal(6, type_=Integer)]], else_ = 0).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 7e1172850..3120185d5 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -1,8 +1,8 @@ import testbase from sqlalchemy import * -import sys +from testlib import * -class ConstraintTest(testbase.AssertMixin): +class ConstraintTest(AssertMixin): def setUp(self): global metadata @@ -52,7 +52,7 @@ class ConstraintTest(testbase.AssertMixin): ) metadata.create_all() - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_check_constraint(self): foo = Table('foo', metadata, Column('id', Integer, primary_key=True), @@ -172,12 +172,13 @@ class ConstraintTest(testbase.AssertMixin): capt = [] connection = testbase.db.connect() - ex = connection._execute + # TODO: hacky, put a real connection proxy in + ex = connection._Connection__execute def proxy(context): capt.append(context.statement) capt.append(repr(context.parameters)) ex(context) - connection._execute = proxy + connection._Connection__execute = proxy schemagen = testbase.db.dialect.schemagenerator(connection) schemagen.traverse(events) diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 10a3610f9..6c200232f 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -1,52 +1,59 @@ -from testbase import PersistTest -import sqlalchemy.util as util -import unittest, sys, os -import sqlalchemy.schema as schema import testbase from sqlalchemy import * -import sqlalchemy - -db = testbase.db +import sqlalchemy.util as util +import sqlalchemy.schema as schema +from sqlalchemy.orm import mapper, create_session +from testlib import * +import datetime class DefaultTest(PersistTest): def setUpAll(self): - global t, f, f2, ts, currenttime + global t, f, f2, ts, currenttime, metadata + + db = testbase.db + metadata = MetaData(db) x = {'x':50} def mydefault(): x['x'] += 1 return x['x'] + def mydefault_with_ctx(ctx): + return ctx.compiled_parameters['col1'] + 10 + + def myupdate_with_ctx(ctx): + return len(ctx.compiled_parameters['col2']) + use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' is_oracle = db.engine.name == 'oracle' # select "count(1)" returns different results on different DBs # also correct for "current_date" compatible as column default, value differences - currenttime = func.current_date(type=Date, engine=db); + currenttime = func.current_date(type_=Date, bind=db); if is_oracle: ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar() - f = select([func.count(1) + 5], engine=db).scalar() - f2 = select([func.count(1) + 14], engine=db).scalar() + f = select([func.count(1) + 5], bind=db).scalar() + f2 = select([func.count(1) + 14], bind=db).scalar() # TODO: engine propigation across nested functions not working - currenttime = func.trunc(currenttime, literal_column("'DAY'"), engine=db) + currenttime = func.trunc(currenttime, literal_column("'DAY'"), bind=db) def1 = currenttime def2 = func.trunc(text("sysdate"), literal_column("'DAY'")) deftype = Date elif use_function_defaults: - f = select([func.count(1) + 5], engine=db).scalar() - f2 = select([func.count(1) + 14], engine=db).scalar() + f = select([func.count(1) + 5], bind=db).scalar() + f2 = select([func.count(1) + 14], bind=db).scalar() def1 = currenttime def2 = text("current_date") deftype = Date ts = db.func.current_date().scalar() else: - f = select([func.count(1) + 5], engine=db).scalar() - f2 = select([func.count(1) + 14], engine=db).scalar() + f = select([func.count(1) + 5], bind=db).scalar() + f2 = select([func.count(1) + 14], bind=db).scalar() def1 = def2 = "3" ts = 3 deftype = Integer - t = Table('default_test1', db, + t = Table('default_test1', metadata, # python function Column('col1', Integer, primary_key=True, default=mydefault), @@ -66,7 +73,13 @@ class DefaultTest(PersistTest): Column('col6', Date, default=currenttime, onupdate=currenttime), Column('boolcol1', Boolean, default=True), - Column('boolcol2', Boolean, default=False) + Column('boolcol2', Boolean, default=False), + + # python function which uses ExecutionContext + Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx), + + # python builtin + Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today) ) t.create() @@ -75,9 +88,18 @@ class DefaultTest(PersistTest): def tearDown(self): t.delete().execute() - + + def testargsignature(self): + def mydefault(x, y): + pass + try: + c = ColumnDefault(mydefault) + assert False + except exceptions.ArgumentError, e: + assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e) + def teststandalone(self): - c = db.engine.contextual_connect() + c = testbase.db.engine.contextual_connect() x = c.execute(t.c.col1.default) y = t.c.col2.default.execute() z = c.execute(t.c.col3.default) @@ -94,9 +116,10 @@ class DefaultTest(PersistTest): t.insert().execute() ctexec = currenttime.scalar() - self.echo("Currenttime "+ repr(ctexec)) + print "Currenttime "+ repr(ctexec) l = t.select().execute() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)]) + today = datetime.date.today() + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)]) def testinsertvalues(self): t.insert(values={'col3':50}).execute() @@ -109,10 +132,10 @@ class DefaultTest(PersistTest): pk = r.last_inserted_ids()[0] t.update(t.c.col1==pk).execute(col4=None, col5=None) ctexec = currenttime.scalar() - self.echo("Currenttime "+ repr(ctexec)) + print "Currenttime "+ repr(ctexec) l = t.select(t.c.col1==pk).execute() l = l.fetchone() - self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False)) + self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today())) # mysql/other db's return 0 or 1 for count(1) self.assert_(14 <= f2 <= 15) @@ -124,8 +147,35 @@ class DefaultTest(PersistTest): l = l.fetchone() self.assert_(l['col3'] == 55) + @testing.supported('postgres') + def testpassiveoverride(self): + """primarily for postgres, tests that when we get a primary key column back + from reflecting a table which has a default value on it, we pre-execute + that PassiveDefault upon insert, even though PassiveDefault says + "let the database execute this", because in postgres we must have all the primary + key values in memory before insert; otherwise we cant locate the just inserted row.""" + + try: + meta = MetaData(testbase.db) + testbase.db.execute(""" + CREATE TABLE speedy_users + ( + speedy_user_id SERIAL PRIMARY KEY, + + user_name VARCHAR NOT NULL, + user_password VARCHAR NOT NULL + ); + """, None) + + t = Table("speedy_users", meta, autoload=True) + t.insert().execute(user_name='user', user_password='lala') + l = t.select().execute().fetchall() + self.assert_(l == [(1, 'user', 'lala')]) + finally: + testbase.db.execute("drop table speedy_users", None) + class AutoIncrementTest(PersistTest): - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def testnonautoincrement(self): meta = MetaData(testbase.db) nonai_table = Table("aitest", meta, @@ -159,6 +209,9 @@ class AutoIncrementTest(PersistTest): table.drop() def testfetchid(self): + + # TODO: what does this test do that all the various ORM tests dont ? + meta = MetaData(testbase.db) table = Table("aitest", meta, Column('id', Integer, primary_key=True), @@ -186,7 +239,7 @@ class AutoIncrementTest(PersistTest): class SequenceTest(PersistTest): - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def setUpAll(self): global cartitems, sometable, metadata metadata = MetaData(testbase.db) @@ -197,13 +250,13 @@ class SequenceTest(PersistTest): ) sometable = Table( 'Manager', metadata, Column( 'obj_id', Integer, Sequence('obj_id_seq'), ), - Column( 'name', type= String, ), + Column( 'name', String, ), Column( 'id', Integer, primary_key= True, ), ) metadata.create_all() - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" sometable.insert().execute(name="somename") @@ -213,7 +266,7 @@ class SequenceTest(PersistTest): (2, "someother", 2), ] - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def testsequence(self): cartitems.insert().execute(description='hi') cartitems.insert().execute(description='there') @@ -222,8 +275,8 @@ class SequenceTest(PersistTest): cartitems.select().execute().fetchall() - @testbase.supported('postgres', 'oracle') - def teststandalone(self): + @testing.supported('postgres', 'oracle') + def test_implicit_sequence_exec(self): s = Sequence("my_sequence", metadata=MetaData(testbase.db)) s.create() try: @@ -232,7 +285,7 @@ class SequenceTest(PersistTest): finally: s.drop() - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def teststandalone_explicit(self): s = Sequence("my_sequence") s.create(bind=testbase.db) @@ -242,12 +295,20 @@ class SequenceTest(PersistTest): finally: s.drop(testbase.db) - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') + def test_checkfirst(self): + s = Sequence("my_sequence") + s.create(testbase.db, checkfirst=False) + s.create(testbase.db, checkfirst=True) + s.drop(testbase.db, checkfirst=False) + s.drop(testbase.db, checkfirst=True) + + @testing.supported('postgres', 'oracle') def teststandalone2(self): x = cartitems.c.cart_id.sequence.execute() self.assert_(1 <= x <= 4) - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def tearDownAll(self): metadata.drop_all() diff --git a/test/sql/generative.py b/test/sql/generative.py new file mode 100644 index 000000000..357a66fcd --- /dev/null +++ b/test/sql/generative.py @@ -0,0 +1,275 @@ +import testbase +from sql import select as selecttests +from sqlalchemy import * +from testlib import * + +class TraversalTest(AssertMixin): + """test ClauseVisitor's traversal, particularly its ability to copy and modify + a ClauseElement in place.""" + + def setUpAll(self): + global A, B + + # establish two ficticious ClauseElements. + # define deep equality semantics as well as deep identity semantics. + class A(ClauseElement): + def __init__(self, expr): + self.expr = expr + + def is_other(self, other): + return other is self + + def __eq__(self, other): + return other.expr == self.expr + + def __ne__(self, other): + return other.expr != self.expr + + def __str__(self): + return "A(%s)" % repr(self.expr) + + class B(ClauseElement): + def __init__(self, *items): + self.items = items + + def is_other(self, other): + if other is not self: + return False + for i1, i2 in zip(self.items, other.items): + if i1 is not i2: + return False + return True + + def __eq__(self, other): + for i1, i2 in zip(self.items, other.items): + if i1 != i2: + return False + return True + + def __ne__(self, other): + for i1, i2 in zip(self.items, other.items): + if i1 != i2: + return True + return False + + def _copy_internals(self): + self.items = [i._clone() for i in self.items] + + def get_children(self, **kwargs): + return self.items + + def __str__(self): + return "B(%s)" % repr([str(i) for i in self.items]) + + def test_test_classes(self): + a1 = A("expr1") + struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct3 = B(a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3")) + + assert a1.is_other(a1) + assert struct.is_other(struct) + assert struct == struct2 + assert struct != struct3 + assert not struct.is_other(struct2) + assert not struct.is_other(struct3) + + def test_clone(self): + struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + + class Vis(ClauseVisitor): + def visit_a(self, a): + pass + def visit_b(self, b): + pass + + vis = Vis() + s2 = vis.traverse(struct, clone=True) + assert struct == s2 + assert not struct.is_other(s2) + + def test_no_clone(self): + struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + + class Vis(ClauseVisitor): + def visit_a(self, a): + pass + def visit_b(self, b): + pass + + vis = Vis() + s2 = vis.traverse(struct, clone=False) + assert struct == s2 + assert struct.is_other(s2) + + def test_change_in_place(self): + struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3")) + + class Vis(ClauseVisitor): + def visit_a(self, a): + if a.expr == "expr2": + a.expr = "expr2modified" + def visit_b(self, b): + pass + + vis = Vis() + s2 = vis.traverse(struct, clone=True) + assert struct != s2 + assert not struct.is_other(s2) + assert struct2 == s2 + + class Vis2(ClauseVisitor): + def visit_a(self, a): + if a.expr == "expr2b": + a.expr = "expr2bmodified" + def visit_b(self, b): + pass + + vis2 = Vis2() + s3 = vis2.traverse(struct, clone=True) + assert struct != s3 + assert struct3 == s3 + +class ClauseTest(selecttests.SQLTest): + """test copy-in-place behavior of various ClauseElements.""" + + def setUpAll(self): + global t1, t2 + t1 = table("table1", + column("col1"), + column("col2"), + column("col3"), + ) + t2 = table("table2", + column("col1"), + column("col2"), + column("col3"), + ) + + def test_binary(self): + clause = t1.c.col2 == t2.c.col2 + assert str(clause) == ClauseVisitor().traverse(clause, clone=True) + + def test_join(self): + clause = t1.join(t2, t1.c.col2==t2.c.col2) + c1 = str(clause) + assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True)) + + class Vis(ClauseVisitor): + def visit_binary(self, binary): + binary.right = t2.c.col3 + + clause2 = Vis().traverse(clause, clone=True) + assert c1 == str(clause) + assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3)) + + def test_select(self): + s = t1.select() + s2 = select([s]) + s2_assert = str(s2) + s3_assert = str(select([t1.select()], t1.c.col2==7)) + class Vis(ClauseVisitor): + def visit_select(self, select): + select.append_whereclause(t1.c.col2==7) + s3 = Vis().traverse(s2, clone=True) + assert str(s3) == s3_assert + assert str(s2) == s2_assert + print str(s2) + print str(s3) + Vis().traverse(s2) + assert str(s2) == s3_assert + + print "------------------" + + s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9))) + class Vis(ClauseVisitor): + def visit_select(self, select): + select.append_whereclause(t1.c.col3==9) + s4 = Vis().traverse(s3, clone=True) + print str(s3) + print str(s4) + assert str(s4) == s4_assert + assert str(s3) == s3_assert + + print "------------------" + s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9))) + class Vis(ClauseVisitor): + def visit_binary(self, binary): + if binary.left is t1.c.col3: + binary.left = t1.c.col1 + binary.right = bindparam("table1_col1") + s5 = Vis().traverse(s4, clone=True) + print str(s4) + print str(s5) + assert str(s5) == s5_assert + assert str(s4) == s4_assert + + def test_correlated_select(self): + s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2) + class Vis(ClauseVisitor): + def visit_select(self, select): + select.append_whereclause(t1.c.col2==7) + + self.runtest(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2") + + def test_clause_adapter(self): + from sqlalchemy import sql_util + + t1alias = t1.alias('t1alias') + + vis = sql_util.ClauseAdapter(t1alias) + ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True) + assert ff._get_from_objects() == [t1alias] + + self.runtest(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2") + + ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True) + self.runtest(ff, "count(t1alias.col1) AS foo") + assert ff._get_from_objects() == [t1alias] + +# TODO: +# self.runtest(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias") + + t2alias = t2.alias('t2alias') + vis.chain(sql_util.ClauseAdapter(t2alias)) + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2") + + + +class SelectTest(selecttests.SQLTest): + """tests the generative capability of Select""" + + def setUpAll(self): + global t1, t2 + t1 = table("table1", + column("col1"), + column("col2"), + column("col3"), + ) + t2 = table("table2", + column("col1"), + column("col2"), + column("col3"), + ) + + def test_select(self): + self.runtest(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :table1_col1 ORDER BY table1.col3") + + self.runtest(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3") + + s = select([t2], t2.c.col1==t1.c.col1, correlate=False) + s = s.correlate(t1).order_by(t2.c.col3) + self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3") + + +if __name__ == '__main__': + testbase.main() diff --git a/test/sql/labels.py b/test/sql/labels.py index ee9fa6bc5..553a3a3bc 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -1,11 +1,12 @@ import testbase - from sqlalchemy import * +from testlib import * + # TODO: either create a mock dialect with named paramstyle and a short identifier length, # or find a way to just use sqlite dialect and make those changes -class LabelTypeTest(testbase.PersistTest): +class LabelTypeTest(PersistTest): def test_type(self): m = MetaData() t = Table('sometable', m, @@ -14,21 +15,26 @@ class LabelTypeTest(testbase.PersistTest): assert isinstance(t.c.col1.label('hi').type, Integer) assert isinstance(select([t.c.col2], scalar=True).label('lala').type, Float) -class LongLabelsTest(testbase.PersistTest): +class LongLabelsTest(PersistTest): def setUpAll(self): - global metadata, table1 - metadata = MetaData(engine=testbase.db) + global metadata, table1, maxlen + metadata = MetaData(testbase.db) table1 = Table("some_large_named_table", metadata, Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True), Column("this_is_the_data_column", String(30)) ) metadata.create_all() + + maxlen = testbase.db.dialect.max_identifier_length + testbase.db.dialect.max_identifier_length = lambda: 29 + def tearDown(self): table1.delete().execute() def tearDownAll(self): metadata.drop_all() + testbase.db.dialect.max_identifier_length = maxlen def test_result(self): table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}) @@ -88,7 +94,7 @@ class LongLabelsTest(testbase.PersistTest): x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect) #print x # assert it doesnt end with "ORDER BY foo.some_large_named_table_this_is_the_primarykey_column" - assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_1""") + assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_2""") if __name__ == '__main__': testbase.main() diff --git a/test/sql/query.py b/test/sql/query.py index 8af5aafea..48a28a9a5 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -1,13 +1,9 @@ -from testbase import PersistTest import testbase -import unittest, sys, datetime - -import sqlalchemy.databases.sqlite as sqllite - -import tables +import datetime from sqlalchemy import * -from sqlalchemy.engine import ResultProxy, RowProxy from sqlalchemy import exceptions +from testlib import * + class QueryTest(PersistTest): @@ -24,25 +20,24 @@ class QueryTest(PersistTest): Column('address', String(30))) metadata.create_all() - def setUp(self): - self.users = users def tearDown(self): - self.users.delete().execute() + addresses.delete().execute() + users.delete().execute() def tearDownAll(self): metadata.drop_all() def testinsert(self): - self.users.insert().execute(user_id = 7, user_name = 'jack') - print repr(self.users.select().execute().fetchall()) - + users.insert().execute(user_id = 7, user_name = 'jack') + assert users.count().scalar() == 1 + def testupdate(self): - self.users.insert().execute(user_id = 7, user_name = 'jack') - print repr(self.users.select().execute().fetchall()) + users.insert().execute(user_id = 7, user_name = 'jack') + assert users.count().scalar() == 1 - self.users.update(self.users.c.user_id == 7).execute(user_name = 'fred') - print repr(self.users.select().execute().fetchall()) + users.update(users.c.user_id == 7).execute(user_name = 'fred') + assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred' def test_lastrow_accessor(self): """test the last_inserted_ids() and lastrow_has_id() functions""" @@ -63,14 +58,15 @@ class QueryTest(PersistTest): if result.lastrow_has_defaults(): criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) row = table.select(criterion).execute().fetchone() - ret.update(row) + for c in table.c: + ret[c.key] = row[c] return ret for supported, table, values, assertvalues in [ ( {'unsupported':['sqlite']}, Table("t1", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True), Column('foo', String(30), primary_key=True)), {'foo':'hi'}, {'id':1, 'foo':'hi'} @@ -78,7 +74,7 @@ class QueryTest(PersistTest): ( {'unsupported':['sqlite']}, Table("t2", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True), Column('foo', String(30), primary_key=True), Column('bar', String(30), PassiveDefault('hi')) ), @@ -98,7 +94,7 @@ class QueryTest(PersistTest): ( {'unsupported':[]}, Table("t4", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), Column('foo', String(30), primary_key=True), Column('bar', String(30), PassiveDefault('hi')) ), @@ -124,109 +120,94 @@ class QueryTest(PersistTest): table.drop() def testrowiteration(self): - self.users.insert().execute(user_id = 7, user_name = 'jack') - self.users.insert().execute(user_id = 8, user_name = 'ed') - self.users.insert().execute(user_id = 9, user_name = 'fred') - r = self.users.select().execute() + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'ed') + users.insert().execute(user_id = 9, user_name = 'fred') + r = users.select().execute() l = [] for row in r: l.append(row) self.assert_(len(l) == 3) def test_fetchmany(self): - self.users.insert().execute(user_id = 7, user_name = 'jack') - self.users.insert().execute(user_id = 8, user_name = 'ed') - self.users.insert().execute(user_id = 9, user_name = 'fred') - r = self.users.select().execute() + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'ed') + users.insert().execute(user_id = 9, user_name = 'fred') + r = users.select().execute() l = [] for row in r.fetchmany(size=2): l.append(row) self.assert_(len(l) == 2, "fetchmany(size=2) got %s rows" % len(l)) def test_compiled_execute(self): - s = select([self.users], self.users.c.user_id==bindparam('id')).compile() + users.insert().execute(user_id = 7, user_name = 'jack') + s = select([users], users.c.user_id==bindparam('id')).compile() c = testbase.db.connect() - print repr(c.execute(s, id=7).fetchall()) - - def test_global_metadata(self): - t1 = Table('table1', Column('col1', Integer, primary_key=True), - Column('col2', String(20))) - t2 = Table('table2', Column('col1', Integer, primary_key=True), - Column('col2', String(20))) - - assert t1.c.col1 - global_connect(testbase.db) - default_metadata.create_all() - try: - assert t1.count().scalar() == 0 - finally: - default_metadata.drop_all() - default_metadata.clear() - + assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7 def test_repeated_bindparams(self): """test that a BindParam can be used more than once. this should be run for dbs with both positional and named paramstyles.""" - self.users.insert().execute(user_id = 7, user_name = 'jack') - self.users.insert().execute(user_id = 8, user_name = 'fred') + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') u = bindparam('userid') - s = self.users.select(or_(self.users.c.user_name==u, self.users.c.user_name==u)) + s = users.select(or_(users.c.user_name==u, users.c.user_name==u)) r = s.execute(userid='fred').fetchall() assert len(r) == 1 def test_bindparam_shortname(self): """test the 'shortname' field on BindParamClause.""" - self.users.insert().execute(user_id = 7, user_name = 'jack') - self.users.insert().execute(user_id = 8, user_name = 'fred') + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') u = bindparam('userid', shortname='someshortname') - s = self.users.select(self.users.c.user_name==u) + s = users.select(users.c.user_name==u) r = s.execute(someshortname='fred').fetchall() assert len(r) == 1 def testdelete(self): - self.users.insert().execute(user_id = 7, user_name = 'jack') - self.users.insert().execute(user_id = 8, user_name = 'fred') - print repr(self.users.select().execute().fetchall()) + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') + print repr(users.select().execute().fetchall()) - self.users.delete(self.users.c.user_name == 'fred').execute() + users.delete(users.c.user_name == 'fred').execute() - print repr(self.users.select().execute().fetchall()) + print repr(users.select().execute().fetchall()) def testselectlimit(self): - self.users.insert().execute(user_id=1, user_name='john') - self.users.insert().execute(user_id=2, user_name='jack') - self.users.insert().execute(user_id=3, user_name='ed') - self.users.insert().execute(user_id=4, user_name='wendy') - self.users.insert().execute(user_id=5, user_name='laura') - self.users.insert().execute(user_id=6, user_name='ralph') - self.users.insert().execute(user_id=7, user_name='fido') - r = self.users.select(limit=3, order_by=[self.users.c.user_id]).execute().fetchall() + users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=2, user_name='jack') + users.insert().execute(user_id=3, user_name='ed') + users.insert().execute(user_id=4, user_name='wendy') + users.insert().execute(user_id=5, user_name='laura') + users.insert().execute(user_id=6, user_name='ralph') + users.insert().execute(user_id=7, user_name='fido') + r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r)) - @testbase.unsupported('mssql') + @testing.unsupported('mssql') def testselectlimitoffset(self): - self.users.insert().execute(user_id=1, user_name='john') - self.users.insert().execute(user_id=2, user_name='jack') - self.users.insert().execute(user_id=3, user_name='ed') - self.users.insert().execute(user_id=4, user_name='wendy') - self.users.insert().execute(user_id=5, user_name='laura') - self.users.insert().execute(user_id=6, user_name='ralph') - self.users.insert().execute(user_id=7, user_name='fido') - r = self.users.select(limit=3, offset=2, order_by=[self.users.c.user_id]).execute().fetchall() + users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=2, user_name='jack') + users.insert().execute(user_id=3, user_name='ed') + users.insert().execute(user_id=4, user_name='wendy') + users.insert().execute(user_id=5, user_name='laura') + users.insert().execute(user_id=6, user_name='ralph') + users.insert().execute(user_id=7, user_name='fido') + r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')]) - r = self.users.select(offset=5, order_by=[self.users.c.user_id]).execute().fetchall() + r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r==[(6, 'ralph'), (7, 'fido')]) - @testbase.supported('mssql') + @testing.supported('mssql') def testselectlimitoffset_mssql(self): try: - r = self.users.select(limit=3, offset=2, order_by=[self.users.c.user_id]).execute().fetchall() + r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall() assert False # InvalidRequestError should have been raised except exceptions.InvalidRequestError: pass - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_scalar_select(self): """test that scalar subqueries with labels get their type propigated to the result set.""" # mysql and/or mysqldb has a bug here, type isnt propigated for scalar subquery. @@ -244,18 +225,26 @@ class QueryTest(PersistTest): datetable.drop() def test_column_accessor(self): - self.users.insert().execute(user_id=1, user_name='john') - self.users.insert().execute(user_id=2, user_name='jack') - r = self.users.select(self.users.c.user_id==2).execute().fetchone() - self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2) - self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack') - - r = text("select * from query_users where user_id=2", engine=testbase.db).execute().fetchone() - self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2) - self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack') + users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=2, user_name='jack') + addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com') + + r = users.select(users.c.user_id==2).execute().fetchone() + self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) + self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') + + r = text("select * from query_users where user_id=2", bind=testbase.db).execute().fetchone() + self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) + self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') + # test slices + r = text("select * from query_addresses", bind=testbase.db).execute().fetchone() + self.assert_(r[0:1] == (1,)) + self.assert_(r[1:] == (2, 'foo@bar.com')) + self.assert_(r[:-1] == (1, 2)) + def test_ambiguous_column(self): - self.users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=1, user_name='john') r = users.outerjoin(addresses).select().execute().fetchone() try: print r['user_id'] @@ -264,18 +253,18 @@ class QueryTest(PersistTest): assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." def test_keys(self): - self.users.insert().execute(user_id=1, user_name='foo') - r = self.users.select().execute().fetchone() + users.insert().execute(user_id=1, user_name='foo') + r = users.select().execute().fetchone() self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) def test_items(self): - self.users.insert().execute(user_id=1, user_name='foo') - r = self.users.select().execute().fetchone() + users.insert().execute(user_id=1, user_name='foo') + r = users.select().execute().fetchone() self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) def test_len(self): - self.users.insert().execute(user_id=1, user_name='foo') - r = self.users.select().execute().fetchone() + users.insert().execute(user_id=1, user_name='foo') + r = users.select().execute().fetchone() self.assertEqual(len(r), 2) r.close() r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() @@ -295,7 +284,11 @@ class QueryTest(PersistTest): x = testbase.db.func.current_date().execute().scalar() y = testbase.db.func.current_date().select().execute().scalar() z = testbase.db.func.current_date().scalar() - assert x == y == z + assert (x == y == z) is True + + x = testbase.db.func.current_date(type_=Date) + assert isinstance(x.type, Date) + assert isinstance(x.execute().scalar(), datetime.date) def test_conn_functions(self): conn = testbase.db.connect() @@ -305,8 +298,8 @@ class QueryTest(PersistTest): z = conn.scalar(func.current_date()) finally: conn.close() - assert x == y == z - + assert (x == y == z) is True + def test_update_functions(self): """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances, and that column-level defaults get overridden""" @@ -357,7 +350,7 @@ class QueryTest(PersistTest): finally: meta.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_functions_with_cols(self): # TODO: shouldnt this work on oracle too ? x = testbase.db.func.current_date().execute().scalar() @@ -366,7 +359,7 @@ class QueryTest(PersistTest): w = select(['*'], from_obj=[testbase.db.func.current_date()]).scalar() # construct a column-based FROM object out of a function, like in [ticket:172] - s = select([column('date', type=DateTime)], from_obj=[testbase.db.func.current_date()]) + s = select([column('date', type_=DateTime)], from_obj=[testbase.db.func.current_date()]) q = s.execute().fetchone()[s.c.date] r = s.alias('datequery').select().scalar() @@ -374,8 +367,8 @@ class QueryTest(PersistTest): def test_column_order_with_simple_query(self): # should return values in column definition order - self.users.insert().execute(user_id=1, user_name='foo') - r = self.users.select(self.users.c.user_id==1).execute().fetchone() + users.insert().execute(user_id=1, user_name='foo') + r = users.select(users.c.user_id==1).execute().fetchone() self.assertEqual(r[0], 1) self.assertEqual(r[1], 'foo') self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) @@ -383,14 +376,14 @@ class QueryTest(PersistTest): def test_column_order_with_text_query(self): # should return values in query order - self.users.insert().execute(user_id=1, user_name='foo') + users.insert().execute(user_id=1, user_name='foo') r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(r[0], 'foo') self.assertEqual(r[1], 1) self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id']) self.assertEqual(r.values(), ['foo', 1]) - @testbase.unsupported('oracle', 'firebird') + @testing.unsupported('oracle', 'firebird') def test_column_accessor_shadow(self): meta = MetaData(testbase.db) shadowed = Table('test_shadowed', meta, @@ -420,7 +413,7 @@ class QueryTest(PersistTest): finally: shadowed.drop(checkfirst=True) - @testbase.supported('mssql') + @testing.supported('mssql') def test_fetchid_trigger(self): meta = MetaData(testbase.db) t1 = Table('t1', meta, @@ -446,7 +439,7 @@ class QueryTest(PersistTest): con.execute("""drop trigger paj""") meta.drop_all() - @testbase.supported('mssql') + @testing.supported('mssql') def test_insertid_schema(self): meta = MetaData(testbase.db) con = testbase.db.connect() @@ -459,7 +452,7 @@ class QueryTest(PersistTest): tbl.drop() con.execute('drop schema paj') - @testbase.supported('mssql') + @testing.supported('mssql') def test_insertid_reserved(self): meta = MetaData(testbase.db) table = Table( @@ -476,51 +469,52 @@ class QueryTest(PersistTest): def test_in_filtering(self): - """test the 'shortname' field on BindParamClause.""" - self.users.insert().execute(user_id = 7, user_name = 'jack') - self.users.insert().execute(user_id = 8, user_name = 'fred') - self.users.insert().execute(user_id = 9, user_name = None) + """test the behavior of the in_() function.""" + + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') + users.insert().execute(user_id = 9, user_name = None) - s = self.users.select(self.users.c.user_name.in_()) + s = users.select(users.c.user_name.in_()) r = s.execute().fetchall() # No username is in empty set assert len(r) == 0 - s = self.users.select(not_(self.users.c.user_name.in_())) + s = users.select(not_(users.c.user_name.in_())) r = s.execute().fetchall() # All usernames with a value are outside an empty set assert len(r) == 2 - s = self.users.select(self.users.c.user_name.in_('jack','fred')) + s = users.select(users.c.user_name.in_('jack','fred')) r = s.execute().fetchall() assert len(r) == 2 - s = self.users.select(not_(self.users.c.user_name.in_('jack','fred'))) + s = users.select(not_(users.c.user_name.in_('jack','fred'))) r = s.execute().fetchall() # Null values are not outside any set assert len(r) == 0 u = bindparam('search_key') - s = self.users.select(u.in_()) + s = users.select(u.in_()) r = s.execute(search_key='john').fetchall() assert len(r) == 0 r = s.execute(search_key=None).fetchall() assert len(r) == 0 - s = self.users.select(not_(u.in_())) + s = users.select(not_(u.in_())) r = s.execute(search_key='john').fetchall() assert len(r) == 3 r = s.execute(search_key=None).fetchall() assert len(r) == 0 - s = self.users.select(self.users.c.user_name.in_() == True) + s = users.select(users.c.user_name.in_() == True) r = s.execute().fetchall() assert len(r) == 0 - s = self.users.select(self.users.c.user_name.in_() == False) + s = users.select(users.c.user_name.in_() == False) r = s.execute().fetchall() assert len(r) == 2 - s = self.users.select(self.users.c.user_name.in_() == None) + s = users.select(users.c.user_name.in_() == None) r = s.execute().fetchall() assert len(r) == 1 @@ -577,7 +571,7 @@ class CompoundTest(PersistTest): assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_intersect(self): i = intersect( select([t2.c.col3, t2.c.col4]), @@ -586,7 +580,7 @@ class CompoundTest(PersistTest): assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - @testbase.unsupported('mysql', 'oracle') + @testing.unsupported('mysql', 'oracle') def test_except_style1(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -595,7 +589,7 @@ class CompoundTest(PersistTest): ), select([t2.c.col3, t2.c.col4])) assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testbase.unsupported('mysql', 'oracle') + @testing.unsupported('mysql', 'oracle') def test_except_style2(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -605,7 +599,7 @@ class CompoundTest(PersistTest): assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testbase.unsupported('sqlite', 'mysql', 'oracle') + @testing.unsupported('sqlite', 'mysql', 'oracle') def test_except_style3(self): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc e = except_( @@ -617,7 +611,7 @@ class CompoundTest(PersistTest): ) self.assertEquals(e.execute().fetchall(), [('ccc',)]) - @testbase.unsupported('sqlite', 'mysql', 'oracle') + @testing.unsupported('sqlite', 'mysql', 'oracle') def test_union_union_all(self): e = union_all( select([t1.c.col3]), @@ -628,7 +622,7 @@ class CompoundTest(PersistTest): ) self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)]) - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_composite(self): u = intersect( select([t2.c.col3, t2.c.col4]), diff --git a/test/sql/quote.py b/test/sql/quote.py index bc40d52ee..2fdf9dba0 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -1,6 +1,7 @@ -from testbase import PersistTest import testbase from sqlalchemy import * +from testlib import * + class QuoteTest(PersistTest): def setUpAll(self): @@ -78,7 +79,7 @@ class QuoteTest(PersistTest): assert t1.c.UcCol.case_sensitive is False assert t2.c.normalcol.case_sensitive is False - @testbase.unsupported('oracle') + @testing.unsupported('oracle') def testlabels(self): """test the quoting of labels. diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index df6a2a883..e0da96a81 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -1,7 +1,9 @@ -from sqlalchemy import * import testbase +from sqlalchemy import * +from testlib import * + -class FoundRowsTest(testbase.AssertMixin): +class FoundRowsTest(AssertMixin): """tests rowcount functionality""" def setUpAll(self): metadata = MetaData(testbase.db) diff --git a/test/sql/select.py b/test/sql/select.py index 4d3eb4ad7..a5cf061e2 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -1,8 +1,8 @@ -from testbase import PersistTest import testbase +import re, operator from sqlalchemy import * from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql -import unittest, re, operator +from testlib import * # the select test now tests almost completely with TableClause/ColumnClause objects, @@ -10,21 +10,21 @@ import unittest, re, operator # so SQLAlchemy's SQL construction engine can be used with no database dependencies at all. table1 = table('mytable', - column('myid'), - column('name'), - column('description'), + column('myid', Integer), + column('name', String), + column('description', String), ) table2 = table( 'myothertable', - column('otherid'), - column('othername'), + column('otherid', Integer), + column('othername', String), ) table3 = table( 'thirdtable', - column('userid'), - column('otherstuff'), + column('userid', Integer), + column('otherstuff', String), ) metadata = MetaData() @@ -54,7 +54,7 @@ addresses = table('addresses', class SQLTest(PersistTest): def runtest(self, clause, result, dialect = None, params = None, checkparams = None): c = clause.compile(parameters=params, dialect=dialect) - self.echo("\nSQL String:\n" + str(c) + repr(c.get_params())) + print "\nSQL String:\n" + str(c) + repr(c.get_params()) cc = re.sub(r'\n', '', str(c)) self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'") if checkparams is not None: @@ -130,6 +130,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A crit = q.c.myid == table1.c.myid self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable ORDER BY mytable.myid) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=sqlite.dialect()) self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=mssql.dialect()) + + def testmssql_aliases_schemas(self): + self.runtest(table4.select(), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable") + + dialect = mssql.dialect() + self.runtest(table4.select(), "SELECT remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM remote_owner.remotetable AS remotetable_1", dialect=dialect) + + # TODO: this is probably incorrect; no "AS <foo>" is being applied to the table + self.runtest(table1.join(table4, table1.c.myid==table4.c.rem_id).select(), "SELECT mytable.myid, mytable.name, mytable.description, remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM mytable JOIN remote_owner.remotetable ON remotetable.rem_id = mytable.myid") def testdontovercorrelate(self): self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""") @@ -142,6 +151,11 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={}) def testwheresubquery(self): + s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s') + self.runtest( + select([users, s.c.street], from_obj=[s]), + """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") + # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet. #self.runtest( # table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), "" @@ -194,7 +208,20 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A s = select([table1.c.myid], scalar=True) self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") - + + s = select([table1.c.myid]).correlate(None).as_scalar() + self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable") + + s = select([table1.c.myid]).as_scalar() + self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") + + # test expressions against scalar selects + self.runtest(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal") + self.runtest(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal") + self.runtest(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal") + + self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo") + zips = table('zips', column('zipcode'), @@ -206,15 +233,17 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A column('nm') ) zip = '12345' - qlat = select([zips.c.latitude], zips.c.zipcode == zip, scalar=True, correlate=False) - qlng = select([zips.c.longitude], zips.c.zipcode == zip, scalar=True, correlate=False) + qlat = select([zips.c.latitude], zips.c.zipcode == zip).correlate(None).as_scalar() + qlng = select([zips.c.longitude], zips.c.zipcode == zip).correlate(None).as_scalar() q = select([places.c.id, places.c.nm, zips.c.zipcode, func.latlondist(qlat, qlng).label('dist')], zips.c.zipcode==zip, order_by = ['dist', places.c.nm] ) - self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = :zips_zipcode_1), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_2)) AS dist FROM places, zips WHERE zips.zipcode = :zips_zipcode ORDER BY dist, places.nm") + self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE " + "zips.zipcode = :zips_zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_1)) AS dist " + "FROM places, zips WHERE zips.zipcode = :zips_zipcode_2 ORDER BY dist, places.nm") zalias = zips.alias('main_zip') qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode, scalar=True) @@ -223,7 +252,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A order_by = ['dist', places.c.nm] ) self.runtest(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm") - + a1 = table2.alias('t2alias') s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True) j1 = table1.join(table2, table1.c.myid==table2.c.otherid) @@ -261,28 +290,20 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) def testoperators(self): - self.runtest( - table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name" - ) - - self.runtest( - literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2" - ) # exercise arithmetic operators for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), (operator.sub, '-'), (operator.div, '/'), ): for (lhs, rhs, res) in ( - ('a', table1.c.myid, ':mytable_myid %s mytable.myid'), - ('a', literal('b'), ':literal %s :literal_1'), + (5, table1.c.myid, ':mytable_myid %s mytable.myid'), + (5, literal(5), ':literal %s :literal_1'), (table1.c.myid, 'b', 'mytable.myid %s :mytable_myid'), - (table1.c.myid, literal('b'), 'mytable.myid %s :literal'), + (table1.c.myid, literal(2.7), 'mytable.myid %s :literal'), (table1.c.myid, table1.c.myid, 'mytable.myid %s mytable.myid'), - (literal('a'), 'b', ':literal %s :literal_1'), - (literal('a'), table1.c.myid, ':literal %s mytable.myid'), - (literal('a'), literal('b'), ':literal %s :literal_1'), + (literal(5), 8, ':literal %s :literal_1'), + (literal(6), table1.c.myid, ':literal %s mytable.myid'), + (literal(7), literal(5.5), ':literal %s :literal_1'), ): self.runtest(py_op(lhs, rhs), res % sql_op) @@ -314,6 +335,25 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A "\n'" + compiled + "'\n does not match\n'" + fwd_sql + "'\n or\n'" + rev_sql + "'") + self.runtest( + table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND mytable.name != :mytable_name" + ) + + self.runtest( + table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name AND mytable.name = :mytable_name_1 AND mytable.name = :mytable_name_2)" + ) + + self.runtest( + table1.select((table1.c.myid != 12) & ~table1.c.name), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name" + ) + + self.runtest( + literal("a") + literal("b") * literal("c"), ":literal || :literal_1 * :literal_2" + ) + # test the op() function, also that its results are further usable in expressions self.runtest( table1.select(table1.c.myid.op('hoho')(12)==14), @@ -374,13 +414,18 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def testalias(self): # test the alias for a table1. column names stay the same, table name "changes" to "foo". self.runtest( - select([alias(table1, 'foo')]) + select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo") - + + for dialect in (firebird.dialect(), oracle.dialect()): + self.runtest( + select([table1.alias('foo')]) + ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo" + ,dialect=dialect) + self.runtest( - select([alias(table1, 'foo')]) - ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo" - ,dialect=firebird.dialect()) + select([table1.alias()]) + ,"SELECT mytable_1.myid, mytable_1.name, mytable_1.description FROM mytable AS mytable_1") # create a select for a join of two tables. use_labels means the column names will have # labels tablename_columnname, which become the column keys accessible off the Selectable object. @@ -401,6 +446,12 @@ myothertable.otherid AS myothertable_otherid FROM mytable, myothertable \ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = :t2view_mytable_myid" ) + + def test_prefixes(self): + self.runtest(table1.select().prefix_with("SQL_CALC_FOUND_ROWS").prefix_with("SQL_SOME_WEIRD_MYSQL_THING"), + "SELECT SQL_CALC_FOUND_ROWS SQL_SOME_WEIRD_MYSQL_THING mytable.myid, mytable.name, mytable.description FROM mytable" + ) + def testtext(self): self.runtest( text("select * from foo where lala = bar") , @@ -429,7 +480,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = s.append_column("column2") s.append_whereclause("column1=12") s.append_whereclause("column2=19") - s.order_by("column1") + s = s.order_by("column1") s.append_from("table1") self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1") @@ -468,7 +519,14 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = checkparams={'bar':4, 'whee': 7}, params={'bar':4, 'whee': 7, 'hoho':10}, ) - + + self.runtest( + text("select * from foo where clock='05:06:07'"), + "select * from foo where clock='05:06:07'", + checkparams={}, + params={}, + ) + dialect = postgres.dialect() self.runtest( text("select * from foo where lala=:bar and hoho=:whee"), @@ -477,6 +535,13 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = params={'bar':4, 'whee': 7, 'hoho':10}, dialect=dialect ) + self.runtest( + text("select * from foo where clock='05:06:07' and mork='\:mindy'"), + "select * from foo where clock='05:06:07' and mork=':mindy'", + checkparams={}, + params={}, + dialect=dialect + ) dialect = sqlite.dialect() self.runtest( @@ -509,7 +574,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today def testliteral(self): self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), - "SELECT :literal + :literal_1 FROM mytable") + "SELECT :literal || :literal_1 FROM mytable") def testcalculatedcolumns(self): value_tbl = table('values', @@ -663,7 +728,7 @@ FROM myothertable ORDER BY myid \ WHERE mytable.name = :mytable_name GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \ FROM mytable WHERE mytable.name = :mytable_name_1" ) - + def test_compound_select_grouping(self): self.runtest( union_all( @@ -716,6 +781,7 @@ EXISTS (select yay from foo where boo = lar)", dialect=postgres.dialect() ) + self.runtest(query, "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \ @@ -835,16 +901,16 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo self.runtest(select([table1], table1.c.myid.in_('a', literal('b'))), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)") - self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')), + self.runtest(select([table1], table1.c.myid.in_(literal(1) + 'a')), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1") self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)") + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :mytable_myid)") self.runtest(select([table1], table1.c.myid.in_(literal('a') + literal('a'), literal('b'))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :literal_2)") + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :literal_2)") - self.runtest(select([table1], table1.c.myid.in_('a', literal('b') +'b')), + self.runtest(select([table1], table1.c.myid.in_(1, literal(3) + 4)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal + :literal_1)") self.runtest(select([table1], table1.c.myid.in_(literal('a') < 'b')), @@ -862,7 +928,7 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid +'a')), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid + :mytable_myid)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'), 'a' + table1.c.myid)), + self.runtest(select([table1], table1.c.myid.in_(literal(1), 'a' + table1.c.myid)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid + mytable.myid)") self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)), @@ -900,16 +966,6 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)") - def testlateargs(self): - """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments - are sent""" - - self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'}) - - self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'}) - - self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'}) - def testcast(self): tbl = table('casttest', column('id', Integer), @@ -963,8 +1019,8 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE "SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))") self.runtest(table.select((5 + table.c.field).in_(5,6)), "SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)") - self.runtest(table.select(not_(table.c.field == 5)), - "SELECT op.field FROM op WHERE NOT op.field = :op_field") + self.runtest(table.select(not_(and_(table.c.field == 5, table.c.field == 7))), + "SELECT op.field FROM op WHERE NOT (op.field = :op_field AND op.field = :op_field_1)") self.runtest(table.select(not_(table.c.field) == 5), "SELECT op.field FROM op WHERE (NOT op.field) = :literal") self.runtest(table.select((table.c.field == table.c.field).between(False, True)), @@ -1019,12 +1075,17 @@ class CRUDTest(SQLTest): values = { table1.c.name : table1.c.name + "lala", table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1") + }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal), name=(mytable.name || :mytable_name) " + "WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal_1 || mytable.name || :literal_2") def testcorrelatedupdate(self): # test against a straight text subquery - u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")}) + u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")}) self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") + + mt = table1.alias() + u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)}) + self.runtest(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)") # test against a regular constructed subquery s = select([table2], table2.c.otherid == table1.c.myid) @@ -1043,7 +1104,18 @@ class CRUDTest(SQLTest): def testdelete(self): self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") - + + def testcorrelateddelete(self): + # test a non-correlated WHERE clause + s = select([table2.c.othername], table2.c.otherid == 7) + u = delete(table1, table1.c.name==s) + self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)") + + # test one that is actually correlated... + s = select([table2.c.othername], table2.c.otherid == table1.c.myid) + u = table1.delete(table1.c.name==s) + self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + class SchemaTest(SQLTest): def testselect(self): # these tests will fail with the MS-SQL compiler since it will alias schema-qualified tables diff --git a/test/sql/selectable.py b/test/sql/selectable.py index ecd8253b8..dcc855074 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -1,17 +1,13 @@ -"""tests that various From objects properly export their columns, as well as useable primary keys
-and foreign keys. Full relational algebra depends on every selectable unit behaving
-nicely with others.."""
-
+"""tests that various From objects properly export their columns, as well as
+useable primary keys and foreign keys. Full relational algebra depends on
+every selectable unit behaving nicely with others.."""
+
import testbase
-import unittest, sys, datetime
-
-
-db = testbase.db
-
from sqlalchemy import *
+from testlib import *
-
-table = Table('table1', db,
+metadata = MetaData()
+table = Table('table1', metadata,
Column('col1', Integer, primary_key=True),
Column('col2', String(20)),
Column('col3', Integer),
@@ -19,14 +15,14 @@ table = Table('table1', db, )
-table2 = Table('table2', db,
+table2 = Table('table2', metadata,
Column('col1', Integer, primary_key=True),
Column('col2', Integer, ForeignKey('table1.col1')),
Column('col3', String(20)),
Column('coly', Integer),
)
-class SelectableTest(testbase.AssertMixin):
+class SelectableTest(AssertMixin):
def testdistance(self):
s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])
@@ -57,7 +53,7 @@ class SelectableTest(testbase.AssertMixin): jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')
jjj = join(table, jj, table.c.col1==jj.c.bar_col1)
assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
-
+
j2 = jjj.alias('foo')
print j2.corresponding_column(jjj.c.table1_col1)
assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1
@@ -170,8 +166,9 @@ class SelectableTest(testbase.AssertMixin): print str(criterion)
print str(j.onclause)
self.assert_(criterion.compare(j.onclause))
+
-class PrimaryKeyTest(testbase.AssertMixin):
+class PrimaryKeyTest(AssertMixin):
def test_join_pk_collapse_implicit(self):
"""test that redundant columns in a join get 'collapsed' into a minimal primary key,
which is the root column along a chain of foreign key relationships."""
@@ -224,8 +221,7 @@ class PrimaryKeyTest(testbase.AssertMixin): j.foreign_keys
assert list(j.primary_key) == [a.c.id]
-
-
+
if __name__ == "__main__":
testbase.main()
-
\ No newline at end of file +
diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index ed9de0912..659033016 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -1,14 +1,11 @@ -from testbase import PersistTest, AssertMixin import testbase import pickleable +import datetime, os from sqlalchemy import * -import string,datetime, re, sys, os import sqlalchemy.engine.url as url - -import sqlalchemy.types from sqlalchemy.databases import mssql, oracle, mysql +from testlib import * -db = testbase.db class MyType(types.TypeEngine): def get_col_spec(self): @@ -107,7 +104,7 @@ class OverrideTest(PersistTest): def setUpAll(self): global users - users = Table('type_users', db, + users = Table('type_users', MetaData(testbase.db), Column('user_id', Integer, primary_key = True), # totall custom type Column('goofy', MyType, nullable = False), @@ -138,11 +135,12 @@ class ColumnsTest(AssertMixin): 'float_column': 'float_column NUMERIC(25, 2)' } + db = testbase.db if not db.name=='sqlite' and not db.name=='oracle': expectedResults['float_column'] = 'float_column FLOAT(25)' print db.engine.__module__ - testTable = Table('testColumns', db, + testTable = Table('testColumns', MetaData(db), Column('int_column', Integer), Column('smallint_column', Smallinteger), Column('varchar_column', String(20)), @@ -157,7 +155,8 @@ class UnicodeTest(AssertMixin): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" def setUpAll(self): global unicode_table - unicode_table = Table('unicode_table', db, + metadata = MetaData(testbase.db) + unicode_table = Table('unicode_table', metadata, Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), Column('unicode_varchar', Unicode(250)), Column('unicode_text', Unicode), @@ -175,49 +174,49 @@ class UnicodeTest(AssertMixin): unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - self.echo(repr(x['unicode_varchar'])) - self.echo(repr(x['unicode_text'])) - self.echo(repr(x['plain_varchar'])) + print repr(x['unicode_varchar']) + print repr(x['unicode_text']) + print repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) if isinstance(x['plain_varchar'], unicode): # SQLLite and MSSQL return non-unicode data as unicode - self.assert_(db.name in ('sqlite', 'mssql')) + self.assert_(testbase.db.name in ('sqlite', 'mssql')) self.assert_(x['plain_varchar'] == unicodedata) - self.echo("it's %s!" % db.name) + print "it's %s!" % testbase.db.name else: self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) def testengineparam(self): """tests engine-wide unicode conversion""" - prev_unicode = db.engine.dialect.convert_unicode + prev_unicode = testbase.db.engine.dialect.convert_unicode try: - db.engine.dialect.convert_unicode = True + testbase.db.engine.dialect.convert_unicode = True rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') unicode_table.insert().execute(unicode_varchar=unicodedata, unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - self.echo(repr(x['unicode_varchar'])) - self.echo(repr(x['unicode_text'])) - self.echo(repr(x['plain_varchar'])) + print repr(x['unicode_varchar']) + print repr(x['unicode_text']) + print repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) finally: - db.engine.dialect.convert_unicode = prev_unicode + testbase.db.engine.dialect.convert_unicode = prev_unicode - @testbase.unsupported('oracle') + @testing.unsupported('oracle') def testlength(self): """checks the database correctly understands the length of a unicode string""" teststr = u'aaa\x1234' - self.assert_(db.func.length(teststr).scalar() == len(teststr)) + self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr)) class BinaryTest(AssertMixin): def setUpAll(self): global binary_table - binary_table = Table('binary_table', db, + binary_table = Table('binary_table', MetaData(testbase.db), Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), Column('data', Binary), Column('data_slice', Binary(100)), @@ -244,39 +243,31 @@ class BinaryTest(AssertMixin): binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1) binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2) binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None) - l = binary_table.select(order_by=binary_table.c.primary_id).execute().fetchall() - print type(stream1), type(l[0]['data']), type(l[0]['data_slice']) - print len(stream1), len(l[0]['data']), len(l[0]['data_slice']) - self.assert_(list(stream1) == list(l[0]['data'])) - self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) - self.assert_(list(stream2) == list(l[1]['data'])) - self.assert_(testobj1 == l[0]['pickled']) - self.assert_(testobj2 == l[1]['pickled']) + + for stmt in ( + binary_table.select(order_by=binary_table.c.primary_id), + text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db) + ): + l = stmt.execute().fetchall() + print type(stream1), type(l[0]['data']), type(l[0]['data_slice']) + print len(stream1), len(l[0]['data']), len(l[0]['data_slice']) + self.assert_(list(stream1) == list(l[0]['data'])) + self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) + self.assert_(list(stream2) == list(l[1]['data'])) + self.assert_(testobj1 == l[0]['pickled']) + self.assert_(testobj2 == l[1]['pickled']) def load_stream(self, name, len=12579): f = os.path.join(os.path.dirname(testbase.__file__), name) # put a number less than the typical MySQL default BLOB size return file(f).read(len) - @testbase.supported('oracle') - def test_oracle_autobinary(self): - stream1 =self.load_stream('binary_data_one.dat') - stream2 =self.load_stream('binary_data_two.dat') - binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100]) - binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99]) - binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None) - result = testbase.db.connect().execute("select primary_id, misc, data, data_slice from binary_table") - l = result.fetchall() - l[0]['data'] - self.assert_(list(stream1) == list(l[0]['data'])) - self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) - self.assert_(list(stream2) == list(l[1]['data'])) - class DateTest(AssertMixin): def setUpAll(self): global users_with_date, insert_data + db = testbase.db if db.engine.name == 'oracle': import sqlalchemy.databases.oracle as oracle insert_data = [ @@ -314,13 +305,14 @@ class DateTest(AssertMixin): if db.engine.name == 'mssql': # MSSQL Datetime values have only a 3.33 milliseconds precision insert_data[2] = [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 547000), datetime.date(1970,4,1), datetime.time(23,59,59,997000)] - + fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time'] collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)), Column('user_date', Date), Column('user_time', Time)] - users_with_date = Table('query_users_with_date', db, *collist) + users_with_date = Table('query_users_with_date', + MetaData(testbase.db), *collist) users_with_date.create() insert_dicts = [dict(zip(fnames, d)) for d in insert_data] @@ -338,7 +330,7 @@ class DateTest(AssertMixin): def testtextdate(self): - x = db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall() + x = testbase.db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall() print repr(x) self.assert_(isinstance(x[0][0], datetime.datetime)) @@ -347,9 +339,13 @@ class DateTest(AssertMixin): #print repr(x) def testdate2(self): - t = Table('testdate', testbase.metadata, Column('id', Integer, Sequence('datetest_id_seq', optional=True), primary_key=True), + meta = MetaData(testbase.db) + t = Table('testdate', meta, + Column('id', Integer, + Sequence('datetest_id_seq', optional=True), + primary_key=True), Column('adate', Date), Column('adatetime', DateTime)) - t.create() + t.create(checkfirst=True) try: d1 = datetime.date(2007, 10, 30) t.insert().execute(adate=d1, adatetime=d1) @@ -361,8 +357,43 @@ class DateTest(AssertMixin): self.assert_(x.adatetime.__class__ == datetime.datetime) finally: - t.drop() + t.drop(checkfirst=True) +class NumericTest(AssertMixin): + def setUpAll(self): + global numeric_table, metadata + metadata = MetaData(testbase.db) + numeric_table = Table('numeric_table', metadata, + Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True), + Column('numericcol', Numeric(asdecimal=False)), + Column('floatcol', Float), + Column('ncasdec', Numeric), + Column('fcasdec', Float(asdecimal=True)) + ) + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + + def tearDown(self): + numeric_table.delete().execute() + + def test_decimal(self): + from decimal import Decimal + numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78) + numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78")) + l = numeric_table.select().execute().fetchall() + print l + rounded = [ + (l[0][0], l[0][1], round(l[0][2], 5), l[0][3], l[0][4]), + (l[1][0], l[1][1], round(l[1][2], 5), l[1][3], l[1][4]), + ] + assert rounded == [ + (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")), + (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")), + ] + + class IntervalTest(AssertMixin): def setUpAll(self): global interval_table, metadata diff --git a/test/sql/unicode.py b/test/sql/unicode.py index 7ce42bf4c..f882c2a5f 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -1,23 +1,27 @@ # coding: utf-8 -import testbase +"""verrrrry basic unicode column name testing""" +import testbase from sqlalchemy import * +from sqlalchemy.orm import mapper, relation, create_session, eagerload +from testlib import * -"""verrrrry basic unicode column name testing""" -class UnicodeSchemaTest(testbase.PersistTest): +class UnicodeSchemaTest(PersistTest): def setUpAll(self): - global metadata, t1, t2 - metadata = MetaData(engine=testbase.db) + global unicode_bind, metadata, t1, t2 + + unicode_bind = self._unicode_bind() + + metadata = MetaData(unicode_bind) t1 = Table('unitable1', metadata, Column(u'méil', Integer, primary_key=True), - Column(u'éXXm', Integer), + Column(u'\u6e2c\u8a66', Integer), ) - t2 = Table(u'unitéble2', metadata, + t2 = Table(u'Unitéble2', metadata, Column(u'méil', Integer, primary_key=True, key="a"), - Column(u'éXXm', Integer, ForeignKey(u'unitable1.méil'), key="b"), - + Column(u'\u6e2c\u8a66', Integer, ForeignKey(u'unitable1.méil'), key="b"), ) metadata.create_all() @@ -26,24 +30,46 @@ class UnicodeSchemaTest(testbase.PersistTest): t1.delete().execute() def tearDownAll(self): + global unicode_bind metadata.drop_all() + del unicode_bind + + def _unicode_bind(self): + if testbase.db.name != 'mysql': + return testbase.db + else: + # most mysql installations don't default to utf8 connections + version = testbase.db.dialect.get_version_info(testbase.db) + if version < (4, 1): + raise AssertionError("Unicode not supported on MySQL < 4.1") + + c = testbase.db.connect() + if not hasattr(c.connection.connection, 'set_character_set'): + raise AssertionError( + "Unicode not supported on this MySQL-python version") + else: + c.connection.set_character_set('utf8') + c.detach() + + return c def test_insert(self): - t1.insert().execute({u'méil':1, u'éXXm':5}) + t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5}) t2.insert().execute({'a':1, 'b':1}) assert t1.select().execute().fetchall() == [(1, 5)] assert t2.select().execute().fetchall() == [(1, 1)] def test_reflect(self): - t1.insert().execute({u'méil':2, u'éXXm':7}) + t1.insert().execute({u'méil':2, u'\u6e2c\u8a66':7}) t2.insert().execute({'a':2, 'b':2}) - meta = MetaData(testbase.db) + meta = MetaData(unicode_bind) tt1 = Table(t1.name, meta, autoload=True) tt2 = Table(t2.name, meta, autoload=True) - tt1.insert().execute({u'méil':1, u'éXXm':5}) - tt2.insert().execute({u'méil':1, u'éXXm':1}) + + tt1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5}) + tt2.insert().execute({u'méil':1, u'\u6e2c\u8a66':1}) assert tt1.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 7), (1, 5)] assert tt2.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 2), (1, 1)] @@ -57,7 +83,7 @@ class UnicodeSchemaTest(testbase.PersistTest): mapper(A, t1, properties={ 't2s':relation(B), 'a':t1.c[u'méil'], - 'b':t1.c[u'éXXm'] + 'b':t1.c[u'\u6e2c\u8a66'] }) mapper(B, t2) sess = create_session() |