summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/pytest52
-rw-r--r--testlib.py25
2 files changed, 57 insertions, 20 deletions
diff --git a/bin/pytest b/bin/pytest
index 6038cf6..32fd1bf 100755
--- a/bin/pytest
+++ b/bin/pytest
@@ -1,8 +1,19 @@
-#!/usr/bin/python
+#!/usr/bin/env python
import os, sys
import os.path as osp
from logilab.common import testlib
+import doctest
+import unittest
+
+# monkeypatch unittest and doctest (ouch !)
+unittest.TestCase = testlib.TestCase
+unittest.main = testlib.unittest_main
+unittest._TextTestResult = testlib.SkipAwareTestResult
+unittest.TextTestRunner = testlib.SkipAwareTextTestRunner
+unittest.TestLoader = testlib.NonStrictTestLoader
+unittest.TestProgram = testlib.SkipAwareTestProgram
+doctest.DocTestCase.__bases__ = (testlib.TestCase,)
def autopath(projdir=os.getcwd()):
@@ -16,12 +27,35 @@ def autopath(projdir=os.getcwd()):
else:
sys.path.insert(0, curdir)
-autopath()
+def testfile(filename):
+ sys.argv.remove(filename)
+ here = os.getcwd()
+ dirname = osp.dirname(filename)
+ if dirname:
+ os.chdir(dirname)
+ sys.path.insert(0, '')
+ modname = osp.basename(filename)[:-3]
+ testlib.unittest_main(modname)
+
+def testall():
+ errcode = 0
+ for dirname, dirs, files in os.walk(os.getcwd()):
+ for skipped in ('CVS', '.svn', '.hg'):
+ if skipped in dirs:
+ dirs.remove(skipped)
+ basename = osp.basename(dirname)
+ if basename in ('test', 'tests'):
+ errcode += testlib.main(dirname, exitafter=False)
+ return errcode
+
+if __name__ == '__main__':
+ autopath()
+ filenames = [arg for arg in sys.argv[1:] if arg.endswith('.py')]
+ if filenames:
+ if len(filenames) > 1:
+ print "Usage: pytest [filename]"
+ sys.exit(-1)
+ # testfile will exit directly
+ testfile(filenames[0])
+ sys.exit(testall())
-for dirname, dirs, files in os.walk(os.getcwd()):
- for skipped in ('CVS', '.svn', '.hg'):
- if skipped in dirs:
- dirs.remove(skipped)
- basename = osp.basename(dirname)
- if basename in ('test', 'tests'):
- testlib.main(dirname, exitafter=False)
diff --git a/testlib.py b/testlib.py
index fdc3891..dcfba35 100644
--- a/testlib.py
+++ b/testlib.py
@@ -160,6 +160,7 @@ def main(testdir=None, exitafter=True):
sys.exit(len(bad) + len(skipped))
else:
sys.path.pop(0)
+ return len(bad)
def run_tests(tests, quiet, verbose, runner=None, capture=0):
""" execute a list of tests
@@ -310,7 +311,8 @@ class SkipAwareTestResult(unittest._TextTestResult):
def __init__(self, stream, descriptions, verbosity,
exitfirst=False, capture=0):
- unittest._TextTestResult.__init__(self, stream, descriptions, verbosity)
+ super(SkipAwareTestResult, self).__init__(stream,
+ descriptions, verbosity)
self.skipped = []
self.debuggers = []
self.descrs = []
@@ -328,13 +330,13 @@ class SkipAwareTestResult(unittest._TextTestResult):
else:
if self.exitfirst:
self.shouldStop = True
- unittest._TextTestResult.addError(self, test, err)
+ super(SkipAwareTestResult, self).addError(test, err)
self._create_pdb(self.getDescription(test))
def addFailure(self, test, err):
if self.exitfirst:
self.shouldStop = True
- unittest._TextTestResult.addFailure(self, test, err)
+ super(SkipAwareTestResult, self).addError(test, err)
self._create_pdb(self.getDescription(test))
def addSkipped(self, test, reason):
@@ -345,7 +347,7 @@ class SkipAwareTestResult(unittest._TextTestResult):
self.stream.write('S')
def printErrors(self):
- unittest._TextTestResult.printErrors(self)
+ super(SkipAwareTestResult, self).printErrors()
self.printSkippedList()
def printSkippedList(self):
@@ -382,7 +384,8 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
def __init__(self, stream=sys.stderr, verbosity=1,
exitfirst=False, capture=False):
- unittest.TextTestRunner.__init__(self, stream=stream, verbosity=verbosity)
+ super(SkipAwareTextTestRunner, self).__init__(stream=stream,
+ verbosity=verbosity)
self.exitfirst = exitfirst
self.capture = capture
@@ -449,7 +452,7 @@ class NonStrictTestLoader(unittest.TestLoader):
parts = name.split('.')
if module is None or len(parts) > 2:
# let the base class do its job here
- return [unittest.TestLoader.loadTestsFromName(self, name)]
+ return [super(NonStrictTestLoader, self).loadTestsFromName(name)]
tests = self._collect_tests(module)
# import pprint
# pprint.pprint(tests)
@@ -509,8 +512,8 @@ Examples:
in MyTestCase
"""
def __init__(self, module='__main__'):
- unittest.TestProgram.__init__(self, module=module,
- testLoader=NonStrictTestLoader())
+ super(SkipAwareTestProgram, self).__init__(
+ module=module, testLoader=NonStrictTestLoader())
def parseArgs(self, argv):
@@ -635,10 +638,10 @@ def capture_stderr():
return _capture('stderr')
-def unittest_main():
+def unittest_main(module='__main__'):
"""use this functon if you want to have the same functionality
as unittest.main"""
- SkipAwareTestProgram()
+ SkipAwareTestProgram(module)
class TestSkipped(Exception):
"""raised when a test is skipped"""
@@ -676,7 +679,7 @@ class TestCase(unittest.TestCase):
"""unittest.TestCase with some additional methods"""
def __init__(self, methodName='runTest'):
- unittest.TestCase.__init__(self, methodName)
+ super(TestCase, self).__init__(methodName)
# internal API changed in python2.5
if sys.version_info >= (2, 5):
self.__exc_info = self._exc_info