summaryrefslogtreecommitdiff
path: root/test/testlib/testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r--test/testlib/testing.py41
1 files changed, 16 insertions, 25 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index 213772e9e..9b92f46d5 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -2,11 +2,12 @@
# monkeypatches unittest.TestLoader.suiteClass at import time
+import testbase
import unittest, re, sys, os
from cStringIO import StringIO
-from sqlalchemy import MetaData, sql
-from sqlalchemy.orm import clear_mappers
import testlib.config as config
+sql, Metadata, clear_mappers = None, None, None
+
__all__ = 'PersistTest', 'AssertMixin', 'ORMTest'
@@ -70,6 +71,10 @@ class ExecutionContextWrapper(object):
can be tracked."""
def __init__(self, ctx):
+ global sql
+ if sql is None:
+ from sqlalchemy import sql
+
self.__dict__['ctx'] = ctx
def __getattr__(self, key):
return getattr(self.ctx, key)
@@ -228,7 +233,11 @@ class ORMTest(AssertMixin):
keep_data = False
def setUpAll(self):
- global _otest_metadata
+ global MetaData, _otest_metadata
+
+ if MetaData is None:
+ from sqlalchemy import MetaData
+
_otest_metadata = MetaData(config.db)
self.define_tables(_otest_metadata)
_otest_metadata.create_all()
@@ -248,6 +257,9 @@ class ORMTest(AssertMixin):
_otest_metadata.drop_all()
def tearDown(self):
+ if clear_mappers is None:
+ from sqlalchemy.orm import clear_mappers
+
if not self.keep_mappers:
clear_mappers()
if not self.keep_data:
@@ -305,30 +317,9 @@ class TTestSuite(unittest.TestSuite):
return (exctype, excvalue, tb)
return (exctype, excvalue, tb)
+# monkeypatch
unittest.TestLoader.suiteClass = TTestSuite
-def _iter_covered_files():
- import sqlalchemy
- for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
- for x in rec[2]:
- if x.endswith('.py'):
- yield os.path.join(rec[0], x)
-
-def cover(callable_, file_=None):
- from testlib import coverage
- coverage_client = coverage.the_coverage
- coverage_client.get_ready()
- coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
- coverage_client.erase()
- coverage_client.start()
- try:
- return callable_()
- finally:
- coverage_client.stop()
- coverage_client.save()
- coverage_client.report(list(_iter_covered_files()),
- show_missing=False, ignore_errors=False,
- file=file_)
class DevNullWriter(object):
def write(self, msg):