summaryrefslogtreecommitdiff
path: root/test/testbase.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/testbase.py')
-rw-r--r--test/testbase.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/test/testbase.py b/test/testbase.py
index cdd0d6a33..2a9094f5d 100644
--- a/test/testbase.py
+++ b/test/testbase.py
@@ -1,6 +1,7 @@
import unittest
import StringIO
import sqlalchemy.engine as engine
+import sqlalchemy.ext.proxy
import re, sys
echo = True
@@ -15,14 +16,22 @@ def parse_argv():
global db, db_uri
DBTYPE = 'sqlite'
-
+ PROXY = False
+
if len(sys.argv) >= 3:
if sys.argv[1] == '--dburi':
(param, db_uri) = (sys.argv.pop(1), sys.argv.pop(1))
elif sys.argv[1] == '--db':
(param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1))
+
if (None == db_uri):
+ p = DBTYPE.split('.')
+ if len(p) > 1:
+ arg = p[0]
+ DBTYPE = p[1]
+ if arg == 'proxy':
+ PROXY = True
if DBTYPE == 'sqlite':
db_uri = 'sqlite://filename=:memory:'
elif DBTYPE == 'sqlite_file':
@@ -37,7 +46,11 @@ def parse_argv():
if not db_uri:
raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle> to test runner."
- db = engine.create_engine(db_uri, echo=echo, default_ordering=True)
+ if PROXY:
+ db = sqlalchemy.ext.proxy.ProxyEngine(echo=echo, default_ordering=True)
+ db.connect(db_uri)
+ else:
+ db = engine.create_engine(db_uri, echo=echo, default_ordering=True)
db = EngineAssert(db)
class PersistTest(unittest.TestCase):
@@ -75,7 +88,7 @@ class AssertMixin(PersistTest):
else:
self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
def assert_sql(self, db, callable_, list, with_sequences=None):
- if with_sequences is not None and (db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle')):
+ if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
db.set_assert_list(self, with_sequences)
else:
db.set_assert_list(self, list)
@@ -89,13 +102,13 @@ class AssertMixin(PersistTest):
callable_()
finally:
self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count))
-
+
class EngineAssert(object):
"""decorates a SQLEngine object to match the incoming queries against a set of assertions."""
def __init__(self, engine):
self.engine = engine
self.realexec = engine.post_exec
- engine.post_exec = self.post_exec
+ self.realexec.im_self.post_exec = self.post_exec
self.logger = engine.logger
self.set_assert_list(None, None)
self.sql_count = 0