summaryrefslogtreecommitdiff
path: root/test/testbase.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
commitbb79e2e871d0a4585164c1a6ed626d96d0231975 (patch)
tree6d457ba6c36c408b45db24ec3c29e147fe7504ff /test/testbase.py
parent4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff)
downloadsqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'test/testbase.py')
-rw-r--r--test/testbase.py223
1 files changed, 147 insertions, 76 deletions
diff --git a/test/testbase.py b/test/testbase.py
index 8ef63ef27..04972779d 100644
--- a/test/testbase.py
+++ b/test/testbase.py
@@ -2,63 +2,96 @@ import unittest
import StringIO
import sqlalchemy.engine as engine
import sqlalchemy.ext.proxy as proxy
-import sqlalchemy.schema as schema
+import sqlalchemy.pool as pool
+#import sqlalchemy.schema as schema
import re, sys
+import sqlalchemy
+import optparse
+
-echo = True
-#echo = False
-#echo = 'debug'
db = None
+metadata = None
db_uri = None
+echo = True
+
+# redefine sys.stdout so all those print statements go to the echo func
+local_stdout = sys.stdout
+class Logger(object):
+ def write(self, msg):
+ if echo:
+ local_stdout.write(msg)
+sys.stdout = Logger()
+
+def echo_text(text):
+ print text
def parse_argv():
# we are using the unittest main runner, so we are just popping out the
# arguments we need instead of using our own getopt type of thing
- global db, db_uri
+ global db, db_uri, metadata
DBTYPE = 'sqlite'
PROXY = False
+
+
+ parser = optparse.OptionParser(usage = "usage: %prog [options] files...")
+ parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)")
+ parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (sqlite, sqlite_file, postgres, mysql, oracle, oracle8, mssql)")
+ parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool")
+ parser.add_option("--verbose", action="store_true", dest="verbose", help="full debug echoing")
+ parser.add_option("--quiet", action="store_true", dest="quiet", help="be totally quiet")
+ parser.add_option("--nothreadlocal", action="store_true", dest="nothreadlocal", help="dont use thread-local mod")
+ parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to SA default)")
+
+ (options, args) = parser.parse_args()
+ sys.argv[1:] = args
- if len(sys.argv) >= 3:
- if sys.argv[1] == '--dburi':
- (param, db_uri) = (sys.argv.pop(1), sys.argv.pop(1))
- elif sys.argv[1] == '--db':
- (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1))
+ if options.dburi:
+ db_uri = param = options.dburi
+ elif options.db:
+ DBTYPE = param = options.db
+
opts = {}
if (None == db_uri):
- p = DBTYPE.split('.')
- if len(p) > 1:
- arg = p[0]
- DBTYPE = p[1]
- if arg == 'proxy':
- PROXY = True
if DBTYPE == 'sqlite':
- db_uri = 'sqlite://filename=:memory:'
+ db_uri = 'sqlite:///:memory:'
elif DBTYPE == 'sqlite_file':
- db_uri = 'sqlite://filename=querytest.db'
+ db_uri = 'sqlite:///querytest.db'
elif DBTYPE == 'postgres':
- db_uri = 'postgres://database=test&port=5432&host=127.0.0.1&user=scott&password=tiger'
+ db_uri = 'postgres://scott:tiger@127.0.0.1:5432/test'
elif DBTYPE == 'mysql':
- db_uri = 'mysql://database=test&host=127.0.0.1&user=scott&password=tiger'
+ db_uri = 'mysql://scott:tiger@127.0.0.1/test'
elif DBTYPE == 'oracle':
- db_uri = 'oracle://user=scott&password=tiger'
+ db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
elif DBTYPE == 'oracle8':
- db_uri = 'oracle://user=scott&password=tiger'
+ db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
opts = {'use_ansi':False}
elif DBTYPE == 'mssql':
- db_uri = 'mssql://database=test&user=scott&password=tiger'
+ db_uri = 'mssql://scott:tiger@/test'
if not db_uri:
raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql> to test runner."
- if PROXY:
- db = proxy.ProxyEngine(echo=echo, default_ordering=True, **opts)
- db.connect(db_uri)
+ if not options.nothreadlocal:
+ __import__('sqlalchemy.mods.threadlocal')
+ sqlalchemy.mods.threadlocal.uninstall_plugin()
+
+ global echo
+ echo = options.verbose and not options.quiet
+
+ global quiet
+ quiet = options.quiet
+
+ if options.enginestrategy is not None:
+ opts['strategy'] = options.enginestrategy
+ if options.mockpool:
+ db = engine.create_engine(db_uri, echo=True, default_ordering=True, poolclass=MockPool, **opts)
else:
- db = engine.create_engine(db_uri, echo=echo, default_ordering=True, **opts)
+ db = engine.create_engine(db_uri, echo=True, default_ordering=True, **opts)
db = EngineAssert(db)
-
+ metadata = sqlalchemy.BoundMetaData(db)
+
def unsupported(*dbs):
"""a decorator that marks a test as unsupported by one or more database implementations"""
def decorate(func):
@@ -87,21 +120,49 @@ def supported(*dbs):
return lala
return decorate
-def echo_text(text):
- print text
class PersistTest(unittest.TestCase):
"""persist base class, provides default setUpAll, tearDownAll and echo functionality"""
def __init__(self, *args, **params):
unittest.TestCase.__init__(self, *args, **params)
def echo(self, text):
- if echo:
- echo_text(text)
+ echo_text(text)
+ def install_threadlocal(self):
+ sqlalchemy.mods.threadlocal.install_plugin()
+ def uninstall_threadlocal(self):
+ sqlalchemy.mods.threadlocal.uninstall_plugin()
def setUpAll(self):
pass
def tearDownAll(self):
pass
+ def shortDescription(self):
+ """overridden to not return docstrings"""
+ return None
+
+class MockPool(pool.Pool):
+ """this pool is hardcore about only one connection being used at a time."""
+ def __init__(self, creator, **params):
+ pool.Pool.__init__(self, **params)
+ self.connection = creator()
+ self._conn = self.connection
+
+ def status(self):
+ return "MockPool"
+
+ def do_return_conn(self, conn):
+ assert conn is self._conn and self.connection is None
+ self.connection = conn
+
+ def do_return_invalid(self):
+ raise "Invalid"
+ def do_get(self):
+ if getattr(self, 'breakpoint', False):
+ raise "breakpoint"
+ assert self.connection is not None
+ c = self.connection
+ self.connection = None
+ return c
class AssertMixin(PersistTest):
"""given a list-based structure of keys/properties which represent information within an object structure, and
@@ -145,8 +206,10 @@ class EngineAssert(proxy.BaseProxyEngine):
"""decorates a SQLEngine object to match the incoming queries against a set of assertions."""
def __init__(self, engine):
self._engine = engine
- self.realexec = engine.post_exec
- self.realexec.im_self.post_exec = self.post_exec
+
+ self.real_execution_context = engine.dialect.create_execution_context
+ engine.dialect.create_execution_context = self.execution_context
+
self.logger = engine.logger
self.set_assert_list(None, None)
self.sql_count = 0
@@ -154,8 +217,6 @@ class EngineAssert(proxy.BaseProxyEngine):
return self._engine
def set_engine(self, e):
self._engine = e
-# def __getattr__(self, key):
- # return getattr(self.engine, key)
def set_assert_list(self, unittest, list):
self.unittest = unittest
self.assert_list = list
@@ -164,46 +225,55 @@ class EngineAssert(proxy.BaseProxyEngine):
def _set_echo(self, echo):
self.engine.echo = echo
echo = property(lambda s: s.engine.echo, _set_echo)
- def post_exec(self, proxy, compiled, parameters, **kwargs):
- self.engine.logger = self.logger
- statement = str(compiled)
- statement = re.sub(r'\n', '', statement)
-
- if self.assert_list is not None:
- item = self.assert_list[-1]
- if not isinstance(item, dict):
- item = self.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if not item.has_key('_converted'):
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- self.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
-
- (query, params) = item
- if callable(params):
- params = params()
-
- query = self.convert_statement(query)
-
- self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- self.sql_count += 1
- return self.realexec(proxy, compiled, parameters, **kwargs)
+
+ def execution_context(self):
+ def post_exec(engine, proxy, compiled, parameters, **kwargs):
+ ctx = e
+ self.engine.logger = self.logger
+ statement = str(compiled)
+ statement = re.sub(r'\n', '', statement)
+ if self.assert_list is not None:
+ item = self.assert_list[-1]
+ if not isinstance(item, dict):
+ item = self.assert_list.pop()
+ else:
+ # asserting a dictionary of statements->parameters
+ # this is to specify query assertions where the queries can be in
+ # multiple orderings
+ if not item.has_key('_converted'):
+ for key in item.keys():
+ ckey = self.convert_statement(key)
+ item[ckey] = item[key]
+ if ckey != key:
+ del item[key]
+ item['_converted'] = True
+ try:
+ entry = item.pop(statement)
+ if len(item) == 1:
+ self.assert_list.pop()
+ item = (statement, entry)
+ except KeyError:
+ self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
+
+ (query, params) = item
+ if callable(params):
+ params = params(ctx)
+ if params is not None and isinstance(params, list) and len(params) == 1:
+ params = params[0]
+
+ query = self.convert_statement(query)
+ self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+ self.sql_count += 1
+ return realexec(ctx, proxy, compiled, parameters, **kwargs)
+
+ e = self.real_execution_context()
+ realexec = e.post_exec
+ realexec.im_self.post_exec = post_exec
+ return e
+
def convert_statement(self, query):
- paramstyle = self.engine.paramstyle
+ paramstyle = self.engine.dialect.paramstyle
if paramstyle == 'named':
pass
elif paramstyle =='pyformat':
@@ -275,10 +345,11 @@ parse_argv()
def runTests(suite):
- runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
+ runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
runner.run(suite)
def main():
- unittest.main()
+ suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
+ runTests(suite)