summaryrefslogtreecommitdiff
path: root/Lib/unittest
diff options
context:
space:
mode:
authorWalter Doerwald <walter@livinglogic.de>2013-12-02 11:43:20 +0100
committerWalter Doerwald <walter@livinglogic.de>2013-12-02 11:43:20 +0100
commit1a7a32b4875516d530b78f524e83aefac9cad1ab (patch)
tree49024aa50be776f7233b7a7b3083549f258822dd /Lib/unittest
parent97261caa90b6e16b6866de47febd059f3b0cd4a1 (diff)
parentd2fb1d098e418d099d7a2b7893a1201f606122fe (diff)
downloadcpython-1a7a32b4875516d530b78f524e83aefac9cad1ab.tar.gz
Fix #19834: merge with 3.3.
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/__main__.py3
-rw-r--r--Lib/unittest/case.py379
-rw-r--r--Lib/unittest/loader.py76
-rw-r--r--Lib/unittest/main.py247
-rw-r--r--Lib/unittest/mock.py229
-rw-r--r--Lib/unittest/result.py16
-rw-r--r--Lib/unittest/suite.py19
-rw-r--r--Lib/unittest/test/__main__.py18
-rw-r--r--Lib/unittest/test/support.py38
-rw-r--r--Lib/unittest/test/test_assertions.py4
-rw-r--r--Lib/unittest/test/test_break.py4
-rw-r--r--Lib/unittest/test/test_case.py231
-rw-r--r--Lib/unittest/test/test_discovery.py169
-rw-r--r--Lib/unittest/test/test_functiontestcase.py4
-rw-r--r--Lib/unittest/test/test_loader.py4
-rw-r--r--Lib/unittest/test/test_program.py72
-rw-r--r--Lib/unittest/test/test_result.py92
-rw-r--r--Lib/unittest/test/test_runner.py16
-rw-r--r--Lib/unittest/test/test_setups.py7
-rw-r--r--Lib/unittest/test/test_skipping.py78
-rw-r--r--Lib/unittest/test/test_suite.py50
-rw-r--r--Lib/unittest/test/testmock/__main__.py18
-rw-r--r--Lib/unittest/test/testmock/testcallable.py4
-rw-r--r--Lib/unittest/test/testmock/testhelpers.py59
-rw-r--r--Lib/unittest/test/testmock/testmock.py143
-rw-r--r--Lib/unittest/test/testmock/testpatch.py1
-rw-r--r--Lib/unittest/test/testmock/testwith.py83
-rw-r--r--Lib/unittest/util.py37
28 files changed, 1671 insertions, 430 deletions
diff --git a/Lib/unittest/__main__.py b/Lib/unittest/__main__.py
index 798ebc0f53..2663178d3f 100644
--- a/Lib/unittest/__main__.py
+++ b/Lib/unittest/__main__.py
@@ -13,7 +13,6 @@ if sys.argv[0].endswith("__main__.py"):
__unittest = True
-from .main import main, TestProgram, USAGE_AS_MAIN
-TestProgram.USAGE = USAGE_AS_MAIN
+from .main import main, TestProgram
main(module=None)
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index f56af5574b..7ed932fafd 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -3,14 +3,16 @@
import sys
import functools
import difflib
+import logging
import pprint
import re
import warnings
import collections
+import contextlib
from . import result
from .util import (strclass, safe_repr, _count_diff_all_purpose,
- _count_diff_hashable)
+ _count_diff_hashable, _common_shorten_repr)
__unittest = True
@@ -26,17 +28,11 @@ class SkipTest(Exception):
instead of raising this directly.
"""
-class _ExpectedFailure(Exception):
+class _ShouldStop(Exception):
"""
- Raise this when a test is expected to fail.
-
- This is an implementation detail.
+ The test should stop.
"""
- def __init__(self, exc_info):
- super(_ExpectedFailure, self).__init__()
- self.exc_info = exc_info
-
class _UnexpectedSuccess(Exception):
"""
The test was supposed to fail, but it didn't!
@@ -44,13 +40,40 @@ class _UnexpectedSuccess(Exception):
class _Outcome(object):
- def __init__(self):
+ def __init__(self, result=None):
+ self.expecting_failure = False
+ self.result = result
+ self.result_supports_subtests = hasattr(result, "addSubTest")
self.success = True
- self.skipped = None
- self.unexpectedSuccess = None
+ self.skipped = []
self.expectedFailure = None
self.errors = []
- self.failures = []
+
+ @contextlib.contextmanager
+ def testPartExecutor(self, test_case, isTest=False):
+ old_success = self.success
+ self.success = True
+ try:
+ yield
+ except KeyboardInterrupt:
+ raise
+ except SkipTest as e:
+ self.success = False
+ self.skipped.append((test_case, str(e)))
+ except _ShouldStop:
+ pass
+ except:
+ exc_info = sys.exc_info()
+ if self.expecting_failure:
+ self.expectedFailure = exc_info
+ else:
+ self.success = False
+ self.errors.append((test_case, exc_info))
+ else:
+ if self.result_supports_subtests and self.success:
+ self.errors.append((test_case, None))
+ finally:
+ self.success = self.success and old_success
def _id(obj):
@@ -88,22 +111,26 @@ def skipUnless(condition, reason):
return skip(reason)
return _id
+def expectedFailure(test_item):
+ test_item.__unittest_expecting_failure__ = True
+ return test_item
-def expectedFailure(func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- try:
- func(*args, **kwargs)
- except Exception:
- raise _ExpectedFailure(sys.exc_info())
- raise _UnexpectedSuccess
- return wrapper
+class _BaseTestCaseContext:
-class _AssertRaisesBaseContext(object):
+ def __init__(self, test_case):
+ self.test_case = test_case
+
+ def _raiseFailure(self, standardMsg):
+ msg = self.test_case._formatMessage(self.msg, standardMsg)
+ raise self.test_case.failureException(msg)
+
+
+class _AssertRaisesBaseContext(_BaseTestCaseContext):
def __init__(self, expected, test_case, callable_obj=None,
expected_regex=None):
+ _BaseTestCaseContext.__init__(self, test_case)
self.expected = expected
self.test_case = test_case
if callable_obj is not None:
@@ -118,10 +145,6 @@ class _AssertRaisesBaseContext(object):
self.expected_regex = expected_regex
self.msg = None
- def _raiseFailure(self, standardMsg):
- msg = self.test_case._formatMessage(self.msg, standardMsg)
- raise self.test_case.failureException(msg)
-
def handle(self, name, callable_obj, args, kwargs):
"""
If callable_obj is None, assertRaises/Warns is being used as a
@@ -135,7 +158,6 @@ class _AssertRaisesBaseContext(object):
callable_obj(*args, **kwargs)
-
class _AssertRaisesContext(_AssertRaisesBaseContext):
"""A context manager used to implement TestCase.assertRaises* methods."""
@@ -217,6 +239,74 @@ class _AssertWarnsContext(_AssertRaisesBaseContext):
self._raiseFailure("{} not triggered".format(exc_name))
+
+_LoggingWatcher = collections.namedtuple("_LoggingWatcher",
+ ["records", "output"])
+
+
+class _CapturingHandler(logging.Handler):
+ """
+ A logging handler capturing all (raw and formatted) logging output.
+ """
+
+ def __init__(self):
+ logging.Handler.__init__(self)
+ self.watcher = _LoggingWatcher([], [])
+
+ def flush(self):
+ pass
+
+ def emit(self, record):
+ self.watcher.records.append(record)
+ msg = self.format(record)
+ self.watcher.output.append(msg)
+
+
+
+class _AssertLogsContext(_BaseTestCaseContext):
+ """A context manager used to implement TestCase.assertLogs()."""
+
+ LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s"
+
+ def __init__(self, test_case, logger_name, level):
+ _BaseTestCaseContext.__init__(self, test_case)
+ self.logger_name = logger_name
+ if level:
+ self.level = logging._nameToLevel.get(level, level)
+ else:
+ self.level = logging.INFO
+ self.msg = None
+
+ def __enter__(self):
+ if isinstance(self.logger_name, logging.Logger):
+ logger = self.logger = self.logger_name
+ else:
+ logger = self.logger = logging.getLogger(self.logger_name)
+ formatter = logging.Formatter(self.LOGGING_FORMAT)
+ handler = _CapturingHandler()
+ handler.setFormatter(formatter)
+ self.watcher = handler.watcher
+ self.old_handlers = logger.handlers[:]
+ self.old_level = logger.level
+ self.old_propagate = logger.propagate
+ logger.handlers = [handler]
+ logger.setLevel(self.level)
+ logger.propagate = False
+ return handler.watcher
+
+ def __exit__(self, exc_type, exc_value, tb):
+ self.logger.handlers = self.old_handlers
+ self.logger.propagate = self.old_propagate
+ self.logger.setLevel(self.old_level)
+ if exc_type is not None:
+ # let unexpected exceptions pass through
+ return False
+ if len(self.watcher.records) == 0:
+ self._raiseFailure(
+ "no logs of level {} or higher triggered on {}"
+ .format(logging.getLevelName(self.level), self.logger.name))
+
+
class TestCase(object):
"""A class whose instances are single test cases.
@@ -270,7 +360,7 @@ class TestCase(object):
not have a method with the specified name.
"""
self._testMethodName = methodName
- self._outcomeForDoCleanups = None
+ self._outcome = None
self._testMethodDoc = 'No test'
try:
testMethod = getattr(self, methodName)
@@ -283,6 +373,7 @@ class TestCase(object):
else:
self._testMethodDoc = testMethod.__doc__
self._cleanups = []
+ self._subtest = None
# Map types to custom assertEqual functions that will compare
# instances of said type in more detail to generate a more useful
@@ -370,44 +461,80 @@ class TestCase(object):
return "<%s testMethod=%s>" % \
(strclass(self.__class__), self._testMethodName)
- def _addSkip(self, result, reason):
+ def _addSkip(self, result, test_case, reason):
addSkip = getattr(result, 'addSkip', None)
if addSkip is not None:
- addSkip(self, reason)
+ addSkip(test_case, reason)
else:
warnings.warn("TestResult has no addSkip method, skips not reported",
RuntimeWarning, 2)
+ result.addSuccess(test_case)
+
+ @contextlib.contextmanager
+ def subTest(self, msg=None, **params):
+ """Return a context manager that will return the enclosed block
+ of code in a subtest identified by the optional message and
+ keyword parameters. A failure in the subtest marks the test
+ case as failed but resumes execution at the end of the enclosed
+ block, allowing further test code to be executed.
+ """
+ if not self._outcome.result_supports_subtests:
+ yield
+ return
+ parent = self._subtest
+ if parent is None:
+ params_map = collections.ChainMap(params)
+ else:
+ params_map = parent.params.new_child(params)
+ self._subtest = _SubTest(self, msg, params_map)
+ try:
+ with self._outcome.testPartExecutor(self._subtest, isTest=True):
+ yield
+ if not self._outcome.success:
+ result = self._outcome.result
+ if result is not None and result.failfast:
+ raise _ShouldStop
+ elif self._outcome.expectedFailure:
+ # If the test is expecting a failure, we really want to
+ # stop now and register the expected failure.
+ raise _ShouldStop
+ finally:
+ self._subtest = parent
+
+ def _feedErrorsToResult(self, result, errors):
+ for test, exc_info in errors:
+ if isinstance(test, _SubTest):
+ result.addSubTest(test.test_case, test, exc_info)
+ elif exc_info is not None:
+ if issubclass(exc_info[0], self.failureException):
+ result.addFailure(test, exc_info)
+ else:
+ result.addError(test, exc_info)
+
+ def _addExpectedFailure(self, result, exc_info):
+ try:
+ addExpectedFailure = result.addExpectedFailure
+ except AttributeError:
+ warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
+ RuntimeWarning)
result.addSuccess(self)
+ else:
+ addExpectedFailure(self, exc_info)
- def _executeTestPart(self, function, outcome, isTest=False):
+ def _addUnexpectedSuccess(self, result):
try:
- function()
- except KeyboardInterrupt:
- raise
- except SkipTest as e:
- outcome.success = False
- outcome.skipped = str(e)
- except _UnexpectedSuccess:
- exc_info = sys.exc_info()
- outcome.success = False
- if isTest:
- outcome.unexpectedSuccess = exc_info
- else:
- outcome.errors.append(exc_info)
- except _ExpectedFailure:
- outcome.success = False
- exc_info = sys.exc_info()
- if isTest:
- outcome.expectedFailure = exc_info
- else:
- outcome.errors.append(exc_info)
- except self.failureException:
- outcome.success = False
- outcome.failures.append(sys.exc_info())
- exc_info = sys.exc_info()
- except:
- outcome.success = False
- outcome.errors.append(sys.exc_info())
+ addUnexpectedSuccess = result.addUnexpectedSuccess
+ except AttributeError:
+ warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure",
+ RuntimeWarning)
+ # We need to pass an actual exception and traceback to addFailure,
+ # otherwise the legacy result can choke.
+ try:
+ raise _UnexpectedSuccess from None
+ except _UnexpectedSuccess:
+ result.addFailure(self, sys.exc_info())
+ else:
+ addUnexpectedSuccess(self)
def run(self, result=None):
orig_result = result
@@ -426,46 +553,38 @@ class TestCase(object):
try:
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
- self._addSkip(result, skip_why)
+ self._addSkip(result, self, skip_why)
finally:
result.stopTest(self)
return
+ expecting_failure = getattr(testMethod,
+ "__unittest_expecting_failure__", False)
try:
- outcome = _Outcome()
- self._outcomeForDoCleanups = outcome
+ outcome = _Outcome(result)
+ self._outcome = outcome
- self._executeTestPart(self.setUp, outcome)
+ with outcome.testPartExecutor(self):
+ self.setUp()
if outcome.success:
- self._executeTestPart(testMethod, outcome, isTest=True)
- self._executeTestPart(self.tearDown, outcome)
+ outcome.expecting_failure = expecting_failure
+ with outcome.testPartExecutor(self, isTest=True):
+ testMethod()
+ outcome.expecting_failure = False
+ with outcome.testPartExecutor(self):
+ self.tearDown()
self.doCleanups()
+ for test, reason in outcome.skipped:
+ self._addSkip(result, test, reason)
+ self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
- result.addSuccess(self)
- else:
- if outcome.skipped is not None:
- self._addSkip(result, outcome.skipped)
- for exc_info in outcome.errors:
- result.addError(self, exc_info)
- for exc_info in outcome.failures:
- result.addFailure(self, exc_info)
- if outcome.unexpectedSuccess is not None:
- addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
- if addUnexpectedSuccess is not None:
- addUnexpectedSuccess(self)
+ if expecting_failure:
+ if outcome.expectedFailure:
+ self._addExpectedFailure(result, outcome.expectedFailure)
else:
- warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
- RuntimeWarning)
- result.addFailure(self, outcome.unexpectedSuccess)
-
- if outcome.expectedFailure is not None:
- addExpectedFailure = getattr(result, 'addExpectedFailure', None)
- if addExpectedFailure is not None:
- addExpectedFailure(self, outcome.expectedFailure)
- else:
- warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
- RuntimeWarning)
- result.addSuccess(self)
+ self._addUnexpectedSuccess(result)
+ else:
+ result.addSuccess(self)
return result
finally:
result.stopTest(self)
@@ -477,11 +596,11 @@ class TestCase(object):
def doCleanups(self):
"""Execute all cleanup functions. Normally called for you after
tearDown."""
- outcome = self._outcomeForDoCleanups or _Outcome()
+ outcome = self._outcome or _Outcome()
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
- part = lambda: function(*args, **kwargs)
- self._executeTestPart(part, outcome)
+ with outcome.testPartExecutor(self):
+ function(*args, **kwargs)
# return this for backwards compatibility
# even though we no longer us it internally
@@ -600,6 +719,28 @@ class TestCase(object):
context = _AssertWarnsContext(expected_warning, self, callable_obj)
return context.handle('assertWarns', callable_obj, args, kwargs)
+ def assertLogs(self, logger=None, level=None):
+ """Fail unless a log message of level *level* or higher is emitted
+ on *logger_name* or its children. If omitted, *level* defaults to
+ INFO and *logger* defaults to the root logger.
+
+ This method must be used as a context manager, and will yield
+ a recording object with two attributes: `output` and `records`.
+ At the end of the context manager, the `output` attribute will
+ be a list of the matching formatted log messages and the
+ `records` attribute will be a list of the corresponding LogRecord
+ objects.
+
+ Example::
+
+ with self.assertLogs('foo', level='INFO') as cm:
+ logging.getLogger('foo').info('first message')
+ logging.getLogger('foo.bar').error('second message')
+ self.assertEqual(cm.output, ['INFO:foo:first message',
+ 'ERROR:foo.bar:second message'])
+ """
+ return _AssertLogsContext(self, logger, level)
+
def _getAssertEqualityFunc(self, first, second):
"""Get a detailed comparison function for the types of the two args.
@@ -629,7 +770,7 @@ class TestCase(object):
def _baseAssertEqual(self, first, second, msg=None):
"""The default assertEqual implementation, not type specific."""
if not first == second:
- standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
+ standardMsg = '%s != %s' % _common_shorten_repr(first, second)
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
@@ -764,14 +905,9 @@ class TestCase(object):
if seq1 == seq2:
return
- seq1_repr = safe_repr(seq1)
- seq2_repr = safe_repr(seq2)
- if len(seq1_repr) > 30:
- seq1_repr = seq1_repr[:30] + '...'
- if len(seq2_repr) > 30:
- seq2_repr = seq2_repr[:30] + '...'
- elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
- differing = '%ss differ: %s != %s\n' % elements
+ differing = '%ss differ: %s != %s\n' % (
+ (seq_type_name.capitalize(),) +
+ _common_shorten_repr(seq1, seq2))
for i in range(min(len1, len2)):
try:
@@ -929,7 +1065,7 @@ class TestCase(object):
self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
if d1 != d2:
- standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
+ standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
diff = ('\n' + '\n'.join(difflib.ndiff(
pprint.pformat(d1).splitlines(),
pprint.pformat(d2).splitlines())))
@@ -1013,8 +1149,7 @@ class TestCase(object):
if len(firstlines) == 1 and first.strip('\r\n') == first:
firstlines = [first + '\n']
secondlines = [second + '\n']
- standardMsg = '%s != %s' % (safe_repr(first, True),
- safe_repr(second, True))
+ standardMsg = '%s != %s' % _common_shorten_repr(first, second)
diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
@@ -1212,3 +1347,39 @@ class FunctionTestCase(TestCase):
return self._description
doc = self._testFunc.__doc__
return doc and doc.split("\n")[0].strip() or None
+
+
+class _SubTest(TestCase):
+
+ def __init__(self, test_case, message, params):
+ super().__init__()
+ self._message = message
+ self.test_case = test_case
+ self.params = params
+ self.failureException = test_case.failureException
+
+ def runTest(self):
+ raise NotImplementedError("subtests cannot be run directly")
+
+ def _subDescription(self):
+ parts = []
+ if self._message:
+ parts.append("[{}]".format(self._message))
+ if self.params:
+ params_desc = ', '.join(
+ "{}={!r}".format(k, v)
+ for (k, v) in sorted(self.params.items()))
+ parts.append("({})".format(params_desc))
+ return " ".join(parts) or '(<subtest>)'
+
+ def id(self):
+ return "{} {}".format(self.test_case.id(), self._subDescription())
+
+ def shortDescription(self):
+ """Returns a one-line description of the subtest, or None if no
+ description has been provided.
+ """
+ return self.test_case.shortDescription()
+
+ def __str__(self):
+ return "{} {}".format(self.test_case, self._subDescription())
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index 9ab26c1efb..808c50eb66 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -34,6 +34,14 @@ def _make_failed_test(classname, methodname, exception, suiteClass):
TestClass = type(classname, (case.TestCase,), attrs)
return suiteClass((TestClass(methodname),))
+def _make_skipped_test(methodname, exception, suiteClass):
+ @case.skip(str(exception))
+ def testSkipped(self):
+ pass
+ attrs = {methodname: testSkipped}
+ TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
+ return suiteClass((TestClass(methodname),))
+
def _jython_aware_splitext(path):
if path.lower().endswith('$py.class'):
return path[:-9]
@@ -53,8 +61,9 @@ class TestLoader(object):
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
if issubclass(testCaseClass, suite.TestSuite):
- raise TypeError("Test cases should not be derived from TestSuite." \
- " Maybe you meant to derive from TestCase?")
+ raise TypeError("Test cases should not be derived from "
+ "TestSuite. Maybe you meant to derive from "
+ "TestCase?")
testCaseNames = self.getTestCaseNames(testCaseClass)
if not testCaseNames and hasattr(testCaseClass, 'runTest'):
testCaseNames = ['runTest']
@@ -169,6 +178,9 @@ class TestLoader(object):
The pattern is deliberately not stored as a loader attribute so that
packages can continue discovery themselves. top_level_dir is stored so
load_tests does not need to pass this argument in to loader.discover().
+
+ Paths are sorted before being imported to ensure reproducible execution
+ order even on filesystems with non-alphabetical ordering like ext3/4.
"""
set_implicit_top = False
if top_level_dir is None and self._top_level_dir is not None:
@@ -189,6 +201,8 @@ class TestLoader(object):
self._top_level_dir = top_level_dir
is_not_importable = False
+ is_namespace = False
+ tests = []
if os.path.isdir(os.path.abspath(start_dir)):
start_dir = os.path.abspath(start_dir)
if start_dir != top_level_dir:
@@ -202,15 +216,52 @@ class TestLoader(object):
else:
the_module = sys.modules[start_dir]
top_part = start_dir.split('.')[0]
- start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
+ try:
+ start_dir = os.path.abspath(
+ os.path.dirname((the_module.__file__)))
+ except AttributeError:
+ # look for namespace packages
+ try:
+ spec = the_module.__spec__
+ except AttributeError:
+ spec = None
+
+ if spec and spec.loader is None:
+ if spec.submodule_search_locations is not None:
+ is_namespace = True
+
+ for path in the_module.__path__:
+ if (not set_implicit_top and
+ not path.startswith(top_level_dir)):
+ continue
+ self._top_level_dir = \
+ (path.split(the_module.__name__
+ .replace(".", os.path.sep))[0])
+ tests.extend(self._find_tests(path,
+ pattern,
+ namespace=True))
+ elif the_module.__name__ in sys.builtin_module_names:
+ # builtin module
+ raise TypeError('Can not use builtin modules '
+ 'as dotted module names') from None
+ else:
+ raise TypeError(
+ 'don\'t know how to discover from {!r}'
+ .format(the_module)) from None
+
if set_implicit_top:
- self._top_level_dir = self._get_directory_containing_module(top_part)
- sys.path.remove(top_level_dir)
+ if not is_namespace:
+ self._top_level_dir = \
+ self._get_directory_containing_module(top_part)
+ sys.path.remove(top_level_dir)
+ else:
+ sys.path.remove(top_level_dir)
if is_not_importable:
raise ImportError('Start directory is not importable: %r' % start_dir)
- tests = list(self._find_tests(start_dir, pattern))
+ if not is_namespace:
+ tests = list(self._find_tests(start_dir, pattern))
return self.suiteClass(tests)
def _get_directory_containing_module(self, module_name):
@@ -243,9 +294,9 @@ class TestLoader(object):
# override this method to use alternative matching strategy
return fnmatch(path, pattern)
- def _find_tests(self, start_dir, pattern):
+ def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
- paths = os.listdir(start_dir)
+ paths = sorted(os.listdir(start_dir))
for path in paths:
full_path = os.path.join(start_dir, path)
@@ -259,6 +310,8 @@ class TestLoader(object):
name = self._get_name_from_path(full_path)
try:
module = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ yield _make_skipped_test(name, e, self.suiteClass)
except:
yield _make_failed_import_test(name, self.suiteClass)
else:
@@ -274,7 +327,8 @@ class TestLoader(object):
raise ImportError(msg % (mod_name, module_dir, expected_dir))
yield self.loadTestsFromModule(module)
elif os.path.isdir(full_path):
- if not os.path.isfile(os.path.join(full_path, '__init__.py')):
+ if (not namespace and
+ not os.path.isfile(os.path.join(full_path, '__init__.py'))):
continue
load_tests = None
@@ -291,8 +345,8 @@ class TestLoader(object):
# tests loaded from package file
yield tests
# recurse into the package
- for test in self._find_tests(full_path, pattern):
- yield test
+ yield from self._find_tests(full_path, pattern,
+ namespace=namespace)
else:
try:
yield load_tests(self, tests, pattern)
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py
index ead64936a0..180df8676e 100644
--- a/Lib/unittest/main.py
+++ b/Lib/unittest/main.py
@@ -1,7 +1,7 @@
"""Unittest main program"""
import sys
-import optparse
+import argparse
import os
from . import loader, runner
@@ -9,53 +9,20 @@ from .signals import installHandler
__unittest = True
-FAILFAST = " -f, --failfast Stop on first failure\n"
-CATCHBREAK = " -c, --catch Catch control-C and display results\n"
-BUFFEROUTPUT = " -b, --buffer Buffer stdout and stderr during test runs\n"
-
-USAGE_AS_MAIN = """\
-Usage: %(progName)s [options] [tests]
-
-Options:
- -h, --help Show this message
- -v, --verbose Verbose output
- -q, --quiet Minimal output
-%(failfast)s%(catchbreak)s%(buffer)s
+MAIN_EXAMPLES = """\
Examples:
- %(progName)s test_module - run tests from test_module
- %(progName)s module.TestClass - run tests from module.TestClass
- %(progName)s module.Class.test_method - run specified test method
-
-[tests] can be a list of any number of test modules, classes and test
-methods.
-
-Alternative Usage: %(progName)s discover [options]
-
-Options:
- -v, --verbose Verbose output
-%(failfast)s%(catchbreak)s%(buffer)s -s directory Directory to start discovery ('.' default)
- -p pattern Pattern to match test files ('test*.py' default)
- -t directory Top level directory of project (default to
- start directory)
-
-For test discovery all test modules must be importable from the top
-level directory of the project.
+ %(prog)s test_module - run tests from test_module
+ %(prog)s module.TestClass - run tests from module.TestClass
+ %(prog)s module.Class.test_method - run specified test method
"""
-USAGE_FROM_MODULE = """\
-Usage: %(progName)s [options] [test] [...]
-
-Options:
- -h, --help Show this message
- -v, --verbose Verbose output
- -q, --quiet Minimal output
-%(failfast)s%(catchbreak)s%(buffer)s
+MODULE_EXAMPLES = """\
Examples:
- %(progName)s - run default set of tests
- %(progName)s MyTestSuite - run suite 'MyTestSuite'
- %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
- %(progName)s MyTestCase - run all 'test*' test methods
- in MyTestCase
+ %(prog)s - run default set of tests
+ %(prog)s MyTestSuite - run suite 'MyTestSuite'
+ %(prog)s MyTestCase.testSomething - run MyTestCase.testSomething
+ %(prog)s MyTestCase - run all 'test*' test methods
+ in MyTestCase
"""
def _convert_name(name):
@@ -82,10 +49,11 @@ class TestProgram(object):
"""A command-line program that runs a set of tests; this is primarily
for making test modules conveniently executable.
"""
- USAGE = USAGE_FROM_MODULE
-
# defaults for testing
+ module=None
+ verbosity = 1
failfast = catchbreak = buffer = progName = warnings = None
+ _discovery_parser = None
def __init__(self, module='__main__', defaultTest=None, argv=None,
testRunner=None, testLoader=loader.defaultTestLoader,
@@ -127,44 +95,47 @@ class TestProgram(object):
def usageExit(self, msg=None):
if msg:
print(msg)
- usage = {'progName': self.progName, 'catchbreak': '', 'failfast': '',
- 'buffer': ''}
- if self.failfast != False:
- usage['failfast'] = FAILFAST
- if self.catchbreak != False:
- usage['catchbreak'] = CATCHBREAK
- if self.buffer != False:
- usage['buffer'] = BUFFEROUTPUT
- print(self.USAGE % usage)
+ if self._discovery_parser is None:
+ self._initArgParsers()
+ self._print_help()
sys.exit(2)
- def parseArgs(self, argv):
- if ((len(argv) > 1 and argv[1].lower() == 'discover') or
- (len(argv) == 1 and self.module is None)):
- self._do_discovery(argv[2:])
- return
-
- parser = self._getOptParser()
- options, args = parser.parse_args(argv[1:])
- self._setAttributesFromOptions(options)
+ def _print_help(self, *args, **kwargs):
+ if self.module is None:
+ print(self._main_parser.format_help())
+ print(MAIN_EXAMPLES % {'prog': self.progName})
+ self._discovery_parser.print_help()
+ else:
+ print(self._main_parser.format_help())
+ print(MODULE_EXAMPLES % {'prog': self.progName})
- if len(args) == 0 and self.module is None:
- # this allows "python -m unittest -v" to still work for
- # test discovery. This means -c / -b / -v / -f options will
- # be handled twice, which is harmless but not ideal.
- self._do_discovery(argv[1:])
- return
+ def parseArgs(self, argv):
+ self._initArgParsers()
+ if self.module is None:
+ if len(argv) > 1 and argv[1].lower() == 'discover':
+ self._do_discovery(argv[2:])
+ return
+ self._main_parser.parse_args(argv[1:], self)
+ if not self.tests:
+ # this allows "python -m unittest -v" to still work for
+ # test discovery.
+ self._do_discovery([])
+ return
+ else:
+ self._main_parser.parse_args(argv[1:], self)
- if len(args) == 0 and self.defaultTest is None:
- # createTests will load tests from self.module
- self.testNames = None
- elif len(args) > 0:
- self.testNames = _convert_names(args)
+ if self.tests:
+ self.testNames = _convert_names(self.tests)
if __name__ == '__main__':
# to support python -m unittest ...
self.module = None
- else:
+ elif self.defaultTest is None:
+ # createTests will load tests from self.module
+ self.testNames = None
+ elif isinstance(self.defaultTest, str):
self.testNames = (self.defaultTest,)
+ else:
+ self.testNames = list(self.defaultTest)
self.createTests()
def createTests(self):
@@ -174,76 +145,84 @@ class TestProgram(object):
self.test = self.testLoader.loadTestsFromNames(self.testNames,
self.module)
- def _getOptParser(self):
- import optparse
- parser = optparse.OptionParser()
- parser.prog = self.progName
- parser.add_option('-v', '--verbose', dest='verbose', default=False,
- help='Verbose output', action='store_true')
- parser.add_option('-q', '--quiet', dest='quiet', default=False,
- help='Quiet output', action='store_true')
+ def _initArgParsers(self):
+ parent_parser = self._getParentArgParser()
+ self._main_parser = self._getMainArgParser(parent_parser)
+ self._discovery_parser = self._getDiscoveryArgParser(parent_parser)
- if self.failfast != False:
- parser.add_option('-f', '--failfast', dest='failfast', default=False,
- help='Stop on first fail or error',
- action='store_true')
- if self.catchbreak != False:
- parser.add_option('-c', '--catch', dest='catchbreak', default=False,
- help='Catch ctrl-C and display results so far',
- action='store_true')
- if self.buffer != False:
- parser.add_option('-b', '--buffer', dest='buffer', default=False,
- help='Buffer stdout and stderr during tests',
- action='store_true')
- return parser
+ def _getParentArgParser(self):
+ parser = argparse.ArgumentParser(add_help=False)
+
+ parser.add_argument('-v', '--verbose', dest='verbosity',
+ action='store_const', const=2,
+ help='Verbose output')
+ parser.add_argument('-q', '--quiet', dest='verbosity',
+ action='store_const', const=0,
+ help='Quiet output')
- def _setAttributesFromOptions(self, options):
- # only set options from the parsing here
- # if they weren't set explicitly in the constructor
if self.failfast is None:
- self.failfast = options.failfast
+ parser.add_argument('-f', '--failfast', dest='failfast',
+ action='store_true',
+ help='Stop on first fail or error')
+ self.failfast = False
if self.catchbreak is None:
- self.catchbreak = options.catchbreak
+ parser.add_argument('-c', '--catch', dest='catchbreak',
+ action='store_true',
+ help='Catch ctrl-C and display results so far')
+ self.catchbreak = False
if self.buffer is None:
- self.buffer = options.buffer
-
- if options.verbose:
- self.verbosity = 2
- elif options.quiet:
- self.verbosity = 0
+ parser.add_argument('-b', '--buffer', dest='buffer',
+ action='store_true',
+ help='Buffer stdout and stderr during tests')
+ self.buffer = False
- def _addDiscoveryOptions(self, parser):
- parser.add_option('-s', '--start-directory', dest='start', default='.',
- help="Directory to start discovery ('.' default)")
- parser.add_option('-p', '--pattern', dest='pattern', default='test*.py',
- help="Pattern to match tests ('test*.py' default)")
- parser.add_option('-t', '--top-level-directory', dest='top', default=None,
- help='Top level directory of project (defaults to start directory)')
-
- def _do_discovery(self, argv, Loader=None):
- if Loader is None:
- Loader = lambda: self.testLoader
+ return parser
- # handle command line args for test discovery
- self.progName = '%s discover' % self.progName
- parser = self._getOptParser()
- self._addDiscoveryOptions(parser)
+ def _getMainArgParser(self, parent):
+ parser = argparse.ArgumentParser(parents=[parent])
+ parser.prog = self.progName
+ parser.print_help = self._print_help
- options, args = parser.parse_args(argv)
- if len(args) > 3:
- self.usageExit()
+ parser.add_argument('tests', nargs='*',
+ help='a list of any number of test modules, '
+ 'classes and test methods.')
- for name, value in zip(('start', 'pattern', 'top'), args):
- setattr(options, name, value)
+ return parser
- self._setAttributesFromOptions(options)
+ def _getDiscoveryArgParser(self, parent):
+ parser = argparse.ArgumentParser(parents=[parent])
+ parser.prog = '%s discover' % self.progName
+ parser.epilog = ('For test discovery all test modules must be '
+ 'importable from the top level directory of the '
+ 'project.')
+
+ parser.add_argument('-s', '--start-directory', dest='start',
+ help="Directory to start discovery ('.' default)")
+ parser.add_argument('-p', '--pattern', dest='pattern',
+ help="Pattern to match tests ('test*.py' default)")
+ parser.add_argument('-t', '--top-level-directory', dest='top',
+ help='Top level directory of project (defaults to '
+ 'start directory)')
+ for arg in ('start', 'pattern', 'top'):
+ parser.add_argument(arg, nargs='?',
+ default=argparse.SUPPRESS,
+ help=argparse.SUPPRESS)
- start_dir = options.start
- pattern = options.pattern
- top_level_dir = options.top
+ return parser
- loader = Loader()
- self.test = loader.discover(start_dir, pattern, top_level_dir)
+ def _do_discovery(self, argv, Loader=None):
+ self.start = '.'
+ self.pattern = 'test*.py'
+ self.top = None
+ if argv is not None:
+ # handle command line args for test discovery
+ if self._discovery_parser is None:
+ # for testing
+ self._initArgParsers()
+ self._discovery_parser.parse_args(argv, self)
+
+ loader = self.testLoader if Loader is None else Loader()
+ self.test = loader.discover(self.start, self.pattern, self.top)
def runTests(self):
if self.catchbreak:
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 073869a1f1..dc5c033739 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -27,7 +27,7 @@ __version__ = '1.0'
import inspect
import pprint
import sys
-from functools import wraps
+from functools import wraps, partial
BaseExceptions = (BaseException,)
@@ -66,55 +66,45 @@ DescriptorTypes = (
)
-def _getsignature(func, skipfirst, instance=False):
- if isinstance(func, type) and not instance:
+def _get_signature_object(func, as_instance, eat_self):
+ """
+ Given an arbitrary, possibly callable object, try to create a suitable
+ signature object.
+ Return a (reduced func, signature) tuple, or None.
+ """
+ if isinstance(func, type) and not as_instance:
+ # If it's a type and should be modelled as a type, use __init__.
try:
func = func.__init__
except AttributeError:
- return
- skipfirst = True
+ return None
+ # Skip the `self` argument in __init__
+ eat_self = True
elif not isinstance(func, FunctionTypes):
- # for classes where instance is True we end up here too
+ # If we really want to model an instance of the passed type,
+ # __call__ should be looked up, not __init__.
try:
func = func.__call__
except AttributeError:
- return
-
+ return None
+ if eat_self:
+ sig_func = partial(func, None)
+ else:
+ sig_func = func
try:
- argspec = inspect.getfullargspec(func)
- except TypeError:
- # C function / method, possibly inherited object().__init__
- return
-
- regargs, varargs, varkw, defaults, kwonly, kwonlydef, ann = argspec
-
-
- # instance methods and classmethods need to lose the self argument
- if getattr(func, '__self__', None) is not None:
- regargs = regargs[1:]
- if skipfirst:
- # this condition and the above one are never both True - why?
- regargs = regargs[1:]
-
- signature = inspect.formatargspec(
- regargs, varargs, varkw, defaults,
- kwonly, kwonlydef, ann, formatvalue=lambda value: "")
- return signature[1:-1], func
+ return func, inspect.signature(sig_func)
+ except ValueError:
+ # Certain callable types are not supported by inspect.signature()
+ return None
def _check_signature(func, mock, skipfirst, instance=False):
- if not _callable(func):
+ sig = _get_signature_object(func, instance, skipfirst)
+ if sig is None:
return
-
- result = _getsignature(func, skipfirst, instance)
- if result is None:
- return
- signature, func = result
-
- # can't use self because "self" is common as an argument name
- # unfortunately even not in the first place
- src = "lambda _mock_self, %s: None" % signature
- checksig = eval(src, {})
+ func, sig = sig
+ def checksig(_mock_self, *args, **kwargs):
+ sig.bind(*args, **kwargs)
_copy_func_details(func, checksig)
type(mock)._mock_check_sig = checksig
@@ -166,15 +156,12 @@ def _set_signature(mock, original, instance=False):
return
skipfirst = isinstance(original, type)
- result = _getsignature(original, skipfirst, instance)
+ result = _get_signature_object(original, instance, skipfirst)
if result is None:
- # was a C function (e.g. object().__init__ ) that can't be mocked
return
-
- signature, func = result
-
- src = "lambda %s: None" % signature
- checksig = eval(src, {})
+ func, sig = result
+ def checksig(*args, **kwargs):
+ sig.bind(*args, **kwargs)
_copy_func_details(func, checksig)
name = original.__name__
@@ -368,7 +355,7 @@ class NonCallableMock(Base):
def __init__(
self, spec=None, wraps=None, name=None, spec_set=None,
parent=None, _spec_state=None, _new_name='', _new_parent=None,
- **kwargs
+ _spec_as_instance=False, _eat_self=None, **kwargs
):
if _new_parent is None:
_new_parent = parent
@@ -382,8 +369,10 @@ class NonCallableMock(Base):
if spec_set is not None:
spec = spec_set
spec_set = True
+ if _eat_self is None:
+ _eat_self = parent is not None
- self._mock_add_spec(spec, spec_set)
+ self._mock_add_spec(spec, spec_set, _spec_as_instance, _eat_self)
__dict__['_mock_children'] = {}
__dict__['_mock_wraps'] = wraps
@@ -428,20 +417,26 @@ class NonCallableMock(Base):
self._mock_add_spec(spec, spec_set)
- def _mock_add_spec(self, spec, spec_set):
+ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
+ _eat_self=False):
_spec_class = None
+ _spec_signature = None
if spec is not None and not _is_list(spec):
if isinstance(spec, type):
_spec_class = spec
else:
_spec_class = _get_class(spec)
+ res = _get_signature_object(spec,
+ _spec_as_instance, _eat_self)
+ _spec_signature = res and res[1]
spec = dir(spec)
__dict__ = self.__dict__
__dict__['_spec_class'] = _spec_class
__dict__['_spec_set'] = spec_set
+ __dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
@@ -695,7 +690,6 @@ class NonCallableMock(Base):
self._mock_children[name] = _deleted
-
def _format_mock_call_signature(self, args, kwargs):
name = self._mock_name or 'mock'
return _format_call_signature(name, args, kwargs)
@@ -711,6 +705,28 @@ class NonCallableMock(Base):
return message % (expected_string, actual_string)
+ def _call_matcher(self, _call):
+ """
+ Given a call (or simply a (args, kwargs) tuple), return a
+ comparison key suitable for matching with other calls.
+ This is a best effort method which relies on the spec's signature,
+ if available, or falls back on the arguments themselves.
+ """
+ sig = self._spec_signature
+ if sig is not None:
+ if len(_call) == 2:
+ name = ''
+ args, kwargs = _call
+ else:
+ name, args, kwargs = _call
+ try:
+ return name, sig.bind(*args, **kwargs)
+ except TypeError as e:
+ return e.with_traceback(None)
+ else:
+ return _call
+
+
def assert_called_with(_mock_self, *args, **kwargs):
"""assert that the mock was called with the specified arguments.
@@ -721,9 +737,14 @@ class NonCallableMock(Base):
expected = self._format_mock_call_signature(args, kwargs)
raise AssertionError('Expected call: %s\nNot called' % (expected,))
- if self.call_args != (args, kwargs):
+ def _error_message():
msg = self._format_mock_failure_message(args, kwargs)
- raise AssertionError(msg)
+ return msg
+ expected = self._call_matcher((args, kwargs))
+ actual = self._call_matcher(self.call_args)
+ if expected != actual:
+ cause = expected if isinstance(expected, Exception) else None
+ raise AssertionError(_error_message()) from cause
def assert_called_once_with(_mock_self, *args, **kwargs):
@@ -747,18 +768,21 @@ class NonCallableMock(Base):
If `any_order` is True then the calls can be in any order, but
they must all appear in `mock_calls`."""
+ expected = [self._call_matcher(c) for c in calls]
+ cause = expected if isinstance(expected, Exception) else None
+ all_calls = _CallList(self._call_matcher(c) for c in self.mock_calls)
if not any_order:
- if calls not in self.mock_calls:
+ if expected not in all_calls:
raise AssertionError(
'Calls not found.\nExpected: %r\n'
'Actual: %r' % (calls, self.mock_calls)
- )
+ ) from cause
return
- all_calls = list(self.mock_calls)
+ all_calls = list(all_calls)
not_found = []
- for kall in calls:
+ for kall in expected:
try:
all_calls.remove(kall)
except ValueError:
@@ -766,7 +790,7 @@ class NonCallableMock(Base):
if not_found:
raise AssertionError(
'%r not all found in call list' % (tuple(not_found),)
- )
+ ) from cause
def assert_any_call(self, *args, **kwargs):
@@ -775,12 +799,14 @@ class NonCallableMock(Base):
The assert passes if the mock has *ever* been called, unlike
`assert_called_with` and `assert_called_once_with` that only pass if
the call is the most recent one."""
- kall = call(*args, **kwargs)
- if kall not in self.call_args_list:
+ expected = self._call_matcher((args, kwargs))
+ actual = [self._call_matcher(c) for c in self.call_args_list]
+ if expected not in actual:
+ cause = expected if isinstance(expected, Exception) else None
expected_string = self._format_mock_call_signature(args, kwargs)
raise AssertionError(
'%s call not found' % expected_string
- )
+ ) from cause
def _get_child_mock(self, **kw):
@@ -850,11 +876,12 @@ class CallableMixin(Base):
self = _mock_self
self.called = True
self.call_count += 1
- self.call_args = _Call((args, kwargs), two=True)
- self.call_args_list.append(_Call((args, kwargs), two=True))
-
_new_name = self._mock_new_name
_new_parent = self._mock_new_parent
+
+ _call = _Call((args, kwargs), two=True)
+ self.call_args = _call
+ self.call_args_list.append(_call)
self.mock_calls.append(_Call(('', args, kwargs)))
seen = set()
@@ -909,8 +936,6 @@ class CallableMixin(Base):
return result
ret_val = effect(*args, **kwargs)
- if ret_val is DEFAULT:
- ret_val = self.return_value
if (self._mock_wraps is not None and
self._mock_return_value is DEFAULT):
@@ -2030,6 +2055,8 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
elif spec is None:
# None we mock with a normal mock without a spec
_kwargs = {}
+ if _kwargs and instance:
+ _kwargs['_spec_as_instance'] = True
_kwargs.update(kwargs)
@@ -2096,10 +2123,12 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
if isinstance(spec, FunctionTypes):
parent = mock.mock
+ skipfirst = _must_skip(spec, entry, is_type)
+ kwargs['_eat_self'] = skipfirst
new = MagicMock(parent=parent, name=entry, _new_name=entry,
- _new_parent=parent, **kwargs)
+ _new_parent=parent,
+ **kwargs)
mock._mock_children[entry] = new
- skipfirst = _must_skip(spec, entry, is_type)
_check_signature(original, new, skipfirst=skipfirst)
# so functions created with _set_signature become instance attributes,
@@ -2113,6 +2142,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
def _must_skip(spec, entry, is_type):
+ """
+ Return whether we should skip the first argument on spec's `entry`
+ attribute.
+ """
if not isinstance(spec, type):
if entry in getattr(spec, '__dict__', {}):
# instance attribute - shouldn't skip
@@ -2125,7 +2158,12 @@ def _must_skip(spec, entry, is_type):
continue
if isinstance(result, (staticmethod, classmethod)):
return False
- return is_type
+ elif isinstance(getattr(result, '__get__', None), MethodWrapperTypes):
+ # Normal method => skip if looked up on type
+ # (if looked up on instance, self is already skipped)
+ return is_type
+ else:
+ return False
# shouldn't get here unless function is a dynamically provided attribute
# XXXX untested behaviour
@@ -2159,9 +2197,31 @@ FunctionTypes = (
type(ANY.__eq__),
)
+MethodWrapperTypes = (
+ type(ANY.__eq__.__get__),
+)
+
file_spec = None
+def _iterate_read_data(read_data):
+ # Helper for mock_open:
+ # Retrieve lines from read_data via a generator so that separate calls to
+ # readline, read, and readlines are properly interleaved
+ data_as_list = ['{}\n'.format(l) for l in read_data.split('\n')]
+
+ if data_as_list[-1] == '\n':
+ # If the last line ended in a newline, the list comprehension will have an
+ # extra entry that's just a newline. Remove this.
+ data_as_list = data_as_list[:-1]
+ else:
+ # If there wasn't an extra newline by itself, then the file being
+ # emulated doesn't have a newline to end the last line remove the
+ # newline that our naive format() added
+ data_as_list[-1] = data_as_list[-1][:-1]
+
+ for line in data_as_list:
+ yield line
def mock_open(mock=None, read_data=''):
"""
@@ -2172,9 +2232,27 @@ def mock_open(mock=None, read_data=''):
default) then a `MagicMock` will be created for you, with the API limited
to methods or attributes available on standard file handles.
- `read_data` is a string for the `read` method of the file handle to return.
- This is an empty string by default.
+ `read_data` is a string for the `read` methoddline`, and `readlines` of the
+ file handle to return. This is an empty string by default.
"""
+ def _readlines_side_effect(*args, **kwargs):
+ if handle.readlines.return_value is not None:
+ return handle.readlines.return_value
+ return list(_data)
+
+ def _read_side_effect(*args, **kwargs):
+ if handle.read.return_value is not None:
+ return handle.read.return_value
+ return ''.join(_data)
+
+ def _readline_side_effect():
+ if handle.readline.return_value is not None:
+ while True:
+ yield handle.readline.return_value
+ for line in _data:
+ yield line
+
+
global file_spec
if file_spec is None:
import _io
@@ -2184,9 +2262,18 @@ def mock_open(mock=None, read_data=''):
mock = MagicMock(name='open', spec=open)
handle = MagicMock(spec=file_spec)
- handle.write.return_value = None
handle.__enter__.return_value = handle
- handle.read.return_value = read_data
+
+ _data = _iterate_read_data(read_data)
+
+ handle.write.return_value = None
+ handle.read.return_value = None
+ handle.readline.return_value = None
+ handle.readlines.return_value = None
+
+ handle.read.side_effect = _read_side_effect
+ handle.readline.side_effect = _readline_side_effect()
+ handle.readlines.side_effect = _readlines_side_effect
mock.return_value = handle
return mock
diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py
index 97e5426927..f3f4b676a3 100644
--- a/Lib/unittest/result.py
+++ b/Lib/unittest/result.py
@@ -121,6 +121,22 @@ class TestResult(object):
self.failures.append((test, self._exc_info_to_string(err, test)))
self._mirrorOutput = True
+ @failfast
+ def addSubTest(self, test, subtest, err):
+ """Called at the end of a subtest.
+ 'err' is None if the subtest ended successfully, otherwise it's a
+ tuple of values as returned by sys.exc_info().
+ """
+ # By default, we don't do anything with successful subtests, but
+ # more sophisticated test results might want to record them.
+ if err is not None:
+ if issubclass(err[0], test.failureException):
+ errors = self.failures
+ else:
+ errors = self.errors
+ errors.append((subtest, self._exc_info_to_string(err, test)))
+ self._mirrorOutput = True
+
def addSuccess(self, test):
"Called when a test has completed successfully"
pass
diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py
index cde5d385ed..ca82765b9c 100644
--- a/Lib/unittest/suite.py
+++ b/Lib/unittest/suite.py
@@ -16,6 +16,8 @@ def _call_if_exists(parent, attr):
class BaseTestSuite(object):
"""A simple test suite that doesn't provide class or module shared fixtures.
"""
+ _cleanup = True
+
def __init__(self, tests=()):
self._tests = []
self.addTests(tests)
@@ -57,12 +59,22 @@ class BaseTestSuite(object):
self.addTest(test)
def run(self, result):
- for test in self:
+ for index, test in enumerate(self):
if result.shouldStop:
break
test(result)
+ if self._cleanup:
+ self._removeTestAtIndex(index)
return result
+ def _removeTestAtIndex(self, index):
+ """Stop holding a reference to the TestCase at index."""
+ try:
+ self._tests[index] = None
+ except TypeError:
+ # support for suite implementations that have overriden self._test
+ pass
+
def __call__(self, *args, **kwds):
return self.run(*args, **kwds)
@@ -87,7 +99,7 @@ class TestSuite(BaseTestSuite):
if getattr(result, '_testRunEntered', False) is False:
result._testRunEntered = topLevel = True
- for test in self:
+ for index, test in enumerate(self):
if result.shouldStop:
break
@@ -106,6 +118,9 @@ class TestSuite(BaseTestSuite):
else:
test.debug()
+ if self._cleanup:
+ self._removeTestAtIndex(index)
+
if topLevel:
self._tearDownPreviousClass(None, result)
self._handleModuleTearDown(result)
diff --git a/Lib/unittest/test/__main__.py b/Lib/unittest/test/__main__.py
new file mode 100644
index 0000000000..44d0591e84
--- /dev/null
+++ b/Lib/unittest/test/__main__.py
@@ -0,0 +1,18 @@
+import os
+import unittest
+
+
+def load_tests(loader, standard_tests, pattern):
+ # top level directory cached on loader instance
+ this_dir = os.path.dirname(__file__)
+ pattern = pattern or "test_*.py"
+ # We are inside unittest.test, so the top-level is two notches up
+ top_level_dir = os.path.dirname(os.path.dirname(this_dir))
+ package_tests = loader.discover(start_dir=this_dir, pattern=pattern,
+ top_level_dir=top_level_dir)
+ standard_tests.addTests(package_tests)
+ return standard_tests
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/unittest/test/support.py b/Lib/unittest/test/support.py
index dbe4ddcd0e..02e8f3a00b 100644
--- a/Lib/unittest/test/support.py
+++ b/Lib/unittest/test/support.py
@@ -41,7 +41,7 @@ class TestHashing(object):
self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e))
-class LoggingResult(unittest.TestResult):
+class _BaseLoggingResult(unittest.TestResult):
def __init__(self, log):
self._events = log
super().__init__()
@@ -52,7 +52,7 @@ class LoggingResult(unittest.TestResult):
def startTestRun(self):
self._events.append('startTestRun')
- super(LoggingResult, self).startTestRun()
+ super().startTestRun()
def stopTest(self, test):
self._events.append('stopTest')
@@ -60,7 +60,7 @@ class LoggingResult(unittest.TestResult):
def stopTestRun(self):
self._events.append('stopTestRun')
- super(LoggingResult, self).stopTestRun()
+ super().stopTestRun()
def addFailure(self, *args):
self._events.append('addFailure')
@@ -68,7 +68,7 @@ class LoggingResult(unittest.TestResult):
def addSuccess(self, *args):
self._events.append('addSuccess')
- super(LoggingResult, self).addSuccess(*args)
+ super().addSuccess(*args)
def addError(self, *args):
self._events.append('addError')
@@ -76,15 +76,39 @@ class LoggingResult(unittest.TestResult):
def addSkip(self, *args):
self._events.append('addSkip')
- super(LoggingResult, self).addSkip(*args)
+ super().addSkip(*args)
def addExpectedFailure(self, *args):
self._events.append('addExpectedFailure')
- super(LoggingResult, self).addExpectedFailure(*args)
+ super().addExpectedFailure(*args)
def addUnexpectedSuccess(self, *args):
self._events.append('addUnexpectedSuccess')
- super(LoggingResult, self).addUnexpectedSuccess(*args)
+ super().addUnexpectedSuccess(*args)
+
+
+class LegacyLoggingResult(_BaseLoggingResult):
+ """
+ A legacy TestResult implementation, without an addSubTest method,
+ which records its method calls.
+ """
+
+ @property
+ def addSubTest(self):
+ raise AttributeError
+
+
+class LoggingResult(_BaseLoggingResult):
+ """
+ A TestResult implementation which records its method calls.
+ """
+
+ def addSubTest(self, test, subtest, err):
+ if err is None:
+ self._events.append('addSubTestSuccess')
+ else:
+ self._events.append('addSubTestFailure')
+ super().addSubTest(test, subtest, err)
class ResultWithNoStartTestRunStopTestRun(object):
diff --git a/Lib/unittest/test/test_assertions.py b/Lib/unittest/test/test_assertions.py
index 7931cadaf3..af08d5ad65 100644
--- a/Lib/unittest/test/test_assertions.py
+++ b/Lib/unittest/test/test_assertions.py
@@ -361,3 +361,7 @@ class TestLongMessage(unittest.TestCase):
['^"regex" does not match "foo"$', '^oops$',
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$'])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_break.py b/Lib/unittest/test/test_break.py
index 75532f42d3..0bf1a229b8 100644
--- a/Lib/unittest/test/test_break.py
+++ b/Lib/unittest/test/test_break.py
@@ -282,3 +282,7 @@ class TestBreakSignalIgnored(TestBreak):
"if threads have been used")
class TestBreakSignalDefault(TestBreak):
int_handler = signal.SIG_DFL
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py
index b8cb0c76b7..4b93179335 100644
--- a/Lib/unittest/test/test_case.py
+++ b/Lib/unittest/test/test_case.py
@@ -1,8 +1,10 @@
+import contextlib
import difflib
import pprint
import pickle
import re
import sys
+import logging
import warnings
import weakref
import inspect
@@ -13,9 +15,15 @@ from test import support
import unittest
from .support import (
- TestEquality, TestHashing, LoggingResult,
+ TestEquality, TestHashing, LoggingResult, LegacyLoggingResult,
ResultWithNoStartTestRunStopTestRun
)
+from test.support import captured_stderr
+
+
+log_foo = logging.getLogger('foo')
+log_foobar = logging.getLogger('foo.bar')
+log_quux = logging.getLogger('quux')
class Test(object):
@@ -297,6 +305,98 @@ class Test_TestCase(unittest.TestCase, TestEquality, TestHashing):
Foo('test').run()
+ def _check_call_order__subtests(self, result, events, expected_events):
+ class Foo(Test.LoggingTestCase):
+ def test(self):
+ super(Foo, self).test()
+ for i in [1, 2, 3]:
+ with self.subTest(i=i):
+ if i == 1:
+ self.fail('failure')
+ for j in [2, 3]:
+ with self.subTest(j=j):
+ if i * j == 6:
+ raise RuntimeError('raised by Foo.test')
+ 1 / 0
+
+ # Order is the following:
+ # i=1 => subtest failure
+ # i=2, j=2 => subtest success
+ # i=2, j=3 => subtest error
+ # i=3, j=2 => subtest error
+ # i=3, j=3 => subtest success
+ # toplevel => error
+ Foo(events).run(result)
+ self.assertEqual(events, expected_events)
+
+ def test_run_call_order__subtests(self):
+ events = []
+ result = LoggingResult(events)
+ expected = ['startTest', 'setUp', 'test', 'tearDown',
+ 'addSubTestFailure', 'addSubTestSuccess',
+ 'addSubTestFailure', 'addSubTestFailure',
+ 'addSubTestSuccess', 'addError', 'stopTest']
+ self._check_call_order__subtests(result, events, expected)
+
+ def test_run_call_order__subtests_legacy(self):
+ # With a legacy result object (without a addSubTest method),
+ # text execution stops after the first subtest failure.
+ events = []
+ result = LegacyLoggingResult(events)
+ expected = ['startTest', 'setUp', 'test', 'tearDown',
+ 'addFailure', 'stopTest']
+ self._check_call_order__subtests(result, events, expected)
+
+ def _check_call_order__subtests_success(self, result, events, expected_events):
+ class Foo(Test.LoggingTestCase):
+ def test(self):
+ super(Foo, self).test()
+ for i in [1, 2]:
+ with self.subTest(i=i):
+ for j in [2, 3]:
+ with self.subTest(j=j):
+ pass
+
+ Foo(events).run(result)
+ self.assertEqual(events, expected_events)
+
+ def test_run_call_order__subtests_success(self):
+ events = []
+ result = LoggingResult(events)
+ # The 6 subtest successes are individually recorded, in addition
+ # to the whole test success.
+ expected = (['startTest', 'setUp', 'test', 'tearDown']
+ + 6 * ['addSubTestSuccess']
+ + ['addSuccess', 'stopTest'])
+ self._check_call_order__subtests_success(result, events, expected)
+
+ def test_run_call_order__subtests_success_legacy(self):
+ # With a legacy result, only the whole test success is recorded.
+ events = []
+ result = LegacyLoggingResult(events)
+ expected = ['startTest', 'setUp', 'test', 'tearDown',
+ 'addSuccess', 'stopTest']
+ self._check_call_order__subtests_success(result, events, expected)
+
+ def test_run_call_order__subtests_failfast(self):
+ events = []
+ result = LoggingResult(events)
+ result.failfast = True
+
+ class Foo(Test.LoggingTestCase):
+ def test(self):
+ super(Foo, self).test()
+ with self.subTest(i=1):
+ self.fail('failure')
+ with self.subTest(i=2):
+ self.fail('failure')
+ self.fail('failure')
+
+ expected = ['startTest', 'setUp', 'test', 'tearDown',
+ 'addSubTestFailure', 'stopTest']
+ Foo(events).run(result)
+ self.assertEqual(events, expected)
+
# "This class attribute gives the exception raised by the test() method.
# If a test framework needs to use a specialized exception, possibly to
# carry additional information, it must subclass this exception in
@@ -729,18 +829,18 @@ class Test_TestCase(unittest.TestCase, TestEquality, TestHashing):
# set a lower threshold value and add a cleanup to restore it
old_threshold = self._diffThreshold
- self._diffThreshold = 2**8
+ self._diffThreshold = 2**5
self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold))
# under the threshold: diff marker (^) in error message
- s = 'x' * (2**7)
+ s = 'x' * (2**4)
with self.assertRaises(self.failureException) as cm:
self.assertEqual(s + 'a', s + 'b')
self.assertIn('^', str(cm.exception))
self.assertEqual(s + 'a', s + 'a')
# over the threshold: diff not used and marker (^) not in error message
- s = 'x' * (2**9)
+ s = 'x' * (2**6)
# if the path that uses difflib is taken, _truncateMessage will be
# called -- replace it with explodingTruncation to verify that this
# doesn't happen
@@ -757,6 +857,37 @@ class Test_TestCase(unittest.TestCase, TestEquality, TestHashing):
self.assertEqual(str(cm.exception), '%r != %r' % (s1, s2))
self.assertEqual(s + 'a', s + 'a')
+ def testAssertEqual_shorten(self):
+ # set a lower threshold value and add a cleanup to restore it
+ old_threshold = self._diffThreshold
+ self._diffThreshold = 0
+ self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold))
+
+ s = 'x' * 100
+ s1, s2 = s + 'a', s + 'b'
+ with self.assertRaises(self.failureException) as cm:
+ self.assertEqual(s1, s2)
+ c = 'xxxx[35 chars]' + 'x' * 61
+ self.assertEqual(str(cm.exception), "'%sa' != '%sb'" % (c, c))
+ self.assertEqual(s + 'a', s + 'a')
+
+ p = 'y' * 50
+ s1, s2 = s + 'a' + p, s + 'b' + p
+ with self.assertRaises(self.failureException) as cm:
+ self.assertEqual(s1, s2)
+ c = 'xxxx[85 chars]xxxxxxxxxxx'
+ #print()
+ #print(str(cm.exception))
+ self.assertEqual(str(cm.exception), "'%sa%s' != '%sb%s'" % (c, p, c, p))
+
+ p = 'y' * 100
+ s1, s2 = s + 'a' + p, s + 'b' + p
+ with self.assertRaises(self.failureException) as cm:
+ self.assertEqual(s1, s2)
+ c = 'xxxx[91 chars]xxxxx'
+ d = 'y' * 40 + '[56 chars]yyyy'
+ self.assertEqual(str(cm.exception), "'%sa%s' != '%sb%s'" % (c, d, c, d))
+
def testAssertCountEqual(self):
a = object()
self.assertCountEqual([1, 2, 3], [3, 2, 1])
@@ -1159,6 +1290,94 @@ test case
with self.assertWarnsRegex(RuntimeWarning, "o+"):
_runtime_warn("barz")
+ @contextlib.contextmanager
+ def assertNoStderr(self):
+ with captured_stderr() as buf:
+ yield
+ self.assertEqual(buf.getvalue(), "")
+
+ def assertLogRecords(self, records, matches):
+ self.assertEqual(len(records), len(matches))
+ for rec, match in zip(records, matches):
+ self.assertIsInstance(rec, logging.LogRecord)
+ for k, v in match.items():
+ self.assertEqual(getattr(rec, k), v)
+
+ def testAssertLogsDefaults(self):
+ # defaults: root logger, level INFO
+ with self.assertNoStderr():
+ with self.assertLogs() as cm:
+ log_foo.info("1")
+ log_foobar.debug("2")
+ self.assertEqual(cm.output, ["INFO:foo:1"])
+ self.assertLogRecords(cm.records, [{'name': 'foo'}])
+
+ def testAssertLogsTwoMatchingMessages(self):
+ # Same, but with two matching log messages
+ with self.assertNoStderr():
+ with self.assertLogs() as cm:
+ log_foo.info("1")
+ log_foobar.debug("2")
+ log_quux.warning("3")
+ self.assertEqual(cm.output, ["INFO:foo:1", "WARNING:quux:3"])
+ self.assertLogRecords(cm.records,
+ [{'name': 'foo'}, {'name': 'quux'}])
+
+ def checkAssertLogsPerLevel(self, level):
+ # Check level filtering
+ with self.assertNoStderr():
+ with self.assertLogs(level=level) as cm:
+ log_foo.warning("1")
+ log_foobar.error("2")
+ log_quux.critical("3")
+ self.assertEqual(cm.output, ["ERROR:foo.bar:2", "CRITICAL:quux:3"])
+ self.assertLogRecords(cm.records,
+ [{'name': 'foo.bar'}, {'name': 'quux'}])
+
+ def testAssertLogsPerLevel(self):
+ self.checkAssertLogsPerLevel(logging.ERROR)
+ self.checkAssertLogsPerLevel('ERROR')
+
+ def checkAssertLogsPerLogger(self, logger):
+ # Check per-logger fitering
+ with self.assertNoStderr():
+ with self.assertLogs(level='DEBUG') as outer_cm:
+ with self.assertLogs(logger, level='DEBUG') as cm:
+ log_foo.info("1")
+ log_foobar.debug("2")
+ log_quux.warning("3")
+ self.assertEqual(cm.output, ["INFO:foo:1", "DEBUG:foo.bar:2"])
+ self.assertLogRecords(cm.records,
+ [{'name': 'foo'}, {'name': 'foo.bar'}])
+ # The outer catchall caught the quux log
+ self.assertEqual(outer_cm.output, ["WARNING:quux:3"])
+
+ def testAssertLogsPerLogger(self):
+ self.checkAssertLogsPerLogger(logging.getLogger('foo'))
+ self.checkAssertLogsPerLogger('foo')
+
+ def testAssertLogsFailureNoLogs(self):
+ # Failure due to no logs
+ with self.assertNoStderr():
+ with self.assertRaises(self.failureException):
+ with self.assertLogs():
+ pass
+
+ def testAssertLogsFailureLevelTooHigh(self):
+ # Failure due to level too high
+ with self.assertNoStderr():
+ with self.assertRaises(self.failureException):
+ with self.assertLogs(level='WARNING'):
+ log_foo.info("1")
+
+ def testAssertLogsFailureMismatchingLogger(self):
+ # Failure due to mismatching logger (and the logged message is
+ # passed through)
+ with self.assertLogs('quux', level='ERROR'):
+ with self.assertRaises(self.failureException):
+ with self.assertLogs('foo'):
+ log_quux.error("1")
+
def testDeprecatedMethodNames(self):
"""
Test that the deprecated methods raise a DeprecationWarning. See #9424.
@@ -1313,3 +1532,7 @@ test case
with support.disable_gc():
del case
self.assertFalse(wr())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py
index ccc7db249e..6b7b1280f8 100644
--- a/Lib/unittest/test/test_discovery.py
+++ b/Lib/unittest/test/test_discovery.py
@@ -1,12 +1,15 @@
import os
import re
import sys
+import types
+import builtins
+from test import support
import unittest
class TestableTestProgram(unittest.TestProgram):
- module = '__main__'
+ module = None
exit = True
defaultTest = failfast = catchbreak = buffer = None
verbosity = 1
@@ -46,9 +49,9 @@ class TestDiscovery(unittest.TestCase):
def restore_isdir():
os.path.isdir = original_isdir
- path_lists = [['test1.py', 'test2.py', 'not_a_test.py', 'test_dir',
+ path_lists = [['test2.py', 'test1.py', 'not_a_test.py', 'test_dir',
'test.foo', 'test-not-a-module.py', 'another_dir'],
- ['test3.py', 'test4.py', ]]
+ ['test4.py', 'test3.py', ]]
os.listdir = lambda path: path_lists.pop(0)
self.addCleanup(restore_listdir)
@@ -70,6 +73,8 @@ class TestDiscovery(unittest.TestCase):
loader._top_level_dir = top_level
suite = list(loader._find_tests(top_level, 'test*.py'))
+ # The test suites found should be sorted alphabetically for reliable
+ # execution order.
expected = [name + ' module tests' for name in
('test1', 'test2')]
expected.extend([('test_dir.%s' % name) + ' module tests' for name in
@@ -132,6 +137,7 @@ class TestDiscovery(unittest.TestCase):
# and directly from the test_directory2 package
self.assertEqual(suite,
['load_tests', 'test_directory2' + ' module tests'])
+ # The test module paths should be sorted for reliable execution order
self.assertEqual(Module.paths, ['test_directory', 'test_directory2'])
# load_tests should have been called once with loader, tests and pattern
@@ -169,7 +175,7 @@ class TestDiscovery(unittest.TestCase):
self.addCleanup(restore_isdir)
_find_tests_args = []
- def _find_tests(start_dir, pattern):
+ def _find_tests(start_dir, pattern, namespace=None):
_find_tests_args.append((start_dir, pattern))
return ['tests']
loader._find_tests = _find_tests
@@ -184,11 +190,9 @@ class TestDiscovery(unittest.TestCase):
self.assertEqual(_find_tests_args, [(start_dir, 'pattern')])
self.assertIn(top_level_dir, sys.path)
- def test_discover_with_modules_that_fail_to_import(self):
- loader = unittest.TestLoader()
-
+ def setup_import_issue_tests(self, fakefile):
listdir = os.listdir
- os.listdir = lambda _: ['test_this_does_not_exist.py']
+ os.listdir = lambda _: [fakefile]
isfile = os.path.isfile
os.path.isfile = lambda _: True
orig_sys_path = sys.path[:]
@@ -198,6 +202,11 @@ class TestDiscovery(unittest.TestCase):
sys.path[:] = orig_sys_path
self.addCleanup(restore)
+ def test_discover_with_modules_that_fail_to_import(self):
+ loader = unittest.TestLoader()
+
+ self.setup_import_issue_tests('test_this_does_not_exist.py')
+
suite = loader.discover('.')
self.assertIn(os.getcwd(), sys.path)
self.assertEqual(suite.countTestCases(), 1)
@@ -206,62 +215,74 @@ class TestDiscovery(unittest.TestCase):
with self.assertRaises(ImportError):
test.test_this_does_not_exist()
+ def test_discover_with_module_that_raises_SkipTest_on_import(self):
+ loader = unittest.TestLoader()
+
+ def _get_module_from_name(name):
+ raise unittest.SkipTest('skipperoo')
+ loader._get_module_from_name = _get_module_from_name
+
+ self.setup_import_issue_tests('test_skip_dummy.py')
+
+ suite = loader.discover('.')
+ self.assertEqual(suite.countTestCases(), 1)
+
+ result = unittest.TestResult()
+ suite.run(result)
+ self.assertEqual(len(result.skipped), 1)
+
def test_command_line_handling_parseArgs(self):
program = TestableTestProgram()
args = []
- def do_discovery(argv):
- args.extend(argv)
- program._do_discovery = do_discovery
+ program._do_discovery = args.append
program.parseArgs(['something', 'discover'])
- self.assertEqual(args, [])
+ self.assertEqual(args, [[]])
+ args[:] = []
program.parseArgs(['something', 'discover', 'foo', 'bar'])
- self.assertEqual(args, ['foo', 'bar'])
+ self.assertEqual(args, [['foo', 'bar']])
def test_command_line_handling_discover_by_default(self):
program = TestableTestProgram()
- program.module = None
- self.called = False
- def do_discovery(argv):
- self.called = True
- self.assertEqual(argv, [])
- program._do_discovery = do_discovery
+ args = []
+ program._do_discovery = args.append
program.parseArgs(['something'])
- self.assertTrue(self.called)
+ self.assertEqual(args, [[]])
+ self.assertEqual(program.verbosity, 1)
+ self.assertIs(program.buffer, False)
+ self.assertIs(program.catchbreak, False)
+ self.assertIs(program.failfast, False)
def test_command_line_handling_discover_by_default_with_options(self):
program = TestableTestProgram()
- program.module = None
- args = ['something', '-v', '-b', '-v', '-c', '-f']
- self.called = False
- def do_discovery(argv):
- self.called = True
- self.assertEqual(argv, args[1:])
- program._do_discovery = do_discovery
- program.parseArgs(args)
- self.assertTrue(self.called)
+ args = []
+ program._do_discovery = args.append
+ program.parseArgs(['something', '-v', '-b', '-v', '-c', '-f'])
+ self.assertEqual(args, [[]])
+ self.assertEqual(program.verbosity, 2)
+ self.assertIs(program.buffer, True)
+ self.assertIs(program.catchbreak, True)
+ self.assertIs(program.failfast, True)
def test_command_line_handling_do_discovery_too_many_arguments(self):
- class Stop(Exception):
- pass
- def usageExit():
- raise Stop
-
program = TestableTestProgram()
- program.usageExit = usageExit
program.testLoader = None
- with self.assertRaises(Stop):
+ with support.captured_stderr() as stderr, \
+ self.assertRaises(SystemExit) as cm:
# too many args
program._do_discovery(['one', 'two', 'three', 'four'])
+ self.assertEqual(cm.exception.args, (2,))
+ self.assertIn('usage:', stderr.getvalue())
def test_command_line_handling_do_discovery_uses_default_loader(self):
program = object.__new__(unittest.TestProgram)
+ program._initArgParsers()
class Loader(object):
args = []
@@ -417,7 +438,7 @@ class TestDiscovery(unittest.TestCase):
expectedPath = os.path.abspath(os.path.dirname(unittest.test.__file__))
self.wasRun = False
- def _find_tests(start_dir, pattern):
+ def _find_tests(start_dir, pattern, namespace=None):
self.wasRun = True
self.assertEqual(start_dir, expectedPath)
return tests
@@ -427,5 +448,79 @@ class TestDiscovery(unittest.TestCase):
self.assertEqual(suite._tests, tests)
+ def test_discovery_from_dotted_path_builtin_modules(self):
+
+ loader = unittest.TestLoader()
+
+ listdir = os.listdir
+ os.listdir = lambda _: ['test_this_does_not_exist.py']
+ isfile = os.path.isfile
+ isdir = os.path.isdir
+ os.path.isdir = lambda _: False
+ orig_sys_path = sys.path[:]
+ def restore():
+ os.path.isfile = isfile
+ os.path.isdir = isdir
+ os.listdir = listdir
+ sys.path[:] = orig_sys_path
+ self.addCleanup(restore)
+
+ with self.assertRaises(TypeError) as cm:
+ loader.discover('sys')
+ self.assertEqual(str(cm.exception),
+ 'Can not use builtin modules '
+ 'as dotted module names')
+
+ def test_discovery_from_dotted_namespace_packages(self):
+ loader = unittest.TestLoader()
+
+ orig_import = __import__
+ package = types.ModuleType('package')
+ package.__path__ = ['/a', '/b']
+ package.__spec__ = types.SimpleNamespace(
+ loader=None,
+ submodule_search_locations=['/a', '/b']
+ )
+
+ def _import(packagename, *args, **kwargs):
+ sys.modules[packagename] = package
+ return package
+
+ def cleanup():
+ builtins.__import__ = orig_import
+ self.addCleanup(cleanup)
+ builtins.__import__ = _import
+
+ _find_tests_args = []
+ def _find_tests(start_dir, pattern, namespace=None):
+ _find_tests_args.append((start_dir, pattern))
+ return ['%s/tests' % start_dir]
+
+ loader._find_tests = _find_tests
+ loader.suiteClass = list
+ suite = loader.discover('package')
+ self.assertEqual(suite, ['/a/tests', '/b/tests'])
+
+ def test_discovery_failed_discovery(self):
+ loader = unittest.TestLoader()
+ package = types.ModuleType('package')
+ orig_import = __import__
+
+ def _import(packagename, *args, **kwargs):
+ sys.modules[packagename] = package
+ return package
+
+ def cleanup():
+ builtins.__import__ = orig_import
+ self.addCleanup(cleanup)
+ builtins.__import__ = _import
+
+ with self.assertRaises(TypeError) as cm:
+ loader.discover('package')
+ self.assertEqual(str(cm.exception),
+ 'don\'t know how to discover from {!r}'
+ .format(package))
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/unittest/test/test_functiontestcase.py b/Lib/unittest/test/test_functiontestcase.py
index 9ce5ee3556..d7fe07a53b 100644
--- a/Lib/unittest/test/test_functiontestcase.py
+++ b/Lib/unittest/test/test_functiontestcase.py
@@ -142,3 +142,7 @@ class Test_FunctionTestCase(unittest.TestCase):
test = unittest.FunctionTestCase(lambda: None, description=desc)
self.assertEqual(test.shortDescription(), "this tests foo")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py
index fcd2e07624..b62a1b5c54 100644
--- a/Lib/unittest/test/test_loader.py
+++ b/Lib/unittest/test/test_loader.py
@@ -1306,3 +1306,7 @@ class Test_TestLoader(unittest.TestCase):
def test_suiteClass__default_value(self):
loader = unittest.TestLoader()
self.assertIs(loader.suiteClass, unittest.TestSuite)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py
index 8a4b3fad58..3294298030 100644
--- a/Lib/unittest/test/test_program.py
+++ b/Lib/unittest/test/test_program.py
@@ -2,6 +2,7 @@ import io
import os
import sys
+from test import support
import unittest
@@ -64,6 +65,41 @@ class Test_TestProgram(unittest.TestCase):
return self.suiteClass(
[self.loadTestsFromTestCase(Test_TestProgram.FooBar)])
+ def loadTestsFromNames(self, names, module):
+ return self.suiteClass(
+ [self.loadTestsFromTestCase(Test_TestProgram.FooBar)])
+
+ def test_defaultTest_with_string(self):
+ class FakeRunner(object):
+ def run(self, test):
+ self.test = test
+ return True
+
+ old_argv = sys.argv
+ sys.argv = ['faketest']
+ runner = FakeRunner()
+ program = unittest.TestProgram(testRunner=runner, exit=False,
+ defaultTest='unittest.test',
+ testLoader=self.FooBarLoader())
+ sys.argv = old_argv
+ self.assertEqual(('unittest.test',), program.testNames)
+
+ def test_defaultTest_with_iterable(self):
+ class FakeRunner(object):
+ def run(self, test):
+ self.test = test
+ return True
+
+ old_argv = sys.argv
+ sys.argv = ['faketest']
+ runner = FakeRunner()
+ program = unittest.TestProgram(
+ testRunner=runner, exit=False,
+ defaultTest=['unittest.test', 'unittest.test2'],
+ testLoader=self.FooBarLoader())
+ sys.argv = old_argv
+ self.assertEqual(['unittest.test', 'unittest.test2'],
+ program.testNames)
def test_NonExit(self):
program = unittest.main(exit=False,
@@ -151,20 +187,38 @@ class TestCommandLineArgs(unittest.TestCase):
if attr == 'catch' and not hasInstallHandler:
continue
+ setattr(program, attr, None)
+ program.parseArgs([None])
+ self.assertIs(getattr(program, attr), False)
+
+ false = []
+ setattr(program, attr, false)
+ program.parseArgs([None])
+ self.assertIs(getattr(program, attr), false)
+
+ true = [42]
+ setattr(program, attr, true)
+ program.parseArgs([None])
+ self.assertIs(getattr(program, attr), true)
+
short_opt = '-%s' % arg[0]
long_opt = '--%s' % arg
for opt in short_opt, long_opt:
setattr(program, attr, None)
-
- program.parseArgs([None, opt])
- self.assertTrue(getattr(program, attr))
-
- for opt in short_opt, long_opt:
- not_none = object()
- setattr(program, attr, not_none)
-
program.parseArgs([None, opt])
- self.assertEqual(getattr(program, attr), not_none)
+ self.assertIs(getattr(program, attr), True)
+
+ setattr(program, attr, False)
+ with support.captured_stderr() as stderr, \
+ self.assertRaises(SystemExit) as cm:
+ program.parseArgs([None, opt])
+ self.assertEqual(cm.exception.args, (2,))
+
+ setattr(program, attr, True)
+ with support.captured_stderr() as stderr, \
+ self.assertRaises(SystemExit) as cm:
+ program.parseArgs([None, opt])
+ self.assertEqual(cm.exception.args, (2,))
def testWarning(self):
"""Test the warnings argument"""
diff --git a/Lib/unittest/test/test_result.py b/Lib/unittest/test/test_result.py
index 7d40725cf2..489fe17754 100644
--- a/Lib/unittest/test/test_result.py
+++ b/Lib/unittest/test/test_result.py
@@ -227,6 +227,40 @@ class Test_TestResult(unittest.TestCase):
self.assertIs(test_case, test)
self.assertIsInstance(formatted_exc, str)
+ def test_addSubTest(self):
+ class Foo(unittest.TestCase):
+ def test_1(self):
+ nonlocal subtest
+ with self.subTest(foo=1):
+ subtest = self._subtest
+ try:
+ 1/0
+ except ZeroDivisionError:
+ exc_info_tuple = sys.exc_info()
+ # Register an error by hand (to check the API)
+ result.addSubTest(test, subtest, exc_info_tuple)
+ # Now trigger a failure
+ self.fail("some recognizable failure")
+
+ subtest = None
+ test = Foo('test_1')
+ result = unittest.TestResult()
+
+ test.run(result)
+
+ self.assertFalse(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 1)
+ self.assertEqual(len(result.failures), 1)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ test_case, formatted_exc = result.errors[0]
+ self.assertIs(test_case, subtest)
+ self.assertIn("ZeroDivisionError", formatted_exc)
+ test_case, formatted_exc = result.failures[0]
+ self.assertIs(test_case, subtest)
+ self.assertIn("some recognizable failure", formatted_exc)
+
def testGetDescriptionWithoutDocstring(self):
result = unittest.TextTestResult(None, True, 1)
self.assertEqual(
@@ -234,6 +268,37 @@ class Test_TestResult(unittest.TestCase):
'testGetDescriptionWithoutDocstring (' + __name__ +
'.Test_TestResult)')
+ def testGetSubTestDescriptionWithoutDocstring(self):
+ with self.subTest(foo=1, bar=2):
+ result = unittest.TextTestResult(None, True, 1)
+ self.assertEqual(
+ result.getDescription(self._subtest),
+ 'testGetSubTestDescriptionWithoutDocstring (' + __name__ +
+ '.Test_TestResult) (bar=2, foo=1)')
+ with self.subTest('some message'):
+ result = unittest.TextTestResult(None, True, 1)
+ self.assertEqual(
+ result.getDescription(self._subtest),
+ 'testGetSubTestDescriptionWithoutDocstring (' + __name__ +
+ '.Test_TestResult) [some message]')
+
+ def testGetSubTestDescriptionWithoutDocstringAndParams(self):
+ with self.subTest():
+ result = unittest.TextTestResult(None, True, 1)
+ self.assertEqual(
+ result.getDescription(self._subtest),
+ 'testGetSubTestDescriptionWithoutDocstringAndParams '
+ '(' + __name__ + '.Test_TestResult) (<subtest>)')
+
+ def testGetNestedSubTestDescriptionWithoutDocstring(self):
+ with self.subTest(foo=1):
+ with self.subTest(bar=2):
+ result = unittest.TextTestResult(None, True, 1)
+ self.assertEqual(
+ result.getDescription(self._subtest),
+ 'testGetNestedSubTestDescriptionWithoutDocstring '
+ '(' + __name__ + '.Test_TestResult) (bar=2, foo=1)')
+
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
def testGetDescriptionWithOneLineDocstring(self):
@@ -247,6 +312,18 @@ class Test_TestResult(unittest.TestCase):
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
+ def testGetSubTestDescriptionWithOneLineDocstring(self):
+ """Tests getDescription() for a method with a docstring."""
+ result = unittest.TextTestResult(None, True, 1)
+ with self.subTest(foo=1, bar=2):
+ self.assertEqual(
+ result.getDescription(self._subtest),
+ ('testGetSubTestDescriptionWithOneLineDocstring '
+ '(' + __name__ + '.Test_TestResult) (bar=2, foo=1)\n'
+ 'Tests getDescription() for a method with a docstring.'))
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
def testGetDescriptionWithMultiLineDocstring(self):
"""Tests getDescription() for a method with a longer docstring.
The second line of the docstring.
@@ -259,6 +336,21 @@ class Test_TestResult(unittest.TestCase):
'Tests getDescription() for a method with a longer '
'docstring.'))
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def testGetSubTestDescriptionWithMultiLineDocstring(self):
+ """Tests getDescription() for a method with a longer docstring.
+ The second line of the docstring.
+ """
+ result = unittest.TextTestResult(None, True, 1)
+ with self.subTest(foo=1, bar=2):
+ self.assertEqual(
+ result.getDescription(self._subtest),
+ ('testGetSubTestDescriptionWithMultiLineDocstring '
+ '(' + __name__ + '.Test_TestResult) (bar=2, foo=1)\n'
+ 'Tests getDescription() for a method with a longer '
+ 'docstring.'))
+
def testStackFrameTrimming(self):
class Frame(object):
class tb_frame(object):
diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py
index e22e6bc279..ef1c1af9f1 100644
--- a/Lib/unittest/test/test_runner.py
+++ b/Lib/unittest/test/test_runner.py
@@ -5,6 +5,7 @@ import pickle
import subprocess
import unittest
+from unittest.case import _Outcome
from .support import LoggingResult, ResultWithNoStartTestRunStopTestRun
@@ -42,12 +43,8 @@ class TestCleanUp(unittest.TestCase):
def testNothing(self):
pass
- class MockOutcome(object):
- success = True
- errors = []
-
test = TestableTest('testNothing')
- test._outcomeForDoCleanups = MockOutcome
+ outcome = test._outcome = _Outcome()
exc1 = Exception('foo')
exc2 = Exception('bar')
@@ -61,9 +58,10 @@ class TestCleanUp(unittest.TestCase):
test.addCleanup(cleanup2)
self.assertFalse(test.doCleanups())
- self.assertFalse(MockOutcome.success)
+ self.assertFalse(outcome.success)
- (Type1, instance1, _), (Type2, instance2, _) = reversed(MockOutcome.errors)
+ ((_, (Type1, instance1, _)),
+ (_, (Type2, instance2, _))) = reversed(outcome.errors)
self.assertEqual((Type1, instance1), (Exception, exc1))
self.assertEqual((Type2, instance2), (Exception, exc2))
@@ -341,3 +339,7 @@ class Test_TextTestRunner(unittest.TestCase):
f = io.StringIO()
runner = unittest.TextTestRunner(f)
self.assertTrue(runner.stream.stream is f)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_setups.py b/Lib/unittest/test/test_setups.py
index b8d5aa41e9..392f95efc0 100644
--- a/Lib/unittest/test/test_setups.py
+++ b/Lib/unittest/test/test_setups.py
@@ -494,14 +494,13 @@ class TestSetups(unittest.TestCase):
Test.__module__ = 'Module'
sys.modules['Module'] = Module
- _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
- suite = unittest.TestSuite()
- suite.addTest(_suite)
-
messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something')
for phase, msg in enumerate(messages):
+ _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
+ suite = unittest.TestSuite([_suite])
with self.assertRaisesRegex(Exception, msg):
suite.debug()
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/unittest/test/test_skipping.py b/Lib/unittest/test/test_skipping.py
index 952240eeed..e18caa48ba 100644
--- a/Lib/unittest/test/test_skipping.py
+++ b/Lib/unittest/test/test_skipping.py
@@ -29,6 +29,31 @@ class Test_TestSkipping(unittest.TestCase):
self.assertEqual(result.skipped, [(test, "testing")])
self.assertEqual(result.testsRun, 1)
+ def test_skipping_subtests(self):
+ class Foo(unittest.TestCase):
+ def test_skip_me(self):
+ with self.subTest(a=1):
+ with self.subTest(b=2):
+ self.skipTest("skip 1")
+ self.skipTest("skip 2")
+ self.skipTest("skip 3")
+ events = []
+ result = LoggingResult(events)
+ test = Foo("test_skip_me")
+ test.run(result)
+ self.assertEqual(events, ['startTest', 'addSkip', 'addSkip',
+ 'addSkip', 'stopTest'])
+ self.assertEqual(len(result.skipped), 3)
+ subtest, msg = result.skipped[0]
+ self.assertEqual(msg, "skip 1")
+ self.assertIsInstance(subtest, unittest.TestCase)
+ self.assertIsNot(subtest, test)
+ subtest, msg = result.skipped[1]
+ self.assertEqual(msg, "skip 2")
+ self.assertIsInstance(subtest, unittest.TestCase)
+ self.assertIsNot(subtest, test)
+ self.assertEqual(result.skipped[2], (test, "skip 3"))
+
def test_skipping_decorators(self):
op_table = ((unittest.skipUnless, False, True),
(unittest.skipIf, True, False))
@@ -95,6 +120,31 @@ class Test_TestSkipping(unittest.TestCase):
self.assertEqual(result.expectedFailures[0][0], test)
self.assertTrue(result.wasSuccessful())
+ def test_expected_failure_subtests(self):
+ # A failure in any subtest counts as the expected failure of the
+ # whole test.
+ class Foo(unittest.TestCase):
+ @unittest.expectedFailure
+ def test_die(self):
+ with self.subTest():
+ # This one succeeds
+ pass
+ with self.subTest():
+ self.fail("help me!")
+ with self.subTest():
+ # This one doesn't get executed
+ self.fail("shouldn't come here")
+ events = []
+ result = LoggingResult(events)
+ test = Foo("test_die")
+ test.run(result)
+ self.assertEqual(events,
+ ['startTest', 'addSubTestSuccess',
+ 'addExpectedFailure', 'stopTest'])
+ self.assertEqual(len(result.expectedFailures), 1)
+ self.assertIs(result.expectedFailures[0][0], test)
+ self.assertTrue(result.wasSuccessful())
+
def test_unexpected_success(self):
class Foo(unittest.TestCase):
@unittest.expectedFailure
@@ -110,6 +160,30 @@ class Test_TestSkipping(unittest.TestCase):
self.assertEqual(result.unexpectedSuccesses, [test])
self.assertTrue(result.wasSuccessful())
+ def test_unexpected_success_subtests(self):
+ # Success in all subtests counts as the unexpected success of
+ # the whole test.
+ class Foo(unittest.TestCase):
+ @unittest.expectedFailure
+ def test_die(self):
+ with self.subTest():
+ # This one succeeds
+ pass
+ with self.subTest():
+ # So does this one
+ pass
+ events = []
+ result = LoggingResult(events)
+ test = Foo("test_die")
+ test.run(result)
+ self.assertEqual(events,
+ ['startTest',
+ 'addSubTestSuccess', 'addSubTestSuccess',
+ 'addUnexpectedSuccess', 'stopTest'])
+ self.assertFalse(result.failures)
+ self.assertEqual(result.unexpectedSuccesses, [test])
+ self.assertTrue(result.wasSuccessful())
+
def test_skip_doesnt_run_setup(self):
class Foo(unittest.TestCase):
wasSetUp = False
@@ -147,3 +221,7 @@ class Test_TestSkipping(unittest.TestCase):
suite = unittest.TestSuite([test])
suite.run(result)
self.assertEqual(result.skipped, [(test, "testing")])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/test_suite.py b/Lib/unittest/test/test_suite.py
index 2db978ddb8..54cec6ecba 100644
--- a/Lib/unittest/test/test_suite.py
+++ b/Lib/unittest/test/test_suite.py
@@ -1,6 +1,8 @@
import unittest
+import gc
import sys
+import weakref
from .support import LoggingResult, TestEquality
@@ -300,7 +302,54 @@ class Test_TestSuite(unittest.TestCase, TestEquality):
# when the bug is fixed this line will not crash
suite.run(unittest.TestResult())
+ def test_remove_test_at_index(self):
+ if not unittest.BaseTestSuite._cleanup:
+ raise unittest.SkipTest("Suite cleanup is disabled")
+ suite = unittest.TestSuite()
+
+ suite._tests = [1, 2, 3]
+ suite._removeTestAtIndex(1)
+
+ self.assertEqual([1, None, 3], suite._tests)
+
+ def test_remove_test_at_index_not_indexable(self):
+ if not unittest.BaseTestSuite._cleanup:
+ raise unittest.SkipTest("Suite cleanup is disabled")
+
+ suite = unittest.TestSuite()
+ suite._tests = None
+
+ # if _removeAtIndex raises for noniterables this next line will break
+ suite._removeTestAtIndex(2)
+
+ def assert_garbage_collect_test_after_run(self, TestSuiteClass):
+ if not unittest.BaseTestSuite._cleanup:
+ raise unittest.SkipTest("Suite cleanup is disabled")
+
+ class Foo(unittest.TestCase):
+ def test_nothing(self):
+ pass
+
+ test = Foo('test_nothing')
+ wref = weakref.ref(test)
+
+ suite = TestSuiteClass([wref()])
+ suite.run(unittest.TestResult())
+
+ del test
+
+ # for the benefit of non-reference counting implementations
+ gc.collect()
+
+ self.assertEqual(suite._tests, [None])
+ self.assertIsNone(wref())
+
+ def test_garbage_collect_test_after_run_BaseTestSuite(self):
+ self.assert_garbage_collect_test_after_run(unittest.BaseTestSuite)
+
+ def test_garbage_collect_test_after_run_TestSuite(self):
+ self.assert_garbage_collect_test_after_run(unittest.TestSuite)
def test_basetestsuite(self):
class Test(unittest.TestCase):
@@ -363,6 +412,5 @@ class Test_TestSuite(unittest.TestCase, TestEquality):
self.assertFalse(result._testRunEntered)
-
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/unittest/test/testmock/__main__.py b/Lib/unittest/test/testmock/__main__.py
new file mode 100644
index 0000000000..45c633a4ee
--- /dev/null
+++ b/Lib/unittest/test/testmock/__main__.py
@@ -0,0 +1,18 @@
+import os
+import unittest
+
+
+def load_tests(loader, standard_tests, pattern):
+ # top level directory cached on loader instance
+ this_dir = os.path.dirname(__file__)
+ pattern = pattern or "test*.py"
+ # We are inside unittest.test.testmock, so the top-level is three notches up
+ top_level_dir = os.path.dirname(os.path.dirname(os.path.dirname(this_dir)))
+ package_tests = loader.discover(start_dir=this_dir, pattern=pattern,
+ top_level_dir=top_level_dir)
+ standard_tests.addTests(package_tests)
+ return standard_tests
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/unittest/test/testmock/testcallable.py b/Lib/unittest/test/testmock/testcallable.py
index 7b2dd003ea..5390a4e10f 100644
--- a/Lib/unittest/test/testmock/testcallable.py
+++ b/Lib/unittest/test/testmock/testcallable.py
@@ -145,3 +145,7 @@ class TestCallable(unittest.TestCase):
mock.wibble.assert_called_once_with()
self.assertRaises(TypeError, mock.wibble, 'some', 'args')
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/unittest/test/testmock/testhelpers.py b/Lib/unittest/test/testmock/testhelpers.py
index a362a2f784..1dbc0b64ba 100644
--- a/Lib/unittest/test/testmock/testhelpers.py
+++ b/Lib/unittest/test/testmock/testhelpers.py
@@ -337,9 +337,10 @@ class SpecSignatureTest(unittest.TestCase):
def test_basic(self):
- for spec in (SomeClass, SomeClass()):
- mock = create_autospec(spec)
- self._check_someclass_mock(mock)
+ mock = create_autospec(SomeClass)
+ self._check_someclass_mock(mock)
+ mock = create_autospec(SomeClass())
+ self._check_someclass_mock(mock)
def test_create_autospec_return_value(self):
@@ -576,10 +577,10 @@ class SpecSignatureTest(unittest.TestCase):
def test_spec_inheritance_for_classes(self):
class Foo(object):
- def a(self):
+ def a(self, x):
pass
class Bar(object):
- def f(self):
+ def f(self, y):
pass
class_mock = create_autospec(Foo)
@@ -587,26 +588,30 @@ class SpecSignatureTest(unittest.TestCase):
self.assertIsNot(class_mock, class_mock())
for this_mock in class_mock, class_mock():
- this_mock.a()
- this_mock.a.assert_called_with()
- self.assertRaises(TypeError, this_mock.a, 'foo')
+ this_mock.a(x=5)
+ this_mock.a.assert_called_with(x=5)
+ this_mock.a.assert_called_with(5)
+ self.assertRaises(TypeError, this_mock.a, 'foo', 'bar')
self.assertRaises(AttributeError, getattr, this_mock, 'b')
instance_mock = create_autospec(Foo())
- instance_mock.a()
- instance_mock.a.assert_called_with()
- self.assertRaises(TypeError, instance_mock.a, 'foo')
+ instance_mock.a(5)
+ instance_mock.a.assert_called_with(5)
+ instance_mock.a.assert_called_with(x=5)
+ self.assertRaises(TypeError, instance_mock.a, 'foo', 'bar')
self.assertRaises(AttributeError, getattr, instance_mock, 'b')
# The return value isn't isn't callable
self.assertRaises(TypeError, instance_mock)
- instance_mock.Bar.f()
- instance_mock.Bar.f.assert_called_with()
+ instance_mock.Bar.f(6)
+ instance_mock.Bar.f.assert_called_with(6)
+ instance_mock.Bar.f.assert_called_with(y=6)
self.assertRaises(AttributeError, getattr, instance_mock.Bar, 'g')
- instance_mock.Bar().f()
- instance_mock.Bar().f.assert_called_with()
+ instance_mock.Bar().f(6)
+ instance_mock.Bar().f.assert_called_with(6)
+ instance_mock.Bar().f.assert_called_with(y=6)
self.assertRaises(AttributeError, getattr, instance_mock.Bar(), 'g')
@@ -663,12 +668,15 @@ class SpecSignatureTest(unittest.TestCase):
self.assertRaises(TypeError, mock)
mock(1, 2)
mock.assert_called_with(1, 2)
+ mock.assert_called_with(1, b=2)
+ mock.assert_called_with(a=1, b=2)
f.f = f
mock = create_autospec(f)
self.assertRaises(TypeError, mock.f)
mock.f(3, 4)
mock.f.assert_called_with(3, 4)
+ mock.f.assert_called_with(a=3, b=4)
def test_skip_attributeerrors(self):
@@ -704,9 +712,13 @@ class SpecSignatureTest(unittest.TestCase):
self.assertRaises(TypeError, mock)
mock(1)
mock.assert_called_once_with(1)
+ mock.assert_called_once_with(a=1)
+ self.assertRaises(AssertionError, mock.assert_called_once_with, 2)
mock(4, 5)
mock.assert_called_with(4, 5)
+ mock.assert_called_with(a=4, b=5)
+ self.assertRaises(AssertionError, mock.assert_called_with, a=5, b=4)
def test_class_with_no_init(self):
@@ -719,24 +731,27 @@ class SpecSignatureTest(unittest.TestCase):
def test_signature_callable(self):
class Callable(object):
- def __init__(self):
+ def __init__(self, x, y):
pass
def __call__(self, a):
pass
mock = create_autospec(Callable)
- mock()
- mock.assert_called_once_with()
+ mock(1, 2)
+ mock.assert_called_once_with(1, 2)
+ mock.assert_called_once_with(x=1, y=2)
self.assertRaises(TypeError, mock, 'a')
- instance = mock()
+ instance = mock(1, 2)
self.assertRaises(TypeError, instance)
instance(a='a')
+ instance.assert_called_once_with('a')
instance.assert_called_once_with(a='a')
instance('a')
instance.assert_called_with('a')
+ instance.assert_called_with(a='a')
- mock = create_autospec(Callable())
+ mock = create_autospec(Callable(1, 2))
mock(a='a')
mock.assert_called_once_with(a='a')
self.assertRaises(TypeError, mock)
@@ -779,7 +794,11 @@ class SpecSignatureTest(unittest.TestCase):
pass
a = create_autospec(Foo)
+ a.f(10)
+ a.f.assert_called_with(10)
+ a.f.assert_called_with(self=10)
a.f(self=10)
+ a.f.assert_called_with(10)
a.f.assert_called_with(self=10)
diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py
index cef5405fe9..20cc6541e6 100644
--- a/Lib/unittest/test/testmock/testmock.py
+++ b/Lib/unittest/test/testmock/testmock.py
@@ -25,6 +25,18 @@ class Iter(object):
__next__ = next
+class Something(object):
+ def meth(self, a, b, c, d=None):
+ pass
+
+ @classmethod
+ def cmeth(cls, a, b, c, d=None):
+ pass
+
+ @staticmethod
+ def smeth(a, b, c, d=None):
+ pass
+
class MockTest(unittest.TestCase):
@@ -273,6 +285,43 @@ class MockTest(unittest.TestCase):
mock.assert_called_with(1, 2, 3, a='fish', b='nothing')
+ def test_assert_called_with_function_spec(self):
+ def f(a, b, c, d=None):
+ pass
+
+ mock = Mock(spec=f)
+
+ mock(1, b=2, c=3)
+ mock.assert_called_with(1, 2, 3)
+ mock.assert_called_with(a=1, b=2, c=3)
+ self.assertRaises(AssertionError, mock.assert_called_with,
+ 1, b=3, c=2)
+ # Expected call doesn't match the spec's signature
+ with self.assertRaises(AssertionError) as cm:
+ mock.assert_called_with(e=8)
+ self.assertIsInstance(cm.exception.__cause__, TypeError)
+
+
+ def test_assert_called_with_method_spec(self):
+ def _check(mock):
+ mock(1, b=2, c=3)
+ mock.assert_called_with(1, 2, 3)
+ mock.assert_called_with(a=1, b=2, c=3)
+ self.assertRaises(AssertionError, mock.assert_called_with,
+ 1, b=3, c=2)
+
+ mock = Mock(spec=Something().meth)
+ _check(mock)
+ mock = Mock(spec=Something.cmeth)
+ _check(mock)
+ mock = Mock(spec=Something().cmeth)
+ _check(mock)
+ mock = Mock(spec=Something.smeth)
+ _check(mock)
+ mock = Mock(spec=Something().smeth)
+ _check(mock)
+
+
def test_assert_called_once_with(self):
mock = Mock()
mock()
@@ -297,6 +346,29 @@ class MockTest(unittest.TestCase):
)
+ def test_assert_called_once_with_function_spec(self):
+ def f(a, b, c, d=None):
+ pass
+
+ mock = Mock(spec=f)
+
+ mock(1, b=2, c=3)
+ mock.assert_called_once_with(1, 2, 3)
+ mock.assert_called_once_with(a=1, b=2, c=3)
+ self.assertRaises(AssertionError, mock.assert_called_once_with,
+ 1, b=3, c=2)
+ # Expected call doesn't match the spec's signature
+ with self.assertRaises(AssertionError) as cm:
+ mock.assert_called_once_with(e=8)
+ self.assertIsInstance(cm.exception.__cause__, TypeError)
+ # Mock called more than once => always fails
+ mock(4, 5, 6)
+ self.assertRaises(AssertionError, mock.assert_called_once_with,
+ 1, 2, 3)
+ self.assertRaises(AssertionError, mock.assert_called_once_with,
+ 4, 5, 6)
+
+
def test_attribute_access_returns_mocks(self):
mock = Mock()
something = mock.something
@@ -995,6 +1067,39 @@ class MockTest(unittest.TestCase):
)
+ def test_assert_has_calls_with_function_spec(self):
+ def f(a, b, c, d=None):
+ pass
+
+ mock = Mock(spec=f)
+
+ mock(1, b=2, c=3)
+ mock(4, 5, c=6, d=7)
+ mock(10, 11, c=12)
+ calls = [
+ ('', (1, 2, 3), {}),
+ ('', (4, 5, 6), {'d': 7}),
+ ((10, 11, 12), {}),
+ ]
+ mock.assert_has_calls(calls)
+ mock.assert_has_calls(calls, any_order=True)
+ mock.assert_has_calls(calls[1:])
+ mock.assert_has_calls(calls[1:], any_order=True)
+ mock.assert_has_calls(calls[:-1])
+ mock.assert_has_calls(calls[:-1], any_order=True)
+ # Reversed order
+ calls = list(reversed(calls))
+ with self.assertRaises(AssertionError):
+ mock.assert_has_calls(calls)
+ mock.assert_has_calls(calls, any_order=True)
+ with self.assertRaises(AssertionError):
+ mock.assert_has_calls(calls[1:])
+ mock.assert_has_calls(calls[1:], any_order=True)
+ with self.assertRaises(AssertionError):
+ mock.assert_has_calls(calls[:-1])
+ mock.assert_has_calls(calls[:-1], any_order=True)
+
+
def test_assert_any_call(self):
mock = Mock()
mock(1, 2)
@@ -1021,6 +1126,26 @@ class MockTest(unittest.TestCase):
)
+ def test_assert_any_call_with_function_spec(self):
+ def f(a, b, c, d=None):
+ pass
+
+ mock = Mock(spec=f)
+
+ mock(1, b=2, c=3)
+ mock(4, 5, c=6, d=7)
+ mock.assert_any_call(1, 2, 3)
+ mock.assert_any_call(a=1, b=2, c=3)
+ mock.assert_any_call(4, 5, 6, 7)
+ mock.assert_any_call(a=4, b=5, c=6, d=7)
+ self.assertRaises(AssertionError, mock.assert_any_call,
+ 1, b=3, c=2)
+ # Expected call doesn't match the spec's signature
+ with self.assertRaises(AssertionError) as cm:
+ mock.assert_any_call(e=8)
+ self.assertIsInstance(cm.exception.__cause__, TypeError)
+
+
def test_mock_calls_create_autospec(self):
def f(a, b):
pass
@@ -1177,20 +1302,6 @@ class MockTest(unittest.TestCase):
self.assertEqual(m.method_calls, [])
- def test_attribute_deletion(self):
- # this behaviour isn't *useful*, but at least it's now tested...
- for Klass in Mock, MagicMock, NonCallableMagicMock, NonCallableMock:
- m = Klass()
- original = m.foo
- m.foo = 3
- del m.foo
- self.assertEqual(m.foo, original)
-
- new = m.foo = Mock()
- del m.foo
- self.assertEqual(m.foo, new)
-
-
def test_mock_parents(self):
for Klass in Mock, MagicMock:
m = Klass()
@@ -1254,7 +1365,8 @@ class MockTest(unittest.TestCase):
def test_attribute_deletion(self):
- for mock in Mock(), MagicMock():
+ for mock in (Mock(), MagicMock(), NonCallableMagicMock(),
+ NonCallableMock()):
self.assertTrue(hasattr(mock, 'm'))
del mock.m
@@ -1274,6 +1386,5 @@ class MockTest(unittest.TestCase):
mock.foo
-
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/unittest/test/testmock/testpatch.py b/Lib/unittest/test/testmock/testpatch.py
index c1091b4e9b..c1bc34fa8c 100644
--- a/Lib/unittest/test/testmock/testpatch.py
+++ b/Lib/unittest/test/testmock/testpatch.py
@@ -1780,6 +1780,5 @@ class PatchTest(unittest.TestCase):
self.assertIs(os.path, path)
-
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/unittest/test/testmock/testwith.py b/Lib/unittest/test/testmock/testwith.py
index 0a0cfad120..f54e051e94 100644
--- a/Lib/unittest/test/testmock/testwith.py
+++ b/Lib/unittest/test/testmock/testwith.py
@@ -172,5 +172,88 @@ class TestMockOpen(unittest.TestCase):
self.assertEqual(result, 'foo')
+ def test_readline_data(self):
+ # Check that readline will return all the lines from the fake file
+ mock = mock_open(read_data='foo\nbar\nbaz\n')
+ with patch('%s.open' % __name__, mock, create=True):
+ h = open('bar')
+ line1 = h.readline()
+ line2 = h.readline()
+ line3 = h.readline()
+ self.assertEqual(line1, 'foo\n')
+ self.assertEqual(line2, 'bar\n')
+ self.assertEqual(line3, 'baz\n')
+
+ # Check that we properly emulate a file that doesn't end in a newline
+ mock = mock_open(read_data='foo')
+ with patch('%s.open' % __name__, mock, create=True):
+ h = open('bar')
+ result = h.readline()
+ self.assertEqual(result, 'foo')
+
+
+ def test_readlines_data(self):
+ # Test that emulating a file that ends in a newline character works
+ mock = mock_open(read_data='foo\nbar\nbaz\n')
+ with patch('%s.open' % __name__, mock, create=True):
+ h = open('bar')
+ result = h.readlines()
+ self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
+
+ # Test that files without a final newline will also be correctly
+ # emulated
+ mock = mock_open(read_data='foo\nbar\nbaz')
+ with patch('%s.open' % __name__, mock, create=True):
+ h = open('bar')
+ result = h.readlines()
+
+ self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
+
+
+ def test_mock_open_read_with_argument(self):
+ # At one point calling read with an argument was broken
+ # for mocks returned by mock_open
+ some_data = 'foo\nbar\nbaz'
+ mock = mock_open(read_data=some_data)
+ self.assertEqual(mock().read(10), some_data)
+
+
+ def test_interleaved_reads(self):
+ # Test that calling read, readline, and readlines pulls data
+ # sequentially from the data we preload with
+ mock = mock_open(read_data='foo\nbar\nbaz\n')
+ with patch('%s.open' % __name__, mock, create=True):
+ h = open('bar')
+ line1 = h.readline()
+ rest = h.readlines()
+ self.assertEqual(line1, 'foo\n')
+ self.assertEqual(rest, ['bar\n', 'baz\n'])
+
+ mock = mock_open(read_data='foo\nbar\nbaz\n')
+ with patch('%s.open' % __name__, mock, create=True):
+ h = open('bar')
+ line1 = h.readline()
+ rest = h.read()
+ self.assertEqual(line1, 'foo\n')
+ self.assertEqual(rest, 'bar\nbaz\n')
+
+
+ def test_overriding_return_values(self):
+ mock = mock_open(read_data='foo')
+ handle = mock()
+
+ handle.read.return_value = 'bar'
+ handle.readline.return_value = 'bar'
+ handle.readlines.return_value = ['bar']
+
+ self.assertEqual(handle.read(), 'bar')
+ self.assertEqual(handle.readline(), 'bar')
+ self.assertEqual(handle.readlines(), ['bar'])
+
+ # call repeatedly to check that a StopIteration is not propagated
+ self.assertEqual(handle.readline(), 'bar')
+ self.assertEqual(handle.readline(), 'bar')
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py
index ccdf0b81fa..aee498fd0b 100644
--- a/Lib/unittest/util.py
+++ b/Lib/unittest/util.py
@@ -1,10 +1,47 @@
"""Various utility functions."""
from collections import namedtuple, OrderedDict
+from os.path import commonprefix
__unittest = True
_MAX_LENGTH = 80
+_PLACEHOLDER_LEN = 12
+_MIN_BEGIN_LEN = 5
+_MIN_END_LEN = 5
+_MIN_COMMON_LEN = 5
+_MIN_DIFF_LEN = _MAX_LENGTH - \
+ (_MIN_BEGIN_LEN + _PLACEHOLDER_LEN + _MIN_COMMON_LEN +
+ _PLACEHOLDER_LEN + _MIN_END_LEN)
+assert _MIN_DIFF_LEN >= 0
+
+def _shorten(s, prefixlen, suffixlen):
+ skip = len(s) - prefixlen - suffixlen
+ if skip > _PLACEHOLDER_LEN:
+ s = '%s[%d chars]%s' % (s[:prefixlen], skip, s[len(s) - suffixlen:])
+ return s
+
+def _common_shorten_repr(*args):
+ args = tuple(map(safe_repr, args))
+ maxlen = max(map(len, args))
+ if maxlen <= _MAX_LENGTH:
+ return args
+
+ prefix = commonprefix(args)
+ prefixlen = len(prefix)
+
+ common_len = _MAX_LENGTH - \
+ (maxlen - prefixlen + _MIN_BEGIN_LEN + _PLACEHOLDER_LEN)
+ if common_len > _MIN_COMMON_LEN:
+ assert _MIN_BEGIN_LEN + _PLACEHOLDER_LEN + _MIN_COMMON_LEN + \
+ (maxlen - prefixlen) < _MAX_LENGTH
+ prefix = _shorten(prefix, _MIN_BEGIN_LEN, common_len)
+ return tuple(prefix + s[prefixlen:] for s in args)
+
+ prefix = _shorten(prefix, _MIN_BEGIN_LEN, _MIN_COMMON_LEN)
+ return tuple(prefix + _shorten(s[prefixlen:], _MIN_DIFF_LEN, _MIN_END_LEN)
+ for s in args)
+
def safe_repr(obj, short=False):
try:
result = repr(obj)