summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrien Di Mascio <Adrien.DiMascio@logilab.fr>2007-03-09 09:46:25 +0100
committerAdrien Di Mascio <Adrien.DiMascio@logilab.fr>2007-03-09 09:46:25 +0100
commit012497eab72481df3f9359ca6494c72f1834359a (patch)
treed22cd9d29d16cfd87ba6d72f893e810538b9ac99
parentb7e21e539ca02132aac5af459e4ade5d54836ec8 (diff)
downloadlogilab-common-012497eab72481df3f9359ca6494c72f1834359a.tar.gz
this should be enough to consider test suites in unittest_main
-rw-r--r--pytest.py2
-rw-r--r--testlib.py21
2 files changed, 18 insertions, 5 deletions
diff --git a/pytest.py b/pytest.py
index def2d4f..6457b77 100644
--- a/pytest.py
+++ b/pytest.py
@@ -131,7 +131,7 @@ def testfile(filename, batchmode=False):
modname = osp.basename(filename)[:-3]
try:
tstart, cstart = time(), clock()
- testprog = testlib.unittest_main(modname, batchmode)
+ testprog = testlib.unittest_main(modname, batchmode=batchmode)
tend, cend = time(), clock()
return testprog, (tend - tstart), (cend - cstart)
finally:
diff --git a/testlib.py b/testlib.py
index 9d405de..a835889 100644
--- a/testlib.py
+++ b/testlib.py
@@ -450,6 +450,15 @@ class NonStrictTestLoader(unittest.TestLoader):
# keep track of class (obj) for convenience
tests[classname] = (obj, methodnames)
return tests
+
+ def loadTestsFromSuite(self, module, suitename):
+ try:
+ suite = getattr(module, suitename)()
+ except AttributeError:
+ print "No such suite", suitename
+ return []
+ return suite
+
def loadTestsFromName(self, name, module=None):
parts = name.split('.')
@@ -462,6 +471,9 @@ class NonStrictTestLoader(unittest.TestLoader):
collected = []
if len(parts) == 1:
pattern = parts[0]
+ if callable(getattr(module, pattern, None)) and pattern not in tests:
+ # consider it as a suite
+ return self.loadTestsFromSuite(module, pattern)
if pattern in tests:
# case python unittest_foo.py MyTestTC
klass, methodnames = tests[pattern]
@@ -514,10 +526,11 @@ Examples:
%(progName)s MyTestCase - run all 'test*' test methods
in MyTestCase
"""
- def __init__(self, module='__main__', batchmode=False):
+ def __init__(self, module='__main__', defaultTest=None, batchmode=False):
self.batchmode = batchmode
super(SkipAwareTestProgram, self).__init__(
- module=module, testLoader=NonStrictTestLoader())
+ module=module, defaultTest=defaultTest,
+ testLoader=NonStrictTestLoader())
def parseArgs(self, argv):
@@ -644,10 +657,10 @@ def capture_stderr():
return _capture('stderr')
-def unittest_main(module='__main__', batchmode=False):
+def unittest_main(module='__main__', defaultTest=None, batchmode=False):
"""use this functon if you want to have the same functionality
as unittest.main"""
- return SkipAwareTestProgram(module, batchmode)
+ return SkipAwareTestProgram(module, defaultTest, batchmode)
class TestSkipped(Exception):
"""raised when a test is skipped"""