diff options
Diffstat (limited to 'test/testbase.py')
-rw-r--r-- | test/testbase.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/test/testbase.py b/test/testbase.py index cdd0d6a33..2a9094f5d 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,6 +1,7 @@ import unittest import StringIO import sqlalchemy.engine as engine +import sqlalchemy.ext.proxy import re, sys echo = True @@ -15,14 +16,22 @@ def parse_argv(): global db, db_uri DBTYPE = 'sqlite' - + PROXY = False + if len(sys.argv) >= 3: if sys.argv[1] == '--dburi': (param, db_uri) = (sys.argv.pop(1), sys.argv.pop(1)) elif sys.argv[1] == '--db': (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1)) + if (None == db_uri): + p = DBTYPE.split('.') + if len(p) > 1: + arg = p[0] + DBTYPE = p[1] + if arg == 'proxy': + PROXY = True if DBTYPE == 'sqlite': db_uri = 'sqlite://filename=:memory:' elif DBTYPE == 'sqlite_file': @@ -37,7 +46,11 @@ def parse_argv(): if not db_uri: raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle> to test runner." - db = engine.create_engine(db_uri, echo=echo, default_ordering=True) + if PROXY: + db = sqlalchemy.ext.proxy.ProxyEngine(echo=echo, default_ordering=True) + db.connect(db_uri) + else: + db = engine.create_engine(db_uri, echo=echo, default_ordering=True) db = EngineAssert(db) class PersistTest(unittest.TestCase): @@ -75,7 +88,7 @@ class AssertMixin(PersistTest): else: self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value)) def assert_sql(self, db, callable_, list, with_sequences=None): - if with_sequences is not None and (db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle')): + if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'): db.set_assert_list(self, with_sequences) else: db.set_assert_list(self, list) @@ -89,13 +102,13 @@ class AssertMixin(PersistTest): callable_() finally: self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count)) - + class EngineAssert(object): """decorates a SQLEngine object to match the incoming queries against a set of assertions.""" def __init__(self, engine): self.engine = engine self.realexec = engine.post_exec - engine.post_exec = self.post_exec + self.realexec.im_self.post_exec = self.post_exec self.logger = engine.logger self.set_assert_list(None, None) self.sql_count = 0 |