diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
commit | bb79e2e871d0a4585164c1a6ed626d96d0231975 (patch) | |
tree | 6d457ba6c36c408b45db24ec3c29e147fe7504ff /test/testbase.py | |
parent | 4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff) | |
download | sqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz |
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'test/testbase.py')
-rw-r--r-- | test/testbase.py | 223 |
1 files changed, 147 insertions, 76 deletions
diff --git a/test/testbase.py b/test/testbase.py index 8ef63ef27..04972779d 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -2,63 +2,96 @@ import unittest import StringIO import sqlalchemy.engine as engine import sqlalchemy.ext.proxy as proxy -import sqlalchemy.schema as schema +import sqlalchemy.pool as pool +#import sqlalchemy.schema as schema import re, sys +import sqlalchemy +import optparse + -echo = True -#echo = False -#echo = 'debug' db = None +metadata = None db_uri = None +echo = True + +# redefine sys.stdout so all those print statements go to the echo func +local_stdout = sys.stdout +class Logger(object): + def write(self, msg): + if echo: + local_stdout.write(msg) +sys.stdout = Logger() + +def echo_text(text): + print text def parse_argv(): # we are using the unittest main runner, so we are just popping out the # arguments we need instead of using our own getopt type of thing - global db, db_uri + global db, db_uri, metadata DBTYPE = 'sqlite' PROXY = False + + + parser = optparse.OptionParser(usage = "usage: %prog [options] files...") + parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)") + parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (sqlite, sqlite_file, postgres, mysql, oracle, oracle8, mssql)") + parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool") + parser.add_option("--verbose", action="store_true", dest="verbose", help="full debug echoing") + parser.add_option("--quiet", action="store_true", dest="quiet", help="be totally quiet") + parser.add_option("--nothreadlocal", action="store_true", dest="nothreadlocal", help="dont use thread-local mod") + parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to SA default)") + + (options, args) = parser.parse_args() + sys.argv[1:] = args - if len(sys.argv) >= 3: - if sys.argv[1] == '--dburi': - (param, db_uri) = (sys.argv.pop(1), sys.argv.pop(1)) - elif sys.argv[1] == '--db': - (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1)) + if options.dburi: + db_uri = param = options.dburi + elif options.db: + DBTYPE = param = options.db + opts = {} if (None == db_uri): - p = DBTYPE.split('.') - if len(p) > 1: - arg = p[0] - DBTYPE = p[1] - if arg == 'proxy': - PROXY = True if DBTYPE == 'sqlite': - db_uri = 'sqlite://filename=:memory:' + db_uri = 'sqlite:///:memory:' elif DBTYPE == 'sqlite_file': - db_uri = 'sqlite://filename=querytest.db' + db_uri = 'sqlite:///querytest.db' elif DBTYPE == 'postgres': - db_uri = 'postgres://database=test&port=5432&host=127.0.0.1&user=scott&password=tiger' + db_uri = 'postgres://scott:tiger@127.0.0.1:5432/test' elif DBTYPE == 'mysql': - db_uri = 'mysql://database=test&host=127.0.0.1&user=scott&password=tiger' + db_uri = 'mysql://scott:tiger@127.0.0.1/test' elif DBTYPE == 'oracle': - db_uri = 'oracle://user=scott&password=tiger' + db_uri = 'oracle://scott:tiger@127.0.0.1:1521' elif DBTYPE == 'oracle8': - db_uri = 'oracle://user=scott&password=tiger' + db_uri = 'oracle://scott:tiger@127.0.0.1:1521' opts = {'use_ansi':False} elif DBTYPE == 'mssql': - db_uri = 'mssql://database=test&user=scott&password=tiger' + db_uri = 'mssql://scott:tiger@/test' if not db_uri: raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql> to test runner." - if PROXY: - db = proxy.ProxyEngine(echo=echo, default_ordering=True, **opts) - db.connect(db_uri) + if not options.nothreadlocal: + __import__('sqlalchemy.mods.threadlocal') + sqlalchemy.mods.threadlocal.uninstall_plugin() + + global echo + echo = options.verbose and not options.quiet + + global quiet + quiet = options.quiet + + if options.enginestrategy is not None: + opts['strategy'] = options.enginestrategy + if options.mockpool: + db = engine.create_engine(db_uri, echo=True, default_ordering=True, poolclass=MockPool, **opts) else: - db = engine.create_engine(db_uri, echo=echo, default_ordering=True, **opts) + db = engine.create_engine(db_uri, echo=True, default_ordering=True, **opts) db = EngineAssert(db) - + metadata = sqlalchemy.BoundMetaData(db) + def unsupported(*dbs): """a decorator that marks a test as unsupported by one or more database implementations""" def decorate(func): @@ -87,21 +120,49 @@ def supported(*dbs): return lala return decorate -def echo_text(text): - print text class PersistTest(unittest.TestCase): """persist base class, provides default setUpAll, tearDownAll and echo functionality""" def __init__(self, *args, **params): unittest.TestCase.__init__(self, *args, **params) def echo(self, text): - if echo: - echo_text(text) + echo_text(text) + def install_threadlocal(self): + sqlalchemy.mods.threadlocal.install_plugin() + def uninstall_threadlocal(self): + sqlalchemy.mods.threadlocal.uninstall_plugin() def setUpAll(self): pass def tearDownAll(self): pass + def shortDescription(self): + """overridden to not return docstrings""" + return None + +class MockPool(pool.Pool): + """this pool is hardcore about only one connection being used at a time.""" + def __init__(self, creator, **params): + pool.Pool.__init__(self, **params) + self.connection = creator() + self._conn = self.connection + + def status(self): + return "MockPool" + + def do_return_conn(self, conn): + assert conn is self._conn and self.connection is None + self.connection = conn + + def do_return_invalid(self): + raise "Invalid" + def do_get(self): + if getattr(self, 'breakpoint', False): + raise "breakpoint" + assert self.connection is not None + c = self.connection + self.connection = None + return c class AssertMixin(PersistTest): """given a list-based structure of keys/properties which represent information within an object structure, and @@ -145,8 +206,10 @@ class EngineAssert(proxy.BaseProxyEngine): """decorates a SQLEngine object to match the incoming queries against a set of assertions.""" def __init__(self, engine): self._engine = engine - self.realexec = engine.post_exec - self.realexec.im_self.post_exec = self.post_exec + + self.real_execution_context = engine.dialect.create_execution_context + engine.dialect.create_execution_context = self.execution_context + self.logger = engine.logger self.set_assert_list(None, None) self.sql_count = 0 @@ -154,8 +217,6 @@ class EngineAssert(proxy.BaseProxyEngine): return self._engine def set_engine(self, e): self._engine = e -# def __getattr__(self, key): - # return getattr(self.engine, key) def set_assert_list(self, unittest, list): self.unittest = unittest self.assert_list = list @@ -164,46 +225,55 @@ class EngineAssert(proxy.BaseProxyEngine): def _set_echo(self, echo): self.engine.echo = echo echo = property(lambda s: s.engine.echo, _set_echo) - def post_exec(self, proxy, compiled, parameters, **kwargs): - self.engine.logger = self.logger - statement = str(compiled) - statement = re.sub(r'\n', '', statement) - - if self.assert_list is not None: - item = self.assert_list[-1] - if not isinstance(item, dict): - item = self.assert_list.pop() - else: - # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in - # multiple orderings - if not item.has_key('_converted'): - for key in item.keys(): - ckey = self.convert_statement(key) - item[ckey] = item[key] - if ckey != key: - del item[key] - item['_converted'] = True - try: - entry = item.pop(statement) - if len(item) == 1: - self.assert_list.pop() - item = (statement, entry) - except KeyError: - self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) - - (query, params) = item - if callable(params): - params = params() - - query = self.convert_statement(query) - - self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - self.sql_count += 1 - return self.realexec(proxy, compiled, parameters, **kwargs) + + def execution_context(self): + def post_exec(engine, proxy, compiled, parameters, **kwargs): + ctx = e + self.engine.logger = self.logger + statement = str(compiled) + statement = re.sub(r'\n', '', statement) + if self.assert_list is not None: + item = self.assert_list[-1] + if not isinstance(item, dict): + item = self.assert_list.pop() + else: + # asserting a dictionary of statements->parameters + # this is to specify query assertions where the queries can be in + # multiple orderings + if not item.has_key('_converted'): + for key in item.keys(): + ckey = self.convert_statement(key) + item[ckey] = item[key] + if ckey != key: + del item[key] + item['_converted'] = True + try: + entry = item.pop(statement) + if len(item) == 1: + self.assert_list.pop() + item = (statement, entry) + except KeyError: + self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) + + (query, params) = item + if callable(params): + params = params(ctx) + if params is not None and isinstance(params, list) and len(params) == 1: + params = params[0] + + query = self.convert_statement(query) + self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) + self.sql_count += 1 + return realexec(ctx, proxy, compiled, parameters, **kwargs) + + e = self.real_execution_context() + realexec = e.post_exec + realexec.im_self.post_exec = post_exec + return e + def convert_statement(self, query): - paramstyle = self.engine.paramstyle + paramstyle = self.engine.dialect.paramstyle if paramstyle == 'named': pass elif paramstyle =='pyformat': @@ -275,10 +345,11 @@ parse_argv() def runTests(suite): - runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) + runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) runner.run(suite) def main(): - unittest.main() + suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__')) + runTests(suite) |