diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2014-02-26 21:03:19 +0200 |
---|---|---|
committer | Serhiy Storchaka <storchaka@gmail.com> | 2014-02-26 21:03:19 +0200 |
commit | 3e7e44e7b9948036633445e39d82d522a4d337ac (patch) | |
tree | 00effc898cea8ad8e64e3390c08b4eea8baa5e83 /Lib/test | |
parent | a4ddfd958a170badda1f7236d5bf9f24ef5728aa (diff) | |
parent | b5a1fdc60bc80f2ce10230a1953ded4a189433f7 (diff) | |
download | cpython-3e7e44e7b9948036633445e39d82d522a4d337ac.tar.gz |
Added tests for issue #20501.
Diffstat (limited to 'Lib/test')
327 files changed, 41932 insertions, 6621 deletions
diff --git a/Lib/test/__main__.py b/Lib/test/__main__.py index ce5615b889..d5fbe159d7 100644 --- a/Lib/test/__main__.py +++ b/Lib/test/__main__.py @@ -1,13 +1,3 @@ -from test import regrtest, support +from test import regrtest - -TEMPDIR, TESTCWD = regrtest._make_temp_dir_for_build(regrtest.TEMPDIR) -regrtest.TEMPDIR = TEMPDIR -regrtest.TESTCWD = TESTCWD - -# Run the tests in a context manager that temporary changes the CWD to a -# temporary and writable directory. If it's not possible to create or -# change the CWD, the original CWD will be used. The original CWD is -# available from support.SAVEDCWD. -with support.temp_cwd(TESTCWD, quiet=True): - regrtest.main() +regrtest.main_in_temp_cwd() diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index aa66db4b7e..8eb57fe87e 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -41,7 +41,7 @@ from multiprocessing import util try: from multiprocessing import reduction - HAS_REDUCTION = True + HAS_REDUCTION = reduction.HAVE_SEND_HANDLE except ImportError: HAS_REDUCTION = False @@ -97,6 +97,9 @@ try: except: MAXFD = 256 +# To speed up tests when using the forkserver, we can preload these: +PRELOAD = ['__main__', 'test.test_multiprocessing_forkserver'] + # # Some tests require ctypes # @@ -292,17 +295,22 @@ class _TestProcess(BaseTestCase): self.assertTimingAlmostEqual(join.elapsed, 0.0) self.assertEqual(p.is_alive(), True) + # XXX maybe terminating too soon causes the problems on Gentoo... + time.sleep(1) + p.terminate() if hasattr(signal, 'alarm'): + # On the Gentoo buildbot waitpid() often seems to block forever. + # We use alarm() to interrupt it if it blocks for too long. def handler(*args): raise RuntimeError('join took too long: %s' % p) old_handler = signal.signal(signal.SIGALRM, handler) try: signal.alarm(10) self.assertEqual(join(), None) - signal.alarm(0) finally: + signal.alarm(0) signal.signal(signal.SIGALRM, old_handler) else: self.assertEqual(join(), None) @@ -340,7 +348,6 @@ class _TestProcess(BaseTestCase): @classmethod def _test_recursion(cls, wconn, id): - from multiprocessing import forking wconn.send(id) if len(id) < 2: for i in range(2): @@ -388,7 +395,7 @@ class _TestProcess(BaseTestCase): self.assertFalse(wait_for_handle(sentinel, timeout=0.0)) event.set() p.join() - self.assertTrue(wait_for_handle(sentinel, timeout=DELTA)) + self.assertTrue(wait_for_handle(sentinel, timeout=1)) # # @@ -688,9 +695,6 @@ class _TestQueue(BaseTestCase): def test_task_done(self): queue = self.JoinableQueue() - if sys.version_info < (2, 5) and not hasattr(queue, 'task_done'): - self.skipTest("requires 'queue.task_done()' method") - workers = [self.Process(target=self._test_task_done, args=(queue,)) for i in range(4)] @@ -1777,6 +1781,35 @@ class _TestPool(BaseTestCase): self.assertEqual(r.get(), expected) self.assertRaises(ValueError, p.map_async, sqr, L) + @classmethod + def _test_traceback(cls): + raise RuntimeError(123) # some comment + + def test_traceback(self): + # We want ensure that the traceback from the child process is + # contained in the traceback raised in the main process. + if self.TYPE == 'processes': + with self.Pool(1) as p: + try: + p.apply(self._test_traceback) + except Exception as e: + exc = e + else: + raise AssertionError('expected RuntimeError') + self.assertIs(type(exc), RuntimeError) + self.assertEqual(exc.args, (123,)) + cause = exc.__cause__ + self.assertIs(type(cause), multiprocessing.pool.RemoteTraceback) + self.assertIn('raise RuntimeError(123) # some comment', cause.tb) + + with test.support.captured_stderr() as f1: + try: + raise exc + except RuntimeError: + sys.excepthook(*sys.exc_info()) + self.assertIn('raise RuntimeError(123) # some comment', + f1.getvalue()) + def raising(): raise KeyError("key") @@ -2484,7 +2517,7 @@ class _TestPicklingConnections(BaseTestCase): @classmethod def tearDownClass(cls): - from multiprocessing.reduction import resource_sharer + from multiprocessing import resource_sharer resource_sharer.stop(timeout=5) @classmethod @@ -2798,30 +2831,40 @@ class _TestFinalize(BaseTestCase): # Test that from ... import * works for each module # -class _TestImportStar(BaseTestCase): +class _TestImportStar(unittest.TestCase): - ALLOWED_TYPES = ('processes',) + def get_module_names(self): + import glob + folder = os.path.dirname(multiprocessing.__file__) + pattern = os.path.join(folder, '*.py') + files = glob.glob(pattern) + modules = [os.path.splitext(os.path.split(f)[1])[0] for f in files] + modules = ['multiprocessing.' + m for m in modules] + modules.remove('multiprocessing.__init__') + modules.append('multiprocessing') + return modules def test_import(self): - modules = [ - 'multiprocessing', 'multiprocessing.connection', - 'multiprocessing.heap', 'multiprocessing.managers', - 'multiprocessing.pool', 'multiprocessing.process', - 'multiprocessing.synchronize', 'multiprocessing.util' - ] - - if HAS_REDUCTION: - modules.append('multiprocessing.reduction') + modules = self.get_module_names() + if sys.platform == 'win32': + modules.remove('multiprocessing.popen_fork') + modules.remove('multiprocessing.popen_forkserver') + modules.remove('multiprocessing.popen_spawn_posix') + else: + modules.remove('multiprocessing.popen_spawn_win32') + if not HAS_REDUCTION: + modules.remove('multiprocessing.popen_forkserver') - if c_int is not None: + if c_int is None: # This module requires _ctypes - modules.append('multiprocessing.sharedctypes') + modules.remove('multiprocessing.sharedctypes') for name in modules: __import__(name) mod = sys.modules[name] + self.assertTrue(hasattr(mod, '__all__'), name) - for attr in getattr(mod, '__all__', ()): + for attr in mod.__all__: self.assertTrue( hasattr(mod, attr), '%r does not have attribute %r' % (mod, attr) @@ -2904,7 +2947,7 @@ class _TestPollEintr(BaseTestCase): @classmethod def _killer(cls, pid): - time.sleep(0.5) + time.sleep(0.1) os.kill(pid, signal.SIGUSR1) @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') @@ -2917,12 +2960,14 @@ class _TestPollEintr(BaseTestCase): try: killer = self.Process(target=self._killer, args=(pid,)) killer.start() - p = self.Process(target=time.sleep, args=(1,)) - p.start() - p.join() + try: + p = self.Process(target=time.sleep, args=(2,)) + p.start() + p.join() + finally: + killer.join() self.assertTrue(got_signal[0]) self.assertEqual(p.exitcode, 0) - killer.join() finally: signal.signal(signal.SIGUSR1, oldhandler) @@ -2935,8 +2980,11 @@ class TestInvalidHandle(unittest.TestCase): @unittest.skipIf(WIN32, "skipped on Windows") def test_invalid_handles(self): conn = multiprocessing.connection.Connection(44977608) + # check that poll() doesn't crash try: - self.assertRaises((ValueError, OSError), conn.poll) + conn.poll() + except (ValueError, OSError): + pass finally: # Hack private attribute _handle to avoid printing an error # in conn.__del__ @@ -2944,131 +2992,6 @@ class TestInvalidHandle(unittest.TestCase): self.assertRaises((ValueError, OSError), multiprocessing.connection.Connection, -1) -# -# Functions used to create test cases from the base ones in this module -# - -def create_test_cases(Mixin, type): - result = {} - glob = globals() - Type = type.capitalize() - ALL_TYPES = {'processes', 'threads', 'manager'} - - for name in list(glob.keys()): - if name.startswith('_Test'): - base = glob[name] - assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES) - if type in base.ALLOWED_TYPES: - newname = 'With' + Type + name[1:] - class Temp(base, Mixin, unittest.TestCase): - pass - result[newname] = Temp - Temp.__name__ = Temp.__qualname__ = newname - Temp.__module__ = Mixin.__module__ - return result - -# -# Create test cases -# - -class ProcessesMixin(object): - TYPE = 'processes' - Process = multiprocessing.Process - connection = multiprocessing.connection - current_process = staticmethod(multiprocessing.current_process) - active_children = staticmethod(multiprocessing.active_children) - Pool = staticmethod(multiprocessing.Pool) - Pipe = staticmethod(multiprocessing.Pipe) - Queue = staticmethod(multiprocessing.Queue) - JoinableQueue = staticmethod(multiprocessing.JoinableQueue) - Lock = staticmethod(multiprocessing.Lock) - RLock = staticmethod(multiprocessing.RLock) - Semaphore = staticmethod(multiprocessing.Semaphore) - BoundedSemaphore = staticmethod(multiprocessing.BoundedSemaphore) - Condition = staticmethod(multiprocessing.Condition) - Event = staticmethod(multiprocessing.Event) - Barrier = staticmethod(multiprocessing.Barrier) - Value = staticmethod(multiprocessing.Value) - Array = staticmethod(multiprocessing.Array) - RawValue = staticmethod(multiprocessing.RawValue) - RawArray = staticmethod(multiprocessing.RawArray) - -testcases_processes = create_test_cases(ProcessesMixin, type='processes') -globals().update(testcases_processes) - - -class ManagerMixin(object): - TYPE = 'manager' - Process = multiprocessing.Process - Queue = property(operator.attrgetter('manager.Queue')) - JoinableQueue = property(operator.attrgetter('manager.JoinableQueue')) - Lock = property(operator.attrgetter('manager.Lock')) - RLock = property(operator.attrgetter('manager.RLock')) - Semaphore = property(operator.attrgetter('manager.Semaphore')) - BoundedSemaphore = property(operator.attrgetter('manager.BoundedSemaphore')) - Condition = property(operator.attrgetter('manager.Condition')) - Event = property(operator.attrgetter('manager.Event')) - Barrier = property(operator.attrgetter('manager.Barrier')) - Value = property(operator.attrgetter('manager.Value')) - Array = property(operator.attrgetter('manager.Array')) - list = property(operator.attrgetter('manager.list')) - dict = property(operator.attrgetter('manager.dict')) - Namespace = property(operator.attrgetter('manager.Namespace')) - - @classmethod - def Pool(cls, *args, **kwds): - return cls.manager.Pool(*args, **kwds) - - @classmethod - def setUpClass(cls): - cls.manager = multiprocessing.Manager() - - @classmethod - def tearDownClass(cls): - # only the manager process should be returned by active_children() - # but this can take a bit on slow machines, so wait a few seconds - # if there are other children too (see #17395) - t = 0.01 - while len(multiprocessing.active_children()) > 1 and t < 5: - time.sleep(t) - t *= 2 - gc.collect() # do garbage collection - if cls.manager._number_of_objects() != 0: - # This is not really an error since some tests do not - # ensure that all processes which hold a reference to a - # managed object have been joined. - print('Shared objects which still exist at manager shutdown:') - print(cls.manager._debug_info()) - cls.manager.shutdown() - cls.manager.join() - cls.manager = None - -testcases_manager = create_test_cases(ManagerMixin, type='manager') -globals().update(testcases_manager) - - -class ThreadsMixin(object): - TYPE = 'threads' - Process = multiprocessing.dummy.Process - connection = multiprocessing.dummy.connection - current_process = staticmethod(multiprocessing.dummy.current_process) - active_children = staticmethod(multiprocessing.dummy.active_children) - Pool = staticmethod(multiprocessing.Pool) - Pipe = staticmethod(multiprocessing.dummy.Pipe) - Queue = staticmethod(multiprocessing.dummy.Queue) - JoinableQueue = staticmethod(multiprocessing.dummy.JoinableQueue) - Lock = staticmethod(multiprocessing.dummy.Lock) - RLock = staticmethod(multiprocessing.dummy.RLock) - Semaphore = staticmethod(multiprocessing.dummy.Semaphore) - BoundedSemaphore = staticmethod(multiprocessing.dummy.BoundedSemaphore) - Condition = staticmethod(multiprocessing.dummy.Condition) - Event = staticmethod(multiprocessing.dummy.Event) - Barrier = staticmethod(multiprocessing.dummy.Barrier) - Value = staticmethod(multiprocessing.dummy.Value) - Array = staticmethod(multiprocessing.dummy.Array) - -testcases_threads = create_test_cases(ThreadsMixin, type='threads') -globals().update(testcases_threads) class OtherTest(unittest.TestCase): @@ -3418,7 +3341,7 @@ class TestFlags(unittest.TestCase): def test_flags(self): import json, subprocess # start child process using unusual flags - prog = ('from test.test_multiprocessing import TestFlags; ' + + prog = ('from test._test_multiprocessing import TestFlags; ' + 'TestFlags.run_in_child()') data = subprocess.check_output( [sys.executable, '-E', '-S', '-O', '-c', prog]) @@ -3465,13 +3388,14 @@ class TestTimeouts(unittest.TestCase): class TestNoForkBomb(unittest.TestCase): def test_noforkbomb(self): + sm = multiprocessing.get_start_method() name = os.path.join(os.path.dirname(__file__), 'mp_fork_bomb.py') - if WIN32: - rc, out, err = test.script_helper.assert_python_failure(name) + if sm != 'fork': + rc, out, err = test.script_helper.assert_python_failure(name, sm) self.assertEqual('', out.decode('ascii')) self.assertIn('RuntimeError', err.decode('ascii')) else: - rc, out, err = test.script_helper.assert_python_ok(name) + rc, out, err = test.script_helper.assert_python_ok(name, sm) self.assertEqual('123', out.decode('ascii').rstrip()) self.assertEqual('', err.decode('ascii')) @@ -3489,7 +3413,8 @@ class TestForkAwareThreadLock(unittest.TestCase): if n > 1: p = multiprocessing.Process(target=cls.child, args=(n-1, conn)) p.start() - p.join() + conn.close() + p.join(timeout=5) else: conn.send(len(util._afterfork_registry)) conn.close() @@ -3500,11 +3425,78 @@ class TestForkAwareThreadLock(unittest.TestCase): old_size = len(util._afterfork_registry) p = multiprocessing.Process(target=self.child, args=(5, w)) p.start() + w.close() new_size = r.recv() - p.join() + p.join(timeout=5) self.assertLessEqual(new_size, old_size) # +# Check that non-forked child processes do not inherit unneeded fds/handles +# + +class TestCloseFds(unittest.TestCase): + + def get_high_socket_fd(self): + if WIN32: + # The child process will not have any socket handles, so + # calling socket.fromfd() should produce WSAENOTSOCK even + # if there is a handle of the same number. + return socket.socket().detach() + else: + # We want to produce a socket with an fd high enough that a + # freshly created child process will not have any fds as high. + fd = socket.socket().detach() + to_close = [] + while fd < 50: + to_close.append(fd) + fd = os.dup(fd) + for x in to_close: + os.close(x) + return fd + + def close(self, fd): + if WIN32: + socket.socket(fileno=fd).close() + else: + os.close(fd) + + @classmethod + def _test_closefds(cls, conn, fd): + try: + s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + except Exception as e: + conn.send(e) + else: + s.close() + conn.send(None) + + def test_closefd(self): + if not HAS_REDUCTION: + raise unittest.SkipTest('requires fd pickling') + + reader, writer = multiprocessing.Pipe() + fd = self.get_high_socket_fd() + try: + p = multiprocessing.Process(target=self._test_closefds, + args=(writer, fd)) + p.start() + writer.close() + e = reader.recv() + p.join(timeout=5) + finally: + self.close(fd) + writer.close() + reader.close() + + if multiprocessing.get_start_method() == 'fork': + self.assertIs(e, None) + else: + WSAENOTSOCK = 10038 + self.assertIsInstance(e, OSError) + self.assertTrue(e.errno == errno.EBADF or + e.winerror == WSAENOTSOCK, e) + +# # Issue #17097: EINTR should be ignored by recv(), send(), accept() etc # @@ -3548,10 +3540,10 @@ class TestIgnoreEINTR(unittest.TestCase): def handler(signum, frame): pass signal.signal(signal.SIGUSR1, handler) - l = multiprocessing.connection.Listener() - conn.send(l.address) - a = l.accept() - a.send('welcome') + with multiprocessing.connection.Listener() as l: + conn.send(l.address) + a = l.accept() + a.send('welcome') @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') def test_ignore_listener(self): @@ -3572,26 +3564,267 @@ class TestIgnoreEINTR(unittest.TestCase): finally: conn.close() +class TestStartMethod(unittest.TestCase): + @classmethod + def _check_context(cls, conn): + conn.send(multiprocessing.get_start_method()) + + def check_context(self, ctx): + r, w = ctx.Pipe(duplex=False) + p = ctx.Process(target=self._check_context, args=(w,)) + p.start() + w.close() + child_method = r.recv() + r.close() + p.join() + self.assertEqual(child_method, ctx.get_start_method()) + + def test_context(self): + for method in ('fork', 'spawn', 'forkserver'): + try: + ctx = multiprocessing.get_context(method) + except ValueError: + continue + self.assertEqual(ctx.get_start_method(), method) + self.assertIs(ctx.get_context(), ctx) + self.assertRaises(ValueError, ctx.set_start_method, 'spawn') + self.assertRaises(ValueError, ctx.set_start_method, None) + self.check_context(ctx) + + def test_set_get(self): + multiprocessing.set_forkserver_preload(PRELOAD) + count = 0 + old_method = multiprocessing.get_start_method() + try: + for method in ('fork', 'spawn', 'forkserver'): + try: + multiprocessing.set_start_method(method, force=True) + except ValueError: + continue + self.assertEqual(multiprocessing.get_start_method(), method) + ctx = multiprocessing.get_context() + self.assertEqual(ctx.get_start_method(), method) + self.assertTrue(type(ctx).__name__.lower().startswith(method)) + self.assertTrue( + ctx.Process.__name__.lower().startswith(method)) + self.check_context(multiprocessing) + count += 1 + finally: + multiprocessing.set_start_method(old_method, force=True) + self.assertGreaterEqual(count, 1) + + def test_get_all(self): + methods = multiprocessing.get_all_start_methods() + if sys.platform == 'win32': + self.assertEqual(methods, ['spawn']) + else: + self.assertTrue(methods == ['fork', 'spawn'] or + methods == ['fork', 'spawn', 'forkserver']) + +# +# Check that killing process does not leak named semaphores +# + +@unittest.skipIf(sys.platform == "win32", + "test semantics don't make sense on Windows") +class TestSemaphoreTracker(unittest.TestCase): + def test_semaphore_tracker(self): + import subprocess + cmd = '''if 1: + import multiprocessing as mp, time, os + mp.set_start_method("spawn") + lock1 = mp.Lock() + lock2 = mp.Lock() + os.write(%d, lock1._semlock.name.encode("ascii") + b"\\n") + os.write(%d, lock2._semlock.name.encode("ascii") + b"\\n") + time.sleep(10) + ''' + r, w = os.pipe() + p = subprocess.Popen([sys.executable, + '-c', cmd % (w, w)], + pass_fds=[w], + stderr=subprocess.PIPE) + os.close(w) + with open(r, 'rb', closefd=True) as f: + name1 = f.readline().rstrip().decode('ascii') + name2 = f.readline().rstrip().decode('ascii') + _multiprocessing.sem_unlink(name1) + p.terminate() + p.wait() + time.sleep(2.0) + with self.assertRaises(OSError) as ctx: + _multiprocessing.sem_unlink(name2) + # docs say it should be ENOENT, but OSX seems to give EINVAL + self.assertIn(ctx.exception.errno, (errno.ENOENT, errno.EINVAL)) + err = p.stderr.read().decode('utf-8') + p.stderr.close() + expected = 'semaphore_tracker: There appear to be 2 leaked semaphores' + self.assertRegex(err, expected) + self.assertRegex(err, 'semaphore_tracker: %r: \[Errno' % name1) + # -# +# Mixins # -def setUpModule(): - if sys.platform.startswith("linux"): - try: - lock = multiprocessing.RLock() - except OSError: - raise unittest.SkipTest("OSError raises on RLock creation, " - "see issue 3111!") - check_enough_semaphores() - util.get_temp_dir() # creates temp directory for use by all processes - multiprocessing.get_logger().setLevel(LOG_LEVEL) +class ProcessesMixin(object): + TYPE = 'processes' + Process = multiprocessing.Process + connection = multiprocessing.connection + current_process = staticmethod(multiprocessing.current_process) + active_children = staticmethod(multiprocessing.active_children) + Pool = staticmethod(multiprocessing.Pool) + Pipe = staticmethod(multiprocessing.Pipe) + Queue = staticmethod(multiprocessing.Queue) + JoinableQueue = staticmethod(multiprocessing.JoinableQueue) + Lock = staticmethod(multiprocessing.Lock) + RLock = staticmethod(multiprocessing.RLock) + Semaphore = staticmethod(multiprocessing.Semaphore) + BoundedSemaphore = staticmethod(multiprocessing.BoundedSemaphore) + Condition = staticmethod(multiprocessing.Condition) + Event = staticmethod(multiprocessing.Event) + Barrier = staticmethod(multiprocessing.Barrier) + Value = staticmethod(multiprocessing.Value) + Array = staticmethod(multiprocessing.Array) + RawValue = staticmethod(multiprocessing.RawValue) + RawArray = staticmethod(multiprocessing.RawArray) + + +class ManagerMixin(object): + TYPE = 'manager' + Process = multiprocessing.Process + Queue = property(operator.attrgetter('manager.Queue')) + JoinableQueue = property(operator.attrgetter('manager.JoinableQueue')) + Lock = property(operator.attrgetter('manager.Lock')) + RLock = property(operator.attrgetter('manager.RLock')) + Semaphore = property(operator.attrgetter('manager.Semaphore')) + BoundedSemaphore = property(operator.attrgetter('manager.BoundedSemaphore')) + Condition = property(operator.attrgetter('manager.Condition')) + Event = property(operator.attrgetter('manager.Event')) + Barrier = property(operator.attrgetter('manager.Barrier')) + Value = property(operator.attrgetter('manager.Value')) + Array = property(operator.attrgetter('manager.Array')) + list = property(operator.attrgetter('manager.list')) + dict = property(operator.attrgetter('manager.dict')) + Namespace = property(operator.attrgetter('manager.Namespace')) + @classmethod + def Pool(cls, *args, **kwds): + return cls.manager.Pool(*args, **kwds) -def tearDownModule(): - # pause a bit so we don't get warning about dangling threads/processes - time.sleep(0.5) + @classmethod + def setUpClass(cls): + cls.manager = multiprocessing.Manager() + + @classmethod + def tearDownClass(cls): + # only the manager process should be returned by active_children() + # but this can take a bit on slow machines, so wait a few seconds + # if there are other children too (see #17395) + t = 0.01 + while len(multiprocessing.active_children()) > 1 and t < 5: + time.sleep(t) + t *= 2 + gc.collect() # do garbage collection + if cls.manager._number_of_objects() != 0: + # This is not really an error since some tests do not + # ensure that all processes which hold a reference to a + # managed object have been joined. + print('Shared objects which still exist at manager shutdown:') + print(cls.manager._debug_info()) + cls.manager.shutdown() + cls.manager.join() + cls.manager = None -if __name__ == '__main__': - unittest.main() +class ThreadsMixin(object): + TYPE = 'threads' + Process = multiprocessing.dummy.Process + connection = multiprocessing.dummy.connection + current_process = staticmethod(multiprocessing.dummy.current_process) + active_children = staticmethod(multiprocessing.dummy.active_children) + Pool = staticmethod(multiprocessing.Pool) + Pipe = staticmethod(multiprocessing.dummy.Pipe) + Queue = staticmethod(multiprocessing.dummy.Queue) + JoinableQueue = staticmethod(multiprocessing.dummy.JoinableQueue) + Lock = staticmethod(multiprocessing.dummy.Lock) + RLock = staticmethod(multiprocessing.dummy.RLock) + Semaphore = staticmethod(multiprocessing.dummy.Semaphore) + BoundedSemaphore = staticmethod(multiprocessing.dummy.BoundedSemaphore) + Condition = staticmethod(multiprocessing.dummy.Condition) + Event = staticmethod(multiprocessing.dummy.Event) + Barrier = staticmethod(multiprocessing.dummy.Barrier) + Value = staticmethod(multiprocessing.dummy.Value) + Array = staticmethod(multiprocessing.dummy.Array) + +# +# Functions used to create test cases from the base ones in this module +# + +def install_tests_in_module_dict(remote_globs, start_method): + __module__ = remote_globs['__name__'] + local_globs = globals() + ALL_TYPES = {'processes', 'threads', 'manager'} + + for name, base in local_globs.items(): + if not isinstance(base, type): + continue + if issubclass(base, BaseTestCase): + if base is BaseTestCase: + continue + assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES + for type_ in base.ALLOWED_TYPES: + newname = 'With' + type_.capitalize() + name[1:] + Mixin = local_globs[type_.capitalize() + 'Mixin'] + class Temp(base, Mixin, unittest.TestCase): + pass + Temp.__name__ = Temp.__qualname__ = newname + Temp.__module__ = __module__ + remote_globs[newname] = Temp + elif issubclass(base, unittest.TestCase): + class Temp(base, object): + pass + Temp.__name__ = Temp.__qualname__ = name + Temp.__module__ = __module__ + remote_globs[name] = Temp + + dangling = [None, None] + old_start_method = [None] + + def setUpModule(): + multiprocessing.set_forkserver_preload(PRELOAD) + multiprocessing.process._cleanup() + dangling[0] = multiprocessing.process._dangling.copy() + dangling[1] = threading._dangling.copy() + old_start_method[0] = multiprocessing.get_start_method(allow_none=True) + try: + multiprocessing.set_start_method(start_method, force=True) + except ValueError: + raise unittest.SkipTest(start_method + + ' start method not supported') + + if sys.platform.startswith("linux"): + try: + lock = multiprocessing.RLock() + except OSError: + raise unittest.SkipTest("OSError raises on RLock creation, " + "see issue 3111!") + check_enough_semaphores() + util.get_temp_dir() # creates temp directory + multiprocessing.get_logger().setLevel(LOG_LEVEL) + + def tearDownModule(): + multiprocessing.set_start_method(old_start_method[0], force=True) + # pause a bit so we don't get warning about dangling threads/processes + time.sleep(0.5) + multiprocessing.process._cleanup() + gc.collect() + tmp = set(multiprocessing.process._dangling) - set(dangling[0]) + if tmp: + print('Dangling processes:', tmp, file=sys.stderr) + del tmp + tmp = set(threading._dangling) - set(dangling[1]) + if tmp: + print('Dangling threads:', tmp, file=sys.stderr) + + remote_globs['setUpModule'] = setUpModule + remote_globs['tearDownModule'] = tearDownModule diff --git a/Lib/test/audiodata/pluck-pcm24.au b/Lib/test/audiodata/pluck-pcm24.au Binary files differnew file mode 100644 index 0000000000..0bb230418a --- /dev/null +++ b/Lib/test/audiodata/pluck-pcm24.au diff --git a/Lib/test/audiotests.py b/Lib/test/audiotests.py index 6ff5bb81bb..b7497ce5b0 100644 --- a/Lib/test/audiotests.py +++ b/Lib/test/audiotests.py @@ -12,24 +12,6 @@ class UnseekableIO(io.FileIO): def seek(self, *args, **kwargs): raise io.UnsupportedOperation -def byteswap2(data): - a = array.array('h') - a.frombytes(data) - a.byteswap() - return a.tobytes() - -def byteswap3(data): - ba = bytearray(data) - ba[::3] = data[2::3] - ba[2::3] = data[::3] - return bytes(ba) - -def byteswap4(data): - a = array.array('i') - a.frombytes(data) - a.byteswap() - return a.tobytes() - class AudioTests: close_fd = False @@ -56,6 +38,12 @@ class AudioTests: params = f.getparams() self.assertEqual(params, (nchannels, sampwidth, framerate, nframes, comptype, compname)) + self.assertEqual(params.nchannels, nchannels) + self.assertEqual(params.sampwidth, sampwidth) + self.assertEqual(params.framerate, framerate) + self.assertEqual(params.nframes, nframes) + self.assertEqual(params.comptype, comptype) + self.assertEqual(params.compname, compname) dump = pickle.dumps(params) self.assertEqual(pickle.loads(dump), params) @@ -72,15 +60,12 @@ class AudioWriteTests(AudioTests): return f def check_file(self, testfile, nframes, frames): - f = self.module.open(testfile, 'rb') - try: + with self.module.open(testfile, 'rb') as f: self.assertEqual(f.getnchannels(), self.nchannels) self.assertEqual(f.getsampwidth(), self.sampwidth) self.assertEqual(f.getframerate(), self.framerate) self.assertEqual(f.getnframes(), nframes) self.assertEqual(f.readframes(nframes), frames) - finally: - f.close() def test_write_params(self): f = self.create_file(TESTFN) @@ -90,6 +75,53 @@ class AudioWriteTests(AudioTests): self.nframes, self.comptype, self.compname) f.close() + def test_write_context_manager_calls_close(self): + # Close checks for a minimum header and will raise an error + # if it is not set, so this proves that close is called. + with self.assertRaises(self.module.Error): + with self.module.open(TESTFN, 'wb'): + pass + with self.assertRaises(self.module.Error): + with open(TESTFN, 'wb') as testfile: + with self.module.open(testfile): + pass + + def test_context_manager_with_open_file(self): + with open(TESTFN, 'wb') as testfile: + with self.module.open(testfile) as f: + f.setnchannels(self.nchannels) + f.setsampwidth(self.sampwidth) + f.setframerate(self.framerate) + f.setcomptype(self.comptype, self.compname) + self.assertEqual(testfile.closed, self.close_fd) + with open(TESTFN, 'rb') as testfile: + with self.module.open(testfile) as f: + self.assertFalse(f.getfp().closed) + params = f.getparams() + self.assertEqual(params.nchannels, self.nchannels) + self.assertEqual(params.sampwidth, self.sampwidth) + self.assertEqual(params.framerate, self.framerate) + if not self.close_fd: + self.assertIsNone(f.getfp()) + self.assertEqual(testfile.closed, self.close_fd) + + def test_context_manager_with_filename(self): + # If the file doesn't get closed, this test won't fail, but it will + # produce a resource leak warning. + with self.module.open(TESTFN, 'wb') as f: + f.setnchannels(self.nchannels) + f.setsampwidth(self.sampwidth) + f.setframerate(self.framerate) + f.setcomptype(self.comptype, self.compname) + with self.module.open(TESTFN) as f: + self.assertFalse(f.getfp().closed) + params = f.getparams() + self.assertEqual(params.nchannels, self.nchannels) + self.assertEqual(params.sampwidth, self.sampwidth) + self.assertEqual(params.framerate, self.framerate) + if not self.close_fd: + self.assertIsNone(f.getfp()) + def test_write(self): f = self.create_file(TESTFN) f.setnframes(self.nframes) @@ -98,6 +130,30 @@ class AudioWriteTests(AudioTests): self.check_file(TESTFN, self.nframes, self.frames) + def test_write_bytearray(self): + f = self.create_file(TESTFN) + f.setnframes(self.nframes) + f.writeframes(bytearray(self.frames)) + f.close() + + self.check_file(TESTFN, self.nframes, self.frames) + + def test_write_array(self): + f = self.create_file(TESTFN) + f.setnframes(self.nframes) + f.writeframes(array.array('h', self.frames)) + f.close() + + self.check_file(TESTFN, self.nframes, self.frames) + + def test_write_memoryview(self): + f = self.create_file(TESTFN) + f.setnframes(self.nframes) + f.writeframes(memoryview(self.frames)) + f.close() + + self.check_file(TESTFN, self.nframes, self.frames) + def test_incompleted_write(self): with open(TESTFN, 'wb') as testfile: testfile.write(b'ababagalamaga') @@ -137,20 +193,18 @@ class AudioWriteTests(AudioTests): self.check_file(testfile, self.nframes, self.frames) def test_unseekable_read(self): - f = self.create_file(TESTFN) - f.setnframes(self.nframes) - f.writeframes(self.frames) - f.close() + with self.create_file(TESTFN) as f: + f.setnframes(self.nframes) + f.writeframes(self.frames) with UnseekableIO(TESTFN, 'rb') as testfile: self.check_file(testfile, self.nframes, self.frames) def test_unseekable_write(self): with UnseekableIO(TESTFN, 'wb') as testfile: - f = self.create_file(testfile) - f.setnframes(self.nframes) - f.writeframes(self.frames) - f.close() + with self.create_file(testfile) as f: + f.setnframes(self.nframes) + f.writeframes(self.frames) self.check_file(TESTFN, self.nframes, self.frames) @@ -267,12 +321,9 @@ class AudioTestsWithSourceFile(AudioTests): with open(TESTFN, 'rb') as testfile: self.assertEqual(testfile.read(13), b'ababagalamaga') - f = self.module.open(testfile, 'rb') - try: + with self.module.open(testfile, 'rb') as f: self.assertEqual(f.getnchannels(), self.nchannels) self.assertEqual(f.getsampwidth(), self.sampwidth) self.assertEqual(f.getframerate(), self.framerate) self.assertEqual(f.getnframes(), self.sndfilenframes) self.assertEqual(f.readframes(self.nframes), self.frames) - finally: - f.close() diff --git a/Lib/test/badsyntax_future10.py b/Lib/test/badsyntax_future10.py new file mode 100644 index 0000000000..fa5ab67a98 --- /dev/null +++ b/Lib/test/badsyntax_future10.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +"spam, bar, blah" +from __future__ import print_function diff --git a/Lib/test/buffer_tests.py b/Lib/test/buffer_tests.py index cf54c28759..0a629400fd 100644 --- a/Lib/test/buffer_tests.py +++ b/Lib/test/buffer_tests.py @@ -160,14 +160,20 @@ class MixinBytesBufferCommonTests(object): self.marshal(b'abc\rab\tdef\ng\thi').expandtabs(8)) self.assertEqual(b'abc\rab def\ng hi', self.marshal(b'abc\rab\tdef\ng\thi').expandtabs(4)) + self.assertEqual(b'abc\r\nab def\ng hi', + self.marshal(b'abc\r\nab\tdef\ng\thi').expandtabs()) + self.assertEqual(b'abc\r\nab def\ng hi', + self.marshal(b'abc\r\nab\tdef\ng\thi').expandtabs(8)) self.assertEqual(b'abc\r\nab def\ng hi', self.marshal(b'abc\r\nab\tdef\ng\thi').expandtabs(4)) - self.assertEqual(b'abc\rab def\ng hi', - self.marshal(b'abc\rab\tdef\ng\thi').expandtabs()) - self.assertEqual(b'abc\rab def\ng hi', - self.marshal(b'abc\rab\tdef\ng\thi').expandtabs(8)) self.assertEqual(b'abc\r\nab\r\ndef\ng\r\nhi', - self.marshal(b'abc\r\nab\r\ndef\ng\r\nhi').expandtabs(4)) + self.marshal(b'abc\r\nab\r\ndef\ng\r\nhi').expandtabs(4)) + # check keyword args + self.assertEqual(b'abc\rab def\ng hi', + self.marshal(b'abc\rab\tdef\ng\thi').expandtabs(tabsize=8)) + self.assertEqual(b'abc\rab def\ng hi', + self.marshal(b'abc\rab\tdef\ng\thi').expandtabs(tabsize=4)) + self.assertEqual(b' a\n b', self.marshal(b' \ta\n\tb').expandtabs(1)) self.assertRaises(TypeError, self.marshal(b'hello').expandtabs, 42, 42) diff --git a/Lib/test/bytecode_helper.py b/Lib/test/bytecode_helper.py new file mode 100644 index 0000000000..58b4209f55 --- /dev/null +++ b/Lib/test/bytecode_helper.py @@ -0,0 +1,41 @@ +"""bytecode_helper - support tools for testing correct bytecode generation""" + +import unittest +import dis +import io + +_UNSPECIFIED = object() + +class BytecodeTestCase(unittest.TestCase): + """Custom assertion methods for inspecting bytecode.""" + + def get_disassembly_as_string(self, co): + s = io.StringIO() + dis.dis(co, file=s) + return s.getvalue() + + def assertInBytecode(self, x, opname, argval=_UNSPECIFIED): + """Returns instr if op is found, otherwise throws AssertionError""" + for instr in dis.get_instructions(x): + if instr.opname == opname: + if argval is _UNSPECIFIED or instr.argval == argval: + return instr + disassembly = self.get_disassembly_as_string(x) + if argval is _UNSPECIFIED: + msg = '%s not found in bytecode:\n%s' % (opname, disassembly) + else: + msg = '(%s,%r) not found in bytecode:\n%s' + msg = msg % (opname, argval, disassembly) + self.fail(msg) + + def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED): + """Throws AssertionError if op is found""" + for instr in dis.get_instructions(x): + if instr.opname == opname: + disassembly = self.get_disassembly_as_string(co) + if opargval is _UNSPECIFIED: + msg = '%s occurs in bytecode:\n%s' % (opname, disassembly) + elif instr.argval == argval: + msg = '(%s,%r) occurs in bytecode:\n%s' + msg = msg % (opname, argval, disassembly) + self.fail(msg) diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py index ab55476226..3f3c60acea 100644 --- a/Lib/test/datetimetester.py +++ b/Lib/test/datetimetester.py @@ -619,6 +619,10 @@ class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase): eq(td(hours=-.2/us_per_hour), td(0)) eq(td(days=-.4/us_per_day, hours=-.2/us_per_hour), td(microseconds=-1)) + # Test for a patch in Issue 8860 + eq(td(microseconds=0.5), 0.5*td(microseconds=1.0)) + eq(td(microseconds=0.5)//td.resolution, 0.5*td.resolution//td.resolution) + def test_massive_normalization(self): td = timedelta(microseconds=-1) self.assertEqual((td.days, td.seconds, td.microseconds), diff --git a/Lib/test/final_a.py b/Lib/test/final_a.py new file mode 100644 index 0000000000..390ee8895a --- /dev/null +++ b/Lib/test/final_a.py @@ -0,0 +1,19 @@ +""" +Fodder for module finalization tests in test_module. +""" + +import shutil +import test.final_b + +x = 'a' + +class C: + def __del__(self): + # Inspect module globals and builtins + print("x =", x) + print("final_b.x =", test.final_b.x) + print("shutil.rmtree =", getattr(shutil.rmtree, '__name__', None)) + print("len =", getattr(len, '__name__', None)) + +c = C() +_underscored = C() diff --git a/Lib/test/final_b.py b/Lib/test/final_b.py new file mode 100644 index 0000000000..7228d82b88 --- /dev/null +++ b/Lib/test/final_b.py @@ -0,0 +1,19 @@ +""" +Fodder for module finalization tests in test_module. +""" + +import shutil +import test.final_a + +x = 'b' + +class C: + def __del__(self): + # Inspect module globals and builtins + print("x =", x) + print("final_a.x =", test.final_a.x) + print("shutil.rmtree =", getattr(shutil.rmtree, '__name__', None)) + print("len =", getattr(len, '__name__', None)) + +c = C() +_underscored = C() diff --git a/Lib/test/fork_wait.py b/Lib/test/fork_wait.py index 88527df25e..19b54ec736 100644 --- a/Lib/test/fork_wait.py +++ b/Lib/test/fork_wait.py @@ -28,7 +28,7 @@ class ForkWait(unittest.TestCase): self.alive[id] = os.getpid() try: time.sleep(SHORTSLEEP) - except IOError: + except OSError: pass def wait_impl(self, cpid): diff --git a/Lib/test/keycert3.pem b/Lib/test/keycert3.pem new file mode 100644 index 0000000000..5bfa62c4ca --- /dev/null +++ b/Lib/test/keycert3.pem @@ -0,0 +1,73 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP +jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM +9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ +aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe +yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j +y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+ +AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW +5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL +9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9 +1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT +DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh +1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m +JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3 +RnJdHOMXWem7/w== +-----END PRIVATE KEY----- +Certificate: + Data: + Version: 1 (0x0) + Serial Number: 12723342612721443281 (0xb09264b1f2da21d1) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Nov 13 19:47:07 2022 GMT + Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d: + 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb: + c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99: + 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c: + f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93: + 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23: + f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5: + af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6: + 21:82:a5:3c:88:e5:be:1b:b1 + Exponent: 65537 (0x10001) + Signature Algorithm: sha1WithRSAEncryption + 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a: + e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93: + f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13: + e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92: + d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59: + 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8: + ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1: + 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75: + 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96: + 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48: + 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a: + f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6: + 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41: + a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb: + fc:a9:94:71 +-----BEGIN CERTIFICATE----- +MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY +WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV +BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3 +WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV +BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv +c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C +tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola +N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1 +TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR +iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG +xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo +5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv +mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF +YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh +2EJ36/yplHE= +-----END CERTIFICATE----- diff --git a/Lib/test/keycert4.pem b/Lib/test/keycert4.pem new file mode 100644 index 0000000000..53355c8a50 --- /dev/null +++ b/Lib/test/keycert4.pem @@ -0,0 +1,73 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAK5UQiMI5VkNs2Qv +L7gUaiDdFevNUXRjU4DHAe3ZzzYLZNE69h9gO9VCSS16tJ5fT5VEu0EZyGr0e3V2 +NkX0ZoU0Hc/UaY4qx7LHmn5SYZpIxhJnkf7SyHJK1zUaGlU0/LxYqIuGCtF5dqx1 +L2OQhEx1GM6RydHdgX69G64LXcY5AgMBAAECgYAhsRMfJkb9ERLMl/oG/5sLQu9L +pWDKt6+ZwdxzlZbggQ85CMYshjLKIod2DLL/sLf2x1PRXyRG131M1E3k8zkkz6de +R1uDrIN/x91iuYzfLQZGh8bMY7Yjd2eoroa6R/7DjpElGejLxOAaDWO0ST2IFQy9 +myTGS2jSM97wcXfsSQJBANP3jelJoS5X6BRjTSneY21wcocxVuQh8pXpErALVNsT +drrFTeaBuZp7KvbtnIM5g2WRNvaxLZlAY/hXPJvi6ncCQQDSix1cebml6EmPlEZS +Mm8gwI2F9ufUunwJmBJcz826Do0ZNGByWDAM/JQZH4FX4GfAFNuj8PUb+GQfadkx +i1DPAkEA0lVsNHojvuDsIo8HGuzarNZQT2beWjJ1jdxh9t7HrTx7LIps6rb/fhOK +Zs0R6gVAJaEbcWAPZ2tFyECInAdnsQJAUjaeXXjuxFkjOFym5PvqpvhpivEx78Bu +JPTr3rAKXmfGMxxfuOa0xK1wSyshP6ZR/RBn/+lcXPKubhHQDOegwwJAJF1DBQnN ++/tLmOPULtDwfP4Zixn+/8GmGOahFoRcu6VIGHmRilJTn6MOButw7Glv2YdeC6l/ +e83Gq6ffLVfKNQ== +-----END PRIVATE KEY----- +Certificate: + Data: + Version: 1 (0x0) + Serial Number: 12723342612721443282 (0xb09264b1f2da21d2) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Nov 13 19:47:07 2022 GMT + Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=fakehostname + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:ae:54:42:23:08:e5:59:0d:b3:64:2f:2f:b8:14: + 6a:20:dd:15:eb:cd:51:74:63:53:80:c7:01:ed:d9: + cf:36:0b:64:d1:3a:f6:1f:60:3b:d5:42:49:2d:7a: + b4:9e:5f:4f:95:44:bb:41:19:c8:6a:f4:7b:75:76: + 36:45:f4:66:85:34:1d:cf:d4:69:8e:2a:c7:b2:c7: + 9a:7e:52:61:9a:48:c6:12:67:91:fe:d2:c8:72:4a: + d7:35:1a:1a:55:34:fc:bc:58:a8:8b:86:0a:d1:79: + 76:ac:75:2f:63:90:84:4c:75:18:ce:91:c9:d1:dd: + 81:7e:bd:1b:ae:0b:5d:c6:39 + Exponent: 65537 (0x10001) + Signature Algorithm: sha1WithRSAEncryption + ad:45:8a:8e:ef:c6:ef:04:41:5c:2c:4a:84:dc:02:76:0c:d0: + 66:0f:f0:16:04:58:4d:fd:68:b7:b8:d3:a8:41:a5:5c:3c:6f: + 65:3c:d1:f8:ce:43:35:e7:41:5f:53:3d:c9:2c:c3:7d:fc:56: + 4a:fa:47:77:38:9d:bb:97:28:0a:3b:91:19:7f:bc:74:ae:15: + 6b:bd:20:36:67:45:a5:1e:79:d7:75:e6:89:5c:6d:54:84:d1: + 95:d7:a7:b4:33:3c:af:37:c4:79:8f:5e:75:dc:75:c2:18:fb: + 61:6f:2d:dc:38:65:5b:ba:67:28:d0:88:d7:8d:b9:23:5a:8e: + e8:c6:bb:db:ce:d5:b8:41:2a:ce:93:08:b6:95:ad:34:20:18: + d5:3b:37:52:74:50:0b:07:2c:b0:6d:a4:4c:7b:f4:e0:fd:d1: + af:17:aa:20:cd:62:e3:f0:9d:37:69:db:41:bd:d4:1c:fb:53: + 20:da:88:9d:76:26:67:ce:01:90:a7:80:1d:a9:5b:39:73:68: + 54:0a:d1:2a:03:1b:8f:3c:43:5d:5d:c4:51:f1:a7:e7:11:da: + 31:2c:49:06:af:04:f4:b8:3c:99:c4:20:b9:06:36:a2:00:92: + 61:1d:0c:6d:24:05:e2:82:e1:47:db:a0:5f:ba:b9:fb:ba:fa: + 49:12:1e:ce +-----BEGIN CERTIFICATE----- +MIICpzCCAY8CCQCwkmSx8toh0jANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY +WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV +BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3 +WjBiMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV +BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRUwEwYDVQQDEwxmYWtlaG9z +dG5hbWUwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAK5UQiMI5VkNs2QvL7gU +aiDdFevNUXRjU4DHAe3ZzzYLZNE69h9gO9VCSS16tJ5fT5VEu0EZyGr0e3V2NkX0 +ZoU0Hc/UaY4qx7LHmn5SYZpIxhJnkf7SyHJK1zUaGlU0/LxYqIuGCtF5dqx1L2OQ +hEx1GM6RydHdgX69G64LXcY5AgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAK1Fio7v +xu8EQVwsSoTcAnYM0GYP8BYEWE39aLe406hBpVw8b2U80fjOQzXnQV9TPcksw338 +Vkr6R3c4nbuXKAo7kRl/vHSuFWu9IDZnRaUeedd15olcbVSE0ZXXp7QzPK83xHmP +XnXcdcIY+2FvLdw4ZVu6ZyjQiNeNuSNajujGu9vO1bhBKs6TCLaVrTQgGNU7N1J0 +UAsHLLBtpEx79OD90a8XqiDNYuPwnTdp20G91Bz7UyDaiJ12JmfOAZCngB2pWzlz +aFQK0SoDG488Q11dxFHxp+cR2jEsSQavBPS4PJnEILkGNqIAkmEdDG0kBeKC4Ufb +oF+6ufu6+kkSHs4= +-----END CERTIFICATE----- diff --git a/Lib/test/leakers/test_gestalt.py b/Lib/test/leakers/test_gestalt.py deleted file mode 100644 index e0081c1fc5..0000000000 --- a/Lib/test/leakers/test_gestalt.py +++ /dev/null @@ -1,14 +0,0 @@ -import sys - -if sys.platform != 'darwin': - raise ValueError("This test only leaks on Mac OS X") - -def leak(): - # taken from platform._mac_ver_lookup() - from gestalt import gestalt - import MacOS - - try: - gestalt('sysu') - except MacOS.Error: - pass diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py index bfbf44e0db..1cbcea2343 100644 --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -80,6 +80,11 @@ class BaseLockTests(BaseTestCase): lock = self.locktype() del lock + def test_repr(self): + lock = self.locktype() + repr(lock) + del lock + def test_acquire_destroy(self): lock = self.locktype() lock.acquire() @@ -413,6 +418,17 @@ class ConditionTests(BaseTestCase): self.assertRaises(RuntimeError, cond.notify) def _check_notify(self, cond): + # Note that this test is sensitive to timing. If the worker threads + # don't execute in a timely fashion, the main thread may think they + # are further along then they are. The main thread therefore issues + # _wait() statements to try to make sure that it doesn't race ahead + # of the workers. + # Secondly, this test assumes that condition variables are not subject + # to spurious wakeups. The absence of spurious wakeups is an implementation + # detail of Condition Cariables in current CPython, but in general, not + # a guaranteed property of condition variables as a programming + # construct. In particular, it is possible that this can no longer + # be conveniently guaranteed should their implementation ever change. N = 5 results1 = [] results2 = [] @@ -440,6 +456,9 @@ class ConditionTests(BaseTestCase): _wait() self.assertEqual(results1, [(True, 1)] * 3) self.assertEqual(results2, []) + # first wait, to ensure all workers settle into cond.wait() before + # we continue. See issue #8799 + _wait() # Notify 5 threads: they might be in their first or second wait cond.acquire() cond.notify(5) @@ -450,6 +469,7 @@ class ConditionTests(BaseTestCase): _wait() self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2) self.assertEqual(results2, [(True, 2)] * 3) + _wait() # make sure all workers settle into cond.wait() # Notify all threads: they are all in their second wait cond.acquire() cond.notify_all() diff --git a/Lib/test/make_ssl_certs.py b/Lib/test/make_ssl_certs.py index 48d2e57f4b..4251d5524f 100644 --- a/Lib/test/make_ssl_certs.py +++ b/Lib/test/make_ssl_certs.py @@ -2,6 +2,7 @@ and friends.""" import os +import shutil import sys import tempfile from subprocess import * @@ -20,11 +21,54 @@ req_template = """ [req_x509_extensions] subjectAltName = DNS:{hostname} + + [ ca ] + default_ca = CA_default + + [ CA_default ] + dir = cadir + database = $dir/index.txt + crlnumber = $dir/crl.txt + default_md = sha1 + default_days = 3600 + default_crl_days = 3600 + certificate = pycacert.pem + private_key = pycakey.pem + serial = $dir/serial + RANDFILE = $dir/.rand + + policy = policy_match + + [ policy_match ] + countryName = match + stateOrProvinceName = optional + organizationName = match + organizationalUnitName = optional + commonName = supplied + emailAddress = optional + + [ policy_anything ] + countryName = optional + stateOrProvinceName = optional + localityName = optional + organizationName = optional + organizationalUnitName = optional + commonName = supplied + emailAddress = optional + + + [ v3_ca ] + + subjectKeyIdentifier=hash + authorityKeyIdentifier=keyid:always,issuer + basicConstraints = CA:true + """ here = os.path.abspath(os.path.dirname(__file__)) -def make_cert_key(hostname): +def make_cert_key(hostname, sign=False): + print("creating cert for " + hostname) tempnames = [] for i in range(3): with tempfile.NamedTemporaryFile(delete=False) as f: @@ -33,10 +77,25 @@ def make_cert_key(hostname): try: with open(req_file, 'w') as f: f.write(req_template.format(hostname=hostname)) - args = ['req', '-new', '-days', '3650', '-nodes', '-x509', + args = ['req', '-new', '-days', '3650', '-nodes', '-newkey', 'rsa:1024', '-keyout', key_file, - '-out', cert_file, '-config', req_file] + '-config', req_file] + if sign: + with tempfile.NamedTemporaryFile(delete=False) as f: + tempnames.append(f.name) + reqfile = f.name + args += ['-out', reqfile ] + + else: + args += ['-x509', '-out', cert_file ] check_call(['openssl'] + args) + + if sign: + args = ['ca', '-config', req_file, '-out', cert_file, '-outdir', 'cadir', + '-policy', 'policy_anything', '-batch', '-infiles', reqfile ] + check_call(['openssl'] + args) + + with open(cert_file, 'r') as f: cert = f.read() with open(key_file, 'r') as f: @@ -46,6 +105,36 @@ def make_cert_key(hostname): for name in tempnames: os.remove(name) +TMP_CADIR = 'cadir' + +def unmake_ca(): + shutil.rmtree(TMP_CADIR) + +def make_ca(): + os.mkdir(TMP_CADIR) + with open(os.path.join('cadir','index.txt'),'a+') as f: + pass # empty file + with open(os.path.join('cadir','crl.txt'),'a+') as f: + r.write("00") + with open(os.path.join('cadir','index.txt.attr'),'w+') as f: + f.write('unique_subject = no') + + with tempfile.NamedTemporaryFile("w") as t: + t.write(req_template.format(hostname='our-ca-server')) + t.flush() + with tempfile.NamedTemporaryFile() as f: + args = ['req', '-new', '-days', '3650', '-extensions', 'v3_ca', '-nodes', + '-newkey', 'rsa:2048', '-keyout', 'pycakey.pem', + '-out', f.name, + '-subj', '/C=XY/L=Castle Anthrax/O=Python Software Foundation CA/CN=our-ca-server'] + check_call(['openssl'] + args) + args = ['ca', '-config', t.name, '-create_serial', + '-out', 'pycacert.pem', '-batch', '-outdir', TMP_CADIR, + '-keyfile', 'pycakey.pem', '-days', '3650', + '-selfsign', '-extensions', 'v3_ca', '-infiles', f.name ] + check_call(['openssl'] + args) + args = ['ca', '-config', t.name, '-gencrl', '-out', 'revocation.crl'] + check_call(['openssl'] + args) if __name__ == '__main__': os.chdir(here) @@ -54,11 +143,34 @@ if __name__ == '__main__': f.write(cert) with open('ssl_key.pem', 'w') as f: f.write(key) + print("password protecting ssl_key.pem in ssl_key.passwd.pem") + check_call(['openssl','rsa','-in','ssl_key.pem','-out','ssl_key.passwd.pem','-des3','-passout','pass:somepass']) + check_call(['openssl','rsa','-in','ssl_key.pem','-out','keycert.passwd.pem','-des3','-passout','pass:somepass']) + with open('keycert.pem', 'w') as f: f.write(key) f.write(cert) + + with open('keycert.passwd.pem', 'a+') as f: + f.write(cert) + # For certificate matching tests + make_ca() cert, key = make_cert_key('fakehostname') with open('keycert2.pem', 'w') as f: f.write(key) f.write(cert) + + cert, key = make_cert_key('localhost', True) + with open('keycert3.pem', 'w') as f: + f.write(key) + f.write(cert) + + cert, key = make_cert_key('fakehostname', True) + with open('keycert4.pem', 'w') as f: + f.write(key) + f.write(cert) + + unmake_ca() + print("\n\nPlease change the values in test_ssl.py, test_parse_cert function related to notAfter,notBefore and serialNumber") + check_call(['openssl','x509','-in','keycert.pem','-dates','-serial','-noout']) diff --git a/Lib/test/mock_socket.py b/Lib/test/mock_socket.py index 861bfb2b56..e36724f54b 100644 --- a/Lib/test/mock_socket.py +++ b/Lib/test/mock_socket.py @@ -145,12 +145,8 @@ def gethostbyname(name): return "" -class gaierror(Exception): - pass - - -class error(Exception): - pass +gaierror = socket_module.gaierror +error = socket_module.error # Constants diff --git a/Lib/test/mp_fork_bomb.py b/Lib/test/mp_fork_bomb.py index 908afe3045..017e010ba0 100644 --- a/Lib/test/mp_fork_bomb.py +++ b/Lib/test/mp_fork_bomb.py @@ -7,6 +7,11 @@ def foo(): # correctly on Windows. However, we should get a RuntimeError rather # than the Windows equivalent of a fork bomb. +if len(sys.argv) > 1: + multiprocessing.set_start_method(sys.argv[1]) +else: + multiprocessing.set_start_method('spawn') + p = multiprocessing.Process(target=foo) p.start() p.join() diff --git a/Lib/test/multibytecodec_support.py b/Lib/test/multibytecodec_support.py index 14fea3ef7d..51ca3bd64f 100644 --- a/Lib/test/multibytecodec_support.py +++ b/Lib/test/multibytecodec_support.py @@ -281,7 +281,7 @@ class TestBase_Mapping(unittest.TestCase): unittest.TestCase.__init__(self, *args, **kw) try: self.open_mapping_file().close() # test it to report the error early - except (IOError, HTTPException): + except (OSError, HTTPException): self.skipTest("Could not retrieve "+self.mapfileurl) def open_mapping_file(self): diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 4d59bde851..a2bea600ed 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1,9 +1,11 @@ +import copyreg import io -import unittest import pickle import pickletools +import random +import struct import sys -import copyreg +import unittest import weakref from http.cookies import SimpleCookie @@ -93,6 +95,9 @@ class E(C): def __getinitargs__(self): return () +class H(object): + pass + import __main__ __main__.C = C C.__module__ = "__main__" @@ -100,6 +105,8 @@ __main__.D = D D.__module__ = "__main__" __main__.E = E E.__module__ = "__main__" +__main__.H = H +H.__module__ = "__main__" class myint(int): def __init__(self, x): @@ -491,31 +498,52 @@ def create_data(): x.append(5) return x + class AbstractPickleTests(unittest.TestCase): # Subclass must define self.dumps, self.loads. + optimized = False + _testdata = create_data() def setUp(self): pass + def assert_is_copy(self, obj, objcopy, msg=None): + """Utility method to verify if two objects are copies of each others. + """ + if msg is None: + msg = "{!r} is not a copy of {!r}".format(obj, objcopy) + self.assertEqual(obj, objcopy, msg=msg) + self.assertIs(type(obj), type(objcopy), msg=msg) + if hasattr(obj, '__dict__'): + self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg) + self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg) + if hasattr(obj, '__slots__'): + self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg) + for slot in obj.__slots__: + self.assertEqual( + hasattr(obj, slot), hasattr(objcopy, slot), msg=msg) + self.assertEqual(getattr(obj, slot, None), + getattr(objcopy, slot, None), msg=msg) + def test_misc(self): # test various datatypes not tested by testdata for proto in protocols: x = myint(4) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) x = (1, ()) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) x = initarg(1, x) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) # XXX test __reduce__ protocol? @@ -524,16 +552,16 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(expected, proto) got = self.loads(s) - self.assertEqual(expected, got) + self.assert_is_copy(expected, got) def test_load_from_data0(self): - self.assertEqual(self._testdata, self.loads(DATA0)) + self.assert_is_copy(self._testdata, self.loads(DATA0)) def test_load_from_data1(self): - self.assertEqual(self._testdata, self.loads(DATA1)) + self.assert_is_copy(self._testdata, self.loads(DATA1)) def test_load_from_data2(self): - self.assertEqual(self._testdata, self.loads(DATA2)) + self.assert_is_copy(self._testdata, self.loads(DATA2)) def test_load_classic_instance(self): # See issue5180. Test loading 2.x pickles that @@ -555,7 +583,7 @@ class AbstractPickleTests(unittest.TestCase): b"X\n" b"p0\n" b"(dp1\nb.").replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle0)) + self.assert_is_copy(X(*args), self.loads(pickle0)) # Protocol 1 (binary mode pickle) """ @@ -572,7 +600,7 @@ class AbstractPickleTests(unittest.TestCase): pickle1 = (b'(c__main__\n' b'X\n' b'q\x00oq\x01}q\x02b.').replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle1)) + self.assert_is_copy(X(*args), self.loads(pickle1)) # Protocol 2 (pickle2 = b'\x80\x02' + pickle1) """ @@ -590,7 +618,7 @@ class AbstractPickleTests(unittest.TestCase): pickle2 = (b'\x80\x02(c__main__\n' b'X\n' b'q\x00oq\x01}q\x02b.').replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle2)) + self.assert_is_copy(X(*args), self.loads(pickle2)) # There are gratuitous differences between pickles produced by # pickle and cPickle, largely because cPickle starts PUT indices at @@ -615,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIsInstance(x, list) self.assertEqual(len(x), 1) self.assertTrue(x is x[0]) @@ -624,6 +653,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(t, proto) x = self.loads(s) + self.assertIsInstance(x, tuple) self.assertEqual(len(x), 1) self.assertEqual(len(x[0]), 1) self.assertTrue(x is x[0][0]) @@ -634,15 +664,39 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(d, proto) x = self.loads(s) + self.assertIsInstance(x, dict) self.assertEqual(list(x.keys()), [1]) self.assertTrue(x[1] is x) + def test_recursive_set(self): + h = H() + y = set({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, set) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + + def test_recursive_frozenset(self): + h = H() + y = frozenset({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, frozenset) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + def test_recursive_inst(self): i = C() i.attr = i for proto in protocols: s = self.dumps(i, proto) x = self.loads(s) + self.assertIsInstance(x, C) self.assertEqual(dir(x), dir(i)) self.assertIs(x.attr, x) @@ -655,6 +709,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIsInstance(x, list) self.assertEqual(len(x), 1) self.assertEqual(dir(x[0]), dir(i)) self.assertEqual(list(x[0].attr.keys()), [1]) @@ -662,31 +717,8 @@ class AbstractPickleTests(unittest.TestCase): def test_get(self): self.assertRaises(KeyError, self.loads, b'g0\np0') - self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)]) - - def test_insecure_strings(self): - # XXX Some of these tests are temporarily disabled - insecure = [b"abc", b"2 + 2", # not quoted - ## b"'abc' + 'def'", # not a single quoted string - b"'abc", # quote is not closed - b"'abc\"", # open quote and close quote don't match - b"'abc' ?", # junk after close quote - b"'\\'", # trailing backslash - # Variations on issue #17710 - b"'", - b'"', - b"' ", - b"' ", - b"' ", - b"' ", - b'" ', - # some tests of the quoting rules - ## b"'abc\"\''", - ## b"'\\\\a\'\'\'\\\'\\\\\''", - ] - for b in insecure: - buf = b"S" + b + b"\012p0\012." - self.assertRaises(ValueError, self.loads, buf) + self.assert_is_copy([(100,), (100,)], + self.loads(b'((Kdtp0\nh\x00l.))')) def test_unicode(self): endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', @@ -697,26 +729,26 @@ class AbstractPickleTests(unittest.TestCase): for u in endcases: p = self.dumps(u, proto) u2 = self.loads(p) - self.assertEqual(u2, u) + self.assert_is_copy(u, u2) def test_unicode_high_plane(self): t = '\U00012345' for proto in protocols: p = self.dumps(t, proto) t2 = self.loads(p) - self.assertEqual(t2, t) + self.assert_is_copy(t, t2) def test_bytes(self): for proto in protocols: for s in b'', b'xyz', b'xyz'*100: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) for s in [bytes([i]) for i in range(256)]: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) for s in [bytes([i, i]) for i in range(256)]: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) def test_ints(self): import sys @@ -726,14 +758,14 @@ class AbstractPickleTests(unittest.TestCase): for expected in (-n, n): s = self.dumps(expected, proto) n2 = self.loads(s) - self.assertEqual(expected, n2) + self.assert_is_copy(expected, n2) n = n >> 1 def test_maxint64(self): maxint64 = (1 << 63) - 1 data = b'I' + str(maxint64).encode("ascii") + b'\n.' got = self.loads(data) - self.assertEqual(got, maxint64) + self.assert_is_copy(maxint64, got) # Try too with a bogus literal. data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.' @@ -748,7 +780,7 @@ class AbstractPickleTests(unittest.TestCase): for n in npos, -npos: pickle = self.dumps(n, proto) got = self.loads(pickle) - self.assertEqual(n, got) + self.assert_is_copy(n, got) # Try a monster. This is quadratic-time in protos 0 & 1, so don't # bother with those. nbase = int("deadbeeffeedface", 16) @@ -756,6 +788,10 @@ class AbstractPickleTests(unittest.TestCase): for n in nbase, -nbase: p = self.dumps(n, 2) got = self.loads(p) + # assert_is_copy is very expensive here as it precomputes + # a failure message by computing the repr() of n and got, + # we just do the check ourselves. + self.assertIs(type(got), int) self.assertEqual(n, got) def test_float(self): @@ -766,7 +802,7 @@ class AbstractPickleTests(unittest.TestCase): for value in test_values: pickle = self.dumps(value, proto) got = self.loads(pickle) - self.assertEqual(value, got) + self.assert_is_copy(value, got) @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') def test_float_format(self): @@ -774,10 +810,18 @@ class AbstractPickleTests(unittest.TestCase): self.assertEqual(self.dumps(1.2, 0)[0:3], b'F1.') def test_reduce(self): - pass + for proto in protocols: + inst = AAA() + dumped = self.dumps(inst, proto) + loaded = self.loads(dumped) + self.assertEqual(loaded, REDUCE_A) def test_getinitargs(self): - pass + for proto in protocols: + inst = initarg(1, 2) + dumped = self.dumps(inst, proto) + loaded = self.loads(dumped) + self.assert_is_copy(inst, loaded) def test_pop_empty_stack(self): # Test issue7455 @@ -798,6 +842,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(a, proto) b = self.loads(s) self.assertEqual(a, b) + self.assertIs(type(a), type(b)) def test_structseq(self): import time @@ -807,29 +852,29 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) if hasattr(os, "stat"): t = os.stat(os.curdir) s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) if hasattr(os, "statvfs"): t = os.statvfs(os.curdir) s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) def test_ellipsis(self): for proto in protocols: s = self.dumps(..., proto) u = self.loads(s) - self.assertEqual(..., u) + self.assertIs(..., u) def test_notimplemented(self): for proto in protocols: s = self.dumps(NotImplemented, proto) u = self.loads(s) - self.assertEqual(NotImplemented, u) + self.assertIs(NotImplemented, u) def test_singleton_types(self): # Issue #6477: Test that types of built-in singletons can be pickled. @@ -843,21 +888,21 @@ class AbstractPickleTests(unittest.TestCase): # Tests for protocol 2 def test_proto(self): - build_none = pickle.NONE + pickle.STOP for proto in protocols: - expected = build_none + pickled = self.dumps(None, proto) if proto >= 2: - expected = pickle.PROTO + bytes([proto]) + expected - p = self.dumps(None, proto) - self.assertEqual(p, expected) + proto_header = pickle.PROTO + bytes([proto]) + self.assertTrue(pickled.startswith(proto_header)) + else: + self.assertEqual(count_opcode(pickle.PROTO, pickled), 0) oob = protocols[-1] + 1 # a future protocol + build_none = pickle.NONE + pickle.STOP badpickle = pickle.PROTO + bytes([oob]) + build_none try: self.loads(badpickle) - except ValueError as detail: - self.assertTrue(str(detail).startswith( - "unsupported pickle protocol")) + except ValueError as err: + self.assertIn("unsupported pickle protocol", str(err)) else: self.fail("expected bad protocol number to raise ValueError") @@ -866,7 +911,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) def test_long4(self): @@ -874,7 +919,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) def test_short_tuples(self): @@ -912,9 +957,9 @@ class AbstractPickleTests(unittest.TestCase): for x in a, b, c, d, e: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y, (proto, x, s, y)) - expected = expected_opcode[proto, len(x)] - self.assertEqual(opcode_in_pickle(expected, s), True) + self.assert_is_copy(x, y) + expected = expected_opcode[min(proto, 3), len(x)] + self.assertTrue(opcode_in_pickle(expected, s)) def test_singletons(self): # Map (proto, singleton) to expected opcode. @@ -938,8 +983,8 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) y = self.loads(s) self.assertTrue(x is y, (proto, x, s, y)) - expected = expected_opcode[proto, x] - self.assertEqual(opcode_in_pickle(expected, s), True) + expected = expected_opcode[min(proto, 3), x] + self.assertTrue(opcode_in_pickle(expected, s)) def test_newobj_tuple(self): x = MyTuple([1, 2, 3]) @@ -948,8 +993,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(tuple(x), tuple(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_list(self): x = MyList([1, 2, 3]) @@ -958,8 +1002,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_generic(self): for proto in protocols: @@ -970,6 +1013,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) y = self.loads(s) detail = (proto, C, B, x, y, type(y)) + self.assert_is_copy(x, y) # XXX revisit self.assertEqual(B(x), B(y), detail) self.assertEqual(x.__dict__, y.__dict__, detail) @@ -1008,11 +1052,10 @@ class AbstractPickleTests(unittest.TestCase): s1 = self.dumps(x, 1) self.assertIn(__name__.encode("utf-8"), s1) self.assertIn(b"MyList", s1) - self.assertEqual(opcode_in_pickle(opcode, s1), False) + self.assertFalse(opcode_in_pickle(opcode, s1)) y = self.loads(s1) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) # Dump using protocol 2 for test. s2 = self.dumps(x, 2) @@ -1021,9 +1064,7 @@ class AbstractPickleTests(unittest.TestCase): self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2)) y = self.loads(s2) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) - + self.assert_is_copy(x, y) finally: e.restore() @@ -1047,7 +1088,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_appends = count_opcode(pickle.APPENDS, s) self.assertEqual(num_appends, proto > 0) @@ -1056,7 +1097,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_appends = count_opcode(pickle.APPENDS, s) if proto == 0: self.assertEqual(num_appends, 0) @@ -1070,7 +1111,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) self.assertIsInstance(s, bytes_types) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) self.assertEqual(num_setitems, proto > 0) @@ -1079,22 +1120,49 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) if proto == 0: self.assertEqual(num_setitems, 0) else: self.assertTrue(num_setitems >= 2) + def test_set_chunking(self): + n = 10 # too small to chunk + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertEqual(num_additems, 1) + + n = 2500 # expect at least two chunks when proto >= 4 + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertGreaterEqual(num_additems, 2) + def test_simple_newobj(self): x = object.__new__(SimpleNewObj) # avoid __init__ x.abc = 666 for proto in protocols: s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto < 4) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + proto >= 4) y = self.loads(s) # will raise TypeError if __init__ called - self.assertEqual(y.abc, 666) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_list_slots(self): x = SlotList([1, 2, 3]) @@ -1102,10 +1170,7 @@ class AbstractPickleTests(unittest.TestCase): x.bar = "hello" s = self.dumps(x, 2) y = self.loads(s) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) - self.assertEqual(x.foo, y.foo) - self.assertEqual(x.bar, y.bar) + self.assert_is_copy(x, y) def test_reduce_overrides_default_reduce_ex(self): for proto in protocols: @@ -1154,11 +1219,10 @@ class AbstractPickleTests(unittest.TestCase): @no_tracing def test_bad_getattr(self): + # Issue #3514: crash when there is an infinite loop in __getattr__ x = BadGetattr() - for proto in 0, 1: + for proto in protocols: self.assertRaises(RuntimeError, self.dumps, x, proto) - # protocol 2 don't raise a RuntimeError. - d = self.dumps(x, 2) def test_reduce_bad_iterator(self): # Issue4176: crash when 4th and 5th items of __reduce__() @@ -1191,11 +1255,10 @@ class AbstractPickleTests(unittest.TestCase): obj = [dict(large_dict), dict(large_dict), dict(large_dict)] for proto in protocols: - dumped = self.dumps(obj, proto) - loaded = self.loads(dumped) - self.assertEqual(loaded, obj, - "Failed protocol %d: %r != %r" - % (proto, obj, loaded)) + with self.subTest(proto=proto): + dumped = self.dumps(obj, proto) + loaded = self.loads(dumped) + self.assert_is_copy(obj, loaded) def test_attribute_name_interning(self): # Test that attribute names of pickled objects are interned when @@ -1248,6 +1311,35 @@ class AbstractPickleTests(unittest.TestCase): dumped = self.dumps(set([3]), 2) self.assertEqual(dumped, DATA6) + def test_load_python2_str_as_bytes(self): + # From Python 2: pickle.dumps('a\x00\xa0', protocol=0) + self.assertEqual(self.loads(b"S'a\\x00\\xa0'\n.", + encoding="bytes"), b'a\x00\xa0') + # From Python 2: pickle.dumps('a\x00\xa0', protocol=1) + self.assertEqual(self.loads(b'U\x03a\x00\xa0.', + encoding="bytes"), b'a\x00\xa0') + # From Python 2: pickle.dumps('a\x00\xa0', protocol=2) + self.assertEqual(self.loads(b'\x80\x02U\x03a\x00\xa0.', + encoding="bytes"), b'a\x00\xa0') + + def test_load_python2_unicode_as_str(self): + # From Python 2: pickle.dumps(u'π', protocol=0) + self.assertEqual(self.loads(b'V\\u03c0\n.', + encoding='bytes'), 'π') + # From Python 2: pickle.dumps(u'π', protocol=1) + self.assertEqual(self.loads(b'X\x02\x00\x00\x00\xcf\x80.', + encoding="bytes"), 'π') + # From Python 2: pickle.dumps(u'π', protocol=2) + self.assertEqual(self.loads(b'\x80\x02X\x02\x00\x00\x00\xcf\x80.', + encoding="bytes"), 'π') + + def test_load_long_python2_str_as_bytes(self): + # From Python 2: pickle.dumps('x' * 300, protocol=1) + self.assertEqual(self.loads(pickle.BINSTRING + + struct.pack("<I", 300) + + b'x' * 300 + pickle.STOP, + encoding='bytes'), b'x' * 300) + def test_large_pickles(self): # Test the correctness of internal buffering routines when handling # large data. @@ -1266,11 +1358,14 @@ class AbstractPickleTests(unittest.TestCase): def test_int_pickling_efficiency(self): # Test compacity of int representation (see issue #12744) for proto in protocols: - sizes = [len(self.dumps(2**n, proto)) for n in range(70)] - # the size function is monotonic - self.assertEqual(sorted(sizes), sizes) - if proto >= 2: - self.assertLessEqual(sizes[-1], 14) + with self.subTest(proto=proto): + pickles = [self.dumps(2**n, proto) for n in range(70)] + sizes = list(map(len, pickles)) + # the size function is monotonic + self.assertEqual(sorted(sizes), sizes) + if proto >= 2: + for p in pickles: + self.assertFalse(opcode_in_pickle(pickle.LONG, p)) def check_negative_32b_binXXX(self, dumped): if sys.maxsize > 2**32: @@ -1301,6 +1396,35 @@ class AbstractPickleTests(unittest.TestCase): dumped = b'\x80\x03X\x01\x00\x00\x00ar\xff\xff\xff\xff.' self.assertRaises(ValueError, self.loads, dumped) + def test_badly_escaped_string(self): + self.assertRaises(ValueError, self.loads, b"S'\\'\n.") + + def test_badly_quoted_string(self): + # Issue #17710 + badpickles = [b"S'\n.", + b'S"\n.', + b'S\' \n.', + b'S" \n.', + b'S\'"\n.', + b'S"\'\n.', + b"S' ' \n.", + b'S" " \n.', + b"S ''\n.", + b'S ""\n.', + b'S \n.', + b'S\n.', + b'S.'] + for p in badpickles: + self.assertRaises(pickle.UnpicklingError, self.loads, p) + + def test_correctly_quoted_string(self): + goodpickles = [(b"S''\n.", ''), + (b'S""\n.', ''), + (b'S"\\n"\n.', '\n'), + (b"S'\\n'\n.", '\n')] + for p, expected in goodpickles: + self.assertEqual(self.loads(p), expected) + def _check_pickling_with_opcode(self, obj, opcode, proto): pickled = self.dumps(obj, proto) self.assertTrue(opcode_in_pickle(opcode, pickled)) @@ -1324,6 +1448,194 @@ class AbstractPickleTests(unittest.TestCase): else: self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto) + # Exercise framing (proto >= 4) for significant workloads + + FRAME_SIZE_TARGET = 64 * 1024 + + def check_frame_opcodes(self, pickled): + """ + Check the arguments of FRAME opcodes in a protocol 4+ pickle. + """ + frame_opcode_size = 9 + last_arg = last_pos = None + for op, arg, pos in pickletools.genops(pickled): + if op.name != 'FRAME': + continue + if last_pos is not None: + # The previous frame's size should be equal to the number + # of bytes up to the current frame. + frame_size = pos - last_pos - frame_opcode_size + self.assertEqual(frame_size, last_arg) + last_arg, last_pos = arg, pos + # The last frame's size should be equal to the number of bytes up + # to the pickle's end. + frame_size = len(pickled) - last_pos - frame_opcode_size + self.assertEqual(frame_size, last_arg) + + def test_framing_many_objects(self): + obj = list(range(10**5)) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + bytes_per_frame = (len(pickled) / + count_opcode(pickle.FRAME, pickled)) + self.assertGreater(bytes_per_frame, + self.FRAME_SIZE_TARGET / 2) + self.assertLessEqual(bytes_per_frame, + self.FRAME_SIZE_TARGET * 1) + self.check_frame_opcodes(pickled) + + def test_framing_large_objects(self): + N = 1024 * 1024 + obj = [b'x' * N, b'y' * N, b'z' * N] + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + n_frames = count_opcode(pickle.FRAME, pickled) + self.assertGreaterEqual(n_frames, len(obj)) + self.check_frame_opcodes(pickled) + + def test_optional_frames(self): + if pickle.HIGHEST_PROTOCOL < 4: + return + + def remove_frames(pickled, keep_frame=None): + """Remove frame opcodes from the given pickle.""" + frame_starts = [] + # 1 byte for the opcode and 8 for the argument + frame_opcode_size = 9 + for opcode, _, pos in pickletools.genops(pickled): + if opcode.name == 'FRAME': + frame_starts.append(pos) + + newpickle = bytearray() + last_frame_end = 0 + for i, pos in enumerate(frame_starts): + if keep_frame and keep_frame(i): + continue + newpickle += pickled[last_frame_end:pos] + last_frame_end = pos + frame_opcode_size + newpickle += pickled[last_frame_end:] + return newpickle + + frame_size = self.FRAME_SIZE_TARGET + num_frames = 20 + obj = [bytes([i]) * frame_size for i in range(num_frames)] + + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + pickled = self.dumps(obj, proto) + + frameless_pickle = remove_frames(pickled) + self.assertEqual(count_opcode(pickle.FRAME, frameless_pickle), 0) + self.assertEqual(obj, self.loads(frameless_pickle)) + + some_frames_pickle = remove_frames(pickled, lambda i: i % 2) + self.assertLess(count_opcode(pickle.FRAME, some_frames_pickle), + count_opcode(pickle.FRAME, pickled)) + self.assertEqual(obj, self.loads(some_frames_pickle)) + + def test_nested_names(self): + global Nested + class Nested: + class A: + class B: + class C: + pass + + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: + with self.subTest(proto=proto, obj=obj): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertIs(obj, unpickled) + + def test_py_methods(self): + global PyMethodsTest + class PyMethodsTest: + @staticmethod + def cheese(): + return "cheese" + @classmethod + def wine(cls): + assert cls is PyMethodsTest + return "wine" + def biscuits(self): + assert isinstance(self, PyMethodsTest) + return "biscuits" + class Nested: + "Nested class" + @staticmethod + def ketchup(): + return "ketchup" + @classmethod + def maple(cls): + assert cls is PyMethodsTest.Nested + return "maple" + def pie(self): + assert isinstance(self, PyMethodsTest.Nested) + return "pie" + + py_methods = ( + PyMethodsTest.cheese, + PyMethodsTest.wine, + PyMethodsTest().biscuits, + PyMethodsTest.Nested.ketchup, + PyMethodsTest.Nested.maple, + PyMethodsTest.Nested().pie + ) + py_unbound_methods = ( + (PyMethodsTest.biscuits, PyMethodsTest), + (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method in py_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(), unpickled()) + for method, cls in py_unbound_methods: + obj = cls() + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(obj), unpickled(obj)) + + def test_c_methods(self): + global Subclass + class Subclass(tuple): + class Nested(str): + pass + + c_methods = ( + # bound built-in method + ("abcd".index, ("c",)), + # unbound built-in method + (str.index, ("abcd", "c")), + # bound "slot" method + ([1, 2, 3].__len__, ()), + # unbound "slot" method + (list.__len__, ([1, 2, 3],)), + # bound "coexist" method + ({1, 2}.__contains__, (2,)), + # unbound "coexist" method + (set.__contains__, ({1, 2}, 2)), + # built-in class method + (dict.fromkeys, (("a", 1), ("b", 2))), + # built-in static method + (bytearray.maketrans, (b"abc", b"xyz")), + # subclass methods + (Subclass([1,2,2]).count, (2,)), + (Subclass.count, (Subclass([1,2,2]), 2)), + (Subclass.Nested("sweet").count, ("e",)), + (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method, args in c_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(*args), unpickled(*args)) + class BigmemPickleTests(unittest.TestCase): @@ -1336,8 +1648,9 @@ class BigmemPickleTests(unittest.TestCase): for proto in protocols: if proto < 2: continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) finally: data = None @@ -1352,24 +1665,44 @@ class BigmemPickleTests(unittest.TestCase): for proto in protocols: if proto < 3: continue - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + with self.subTest(proto=proto): + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINBYTES + + struct.pack("<I", len(data))) + data_start = pickled.index(data) + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + finally: + pickled = None finally: data = None @bigmemtest(size=_4G, memuse=2.5, dry_run=False) def test_huge_bytes_64b(self, size): - data = b"a" * size + data = b"acbd" * (size // 4) try: for proto in protocols: if proto < 3: continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto == 3: + # Protocol 3 does not support large bytes objects. + # Verify that we do not crash when processing one. + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) + continue + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINBYTES8 + + struct.pack("<Q", len(data))) + data_start = pickled.index(data) + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + finally: + pickled = None finally: data = None @@ -1381,27 +1714,52 @@ class BigmemPickleTests(unittest.TestCase): data = "abcd" * (size // 4) try: for proto in protocols: - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + if proto == 0: + continue + with self.subTest(proto=proto): + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINUNICODE + + struct.pack("<I", len(data))) + data_start = pickled.index(b'abcd') + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + self.assertEqual((pickled.rindex(b"abcd") + len(b"abcd") - + pickled.index(b"abcd")), len(data)) + finally: + pickled = None finally: data = None - # BINUNICODE (protocols 1, 2 and 3) cannot carry more than - # 2**32 - 1 bytes of utf-8 encoded unicode. + # BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes + # of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge + # unicode strings however. @bigmemtest(size=_4G, memuse=8, dry_run=False) def test_huge_str_64b(self, size): - data = "a" * size + data = "abcd" * (size // 4) try: for proto in protocols: if proto == 0: continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto < 4: + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) + continue + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINUNICODE8 + + struct.pack("<Q", len(data))) + data_start = pickled.index(b'abcd') + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + self.assertEqual((pickled.rindex(b"abcd") + len(b"abcd") - + pickled.index(b"abcd")), len(data)) + finally: + pickled = None finally: data = None @@ -1445,8 +1803,8 @@ class REX_five(object): return object.__reduce__(self) class REX_six(object): - """This class is used to check the 4th argument (list iterator) of the reduce - protocol. + """This class is used to check the 4th argument (list iterator) of + the reduce protocol. """ def __init__(self, items=None): self.items = items if items is not None else [] @@ -1458,8 +1816,8 @@ class REX_six(object): return type(self), (), None, iter(self.items), None class REX_seven(object): - """This class is used to check the 5th argument (dict iterator) of the reduce - protocol. + """This class is used to check the 5th argument (dict iterator) of + the reduce protocol. """ def __init__(self, table=None): self.table = table if table is not None else {} @@ -1497,10 +1855,16 @@ class MyList(list): class MyDict(dict): sample = {"a": 1, "b": 2} +class MySet(set): + sample = {"a", "b"} + +class MyFrozenSet(frozenset): + sample = frozenset({"a", "b"}) + myclasses = [MyInt, MyFloat, MyComplex, MyStr, MyUnicode, - MyTuple, MyList, MyDict] + MyTuple, MyList, MyDict, MySet, MyFrozenSet] class SlotList(MyList): @@ -1510,6 +1874,8 @@ class SimpleNewObj(object): def __init__(self, a, b, c): # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") + def __eq__(self, other): + return self.__dict__ == other.__dict__ class BadGetattr: def __getattr__(self, key): @@ -1546,7 +1912,7 @@ class AbstractPickleModuleTests(unittest.TestCase): def test_highest_protocol(self): # Of course this needs to be changed when HIGHEST_PROTOCOL changes. - self.assertEqual(pickle.HIGHEST_PROTOCOL, 3) + self.assertEqual(pickle.HIGHEST_PROTOCOL, 4) def test_callapi(self): f = io.BytesIO() @@ -1731,22 +2097,23 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase): def _check_multiple_unpicklings(self, ioclass): for proto in protocols: - data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] - f = ioclass() - pickler = self.pickler_class(f, protocol=proto) - pickler.dump(data1) - pickled = f.getvalue() - - N = 5 - f = ioclass(pickled * N) - unpickler = self.unpickler_class(f) - for i in range(N): - if f.seekable(): - pos = f.tell() - self.assertEqual(unpickler.load(), data1) - if f.seekable(): - self.assertEqual(f.tell(), pos + len(pickled)) - self.assertRaises(EOFError, unpickler.load) + with self.subTest(proto=proto): + data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] + f = ioclass() + pickler = self.pickler_class(f, protocol=proto) + pickler.dump(data1) + pickled = f.getvalue() + + N = 5 + f = ioclass(pickled * N) + unpickler = self.unpickler_class(f) + for i in range(N): + if f.seekable(): + pos = f.tell() + self.assertEqual(unpickler.load(), data1) + if f.seekable(): + self.assertEqual(f.tell(), pos + len(pickled)) + self.assertRaises(EOFError, unpickler.load) def test_multiple_unpicklings_seekable(self): self._check_multiple_unpicklings(io.BytesIO) diff --git a/Lib/test/pycacert.pem b/Lib/test/pycacert.pem new file mode 100644 index 0000000000..09b1f3e08a --- /dev/null +++ b/Lib/test/pycacert.pem @@ -0,0 +1,78 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 12723342612721443280 (0xb09264b1f2da21d0) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Jan 2 19:47:07 2023 GMT + Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2: + 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4: + e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f: + e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f: + 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf: + 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d: + a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3: + e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4: + 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf: + 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c: + e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6: + c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a: + cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01: + 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87: + 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f: + 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14: + e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4: + c5:4d + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + X509v3 Authority Key Identifier: + keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6: + 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d: + a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95: + 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17: + 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c: + 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4: + fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7: + 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24: + 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33: + 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61: + ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f: + 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64: + b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb: + 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3: + 5e:58:c8:9e +-----BEGIN CERTIFICATE----- +MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx +OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV +q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/ +AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA +Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni +0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx +6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w +HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2 +2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB +AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4 +QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1 +Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O +JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR +f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf +9mmvtk57HVjsO6lTo15YyJ4= +-----END CERTIFICATE----- diff --git a/Lib/test/pycakey.pem b/Lib/test/pycakey.pem new file mode 100644 index 0000000000..fc6effefb2 --- /dev/null +++ b/Lib/test/pycakey.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDn3unjDJ8AtqH9 +K1uW0m/M4L6GuSBe7AN6VavqpOn5SYXSZtXtx3rqVo4tj+dC4mIoqZ/WG47rtbSc +nxSr3+aUi3YdPm0kYe0MvwCKYQzfXMg2cxYAzUe6baSkdIiDIwoZ/AmnPEpL0+cd +LeTqTFQh8ybbiTcY1AK7QDJfpP8tHPfUu+yOz1yCrOZ8CGxIhWEHfyXgXOC8NF/g +uQRHdchHC4281shoXzODYtIgRDWxrYEais28NbBci0fWGOmcGJfMATwpzOge5OTB +uN7nwhEYh1qTNNimJfcUcevkIaLSDy4u1GIANdPW71xgS0ypFOLdFVhGNzMmt+cu +Xe1C5MVNAgMBAAECggEBAJPM7QuUrPn4cLN/Ysd15lwTWn9oHDFFgkYFvCs66gXE +ju/6Kx2BjWE4wTJby09AHM/MqB0DvguT7Mf1Q2j3tPQ1HZowg8OwRDleuwp6KIls +jBbhL0Jdl/5HC67ktWvZ9wNvO/wFG1rQfT6FVajf9LUbWEaSZbOG2SLhHfsHorzu +xjTJaI3bQ/0+79B1exwk5ruwhzFRd/XpY8hls7D/RfPIuHDlBghkW3N59KFWrf5h +6bNEh2THm0+IyGcGqs0FD+QCOXyvsjwSUswqrr2ctLREOeDcd5ReUjSxYgjcJRrm +J7ceIY/+uwDJxw/OlnmBvF6pQMkKwYW2gFztu+g2t4UCgYEA/9yo01Exz4crxXsy +tAlnDJM++nZcm07rtFjTKHUfKY/cCgNTa8udM0svnfwlid/dpgLsI38gx04HHC1i +EZ4acz+ToIWedLxM0nq73//xeRWEazOvCz1mMTZaMldahTWAyzN8qVK2B/625Yy4 +wNYWyweBBwEB8MzaCs73spksXOsCgYEA5/7wvhiofYGFAfMuANeJIwDL2OtBnoOv +mVNfCmi3GC38fzwyi5ZpskWDiS2woJ+LQfs9Qu4EcZbUFLd7gbeOvb5gmFUtYope +LitUUKunIR18MkQ+mQDBpQPQPhk4QJP5reCbWkrfTu7b5o/iS41s6fBTFmuzhLcT +C71vFdCyeKcCgYAiCCqYeOtELDmBOeLDmaCQRqGQ1N96dOPbCBmF/xYXBCCDYG/f +HaUaJnz96YTgstsbcrYP/p/Qgqtlbw/lQf9IpwMuzbcG1ejt8g89OyDWNyt2ytgU +iaUnFJCos3/Byh0Iah/BsdOueo2/OJl2ZMOBW80orlSgv86cs2y037TL4wKBgQDm +OOyW+MlbowhnIvfoBfwlLEkefnej4nKD6WRLZBcue5Qyf355X06Mhsc9foXlH+6G +D9h/bswiHNdhp6N82rdgPGiHQx/CxiUoE/+b/nvgNO5mw6qLE2EXbG1e8pAMJcyE +bHw+YkawggDfELI036fRj5gki8SeUz8nS1nNgElbyQKBgCRDX9Jh+MwSLu4QBWdt +/fi+lv3K6kun/fI7EOV1vCV/j871tICu7pu5BrOLxAHqoVfU9AUX299/2KjCb5pv +kjogiUK6qWCWBlfuqDNWGCoUGt1rhznUva0nNjSMy5rinBhhjpROZC2pw48lOluP +UuvXsaPph7GTqPuy4Kab12YC +-----END PRIVATE KEY----- diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py index ae62c6e7a0..c1c831ff86 100755 --- a/Lib/test/regrtest.py +++ b/Lib/test/regrtest.py @@ -1,11 +1,18 @@ #! /usr/bin/env python3 """ -Usage: +Script to run Python regression tests. +Run this script with -h or --help for documentation. +""" + +USAGE = """\ python -m test [options] [test_name1 [test_name2 ...]] python path/to/Lib/test/regrtest.py [options] [test_name1 [test_name2 ...]] +""" +DESCRIPTION = """\ +Run Python regression tests. If no arguments or options are provided, finds all files matching the pattern "test_*" in the Lib/test subdirectory and runs @@ -15,63 +22,10 @@ For more rigorous testing, it is useful to use the following command line: python -E -Wd -m test [options] [test_name1 ...] +""" - -Options: - --h/--help -- print this text and exit ---timeout TIMEOUT - -- dump the traceback and exit if a test takes more - than TIMEOUT seconds; disabled if TIMEOUT is negative - or equals to zero ---wait -- wait for user input, e.g., allow a debugger to be attached - -Verbosity - --v/--verbose -- run tests in verbose mode with output to stdout --w/--verbose2 -- re-run failed tests in verbose mode --W/--verbose3 -- display test output on failure --d/--debug -- print traceback for failed tests --q/--quiet -- no output unless one or more tests fail --o/--slow -- print the slowest 10 tests - --header -- print header with interpreter info - -Selecting tests - --r/--randomize -- randomize test execution order (see below) - --randseed -- pass a random seed to reproduce a previous random run --f/--fromfile -- read names of tests to run from a file (see below) --x/--exclude -- arguments are tests to *exclude* --s/--single -- single step through a set of tests (see below) --m/--match PAT -- match test cases and methods with glob pattern PAT --G/--failfast -- fail as soon as a test fails (only with -v or -W) --u/--use RES1,RES2,... - -- specify which special resource intensive tests to run --M/--memlimit LIMIT - -- run very large memory-consuming tests - --testdir DIR - -- execute test files in the specified directory (instead - of the Python stdlib test suite) - -Special runs - --l/--findleaks -- if GC is available detect tests that leak memory --L/--runleaks -- run the leaks(1) command just before exit --R/--huntrleaks RUNCOUNTS - -- search for reference leaks (needs debug build, v. slow) --j/--multiprocess PROCESSES - -- run PROCESSES processes at once --T/--coverage -- turn on code coverage tracing using the trace module --D/--coverdir DIRECTORY - -- Directory where coverage files are put --N/--nocoverdir -- Put coverage files alongside modules --t/--threshold THRESHOLD - -- call gc.set_threshold(THRESHOLD) --n/--nowindows -- suppress error message boxes on Windows --F/--forever -- run the specified tests in a loop, until an error happens - - -Additional Option Details: +EPILOG = """\ +Additional option details: -r randomizes test execution order. You can use --randseed=int to provide a int seed value for the randomizer; this is useful for reproducing troublesome @@ -168,11 +122,12 @@ option '-uall,-gui'. # We import importlib *ASAP* in order to test #15386 import importlib +import argparse import builtins import faulthandler -import getopt import io import json +import locale import logging import os import platform @@ -194,7 +149,7 @@ try: except ImportError: threading = None try: - import multiprocessing.process + import _multiprocessing, multiprocessing.process except ImportError: multiprocessing = None @@ -246,20 +201,262 @@ from test import support RESOURCE_NAMES = ('audio', 'curses', 'largefile', 'network', 'decimal', 'cpu', 'subprocess', 'urlfetch', 'gui') -TEMPDIR = os.path.abspath(tempfile.gettempdir()) - -def usage(msg): - print(msg, file=sys.stderr) - print("Use --help for usage", file=sys.stderr) - sys.exit(2) - +# When tests are run from the Python build directory, it is best practice +# to keep the test files in a subfolder. This eases the cleanup of leftover +# files using the "make distclean" command. +if sysconfig.is_python_build(): + TEMPDIR = os.path.join(sysconfig.get_config_var('srcdir'), 'build') +else: + TEMPDIR = tempfile.gettempdir() +TEMPDIR = os.path.abspath(TEMPDIR) + +class _ArgParser(argparse.ArgumentParser): + + def error(self, message): + super().error(message + "\nPass -h or --help for complete help.") + +def _create_parser(): + # Set prog to prevent the uninformative "__main__.py" from displaying in + # error messages when using "python -m test ...". + parser = _ArgParser(prog='regrtest.py', + usage=USAGE, + description=DESCRIPTION, + epilog=EPILOG, + add_help=False, + formatter_class=argparse.RawDescriptionHelpFormatter) + + # Arguments with this clause added to its help are described further in + # the epilog's "Additional option details" section. + more_details = ' See the section at bottom for more details.' + + group = parser.add_argument_group('General options') + # We add help explicitly to control what argument group it renders under. + group.add_argument('-h', '--help', action='help', + help='show this help message and exit') + group.add_argument('--timeout', metavar='TIMEOUT', type=float, + help='dump the traceback and exit if a test takes ' + 'more than TIMEOUT seconds; disabled if TIMEOUT ' + 'is negative or equals to zero') + group.add_argument('--wait', action='store_true', + help='wait for user input, e.g., allow a debugger ' + 'to be attached') + group.add_argument('--slaveargs', metavar='ARGS') + group.add_argument('-S', '--start', metavar='START', + help='the name of the test at which to start.' + + more_details) + + group = parser.add_argument_group('Verbosity') + group.add_argument('-v', '--verbose', action='count', + help='run tests in verbose mode with output to stdout') + group.add_argument('-w', '--verbose2', action='store_true', + help='re-run failed tests in verbose mode') + group.add_argument('-W', '--verbose3', action='store_true', + help='display test output on failure') + group.add_argument('-q', '--quiet', action='store_true', + help='no output unless one or more tests fail') + group.add_argument('-o', '--slow', action='store_true', dest='print_slow', + help='print the slowest 10 tests') + group.add_argument('--header', action='store_true', + help='print header with interpreter info') + + group = parser.add_argument_group('Selecting tests') + group.add_argument('-r', '--randomize', action='store_true', + help='randomize test execution order.' + more_details) + group.add_argument('--randseed', metavar='SEED', + dest='random_seed', type=int, + help='pass a random seed to reproduce a previous ' + 'random run') + group.add_argument('-f', '--fromfile', metavar='FILE', + help='read names of tests to run from a file.' + + more_details) + group.add_argument('-x', '--exclude', action='store_true', + help='arguments are tests to *exclude*') + group.add_argument('-s', '--single', action='store_true', + help='single step through a set of tests.' + + more_details) + group.add_argument('-m', '--match', metavar='PAT', + dest='match_tests', + help='match test cases and methods with glob pattern PAT') + group.add_argument('-G', '--failfast', action='store_true', + help='fail as soon as a test fails (only with -v or -W)') + group.add_argument('-u', '--use', metavar='RES1,RES2,...', + action='append', type=resources_list, + help='specify which special resource intensive tests ' + 'to run.' + more_details) + group.add_argument('-M', '--memlimit', metavar='LIMIT', + help='run very large memory-consuming tests.' + + more_details) + group.add_argument('--testdir', metavar='DIR', + type=relative_filename, + help='execute test files in the specified directory ' + '(instead of the Python stdlib test suite)') + + group = parser.add_argument_group('Special runs') + group.add_argument('-l', '--findleaks', action='store_true', + help='if GC is available detect tests that leak memory') + group.add_argument('-L', '--runleaks', action='store_true', + help='run the leaks(1) command just before exit.' + + more_details) + group.add_argument('-R', '--huntrleaks', metavar='RUNCOUNTS', + type=huntrleaks, + help='search for reference leaks (needs debug build, ' + 'very slow).' + more_details) + group.add_argument('-j', '--multiprocess', metavar='PROCESSES', + dest='use_mp', type=int, + help='run PROCESSES processes at once') + group.add_argument('-T', '--coverage', action='store_true', + dest='trace', + help='turn on code coverage tracing using the trace ' + 'module') + group.add_argument('-D', '--coverdir', metavar='DIR', + type=relative_filename, + help='directory where coverage files are put') + group.add_argument('-N', '--nocoverdir', + action='store_const', const=None, dest='coverdir', + help='put coverage files alongside modules') + group.add_argument('-t', '--threshold', metavar='THRESHOLD', + type=int, + help='call gc.set_threshold(THRESHOLD)') + group.add_argument('-n', '--nowindows', action='store_true', + help='suppress error message boxes on Windows') + group.add_argument('-F', '--forever', action='store_true', + help='run the specified tests in a loop, until an ' + 'error happens') + + parser.add_argument('args', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) + + return parser + +def relative_filename(string): + # CWD is replaced with a temporary dir before calling main(), so we + # join it with the saved CWD so it ends up where the user expects. + return os.path.join(support.SAVEDCWD, string) + +def huntrleaks(string): + args = string.split(':') + if len(args) not in (2, 3): + raise argparse.ArgumentTypeError( + 'needs 2 or 3 colon-separated arguments') + nwarmup = int(args[0]) if args[0] else 5 + ntracked = int(args[1]) if args[1] else 4 + fname = args[2] if len(args) > 2 and args[2] else 'reflog.txt' + return nwarmup, ntracked, fname + +def resources_list(string): + u = [x.lower() for x in string.split(',')] + for r in u: + if r == 'all' or r == 'none': + continue + if r[0] == '-': + r = r[1:] + if r not in RESOURCE_NAMES: + raise argparse.ArgumentTypeError('invalid resource: ' + r) + return u -def main(tests=None, testdir=None, verbose=0, quiet=False, +def _parse_args(args, **kwargs): + # Defaults + ns = argparse.Namespace(testdir=None, verbose=0, quiet=False, exclude=False, single=False, randomize=False, fromfile=None, findleaks=False, use_resources=None, trace=False, coverdir='coverage', runleaks=False, huntrleaks=False, verbose2=False, print_slow=False, random_seed=None, use_mp=None, verbose3=False, forever=False, - header=False, failfast=False, match_tests=None): + header=False, failfast=False, match_tests=None) + for k, v in kwargs.items(): + if not hasattr(ns, k): + raise TypeError('%r is an invalid keyword argument ' + 'for this function' % k) + setattr(ns, k, v) + if ns.use_resources is None: + ns.use_resources = [] + + parser = _create_parser() + parser.parse_args(args=args, namespace=ns) + + if ns.single and ns.fromfile: + parser.error("-s and -f don't go together!") + if ns.use_mp and ns.trace: + parser.error("-T and -j don't go together!") + if ns.use_mp and ns.findleaks: + parser.error("-l and -j don't go together!") + if ns.use_mp and ns.memlimit: + parser.error("-M and -j don't go together!") + if ns.failfast and not (ns.verbose or ns.verbose3): + parser.error("-G/--failfast needs either -v or -W") + + if ns.quiet: + ns.verbose = 0 + if ns.timeout is not None: + if hasattr(faulthandler, 'dump_traceback_later'): + if ns.timeout <= 0: + ns.timeout = None + else: + print("Warning: The timeout option requires " + "faulthandler.dump_traceback_later") + ns.timeout = None + if ns.use_mp is not None: + if ns.use_mp <= 0: + # Use all cores + extras for tests that like to sleep + ns.use_mp = 2 + (os.cpu_count() or 1) + if ns.use_mp == 1: + ns.use_mp = None + if ns.use: + for a in ns.use: + for r in a: + if r == 'all': + ns.use_resources[:] = RESOURCE_NAMES + continue + if r == 'none': + del ns.use_resources[:] + continue + remove = False + if r[0] == '-': + remove = True + r = r[1:] + if remove: + if r in ns.use_resources: + ns.use_resources.remove(r) + elif r not in ns.use_resources: + ns.use_resources.append(r) + if ns.random_seed is not None: + ns.randomize = True + + return ns + + +def run_test_in_subprocess(testname, ns): + """Run the given test in a subprocess with --slaveargs. + + ns is the option Namespace parsed from command-line arguments. regrtest + is invoked in a subprocess with the --slaveargs argument; when the + subprocess exits, its return code, stdout and stderr are returned as a + 3-tuple. + """ + from subprocess import Popen, PIPE + base_cmd = ([sys.executable] + support.args_from_interpreter_flags() + + ['-X', 'faulthandler', '-m', 'test.regrtest']) + + slaveargs = ( + (testname, ns.verbose, ns.quiet), + dict(huntrleaks=ns.huntrleaks, + use_resources=ns.use_resources, + output_on_failure=ns.verbose3, + timeout=ns.timeout, failfast=ns.failfast, + match_tests=ns.match_tests)) + # Running the child from the same working directory as regrtest's original + # invocation ensures that TEMPDIR for the child is the same when + # sysconfig.is_python_build() is true. See issue 15300. + popen = Popen(base_cmd + ['--slaveargs', json.dumps(slaveargs)], + stdout=PIPE, stderr=PIPE, + universal_newlines=True, + close_fds=(os.name != 'nt'), + cwd=support.SAVEDCWD) + stdout, stderr = popen.communicate() + retcode = popen.wait() + return retcode, stdout, stderr + + +def main(tests=None, **kwargs): """Execute a test suite. This also parses command-line options and modifies its behavior @@ -282,7 +479,6 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, directly to set the values that would normally be set by flags on the command line. """ - # Display the Python traceback on fatal errors (e.g. segfault) faulthandler.enable(all_threads=True) @@ -298,187 +494,51 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, replace_stdout() support.record_original_stdout(sys.stdout) - try: - opts, args = getopt.getopt(sys.argv[1:], 'hvqxsoS:rf:lu:t:TD:NLR:FdwWM:nj:Gm:', - ['help', 'verbose', 'verbose2', 'verbose3', 'quiet', - 'exclude', 'single', 'slow', 'randomize', 'fromfile=', 'findleaks', - 'use=', 'threshold=', 'coverdir=', 'nocoverdir', - 'runleaks', 'huntrleaks=', 'memlimit=', 'randseed=', - 'multiprocess=', 'coverage', 'slaveargs=', 'forever', 'debug', - 'start=', 'nowindows', 'header', 'testdir=', 'timeout=', 'wait', - 'failfast', 'match=']) - except getopt.error as msg: - usage(msg) - # Defaults - if random_seed is None: - random_seed = random.randrange(10000000) - if use_resources is None: - use_resources = [] - debug = False - start = None - timeout = None - for o, a in opts: - if o in ('-h', '--help'): - print(__doc__) - return - elif o in ('-v', '--verbose'): - verbose += 1 - elif o in ('-w', '--verbose2'): - verbose2 = True - elif o in ('-d', '--debug'): - debug = True - elif o in ('-W', '--verbose3'): - verbose3 = True - elif o in ('-G', '--failfast'): - failfast = True - elif o in ('-q', '--quiet'): - quiet = True; - verbose = 0 - elif o in ('-x', '--exclude'): - exclude = True - elif o in ('-S', '--start'): - start = a - elif o in ('-s', '--single'): - single = True - elif o in ('-o', '--slow'): - print_slow = True - elif o in ('-r', '--randomize'): - randomize = True - elif o == '--randseed': - randomize = True - random_seed = int(a) - elif o in ('-f', '--fromfile'): - fromfile = a - elif o in ('-m', '--match'): - match_tests = a - elif o in ('-l', '--findleaks'): - findleaks = True - elif o in ('-L', '--runleaks'): - runleaks = True - elif o in ('-t', '--threshold'): - import gc - gc.set_threshold(int(a)) - elif o in ('-T', '--coverage'): - trace = True - elif o in ('-D', '--coverdir'): - # CWD is replaced with a temporary dir before calling main(), so we - # need join it with the saved CWD so it goes where the user expects. - coverdir = os.path.join(support.SAVEDCWD, a) - elif o in ('-N', '--nocoverdir'): - coverdir = None - elif o in ('-R', '--huntrleaks'): - huntrleaks = a.split(':') - if len(huntrleaks) not in (2, 3): - print(a, huntrleaks) - usage('-R takes 2 or 3 colon-separated arguments') - if not huntrleaks[0]: - huntrleaks[0] = 5 - else: - huntrleaks[0] = int(huntrleaks[0]) - if not huntrleaks[1]: - huntrleaks[1] = 4 - else: - huntrleaks[1] = int(huntrleaks[1]) - if len(huntrleaks) == 2 or not huntrleaks[2]: - huntrleaks[2:] = ["reflog.txt"] - # Avoid false positives due to various caches - # filling slowly with random data: - warm_caches() - elif o in ('-M', '--memlimit'): - support.set_memlimit(a) - elif o in ('-u', '--use'): - u = [x.lower() for x in a.split(',')] - for r in u: - if r == 'all': - use_resources[:] = RESOURCE_NAMES - continue - if r == 'none': - del use_resources[:] - continue - remove = False - if r[0] == '-': - remove = True - r = r[1:] - if r not in RESOURCE_NAMES: - usage('Invalid -u/--use option: ' + a) - if remove: - if r in use_resources: - use_resources.remove(r) - elif r not in use_resources: - use_resources.append(r) - elif o in ('-n', '--nowindows'): - import msvcrt - msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS| - msvcrt.SEM_NOALIGNMENTFAULTEXCEPT| - msvcrt.SEM_NOGPFAULTERRORBOX| - msvcrt.SEM_NOOPENFILEERRORBOX) - try: - msvcrt.CrtSetReportMode - except AttributeError: - # release build - pass - else: - for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]: - msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE) - msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR) - elif o in ('-F', '--forever'): - forever = True - elif o in ('-j', '--multiprocess'): - use_mp = int(a) - if use_mp <= 0: - try: - import multiprocessing - # Use all cores + extras for tests that like to sleep - use_mp = 2 + multiprocessing.cpu_count() - except (ImportError, NotImplementedError): - use_mp = 3 - if use_mp == 1: - use_mp = None - elif o == '--header': - header = True - elif o == '--slaveargs': - args, kwargs = json.loads(a) - try: - result = runtest(*args, **kwargs) - except KeyboardInterrupt: - result = INTERRUPTED, '' - except BaseException as e: - traceback.print_exc() - result = CHILD_ERROR, str(e) - sys.stdout.flush() - print() # Force a newline (just in case) - print(json.dumps(result)) - sys.exit(0) - elif o == '--testdir': - # CWD is replaced with a temporary dir before calling main(), so we - # join it with the saved CWD so it ends up where the user expects. - testdir = os.path.join(support.SAVEDCWD, a) - elif o == '--timeout': - if hasattr(faulthandler, 'dump_traceback_later'): - timeout = float(a) - if timeout <= 0: - timeout = None - else: - print("Warning: The timeout option requires " - "faulthandler.dump_traceback_later") - timeout = None - elif o == '--wait': - input("Press any key to continue...") + ns = _parse_args(sys.argv[1:], **kwargs) + + if ns.huntrleaks: + # Avoid false positives due to various caches + # filling slowly with random data: + warm_caches() + if ns.memlimit is not None: + support.set_memlimit(ns.memlimit) + if ns.threshold is not None: + import gc + gc.set_threshold(ns.threshold) + if ns.nowindows: + import msvcrt + msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS| + msvcrt.SEM_NOALIGNMENTFAULTEXCEPT| + msvcrt.SEM_NOGPFAULTERRORBOX| + msvcrt.SEM_NOOPENFILEERRORBOX) + try: + msvcrt.CrtSetReportMode + except AttributeError: + # release build + pass else: - print(("No handler for option {}. Please report this as a bug " - "at http://bugs.python.org.").format(o), file=sys.stderr) - sys.exit(1) - if single and fromfile: - usage("-s and -f don't go together!") - if use_mp and trace: - usage("-T and -j don't go together!") - if use_mp and findleaks: - usage("-l and -j don't go together!") - if use_mp and support.max_memuse: - usage("-M and -j don't go together!") - if failfast and not (verbose or verbose3): - usage("-G/--failfast needs either -v or -W") + for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]: + msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE) + msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR) + if ns.wait: + input("Press any key to continue...") + + if ns.slaveargs is not None: + args, kwargs = json.loads(ns.slaveargs) + if kwargs.get('huntrleaks'): + unittest.BaseTestSuite._cleanup = False + try: + result = runtest(*args, **kwargs) + except KeyboardInterrupt: + result = INTERRUPTED, '' + except BaseException as e: + traceback.print_exc() + result = CHILD_ERROR, str(e) + sys.stdout.flush() + print() # Force a newline (just in case) + print(json.dumps(result)) + sys.exit(0) good = [] bad = [] @@ -487,12 +547,12 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, environment_changed = [] interrupted = False - if findleaks: + if ns.findleaks: try: import gc except ImportError: print('No GC available, disabling findleaks.') - findleaks = False + ns.findleaks = False else: # Uncomment the line below to report garbage that is not # freeable by reference counting alone. By default only @@ -500,82 +560,87 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, #gc.set_debug(gc.DEBUG_SAVEALL) found_garbage = [] - if single: + if ns.huntrleaks: + unittest.BaseTestSuite._cleanup = False + + if ns.single: filename = os.path.join(TEMPDIR, 'pynexttest') try: - fp = open(filename, 'r') - next_test = fp.read().strip() - tests = [next_test] - fp.close() - except IOError: + with open(filename, 'r') as fp: + next_test = fp.read().strip() + tests = [next_test] + except OSError: pass - if fromfile: + if ns.fromfile: tests = [] - fp = open(os.path.join(support.SAVEDCWD, fromfile)) - count_pat = re.compile(r'\[\s*\d+/\s*\d+\]') - for line in fp: - line = count_pat.sub('', line) - guts = line.split() # assuming no test has whitespace in its name - if guts and not guts[0].startswith('#'): - tests.extend(guts) - fp.close() + with open(os.path.join(support.SAVEDCWD, ns.fromfile)) as fp: + count_pat = re.compile(r'\[\s*\d+/\s*\d+\]') + for line in fp: + line = count_pat.sub('', line) + guts = line.split() # assuming no test has whitespace in its name + if guts and not guts[0].startswith('#'): + tests.extend(guts) # Strip .py extensions. - removepy(args) + removepy(ns.args) removepy(tests) stdtests = STDTESTS[:] nottests = NOTTESTS.copy() - if exclude: - for arg in args: + if ns.exclude: + for arg in ns.args: if arg in stdtests: stdtests.remove(arg) nottests.add(arg) - args = [] + ns.args = [] # For a partial run, we do not need to clutter the output. - if verbose or header or not (quiet or single or tests or args): + if ns.verbose or ns.header or not (ns.quiet or ns.single or tests or ns.args): # Print basic platform information print("==", platform.python_implementation(), *sys.version.split()) print("== ", platform.platform(aliased=True), "%s-endian" % sys.byteorder) + print("== ", "hash algorithm:", sys.hash_info.algorithm, + "64bit" if sys.maxsize > 2**32 else "32bit") print("== ", os.getcwd()) print("Testing with flags:", sys.flags) # if testdir is set, then we are not running the python tests suite, so # don't add default tests to be executed or skipped (pass empty values) - if testdir: - alltests = findtests(testdir, list(), set()) + if ns.testdir: + alltests = findtests(ns.testdir, list(), set()) else: - alltests = findtests(testdir, stdtests, nottests) + alltests = findtests(ns.testdir, stdtests, nottests) - selected = tests or args or alltests - if single: + selected = tests or ns.args or alltests + if ns.single: selected = selected[:1] try: next_single_test = alltests[alltests.index(selected[0])+1] except IndexError: next_single_test = None # Remove all the selected tests that precede start if it's set. - if start: + if ns.start: try: - del selected[:selected.index(start)] + del selected[:selected.index(ns.start)] except ValueError: - print("Couldn't find starting test (%s), using all tests" % start) - if randomize: - random.seed(random_seed) - print("Using random seed", random_seed) + print("Couldn't find starting test (%s), using all tests" % ns.start) + if ns.randomize: + if ns.random_seed is None: + ns.random_seed = random.randrange(10000000) + random.seed(ns.random_seed) + print("Using random seed", ns.random_seed) random.shuffle(selected) - if trace: + if ns.trace: import trace, tempfile tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix, tempfile.gettempdir()], trace=False, count=True) test_times = [] - support.verbose = verbose # Tell tests to be moderately quiet - support.use_resources = use_resources + support.verbose = ns.verbose # Tell tests to be moderately quiet + support.use_resources = ns.use_resources save_modules = sys.modules.keys() def accumulate_result(test, result): @@ -593,7 +658,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, skipped.append(test) resource_denieds.append(test) - if forever: + if ns.forever: def test_forever(tests=list(selected)): while True: for test in tests: @@ -608,19 +673,16 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, test_count = '/{}'.format(len(selected)) test_count_width = len(test_count) - 1 - if use_mp: + if ns.use_mp: try: from threading import Thread except ImportError: print("Multiprocess option requires thread support") sys.exit(2) from queue import Queue - from subprocess import Popen, PIPE - debug_output_pat = re.compile(r"\[\d+ refs\]$") + debug_output_pat = re.compile(r"\[\d+ refs, \d+ blocks\]$") output = Queue() pending = MultiprocessTests(tests) - opt_args = support.args_from_interpreter_flags() - base_cmd = [sys.executable] + opt_args + ['-m', 'test.regrtest'] def work(): # A worker thread. try: @@ -630,24 +692,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, except StopIteration: output.put((None, None, None, None)) return - args_tuple = ( - (test, verbose, quiet), - dict(huntrleaks=huntrleaks, use_resources=use_resources, - debug=debug, output_on_failure=verbose3, - timeout=timeout, failfast=failfast, - match_tests=match_tests) - ) - # -E is needed by some tests, e.g. test_import - # Running the child from the same working directory ensures - # that TEMPDIR for the child is the same when - # sysconfig.is_python_build() is true. See issue 15300. - popen = Popen(base_cmd + ['--slaveargs', json.dumps(args_tuple)], - stdout=PIPE, stderr=PIPE, - universal_newlines=True, - close_fds=(os.name != 'nt'), - cwd=support.SAVEDCWD) - stdout, stderr = popen.communicate() - retcode = popen.wait() + retcode, stdout, stderr = run_test_in_subprocess(test, ns) # Strip last refcount output line if it exists, since it # comes from the shutdown of the interpreter in the subcommand. stderr = debug_output_pat.sub("", stderr) @@ -664,19 +709,19 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, except BaseException: output.put((None, None, None, None)) raise - workers = [Thread(target=work) for i in range(use_mp)] + workers = [Thread(target=work) for i in range(ns.use_mp)] for worker in workers: worker.start() finished = 0 test_index = 1 try: - while finished < use_mp: + while finished < ns.use_mp: test, stdout, stderr, result = output.get() if test is None: finished += 1 continue accumulate_result(test, result) - if not quiet: + if not ns.quiet: fmt = "[{1:{0}}{2}/{3}] {4}" if bad else "[{1:{0}}{2}] {4}" print(fmt.format( test_count_width, test_index, test_count, @@ -699,29 +744,30 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, worker.join() else: for test_index, test in enumerate(tests, 1): - if not quiet: + if not ns.quiet: fmt = "[{1:{0}}{2}/{3}] {4}" if bad else "[{1:{0}}{2}] {4}" print(fmt.format( test_count_width, test_index, test_count, len(bad), test)) sys.stdout.flush() - if trace: + if ns.trace: # If we're tracing code coverage, then we don't exit with status # if on a false return value from main. - tracer.runctx('runtest(test, verbose, quiet, timeout=timeout)', + tracer.runctx('runtest(test, ns.verbose, ns.quiet, timeout=ns.timeout)', globals=globals(), locals=vars()) else: try: - result = runtest(test, verbose, quiet, huntrleaks, debug, - output_on_failure=verbose3, - timeout=timeout, failfast=failfast, - match_tests=match_tests) + result = runtest(test, ns.verbose, ns.quiet, + ns.huntrleaks, + output_on_failure=ns.verbose3, + timeout=ns.timeout, failfast=ns.failfast, + match_tests=ns.match_tests) accumulate_result(test, result) except KeyboardInterrupt: interrupted = True break except: raise - if findleaks: + if ns.findleaks: gc.collect() if gc.garbage: print("Warning: test created", len(gc.garbage), end=' ') @@ -742,11 +788,11 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, omitted = set(selected) - set(good) - set(bad) - set(skipped) print(count(len(omitted), "test"), "omitted:") printlist(omitted) - if good and not quiet: + if good and not ns.quiet: if not bad and not skipped and not interrupted and len(good) > 1: print("All", end=' ') print(count(len(good), "test"), "OK.") - if print_slow: + if ns.print_slow: test_times.sort(reverse=True) print("10 slowest tests:") for time, test in test_times[:10]: @@ -760,32 +806,19 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, print("{} altered the execution environment:".format( count(len(environment_changed), "test"))) printlist(environment_changed) - if skipped and not quiet: + if skipped and not ns.quiet: print(count(len(skipped), "test"), "skipped:") printlist(skipped) - e = _ExpectedSkips() - plat = sys.platform - if e.isvalid(): - surprise = set(skipped) - e.getexpected() - set(resource_denieds) - if surprise: - print(count(len(surprise), "skip"), \ - "unexpected on", plat + ":") - printlist(surprise) - else: - print("Those skips are all expected on", plat + ".") - else: - print("Ask someone to teach regrtest.py about which tests are") - print("expected to get skipped on", plat + ".") - - if verbose2 and bad: + if ns.verbose2 and bad: print("Re-running failed tests in verbose mode") for test in bad: print("Re-running test %r in verbose mode" % test) sys.stdout.flush() try: - verbose = True - ok = runtest(test, True, quiet, huntrleaks, debug, timeout=timeout) + ns.verbose = True + ok = runtest(test, True, ns.quiet, ns.huntrleaks, + timeout=ns.timeout) except KeyboardInterrupt: # print a newline separate from the ^C print() @@ -793,18 +826,18 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, except: raise - if single: + if ns.single: if next_single_test: with open(filename, 'w') as fp: fp.write(next_single_test + '\n') else: os.unlink(filename) - if trace: + if ns.trace: r = tracer.results() - r.write_results(show_missing=True, summary=True, coverdir=coverdir) + r.write_results(show_missing=True, summary=True, coverdir=ns.coverdir) - if runleaks: + if ns.runleaks: os.system("leaks %d" % os.getpid()) sys.exit(len(bad) > 0 or interrupted) @@ -877,7 +910,7 @@ def replace_stdout(): atexit.register(restore_stdout) def runtest(test, verbose, quiet, - huntrleaks=False, debug=False, use_resources=None, + huntrleaks=False, use_resources=None, output_on_failure=False, failfast=False, match_tests=None, timeout=None): """Run a single test. @@ -885,14 +918,15 @@ def runtest(test, verbose, quiet, test -- the name of the test verbose -- if true, print more messages quiet -- if true, don't print 'skipped' messages (probably redundant) - test_times -- a list of (time, test_name) pairs huntrleaks -- run multiple times to test for leaks; requires a debug build; a triple corresponding to -R's three arguments + use_resources -- list of extra resources to use output_on_failure -- if true, display test output on failure timeout -- dump the traceback and exit if a test takes more than timeout seconds + failfast, match_tests -- See regrtest command-line flags for these. - Returns one of the test result constants: + Returns the tuple result, test_time, where result is one of the constants: INTERRUPTED KeyboardInterrupt when run under -j RESOURCE_DENIED test skipped because resource denied SKIPPED test skipped for some other reason @@ -930,7 +964,7 @@ def runtest(test, verbose, quiet, sys.stdout = stream sys.stderr = stream result = runtest_inner(test, verbose, quiet, huntrleaks, - debug, display_failure=False) + display_failure=False) if result[0] == FAILED: output = stream.getvalue() orig_stderr.write(output) @@ -940,7 +974,7 @@ def runtest(test, verbose, quiet, sys.stderr = orig_stderr else: support.verbose = verbose # Tell tests to be moderately quiet - result = runtest_inner(test, verbose, quiet, huntrleaks, debug, + result = runtest_inner(test, verbose, quiet, huntrleaks, display_failure=not verbose) return result finally: @@ -992,10 +1026,12 @@ class saved_test_environment: 'os.environ', 'sys.path', 'sys.path_hooks', '__import__', 'warnings.filters', 'asyncore.socket_map', 'logging._handlers', 'logging._handlerList', 'sys.gettrace', - 'sys.warnoptions', 'threading._dangling', - 'multiprocessing.process._dangling', + 'sys.warnoptions', + # multiprocessing.process._cleanup() may release ref + # to a thread, so check processes first. + 'multiprocessing.process._dangling', 'threading._dangling', 'sysconfig._CONFIG_VARS', 'sysconfig._INSTALL_SCHEMES', - 'support.TESTFN', + 'support.TESTFN', 'locale', 'warnings.showwarning', ) def get_sys_argv(self): @@ -1123,6 +1159,8 @@ class saved_test_environment: def get_multiprocessing_process__dangling(self): if not multiprocessing: return None + # Unjoined process objects can survive after process exits + multiprocessing.process._cleanup() # This copies the weakrefs without making any strong reference return multiprocessing.process._dangling.copy() def restore_multiprocessing_process__dangling(self, saved): @@ -1164,6 +1202,25 @@ class saved_test_environment: elif os.path.isdir(support.TESTFN): shutil.rmtree(support.TESTFN) + _lc = [getattr(locale, lc) for lc in dir(locale) + if lc.startswith('LC_')] + def get_locale(self): + pairings = [] + for lc in self._lc: + try: + pairings.append((lc, locale.setlocale(lc, None))) + except (TypeError, ValueError): + continue + return pairings + def restore_locale(self, saved): + for lc, setting in saved: + locale.setlocale(lc, setting) + + def get_warnings_showwarning(self): + return warnings.showwarning + def restore_warnings_showwarning(self, fxn): + warnings.showwarning = fxn + def resource_info(self): for name in self.resources: method_suffix = name.replace('.', '_') @@ -1198,7 +1255,7 @@ class saved_test_environment: def runtest_inner(test, verbose, quiet, - huntrleaks=False, debug=False, display_failure=True): + huntrleaks=False, display_failure=True): support.unload(test) test_time = 0.0 @@ -1211,8 +1268,7 @@ def runtest_inner(test, verbose, quiet, abstest = 'test.' + test with saved_test_environment(test, verbose, quiet) as environment: start_time = time.time() - the_package = __import__(abstest, globals(), locals(), []) - the_module = getattr(the_package, test) + the_module = importlib.import_module(abstest) # If the test has a test_main, that will run the appropriate # tests. If not, use normal unittest test loading. test_runner = getattr(the_module, "test_main", None) @@ -1221,8 +1277,7 @@ def runtest_inner(test, verbose, quiet, test_runner = lambda: support.run_unittest(tests) test_runner() if huntrleaks: - refleak = dash_R(the_module, test, test_runner, - huntrleaks) + refleak = dash_R(the_module, test, test_runner, huntrleaks) test_time = time.time() - start_time except support.ResourceDenied as msg: if not quiet: @@ -1328,41 +1383,50 @@ def dash_R(the_module, test, indirect_test, huntrleaks): for obj in abc.__subclasses__() + [abc]: abcs[obj] = obj._abc_registry.copy() - if indirect_test: - def run_the_test(): - indirect_test() - else: - def run_the_test(): - del sys.modules[the_module.__name__] - exec('import ' + the_module.__name__) - - deltas = [] nwarmup, ntracked, fname = huntrleaks fname = os.path.join(support.SAVEDCWD, fname) repcount = nwarmup + ntracked + rc_deltas = [0] * repcount + alloc_deltas = [0] * repcount + print("beginning", repcount, "repetitions", file=sys.stderr) print(("1234567890"*(repcount//10 + 1))[:repcount], file=sys.stderr) sys.stderr.flush() - dash_R_cleanup(fs, ps, pic, zdc, abcs) for i in range(repcount): - rc_before = sys.gettotalrefcount() - run_the_test() + indirect_test() + alloc_after, rc_after = dash_R_cleanup(fs, ps, pic, zdc, abcs) sys.stderr.write('.') sys.stderr.flush() - dash_R_cleanup(fs, ps, pic, zdc, abcs) - rc_after = sys.gettotalrefcount() if i >= nwarmup: - deltas.append(rc_after - rc_before) + rc_deltas[i] = rc_after - rc_before + alloc_deltas[i] = alloc_after - alloc_before + alloc_before, rc_before = alloc_after, rc_after print(file=sys.stderr) - if any(deltas): - msg = '%s leaked %s references, sum=%s' % (test, deltas, sum(deltas)) - print(msg, file=sys.stderr) - sys.stderr.flush() - with open(fname, "a") as refrep: - print(msg, file=refrep) - refrep.flush() - return True - return False + # These checkers return False on success, True on failure + def check_rc_deltas(deltas): + return any(deltas) + def check_alloc_deltas(deltas): + # At least 1/3rd of 0s + if 3 * deltas.count(0) < len(deltas): + return True + # Nothing else than 1s, 0s and -1s + if not set(deltas) <= {1,0,-1}: + return True + return False + failed = False + for deltas, item_name, checker in [ + (rc_deltas, 'references', check_rc_deltas), + (alloc_deltas, 'memory blocks', check_alloc_deltas)]: + if checker(deltas): + msg = '%s leaked %s %s, sum=%s' % ( + test, deltas[nwarmup:], item_name, sum(deltas)) + print(msg, file=sys.stderr) + sys.stderr.flush() + with open(fname, "a") as refrep: + print(msg, file=refrep) + refrep.flush() + failed = True + return failed def dash_R_cleanup(fs, ps, pic, zdc, abcs): import gc, copyreg @@ -1428,8 +1492,11 @@ def dash_R_cleanup(fs, ps, pic, zdc, abcs): else: ctypes._reset_cache() - # Collect cyclic trash. + # Collect cyclic trash and read memory statistics immediately after. + func1 = sys.getallocatedblocks + func2 = sys.gettotalrefcount gc.collect() + return func1(), func2() def warm_caches(): # char cache @@ -1472,307 +1539,10 @@ def printlist(x, width=70, indent=4): print(fill(' '.join(str(elt) for elt in sorted(x)), width, initial_indent=blanks, subsequent_indent=blanks)) -# Map sys.platform to a string containing the basenames of tests -# expected to be skipped on that platform. -# -# Special cases: -# test_pep277 -# The _ExpectedSkips constructor adds this to the set of expected -# skips if not os.path.supports_unicode_filenames. -# test_timeout -# Controlled by test_timeout.skip_expected. Requires the network -# resource and a socket module. -# -# Tests that are expected to be skipped everywhere except on one platform -# are also handled separately. - -_expectations = ( - ('win32', - """ - test__locale - test_crypt - test_curses - test_dbm - test_devpoll - test_fcntl - test_fork1 - test_epoll - test_dbm_gnu - test_dbm_ndbm - test_grp - test_ioctl - test_largefile - test_kqueue - test_openpty - test_ossaudiodev - test_pipes - test_poll - test_posix - test_pty - test_pwd - test_resource - test_signal - test_syslog - test_threadsignals - test_wait3 - test_wait4 - """), - ('linux', - """ - test_curses - test_devpoll - test_largefile - test_kqueue - test_ossaudiodev - """), - ('unixware', - """ - test_epoll - test_largefile - test_kqueue - test_minidom - test_openpty - test_pyexpat - test_sax - test_sundry - """), - ('openunix', - """ - test_epoll - test_largefile - test_kqueue - test_minidom - test_openpty - test_pyexpat - test_sax - test_sundry - """), - ('sco_sv', - """ - test_asynchat - test_fork1 - test_epoll - test_gettext - test_largefile - test_locale - test_kqueue - test_minidom - test_openpty - test_pyexpat - test_queue - test_sax - test_sundry - test_thread - test_threaded_import - test_threadedtempfile - test_threading - """), - ('darwin', - """ - test__locale - test_curses - test_devpoll - test_epoll - test_dbm_gnu - test_gdb - test_largefile - test_locale - test_minidom - test_ossaudiodev - test_poll - """), - ('sunos', - """ - test_curses - test_dbm - test_epoll - test_kqueue - test_dbm_gnu - test_gzip - test_openpty - test_zipfile - test_zlib - """), - ('hp-ux', - """ - test_curses - test_epoll - test_dbm_gnu - test_gzip - test_largefile - test_locale - test_kqueue - test_minidom - test_openpty - test_pyexpat - test_sax - test_zipfile - test_zlib - """), - ('cygwin', - """ - test_curses - test_dbm - test_devpoll - test_epoll - test_ioctl - test_kqueue - test_largefile - test_locale - test_ossaudiodev - test_socketserver - """), - ('os2emx', - """ - test_audioop - test_curses - test_epoll - test_kqueue - test_largefile - test_mmap - test_openpty - test_ossaudiodev - test_pty - test_resource - test_signal - """), - ('freebsd', - """ - test_devpoll - test_epoll - test_dbm_gnu - test_locale - test_ossaudiodev - test_pep277 - test_pty - test_socketserver - test_tcl - test_tk - test_ttk_guionly - test_ttk_textonly - test_timeout - test_urllibnet - test_multiprocessing - """), - ('aix', - """ - test_bz2 - test_epoll - test_dbm_gnu - test_gzip - test_kqueue - test_ossaudiodev - test_tcl - test_tk - test_ttk_guionly - test_ttk_textonly - test_zipimport - test_zlib - """), - ('openbsd', - """ - test_ctypes - test_devpoll - test_epoll - test_dbm_gnu - test_locale - test_normalization - test_ossaudiodev - test_pep277 - test_tcl - test_tk - test_ttk_guionly - test_ttk_textonly - test_multiprocessing - """), - ('netbsd', - """ - test_ctypes - test_curses - test_devpoll - test_epoll - test_dbm_gnu - test_locale - test_ossaudiodev - test_pep277 - test_tcl - test_tk - test_ttk_guionly - test_ttk_textonly - test_multiprocessing - """), -) - -class _ExpectedSkips: - def __init__(self): - import os.path - from test import test_timeout - - self.valid = False - expected = None - for item in _expectations: - if sys.platform.startswith(item[0]): - expected = item[1] - break - if expected is not None: - self.expected = set(expected.split()) - - # These are broken tests, for now skipped on every platform. - # XXX Fix these! - self.expected.add('test_nis') - - # expected to be skipped on every platform, even Linux - if not os.path.supports_unicode_filenames: - self.expected.add('test_pep277') - - # doctest, profile and cProfile tests fail when the codec for the - # fs encoding isn't built in because PyUnicode_Decode() adds two - # calls into Python. - encs = ("utf-8", "latin-1", "ascii", "mbcs", "utf-16", "utf-32") - if sys.getfilesystemencoding().lower() not in encs: - self.expected.add('test_profile') - self.expected.add('test_cProfile') - self.expected.add('test_doctest') - - if test_timeout.skip_expected: - self.expected.add('test_timeout') - - if sys.platform != "win32": - # test_sqlite is only reliable on Windows where the library - # is distributed with Python - WIN_ONLY = {"test_unicode_file", "test_winreg", - "test_winsound", "test_startfile", - "test_sqlite", "test_msilib"} - self.expected |= WIN_ONLY - - if sys.platform != 'sunos5': - self.expected.add('test_nis') - - if support.python_is_optimized(): - self.expected.add("test_gdb") - - self.valid = True - - def isvalid(self): - "Return true iff _ExpectedSkips knows about the current platform." - return self.valid - - def getexpected(self): - """Return set of test names we expect to skip on current platform. - - self.isvalid() must be true. - """ - - assert self.isvalid() - return self.expected - -def _make_temp_dir_for_build(TEMPDIR): - # When tests are run from the Python build directory, it is best practice - # to keep the test files in a subfolder. It eases the cleanup of leftover - # files using command "make distclean". + +def main_in_temp_cwd(): + """Run main() in a temporary working directory.""" if sysconfig.is_python_build(): - TEMPDIR = os.path.join(sysconfig.get_config_var('srcdir'), 'build') - TEMPDIR = os.path.abspath(TEMPDIR) try: os.mkdir(TEMPDIR) except FileExistsError: @@ -1781,10 +1551,16 @@ def _make_temp_dir_for_build(TEMPDIR): # Define a writable temp dir that will be used as cwd while running # the tests. The name of the dir includes the pid to allow parallel # testing (see the -j option). - TESTCWD = 'test_python_{}'.format(os.getpid()) + test_cwd = 'test_python_{}'.format(os.getpid()) + test_cwd = os.path.join(TEMPDIR, test_cwd) + + # Run the tests in a context manager that temporarily changes the CWD to a + # temporary and writable directory. If it's not possible to create or + # change the CWD, the original CWD will be used. The original CWD is + # available from support.SAVEDCWD. + with support.temp_cwd(test_cwd, quiet=True): + main() - TESTCWD = os.path.join(TEMPDIR, TESTCWD) - return TEMPDIR, TESTCWD if __name__ == '__main__': # Remove regrtest.py's own directory from the module search path. Despite @@ -1808,11 +1584,4 @@ if __name__ == '__main__': # sanity check assert __file__ == os.path.abspath(sys.argv[0]) - TEMPDIR, TESTCWD = _make_temp_dir_for_build(TEMPDIR) - - # Run the tests in a context manager that temporary changes the CWD to a - # temporary and writable directory. If it's not possible to create or - # change the CWD, the original CWD will be used. The original CWD is - # available from support.SAVEDCWD. - with support.temp_cwd(TESTCWD, quiet=True): - main() + main_in_temp_cwd() diff --git a/Lib/test/revocation.crl b/Lib/test/revocation.crl new file mode 100644 index 0000000000..6d89b08ebe --- /dev/null +++ b/Lib/test/revocation.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBpjCBjwIBATANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJYWTEmMCQGA1UE +CgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNVBAMMDW91ci1j +YS1zZXJ2ZXIXDTEzMTEyMTE3MDg0N1oXDTIzMDkzMDE3MDg0N1qgDjAMMAoGA1Ud +FAQDAgEAMA0GCSqGSIb3DQEBBQUAA4IBAQCNJXC2mVKauEeN3LlQ3ZtM5gkH3ExH ++i4bmJjtJn497WwvvoIeUdrmVXgJQR93RtV37hZwN0SXMLlNmUZPH4rHhihayw4m +unCzVj/OhCCY7/TPjKuJ1O/0XhaLBpBVjQN7R/1ujoRKbSia/CD3vcn7Fqxzw7LK +fSRCKRGTj1CZiuxrphtFchwALXSiFDy9mr2ZKhImcyq1PydfgEzU78APpOkMQsIC +UNJ/cf3c9emzf+dUtcMEcejQ3mynBo4eIGg1EW42bz4q4hSjzQlKcBV0muw5qXhc +HOxH2iTFhQ7SrvVuK/dM14rYM4B5mSX3nRC1kNmXpS9j3wJDhuwmjHed +-----END X509 CRL----- diff --git a/Lib/test/script_helper.py b/Lib/test/script_helper.py index ab201649e5..af0545bac2 100644 --- a/Lib/test/script_helper.py +++ b/Lib/test/script_helper.py @@ -12,13 +12,22 @@ import contextlib import shutil import zipfile -from imp import source_from_cache +from importlib.util import source_from_cache from test.support import make_legacy_pyc, strip_python_stderr, temp_dir # Executing the interpreter in a subprocess def _assert_python(expected_success, *args, **env_vars): - cmd_line = [sys.executable] - if not env_vars: + if '__isolated' in env_vars: + isolated = env_vars.pop('__isolated') + else: + isolated = not env_vars + cmd_line = [sys.executable, '-X', 'faulthandler'] + if isolated: + # isolated mode: ignore Python environment variables, ignore user + # site-packages, and don't add the current directory to sys.path + cmd_line.append('-I') + elif not env_vars: + # ignore Python environment variables cmd_line.append('-E') # Need to preserve the original environment, for in-place testing of # shared library builds. @@ -39,7 +48,7 @@ def _assert_python(expected_success, *args, **env_vars): p.stdout.close() p.stderr.close() rc = p.returncode - err = strip_python_stderr(err) + err = strip_python_stderr(err) if (rc and expected_success) or (not rc and not expected_success): raise AssertionError( "Process return code is %d, " @@ -49,18 +58,32 @@ def _assert_python(expected_success, *args, **env_vars): def assert_python_ok(*args, **env_vars): """ Assert that running the interpreter with `args` and optional environment - variables `env_vars` is ok and return a (return code, stdout, stderr) tuple. + variables `env_vars` succeeds (rc == 0) and return a (return code, stdout, + stderr) tuple. + + If the __cleanenv keyword is set, env_vars is used a fresh environment. + + Python is started in isolated mode (command line option -I), + except if the __isolated keyword is set to False. """ return _assert_python(True, *args, **env_vars) def assert_python_failure(*args, **env_vars): """ Assert that running the interpreter with `args` and optional environment - variables `env_vars` fails and return a (return code, stdout, stderr) tuple. + variables `env_vars` fails (rc != 0) and return a (return code, stdout, + stderr) tuple. + + See assert_python_ok() for more options. """ return _assert_python(False, *args, **env_vars) def spawn_python(*args, **kw): + """Run a Python subprocess with the given arguments. + + kw is extra keyword args to pass to subprocess.Popen. Returns a Popen + object. + """ cmd_line = [sys.executable, '-E'] cmd_line.extend(args) return subprocess.Popen(cmd_line, stdin=subprocess.PIPE, @@ -68,6 +91,7 @@ def spawn_python(*args, **kw): **kw) def kill_python(p): + """Run the given Popen process until completion and return stdout.""" p.stdin.close() data = p.stdout.read() p.stdout.close() @@ -77,8 +101,10 @@ def kill_python(p): subprocess._cleanup() return data -def make_script(script_dir, script_basename, source): - script_filename = script_basename+os.extsep+'py' +def make_script(script_dir, script_basename, source, omit_suffix=False): + script_filename = script_basename + if not omit_suffix: + script_filename += os.extsep + 'py' script_name = os.path.join(script_dir, script_filename) # The script should be encoded to UTF-8, the default string encoding script_file = open(script_name, 'w', encoding='utf-8') diff --git a/Lib/test/sortperf.py b/Lib/test/sortperf.py index af7c0b4bd1..90722f7abc 100644 --- a/Lib/test/sortperf.py +++ b/Lib/test/sortperf.py @@ -22,7 +22,7 @@ def randfloats(n): fn = os.path.join(td, "rr%06d" % n) try: fp = open(fn, "rb") - except IOError: + except OSError: r = random.random result = [r() for i in range(n)] try: @@ -35,9 +35,9 @@ def randfloats(n): if fp: try: os.unlink(fn) - except os.error: + except OSError: pass - except IOError as msg: + except OSError as msg: print("can't write", fn, ":", msg) else: result = marshal.load(fp) diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py index 8686153a17..759b3f487e 100644 --- a/Lib/test/ssl_servers.py +++ b/Lib/test/ssl_servers.py @@ -35,7 +35,7 @@ class HTTPSServer(_HTTPServer): try: sock, addr = self.socket.accept() sslconn = self.context.wrap_socket(sock, server_side=True) - except socket.error as e: + except OSError as e: # socket errors are silenced by the caller, print them here if support.verbose: sys.stderr.write("Got an error:\n%s\n" % e) @@ -147,9 +147,11 @@ class HTTPSServerThread(threading.Thread): self.server.shutdown() -def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None): - # we assume the certfile contains both private key and certificate - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) +def make_https_server(case, *, context=None, certfile=CERTFILE, + host=HOST, handler_class=None): + if context is None: + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + # We assume the certfile contains both private key and certificate context.load_cert_chain(certfile) server = HTTPSServerThread(context, host, handler_class) flag = threading.Event() diff --git a/Lib/test/ssltests.py b/Lib/test/ssltests.py new file mode 100644 index 0000000000..9b0ed22b95 --- /dev/null +++ b/Lib/test/ssltests.py @@ -0,0 +1,22 @@ +# Convenience test module to run all of the SSL-related tests in the +# standard library. + +import sys +import subprocess + +TESTS = ['test_asyncio', 'test_ftplib', 'test_hashlib', 'test_httplib', + 'test_imaplib', 'test_nntplib', 'test_poplib', 'test_smtplib', + 'test_smtpnet', 'test_urllib2_localnet', 'test_venv'] + +def run_regrtests(*extra_args): + args = [sys.executable, "-m", "test"] + if not extra_args: + args.append("-unetwork") + else: + args.extend(extra_args) + args.extend(TESTS) + result = subprocess.call(args) + sys.exit(result) + +if __name__ == '__main__': + run_regrtests(*sys.argv[1:]) diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 4800d6d7f2..5ed01f2979 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -327,13 +327,26 @@ class BaseTest: self.checkraises(TypeError, 'hello', 'upper', 42) def test_expandtabs(self): - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs') - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 8) - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 4) - self.checkequal('abc\r\nab def\ng hi', 'abc\r\nab\tdef\ng\thi', 'expandtabs', 4) - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs') - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 8) - self.checkequal('abc\r\nab\r\ndef\ng\r\nhi', 'abc\r\nab\r\ndef\ng\r\nhi', 'expandtabs', 4) + self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', + 'expandtabs') + self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', + 'expandtabs', 8) + self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', + 'expandtabs', 4) + self.checkequal('abc\r\nab def\ng hi', 'abc\r\nab\tdef\ng\thi', + 'expandtabs') + self.checkequal('abc\r\nab def\ng hi', 'abc\r\nab\tdef\ng\thi', + 'expandtabs', 8) + self.checkequal('abc\r\nab def\ng hi', 'abc\r\nab\tdef\ng\thi', + 'expandtabs', 4) + self.checkequal('abc\r\nab\r\ndef\ng\r\nhi', 'abc\r\nab\r\ndef\ng\r\nhi', + 'expandtabs', 4) + # check keyword args + self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', + 'expandtabs', tabsize=8) + self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', + 'expandtabs', tabsize=4) + self.checkequal(' a\n b', ' \ta\n\tb', 'expandtabs', 1) self.checkraises(TypeError, 'hello', 'expandtabs', 42, 42) @@ -1165,7 +1178,8 @@ class MixinStrUnicodeUserStringTest: self.checkraises(TypeError, 'abc', '__mod__') self.checkraises(TypeError, '%(foo)s', '__mod__', 42) self.checkraises(TypeError, '%s%s', '__mod__', (42,)) - self.checkraises(TypeError, '%c', '__mod__', (None,)) + with self.assertWarns(DeprecationWarning): + self.checkraises(TypeError, '%c', '__mod__', (None,)) self.checkraises(ValueError, '%(foo', '__mod__', {}) self.checkraises(TypeError, '%(foo)s %(bar)s', '__mod__', ('foo', 42)) self.checkraises(TypeError, '%d', '__mod__', "42") # not numeric diff --git a/Lib/test/subprocessdata/fd_status.py b/Lib/test/subprocessdata/fd_status.py index 1f61e13a34..d12bd95abe 100644 --- a/Lib/test/subprocessdata/fd_status.py +++ b/Lib/test/subprocessdata/fd_status.py @@ -1,17 +1,27 @@ """When called as a script, print a comma-separated list of the open -file descriptors on stdout.""" +file descriptors on stdout. + +Usage: +fd_stats.py: check all file descriptors +fd_status.py fd1 fd2 ...: check only specified file descriptors +""" import errno import os - -try: - _MAXFD = os.sysconf("SC_OPEN_MAX") -except: - _MAXFD = 256 +import stat +import sys if __name__ == "__main__": fds = [] - for fd in range(0, _MAXFD): + if len(sys.argv) == 1: + try: + _MAXFD = os.sysconf("SC_OPEN_MAX") + except: + _MAXFD = 256 + test_fds = range(0, _MAXFD) + else: + test_fds = map(int, sys.argv[1:]) + for fd in test_fds: try: st = os.fstat(fd) except OSError as e: @@ -19,6 +29,6 @@ if __name__ == "__main__": continue raise # Ignore Solaris door files - if st.st_mode & 0xF000 != 0xd000: + if not stat.S_ISDOOR(st.st_mode): fds.append(fd) print(','.join(map(str, fds))) diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 5c03f54d5c..9857f9dca4 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -15,10 +15,10 @@ import shutil import warnings import unittest import importlib +import importlib.util import collections.abc import re import subprocess -import imp import time import sysconfig import fnmatch @@ -56,26 +56,49 @@ try: except ImportError: lzma = None +try: + import resource +except ImportError: + resource = None + __all__ = [ - "Error", "TestFailed", "ResourceDenied", "import_module", "verbose", - "use_resources", "max_memuse", "record_original_stdout", - "get_original_stdout", "unload", "unlink", "rmtree", "forget", + # globals + "PIPE_MAX_SIZE", "verbose", "max_memuse", "use_resources", "failfast", + # exceptions + "Error", "TestFailed", "ResourceDenied", + # imports + "import_module", "import_fresh_module", "CleanImport", + # modules + "unload", "forget", + # io + "record_original_stdout", "get_original_stdout", "captured_stdout", + "captured_stdin", "captured_stderr", + # filesystem + "TESTFN", "SAVEDCWD", "unlink", "rmtree", "temp_cwd", "findfile", + "create_empty_file", "can_symlink", "fs_is_case_insensitive", + # unittest "is_resource_enabled", "requires", "requires_freebsd_version", - "requires_linux_version", "requires_mac_ver", "find_unused_port", - "bind_port", "IPV6_ENABLED", "is_jython", "TESTFN", "HOST", "SAVEDCWD", - "temp_cwd", "findfile", "create_empty_file", "sortdict", - "check_syntax_error", "open_urlresource", "check_warnings", "CleanImport", - "EnvironmentVarGuard", "TransientResource", "captured_stdout", - "captured_stdin", "captured_stderr", "time_out", "socket_peer_reset", - "ioerror_peer_reset", "run_with_locale", 'temp_umask', - "transient_internet", "set_memlimit", "bigmemtest", "bigaddrspacetest", - "BasicTestRunner", "run_unittest", "run_doctest", "threading_setup", - "threading_cleanup", "reap_children", "cpython_only", "check_impl_detail", - "get_attribute", "swap_item", "swap_attr", "requires_IEEE_754", - "TestHandler", "Matcher", "can_symlink", "skip_unless_symlink", - "skip_unless_xattr", "import_fresh_module", "requires_zlib", - "PIPE_MAX_SIZE", "failfast", "anticipate_failure", "run_with_tz", - "requires_gzip", "requires_bz2", "requires_lzma", "suppress_crash_popup", + "requires_linux_version", "requires_mac_ver", "check_syntax_error", + "TransientResource", "time_out", "socket_peer_reset", "ioerror_peer_reset", + "transient_internet", "BasicTestRunner", "run_unittest", "run_doctest", + "skip_unless_symlink", "requires_gzip", "requires_bz2", "requires_lzma", + "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", + "requires_IEEE_754", "skip_unless_xattr", "requires_zlib", + "anticipate_failure", + # sys + "is_jython", "check_impl_detail", + # network + "HOST", "IPV6_ENABLED", "find_unused_port", "bind_port", "open_urlresource", + # processes + 'temp_umask', "reap_children", + # logging + "TestHandler", + # threads + "threading_setup", "threading_cleanup", + # miscellaneous + "check_warnings", "EnvironmentVarGuard", "run_with_locale", "swap_item", + "swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict", + "run_with_tz", ] class Error(Exception): @@ -97,7 +120,8 @@ def _ignore_deprecated_imports(ignore=True): """Context manager to suppress package and module deprecation warnings when importing them. - If ignore is False, this context manager has no effect.""" + If ignore is False, this context manager has no effect. + """ if ignore: with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".+ (module|package)", @@ -107,16 +131,21 @@ def _ignore_deprecated_imports(ignore=True): yield -def import_module(name, deprecated=False): +def import_module(name, deprecated=False, *, required_on=()): """Import and return the module to be tested, raising SkipTest if it is not available. If deprecated is True, any module or package deprecation messages - will be suppressed.""" + will be suppressed. If a module is required on a platform but optional for + others, set required_on to an iterable of platform prefixes which will be + compared against sys.platform. + """ with _ignore_deprecated_imports(deprecated): try: return importlib.import_module(name) except ImportError as msg: + if sys.platform.startswith(tuple(required_on)): + raise raise unittest.SkipTest(str(msg)) @@ -302,25 +331,20 @@ else: def unlink(filename): try: _unlink(filename) - except OSError as error: - # The filename need not exist. - if error.errno not in (errno.ENOENT, errno.ENOTDIR): - raise + except (FileNotFoundError, NotADirectoryError): + pass def rmdir(dirname): try: _rmdir(dirname) - except OSError as error: - # The directory need not exist. - if error.errno != errno.ENOENT: - raise + except FileNotFoundError: + pass def rmtree(path): try: _rmtree(path) - except OSError as error: - if error.errno != errno.ENOENT: - raise + except FileNotFoundError: + pass def make_legacy_pyc(source): """Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location. @@ -332,7 +356,7 @@ def make_legacy_pyc(source): does not need to exist, however the PEP 3147 pyc file must exist. :return: The file system path to the legacy pyc file. """ - pyc_file = imp.cache_from_source(source) + pyc_file = importlib.util.cache_from_source(source) up_one = os.path.dirname(os.path.abspath(source)) legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o')) os.rename(pyc_file, legacy_pyc) @@ -351,8 +375,8 @@ def forget(modname): # combinations of PEP 3147 and legacy pyc and pyo files. unlink(source + 'c') unlink(source + 'o') - unlink(imp.cache_from_source(source, debug_override=True)) - unlink(imp.cache_from_source(source, debug_override=False)) + unlink(importlib.util.cache_from_source(source, debug_override=True)) + unlink(importlib.util.cache_from_source(source, debug_override=False)) # On some platforms, should not run gui test even if it is allowed # in `use_resources'. @@ -513,7 +537,7 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): the SO_REUSEADDR socket option having different semantics on Windows versus Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind, listen and then accept connections on identical host/ports. An EADDRINUSE - socket.error will be raised at some point (depending on the platform and + OSError will be raised at some point (depending on the platform and the order bind and listen were called on each socket). However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE @@ -591,9 +615,9 @@ def _is_ipv6_enabled(): sock = None try: sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.bind(('::1', 0)) + sock.bind((HOSTv6, 0)) return True - except (socket.error, socket.gaierror): + except OSError: pass finally: if sock: @@ -1180,9 +1204,9 @@ class TransientResource(object): # Context managers that raise ResourceDenied when various issues # with the Internet connection manifest themselves as exceptions. # XXX deprecate these and use transient_internet() instead -time_out = TransientResource(IOError, errno=errno.ETIMEDOUT) -socket_peer_reset = TransientResource(socket.error, errno=errno.ECONNRESET) -ioerror_peer_reset = TransientResource(IOError, errno=errno.ECONNRESET) +time_out = TransientResource(OSError, errno=errno.ETIMEDOUT) +socket_peer_reset = TransientResource(OSError, errno=errno.ECONNRESET) +ioerror_peer_reset = TransientResource(OSError, errno=errno.ECONNRESET) @contextlib.contextmanager @@ -1228,17 +1252,17 @@ def transient_internet(resource_name, *, timeout=30.0, errnos=()): if timeout is not None: socket.setdefaulttimeout(timeout) yield - except IOError as err: + except OSError as err: # urllib can wrap original socket errors multiple times (!), we must # unwrap to get at the original error. while True: a = err.args - if len(a) >= 1 and isinstance(a[0], IOError): + if len(a) >= 1 and isinstance(a[0], OSError): err = a[0] # The error can also be wrapped as args[1]: # except socket.error as msg: - # raise IOError('socket error', msg).with_traceback(sys.exc_info()[2]) - elif len(a) >= 2 and isinstance(a[1], IOError): + # raise OSError('socket error', msg).with_traceback(sys.exc_info()[2]) + elif len(a) >= 2 and isinstance(a[1], OSError): err = a[1] else: break @@ -1697,9 +1721,18 @@ def run_unittest(*classes): #======================================================================= # Check for the presence of docstrings. -HAVE_DOCSTRINGS = (check_impl_detail(cpython=False) or - sys.platform == 'win32' or - sysconfig.get_config_var('WITH_DOC_STRINGS')) +# Rather than trying to enumerate all the cases where docstrings may be +# disabled, we just check for that directly + +def _check_docstrings(): + """Just used to check if docstrings are enabled""" + +MISSING_C_DOCSTRINGS = (check_impl_detail() and + sys.platform != 'win32' and + not sysconfig.get_config_var('WITH_DOC_STRINGS')) + +HAVE_DOCSTRINGS = (_check_docstrings.__doc__ is not None and + not MISSING_C_DOCSTRINGS) requires_docstrings = unittest.skipUnless(HAVE_DOCSTRINGS, "test requires docstrings") @@ -1774,12 +1807,12 @@ def threading_setup(): def threading_cleanup(*original_values): if not _thread: return - _MAX_COUNT = 10 + _MAX_COUNT = 100 for count in range(_MAX_COUNT): values = _thread._count(), threading._dangling if values == original_values: break - time.sleep(0.1) + time.sleep(0.01) gc_collect() # XXX print a warning in case of failure? @@ -1881,7 +1914,7 @@ def strip_python_stderr(stderr): This will typically be run on the result of the communicate() method of a subprocess.Popen object. """ - stderr = re.sub(br"\[\d+ refs\]\r?\n?", b"", stderr).strip() + stderr = re.sub(br"\[\d+ refs, \d+ blocks\]\r?\n?", b"", stderr).strip() return stderr def args_from_interpreter_flags(): @@ -2012,27 +2045,81 @@ def skip_unless_xattr(test): return test if ok else unittest.skip(msg)(test) -if sys.platform.startswith('win'): - @contextlib.contextmanager - def suppress_crash_popup(): - """Disable Windows Error Reporting dialogs using SetErrorMode.""" - # see http://msdn.microsoft.com/en-us/library/windows/desktop/ms680621%28v=vs.85%29.aspx - # GetErrorMode is not available on Windows XP and Windows Server 2003, - # but SetErrorMode returns the previous value, so we can use that - import ctypes - k32 = ctypes.windll.kernel32 - SEM_NOGPFAULTERRORBOX = 0x02 - old_error_mode = k32.SetErrorMode(SEM_NOGPFAULTERRORBOX) - k32.SetErrorMode(old_error_mode | SEM_NOGPFAULTERRORBOX) - try: - yield - finally: - k32.SetErrorMode(old_error_mode) -else: - # this is a no-op for other platforms - @contextlib.contextmanager - def suppress_crash_popup(): - yield +def fs_is_case_insensitive(directory): + """Detects if the file system for the specified directory is case-insensitive.""" + base_fp, base_path = tempfile.mkstemp(dir=directory) + case_path = base_path.upper() + if case_path == base_path: + case_path = base_path.lower() + try: + return os.path.samefile(base_path, case_path) + except FileNotFoundError: + return False + finally: + os.unlink(base_path) + + +class SuppressCrashReport: + """Try to prevent a crash report from popping up. + + On Windows, don't display the Windows Error Reporting dialog. On UNIX, + disable the creation of coredump file. + """ + old_value = None + + def __enter__(self): + """On Windows, disable Windows Error Reporting dialogs using + SetErrorMode. + + On UNIX, try to save the previous core file size limit, then set + soft limit to 0. + """ + if sys.platform.startswith('win'): + # see http://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx + # GetErrorMode is not available on Windows XP and Windows Server 2003, + # but SetErrorMode returns the previous value, so we can use that + import ctypes + self._k32 = ctypes.windll.kernel32 + SEM_NOGPFAULTERRORBOX = 0x02 + self.old_value = self._k32.SetErrorMode(SEM_NOGPFAULTERRORBOX) + self._k32.SetErrorMode(self.old_value | SEM_NOGPFAULTERRORBOX) + else: + if resource is not None: + try: + self.old_value = resource.getrlimit(resource.RLIMIT_CORE) + resource.setrlimit(resource.RLIMIT_CORE, + (0, self.old_value[1])) + except (ValueError, OSError): + pass + if sys.platform == 'darwin': + # Check if the 'Crash Reporter' on OSX was configured + # in 'Developer' mode and warn that it will get triggered + # when it is. + # + # This assumes that this context manager is used in tests + # that might trigger the next manager. + value = subprocess.Popen(['/usr/bin/defaults', 'read', + 'com.apple.CrashReporter', 'DialogType'], + stdout=subprocess.PIPE).communicate()[0] + if value.strip() == b'developer': + print("this test triggers the Crash Reporter, " + "that is intentional", end='', flush=True) + + return self + + def __exit__(self, *ignore_exc): + """Restore Windows ErrorMode or core file behavior to initial value.""" + if self.old_value is None: + return + + if sys.platform.startswith('win'): + self._k32.SetErrorMode(self.old_value) + else: + if resource is not None: + try: + resource.setrlimit(resource.RLIMIT_CORE, self.old_value) + except (ValueError, OSError): + pass def patch(test_instance, object_to_patch, attr_name, new_value): @@ -2067,3 +2154,23 @@ def patch(test_instance, object_to_patch, attr_name, new_value): # actually override the attribute setattr(object_to_patch, attr_name, new_value) + + +def run_in_subinterp(code): + """ + Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc + module is enabled. + """ + # Issue #10915, #15751: PyGILState_*() functions don't work with + # sub-interpreters, the tracemalloc module uses these functions internally + try: + import tracemalloc + except ImportError: + pass + else: + if tracemalloc.is_tracing(): + raise unittest.SkipTest("run_in_subinterp() cannot be used " + "if tracemalloc module is tracing " + "memory allocations") + import _testcapi + return _testcapi.run_in_subinterp(code) diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py index 608ec01f14..8cc285f70a 100644 --- a/Lib/test/test___all__.py +++ b/Lib/test/test___all__.py @@ -29,17 +29,20 @@ class AllTest(unittest.TestCase): if not hasattr(sys.modules[modname], "__all__"): raise NoAll(modname) names = {} - try: - exec("from %s import *" % modname, names) - except Exception as e: - # Include the module name in the exception string - self.fail("__all__ failure in {}: {}: {}".format( - modname, e.__class__.__name__, e)) - if "__builtins__" in names: - del names["__builtins__"] - keys = set(names) - all = set(sys.modules[modname].__all__) - self.assertEqual(keys, all) + with self.subTest(module=modname): + try: + exec("from %s import *" % modname, names) + except Exception as e: + # Include the module name in the exception string + self.fail("__all__ failure in {}: {}: {}".format( + modname, e.__class__.__name__, e)) + if "__builtins__" in names: + del names["__builtins__"] + keys = set(names) + all_list = sys.modules[modname].__all__ + all_set = set(all_list) + self.assertCountEqual(all_set, all_list, "in module {}".format(modname)) + self.assertEqual(keys, all_set, "in module {}".format(modname)) def walk_modules(self, basedir, modpath): for fn in sorted(os.listdir(basedir)): @@ -110,8 +113,5 @@ class AllTest(unittest.TestCase): print('Following modules failed to be imported:', failed_imports) -def test_main(): - support.run_unittest(AllTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test__opcode.py b/Lib/test/test__opcode.py new file mode 100644 index 0000000000..0152e9d94d --- /dev/null +++ b/Lib/test/test__opcode.py @@ -0,0 +1,23 @@ +import dis +from test.support import run_unittest, import_module +import unittest + +_opcode = import_module("_opcode") + +class OpcodeTests(unittest.TestCase): + + def test_stack_effect(self): + self.assertEqual(_opcode.stack_effect(dis.opmap['POP_TOP']), -1) + self.assertEqual(_opcode.stack_effect(dis.opmap['DUP_TOP_TWO']), 2) + self.assertEqual(_opcode.stack_effect(dis.opmap['BUILD_SLICE'], 0), -1) + self.assertEqual(_opcode.stack_effect(dis.opmap['BUILD_SLICE'], 1), -1) + self.assertEqual(_opcode.stack_effect(dis.opmap['BUILD_SLICE'], 3), -2) + self.assertRaises(ValueError, _opcode.stack_effect, 30000) + self.assertRaises(ValueError, _opcode.stack_effect, dis.opmap['BUILD_SLICE']) + self.assertRaises(ValueError, _opcode.stack_effect, dis.opmap['POP_TOP'], 0) + +def test_main(): + run_unittest(OpcodeTests) + +if __name__ == "__main__": + test_main() diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py index d4d7556518..93f9dae215 100644 --- a/Lib/test/test_abc.py +++ b/Lib/test/test_abc.py @@ -68,6 +68,19 @@ class TestLegacyAPI(unittest.TestCase): class TestABC(unittest.TestCase): + def test_ABC_helper(self): + # create an ABC using the helper class and perform basic checks + class C(abc.ABC): + @classmethod + @abc.abstractmethod + def foo(cls): return cls.__name__ + self.assertEqual(type(C), abc.ABCMeta) + self.assertRaises(TypeError, C) + class D(C): + @classmethod + def foo(cls): return super().foo() + self.assertEqual(D.foo(), 'D') + def test_abstractmethod_basics(self): @abc.abstractmethod def foo(self): pass @@ -288,7 +301,10 @@ class TestABC(unittest.TestCase): b = B() self.assertFalse(isinstance(b, A)) self.assertFalse(isinstance(b, (A,))) + token_old = abc.get_cache_token() A.register(B) + token_new = abc.get_cache_token() + self.assertNotEqual(token_old, token_new) self.assertTrue(isinstance(b, A)) self.assertTrue(isinstance(b, (A,))) diff --git a/Lib/test/test_aifc.py b/Lib/test/test_aifc.py index 10260e3269..ab5143787b 100644 --- a/Lib/test/test_aifc.py +++ b/Lib/test/test_aifc.py @@ -1,6 +1,7 @@ from test.support import findfile, TESTFN, unlink import unittest from test import audiotests +from audioop import byteswap import os import io import sys @@ -119,7 +120,7 @@ class AifcULAWTest(AifcTest, unittest.TestCase): E5040CBC 617C0A3C 08BC0A3C 2C7C0B3C 517C0E3C 8A8410FC B6840EBC 457C0A3C \ """) if sys.byteorder != 'big': - frames = audiotests.byteswap2(frames) + frames = byteswap(frames, 2) class AifcALAWTest(AifcTest, unittest.TestCase): @@ -140,7 +141,7 @@ class AifcALAWTest(AifcTest, unittest.TestCase): E4800CC0 62000A40 08C00A40 2B000B40 52000E40 8A001180 B6000EC0 46000A40 \ """) if sys.byteorder != 'big': - frames = audiotests.byteswap2(frames) + frames = byteswap(frames, 2) class AifcMiscTest(audiotests.AudioTests, unittest.TestCase): @@ -149,6 +150,21 @@ class AifcMiscTest(audiotests.AudioTests, unittest.TestCase): #This file contains chunk types aifc doesn't recognize. self.f = aifc.open(findfile('Sine-1000Hz-300ms.aif')) + def test_params_added(self): + f = self.f = aifc.open(TESTFN, 'wb') + f.aiff() + f.setparams((1, 1, 1, 1, b'NONE', b'')) + f.close() + + f = self.f = aifc.open(TESTFN, 'rb') + params = f.getparams() + self.assertEqual(params.nchannels, f.getnchannels()) + self.assertEqual(params.sampwidth, f.getsampwidth()) + self.assertEqual(params.framerate, f.getframerate()) + self.assertEqual(params.nframes, f.getnframes()) + self.assertEqual(params.comptype, f.getcomptype()) + self.assertEqual(params.compname, f.getcompname()) + def test_write_header_comptype_sampwidth(self): for comptype in (b'ULAW', b'ulaw', b'ALAW', b'alaw', b'G722'): fout = aifc.open(io.BytesIO(), 'wb') @@ -352,12 +368,14 @@ class AIFCLowLevelTest(unittest.TestCase): def test_write_aiff_by_extension(self): sampwidth = 2 - fout = self.fout = aifc.open(TESTFN + '.aiff', 'wb') + filename = TESTFN + '.aiff' + fout = self.fout = aifc.open(filename, 'wb') + self.addCleanup(unlink, filename) fout.setparams((1, sampwidth, 1, 1, b'ULAW', b'')) frames = b'\x00' * fout.getnchannels() * sampwidth fout.writeframes(frames) fout.close() - f = self.f = aifc.open(TESTFN + '.aiff', 'rb') + f = self.f = aifc.open(filename, 'rb') self.assertEqual(f.getcomptype(), b'NONE') f.close() diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index 9cb15c78bd..25ba9e62d1 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -14,6 +14,7 @@ import argparse from io import StringIO from test import support +from unittest import mock class StdIOBuffer(StringIO): pass @@ -1421,6 +1422,19 @@ class TestFileTypeRepr(TestCase): type = argparse.FileType('wb', 1) self.assertEqual("FileType('wb', 1)", repr(type)) + def test_r_latin(self): + type = argparse.FileType('r', encoding='latin_1') + self.assertEqual("FileType('r', encoding='latin_1')", repr(type)) + + def test_w_big5_ignore(self): + type = argparse.FileType('w', encoding='big5', errors='ignore') + self.assertEqual("FileType('w', encoding='big5', errors='ignore')", + repr(type)) + + def test_r_1_replace(self): + type = argparse.FileType('r', 1, errors='replace') + self.assertEqual("FileType('r', 1, errors='replace')", repr(type)) + class RFile(object): seen = {} @@ -1557,6 +1571,24 @@ class TestFileTypeWB(TempDirMixin, ParserTestCase): ] +class TestFileTypeOpenArgs(TestCase): + """Test that open (the builtin) is correctly called""" + + def test_open_args(self): + FT = argparse.FileType + cases = [ + (FT('rb'), ('rb', -1, None, None)), + (FT('w', 1), ('w', 1, None, None)), + (FT('w', errors='replace'), ('w', -1, None, 'replace')), + (FT('wb', encoding='big5'), ('wb', -1, 'big5', None)), + (FT('w', 0, 'l1', 'strict'), ('w', 0, 'l1', 'strict')), + ] + with mock.patch('builtins.open') as m: + for type, args in cases: + type('foo') + m.assert_called_with('foo', *args) + + class TestTypeCallable(ParserTestCase): """Test some callables as option/argument types""" @@ -4381,7 +4413,7 @@ class TestOptionalsHelpVersionActions(TestCase): def test_version_format(self): parser = ErrorRaisingArgumentParser(prog='PPP') parser.add_argument('-v', '--version', action='version', version='%(prog)s 3.5') - msg = self._get_error(parser.parse_args, ['-v']).stderr + msg = self._get_error(parser.parse_args, ['-v']).stdout self.assertEqual('PPP 3.5\n', msg) def test_version_no_help(self): @@ -4394,7 +4426,7 @@ class TestOptionalsHelpVersionActions(TestCase): def test_version_action(self): parser = ErrorRaisingArgumentParser(prog='XXX') parser.add_argument('-V', action='version', version='%(prog)s 3.7') - msg = self._get_error(parser.parse_args, ['-V']).stderr + msg = self._get_error(parser.parse_args, ['-V']).stdout self.assertEqual('XXX 3.7\n', msg) def test_no_help(self): diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index d68284f919..f8dbf0695c 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -353,12 +353,12 @@ class BaseTest: support.unlink(support.TESTFN) def test_fromfile_ioerror(self): - # Issue #5395: Check if fromfile raises a proper IOError + # Issue #5395: Check if fromfile raises a proper OSError # instead of EOFError. a = array.array(self.typecode) f = open(support.TESTFN, 'wb') try: - self.assertRaises(IOError, a.fromfile, f, len(self.example)) + self.assertRaises(OSError, a.fromfile, f, len(self.example)) finally: f.close() support.unlink(support.TESTFN) @@ -1026,6 +1026,18 @@ class BaseTest: basesize = support.calcvobjsize('Pn2Pi') support.check_sizeof(self, a, basesize) + def test_initialize_with_unicode(self): + if self.typecode != 'u': + with self.assertRaises(TypeError) as cm: + a = array.array(self.typecode, 'foo') + self.assertIn("cannot use a str", str(cm.exception)) + with self.assertRaises(TypeError) as cm: + a = array.array(self.typecode, array.array('u', 'foo')) + self.assertIn("cannot use a unicode array", str(cm.exception)) + else: + a = array.array(self.typecode, "foo") + a = array.array(self.typecode, array.array('u', 'foo')) + class StringTest(BaseTest): diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 63528885e9..f17c93fc68 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -37,7 +37,7 @@ exec_tests = [ # FunctionDef with kwargs "def f(**kwargs): pass", # FunctionDef with all kind of args - "def f(a, b=1, c=None, d=[], e={}, *args, **kwargs): pass", + "def f(a, b=1, c=None, d=[], e={}, *args, f=42, **kwargs): pass", # ClassDef "class C:pass", # ClassDef, new style class @@ -180,20 +180,36 @@ eval_tests = [ class AST_Tests(unittest.TestCase): - def _assertTrueorder(self, ast_node, parent_pos): + def _assertTrueorder(self, ast_node, parent_pos, reverse_check = False): + def should_reverse_check(parent, child): + # In some situations, the children of nodes occur before + # their parents, for example in a.b.c, a occurs before b + # but a is a child of b. + if isinstance(parent, ast.Call): + if parent.func == child: + return True + if isinstance(parent, (ast.Attribute, ast.Subscript)): + return True + return False + if not isinstance(ast_node, ast.AST) or ast_node._fields is None: return if isinstance(ast_node, (ast.expr, ast.stmt, ast.excepthandler)): node_pos = (ast_node.lineno, ast_node.col_offset) - self.assertTrue(node_pos >= parent_pos) + if reverse_check: + self.assertTrue(node_pos <= parent_pos) + else: + self.assertTrue(node_pos >= parent_pos) parent_pos = (ast_node.lineno, ast_node.col_offset) for name in ast_node._fields: value = getattr(ast_node, name) if isinstance(value, list): for child in value: - self._assertTrueorder(child, parent_pos) + self._assertTrueorder(child, parent_pos, + should_reverse_check(ast_node, child)) elif value is not None: - self._assertTrueorder(value, parent_pos) + self._assertTrueorder(value, parent_pos, + should_reverse_check(ast_node, value)) def test_AST_objects(self): x = ast.AST() @@ -262,14 +278,14 @@ class AST_Tests(unittest.TestCase): def test_arguments(self): x = ast.arguments() - self.assertEqual(x._fields, ('args', 'vararg', 'varargannotation', - 'kwonlyargs', 'kwarg', 'kwargannotation', - 'defaults', 'kw_defaults')) + self.assertEqual(x._fields, ('args', 'vararg', + 'kwonlyargs', 'kw_defaults', + 'kwarg', 'defaults')) with self.assertRaises(AttributeError): x.vararg - x = ast.arguments(*range(1, 9)) + x = ast.arguments(*range(1, 7)) self.assertEqual(x.vararg, 2) def test_field_attr_writable(self): @@ -439,7 +455,7 @@ class ASTHelpers_Test(unittest.TestCase): "lineno=1, col_offset=0), args=[Name(id='eggs', ctx=Load(), " "lineno=1, col_offset=5), Str(s='and cheese', lineno=1, " "col_offset=11)], keywords=[], starargs=None, kwargs=None, " - "lineno=1, col_offset=0), lineno=1, col_offset=0)])" + "lineno=1, col_offset=4), lineno=1, col_offset=0)])" ) def test_copy_location(self): @@ -460,7 +476,7 @@ class ASTHelpers_Test(unittest.TestCase): "Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), " "lineno=1, col_offset=0), args=[Str(s='spam', lineno=1, " "col_offset=6)], keywords=[], starargs=None, kwargs=None, " - "lineno=1, col_offset=0), lineno=1, col_offset=0), " + "lineno=1, col_offset=5), lineno=1, col_offset=0), " "Expr(value=Call(func=Name(id='spam', ctx=Load(), lineno=1, " "col_offset=0), args=[Str(s='eggs', lineno=1, col_offset=0)], " "keywords=[], starargs=None, kwargs=None, lineno=1, " @@ -560,8 +576,8 @@ class ASTValidatorTests(unittest.TestCase): self.mod(m, "must have Load context", "eval") def _check_arguments(self, fac, check): - def arguments(args=None, vararg=None, varargannotation=None, - kwonlyargs=None, kwarg=None, kwargannotation=None, + def arguments(args=None, vararg=None, + kwonlyargs=None, kwarg=None, defaults=None, kw_defaults=None): if args is None: args = [] @@ -571,20 +587,12 @@ class ASTValidatorTests(unittest.TestCase): defaults = [] if kw_defaults is None: kw_defaults = [] - args = ast.arguments(args, vararg, varargannotation, kwonlyargs, - kwarg, kwargannotation, defaults, kw_defaults) + args = ast.arguments(args, vararg, kwonlyargs, kw_defaults, + kwarg, defaults) return fac(args) args = [ast.arg("x", ast.Name("x", ast.Store()))] check(arguments(args=args), "must have Load context") - check(arguments(varargannotation=ast.Num(3)), - "varargannotation but no vararg") - check(arguments(varargannotation=ast.Name("x", ast.Store()), vararg="x"), - "must have Load context") check(arguments(kwonlyargs=args), "must have Load context") - check(arguments(kwargannotation=ast.Num(42)), - "kwargannotation but no kwarg") - check(arguments(kwargannotation=ast.Name("x", ast.Store()), - kwarg="x"), "must have Load context") check(arguments(defaults=[ast.Num(3)]), "more positional defaults than args") check(arguments(kw_defaults=[ast.Num(4)]), @@ -599,7 +607,7 @@ class ASTValidatorTests(unittest.TestCase): "must have Load context") def test_funcdef(self): - a = ast.arguments([], None, None, [], None, None, [], []) + a = ast.arguments([], None, [], [], None, []) f = ast.FunctionDef("x", a, [], [], None) self.stmt(f, "empty body on FunctionDef") f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())], @@ -770,7 +778,7 @@ class ASTValidatorTests(unittest.TestCase): self.expr(u, "must have Load context") def test_lambda(self): - a = ast.arguments([], None, None, [], None, None, [], []) + a = ast.arguments([], None, [], [], None, []) self.expr(ast.Lambda(a, ast.Name("x", ast.Store())), "must have Load context") def fac(args): @@ -928,6 +936,9 @@ class ASTValidatorTests(unittest.TestCase): def test_tuple(self): self._sequence(ast.Tuple) + def test_nameconstant(self): + self.expr(ast.NameConstant(4), "singleton must be True, False, or None") + def test_stdlib_validates(self): stdlib = os.path.dirname(ast.__file__) tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")] @@ -936,13 +947,10 @@ class ASTValidatorTests(unittest.TestCase): fn = os.path.join(stdlib, module) with open(fn, "r", encoding="utf-8") as fp: source = fp.read() - mod = ast.parse(source) + mod = ast.parse(source, fn) compile(mod, fn, "exec") -def test_main(): - support.run_unittest(AST_Tests, ASTHelpers_Test, ASTValidatorTests) - def main(): if __name__ != '__main__': return @@ -955,20 +963,20 @@ def main(): print("]") print("main()") raise SystemExit - test_main() + unittest.main() #### EVERYTHING BELOW IS GENERATED ##### exec_results = [ -('Module', [('Expr', (1, 0), ('Name', (1, 0), 'None', ('Load',)))]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, None, [], None, None, [], []), [('Pass', (1, 9))], [], None)]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', 'a', None)], None, None, [], None, None, [], []), [('Pass', (1, 10))], [], None)]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', 'a', None)], None, None, [], None, None, [('Num', (1, 8), 0)], []), [('Pass', (1, 12))], [], None)]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], 'args', None, [], None, None, [], []), [('Pass', (1, 14))], [], None)]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, None, [], 'kwargs', None, [], []), [('Pass', (1, 17))], [], None)]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', 'a', None), ('arg', 'b', None), ('arg', 'c', None), ('arg', 'd', None), ('arg', 'e', None)], 'args', None, [], 'kwargs', None, [('Num', (1, 11), 1), ('Name', (1, 16), 'None', ('Load',)), ('List', (1, 24), [], ('Load',)), ('Dict', (1, 30), [], [])], []), [('Pass', (1, 52))], [], None)]), +('Module', [('Expr', (1, 0), ('NameConstant', (1, 0), None))]), +('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], None, []), [('Pass', (1, 9))], [], None)]), +('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None)], None, [], [], None, []), [('Pass', (1, 10))], [], None)]), +('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None)], None, [], [], None, [('Num', (1, 8), 0)]), [('Pass', (1, 12))], [], None)]), +('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], ('arg', (1, 7), 'args', None), [], [], None, []), [('Pass', (1, 14))], [], None)]), +('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], ('arg', (1, 8), 'kwargs', None), []), [('Pass', (1, 17))], [], None)]), + ('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None), ('arg', (1, 9), 'b', None), ('arg', (1, 14), 'c', None), ('arg', (1, 22), 'd', None), ('arg', (1, 28), 'e', None)], ('arg', (1, 35), 'args', None), [('arg', (1, 41), 'f', None)], [('Num', (1, 43), 42)], ('arg', (1, 49), 'kwargs', None), [('Num', (1, 11), 1), ('NameConstant', (1, 16), None), ('List', (1, 24), [], ('Load',)), ('Dict', (1, 30), [], [])]), [('Pass', (1, 58))], [], None)]), ('Module', [('ClassDef', (1, 0), 'C', [], [], None, None, [('Pass', (1, 8))], [])]), ('Module', [('ClassDef', (1, 0), 'C', [('Name', (1, 8), 'object', ('Load',))], [], None, None, [('Pass', (1, 17))], [])]), -('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, None, [], None, None, [], []), [('Return', (1, 8), ('Num', (1, 15), 1))], [], None)]), +('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], None, []), [('Return', (1, 8), ('Num', (1, 15), 1))], [], None)]), ('Module', [('Delete', (1, 0), [('Name', (1, 4), 'v', ('Del',))])]), ('Module', [('Assign', (1, 0), [('Name', (1, 0), 'v', ('Store',))], ('Num', (1, 4), 1))]), ('Module', [('AugAssign', (1, 0), ('Name', (1, 0), 'v', ('Store',)), ('Add',), ('Num', (1, 5), 1))]), @@ -977,7 +985,7 @@ exec_results = [ ('Module', [('If', (1, 0), ('Name', (1, 3), 'v', ('Load',)), [('Pass', (1, 5))], [])]), ('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',)))], [('Pass', (1, 13))])]), ('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',))), ('withitem', ('Name', (1, 13), 'z', ('Load',)), ('Name', (1, 18), 'q', ('Store',)))], [('Pass', (1, 21))])]), -('Module', [('Raise', (1, 0), ('Call', (1, 6), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], [], None, None), None)]), +('Module', [('Raise', (1, 0), ('Call', (1, 15), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], [], None, None), None)]), ('Module', [('Try', (1, 0), [('Pass', (2, 2))], [('ExceptHandler', (3, 0), ('Name', (3, 7), 'Exception', ('Load',)), None, [('Pass', (4, 2))])], [], [])]), ('Module', [('Try', (1, 0), [('Pass', (2, 2))], [], [], [('Pass', (4, 2))])]), ('Module', [('Assert', (1, 0), ('Name', (1, 7), 'v', ('Load',)), None)]), @@ -1002,29 +1010,29 @@ single_results = [ ('Interactive', [('Expr', (1, 0), ('BinOp', (1, 0), ('Num', (1, 0), 1), ('Add',), ('Num', (1, 2), 2)))]), ] eval_results = [ -('Expression', ('Name', (1, 0), 'None', ('Load',))), +('Expression', ('NameConstant', (1, 0), None)), ('Expression', ('BoolOp', (1, 0), ('And',), [('Name', (1, 0), 'a', ('Load',)), ('Name', (1, 6), 'b', ('Load',))])), ('Expression', ('BinOp', (1, 0), ('Name', (1, 0), 'a', ('Load',)), ('Add',), ('Name', (1, 4), 'b', ('Load',)))), ('Expression', ('UnaryOp', (1, 0), ('Not',), ('Name', (1, 4), 'v', ('Load',)))), -('Expression', ('Lambda', (1, 0), ('arguments', [], None, None, [], None, None, [], []), ('Name', (1, 7), 'None', ('Load',)))), +('Expression', ('Lambda', (1, 0), ('arguments', [], None, [], [], None, []), ('NameConstant', (1, 7), None))), ('Expression', ('Dict', (1, 0), [('Num', (1, 2), 1)], [('Num', (1, 4), 2)])), ('Expression', ('Dict', (1, 0), [], [])), -('Expression', ('Set', (1, 0), [('Name', (1, 1), 'None', ('Load',))])), +('Expression', ('Set', (1, 0), [('NameConstant', (1, 1), None)])), ('Expression', ('Dict', (1, 0), [('Num', (2, 6), 1)], [('Num', (4, 10), 2)])), ('Expression', ('ListComp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), [('comprehension', ('Name', (1, 7), 'b', ('Store',)), ('Name', (1, 12), 'c', ('Load',)), [('Name', (1, 17), 'd', ('Load',))])])), ('Expression', ('GeneratorExp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), [('comprehension', ('Name', (1, 7), 'b', ('Store',)), ('Name', (1, 12), 'c', ('Load',)), [('Name', (1, 17), 'd', ('Load',))])])), ('Expression', ('Compare', (1, 0), ('Num', (1, 0), 1), [('Lt',), ('Lt',)], [('Num', (1, 4), 2), ('Num', (1, 8), 3)])), -('Expression', ('Call', (1, 0), ('Name', (1, 0), 'f', ('Load',)), [('Num', (1, 2), 1), ('Num', (1, 4), 2)], [('keyword', 'c', ('Num', (1, 8), 3))], ('Name', (1, 11), 'd', ('Load',)), ('Name', (1, 15), 'e', ('Load',)))), +('Expression', ('Call', (1, 1), ('Name', (1, 0), 'f', ('Load',)), [('Num', (1, 2), 1), ('Num', (1, 4), 2)], [('keyword', 'c', ('Num', (1, 8), 3))], ('Name', (1, 11), 'd', ('Load',)), ('Name', (1, 15), 'e', ('Load',)))), ('Expression', ('Num', (1, 0), 10)), ('Expression', ('Str', (1, 0), 'string')), -('Expression', ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',))), -('Expression', ('Subscript', (1, 0), ('Name', (1, 0), 'a', ('Load',)), ('Slice', ('Name', (1, 2), 'b', ('Load',)), ('Name', (1, 4), 'c', ('Load',)), None), ('Load',))), +('Expression', ('Attribute', (1, 2), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',))), +('Expression', ('Subscript', (1, 2), ('Name', (1, 0), 'a', ('Load',)), ('Slice', ('Name', (1, 2), 'b', ('Load',)), ('Name', (1, 4), 'c', ('Load',)), None), ('Load',))), ('Expression', ('Name', (1, 0), 'v', ('Load',))), ('Expression', ('List', (1, 0), [('Num', (1, 1), 1), ('Num', (1, 3), 2), ('Num', (1, 5), 3)], ('Load',))), ('Expression', ('List', (1, 0), [], ('Load',))), ('Expression', ('Tuple', (1, 0), [('Num', (1, 0), 1), ('Num', (1, 2), 2), ('Num', (1, 4), 3)], ('Load',))), ('Expression', ('Tuple', (1, 1), [('Num', (1, 1), 1), ('Num', (1, 3), 2), ('Num', (1, 5), 3)], ('Load',))), ('Expression', ('Tuple', (1, 0), [], ('Load',))), -('Expression', ('Call', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 8), ('Attribute', (1, 8), ('Name', (1, 8), 'a', ('Load',)), 'b', ('Load',)), ('Slice', ('Num', (1, 12), 1), ('Num', (1, 14), 2), None), ('Load',))], [], None, None)), +('Expression', ('Call', (1, 7), ('Attribute', (1, 6), ('Attribute', (1, 4), ('Attribute', (1, 2), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 12), ('Attribute', (1, 10), ('Name', (1, 8), 'a', ('Load',)), 'b', ('Load',)), ('Slice', ('Num', (1, 12), 1), ('Num', (1, 14), 2), None), ('Load',))], [], None, None)), ] main() diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py index c79fe6f613..f93a52d8c9 100644 --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -15,6 +15,7 @@ except ImportError: HOST = support.HOST SERVER_QUIT = b'QUIT\n' +TIMEOUT = 3.0 if threading: class echo_server(threading.Thread): @@ -123,7 +124,9 @@ class TestAsynchat(unittest.TestCase): c.push(b"I'm not dead yet!" + term) c.push(SERVER_QUIT) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"]) @@ -154,7 +157,9 @@ class TestAsynchat(unittest.TestCase): c.push(data) c.push(SERVER_QUIT) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, [data[:termlen]]) @@ -174,7 +179,9 @@ class TestAsynchat(unittest.TestCase): c.push(data) c.push(SERVER_QUIT) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, []) self.assertEqual(c.buffer, data) @@ -186,7 +193,9 @@ class TestAsynchat(unittest.TestCase): p = asynchat.simple_producer(data+SERVER_QUIT, buffer_size=8) c.push_with_producer(p) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"]) @@ -196,7 +205,9 @@ class TestAsynchat(unittest.TestCase): data = b"hello world\nI'm not dead yet!\n" c.push_with_producer(data+SERVER_QUIT) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"]) @@ -207,7 +218,9 @@ class TestAsynchat(unittest.TestCase): c.push(b"hello world\n\nI'm not dead yet!\n") c.push(SERVER_QUIT) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, [b"hello world", b"", b"I'm not dead yet!"]) @@ -226,7 +239,9 @@ class TestAsynchat(unittest.TestCase): # where the server echoes all of its data before we can check that it # got any down below. s.start_resend_event.set() - s.join() + s.join(timeout=TIMEOUT) + if s.is_alive(): + self.fail("join() timed out") self.assertEqual(c.contents, []) # the server might have been able to send a byte or two back, but this @@ -268,9 +283,5 @@ class TestFifo(unittest.TestCase): self.assertEqual(f.pop(), (0, None)) -def test_main(verbose=None): - support.run_unittest(TestAsynchat, TestAsynchat_WithPoll, - TestHelperFunctions, TestFifo) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py new file mode 100644 index 0000000000..23ce5e8059 --- /dev/null +++ b/Lib/test/test_asyncio/__init__.py @@ -0,0 +1,31 @@ +import os +import sys +import unittest +from test.support import run_unittest, import_module + +# Skip tests if we don't have threading. +import_module('threading') +# Skip tests if we don't have concurrent.futures. +import_module('concurrent.futures') + + +def suite(): + tests_file = os.path.join(os.path.dirname(__file__), 'tests.txt') + with open(tests_file) as fp: + test_names = fp.read().splitlines() + tests = unittest.TestSuite() + loader = unittest.TestLoader() + for test_name in test_names: + mod_name = 'test.' + test_name + try: + __import__(mod_name) + except unittest.SkipTest: + pass + else: + mod = sys.modules[mod_name] + tests.addTests(loader.loadTestsFromModule(mod)) + return tests + + +def test_main(): + run_unittest(suite()) diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py new file mode 100644 index 0000000000..b549492038 --- /dev/null +++ b/Lib/test/test_asyncio/__main__.py @@ -0,0 +1,5 @@ +from . import test_main + + +if __name__ == '__main__': + test_main() diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py new file mode 100644 index 0000000000..006364bb00 --- /dev/null +++ b/Lib/test/test_asyncio/echo.py @@ -0,0 +1,8 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + os.write(1, buf) diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py new file mode 100644 index 0000000000..e83ca09fb7 --- /dev/null +++ b/Lib/test/test_asyncio/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py new file mode 100644 index 0000000000..064496736b --- /dev/null +++ b/Lib/test/test_asyncio/echo3.py @@ -0,0 +1,11 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/Lib/test/test_asyncio/keycert3.pem b/Lib/test/test_asyncio/keycert3.pem new file mode 100644 index 0000000000..5bfa62c4ca --- /dev/null +++ b/Lib/test/test_asyncio/keycert3.pem @@ -0,0 +1,73 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP +jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM +9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ +aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe +yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j +y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+ +AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW +5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL +9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9 +1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT +DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh +1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m +JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3 +RnJdHOMXWem7/w== +-----END PRIVATE KEY----- +Certificate: + Data: + Version: 1 (0x0) + Serial Number: 12723342612721443281 (0xb09264b1f2da21d1) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Nov 13 19:47:07 2022 GMT + Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d: + 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb: + c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99: + 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c: + f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93: + 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23: + f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5: + af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6: + 21:82:a5:3c:88:e5:be:1b:b1 + Exponent: 65537 (0x10001) + Signature Algorithm: sha1WithRSAEncryption + 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a: + e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93: + f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13: + e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92: + d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59: + 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8: + ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1: + 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75: + 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96: + 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48: + 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a: + f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6: + 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41: + a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb: + fc:a9:94:71 +-----BEGIN CERTIFICATE----- +MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY +WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV +BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3 +WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV +BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv +c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C +tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola +N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1 +TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR +iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG +xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo +5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv +mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF +YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh +2EJ36/yplHE= +-----END CERTIFICATE----- diff --git a/Lib/test/test_asyncio/pycacert.pem b/Lib/test/test_asyncio/pycacert.pem new file mode 100644 index 0000000000..09b1f3e08a --- /dev/null +++ b/Lib/test/test_asyncio/pycacert.pem @@ -0,0 +1,78 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 12723342612721443280 (0xb09264b1f2da21d0) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Jan 2 19:47:07 2023 GMT + Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2: + 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4: + e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f: + e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f: + 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf: + 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d: + a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3: + e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4: + 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf: + 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c: + e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6: + c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a: + cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01: + 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87: + 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f: + 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14: + e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4: + c5:4d + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + X509v3 Authority Key Identifier: + keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6: + 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d: + a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95: + 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17: + 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c: + 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4: + fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7: + 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24: + 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33: + 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61: + ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f: + 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64: + b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb: + 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3: + 5e:58:c8:9e +-----BEGIN CERTIFICATE----- +MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx +OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV +q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/ +AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA +Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni +0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx +6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w +HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2 +2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB +AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4 +QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1 +Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O +JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR +f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf +9mmvtk57HVjsO6lTo15YyJ4= +-----END CERTIFICATE----- diff --git a/Lib/test/test_asyncio/ssl_cert.pem b/Lib/test/test_asyncio/ssl_cert.pem new file mode 100644 index 0000000000..47a7d7e37e --- /dev/null +++ b/Lib/test/test_asyncio/ssl_cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV +BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u +IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw +MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH +Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k +YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7 +6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt +pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd +BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G +lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1 +CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX +-----END CERTIFICATE----- diff --git a/Lib/test/test_asyncio/ssl_key.pem b/Lib/test/test_asyncio/ssl_key.pem new file mode 100644 index 0000000000..3fd3bbd54a --- /dev/null +++ b/Lib/test/test_asyncio/ssl_key.pem @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANtb0+YrKuxevGpm +LrjaUhZSgz6zFAmuGFmKmUbdjmfv9zSmmdsQIksK++jK0Be9LeZy20j6ahOfuVa0 +ufEmPoP7Fy4hXegKZR9cCWcIe/A6H2xWF1IIJLRTLaU8ol/I7T+um5HD5AwAwNPP +USNU0Eegmvp+xxWu3NX2m1Veot85AgMBAAECgYA3ZdZ673X0oexFlq7AAmrutkHt +CL7LvwrpOiaBjhyTxTeSNWzvtQBkIU8DOI0bIazA4UreAFffwtvEuPmonDb3F+Iq +SMAu42XcGyVZEl+gHlTPU9XRX7nTOXVt+MlRRRxL6t9GkGfUAXI3XxJDXW3c0vBK +UL9xqD8cORXOfE06rQJBAP8mEX1ERkR64Ptsoe4281vjTlNfIbs7NMPkUnrn9N/Y +BLhjNIfQ3HFZG8BTMLfX7kCS9D593DW5tV4Z9BP/c6cCQQDcFzCcVArNh2JSywOQ +ZfTfRbJg/Z5Lt9Fkngv1meeGNPgIMLN8Sg679pAOOWmzdMO3V706rNPzSVMME7E5 +oPIfAkEA8pDddarP5tCvTTgUpmTFbakm0KoTZm2+FzHcnA4jRh+XNTjTOv98Y6Ik +eO5d1ZnKXseWvkZncQgxfdnMqqpj5wJAcNq/RVne1DbYlwWchT2Si65MYmmJ8t+F +0mcsULqjOnEMwf5e+ptq5LzwbyrHZYq5FNk7ocufPv/ZQrcSSC+cFwJBAKvOJByS +x56qyGeZLOQlWS2JS3KJo59XuLFGqcbgN9Om9xFa41Yb4N9NvplFivsvZdw3m1Q/ +SPIXQuT8RMPDVNQ= +-----END PRIVATE KEY----- diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py new file mode 100644 index 0000000000..340ca67d64 --- /dev/null +++ b/Lib/test/test_asyncio/test_base_events.py @@ -0,0 +1,926 @@ +"""Tests for base_events.py""" + +import errno +import logging +import socket +import sys +import time +import unittest +from unittest import mock +from test.support import IPV6_ENABLED + +import asyncio +from asyncio import base_events +from asyncio import constants +from asyncio import test_utils + + +MOCK_ANY = mock.ANY +PY34 = sys.version_info >= (3, 4) + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = mock.Mock() + asyncio.set_event_loop(None) + + def test_not_implemented(self): + m = mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) + + def test__add_callback_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = asyncio.TimerHandle(time.monotonic()+10, lambda: False, (), + self.loop) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = mock.Mock() + self.loop.run_in_executor = mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, asyncio.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, asyncio.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = mock.Mock() + delay = 0.1 + + when = self.loop.time() + delay + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + dt = self.loop.time() - t0 + + # 50 ms: maximum granularity of the event loop + self.assertGreaterEqual(dt, delay - 0.050, dt) + # tolerate a difference of +800 ms because some Python buildbots + # are really slow + self.assertLessEqual(dt, 0.9, dt) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, asyncio.Handle(cb, (), self.loop), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, asyncio.TimerHandle(10, cb, (), self.loop)) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = asyncio.Handle(cb, (), self.loop) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, asyncio.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = asyncio.Handle(cb, (), self.loop) + f = asyncio.Future(loop=self.loop) + executor = mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), + self.loop) + h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), + self.loop) + + h1.cancel() + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.5 < t < 10.5, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + def test_set_debug(self): + self.loop.set_debug(True) + self.assertTrue(self.loop.get_debug()) + self.loop.set_debug(False) + self.assertFalse(self.loop.get_debug()) + + @mock.patch('asyncio.base_events.time') + @mock.patch('asyncio.base_events.logger') + def test__run_once_logging(self, m_logger, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + + self.loop._scheduled.append( + asyncio.TimerHandle(11.0, lambda: True, (), self.loop)) + self.loop._process_events = mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, (), + self.loop)] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), + self.loop) + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises(TypeError, + self.loop.run_until_complete, 'blah') + + def test_subprocess_exec_invalid_args(self): + args = [sys.executable, '-c', 'pass'] + + # missing program parameter (empty args) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol) + + # exepected multiple arguments, not a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, args) + + # program arguments must be strings, not int + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, sys.executable, 123) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, bufsize=4096) + + def test_subprocess_shell_invalid_args(self): + # expected a string, not an int or a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 123) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, [sys.executable, '-c', 'pass']) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) + + def test_default_exc_handler_callback(self): + self.loop._process_events = mock.Mock() + + def zero_error(fut): + fut.set_result(True) + 1/0 + + # Test call_soon (events.Handle) + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_soon(zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + # Test call_later (events.TimerHandle) + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_later(0.01, zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_coro(self): + self.loop._process_events = mock.Mock() + + @asyncio.coroutine + def zero_error_coro(): + yield from asyncio.sleep(0.01, loop=self.loop) + 1/0 + + # Test Future.__del__ + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.async(zero_error_coro(), loop=self.loop) + fut.add_done_callback(lambda *args: self.loop.stop()) + self.loop.run_forever() + fut = None # Trigger Future.__del__ or futures._TracebackLogger + if PY34: + # Future.__del__ in Python 3.4 logs error with + # an actual exception context + log.error.assert_called_with( + test_utils.MockPattern('.*exception was never retrieved'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + else: + # futures._TracebackLogger logs only textual traceback + log.error.assert_called_with( + test_utils.MockPattern( + '.*exception was never retrieved.*ZeroDiv'), + exc_info=False) + + def test_set_exc_handler_invalid(self): + with self.assertRaisesRegex(TypeError, 'A callable object or None'): + self.loop.set_exception_handler('spam') + + def test_set_exc_handler_custom(self): + def zero_error(): + 1/0 + + def run_loop(): + self.loop.call_soon(zero_error) + self.loop._run_once() + + self.loop._process_events = mock.Mock() + + mock_handler = mock.Mock() + self.loop.set_exception_handler(mock_handler) + run_loop() + mock_handler.assert_called_with(self.loop, { + 'exception': MOCK_ANY, + 'message': test_utils.MockPattern( + 'Exception in callback.*zero_error'), + 'handle': MOCK_ANY, + }) + mock_handler.reset_mock() + + self.loop.set_exception_handler(None) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + assert not mock_handler.called + + def test_set_exc_handler_broken(self): + def run_loop(): + def zero_error(): + 1/0 + self.loop.call_soon(zero_error) + self.loop._run_once() + + def handler(loop, context): + raise AttributeError('spam') + + self.loop._process_events = mock.Mock() + + self.loop.set_exception_handler(handler) + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Unhandled error in exception handler'), + exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_broken(self): + _context = None + + class Loop(base_events.BaseEventLoop): + + _selector = mock.Mock() + _process_events = mock.Mock() + + def default_exception_handler(self, context): + nonlocal _context + _context = context + # Simulates custom buggy "default_exception_handler" + raise ValueError('spam') + + loop = Loop() + asyncio.set_event_loop(loop) + + def run_loop(): + def zero_error(): + 1/0 + loop.call_soon(zero_error) + loop._run_once() + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + 'Exception in default exception handler', + exc_info=True) + + def custom_handler(loop, context): + raise ValueError('ham') + + _context = None + loop.set_exception_handler(custom_handler) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern('Exception in default exception.*' + 'while handling.*in custom'), + exc_info=True) + + # Check that original context was passed to default + # exception handler. + self.assertIn('context', _context) + self.assertIs(type(_context['context']['exception']), + ZeroDivisionError) + + +class MyProto(asyncio.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def error_received(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @asyncio.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_default(self): + self.loop.getaddrinfo = mock.Mock() + + def mock_getaddrinfo(*args, **kwds): + f = asyncio.Future(loop=self.loop) + f.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.SOL_TCP, '', ('1.2.3.4', 80))]) + return f + + self.loop.getaddrinfo.side_effect = mock_getaddrinfo + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.return_value = () + self.loop._make_ssl_transport = mock.Mock() + + class _SelectorTransportMock: + _sock = None + + def close(self): + self._sock.close() + + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, + **kwds): + waiter.set_result(None) + transport = _SelectorTransportMock() + transport._sock = sock + return transport + + self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport + ANY = mock.ANY + # First try the default server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='python.org') + # Next try an explicit server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='perl.com') + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='perl.com') + # Finally try an explicit empty server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='') + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='') + + def test_create_connection_no_ssl_server_hostname_errors(self): + # When not using ssl, server_hostname must be None. + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='python.org') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_errors(self): + # When using ssl, server_hostname may be None if host is non-empty. + coro = self.loop.create_connection(MyProto, '', 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, None, 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + sock = socket.socket() + coro = self.loop.create_connection(MyProto, None, None, + ssl=True, sock=sock) + self.addCleanup(sock.close) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @mock.patch('asyncio.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_socket.getaddrinfo._is_coroutine = False + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + m_socket.getaddrinfo._is_coroutine = False + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @mock.patch('asyncio.base_events.logger') + def test_accept_connection_exception(self, m_log): + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') + self.loop.remove_reader = mock.Mock() + self.loop.call_later = mock.Mock() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(m_log.error.called) + self.assertFalse(sock.close.called) + self.loop.remove_reader.assert_called_with(10) + self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, + # self.loop._start_serving + mock.ANY, + MyProto, sock, None, None) + + def test_call_coroutine(self): + @asyncio.coroutine + def coroutine_function(): + pass + + with self.assertRaises(TypeError): + self.loop.call_soon(coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_later(60, coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, coroutine_function) + with self.assertRaises(TypeError): + self.loop.run_in_executor(None, coroutine_function) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py new file mode 100644 index 0000000000..f01d1f38a8 --- /dev/null +++ b/Lib/test/test_asyncio/test_events.py @@ -0,0 +1,2016 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import platform +import signal +import socket +try: + import ssl +except ImportError: + ssl = None + HAS_SNI = False +else: + from ssl import HAS_SNI +import subprocess +import sys +import threading +import time +import errno +import unittest +from unittest import mock +from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR + + +import asyncio +from asyncio import selector_events +from asyncio import test_utils + + +def data_file(filename): + if hasattr(support, 'TEST_HOME_DIR'): + fullname = os.path.join(support.TEST_HOME_DIR, filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(filename) + + +def osx_tiger(): + """Return True if the platform is Mac OS 10.4 or older.""" + if sys.platform != 'darwin': + return False + version = platform.mac_ver()[0] + version = tuple(map(int, version.split('.'))) + return version < (10, 5) + + +ONLYCERT = data_file('ssl_cert.pem') +ONLYKEY = data_file('ssl_key.pem') +SIGNED_CERTFILE = data_file('keycert3.pem') +SIGNING_CA = data_file('pycacert.pem') + + +class MyBaseProto(asyncio.Protocol): + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def error_received(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(asyncio.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + + def connection_lost(self, exc): + if 'EOF' not in self.state: + self.state.append('EOF') # It is okay if EOF is missed. + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(asyncio.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(asyncio.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = asyncio.Future(loop=loop) + self.completed = asyncio.Future(loop=loop) + self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: asyncio.Event(loop=loop), + 2: asyncio.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @asyncio.coroutine + def coro1(): + yield + + @asyncio.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) + + def test_run_until_complete_stopped(self): + @asyncio.coroutine + def cb(): + self.loop.stop() + yield from asyncio.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def _basetest_sock_client_ops(self, httpd, sock): + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + self._basetest_sock_client_ops(httpd, sock) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.loop.call_later(0.5, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def _basetest_create_connection(self, connection_fut, check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + self._basetest_create_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut, check_sockname) + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def _basetest_create_ssl_connection(self, connection_fut, + check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, + ssl=test_utils.dummy_ssl_context()) + + self._basetest_create_ssl_connection(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_ssl_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='127.0.0.1') + + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = support.find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = MyProto() + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.sendall(b'xxx') + test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto is not None, 10) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server(self): + proto = MyProto() + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto is not None, 10) + + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _create_ssl_context(self, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) + server = self.loop.run_until_complete(f) + + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + return server, host, port + + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, ONLYCERT, ONLYKEY) + + f_c = self.loop.create_connection(MyBaseProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, ONLYCERT, ONLYKEY) + + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # incorrect server_hostname + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with self.assertRaisesRegex( + ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + def test_create_server_sock(self): + proto = asyncio.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_server_dual_stack(self): + f_proto = asyncio.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = support.find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = asyncio.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + for _ in range(1000): + if server.nbytes: + break + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + for _ in range(1000): + if client.nbytes: + break + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + self.skipTest('loop is not a BaseSelectorEventLoop') + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = MyReadPipeProto(loop=self.loop) + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @asyncio.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe( + lambda: proto, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 + @support.requires_freebsd_version(8) + def test_read_pty_output(self): + proto = MyReadPipeProto(loop=self.loop) + + master, slave = os.openpty() + master_read_obj = io.open(master, 'rb', 0) + + @asyncio.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(lambda: proto, + master_read_obj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(slave, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes) + self.assertEqual(1, proto.nbytes) + + os.write(slave, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(slave) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + rsock, wsock = test_utils.socketpair() + pipeobj = io.open(wsock.detach(), 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024)) + self.assertEqual(b'1', data) + + rsock.close() + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + def test_write_pty(self): + master, slave = os.openpty() + slave_write_obj = io.open(slave, 'wb', 0) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(master, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(master, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(master) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + if ov is not None: + self.assertTrue(ov.pending) + + @asyncio.coroutine + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except asyncio.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = asyncio.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(asyncio.CancelledError, f.result) + if ov is not None: + self.assertFalse(ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + def test_timeout_rounding(self): + def _run_once(): + self.loop._run_once_counter += 1 + orig_run_once() + + orig_run_once = self.loop._run_once + self.loop._run_once_counter = 0 + self.loop._run_once = _run_once + + @asyncio.coroutine + def wait(): + loop = self.loop + yield from asyncio.sleep(1e-2, loop=loop) + yield from asyncio.sleep(1e-4, loop=loop) + yield from asyncio.sleep(1e-6, loop=loop) + yield from asyncio.sleep(1e-8, loop=loop) + yield from asyncio.sleep(1e-10, loop=loop) + + self.loop.run_until_complete(wait()) + # The ideal number of call is 12, but on some platforms, the selector + # may sleep at little bit less than timeout depending on the resolution + # of the clock used by the kernel. Tolerate a few useless calls on + # these platforms. + self.assertLessEqual(self.loop._run_once_counter, 20, + {'clock_resolution': self.loop._clock_resolution, + 'selector': self.loop._selector.__class__.__name__}) + + def test_sock_connect_address(self): + addresses = [(socket.AF_INET, ('www.python.org', 80))] + if support.IPV6_ENABLED: + addresses.extend(( + (socket.AF_INET6, ('www.python.org', 80)), + (socket.AF_INET6, ('www.python.org', 80, 0, 0)), + )) + + for family, address in addresses: + for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): + sock = socket.socket(family, sock_type) + with sock: + connect = self.loop.sock_connect(sock, address) + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(connect) + self.assertIn('address must be resolved', + str(cm.exception)) + + +class SubprocessTestsMixin: + + def check_terminated(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + def check_killed(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_subprocess_exec(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + def test_subprocess_interactive(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + def test_subprocess_shell(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo Python') + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python') + self.assertEqual(proto.data[2], b'') + + def test_subprocess_exitcode(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + def test_subprocess_close_after_finish(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + def test_subprocess_kill(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + + def test_subprocess_terminate(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.terminate() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + def test_subprocess_send_signal(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + def test_subprocess_stderr(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + def test_subprocess_stderr_redirect_to_stdout(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + def test_subprocess_close_client_stream(self): + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + if sys.platform != 'win32': + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + else: + # After closing the read-end of a pipe, writing to the + # write-end using os.write() fails with errno==EINVAL and + # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using + # WriteFile() we get ERROR_BROKEN_PIPE as expected.) + self.assertEqual(b'ERR:OSError', proto.data[2]) + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + def test_subprocess_wait_no_same_group(self): + # start the new process in a new session + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None, + start_new_session=True) + _, proto = yield self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + def test_subprocess_exec_invalid_args(self): + @asyncio.coroutine + def connect(**kwds): + yield from self.loop.subprocess_exec( + asyncio.SubprocessProtocol, + 'pwd', **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=True)) + + def test_subprocess_shell_invalid_args(self): + @asyncio.coroutine + def connect(cmd=None, **kwds): + if not cmd: + cmd = 'pwd' + yield from self.loop.subprocess_shell( + asyncio.SubprocessProtocol, + cmd, **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(['ls', '-l'])) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=False)) + + +if sys.platform == 'win32': + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_match_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_verified(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from asyncio import selectors + + class UnixEventLoopTestsMixin(EventLoopTestsMixin): + def setUp(self): + super().setUp() + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + asyncio.set_child_watcher(watcher) + + def tearDown(self): + asyncio.set_child_watcher(None) + super().tearDown() + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + # Issue #20667: KqueueEventLoopTests.test_read_pty_output() + # hangs on OpenBSD 5.5 + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') + def test_read_pty_output(self): + super().test_read_pty_output() + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_write_pty(self): + super().test_write_pty() + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = asyncio.Handle(callback, args, mock.Mock()) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())<cancelled>'), r) + + def test_handle_from_handle(self): + def callback(*args): + return args + m_loop = object() + h1 = asyncio.Handle(callback, (), loop=m_loop) + self.assertRaises( + AssertionError, asyncio.Handle, h1, (), m_loop) + + def test_callback_with_exception(self): + def callback(): + raise ValueError() + + m_loop = mock.Mock() + m_loop.call_exception_handler = mock.Mock() + + h = asyncio.Handle(callback, (), m_loop) + h._run() + + m_loop.call_exception_handler.assert_called_with({ + 'message': test_utils.MockPattern('Exception in callback.*'), + 'exception': mock.ANY, + 'handle': h + }) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = asyncio.TimerHandle(when, lambda: False, (), + mock.Mock()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = asyncio.TimerHandle(when, callback, args, mock.Mock()) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())<cancelled>'), r) + + self.assertRaises(AssertionError, + asyncio.TimerHandle, None, callback, args, + mock.Mock()) + + def test_timer_comparison(self): + loop = mock.Mock() + + def callback(*args): + return args + + when = time.monotonic() + + h1 = asyncio.TimerHandle(when, callback, (), loop) + h2 = asyncio.TimerHandle(when, callback, (), loop) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = asyncio.TimerHandle(when, callback, (), loop) + h2 = asyncio.TimerHandle(when + 10.0, callback, (), loop) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = asyncio.Handle(callback, (), loop) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + loop = asyncio.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.close) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = mock.Mock() + p = asyncio.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = asyncio.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.error_received(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = asyncio.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = asyncio.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + self.assertRaises(NotImplementedError, policy.get_child_watcher) + self.assertRaises(NotImplementedError, policy.set_child_watcher, + object()) + + def test_get_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + self.assertIsNone(policy._local._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + + self.assertIs(policy._local._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_calls_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + with mock.patch.object( + policy, "set_event_loop", + wraps=policy.set_event_loop) as m_set_event_loop: + + loop = policy.get_event_loop() + + # policy._local._loop must be set through .set_event_loop() + # (the unix DefaultEventLoopPolicy needs this call to attach + # the child watcher correctly) + m_set_event_loop.assert_called_with(loop) + + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = asyncio.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = asyncio.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = asyncio.get_event_loop_policy() + self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, asyncio.set_event_loop_policy, object()) + + old_policy = asyncio.get_event_loop_policy() + + policy = asyncio.DefaultEventLoopPolicy() + asyncio.set_event_loop_policy(policy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py new file mode 100644 index 0000000000..399e8f438e --- /dev/null +++ b/Lib/test/test_asyncio/test_futures.py @@ -0,0 +1,351 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +from unittest import mock + +import asyncio +from asyncio import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = asyncio.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + asyncio.set_event_loop(self.loop) + f = asyncio.Future() + self.assertIs(f._loop, self.loop) + finally: + asyncio.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future doesn't accept a positional argument + self.assertRaises(TypeError, asyncio.Future, 42) + + def test_cancel(self): + f = asyncio.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(asyncio.CancelledError, f.result) + self.assertRaises(asyncio.CancelledError, f.exception) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception_class(self): + f = asyncio.Future(loop=self.loop) + f.set_exception(RuntimeError) + self.assertIsInstance(f.exception(), RuntimeError) + + def test_yield_from_twice(self): + f = asyncio.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = asyncio.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future<PENDING>') + f_pending.cancel() + + f_cancelled = asyncio.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>') + + f_result = asyncio.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future<result=4>') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = asyncio.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = asyncio.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future<PENDING, [<function _fakefunc', + repr(f_few_callbacks)) + f_few_callbacks.cancel() + + f_many_callbacks = asyncio.Future(loop=self.loop) + for i in range(20): + f_many_callbacks.add_done_callback(_fakefunc) + r = repr(f_many_callbacks) + self.assertIn('Future<PENDING, [<function _fakefunc', r) + self.assertIn('<18 more>', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = asyncio.Future(loop=self.loop) + f.set_result(10) + + newf = asyncio.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = asyncio.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = asyncio.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = asyncio.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = asyncio.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = asyncio.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_abandoned(self, m_log): + fut = asyncio.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_unretrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, asyncio.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = asyncio.Future(loop=self.loop) + f2 = asyncio.wrap_future(f1) + self.assertIs(f1, f2) + + @mock.patch('asyncio.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + def test_wrap_future_cancel(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(f1.cancelled()) + self.assertTrue(f2.cancelled()) + + def test_wrap_future_cancel2(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f1.set_result(42) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertFalse(f1.cancelled()) + self.assertEqual(f1.result(), 42) + self.assertTrue(f2.cancelled()) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return asyncio.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py new file mode 100644 index 0000000000..f542463ad2 --- /dev/null +++ b/Lib/test/test_asyncio/test_locks.py @@ -0,0 +1,870 @@ +"""Tests for lock.py""" + +import unittest +from unittest import mock +import re + +import asyncio +from asyncio import test_utils + + +STR_RGX_REPR = ( + r'^<(?P<class>.*?) object at (?P<address>.*?)' + r'\[(?P<extras>' + r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?' + r')\]>\Z' +) +RGX_REPR = re.compile(STR_RGX_REPR) + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = mock.Mock() + lock = asyncio.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = asyncio.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + asyncio.set_event_loop(self.loop) + lock = asyncio.Lock() + self.assertIs(lock._loop, self.loop) + finally: + asyncio.set_event_loop(None) + + def test_repr(self): + lock = asyncio.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + @asyncio.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + def test_lock(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = asyncio.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @asyncio.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = asyncio.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = asyncio.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = asyncio.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = asyncio.Future(loop=self.loop) + ta = asyncio.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = asyncio.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = asyncio.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = asyncio.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = asyncio.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_cant_reuse(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + # This spells "yield from lock" outside a generator. + cm = self.loop.run_until_complete(acquire_lock()) + with cm: + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + with self.assertRaises(AttributeError): + with cm: + pass + + def test_context_manager_no_yield(self): + lock = asyncio.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertFalse(lock.locked()) + + +class EventTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = mock.Mock() + ev = asyncio.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = asyncio.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + asyncio.set_event_loop(self.loop) + ev = asyncio.Event() + self.assertIs(ev._loop, self.loop) + finally: + asyncio.set_event_loop(None) + + def test_repr(self): + ev = asyncio.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + match = RGX_REPR.match(repr(ev)) + self.assertEqual(match.group('extras'), 'unset') + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + self.assertTrue(RGX_REPR.match(repr(ev))) + + ev._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(ev)) + self.assertTrue(RGX_REPR.match(repr(ev))) + + def test_wait(self): + ev = asyncio.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @asyncio.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @asyncio.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @asyncio.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = asyncio.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = asyncio.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = asyncio.Event(loop=self.loop) + + wait = asyncio.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = asyncio.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = asyncio.Event(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = asyncio.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = mock.Mock() + cond = asyncio.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = asyncio.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + asyncio.set_event_loop(self.loop) + cond = asyncio.Condition() + self.assertIs(cond._loop, self.loop) + finally: + asyncio.set_event_loop(None) + + def test_wait(self): + cond = asyncio.Condition(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = asyncio.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = asyncio.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = asyncio.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = asyncio.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = asyncio.Condition(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @asyncio.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = asyncio.Condition(loop=self.loop) + + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + def test_repr(self): + cond = asyncio.Condition(loop=self.loop) + self.assertTrue('unlocked' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + self.loop.run_until_complete(cond.acquire()) + self.assertTrue('locked' in repr(cond)) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + def test_context_manager(self): + cond = asyncio.Condition(loop=self.loop) + + @asyncio.coroutine + def acquire_cond(): + return (yield from cond) + + with self.loop.run_until_complete(acquire_cond()): + self.assertTrue(cond.locked()) + + self.assertFalse(cond.locked()) + + def test_context_manager_no_yield(self): + cond = asyncio.Condition(loop=self.loop) + + try: + with cond: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertFalse(cond.locked()) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = mock.Mock() + sem = asyncio.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = asyncio.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + asyncio.set_event_loop(self.loop) + sem = asyncio.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + asyncio.set_event_loop(None) + + def test_initial_value_zero(self): + sem = asyncio.Semaphore(0, loop=self.loop) + self.assertTrue(sem.locked()) + + def test_repr(self): + sem = asyncio.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + self.assertTrue(RGX_REPR.match(repr(sem))) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + self.assertTrue('waiters' not in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + def test_semaphore(self): + sem = asyncio.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @asyncio.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, asyncio.Semaphore, -1) + + def test_acquire(self): + sem = asyncio.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @asyncio.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @asyncio.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = asyncio.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = asyncio.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = asyncio.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = asyncio.BoundedSemaphore(loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = asyncio.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + def test_context_manager_no_yield(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + try: + with sem: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py new file mode 100644 index 0000000000..5bf24a4503 --- /dev/null +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -0,0 +1,490 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +from unittest import mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + + def test_ctor(self): + fut = asyncio.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = mock.Mock() + tr._force_close = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, None) + tr._loop_writing.assert_called_with(data=b'data') + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = mock.Mock() + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, b'data') + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = bytearray(b'data') + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @mock.patch('asyncio.proactor_events.logger') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, None) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @mock.patch('asyncio.base_events.logger') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.error.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = mock.Mock() + write_fut = tr._write_fut = mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + + def test_pause_resume_reading(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = asyncio.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause_reading() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume_reading() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = mock.Mock(socket.socket) + self.proactor = mock.Mock() + + self.ssock, self.csock = mock.Mock(), mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @mock.patch.object(BaseProactorEventLoop, 'call_soon') + @mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + mock.Mock(), mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @mock.patch('asyncio.base_events.logger') + def test_create_server(self, m_log): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = mock.Mock() + fut.result.return_value = (mock.Mock(), mock.Mock()) + + make_tr = self.loop._make_socket_transport = mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.error.called) + + def test_create_server_cancel(self): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = asyncio.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py new file mode 100644 index 0000000000..f79fee21d6 --- /dev/null +++ b/Lib/test/test_asyncio/test_queues.py @@ -0,0 +1,466 @@ +"""Tests for queues.py""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = asyncio.Queue(loop=loop) + self.assertTrue(fn(q).startswith('<Queue'), fn(q)) + id_is_present = hex(id(q)) in fn(q) + self.assertEqual(expect_id, id_is_present) + + @asyncio.coroutine + def add_getter(): + q = asyncio.Queue(loop=loop) + # Start a task that waits to get. + asyncio.Task(q.get(), loop=loop) + # Let it start waiting. + yield from asyncio.sleep(0.1, loop=loop) + self.assertTrue('_getters[1]' in fn(q)) + # resume q.get coroutine to finish generator + q.put_nowait(0) + + loop.run_until_complete(add_getter()) + + @asyncio.coroutine + def add_putter(): + q = asyncio.Queue(maxsize=1, loop=loop) + q.put_nowait(1) + # Start a task that waits to put. + asyncio.Task(q.put(2), loop=loop) + # Let it start waiting. + yield from asyncio.sleep(0.1, loop=loop) + self.assertTrue('_putters[1]' in fn(q)) + # resume q.put coroutine to finish generator + q.get_nowait() + + loop.run_until_complete(add_putter()) + + q = asyncio.Queue(loop=loop) + q.put_nowait(1) + self.assertTrue('_queue=[1]' in fn(q)) + + def test_ctor_loop(self): + loop = mock.Mock() + q = asyncio.Queue(loop=loop) + self.assertIs(q._loop, loop) + + q = asyncio.Queue(loop=self.loop) + self.assertIs(q._loop, self.loop) + + def test_ctor_noloop(self): + try: + asyncio.set_event_loop(self.loop) + q = asyncio.Queue() + self.assertIs(q._loop, self.loop) + finally: + asyncio.set_event_loop(None) + + def test_repr(self): + self._test_repr_or_str(repr, True) + + def test_str(self): + self._test_repr_or_str(str, False) + + def test_empty(self): + q = asyncio.Queue(loop=self.loop) + self.assertTrue(q.empty()) + q.put_nowait(1) + self.assertFalse(q.empty()) + self.assertEqual(1, q.get_nowait()) + self.assertTrue(q.empty()) + + def test_full(self): + q = asyncio.Queue(loop=self.loop) + self.assertFalse(q.full()) + + q = asyncio.Queue(maxsize=1, loop=self.loop) + q.put_nowait(1) + self.assertTrue(q.full()) + + def test_order(self): + q = asyncio.Queue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 3, 2], items) + + def test_maxsize(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.02, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = asyncio.Queue(maxsize=2, loop=loop) + self.assertEqual(2, q.maxsize) + have_been_put = [] + + @asyncio.coroutine + def putter(): + for i in range(3): + yield from q.put(i) + have_been_put.append(i) + return True + + @asyncio.coroutine + def test(): + t = asyncio.Task(putter(), loop=loop) + yield from asyncio.sleep(0.01, loop=loop) + + # The putter is blocked after putting two items. + self.assertEqual([0, 1], have_been_put) + self.assertEqual(0, q.get_nowait()) + + # Let the putter resume and put last item. + yield from asyncio.sleep(0.01, loop=loop) + self.assertEqual([0, 1, 2], have_been_put) + self.assertEqual(1, q.get_nowait()) + self.assertEqual(2, q.get_nowait()) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + loop.run_until_complete(test()) + self.assertAlmostEqual(0.02, loop.time()) + + +class QueueGetTests(_QueueTestBase): + + def test_blocking_get(self): + q = asyncio.Queue(loop=self.loop) + q.put_nowait(1) + + @asyncio.coroutine + def queue_get(): + return (yield from q.get()) + + res = self.loop.run_until_complete(queue_get()) + self.assertEqual(1, res) + + def test_get_with_putters(self): + q = asyncio.Queue(1, loop=self.loop) + q.put_nowait(1) + + waiter = asyncio.Future(loop=self.loop) + q._putters.append((2, waiter)) + + res = self.loop.run_until_complete(q.get()) + self.assertEqual(1, res) + self.assertTrue(waiter.done()) + self.assertIsNone(waiter.result()) + + def test_blocking_get_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = asyncio.Queue(loop=loop) + started = asyncio.Event(loop=loop) + finished = False + + @asyncio.coroutine + def queue_get(): + nonlocal finished + started.set() + res = yield from q.get() + finished = True + return res + + @asyncio.coroutine + def queue_put(): + loop.call_later(0.01, q.put_nowait, 1) + queue_get_task = asyncio.Task(queue_get(), loop=loop) + yield from started.wait() + self.assertFalse(finished) + res = yield from queue_get_task + self.assertTrue(finished) + return res + + res = loop.run_until_complete(queue_put()) + self.assertEqual(1, res) + self.assertAlmostEqual(0.01, loop.time()) + + def test_nonblocking_get(self): + q = asyncio.Queue(loop=self.loop) + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_get_exception(self): + q = asyncio.Queue(loop=self.loop) + self.assertRaises(asyncio.QueueEmpty, q.get_nowait) + + def test_get_cancelled(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.061, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = asyncio.Queue(loop=loop) + + @asyncio.coroutine + def queue_get(): + return (yield from asyncio.wait_for(q.get(), 0.051, loop=loop)) + + @asyncio.coroutine + def test(): + get_task = asyncio.Task(queue_get(), loop=loop) + yield from asyncio.sleep(0.01, loop=loop) # let the task start + q.put_nowait(1) + return (yield from get_task) + + self.assertEqual(1, loop.run_until_complete(test())) + self.assertAlmostEqual(0.06, loop.time()) + + def test_get_cancelled_race(self): + q = asyncio.Queue(loop=self.loop) + + t1 = asyncio.Task(q.get(), loop=self.loop) + t2 = asyncio.Task(q.get(), loop=self.loop) + + test_utils.run_briefly(self.loop) + t1.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t1.done()) + q.put_nowait('a') + test_utils.run_briefly(self.loop) + self.assertEqual(t2.result(), 'a') + + def test_get_with_waiting_putters(self): + q = asyncio.Queue(loop=self.loop, maxsize=1) + asyncio.Task(q.put('a'), loop=self.loop) + asyncio.Task(q.put('b'), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(self.loop.run_until_complete(q.get()), 'a') + self.assertEqual(self.loop.run_until_complete(q.get()), 'b') + + +class QueuePutTests(_QueueTestBase): + + def test_blocking_put(self): + q = asyncio.Queue(loop=self.loop) + + @asyncio.coroutine + def queue_put(): + # No maxsize, won't block. + yield from q.put(1) + + self.loop.run_until_complete(queue_put()) + + def test_blocking_put_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = asyncio.Queue(maxsize=1, loop=loop) + started = asyncio.Event(loop=loop) + finished = False + + @asyncio.coroutine + def queue_put(): + nonlocal finished + started.set() + yield from q.put(1) + yield from q.put(2) + finished = True + + @asyncio.coroutine + def queue_get(): + loop.call_later(0.01, q.get_nowait) + queue_put_task = asyncio.Task(queue_put(), loop=loop) + yield from started.wait() + self.assertFalse(finished) + yield from queue_put_task + self.assertTrue(finished) + + loop.run_until_complete(queue_get()) + self.assertAlmostEqual(0.01, loop.time()) + + def test_nonblocking_put(self): + q = asyncio.Queue(loop=self.loop) + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_put_exception(self): + q = asyncio.Queue(maxsize=1, loop=self.loop) + q.put_nowait(1) + self.assertRaises(asyncio.QueueFull, q.put_nowait, 2) + + def test_put_cancelled(self): + q = asyncio.Queue(loop=self.loop) + + @asyncio.coroutine + def queue_put(): + yield from q.put(1) + return True + + @asyncio.coroutine + def test(): + return (yield from q.get()) + + t = asyncio.Task(queue_put(), loop=self.loop) + self.assertEqual(1, self.loop.run_until_complete(test())) + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_put_cancelled_race(self): + q = asyncio.Queue(loop=self.loop, maxsize=1) + + asyncio.Task(q.put('a'), loop=self.loop) + asyncio.Task(q.put('c'), loop=self.loop) + t = asyncio.Task(q.put('b'), loop=self.loop) + + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t.done()) + self.assertEqual(q.get_nowait(), 'a') + self.assertEqual(q.get_nowait(), 'c') + + def test_put_with_waiting_getters(self): + q = asyncio.Queue(loop=self.loop) + t = asyncio.Task(q.get(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.loop.run_until_complete(q.put('a')) + self.assertEqual(self.loop.run_until_complete(t), 'a') + + +class LifoQueueTests(_QueueTestBase): + + def test_order(self): + q = asyncio.LifoQueue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([2, 3, 1], items) + + +class PriorityQueueTests(_QueueTestBase): + + def test_order(self): + q = asyncio.PriorityQueue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 2, 3], items) + + +class JoinableQueueTests(_QueueTestBase): + + def test_task_done_underflow(self): + q = asyncio.JoinableQueue(loop=self.loop) + self.assertRaises(ValueError, q.task_done) + + def test_task_done(self): + q = asyncio.JoinableQueue(loop=self.loop) + for i in range(100): + q.put_nowait(i) + + accumulator = 0 + + # Two workers get items from the queue and call task_done after each. + # Join the queue and assert all items have been processed. + running = True + + @asyncio.coroutine + def worker(): + nonlocal accumulator + + while running: + item = yield from q.get() + accumulator += item + q.task_done() + + @asyncio.coroutine + def test(): + for _ in range(2): + asyncio.Task(worker(), loop=self.loop) + + yield from q.join() + + self.loop.run_until_complete(test()) + self.assertEqual(sum(range(100)), accumulator) + + # close running generators + running = False + for i in range(2): + q.put_nowait(0) + + def test_join_empty_queue(self): + q = asyncio.JoinableQueue(loop=self.loop) + + # Test that a queue join()s successfully, and before anything else + # (done twice for insurance). + + @asyncio.coroutine + def join(): + yield from q.join() + yield from q.join() + + self.loop.run_until_complete(join()) + + def test_format(self): + q = asyncio.JoinableQueue(loop=self.loop) + self.assertEqual(q._format(), 'maxsize=0') + + q._unfinished_tasks = 2 + self.assertEqual(q._format(), 'maxsize=0 tasks=2') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py new file mode 100644 index 0000000000..964b2e8ec8 --- /dev/null +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -0,0 +1,1690 @@ +"""Tests for selector_events.py""" + +import errno +import gc +import pprint +import socket +import sys +import unittest +from unittest import mock +try: + import ssl +except ImportError: + ssl = None + +import asyncio +from asyncio import selectors +from asyncio import test_utils +from asyncio.selector_events import BaseSelectorEventLoop +from asyncio.selector_events import _SelectorTransport +from asyncio.selector_events import _SelectorSslTransport +from asyncio.selector_events import _SelectorSocketTransport +from asyncio.selector_events import _SelectorDatagramTransport + + +MOCK_ANY = mock.ANY + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = mock.Mock() + self._csock = mock.Mock() + self._internal_fds += 1 + + +def list_to_buffer(l=()): + return bytearray().join(l) + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + selector = mock.Mock() + self.loop = TestBaseSelectorEventLoop(selector) + + def test_make_socket_transport(self): + m = mock.Mock() + self.loop.add_reader = mock.Mock() + transport = self.loop._make_socket_transport(m, asyncio.Protocol()) + self.assertIsInstance(transport, _SelectorSocketTransport) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_make_ssl_transport(self): + m = mock.Mock() + self.loop.add_reader = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() + waiter = asyncio.Future(loop=self.loop) + transport = self.loop._make_ssl_transport( + m, asyncio.Protocol(), m, waiter) + self.assertIsInstance(transport, _SelectorSslTransport) + + @mock.patch('asyncio.selector_events.ssl', None) + def test_make_ssl_transport_without_ssl_error(self): + m = mock.Mock() + self.loop.add_reader = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() + with self.assertRaises(RuntimeError): + self.loop._make_ssl_transport(m, m, m, m) + + def test_close(self): + ssock = self.loop._ssock + ssock.fileno.return_value = 7 + csock = self.loop._csock + csock.fileno.return_value = 1 + remove_reader = self.loop.remove_reader = mock.Mock() + + self.loop._selector.close() + self.loop._selector = selector = mock.Mock() + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertIsNone(self.loop._csock) + self.assertIsNone(self.loop._ssock) + selector.close.assert_called_with() + ssock.close.assert_called_with() + csock.close.assert_called_with() + remove_reader.assert_called_with(7) + + self.loop.close() + self.loop.close() + + def test_close_no_selector(self): + ssock = self.loop._ssock + csock = self.loop._csock + remove_reader = self.loop.remove_reader = mock.Mock() + + self.loop._selector.close() + self.loop._selector = None + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.loop._socketpair) + + def test_read_from_self_tryagain(self): + self.loop._ssock.recv.side_effect = BlockingIOError + self.assertIsNone(self.loop._read_from_self()) + + def test_read_from_self_exception(self): + self.loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.loop._read_from_self) + + def test_write_to_self_tryagain(self): + self.loop._csock.send.side_effect = BlockingIOError + self.assertIsNone(self.loop._write_to_self()) + + def test_write_to_self_exception(self): + self.loop._csock.send.side_effect = OSError() + self.assertRaises(OSError, self.loop._write_to_self) + + def test_sock_recv(self): + sock = mock.Mock() + self.loop._sock_recv = mock.Mock() + + f = self.loop.sock_recv(sock, 1024) + self.assertIsInstance(f, asyncio.Future) + self.loop._sock_recv.assert_called_with(f, False, sock, 1024) + + def test__sock_recv_canceled_fut(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = mock.Mock() + sock.fileno.return_value = 10 + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop.remove_reader = mock.Mock() + self.loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = BlockingIOError + + self.loop.add_reader = mock.Mock() + self.loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024), + self.loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + err = sock.recv.side_effect = OSError() + + self.loop._sock_recv(f, False, sock, 1024) + self.assertIs(err, f.exception()) + + def test_sock_sendall(self): + sock = mock.Mock() + self.loop._sock_sendall = mock.Mock() + + f = self.loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, asyncio.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = mock.Mock() + self.loop._sock_sendall = mock.Mock() + + f = self.loop.sock_sendall(sock, b'') + self.assertIsInstance(f, asyncio.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertFalse(self.loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = mock.Mock() + sock.fileno.return_value = 10 + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = mock.Mock() + self.loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = BlockingIOError + + self.loop.add_writer = mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_interrupted(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = InterruptedError + + self.loop.add_writer = mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + err = sock.send.side_effect = OSError() + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertIs(f.exception(), err) + + def test__sock_sendall(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test__sock_sendall_partial(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.loop.add_writer = mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'ta'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.loop.add_writer = mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test_sock_connect(self): + sock = mock.Mock() + self.loop._sock_connect = mock.Mock() + + f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f, asyncio.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1', 8080)), + self.loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = asyncio.Future(loop=self.loop) + + sock = mock.Mock() + sock.fileno.return_value = 10 + + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = mock.Mock() + sock.fileno.return_value = 10 + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.loop.add_writer = mock.Mock() + self.loop.remove_writer = mock.Mock() + + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual( + (10, self.loop._sock_connect, f, + True, sock, ('127.0.0.1', 8080)), + self.loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.loop.remove_writer = mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f.exception(), OSError) + + def test_sock_accept(self): + sock = mock.Mock() + self.loop._sock_accept = mock.Mock() + + f = self.loop.sock_accept(sock) + self.assertIsInstance(f, asyncio.Future) + self.assertEqual( + (f, False, sock), self.loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = asyncio.Future(loop=self.loop) + + conn = mock.Mock() + + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = mock.Mock() + sock.fileno.return_value = 10 + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop.remove_reader = mock.Mock() + self.loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = BlockingIOError + + self.loop.add_reader = mock.Mock() + self.loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.loop._sock_accept, f, True, sock), + self.loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + err = sock.accept.side_effect = OSError() + + self.loop._sock_accept(f, False, sock) + self.assertIs(err, f.exception()) + + def test_add_reader(self): + self.loop._selector.get_key.side_effect = KeyError + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertIsNone(w) + + def test_add_reader_existing(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (reader, writer)) + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + def test_add_reader_existing_writer(self): + writer = mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, writer)) + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + def test_remove_reader(self): + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (None, None)) + self.assertFalse(self.loop.remove_reader(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) + self.assertTrue( + self.loop.remove_reader(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.loop._selector.get_key.side_effect = KeyError + self.assertFalse( + self.loop.remove_reader(1)) + + def test_add_writer(self): + self.loop._selector.get_key.side_effect = KeyError + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE, mask) + self.assertIsNone(r) + self.assertEqual(cb, w._callback) + + def test_add_writer_existing(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, writer)) + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(reader, r) + self.assertEqual(cb, w._callback) + + def test_remove_writer(self): + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.loop.remove_writer(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) + self.assertTrue( + self.loop.remove_writer(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.loop._selector.get_key.side_effect = KeyError + self.assertFalse( + self.loop.remove_writer(1)) + + def test_process_events_read(self): + reader = mock.Mock() + reader._cancelled = False + + self.loop._add_callback = mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.assertTrue(self.loop._add_callback.called) + self.loop._add_callback.assert_called_with(reader) + + def test_process_events_read_cancelled(self): + reader = mock.Mock() + reader.cancelled = True + + self.loop.remove_reader = mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.loop.remove_reader.assert_called_with(1) + + def test_process_events_write(self): + writer = mock.Mock() + writer._cancelled = False + + self.loop._add_callback = mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop._add_callback.assert_called_with(writer) + + def test_process_events_write_cancelled(self): + writer = mock.Mock() + writer.cancelled = True + self.loop.remove_writer = mock.Mock() + + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop.remove_writer.assert_called_with(1) + + +class SelectorTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + + def test_ctor(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + self.assertIs(tr._loop, self.loop) + self.assertIs(tr._sock, self.sock) + self.assertIs(tr._sock_fd, 7) + + def test_abort(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = mock.Mock() + + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr.close() + + self.assertTrue(tr._closing) + self.assertEqual(1, self.loop.remove_reader_count[7]) + self.protocol.connection_lost(None) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + def test_close_write_buffer(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer.extend(b'data') + tr.close() + + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_force_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer.extend(b'1') + self.loop.add_reader(7, mock.sentinel) + self.loop.add_writer(7, mock.sentinel) + tr._force_close(None) + + self.assertTrue(tr._closing) + self.assertEqual(tr._buffer, list_to_buffer()) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + + # second close should not remove reader + tr._force_close(None) + self.assertFalse(self.loop.readers) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + @mock.patch('asyncio.log.logger.error') + def test_fatal_error(self, m_exc): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = mock.Mock() + tr._fatal_error(exc) + + m_exc.assert_called_with( + test_utils.MockPattern( + 'Fatal error on transport\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + tr._force_close.assert_called_with(exc) + + def test_connection_lost(self): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() + self.assertIsNone(tr._sock) + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + self.sock_fd = self.sock.fileno.return_value = 7 + + def test_ctor(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.loop.assert_reader(7, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + + _SelectorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + def test_pause_resume_reading(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + tr.pause_reading() + self.assertTrue(tr._paused) + self.assertFalse(7 in self.loop.readers) + tr.resume_reading() + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + + def test_read_ready(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + def test_read_ready_eof(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = mock.Mock() + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + transport.close.assert_called_with() + + def test_read_ready_eof_keep_open(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = mock.Mock() + + self.sock.recv.return_value = b'' + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + + @mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): + self.sock.recv.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @mock.patch('logging.exception') + def test_read_ready_tryagain_interrupted(self, m_exc): + self.sock.recv.side_effect = InterruptedError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @mock.patch('logging.exception') + def test_read_ready_conn_reset(self, m_exc): + err = self.sock.recv.side_effect = ConnectionResetError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._force_close = mock.Mock() + transport._read_ready() + transport._force_close.assert_called_with(err) + + @mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): + err = self.sock.recv.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on socket transport') + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_bytearray(self): + data = bytearray(b'data') + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + + def test_write_memoryview(self): + data = memoryview(b'data') + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual(list_to_buffer([b'data1', b'data2']), + transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + def test_write_partial_bytearray(self): + data = bytearray(b'data') + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + + def test_write_partial_memoryview(self): + data = memoryview(b'data') + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + self.sock.fileno.return_value = 7 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + data = b'data' + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + @mock.patch('asyncio.selector_events.logger') + def test_write_exception(self, m_log): + err = self.sock.send.side_effect = OSError() + + data = b'data' + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport.write(data) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') + transport._conn_lost = 1 + + self.sock.reset_mock() + transport.write(data) + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._conn_lost, 2) + transport.write(data) + transport.write(data) + transport.write(data) + transport.write(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(TypeError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertFalse(self.loop.writers) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.extend(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + # This is an internal error. + self.assertRaises(AssertionError, transport._write_ready) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_ready_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer = list_to_buffer([b'data1', b'data2']) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer) + + def test_write_ready_exception(self): + err = self.sock.send.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport._buffer.extend(b'data') + transport._write_ready() + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') + + @mock.patch('asyncio.base_events.logger') + def test_write_ready_exception_and_close(self, m_log): + self.sock.send.side_effect = OSError() + remove_writer = self.loop.remove_writer = mock.Mock() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + transport._buffer.extend(b'data') + transport._write_ready() + remove_writer.assert_called_with(self.sock_fd) + + def test_write_eof(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertEqual(tr._buffer, list_to_buffer([b'data'])) + self.assertTrue(tr._eof) + self.assertFalse(self.sock.shutdown.called) + self.sock.send.side_effect = lambda _: 4 + tr._write_ready() + self.assertTrue(self.sock.send.called) + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.sslsock = mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + + def _make_one(self, create_waiter=None): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + self.sock.reset_mock() + self.sslsock.reset_mock() + self.sslcontext.reset_mock() + self.loop.reset_counters() + return transport + + def test_on_handshake(self): + waiter = asyncio.Future(loop=self.loop) + tr = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + waiter=waiter) + self.assertTrue(self.sslsock.do_handshake.called) + self.loop.assert_reader(1, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertIsNone(waiter.result()) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._on_handshake() + self.loop.assert_reader(1, transport._on_handshake) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._on_handshake() + self.loop.assert_writer(1, transport._on_handshake) + + def test_on_handshake_exc(self): + exc = ValueError() + self.sslsock.do_handshake.side_effect = exc + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._waiter = asyncio.Future(loop=self.loop) + transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) + + def test_on_handshake_base_exc(self): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._waiter = asyncio.Future(loop=self.loop) + exc = BaseException() + self.sslsock.do_handshake.side_effect = exc + self.assertRaises(BaseException, transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) + + def test_pause_resume_reading(self): + tr = self._make_one() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._read_ready) + tr.pause_reading() + self.assertTrue(tr._paused) + self.assertFalse(1 in self.loop.readers) + tr.resume_reading() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._read_ready) + + def test_write(self): + transport = self._make_one() + transport.write(b'data') + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_bytearray(self): + transport = self._make_one() + data = bytearray(b'data') + transport.write(data) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + self.assertIsNot(data, transport._buffer) # Hasn't been incorporated. + + def test_write_memoryview(self): + transport = self._make_one() + data = memoryview(b'data') + transport.write(data) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_no_data(self): + transport = self._make_one() + transport._buffer.extend(b'data') + transport.write(b'') + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_str(self): + transport = self._make_one() + self.assertRaises(TypeError, transport.write, 'str') + + def test_write_closing(self): + transport = self._make_one() + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + @mock.patch('asyncio.selector_events.logger') + def test_write_exception(self, m_log): + transport = self._make_one() + transport._conn_lost = 1 + transport.write(b'data') + self.assertEqual(transport._buffer, list_to_buffer()) + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_read_ready_recv(self): + self.sslsock.recv.return_value = b'data' + transport = self._make_one() + transport._read_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_read_ready_write_wants_read(self): + self.loop.add_writer = mock.Mock() + self.sslsock.recv.side_effect = BlockingIOError + transport = self._make_one() + transport._write_wants_read = True + transport._write_ready = mock.Mock() + transport._buffer.extend(b'data') + transport._read_ready() + + self.assertFalse(transport._write_wants_read) + transport._write_ready.assert_called_with() + self.loop.add_writer.assert_called_with( + transport._sock_fd, transport._write_ready) + + def test_read_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + transport = self._make_one() + transport.close = mock.Mock() + transport._read_ready() + transport.close.assert_called_with() + self.protocol.eof_received.assert_called_with() + + def test_read_ready_recv_conn_reset(self): + err = self.sslsock.recv.side_effect = ConnectionResetError() + transport = self._make_one() + transport._force_close = mock.Mock() + transport._read_ready() + transport._force_close.assert_called_with(err) + + def test_read_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + transport = self._make_one() + transport._read_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = BlockingIOError + transport._read_ready() + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = InterruptedError + transport._read_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_read_ready_recv_write(self): + self.loop.remove_reader = mock.Mock() + self.loop.add_writer = mock.Mock() + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + transport = self._make_one() + transport._read_ready() + self.assertFalse(self.protocol.data_received.called) + self.assertTrue(transport._read_wants_write) + + self.loop.remove_reader.assert_called_with(transport._sock_fd) + self.loop.add_writer.assert_called_with( + transport._sock_fd, transport._write_ready) + + def test_read_ready_recv_exc(self): + err = self.sslsock.recv.side_effect = OSError() + transport = self._make_one() + transport._fatal_error = mock.Mock() + transport._read_ready() + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on SSL transport') + + def test_write_ready_send(self): + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport._buffer = list_to_buffer([b'data']) + transport._write_ready() + self.assertEqual(list_to_buffer(), transport._buffer) + self.assertTrue(self.sslsock.send.called) + + def test_write_ready_send_none(self): + self.sslsock.send.return_value = 0 + transport = self._make_one() + transport._buffer = list_to_buffer([b'data1', b'data2']) + transport._write_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer) + + def test_write_ready_send_partial(self): + self.sslsock.send.return_value = 2 + transport = self._make_one() + transport._buffer = list_to_buffer([b'data1', b'data2']) + transport._write_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer) + + def test_write_ready_send_closing_partial(self): + self.sslsock.send.return_value = 2 + transport = self._make_one() + transport._buffer = list_to_buffer([b'data1', b'data2']) + transport._write_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_write_ready_send_closing(self): + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = list_to_buffer([b'data']) + transport._write_ready() + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_ready_send_closing_empty_buffer(self): + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = list_to_buffer() + transport._write_ready() + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_ready_send_retry(self): + transport = self._make_one() + transport._buffer = list_to_buffer([b'data']) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + transport._write_ready() + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + self.sslsock.send.side_effect = BlockingIOError() + transport._write_ready() + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_ready_send_read(self): + transport = self._make_one() + transport._buffer = list_to_buffer([b'data']) + + self.loop.remove_writer = mock.Mock() + self.sslsock.send.side_effect = ssl.SSLWantReadError + transport._write_ready() + self.assertFalse(self.protocol.data_received.called) + self.assertTrue(transport._write_wants_read) + self.loop.remove_writer.assert_called_with(transport._sock_fd) + + def test_write_ready_send_exc(self): + err = self.sslsock.send.side_effect = OSError() + + transport = self._make_one() + transport._buffer = list_to_buffer([b'data']) + transport._fatal_error = mock.Mock() + transport._write_ready() + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on SSL transport') + self.assertEqual(list_to_buffer(), transport._buffer) + + def test_write_ready_read_wants_write(self): + self.loop.add_reader = mock.Mock() + self.sslsock.send.side_effect = BlockingIOError + transport = self._make_one() + transport._read_wants_write = True + transport._read_ready = mock.Mock() + transport._write_ready() + + self.assertFalse(transport._read_wants_write) + transport._read_ready.assert_called_with() + self.loop.add_reader.assert_called_with( + transport._sock_fd, transport._read_ready) + + def test_write_eof(self): + tr = self._make_one() + self.assertFalse(tr.can_write_eof()) + self.assertRaises(NotImplementedError, tr.write_eof) + + def test_close(self): + tr = self._make_one() + tr.close() + + self.assertTrue(tr._closing) + self.assertEqual(1, self.loop.remove_reader_count[1]) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[1]) + + @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support') + def test_server_hostname(self): + _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + server_hostname='localhost') + self.sslcontext.wrap_socket.assert_called_with( + self.sock, do_handshake_on_connect=False, server_side=False, + server_hostname='localhost') + + +class SelectorSslWithoutSslTransportTests(unittest.TestCase): + + @mock.patch('asyncio.selector_events.ssl', None) + def test_ssl_transport_requires_ssl_module(self): + Mock = mock.Mock + with self.assertRaises(RuntimeError): + _SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) + self.sock = mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + + def test_read_ready(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.protocol.datagram_received.assert_called_with( + b'data', ('0.0.0.0', 1234)) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + self.sock.recvfrom.side_effect = BlockingIOError + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + err = self.sock.recvfrom.side_effect = RuntimeError() + transport._fatal_error = mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on datagram transport') + + def test_read_ready_oserr(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + err = self.sock.recvfrom.side_effect = OSError() + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + self.protocol.error_received.assert_called_with(err) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_bytearray(self): + data = bytearray(b'data') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_memoryview(self): + data = memoryview(b'data') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_buffer_bytearray(self): + data2 = bytearray(b'data2') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_buffer_memoryview(self): + data2 = memoryview(b'data2') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_tryagain(self): + data = b'data' + + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + @mock.patch('asyncio.selector_events.logger') + def test_sendto_exception(self, m_log): + data = b'data' + err = self.sock.sendto.side_effect = RuntimeError() + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + transport._conn_lost = 1 + + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_sendto_error_received(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertEqual(transport._conn_lost, 0) + self.assertFalse(transport._fatal_error.called) + + def test_sendto_error_received_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport.sendto(data) + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(TypeError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + ValueError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, address=(1,)) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.sock.sendto.assert_called_with(data, ()) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_tryagain(self): + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + err = self.sock.sendto.side_effect = RuntimeError() + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + + def test_sendto_ready_error_received(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_error_received_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + @mock.patch('asyncio.base_events.logger.error') + def test_fatal_error_connected(self, m_exc): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.assertFalse(self.protocol.error_received.called) + m_exc.assert_called_with( + test_utils.MockPattern( + 'Fatal error on transport\nprotocol:.*\ntransport:.*'), + exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py new file mode 100644 index 0000000000..031499e814 --- /dev/null +++ b/Lib/test/test_asyncio/test_streams.py @@ -0,0 +1,588 @@ +"""Tests for streams.py.""" + +import gc +import socket +import unittest +from unittest import mock +try: + import ssl +except ImportError: + ssl = None + +import asyncio +from asyncio import test_utils + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + @mock.patch('asyncio.streams.events') + def test_ctor_global_loop(self, m_events): + stream = asyncio.StreamReader() + self.assertIs(stream._loop, m_events.get_event_loop.return_value) + + def _basetest_open_connection(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + + def test_open_connection(self): + with test_utils.run_test_server() as httpd: + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + try: + reader, writer = self.loop.run_until_complete(open_connection_fut) + finally: + asyncio.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_connection( + *httpd.address, + ssl=test_utils.dummy_ssl_context(), + loop=self.loop) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_no_loop_ssl(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_unix_connection( + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='', + loop=self.loop) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + def _basetest_open_connection_error(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + writer.close() + test_utils.run_briefly(self.loop) + + def test_open_connection_error(self): + with test_utils.run_test_server() as httpd: + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_error(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) + + def test_feed_empty_data(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(b'') + self.assertEqual(b'', stream._buffer) + + def test_feed_nonempty_data(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, stream._buffer) + + def test_read_zero(self): + # Read zero bytes. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(self.DATA, stream._buffer) + + def test_read(self): + # Read bytes. + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(30), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertEqual(b'', stream._buffer) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(b'line2', stream._buffer) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(1024), loop=self.loop) + + def cb(): + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(b'', stream._buffer) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(-1), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertEqual(b'', stream._buffer) + + def test_read_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. 'readline' will need to wait for the data + # to come from 'cb' + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'chunk1 ') + read_task = asyncio.Task(stream.readline(), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.loop.call_soon(cb) + + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(b' chunk4', stream._buffer) + + def test_readline_limit_with_existing_data(self): + # Read one line. The data is in StreamReader's buffer + # before the event loop is run. + + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + # The buffer should contain the remaining data after exception + self.assertEqual(b'line2\n', stream._buffer) + + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + # No b'\n' at the end. The 'limit' is set to 3. So before + # waiting for the new data in buffer, 'readline' will consume + # the entire buffer, and since the length of the consumed data + # is more than 3, it will raise a ValueError. The buffer is + # expected to be empty now. + self.assertEqual(b'', stream._buffer) + + def test_at_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + self.assertFalse(stream.at_eof()) + + stream.feed_data(b'some data\n') + self.assertFalse(stream.at_eof()) + + self.loop.run_until_complete(stream.readline()) + self.assertFalse(stream.at_eof()) + + stream.feed_data(b'some data\n') + stream.feed_eof() + self.loop.run_until_complete(stream.readline()) + self.assertTrue(stream.at_eof()) + + def test_readline_limit(self): + # Read one line. StreamReaders are fed with data after + # their 'readline' methods are called. + + stream = asyncio.StreamReader(limit=7, loop=self.loop) + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + # The buffer had just one line of data, and after raising + # a ValueError it should be empty. + self.assertEqual(b'', stream._buffer) + + stream = asyncio.StreamReader(limit=7, loop=self.loop) + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2\n') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual(b'chunk3\n', stream._buffer) + + def test_readline_nolimit_nowait(self): + # All needed data for the first 'readline' call will be + # in the buffer. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(b'line2\nline3\n', stream._buffer) + + def test_readline_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + self.loop.run_until_complete(stream.readline()) + + data = self.loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual(b'ine3\n', stream._buffer) + + def test_readline_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual(b'', stream._buffer) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(self.DATA, stream._buffer) + + data = self.loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(self.DATA, stream._buffer) + + def test_readexactly(self): + # Read exact number of bytes. + stream = asyncio.StreamReader(loop=self.loop) + + n = 2 * len(self.DATA) + read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(self.DATA, stream._buffer) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = asyncio.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.loop.call_soon(cb) + + with self.assertRaises(asyncio.IncompleteReadError) as cm: + self.loop.run_until_complete(read_task) + self.assertEqual(cm.exception.partial, self.DATA) + self.assertEqual(cm.exception.expected, n) + self.assertEqual(str(cm.exception), + '18 bytes read on a total of 36 expected bytes') + self.assertEqual(b'', stream._buffer) + + def test_readexactly_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readexactly(2)) + + def test_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = asyncio.StreamReader(loop=self.loop) + + @asyncio.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @asyncio.coroutine + def readline(): + yield from stream.readline() + + t1 = asyncio.Task(stream.readline(), loop=self.loop) + t2 = asyncio.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + def test_exception_cancel(self): + stream = asyncio.StreamReader(loop=self.loop) + + @asyncio.coroutine + def read_a_line(): + yield from stream.readline() + + t = asyncio.Task(read_a_line(), loop=self.loop) + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + # The following line fails if set_exception() isn't careful. + stream.set_exception(RuntimeError('message')) + test_utils.run_briefly(self.loop) + self.assertIs(stream._waiter, None) + + def test_start_server(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + @asyncio.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, + sock=sock, + loop=self.loop)) + return sock.getsockname() + + def handle_client_callback(self, client_reader, client_writer): + task = asyncio.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) + addr = sock.getsockname() + sock.close() + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client_callback, + host=addr[0], port=addr[1], + loop=self.loop)) + return addr + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @asyncio.coroutine + def client(addr): + reader, writer = yield from asyncio.open_connection( + *addr, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + server = MyServer(self.loop) + addr = server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + server = MyServer(self.loop) + addr = server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_start_unix_server(self): + + class MyServer: + + def __init__(self, loop, path): + self.server = None + self.loop = loop + self.path = path + + @asyncio.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client, + path=self.path, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = asyncio.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client_callback, + path=self.path, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @asyncio.coroutine + def client(path): + reader, writer = yield from asyncio.open_unix_connection( + path, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_subprocess.py b/Lib/test/test_asyncio/test_subprocess.py new file mode 100644 index 0000000000..14fd17e61a --- /dev/null +++ b/Lib/test/test_asyncio/test_subprocess.py @@ -0,0 +1,184 @@ +from asyncio import subprocess +import asyncio +import signal +import sys +import unittest +from test import support +if sys.platform != 'win32': + from asyncio import unix_events + +# Program blocking +PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] + +# Program sleeping during 1 second +PROGRAM_SLEEP_1SEC = [sys.executable, '-c', 'import time; time.sleep(1)'] + +# Program copying input to output +PROGRAM_CAT = [ + sys.executable, '-c', + ';'.join(('import sys', + 'data = sys.stdin.buffer.read()', + 'sys.stdout.buffer.write(data)'))] + +class SubprocessMixin: + + def test_stdin_stdout(self): + args = PROGRAM_CAT + + @asyncio.coroutine + def run(data): + proc = yield from asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + loop=self.loop) + + # feed data + proc.stdin.write(data) + yield from proc.stdin.drain() + proc.stdin.close() + + # get output and exitcode + data = yield from proc.stdout.read() + exitcode = yield from proc.wait() + return (exitcode, data) + + task = run(b'some data') + task = asyncio.wait_for(task, 10.0, loop=self.loop) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'some data') + + def test_communicate(self): + args = PROGRAM_CAT + + @asyncio.coroutine + def run(data): + proc = yield from asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + loop=self.loop) + stdout, stderr = yield from proc.communicate(data) + return proc.returncode, stdout + + task = run(b'some data') + task = asyncio.wait_for(task, 10.0, loop=self.loop) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'some data') + + def test_shell(self): + create = asyncio.create_subprocess_shell('exit 7', + loop=self.loop) + proc = self.loop.run_until_complete(create) + exitcode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(exitcode, 7) + + def test_start_new_session(self): + # start the new process in a new session + create = asyncio.create_subprocess_shell('exit 8', + start_new_session=True, + loop=self.loop) + proc = self.loop.run_until_complete(create) + exitcode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(exitcode, 8) + + def test_kill(self): + args = PROGRAM_BLOCKED + create = asyncio.create_subprocess_exec(*args, loop=self.loop) + proc = self.loop.run_until_complete(create) + proc.kill() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_terminate(self): + args = PROGRAM_BLOCKED + create = asyncio.create_subprocess_exec(*args, loop=self.loop) + proc = self.loop.run_until_complete(create) + proc.terminate() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + def test_send_signal(self): + args = PROGRAM_BLOCKED + create = asyncio.create_subprocess_exec(*args, loop=self.loop) + proc = self.loop.run_until_complete(create) + proc.send_signal(signal.SIGHUP) + returncode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(-signal.SIGHUP, returncode) + + def test_broken_pipe(self): + large_data = b'x' * support.PIPE_MAX_SIZE + + create = asyncio.create_subprocess_exec( + *PROGRAM_SLEEP_1SEC, + stdin=subprocess.PIPE, + loop=self.loop) + proc = self.loop.run_until_complete(create) + with self.assertRaises(BrokenPipeError): + self.loop.run_until_complete(proc.communicate(large_data)) + self.loop.run_until_complete(proc.wait()) + + +if sys.platform != 'win32': + # Unix + class SubprocessWatcherMixin(SubprocessMixin): + + Watcher = None + + def setUp(self): + policy = asyncio.get_event_loop_policy() + self.loop = policy.new_event_loop() + + # ensure that the event loop is passed explicitly in the code + policy.set_event_loop(None) + + watcher = self.Watcher() + watcher.attach_loop(self.loop) + policy.set_child_watcher(watcher) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + policy.set_child_watcher(None) + self.loop.close() + policy.set_event_loop(None) + + class SubprocessSafeWatcherTests(SubprocessWatcherMixin, + unittest.TestCase): + + Watcher = unix_events.SafeChildWatcher + + class SubprocessFastWatcherTests(SubprocessWatcherMixin, + unittest.TestCase): + + Watcher = unix_events.FastChildWatcher + +else: + # Windows + class SubprocessProactorTests(SubprocessMixin, unittest.TestCase): + + def setUp(self): + policy = asyncio.get_event_loop_policy() + self.loop = asyncio.ProactorEventLoop() + + # ensure that the event loop is passed explicitly in the code + policy.set_event_loop(None) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + self.loop.close() + policy.set_event_loop(None) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py new file mode 100644 index 0000000000..ced34312f7 --- /dev/null +++ b/Lib/test/test_asyncio/test_tasks.py @@ -0,0 +1,1677 @@ +"""Tests for tasks.py.""" + +import gc +import os.path +import unittest +from test.script_helper import assert_python_ok + +import asyncio +from asyncio import test_utils + + +@asyncio.coroutine +def coroutine_function(): + pass + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + gc.collect() + + def test_task_class(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t = asyncio.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = asyncio.new_event_loop() + t = asyncio.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_coroutine(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t = asyncio.async(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = asyncio.new_event_loop() + t = asyncio.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_future(self): + f_orig = asyncio.Future(loop=self.loop) + f_orig.set_result('ko') + + f = asyncio.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = asyncio.new_event_loop() + + with self.assertRaises(ValueError): + f = asyncio.async(f_orig, loop=loop) + + loop.close() + + f = asyncio.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t_orig = asyncio.Task(notmuch(), loop=self.loop) + t = asyncio.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = asyncio.new_event_loop() + + with self.assertRaises(ValueError): + t = asyncio.async(t_orig, loop=loop) + + loop.close() + + t = asyncio.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + asyncio.async('ok') + + def test_task_repr(self): + @asyncio.coroutine + def notmuch(): + yield from [] + return 'abc' + + t = asyncio.Task(notmuch(), loop=self.loop) + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>') + self.assertRaises(asyncio.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>') + t = asyncio.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>") + + def test_task_repr_custom(self): + @asyncio.coroutine + def coro(): + pass + + class T(asyncio.Future): + def __repr__(self): + return 'T[]' + + class MyTask(asyncio.Task, T): + def __repr__(self): + return super().__repr__() + + gen = coro() + t = MyTask(gen, loop=self.loop) + self.assertEqual(repr(t), 'T[](<coro>)') + gen.close() + + def test_task_basics(self): + @asyncio.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @asyncio.coroutine + def inner1(): + return 42 + + @asyncio.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @asyncio.coroutine + def task(): + yield from asyncio.sleep(10.0, loop=loop) + return 12 + + t = asyncio.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @asyncio.coroutine + def task(): + yield + yield + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from f + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from f + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except asyncio.CancelledError: + return 42 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + fut3 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except asyncio.CancelledError: + pass + res = yield from fut3 + return res + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + @asyncio.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from asyncio.sleep(100, loop=loop) + return 12 + + t = asyncio.Task(task(), loop=loop) + self.assertRaises( + asyncio.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @asyncio.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(asyncio.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = asyncio.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + foo_running = None + + @asyncio.coroutine + def foo(): + nonlocal foo_running + foo_running = True + try: + yield from asyncio.sleep(0.2, loop=loop) + finally: + foo_running = False + return 'done' + + fut = asyncio.Task(foo(), loop=loop) + + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop)) + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) + self.assertAlmostEqual(0.1, loop.time()) + self.assertEqual(foo_running, False) + + def test_wait_for_blocking(self): + loop = test_utils.TestLoop() + self.addCleanup(loop.close) + + @asyncio.coroutine + def coro(): + return 'done' + + res = loop.run_until_complete(asyncio.wait_for(coro(), + timeout=None, + loop=loop)) + self.assertEqual(res, 'done') + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @asyncio.coroutine + def foo(): + yield from asyncio.sleep(0.2, loop=loop) + return 'done' + + asyncio.set_event_loop(loop) + try: + fut = asyncio.Task(foo(), loop=loop) + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.01)) + finally: + asyncio.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertTrue(fut.done()) + self.assertTrue(fut.cancelled()) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + asyncio.set_event_loop(loop) + try: + res = loop.run_until_complete( + asyncio.Task(foo(), loop=loop)) + finally: + asyncio.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('test') + + task = asyncio.Task( + asyncio.wait([c, c, coro('spam')], loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + + self.assertFalse(pending) + self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + asyncio.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + asyncio.wait([asyncio.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @asyncio.coroutine + def coro1(): + yield + + @asyncio.coroutine + def coro2(): + yield + yield + + a = asyncio.Task(coro1(), loop=self.loop) + b = asyncio.Task(coro2(), loop=self.loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + + @asyncio.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = asyncio.Task(exc(), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + + @asyncio.coroutine + def exc(): + yield from asyncio.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = asyncio.Task(exc(), loop=loop) + task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + + @asyncio.coroutine + def sleeper(): + yield from asyncio.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = asyncio.Task(sleeper(), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + asyncio.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @asyncio.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from asyncio.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @asyncio.coroutine + def foo(): + values = [] + for f in asyncio.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + yield + yield 0 + yield 0 + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.sleep(0.1, 'a', loop=loop) + b = asyncio.sleep(0.15, 'b', loop=loop) + + @asyncio.coroutine + def foo(): + values = [] + for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop): + if values: + loop.advance_time(0.02) + try: + v = yield from f + values.append((1, v)) + except asyncio.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertIsInstance(res[1][1], asyncio.TimeoutError) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_as_completed_with_unused_timeout(self): + + def gen(): + yield + yield 0 + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.sleep(0.01, 'a', loop=loop) + + @asyncio.coroutine + def foo(): + for f in asyncio.as_completed([a], timeout=1, loop=loop): + v = yield from f + self.assertEqual(v, 'a') + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(asyncio.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(asyncio.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = asyncio.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_as_completed_duplicate_coroutines(self): + + @asyncio.coroutine + def coro(s): + return s + + @asyncio.coroutine + def runner(): + result = [] + c = coro('ham') + for f in asyncio.as_completed([c, c, coro('spam')], + loop=self.loop): + result.append((yield from f)) + return result + + fut = asyncio.Task(runner(), loop=self.loop) + self.loop.run_until_complete(fut) + result = fut.result() + self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(len(result), 2) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @asyncio.coroutine + def sleeper(dt, arg): + yield from asyncio.sleep(dt/2, loop=loop) + res = yield from asyncio.sleep(dt/2, arg, loop=loop) + return res + + t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @asyncio.coroutine + def sleep(dt): + yield from asyncio.sleep(dt, loop=loop) + + @asyncio.coroutine + def doit(): + sleeper = asyncio.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except asyncio.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def coro(): + yield from fut + + task = asyncio.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @asyncio.coroutine + def notmuch(): + return 'ko' + + gen = notmuch() + task = asyncio.Task(gen, loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + gen.close() + + def test_step_result(self): + @asyncio.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(asyncio.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @asyncio.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = asyncio.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @asyncio.coroutine + def notmutch(): + raise BaseException() + + task = asyncio.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @asyncio.coroutine + def sleeper(): + yield from asyncio.sleep(10, loop=loop) + + base_exc = BaseException() + + @asyncio.coroutine + def notmutch(): + try: + yield from sleeper() + except asyncio.CancelledError: + raise base_exc + + task = asyncio.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(asyncio.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(asyncio.iscoroutinefunction(fn1)) + + @asyncio.coroutine + def fn2(): + yield + self.assertTrue(asyncio.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @asyncio.coroutine + def coro(): + yield + + @asyncio.coroutine + def wait_for_future(): + gen = coro() + try: + yield gen + finally: + gen.close() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @asyncio.coroutine + def func(): + return 'test' + + self.assertTrue(asyncio.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(asyncio.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def func(): + return fut + + @asyncio.coroutine + def coro(): + fut.set_result('test') + + t1 = asyncio.Task(func(), loop=self.loop) + t2 = asyncio.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + def test_current_task(self): + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + @asyncio.coroutine + def coro(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task) + + task = asyncio.Task(coro(self.loop), loop=self.loop) + self.loop.run_until_complete(task) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + def test_current_task_with_interleaving_tasks(self): + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def coro1(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) + yield from fut1 + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) + fut2.set_result(True) + + @asyncio.coroutine + def coro2(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) + fut1.set_result(True) + yield from fut2 + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) + + task1 = asyncio.Task(coro1(self.loop), loop=self.loop) + task2 = asyncio.Task(coro2(self.loop), loop=self.loop) + + self.loop.run_until_complete(asyncio.wait((task1, task2), + loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except asyncio.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @asyncio.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except asyncio.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @asyncio.coroutine + def outer(): + nonlocal proof + d, p = yield from asyncio.wait([inner()], loop=self.loop) + proof += 100 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(asyncio.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @asyncio.coroutine + def outer(): + nonlocal proof + yield from asyncio.shield(inner(), loop=self.loop) + proof += 100 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + parent = asyncio.gather(child1, child2, loop=self.loop) + outer = asyncio.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + inner1 = asyncio.shield(child1, loop=self.loop) + inner2 = asyncio.shield(child2, loop=self.loop) + parent = asyncio.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), asyncio.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + def test_as_completed_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # as_completed() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(fut, loop=self.loop)) + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(coroutine_function(), loop=self.loop)) + + def test_wait_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # wait() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(fut, loop=self.loop)) + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(coroutine_function(), loop=self.loop)) + + # wait() expects at least a future + self.assertRaises(ValueError, self.loop.run_until_complete, + asyncio.wait([], loop=self.loop)) + +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] + fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_return_exceptions(self): + a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + def test_env_var_debug(self): + path = os.path.dirname(asyncio.__file__) + path = os.path.normpath(os.path.join(path, '..')) + code = '\n'.join(( + 'import sys', + 'sys.path.insert(0, %r)' % path, + 'import asyncio.tasks', + 'print(asyncio.tasks._DEBUG)')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'False') + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + asyncio.set_event_loop(self.one_loop) + self.addCleanup(asyncio.set_event_loop, None) + fut = asyncio.gather(*seq_or_iter) + self.assertIsInstance(fut, asyncio.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = asyncio.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = asyncio.Future(loop=self.one_loop) + fut2 = asyncio.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + asyncio.gather(fut1, fut2) + with self.assertRaises(ValueError): + asyncio.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [asyncio.Future(loop=self.other_loop) for i in range(3)] + fut = asyncio.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = asyncio.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(a, b, c, d, e) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), asyncio.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) + for i in range(6)] + fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], asyncio.CancelledError) + self.assertIsInstance(res[4], asyncio.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + asyncio.set_event_loop(self.one_loop) + + def tearDown(self): + asyncio.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + @asyncio.coroutine + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @asyncio.coroutine + def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + fut = asyncio.gather(gen1, gen2) + self.assertIs(fut._loop, self.one_loop) + gen1.close() + gen2.close() + gen3 = coro() + gen4 = coro() + fut = asyncio.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + gen3.close() + gen4.close() + + def test_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('abc') + fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop) + self._run_loop(self.one_loop) + self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc']) + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = asyncio.Future(loop=self.one_loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = asyncio.async(inner(), loop=self.one_loop) + child2 = asyncio.async(inner(), loop=self.one_loop) + gatherer = None + + @asyncio.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = asyncio.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = asyncio.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(asyncio.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @asyncio.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = asyncio.Future(loop=self.one_loop) + b = asyncio.Future(loop=self.one_loop) + + @asyncio.coroutine + def outer(): + yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop) + + f = asyncio.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py new file mode 100644 index 0000000000..cfbdf3e9bc --- /dev/null +++ b/Lib/test/test_asyncio/test_transports.py @@ -0,0 +1,88 @@ +"""Tests for transports.py.""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = asyncio.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = asyncio.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = asyncio.Transport() + transport.write = mock.Mock() + + transport.writelines([b'line1', + bytearray(b'line2'), + memoryview(b'line3')]) + self.assertEqual(1, transport.write.call_count) + transport.write.assert_called_with(b'line1line2line3') + + def test_not_implemented(self): + transport = asyncio.Transport() + + self.assertRaises(NotImplementedError, + transport.set_write_buffer_limits) + self.assertRaises(NotImplementedError, transport.get_write_buffer_size) + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = asyncio.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = asyncio.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) + + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + transport = MyTransport() + transport._protocol = mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py new file mode 100644 index 0000000000..cc74383952 --- /dev/null +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -0,0 +1,1580 @@ +"""Tests for unix_events.py.""" + +import collections +import gc +import errno +import io +import os +import pprint +import signal +import socket +import stat +import sys +import tempfile +import threading +import unittest +from unittest import mock + +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + + +import asyncio +from asyncio import log +from asyncio import test_utils +from asyncio import unix_events + + +MOCK_ANY = mock.ANY + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopSignalTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = asyncio.Handle(mock.Mock(), (), + loop=mock.Mock()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertIsInstance(h, asyncio.Handle) + self.assertEqual(h._callback, cb) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_close(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGCHLD, lambda: True) + + self.assertEqual(len(self.loop._signal_handlers), 2) + + m_signal.set_wakeup_fd.reset_mock() + + self.loop.close() + + self.assertEqual(len(self.loop._signal_handlers), 0) + m_signal.set_wakeup_fd.assert_called_once_with(-1) + + +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') +class SelectorEventLoopUnixSocketTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_create_unix_server_existing_path_sock(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX) + sock.bind(path) + with sock: + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_existing_path_nonsock(self): + with tempfile.NamedTemporaryFile() as file: + coro = self.loop.create_unix_server(lambda: None, file.name) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_ssl_bool(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl=True) + with self.assertRaisesRegex(TypeError, + 'ssl argument must be an SSLContext'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_nopath_nosock(self): + coro = self.loop.create_unix_server(lambda: None, path=None) + with self.assertRaisesRegex(ValueError, + 'path was not specified, and no sock'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_inetsock(self): + sock = socket.socket() + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Socket was expected'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_path_sock(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', sock=object()) + with self.assertRaisesRegex(ValueError, 'path and sock can not be'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nopath_nosock(self): + coro = self.loop.create_unix_connection( + lambda: None, None) + with self.assertRaisesRegex(ValueError, + 'no path and sock were specified'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nossl_serverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', server_hostname='spam') + with self.assertRaisesRegex(ValueError, + 'server_hostname is only meaningful'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_noserverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', ssl=True) + + with self.assertRaisesRegex( + ValueError, 'you have to pass server_hostname when using ssl'): + + self.loop.run_until_complete(coro) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + @mock.patch('os.read') + def test_pause_reading(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = mock.Mock() + self.loop.add_reader(5, m) + tr.pause_reading() + self.assertFalse(self.loop.readers) + + @mock.patch('os.read') + def test_resume_reading(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume_reading() + self.loop.assert_reader(5, tr._read_ready) + + @mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(4, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(4, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFSOCK + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('asyncio.unix_events.logger') + @mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + + @mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal write error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(4, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(4, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + +class AbstractChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = asyncio.AbstractChildWatcher() + self.assertRaises( + NotImplementedError, watcher.add_child_handler, f, f) + self.assertRaises( + NotImplementedError, watcher.remove_child_handler, f) + self.assertRaises( + NotImplementedError, watcher.attach_loop, f) + self.assertRaises( + NotImplementedError, watcher.close) + self.assertRaises( + NotImplementedError, watcher.__enter__) + self.assertRaises( + NotImplementedError, watcher.__exit__, f, f, f) + + +class BaseChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = unix_events.BaseChildWatcher() + self.assertRaises( + NotImplementedError, watcher._do_waitpid, f) + + +WaitPidMocks = collections.namedtuple("WaitPidMocks", + ("waitpid", + "WIFEXITED", + "WIFSIGNALED", + "WEXITSTATUS", + "WTERMSIG", + )) + + +class ChildWatcherTestsMixin: + + ignore_warnings = mock.patch.object(log.logger, "warning") + + def setUp(self): + self.loop = test_utils.TestLoop() + self.running = False + self.zombies = {} + + with mock.patch.object( + self.loop, "add_signal_handler") as self.m_add_signal_handler: + self.watcher = self.create_watcher() + self.watcher.attach_loop(self.loop) + + def waitpid(self, pid, flags): + if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1: + self.assertGreater(pid, 0) + try: + if pid < 0: + return self.zombies.popitem() + else: + return pid, self.zombies.pop(pid) + except KeyError: + pass + if self.running: + return 0, 0 + else: + raise ChildProcessError() + + def add_zombie(self, pid, returncode): + self.zombies[pid] = returncode + 32768 + + def WIFEXITED(self, status): + return status >= 32768 + + def WIFSIGNALED(self, status): + return 32700 < status < 32768 + + def WEXITSTATUS(self, status): + self.assertTrue(self.WIFEXITED(status)) + return status - 32768 + + def WTERMSIG(self, status): + self.assertTrue(self.WIFSIGNALED(status)) + return 32768 - status + + def test_create_watcher(self): + self.m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + def waitpid_mocks(func): + def wrapped_func(self): + def patch(target, wrapper): + return mock.patch(target, wraps=wrapper, + new_callable=mock.Mock) + + with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \ + patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \ + patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \ + patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \ + patch('os.waitpid', self.waitpid) as m_waitpid: + func(self, WaitPidMocks(m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + )) + return wrapped_func + + @waitpid_mocks + def test_sigchld(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(42, callback, 9, 10, 14) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child is running + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (returncode 12) + self.running = False + self.add_zombie(42, 12) + self.watcher._sig_chld() + + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + callback.assert_called_once_with(42, 12, 9, 10, 14) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(42, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(43, callback1, 7, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(44, callback2, 147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 1 terminates (signal 3) + self.add_zombie(43, -3) + self.watcher._sig_chld() + + callback1.assert_called_once_with(43, -3, 7, 8) + self.assertFalse(callback2.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + callback1.reset_mock() + + # child 2 still running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 2 terminates (code 108) + self.add_zombie(44, 108) + self.running = False + self.watcher._sig_chld() + + callback2.assert_called_once_with(44, 108, 147, 18) + self.assertFalse(callback1.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(43, 14) + self.add_zombie(44, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children_terminating_together(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(45, callback1, 17, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(46, callback2, 1147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 1 terminates (code 78) + # child 2 terminates (signal 5) + self.add_zombie(45, 78) + self.add_zombie(46, -5) + self.running = False + self.watcher._sig_chld() + + callback1.assert_called_once_with(45, 78, 17, 8) + callback2.assert_called_once_with(46, -5, 1147, 18) + self.assertTrue(m.WIFSIGNALED.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + m.WEXITSTATUS.reset_mock() + callback1.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(45, 14) + self.add_zombie(46, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_race_condition(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + # child terminates before being registered + self.add_zombie(50, 4) + self.watcher._sig_chld() + + self.watcher.add_child_handler(50, callback, 1, 12) + + callback.assert_called_once_with(50, 4, 1, 12) + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(50, -1) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_replace_handler(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(51, callback1, 19) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register the same child again + with self.watcher: + self.watcher.add_child_handler(51, callback2, 21) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (signal 8) + self.running = False + self.add_zombie(51, -8) + self.watcher._sig_chld() + + callback2.assert_called_once_with(51, -8, 21) + self.assertFalse(callback1.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + callback2.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(51, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_remove_handler(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(52, callback, 1984) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # unregister the child + self.watcher.remove_child_handler(52) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (code 99) + self.running = False + self.add_zombie(52, 99) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_unknown_status(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(53, callback, -19) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # terminate with unknown status + self.zombies[53] = 1178 + self.running = False + self.watcher._sig_chld() + + callback.assert_called_once_with(53, 1178, -19) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + callback.reset_mock() + m.WIFEXITED.reset_mock() + m.WIFSIGNALED.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(53, 101) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_remove_child_handler(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + # register children + with self.watcher: + self.running = True + self.watcher.add_child_handler(54, callback1, 1) + self.watcher.add_child_handler(55, callback2, 2) + self.watcher.add_child_handler(56, callback3, 3) + + # remove child handler 1 + self.assertTrue(self.watcher.remove_child_handler(54)) + + # remove child handler 2 multiple times + self.assertTrue(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + + # all children terminate + self.add_zombie(54, 0) + self.add_zombie(55, 1) + self.add_zombie(56, 2) + self.running = False + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(56, 2, 3) + + @waitpid_mocks + def test_sigchld_unhandled_exception(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(57, callback) + + # raise an exception + m.waitpid.side_effect = ValueError + + with mock.patch.object(log.logger, + 'error') as m_error: + + self.assertEqual(self.watcher._sig_chld(), None) + self.assertTrue(m_error.called) + + @waitpid_mocks + def test_sigchld_child_reaped_elsewhere(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(58, callback) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates + self.running = False + self.add_zombie(58, 4) + + # waitpid is called elsewhere + os.waitpid(58, os.WNOHANG) + + m.waitpid.reset_mock() + + # sigchld + with self.ignore_warnings: + self.watcher._sig_chld() + + callback.assert_called(m.waitpid) + if isinstance(self.watcher, asyncio.FastChildWatcher): + # here the FastChildWatche enters a deadlock + # (there is no way to prevent it) + self.assertFalse(callback.called) + else: + callback.assert_called_once_with(58, 255) + + @waitpid_mocks + def test_sigchld_unknown_pid_during_registration(self, m): + # register two children + callback1 = mock.Mock() + callback2 = mock.Mock() + + with self.ignore_warnings, self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(591, 7) + # an unknown child terminates + self.add_zombie(593, 17) + + self.watcher._sig_chld() + + self.watcher.add_child_handler(591, callback1) + self.watcher.add_child_handler(592, callback2) + + callback1.assert_called_once_with(591, 7) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_set_loop(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(60, callback) + + # attach a new loop + old_loop = self.loop + self.loop = test_utils.TestLoop() + patch = mock.patch.object + + with patch(old_loop, "remove_signal_handler") as m_old_remove, \ + patch(self.loop, "add_signal_handler") as m_new_add: + + self.watcher.attach_loop(self.loop) + + m_old_remove.assert_called_once_with( + signal.SIGCHLD) + m_new_add.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + # child terminates + self.running = False + self.add_zombie(60, 9) + self.watcher._sig_chld() + + callback.assert_called_once_with(60, 9) + + @waitpid_mocks + def test_set_loop_race_condition(self, m): + # register 3 children + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(61, callback1) + self.watcher.add_child_handler(62, callback2) + self.watcher.add_child_handler(622, callback3) + + # detach the loop + old_loop = self.loop + self.loop = None + + with mock.patch.object( + old_loop, "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.attach_loop(None) + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + + # child 1 & 2 terminate + self.add_zombie(61, 11) + self.add_zombie(62, -5) + + # SIGCHLD was not caught + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(callback3.called) + + # attach a new loop + self.loop = test_utils.TestLoop() + + with mock.patch.object( + self.loop, "add_signal_handler") as m_add_signal_handler: + + self.watcher.attach_loop(self.loop) + + m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + callback1.assert_called_once_with(61, 11) # race condition! + callback2.assert_called_once_with(62, -5) # race condition! + self.assertFalse(callback3.called) + + callback1.reset_mock() + callback2.reset_mock() + + # child 3 terminates + self.running = False + self.add_zombie(622, 19) + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(622, 19) + + @waitpid_mocks + def test_close(self, m): + # register two children + callback1 = mock.Mock() + + with self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(63, 9) + # other child terminates + self.add_zombie(65, 18) + self.watcher._sig_chld() + + self.watcher.add_child_handler(63, callback1) + self.watcher.add_child_handler(64, callback1) + + self.assertEqual(len(self.watcher._callbacks), 1) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertEqual(len(self.watcher._zombies), 1) + + with mock.patch.object( + self.loop, + "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.close() + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + self.assertFalse(self.watcher._callbacks) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertFalse(self.watcher._zombies) + + +class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): + def create_watcher(self): + return asyncio.SafeChildWatcher() + + +class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): + def create_watcher(self): + return asyncio.FastChildWatcher() + + +class PolicyTests(unittest.TestCase): + + def create_policy(self): + return asyncio.DefaultEventLoopPolicy() + + def test_get_child_watcher(self): + policy = self.create_policy() + self.assertIsNone(policy._watcher) + + watcher = policy.get_child_watcher() + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + + self.assertIs(policy._watcher, watcher) + + self.assertIs(watcher, policy.get_child_watcher()) + self.assertIsNone(watcher._loop) + + def test_get_child_watcher_after_set(self): + policy = self.create_policy() + watcher = asyncio.FastChildWatcher() + + policy.set_child_watcher(watcher) + self.assertIs(policy._watcher, watcher) + self.assertIs(watcher, policy.get_child_watcher()) + + def test_get_child_watcher_with_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + self.assertIsNone(policy._watcher) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIs(watcher._loop, loop) + + loop.close() + + def test_get_child_watcher_thread(self): + + def f(): + policy.set_event_loop(policy.new_event_loop()) + + self.assertIsInstance(policy.get_event_loop(), + asyncio.AbstractEventLoop) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIsNone(watcher._loop) + + policy.get_event_loop().close() + + policy = self.create_policy() + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_child_watcher_replace_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + watcher = policy.get_child_watcher() + + self.assertIs(watcher._loop, loop) + + new_loop = policy.new_event_loop() + policy.set_event_loop(new_loop) + + self.assertIs(watcher._loop, new_loop) + + policy.set_event_loop(None) + + self.assertIs(watcher._loop, None) + + loop.close() + new_loop.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py new file mode 100644 index 0000000000..f652258639 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -0,0 +1,134 @@ +import os +import sys +import unittest + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +import asyncio +from asyncio import _overlapped +from asyncio import windows_events + + +class UpperProto(asyncio.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, asyncio.Protocol()) + f = asyncio.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + b.close() + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = asyncio.StreamReader(loop=self.loop) + protocol = asyncio.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda: protocol, ADDRESS) + self.assertIsInstance(trans, asyncio.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + return 'done' + + def test_wait_for_handle(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with 0.2s timeout; + # result should be False at timeout + f = self.loop._proactor.wait_for_handle(event, 0.2) + start = self.loop.time() + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertFalse(f.result()) + self.assertTrue(0.18 < elapsed < 0.9, elapsed) + + _overlapped.SetEvent(event) + + # Wait for for set event; + # result should be True immediately + f = self.loop._proactor.wait_for_handle(event, 10) + start = self.loop.time() + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertTrue(f.result()) + self.assertTrue(0 <= elapsed < 0.1, elapsed) + + _overlapped.ResetEvent(event) + + # Wait for unset event with a cancelled future; + # CancelledError should be raised immediately + f = self.loop._proactor.wait_for_handle(event, 10) + f.cancel() + start = self.loop.time() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertTrue(0 <= elapsed < 0.1, elapsed) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py new file mode 100644 index 0000000000..7616c73e45 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_utils.py @@ -0,0 +1,141 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +from unittest import mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +from asyncio import windows_utils +from asyncio import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + # Super-long timeout for slow buildbots. + res = _winapi.WaitForMultipleObjects(events, True, 10000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/tests.txt b/Lib/test/test_asyncio/tests.txt new file mode 100644 index 0000000000..c69582bc86 --- /dev/null +++ b/Lib/test/test_asyncio/tests.txt @@ -0,0 +1,14 @@ +test_asyncio.test_base_events +test_asyncio.test_events +test_asyncio.test_futures +test_asyncio.test_locks +test_asyncio.test_proactor_events +test_asyncio.test_queues +test_asyncio.test_selector_events +test_asyncio.test_streams +test_asyncio.test_subprocess +test_asyncio.test_tasks +test_asyncio.test_transports +test_asyncio.test_unix_events +test_asyncio.test_windows_events +test_asyncio.test_windows_utils diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py index 5d0632ef26..084d247295 100644 --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -19,6 +19,7 @@ try: except ImportError: threading = None +TIMEOUT = 3 HAS_UNIX_SOCKETS = hasattr(socket, 'AF_UNIX') class dummysocket: @@ -395,7 +396,10 @@ class DispatcherWithSendTests(unittest.TestCase): self.assertEqual(cap.getvalue(), data*2) finally: - t.join() + t.join(timeout=TIMEOUT) + if t.is_alive(): + self.fail("join() timed out") + class DispatcherWithSendTests_UsePoll(DispatcherWithSendTests): @@ -740,7 +744,12 @@ class BaseTestAPI: s.create_socket(self.family) self.assertEqual(s.socket.family, self.family) SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0) - self.assertEqual(s.socket.type, socket.SOCK_STREAM | SOCK_NONBLOCK) + sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK + if hasattr(socket, 'SOCK_CLOEXEC'): + self.assertIn(s.socket.type, + (sock_type | socket.SOCK_CLOEXEC, sock_type)) + else: + self.assertEqual(s.socket.type, sock_type) def test_bind(self): if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: @@ -754,7 +763,7 @@ class BaseTestAPI: s2 = asyncore.dispatcher() s2.create_socket(self.family) # EADDRINUSE indicates the socket was correctly bound - self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) + self.assertRaises(OSError, s2.bind, (self.addr[0], port)) def test_set_reuse_addr(self): if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: @@ -762,7 +771,7 @@ class BaseTestAPI: sock = socket.socket(self.family) try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - except socket.error: + except OSError: unittest.skip("SO_REUSEADDR not supported on this platform") else: # if SO_REUSEADDR succeeded for sock we expect asyncore @@ -787,7 +796,11 @@ class BaseTestAPI: t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500)) t.start() - self.addCleanup(t.join) + def cleanup(): + t.join(timeout=TIMEOUT) + if t.is_alive(): + self.fail("join() timed out") + self.addCleanup(cleanup) s = socket.socket(self.family, socket.SOCK_STREAM) s.settimeout(.2) @@ -795,7 +808,7 @@ class BaseTestAPI: struct.pack('ii', 1, 0)) try: s.connect(server.address) - except socket.error: + except OSError: pass finally: s.close() diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py index 30c3b4a07b..70d2f1cf54 100644 --- a/Lib/test/test_atexit.py +++ b/Lib/test/test_atexit.py @@ -23,7 +23,9 @@ def raise1(): def raise2(): raise SystemError -class TestCase(unittest.TestCase): + +class GeneralTest(unittest.TestCase): + def setUp(self): self.save_stdout = sys.stdout self.save_stderr = sys.stderr @@ -141,8 +143,43 @@ class TestCase(unittest.TestCase): self.assertEqual(l, [5]) +class SubinterpreterTest(unittest.TestCase): + + def test_callbacks_leak(self): + # This test shows a leak in refleak mode if atexit doesn't + # take care to free callbacks in its per-subinterpreter module + # state. + n = atexit._ncallbacks() + code = r"""if 1: + import atexit + def f(): + pass + atexit.register(f) + del atexit + """ + ret = support.run_in_subinterp(code) + self.assertEqual(ret, 0) + self.assertEqual(atexit._ncallbacks(), n) + + def test_callbacks_leak_refcycle(self): + # Similar to the above, but with a refcycle through the atexit + # module. + n = atexit._ncallbacks() + code = r"""if 1: + import atexit + def f(): + pass + atexit.register(f) + atexit.__atexit = atexit + """ + ret = support.run_in_subinterp(code) + self.assertEqual(ret, 0) + self.assertEqual(atexit._ncallbacks(), n) + + def test_main(): - support.run_unittest(TestCase) + support.run_unittest(__name__) + if __name__ == "__main__": test_main() diff --git a/Lib/test/test_audioop.py b/Lib/test/test_audioop.py index a92cf874bd..879adea2b8 100644 --- a/Lib/test/test_audioop.py +++ b/Lib/test/test_audioop.py @@ -5,13 +5,18 @@ import unittest def pack(width, data): return b''.join(v.to_bytes(width, sys.byteorder, signed=True) for v in data) -packs = {w: (lambda *data, width=w: pack(width, data)) for w in (1, 2, 4)} -maxvalues = {w: (1 << (8 * w - 1)) - 1 for w in (1, 2, 4)} -minvalues = {w: -1 << (8 * w - 1) for w in (1, 2, 4)} +def unpack(width, data): + return [int.from_bytes(data[i: i + width], sys.byteorder, signed=True) + for i in range(0, len(data), width)] + +packs = {w: (lambda *data, width=w: pack(width, data)) for w in (1, 2, 3, 4)} +maxvalues = {w: (1 << (8 * w - 1)) - 1 for w in (1, 2, 3, 4)} +minvalues = {w: -1 << (8 * w - 1) for w in (1, 2, 3, 4)} datas = { 1: b'\x00\x12\x45\xbb\x7f\x80\xff', 2: packs[2](0, 0x1234, 0x4567, -0x4567, 0x7fff, -0x8000, -1), + 3: packs[3](0, 0x123456, 0x456789, -0x456789, 0x7fffff, -0x800000, -1), 4: packs[4](0, 0x12345678, 0x456789ab, -0x456789ab, 0x7fffffff, -0x80000000, -1), } @@ -19,6 +24,7 @@ datas = { INVALID_DATA = [ (b'abc', 0), (b'abc', 2), + (b'ab', 3), (b'abc', 4), ] @@ -26,8 +32,10 @@ INVALID_DATA = [ class TestAudioop(unittest.TestCase): def test_max(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.max(b'', w), 0) + self.assertEqual(audioop.max(bytearray(), w), 0) + self.assertEqual(audioop.max(memoryview(b''), w), 0) p = packs[w] self.assertEqual(audioop.max(p(5), w), 5) self.assertEqual(audioop.max(p(5, -8, -1), w), 8) @@ -36,9 +44,13 @@ class TestAudioop(unittest.TestCase): self.assertEqual(audioop.max(datas[w], w), -minvalues[w]) def test_minmax(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.minmax(b'', w), (0x7fffffff, -0x80000000)) + self.assertEqual(audioop.minmax(bytearray(), w), + (0x7fffffff, -0x80000000)) + self.assertEqual(audioop.minmax(memoryview(b''), w), + (0x7fffffff, -0x80000000)) p = packs[w] self.assertEqual(audioop.minmax(p(5), w), (5, 5)) self.assertEqual(audioop.minmax(p(5, -8, -1), w), (-8, 5)) @@ -50,16 +62,20 @@ class TestAudioop(unittest.TestCase): (minvalues[w], maxvalues[w])) def test_maxpp(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.maxpp(b'', w), 0) + self.assertEqual(audioop.maxpp(bytearray(), w), 0) + self.assertEqual(audioop.maxpp(memoryview(b''), w), 0) self.assertEqual(audioop.maxpp(packs[w](*range(100)), w), 0) self.assertEqual(audioop.maxpp(packs[w](9, 10, 5, 5, 0, 1), w), 10) self.assertEqual(audioop.maxpp(datas[w], w), maxvalues[w] - minvalues[w]) def test_avg(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.avg(b'', w), 0) + self.assertEqual(audioop.avg(bytearray(), w), 0) + self.assertEqual(audioop.avg(memoryview(b''), w), 0) p = packs[w] self.assertEqual(audioop.avg(p(5), w), 5) self .assertEqual(audioop.avg(p(5, 8), w), 6) @@ -74,17 +90,22 @@ class TestAudioop(unittest.TestCase): -0x60000000) def test_avgpp(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.avgpp(b'', w), 0) + self.assertEqual(audioop.avgpp(bytearray(), w), 0) + self.assertEqual(audioop.avgpp(memoryview(b''), w), 0) self.assertEqual(audioop.avgpp(packs[w](*range(100)), w), 0) self.assertEqual(audioop.avgpp(packs[w](9, 10, 5, 5, 0, 1), w), 10) self.assertEqual(audioop.avgpp(datas[1], 1), 196) self.assertEqual(audioop.avgpp(datas[2], 2), 50534) + self.assertEqual(audioop.avgpp(datas[3], 3), 12937096) self.assertEqual(audioop.avgpp(datas[4], 4), 3311897002) def test_rms(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.rms(b'', w), 0) + self.assertEqual(audioop.rms(bytearray(), w), 0) + self.assertEqual(audioop.rms(memoryview(b''), w), 0) p = packs[w] self.assertEqual(audioop.rms(p(*range(100)), w), 57) self.assertAlmostEqual(audioop.rms(p(maxvalues[w]) * 5, w), @@ -93,11 +114,14 @@ class TestAudioop(unittest.TestCase): -minvalues[w], delta=1) self.assertEqual(audioop.rms(datas[1], 1), 77) self.assertEqual(audioop.rms(datas[2], 2), 20001) + self.assertEqual(audioop.rms(datas[3], 3), 5120523) self.assertEqual(audioop.rms(datas[4], 4), 1310854152) def test_cross(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.cross(b'', w), -1) + self.assertEqual(audioop.cross(bytearray(), w), -1) + self.assertEqual(audioop.cross(memoryview(b''), w), -1) p = packs[w] self.assertEqual(audioop.cross(p(0, 1, 2), w), 0) self.assertEqual(audioop.cross(p(1, 2, -3, -4), w), 1) @@ -106,22 +130,29 @@ class TestAudioop(unittest.TestCase): self.assertEqual(audioop.cross(p(minvalues[w], maxvalues[w]), w), 1) def test_add(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.add(b'', b'', w), b'') + self.assertEqual(audioop.add(bytearray(), bytearray(), w), b'') + self.assertEqual(audioop.add(memoryview(b''), memoryview(b''), w), b'') self.assertEqual(audioop.add(datas[w], b'\0' * len(datas[w]), w), datas[w]) self.assertEqual(audioop.add(datas[1], datas[1], 1), b'\x00\x24\x7f\x80\x7f\x80\xfe') self.assertEqual(audioop.add(datas[2], datas[2], 2), packs[2](0, 0x2468, 0x7fff, -0x8000, 0x7fff, -0x8000, -2)) + self.assertEqual(audioop.add(datas[3], datas[3], 3), + packs[3](0, 0x2468ac, 0x7fffff, -0x800000, + 0x7fffff, -0x800000, -2)) self.assertEqual(audioop.add(datas[4], datas[4], 4), packs[4](0, 0x2468acf0, 0x7fffffff, -0x80000000, 0x7fffffff, -0x80000000, -2)) def test_bias(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: for bias in 0, 1, -1, 127, -128, 0x7fffffff, -0x80000000: self.assertEqual(audioop.bias(b'', w, bias), b'') + self.assertEqual(audioop.bias(bytearray(), w, bias), b'') + self.assertEqual(audioop.bias(memoryview(b''), w, bias), b'') self.assertEqual(audioop.bias(datas[1], 1, 1), b'\x01\x13\x46\xbc\x80\x81\x00') self.assertEqual(audioop.bias(datas[1], 1, -1), @@ -138,6 +169,17 @@ class TestAudioop(unittest.TestCase): packs[2](-1, 0x1233, 0x4566, -0x4568, 0x7ffe, 0x7fff, -2)) self.assertEqual(audioop.bias(datas[2], 2, -0x80000000), datas[2]) + self.assertEqual(audioop.bias(datas[3], 3, 1), + packs[3](1, 0x123457, 0x45678a, -0x456788, + -0x800000, -0x7fffff, 0)) + self.assertEqual(audioop.bias(datas[3], 3, -1), + packs[3](-1, 0x123455, 0x456788, -0x45678a, + 0x7ffffe, 0x7fffff, -2)) + self.assertEqual(audioop.bias(datas[3], 3, 0x7fffffff), + packs[3](-1, 0x123455, 0x456788, -0x45678a, + 0x7ffffe, 0x7fffff, -2)) + self.assertEqual(audioop.bias(datas[3], 3, -0x80000000), + datas[3]) self.assertEqual(audioop.bias(datas[4], 4, 1), packs[4](1, 0x12345679, 0x456789ac, -0x456789aa, -0x80000000, -0x7fffffff, 0)) @@ -152,99 +194,146 @@ class TestAudioop(unittest.TestCase): -1, 0, 0x7fffffff)) def test_lin2lin(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.lin2lin(datas[w], w, w), datas[w]) + self.assertEqual(audioop.lin2lin(bytearray(datas[w]), w, w), + datas[w]) + self.assertEqual(audioop.lin2lin(memoryview(datas[w]), w, w), + datas[w]) self.assertEqual(audioop.lin2lin(datas[1], 1, 2), packs[2](0, 0x1200, 0x4500, -0x4500, 0x7f00, -0x8000, -0x100)) + self.assertEqual(audioop.lin2lin(datas[1], 1, 3), + packs[3](0, 0x120000, 0x450000, -0x450000, + 0x7f0000, -0x800000, -0x10000)) self.assertEqual(audioop.lin2lin(datas[1], 1, 4), packs[4](0, 0x12000000, 0x45000000, -0x45000000, 0x7f000000, -0x80000000, -0x1000000)) self.assertEqual(audioop.lin2lin(datas[2], 2, 1), b'\x00\x12\x45\xba\x7f\x80\xff') + self.assertEqual(audioop.lin2lin(datas[2], 2, 3), + packs[3](0, 0x123400, 0x456700, -0x456700, + 0x7fff00, -0x800000, -0x100)) self.assertEqual(audioop.lin2lin(datas[2], 2, 4), packs[4](0, 0x12340000, 0x45670000, -0x45670000, 0x7fff0000, -0x80000000, -0x10000)) + self.assertEqual(audioop.lin2lin(datas[3], 3, 1), + b'\x00\x12\x45\xba\x7f\x80\xff') + self.assertEqual(audioop.lin2lin(datas[3], 3, 2), + packs[2](0, 0x1234, 0x4567, -0x4568, 0x7fff, -0x8000, -1)) + self.assertEqual(audioop.lin2lin(datas[3], 3, 4), + packs[4](0, 0x12345600, 0x45678900, -0x45678900, + 0x7fffff00, -0x80000000, -0x100)) self.assertEqual(audioop.lin2lin(datas[4], 4, 1), b'\x00\x12\x45\xba\x7f\x80\xff') self.assertEqual(audioop.lin2lin(datas[4], 4, 2), packs[2](0, 0x1234, 0x4567, -0x4568, 0x7fff, -0x8000, -1)) + self.assertEqual(audioop.lin2lin(datas[4], 4, 3), + packs[3](0, 0x123456, 0x456789, -0x45678a, + 0x7fffff, -0x800000, -1)) def test_adpcm2lin(self): self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 1, None), (b'\x00\x00\x00\xff\x00\xff', (-179, 40))) + self.assertEqual(audioop.adpcm2lin(bytearray(b'\x07\x7f\x7f'), 1, None), + (b'\x00\x00\x00\xff\x00\xff', (-179, 40))) + self.assertEqual(audioop.adpcm2lin(memoryview(b'\x07\x7f\x7f'), 1, None), + (b'\x00\x00\x00\xff\x00\xff', (-179, 40))) self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 2, None), (packs[2](0, 0xb, 0x29, -0x16, 0x72, -0xb3), (-179, 40))) + self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 3, None), + (packs[3](0, 0xb00, 0x2900, -0x1600, 0x7200, + -0xb300), (-179, 40))) self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 4, None), (packs[4](0, 0xb0000, 0x290000, -0x160000, 0x720000, -0xb30000), (-179, 40))) # Very cursory test - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.adpcm2lin(b'\0' * 5, w, None), (b'\0' * w * 10, (0, 0))) def test_lin2adpcm(self): self.assertEqual(audioop.lin2adpcm(datas[1], 1, None), (b'\x07\x7f\x7f', (-221, 39))) - self.assertEqual(audioop.lin2adpcm(datas[2], 2, None), - (b'\x07\x7f\x7f', (31, 39))) - self.assertEqual(audioop.lin2adpcm(datas[4], 4, None), - (b'\x07\x7f\x7f', (31, 39))) + self.assertEqual(audioop.lin2adpcm(bytearray(datas[1]), 1, None), + (b'\x07\x7f\x7f', (-221, 39))) + self.assertEqual(audioop.lin2adpcm(memoryview(datas[1]), 1, None), + (b'\x07\x7f\x7f', (-221, 39))) + for w in 2, 3, 4: + self.assertEqual(audioop.lin2adpcm(datas[w], w, None), + (b'\x07\x7f\x7f', (31, 39))) # Very cursory test - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.lin2adpcm(b'\0' * w * 10, w, None), (b'\0' * 5, (0, 0))) + def test_invalid_adpcm_state(self): + # state must be a tuple or None, not an integer + self.assertRaises(TypeError, audioop.adpcm2lin, b'\0', 1, 555) + self.assertRaises(TypeError, audioop.lin2adpcm, b'\0', 1, 555) + def test_lin2alaw(self): self.assertEqual(audioop.lin2alaw(datas[1], 1), b'\xd5\x87\xa4\x24\xaa\x2a\x5a') - self.assertEqual(audioop.lin2alaw(datas[2], 2), - b'\xd5\x87\xa4\x24\xaa\x2a\x55') - self.assertEqual(audioop.lin2alaw(datas[4], 4), - b'\xd5\x87\xa4\x24\xaa\x2a\x55') + self.assertEqual(audioop.lin2alaw(bytearray(datas[1]), 1), + b'\xd5\x87\xa4\x24\xaa\x2a\x5a') + self.assertEqual(audioop.lin2alaw(memoryview(datas[1]), 1), + b'\xd5\x87\xa4\x24\xaa\x2a\x5a') + for w in 2, 3, 4: + self.assertEqual(audioop.lin2alaw(datas[w], w), + b'\xd5\x87\xa4\x24\xaa\x2a\x55') def test_alaw2lin(self): encoded = b'\x00\x03\x24\x2a\x51\x54\x55\x58\x6b\x71\x7f'\ b'\x80\x83\xa4\xaa\xd1\xd4\xd5\xd8\xeb\xf1\xff' src = [-688, -720, -2240, -4032, -9, -3, -1, -27, -244, -82, -106, 688, 720, 2240, 4032, 9, 3, 1, 27, 244, 82, 106] - for w in 1, 2, 4: - self.assertEqual(audioop.alaw2lin(encoded, w), - packs[w](*(x << (w * 8) >> 13 for x in src))) + for w in 1, 2, 3, 4: + decoded = packs[w](*(x << (w * 8) >> 13 for x in src)) + self.assertEqual(audioop.alaw2lin(encoded, w), decoded) + self.assertEqual(audioop.alaw2lin(bytearray(encoded), w), decoded) + self.assertEqual(audioop.alaw2lin(memoryview(encoded), w), decoded) encoded = bytes(range(256)) - for w in 2, 4: + for w in 2, 3, 4: decoded = audioop.alaw2lin(encoded, w) self.assertEqual(audioop.lin2alaw(decoded, w), encoded) def test_lin2ulaw(self): self.assertEqual(audioop.lin2ulaw(datas[1], 1), b'\xff\xad\x8e\x0e\x80\x00\x67') - self.assertEqual(audioop.lin2ulaw(datas[2], 2), - b'\xff\xad\x8e\x0e\x80\x00\x7e') - self.assertEqual(audioop.lin2ulaw(datas[4], 4), - b'\xff\xad\x8e\x0e\x80\x00\x7e') + self.assertEqual(audioop.lin2ulaw(bytearray(datas[1]), 1), + b'\xff\xad\x8e\x0e\x80\x00\x67') + self.assertEqual(audioop.lin2ulaw(memoryview(datas[1]), 1), + b'\xff\xad\x8e\x0e\x80\x00\x67') + for w in 2, 3, 4: + self.assertEqual(audioop.lin2ulaw(datas[w], w), + b'\xff\xad\x8e\x0e\x80\x00\x7e') def test_ulaw2lin(self): encoded = b'\x00\x0e\x28\x3f\x57\x6a\x76\x7c\x7e\x7f'\ b'\x80\x8e\xa8\xbf\xd7\xea\xf6\xfc\xfe\xff' src = [-8031, -4447, -1471, -495, -163, -53, -18, -6, -2, 0, 8031, 4447, 1471, 495, 163, 53, 18, 6, 2, 0] - for w in 1, 2, 4: - self.assertEqual(audioop.ulaw2lin(encoded, w), - packs[w](*(x << (w * 8) >> 14 for x in src))) + for w in 1, 2, 3, 4: + decoded = packs[w](*(x << (w * 8) >> 14 for x in src)) + self.assertEqual(audioop.ulaw2lin(encoded, w), decoded) + self.assertEqual(audioop.ulaw2lin(bytearray(encoded), w), decoded) + self.assertEqual(audioop.ulaw2lin(memoryview(encoded), w), decoded) # Current u-law implementation has two codes fo 0: 0x7f and 0xff. encoded = bytes(range(127)) + bytes(range(128, 256)) - for w in 2, 4: + for w in 2, 3, 4: decoded = audioop.ulaw2lin(encoded, w) self.assertEqual(audioop.lin2ulaw(decoded, w), encoded) def test_mul(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.mul(b'', w, 2), b'') + self.assertEqual(audioop.mul(bytearray(), w, 2), b'') + self.assertEqual(audioop.mul(memoryview(b''), w, 2), b'') self.assertEqual(audioop.mul(datas[w], w, 0), b'\0' * len(datas[w])) self.assertEqual(audioop.mul(datas[w], w, 1), @@ -253,14 +342,21 @@ class TestAudioop(unittest.TestCase): b'\x00\x24\x7f\x80\x7f\x80\xfe') self.assertEqual(audioop.mul(datas[2], 2, 2), packs[2](0, 0x2468, 0x7fff, -0x8000, 0x7fff, -0x8000, -2)) + self.assertEqual(audioop.mul(datas[3], 3, 2), + packs[3](0, 0x2468ac, 0x7fffff, -0x800000, + 0x7fffff, -0x800000, -2)) self.assertEqual(audioop.mul(datas[4], 4, 2), packs[4](0, 0x2468acf0, 0x7fffffff, -0x80000000, 0x7fffffff, -0x80000000, -2)) def test_ratecv(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.ratecv(b'', w, 1, 8000, 8000, None), (b'', (-1, ((0, 0),)))) + self.assertEqual(audioop.ratecv(bytearray(), w, 1, 8000, 8000, None), + (b'', (-1, ((0, 0),)))) + self.assertEqual(audioop.ratecv(memoryview(b''), w, 1, 8000, 8000, None), + (b'', (-1, ((0, 0),)))) self.assertEqual(audioop.ratecv(b'', w, 5, 8000, 8000, None), (b'', (-1, ((0, 0),) * 5))) self.assertEqual(audioop.ratecv(b'', w, 1, 8000, 16000, None), @@ -272,7 +368,7 @@ class TestAudioop(unittest.TestCase): d2, state = audioop.ratecv(b'\x00\x01\x02', 1, 1, 8000, 16000, state) self.assertEqual(d1 + d2, b'\000\000\001\001\002\001\000\000\001\001\002') - for w in 1, 2, 4: + for w in 1, 2, 3, 4: d0, state0 = audioop.ratecv(datas[w], w, 1, 8000, 16000, None) d, state = b'', None for i in range(0, len(datas[w]), w): @@ -283,13 +379,15 @@ class TestAudioop(unittest.TestCase): self.assertEqual(state, state0) def test_reverse(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: self.assertEqual(audioop.reverse(b'', w), b'') + self.assertEqual(audioop.reverse(bytearray(), w), b'') + self.assertEqual(audioop.reverse(memoryview(b''), w), b'') self.assertEqual(audioop.reverse(packs[w](0, 1, 2), w), packs[w](2, 1, 0)) def test_tomono(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: data1 = datas[w] data2 = bytearray(2 * len(data1)) for k in range(w): @@ -299,9 +397,13 @@ class TestAudioop(unittest.TestCase): for k in range(w): data2[k+w::2*w] = data1[k::w] self.assertEqual(audioop.tomono(data2, w, 0.5, 0.5), data1) + self.assertEqual(audioop.tomono(bytearray(data2), w, 0.5, 0.5), + data1) + self.assertEqual(audioop.tomono(memoryview(data2), w, 0.5, 0.5), + data1) def test_tostereo(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: data1 = datas[w] data2 = bytearray(2 * len(data1)) for k in range(w): @@ -311,14 +413,25 @@ class TestAudioop(unittest.TestCase): for k in range(w): data2[k+w::2*w] = data1[k::w] self.assertEqual(audioop.tostereo(data1, w, 1, 1), data2) + self.assertEqual(audioop.tostereo(bytearray(data1), w, 1, 1), data2) + self.assertEqual(audioop.tostereo(memoryview(data1), w, 1, 1), + data2) def test_findfactor(self): self.assertEqual(audioop.findfactor(datas[2], datas[2]), 1.0) + self.assertEqual(audioop.findfactor(bytearray(datas[2]), + bytearray(datas[2])), 1.0) + self.assertEqual(audioop.findfactor(memoryview(datas[2]), + memoryview(datas[2])), 1.0) self.assertEqual(audioop.findfactor(b'\0' * len(datas[2]), datas[2]), 0.0) def test_findfit(self): self.assertEqual(audioop.findfit(datas[2], datas[2]), (0, 1.0)) + self.assertEqual(audioop.findfit(bytearray(datas[2]), + bytearray(datas[2])), (0, 1.0)) + self.assertEqual(audioop.findfit(memoryview(datas[2]), + memoryview(datas[2])), (0, 1.0)) self.assertEqual(audioop.findfit(datas[2], packs[2](1, 2, 0)), (1, 8038.8)) self.assertEqual(audioop.findfit(datas[2][:-2] * 5 + datas[2], datas[2]), @@ -326,16 +439,37 @@ class TestAudioop(unittest.TestCase): def test_findmax(self): self.assertEqual(audioop.findmax(datas[2], 1), 5) + self.assertEqual(audioop.findmax(bytearray(datas[2]), 1), 5) + self.assertEqual(audioop.findmax(memoryview(datas[2]), 1), 5) def test_getsample(self): - for w in 1, 2, 4: + for w in 1, 2, 3, 4: data = packs[w](0, 1, -1, maxvalues[w], minvalues[w]) self.assertEqual(audioop.getsample(data, w, 0), 0) + self.assertEqual(audioop.getsample(bytearray(data), w, 0), 0) + self.assertEqual(audioop.getsample(memoryview(data), w, 0), 0) self.assertEqual(audioop.getsample(data, w, 1), 1) self.assertEqual(audioop.getsample(data, w, 2), -1) self.assertEqual(audioop.getsample(data, w, 3), maxvalues[w]) self.assertEqual(audioop.getsample(data, w, 4), minvalues[w]) + def test_byteswap(self): + swapped_datas = { + 1: datas[1], + 2: packs[2](0, 0x3412, 0x6745, -0x6646, -0x81, 0x80, -1), + 3: packs[3](0, 0x563412, -0x7698bb, 0x7798ba, -0x81, 0x80, -1), + 4: packs[4](0, 0x78563412, -0x547698bb, 0x557698ba, + -0x81, 0x80, -1), + } + for w in 1, 2, 3, 4: + self.assertEqual(audioop.byteswap(b'', w), b'') + self.assertEqual(audioop.byteswap(datas[w], w), swapped_datas[w]) + self.assertEqual(audioop.byteswap(swapped_datas[w], w), datas[w]) + self.assertEqual(audioop.byteswap(bytearray(datas[w]), w), + swapped_datas[w]) + self.assertEqual(audioop.byteswap(memoryview(datas[w]), w), + swapped_datas[w]) + def test_negativelen(self): # from issue 3306, previously it segfaulted self.assertRaises(audioop.error, @@ -365,10 +499,33 @@ class TestAudioop(unittest.TestCase): self.assertRaises(audioop.error, audioop.lin2alaw, data, size) self.assertRaises(audioop.error, audioop.lin2adpcm, data, size, state) + def test_string(self): + data = 'abcd' + size = 2 + self.assertRaises(TypeError, audioop.getsample, data, size, 0) + self.assertRaises(TypeError, audioop.max, data, size) + self.assertRaises(TypeError, audioop.minmax, data, size) + self.assertRaises(TypeError, audioop.avg, data, size) + self.assertRaises(TypeError, audioop.rms, data, size) + self.assertRaises(TypeError, audioop.avgpp, data, size) + self.assertRaises(TypeError, audioop.maxpp, data, size) + self.assertRaises(TypeError, audioop.cross, data, size) + self.assertRaises(TypeError, audioop.mul, data, size, 1.0) + self.assertRaises(TypeError, audioop.tomono, data, size, 0.5, 0.5) + self.assertRaises(TypeError, audioop.tostereo, data, size, 0.5, 0.5) + self.assertRaises(TypeError, audioop.add, data, data, size) + self.assertRaises(TypeError, audioop.bias, data, size, 0) + self.assertRaises(TypeError, audioop.reverse, data, size) + self.assertRaises(TypeError, audioop.lin2lin, data, size, size) + self.assertRaises(TypeError, audioop.ratecv, data, size, 1, 1, 1, None) + self.assertRaises(TypeError, audioop.lin2ulaw, data, size) + self.assertRaises(TypeError, audioop.lin2alaw, data, size) + self.assertRaises(TypeError, audioop.lin2adpcm, data, size, None) + def test_wrongsize(self): data = b'abcdefgh' state = None - for size in (-1, 0, 3, 5, 1024): + for size in (-1, 0, 5, 1024): self.assertRaises(audioop.error, audioop.ulaw2lin, data, size) self.assertRaises(audioop.error, audioop.alaw2lin, data, size) self.assertRaises(audioop.error, audioop.adpcm2lin, data, size, state) diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py index 13695de67e..b9738ddaf5 100644 --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -5,10 +5,21 @@ import binascii import os import sys import subprocess - +import struct +from array import array class LegacyBase64TestCase(unittest.TestCase): + + # Legacy API is not as permissive as the modern API + def check_type_errors(self, f): + self.assertRaises(TypeError, f, "") + self.assertRaises(TypeError, f, []) + multidimensional = memoryview(b"1234").cast('B', (2, 2)) + self.assertRaises(TypeError, f, multidimensional) + int_data = memoryview(b"1234").cast('I') + self.assertRaises(TypeError, f, int_data) + def test_encodebytes(self): eq = self.assertEqual eq(base64.encodebytes(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=\n") @@ -24,7 +35,9 @@ class LegacyBase64TestCase(unittest.TestCase): b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n") # Non-bytes eq(base64.encodebytes(bytearray(b'abc')), b'YWJj\n') - self.assertRaises(TypeError, base64.encodebytes, "") + eq(base64.encodebytes(memoryview(b'abc')), b'YWJj\n') + eq(base64.encodebytes(array('B', b'abc')), b'YWJj\n') + self.check_type_errors(base64.encodebytes) def test_decodebytes(self): eq = self.assertEqual @@ -41,7 +54,9 @@ class LegacyBase64TestCase(unittest.TestCase): eq(base64.decodebytes(b''), b'') # Non-bytes eq(base64.decodebytes(bytearray(b'YWJj\n')), b'abc') - self.assertRaises(TypeError, base64.decodebytes, "") + eq(base64.decodebytes(memoryview(b'YWJj\n')), b'abc') + eq(base64.decodebytes(array('B', b'YWJj\n')), b'abc') + self.check_type_errors(base64.decodebytes) def test_encode(self): eq = self.assertEqual @@ -73,6 +88,42 @@ class LegacyBase64TestCase(unittest.TestCase): class BaseXYTestCase(unittest.TestCase): + + # Modern API completely ignores exported dimension and format data and + # treats any buffer as a stream of bytes + def check_encode_type_errors(self, f): + self.assertRaises(TypeError, f, "") + self.assertRaises(TypeError, f, []) + + def check_decode_type_errors(self, f): + self.assertRaises(TypeError, f, []) + + def check_other_types(self, f, bytes_data, expected): + eq = self.assertEqual + b = bytearray(bytes_data) + eq(f(b), expected) + # The bytearray wasn't mutated + eq(b, bytes_data) + eq(f(memoryview(bytes_data)), expected) + eq(f(array('B', bytes_data)), expected) + # XXX why is b64encode hardcoded here? + self.check_nonbyte_element_format(base64.b64encode, bytes_data) + self.check_multidimensional(base64.b64encode, bytes_data) + + def check_multidimensional(self, f, data): + padding = b"\x00" if len(data) % 2 else b"" + bytes_data = data + padding # Make sure cast works + shape = (len(bytes_data) // 2, 2) + multidimensional = memoryview(bytes_data).cast('B', shape) + self.assertEqual(f(multidimensional), f(bytes_data)) + + def check_nonbyte_element_format(self, f, data): + padding = b"\x00" * ((4 - len(data)) % 4) + bytes_data = data + padding # Make sure cast works + int_data = memoryview(bytes_data).cast('I') + self.assertEqual(f(int_data), f(bytes_data)) + + def test_b64encode(self): eq = self.assertEqual # Test default alphabet @@ -90,13 +141,16 @@ class BaseXYTestCase(unittest.TestCase): b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==") # Test with arbitrary alternative characters eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=b'*$'), b'01a*b$cd') - # Non-bytes - eq(base64.b64encode(bytearray(b'abcd')), b'YWJjZA==') eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=bytearray(b'*$')), b'01a*b$cd') - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.b64encode, "") - self.assertRaises(TypeError, base64.b64encode, b"", altchars="") + eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=memoryview(b'*$')), + b'01a*b$cd') + eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=array('B', b'*$')), + b'01a*b$cd') + # Non-bytes + self.check_other_types(base64.b64encode, b'abcd', b'YWJjZA==') + self.check_encode_type_errors(base64.b64encode) + self.assertRaises(TypeError, base64.b64encode, b"", altchars="*$") # Test standard alphabet eq(base64.standard_b64encode(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=") eq(base64.standard_b64encode(b"a"), b"YQ==") @@ -110,15 +164,15 @@ class BaseXYTestCase(unittest.TestCase): b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT" b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==") # Non-bytes - eq(base64.standard_b64encode(bytearray(b'abcd')), b'YWJjZA==') - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.standard_b64encode, "") + self.check_other_types(base64.standard_b64encode, + b'abcd', b'YWJjZA==') + self.check_encode_type_errors(base64.standard_b64encode) # Test with 'URL safe' alternative characters eq(base64.urlsafe_b64encode(b'\xd3V\xbeo\xf7\x1d'), b'01a-b_cd') # Non-bytes - eq(base64.urlsafe_b64encode(bytearray(b'\xd3V\xbeo\xf7\x1d')), b'01a-b_cd') - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.urlsafe_b64encode, "") + self.check_other_types(base64.urlsafe_b64encode, + b'\xd3V\xbeo\xf7\x1d', b'01a-b_cd') + self.check_encode_type_errors(base64.urlsafe_b64encode) def test_b64decode(self): eq = self.assertEqual @@ -141,7 +195,8 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.b64decode(data), res) eq(base64.b64decode(data.decode('ascii')), res) # Non-bytes - eq(base64.b64decode(bytearray(b"YWJj")), b"abc") + self.check_other_types(base64.b64decode, b"YWJj", b"abc") + self.check_decode_type_errors(base64.b64decode) # Test with arbitrary alternative characters tests_altchars = {(b'01a*b$cd', b'*$'): b'\xd3V\xbeo\xf7\x1d', @@ -160,7 +215,8 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.standard_b64decode(data), res) eq(base64.standard_b64decode(data.decode('ascii')), res) # Non-bytes - eq(base64.standard_b64decode(bytearray(b"YWJj")), b"abc") + self.check_other_types(base64.standard_b64decode, b"YWJj", b"abc") + self.check_decode_type_errors(base64.standard_b64decode) # Test with 'URL safe' alternative characters tests_urlsafe = {b'01a-b_cd': b'\xd3V\xbeo\xf7\x1d', @@ -170,7 +226,9 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.urlsafe_b64decode(data), res) eq(base64.urlsafe_b64decode(data.decode('ascii')), res) # Non-bytes - eq(base64.urlsafe_b64decode(bytearray(b'01a-b_cd')), b'\xd3V\xbeo\xf7\x1d') + self.check_other_types(base64.urlsafe_b64decode, b'01a-b_cd', + b'\xd3V\xbeo\xf7\x1d') + self.check_decode_type_errors(base64.urlsafe_b64decode) def test_b64decode_padding_error(self): self.assertRaises(binascii.Error, base64.b64decode, b'abc') @@ -205,8 +263,8 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.b32encode(b'abcd'), b'MFRGGZA=') eq(base64.b32encode(b'abcde'), b'MFRGGZDF') # Non-bytes - eq(base64.b32encode(bytearray(b'abcd')), b'MFRGGZA=') - self.assertRaises(TypeError, base64.b32encode, "") + self.check_other_types(base64.b32encode, b'abcd', b'MFRGGZA=') + self.check_encode_type_errors(base64.b32encode) def test_b32decode(self): eq = self.assertEqual @@ -222,7 +280,8 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.b32decode(data), res) eq(base64.b32decode(data.decode('ascii')), res) # Non-bytes - eq(base64.b32decode(bytearray(b'MFRGG===')), b'abc') + self.check_other_types(base64.b32decode, b'MFRGG===', b"abc") + self.check_decode_type_errors(base64.b32decode) def test_b32decode_casefold(self): eq = self.assertEqual @@ -277,8 +336,9 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.b16encode(b'\x01\x02\xab\xcd\xef'), b'0102ABCDEF') eq(base64.b16encode(b'\x00'), b'00') # Non-bytes - eq(base64.b16encode(bytearray(b'\x01\x02\xab\xcd\xef')), b'0102ABCDEF') - self.assertRaises(TypeError, base64.b16encode, "") + self.check_other_types(base64.b16encode, b'\x01\x02\xab\xcd\xef', + b'0102ABCDEF') + self.check_encode_type_errors(base64.b16encode) def test_b16decode(self): eq = self.assertEqual @@ -293,14 +353,268 @@ class BaseXYTestCase(unittest.TestCase): eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef') eq(base64.b16decode('0102abcdef', True), b'\x01\x02\xab\xcd\xef') # Non-bytes - eq(base64.b16decode(bytearray(b"0102ABCDEF")), b'\x01\x02\xab\xcd\xef') + self.check_other_types(base64.b16decode, b"0102ABCDEF", + b'\x01\x02\xab\xcd\xef') + self.check_decode_type_errors(base64.b16decode) + eq(base64.b16decode(bytearray(b"0102abcdef"), True), + b'\x01\x02\xab\xcd\xef') + eq(base64.b16decode(memoryview(b"0102abcdef"), True), + b'\x01\x02\xab\xcd\xef') + eq(base64.b16decode(array('B', b"0102abcdef"), True), + b'\x01\x02\xab\xcd\xef') + + def test_a85encode(self): + eq = self.assertEqual + + tests = { + b'': b'', + b"www.python.org": b'GB\\6`E-ZP=Df.1GEb>', + bytes(range(255)): b"""!!*-'"9eu7#RLhG$k3[W&.oNg'GVB"(`=52*$$""" + b"""(B+<_pR,UFcb-n-Vr/1iJ-0JP==1c70M3&s#]4?Ykm5X@_(6q'R884cE""" + b"""H9MJ8X:f1+h<)lt#=BSg3>[:ZC?t!MSA7]@cBPD3sCi+'.E,fo>FEMbN""" + b"""G^4U^I!pHnJ:W<)KS>/9Ll%"IN/`jYOHG]iPa.Q$R$jD4S=Q7DTV8*TU""" + b"""nsrdW2ZetXKAY/Yd(L?['d?O\\@K2_]Y2%o^qmn*`5Ta:aN;TJbg"GZd""" + b"""*^:jeCE.%f\\,!5gtgiEi8N\\UjQ5OekiqBum-X60nF?)@o_%qPq"ad`""" + b"""r;HT""", + b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + b"0123456789!@#0^&*();:<>,. []{}": + b'@:E_WAS,RgBkhF"D/O92EH6,BF`qtRH$VbC6UX@47n?3D92&&T' + b":Jand;cHat='/U/0JP==1c70M3&r-I,;<FN.OZ`-3]oSW/g+A(H[P", + b"no padding..": b'DJpY:@:Wn_DJ(RS', + b"zero compression\0\0\0\0": b'H=_,8+Cf>,E,oN2F(oQ1z', + b"zero compression\0\0\0": b'H=_,8+Cf>,E,oN2F(oQ1!!!!', + b"Boundary:\0\0\0\0": b'6>q!aA79M(3WK-[!!', + b"Space compr: ": b';fH/TAKYK$D/aMV+<VdL', + b'\xff': b'rr', + b'\xff'*2: b's8N', + b'\xff'*3: b's8W*', + b'\xff'*4: b's8W-!', + } + + for data, res in tests.items(): + eq(base64.a85encode(data), res, data) + eq(base64.a85encode(data, adobe=False), res, data) + eq(base64.a85encode(data, adobe=True), b'<~' + res + b'~>', data) + + self.check_other_types(base64.a85encode, b"www.python.org", + b'GB\\6`E-ZP=Df.1GEb>') + + self.assertRaises(TypeError, base64.a85encode, "") + + eq(base64.a85encode(b"www.python.org", wrapcol=7, adobe=False), + b'GB\\6`E-\nZP=Df.1\nGEb>') + eq(base64.a85encode(b"\0\0\0\0www.python.org", wrapcol=7, adobe=False), + b'zGB\\6`E\n-ZP=Df.\n1GEb>') + eq(base64.a85encode(b"www.python.org", wrapcol=7, adobe=True), + b'<~GB\\6`\nE-ZP=Df\n.1GEb>\n~>') + + eq(base64.a85encode(b' '*8, foldspaces=True, adobe=False), b'yy') + eq(base64.a85encode(b' '*7, foldspaces=True, adobe=False), b'y+<Vd') + eq(base64.a85encode(b' '*6, foldspaces=True, adobe=False), b'y+<U') + eq(base64.a85encode(b' '*5, foldspaces=True, adobe=False), b'y+9') + + def test_b85encode(self): + eq = self.assertEqual + + tests = { + b'': b'', + b'www.python.org': b'cXxL#aCvlSZ*DGca%T', + bytes(range(255)): b"""009C61O)~M2nh-c3=Iws5D^j+6crX17#SKH9337X""" + b"""AR!_nBqb&%C@Cr{EG;fCFflSSG&MFiI5|2yJUu=?KtV!7L`6nNNJ&ad""" + b"""OifNtP*GA-R8>}2SXo+ITwPvYU}0ioWMyV&XlZI|Y;A6DaB*^Tbai%j""" + b"""czJqze0_d@fPsR8goTEOh>41ejE#<ukdcy;l$Dm3n3<ZJoSmMZprN9p""" + b"""q@|{(sHv)}tgWuEu(7hUw6(UkxVgH!yuH4^z`?@9#Kp$P$jQpf%+1cv""" + b"""(9zP<)YaD4*xB0K+}+;a;Njxq<mKk)=;`X~?CtLF@bU8V^!4`l`1$(#""" + b"""{Qdp""", + b"""abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ""" + b"""0123456789!@#0^&*();:<>,. []{}""": + b"""VPa!sWoBn+X=-b1ZEkOHadLBXb#`}nd3r%YLqtVJM@UIZOH55pPf$@(""" + b"""Q&d$}S6EqEFflSSG&MFiI5{CeBQRbjDkv#CIy^osE+AW7dwl""", + b'no padding..': b'Zf_uPVPs@!Zf7no', + b'zero compression\x00\x00\x00\x00': b'dS!BNAY*TBaB^jHb7^mG00000', + b'zero compression\x00\x00\x00': b'dS!BNAY*TBaB^jHb7^mG0000', + b"""Boundary:\x00\x00\x00\x00""": b"""LT`0$WMOi7IsgCw00""", + b'Space compr: ': b'Q*dEpWgug3ZE$irARr(h', + b'\xff': b'{{', + b'\xff'*2: b'|Nj', + b'\xff'*3: b'|Ns9', + b'\xff'*4: b'|NsC0', + } + + for data, res in tests.items(): + eq(base64.b85encode(data), res) + + self.check_other_types(base64.b85encode, b"www.python.org", + b'cXxL#aCvlSZ*DGca%T') + + def test_a85decode(self): + eq = self.assertEqual + + tests = { + b'': b'', + b'GB\\6`E-ZP=Df.1GEb>': b'www.python.org', + b"""! ! * -'"\n\t\t9eu\r\n7# RL\vhG$k3[W&.oNg'GVB"(`=52*$$""" + b"""(B+<_pR,UFcb-n-Vr/1iJ-0JP==1c70M3&s#]4?Ykm5X@_(6q'R884cE""" + b"""H9MJ8X:f1+h<)lt#=BSg3>[:ZC?t!MSA7]@cBPD3sCi+'.E,fo>FEMbN""" + b"""G^4U^I!pHnJ:W<)KS>/9Ll%"IN/`jYOHG]iPa.Q$R$jD4S=Q7DTV8*TU""" + b"""nsrdW2ZetXKAY/Yd(L?['d?O\\@K2_]Y2%o^qmn*`5Ta:aN;TJbg"GZd""" + b"""*^:jeCE.%f\\,!5gtgiEi8N\\UjQ5OekiqBum-X60nF?)@o_%qPq"ad`""" + b"""r;HT""": bytes(range(255)), + b"""@:E_WAS,RgBkhF"D/O92EH6,BF`qtRH$VbC6UX@47n?3D92&&T:Jand;c""" + b"""Hat='/U/0JP==1c70M3&r-I,;<FN.OZ`-3]oSW/g+A(H[P""": + b'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234' + b'56789!@#0^&*();:<>,. []{}', + b'DJpY:@:Wn_DJ(RS': b'no padding..', + b'H=_,8+Cf>,E,oN2F(oQ1z': b'zero compression\x00\x00\x00\x00', + b'H=_,8+Cf>,E,oN2F(oQ1!!!!': b'zero compression\x00\x00\x00', + b'6>q!aA79M(3WK-[!!': b"Boundary:\x00\x00\x00\x00", + b';fH/TAKYK$D/aMV+<VdL': b'Space compr: ', + b'rr': b'\xff', + b's8N': b'\xff'*2, + b's8W*': b'\xff'*3, + b's8W-!': b'\xff'*4, + } + + for data, res in tests.items(): + eq(base64.a85decode(data), res, data) + eq(base64.a85decode(data, adobe=False), res, data) + eq(base64.a85decode(data.decode("ascii"), adobe=False), res, data) + eq(base64.a85decode(b'<~' + data + b'~>', adobe=True), res, data) + eq(base64.a85decode('<~%s~>' % data.decode("ascii"), adobe=True), + res, data) + + eq(base64.a85decode(b'yy', foldspaces=True, adobe=False), b' '*8) + eq(base64.a85decode(b'y+<Vd', foldspaces=True, adobe=False), b' '*7) + eq(base64.a85decode(b'y+<U', foldspaces=True, adobe=False), b' '*6) + eq(base64.a85decode(b'y+9', foldspaces=True, adobe=False), b' '*5) + + self.check_other_types(base64.a85decode, b'GB\\6`E-ZP=Df.1GEb>', + b"www.python.org") + + def test_b85decode(self): + eq = self.assertEqual + + tests = { + b'': b'', + b'cXxL#aCvlSZ*DGca%T': b'www.python.org', + b"""009C61O)~M2nh-c3=Iws5D^j+6crX17#SKH9337X""" + b"""AR!_nBqb&%C@Cr{EG;fCFflSSG&MFiI5|2yJUu=?KtV!7L`6nNNJ&ad""" + b"""OifNtP*GA-R8>}2SXo+ITwPvYU}0ioWMyV&XlZI|Y;A6DaB*^Tbai%j""" + b"""czJqze0_d@fPsR8goTEOh>41ejE#<ukdcy;l$Dm3n3<ZJoSmMZprN9p""" + b"""q@|{(sHv)}tgWuEu(7hUw6(UkxVgH!yuH4^z`?@9#Kp$P$jQpf%+1cv""" + b"""(9zP<)YaD4*xB0K+}+;a;Njxq<mKk)=;`X~?CtLF@bU8V^!4`l`1$(#""" + b"""{Qdp""": bytes(range(255)), + b"""VPa!sWoBn+X=-b1ZEkOHadLBXb#`}nd3r%YLqtVJM@UIZOH55pPf$@(""" + b"""Q&d$}S6EqEFflSSG&MFiI5{CeBQRbjDkv#CIy^osE+AW7dwl""": + b"""abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ""" + b"""0123456789!@#0^&*();:<>,. []{}""", + b'Zf_uPVPs@!Zf7no': b'no padding..', + b'dS!BNAY*TBaB^jHb7^mG00000': b'zero compression\x00\x00\x00\x00', + b'dS!BNAY*TBaB^jHb7^mG0000': b'zero compression\x00\x00\x00', + b"""LT`0$WMOi7IsgCw00""": b"""Boundary:\x00\x00\x00\x00""", + b'Q*dEpWgug3ZE$irARr(h': b'Space compr: ', + b'{{': b'\xff', + b'|Nj': b'\xff'*2, + b'|Ns9': b'\xff'*3, + b'|NsC0': b'\xff'*4, + } + + for data, res in tests.items(): + eq(base64.b85decode(data), res) + eq(base64.b85decode(data.decode("ascii")), res) + + self.check_other_types(base64.b85decode, b'cXxL#aCvlSZ*DGca%T', + b"www.python.org") + + def test_a85_padding(self): + eq = self.assertEqual + + eq(base64.a85encode(b"x", pad=True), b'GQ7^D') + eq(base64.a85encode(b"xx", pad=True), b"G^'2g") + eq(base64.a85encode(b"xxx", pad=True), b'G^+H5') + eq(base64.a85encode(b"xxxx", pad=True), b'G^+IX') + eq(base64.a85encode(b"xxxxx", pad=True), b'G^+IXGQ7^D') + + eq(base64.a85decode(b'GQ7^D'), b"x\x00\x00\x00") + eq(base64.a85decode(b"G^'2g"), b"xx\x00\x00") + eq(base64.a85decode(b'G^+H5'), b"xxx\x00") + eq(base64.a85decode(b'G^+IX'), b"xxxx") + eq(base64.a85decode(b'G^+IXGQ7^D'), b"xxxxx\x00\x00\x00") + + def test_b85_padding(self): + eq = self.assertEqual + + eq(base64.b85encode(b"x", pad=True), b'cmMzZ') + eq(base64.b85encode(b"xx", pad=True), b'cz6H+') + eq(base64.b85encode(b"xxx", pad=True), b'czAdK') + eq(base64.b85encode(b"xxxx", pad=True), b'czAet') + eq(base64.b85encode(b"xxxxx", pad=True), b'czAetcmMzZ') + + eq(base64.b85decode(b'cmMzZ'), b"x\x00\x00\x00") + eq(base64.b85decode(b'cz6H+'), b"xx\x00\x00") + eq(base64.b85decode(b'czAdK'), b"xxx\x00") + eq(base64.b85decode(b'czAet'), b"xxxx") + eq(base64.b85decode(b'czAetcmMzZ'), b"xxxxx\x00\x00\x00") + + def test_a85decode_errors(self): + illegal = (set(range(32)) | set(range(118, 256))) - set(b' \t\n\r\v') + for c in illegal: + with self.assertRaises(ValueError, msg=bytes([c])): + base64.a85decode(b'!!!!' + bytes([c])) + with self.assertRaises(ValueError, msg=bytes([c])): + base64.a85decode(b'!!!!' + bytes([c]), adobe=False) + with self.assertRaises(ValueError, msg=bytes([c])): + base64.a85decode(b'<~!!!!' + bytes([c]) + b'~>', adobe=True) + + self.assertRaises(ValueError, base64.a85decode, + b"malformed", adobe=True) + self.assertRaises(ValueError, base64.a85decode, + b"<~still malformed", adobe=True) + self.assertRaises(ValueError, base64.a85decode, + b"also malformed~>", adobe=True) + + # With adobe=False (the default), Adobe framing markers are disallowed + self.assertRaises(ValueError, base64.a85decode, + b"<~~>") + self.assertRaises(ValueError, base64.a85decode, + b"<~~>", adobe=False) + base64.a85decode(b"<~~>", adobe=True) # sanity check + + self.assertRaises(ValueError, base64.a85decode, + b"abcx", adobe=False) + self.assertRaises(ValueError, base64.a85decode, + b"abcdey", adobe=False) + self.assertRaises(ValueError, base64.a85decode, + b"a b\nc", adobe=False, ignorechars=b"") + + self.assertRaises(ValueError, base64.a85decode, b's', adobe=False) + self.assertRaises(ValueError, base64.a85decode, b's8', adobe=False) + self.assertRaises(ValueError, base64.a85decode, b's8W', adobe=False) + self.assertRaises(ValueError, base64.a85decode, b's8W-', adobe=False) + self.assertRaises(ValueError, base64.a85decode, b's8W-"', adobe=False) + + def test_b85decode_errors(self): + illegal = list(range(33)) + \ + list(b'"\',./:[\\]') + \ + list(range(128, 256)) + for c in illegal: + with self.assertRaises(ValueError, msg=bytes([c])): + base64.b85decode(b'0000' + bytes([c])) + + self.assertRaises(ValueError, base64.b85decode, b'|') + self.assertRaises(ValueError, base64.b85decode, b'|N') + self.assertRaises(ValueError, base64.b85decode, b'|Ns') + self.assertRaises(ValueError, base64.b85decode, b'|NsC') + self.assertRaises(ValueError, base64.b85decode, b'|NsC1') def test_decode_nonascii_str(self): decode_funcs = (base64.b64decode, base64.standard_b64decode, base64.urlsafe_b64decode, base64.b32decode, - base64.b16decode) + base64.b16decode, + base64.b85decode, + base64.a85decode) for f in decode_funcs: self.assertRaises(ValueError, f, 'with non-ascii \xcb') diff --git a/Lib/test/test_bisect.py b/Lib/test/test_bisect.py index 7b9bd19fcb..580a963f62 100644 --- a/Lib/test/test_bisect.py +++ b/Lib/test/test_bisect.py @@ -7,7 +7,7 @@ py_bisect = support.import_fresh_module('bisect', blocked=['_bisect']) c_bisect = support.import_fresh_module('bisect', fresh=['_bisect']) class Range(object): - """A trivial range()-like object without any integer width limitations.""" + """A trivial range()-like object that has an insert() method.""" def __init__(self, start, stop): self.start = start self.stop = stop @@ -120,10 +120,10 @@ class TestBisect: def test_negative_lo(self): # Issue 3301 mod = self.module - self.assertRaises(ValueError, mod.bisect_left, [1, 2, 3], 5, -1, 3), - self.assertRaises(ValueError, mod.bisect_right, [1, 2, 3], 5, -1, 3), - self.assertRaises(ValueError, mod.insort_left, [1, 2, 3], 5, -1, 3), - self.assertRaises(ValueError, mod.insort_right, [1, 2, 3], 5, -1, 3), + self.assertRaises(ValueError, mod.bisect_left, [1, 2, 3], 5, -1, 3) + self.assertRaises(ValueError, mod.bisect_right, [1, 2, 3], 5, -1, 3) + self.assertRaises(ValueError, mod.insort_left, [1, 2, 3], 5, -1, 3) + self.assertRaises(ValueError, mod.insort_right, [1, 2, 3], 5, -1, 3) def test_large_range(self): # Issue 13496 diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py index 04e3f61e89..1667847a9d 100644 --- a/Lib/test/test_buffer.py +++ b/Lib/test/test_buffer.py @@ -4290,9 +4290,5 @@ class TestBufferProtocol(unittest.TestCase): self.assertRaises(BufferError, memoryview, x) -def test_main(): - support.run_unittest(TestBufferProtocol) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index c342a4329f..b561a6f73a 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -16,6 +16,7 @@ import unittest import warnings from operator import neg from test.support import TESTFN, unlink, run_unittest, check_warnings +from test.script_helper import assert_python_ok try: import pty, signal except ImportError: @@ -464,6 +465,11 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(TypeError, eval, ()) self.assertRaises(SyntaxError, eval, bom[:2] + b'a') + class X: + def __getitem__(self, key): + raise ValueError + self.assertRaises(ValueError, eval, "foo", {}, X()) + def test_general_eval(self): # Tests that general mappings can be used for the locals argument @@ -579,7 +585,10 @@ class BuiltinTest(unittest.TestCase): raise frozendict_error("frozendict is readonly") # read-only builtins - frozen_builtins = frozendict(__builtins__) + if isinstance(__builtins__, types.ModuleType): + frozen_builtins = frozendict(__builtins__.__dict__) + else: + frozen_builtins = frozendict(__builtins__) code = compile("__builtins__['superglobal']=2; print(superglobal)", "test", "exec") self.assertRaises(frozendict_error, exec, code, {'__builtins__': frozen_builtins}) @@ -839,8 +848,19 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(max(1, 2.0, 3), 3) self.assertEqual(max(1.0, 2, 3), 3) + self.assertRaises(TypeError, max) + self.assertRaises(TypeError, max, 42) + self.assertRaises(ValueError, max, ()) + class BadSeq: + def __getitem__(self, index): + raise ValueError + self.assertRaises(ValueError, max, BadSeq()) + for stmt in ( "max(key=int)", # no args + "max(default=None)", + "max(1, 2, default=None)", # require container for default + "max(default=None, key=int)", "max(1, key=int)", # single arg not iterable "max(1, 2, keystone=int)", # wrong keyword "max(1, 2, key=int, abc=int)", # two many keywords @@ -857,6 +877,13 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(max((1,2), key=neg), 1) # two elem iterable self.assertEqual(max(1, 2, key=neg), 1) # two elems + self.assertEqual(max((), default=None), None) # zero elem iterable + self.assertEqual(max((1,), default=None), 1) # one elem iterable + self.assertEqual(max((1,2), default=None), 2) # two elem iterable + + self.assertEqual(max((), default=1, key=neg), 1) + self.assertEqual(max((1, 2), default=3, key=neg), 1) + data = [random.randrange(200) for i in range(100)] keys = dict((elem, random.randrange(50)) for elem in data) f = keys.__getitem__ @@ -883,6 +910,9 @@ class BuiltinTest(unittest.TestCase): for stmt in ( "min(key=int)", # no args + "min(default=None)", + "min(1, 2, default=None)", # require container for default + "min(default=None, key=int)", "min(1, key=int)", # single arg not iterable "min(1, 2, keystone=int)", # wrong keyword "min(1, 2, key=int, abc=int)", # two many keywords @@ -899,6 +929,13 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(min((1,2), key=neg), 2) # two elem iterable self.assertEqual(min(1, 2, key=neg), 2) # two elems + self.assertEqual(min((), default=None), None) # zero elem iterable + self.assertEqual(min((1,), default=None), 1) # one elem iterable + self.assertEqual(min((1,2), default=None), 1) # two elem iterable + + self.assertEqual(min((), default=1, key=neg), 1) + self.assertEqual(min((1, 2), default=1, key=neg), 2) + data = [random.randrange(200) for i in range(100)] keys = dict((elem, random.randrange(50)) for elem in data) f = keys.__getitem__ @@ -940,29 +977,25 @@ class BuiltinTest(unittest.TestCase): def write_testfile(self): # NB the first 4 lines are also used to test input, below fp = open(TESTFN, 'w') - try: + self.addCleanup(unlink, TESTFN) + with fp: fp.write('1+1\n') fp.write('The quick brown fox jumps over the lazy dog') fp.write('.\n') fp.write('Dear John\n') fp.write('XXX'*100) fp.write('YYY'*100) - finally: - fp.close() def test_open(self): self.write_testfile() fp = open(TESTFN, 'r') - try: + with fp: self.assertEqual(fp.readline(4), '1+1\n') self.assertEqual(fp.readline(), 'The quick brown fox jumps over the lazy dog.\n') self.assertEqual(fp.readline(4), 'Dear') self.assertEqual(fp.readline(100), ' John\n') self.assertEqual(fp.read(300), 'XXX'*100) self.assertEqual(fp.read(1000), 'YYY'*100) - finally: - fp.close() - unlink(TESTFN) def test_open_default_encoding(self): old_environ = dict(os.environ) @@ -977,15 +1010,17 @@ class BuiltinTest(unittest.TestCase): self.write_testfile() current_locale_encoding = locale.getpreferredencoding(False) fp = open(TESTFN, 'w') - try: + with fp: self.assertEqual(fp.encoding, current_locale_encoding) - finally: - fp.close() - unlink(TESTFN) finally: os.environ.clear() os.environ.update(old_environ) + def test_open_non_inheritable(self): + fileobj = open(__file__) + with fileobj: + self.assertFalse(os.get_inheritable(fileobj.fileno())) + def test_ord(self): self.assertEqual(ord(' '), 32) self.assertEqual(ord('A'), 65) @@ -1096,7 +1131,6 @@ class BuiltinTest(unittest.TestCase): sys.stdin = savestdin sys.stdout = savestdout fp.close() - unlink(TESTFN) @unittest.skipUnless(pty, "the pty and signal modules must be available") def check_input_tty(self, prompt, terminal_input, stdio_encoding=None): @@ -1476,17 +1510,11 @@ class BuiltinTest(unittest.TestCase): # -------------------------------------------------------------------- # Issue #7994: object.__format__ with a non-empty format string is # deprecated - def test_deprecated_format_string(obj, fmt_str, should_raise_warning): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always", DeprecationWarning) - format(obj, fmt_str) - if should_raise_warning: - self.assertEqual(len(w), 1) - self.assertIsInstance(w[0].message, DeprecationWarning) - self.assertIn('object.__format__ with a non-empty format ' - 'string', str(w[0].message)) + def test_deprecated_format_string(obj, fmt_str, should_raise): + if should_raise: + self.assertRaises(TypeError, format, obj, fmt_str) else: - self.assertEqual(len(w), 0) + format(obj, fmt_str) fmt_strs = ['', 's'] @@ -1565,21 +1593,45 @@ class TestSorted(unittest.TestCase): data = 'The quick Brown fox Jumped over The lazy Dog'.split() self.assertRaises(TypeError, sorted, data, None, lambda x,y: 0) -def test_main(verbose=None): - test_classes = (BuiltinTest, TestSorted) - - run_unittest(*test_classes) - - # verify reference counting - if verbose and hasattr(sys, "gettotalrefcount"): - import gc - counts = [None] * 5 - for i in range(len(counts)): - run_unittest(*test_classes) - gc.collect() - counts[i] = sys.gettotalrefcount() - print(counts) +class ShutdownTest(unittest.TestCase): + + def test_cleanup(self): + # Issue #19255: builtins are still available at shutdown + code = """if 1: + import builtins + import sys + + class C: + def __del__(self): + print("before") + # Check that builtins still exist + len(()) + print("after") + + c = C() + # Make this module survive until builtins and sys are cleaned + builtins.here = sys.modules[__name__] + sys.here = sys.modules[__name__] + # Create a reference loop so that this module needs to go + # through a GC phase. + here = sys.modules[__name__] + """ + # Issue #20599: Force ASCII encoding to get a codec implemented in C, + # otherwise the codec may be unloaded before C.__del__() is called, and + # so print("before") fails because the codec cannot be used to encode + # "before" to sys.stdout.encoding. For example, on Windows, + # sys.stdout.encoding is the OEM code page and these code pages are + # implemented in Python + rc, out, err = assert_python_ok("-c", code, + PYTHONIOENCODING="ascii") + self.assertEqual(["before", "after"], out.decode().splitlines()) + + +def load_tests(loader, tests, pattern): + from doctest import DocTestSuite + tests.addTest(DocTestSuite(builtins)) + return tests if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 3c09141184..f350211f71 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -288,8 +288,22 @@ class BaseBytesTest: self.assertEqual(self.type2test(b"").join(lst), b"abc") self.assertEqual(self.type2test(b"").join(tuple(lst)), b"abc") self.assertEqual(self.type2test(b"").join(iter(lst)), b"abc") - self.assertEqual(self.type2test(b".").join([b"ab", b"cd"]), b"ab.cd") - # XXX more... + dot_join = self.type2test(b".:").join + self.assertEqual(dot_join([b"ab", b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([memoryview(b"ab"), b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([b"ab", memoryview(b"cd")]), b"ab.:cd") + self.assertEqual(dot_join([bytearray(b"ab"), b"cd"]), b"ab.:cd") + self.assertEqual(dot_join([b"ab", bytearray(b"cd")]), b"ab.:cd") + # Stress it with many items + seq = [b"abc"] * 1000 + expected = b"abc" + b".:abc" * 999 + self.assertEqual(dot_join(seq), expected) + # Error handling and cleanup when some item in the middle of the + # sequence has the wrong type. + with self.assertRaises(TypeError): + dot_join([bytearray(b"ab"), "cd", b"ef"]) + with self.assertRaises(TypeError): + dot_join([memoryview(b"ab"), "cd", b"ef"]) def test_count(self): b = self.type2test(b'mississippi') @@ -765,7 +779,7 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): finally: try: os.remove(tfn) - except os.error: + except OSError: pass def test_reverse(self): @@ -901,6 +915,15 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): with self.assertRaises(ValueError): b[3:4] = elem + def test_setslice_extend(self): + # Exercise the resizing logic (see issue #19087) + b = bytearray(range(100)) + self.assertEqual(list(b), list(range(100))) + del b[:10] + self.assertEqual(list(b), list(range(10, 100))) + b.extend(range(100, 110)) + self.assertEqual(list(b), list(range(10, 110))) + def test_extended_set_del_slice(self): indices = (0, None, 1, 3, 19, 300, 1<<333, -1, -2, -31, -300) for start in indices: @@ -1280,6 +1303,11 @@ class BytearrayPEP3137Test(unittest.TestCase, self.assertEqual(val, newval) self.assertTrue(val is not newval, expr+' returned val on a mutable object') + sep = self.marshal(b'') + newval = sep.join([val]) + self.assertEqual(val, newval) + self.assertIsNot(val, newval) + class FixedStringTest(test.string_tests.BaseTest): diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index c8c93513d8..40d8425eee 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -1,5 +1,5 @@ from test import support -from test.support import TESTFN, bigmemtest, _4G +from test.support import bigmemtest, _4G import unittest from io import BytesIO @@ -8,6 +8,7 @@ import pickle import random import subprocess import sys +from test.support import unlink try: import threading @@ -18,10 +19,10 @@ except ImportError: bz2 = support.import_module('bz2') from bz2 import BZ2File, BZ2Compressor, BZ2Decompressor -has_cmdline_bunzip2 = sys.platform not in ("win32", "os2emx") class BaseTest(unittest.TestCase): "Base for other testcases." + TEXT_LINES = [ b'root:x:0:0:root:/root:/bin/bash\n', b'bin:x:1:1:bin:/bin:\n', @@ -51,13 +52,17 @@ class BaseTest(unittest.TestCase): BAD_DATA = b'this is not a valid bzip2 file' def setUp(self): - self.filename = TESTFN + self.filename = support.TESTFN def tearDown(self): if os.path.isfile(self.filename): os.unlink(self.filename) - if has_cmdline_bunzip2: + if sys.platform == "win32": + # bunzip2 isn't available to run on Windows. + def decompress(self, data): + return bz2.decompress(data) + else: def decompress(self, data): pop = subprocess.Popen("bunzip2", shell=True, stdin=subprocess.PIPE, @@ -71,13 +76,9 @@ class BaseTest(unittest.TestCase): ret = bz2.decompress(data) return ret - else: - # bunzip2 isn't available to run on Windows. - def decompress(self, data): - return bz2.decompress(data) class BZ2FileTest(BaseTest): - "Test BZ2File type miscellaneous methods." + "Test the BZ2File class." def createTempFile(self, streams=1, suffix=b""): with open(self.filename, "wb") as f: @@ -85,18 +86,12 @@ class BZ2FileTest(BaseTest): f.write(suffix) def testBadArgs(self): - with self.assertRaises(TypeError): - BZ2File(123.456) - with self.assertRaises(ValueError): - BZ2File("/dev/null", "z") - with self.assertRaises(ValueError): - BZ2File("/dev/null", "rx") - with self.assertRaises(ValueError): - BZ2File("/dev/null", "rbt") - with self.assertRaises(ValueError): - BZ2File("/dev/null", compresslevel=0) - with self.assertRaises(ValueError): - BZ2File("/dev/null", compresslevel=10) + self.assertRaises(TypeError, BZ2File, 123.456) + self.assertRaises(ValueError, BZ2File, "/dev/null", "z") + self.assertRaises(ValueError, BZ2File, "/dev/null", "rx") + self.assertRaises(ValueError, BZ2File, "/dev/null", "rbt") + self.assertRaises(ValueError, BZ2File, "/dev/null", compresslevel=0) + self.assertRaises(ValueError, BZ2File, "/dev/null", compresslevel=10) def testRead(self): self.createTempFile() @@ -232,9 +227,8 @@ class BZ2FileTest(BaseTest): self.createTempFile() bz2f = BZ2File(self.filename) bz2f.close() - self.assertRaises(ValueError, bz2f.__next__) - # This call will deadlock if the above .__next__ call failed to - # release the lock. + self.assertRaises(ValueError, next, bz2f) + # This call will deadlock if the above call failed to release the lock. self.assertRaises(ValueError, bz2f.readlines) def testWrite(self): @@ -278,8 +272,8 @@ class BZ2FileTest(BaseTest): bz2f.write(b"abc") with BZ2File(self.filename, "r") as bz2f: - self.assertRaises(IOError, bz2f.write, b"a") - self.assertRaises(IOError, bz2f.writelines, [b"a"]) + self.assertRaises(OSError, bz2f.write, b"a") + self.assertRaises(OSError, bz2f.writelines, [b"a"]) def testAppend(self): with BZ2File(self.filename, "w") as bz2f: @@ -397,7 +391,7 @@ class BZ2FileTest(BaseTest): bz2f.close() self.assertRaises(ValueError, bz2f.seekable) - bz2f = BZ2File(BytesIO(), mode="w") + bz2f = BZ2File(BytesIO(), "w") try: self.assertFalse(bz2f.seekable()) finally: @@ -423,7 +417,7 @@ class BZ2FileTest(BaseTest): bz2f.close() self.assertRaises(ValueError, bz2f.readable) - bz2f = BZ2File(BytesIO(), mode="w") + bz2f = BZ2File(BytesIO(), "w") try: self.assertFalse(bz2f.readable()) finally: @@ -440,7 +434,7 @@ class BZ2FileTest(BaseTest): bz2f.close() self.assertRaises(ValueError, bz2f.writable) - bz2f = BZ2File(BytesIO(), mode="w") + bz2f = BZ2File(BytesIO(), "w") try: self.assertTrue(bz2f.writable()) finally: @@ -454,7 +448,7 @@ class BZ2FileTest(BaseTest): del o def testOpenNonexistent(self): - self.assertRaises(IOError, BZ2File, "/non/existent") + self.assertRaises(OSError, BZ2File, "/non/existent") def testReadlinesNoNewline(self): # Issue #1191043: readlines() fails on a file containing no newline. @@ -494,7 +488,7 @@ class BZ2FileTest(BaseTest): # Issue #7205: Using a BZ2File from several threads shouldn't deadlock. data = b"1" * 2**20 nthreads = 10 - with bz2.BZ2File(self.filename, 'wb') as f: + with BZ2File(self.filename, 'wb') as f: def comp(): for i in range(5): f.write(data) @@ -505,28 +499,27 @@ class BZ2FileTest(BaseTest): t.join() def testWithoutThreading(self): - bz2 = support.import_fresh_module("bz2", blocked=("threading",)) - with bz2.BZ2File(self.filename, "wb") as f: + module = support.import_fresh_module("bz2", blocked=("threading",)) + with module.BZ2File(self.filename, "wb") as f: f.write(b"abc") - with bz2.BZ2File(self.filename, "rb") as f: + with module.BZ2File(self.filename, "rb") as f: self.assertEqual(f.read(), b"abc") def testMixedIterationAndReads(self): self.createTempFile() linelen = len(self.TEXT_LINES[0]) halflen = linelen // 2 - with bz2.BZ2File(self.filename) as bz2f: + with BZ2File(self.filename) as bz2f: bz2f.read(halflen) self.assertEqual(next(bz2f), self.TEXT_LINES[0][halflen:]) self.assertEqual(bz2f.read(), self.TEXT[linelen:]) - with bz2.BZ2File(self.filename) as bz2f: + with BZ2File(self.filename) as bz2f: bz2f.readline() self.assertEqual(next(bz2f), self.TEXT_LINES[1]) self.assertEqual(bz2f.readline(), self.TEXT_LINES[2]) - with bz2.BZ2File(self.filename) as bz2f: + with BZ2File(self.filename) as bz2f: bz2f.readlines() - with self.assertRaises(StopIteration): - next(bz2f) + self.assertRaises(StopIteration, next, bz2f) self.assertEqual(bz2f.readlines(), []) def testMultiStreamOrdering(self): @@ -594,6 +587,20 @@ class BZ2FileTest(BaseTest): bz2f.seek(-150, 1) self.assertEqual(bz2f.read(), self.TEXT[500-150:]) + def test_read_truncated(self): + # Drop the eos_magic field (6 bytes) and CRC (4 bytes). + truncated = self.DATA[:-10] + with BZ2File(BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + with BZ2File(BytesIO(truncated)) as f: + self.assertEqual(f.read(len(self.TEXT)), self.TEXT) + self.assertRaises(EOFError, f.read, 1) + # Incomplete 4-byte file header, and block header of at least 146 bits. + for i in range(22): + with BZ2File(BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + + class BZ2CompressorTest(BaseTest): def testCompress(self): bz2c = BZ2Compressor() @@ -740,97 +747,122 @@ class CompressDecompressTest(BaseTest): class OpenTest(BaseTest): + "Test the open function." + + def open(self, *args, **kwargs): + return bz2.open(*args, **kwargs) + def test_binary_modes(self): - with bz2.open(self.filename, "wb") as f: - f.write(self.TEXT) - with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()) - self.assertEqual(file_data, self.TEXT) - with bz2.open(self.filename, "rb") as f: - self.assertEqual(f.read(), self.TEXT) - with bz2.open(self.filename, "ab") as f: - f.write(self.TEXT) - with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()) - self.assertEqual(file_data, self.TEXT * 2) + for mode in ("wb", "xb"): + if mode == "xb": + unlink(self.filename) + with self.open(self.filename, mode) as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = self.decompress(f.read()) + self.assertEqual(file_data, self.TEXT) + with self.open(self.filename, "rb") as f: + self.assertEqual(f.read(), self.TEXT) + with self.open(self.filename, "ab") as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = self.decompress(f.read()) + self.assertEqual(file_data, self.TEXT * 2) def test_implicit_binary_modes(self): # Test implicit binary modes (no "b" or "t" in mode string). - with bz2.open(self.filename, "w") as f: - f.write(self.TEXT) - with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()) - self.assertEqual(file_data, self.TEXT) - with bz2.open(self.filename, "r") as f: - self.assertEqual(f.read(), self.TEXT) - with bz2.open(self.filename, "a") as f: - f.write(self.TEXT) - with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()) - self.assertEqual(file_data, self.TEXT * 2) + for mode in ("w", "x"): + if mode == "x": + unlink(self.filename) + with self.open(self.filename, mode) as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = self.decompress(f.read()) + self.assertEqual(file_data, self.TEXT) + with self.open(self.filename, "r") as f: + self.assertEqual(f.read(), self.TEXT) + with self.open(self.filename, "a") as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = self.decompress(f.read()) + self.assertEqual(file_data, self.TEXT * 2) def test_text_modes(self): text = self.TEXT.decode("ascii") text_native_eol = text.replace("\n", os.linesep) - with bz2.open(self.filename, "wt") as f: - f.write(text) - with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()).decode("ascii") - self.assertEqual(file_data, text_native_eol) - with bz2.open(self.filename, "rt") as f: - self.assertEqual(f.read(), text) - with bz2.open(self.filename, "at") as f: - f.write(text) - with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()).decode("ascii") - self.assertEqual(file_data, text_native_eol * 2) + for mode in ("wt", "xt"): + if mode == "xt": + unlink(self.filename) + with self.open(self.filename, mode) as f: + f.write(text) + with open(self.filename, "rb") as f: + file_data = self.decompress(f.read()).decode("ascii") + self.assertEqual(file_data, text_native_eol) + with self.open(self.filename, "rt") as f: + self.assertEqual(f.read(), text) + with self.open(self.filename, "at") as f: + f.write(text) + with open(self.filename, "rb") as f: + file_data = self.decompress(f.read()).decode("ascii") + self.assertEqual(file_data, text_native_eol * 2) + + def test_x_mode(self): + for mode in ("x", "xb", "xt"): + unlink(self.filename) + with self.open(self.filename, mode) as f: + pass + with self.assertRaises(FileExistsError): + with self.open(self.filename, mode) as f: + pass def test_fileobj(self): - with bz2.open(BytesIO(self.DATA), "r") as f: + with self.open(BytesIO(self.DATA), "r") as f: self.assertEqual(f.read(), self.TEXT) - with bz2.open(BytesIO(self.DATA), "rb") as f: + with self.open(BytesIO(self.DATA), "rb") as f: self.assertEqual(f.read(), self.TEXT) text = self.TEXT.decode("ascii") - with bz2.open(BytesIO(self.DATA), "rt") as f: + with self.open(BytesIO(self.DATA), "rt") as f: self.assertEqual(f.read(), text) def test_bad_params(self): # Test invalid parameter combinations. - with self.assertRaises(ValueError): - bz2.open(self.filename, "wbt") - with self.assertRaises(ValueError): - bz2.open(self.filename, "rb", encoding="utf-8") - with self.assertRaises(ValueError): - bz2.open(self.filename, "rb", errors="ignore") - with self.assertRaises(ValueError): - bz2.open(self.filename, "rb", newline="\n") + self.assertRaises(ValueError, + self.open, self.filename, "wbt") + self.assertRaises(ValueError, + self.open, self.filename, "xbt") + self.assertRaises(ValueError, + self.open, self.filename, "rb", encoding="utf-8") + self.assertRaises(ValueError, + self.open, self.filename, "rb", errors="ignore") + self.assertRaises(ValueError, + self.open, self.filename, "rb", newline="\n") def test_encoding(self): # Test non-default encoding. text = self.TEXT.decode("ascii") text_native_eol = text.replace("\n", os.linesep) - with bz2.open(self.filename, "wt", encoding="utf-16-le") as f: + with self.open(self.filename, "wt", encoding="utf-16-le") as f: f.write(text) with open(self.filename, "rb") as f: - file_data = bz2.decompress(f.read()).decode("utf-16-le") + file_data = self.decompress(f.read()).decode("utf-16-le") self.assertEqual(file_data, text_native_eol) - with bz2.open(self.filename, "rt", encoding="utf-16-le") as f: + with self.open(self.filename, "rt", encoding="utf-16-le") as f: self.assertEqual(f.read(), text) def test_encoding_error_handler(self): # Test with non-default encoding error handler. - with bz2.open(self.filename, "wb") as f: + with self.open(self.filename, "wb") as f: f.write(b"foo\xffbar") - with bz2.open(self.filename, "rt", encoding="ascii", errors="ignore") \ + with self.open(self.filename, "rt", encoding="ascii", errors="ignore") \ as f: self.assertEqual(f.read(), "foobar") def test_newline(self): # Test with explicit newline (universal newline mode disabled). text = self.TEXT.decode("ascii") - with bz2.open(self.filename, "wt", newline="\n") as f: + with self.open(self.filename, "wt", newline="\n") as f: f.write(text) - with bz2.open(self.filename, "rt", newline="\r") as f: + with self.open(self.filename, "rt", newline="\r") as f: self.assertEqual(f.readlines(), [text]) diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py index 50c4bba24f..ba7c38db27 100644 --- a/Lib/test/test_capi.py +++ b/Lib/test/test_capi.py @@ -9,6 +9,7 @@ import sys import time import unittest from test import support +from test.support import MISSING_C_DOCSTRINGS try: import _posixsubprocess except ImportError: @@ -44,7 +45,7 @@ class CAPITest(unittest.TestCase): @unittest.skipUnless(threading, 'Threading required for this test.') def test_no_FatalError_infinite_loop(self): - with support.suppress_crash_popup(): + with support.SuppressCrashReport(): p = subprocess.Popen([sys.executable, "-c", 'import _testcapi;' '_testcapi.crash_no_current_thread()'], @@ -110,6 +111,46 @@ class CAPITest(unittest.TestCase): self.assertRaises(TypeError, _posixsubprocess.fork_exec, Z(),[b'1'],3,[1, 2],5,6,7,8,9,10,11,12,13,14,15,16,17) + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_docstring_signature_parsing(self): + + self.assertEqual(_testcapi.no_docstring.__doc__, None) + self.assertEqual(_testcapi.no_docstring.__text_signature__, None) + + self.assertEqual(_testcapi.docstring_empty.__doc__, "") + self.assertEqual(_testcapi.docstring_empty.__text_signature__, None) + + self.assertEqual(_testcapi.docstring_no_signature.__doc__, + "This docstring has no signature.") + self.assertEqual(_testcapi.docstring_no_signature.__text_signature__, None) + + self.assertEqual(_testcapi.docstring_with_invalid_signature.__doc__, + "docstring_with_invalid_signature($module, /, boo)\n" + "\n" + "This docstring has an invalid signature." + ) + self.assertEqual(_testcapi.docstring_with_invalid_signature.__text_signature__, None) + + self.assertEqual(_testcapi.docstring_with_invalid_signature2.__doc__, + "docstring_with_invalid_signature2($module, /, boo)\n" + "\n" + "--\n" + "\n" + "This docstring also has an invalid signature." + ) + self.assertEqual(_testcapi.docstring_with_invalid_signature2.__text_signature__, None) + + self.assertEqual(_testcapi.docstring_with_signature.__doc__, + "This docstring has a valid signature.") + self.assertEqual(_testcapi.docstring_with_signature.__text_signature__, "($module, /, sig)") + + self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__doc__, + "\nThis docstring has a valid signature and some extra newlines.") + self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__text_signature__, + "($module, /, parameter)") + + @unittest.skipUnless(threading, 'Threading required for this test.') class TestPendingCalls(unittest.TestCase): @@ -193,6 +234,9 @@ class TestPendingCalls(unittest.TestCase): self.pendingcalls_submit(l, n) self.pendingcalls_wait(l, n) + +class SubinterpreterTest(unittest.TestCase): + def test_subinterps(self): import builtins r, w = os.pipe() @@ -203,47 +247,109 @@ class TestPendingCalls(unittest.TestCase): pickle.dump(id(builtins), f) """.format(w) with open(r, "rb") as f: - ret = _testcapi.run_in_subinterp(code) + ret = support.run_in_subinterp(code) self.assertEqual(ret, 0) self.assertNotEqual(pickle.load(f), id(sys.modules)) self.assertNotEqual(pickle.load(f), id(builtins)) + # Bug #6012 class Test6012(unittest.TestCase): def test(self): self.assertEqual(_testcapi.argparsing("Hello", "World"), 1) -class EmbeddingTest(unittest.TestCase): - - @unittest.skipIf( - sys.platform.startswith('win'), - "test doesn't work under Windows") - def test_subinterps(self): - # XXX only tested under Unix checkouts +class EmbeddingTests(unittest.TestCase): + def setUp(self): basepath = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - oldcwd = os.getcwd() + exename = "_testembed" + if sys.platform.startswith("win"): + ext = ("_d" if "_d" in sys.executable else "") + ".exe" + exename += ext + exepath = os.path.dirname(sys.executable) + else: + exepath = os.path.join(basepath, "Modules") + self.test_exe = exe = os.path.join(exepath, exename) + if not os.path.exists(exe): + self.skipTest("%r doesn't exist" % exe) # This is needed otherwise we get a fatal error: # "Py_Initialize: Unable to get the locale encoding # LookupError: no codec search functions registered: can't find encoding" + self.oldcwd = os.getcwd() os.chdir(basepath) + + def tearDown(self): + os.chdir(self.oldcwd) + + def run_embedded_interpreter(self, *args): + """Runs a test in the embedded interpreter""" + cmd = [self.test_exe] + cmd.extend(args) + p = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + (out, err) = p.communicate() + self.assertEqual(p.returncode, 0, + "bad returncode %d, stderr is %r" % + (p.returncode, err)) + return out.decode("latin1"), err.decode("latin1") + + def test_subinterps(self): + # This is just a "don't crash" test + out, err = self.run_embedded_interpreter() + if support.verbose: + print() + print(out) + print(err) + + @staticmethod + def _get_default_pipe_encoding(): + rp, wp = os.pipe() try: - exe = os.path.join(basepath, "Modules", "_testembed") - if not os.path.exists(exe): - self.skipTest("%r doesn't exist" % exe) - p = subprocess.Popen([exe], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - (out, err) = p.communicate() - self.assertEqual(p.returncode, 0, - "bad returncode %d, stderr is %r" % - (p.returncode, err)) - if support.verbose: - print() - print(out.decode('latin1')) - print(err.decode('latin1')) + with os.fdopen(wp, 'w') as w: + default_pipe_encoding = w.encoding finally: - os.chdir(oldcwd) + os.close(rp) + return default_pipe_encoding + + def test_forced_io_encoding(self): + # Checks forced configuration of embedded interpreter IO streams + out, err = self.run_embedded_interpreter("forced_io_encoding") + if support.verbose: + print() + print(out) + print(err) + expected_stdin_encoding = sys.__stdin__.encoding + expected_pipe_encoding = self._get_default_pipe_encoding() + expected_output = os.linesep.join([ + "--- Use defaults ---", + "Expected encoding: default", + "Expected errors: default", + "stdin: {0}:strict", + "stdout: {1}:strict", + "stderr: {1}:backslashreplace", + "--- Set errors only ---", + "Expected encoding: default", + "Expected errors: surrogateescape", + "stdin: {0}:surrogateescape", + "stdout: {1}:surrogateescape", + "stderr: {1}:backslashreplace", + "--- Set encoding only ---", + "Expected encoding: latin-1", + "Expected errors: default", + "stdin: latin-1:strict", + "stdout: latin-1:strict", + "stderr: latin-1:backslashreplace", + "--- Set encoding and errors ---", + "Expected encoding: latin-1", + "Expected errors: surrogateescape", + "stdin: latin-1:surrogateescape", + "stdout: latin-1:surrogateescape", + "stderr: latin-1:backslashreplace"]).format(expected_stdin_encoding, + expected_pipe_encoding) + # This is useful if we ever trip over odd platform behaviour + self.maxDiff = None + self.assertEqual(out.strip(), expected_output) class SkipitemTest(unittest.TestCase): @@ -355,8 +461,9 @@ class Test_testcapi(unittest.TestCase): def test__testcapi(self): for name in dir(_testcapi): if name.startswith('test_'): - test = getattr(_testcapi, name) - test() + with self.subTest("internal", name=name): + test = getattr(_testcapi, name) + test() if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index a89d7e4f10..327c1455fc 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -4,6 +4,7 @@ import test.support, unittest import os +import shutil import sys import subprocess import tempfile @@ -41,8 +42,10 @@ class CmdLineTest(unittest.TestCase): def test_version(self): version = ('Python %d.%d' % sys.version_info[:2]).encode("ascii") - rc, out, err = assert_python_ok('-V') - self.assertTrue(err.startswith(version)) + for switch in '-V', '--version': + rc, out, err = assert_python_ok(switch) + self.assertFalse(err.startswith(version)) + self.assertTrue(out.startswith(version)) def test_verbose(self): # -v causes imports to write to stderr. If the write to @@ -54,14 +57,49 @@ class CmdLineTest(unittest.TestCase): self.assertNotIn(b'stack overflow', err) def test_xoptions(self): - rc, out, err = assert_python_ok('-c', 'import sys; print(sys._xoptions)') - opts = eval(out.splitlines()[0]) + def get_xoptions(*args): + # use subprocess module directly because test.script_helper adds + # "-X faulthandler" to the command line + args = (sys.executable, '-E') + args + args += ('-c', 'import sys; print(sys._xoptions)') + out = subprocess.check_output(args) + opts = eval(out.splitlines()[0]) + return opts + + opts = get_xoptions() self.assertEqual(opts, {}) - rc, out, err = assert_python_ok( - '-Xa', '-Xb=c,d=e', '-c', 'import sys; print(sys._xoptions)') - opts = eval(out.splitlines()[0]) + + opts = get_xoptions('-Xa', '-Xb=c,d=e') self.assertEqual(opts, {'a': True, 'b': 'c,d=e'}) + def test_showrefcount(self): + def run_python(*args): + # this is similar to assert_python_ok but doesn't strip + # the refcount from stderr. It can be replaced once + # assert_python_ok stops doing that. + cmd = [sys.executable] + cmd.extend(args) + PIPE = subprocess.PIPE + p = subprocess.Popen(cmd, stdout=PIPE, stderr=PIPE) + out, err = p.communicate() + p.stdout.close() + p.stderr.close() + rc = p.returncode + self.assertEqual(rc, 0) + return rc, out, err + code = 'import sys; print(sys._xoptions)' + # normally the refcount is hidden + rc, out, err = run_python('-c', code) + self.assertEqual(out.rstrip(), b'{}') + self.assertEqual(err, b'') + # "-X showrefcount" shows the refcount, but only in debug builds + rc, out, err = run_python('-X', 'showrefcount', '-c', code) + self.assertEqual(out.rstrip(), b"{'showrefcount': True}") + if hasattr(sys, 'gettotalrefcount'): # debug build + self.assertRegex(err, br'^\[\d+ refs, \d+ blocks\]') + else: + self.assertEqual(err, b'') + def test_run_module(self): # Test expected operation of the '-m' switch # Switch needs an argument @@ -70,9 +108,9 @@ class CmdLineTest(unittest.TestCase): assert_python_failure('-m', 'fnord43520xyz') # Check the runpy module also gives an error for # a nonexistent module - assert_python_failure('-m', 'runpy', 'fnord43520xyz'), + assert_python_failure('-m', 'runpy', 'fnord43520xyz') # All good if module is located and run successfully - assert_python_ok('-m', 'timeit', '-n', '1'), + assert_python_ok('-m', 'timeit', '-n', '1') def test_run_module_bug1764407(self): # -m and -i need to play well together @@ -213,6 +251,23 @@ class CmdLineTest(unittest.TestCase): self.assertIn(path1.encode('ascii'), out) self.assertIn(path2.encode('ascii'), out) + def test_empty_PYTHONPATH_issue16309(self): + # On Posix, it is documented that setting PATH to the + # empty string is equivalent to not setting PATH at all, + # which is an exception to the rule that in a string like + # "/bin::/usr/bin" the empty string in the middle gets + # interpreted as '.' + code = """if 1: + import sys + path = ":".join(sys.path) + path = path.encode("ascii", "backslashreplace") + sys.stdout.buffer.write(path)""" + rc1, out1, err1 = assert_python_ok('-c', code, PYTHONPATH="") + rc2, out2, err2 = assert_python_ok('-c', code, __isolated=False) + # regarding to Posix specification, outputs should be equal + # for empty and unset PYTHONPATH + self.assertEqual(out1, out2) + def test_displayhook_unencodable(self): for encoding in ('ascii', 'latin-1', 'utf-8'): env = os.environ.copy() @@ -290,7 +345,7 @@ class CmdLineTest(unittest.TestCase): rc, out, err = assert_python_ok('-c', code) self.assertEqual(b'', out) self.assertRegex(err.decode('ascii', 'ignore'), - 'Exception OSError: .* ignored') + 'Exception ignored in.*\nOSError: .*') def test_closed_stdout(self): # Issue #13444: if stdout has been explicitly closed, we should @@ -385,6 +440,31 @@ class CmdLineTest(unittest.TestCase): self.assertEqual(b'', out) + def test_isolatedmode(self): + self.verify_valid_flag('-I') + self.verify_valid_flag('-IEs') + rc, out, err = assert_python_ok('-I', '-c', + 'from sys import flags as f; ' + 'print(f.no_user_site, f.ignore_environment, f.isolated)', + # dummyvar to prevent extranous -E + dummyvar="") + self.assertEqual(out.strip(), b'1 1 1') + with test.support.temp_cwd() as tmpdir: + fake = os.path.join(tmpdir, "uuid.py") + main = os.path.join(tmpdir, "main.py") + with open(fake, "w") as f: + f.write("raise RuntimeError('isolated mode test')\n") + with open(main, "w") as f: + f.write("import uuid\n") + f.write("print('ok')\n") + self.assertRaises(subprocess.CalledProcessError, + subprocess.check_output, + [sys.executable, main], cwd=tmpdir, + stderr=subprocess.DEVNULL) + out = subprocess.check_output([sys.executable, "-I", main], + cwd=tmpdir) + self.assertEqual(out.strip(), b"ok") + def test_main(): test.support.run_unittest(CmdLineTest) test.support.reap_children() diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index b4269b994e..1e6746d434 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -41,11 +41,28 @@ from importlib.machinery import BuiltinImporter _loader = __loader__ if __loader__ is BuiltinImporter else type(__loader__) print('__loader__==%a' % _loader) print('__file__==%a' % __file__) -assertEqual(__cached__, None) +print('__cached__==%a' % __cached__) print('__package__==%r' % __package__) +# Check PEP 451 details +import os.path +if __package__ is not None: + print('__main__ was located through the import system') + assertIdentical(__spec__.loader, __loader__) + expected_spec_name = os.path.splitext(os.path.basename(__file__))[0] + if __package__: + expected_spec_name = __package__ + "." + expected_spec_name + assertEqual(__spec__.name, expected_spec_name) + assertEqual(__spec__.parent, __package__) + assertIdentical(__spec__.submodule_search_locations, None) + assertEqual(__spec__.origin, __file__) + if __spec__.cached is not None: + assertEqual(__spec__.cached, __cached__) # Check the sys module import sys assertIdentical(globals(), sys.modules[__name__].__dict__) +if __spec__ is not None: + # XXX: We're not currently making __main__ available under its real name + pass # assertIdentical(globals(), sys.modules[__spec__.name].__dict__) from test import test_cmd_line_script example_args_list = test_cmd_line_script.example_args assertEqual(sys.argv[1:], example_args_list) @@ -123,7 +140,7 @@ class CmdLineTest(unittest.TestCase): if not __debug__: cmd_line_switches += ('-' + 'O' * sys.flags.optimize,) run_args = cmd_line_switches + (script_name,) + tuple(example_args) - rc, out, err = assert_python_ok(*run_args) + rc, out, err = assert_python_ok(*run_args, __isolated=False) self._check_output(script_name, rc, out + err, expected_file, expected_argv0, expected_path0, expected_package, expected_loader) @@ -294,7 +311,7 @@ class CmdLineTest(unittest.TestCase): pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir, "import sys; print('init_argv0==%r' % sys.argv[0])") script_name = _make_test_script(pkg_dir, 'script') - rc, out, err = assert_python_ok('-m', 'test_pkg.script', *example_args) + rc, out, err = assert_python_ok('-m', 'test_pkg.script', *example_args, __isolated=False) if verbose > 1: print(out) expected = "init_argv0==%r" % '-m' @@ -311,7 +328,8 @@ class CmdLineTest(unittest.TestCase): with open("-c", "w") as f: f.write("data") rc, out, err = assert_python_ok('-c', - 'import sys; print("sys.path[0]==%r" % sys.path[0])') + 'import sys; print("sys.path[0]==%r" % sys.path[0])', + __isolated=False) if verbose > 1: print(out) expected = "sys.path[0]==%r" % '' @@ -325,7 +343,8 @@ class CmdLineTest(unittest.TestCase): with support.change_cwd(path=script_dir): with open("-m", "w") as f: f.write("data") - rc, out, err = assert_python_ok('-m', 'other', *example_args) + rc, out, err = assert_python_ok('-m', 'other', *example_args, + __isolated=False) self._check_output(script_name, rc, out, script_name, script_name, '', '', importlib.machinery.SourceFileLoader) @@ -386,6 +405,24 @@ class CmdLineTest(unittest.TestCase): 'stdout=%r stderr=%r' % (stdout, stderr)) self.assertEqual(0, rc) + def test_issue20500_exit_with_exception_value(self): + script = textwrap.dedent("""\ + import sys + error = None + try: + raise ValueError('some text') + except ValueError as err: + error = err + + if error: + sys.exit(error) + """) + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script', script) + exitcode, stdout, stderr = assert_python_failure(script_name) + text = stderr.decode('ascii') + self.assertEqual(text, "some text") + def test_main(): support.run_unittest(CmdLineTest) diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py index adef1701d4..5fd21dc32c 100644 --- a/Lib/test/test_code_module.py +++ b/Lib/test/test_code_module.py @@ -64,6 +64,20 @@ class TestInteractiveConsole(unittest.TestCase): self.console.interact() self.assertTrue(hook.called) + def test_banner(self): + # with banner + self.infunc.side_effect = EOFError('Finished') + self.console.interact(banner='Foo') + self.assertEqual(len(self.stderr.method_calls), 2) + banner_call = self.stderr.method_calls[0] + self.assertEqual(banner_call, ['write', ('Foo\n',), {}]) + + # no banner + self.stderr.reset_mock() + self.infunc.side_effect = EOFError('Finished') + self.console.interact(banner='') + self.assertEqual(len(self.stderr.method_calls), 1) + def test_main(): support.run_unittest(TestInteractiveConsole) diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index fd88505081..84804bb0da 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -875,8 +875,6 @@ class CodecCallbackTest(unittest.TestCase): with self.assertRaises(TypeError): data.decode(encoding, "test.replacing") -def test_main(): - test.support.run_unittest(CodecCallbackTest) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index a8b3da0f37..9b62d5b12f 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -1,4 +1,5 @@ import codecs +import contextlib import io import locale import sys @@ -342,8 +343,46 @@ class ReadTest(MixInCheckStateHandling): self.assertEqual(reader.readline(), s5) self.assertEqual(reader.readline(), "") + ill_formed_sequence_replace = "\ufffd" + + def test_lone_surrogates(self): + self.assertRaises(UnicodeEncodeError, "\ud800".encode, self.encoding) + self.assertEqual("[\uDC80]".encode(self.encoding, "backslashreplace"), + "[\\udc80]".encode(self.encoding)) + self.assertEqual("[\uDC80]".encode(self.encoding, "xmlcharrefreplace"), + "[�]".encode(self.encoding)) + self.assertEqual("[\uDC80]".encode(self.encoding, "ignore"), + "[]".encode(self.encoding)) + self.assertEqual("[\uDC80]".encode(self.encoding, "replace"), + "[?]".encode(self.encoding)) + + bom = "".encode(self.encoding) + for before, after in [("\U00010fff", "A"), ("[", "]"), + ("A", "\U00010fff")]: + before_sequence = before.encode(self.encoding)[len(bom):] + after_sequence = after.encode(self.encoding)[len(bom):] + test_string = before + "\uDC80" + after + test_sequence = (bom + before_sequence + + self.ill_formed_sequence + after_sequence) + self.assertRaises(UnicodeDecodeError, test_sequence.decode, + self.encoding) + self.assertEqual(test_string.encode(self.encoding, + "surrogatepass"), + test_sequence) + self.assertEqual(test_sequence.decode(self.encoding, + "surrogatepass"), + test_string) + self.assertEqual(test_sequence.decode(self.encoding, "ignore"), + before + after) + self.assertEqual(test_sequence.decode(self.encoding, "replace"), + before + self.ill_formed_sequence_replace + after) + class UTF32Test(ReadTest, unittest.TestCase): encoding = "utf-32" + if sys.byteorder == 'little': + ill_formed_sequence = b"\x80\xdc\x00\x00" + else: + ill_formed_sequence = b"\x00\x00\xdc\x80" spamle = (b'\xff\xfe\x00\x00' b's\x00\x00\x00p\x00\x00\x00a\x00\x00\x00m\x00\x00\x00' @@ -435,6 +474,7 @@ class UTF32Test(ReadTest, unittest.TestCase): class UTF32LETest(ReadTest, unittest.TestCase): encoding = "utf-32-le" + ill_formed_sequence = b"\x80\xdc\x00\x00" def test_partial(self): self.check_partial( @@ -479,6 +519,7 @@ class UTF32LETest(ReadTest, unittest.TestCase): class UTF32BETest(ReadTest, unittest.TestCase): encoding = "utf-32-be" + ill_formed_sequence = b"\x00\x00\xdc\x80" def test_partial(self): self.check_partial( @@ -524,6 +565,10 @@ class UTF32BETest(ReadTest, unittest.TestCase): class UTF16Test(ReadTest, unittest.TestCase): encoding = "utf-16" + if sys.byteorder == 'little': + ill_formed_sequence = b"\x80\xdc" + else: + ill_formed_sequence = b"\xdc\x80" spamle = b'\xff\xfes\x00p\x00a\x00m\x00s\x00p\x00a\x00m\x00' spambe = b'\xfe\xff\x00s\x00p\x00a\x00m\x00s\x00p\x00a\x00m' @@ -599,11 +644,14 @@ class UTF16Test(ReadTest, unittest.TestCase): self.addCleanup(support.unlink, support.TESTFN) with open(support.TESTFN, 'wb') as fp: fp.write(s) - with codecs.open(support.TESTFN, 'U', encoding=self.encoding) as reader: + with support.check_warnings(('', DeprecationWarning)): + reader = codecs.open(support.TESTFN, 'U', encoding=self.encoding) + with reader: self.assertEqual(reader.read(), s1) class UTF16LETest(ReadTest, unittest.TestCase): encoding = "utf-16-le" + ill_formed_sequence = b"\x80\xdc" def test_partial(self): self.check_partial( @@ -647,6 +695,7 @@ class UTF16LETest(ReadTest, unittest.TestCase): class UTF16BETest(ReadTest, unittest.TestCase): encoding = "utf-16-be" + ill_formed_sequence = b"\xdc\x80" def test_partial(self): self.check_partial( @@ -690,6 +739,8 @@ class UTF16BETest(ReadTest, unittest.TestCase): class UTF8Test(ReadTest, unittest.TestCase): encoding = "utf-8" + ill_formed_sequence = b"\xed\xb2\x80" + ill_formed_sequence_replace = "\ufffd" * 3 def test_partial(self): self.check_partial( @@ -719,18 +770,11 @@ class UTF8Test(ReadTest, unittest.TestCase): u, u.encode(self.encoding)) def test_lone_surrogates(self): - self.assertRaises(UnicodeEncodeError, "\ud800".encode, "utf-8") - self.assertRaises(UnicodeDecodeError, b"\xed\xa0\x80".decode, "utf-8") - self.assertEqual("[\uDC80]".encode("utf-8", "backslashreplace"), - b'[\\udc80]') - self.assertEqual("[\uDC80]".encode("utf-8", "xmlcharrefreplace"), - b'[�]') - self.assertEqual("[\uDC80]".encode("utf-8", "surrogateescape"), + super().test_lone_surrogates() + # not sure if this is making sense for + # UTF-16 and UTF-32 + self.assertEqual("[\uDC80]".encode('utf-8', "surrogateescape"), b'[\x80]') - self.assertEqual("[\uDC80]".encode("utf-8", "ignore"), - b'[]') - self.assertEqual("[\uDC80]".encode("utf-8", "replace"), - b'[?]') def test_surrogatepass_handler(self): self.assertEqual("abc\ud800def".encode("utf-8", "surrogatepass"), @@ -913,15 +957,19 @@ class UTF7Test(ReadTest, unittest.TestCase): (b'a+////,+IKw-b', 'a\uffff\ufffd\u20acb'), ] for raw, expected in tests: - self.assertRaises(UnicodeDecodeError, codecs.utf_7_decode, - raw, 'strict', True) - self.assertEqual(raw.decode('utf-7', 'replace'), expected) + with self.subTest(raw=raw): + self.assertRaises(UnicodeDecodeError, codecs.utf_7_decode, + raw, 'strict', True) + self.assertEqual(raw.decode('utf-7', 'replace'), expected) def test_nonbmp(self): self.assertEqual('\U000104A0'.encode(self.encoding), b'+2AHcoA-') self.assertEqual('\ud801\udca0'.encode(self.encoding), b'+2AHcoA-') self.assertEqual(b'+2AHcoA-'.decode(self.encoding), '\U000104A0') + test_lone_surrogates = None + + class UTF16ExTest(unittest.TestCase): def test_errors(self): @@ -946,7 +994,7 @@ class ReadBufferTest(unittest.TestCase): self.assertRaises(TypeError, codecs.readbuffer_encode) self.assertRaises(TypeError, codecs.readbuffer_encode, 42) -class UTF8SigTest(ReadTest, unittest.TestCase): +class UTF8SigTest(UTF8Test, unittest.TestCase): encoding = "utf-8-sig" def test_partial(self): @@ -1628,6 +1676,7 @@ all_unicode_encodings = [ "cp037", "cp1006", "cp1026", + "cp1125", "cp1140", "cp1250", "cp1251", @@ -2370,60 +2419,93 @@ bytes_transform_encodings = [ "quopri_codec", "hex_codec", ] + +transform_aliases = { + "base64_codec": ["base64", "base_64"], + "uu_codec": ["uu"], + "quopri_codec": ["quopri", "quoted_printable", "quotedprintable"], + "hex_codec": ["hex"], + "rot_13": ["rot13"], +} + try: import zlib except ImportError: - pass + zlib = None else: bytes_transform_encodings.append("zlib_codec") + transform_aliases["zlib_codec"] = ["zip", "zlib"] try: import bz2 except ImportError: pass else: bytes_transform_encodings.append("bz2_codec") + transform_aliases["bz2_codec"] = ["bz2"] class TransformCodecTest(unittest.TestCase): def test_basics(self): binput = bytes(range(256)) for encoding in bytes_transform_encodings: - # generic codecs interface - (o, size) = codecs.getencoder(encoding)(binput) - self.assertEqual(size, len(binput)) - (i, size) = codecs.getdecoder(encoding)(o) - self.assertEqual(size, len(o)) - self.assertEqual(i, binput) + with self.subTest(encoding=encoding): + # generic codecs interface + (o, size) = codecs.getencoder(encoding)(binput) + self.assertEqual(size, len(binput)) + (i, size) = codecs.getdecoder(encoding)(o) + self.assertEqual(size, len(o)) + self.assertEqual(i, binput) def test_read(self): for encoding in bytes_transform_encodings: - sin = codecs.encode(b"\x80", encoding) - reader = codecs.getreader(encoding)(io.BytesIO(sin)) - sout = reader.read() - self.assertEqual(sout, b"\x80") + with self.subTest(encoding=encoding): + sin = codecs.encode(b"\x80", encoding) + reader = codecs.getreader(encoding)(io.BytesIO(sin)) + sout = reader.read() + self.assertEqual(sout, b"\x80") def test_readline(self): for encoding in bytes_transform_encodings: - sin = codecs.encode(b"\x80", encoding) - reader = codecs.getreader(encoding)(io.BytesIO(sin)) - sout = reader.readline() - self.assertEqual(sout, b"\x80") + with self.subTest(encoding=encoding): + sin = codecs.encode(b"\x80", encoding) + reader = codecs.getreader(encoding)(io.BytesIO(sin)) + sout = reader.readline() + self.assertEqual(sout, b"\x80") + + def test_buffer_api_usage(self): + # We check all the transform codecs accept memoryview input + # for encoding and decoding + # and also that they roundtrip correctly + original = b"12345\x80" + for encoding in bytes_transform_encodings: + with self.subTest(encoding=encoding): + data = original + view = memoryview(data) + data = codecs.encode(data, encoding) + view_encoded = codecs.encode(view, encoding) + self.assertEqual(view_encoded, data) + view = memoryview(data) + data = codecs.decode(data, encoding) + self.assertEqual(data, original) + view_decoded = codecs.decode(view, encoding) + self.assertEqual(view_decoded, data) def test_text_to_binary_blacklists_binary_transforms(self): # Check binary -> binary codecs give a good error for str input bad_input = "bad input type" for encoding in bytes_transform_encodings: - fmt = (r"{!r} is not a text encoding; " - r"use codecs.encode\(\) to handle arbitrary codecs") - msg = fmt.format(encoding) - with self.assertRaisesRegex(LookupError, msg) as failure: - bad_input.encode(encoding) - self.assertIsNone(failure.exception.__cause__) + with self.subTest(encoding=encoding): + fmt = ( "{!r} is not a text encoding; " + "use codecs.encode\(\) to handle arbitrary codecs") + msg = fmt.format(encoding) + with self.assertRaisesRegex(LookupError, msg) as failure: + bad_input.encode(encoding) + self.assertIsNone(failure.exception.__cause__) def test_text_to_binary_blacklists_text_transforms(self): # Check str.encode gives a good error message for str -> str codecs msg = (r"^'rot_13' is not a text encoding; " - r"use codecs.encode\(\) to handle arbitrary codecs") + "use codecs.encode\(\) to handle arbitrary codecs") with self.assertRaisesRegex(LookupError, msg): "just an example message".encode("rot_13") @@ -2432,23 +2514,224 @@ class TransformCodecTest(unittest.TestCase): # message for binary -> binary codecs data = b"encode first to ensure we meet any format restrictions" for encoding in bytes_transform_encodings: - encoded_data = codecs.encode(data, encoding) - fmt = (r"{!r} is not a text encoding; " - r"use codecs.decode\(\) to handle arbitrary codecs") - msg = fmt.format(encoding) - with self.assertRaisesRegex(LookupError, msg): - encoded_data.decode(encoding) - with self.assertRaisesRegex(LookupError, msg): - bytearray(encoded_data).decode(encoding) + with self.subTest(encoding=encoding): + encoded_data = codecs.encode(data, encoding) + fmt = (r"{!r} is not a text encoding; " + "use codecs.decode\(\) to handle arbitrary codecs") + msg = fmt.format(encoding) + with self.assertRaisesRegex(LookupError, msg): + encoded_data.decode(encoding) + with self.assertRaisesRegex(LookupError, msg): + bytearray(encoded_data).decode(encoding) def test_binary_to_text_blacklists_text_transforms(self): # Check str -> str codec gives a good error for binary input for bad_input in (b"immutable", bytearray(b"mutable")): - msg = (r"^'rot_13' is not a text encoding; " - r"use codecs.decode\(\) to handle arbitrary codecs") - with self.assertRaisesRegex(LookupError, msg) as failure: - bad_input.decode("rot_13") - self.assertIsNone(failure.exception.__cause__) + with self.subTest(bad_input=bad_input): + msg = (r"^'rot_13' is not a text encoding; " + "use codecs.decode\(\) to handle arbitrary codecs") + with self.assertRaisesRegex(LookupError, msg) as failure: + bad_input.decode("rot_13") + self.assertIsNone(failure.exception.__cause__) + + @unittest.skipUnless(zlib, "Requires zlib support") + def test_custom_zlib_error_is_wrapped(self): + # Check zlib codec gives a good error for malformed input + msg = "^decoding with 'zlib_codec' codec failed" + with self.assertRaisesRegex(Exception, msg) as failure: + codecs.decode(b"hello", "zlib_codec") + self.assertIsInstance(failure.exception.__cause__, + type(failure.exception)) + + def test_custom_hex_error_is_wrapped(self): + # Check hex codec gives a good error for malformed input + msg = "^decoding with 'hex_codec' codec failed" + with self.assertRaisesRegex(Exception, msg) as failure: + codecs.decode(b"hello", "hex_codec") + self.assertIsInstance(failure.exception.__cause__, + type(failure.exception)) + + # Unfortunately, the bz2 module throws OSError, which the codec + # machinery currently can't wrap :( + + # Ensure codec aliases from http://bugs.python.org/issue7475 work + def test_aliases(self): + for codec_name, aliases in transform_aliases.items(): + expected_name = codecs.lookup(codec_name).name + for alias in aliases: + with self.subTest(alias=alias): + info = codecs.lookup(alias) + self.assertEqual(info.name, expected_name) + + +# The codec system tries to wrap exceptions in order to ensure the error +# mentions the operation being performed and the codec involved. We +# currently *only* want this to happen for relatively stateless +# exceptions, where the only significant information they contain is their +# type and a single str argument. + +# Use a local codec registry to avoid appearing to leak objects when +# registering multiple seach functions +_TEST_CODECS = {} + +def _get_test_codec(codec_name): + return _TEST_CODECS.get(codec_name) +codecs.register(_get_test_codec) # Returns None, not usable as a decorator + +class ExceptionChainingTest(unittest.TestCase): + + def setUp(self): + # There's no way to unregister a codec search function, so we just + # ensure we render this one fairly harmless after the test + # case finishes by using the test case repr as the codec name + # The codecs module normalizes codec names, although this doesn't + # appear to be formally documented... + # We also make sure we use a truly unique id for the custom codec + # to avoid issues with the codec cache when running these tests + # multiple times (e.g. when hunting for refleaks) + unique_id = repr(self) + str(id(self)) + self.codec_name = encodings.normalize_encoding(unique_id).lower() + + # We store the object to raise on the instance because of a bad + # interaction between the codec caching (which means we can't + # recreate the codec entry) and regrtest refleak hunting (which + # runs the same test instance multiple times). This means we + # need to ensure the codecs call back in to the instance to find + # out which exception to raise rather than binding them in a + # closure to an object that may change on the next run + self.obj_to_raise = RuntimeError + + def tearDown(self): + _TEST_CODECS.pop(self.codec_name, None) + + def set_codec(self, encode, decode): + codec_info = codecs.CodecInfo(encode, decode, + name=self.codec_name) + _TEST_CODECS[self.codec_name] = codec_info + + @contextlib.contextmanager + def assertWrapped(self, operation, exc_type, msg): + full_msg = r"{} with {!r} codec failed \({}: {}\)".format( + operation, self.codec_name, exc_type.__name__, msg) + with self.assertRaisesRegex(exc_type, full_msg) as caught: + yield caught + self.assertIsInstance(caught.exception.__cause__, exc_type) + self.assertIsNotNone(caught.exception.__cause__.__traceback__) + + def raise_obj(self, *args, **kwds): + # Helper to dynamically change the object raised by a test codec + raise self.obj_to_raise + + def check_wrapped(self, obj_to_raise, msg, exc_type=RuntimeError): + self.obj_to_raise = obj_to_raise + self.set_codec(self.raise_obj, self.raise_obj) + with self.assertWrapped("encoding", exc_type, msg): + "str_input".encode(self.codec_name) + with self.assertWrapped("encoding", exc_type, msg): + codecs.encode("str_input", self.codec_name) + with self.assertWrapped("decoding", exc_type, msg): + b"bytes input".decode(self.codec_name) + with self.assertWrapped("decoding", exc_type, msg): + codecs.decode(b"bytes input", self.codec_name) + + def test_raise_by_type(self): + self.check_wrapped(RuntimeError, "") + + def test_raise_by_value(self): + msg = "This should be wrapped" + self.check_wrapped(RuntimeError(msg), msg) + + def test_raise_grandchild_subclass_exact_size(self): + msg = "This should be wrapped" + class MyRuntimeError(RuntimeError): + __slots__ = () + self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError) + + def test_raise_subclass_with_weakref_support(self): + msg = "This should be wrapped" + class MyRuntimeError(RuntimeError): + pass + self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError) + + def check_not_wrapped(self, obj_to_raise, msg): + def raise_obj(*args, **kwds): + raise obj_to_raise + self.set_codec(raise_obj, raise_obj) + with self.assertRaisesRegex(RuntimeError, msg): + "str input".encode(self.codec_name) + with self.assertRaisesRegex(RuntimeError, msg): + codecs.encode("str input", self.codec_name) + with self.assertRaisesRegex(RuntimeError, msg): + b"bytes input".decode(self.codec_name) + with self.assertRaisesRegex(RuntimeError, msg): + codecs.decode(b"bytes input", self.codec_name) + + def test_init_override_is_not_wrapped(self): + class CustomInit(RuntimeError): + def __init__(self): + pass + self.check_not_wrapped(CustomInit, "") + + def test_new_override_is_not_wrapped(self): + class CustomNew(RuntimeError): + def __new__(cls): + return super().__new__(cls) + self.check_not_wrapped(CustomNew, "") + + def test_instance_attribute_is_not_wrapped(self): + msg = "This should NOT be wrapped" + exc = RuntimeError(msg) + exc.attr = 1 + self.check_not_wrapped(exc, "^{}$".format(msg)) + + def test_non_str_arg_is_not_wrapped(self): + self.check_not_wrapped(RuntimeError(1), "1") + + def test_multiple_args_is_not_wrapped(self): + msg_re = r"^\('a', 'b', 'c'\)$" + self.check_not_wrapped(RuntimeError('a', 'b', 'c'), msg_re) + + # http://bugs.python.org/issue19609 + def test_codec_lookup_failure_not_wrapped(self): + msg = "^unknown encoding: {}$".format(self.codec_name) + # The initial codec lookup should not be wrapped + with self.assertRaisesRegex(LookupError, msg): + "str input".encode(self.codec_name) + with self.assertRaisesRegex(LookupError, msg): + codecs.encode("str input", self.codec_name) + with self.assertRaisesRegex(LookupError, msg): + b"bytes input".decode(self.codec_name) + with self.assertRaisesRegex(LookupError, msg): + codecs.decode(b"bytes input", self.codec_name) + + def test_unflagged_non_text_codec_handling(self): + # The stdlib non-text codecs are now marked so they're + # pre-emptively skipped by the text model related methods + # However, third party codecs won't be flagged, so we still make + # sure the case where an inappropriate output type is produced is + # handled appropriately + def encode_to_str(*args, **kwds): + return "not bytes!", 0 + def decode_to_bytes(*args, **kwds): + return b"not str!", 0 + self.set_codec(encode_to_str, decode_to_bytes) + # No input or output type checks on the codecs module functions + encoded = codecs.encode(None, self.codec_name) + self.assertEqual(encoded, "not bytes!") + decoded = codecs.decode(None, self.codec_name) + self.assertEqual(decoded, b"not str!") + # Text model methods should complain + fmt = (r"^{!r} encoder returned 'str' instead of 'bytes'; " + "use codecs.encode\(\) to encode to arbitrary types$") + msg = fmt.format(self.codec_name) + with self.assertRaisesRegex(TypeError, msg): + "str_input".encode(self.codec_name) + fmt = (r"^{!r} decoder returned 'bytes' instead of 'str'; " + "use codecs.decode\(\) to decode to arbitrary types$") + msg = fmt.format(self.codec_name) + with self.assertRaisesRegex(TypeError, msg): + b"bytes input".decode(self.codec_name) + @unittest.skipUnless(sys.platform == 'win32', @@ -2460,8 +2743,8 @@ class CodePageTest(unittest.TestCase): def test_invalid_code_page(self): self.assertRaises(ValueError, codecs.code_page_encode, -1, 'a') self.assertRaises(ValueError, codecs.code_page_decode, -1, b'a') - self.assertRaises(WindowsError, codecs.code_page_encode, 123, 'a') - self.assertRaises(WindowsError, codecs.code_page_decode, 123, b'a') + self.assertRaises(OSError, codecs.code_page_encode, 123, 'a') + self.assertRaises(OSError, codecs.code_page_decode, 123, b'a') def test_code_page_name(self): self.assertRaisesRegex(UnicodeEncodeError, 'cp932', diff --git a/Lib/test/test_coding.py b/Lib/test/test_coding.py deleted file mode 100644 index 989c7a85d2..0000000000 --- a/Lib/test/test_coding.py +++ /dev/null @@ -1,63 +0,0 @@ -import unittest -from test.support import TESTFN, unlink, unload -import importlib, os, sys - -class CodingTest(unittest.TestCase): - def test_bad_coding(self): - module_name = 'bad_coding' - self.verify_bad_module(module_name) - - def test_bad_coding2(self): - module_name = 'bad_coding2' - self.verify_bad_module(module_name) - - def verify_bad_module(self, module_name): - self.assertRaises(SyntaxError, __import__, 'test.' + module_name) - - path = os.path.dirname(__file__) - filename = os.path.join(path, module_name + '.py') - with open(filename, "rb") as fp: - bytes = fp.read() - self.assertRaises(SyntaxError, compile, bytes, filename, 'exec') - - def test_exec_valid_coding(self): - d = {} - exec(b'# coding: cp949\na = "\xaa\xa7"\n', d) - self.assertEqual(d['a'], '\u3047') - - def test_file_parse(self): - # issue1134: all encodings outside latin-1 and utf-8 fail on - # multiline strings and long lines (>512 columns) - unload(TESTFN) - filename = TESTFN + ".py" - f = open(filename, "w", encoding="cp1252") - sys.path.insert(0, os.curdir) - try: - with f: - f.write("# -*- coding: cp1252 -*-\n") - f.write("'''A short string\n") - f.write("'''\n") - f.write("'A very long string %s'\n" % ("X" * 1000)) - - importlib.invalidate_caches() - __import__(TESTFN) - finally: - del sys.path[0] - unlink(filename) - unlink(filename + "c") - unlink(filename + "o") - unload(TESTFN) - - def test_error_from_string(self): - # See http://bugs.python.org/issue6289 - input = "# coding: ascii\n\N{SNOWMAN}".encode('utf-8') - with self.assertRaises(SyntaxError) as c: - compile(input, "<string>", "exec") - expected = "'ascii' codec can't decode byte 0xe2 in position 16: " \ - "ordinal not in range(128)" - self.assertTrue(c.exception.args[0].startswith(expected), - msg=c.exception.args[0]) - - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index ff52755354..ee28a6c0b3 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -112,6 +112,38 @@ class TestChainMap(unittest.TestCase): self.assertEqual(dict(d), dict(a=1, b=2, c=30)) self.assertEqual(dict(d.items()), dict(a=1, b=2, c=30)) + def test_new_child(self): + 'Tests for changes for issue #16613.' + c = ChainMap() + c['a'] = 1 + c['b'] = 2 + m = {'b':20, 'c': 30} + d = c.new_child(m) + self.assertEqual(d.maps, [{'b':20, 'c':30}, {'a':1, 'b':2}]) # check internal state + self.assertIs(m, d.maps[0]) + + # Use a different map than a dict + class lowerdict(dict): + def __getitem__(self, key): + if isinstance(key, str): + key = key.lower() + return dict.__getitem__(self, key) + def __contains__(self, key): + if isinstance(key, str): + key = key.lower() + return dict.__contains__(self, key) + + c = ChainMap() + c['a'] = 1 + c['b'] = 2 + m = lowerdict(b=20, c=30) + d = c.new_child(m) + self.assertIs(m, d.maps[0]) + for key in 'abc': # check contains + self.assertIn(key, d) + for k, v in dict(a=1, B=20, C=30, z=100).items(): # check get + self.assertEqual(d.get(k, 100), v) + ################################################################################ ### Named Tuples @@ -269,7 +301,7 @@ class TestNamedTuple(unittest.TestCase): for module in (pickle,): loads = getattr(module, 'loads') dumps = getattr(module, 'dumps') - for protocol in -1, 0, 1, 2: + for protocol in range(-1, module.HIGHEST_PROTOCOL + 1): q = loads(dumps(p, protocol)) self.assertEqual(p, q) self.assertEqual(p._fields, q._fields) @@ -750,6 +782,8 @@ class TestCollectionABCs(ABCTestCase): self.assertTrue(issubclass(sample, Sequence)) self.assertIsInstance(range(10), Sequence) self.assertTrue(issubclass(range, Sequence)) + self.assertIsInstance(memoryview(b""), Sequence) + self.assertTrue(issubclass(memoryview, Sequence)) self.assertTrue(issubclass(str, Sequence)) self.validate_abstract_methods(Sequence, '__contains__', '__iter__', '__len__', '__getitem__') @@ -904,23 +938,26 @@ class TestCounter(unittest.TestCase): words = Counter('which witch had which witches wrist watch'.split()) update_test = Counter() update_test.update(words) - for i, dup in enumerate([ - words.copy(), - copy.copy(words), - copy.deepcopy(words), - pickle.loads(pickle.dumps(words, 0)), - pickle.loads(pickle.dumps(words, 1)), - pickle.loads(pickle.dumps(words, 2)), - pickle.loads(pickle.dumps(words, -1)), - eval(repr(words)), - update_test, - Counter(words), - ]): - msg = (i, dup, words) - self.assertTrue(dup is not words) - self.assertEqual(dup, words) - self.assertEqual(len(dup), len(words)) - self.assertEqual(type(dup), type(words)) + for label, dup in [ + ('words.copy()', words.copy()), + ('copy.copy(words)', copy.copy(words)), + ('copy.deepcopy(words)', copy.deepcopy(words)), + ('pickle.loads(pickle.dumps(words, 0))', + pickle.loads(pickle.dumps(words, 0))), + ('pickle.loads(pickle.dumps(words, 1))', + pickle.loads(pickle.dumps(words, 1))), + ('pickle.loads(pickle.dumps(words, 2))', + pickle.loads(pickle.dumps(words, 2))), + ('pickle.loads(pickle.dumps(words, -1))', + pickle.loads(pickle.dumps(words, -1))), + ('eval(repr(words))', eval(repr(words))), + ('update_test', update_test), + ('Counter(words)', Counter(words)), + ]: + with self.subTest(label=label): + msg = "\ncopy: %s\nwords: %s" % (dup, words) + self.assertIsNot(dup, words, msg) + self.assertEqual(dup, words) def test_copy_subclass(self): class MyCounter(Counter): @@ -1043,8 +1080,10 @@ class TestCounter(unittest.TestCase): # test fidelity to the pure python version c = CounterSubclassWithSetItem('abracadabra') self.assertTrue(c.called) + self.assertEqual(dict(c), {'a': 5, 'b': 2, 'c': 1, 'd': 1, 'r':2 }) c = CounterSubclassWithGet('abracadabra') self.assertTrue(c.called) + self.assertEqual(dict(c), {'a': 5, 'b': 2, 'c': 1, 'd': 1, 'r':2 }) ################################################################################ @@ -1205,24 +1244,28 @@ class TestOrderedDict(unittest.TestCase): od = OrderedDict(pairs) update_test = OrderedDict() update_test.update(od) - for i, dup in enumerate([ - od.copy(), - copy.copy(od), - copy.deepcopy(od), - pickle.loads(pickle.dumps(od, 0)), - pickle.loads(pickle.dumps(od, 1)), - pickle.loads(pickle.dumps(od, 2)), - pickle.loads(pickle.dumps(od, 3)), - pickle.loads(pickle.dumps(od, -1)), - eval(repr(od)), - update_test, - OrderedDict(od), - ]): - self.assertTrue(dup is not od) - self.assertEqual(dup, od) - self.assertEqual(list(dup.items()), list(od.items())) - self.assertEqual(len(dup), len(od)) - self.assertEqual(type(dup), type(od)) + for label, dup in [ + ('od.copy()', od.copy()), + ('copy.copy(od)', copy.copy(od)), + ('copy.deepcopy(od)', copy.deepcopy(od)), + ('pickle.loads(pickle.dumps(od, 0))', + pickle.loads(pickle.dumps(od, 0))), + ('pickle.loads(pickle.dumps(od, 1))', + pickle.loads(pickle.dumps(od, 1))), + ('pickle.loads(pickle.dumps(od, 2))', + pickle.loads(pickle.dumps(od, 2))), + ('pickle.loads(pickle.dumps(od, 3))', + pickle.loads(pickle.dumps(od, 3))), + ('pickle.loads(pickle.dumps(od, -1))', + pickle.loads(pickle.dumps(od, -1))), + ('eval(repr(od))', eval(repr(od))), + ('update_test', update_test), + ('OrderedDict(od)', OrderedDict(od)), + ]: + with self.subTest(label=label): + msg = "\ncopy: %s\nod: %s" % (dup, od) + self.assertIsNot(dup, od, msg) + self.assertEqual(dup, od) def test_yaml_linkage(self): # Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature. @@ -1237,9 +1280,18 @@ class TestOrderedDict(unittest.TestCase): # do not save instance dictionary if not needed pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] od = OrderedDict(pairs) - self.assertEqual(len(od.__reduce__()), 2) + self.assertIsNone(od.__reduce__()[2]) od.x = 10 - self.assertEqual(len(od.__reduce__()), 3) + self.assertIsNotNone(od.__reduce__()[2]) + + def test_pickle_recursive(self): + od = OrderedDict() + od[1] = od + for proto in range(-1, pickle.HIGHEST_PROTOCOL + 1): + dup = pickle.loads(pickle.dumps(od, proto)) + self.assertIsNot(dup, od) + self.assertEqual(list(dup.keys()), [1]) + self.assertIs(dup[1], dup) def test_repr(self): od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]) diff --git a/Lib/test/test_colorsys.py b/Lib/test/test_colorsys.py index e405b8a4d0..a24e3adcb4 100644 --- a/Lib/test/test_colorsys.py +++ b/Lib/test/test_colorsys.py @@ -1,4 +1,4 @@ -import unittest, test.support +import unittest import colorsys def frange(start, stop, step): @@ -69,8 +69,32 @@ class ColorsysTest(unittest.TestCase): self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb)) self.assertTripleEqual(rgb, colorsys.hls_to_rgb(*hls)) -def test_main(): - test.support.run_unittest(ColorsysTest) + def test_yiq_roundtrip(self): + for r in frange(0.0, 1.0, 0.2): + for g in frange(0.0, 1.0, 0.2): + for b in frange(0.0, 1.0, 0.2): + rgb = (r, g, b) + self.assertTripleEqual( + rgb, + colorsys.yiq_to_rgb(*colorsys.rgb_to_yiq(*rgb)) + ) + + def test_yiq_values(self): + values = [ + # rgb, yiq + ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0)), # black + ((0.0, 0.0, 1.0), (0.11, -0.3217, 0.3121)), # blue + ((0.0, 1.0, 0.0), (0.59, -0.2773, -0.5251)), # green + ((0.0, 1.0, 1.0), (0.7, -0.599, -0.213)), # cyan + ((1.0, 0.0, 0.0), (0.3, 0.599, 0.213)), # red + ((1.0, 0.0, 1.0), (0.41, 0.2773, 0.5251)), # purple + ((1.0, 1.0, 0.0), (0.89, 0.3217, -0.3121)), # yellow + ((1.0, 1.0, 1.0), (1.0, 0.0, 0.0)), # white + ((0.5, 0.5, 0.5), (0.5, 0.0, 0.0)), # grey + ] + for (rgb, yiq) in values: + self.assertTripleEqual(yiq, colorsys.rgb_to_yiq(*rgb)) + self.assertTripleEqual(rgb, colorsys.yiq_to_rgb(*yiq)) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py index 0505c521d1..2a42238755 100644 --- a/Lib/test/test_compileall.py +++ b/Lib/test/test_compileall.py @@ -1,11 +1,10 @@ import sys import compileall -import imp +import importlib.util import os import py_compile import shutil import struct -import subprocess import tempfile import time import unittest @@ -18,11 +17,11 @@ class CompileallTests(unittest.TestCase): def setUp(self): self.directory = tempfile.mkdtemp() self.source_path = os.path.join(self.directory, '_test.py') - self.bc_path = imp.cache_from_source(self.source_path) + self.bc_path = importlib.util.cache_from_source(self.source_path) with open(self.source_path, 'w') as file: file.write('x = 123\n') self.source_path2 = os.path.join(self.directory, '_test2.py') - self.bc_path2 = imp.cache_from_source(self.source_path2) + self.bc_path2 = importlib.util.cache_from_source(self.source_path2) shutil.copyfile(self.source_path, self.source_path2) self.subdirectory = os.path.join(self.directory, '_subdir') os.mkdir(self.subdirectory) @@ -36,7 +35,7 @@ class CompileallTests(unittest.TestCase): with open(self.bc_path, 'rb') as file: data = file.read(8) mtime = int(os.stat(self.source_path).st_mtime) - compare = struct.pack('<4sl', imp.get_magic(), mtime) + compare = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, mtime) return data, compare @unittest.skipUnless(hasattr(os, 'stat'), 'test needs os.stat()') @@ -56,7 +55,8 @@ class CompileallTests(unittest.TestCase): def test_mtime(self): # Test a change in mtime leads to a new .pyc. - self.recreation_check(struct.pack('<4sl', imp.get_magic(), 1)) + self.recreation_check(struct.pack('<4sl', importlib.util.MAGIC_NUMBER, + 1)) def test_magic_number(self): # Test a change in mtime leads to a new .pyc. @@ -96,14 +96,14 @@ class CompileallTests(unittest.TestCase): # interpreter's creates the correct file names optimize = 1 if __debug__ else 0 compileall.compile_dir(self.directory, quiet=True, optimize=optimize) - cached = imp.cache_from_source(self.source_path, - debug_override=not optimize) + cached = importlib.util.cache_from_source(self.source_path, + debug_override=not optimize) self.assertTrue(os.path.isfile(cached)) - cached2 = imp.cache_from_source(self.source_path2, - debug_override=not optimize) + cached2 = importlib.util.cache_from_source(self.source_path2, + debug_override=not optimize) self.assertTrue(os.path.isfile(cached2)) - cached3 = imp.cache_from_source(self.source_path3, - debug_override=not optimize) + cached3 = importlib.util.cache_from_source(self.source_path3, + debug_override=not optimize) self.assertTrue(os.path.isfile(cached3)) @@ -151,10 +151,12 @@ class CommandLineTests(unittest.TestCase): return rc, out, err def assertCompiled(self, fn): - self.assertTrue(os.path.exists(imp.cache_from_source(fn))) + path = importlib.util.cache_from_source(fn) + self.assertTrue(os.path.exists(path)) def assertNotCompiled(self, fn): - self.assertFalse(os.path.exists(imp.cache_from_source(fn))) + path = importlib.util.cache_from_source(fn) + self.assertFalse(os.path.exists(path)) def setUp(self): self.addCleanup(self._cleanup) @@ -180,7 +182,7 @@ class CommandLineTests(unittest.TestCase): def test_no_args_respects_force_flag(self): bazfn = script_helper.make_script(self.directory, 'baz', '') self.assertRunOK(PYTHONPATH=self.directory) - pycpath = imp.cache_from_source(bazfn) + pycpath = importlib.util.cache_from_source(bazfn) # Set atime/mtime backward to avoid file timestamp resolution issues os.utime(pycpath, (time.time()-60,)*2) mtime = os.stat(pycpath).st_mtime @@ -212,8 +214,8 @@ class CommandLineTests(unittest.TestCase): ['-m', 'compileall', '-q', self.pkgdir])) # Verify the __pycache__ directory contents. self.assertTrue(os.path.exists(self.pkgdir_cachedir)) - expected = sorted(base.format(imp.get_tag(), ext) for base in - ('__init__.{}.{}', 'bar.{}.{}')) + expected = sorted(base.format(sys.implementation.cache_tag, ext) + for base in ('__init__.{}.{}', 'bar.{}.{}')) self.assertEqual(sorted(os.listdir(self.pkgdir_cachedir)), expected) # Make sure there are no .pyc files in the source directory. self.assertFalse([fn for fn in os.listdir(self.pkgdir) @@ -246,7 +248,7 @@ class CommandLineTests(unittest.TestCase): def test_force(self): self.assertRunOK('-q', self.pkgdir) - pycpath = imp.cache_from_source(self.barfn) + pycpath = importlib.util.cache_from_source(self.barfn) # set atime/mtime backward to avoid file timestamp resolution issues os.utime(pycpath, (time.time()-60,)*2) mtime = os.stat(pycpath).st_mtime @@ -310,10 +312,10 @@ class CommandLineTests(unittest.TestCase): bazfn = script_helper.make_script(self.pkgdir, 'baz', 'raise Exception') self.assertRunOK('-q', '-d', 'dinsdale', self.pkgdir) fn = script_helper.make_script(self.pkgdir, 'bing', 'import baz') - pyc = imp.cache_from_source(bazfn) + pyc = importlib.util.cache_from_source(bazfn) os.rename(pyc, os.path.join(self.pkgdir, 'baz.pyc')) os.remove(bazfn) - rc, out, err = script_helper.assert_python_failure(fn) + rc, out, err = script_helper.assert_python_failure(fn, __isolated=False) self.assertRegex(err, b'File "dinsdale') def test_include_bad_file(self): @@ -321,7 +323,7 @@ class CommandLineTests(unittest.TestCase): '-i', os.path.join(self.directory, 'nosuchfile'), self.pkgdir) self.assertRegex(out, b'rror.*nosuchfile') self.assertNotRegex(err, b'Traceback') - self.assertFalse(os.path.exists(imp.cache_from_source( + self.assertFalse(os.path.exists(importlib.util.cache_from_source( self.pkgdir_cachedir))) def test_include_file_with_arg(self): @@ -378,13 +380,5 @@ class CommandLineTests(unittest.TestCase): self.assertRegex(out, b"Can't list 'badfilename'") -def test_main(): - support.run_unittest( - CommandLineTests, - CompileallTests, - EncodingTest, - ) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py index 1bf409798b..cd55375bdb 100644 --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -220,6 +220,8 @@ class ComplexTest(unittest.TestCase): self.assertRaises(TypeError, complex, OS(None)) self.assertRaises(TypeError, complex, NS(None)) self.assertRaises(TypeError, complex, {}) + self.assertRaises(TypeError, complex, NS(1.5)) + self.assertRaises(TypeError, complex, NS(1)) self.assertAlmostEqual(complex("1+10j"), 1+10j) self.assertAlmostEqual(complex(10), 10+0j) @@ -301,6 +303,7 @@ class ComplexTest(unittest.TestCase): self.assertRaises(TypeError, float, 5+3j) self.assertRaises(ValueError, complex, "") self.assertRaises(TypeError, complex, None) + self.assertRaisesRegex(TypeError, "not 'NoneType'", complex, None) self.assertRaises(ValueError, complex, "\0") self.assertRaises(ValueError, complex, "3\09") self.assertRaises(TypeError, complex, "1", "2") diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py index 04c4c3717e..c74b2ca6ed 100644 --- a/Lib/test/test_concurrent_futures.py +++ b/Lib/test/test_concurrent_futures.py @@ -15,6 +15,7 @@ import sys import threading import time import unittest +import weakref from concurrent import futures from concurrent.futures._base import ( @@ -52,6 +53,11 @@ def sleep_and_print(t, msg): sys.stdout.flush() +class MyObject(object): + def my_method(self): + pass + + class ExecutorMixin: worker_count = 5 @@ -403,6 +409,22 @@ class ExecutorTest: self.executor.map(str, [2] * (self.worker_count + 1)) self.executor.shutdown() + @test.support.cpython_only + def test_no_stale_references(self): + # Issue #16284: check that the executors don't unnecessarily hang onto + # references. + my_object = MyObject() + my_object_collected = threading.Event() + my_object_callback = weakref.ref( + my_object, lambda obj: my_object_collected.set()) + # Deliberately discarding the future. + self.executor.submit(my_object.my_method) + del my_object + + collected = my_object_collected.wait(timeout=5.0) + self.assertTrue(collected, + "Stale reference not collected within timeout.") + class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, unittest.TestCase): def test_map_submits_without_iteration(self): diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index e38d043a5b..39cc776dbc 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -1,5 +1,6 @@ """Unit tests for contextlib.py, and other context managers.""" +import io import sys import tempfile import unittest @@ -100,15 +101,25 @@ class ContextManagerTestCase(unittest.TestCase): self.assertEqual(baz.__name__,'baz') self.assertEqual(baz.foo, 'bar') - @unittest.skipIf(sys.flags.optimize >= 2, - "Docstrings are omitted with -O2 and above") + @support.requires_docstrings def test_contextmanager_doc_attrib(self): baz = self._create_contextmanager_attribs() self.assertEqual(baz.__doc__, "Whee!") + @support.requires_docstrings + def test_instance_docstring_given_cm_docstring(self): + baz = self._create_contextmanager_attribs()(None) + self.assertEqual(baz.__doc__, "Whee!") + + class ClosingTestCase(unittest.TestCase): - # XXX This needs more work + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = closing.__doc__ + obj = closing(None) + self.assertEqual(obj.__doc__, cm_docstring) def test_closing(self): state = [] @@ -204,6 +215,7 @@ class LockContextTestCase(unittest.TestCase): class mycontext(ContextDecorator): + """Example decoration-compatible context manager for testing""" started = False exc = None catch = False @@ -219,6 +231,13 @@ class mycontext(ContextDecorator): class TestContextDecorator(unittest.TestCase): + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = mycontext.__doc__ + obj = mycontext() + self.assertEqual(obj.__doc__, cm_docstring) + def test_contextdecorator(self): context = mycontext() with context as result: @@ -372,6 +391,13 @@ class TestContextDecorator(unittest.TestCase): class TestExitStack(unittest.TestCase): + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = ExitStack.__doc__ + obj = ExitStack() + self.assertEqual(obj.__doc__, cm_docstring) + def test_no_resources(self): with ExitStack(): pass @@ -692,10 +718,111 @@ class TestExitStack(unittest.TestCase): stack.push(cm) self.assertIs(stack._exit_callbacks[-1], cm) +class TestRedirectStdout(unittest.TestCase): + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = redirect_stdout.__doc__ + obj = redirect_stdout(None) + self.assertEqual(obj.__doc__, cm_docstring) + + def test_no_redirect_in_init(self): + orig_stdout = sys.stdout + redirect_stdout(None) + self.assertIs(sys.stdout, orig_stdout) + + def test_redirect_to_string_io(self): + f = io.StringIO() + msg = "Consider an API like help(), which prints directly to stdout" + orig_stdout = sys.stdout + with redirect_stdout(f): + print(msg) + self.assertIs(sys.stdout, orig_stdout) + s = f.getvalue().strip() + self.assertEqual(s, msg) + + def test_enter_result_is_target(self): + f = io.StringIO() + with redirect_stdout(f) as enter_result: + self.assertIs(enter_result, f) + + def test_cm_is_reusable(self): + f = io.StringIO() + write_to_f = redirect_stdout(f) + orig_stdout = sys.stdout + with write_to_f: + print("Hello", end=" ") + with write_to_f: + print("World!") + self.assertIs(sys.stdout, orig_stdout) + s = f.getvalue() + self.assertEqual(s, "Hello World!\n") + + def test_cm_is_reentrant(self): + f = io.StringIO() + write_to_f = redirect_stdout(f) + orig_stdout = sys.stdout + with write_to_f: + print("Hello", end=" ") + with write_to_f: + print("World!") + self.assertIs(sys.stdout, orig_stdout) + s = f.getvalue() + self.assertEqual(s, "Hello World!\n") + + +class TestSuppress(unittest.TestCase): + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = suppress.__doc__ + obj = suppress() + self.assertEqual(obj.__doc__, cm_docstring) + + def test_no_result_from_enter(self): + with suppress(ValueError) as enter_result: + self.assertIsNone(enter_result) + + def test_no_exception(self): + with suppress(ValueError): + self.assertEqual(pow(2, 5), 32) + + def test_exact_exception(self): + with suppress(TypeError): + len(5) + + def test_exception_hierarchy(self): + with suppress(LookupError): + 'Hello'[50] + + def test_other_exception(self): + with self.assertRaises(ZeroDivisionError): + with suppress(TypeError): + 1/0 + + def test_no_args(self): + with self.assertRaises(ZeroDivisionError): + with suppress(): + 1/0 -# This is needed to make the test actually run under regrtest.py! -def test_main(): - support.run_unittest(__name__) + def test_multiple_exception_args(self): + with suppress(ZeroDivisionError, TypeError): + 1/0 + with suppress(ZeroDivisionError, TypeError): + len(5) + + def test_cm_is_reentrant(self): + ignore_exceptions = suppress(Exception) + with ignore_exceptions: + pass + with ignore_exceptions: + len(5) + with ignore_exceptions: + 1/0 + with ignore_exceptions: # Check nested usage + len(5) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py index 56766682b3..ce5d27edd6 100644 --- a/Lib/test/test_cprofile.py +++ b/Lib/test/test_cprofile.py @@ -6,9 +6,11 @@ from test.support import run_unittest, TESTFN, unlink # rip off all interesting stuff from test_profile import cProfile from test.test_profile import ProfileTest, regenerate_expected_output +from test.profilee import testfunc class CProfileTest(ProfileTest): profilerclass = cProfile.Profile + profilermodule = cProfile expected_max_output = "{built-in method max}" def get_expected_output(self): @@ -27,6 +29,7 @@ class CProfileTest(ProfileTest): obj.enable() obj = _lsprof.Profiler(1) obj.disable() + obj.clear() finally: sys.stderr = orig_stderr finally: diff --git a/Lib/test/test_crypt.py b/Lib/test/test_crypt.py index cfb7341e20..624d702f99 100644 --- a/Lib/test/test_crypt.py +++ b/Lib/test/test_crypt.py @@ -1,11 +1,7 @@ from test import support import unittest -def setUpModule(): - # this import will raise unittest.SkipTest if _crypt doesn't exist, - # so it has to be done in setUpModule for test discovery to work - global crypt - crypt = support.import_module('crypt') +crypt = support.import_module('crypt') class CryptTestCase(unittest.TestCase): diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 83f8cb3cb7..7e2485f954 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -136,12 +136,12 @@ class Test_Csv(unittest.TestCase): return 10; def __getitem__(self, i): if i > 2: - raise IOError - self.assertRaises(IOError, self._write_test, BadList(), '') + raise OSError + self.assertRaises(OSError, self._write_test, BadList(), '') class BadItem: def __str__(self): - raise IOError - self.assertRaises(IOError, self._write_test, [BadItem()], '') + raise OSError + self.assertRaises(OSError, self._write_test, [BadItem()], '') def test_write_bigfield(self): # This exercises the buffer realloc functionality @@ -186,9 +186,9 @@ class Test_Csv(unittest.TestCase): def test_writerows(self): class BrokenFile: def write(self, buf): - raise IOError + raise OSError writer = csv.writer(BrokenFile()) - self.assertRaises(IOError, writer.writerows, [['a']]) + self.assertRaises(OSError, writer.writerows, [['a']]) with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj) @@ -308,6 +308,15 @@ class Test_Csv(unittest.TestCase): for i, row in enumerate(csv.reader(fileobj)): self.assertEqual(row, rows[i]) + def test_roundtrip_escaped_unquoted_newlines(self): + with TemporaryFile("w+", newline='') as fileobj: + writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\") + rows = [['a\nb','b'],['c','x\r\nd']] + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")): + self.assertEqual(row,rows[i]) + class TestDialectRegistry(unittest.TestCase): def test_registry_badargs(self): self.assertRaises(TypeError, csv.list_dialects, None) @@ -836,10 +845,11 @@ class TestDialectValidity(unittest.TestCase): d = mydialect() for field_name in ("delimiter", "escapechar", "quotechar"): - self.assertRaises(csv.Error, create_invalid, field_name, "") - self.assertRaises(csv.Error, create_invalid, field_name, "abc") - self.assertRaises(csv.Error, create_invalid, field_name, b'x') - self.assertRaises(csv.Error, create_invalid, field_name, 5) + with self.subTest(field_name=field_name): + self.assertRaises(csv.Error, create_invalid, field_name, "") + self.assertRaises(csv.Error, create_invalid, field_name, "abc") + self.assertRaises(csv.Error, create_invalid, field_name, b'x') + self.assertRaises(csv.Error, create_invalid, field_name, 5) class TestSniffer(unittest.TestCase): diff --git a/Lib/test/test_curses.py b/Lib/test/test_curses.py index 7310afc995..ce8f254bbb 100644 --- a/Lib/test/test_curses.py +++ b/Lib/test/test_curses.py @@ -115,8 +115,8 @@ def window_funcs(stdscr): stdscr.notimeout(1) win2.overlay(win) win2.overwrite(win) - win2.overlay(win, 1, 2, 3, 3, 2, 1) - win2.overwrite(win, 1, 2, 3, 3, 2, 1) + win2.overlay(win, 1, 2, 2, 1, 3, 3) + win2.overwrite(win, 1, 2, 2, 1, 3, 3) stdscr.redrawln(1,2) stdscr.scrollok(1) diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py index d11d924b30..623d9929d5 100644 --- a/Lib/test/test_dbm.py +++ b/Lib/test/test_dbm.py @@ -61,7 +61,7 @@ class AnyDBMTestCase: return keys def test_error(self): - self.assertTrue(issubclass(self.module.error, IOError)) + self.assertTrue(issubclass(self.module.error, OSError)) def test_anydbm_not_existing(self): self.assertRaises(dbm.error, dbm.open, _fname) diff --git a/Lib/test/test_dbm_dumb.py b/Lib/test/test_dbm_dumb.py index 4f03ae7e58..7601e64638 100644 --- a/Lib/test/test_dbm_dumb.py +++ b/Lib/test/test_dbm_dumb.py @@ -183,6 +183,19 @@ class DumbDBMTestCase(unittest.TestCase): self.assertEqual(expected, got) f.close() + def test_context_manager(self): + with dumbdbm.open(_fname, 'c') as db: + db["dumbdbm context manager"] = "context manager" + + with dumbdbm.open(_fname, 'r') as db: + self.assertEqual(list(db.keys()), [b"dumbdbm context manager"]) + + # This currently just raises AttributeError rather than a specific + # exception like the GNU or NDBM based implementations. See + # http://bugs.python.org/issue19385 for details. + with self.assertRaises(Exception): + db.keys() + def tearDown(self): _delete_files() diff --git a/Lib/test/test_dbm_gnu.py b/Lib/test/test_dbm_gnu.py index 4fb66c54b8..a7808f51c7 100644 --- a/Lib/test/test_dbm_gnu.py +++ b/Lib/test/test_dbm_gnu.py @@ -81,6 +81,17 @@ class TestGdbm(unittest.TestCase): size2 = os.path.getsize(filename) self.assertTrue(size1 > size2 >= size0) + def test_context_manager(self): + with gdbm.open(filename, 'c') as db: + db["gdbm context manager"] = "context manager" + + with gdbm.open(filename, 'r') as db: + self.assertEqual(list(db.keys()), [b"gdbm context manager"]) + + with self.assertRaises(gdbm.error) as cm: + db.keys() + self.assertEqual(str(cm.exception), + "GDBM object has already been closed") if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_dbm_ndbm.py b/Lib/test/test_dbm_ndbm.py index b57e1f0653..2291561def 100644 --- a/Lib/test/test_dbm_ndbm.py +++ b/Lib/test/test_dbm_ndbm.py @@ -37,5 +37,18 @@ class DbmTestCase(unittest.TestCase): except error: self.fail() + def test_context_manager(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db["ndbm context manager"] = "context manager" + + with dbm.ndbm.open(self.filename, 'r') as db: + self.assertEqual(list(db.keys()), [b"ndbm context manager"]) + + with self.assertRaises(dbm.ndbm.error) as cm: + db.keys() + self.assertEqual(str(cm.exception), + "DBM object has already been closed") + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py index a8487d2d95..7bff1d2798 100644 --- a/Lib/test/test_deque.py +++ b/Lib/test/test_deque.py @@ -543,8 +543,8 @@ class TestBasic(unittest.TestCase): check = self.check_sizeof check(deque(), basesize + blocksize) check(deque('a'), basesize + blocksize) - check(deque('a' * (BLOCKLEN // 2)), basesize + blocksize) - check(deque('a' * (BLOCKLEN // 2 + 1)), basesize + 2 * blocksize) + check(deque('a' * (BLOCKLEN - 1)), basesize + blocksize) + check(deque('a' * BLOCKLEN), basesize + 2 * blocksize) check(deque('a' * (42 * BLOCKLEN)), basesize + 43 * blocksize) class TestVariousIteratorArgs(unittest.TestCase): diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index 1a891a8063..2a9e329633 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1,8 +1,11 @@ import builtins +import copyreg import gc +import itertools +import math +import pickle import sys import types -import math import unittest import weakref @@ -1797,6 +1800,7 @@ order (MRO) for bases """ ("__trunc__", int, zero, set(), {}), ("__ceil__", math.ceil, zero, set(), {}), ("__dir__", dir, empty_seq, set(), {}), + ("__round__", round, zero, set(), {}), ] class Checker(object): @@ -2258,7 +2262,9 @@ order (MRO) for bases """ minstance = M("m") minstance.b = 2 minstance.a = 1 - names = [x for x in dir(minstance) if x not in ["__name__", "__doc__"]] + default_attributes = ['__name__', '__doc__', '__package__', + '__loader__', '__spec__'] + names = [x for x in dir(minstance) if x not in default_attributes] self.assertEqual(names, ['a', 'b']) class M2(M): @@ -3151,176 +3157,6 @@ order (MRO) for bases """ self.assertEqual(e.a, 1) self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError())) - def test_pickles(self): - # Testing pickling and copying new-style classes and objects... - import pickle - - def sorteditems(d): - L = list(d.items()) - L.sort() - return L - - global C - class C(object): - def __init__(self, a, b): - super(C, self).__init__() - self.a = a - self.b = b - def __repr__(self): - return "C(%r, %r)" % (self.a, self.b) - - global C1 - class C1(list): - def __new__(cls, a, b): - return super(C1, cls).__new__(cls) - def __getnewargs__(self): - return (self.a, self.b) - def __init__(self, a, b): - self.a = a - self.b = b - def __repr__(self): - return "C1(%r, %r)<%r>" % (self.a, self.b, list(self)) - - global C2 - class C2(int): - def __new__(cls, a, b, val=0): - return super(C2, cls).__new__(cls, val) - def __getnewargs__(self): - return (self.a, self.b, int(self)) - def __init__(self, a, b, val=0): - self.a = a - self.b = b - def __repr__(self): - return "C2(%r, %r)<%r>" % (self.a, self.b, int(self)) - - global C3 - class C3(object): - def __init__(self, foo): - self.foo = foo - def __getstate__(self): - return self.foo - def __setstate__(self, foo): - self.foo = foo - - global C4classic, C4 - class C4classic: # classic - pass - class C4(C4classic, object): # mixed inheritance - pass - - for bin in 0, 1: - for cls in C, C1, C2: - s = pickle.dumps(cls, bin) - cls2 = pickle.loads(s) - self.assertIs(cls2, cls) - - a = C1(1, 2); a.append(42); a.append(24) - b = C2("hello", "world", 42) - s = pickle.dumps((a, b), bin) - x, y = pickle.loads(s) - self.assertEqual(x.__class__, a.__class__) - self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__)) - self.assertEqual(y.__class__, b.__class__) - self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__)) - self.assertEqual(repr(x), repr(a)) - self.assertEqual(repr(y), repr(b)) - # Test for __getstate__ and __setstate__ on new style class - u = C3(42) - s = pickle.dumps(u, bin) - v = pickle.loads(s) - self.assertEqual(u.__class__, v.__class__) - self.assertEqual(u.foo, v.foo) - # Test for picklability of hybrid class - u = C4() - u.foo = 42 - s = pickle.dumps(u, bin) - v = pickle.loads(s) - self.assertEqual(u.__class__, v.__class__) - self.assertEqual(u.foo, v.foo) - - # Testing copy.deepcopy() - import copy - for cls in C, C1, C2: - cls2 = copy.deepcopy(cls) - self.assertIs(cls2, cls) - - a = C1(1, 2); a.append(42); a.append(24) - b = C2("hello", "world", 42) - x, y = copy.deepcopy((a, b)) - self.assertEqual(x.__class__, a.__class__) - self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__)) - self.assertEqual(y.__class__, b.__class__) - self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__)) - self.assertEqual(repr(x), repr(a)) - self.assertEqual(repr(y), repr(b)) - - def test_pickle_slots(self): - # Testing pickling of classes with __slots__ ... - import pickle - # Pickling of classes with __slots__ but without __getstate__ should fail - # (if using protocol 0 or 1) - global B, C, D, E - class B(object): - pass - for base in [object, B]: - class C(base): - __slots__ = ['a'] - class D(C): - pass - try: - pickle.dumps(C(), 0) - except TypeError: - pass - else: - self.fail("should fail: pickle C instance - %s" % base) - try: - pickle.dumps(C(), 0) - except TypeError: - pass - else: - self.fail("should fail: pickle D instance - %s" % base) - # Give C a nice generic __getstate__ and __setstate__ - class C(base): - __slots__ = ['a'] - def __getstate__(self): - try: - d = self.__dict__.copy() - except AttributeError: - d = {} - for cls in self.__class__.__mro__: - for sn in cls.__dict__.get('__slots__', ()): - try: - d[sn] = getattr(self, sn) - except AttributeError: - pass - return d - def __setstate__(self, d): - for k, v in list(d.items()): - setattr(self, k, v) - class D(C): - pass - # Now it should work - x = C() - y = pickle.loads(pickle.dumps(x)) - self.assertNotHasAttr(y, 'a') - x.a = 42 - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a, 42) - x = D() - x.a = 42 - x.b = 100 - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a + y.b, 142) - # A subclass that adds a slot should also work - class E(C): - __slots__ = ['b'] - x = E() - x.a = 42 - x.b = "foo" - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a, x.a) - self.assertEqual(y.b, x.b) - def test_binary_operator_override(self): # Testing overrides of binary operations... class I(int): @@ -3745,18 +3581,8 @@ order (MRO) for bases """ # bug). del c - # If that didn't blow up, it's also interesting to see whether clearing - # the last container slot works: that will attempt to delete c again, - # which will cause c to get appended back to the container again - # "during" the del. (On non-CPython implementations, however, __del__ - # is typically not called again.) support.gc_collect() self.assertEqual(len(C.container), 1) - del C.container[-1] - if support.check_impl_detail(): - support.gc_collect() - self.assertEqual(len(C.container), 1) - self.assertEqual(C.container[-1].attr, 42) # Make c mortal again, so that the test framework with -l doesn't report # it as a leak. @@ -4536,6 +4362,13 @@ order (MRO) for bases """ self.assertRaises(TypeError, type.__dict__['__qualname__'].__set__, str, 'Oink') + global Y + class Y: + class Inside: + pass + self.assertEqual(Y.__qualname__, 'Y') + self.assertEqual(Y.Inside.__qualname__, 'Y.Inside') + def test_qualname_dict(self): ns = {'__qualname__': 'some.name'} tp = type('Foo', (), ns) @@ -4691,11 +4524,453 @@ class MiscTests(unittest.TestCase): self.assertEqual(X.mykey2, 'from Base2') +class PicklingTests(unittest.TestCase): + + def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None, + listitems=None, dictitems=None): + if proto >= 4: + reduce_value = obj.__reduce_ex__(proto) + self.assertEqual(reduce_value[:3], + (copyreg.__newobj_ex__, + (type(obj), args, kwargs), + state)) + if listitems is not None: + self.assertListEqual(list(reduce_value[3]), listitems) + else: + self.assertIsNone(reduce_value[3]) + if dictitems is not None: + self.assertDictEqual(dict(reduce_value[4]), dictitems) + else: + self.assertIsNone(reduce_value[4]) + elif proto >= 2: + reduce_value = obj.__reduce_ex__(proto) + self.assertEqual(reduce_value[:3], + (copyreg.__newobj__, + (type(obj),) + args, + state)) + if listitems is not None: + self.assertListEqual(list(reduce_value[3]), listitems) + else: + self.assertIsNone(reduce_value[3]) + if dictitems is not None: + self.assertDictEqual(dict(reduce_value[4]), dictitems) + else: + self.assertIsNone(reduce_value[4]) + else: + base_type = type(obj).__base__ + reduce_value = (copyreg._reconstructor, + (type(obj), + base_type, + None if base_type is object else base_type(obj))) + if state is not None: + reduce_value += (state,) + self.assertEqual(obj.__reduce_ex__(proto), reduce_value) + self.assertEqual(obj.__reduce__(), reduce_value) + + def test_reduce(self): + protocols = range(pickle.HIGHEST_PROTOCOL + 1) + args = (-101, "spam") + kwargs = {'bacon': -201, 'fish': -301} + state = {'cheese': -401} + + class C1: + def __getnewargs__(self): + return args + obj = C1() + for proto in protocols: + self._check_reduce(proto, obj, args) + + for name, value in state.items(): + setattr(obj, name, value) + for proto in protocols: + self._check_reduce(proto, obj, args, state=state) + + class C2: + def __getnewargs__(self): + return "bad args" + obj = C2() + for proto in protocols: + if proto >= 2: + with self.assertRaises(TypeError): + obj.__reduce_ex__(proto) + + class C3: + def __getnewargs_ex__(self): + return (args, kwargs) + obj = C3() + for proto in protocols: + if proto >= 4: + self._check_reduce(proto, obj, args, kwargs) + elif proto >= 2: + with self.assertRaises(ValueError): + obj.__reduce_ex__(proto) + + class C4: + def __getnewargs_ex__(self): + return (args, "bad dict") + class C5: + def __getnewargs_ex__(self): + return ("bad tuple", kwargs) + class C6: + def __getnewargs_ex__(self): + return () + class C7: + def __getnewargs_ex__(self): + return "bad args" + for proto in protocols: + for cls in C4, C5, C6, C7: + obj = cls() + if proto >= 2: + with self.assertRaises((TypeError, ValueError)): + obj.__reduce_ex__(proto) + + class C8: + def __getnewargs_ex__(self): + return (args, kwargs) + obj = C8() + for proto in protocols: + if 2 <= proto < 4: + with self.assertRaises(ValueError): + obj.__reduce_ex__(proto) + class C9: + def __getnewargs_ex__(self): + return (args, {}) + obj = C9() + for proto in protocols: + self._check_reduce(proto, obj, args) + + class C10: + def __getnewargs_ex__(self): + raise IndexError + obj = C10() + for proto in protocols: + if proto >= 2: + with self.assertRaises(IndexError): + obj.__reduce_ex__(proto) + + class C11: + def __getstate__(self): + return state + obj = C11() + for proto in protocols: + self._check_reduce(proto, obj, state=state) + + class C12: + def __getstate__(self): + return "not dict" + obj = C12() + for proto in protocols: + self._check_reduce(proto, obj, state="not dict") + + class C13: + def __getstate__(self): + raise IndexError + obj = C13() + for proto in protocols: + with self.assertRaises(IndexError): + obj.__reduce_ex__(proto) + if proto < 2: + with self.assertRaises(IndexError): + obj.__reduce__() + + class C14: + __slots__ = tuple(state) + def __init__(self): + for name, value in state.items(): + setattr(self, name, value) + + obj = C14() + for proto in protocols: + if proto >= 2: + self._check_reduce(proto, obj, state=(None, state)) + else: + with self.assertRaises(TypeError): + obj.__reduce_ex__(proto) + with self.assertRaises(TypeError): + obj.__reduce__() + + class C15(dict): + pass + obj = C15({"quebec": -601}) + for proto in protocols: + self._check_reduce(proto, obj, dictitems=dict(obj)) + + class C16(list): + pass + obj = C16(["yukon"]) + for proto in protocols: + self._check_reduce(proto, obj, listitems=list(obj)) + + def test_special_method_lookup(self): + protocols = range(pickle.HIGHEST_PROTOCOL + 1) + class Picky: + def __getstate__(self): + return {} + + def __getattr__(self, attr): + if attr in ("__getnewargs__", "__getnewargs_ex__"): + raise AssertionError(attr) + return None + for protocol in protocols: + state = {} if protocol >= 2 else None + self._check_reduce(protocol, Picky(), state=state) + + def _assert_is_copy(self, obj, objcopy, msg=None): + """Utility method to verify if two objects are copies of each others. + """ + if msg is None: + msg = "{!r} is not a copy of {!r}".format(obj, objcopy) + if type(obj).__repr__ is object.__repr__: + # We have this limitation for now because we use the object's repr + # to help us verify that the two objects are copies. This allows + # us to delegate the non-generic verification logic to the objects + # themselves. + raise ValueError("object passed to _assert_is_copy must " + + "override the __repr__ method.") + self.assertIsNot(obj, objcopy, msg=msg) + self.assertIs(type(obj), type(objcopy), msg=msg) + if hasattr(obj, '__dict__'): + self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg) + self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg) + if hasattr(obj, '__slots__'): + self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg) + for slot in obj.__slots__: + self.assertEqual( + hasattr(obj, slot), hasattr(objcopy, slot), msg=msg) + self.assertEqual(getattr(obj, slot, None), + getattr(objcopy, slot, None), msg=msg) + self.assertEqual(repr(obj), repr(objcopy), msg=msg) + + @staticmethod + def _generate_pickle_copiers(): + """Utility method to generate the many possible pickle configurations. + """ + class PickleCopier: + "This class copies object using pickle." + def __init__(self, proto, dumps, loads): + self.proto = proto + self.dumps = dumps + self.loads = loads + def copy(self, obj): + return self.loads(self.dumps(obj, self.proto)) + def __repr__(self): + # We try to be as descriptive as possible here since this is + # the string which we will allow us to tell the pickle + # configuration we are using during debugging. + return ("PickleCopier(proto={}, dumps={}.{}, loads={}.{})" + .format(self.proto, + self.dumps.__module__, self.dumps.__qualname__, + self.loads.__module__, self.loads.__qualname__)) + return (PickleCopier(*args) for args in + itertools.product(range(pickle.HIGHEST_PROTOCOL + 1), + {pickle.dumps, pickle._dumps}, + {pickle.loads, pickle._loads})) + + def test_pickle_slots(self): + # Tests pickling of classes with __slots__. + + # Pickling of classes with __slots__ but without __getstate__ should + # fail (if using protocol 0 or 1) + global C + class C: + __slots__ = ['a'] + with self.assertRaises(TypeError): + pickle.dumps(C(), 0) + + global D + class D(C): + pass + with self.assertRaises(TypeError): + pickle.dumps(D(), 0) + + class C: + "A class with __getstate__ and __setstate__ implemented." + __slots__ = ['a'] + def __getstate__(self): + state = getattr(self, '__dict__', {}).copy() + for cls in type(self).__mro__: + for slot in cls.__dict__.get('__slots__', ()): + try: + state[slot] = getattr(self, slot) + except AttributeError: + pass + return state + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + def __repr__(self): + return "%s()<%r>" % (type(self).__name__, self.__getstate__()) + + class D(C): + "A subclass of a class with slots." + pass + + global E + class E(C): + "A subclass with an extra slot." + __slots__ = ['b'] + + # Now it should work + for pickle_copier in self._generate_pickle_copiers(): + with self.subTest(pickle_copier=pickle_copier): + x = C() + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x.a = 42 + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x = D() + x.a = 42 + x.b = 100 + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x = E() + x.a = 42 + x.b = "foo" + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + def test_reduce_copying(self): + # Tests pickling and copying new-style classes and objects. + global C1 + class C1: + "The state of this class is copyable via its instance dict." + ARGS = (1, 2) + NEED_DICT_COPYING = True + def __init__(self, a, b): + super().__init__() + self.a = a + self.b = b + def __repr__(self): + return "C1(%r, %r)" % (self.a, self.b) + + global C2 + class C2(list): + "A list subclass copyable via __getnewargs__." + ARGS = (1, 2) + NEED_DICT_COPYING = False + def __new__(cls, a, b): + self = super().__new__(cls) + self.a = a + self.b = b + return self + def __init__(self, *args): + super().__init__() + # This helps testing that __init__ is not called during the + # unpickling process, which would cause extra appends. + self.append("cheese") + @classmethod + def __getnewargs__(cls): + return cls.ARGS + def __repr__(self): + return "C2(%r, %r)<%r>" % (self.a, self.b, list(self)) + + global C3 + class C3(list): + "A list subclass copyable via __getstate__." + ARGS = (1, 2) + NEED_DICT_COPYING = False + def __init__(self, a, b): + self.a = a + self.b = b + # This helps testing that __init__ is not called during the + # unpickling process, which would cause extra appends. + self.append("cheese") + @classmethod + def __getstate__(cls): + return cls.ARGS + def __setstate__(self, state): + a, b = state + self.a = a + self.b = b + def __repr__(self): + return "C3(%r, %r)<%r>" % (self.a, self.b, list(self)) + + global C4 + class C4(int): + "An int subclass copyable via __getnewargs__." + ARGS = ("hello", "world", 1) + NEED_DICT_COPYING = False + def __new__(cls, a, b, value): + self = super().__new__(cls, value) + self.a = a + self.b = b + return self + @classmethod + def __getnewargs__(cls): + return cls.ARGS + def __repr__(self): + return "C4(%r, %r)<%r>" % (self.a, self.b, int(self)) + + global C5 + class C5(int): + "An int subclass copyable via __getnewargs_ex__." + ARGS = (1, 2) + KWARGS = {'value': 3} + NEED_DICT_COPYING = False + def __new__(cls, a, b, *, value=0): + self = super().__new__(cls, value) + self.a = a + self.b = b + return self + @classmethod + def __getnewargs_ex__(cls): + return (cls.ARGS, cls.KWARGS) + def __repr__(self): + return "C5(%r, %r)<%r>" % (self.a, self.b, int(self)) + + test_classes = (C1, C2, C3, C4, C5) + # Testing copying through pickle + pickle_copiers = self._generate_pickle_copiers() + for cls, pickle_copier in itertools.product(test_classes, pickle_copiers): + with self.subTest(cls=cls, pickle_copier=pickle_copier): + kwargs = getattr(cls, 'KWARGS', {}) + obj = cls(*cls.ARGS, **kwargs) + proto = pickle_copier.proto + if 2 <= proto < 4 and hasattr(cls, '__getnewargs_ex__'): + with self.assertRaises(ValueError): + pickle_copier.dumps(obj, proto) + continue + objcopy = pickle_copier.copy(obj) + self._assert_is_copy(obj, objcopy) + # For test classes that supports this, make sure we didn't go + # around the reduce protocol by simply copying the attribute + # dictionary. We clear attributes using the previous copy to + # not mutate the original argument. + if proto >= 2 and not cls.NEED_DICT_COPYING: + objcopy.__dict__.clear() + objcopy2 = pickle_copier.copy(objcopy) + self._assert_is_copy(obj, objcopy2) + + # Testing copying through copy.deepcopy() + for cls in test_classes: + with self.subTest(cls=cls): + kwargs = getattr(cls, 'KWARGS', {}) + obj = cls(*cls.ARGS, **kwargs) + # XXX: We need to modify the copy module to support PEP 3154's + # reduce protocol 4. + if hasattr(cls, '__getnewargs_ex__'): + continue + objcopy = deepcopy(obj) + self._assert_is_copy(obj, objcopy) + # For test classes that supports this, make sure we didn't go + # around the reduce protocol by simply copying the attribute + # dictionary. We clear attributes using the previous copy to + # not mutate the original argument. + if not cls.NEED_DICT_COPYING: + objcopy.__dict__.clear() + objcopy2 = deepcopy(objcopy) + self._assert_is_copy(obj, objcopy2) + + def test_main(): # Run all local test cases, with PTypesLongInitTest first. support.run_unittest(PTypesLongInitTest, OperatorsTest, ClassPropertiesAndMethods, DictProxyTests, - MiscTests) + MiscTests, PicklingTests) if __name__ == "__main__": test_main() diff --git a/Lib/test/test_devpoll.py b/Lib/test/test_devpoll.py index 9129ac0cb1..e634d3a373 100644 --- a/Lib/test/test_devpoll.py +++ b/Lib/test/test_devpoll.py @@ -2,13 +2,17 @@ # Initial tests are copied as is from "test_poll.py" -import os, select, random, unittest, sys +import os +import random +import select +import sys +import unittest from test.support import TESTFN, run_unittest, cpython_only try: select.devpoll except AttributeError: - raise unittest.SkipTest("select.devpoll not defined -- skipping test_devpoll") + raise unittest.SkipTest("select.devpoll not defined") def find_ready_matching(ready, flag): @@ -87,6 +91,35 @@ class DevPollTests(unittest.TestCase): self.assertRaises(OverflowError, pollster.poll, 1 << 63) self.assertRaises(OverflowError, pollster.poll, 1 << 64) + def test_close(self): + open_file = open(__file__, "rb") + self.addCleanup(open_file.close) + fd = open_file.fileno() + devpoll = select.devpoll() + + # test fileno() method and closed attribute + self.assertIsInstance(devpoll.fileno(), int) + self.assertFalse(devpoll.closed) + + # test close() + devpoll.close() + self.assertTrue(devpoll.closed) + self.assertRaises(ValueError, devpoll.fileno) + + # close() can be called more than once + devpoll.close() + + # operations must fail with ValueError("I/O operation on closed ...") + self.assertRaises(ValueError, devpoll.modify, fd, select.POLLIN) + self.assertRaises(ValueError, devpoll.poll) + self.assertRaises(ValueError, devpoll.register, fd, fd, select.POLLIN) + self.assertRaises(ValueError, devpoll.unregister, fd) + + def test_fd_non_inheritable(self): + devpoll = select.devpoll() + self.addCleanup(devpoll.close) + self.assertEqual(os.get_inheritable(devpoll.fileno()), False) + def test_events_mask_overflow(self): pollster = select.devpoll() w, r = os.pipe() diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py index b86cc86856..d1229fba6d 100644 --- a/Lib/test/test_dis.py +++ b/Lib/test/test_dis.py @@ -1,12 +1,30 @@ # Minimal tests for dis module from test.support import run_unittest, captured_stdout +from test.bytecode_helper import BytecodeTestCase import difflib import unittest import sys import dis import io import re +import types +import contextlib + +def get_tb(): + def _error(): + try: + 1 / 0 + except Exception as e: + tb = e.__traceback__ + return tb + + tb = _error() + while tb.tb_next: + tb = tb.tb_next + return tb + +TRACEBACK_CODE = get_tb().tb_frame.f_code class _C: def __init__(self, x): @@ -23,12 +41,12 @@ dis_c_instance_method = """\ """ % (_C.__init__.__code__.co_firstlineno + 1,) dis_c_instance_method_bytes = """\ - 0 LOAD_FAST 1 (1) - 3 LOAD_CONST 1 (1) - 6 COMPARE_OP 2 (==) - 9 LOAD_FAST 0 (0) - 12 STORE_ATTR 0 (0) - 15 LOAD_CONST 0 (0) + 0 LOAD_FAST 1 (1) + 3 LOAD_CONST 1 (1) + 6 COMPARE_OP 2 (==) + 9 LOAD_FAST 0 (0) + 12 STORE_ATTR 0 (0) + 15 LOAD_CONST 0 (0) 18 RETURN_VALUE """ @@ -49,11 +67,11 @@ dis_f = """\ dis_f_co_code = """\ - 0 LOAD_GLOBAL 0 (0) - 3 LOAD_FAST 0 (0) - 6 CALL_FUNCTION 1 (1 positional, 0 keyword pair) + 0 LOAD_GLOBAL 0 (0) + 3 LOAD_FAST 0 (0) + 6 CALL_FUNCTION 1 (1 positional, 0 keyword pair) 9 POP_TOP - 10 LOAD_CONST 1 (1) + 10 LOAD_CONST 1 (1) 13 RETURN_VALUE """ @@ -89,7 +107,7 @@ def bug1333982(x=[]): pass dis_bug1333982 = """\ - %-4d 0 LOAD_CONST 1 (0) +%3d 0 LOAD_CONST 1 (0) 3 POP_JUMP_IF_TRUE 35 6 LOAD_GLOBAL 0 (AssertionError) 9 LOAD_CONST 2 (<code object <listcomp> at 0x..., file "%s", line %d>) @@ -99,12 +117,12 @@ dis_bug1333982 = """\ 21 GET_ITER 22 CALL_FUNCTION 1 (1 positional, 0 keyword pair) - %-4d 25 LOAD_CONST 4 (1) +%3d 25 LOAD_CONST 4 (1) 28 BINARY_ADD 29 CALL_FUNCTION 1 (1 positional, 0 keyword pair) 32 RAISE_VARARGS 1 - %-4d >> 35 LOAD_CONST 0 (None) +%3d >> 35 LOAD_CONST 0 (None) 38 RETURN_VALUE """ % (bug1333982.__code__.co_firstlineno + 1, __file__, @@ -171,38 +189,69 @@ dis_compound_stmt_str = """\ 25 RETURN_VALUE """ +dis_traceback = """\ + %-4d 0 SETUP_EXCEPT 12 (to 15) + + %-4d 3 LOAD_CONST 1 (1) + 6 LOAD_CONST 2 (0) + --> 9 BINARY_TRUE_DIVIDE + 10 POP_TOP + 11 POP_BLOCK + 12 JUMP_FORWARD 46 (to 61) + + %-4d >> 15 DUP_TOP + 16 LOAD_GLOBAL 0 (Exception) + 19 COMPARE_OP 10 (exception match) + 22 POP_JUMP_IF_FALSE 60 + 25 POP_TOP + 26 STORE_FAST 0 (e) + 29 POP_TOP + 30 SETUP_FINALLY 14 (to 47) + + %-4d 33 LOAD_FAST 0 (e) + 36 LOAD_ATTR 1 (__traceback__) + 39 STORE_FAST 1 (tb) + 42 POP_BLOCK + 43 POP_EXCEPT + 44 LOAD_CONST 0 (None) + >> 47 LOAD_CONST 0 (None) + 50 STORE_FAST 0 (e) + 53 DELETE_FAST 0 (e) + 56 END_FINALLY + 57 JUMP_FORWARD 1 (to 61) + >> 60 END_FINALLY + + %-4d >> 61 LOAD_FAST 1 (tb) + 64 RETURN_VALUE +""" % (TRACEBACK_CODE.co_firstlineno + 1, + TRACEBACK_CODE.co_firstlineno + 2, + TRACEBACK_CODE.co_firstlineno + 3, + TRACEBACK_CODE.co_firstlineno + 4, + TRACEBACK_CODE.co_firstlineno + 5) + class DisTests(unittest.TestCase): def get_disassembly(self, func, lasti=-1, wrapper=True): - s = io.StringIO() - save_stdout = sys.stdout - sys.stdout = s - try: + # We want to test the default printing behaviour, not the file arg + output = io.StringIO() + with contextlib.redirect_stdout(output): if wrapper: dis.dis(func) else: dis.disassemble(func, lasti) - finally: - sys.stdout = save_stdout - # Trim trailing blanks (if any). - return [line.rstrip() for line in s.getvalue().splitlines()] + return output.getvalue() def get_disassemble_as_string(self, func, lasti=-1): - return '\n'.join(self.get_disassembly(func, lasti, False)) + return self.get_disassembly(func, lasti, False) + + def strip_addresses(self, text): + return re.sub(r'\b0x[0-9A-Fa-f]+\b', '0x...', text) def do_disassembly_test(self, func, expected): - lines = self.get_disassembly(func) - expected = expected.splitlines() - if expected == lines: - return - else: - lines = [re.sub('0x[0-9A-Fa-f]+', '0x...', l) for l in lines] - if expected == lines: - return - self.fail( - "events did not match expectation:\n" + - "\n".join(difflib.ndiff(expected, - lines))) + got = self.get_disassembly(func) + if got != expected: + got = self.strip_addresses(got) + self.assertEqual(got, expected) def test_opmap(self): self.assertEqual(dis.opmap["NOP"], 9) @@ -290,35 +339,35 @@ class DisTests(unittest.TestCase): def test_dis_object(self): self.assertRaises(TypeError, dis.dis, object()) +class DisWithFileTests(DisTests): + + # Run the tests again, using the file arg instead of print + def get_disassembly(self, func, lasti=-1, wrapper=True): + output = io.StringIO() + if wrapper: + dis.dis(func, file=output) + else: + dis.disassemble(func, lasti, file=output) + return output.getvalue() + + + code_info_code_info = """\ Name: code_info Filename: (.*) Argument count: 1 Kw-only arguments: 0 Number of locals: 1 -Stack size: 4 +Stack size: 3 Flags: OPTIMIZED, NEWLOCALS, NOFREE Constants: 0: %r - 1: '__func__' - 2: '__code__' - 3: '<code_info>' - 4: 'co_code' - 5: "don't know how to disassemble %%s objects" -%sNames: - 0: hasattr - 1: __func__ - 2: __code__ - 3: isinstance - 4: str - 5: _try_compile - 6: _format_code_info - 7: TypeError - 8: type - 9: __name__ +Names: + 0: _format_code_info + 1: _get_code_object Variable names: - 0: x""" % (('Formatted details of methods, functions, or code.', ' 6: None\n') - if sys.flags.optimize < 2 else (None, '')) + 0: x""" % (('Formatted details of methods, functions, or code.',) + if sys.flags.optimize < 2 else (None,)) @staticmethod def tricky(x, y, z=True, *args, c, d, e=[], **kwds): @@ -382,7 +431,7 @@ Free variables: code_info_expr_str = """\ Name: <module> -Filename: <code_info> +Filename: <disassembly> Argument count: 0 Kw-only arguments: 0 Number of locals: 0 @@ -395,7 +444,7 @@ Names: code_info_simple_stmt_str = """\ Name: <module> -Filename: <code_info> +Filename: <disassembly> Argument count: 0 Kw-only arguments: 0 Number of locals: 0 @@ -409,7 +458,7 @@ Names: code_info_compound_stmt_str = """\ Name: <module> -Filename: <code_info> +Filename: <disassembly> Argument count: 0 Kw-only arguments: 0 Number of locals: 0 @@ -443,6 +492,9 @@ class CodeInfoTests(unittest.TestCase): with captured_stdout() as output: dis.show_code(x) self.assertRegex(output.getvalue(), expected+"\n") + output = io.StringIO() + dis.show_code(x, file=output) + self.assertRegex(output.getvalue(), expected) def test_code_info_object(self): self.assertRaises(TypeError, dis.code_info, object()) @@ -451,5 +503,330 @@ class CodeInfoTests(unittest.TestCase): self.assertEqual(dis.pretty_flags(0), '0x0') +# Fodder for instruction introspection tests +# Editing any of these may require recalculating the expected output +def outer(a=1, b=2): + def f(c=3, d=4): + def inner(e=5, f=6): + print(a, b, c, d, e, f) + print(a, b, c, d) + return inner + print(a, b, '', 1, [], {}, "Hello world!") + return f + +def jumpy(): + # This won't actually run (but that's OK, we only disassemble it) + for i in range(10): + print(i) + if i < 4: + continue + if i > 6: + break + else: + print("I can haz else clause?") + while i: + print(i) + i -= 1 + if i > 6: + continue + if i < 4: + break + else: + print("Who let lolcatz into this test suite?") + try: + 1 / 0 + except ZeroDivisionError: + print("Here we go, here we go, here we go...") + else: + with i as dodgy: + print("Never reach this") + finally: + print("OK, now we're done") + +# End fodder for opinfo generation tests +expected_outer_line = 1 +_line_offset = outer.__code__.co_firstlineno - 1 +code_object_f = outer.__code__.co_consts[3] +expected_f_line = code_object_f.co_firstlineno - _line_offset +code_object_inner = code_object_f.co_consts[3] +expected_inner_line = code_object_inner.co_firstlineno - _line_offset +expected_jumpy_line = 1 + +# The following lines are useful to regenerate the expected results after +# either the fodder is modified or the bytecode generation changes +# After regeneration, update the references to code_object_f and +# code_object_inner before rerunning the tests + +#_instructions = dis.get_instructions(outer, first_line=expected_outer_line) +#print('expected_opinfo_outer = [\n ', + #',\n '.join(map(str, _instructions)), ',\n]', sep='') +#_instructions = dis.get_instructions(outer(), first_line=expected_outer_line) +#print('expected_opinfo_f = [\n ', + #',\n '.join(map(str, _instructions)), ',\n]', sep='') +#_instructions = dis.get_instructions(outer()(), first_line=expected_outer_line) +#print('expected_opinfo_inner = [\n ', + #',\n '.join(map(str, _instructions)), ',\n]', sep='') +#_instructions = dis.get_instructions(jumpy, first_line=expected_jumpy_line) +#print('expected_opinfo_jumpy = [\n ', + #',\n '.join(map(str, _instructions)), ',\n]', sep='') + + +Instruction = dis.Instruction +expected_opinfo_outer = [ + Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=3, argrepr='3', offset=0, starts_line=2, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=3, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CLOSURE', opcode=135, arg=0, argval='a', argrepr='a', offset=6, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CLOSURE', opcode=135, arg=1, argval='b', argrepr='b', offset=9, starts_line=None, is_jump_target=False), + Instruction(opname='BUILD_TUPLE', opcode=102, arg=2, argval=2, argrepr='', offset=12, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=code_object_f, argrepr=repr(code_object_f), offset=15, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='outer.<locals>.f', argrepr="'outer.<locals>.f'", offset=18, starts_line=None, is_jump_target=False), + Instruction(opname='MAKE_CLOSURE', opcode=134, arg=2, argval=2, argrepr='', offset=21, starts_line=None, is_jump_target=False), + Instruction(opname='STORE_FAST', opcode=125, arg=2, argval='f', argrepr='f', offset=24, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='print', argrepr='print', offset=27, starts_line=7, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=0, argval='a', argrepr='a', offset=30, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=1, argval='b', argrepr='b', offset=33, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval='', argrepr="''", offset=36, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=6, argval=1, argrepr='1', offset=39, starts_line=None, is_jump_target=False), + Instruction(opname='BUILD_LIST', opcode=103, arg=0, argval=0, argrepr='', offset=42, starts_line=None, is_jump_target=False), + Instruction(opname='BUILD_MAP', opcode=105, arg=0, argval=0, argrepr='', offset=45, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=7, argval='Hello world!', argrepr="'Hello world!'", offset=48, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=7, argval=7, argrepr='7 positional, 0 keyword pair', offset=51, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=54, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=2, argval='f', argrepr='f', offset=55, starts_line=8, is_jump_target=False), + Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=58, starts_line=None, is_jump_target=False), +] + +expected_opinfo_f = [ + Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=5, argrepr='5', offset=0, starts_line=3, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=6, argrepr='6', offset=3, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CLOSURE', opcode=135, arg=2, argval='a', argrepr='a', offset=6, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CLOSURE', opcode=135, arg=3, argval='b', argrepr='b', offset=9, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CLOSURE', opcode=135, arg=0, argval='c', argrepr='c', offset=12, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CLOSURE', opcode=135, arg=1, argval='d', argrepr='d', offset=15, starts_line=None, is_jump_target=False), + Instruction(opname='BUILD_TUPLE', opcode=102, arg=4, argval=4, argrepr='', offset=18, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=code_object_inner, argrepr=repr(code_object_inner), offset=21, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='outer.<locals>.f.<locals>.inner', argrepr="'outer.<locals>.f.<locals>.inner'", offset=24, starts_line=None, is_jump_target=False), + Instruction(opname='MAKE_CLOSURE', opcode=134, arg=2, argval=2, argrepr='', offset=27, starts_line=None, is_jump_target=False), + Instruction(opname='STORE_FAST', opcode=125, arg=2, argval='inner', argrepr='inner', offset=30, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='print', argrepr='print', offset=33, starts_line=5, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=2, argval='a', argrepr='a', offset=36, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=3, argval='b', argrepr='b', offset=39, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=0, argval='c', argrepr='c', offset=42, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=1, argval='d', argrepr='d', offset=45, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=4, argval=4, argrepr='4 positional, 0 keyword pair', offset=48, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=51, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=2, argval='inner', argrepr='inner', offset=52, starts_line=6, is_jump_target=False), + Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=55, starts_line=None, is_jump_target=False), +] + +expected_opinfo_inner = [ + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='print', argrepr='print', offset=0, starts_line=4, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=0, argval='a', argrepr='a', offset=3, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=1, argval='b', argrepr='b', offset=6, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=2, argval='c', argrepr='c', offset=9, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_DEREF', opcode=136, arg=3, argval='d', argrepr='d', offset=12, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='e', argrepr='e', offset=15, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=1, argval='f', argrepr='f', offset=18, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=6, argval=6, argrepr='6 positional, 0 keyword pair', offset=21, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=24, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=25, starts_line=None, is_jump_target=False), + Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=28, starts_line=None, is_jump_target=False), +] + +expected_opinfo_jumpy = [ + Instruction(opname='SETUP_LOOP', opcode=120, arg=74, argval=77, argrepr='to 77', offset=0, starts_line=3, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='range', argrepr='range', offset=3, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=10, argrepr='10', offset=6, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=9, starts_line=None, is_jump_target=False), + Instruction(opname='GET_ITER', opcode=68, arg=None, argval=None, argrepr='', offset=12, starts_line=None, is_jump_target=False), + Instruction(opname='FOR_ITER', opcode=93, arg=50, argval=66, argrepr='to 66', offset=13, starts_line=None, is_jump_target=True), + Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=16, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=19, starts_line=4, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=22, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=25, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=28, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=29, starts_line=5, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=32, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=35, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=47, argval=47, argrepr='', offset=38, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=41, starts_line=6, is_jump_target=False), + Instruction(opname='JUMP_FORWARD', opcode=110, arg=0, argval=47, argrepr='to 47', offset=44, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=47, starts_line=7, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=50, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=53, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=13, argval=13, argrepr='', offset=56, starts_line=None, is_jump_target=False), + Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=59, starts_line=8, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=60, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=63, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=66, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=67, starts_line=10, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=70, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=73, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=76, starts_line=None, is_jump_target=False), + Instruction(opname='SETUP_LOOP', opcode=120, arg=74, argval=154, argrepr='to 154', offset=77, starts_line=11, is_jump_target=True), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=80, starts_line=None, is_jump_target=True), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=143, argval=143, argrepr='', offset=83, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=86, starts_line=12, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=89, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=92, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=95, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=96, starts_line=13, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=99, starts_line=None, is_jump_target=False), + Instruction(opname='INPLACE_SUBTRACT', opcode=56, arg=None, argval=None, argrepr='', offset=102, starts_line=None, is_jump_target=False), + Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=103, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=106, starts_line=14, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=109, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=112, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=124, argval=124, argrepr='', offset=115, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=118, starts_line=15, is_jump_target=False), + Instruction(opname='JUMP_FORWARD', opcode=110, arg=0, argval=124, argrepr='to 124', offset=121, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=124, starts_line=16, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=127, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=130, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=80, argval=80, argrepr='', offset=133, starts_line=None, is_jump_target=False), + Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=136, starts_line=17, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=137, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=140, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=143, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=144, starts_line=19, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=6, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=147, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=150, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=153, starts_line=None, is_jump_target=False), + Instruction(opname='SETUP_FINALLY', opcode=122, arg=72, argval=229, argrepr='to 229', offset=154, starts_line=20, is_jump_target=True), + Instruction(opname='SETUP_EXCEPT', opcode=121, arg=12, argval=172, argrepr='to 172', offset=157, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=160, starts_line=21, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=7, argval=0, argrepr='0', offset=163, starts_line=None, is_jump_target=False), + Instruction(opname='BINARY_TRUE_DIVIDE', opcode=27, arg=None, argval=None, argrepr='', offset=166, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=167, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=168, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_FORWARD', opcode=110, arg=28, argval=200, argrepr='to 200', offset=169, starts_line=None, is_jump_target=False), + Instruction(opname='DUP_TOP', opcode=4, arg=None, argval=None, argrepr='', offset=172, starts_line=22, is_jump_target=True), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=2, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=173, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=10, argval='exception match', argrepr='exception match', offset=176, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=199, argval=199, argrepr='', offset=179, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=182, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=183, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=184, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=185, starts_line=23, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=8, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=188, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=191, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=194, starts_line=None, is_jump_target=False), + Instruction(opname='POP_EXCEPT', opcode=89, arg=None, argval=None, argrepr='', offset=195, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_FORWARD', opcode=110, arg=26, argval=225, argrepr='to 225', offset=196, starts_line=None, is_jump_target=False), + Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=199, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=200, starts_line=25, is_jump_target=True), + Instruction(opname='SETUP_WITH', opcode=143, arg=17, argval=223, argrepr='to 223', offset=203, starts_line=None, is_jump_target=False), + Instruction(opname='STORE_FAST', opcode=125, arg=1, argval='dodgy', argrepr='dodgy', offset=206, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=209, starts_line=26, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=9, argval='Never reach this', argrepr="'Never reach this'", offset=212, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=215, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=218, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=219, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=220, starts_line=None, is_jump_target=False), + Instruction(opname='WITH_CLEANUP', opcode=81, arg=None, argval=None, argrepr='', offset=223, starts_line=None, is_jump_target=True), + Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=224, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=225, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=226, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=229, starts_line=28, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=10, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=232, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=235, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=238, starts_line=None, is_jump_target=False), + Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=239, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=240, starts_line=None, is_jump_target=False), + Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=243, starts_line=None, is_jump_target=False), +] + +# One last piece of inspect fodder to check the default line number handling +def simple(): pass +expected_opinfo_simple = [ + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=0, starts_line=simple.__code__.co_firstlineno, is_jump_target=False), + Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=3, starts_line=None, is_jump_target=False) +] + + +class InstructionTests(BytecodeTestCase): + + def test_default_first_line(self): + actual = dis.get_instructions(simple) + self.assertEqual(list(actual), expected_opinfo_simple) + + def test_first_line_set_to_None(self): + actual = dis.get_instructions(simple, first_line=None) + self.assertEqual(list(actual), expected_opinfo_simple) + + def test_outer(self): + actual = dis.get_instructions(outer, first_line=expected_outer_line) + self.assertEqual(list(actual), expected_opinfo_outer) + + def test_nested(self): + with captured_stdout(): + f = outer() + actual = dis.get_instructions(f, first_line=expected_f_line) + self.assertEqual(list(actual), expected_opinfo_f) + + def test_doubly_nested(self): + with captured_stdout(): + inner = outer()() + actual = dis.get_instructions(inner, first_line=expected_inner_line) + self.assertEqual(list(actual), expected_opinfo_inner) + + def test_jumpy(self): + actual = dis.get_instructions(jumpy, first_line=expected_jumpy_line) + self.assertEqual(list(actual), expected_opinfo_jumpy) + +# get_instructions has its own tests above, so can rely on it to validate +# the object oriented API +class BytecodeTests(unittest.TestCase): + def test_instantiation(self): + # Test with function, method, code string and code object + for obj in [_f, _C(1).__init__, "a=1", _f.__code__]: + with self.subTest(obj=obj): + b = dis.Bytecode(obj) + self.assertIsInstance(b.codeobj, types.CodeType) + + self.assertRaises(TypeError, dis.Bytecode, object()) + + def test_iteration(self): + for obj in [_f, _C(1).__init__, "a=1", _f.__code__]: + with self.subTest(obj=obj): + via_object = list(dis.Bytecode(obj)) + via_generator = list(dis.get_instructions(obj)) + self.assertEqual(via_object, via_generator) + + def test_explicit_first_line(self): + actual = dis.Bytecode(outer, first_line=expected_outer_line) + self.assertEqual(list(actual), expected_opinfo_outer) + + def test_source_line_in_disassembly(self): + # Use the line in the source code + actual = dis.Bytecode(simple).dis()[:3] + expected = "{:>3}".format(simple.__code__.co_firstlineno) + self.assertEqual(actual, expected) + # Use an explicit first line number + actual = dis.Bytecode(simple, first_line=350).dis()[:3] + self.assertEqual(actual, "350") + + def test_info(self): + self.maxDiff = 1000 + for x, expected in CodeInfoTests.test_pairs: + b = dis.Bytecode(x) + self.assertRegex(b.info(), expected) + + def test_disassembled(self): + actual = dis.Bytecode(_f).dis() + self.assertEqual(actual, dis_f) + + def test_from_traceback(self): + tb = get_tb() + b = dis.Bytecode.from_traceback(tb) + while tb.tb_next: tb = tb.tb_next + + self.assertEqual(b.current_offset, tb.tb_lasti) + + def test_from_traceback_dis(self): + tb = get_tb() + b = dis.Bytecode.from_traceback(tb) + self.assertEqual(b.dis(), dis_traceback) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_doctest.py b/Lib/test/test_doctest.py index 86259c3c56..56193e87b1 100644 --- a/Lib/test/test_doctest.py +++ b/Lib/test/test_doctest.py @@ -409,7 +409,8 @@ Compare `DocTestCase`: """ -def test_DocTestFinder(): r""" +class test_DocTestFinder: + def basics(): r""" Unit tests for the `DocTestFinder` class. DocTestFinder is used to extract DocTests from an object's docstring @@ -646,6 +647,39 @@ DocTestFinder finds the line number of each example: [1, 9, 12] """ + if int.__doc__: # simple check for --without-doc-strings, skip if lacking + def non_Python_modules(): r""" + +Finding Doctests in Modules Not Written in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +DocTestFinder can also find doctests in most modules not written in Python. +We'll use builtins as an example, since it almost certainly isn't written in +plain ol' Python and is guaranteed to be available. + + >>> import builtins + >>> tests = doctest.DocTestFinder().find(builtins) + >>> 790 < len(tests) < 800 # approximate number of objects with docstrings + True + >>> real_tests = [t for t in tests if len(t.examples) > 0] + >>> len(real_tests) # objects that actually have doctests + 8 + >>> for t in real_tests: + ... print('{} {}'.format(len(t.examples), t.name)) + ... + 1 builtins.bin + 3 builtins.float.as_integer_ratio + 2 builtins.float.fromhex + 2 builtins.float.hex + 1 builtins.hex + 1 builtins.int + 2 builtins.int.bit_length + 1 builtins.oct + +Note here that 'bin', 'oct', and 'hex' are functions; 'float.as_integer_ratio', +'float.hex', and 'int.bit_length' are methods; 'float.fromhex' is a classmethod, +and 'int' is a type. +""" + def test_DocTestParser(): r""" Unit tests for the `DocTestParser` class. @@ -1436,8 +1470,40 @@ However, output from `report_start` is not suppressed: 2 TestResults(failed=3, attempted=5) -For the purposes of REPORT_ONLY_FIRST_FAILURE, unexpected exceptions -count as failures: +The FAIL_FAST flag causes the runner to exit after the first failing example, +so subsequent examples are not even attempted: + + >>> flags = doctest.FAIL_FAST + >>> doctest.DocTestRunner(verbose=False, optionflags=flags).run(test) + ... # doctest: +ELLIPSIS + ********************************************************************** + File ..., line 5, in f + Failed example: + print(2) # first failure + Expected: + 200 + Got: + 2 + TestResults(failed=1, attempted=2) + +Specifying both FAIL_FAST and REPORT_ONLY_FIRST_FAILURE is equivalent to +FAIL_FAST only: + + >>> flags = doctest.FAIL_FAST | doctest.REPORT_ONLY_FIRST_FAILURE + >>> doctest.DocTestRunner(verbose=False, optionflags=flags).run(test) + ... # doctest: +ELLIPSIS + ********************************************************************** + File ..., line 5, in f + Failed example: + print(2) # first failure + Expected: + 200 + Got: + 2 + TestResults(failed=1, attempted=2) + +For the purposes of both REPORT_ONLY_FIRST_FAILURE and FAIL_FAST, unexpected +exceptions count as failures: >>> def f(x): ... r''' @@ -1464,6 +1530,17 @@ count as failures: ... ValueError: 2 TestResults(failed=3, attempted=5) + >>> flags = doctest.FAIL_FAST + >>> doctest.DocTestRunner(verbose=False, optionflags=flags).run(test) + ... # doctest: +ELLIPSIS + ********************************************************************** + File ..., line 5, in f + Failed example: + raise ValueError(2) # first failure + Exception raised: + ... + ValueError: 2 + TestResults(failed=1, attempted=2) New option flags can also be registered, via register_optionflag(). Here we reach into doctest's internals a bit. @@ -2580,6 +2657,240 @@ Check doctest with a non-ascii filename: TestResults(failed=1, attempted=1) """ +def test_CLI(): r""" +The doctest module can be used to run doctests against an arbitrary file. +These tests test this CLI functionality. + +We'll use the support module's script_helpers for this, and write a test files +to a temp dir to run the command against. Due to a current limitation in +script_helpers, though, we need a little utility function to turn the returned +output into something we can doctest against: + + >>> def normalize(s): + ... return '\n'.join(s.decode().splitlines()) + +Note: we also pass TERM='' to all the assert_python calls to avoid a bug +in the readline library that is triggered in these tests because we are +running them in a new python process. See: + + http://lists.gnu.org/archive/html/bug-readline/2013-06/msg00000.html + +With those preliminaries out of the way, we'll start with a file with two +simple tests and no errors. We'll run both the unadorned doctest command, and +the verbose version, and then check the output: + + >>> from test import script_helper + >>> with script_helper.temp_dir() as tmpdir: + ... fn = os.path.join(tmpdir, 'myfile.doc') + ... with open(fn, 'w') as f: + ... _ = f.write('This is a very simple test file.\n') + ... _ = f.write(' >>> 1 + 1\n') + ... _ = f.write(' 2\n') + ... _ = f.write(' >>> "a"\n') + ... _ = f.write(" 'a'\n") + ... _ = f.write('\n') + ... _ = f.write('And that is it.\n') + ... rc1, out1, err1 = script_helper.assert_python_ok( + ... '-m', 'doctest', fn, TERM='') + ... rc2, out2, err2 = script_helper.assert_python_ok( + ... '-m', 'doctest', '-v', fn, TERM='') + +With no arguments and passing tests, we should get no output: + + >>> rc1, out1, err1 + (0, b'', b'') + +With the verbose flag, we should see the test output, but no error output: + + >>> rc2, err2 + (0, b'') + >>> print(normalize(out2)) + Trying: + 1 + 1 + Expecting: + 2 + ok + Trying: + "a" + Expecting: + 'a' + ok + 1 items passed all tests: + 2 tests in myfile.doc + 2 tests in 1 items. + 2 passed and 0 failed. + Test passed. + +Now we'll write a couple files, one with three tests, the other a python module +with two tests, both of the files having "errors" in the tests that can be made +non-errors by applying the appropriate doctest options to the run (ELLIPSIS in +the first file, NORMALIZE_WHITESPACE in the second). This combination will +allow to thoroughly test the -f and -o flags, as well as the doctest command's +ability to process more than one file on the command line and, since the second +file ends in '.py', its handling of python module files (as opposed to straight +text files). + + >>> from test import script_helper + >>> with script_helper.temp_dir() as tmpdir: + ... fn = os.path.join(tmpdir, 'myfile.doc') + ... with open(fn, 'w') as f: + ... _ = f.write('This is another simple test file.\n') + ... _ = f.write(' >>> 1 + 1\n') + ... _ = f.write(' 2\n') + ... _ = f.write(' >>> "abcdef"\n') + ... _ = f.write(" 'a...f'\n") + ... _ = f.write(' >>> "ajkml"\n') + ... _ = f.write(" 'a...l'\n") + ... _ = f.write('\n') + ... _ = f.write('And that is it.\n') + ... fn2 = os.path.join(tmpdir, 'myfile2.py') + ... with open(fn2, 'w') as f: + ... _ = f.write('def test_func():\n') + ... _ = f.write(' \"\"\"\n') + ... _ = f.write(' This is simple python test function.\n') + ... _ = f.write(' >>> 1 + 1\n') + ... _ = f.write(' 2\n') + ... _ = f.write(' >>> "abc def"\n') + ... _ = f.write(" 'abc def'\n") + ... _ = f.write("\n") + ... _ = f.write(' \"\"\"\n') + ... import shutil + ... rc1, out1, err1 = script_helper.assert_python_failure( + ... '-m', 'doctest', fn, fn2, TERM='') + ... rc2, out2, err2 = script_helper.assert_python_ok( + ... '-m', 'doctest', '-o', 'ELLIPSIS', fn, TERM='') + ... rc3, out3, err3 = script_helper.assert_python_ok( + ... '-m', 'doctest', '-o', 'ELLIPSIS', + ... '-o', 'NORMALIZE_WHITESPACE', fn, fn2, TERM='') + ... rc4, out4, err4 = script_helper.assert_python_failure( + ... '-m', 'doctest', '-f', fn, fn2, TERM='') + ... rc5, out5, err5 = script_helper.assert_python_ok( + ... '-m', 'doctest', '-v', '-o', 'ELLIPSIS', + ... '-o', 'NORMALIZE_WHITESPACE', fn, fn2, TERM='') + +Our first test run will show the errors from the first file (doctest stops if a +file has errors). Note that doctest test-run error output appears on stdout, +not stderr: + + >>> rc1, err1 + (1, b'') + >>> print(normalize(out1)) # doctest: +ELLIPSIS + ********************************************************************** + File "...myfile.doc", line 4, in myfile.doc + Failed example: + "abcdef" + Expected: + 'a...f' + Got: + 'abcdef' + ********************************************************************** + File "...myfile.doc", line 6, in myfile.doc + Failed example: + "ajkml" + Expected: + 'a...l' + Got: + 'ajkml' + ********************************************************************** + 1 items had failures: + 2 of 3 in myfile.doc + ***Test Failed*** 2 failures. + +With -o ELLIPSIS specified, the second run, against just the first file, should +produce no errors, and with -o NORMALIZE_WHITESPACE also specified, neither +should the third, which ran against both files: + + >>> rc2, out2, err2 + (0, b'', b'') + >>> rc3, out3, err3 + (0, b'', b'') + +The fourth run uses FAIL_FAST, so we should see only one error: + + >>> rc4, err4 + (1, b'') + >>> print(normalize(out4)) # doctest: +ELLIPSIS + ********************************************************************** + File "...myfile.doc", line 4, in myfile.doc + Failed example: + "abcdef" + Expected: + 'a...f' + Got: + 'abcdef' + ********************************************************************** + 1 items had failures: + 1 of 2 in myfile.doc + ***Test Failed*** 1 failures. + +The fifth test uses verbose with the two options, so we should get verbose +success output for the tests in both files: + + >>> rc5, err5 + (0, b'') + >>> print(normalize(out5)) + Trying: + 1 + 1 + Expecting: + 2 + ok + Trying: + "abcdef" + Expecting: + 'a...f' + ok + Trying: + "ajkml" + Expecting: + 'a...l' + ok + 1 items passed all tests: + 3 tests in myfile.doc + 3 tests in 1 items. + 3 passed and 0 failed. + Test passed. + Trying: + 1 + 1 + Expecting: + 2 + ok + Trying: + "abc def" + Expecting: + 'abc def' + ok + 1 items had no tests: + myfile2 + 1 items passed all tests: + 2 tests in myfile2.test_func + 2 tests in 2 items. + 2 passed and 0 failed. + Test passed. + +We should also check some typical error cases. + +Invalid file name: + + >>> rc, out, err = script_helper.assert_python_failure( + ... '-m', 'doctest', 'nosuchfile', TERM='') + >>> rc, out + (1, b'') + >>> print(normalize(err)) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + FileNotFoundError: [Errno ...] No such file or directory: 'nosuchfile' + +Invalid doctest option: + + >>> rc, out, err = script_helper.assert_python_failure( + ... '-m', 'doctest', '-o', 'nosuchoption', TERM='') + >>> rc, out + (2, b'') + >>> print(normalize(err)) # doctest: +ELLIPSIS + usage...invalid...nosuchoption... + +""" + ###################################################################### ## Main ###################################################################### diff --git a/Lib/test/test_dynamicclassattribute.py b/Lib/test/test_dynamicclassattribute.py new file mode 100644 index 0000000000..079d6c3224 --- /dev/null +++ b/Lib/test/test_dynamicclassattribute.py @@ -0,0 +1,304 @@ +# Test case for DynamicClassAttribute +# more tests are in test_descr + +import abc +import sys +import unittest +from test.support import run_unittest +from types import DynamicClassAttribute + +class PropertyBase(Exception): + pass + +class PropertyGet(PropertyBase): + pass + +class PropertySet(PropertyBase): + pass + +class PropertyDel(PropertyBase): + pass + +class BaseClass(object): + def __init__(self): + self._spam = 5 + + @DynamicClassAttribute + def spam(self): + """BaseClass.getter""" + return self._spam + + @spam.setter + def spam(self, value): + self._spam = value + + @spam.deleter + def spam(self): + del self._spam + +class SubClass(BaseClass): + + spam = BaseClass.__dict__['spam'] + + @spam.getter + def spam(self): + """SubClass.getter""" + raise PropertyGet(self._spam) + + @spam.setter + def spam(self, value): + raise PropertySet(self._spam) + + @spam.deleter + def spam(self): + raise PropertyDel(self._spam) + +class PropertyDocBase(object): + _spam = 1 + def _get_spam(self): + return self._spam + spam = DynamicClassAttribute(_get_spam, doc="spam spam spam") + +class PropertyDocSub(PropertyDocBase): + spam = PropertyDocBase.__dict__['spam'] + @spam.getter + def spam(self): + """The decorator does not use this doc string""" + return self._spam + +class PropertySubNewGetter(BaseClass): + spam = BaseClass.__dict__['spam'] + @spam.getter + def spam(self): + """new docstring""" + return 5 + +class PropertyNewGetter(object): + @DynamicClassAttribute + def spam(self): + """original docstring""" + return 1 + @spam.getter + def spam(self): + """new docstring""" + return 8 + +class ClassWithAbstractVirtualProperty(metaclass=abc.ABCMeta): + @DynamicClassAttribute + @abc.abstractmethod + def color(): + pass + +class ClassWithPropertyAbstractVirtual(metaclass=abc.ABCMeta): + @abc.abstractmethod + @DynamicClassAttribute + def color(): + pass + +class PropertyTests(unittest.TestCase): + def test_property_decorator_baseclass(self): + # see #1620 + base = BaseClass() + self.assertEqual(base.spam, 5) + self.assertEqual(base._spam, 5) + base.spam = 10 + self.assertEqual(base.spam, 10) + self.assertEqual(base._spam, 10) + delattr(base, "spam") + self.assertTrue(not hasattr(base, "spam")) + self.assertTrue(not hasattr(base, "_spam")) + base.spam = 20 + self.assertEqual(base.spam, 20) + self.assertEqual(base._spam, 20) + + def test_property_decorator_subclass(self): + # see #1620 + sub = SubClass() + self.assertRaises(PropertyGet, getattr, sub, "spam") + self.assertRaises(PropertySet, setattr, sub, "spam", None) + self.assertRaises(PropertyDel, delattr, sub, "spam") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_decorator_subclass_doc(self): + sub = SubClass() + self.assertEqual(sub.__class__.__dict__['spam'].__doc__, "SubClass.getter") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_decorator_baseclass_doc(self): + base = BaseClass() + self.assertEqual(base.__class__.__dict__['spam'].__doc__, "BaseClass.getter") + + def test_property_decorator_doc(self): + base = PropertyDocBase() + sub = PropertyDocSub() + self.assertEqual(base.__class__.__dict__['spam'].__doc__, "spam spam spam") + self.assertEqual(sub.__class__.__dict__['spam'].__doc__, "spam spam spam") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_getter_doc_override(self): + newgettersub = PropertySubNewGetter() + self.assertEqual(newgettersub.spam, 5) + self.assertEqual(newgettersub.__class__.__dict__['spam'].__doc__, "new docstring") + newgetter = PropertyNewGetter() + self.assertEqual(newgetter.spam, 8) + self.assertEqual(newgetter.__class__.__dict__['spam'].__doc__, "new docstring") + + def test_property___isabstractmethod__descriptor(self): + for val in (True, False, [], [1], '', '1'): + class C(object): + def foo(self): + pass + foo.__isabstractmethod__ = val + foo = DynamicClassAttribute(foo) + self.assertIs(C.__dict__['foo'].__isabstractmethod__, bool(val)) + + # check that the DynamicClassAttribute's __isabstractmethod__ descriptor does the + # right thing when presented with a value that fails truth testing: + class NotBool(object): + def __nonzero__(self): + raise ValueError() + __len__ = __nonzero__ + with self.assertRaises(ValueError): + class C(object): + def foo(self): + pass + foo.__isabstractmethod__ = NotBool() + foo = DynamicClassAttribute(foo) + + def test_abstract_virtual(self): + self.assertRaises(TypeError, ClassWithAbstractVirtualProperty) + self.assertRaises(TypeError, ClassWithPropertyAbstractVirtual) + class APV(ClassWithPropertyAbstractVirtual): + pass + self.assertRaises(TypeError, APV) + class AVP(ClassWithAbstractVirtualProperty): + pass + self.assertRaises(TypeError, AVP) + class Okay1(ClassWithAbstractVirtualProperty): + @DynamicClassAttribute + def color(self): + return self._color + def __init__(self): + self._color = 'cyan' + with self.assertRaises(AttributeError): + Okay1.color + self.assertEqual(Okay1().color, 'cyan') + class Okay2(ClassWithAbstractVirtualProperty): + @DynamicClassAttribute + def color(self): + return self._color + def __init__(self): + self._color = 'magenta' + with self.assertRaises(AttributeError): + Okay2.color + self.assertEqual(Okay2().color, 'magenta') + + +# Issue 5890: subclasses of DynamicClassAttribute do not preserve method __doc__ strings +class PropertySub(DynamicClassAttribute): + """This is a subclass of DynamicClassAttribute""" + +class PropertySubSlots(DynamicClassAttribute): + """This is a subclass of DynamicClassAttribute that defines __slots__""" + __slots__ = () + +class PropertySubclassTests(unittest.TestCase): + + @unittest.skipIf(hasattr(PropertySubSlots, '__doc__'), + "__doc__ is already present, __slots__ will have no effect") + def test_slots_docstring_copy_exception(self): + try: + class Foo(object): + @PropertySubSlots + def spam(self): + """Trying to copy this docstring will raise an exception""" + return 1 + print('\n',spam.__doc__) + except AttributeError: + pass + else: + raise Exception("AttributeError not raised") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_docstring_copy(self): + class Foo(object): + @PropertySub + def spam(self): + """spam wrapped in DynamicClassAttribute subclass""" + return 1 + self.assertEqual( + Foo.__dict__['spam'].__doc__, + "spam wrapped in DynamicClassAttribute subclass") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_setter_copies_getter_docstring(self): + class Foo(object): + def __init__(self): self._spam = 1 + @PropertySub + def spam(self): + """spam wrapped in DynamicClassAttribute subclass""" + return self._spam + @spam.setter + def spam(self, value): + """this docstring is ignored""" + self._spam = value + foo = Foo() + self.assertEqual(foo.spam, 1) + foo.spam = 2 + self.assertEqual(foo.spam, 2) + self.assertEqual( + Foo.__dict__['spam'].__doc__, + "spam wrapped in DynamicClassAttribute subclass") + class FooSub(Foo): + spam = Foo.__dict__['spam'] + @spam.setter + def spam(self, value): + """another ignored docstring""" + self._spam = 'eggs' + foosub = FooSub() + self.assertEqual(foosub.spam, 1) + foosub.spam = 7 + self.assertEqual(foosub.spam, 'eggs') + self.assertEqual( + FooSub.__dict__['spam'].__doc__, + "spam wrapped in DynamicClassAttribute subclass") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_new_getter_new_docstring(self): + + class Foo(object): + @PropertySub + def spam(self): + """a docstring""" + return 1 + @spam.getter + def spam(self): + """a new docstring""" + return 2 + self.assertEqual(Foo.__dict__['spam'].__doc__, "a new docstring") + class FooBase(object): + @PropertySub + def spam(self): + """a docstring""" + return 1 + class Foo2(FooBase): + spam = FooBase.__dict__['spam'] + @spam.getter + def spam(self): + """a new docstring""" + return 2 + self.assertEqual(Foo.__dict__['spam'].__doc__, "a new docstring") + + + +def test_main(): + run_unittest(PropertyTests, PropertySubclassTests) + +if __name__ == '__main__': + test_main() diff --git a/Lib/test/test_email/__init__.py b/Lib/test/test_email/__init__.py index f206ace10c..d8896eed46 100644 --- a/Lib/test/test_email/__init__.py +++ b/Lib/test/test_email/__init__.py @@ -2,6 +2,7 @@ import os import sys import unittest import test.support +import collections import email from email.message import Message from email._policybase import compat32 @@ -42,6 +43,8 @@ class TestEmailBase(unittest.TestCase): # here we make minimal changes in the test_email tests compared to their # pre-3.3 state. policy = compat32 + # Likewise, the default message object is Message. + message = Message def __init__(self, *args, **kw): super().__init__(*args, **kw) @@ -54,11 +57,23 @@ class TestEmailBase(unittest.TestCase): with openfile(filename) as fp: return email.message_from_file(fp, policy=self.policy) - def _str_msg(self, string, message=Message, policy=None): + def _str_msg(self, string, message=None, policy=None): if policy is None: policy = self.policy + if message is None: + message = self.message return email.message_from_string(string, message, policy=policy) + def _bytes_msg(self, bytestring, message=None, policy=None): + if policy is None: + policy = self.policy + if message is None: + message = self.message + return email.message_from_bytes(bytestring, message, policy=policy) + + def _make_message(self): + return self.message(policy=self.policy) + def _bytes_repr(self, b): return [repr(x) for x in b.splitlines(keepends=True)] @@ -123,6 +138,7 @@ def parameterize(cls): """ paramdicts = {} + testers = collections.defaultdict(list) for name, attr in cls.__dict__.items(): if name.endswith('_params'): if not hasattr(attr, 'keys'): @@ -134,7 +150,15 @@ def parameterize(cls): d[n] = x attr = d paramdicts[name[:-7] + '_as_'] = attr + if '_as_' in name: + testers[name.split('_as_')[0] + '_as_'].append(name) testfuncs = {} + for name in paramdicts: + if name not in testers: + raise ValueError("No tester found for {}".format(name)) + for name in testers: + if name not in paramdicts: + raise ValueError("No params found for {}".format(name)) for name, attr in cls.__dict__.items(): for paramsname, paramsdict in paramdicts.items(): if name.startswith(paramsname): diff --git a/Lib/test/test_email/test_contentmanager.py b/Lib/test/test_email/test_contentmanager.py new file mode 100644 index 0000000000..cdb04e45ed --- /dev/null +++ b/Lib/test/test_email/test_contentmanager.py @@ -0,0 +1,796 @@ +import unittest +from test.test_email import TestEmailBase, parameterize +import textwrap +from email import policy +from email.message import EmailMessage +from email.contentmanager import ContentManager, raw_data_manager + + +@parameterize +class TestContentManager(TestEmailBase): + + policy = policy.default + message = EmailMessage + + get_key_params = { + 'full_type': (1, 'text/plain',), + 'maintype_only': (2, 'text',), + 'null_key': (3, '',), + } + + def get_key_as_get_content_key(self, order, key): + def foo_getter(msg, foo=None): + bar = msg['X-Bar-Header'] + return foo, bar + cm = ContentManager() + cm.add_get_handler(key, foo_getter) + m = self._make_message() + m['Content-Type'] = 'text/plain' + m['X-Bar-Header'] = 'foo' + self.assertEqual(cm.get_content(m, foo='bar'), ('bar', 'foo')) + + def get_key_as_get_content_key_order(self, order, key): + def bar_getter(msg): + return msg['X-Bar-Header'] + def foo_getter(msg): + return msg['X-Foo-Header'] + cm = ContentManager() + cm.add_get_handler(key, foo_getter) + for precedence, key in self.get_key_params.values(): + if precedence > order: + cm.add_get_handler(key, bar_getter) + m = self._make_message() + m['Content-Type'] = 'text/plain' + m['X-Bar-Header'] = 'bar' + m['X-Foo-Header'] = 'foo' + self.assertEqual(cm.get_content(m), ('foo')) + + def test_get_content_raises_if_unknown_mimetype_and_no_default(self): + cm = ContentManager() + m = self._make_message() + m['Content-Type'] = 'text/plain' + with self.assertRaisesRegex(KeyError, 'text/plain'): + cm.get_content(m) + + class BaseThing(str): + pass + baseobject_full_path = __name__ + '.' + 'TestContentManager.BaseThing' + class Thing(BaseThing): + pass + testobject_full_path = __name__ + '.' + 'TestContentManager.Thing' + + set_key_params = { + 'type': (0, Thing,), + 'full_path': (1, testobject_full_path,), + 'qualname': (2, 'TestContentManager.Thing',), + 'name': (3, 'Thing',), + 'base_type': (4, BaseThing,), + 'base_full_path': (5, baseobject_full_path,), + 'base_qualname': (6, 'TestContentManager.BaseThing',), + 'base_name': (7, 'BaseThing',), + 'str_type': (8, str,), + 'str_full_path': (9, 'builtins.str',), + 'str_name': (10, 'str',), # str name and qualname are the same + 'null_key': (11, None,), + } + + def set_key_as_set_content_key(self, order, key): + def foo_setter(msg, obj, foo=None): + msg['X-Foo-Header'] = foo + msg.set_payload(obj) + cm = ContentManager() + cm.add_set_handler(key, foo_setter) + m = self._make_message() + msg_obj = self.Thing() + cm.set_content(m, msg_obj, foo='bar') + self.assertEqual(m['X-Foo-Header'], 'bar') + self.assertEqual(m.get_payload(), msg_obj) + + def set_key_as_set_content_key_order(self, order, key): + def foo_setter(msg, obj): + msg['X-FooBar-Header'] = 'foo' + msg.set_payload(obj) + def bar_setter(msg, obj): + msg['X-FooBar-Header'] = 'bar' + cm = ContentManager() + cm.add_set_handler(key, foo_setter) + for precedence, key in self.get_key_params.values(): + if precedence > order: + cm.add_set_handler(key, bar_setter) + m = self._make_message() + msg_obj = self.Thing() + cm.set_content(m, msg_obj) + self.assertEqual(m['X-FooBar-Header'], 'foo') + self.assertEqual(m.get_payload(), msg_obj) + + def test_set_content_raises_if_unknown_type_and_no_default(self): + cm = ContentManager() + m = self._make_message() + msg_obj = self.Thing() + with self.assertRaisesRegex(KeyError, self.testobject_full_path): + cm.set_content(m, msg_obj) + + def test_set_content_raises_if_called_on_multipart(self): + cm = ContentManager() + m = self._make_message() + m['Content-Type'] = 'multipart/foo' + with self.assertRaises(TypeError): + cm.set_content(m, 'test') + + def test_set_content_calls_clear_content(self): + m = self._make_message() + m['Content-Foo'] = 'bar' + m['Content-Type'] = 'text/html' + m['To'] = 'test' + m.set_payload('abc') + cm = ContentManager() + cm.add_set_handler(str, lambda *args, **kw: None) + m.set_content('xyz', content_manager=cm) + self.assertIsNone(m['Content-Foo']) + self.assertIsNone(m['Content-Type']) + self.assertEqual(m['To'], 'test') + self.assertIsNone(m.get_payload()) + + +@parameterize +class TestRawDataManager(TestEmailBase): + # Note: these tests are dependent on the order in which headers are added + # to the message objects by the code. There's no defined ordering in + # RFC5322/MIME, so this makes the tests more fragile than the standards + # require. However, if the header order changes it is best to understand + # *why*, and make sure it isn't a subtle bug in whatever change was + # applied. + + policy = policy.default.clone(max_line_length=60, + content_manager=raw_data_manager) + message = EmailMessage + + def test_get_text_plain(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain + + Basic text. + """)) + self.assertEqual(raw_data_manager.get_content(m), "Basic text.\n") + + def test_get_text_html(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/html + + <p>Basic text.</p> + """)) + self.assertEqual(raw_data_manager.get_content(m), + "<p>Basic text.</p>\n") + + def test_get_text_plain_latin1(self): + m = self._bytes_msg(textwrap.dedent("""\ + Content-Type: text/plain; charset=latin1 + + Basìc tëxt. + """).encode('latin1')) + self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt.\n") + + def test_get_text_plain_latin1_quoted_printable(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain; charset="latin-1" + Content-Transfer-Encoding: quoted-printable + + Bas=ECc t=EBxt. + """)) + self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt.\n") + + def test_get_text_plain_utf8_base64(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain; charset="utf8" + Content-Transfer-Encoding: base64 + + QmFzw6xjIHTDq3h0Lgo= + """)) + self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt.\n") + + def test_get_text_plain_bad_utf8_quoted_printable(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain; charset="utf8" + Content-Transfer-Encoding: quoted-printable + + Bas=c3=acc t=c3=abxt=fd. + """)) + self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt�.\n") + + def test_get_text_plain_bad_utf8_quoted_printable_ignore_errors(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain; charset="utf8" + Content-Transfer-Encoding: quoted-printable + + Bas=c3=acc t=c3=abxt=fd. + """)) + self.assertEqual(raw_data_manager.get_content(m, errors='ignore'), + "Basìc tëxt.\n") + + def test_get_text_plain_utf8_base64_recoverable_bad_CTE_data(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain; charset="utf8" + Content-Transfer-Encoding: base64 + + QmFzw6xjIHTDq3h0Lgo\xFF= + """)) + self.assertEqual(raw_data_manager.get_content(m, errors='ignore'), + "Basìc tëxt.\n") + + def test_get_text_invalid_keyword(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: text/plain + + Basic text. + """)) + with self.assertRaises(TypeError): + raw_data_manager.get_content(m, foo='ignore') + + def test_get_non_text(self): + template = textwrap.dedent("""\ + Content-Type: {} + Content-Transfer-Encoding: base64 + + Ym9ndXMgZGF0YQ== + """) + for maintype in 'audio image video application'.split(): + with self.subTest(maintype=maintype): + m = self._str_msg(template.format(maintype+'/foo')) + self.assertEqual(raw_data_manager.get_content(m), b"bogus data") + + def test_get_non_text_invalid_keyword(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: image/jpg + Content-Transfer-Encoding: base64 + + Ym9ndXMgZGF0YQ== + """)) + with self.assertRaises(TypeError): + raw_data_manager.get_content(m, errors='ignore') + + def test_get_raises_on_multipart(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: multipart/mixed; boundary="===" + + --=== + --===-- + """)) + with self.assertRaises(KeyError): + raw_data_manager.get_content(m) + + def test_get_message_rfc822_and_external_body(self): + template = textwrap.dedent("""\ + Content-Type: message/{} + + To: foo@example.com + From: bar@example.com + Subject: example + + an example message + """) + for subtype in 'rfc822 external-body'.split(): + with self.subTest(subtype=subtype): + m = self._str_msg(template.format(subtype)) + sub_msg = raw_data_manager.get_content(m) + self.assertIsInstance(sub_msg, self.message) + self.assertEqual(raw_data_manager.get_content(sub_msg), + "an example message\n") + self.assertEqual(sub_msg['to'], 'foo@example.com') + self.assertEqual(sub_msg['from'].addresses[0].username, 'bar') + + def test_get_message_non_rfc822_or_external_body_yields_bytes(self): + m = self._str_msg(textwrap.dedent("""\ + Content-Type: message/partial + + To: foo@example.com + From: bar@example.com + Subject: example + + The real body is in another message. + """)) + self.assertEqual(raw_data_manager.get_content(m)[:10], b'To: foo@ex') + + def test_set_text_plain(self): + m = self._make_message() + content = "Simple message.\n" + raw_data_manager.set_content(m, content) + self.assertEqual(str(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 7bit + + Simple message. + """)) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_html(self): + m = self._make_message() + content = "<p>Simple message.</p>\n" + raw_data_manager.set_content(m, content, subtype='html') + self.assertEqual(str(m), textwrap.dedent("""\ + Content-Type: text/html; charset="utf-8" + Content-Transfer-Encoding: 7bit + + <p>Simple message.</p> + """)) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_charset_latin_1(self): + m = self._make_message() + content = "Simple message.\n" + raw_data_manager.set_content(m, content, charset='latin-1') + self.assertEqual(str(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="iso-8859-1" + Content-Transfer-Encoding: 7bit + + Simple message. + """)) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_short_line_minimal_non_ascii_heuristics(self): + m = self._make_message() + content = "et là il est monté sur moi et il commence à m'éto.\n" + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + + et là il est monté sur moi et il commence à m'éto. + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_long_line_minimal_non_ascii_heuristics(self): + m = self._make_message() + content = ("j'ai un problème de python. il est sorti de son" + " vivarium. et là il est monté sur moi et il commence" + " à m'éto.\n") + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: quoted-printable + + j'ai un probl=C3=A8me de python. il est sorti de son vivari= + um. et l=C3=A0 il est mont=C3=A9 sur moi et il commence = + =C3=A0 m'=C3=A9to. + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_11_lines_long_line_minimal_non_ascii_heuristics(self): + m = self._make_message() + content = '\n'*10 + ( + "j'ai un problème de python. il est sorti de son" + " vivarium. et là il est monté sur moi et il commence" + " à m'éto.\n") + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: quoted-printable + """ + '\n'*10 + """ + j'ai un probl=C3=A8me de python. il est sorti de son vivari= + um. et l=C3=A0 il est mont=C3=A9 sur moi et il commence = + =C3=A0 m'=C3=A9to. + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_maximal_non_ascii_heuristics(self): + m = self._make_message() + content = "áàäéèęöő.\n" + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + + áàäéèęöő. + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_11_lines_maximal_non_ascii_heuristics(self): + m = self._make_message() + content = '\n'*10 + "áàäéèęöő.\n" + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + """ + '\n'*10 + """ + áàäéèęöő. + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_long_line_maximal_non_ascii_heuristics(self): + m = self._make_message() + content = ("áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő" + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő" + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő.\n") + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: base64 + + w6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOgw6TDqcOoxJnD + tsWRw6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOgw6TDqcOo + xJnDtsWRw6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOgw6TD + qcOoxJnDtsWRw6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOg + w6TDqcOoxJnDtsWRLgo= + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_11_lines_long_line_maximal_non_ascii_heuristics(self): + # Yes, it chooses "wrong" here. It's a heuristic. So this result + # could change if we come up with a better heuristic. + m = self._make_message() + content = ('\n'*10 + + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő" + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő" + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő.\n") + raw_data_manager.set_content(m, "\n"*10 + + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő" + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő" + "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő.\n") + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: quoted-printable + """ + '\n'*10 + """ + =C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3= + =A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4= + =C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3= + =A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99= + =C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5= + =91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1= + =C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3= + =A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9= + =C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4= + =99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6= + =C5=91. + """).encode('utf-8')) + self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content) + self.assertEqual(m.get_content(), content) + + def test_set_text_non_ascii_with_cte_7bit_raises(self): + m = self._make_message() + with self.assertRaises(UnicodeError): + raw_data_manager.set_content(m,"áàäéèęöő.\n", cte='7bit') + + def test_set_text_non_ascii_with_charset_ascii_raises(self): + m = self._make_message() + with self.assertRaises(UnicodeError): + raw_data_manager.set_content(m,"áàäéèęöő.\n", charset='ascii') + + def test_set_text_non_ascii_with_cte_7bit_and_charset_ascii_raises(self): + m = self._make_message() + with self.assertRaises(UnicodeError): + raw_data_manager.set_content(m,"áàäéèęöő.\n", cte='7bit', charset='ascii') + + def test_set_message(self): + m = self._make_message() + m['Subject'] = "Forwarded message" + content = self._make_message() + content['To'] = 'python@vivarium.org' + content['From'] = 'police@monty.org' + content['Subject'] = "get back in your box" + content.set_content("Or face the comfy chair.") + raw_data_manager.set_content(m, content) + self.assertEqual(str(m), textwrap.dedent("""\ + Subject: Forwarded message + Content-Type: message/rfc822 + Content-Transfer-Encoding: 8bit + + To: python@vivarium.org + From: police@monty.org + Subject: get back in your box + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 7bit + MIME-Version: 1.0 + + Or face the comfy chair. + """)) + payload = m.get_payload(0) + self.assertIsInstance(payload, self.message) + self.assertEqual(str(payload), str(content)) + self.assertIsInstance(m.get_content(), self.message) + self.assertEqual(str(m.get_content()), str(content)) + + def test_set_message_with_non_ascii_and_coercion_to_7bit(self): + m = self._make_message() + m['Subject'] = "Escape report" + content = self._make_message() + content['To'] = 'police@monty.org' + content['From'] = 'victim@monty.org' + content['Subject'] = "Help" + content.set_content("j'ai un problème de python. il est sorti de son" + " vivarium.") + raw_data_manager.set_content(m, content) + self.assertEqual(bytes(m), textwrap.dedent("""\ + Subject: Escape report + Content-Type: message/rfc822 + Content-Transfer-Encoding: 8bit + + To: police@monty.org + From: victim@monty.org + Subject: Help + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + MIME-Version: 1.0 + + j'ai un problème de python. il est sorti de son vivarium. + """).encode('utf-8')) + # The choice of base64 for the body encoding is because generator + # doesn't bother with heuristics and uses it unconditionally for utf-8 + # text. + # XXX: the first cte should be 7bit, too...that's a generator bug. + # XXX: the line length in the body also looks like a generator bug. + self.assertEqual(m.as_string(maxheaderlen=self.policy.max_line_length), + textwrap.dedent("""\ + Subject: Escape report + Content-Type: message/rfc822 + Content-Transfer-Encoding: 8bit + + To: police@monty.org + From: victim@monty.org + Subject: Help + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: base64 + MIME-Version: 1.0 + + aidhaSB1biBwcm9ibMOobWUgZGUgcHl0aG9uLiBpbCBlc3Qgc29ydGkgZGUgc29uIHZpdmFyaXVt + Lgo= + """)) + self.assertIsInstance(m.get_content(), self.message) + self.assertEqual(str(m.get_content()), str(content)) + + def test_set_message_invalid_cte_raises(self): + m = self._make_message() + content = self._make_message() + for cte in 'quoted-printable base64'.split(): + for subtype in 'rfc822 external-body'.split(): + with self.subTest(cte=cte, subtype=subtype): + with self.assertRaises(ValueError) as ar: + m.set_content(content, subtype, cte=cte) + exc = str(ar.exception) + self.assertIn(cte, exc) + self.assertIn(subtype, exc) + subtype = 'external-body' + for cte in '8bit binary'.split(): + with self.subTest(cte=cte, subtype=subtype): + with self.assertRaises(ValueError) as ar: + m.set_content(content, subtype, cte=cte) + exc = str(ar.exception) + self.assertIn(cte, exc) + self.assertIn(subtype, exc) + + def test_set_image_jpg(self): + for content in (b"bogus content", + bytearray(b"bogus content"), + memoryview(b"bogus content")): + with self.subTest(content=content): + m = self._make_message() + raw_data_manager.set_content(m, content, 'image', 'jpeg') + self.assertEqual(str(m), textwrap.dedent("""\ + Content-Type: image/jpeg + Content-Transfer-Encoding: base64 + + Ym9ndXMgY29udGVudA== + """)) + self.assertEqual(m.get_payload(decode=True), content) + self.assertEqual(m.get_content(), content) + + def test_set_audio_aif_with_quoted_printable_cte(self): + # Why you would use qp, I don't know, but it is technically supported. + # XXX: the incorrect line length is because binascii.b2a_qp doesn't + # support a line length parameter, but we must use it to get newline + # encoding. + # XXX: what about that lack of tailing newline? Do we actually handle + # that correctly in all cases? That is, if the *source* has an + # unencoded newline, do we add an extra newline to the returned payload + # or not? And can that actually be disambiguated based on the RFC? + m = self._make_message() + content = b'b\xFFgus\tcon\nt\rent ' + b'z'*100 + m.set_content(content, 'audio', 'aif', cte='quoted-printable') + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: audio/aif + Content-Transfer-Encoding: quoted-printable + MIME-Version: 1.0 + + b=FFgus=09con=0At=0Dent=20zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz= + zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz""").encode('latin-1')) + self.assertEqual(m.get_payload(decode=True), content) + self.assertEqual(m.get_content(), content) + + def test_set_video_mpeg_with_binary_cte(self): + m = self._make_message() + content = b'b\xFFgus\tcon\nt\rent ' + b'z'*100 + m.set_content(content, 'video', 'mpeg', cte='binary') + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: video/mpeg + Content-Transfer-Encoding: binary + MIME-Version: 1.0 + + """).encode('ascii') + + # XXX: the second \n ought to be a \r, but generator gets it wrong. + # THIS MEANS WE DON'T ACTUALLY SUPPORT THE 'binary' CTE. + b'b\xFFgus\tcon\nt\nent zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz' + + b'zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz') + self.assertEqual(m.get_payload(decode=True), content) + self.assertEqual(m.get_content(), content) + + def test_set_application_octet_stream_with_8bit_cte(self): + # In 8bit mode, univeral line end logic applies. It is up to the + # application to make sure the lines are short enough; we don't check. + m = self._make_message() + content = b'b\xFFgus\tcon\nt\rent\n' + b'z'*60 + b'\n' + m.set_content(content, 'application', 'octet-stream', cte='8bit') + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: application/octet-stream + Content-Transfer-Encoding: 8bit + MIME-Version: 1.0 + + """).encode('ascii') + + b'b\xFFgus\tcon\nt\nent\n' + + b'zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz\n') + self.assertEqual(m.get_payload(decode=True), content) + self.assertEqual(m.get_content(), content) + + def test_set_headers_from_header_objects(self): + m = self._make_message() + content = "Simple message.\n" + header_factory = self.policy.header_factory + raw_data_manager.set_content(m, content, headers=( + header_factory("To", "foo@example.com"), + header_factory("From", "foo@example.com"), + header_factory("Subject", "I'm talking to myself."))) + self.assertEqual(str(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + To: foo@example.com + From: foo@example.com + Subject: I'm talking to myself. + Content-Transfer-Encoding: 7bit + + Simple message. + """)) + + def test_set_headers_from_strings(self): + m = self._make_message() + content = "Simple message.\n" + raw_data_manager.set_content(m, content, headers=( + "X-Foo-Header: foo", + "X-Bar-Header: bar",)) + self.assertEqual(str(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + X-Foo-Header: foo + X-Bar-Header: bar + Content-Transfer-Encoding: 7bit + + Simple message. + """)) + + def test_set_headers_with_invalid_duplicate_string_header_raises(self): + m = self._make_message() + content = "Simple message.\n" + with self.assertRaisesRegex(ValueError, 'Content-Type'): + raw_data_manager.set_content(m, content, headers=( + "Content-Type: foo/bar",) + ) + + def test_set_headers_with_invalid_duplicate_header_header_raises(self): + m = self._make_message() + content = "Simple message.\n" + header_factory = self.policy.header_factory + with self.assertRaisesRegex(ValueError, 'Content-Type'): + raw_data_manager.set_content(m, content, headers=( + header_factory("Content-Type", " foo/bar"),) + ) + + def test_set_headers_with_defective_string_header_raises(self): + m = self._make_message() + content = "Simple message.\n" + with self.assertRaisesRegex(ValueError, 'a@fairly@@invalid@address'): + raw_data_manager.set_content(m, content, headers=( + 'To: a@fairly@@invalid@address',) + ) + print(m['To'].defects) + + def test_set_headers_with_defective_header_header_raises(self): + m = self._make_message() + content = "Simple message.\n" + header_factory = self.policy.header_factory + with self.assertRaisesRegex(ValueError, 'a@fairly@@invalid@address'): + raw_data_manager.set_content(m, content, headers=( + header_factory('To', 'a@fairly@@invalid@address'),) + ) + print(m['To'].defects) + + def test_set_disposition_inline(self): + m = self._make_message() + m.set_content('foo', disposition='inline') + self.assertEqual(m['Content-Disposition'], 'inline') + + def test_set_disposition_attachment(self): + m = self._make_message() + m.set_content('foo', disposition='attachment') + self.assertEqual(m['Content-Disposition'], 'attachment') + + def test_set_disposition_foo(self): + m = self._make_message() + m.set_content('foo', disposition='foo') + self.assertEqual(m['Content-Disposition'], 'foo') + + # XXX: we should have a 'strict' policy mode (beyond raise_on_defect) that + # would cause 'foo' above to raise. + + def test_set_filename(self): + m = self._make_message() + m.set_content('foo', filename='bar.txt') + self.assertEqual(m['Content-Disposition'], + 'attachment; filename="bar.txt"') + + def test_set_filename_and_disposition_inline(self): + m = self._make_message() + m.set_content('foo', disposition='inline', filename='bar.txt') + self.assertEqual(m['Content-Disposition'], 'inline; filename="bar.txt"') + + def test_set_non_ascii_filename(self): + m = self._make_message() + m.set_content('foo', filename='ábárî.txt') + self.assertEqual(bytes(m), textwrap.dedent("""\ + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 7bit + Content-Disposition: attachment; + filename*=utf-8''%C3%A1b%C3%A1r%C3%AE.txt + MIME-Version: 1.0 + + foo + """).encode('ascii')) + + content_object_params = { + 'text_plain': ('content', ()), + 'text_html': ('content', ('html',)), + 'application_octet_stream': (b'content', + ('application', 'octet_stream')), + 'image_jpeg': (b'content', ('image', 'jpeg')), + 'message_rfc822': (message(), ()), + 'message_external_body': (message(), ('external-body',)), + } + + def content_object_as_header_receiver(self, obj, mimetype): + m = self._make_message() + m.set_content(obj, *mimetype, headers=( + 'To: foo@example.com', + 'From: bar@simple.net')) + self.assertEqual(m['to'], 'foo@example.com') + self.assertEqual(m['from'], 'bar@simple.net') + + def content_object_as_disposition_inline_receiver(self, obj, mimetype): + m = self._make_message() + m.set_content(obj, *mimetype, disposition='inline') + self.assertEqual(m['Content-Disposition'], 'inline') + + def content_object_as_non_ascii_filename_receiver(self, obj, mimetype): + m = self._make_message() + m.set_content(obj, *mimetype, disposition='inline', filename='bár.txt') + self.assertEqual(m['Content-Disposition'], 'inline; filename="bár.txt"') + self.assertEqual(m.get_filename(), "bár.txt") + self.assertEqual(m['Content-Disposition'].params['filename'], "bár.txt") + + def content_object_as_cid_receiver(self, obj, mimetype): + m = self._make_message() + m.set_content(obj, *mimetype, cid='some_random_stuff') + self.assertEqual(m['Content-ID'], 'some_random_stuff') + + def content_object_as_params_receiver(self, obj, mimetype): + m = self._make_message() + params = {'foo': 'bár', 'abc': 'xyz'} + m.set_content(obj, *mimetype, params=params) + if isinstance(obj, str): + params['charset'] = 'utf-8' + self.assertEqual(m['Content-Type'].params, params) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py index 51fe756df7..26ed96ccc4 100644 --- a/Lib/test/test_email/test_email.py +++ b/Lib/test/test_email/test_email.py @@ -281,15 +281,42 @@ class TestMessageAPI(TestEmailBase): self.assertIn('TO', msg) def test_as_string(self): - eq = self.ndiffAssertEqual msg = self._msgobj('msg_01.txt') with openfile('msg_01.txt') as fp: text = fp.read() - eq(text, str(msg)) + self.assertEqual(text, str(msg)) fullrepr = msg.as_string(unixfrom=True) lines = fullrepr.split('\n') self.assertTrue(lines[0].startswith('From ')) - eq(text, NL.join(lines[1:])) + self.assertEqual(text, NL.join(lines[1:])) + + def test_as_string_policy(self): + msg = self._msgobj('msg_01.txt') + newpolicy = msg.policy.clone(linesep='\r\n') + fullrepr = msg.as_string(policy=newpolicy) + s = StringIO() + g = Generator(s, policy=newpolicy) + g.flatten(msg) + self.assertEqual(fullrepr, s.getvalue()) + + def test_as_bytes(self): + msg = self._msgobj('msg_01.txt') + with openfile('msg_01.txt') as fp: + data = fp.read().encode('ascii') + self.assertEqual(data, bytes(msg)) + fullrepr = msg.as_bytes(unixfrom=True) + lines = fullrepr.split(b'\n') + self.assertTrue(lines[0].startswith(b'From ')) + self.assertEqual(data, b'\n'.join(lines[1:])) + + def test_as_bytes_policy(self): + msg = self._msgobj('msg_01.txt') + newpolicy = msg.policy.clone(linesep='\r\n') + fullrepr = msg.as_bytes(policy=newpolicy) + s = BytesIO() + g = BytesGenerator(s,policy=newpolicy) + g.flatten(msg) + self.assertEqual(fullrepr, s.getvalue()) # test_headerregistry.TestContentTypeHeader.bad_params def test_bad_param(self): @@ -742,8 +769,15 @@ class TestEncoders(unittest.TestCase): # whose output character set is 7bit gets a transfer-encoding # of 7bit. eq = self.assertEqual - msg = MIMEText('文', _charset='euc-jp') + msg = MIMEText('文\n', _charset='euc-jp') eq(msg['content-transfer-encoding'], '7bit') + eq(msg.as_string(), textwrap.dedent("""\ + MIME-Version: 1.0 + Content-Type: text/plain; charset="iso-2022-jp" + Content-Transfer-Encoding: 7bit + + \x1b$BJ8\x1b(B + """)) def test_qp_encode_latin1(self): msg = MIMEText('\xe1\xf6\n', 'text', 'ISO-8859-1') diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py index adaf3e8fe4..7d64a3829c 100644 --- a/Lib/test/test_email/test_headerregistry.py +++ b/Lib/test/test_email/test_headerregistry.py @@ -661,7 +661,7 @@ class TestContentTypeHeader(TestHeaderBase): 'text/plain; name="ascii_is_the_default"'), 'rfc2231_bad_character_in_charset_parameter_value': ( - "text/plain; charset*=ascii''utf-8%E2%80%9D", + "text/plain; charset*=ascii''utf-8%F1%F2%F3", 'text/plain', 'text', 'plain', @@ -669,6 +669,18 @@ class TestContentTypeHeader(TestHeaderBase): [errors.UndecodableBytesDefect], 'text/plain; charset="utf-8\uFFFD\uFFFD\uFFFD"'), + 'rfc2231_utf_8_in_supposedly_ascii_charset_parameter_value': ( + "text/plain; charset*=ascii''utf-8%E2%80%9D", + 'text/plain', + 'text', + 'plain', + {'charset': 'utf-8”'}, + [errors.UndecodableBytesDefect], + 'text/plain; charset="utf-8”"', + ), + # XXX: if the above were *re*folded, it would get tagged as utf-8 + # instead of ascii in the param, since it now contains non-ASCII. + 'rfc2231_encoded_then_unencoded_segments': ( ('application/x-foo;' '\tname*0*="us-ascii\'en-us\'My";' diff --git a/Lib/test/test_email/test_message.py b/Lib/test/test_email/test_message.py index 8cc3f80e46..c761c62d04 100644 --- a/Lib/test/test_email/test_message.py +++ b/Lib/test/test_email/test_message.py @@ -1,6 +1,13 @@ import unittest -from email import policy -from test.test_email import TestEmailBase +import textwrap +from email import policy, message_from_string +from email.message import EmailMessage, MIMEPart +from test.test_email import TestEmailBase, parameterize + + +# Helper. +def first(iterable): + return next(filter(lambda x: x is not None, iterable), None) class Test(TestEmailBase): @@ -13,6 +20,753 @@ class Test(TestEmailBase): with self.assertRaises(ValueError): m['To'] = 'xyz@abc' + def test_rfc2043_auto_decoded_and_emailmessage_used(self): + m = message_from_string(textwrap.dedent("""\ + Subject: Ayons asperges pour le =?utf-8?q?d=C3=A9jeuner?= + From: =?utf-8?q?Pep=C3=A9?= Le Pew <pepe@example.com> + To: "Penelope Pussycat" <"penelope@example.com"> + MIME-Version: 1.0 + Content-Type: text/plain; charset="utf-8" + + sample text + """), policy=policy.default) + self.assertEqual(m['subject'], "Ayons asperges pour le déjeuner") + self.assertEqual(m['from'], "Pepé Le Pew <pepe@example.com>") + self.assertIsInstance(m, EmailMessage) + + +@parameterize +class TestEmailMessageBase: + + policy = policy.default + + # The first argument is a triple (related, html, plain) of indices into the + # list returned by 'walk' called on a Message constructed from the third. + # The indices indicate which part should match the corresponding part-type + # when passed to get_body (ie: the "first" part of that type in the + # message). The second argument is a list of indices into the 'walk' list + # of the attachments that should be returned by a call to + # 'iter_attachments'. The third argument is a list of indices into 'walk' + # that should be returned by a call to 'iter_parts'. Note that the first + # item returned by 'walk' is the Message itself. + + message_params = { + + 'empty_message': ( + (None, None, 0), + (), + (), + ""), + + 'non_mime_plain': ( + (None, None, 0), + (), + (), + textwrap.dedent("""\ + To: foo@example.com + + simple text body + """)), + + 'mime_non_text': ( + (None, None, None), + (), + (), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: image/jpg + + bogus body. + """)), + + 'plain_html_alternative': ( + (None, 2, 1), + (), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/alternative; boundary="===" + + preamble + + --=== + Content-Type: text/plain + + simple body + + --=== + Content-Type: text/html + + <p>simple body</p> + --===-- + """)), + + 'plain_html_mixed': ( + (None, 2, 1), + (), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + preamble + + --=== + Content-Type: text/plain + + simple body + + --=== + Content-Type: text/html + + <p>simple body</p> + + --===-- + """)), + + 'plain_html_attachment_mixed': ( + (None, None, 1), + (2,), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: text/plain + + simple body + + --=== + Content-Type: text/html + Content-Disposition: attachment + + <p>simple body</p> + + --===-- + """)), + + 'html_text_attachment_mixed': ( + (None, 2, None), + (1,), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: text/plain + Content-Disposition: AtTaChment + + simple body + + --=== + Content-Type: text/html + + <p>simple body</p> + + --===-- + """)), + + 'html_text_attachment_inline_mixed': ( + (None, 2, 1), + (), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: text/plain + Content-Disposition: InLine + + simple body + + --=== + Content-Type: text/html + Content-Disposition: inline + + <p>simple body</p> + + --===-- + """)), + + # RFC 2387 + 'related': ( + (0, 1, None), + (2,), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/related; boundary="==="; type=text/html + + --=== + Content-Type: text/html + + <p>simple body</p> + + --=== + Content-Type: image/jpg + Content-ID: <image1> + + bogus data + + --===-- + """)), + + # This message structure will probably never be seen in the wild, but + # it proves we distinguish between text parts based on 'start'. The + # content would not, of course, actually work :) + 'related_with_start': ( + (0, 2, None), + (1,), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/related; boundary="==="; type=text/html; + start="<body>" + + --=== + Content-Type: text/html + Content-ID: <include> + + useless text + + --=== + Content-Type: text/html + Content-ID: <body> + + <p>simple body</p> + <!--#include file="<include>"--> + + --===-- + """)), + + + 'mixed_alternative_plain_related': ( + (3, 4, 2), + (6, 7), + (1, 6, 7), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: multipart/alternative; boundary="+++" + + --+++ + Content-Type: text/plain + + simple body + + --+++ + Content-Type: multipart/related; boundary="___" + + --___ + Content-Type: text/html + + <p>simple body</p> + + --___ + Content-Type: image/jpg + Content-ID: <image1@cid> + + bogus jpg body + + --___-- + + --+++-- + + --=== + Content-Type: image/jpg + Content-Disposition: attachment + + bogus jpg body + + --=== + Content-Type: image/jpg + Content-Disposition: AttacHmenT + + another bogus jpg body + + --===-- + """)), + + # This structure suggested by Stephen J. Turnbull...may not exist/be + # supported in the wild, but we want to support it. + 'mixed_related_alternative_plain_html': ( + (1, 4, 3), + (6, 7), + (1, 6, 7), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: multipart/related; boundary="+++" + + --+++ + Content-Type: multipart/alternative; boundary="___" + + --___ + Content-Type: text/plain + + simple body + + --___ + Content-Type: text/html + + <p>simple body</p> + + --___-- + + --+++ + Content-Type: image/jpg + Content-ID: <image1@cid> + + bogus jpg body + + --+++-- + + --=== + Content-Type: image/jpg + Content-Disposition: attachment + + bogus jpg body + + --=== + Content-Type: image/jpg + Content-Disposition: attachment + + another bogus jpg body + + --===-- + """)), + + # Same thing, but proving we only look at the root part, which is the + # first one if there isn't any start parameter. That is, this is a + # broken related. + 'mixed_related_alternative_plain_html_wrong_order': ( + (1, None, None), + (6, 7), + (1, 6, 7), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: multipart/related; boundary="+++" + + --+++ + Content-Type: image/jpg + Content-ID: <image1@cid> + + bogus jpg body + + --+++ + Content-Type: multipart/alternative; boundary="___" + + --___ + Content-Type: text/plain + + simple body + + --___ + Content-Type: text/html + + <p>simple body</p> + + --___-- + + --+++-- + + --=== + Content-Type: image/jpg + Content-Disposition: attachment + + bogus jpg body + + --=== + Content-Type: image/jpg + Content-Disposition: attachment + + another bogus jpg body + + --===-- + """)), + + 'message_rfc822': ( + (None, None, None), + (), + (), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: message/rfc822 + + To: bar@example.com + From: robot@examp.com + + this is a message body. + """)), + + 'mixed_text_message_rfc822': ( + (None, None, 1), + (2,), + (1, 2), + textwrap.dedent("""\ + To: foo@example.com + MIME-Version: 1.0 + Content-Type: multipart/mixed; boundary="===" + + --=== + Content-Type: text/plain + + Your message has bounced, ser. + + --=== + Content-Type: message/rfc822 + + To: bar@example.com + From: robot@examp.com + + this is a message body. + + --===-- + """)), + + } + + def message_as_get_body(self, body_parts, attachments, parts, msg): + m = self._str_msg(msg) + allparts = list(m.walk()) + expected = [None if n is None else allparts[n] for n in body_parts] + related = 0; html = 1; plain = 2 + self.assertEqual(m.get_body(), first(expected)) + self.assertEqual(m.get_body(preferencelist=( + 'related', 'html', 'plain')), + first(expected)) + self.assertEqual(m.get_body(preferencelist=('related', 'html')), + first(expected[related:html+1])) + self.assertEqual(m.get_body(preferencelist=('related', 'plain')), + first([expected[related], expected[plain]])) + self.assertEqual(m.get_body(preferencelist=('html', 'plain')), + first(expected[html:plain+1])) + self.assertEqual(m.get_body(preferencelist=['related']), + expected[related]) + self.assertEqual(m.get_body(preferencelist=['html']), expected[html]) + self.assertEqual(m.get_body(preferencelist=['plain']), expected[plain]) + self.assertEqual(m.get_body(preferencelist=('plain', 'html')), + first(expected[plain:html-1:-1])) + self.assertEqual(m.get_body(preferencelist=('plain', 'related')), + first([expected[plain], expected[related]])) + self.assertEqual(m.get_body(preferencelist=('html', 'related')), + first(expected[html::-1])) + self.assertEqual(m.get_body(preferencelist=('plain', 'html', 'related')), + first(expected[::-1])) + self.assertEqual(m.get_body(preferencelist=('html', 'plain', 'related')), + first([expected[html], + expected[plain], + expected[related]])) + + def message_as_iter_attachment(self, body_parts, attachments, parts, msg): + m = self._str_msg(msg) + allparts = list(m.walk()) + attachments = [allparts[n] for n in attachments] + self.assertEqual(list(m.iter_attachments()), attachments) + + def message_as_iter_parts(self, body_parts, attachments, parts, msg): + m = self._str_msg(msg) + allparts = list(m.walk()) + parts = [allparts[n] for n in parts] + self.assertEqual(list(m.iter_parts()), parts) + + class _TestContentManager: + def get_content(self, msg, *args, **kw): + return msg, args, kw + def set_content(self, msg, *args, **kw): + self.msg = msg + self.args = args + self.kw = kw + + def test_get_content_with_cm(self): + m = self._str_msg('') + cm = self._TestContentManager() + self.assertEqual(m.get_content(content_manager=cm), (m, (), {})) + msg, args, kw = m.get_content('foo', content_manager=cm, bar=1, k=2) + self.assertEqual(msg, m) + self.assertEqual(args, ('foo',)) + self.assertEqual(kw, dict(bar=1, k=2)) + + def test_get_content_default_cm_comes_from_policy(self): + p = policy.default.clone(content_manager=self._TestContentManager()) + m = self._str_msg('', policy=p) + self.assertEqual(m.get_content(), (m, (), {})) + msg, args, kw = m.get_content('foo', bar=1, k=2) + self.assertEqual(msg, m) + self.assertEqual(args, ('foo',)) + self.assertEqual(kw, dict(bar=1, k=2)) + + def test_set_content_with_cm(self): + m = self._str_msg('') + cm = self._TestContentManager() + m.set_content(content_manager=cm) + self.assertEqual(cm.msg, m) + self.assertEqual(cm.args, ()) + self.assertEqual(cm.kw, {}) + m.set_content('foo', content_manager=cm, bar=1, k=2) + self.assertEqual(cm.msg, m) + self.assertEqual(cm.args, ('foo',)) + self.assertEqual(cm.kw, dict(bar=1, k=2)) + + def test_set_content_default_cm_comes_from_policy(self): + cm = self._TestContentManager() + p = policy.default.clone(content_manager=cm) + m = self._str_msg('', policy=p) + m.set_content() + self.assertEqual(cm.msg, m) + self.assertEqual(cm.args, ()) + self.assertEqual(cm.kw, {}) + m.set_content('foo', bar=1, k=2) + self.assertEqual(cm.msg, m) + self.assertEqual(cm.args, ('foo',)) + self.assertEqual(cm.kw, dict(bar=1, k=2)) + + # outcome is whether xxx_method should raise ValueError error when called + # on multipart/subtype. Blank outcome means it depends on xxx (add + # succeeds, make raises). Note: 'none' means there are content-type + # headers but payload is None...this happening in practice would be very + # unusual, so treating it as if there were content seems reasonable. + # method subtype outcome + subtype_params = ( + ('related', 'no_content', 'succeeds'), + ('related', 'none', 'succeeds'), + ('related', 'plain', 'succeeds'), + ('related', 'related', ''), + ('related', 'alternative', 'raises'), + ('related', 'mixed', 'raises'), + ('alternative', 'no_content', 'succeeds'), + ('alternative', 'none', 'succeeds'), + ('alternative', 'plain', 'succeeds'), + ('alternative', 'related', 'succeeds'), + ('alternative', 'alternative', ''), + ('alternative', 'mixed', 'raises'), + ('mixed', 'no_content', 'succeeds'), + ('mixed', 'none', 'succeeds'), + ('mixed', 'plain', 'succeeds'), + ('mixed', 'related', 'succeeds'), + ('mixed', 'alternative', 'succeeds'), + ('mixed', 'mixed', ''), + ) + + def _make_subtype_test_message(self, subtype): + m = self.message() + payload = None + msg_headers = [ + ('To', 'foo@bar.com'), + ('From', 'bar@foo.com'), + ] + if subtype != 'no_content': + ('content-shadow', 'Logrus'), + msg_headers.append(('X-Random-Header', 'Corwin')) + if subtype == 'text': + payload = '' + msg_headers.append(('Content-Type', 'text/plain')) + m.set_payload('') + elif subtype != 'no_content': + payload = [] + msg_headers.append(('Content-Type', 'multipart/' + subtype)) + msg_headers.append(('X-Trump', 'Random')) + m.set_payload(payload) + for name, value in msg_headers: + m[name] = value + return m, msg_headers, payload + + def _check_disallowed_subtype_raises(self, m, method_name, subtype, method): + with self.assertRaises(ValueError) as ar: + getattr(m, method)() + exc_text = str(ar.exception) + self.assertIn(subtype, exc_text) + self.assertIn(method_name, exc_text) + + def _check_make_multipart(self, m, msg_headers, payload): + count = 0 + for name, value in msg_headers: + if not name.lower().startswith('content-'): + self.assertEqual(m[name], value) + count += 1 + self.assertEqual(len(m), count+1) # +1 for new Content-Type + part = next(m.iter_parts()) + count = 0 + for name, value in msg_headers: + if name.lower().startswith('content-'): + self.assertEqual(part[name], value) + count += 1 + self.assertEqual(len(part), count) + self.assertEqual(part.get_payload(), payload) + + def subtype_as_make(self, method, subtype, outcome): + m, msg_headers, payload = self._make_subtype_test_message(subtype) + make_method = 'make_' + method + if outcome in ('', 'raises'): + self._check_disallowed_subtype_raises(m, method, subtype, make_method) + return + getattr(m, make_method)() + self.assertEqual(m.get_content_maintype(), 'multipart') + self.assertEqual(m.get_content_subtype(), method) + if subtype == 'no_content': + self.assertEqual(len(m.get_payload()), 0) + self.assertEqual(m.items(), + msg_headers + [('Content-Type', + 'multipart/'+method)]) + else: + self.assertEqual(len(m.get_payload()), 1) + self._check_make_multipart(m, msg_headers, payload) + + def subtype_as_make_with_boundary(self, method, subtype, outcome): + # Doing all variation is a bit of overkill... + m = self.message() + if outcome in ('', 'raises'): + m['Content-Type'] = 'multipart/' + subtype + with self.assertRaises(ValueError) as cm: + getattr(m, 'make_' + method)() + return + if subtype == 'plain': + m['Content-Type'] = 'text/plain' + elif subtype != 'no_content': + m['Content-Type'] = 'multipart/' + subtype + getattr(m, 'make_' + method)(boundary="abc") + self.assertTrue(m.is_multipart()) + self.assertEqual(m.get_boundary(), 'abc') + + def test_policy_on_part_made_by_make_comes_from_message(self): + for method in ('make_related', 'make_alternative', 'make_mixed'): + m = self.message(policy=self.policy.clone(content_manager='foo')) + m['Content-Type'] = 'text/plain' + getattr(m, method)() + self.assertEqual(m.get_payload(0).policy.content_manager, 'foo') + + class _TestSetContentManager: + def set_content(self, msg, content, *args, **kw): + msg['Content-Type'] = 'text/plain' + msg.set_payload(content) + + def subtype_as_add(self, method, subtype, outcome): + m, msg_headers, payload = self._make_subtype_test_message(subtype) + cm = self._TestSetContentManager() + add_method = 'add_attachment' if method=='mixed' else 'add_' + method + if outcome == 'raises': + self._check_disallowed_subtype_raises(m, method, subtype, add_method) + return + getattr(m, add_method)('test', content_manager=cm) + self.assertEqual(m.get_content_maintype(), 'multipart') + self.assertEqual(m.get_content_subtype(), method) + if method == subtype or subtype == 'no_content': + self.assertEqual(len(m.get_payload()), 1) + for name, value in msg_headers: + self.assertEqual(m[name], value) + part = m.get_payload()[0] + else: + self.assertEqual(len(m.get_payload()), 2) + self._check_make_multipart(m, msg_headers, payload) + part = m.get_payload()[1] + self.assertEqual(part.get_content_type(), 'text/plain') + self.assertEqual(part.get_payload(), 'test') + if method=='mixed': + self.assertEqual(part['Content-Disposition'], 'attachment') + elif method=='related': + self.assertEqual(part['Content-Disposition'], 'inline') + else: + # Otherwise we don't guess. + self.assertIsNone(part['Content-Disposition']) + + class _TestSetRaisingContentManager: + def set_content(self, msg, content, *args, **kw): + raise Exception('test') + + def test_default_content_manager_for_add_comes_from_policy(self): + cm = self._TestSetRaisingContentManager() + m = self.message(policy=self.policy.clone(content_manager=cm)) + for method in ('add_related', 'add_alternative', 'add_attachment'): + with self.assertRaises(Exception) as ar: + getattr(m, method)('') + self.assertEqual(str(ar.exception), 'test') + + def message_as_clear(self, body_parts, attachments, parts, msg): + m = self._str_msg(msg) + m.clear() + self.assertEqual(len(m), 0) + self.assertEqual(list(m.items()), []) + self.assertIsNone(m.get_payload()) + self.assertEqual(list(m.iter_parts()), []) + + def message_as_clear_content(self, body_parts, attachments, parts, msg): + m = self._str_msg(msg) + expected_headers = [h for h in m.keys() + if not h.lower().startswith('content-')] + m.clear_content() + self.assertEqual(list(m.keys()), expected_headers) + self.assertIsNone(m.get_payload()) + self.assertEqual(list(m.iter_parts()), []) + + def test_is_attachment(self): + m = self._make_message() + self.assertFalse(m.is_attachment) + m['Content-Disposition'] = 'inline' + self.assertFalse(m.is_attachment) + m.replace_header('Content-Disposition', 'attachment') + self.assertTrue(m.is_attachment) + m.replace_header('Content-Disposition', 'AtTachMent') + self.assertTrue(m.is_attachment) + + + +class TestEmailMessage(TestEmailMessageBase, TestEmailBase): + message = EmailMessage + + def test_set_content_adds_MIME_Version(self): + m = self._str_msg('') + cm = self._TestContentManager() + self.assertNotIn('MIME-Version', m) + m.set_content(content_manager=cm) + self.assertEqual(m['MIME-Version'], '1.0') + + class _MIME_Version_adding_CM: + def set_content(self, msg, *args, **kw): + msg['MIME-Version'] = '1.0' + + def test_set_content_does_not_duplicate_MIME_Version(self): + m = self._str_msg('') + cm = self._MIME_Version_adding_CM() + self.assertNotIn('MIME-Version', m) + m.set_content(content_manager=cm) + self.assertEqual(m['MIME-Version'], '1.0') + + +class TestMIMEPart(TestEmailMessageBase, TestEmailBase): + # Doing the full test run here may seem a bit redundant, since the two + # classes are almost identical. But what if they drift apart? So we do + # the full tests so that any future drift doesn't introduce bugs. + message = MIMEPart + + def test_set_content_does_not_add_MIME_Version(self): + m = self._str_msg('') + cm = self._TestContentManager() + self.assertNotIn('MIME-Version', m) + m.set_content(content_manager=cm) + self.assertNotIn('MIME-Version', m) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_email/test_policy.py b/Lib/test/test_email/test_policy.py index 983bd49a11..06ad5f24b8 100644 --- a/Lib/test/test_email/test_policy.py +++ b/Lib/test/test_email/test_policy.py @@ -30,6 +30,7 @@ class PolicyAPITests(unittest.TestCase): 'raise_on_defect': False, 'header_factory': email.policy.EmailPolicy.header_factory, 'refold_source': 'long', + 'content_manager': email.policy.EmailPolicy.content_manager, }) # For each policy under test, we give here what we expect the defaults to diff --git a/Lib/test/test_email/torture_test.py b/Lib/test/test_email/torture_test.py index 544b1bbb39..19cf64f0c7 100644 --- a/Lib/test/test_email/torture_test.py +++ b/Lib/test/test_email/torture_test.py @@ -27,7 +27,7 @@ def openfile(filename): # Prevent this test from running in the Python distro try: openfile('crispin-torture.txt') -except IOError: +except OSError: raise TestSkipped diff --git a/Lib/test/test_ensurepip.py b/Lib/test/test_ensurepip.py new file mode 100644 index 0000000000..8644a651bb --- /dev/null +++ b/Lib/test/test_ensurepip.py @@ -0,0 +1,343 @@ +import unittest +import unittest.mock +import test.support +import os +import os.path +import contextlib +import sys + +import ensurepip +import ensurepip._uninstall + +# pip currently requires ssl support, so we ensure we handle +# it being missing (http://bugs.python.org/issue19744) +ensurepip_no_ssl = test.support.import_fresh_module("ensurepip", + blocked=["ssl"]) +try: + import ssl +except ImportError: + def requires_usable_pip(f): + deco = unittest.skip(ensurepip._MISSING_SSL_MESSAGE) + return deco(f) +else: + def requires_usable_pip(f): + return f + +class TestEnsurePipVersion(unittest.TestCase): + + def test_returns_version(self): + self.assertEqual(ensurepip._PIP_VERSION, ensurepip.version()) + +class EnsurepipMixin: + + def setUp(self): + run_pip_patch = unittest.mock.patch("ensurepip._run_pip") + self.run_pip = run_pip_patch.start() + self.addCleanup(run_pip_patch.stop) + + # Avoid side effects on the actual os module + real_devnull = os.devnull + os_patch = unittest.mock.patch("ensurepip.os") + patched_os = os_patch.start() + self.addCleanup(os_patch.stop) + patched_os.devnull = real_devnull + patched_os.path = os.path + self.os_environ = patched_os.environ = os.environ.copy() + + +class TestBootstrap(EnsurepipMixin, unittest.TestCase): + + @requires_usable_pip + def test_basic_bootstrapping(self): + ensurepip.bootstrap() + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + additional_paths = self.run_pip.call_args[0][1] + self.assertEqual(len(additional_paths), 2) + + @requires_usable_pip + def test_bootstrapping_with_root(self): + ensurepip.bootstrap(root="/foo/bar/") + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "--root", "/foo/bar/", + "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + @requires_usable_pip + def test_bootstrapping_with_user(self): + ensurepip.bootstrap(user=True) + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "--user", "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + @requires_usable_pip + def test_bootstrapping_with_upgrade(self): + ensurepip.bootstrap(upgrade=True) + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "--upgrade", "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + @requires_usable_pip + def test_bootstrapping_with_verbosity_1(self): + ensurepip.bootstrap(verbosity=1) + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "-v", "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + @requires_usable_pip + def test_bootstrapping_with_verbosity_2(self): + ensurepip.bootstrap(verbosity=2) + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "-vv", "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + @requires_usable_pip + def test_bootstrapping_with_verbosity_3(self): + ensurepip.bootstrap(verbosity=3) + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "-vvv", "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + @requires_usable_pip + def test_bootstrapping_with_regular_install(self): + ensurepip.bootstrap() + self.assertEqual(self.os_environ["ENSUREPIP_OPTIONS"], "install") + + @requires_usable_pip + def test_bootstrapping_with_alt_install(self): + ensurepip.bootstrap(altinstall=True) + self.assertEqual(self.os_environ["ENSUREPIP_OPTIONS"], "altinstall") + + @requires_usable_pip + def test_bootstrapping_with_default_pip(self): + ensurepip.bootstrap(default_pip=True) + self.assertNotIn("ENSUREPIP_OPTIONS", self.os_environ) + + def test_altinstall_default_pip_conflict(self): + with self.assertRaises(ValueError): + ensurepip.bootstrap(altinstall=True, default_pip=True) + self.run_pip.assert_not_called() + + @requires_usable_pip + def test_pip_environment_variables_removed(self): + # ensurepip deliberately ignores all pip environment variables + # See http://bugs.python.org/issue19734 for details + self.os_environ["PIP_THIS_SHOULD_GO_AWAY"] = "test fodder" + ensurepip.bootstrap() + self.assertNotIn("PIP_THIS_SHOULD_GO_AWAY", self.os_environ) + + @requires_usable_pip + def test_pip_config_file_disabled(self): + # ensurepip deliberately ignores the pip config file + # See http://bugs.python.org/issue20053 for details + ensurepip.bootstrap() + self.assertEqual(self.os_environ["PIP_CONFIG_FILE"], os.devnull) + +@contextlib.contextmanager +def fake_pip(version=ensurepip._PIP_VERSION): + if version is None: + pip = None + else: + class FakePip(): + __version__ = version + pip = FakePip() + sentinel = object() + orig_pip = sys.modules.get("pip", sentinel) + sys.modules["pip"] = pip + try: + yield pip + finally: + if orig_pip is sentinel: + del sys.modules["pip"] + else: + sys.modules["pip"] = orig_pip + +class TestUninstall(EnsurepipMixin, unittest.TestCase): + + def test_uninstall_skipped_when_not_installed(self): + with fake_pip(None): + ensurepip._uninstall_helper() + self.run_pip.assert_not_called() + + def test_uninstall_fails_with_wrong_version(self): + with fake_pip("not a valid version"): + with self.assertRaises(RuntimeError): + ensurepip._uninstall_helper() + self.run_pip.assert_not_called() + + + @requires_usable_pip + def test_uninstall(self): + with fake_pip(): + ensurepip._uninstall_helper() + + self.run_pip.assert_called_once_with( + ["uninstall", "-y", "pip", "setuptools"] + ) + + @requires_usable_pip + def test_uninstall_with_verbosity_1(self): + with fake_pip(): + ensurepip._uninstall_helper(verbosity=1) + + self.run_pip.assert_called_once_with( + ["uninstall", "-y", "-v", "pip", "setuptools"] + ) + + @requires_usable_pip + def test_uninstall_with_verbosity_2(self): + with fake_pip(): + ensurepip._uninstall_helper(verbosity=2) + + self.run_pip.assert_called_once_with( + ["uninstall", "-y", "-vv", "pip", "setuptools"] + ) + + @requires_usable_pip + def test_uninstall_with_verbosity_3(self): + with fake_pip(): + ensurepip._uninstall_helper(verbosity=3) + + self.run_pip.assert_called_once_with( + ["uninstall", "-y", "-vvv", "pip", "setuptools"] + ) + + @requires_usable_pip + def test_pip_environment_variables_removed(self): + # ensurepip deliberately ignores all pip environment variables + # See http://bugs.python.org/issue19734 for details + self.os_environ["PIP_THIS_SHOULD_GO_AWAY"] = "test fodder" + with fake_pip(): + ensurepip._uninstall_helper() + self.assertNotIn("PIP_THIS_SHOULD_GO_AWAY", self.os_environ) + + @requires_usable_pip + def test_pip_config_file_disabled(self): + # ensurepip deliberately ignores the pip config file + # See http://bugs.python.org/issue20053 for details + with fake_pip(): + ensurepip._uninstall_helper() + self.assertEqual(self.os_environ["PIP_CONFIG_FILE"], os.devnull) + + +class TestMissingSSL(EnsurepipMixin, unittest.TestCase): + + def setUp(self): + sys.modules["ensurepip"] = ensurepip_no_ssl + @self.addCleanup + def restore_module(): + sys.modules["ensurepip"] = ensurepip + super().setUp() + + def test_bootstrap_requires_ssl(self): + self.os_environ["PIP_THIS_SHOULD_STAY"] = "test fodder" + with self.assertRaisesRegex(RuntimeError, "requires SSL/TLS"): + ensurepip_no_ssl.bootstrap() + self.run_pip.assert_not_called() + self.assertIn("PIP_THIS_SHOULD_STAY", self.os_environ) + + def test_uninstall_requires_ssl(self): + self.os_environ["PIP_THIS_SHOULD_STAY"] = "test fodder" + with self.assertRaisesRegex(RuntimeError, "requires SSL/TLS"): + with fake_pip(): + ensurepip_no_ssl._uninstall_helper() + self.run_pip.assert_not_called() + self.assertIn("PIP_THIS_SHOULD_STAY", self.os_environ) + + def test_main_exits_early_with_warning(self): + with test.support.captured_stderr() as stderr: + ensurepip_no_ssl._main(["--version"]) + warning = stderr.getvalue().strip() + self.assertTrue(warning.endswith("requires SSL/TLS"), warning) + self.run_pip.assert_not_called() + +# Basic testing of the main functions and their argument parsing + +EXPECTED_VERSION_OUTPUT = "pip " + ensurepip._PIP_VERSION + +class TestBootstrappingMainFunction(EnsurepipMixin, unittest.TestCase): + + @requires_usable_pip + def test_bootstrap_version(self): + with test.support.captured_stdout() as stdout: + with self.assertRaises(SystemExit): + ensurepip._main(["--version"]) + result = stdout.getvalue().strip() + self.assertEqual(result, EXPECTED_VERSION_OUTPUT) + self.run_pip.assert_not_called() + + @requires_usable_pip + def test_basic_bootstrapping(self): + ensurepip._main([]) + + self.run_pip.assert_called_once_with( + [ + "install", "--no-index", "--find-links", + unittest.mock.ANY, "setuptools", "pip", + ], + unittest.mock.ANY, + ) + + additional_paths = self.run_pip.call_args[0][1] + self.assertEqual(len(additional_paths), 2) + +class TestUninstallationMainFunction(EnsurepipMixin, unittest.TestCase): + + def test_uninstall_version(self): + with test.support.captured_stdout() as stdout: + with self.assertRaises(SystemExit): + ensurepip._uninstall._main(["--version"]) + result = stdout.getvalue().strip() + self.assertEqual(result, EXPECTED_VERSION_OUTPUT) + self.run_pip.assert_not_called() + + @requires_usable_pip + def test_basic_uninstall(self): + with fake_pip(): + ensurepip._uninstall._main([]) + + self.run_pip.assert_called_once_with( + ["uninstall", "-y", "pip", "setuptools"] + ) + + + +if __name__ == "__main__": + test.support.run_unittest(__name__) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py new file mode 100644 index 0000000000..b8ef632a0d --- /dev/null +++ b/Lib/test/test_enum.py @@ -0,0 +1,1594 @@ +import enum +import inspect +import pydoc +import unittest +from collections import OrderedDict +from enum import Enum, IntEnum, EnumMeta, unique +from io import StringIO +from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL + +# for pickle tests +try: + class Stooges(Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception as exc: + Stooges = exc + +try: + class IntStooges(int, Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception as exc: + IntStooges = exc + +try: + class FloatStooges(float, Enum): + LARRY = 1.39 + CURLY = 2.72 + MOE = 3.142596 +except Exception as exc: + FloatStooges = exc + +# for pickle test and subclass tests +try: + class StrEnum(str, Enum): + 'accepts only string values' + class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' +except Exception as exc: + Name = exc + +try: + Question = Enum('Question', 'who what when where why', module=__name__) +except Exception as exc: + Question = exc + +try: + Answer = Enum('Answer', 'him this then there because') +except Exception as exc: + Answer = exc + +try: + Theory = Enum('Theory', 'rule law supposition', qualname='spanish_inquisition') +except Exception as exc: + Theory = exc + +# for doctests +try: + class Fruit(Enum): + tomato = 1 + banana = 2 + cherry = 3 +except Exception: + pass + +def test_pickle_dump_load(assertion, source, target=None, + *, protocol=(0, HIGHEST_PROTOCOL)): + start, stop = protocol + if target is None: + target = source + for protocol in range(start, stop+1): + assertion(loads(dumps(source, protocol=protocol)), target) + +def test_pickle_exception(assertion, exception, obj, + *, protocol=(0, HIGHEST_PROTOCOL)): + start, stop = protocol + for protocol in range(start, stop+1): + with assertion(exception): + dumps(obj, protocol=protocol) + +class TestHelpers(unittest.TestCase): + # _is_descriptor, _is_sunder, _is_dunder + + def test_is_descriptor(self): + class foo: + pass + for attr in ('__get__','__set__','__delete__'): + obj = foo() + self.assertFalse(enum._is_descriptor(obj)) + setattr(obj, attr, 1) + self.assertTrue(enum._is_descriptor(obj)) + + def test_is_sunder(self): + for s in ('_a_', '_aa_'): + self.assertTrue(enum._is_sunder(s)) + + for s in ('a', 'a_', '_a', '__a', 'a__', '__a__', '_a__', '__a_', '_', + '__', '___', '____', '_____',): + self.assertFalse(enum._is_sunder(s)) + + def test_is_dunder(self): + for s in ('__a__', '__aa__'): + self.assertTrue(enum._is_dunder(s)) + for s in ('a', 'a_', '_a', '__a', 'a__', '_a_', '_a__', '__a_', '_', + '__', '___', '____', '_____',): + self.assertFalse(enum._is_dunder(s)) + + +class TestEnum(unittest.TestCase): + + def setUp(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + self.Season = Season + + class Konstants(float, Enum): + E = 2.7182818 + PI = 3.1415926 + TAU = 2 * PI + self.Konstants = Konstants + + class Grades(IntEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 0 + self.Grades = Grades + + class Directional(str, Enum): + EAST = 'east' + WEST = 'west' + NORTH = 'north' + SOUTH = 'south' + self.Directional = Directional + + from datetime import date + class Holiday(date, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + self.Holiday = Holiday + + def test_dir_on_class(self): + Season = self.Season + self.assertEqual( + set(dir(Season)), + set(['__class__', '__doc__', '__members__', '__module__', + 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']), + ) + + def test_dir_on_item(self): + Season = self.Season + self.assertEqual( + set(dir(Season.WINTER)), + set(['__class__', '__doc__', '__module__', 'name', 'value']), + ) + + def test_dir_with_added_behavior(self): + class Test(Enum): + this = 'that' + these = 'those' + def wowser(self): + return ("Wowser! I'm %s!" % self.name) + self.assertEqual( + set(dir(Test)), + set(['__class__', '__doc__', '__members__', '__module__', 'this', 'these']), + ) + self.assertEqual( + set(dir(Test.this)), + set(['__class__', '__doc__', '__module__', 'name', 'value', 'wowser']), + ) + + def test_enum_in_enum_out(self): + Season = self.Season + self.assertIs(Season(Season.WINTER), Season.WINTER) + + def test_enum_value(self): + Season = self.Season + self.assertEqual(Season.SPRING.value, 1) + + def test_intenum_value(self): + self.assertEqual(IntStooges.CURLY.value, 2) + + def test_enum(self): + Season = self.Season + lst = list(Season) + self.assertEqual(len(lst), len(Season)) + self.assertEqual(len(Season), 4, Season) + self.assertEqual( + [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst) + + for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split(), 1): + e = Season(i) + self.assertEqual(e, getattr(Season, season)) + self.assertEqual(e.value, i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, season) + self.assertIn(e, Season) + self.assertIs(type(e), Season) + self.assertIsInstance(e, Season) + self.assertEqual(str(e), 'Season.' + season) + self.assertEqual( + repr(e), + '<Season.{0}: {1}>'.format(season, i), + ) + + def test_value_name(self): + Season = self.Season + self.assertEqual(Season.SPRING.name, 'SPRING') + self.assertEqual(Season.SPRING.value, 1) + with self.assertRaises(AttributeError): + Season.SPRING.name = 'invierno' + with self.assertRaises(AttributeError): + Season.SPRING.value = 2 + + def test_changing_member(self): + Season = self.Season + with self.assertRaises(AttributeError): + Season.WINTER = 'really cold' + + def test_attribute_deletion(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + + def spam(cls): + pass + + self.assertTrue(hasattr(Season, 'spam')) + del Season.spam + self.assertFalse(hasattr(Season, 'spam')) + + with self.assertRaises(AttributeError): + del Season.SPRING + with self.assertRaises(AttributeError): + del Season.DRY + with self.assertRaises(AttributeError): + del Season.SPRING.name + + def test_invalid_names(self): + with self.assertRaises(ValueError): + class Wrong(Enum): + mro = 9 + with self.assertRaises(ValueError): + class Wrong(Enum): + _create_= 11 + with self.assertRaises(ValueError): + class Wrong(Enum): + _get_mixins_ = 9 + with self.assertRaises(ValueError): + class Wrong(Enum): + _find_new_ = 1 + with self.assertRaises(ValueError): + class Wrong(Enum): + _any_name_ = 9 + + def test_contains(self): + Season = self.Season + self.assertIn(Season.AUTUMN, Season) + self.assertNotIn(3, Season) + + val = Season(3) + self.assertIn(val, Season) + + class OtherEnum(Enum): + one = 1; two = 2 + self.assertNotIn(OtherEnum.two, Season) + + def test_comparisons(self): + Season = self.Season + with self.assertRaises(TypeError): + Season.SPRING < Season.WINTER + with self.assertRaises(TypeError): + Season.SPRING > 4 + + self.assertNotEqual(Season.SPRING, 1) + + class Part(Enum): + SPRING = 1 + CLIP = 2 + BARREL = 3 + + self.assertNotEqual(Season.SPRING, Part.SPRING) + with self.assertRaises(TypeError): + Season.SPRING < Part.CLIP + + def test_enum_duplicates(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = FALL = 3 + WINTER = 4 + ANOTHER_SPRING = 1 + lst = list(Season) + self.assertEqual( + lst, + [Season.SPRING, Season.SUMMER, + Season.AUTUMN, Season.WINTER, + ]) + self.assertIs(Season.FALL, Season.AUTUMN) + self.assertEqual(Season.FALL.value, 3) + self.assertEqual(Season.AUTUMN.value, 3) + self.assertIs(Season(3), Season.AUTUMN) + self.assertIs(Season(1), Season.SPRING) + self.assertEqual(Season.FALL.name, 'AUTUMN') + self.assertEqual( + [k for k,v in Season.__members__.items() if v.name != k], + ['FALL', 'ANOTHER_SPRING'], + ) + + def test_duplicate_name(self): + with self.assertRaises(TypeError): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + red = 4 + + with self.assertRaises(TypeError): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def red(self): + return 'red' + + with self.assertRaises(TypeError): + class Color(Enum): + @property + def red(self): + return 'redder' + red = 1 + green = 2 + blue = 3 + + + def test_enum_with_value_name(self): + class Huh(Enum): + name = 1 + value = 2 + self.assertEqual( + list(Huh), + [Huh.name, Huh.value], + ) + self.assertIs(type(Huh.name), Huh) + self.assertEqual(Huh.name.name, 'name') + self.assertEqual(Huh.name.value, 1) + + def test_format_enum(self): + Season = self.Season + self.assertEqual('{}'.format(Season.SPRING), + '{}'.format(str(Season.SPRING))) + self.assertEqual( '{:}'.format(Season.SPRING), + '{:}'.format(str(Season.SPRING))) + self.assertEqual('{:20}'.format(Season.SPRING), + '{:20}'.format(str(Season.SPRING))) + self.assertEqual('{:^20}'.format(Season.SPRING), + '{:^20}'.format(str(Season.SPRING))) + self.assertEqual('{:>20}'.format(Season.SPRING), + '{:>20}'.format(str(Season.SPRING))) + self.assertEqual('{:<20}'.format(Season.SPRING), + '{:<20}'.format(str(Season.SPRING))) + + def test_format_enum_custom(self): + class TestFloat(float, Enum): + one = 1.0 + two = 2.0 + def __format__(self, spec): + return 'TestFloat success!' + self.assertEqual('{}'.format(TestFloat.one), 'TestFloat success!') + + def assertFormatIsValue(self, spec, member): + self.assertEqual(spec.format(member), spec.format(member.value)) + + def test_format_enum_date(self): + Holiday = self.Holiday + self.assertFormatIsValue('{}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:^20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:>20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:<20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:%Y %m}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{:%Y %m %M:00}', Holiday.IDES_OF_MARCH) + + def test_format_enum_float(self): + Konstants = self.Konstants + self.assertFormatIsValue('{}', Konstants.TAU) + self.assertFormatIsValue('{:}', Konstants.TAU) + self.assertFormatIsValue('{:20}', Konstants.TAU) + self.assertFormatIsValue('{:^20}', Konstants.TAU) + self.assertFormatIsValue('{:>20}', Konstants.TAU) + self.assertFormatIsValue('{:<20}', Konstants.TAU) + self.assertFormatIsValue('{:n}', Konstants.TAU) + self.assertFormatIsValue('{:5.2}', Konstants.TAU) + self.assertFormatIsValue('{:f}', Konstants.TAU) + + def test_format_enum_int(self): + Grades = self.Grades + self.assertFormatIsValue('{}', Grades.C) + self.assertFormatIsValue('{:}', Grades.C) + self.assertFormatIsValue('{:20}', Grades.C) + self.assertFormatIsValue('{:^20}', Grades.C) + self.assertFormatIsValue('{:>20}', Grades.C) + self.assertFormatIsValue('{:<20}', Grades.C) + self.assertFormatIsValue('{:+}', Grades.C) + self.assertFormatIsValue('{:08X}', Grades.C) + self.assertFormatIsValue('{:b}', Grades.C) + + def test_format_enum_str(self): + Directional = self.Directional + self.assertFormatIsValue('{}', Directional.WEST) + self.assertFormatIsValue('{:}', Directional.WEST) + self.assertFormatIsValue('{:20}', Directional.WEST) + self.assertFormatIsValue('{:^20}', Directional.WEST) + self.assertFormatIsValue('{:>20}', Directional.WEST) + self.assertFormatIsValue('{:<20}', Directional.WEST) + + def test_hash(self): + Season = self.Season + dates = {} + dates[Season.WINTER] = '1225' + dates[Season.SPRING] = '0315' + dates[Season.SUMMER] = '0704' + dates[Season.AUTUMN] = '1031' + self.assertEqual(dates[Season.AUTUMN], '1031') + + def test_intenum_from_scratch(self): + class phy(int, Enum): + pi = 3 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_intenum_inherited(self): + class IntEnum(int, Enum): + pass + class phy(IntEnum): + pi = 3 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_floatenum_from_scratch(self): + class phy(float, Enum): + pi = 3.1415926 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_floatenum_inherited(self): + class FloatEnum(float, Enum): + pass + class phy(FloatEnum): + pi = 3.1415926 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_strenum_from_scratch(self): + class phy(str, Enum): + pi = 'Pi' + tau = 'Tau' + self.assertTrue(phy.pi < phy.tau) + + def test_strenum_inherited(self): + class StrEnum(str, Enum): + pass + class phy(StrEnum): + pi = 'Pi' + tau = 'Tau' + self.assertTrue(phy.pi < phy.tau) + + + def test_intenum(self): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + + self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c') + self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2]) + + lst = list(WeekDay) + self.assertEqual(len(lst), len(WeekDay)) + self.assertEqual(len(WeekDay), 7) + target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' + target = target.split() + for i, weekday in enumerate(target, 1): + e = WeekDay(i) + self.assertEqual(e, i) + self.assertEqual(int(e), i) + self.assertEqual(e.name, weekday) + self.assertIn(e, WeekDay) + self.assertEqual(lst.index(e)+1, i) + self.assertTrue(0 < e < 8) + self.assertIs(type(e), WeekDay) + self.assertIsInstance(e, int) + self.assertIsInstance(e, Enum) + + def test_intenum_duplicates(self): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = TEUSDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + self.assertIs(WeekDay.TEUSDAY, WeekDay.TUESDAY) + self.assertEqual(WeekDay(3).name, 'TUESDAY') + self.assertEqual([k for k,v in WeekDay.__members__.items() + if v.name != k], ['TEUSDAY', ]) + + def test_pickle_enum(self): + if isinstance(Stooges, Exception): + raise Stooges + test_pickle_dump_load(self.assertIs, Stooges.CURLY) + test_pickle_dump_load(self.assertIs, Stooges) + + def test_pickle_int(self): + if isinstance(IntStooges, Exception): + raise IntStooges + test_pickle_dump_load(self.assertIs, IntStooges.CURLY) + test_pickle_dump_load(self.assertIs, IntStooges) + + def test_pickle_float(self): + if isinstance(FloatStooges, Exception): + raise FloatStooges + test_pickle_dump_load(self.assertIs, FloatStooges.CURLY) + test_pickle_dump_load(self.assertIs, FloatStooges) + + def test_pickle_enum_function(self): + if isinstance(Answer, Exception): + raise Answer + test_pickle_dump_load(self.assertIs, Answer.him) + test_pickle_dump_load(self.assertIs, Answer) + + def test_pickle_enum_function_with_module(self): + if isinstance(Question, Exception): + raise Question + test_pickle_dump_load(self.assertIs, Question.who) + test_pickle_dump_load(self.assertIs, Question) + + def test_enum_function_with_qualname(self): + if isinstance(Theory, Exception): + raise Theory + self.assertEqual(Theory.__qualname__, 'spanish_inquisition') + + def test_class_nested_enum_and_pickle_protocol_four(self): + # would normally just have this directly in the class namespace + class NestedEnum(Enum): + twigs = 'common' + shiny = 'rare' + + self.__class__.NestedEnum = NestedEnum + self.NestedEnum.__qualname__ = '%s.NestedEnum' % self.__class__.__name__ + test_pickle_exception( + self.assertRaises, PicklingError, self.NestedEnum.twigs, + protocol=(0, 3)) + test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs, + protocol=(4, HIGHEST_PROTOCOL)) + + def test_exploding_pickle(self): + BadPickle = Enum( + 'BadPickle', 'dill sweet bread-n-butter', module=__name__) + globals()['BadPickle'] = BadPickle + # now break BadPickle to test exception raising + enum._make_class_unpicklable(BadPickle) + test_pickle_exception(self.assertRaises, TypeError, BadPickle.dill) + test_pickle_exception(self.assertRaises, PicklingError, BadPickle) + + def test_string_enum(self): + class SkillLevel(str, Enum): + master = 'what is the sound of one hand clapping?' + journeyman = 'why did the chicken cross the road?' + apprentice = 'knock, knock!' + self.assertEqual(SkillLevel.apprentice, 'knock, knock!') + + def test_getattr_getitem(self): + class Period(Enum): + morning = 1 + noon = 2 + evening = 3 + night = 4 + self.assertIs(Period(2), Period.noon) + self.assertIs(getattr(Period, 'night'), Period.night) + self.assertIs(Period['morning'], Period.morning) + + def test_getattr_dunder(self): + Season = self.Season + self.assertTrue(getattr(Season, '__eq__')) + + def test_iteration_order(self): + class Season(Enum): + SUMMER = 2 + WINTER = 4 + AUTUMN = 3 + SPRING = 1 + self.assertEqual( + list(Season), + [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING], + ) + + def test_reversed_iteration_order(self): + self.assertEqual( + list(reversed(self.Season)), + [self.Season.WINTER, self.Season.AUTUMN, self.Season.SUMMER, + self.Season.SPRING] + ) + + def test_programatic_function_string(self): + SummerMonth = Enum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_string_list(self): + SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_iterable(self): + SummerMonth = Enum( + 'SummerMonth', + (('june', 1), ('july', 2), ('august', 3)) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_from_dict(self): + SummerMonth = Enum( + 'SummerMonth', + OrderedDict((('june', 1), ('july', 2), ('august', 3))) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_type(self): + SummerMonth = Enum('SummerMonth', 'june july august', type=int) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_type_from_subclass(self): + SummerMonth = IntEnum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_subclassing(self): + if isinstance(Name, Exception): + raise Name + self.assertEqual(Name.BDFL, 'Guido van Rossum') + self.assertTrue(Name.BDFL, Name('Guido van Rossum')) + self.assertIs(Name.BDFL, getattr(Name, 'BDFL')) + test_pickle_dump_load(self.assertIs, Name.BDFL) + + def test_extending(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + with self.assertRaises(TypeError): + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + + def test_exclude_methods(self): + class whatever(Enum): + this = 'that' + these = 'those' + def really(self): + return 'no, not %s' % self.value + self.assertIsNot(type(whatever.really), whatever) + self.assertEqual(whatever.this.really(), 'no, not that') + + def test_wrong_inheritance_order(self): + with self.assertRaises(TypeError): + class Wrong(Enum, str): + NotHere = 'error before this point' + + def test_intenum_transitivity(self): + class number(IntEnum): + one = 1 + two = 2 + three = 3 + class numero(IntEnum): + uno = 1 + dos = 2 + tres = 3 + self.assertEqual(number.one, numero.uno) + self.assertEqual(number.two, numero.dos) + self.assertEqual(number.three, numero.tres) + + def test_wrong_enum_in_call(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_wrong_enum_in_mixed_call(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_mixed_enum_in_call_1(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertIs(Monochrome(Gender.female), Monochrome.white) + + def test_mixed_enum_in_call_2(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertIs(Monochrome(Gender.male), Monochrome.black) + + def test_flufl_enum(self): + class Fluflnum(Enum): + def __int__(self): + return int(self.value) + class MailManOptions(Fluflnum): + option1 = 1 + option2 = 2 + option3 = 3 + self.assertEqual(int(MailManOptions.option1), 1) + + def test_introspection(self): + class Number(IntEnum): + one = 100 + two = 200 + self.assertIs(Number.one._member_type_, int) + self.assertIs(Number._member_type_, int) + class String(str, Enum): + yarn = 'soft' + rope = 'rough' + wire = 'hard' + self.assertIs(String.yarn._member_type_, str) + self.assertIs(String._member_type_, str) + class Plain(Enum): + vanilla = 'white' + one = 1 + self.assertIs(Plain.vanilla._member_type_, object) + self.assertIs(Plain._member_type_, object) + + def test_no_such_enum_member(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + with self.assertRaises(ValueError): + Color(4) + with self.assertRaises(KeyError): + Color['chartreuse'] + + def test_new_repr(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def __repr__(self): + return "don't you just love shades of %s?" % self.name + self.assertEqual( + repr(Color.blue), + "don't you just love shades of blue?", + ) + + def test_inherited_repr(self): + class MyEnum(Enum): + def __repr__(self): + return "My name is %s." % self.name + class MyIntEnum(int, MyEnum): + this = 1 + that = 2 + theother = 3 + self.assertEqual(repr(MyIntEnum.that), "My name is that.") + + def test_multiple_mixin_mro(self): + class auto_enum(type(Enum)): + def __new__(metacls, cls, bases, classdict): + temp = type(classdict)() + names = set(classdict._member_names) + i = 0 + for k in classdict._member_names: + v = classdict[k] + if v is Ellipsis: + v = i + else: + i = v + i += 1 + temp[k] = v + for k, v in classdict.items(): + if k not in names: + temp[k] = v + return super(auto_enum, metacls).__new__( + metacls, cls, bases, temp) + + class AutoNumberedEnum(Enum, metaclass=auto_enum): + pass + + class AutoIntEnum(IntEnum, metaclass=auto_enum): + pass + + class TestAutoNumber(AutoNumberedEnum): + a = ... + b = 3 + c = ... + + class TestAutoInt(AutoIntEnum): + a = ... + b = 3 + c = ... + + def test_subclasses_with_getnewargs(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs__(self): + return self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertIs, NEI.y) + test_pickle_dump_load(self.assertIs, NEI) + + def test_subclasses_with_getnewargs_ex(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs_ex__(self): + return self._args, {} + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5, protocol=(4, 4)) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertIs, NEI.y, protocol=(4, 4)) + test_pickle_dump_load(self.assertIs, NEI) + + def test_subclasses_with_reduce(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __reduce__(self): + return self.__class__, self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertIs, NEI.y) + test_pickle_dump_load(self.assertIs, NEI) + + def test_subclasses_with_reduce_ex(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __reduce_ex__(self, proto): + return self.__class__, self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertIs, NEI.y) + test_pickle_dump_load(self.assertIs, NEI) + + def test_subclasses_without_direct_pickle_support(self): + class NamedInt(int): + __qualname__ = 'NamedInt' + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_exception(self.assertRaises, TypeError, NEI.x) + test_pickle_exception(self.assertRaises, PicklingError, NEI) + + def test_subclasses_without_direct_pickle_support_using_name(self): + class NamedInt(int): + __qualname__ = 'NamedInt' + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' + x = ('the-x', 1) + y = ('the-y', 2) + def __reduce_ex__(self, proto): + return getattr, (self.__class__, self._name_) + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertIs, NEI.y) + test_pickle_dump_load(self.assertIs, NEI) + + def test_tuple_subclass(self): + class SomeTuple(tuple, Enum): + __qualname__ = 'SomeTuple' # needed for pickle protocol 4 + first = (1, 'for the money') + second = (2, 'for the show') + third = (3, 'for the music') + self.assertIs(type(SomeTuple.first), SomeTuple) + self.assertIsInstance(SomeTuple.second, tuple) + self.assertEqual(SomeTuple.third, (3, 'for the music')) + globals()['SomeTuple'] = SomeTuple + test_pickle_dump_load(self.assertIs, SomeTuple.first) + + def test_duplicate_values_give_unique_enum_items(self): + class AutoNumber(Enum): + first = () + second = () + third = () + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj + def __int__(self): + return int(self._value_) + self.assertEqual( + list(AutoNumber), + [AutoNumber.first, AutoNumber.second, AutoNumber.third], + ) + self.assertEqual(int(AutoNumber.second), 2) + self.assertEqual(AutoNumber.third.value, 3) + self.assertIs(AutoNumber(1), AutoNumber.first) + + def test_inherited_new_from_enhanced_enum(self): + class AutoNumber(Enum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj + def __int__(self): + return int(self._value_) + class Color(AutoNumber): + red = () + green = () + blue = () + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + self.assertEqual(list(map(int, Color)), [1, 2, 3]) + + def test_inherited_new_from_mixed_enum(self): + class AutoNumber(IntEnum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = int.__new__(cls, value) + obj._value_ = value + return obj + class Color(AutoNumber): + red = () + green = () + blue = () + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + self.assertEqual(list(map(int, Color)), [1, 2, 3]) + + def test_equality(self): + class AlwaysEqual: + def __eq__(self, other): + return True + class OrdinaryEnum(Enum): + a = 1 + self.assertEqual(AlwaysEqual(), OrdinaryEnum.a) + self.assertEqual(OrdinaryEnum.a, AlwaysEqual()) + + def test_ordered_mixin(self): + class OrderedEnum(Enum): + def __ge__(self, other): + if self.__class__ is other.__class__: + return self._value_ >= other._value_ + return NotImplemented + def __gt__(self, other): + if self.__class__ is other.__class__: + return self._value_ > other._value_ + return NotImplemented + def __le__(self, other): + if self.__class__ is other.__class__: + return self._value_ <= other._value_ + return NotImplemented + def __lt__(self, other): + if self.__class__ is other.__class__: + return self._value_ < other._value_ + return NotImplemented + class Grade(OrderedEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 1 + self.assertGreater(Grade.A, Grade.B) + self.assertLessEqual(Grade.F, Grade.C) + self.assertLess(Grade.D, Grade.A) + self.assertGreaterEqual(Grade.B, Grade.B) + self.assertEqual(Grade.B, Grade.B) + self.assertNotEqual(Grade.C, Grade.D) + + def test_extending2(self): + class Shade(Enum): + def shade(self): + print(self.name) + class Color(Shade): + red = 1 + green = 2 + blue = 3 + with self.assertRaises(TypeError): + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + + def test_extending3(self): + class Shade(Enum): + def shade(self): + return self.name + class Color(Shade): + def hex(self): + return '%s hexlified!' % self.value + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertEqual(MoreColor.magenta.hex(), '5 hexlified!') + + + def test_no_duplicates(self): + class UniqueEnum(Enum): + def __init__(self, *args): + cls = self.__class__ + if any(self.value == e.value for e in cls): + a = self.name + e = cls(self.value).name + raise ValueError( + "aliases not allowed in UniqueEnum: %r --> %r" + % (a, e) + ) + class Color(UniqueEnum): + red = 1 + green = 2 + blue = 3 + with self.assertRaises(ValueError): + class Color(UniqueEnum): + red = 1 + green = 2 + blue = 3 + grene = 2 + + def test_init(self): + class Planet(Enum): + MERCURY = (3.303e+23, 2.4397e6) + VENUS = (4.869e+24, 6.0518e6) + EARTH = (5.976e+24, 6.37814e6) + MARS = (6.421e+23, 3.3972e6) + JUPITER = (1.9e+27, 7.1492e7) + SATURN = (5.688e+26, 6.0268e7) + URANUS = (8.686e+25, 2.5559e7) + NEPTUNE = (1.024e+26, 2.4746e7) + def __init__(self, mass, radius): + self.mass = mass # in kilograms + self.radius = radius # in meters + @property + def surface_gravity(self): + # universal gravitational constant (m3 kg-1 s-2) + G = 6.67300E-11 + return G * self.mass / (self.radius * self.radius) + self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80) + self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6)) + + def test_nonhash_value(self): + class AutoNumberInAList(Enum): + def __new__(cls): + value = [len(cls.__members__) + 1] + obj = object.__new__(cls) + obj._value_ = value + return obj + class ColorInAList(AutoNumberInAList): + red = () + green = () + blue = () + self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) + for enum, value in zip(ColorInAList, range(3)): + value += 1 + self.assertEqual(enum.value, [value]) + self.assertIs(ColorInAList([value]), enum) + + def test_conflicting_types_resolved_in_new(self): + class LabelledIntEnum(int, Enum): + def __new__(cls, *args): + value, label = args + obj = int.__new__(cls, value) + obj.label = label + obj._value_ = value + return obj + + class LabelledList(LabelledIntEnum): + unprocessed = (1, "Unprocessed") + payment_complete = (2, "Payment Complete") + + self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) + self.assertEqual(LabelledList.unprocessed, 1) + self.assertEqual(LabelledList(1), LabelledList.unprocessed) + + +class TestUnique(unittest.TestCase): + + def test_unique_clean(self): + @unique + class Clean(Enum): + one = 1 + two = 'dos' + tres = 4.0 + @unique + class Cleaner(IntEnum): + single = 1 + double = 2 + triple = 3 + + def test_unique_dirty(self): + with self.assertRaisesRegex(ValueError, 'tres.*one'): + @unique + class Dirty(Enum): + one = 1 + two = 'dos' + tres = 1 + with self.assertRaisesRegex( + ValueError, + 'double.*single.*turkey.*triple', + ): + @unique + class Dirtier(IntEnum): + single = 1 + double = 1 + triple = 3 + turkey = 3 + + +expected_help_output = """ +Help on class Color in module %s: + +class Color(enum.Enum) + | Method resolution order: + | Color + | enum.Enum + | builtins.object + |\x20\x20 + | Data and other attributes defined here: + |\x20\x20 + | blue = <Color.blue: 3> + |\x20\x20 + | green = <Color.green: 2> + |\x20\x20 + | red = <Color.red: 1> + |\x20\x20 + | ---------------------------------------------------------------------- + | Data descriptors inherited from enum.Enum: + |\x20\x20 + | name + | The name of the Enum member. + |\x20\x20 + | value + | The value of the Enum member. + |\x20\x20 + | ---------------------------------------------------------------------- + | Data descriptors inherited from enum.EnumMeta: + |\x20\x20 + | __members__ + | Returns a mapping of member name->value. + |\x20\x20\x20\x20\x20\x20 + | This mapping lists all enum members, including aliases. Note that this + | is a read-only view of the internal mapping. +""".strip() + +class TestStdLib(unittest.TestCase): + + class Color(Enum): + red = 1 + green = 2 + blue = 3 + + def test_pydoc(self): + # indirectly test __objclass__ + expected_text = expected_help_output % __name__ + output = StringIO() + helper = pydoc.Helper(output=output) + helper(self.Color) + result = output.getvalue().strip() + if result != expected_text: + print_diffs(expected_text, result) + self.fail("outputs are not equal, see diff above") + + def test_inspect_getmembers(self): + values = dict(( + ('__class__', EnumMeta), + ('__doc__', None), + ('__members__', self.Color.__members__), + ('__module__', __name__), + ('blue', self.Color.blue), + ('green', self.Color.green), + ('name', Enum.__dict__['name']), + ('red', self.Color.red), + ('value', Enum.__dict__['value']), + )) + result = dict(inspect.getmembers(self.Color)) + self.assertEqual(values.keys(), result.keys()) + failed = False + for k in values.keys(): + if result[k] != values[k]: + print() + print('\n%s\n key: %s\n result: %s\nexpected: %s\n%s\n' % + ('=' * 75, k, result[k], values[k], '=' * 75), sep='') + failed = True + if failed: + self.fail("result does not equal expected, see print above") + + def test_inspect_classify_class_attrs(self): + # indirectly test __objclass__ + from inspect import Attribute + values = [ + Attribute(name='__class__', kind='data', + defining_class=object, object=EnumMeta), + Attribute(name='__doc__', kind='data', + defining_class=self.Color, object=None), + Attribute(name='__members__', kind='property', + defining_class=EnumMeta, object=EnumMeta.__members__), + Attribute(name='__module__', kind='data', + defining_class=self.Color, object=__name__), + Attribute(name='blue', kind='data', + defining_class=self.Color, object=self.Color.blue), + Attribute(name='green', kind='data', + defining_class=self.Color, object=self.Color.green), + Attribute(name='red', kind='data', + defining_class=self.Color, object=self.Color.red), + Attribute(name='name', kind='data', + defining_class=Enum, object=Enum.__dict__['name']), + Attribute(name='value', kind='data', + defining_class=Enum, object=Enum.__dict__['value']), + ] + values.sort(key=lambda item: item.name) + result = list(inspect.classify_class_attrs(self.Color)) + result.sort(key=lambda item: item.name) + failed = False + for v, r in zip(values, result): + if r != v: + print('\n%s\n%s\n%s\n%s\n' % ('=' * 75, r, v, '=' * 75), sep='') + failed = True + if failed: + self.fail("result does not equal expected, see print above") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_enumerate.py b/Lib/test/test_enumerate.py index 4af217b733..8742afcea8 100644 --- a/Lib/test/test_enumerate.py +++ b/Lib/test/test_enumerate.py @@ -1,4 +1,5 @@ import unittest +import operator import sys import pickle @@ -168,15 +169,12 @@ class TestReversed(unittest.TestCase, PickleTest): x = range(1) self.assertEqual(type(reversed(x)), type(iter(x))) - @support.cpython_only def test_len(self): - # This is an implementation detail, not an interface requirement - from test.test_iterlen import len for s in ('hello', tuple('hello'), list('hello'), range(5)): - self.assertEqual(len(reversed(s)), len(s)) + self.assertEqual(operator.length_hint(reversed(s)), len(s)) r = reversed(s) list(r) - self.assertEqual(len(r), 0) + self.assertEqual(operator.length_hint(r), 0) class SeqWithWeirdLen: called = False def __len__(self): @@ -187,7 +185,7 @@ class TestReversed(unittest.TestCase, PickleTest): def __getitem__(self, index): return index r = reversed(SeqWithWeirdLen()) - self.assertRaises(ZeroDivisionError, len, r) + self.assertRaises(ZeroDivisionError, operator.length_hint, r) def test_gc(self): diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index 871efb2d31..0c88cc424c 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -21,10 +21,11 @@ """ Tests for epoll wrapper. """ -import socket import errno -import time +import os import select +import socket +import time import unittest from test import support @@ -33,7 +34,7 @@ if not hasattr(select, "epoll"): try: select.epoll() -except IOError as e: +except OSError as e: if e.errno == errno.ENOSYS: raise unittest.SkipTest("kernel doesn't support epoll()") raise @@ -55,7 +56,7 @@ class TestEPoll(unittest.TestCase): client.setblocking(False) try: client.connect(('127.0.0.1', self.serverSocket.getsockname()[1])) - except socket.error as e: + except OSError as e: self.assertEqual(e.args[0], errno.EINPROGRESS) else: raise AssertionError("Connect should have raised EINPROGRESS") @@ -86,6 +87,13 @@ class TestEPoll(unittest.TestCase): self.assertRaises(TypeError, select.epoll, ['foo']) self.assertRaises(TypeError, select.epoll, {}) + def test_context_manager(self): + with select.epoll(16) as ep: + self.assertGreater(ep.fileno(), 0) + self.assertFalse(ep.closed) + self.assertTrue(ep.closed) + self.assertRaises(ValueError, ep.fileno) + def test_add(self): server, client = self._connected_pair() @@ -114,12 +122,12 @@ class TestEPoll(unittest.TestCase): # ValueError: file descriptor cannot be a negative integer (-1) self.assertRaises(ValueError, ep.register, -1, select.EPOLLIN | select.EPOLLOUT) - # IOError: [Errno 9] Bad file descriptor - self.assertRaises(IOError, ep.register, 10000, + # OSError: [Errno 9] Bad file descriptor + self.assertRaises(OSError, ep.register, 10000, select.EPOLLIN | select.EPOLLOUT) # registering twice also raises an exception ep.register(server, select.EPOLLIN | select.EPOLLOUT) - self.assertRaises(IOError, ep.register, server, + self.assertRaises(OSError, ep.register, server, select.EPOLLIN | select.EPOLLOUT) finally: ep.close() @@ -141,7 +149,7 @@ class TestEPoll(unittest.TestCase): ep.close() try: ep2.poll(1, 4) - except IOError as e: + except OSError as e: self.assertEqual(e.args[0], errno.EBADF, e) else: self.fail("epoll on closed fd didn't raise EBADF") @@ -217,6 +225,35 @@ class TestEPoll(unittest.TestCase): server.close() ep.unregister(fd) + def test_close(self): + open_file = open(__file__, "rb") + self.addCleanup(open_file.close) + fd = open_file.fileno() + epoll = select.epoll() + + # test fileno() method and closed attribute + self.assertIsInstance(epoll.fileno(), int) + self.assertFalse(epoll.closed) + + # test close() + epoll.close() + self.assertTrue(epoll.closed) + self.assertRaises(ValueError, epoll.fileno) + + # close() can be called more than once + epoll.close() + + # operations must fail with ValueError("I/O operation on closed ...") + self.assertRaises(ValueError, epoll.modify, fd, select.EPOLLIN) + self.assertRaises(ValueError, epoll.poll, 1.0) + self.assertRaises(ValueError, epoll.register, fd, select.EPOLLIN) + self.assertRaises(ValueError, epoll.unregister, fd) + + def test_fd_non_inheritable(self): + epoll = select.epoll() + self.addCleanup(epoll.close) + self.assertEqual(os.get_inheritable(epoll.fileno()), False) + def test_main(): support.run_unittest(TestEPoll) diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py index 8d11d90abb..62a2e324b1 100644 --- a/Lib/test/test_exceptions.py +++ b/Lib/test/test_exceptions.py @@ -8,8 +8,8 @@ import weakref import errno from test.support import (TESTFN, captured_output, check_impl_detail, - check_warnings, cpython_only, gc_collect, - no_tracing, run_unittest, unlink) + check_warnings, cpython_only, gc_collect, run_unittest, + no_tracing, unlink) class NaiveException(Exception): def __init__(self, x): @@ -257,22 +257,22 @@ class ExceptionTests(unittest.TestCase): {'args' : ('foo', 1)}), (SystemExit, ('foo',), {'args' : ('foo',), 'code' : 'foo'}), - (IOError, ('foo',), + (OSError, ('foo',), {'args' : ('foo',), 'filename' : None, 'errno' : None, 'strerror' : None}), - (IOError, ('foo', 'bar'), + (OSError, ('foo', 'bar'), {'args' : ('foo', 'bar'), 'filename' : None, 'errno' : 'foo', 'strerror' : 'bar'}), - (IOError, ('foo', 'bar', 'baz'), + (OSError, ('foo', 'bar', 'baz'), {'args' : ('foo', 'bar'), 'filename' : 'baz', 'errno' : 'foo', 'strerror' : 'bar'}), - (IOError, ('foo', 'bar', 'baz', 'quux'), - {'args' : ('foo', 'bar', 'baz', 'quux')}), - (EnvironmentError, ('errnoStr', 'strErrorStr', 'filenameStr'), + (OSError, ('foo', 'bar', 'baz', None, 'quux'), + {'args' : ('foo', 'bar'), 'filename' : 'baz', 'filename2': 'quux'}), + (OSError, ('errnoStr', 'strErrorStr', 'filenameStr'), {'args' : ('errnoStr', 'strErrorStr'), 'strerror' : 'strErrorStr', 'errno' : 'errnoStr', 'filename' : 'filenameStr'}), - (EnvironmentError, (1, 'strErrorStr', 'filenameStr'), + (OSError, (1, 'strErrorStr', 'filenameStr'), {'args' : (1, 'strErrorStr'), 'errno' : 1, 'strerror' : 'strErrorStr', 'filename' : 'filenameStr'}), (SyntaxError, (), {'msg' : None, 'text' : None, @@ -422,7 +422,7 @@ class ExceptionTests(unittest.TestCase): self.assertIsNone(e.__context__) self.assertIsNone(e.__cause__) - class MyException(EnvironmentError): + class MyException(OSError): pass e = MyException() @@ -968,8 +968,5 @@ class ImportErrorTests(unittest.TestCase): self.assertEqual(str(arg), str(exc)) -def test_main(): - run_unittest(ExceptionTests, ImportErrorTests) - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py index 770e70c77d..ebd99edc8f 100644 --- a/Lib/test/test_faulthandler.py +++ b/Lib/test/test_faulthandler.py @@ -19,18 +19,6 @@ except ImportError: TIMEOUT = 0.5 -try: - from resource import setrlimit, RLIMIT_CORE, error as resource_error -except ImportError: - prepare_subprocess = None -else: - def prepare_subprocess(): - # don't create core file - try: - setrlimit(RLIMIT_CORE, (0, 0)) - except (ValueError, resource_error): - pass - def expected_traceback(lineno1, lineno2, header, min_count=1): regex = header regex += ' File "<string>", line %s in func\n' % lineno1 @@ -59,10 +47,8 @@ class FaultHandlerTests(unittest.TestCase): build, and replace "Current thread 0x00007f8d8fbd9700" by "Current thread XXX". """ - options = {} - if prepare_subprocess: - options['preexec_fn'] = prepare_subprocess - process = script_helper.spawn_python('-c', code, **options) + with support.SuppressCrashReport(): + process = script_helper.spawn_python('-c', code) stdout, stderr = process.communicate() exitcode = process.wait() output = support.strip_python_stderr(stdout) @@ -86,9 +72,9 @@ class FaultHandlerTests(unittest.TestCase): Raise an error if the output doesn't match the expected format. """ if all_threads: - header = 'Current thread XXX' + header = 'Current thread XXX (most recent call first)' else: - header = 'Traceback (most recent call first)' + header = 'Stack (most recent call first)' regex = """ ^Fatal Python error: {name} @@ -101,8 +87,7 @@ class FaultHandlerTests(unittest.TestCase): header=re.escape(header)) if other_regex: regex += '|' + other_regex - with support.suppress_crash_popup(): - output, exitcode = self.get_output(code, filename) + output, exitcode = self.get_output(code, filename) output = '\n'.join(output) self.assertRegex(output, regex) self.assertNotEqual(exitcode, 0) @@ -232,8 +217,7 @@ faulthandler.disable() faulthandler._sigsegv() """.strip() not_expected = 'Fatal Python error' - with support.suppress_crash_popup(): - stderr, exitcode = self.get_output(code) + stderr, exitcode = self.get_output(code) stder = '\n'.join(stderr) self.assertTrue(not_expected not in stderr, "%r is present in %r" % (not_expected, stderr)) @@ -264,16 +248,34 @@ faulthandler._sigsegv() def test_disabled_by_default(self): # By default, the module should be disabled code = "import faulthandler; print(faulthandler.is_enabled())" - rc, stdout, stderr = assert_python_ok("-c", code) - stdout = (stdout + stderr).strip() - self.assertEqual(stdout, b"False") + args = (sys.executable, '-E', '-c', code) + # don't use assert_python_ok() because it always enable faulthandler + output = subprocess.check_output(args) + self.assertEqual(output.rstrip(), b"False") def test_sys_xoptions(self): # Test python -X faulthandler code = "import faulthandler; print(faulthandler.is_enabled())" - rc, stdout, stderr = assert_python_ok("-X", "faulthandler", "-c", code) - stdout = (stdout + stderr).strip() - self.assertEqual(stdout, b"True") + args = (sys.executable, "-E", "-X", "faulthandler", "-c", code) + # don't use assert_python_ok() because it always enable faulthandler + output = subprocess.check_output(args) + self.assertEqual(output.rstrip(), b"True") + + def test_env_var(self): + # empty env var + code = "import faulthandler; print(faulthandler.is_enabled())" + args = (sys.executable, "-c", code) + env = os.environ.copy() + env['PYTHONFAULTHANDLER'] = '' + # don't use assert_python_ok() because it always enable faulthandler + output = subprocess.check_output(args, env=env) + self.assertEqual(output.rstrip(), b"False") + + # non-empty env var + env = os.environ.copy() + env['PYTHONFAULTHANDLER'] = '1' + output = subprocess.check_output(args, env=env) + self.assertEqual(output.rstrip(), b"True") def check_dump_traceback(self, filename): """ @@ -304,7 +306,7 @@ funcA() else: lineno = 8 expected = [ - 'Traceback (most recent call first):', + 'Stack (most recent call first):', ' File "<string>", line %s in funcB' % lineno, ' File "<string>", line 11 in funcA', ' File "<string>", line 13 in <module>' @@ -336,7 +338,7 @@ def {func_name}(): func_name=func_name, ) expected = [ - 'Traceback (most recent call first):', + 'Stack (most recent call first):', ' File "<string>", line 4 in %s' % truncated, ' File "<string>", line 6 in <module>' ] @@ -390,13 +392,13 @@ waiter.join() else: lineno = 10 regex = """ -^Thread 0x[0-9a-f]+: +^Thread 0x[0-9a-f]+ \(most recent call first\): (?: File ".*threading.py", line [0-9]+ in [_a-z]+ ){{1,3}} File "<string>", line 23 in run File ".*threading.py", line [0-9]+ in _bootstrap_inner File ".*threading.py", line [0-9]+ in _bootstrap -Current thread XXX: +Current thread XXX \(most recent call first\): File "<string>", line {lineno} in dump File "<string>", line 28 in <module>$ """.strip() @@ -459,7 +461,7 @@ if file is not None: count = loops if repeat: count *= 2 - header = r'Timeout \(%s\)!\nThread 0x[0-9a-f]+:\n' % timeout_str + header = r'Timeout \(%s\)!\nThread 0x[0-9a-f]+ \(most recent call first\):\n' % timeout_str regex = expected_traceback(9, 20, header, min_count=count) self.assertRegex(trace, regex) else: @@ -561,9 +563,9 @@ sys.exit(exitcode) trace = '\n'.join(trace) if not unregister: if all_threads: - regex = 'Current thread XXX:\n' + regex = 'Current thread XXX \(most recent call first\):\n' else: - regex = 'Traceback \(most recent call first\):\n' + regex = 'Stack \(most recent call first\):\n' regex = expected_traceback(7, 28, regex) self.assertRegex(trace, regex) else: @@ -590,8 +592,5 @@ sys.exit(exitcode) self.check_register(chain=True) -def test_main(): - support.run_unittest(FaultHandlerTests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_fcntl.py b/Lib/test/test_fcntl.py index 1810c4e95e..e3b7ed20a0 100644 --- a/Lib/test/test_fcntl.py +++ b/Lib/test/test_fcntl.py @@ -1,7 +1,4 @@ """Test program for the fcntl C module. - -OS/2+EMX doesn't support the file locking operations. - """ import platform import os @@ -39,8 +36,6 @@ def get_lockdata(): lockdata = struct.pack('qqihhi', 0, 0, 0, fcntl.F_WRLCK, 0, 0) elif sys.platform in ['aix3', 'aix4', 'hp-uxB', 'unixware7']: lockdata = struct.pack('hhlllii', fcntl.F_WRLCK, 0, 0, 0, 0, 0, 0) - elif sys.platform in ['os2emx']: - lockdata = None else: lockdata = struct.pack('hh'+start_len+'hh', fcntl.F_WRLCK, 0, 0, 0, 0, 0) if lockdata: @@ -72,18 +67,20 @@ class TestFcntl(unittest.TestCase): rv = fcntl.fcntl(self.f.fileno(), fcntl.F_SETFL, os.O_NONBLOCK) if verbose: print('Status from fcntl with O_NONBLOCK: ', rv) - if sys.platform not in ['os2emx']: - rv = fcntl.fcntl(self.f.fileno(), fcntl.F_SETLKW, lockdata) - if verbose: - print('String from fcntl with F_SETLKW: ', repr(rv)) + rv = fcntl.fcntl(self.f.fileno(), fcntl.F_SETLKW, lockdata) + if verbose: + print('String from fcntl with F_SETLKW: ', repr(rv)) self.f.close() def test_fcntl_file_descriptor(self): # again, but pass the file rather than numeric descriptor self.f = open(TESTFN, 'wb') rv = fcntl.fcntl(self.f, fcntl.F_SETFL, os.O_NONBLOCK) - if sys.platform not in ['os2emx']: - rv = fcntl.fcntl(self.f, fcntl.F_SETLKW, lockdata) + if verbose: + print('Status from fcntl with O_NONBLOCK: ', rv) + rv = fcntl.fcntl(self.f, fcntl.F_SETLKW, lockdata) + if verbose: + print('String from fcntl with F_SETLKW: ', repr(rv)) self.f.close() def test_fcntl_bad_file(self): @@ -127,6 +124,26 @@ class TestFcntl(unittest.TestCase): finally: os.close(fd) + def test_flock(self): + # Solaris needs readable file for shared lock + self.f = open(TESTFN, 'wb+') + fileno = self.f.fileno() + fcntl.flock(fileno, fcntl.LOCK_SH) + fcntl.flock(fileno, fcntl.LOCK_UN) + fcntl.flock(self.f, fcntl.LOCK_SH | fcntl.LOCK_NB) + fcntl.flock(self.f, fcntl.LOCK_UN) + fcntl.flock(fileno, fcntl.LOCK_EX) + fcntl.flock(fileno, fcntl.LOCK_UN) + + self.assertRaises(ValueError, fcntl.flock, -1, fcntl.LOCK_SH) + self.assertRaises(TypeError, fcntl.flock, 'spam', fcntl.LOCK_SH) + + @cpython_only + def test_flock_overflow(self): + import _testcapi + self.assertRaises(OverflowError, fcntl.flock, _testcapi.INT_MAX+1, + fcntl.LOCK_SH) + def test_main(): run_unittest(TestFcntl) diff --git a/Lib/test/test_file.py b/Lib/test/test_file.py index 76b1694e98..d54e976143 100644 --- a/Lib/test/test_file.py +++ b/Lib/test/test_file.py @@ -87,7 +87,7 @@ class AutoFileTests: self.assertTrue(not f.closed) if hasattr(f, "readinto"): - self.assertRaises((IOError, TypeError), f.readinto, "") + self.assertRaises((OSError, TypeError), f.readinto, "") f.close() self.assertTrue(f.closed) @@ -126,7 +126,7 @@ class AutoFileTests: self.assertEqual(self.f.__exit__(*sys.exc_info()), None) def testReadWhenWriting(self): - self.assertRaises(IOError, self.f.read) + self.assertRaises(OSError, self.f.read) class CAutoFileTests(AutoFileTests, unittest.TestCase): open = io.open @@ -177,7 +177,7 @@ class OtherFileTests: d = int(f.read().decode("ascii")) f.close() f.close() - except IOError as msg: + except OSError as msg: self.fail('error setting buffer size %d: %s' % (s, str(msg))) self.assertEqual(d, s) diff --git a/Lib/test/test_filecmp.py b/Lib/test/test_filecmp.py index 09599794ad..c7cb3fcf36 100644 --- a/Lib/test/test_filecmp.py +++ b/Lib/test/test_filecmp.py @@ -1,4 +1,3 @@ - import os, filecmp, shutil, tempfile import unittest from test import support @@ -40,15 +39,27 @@ class FileCompareTestCase(unittest.TestCase): self.assertFalse(filecmp.cmp(self.name, self.dir), "File and directory compare as equal") + def test_cache_clear(self): + first_compare = filecmp.cmp(self.name, self.name_same, shallow=False) + second_compare = filecmp.cmp(self.name, self.name_diff, shallow=False) + filecmp.clear_cache() + self.assertTrue(len(filecmp._cache) == 0, + "Cache not cleared after calling clear_cache") + class DirCompareTestCase(unittest.TestCase): def setUp(self): tmpdir = tempfile.gettempdir() self.dir = os.path.join(tmpdir, 'dir') self.dir_same = os.path.join(tmpdir, 'dir-same') self.dir_diff = os.path.join(tmpdir, 'dir-diff') + + # Another dir is created under dir_same, but it has a name from the + # ignored list so it should not affect testing results. + self.dir_ignored = os.path.join(self.dir_same, '.hg') + self.caseinsensitive = os.path.normcase('A') == os.path.normcase('a') data = 'Contents of file go here.\n' - for dir in [self.dir, self.dir_same, self.dir_diff]: + for dir in (self.dir, self.dir_same, self.dir_diff, self.dir_ignored): shutil.rmtree(dir, True) os.mkdir(dir) if self.caseinsensitive and dir is self.dir_same: @@ -64,9 +75,11 @@ class DirCompareTestCase(unittest.TestCase): output.close() def tearDown(self): - shutil.rmtree(self.dir) - shutil.rmtree(self.dir_same) - shutil.rmtree(self.dir_diff) + for dir in (self.dir, self.dir_same, self.dir_diff): + shutil.rmtree(dir) + + def test_default_ignores(self): + self.assertIn('.hg', filecmp.DEFAULT_IGNORES) def test_cmpfiles(self): self.assertTrue(filecmp.cmpfiles(self.dir, self.dir, ['file']) == diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py index a7624d3c3d..a537bc8c7c 100644 --- a/Lib/test/test_fileinput.py +++ b/Lib/test/test_fileinput.py @@ -22,7 +22,7 @@ except ImportError: from io import StringIO from fileinput import FileInput, hook_encoded -from test.support import verbose, TESTFN, run_unittest +from test.support import verbose, TESTFN, run_unittest, check_warnings from test.support import unlink as safe_unlink @@ -224,8 +224,10 @@ class FileInputTests(unittest.TestCase): try: # try opening in universal newline mode t1 = writeTmp(1, [b"A\nB\r\nC\rD"], mode="wb") - fi = FileInput(files=t1, mode="U") - lines = list(fi) + with check_warnings(('', DeprecationWarning)): + fi = FileInput(files=t1, mode="U") + with check_warnings(('', DeprecationWarning)): + lines = list(fi) self.assertEqual(lines, ["A\n", "B\n", "C\n", "D"]) finally: remove_tempfiles(t1) @@ -293,8 +295,8 @@ class FileInputTests(unittest.TestCase): try: t1 = writeTmp(1, [""]) with FileInput(files=t1) as fi: - raise IOError - except IOError: + raise OSError + except OSError: self.assertEqual(fi._files, ()) finally: remove_tempfiles(t1) @@ -866,27 +868,13 @@ class Test_hook_encoded(unittest.TestCase): self.assertEqual(lines, expected_lines) check('r', ['A\n', 'B\n', 'C\n', 'D\u20ac']) - check('rU', ['A\n', 'B\n', 'C\n', 'D\u20ac']) - check('U', ['A\n', 'B\n', 'C\n', 'D\u20ac']) + with self.assertWarns(DeprecationWarning): + check('rU', ['A\n', 'B\n', 'C\n', 'D\u20ac']) + with self.assertWarns(DeprecationWarning): + check('U', ['A\n', 'B\n', 'C\n', 'D\u20ac']) with self.assertRaises(ValueError): check('rb', ['A\n', 'B\r\n', 'C\r', 'D\u20ac']) -def test_main(): - run_unittest( - BufferSizesTests, - FileInputTests, - Test_fileinput_input, - Test_fileinput_close, - Test_fileinput_nextfile, - Test_fileinput_filename, - Test_fileinput_lineno, - Test_fileinput_filelineno, - Test_fileinput_fileno, - Test_fileinput_isfirstline, - Test_fileinput_isstdin, - Test_hook_compressed, - Test_hook_encoded, - ) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py index 444be91920..c37482ed70 100644 --- a/Lib/test/test_fileio.py +++ b/Lib/test/test_fileio.py @@ -144,16 +144,16 @@ class AutoFileTests(unittest.TestCase): # Unix calls dircheck() and returns "[Errno 21]: Is a directory" try: _FileIO('.', 'r') - except IOError as e: + except OSError as e: self.assertNotEqual(e.errno, 0) self.assertEqual(e.filename, ".") else: - self.fail("Should have raised IOError") + self.fail("Should have raised OSError") @unittest.skipIf(os.name == 'nt', "test only works on a POSIX-like system") def testOpenDirFD(self): fd = os.open('.', os.O_RDONLY) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: _FileIO(fd, 'r') os.close(fd) self.assertEqual(cm.exception.errno, errno.EISDIR) @@ -171,7 +171,7 @@ class AutoFileTests(unittest.TestCase): finally: try: self.f.close() - except IOError: + except OSError: pass return wrapper @@ -183,14 +183,14 @@ class AutoFileTests(unittest.TestCase): os.close(f.fileno()) try: func(self, f) - except IOError as e: + except OSError as e: self.assertEqual(e.errno, errno.EBADF) else: - self.fail("Should have raised IOError") + self.fail("Should have raised OSError") finally: try: self.f.close() - except IOError: + except OSError: pass return wrapper @@ -237,7 +237,7 @@ class AutoFileTests(unittest.TestCase): def ReopenForRead(self): try: self.f.close() - except IOError: + except OSError: pass self.f = _FileIO(TESTFN, 'r') os.close(self.f.fileno()) @@ -285,7 +285,7 @@ class OtherFileTests(unittest.TestCase): if sys.platform != "win32": try: f = _FileIO("/dev/tty", "a") - except EnvironmentError: + except OSError: # When run in a cron job there just aren't any # ttys, so skip the test. This also handles other # OS'es that don't support /dev/tty. @@ -360,7 +360,7 @@ class OtherFileTests(unittest.TestCase): self.assertRaises(OSError, _FileIO, make_bad_fd()) if sys.platform == 'win32': import msvcrt - self.assertRaises(IOError, msvcrt.get_osfhandle, make_bad_fd()) + self.assertRaises(OSError, msvcrt.get_osfhandle, make_bad_fd()) @cpython_only def testInvalidFd_overflow(self): diff --git a/Lib/test/test_finalization.py b/Lib/test/test_finalization.py new file mode 100644 index 0000000000..03ac1aa274 --- /dev/null +++ b/Lib/test/test_finalization.py @@ -0,0 +1,522 @@ +""" +Tests for object finalization semantics, as outlined in PEP 442. +""" + +import contextlib +import gc +import unittest +import weakref + +try: + from _testcapi import with_tp_del +except ImportError: + def with_tp_del(cls): + class C(object): + def __new__(cls, *args, **kwargs): + raise TypeError('requires _testcapi.with_tp_del') + return C + +from test import support + + +class NonGCSimpleBase: + """ + The base class for all the objects under test, equipped with various + testing features. + """ + + survivors = [] + del_calls = [] + tp_del_calls = [] + errors = [] + + _cleaning = False + + __slots__ = () + + @classmethod + def _cleanup(cls): + cls.survivors.clear() + cls.errors.clear() + gc.garbage.clear() + gc.collect() + cls.del_calls.clear() + cls.tp_del_calls.clear() + + @classmethod + @contextlib.contextmanager + def test(cls): + """ + A context manager to use around all finalization tests. + """ + with support.disable_gc(): + cls.del_calls.clear() + cls.tp_del_calls.clear() + NonGCSimpleBase._cleaning = False + try: + yield + if cls.errors: + raise cls.errors[0] + finally: + NonGCSimpleBase._cleaning = True + cls._cleanup() + + def check_sanity(self): + """ + Check the object is sane (non-broken). + """ + + def __del__(self): + """ + PEP 442 finalizer. Record that this was called, check the + object is in a sane state, and invoke a side effect. + """ + try: + if not self._cleaning: + self.del_calls.append(id(self)) + self.check_sanity() + self.side_effect() + except Exception as e: + self.errors.append(e) + + def side_effect(self): + """ + A side effect called on destruction. + """ + + +class SimpleBase(NonGCSimpleBase): + + def __init__(self): + self.id_ = id(self) + + def check_sanity(self): + assert self.id_ == id(self) + + +class NonGC(NonGCSimpleBase): + __slots__ = () + +class NonGCResurrector(NonGCSimpleBase): + __slots__ = () + + def side_effect(self): + """ + Resurrect self by storing self in a class-wide list. + """ + self.survivors.append(self) + +class Simple(SimpleBase): + pass + +class SimpleResurrector(NonGCResurrector, SimpleBase): + pass + + +class TestBase: + + def setUp(self): + self.old_garbage = gc.garbage[:] + gc.garbage[:] = [] + + def tearDown(self): + # None of the tests here should put anything in gc.garbage + try: + self.assertEqual(gc.garbage, []) + finally: + del self.old_garbage + gc.collect() + + def assert_del_calls(self, ids): + self.assertEqual(sorted(SimpleBase.del_calls), sorted(ids)) + + def assert_tp_del_calls(self, ids): + self.assertEqual(sorted(SimpleBase.tp_del_calls), sorted(ids)) + + def assert_survivors(self, ids): + self.assertEqual(sorted(id(x) for x in SimpleBase.survivors), sorted(ids)) + + def assert_garbage(self, ids): + self.assertEqual(sorted(id(x) for x in gc.garbage), sorted(ids)) + + def clear_survivors(self): + SimpleBase.survivors.clear() + + +class SimpleFinalizationTest(TestBase, unittest.TestCase): + """ + Test finalization without refcycles. + """ + + def test_simple(self): + with SimpleBase.test(): + s = Simple() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + + def test_simple_resurrect(self): + with SimpleBase.test(): + s = SimpleResurrector() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors(ids) + self.assertIsNot(wr(), None) + self.clear_survivors() + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + + def test_non_gc(self): + with SimpleBase.test(): + s = NonGC() + self.assertFalse(gc.is_tracked(s)) + ids = [id(s)] + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + + def test_non_gc_resurrect(self): + with SimpleBase.test(): + s = NonGCResurrector() + self.assertFalse(gc.is_tracked(s)) + ids = [id(s)] + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors(ids) + self.clear_survivors() + gc.collect() + self.assert_del_calls(ids * 2) + self.assert_survivors(ids) + + +class SelfCycleBase: + + def __init__(self): + super().__init__() + self.ref = self + + def check_sanity(self): + super().check_sanity() + assert self.ref is self + +class SimpleSelfCycle(SelfCycleBase, Simple): + pass + +class SelfCycleResurrector(SelfCycleBase, SimpleResurrector): + pass + +class SuicidalSelfCycle(SelfCycleBase, Simple): + + def side_effect(self): + """ + Explicitly break the reference cycle. + """ + self.ref = None + + +class SelfCycleFinalizationTest(TestBase, unittest.TestCase): + """ + Test finalization of an object having a single cyclic reference to + itself. + """ + + def test_simple(self): + with SimpleBase.test(): + s = SimpleSelfCycle() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + + def test_simple_resurrect(self): + # Test that __del__ can resurrect the object being finalized. + with SimpleBase.test(): + s = SelfCycleResurrector() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors(ids) + # XXX is this desirable? + self.assertIs(wr(), None) + # When trying to destroy the object a second time, __del__ + # isn't called anymore (and the object isn't resurrected). + self.clear_survivors() + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + + def test_simple_suicide(self): + # Test the GC is able to deal with an object that kills its last + # reference during __del__. + with SimpleBase.test(): + s = SuicidalSelfCycle() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + + +class ChainedBase: + + def chain(self, left): + self.suicided = False + self.left = left + left.right = self + + def check_sanity(self): + super().check_sanity() + if self.suicided: + assert self.left is None + assert self.right is None + else: + left = self.left + if left.suicided: + assert left.right is None + else: + assert left.right is self + right = self.right + if right.suicided: + assert right.left is None + else: + assert right.left is self + +class SimpleChained(ChainedBase, Simple): + pass + +class ChainedResurrector(ChainedBase, SimpleResurrector): + pass + +class SuicidalChained(ChainedBase, Simple): + + def side_effect(self): + """ + Explicitly break the reference cycle. + """ + self.suicided = True + self.left = None + self.right = None + + +class CycleChainFinalizationTest(TestBase, unittest.TestCase): + """ + Test finalization of a cyclic chain. These tests are similar in + spirit to the self-cycle tests above, but the collectable object + graph isn't trivial anymore. + """ + + def build_chain(self, classes): + nodes = [cls() for cls in classes] + for i in range(len(nodes)): + nodes[i].chain(nodes[i-1]) + return nodes + + def check_non_resurrecting_chain(self, classes): + N = len(classes) + with SimpleBase.test(): + nodes = self.build_chain(classes) + ids = [id(s) for s in nodes] + wrs = [weakref.ref(s) for s in nodes] + del nodes + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + self.assertEqual([wr() for wr in wrs], [None] * N) + gc.collect() + self.assert_del_calls(ids) + + def check_resurrecting_chain(self, classes): + N = len(classes) + with SimpleBase.test(): + nodes = self.build_chain(classes) + N = len(nodes) + ids = [id(s) for s in nodes] + survivor_ids = [id(s) for s in nodes if isinstance(s, SimpleResurrector)] + wrs = [weakref.ref(s) for s in nodes] + del nodes + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors(survivor_ids) + # XXX desirable? + self.assertEqual([wr() for wr in wrs], [None] * N) + self.clear_survivors() + gc.collect() + self.assert_del_calls(ids) + self.assert_survivors([]) + + def test_homogenous(self): + self.check_non_resurrecting_chain([SimpleChained] * 3) + + def test_homogenous_resurrect(self): + self.check_resurrecting_chain([ChainedResurrector] * 3) + + def test_homogenous_suicidal(self): + self.check_non_resurrecting_chain([SuicidalChained] * 3) + + def test_heterogenous_suicidal_one(self): + self.check_non_resurrecting_chain([SuicidalChained, SimpleChained] * 2) + + def test_heterogenous_suicidal_two(self): + self.check_non_resurrecting_chain( + [SuicidalChained] * 2 + [SimpleChained] * 2) + + def test_heterogenous_resurrect_one(self): + self.check_resurrecting_chain([ChainedResurrector, SimpleChained] * 2) + + def test_heterogenous_resurrect_two(self): + self.check_resurrecting_chain( + [ChainedResurrector, SimpleChained, SuicidalChained] * 2) + + def test_heterogenous_resurrect_three(self): + self.check_resurrecting_chain( + [ChainedResurrector] * 2 + [SimpleChained] * 2 + [SuicidalChained] * 2) + + +# NOTE: the tp_del slot isn't automatically inherited, so we have to call +# with_tp_del() for each instantiated class. + +class LegacyBase(SimpleBase): + + def __del__(self): + try: + # Do not invoke side_effect here, since we are now exercising + # the tp_del slot. + if not self._cleaning: + self.del_calls.append(id(self)) + self.check_sanity() + except Exception as e: + self.errors.append(e) + + def __tp_del__(self): + """ + Legacy (pre-PEP 442) finalizer, mapped to a tp_del slot. + """ + try: + if not self._cleaning: + self.tp_del_calls.append(id(self)) + self.check_sanity() + self.side_effect() + except Exception as e: + self.errors.append(e) + +@with_tp_del +class Legacy(LegacyBase): + pass + +@with_tp_del +class LegacyResurrector(LegacyBase): + + def side_effect(self): + """ + Resurrect self by storing self in a class-wide list. + """ + self.survivors.append(self) + +@with_tp_del +class LegacySelfCycle(SelfCycleBase, LegacyBase): + pass + + +@support.cpython_only +class LegacyFinalizationTest(TestBase, unittest.TestCase): + """ + Test finalization of objects with a tp_del. + """ + + def tearDown(self): + # These tests need to clean up a bit more, since they create + # uncollectable objects. + gc.garbage.clear() + gc.collect() + super().tearDown() + + def test_legacy(self): + with SimpleBase.test(): + s = Legacy() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_tp_del_calls(ids) + self.assert_survivors([]) + self.assertIs(wr(), None) + gc.collect() + self.assert_del_calls(ids) + self.assert_tp_del_calls(ids) + + def test_legacy_resurrect(self): + with SimpleBase.test(): + s = LegacyResurrector() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls(ids) + self.assert_tp_del_calls(ids) + self.assert_survivors(ids) + # weakrefs are cleared before tp_del is called. + self.assertIs(wr(), None) + self.clear_survivors() + gc.collect() + self.assert_del_calls(ids) + self.assert_tp_del_calls(ids * 2) + self.assert_survivors(ids) + self.assertIs(wr(), None) + + def test_legacy_self_cycle(self): + # Self-cycles with legacy finalizers end up in gc.garbage. + with SimpleBase.test(): + s = LegacySelfCycle() + ids = [id(s)] + wr = weakref.ref(s) + del s + gc.collect() + self.assert_del_calls([]) + self.assert_tp_del_calls([]) + self.assert_survivors([]) + self.assert_garbage(ids) + self.assertIsNot(wr(), None) + # Break the cycle to allow collection + gc.garbage[0].ref = None + self.assert_garbage([]) + self.assertIs(wr(), None) + + +def test_main(): + support.run_unittest(__name__) + +if __name__ == "__main__": + test_main() diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py index 96d907a671..e1e1f0423e 100644 --- a/Lib/test/test_float.py +++ b/Lib/test/test_float.py @@ -41,6 +41,7 @@ class GeneralFloatCases(unittest.TestCase): self.assertRaises(ValueError, float, "-.") self.assertRaises(ValueError, float, b"-") self.assertRaises(TypeError, float, {}) + self.assertRaisesRegex(TypeError, "not 'dict'", float, {}) # Lone surrogate self.assertRaises(UnicodeEncodeError, float, '\uD8F0') # check that we don't accept alternate exponent markers diff --git a/Lib/test/test_fork1.py b/Lib/test/test_fork1.py index 8192c38a44..e0626dffdc 100644 --- a/Lib/test/test_fork1.py +++ b/Lib/test/test_fork1.py @@ -1,7 +1,7 @@ """This test checks for correct fork() behavior. """ -import imp +import _imp as imp import os import signal import sys diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index ea5a731d98..fc71e4818e 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -142,7 +142,8 @@ class FormatTest(unittest.TestCase): testformat("%#+027.23X", big, "+0X0001234567890ABCDEF12345") # same, except no 0 flag testformat("%#+27.23X", big, " +0X001234567890ABCDEF12345") - testformat("%x", float(big), "123456_______________", 6) + with self.assertWarns(DeprecationWarning): + testformat("%x", float(big), "123456_______________", 6) big = 0o12345670123456701234567012345670 # 32 octal digits testformat("%o", big, "12345670123456701234567012345670") testformat("%o", -big, "-12345670123456701234567012345670") @@ -182,7 +183,8 @@ class FormatTest(unittest.TestCase): testformat("%034.33o", big, "0012345670123456701234567012345670") # base marker shouldn't change that testformat("%0#34.33o", big, "0o012345670123456701234567012345670") - testformat("%o", float(big), "123456__________________________", 6) + with self.assertWarns(DeprecationWarning): + testformat("%o", float(big), "123456__________________________", 6) # Some small ints, in both Python int and flavors). testformat("%d", 42, "42") testformat("%d", -42, "-42") @@ -193,7 +195,8 @@ class FormatTest(unittest.TestCase): testformat("%#x", 1, "0x1") testformat("%#X", 1, "0X1") testformat("%#X", 1, "0X1") - testformat("%#x", 1.0, "0x1") + with self.assertWarns(DeprecationWarning): + testformat("%#x", 1.0, "0x1") testformat("%#o", 1, "0o1") testformat("%#o", 1, "0o1") testformat("%#o", 0, "0o0") @@ -210,12 +213,14 @@ class FormatTest(unittest.TestCase): testformat("%x", -0x42, "-42") testformat("%x", 0x42, "42") testformat("%x", -0x42, "-42") - testformat("%x", float(0x42), "42") + with self.assertWarns(DeprecationWarning): + testformat("%x", float(0x42), "42") testformat("%o", 0o42, "42") testformat("%o", -0o42, "-42") testformat("%o", 0o42, "42") testformat("%o", -0o42, "-42") - testformat("%o", float(0o42), "42") + with self.assertWarns(DeprecationWarning): + testformat("%o", float(0o42), "42") testformat("%r", "\u0378", "'\\u0378'") # non printable testformat("%a", "\u0378", "'\\u0378'") # non printable testformat("%r", "\u0374", "'\u0374'") # printable @@ -307,10 +312,25 @@ class FormatTest(unittest.TestCase): finally: locale.setlocale(locale.LC_ALL, oldloc) + @support.cpython_only + def test_optimisations(self): + text = "abcde" # 5 characters + + self.assertIs("%s" % text, text) + self.assertIs("%.5s" % text, text) + self.assertIs("%.10s" % text, text) + self.assertIs("%1s" % text, text) + self.assertIs("%5s" % text, text) + self.assertIs("{0}".format(text), text) + self.assertIs("{0:s}".format(text), text) + self.assertIs("{0:.5s}".format(text), text) + self.assertIs("{0:.10s}".format(text), text) + self.assertIs("{0:1s}".format(text), text) + self.assertIs("{0:5s}".format(text), text) -def test_main(): - support.run_unittest(FormatTest) + self.assertIs(text % (), text) + self.assertIs(text.format(), text) def test_precision(self): f = 1.2 @@ -318,14 +338,12 @@ def test_main(): self.assertEqual(format(f, ".3f"), "1.200") with self.assertRaises(ValueError) as cm: format(f, ".%sf" % (sys.maxsize + 1)) - self.assertEqual(str(cm.exception), "precision too big") c = complex(f) self.assertEqual(format(c, ".0f"), "1+0j") self.assertEqual(format(c, ".3f"), "1.200+0.000j") with self.assertRaises(ValueError) as cm: format(c, ".%sf" % (sys.maxsize + 1)) - self.assertEqual(str(cm.exception), "precision too big") @support.cpython_only def test_precision_c_limits(self): diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py index 1fad9215c0..33365322b9 100644 --- a/Lib/test/test_fractions.py +++ b/Lib/test/test_fractions.py @@ -146,9 +146,10 @@ class FractionTest(unittest.TestCase): self.assertEqual((0, 1), _components(F(-0.0))) self.assertEqual((3602879701896397, 36028797018963968), _components(F(0.1))) - self.assertRaises(TypeError, F, float('nan')) - self.assertRaises(TypeError, F, float('inf')) - self.assertRaises(TypeError, F, float('-inf')) + # bug 16469: error types should be consistent with float -> int + self.assertRaises(ValueError, F, float('nan')) + self.assertRaises(OverflowError, F, float('inf')) + self.assertRaises(OverflowError, F, float('-inf')) def testInitFromDecimal(self): self.assertEqual((11, 10), @@ -157,10 +158,11 @@ class FractionTest(unittest.TestCase): _components(F(Decimal('3.5e-2')))) self.assertEqual((0, 1), _components(F(Decimal('.000e20')))) - self.assertRaises(TypeError, F, Decimal('nan')) - self.assertRaises(TypeError, F, Decimal('snan')) - self.assertRaises(TypeError, F, Decimal('inf')) - self.assertRaises(TypeError, F, Decimal('-inf')) + # bug 16469: error types should be consistent with decimal -> int + self.assertRaises(ValueError, F, Decimal('nan')) + self.assertRaises(ValueError, F, Decimal('snan')) + self.assertRaises(OverflowError, F, Decimal('inf')) + self.assertRaises(OverflowError, F, Decimal('-inf')) def testFromString(self): self.assertEqual((5, 1), _components(F("5"))) @@ -248,14 +250,15 @@ class FractionTest(unittest.TestCase): inf = 1e1000 nan = inf - inf + # bug 16469: error types should be consistent with float -> int self.assertRaisesMessage( - TypeError, "Cannot convert inf to Fraction.", + OverflowError, "Cannot convert inf to Fraction.", F.from_float, inf) self.assertRaisesMessage( - TypeError, "Cannot convert -inf to Fraction.", + OverflowError, "Cannot convert -inf to Fraction.", F.from_float, -inf) self.assertRaisesMessage( - TypeError, "Cannot convert nan to Fraction.", + ValueError, "Cannot convert nan to Fraction.", F.from_float, nan) def testFromDecimal(self): @@ -268,17 +271,18 @@ class FractionTest(unittest.TestCase): self.assertEqual(1 - F(1, 10**30), F.from_decimal(Decimal("0." + "9" * 30))) + # bug 16469: error types should be consistent with decimal -> int self.assertRaisesMessage( - TypeError, "Cannot convert Infinity to Fraction.", + OverflowError, "Cannot convert Infinity to Fraction.", F.from_decimal, Decimal("inf")) self.assertRaisesMessage( - TypeError, "Cannot convert -Infinity to Fraction.", + OverflowError, "Cannot convert -Infinity to Fraction.", F.from_decimal, Decimal("-inf")) self.assertRaisesMessage( - TypeError, "Cannot convert NaN to Fraction.", + ValueError, "Cannot convert NaN to Fraction.", F.from_decimal, Decimal("nan")) self.assertRaisesMessage( - TypeError, "Cannot convert sNaN to Fraction.", + ValueError, "Cannot convert sNaN to Fraction.", F.from_decimal, Decimal("snan")) def testLimitDenominator(self): diff --git a/Lib/test/test_frame.py b/Lib/test/test_frame.py new file mode 100644 index 0000000000..2dd5780c88 --- /dev/null +++ b/Lib/test/test_frame.py @@ -0,0 +1,116 @@ +import gc +import sys +import unittest +import weakref + +from test import support + + +class ClearTest(unittest.TestCase): + """ + Tests for frame.clear(). + """ + + def inner(self, x=5, **kwargs): + 1/0 + + def outer(self, **kwargs): + try: + self.inner(**kwargs) + except ZeroDivisionError as e: + exc = e + return exc + + def clear_traceback_frames(self, tb): + """ + Clear all frames in a traceback. + """ + while tb is not None: + tb.tb_frame.clear() + tb = tb.tb_next + + def test_clear_locals(self): + class C: + pass + c = C() + wr = weakref.ref(c) + exc = self.outer(c=c) + del c + support.gc_collect() + # A reference to c is held through the frames + self.assertIsNot(None, wr()) + self.clear_traceback_frames(exc.__traceback__) + support.gc_collect() + # The reference was released by .clear() + self.assertIs(None, wr()) + + def test_clear_generator(self): + endly = False + def g(): + nonlocal endly + try: + yield + inner() + finally: + endly = True + gen = g() + next(gen) + self.assertFalse(endly) + # Clearing the frame closes the generator + gen.gi_frame.clear() + self.assertTrue(endly) + + def test_clear_executing(self): + # Attempting to clear an executing frame is forbidden. + try: + 1/0 + except ZeroDivisionError as e: + f = e.__traceback__.tb_frame + with self.assertRaises(RuntimeError): + f.clear() + with self.assertRaises(RuntimeError): + f.f_back.clear() + + def test_clear_executing_generator(self): + # Attempting to clear an executing generator frame is forbidden. + endly = False + def g(): + nonlocal endly + try: + 1/0 + except ZeroDivisionError as e: + f = e.__traceback__.tb_frame + with self.assertRaises(RuntimeError): + f.clear() + with self.assertRaises(RuntimeError): + f.f_back.clear() + yield f + finally: + endly = True + gen = g() + f = next(gen) + self.assertFalse(endly) + # Clearing the frame closes the generator + f.clear() + self.assertTrue(endly) + + @support.cpython_only + def test_clear_refcycles(self): + # .clear() doesn't leave any refcycle behind + with support.disable_gc(): + class C: + pass + c = C() + wr = weakref.ref(c) + exc = self.outer(c=c) + del c + self.assertIsNot(None, wr()) + self.clear_traceback_frames(exc.__traceback__) + self.assertIs(None, wr()) + + +def test_main(): + support.run_unittest(__name__) + +if __name__ == "__main__": + test_main() diff --git a/Lib/test/test_frozen.py b/Lib/test/test_frozen.py deleted file mode 100644 index fd6761c651..0000000000 --- a/Lib/test/test_frozen.py +++ /dev/null @@ -1,79 +0,0 @@ -# Test the frozen module defined in frozen.c. - -from test.support import captured_stdout, run_unittest -import unittest -import sys - -class FrozenTests(unittest.TestCase): - - module_attrs = frozenset(['__builtins__', '__cached__', '__doc__', - '__loader__', '__name__', - '__package__']) - package_attrs = frozenset(list(module_attrs) + ['__path__']) - - def test_frozen(self): - with captured_stdout() as stdout: - try: - import __hello__ - except ImportError as x: - self.fail("import __hello__ failed:" + str(x)) - self.assertEqual(__hello__.initialized, True) - expect = set(self.module_attrs) - expect.add('initialized') - self.assertEqual(set(dir(__hello__)), expect) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - with captured_stdout() as stdout: - try: - import __phello__ - except ImportError as x: - self.fail("import __phello__ failed:" + str(x)) - self.assertEqual(__phello__.initialized, True) - expect = set(self.package_attrs) - expect.add('initialized') - if not "__phello__.spam" in sys.modules: - self.assertEqual(set(dir(__phello__)), expect) - else: - expect.add('spam') - self.assertEqual(set(dir(__phello__)), expect) - self.assertEqual(__phello__.__path__, [__phello__.__name__]) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - with captured_stdout() as stdout: - try: - import __phello__.spam - except ImportError as x: - self.fail("import __phello__.spam failed:" + str(x)) - self.assertEqual(__phello__.spam.initialized, True) - spam_expect = set(self.module_attrs) - spam_expect.add('initialized') - self.assertEqual(set(dir(__phello__.spam)), spam_expect) - phello_expect = set(self.package_attrs) - phello_expect.add('initialized') - phello_expect.add('spam') - self.assertEqual(set(dir(__phello__)), phello_expect) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - try: - import __phello__.foo - except ImportError: - pass - else: - self.fail("import __phello__.foo should have failed") - - try: - import __phello__.foo - except ImportError: - pass - else: - self.fail("import __phello__.foo should have failed") - - del sys.modules['__hello__'] - del sys.modules['__phello__'] - del sys.modules['__phello__.spam'] - -def test_main(): - run_unittest(FrozenTests) - -if __name__ == "__main__": - test_main() diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index 6c95c491ff..a9bf30b3d5 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -15,12 +15,16 @@ try: import ssl except ImportError: ssl = None + HAS_SNI = False +else: + from ssl import HAS_SNI from unittest import TestCase, skipUnless from test import support from test.support import HOST, HOSTv6 threading = support.import_module('threading') +TIMEOUT = 3 # the dummy data returned by server over the data channel when # RETR, LIST, NLST, MLSD commands are issued RETR_DATA = 'abcde12345\r\n' * 1000 @@ -126,7 +130,7 @@ class DummyFTPHandler(asynchat.async_chat): addr = list(map(int, arg.split(','))) ip = '%d.%d.%d.%d' %tuple(addr[:4]) port = (addr[4] * 256) + addr[5] - s = socket.create_connection((ip, port), timeout=2) + s = socket.create_connection((ip, port), timeout=TIMEOUT) self.dtp = self.dtp_handler(s, baseclass=self) self.push('200 active data connection established') @@ -134,7 +138,7 @@ class DummyFTPHandler(asynchat.async_chat): with socket.socket() as sock: sock.bind((self.socket.getsockname()[0], 0)) sock.listen(5) - sock.settimeout(10) + sock.settimeout(TIMEOUT) ip, port = sock.getsockname()[:2] ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2)) @@ -144,7 +148,7 @@ class DummyFTPHandler(asynchat.async_chat): def cmd_eprt(self, arg): af, ip, port = arg.split(arg[0])[1:-1] port = int(port) - s = socket.create_connection((ip, port), timeout=2) + s = socket.create_connection((ip, port), timeout=TIMEOUT) self.dtp = self.dtp_handler(s, baseclass=self) self.push('200 active data connection established') @@ -152,7 +156,7 @@ class DummyFTPHandler(asynchat.async_chat): with socket.socket(socket.AF_INET6) as sock: sock.bind((self.socket.getsockname()[0], 0)) sock.listen(5) - sock.settimeout(10) + sock.settimeout(TIMEOUT) port = sock.getsockname()[1] self.push('229 entering extended passive mode (|||%d|)' %port) conn, addr = sock.accept() @@ -300,7 +304,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): if ssl is not None: - CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem") + CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem") + CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem") class SSLConnection(asyncore.dispatcher): """An asyncore.dispatcher subclass supporting TLS/SSL.""" @@ -327,7 +332,7 @@ if ssl is not None: elif err.args[0] == ssl.SSL_ERROR_EOF: return self.handle_close() raise - except socket.error as err: + except OSError as err: if err.args[0] == errno.ECONNABORTED: return self.handle_close() else: @@ -341,7 +346,7 @@ if ssl is not None: if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): return - except socket.error as err: + except OSError as err: # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return # from OpenSSL's SSL_shutdown(), corresponding to a # closed socket condition. See also: @@ -460,7 +465,7 @@ class TestFTPClass(TestCase): def setUp(self): self.server = DummyFTPServer((HOST, 0)) self.server.start() - self.client = ftplib.FTP(timeout=2) + self.client = ftplib.FTP(timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) def tearDown(self): @@ -488,7 +493,7 @@ class TestFTPClass(TestCase): def test_all_errors(self): exceptions = (ftplib.error_reply, ftplib.error_temp, ftplib.error_perm, - ftplib.error_proto, ftplib.Error, IOError, EOFError) + ftplib.error_proto, ftplib.Error, OSError, EOFError) for x in exceptions: try: raise x('exception not included in all_errors set') @@ -678,7 +683,7 @@ class TestFTPClass(TestCase): def test_makepasv(self): host, port = self.client.makepasv() - conn = socket.create_connection((host, port), 10) + conn = socket.create_connection((host, port), timeout=TIMEOUT) conn.close() # IPv4 is in use, just make sure send_epsv has not been used self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') @@ -691,12 +696,12 @@ class TestFTPClass(TestCase): return False try: self.client.sendcmd('noop') - except (socket.error, EOFError): + except (OSError, EOFError): return False return True # base test - with ftplib.FTP(timeout=2) as self.client: + with ftplib.FTP(timeout=TIMEOUT) as self.client: self.client.connect(self.server.host, self.server.port) self.client.sendcmd('noop') self.assertTrue(is_client_connected()) @@ -704,7 +709,7 @@ class TestFTPClass(TestCase): self.assertFalse(is_client_connected()) # QUIT sent inside the with block - with ftplib.FTP(timeout=2) as self.client: + with ftplib.FTP(timeout=TIMEOUT) as self.client: self.client.connect(self.server.host, self.server.port) self.client.sendcmd('noop') self.client.quit() @@ -714,7 +719,7 @@ class TestFTPClass(TestCase): # force a wrong response code to be sent on QUIT: error_perm # is expected and the connection is supposed to be closed try: - with ftplib.FTP(timeout=2) as self.client: + with ftplib.FTP(timeout=TIMEOUT) as self.client: self.client.connect(self.server.host, self.server.port) self.client.sendcmd('noop') self.server.handler_instance.next_response = '550 error on quit' @@ -736,7 +741,7 @@ class TestFTPClass(TestCase): source_address=(HOST, port)) self.assertEqual(self.client.sock.getsockname()[1], port) self.client.quit() - except IOError as e: + except OSError as e: if e.errno == errno.EADDRINUSE: self.skipTest("couldn't bind to port %d" % port) raise @@ -747,7 +752,7 @@ class TestFTPClass(TestCase): try: with self.client.transfercmd('list') as sock: self.assertEqual(sock.getsockname()[1], port) - except IOError as e: + except OSError as e: if e.errno == errno.EADDRINUSE: self.skipTest("couldn't bind to port %d" % port) raise @@ -785,7 +790,7 @@ class TestIPv6Environment(TestCase): def setUp(self): self.server = DummyFTPServer((HOSTv6, 0), af=socket.AF_INET6) self.server.start() - self.client = ftplib.FTP() + self.client = ftplib.FTP(timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) def tearDown(self): @@ -802,7 +807,7 @@ class TestIPv6Environment(TestCase): def test_makepasv(self): host, port = self.client.makepasv() - conn = socket.create_connection((host, port), 10) + conn = socket.create_connection((host, port), timeout=TIMEOUT) conn.close() self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv') @@ -829,7 +834,7 @@ class TestTLS_FTPClassMixin(TestFTPClass): def setUp(self): self.server = DummyTLS_FTPServer((HOST, 0)) self.server.start() - self.client = ftplib.FTP_TLS(timeout=2) + self.client = ftplib.FTP_TLS(timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) # enable TLS self.client.auth() @@ -843,7 +848,7 @@ class TestTLS_FTPClass(TestCase): def setUp(self): self.server = DummyTLS_FTPServer((HOST, 0)) self.server.start() - self.client = ftplib.FTP_TLS(timeout=2) + self.client = ftplib.FTP_TLS(timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) def tearDown(self): @@ -903,7 +908,7 @@ class TestTLS_FTPClass(TestCase): self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, keyfile=CERTFILE, context=ctx) - self.client = ftplib.FTP_TLS(context=ctx, timeout=2) + self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) self.assertNotIsInstance(self.client.sock, ssl.SSLSocket) self.client.auth() @@ -922,6 +927,37 @@ class TestTLS_FTPClass(TestCase): self.client.ccc() self.assertRaises(ValueError, self.client.sock.unwrap) + @skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_check_hostname(self): + self.client.quit() + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.check_hostname = True + ctx.load_verify_locations(CAFILE) + self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) + + # 127.0.0.1 doesn't match SAN + self.client.connect(self.server.host, self.server.port) + with self.assertRaises(ssl.CertificateError): + self.client.auth() + # exception quits connection + + self.client.connect(self.server.host, self.server.port) + self.client.prot_p() + with self.assertRaises(ssl.CertificateError): + with self.client.transfercmd("list") as sock: + pass + self.client.quit() + + self.client.connect("localhost", self.server.port) + self.client.auth() + self.client.quit() + + self.client.connect("localhost", self.server.port) + self.client.prot_p() + with self.client.transfercmd("list") as sock: + pass + class TestTimeouts(TestCase): @@ -1017,8 +1053,19 @@ class TestTimeouts(TestCase): ftp.close() +class TestNetrcDeprecation(TestCase): + + def test_deprecation(self): + with support.temp_cwd(), support.EnvironmentVarGuard() as env: + env['HOME'] = os.getcwd() + open('.netrc', 'w').close() + with self.assertWarns(DeprecationWarning): + ftplib.Netrc() + + + def test_main(): - tests = [TestFTPClass, TestTimeouts, + tests = [TestFTPClass, TestTimeouts, TestNetrcDeprecation, TestIPv6Environment, TestTLS_FTPClassMixin, TestTLS_FTPClass] diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py index c8ed83020e..5094f7ba1f 100644 --- a/Lib/test/test_funcattrs.py +++ b/Lib/test/test_funcattrs.py @@ -7,6 +7,11 @@ def global_function(): def inner_function(): class LocalClass: pass + global inner_global_function + def inner_global_function(): + def inner_function2(): + pass + return inner_function2 return LocalClass return lambda: inner_function @@ -116,6 +121,8 @@ class FunctionPropertiesTest(FuncAttrsTest): 'global_function.<locals>.inner_function') self.assertEqual(global_function()()().__qualname__, 'global_function.<locals>.inner_function.<locals>.LocalClass') + self.assertEqual(inner_global_function.__qualname__, 'inner_global_function') + self.assertEqual(inner_global_function().__qualname__, 'inner_global_function.<locals>.inner_function2') self.b.__qualname__ = 'c' self.assertEqual(self.b.__qualname__, 'c') self.b.__qualname__ = 'd' diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index d1ce2a9c8e..75ae7f3a15 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,66 +1,52 @@ -import functools +import abc import collections +from itertools import permutations +import pickle +from random import choice import sys -import unittest from test import support +import unittest from weakref import proxy -import pickle -from random import choice -@staticmethod -def PythonPartial(func, *args, **keywords): - 'Pure Python approximation of partial()' - def newfunc(*fargs, **fkeywords): - newkeywords = keywords.copy() - newkeywords.update(fkeywords) - return func(*(args + fargs), **newkeywords) - newfunc.func = func - newfunc.args = args - newfunc.keywords = keywords - return newfunc +import functools + +py_functools = support.import_fresh_module('functools', blocked=['_functools']) +c_functools = support.import_fresh_module('functools', fresh=['_functools']) + +decimal = support.import_fresh_module('decimal', fresh=['_decimal']) + def capture(*args, **kw): """capture all positional and keyword arguments""" return args, kw + def signature(part): """ return the signature of a partial object """ return (part.func, part.args, part.keywords, part.__dict__) -class TestPartial(unittest.TestCase): - thetype = functools.partial +class TestPartial: def test_basic_examples(self): - p = self.thetype(capture, 1, 2, a=10, b=20) + p = self.partial(capture, 1, 2, a=10, b=20) + self.assertTrue(callable(p)) self.assertEqual(p(3, 4, b=30, c=40), ((1, 2, 3, 4), dict(a=10, b=30, c=40))) - p = self.thetype(map, lambda x: x*10) + p = self.partial(map, lambda x: x*10) self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) def test_attributes(self): - p = self.thetype(capture, 1, 2, a=10, b=20) + p = self.partial(capture, 1, 2, a=10, b=20) # attributes should be readable self.assertEqual(p.func, capture) self.assertEqual(p.args, (1, 2)) self.assertEqual(p.keywords, dict(a=10, b=20)) - # attributes should not be writable - self.assertRaises(AttributeError, setattr, p, 'func', map) - self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) - self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) - - p = self.thetype(hex) - try: - del p.__dict__ - except TypeError: - pass - else: - self.fail('partial object allowed __dict__ to be deleted') def test_argument_checking(self): - self.assertRaises(TypeError, self.thetype) # need at least a func arg + self.assertRaises(TypeError, self.partial) # need at least a func arg try: - self.thetype(2)() + self.partial(2)() except TypeError: pass else: @@ -71,7 +57,7 @@ class TestPartial(unittest.TestCase): def func(a=10, b=20): return a d = {'a':3} - p = self.thetype(func, a=5) + p = self.partial(func, a=5) self.assertEqual(p(**d), 3) self.assertEqual(d, {'a':3}) p(b=7) @@ -80,20 +66,20 @@ class TestPartial(unittest.TestCase): def test_arg_combinations(self): # exercise special code paths for zero args in either partial # object or the caller - p = self.thetype(capture) + p = self.partial(capture) self.assertEqual(p(), ((), {})) self.assertEqual(p(1,2), ((1,2), {})) - p = self.thetype(capture, 1, 2) + p = self.partial(capture, 1, 2) self.assertEqual(p(), ((1,2), {})) self.assertEqual(p(3,4), ((1,2,3,4), {})) def test_kw_combinations(self): # exercise special code paths for no keyword args in # either the partial object or the caller - p = self.thetype(capture) + p = self.partial(capture) self.assertEqual(p(), ((), {})) self.assertEqual(p(a=1), ((), {'a':1})) - p = self.thetype(capture, a=1) + p = self.partial(capture, a=1) self.assertEqual(p(), ((), {'a':1})) self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) # keyword args in the call override those in the partial object @@ -102,7 +88,7 @@ class TestPartial(unittest.TestCase): def test_positional(self): # make sure positional arguments are captured correctly for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: - p = self.thetype(capture, *args) + p = self.partial(capture, *args) expected = args + ('x',) got, empty = p('x') self.assertTrue(expected == got and empty == {}) @@ -110,14 +96,14 @@ class TestPartial(unittest.TestCase): def test_keyword(self): # make sure keyword arguments are captured correctly for a in ['a', 0, None, 3.5]: - p = self.thetype(capture, a=a) + p = self.partial(capture, a=a) expected = {'a':a,'x':None} empty, got = p(x=None) self.assertTrue(expected == got and empty == ()) def test_no_side_effects(self): # make sure there are no side effects that affect subsequent calls - p = self.thetype(capture, 0, a=1) + p = self.partial(capture, 0, a=1) args1, kw1 = p(1, b=2) self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) args2, kw2 = p() @@ -126,13 +112,13 @@ class TestPartial(unittest.TestCase): def test_error_propagation(self): def f(x, y): x / y - self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0)) - self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0) - self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0) - self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1) + self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) + self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) + self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) + self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) def test_weakref(self): - f = self.thetype(int, base=16) + f = self.partial(int, base=16) p = proxy(f) self.assertEqual(f.func, p.func) f = None @@ -140,39 +126,61 @@ class TestPartial(unittest.TestCase): def test_with_bound_and_unbound_methods(self): data = list(map(str, range(10))) - join = self.thetype(str.join, '') + join = self.partial(str.join, '') self.assertEqual(join(data), '0123456789') - join = self.thetype(''.join) + join = self.partial(''.join) self.assertEqual(join(data), '0123456789') + +@unittest.skipUnless(c_functools, 'requires the C _functools module') +class TestPartialC(TestPartial, unittest.TestCase): + if c_functools: + partial = c_functools.partial + + def test_attributes_unwritable(self): + # attributes should not be writable + p = self.partial(capture, 1, 2, a=10, b=20) + self.assertRaises(AttributeError, setattr, p, 'func', map) + self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) + self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) + + p = self.partial(hex) + try: + del p.__dict__ + except TypeError: + pass + else: + self.fail('partial object allowed __dict__ to be deleted') + def test_repr(self): args = (object(), object()) args_repr = ', '.join(repr(a) for a in args) - kwargs = {'a': object(), 'b': object()} + #kwargs = {'a': object(), 'b': object()} + kwargs = {'a': object()} kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items()) - if self.thetype is functools.partial: + if self.partial is c_functools.partial: name = 'functools.partial' else: - name = self.thetype.__name__ + name = self.partial.__name__ - f = self.thetype(capture) + f = self.partial(capture) self.assertEqual('{}({!r})'.format(name, capture), repr(f)) - f = self.thetype(capture, *args) + f = self.partial(capture, *args) self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr), repr(f)) - f = self.thetype(capture, **kwargs) + f = self.partial(capture, **kwargs) self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr), repr(f)) - f = self.thetype(capture, *args, **kwargs) + f = self.partial(capture, *args, **kwargs) self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr), repr(f)) def test_pickle(self): - f = self.thetype(signature, 'asdf', bar=True) + f = self.partial(signature, 'asdf', bar=True) f.add_something_to__dict__ = True f_copy = pickle.loads(pickle.dumps(f)) self.assertEqual(signature(f), signature(f_copy)) @@ -191,30 +199,140 @@ class TestPartial(unittest.TestCase): return {} raise IndexError - f = self.thetype(object) + f = self.partial(object) self.assertRaisesRegex(SystemError, "new style getargs format but argument is not a tuple", f.__setstate__, BadSequence()) -class PartialSubclass(functools.partial): - pass -class TestPartialSubclass(TestPartial): +class TestPartialPy(TestPartial, unittest.TestCase): + partial = staticmethod(py_functools.partial) + + +if c_functools: + class PartialSubclass(c_functools.partial): + pass - thetype = PartialSubclass -class TestPythonPartial(TestPartial): +@unittest.skipUnless(c_functools, 'requires the C _functools module') +class TestPartialCSubclass(TestPartialC): + if c_functools: + partial = PartialSubclass - thetype = PythonPartial - # the python version hasn't a nice repr - test_repr = None +class TestPartialMethod(unittest.TestCase): - # the python version isn't picklable - test_pickle = test_setstate_refcount = None + class A(object): + nothing = functools.partialmethod(capture) + positional = functools.partialmethod(capture, 1) + keywords = functools.partialmethod(capture, a=2) + both = functools.partialmethod(capture, 3, b=4) + + nested = functools.partialmethod(positional, 5) + + over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) + + static = functools.partialmethod(staticmethod(capture), 8) + cls = functools.partialmethod(classmethod(capture), d=9) + + a = A() + + def test_arg_combinations(self): + self.assertEqual(self.a.nothing(), ((self.a,), {})) + self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) + self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) + self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) + + self.assertEqual(self.a.positional(), ((self.a, 1), {})) + self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) + self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) + self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) + + self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) + self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) + self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) + self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) + + self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) + self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) + self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) + self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) + + self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) + + def test_nested(self): + self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) + self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) + self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) + self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) + + self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) + + def test_over_partial(self): + self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) + self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) + self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) + self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) + + self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) + + def test_bound_method_introspection(self): + obj = self.a + self.assertIs(obj.both.__self__, obj) + self.assertIs(obj.nested.__self__, obj) + self.assertIs(obj.over_partial.__self__, obj) + self.assertIs(obj.cls.__self__, self.A) + self.assertIs(self.A.cls.__self__, self.A) + + def test_unbound_method_retrieval(self): + obj = self.A + self.assertFalse(hasattr(obj.both, "__self__")) + self.assertFalse(hasattr(obj.nested, "__self__")) + self.assertFalse(hasattr(obj.over_partial, "__self__")) + self.assertFalse(hasattr(obj.static, "__self__")) + self.assertFalse(hasattr(self.a.static, "__self__")) + + def test_descriptors(self): + for obj in [self.A, self.a]: + with self.subTest(obj=obj): + self.assertEqual(obj.static(), ((8,), {})) + self.assertEqual(obj.static(5), ((8, 5), {})) + self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) + self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) + + self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) + self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) + self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) + self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) + + def test_overriding_keywords(self): + self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) + self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) + + def test_invalid_args(self): + with self.assertRaises(TypeError): + class B(object): + method = functools.partialmethod(None, 1) + + def test_repr(self): + self.assertEqual(repr(vars(self.A)['both']), + 'functools.partialmethod({}, 3, b=4)'.format(capture)) + + def test_abstract(self): + class Abstract(abc.ABCMeta): + + @abc.abstractmethod + def add(self, x, y): + pass + + add5 = functools.partialmethod(add, 5) + + self.assertTrue(Abstract.add.__isabstractmethod__) + self.assertTrue(Abstract.add5.__isabstractmethod__) + + for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: + self.assertFalse(getattr(func, '__isabstractmethod__', False)) - # the python version isn't a type - test_attributes = None class TestUpdateWrapper(unittest.TestCase): @@ -223,19 +341,26 @@ class TestUpdateWrapper(unittest.TestCase): updated=functools.WRAPPER_UPDATES): # Check attributes were assigned for name in assigned: - self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name)) + self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) # Check attributes were updated for name in updated: wrapper_attr = getattr(wrapper, name) wrapped_attr = getattr(wrapped, name) for key in wrapped_attr: - self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) + if name == "__dict__" and key == "__wrapped__": + # __wrapped__ is overwritten by the update code + continue + self.assertIs(wrapped_attr[key], wrapper_attr[key]) + # Check __wrapped__ + self.assertIs(wrapper.__wrapped__, wrapped) + def _default_update(self): def f(a:'This is a new annotation'): """This is a test""" pass f.attr = 'This is also a test' + f.__wrapped__ = "This is a bald faced lie" def wrapper(b:'This is the prior annotation'): pass functools.update_wrapper(wrapper, f) @@ -322,6 +447,7 @@ class TestUpdateWrapper(unittest.TestCase): self.assertTrue(wrapper.__doc__.startswith('max(')) self.assertEqual(wrapper.__annotations__, {}) + class TestWraps(TestUpdateWrapper): def _default_update(self): @@ -329,14 +455,15 @@ class TestWraps(TestUpdateWrapper): """This is a test""" pass f.attr = 'This is also a test' + f.__wrapped__ = "This is still a bald faced lie" @functools.wraps(f) def wrapper(): pass - self.check_wrapper(wrapper, f) return wrapper, f def test_default_update(self): wrapper, f = self._default_update() + self.check_wrapper(wrapper, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') @@ -382,6 +509,7 @@ class TestWraps(TestUpdateWrapper): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) + class TestReduce(unittest.TestCase): func = functools.reduce @@ -462,24 +590,29 @@ class TestReduce(unittest.TestCase): d = {"one": 1, "two": 2, "three": 3} self.assertEqual(self.func(add, d), "".join(d.keys())) -class TestCmpToKey(unittest.TestCase): + +class TestCmpToKey: def test_cmp_to_key(self): def cmp1(x, y): return (x > y) - (x < y) - key = functools.cmp_to_key(cmp1) + key = self.cmp_to_key(cmp1) self.assertEqual(key(3), key(3)) self.assertGreater(key(3), key(1)) + self.assertGreaterEqual(key(3), key(3)) + def cmp2(x, y): return int(x) - int(y) - key = functools.cmp_to_key(cmp2) + key = self.cmp_to_key(cmp2) self.assertEqual(key(4.0), key('4')) self.assertLess(key(2), key('35')) + self.assertLessEqual(key(2), key('35')) + self.assertNotEqual(key(2), key('35')) def test_cmp_to_key_arguments(self): def cmp1(x, y): return (x > y) - (x < y) - key = functools.cmp_to_key(mycmp=cmp1) + key = self.cmp_to_key(mycmp=cmp1) self.assertEqual(key(obj=3), key(obj=3)) self.assertGreater(key(obj=3), key(obj=1)) with self.assertRaises((TypeError, AttributeError)): @@ -487,10 +620,10 @@ class TestCmpToKey(unittest.TestCase): with self.assertRaises((TypeError, AttributeError)): 1 < key(3) # lhs is not a K object with self.assertRaises(TypeError): - key = functools.cmp_to_key() # too few args + key = self.cmp_to_key() # too few args with self.assertRaises(TypeError): - key = functools.cmp_to_key(cmp1, None) # too many args - key = functools.cmp_to_key(cmp1) + key = self.cmp_to_key(cmp1, None) # too many args + key = self.cmp_to_key(cmp1) with self.assertRaises(TypeError): key() # too few args with self.assertRaises(TypeError): @@ -499,7 +632,7 @@ class TestCmpToKey(unittest.TestCase): def test_bad_cmp(self): def cmp1(x, y): raise ZeroDivisionError - key = functools.cmp_to_key(cmp1) + key = self.cmp_to_key(cmp1) with self.assertRaises(ZeroDivisionError): key(3) > key(1) @@ -514,13 +647,13 @@ class TestCmpToKey(unittest.TestCase): def test_obj_field(self): def cmp1(x, y): return (x > y) - (x < y) - key = functools.cmp_to_key(mycmp=cmp1) + key = self.cmp_to_key(mycmp=cmp1) self.assertEqual(key(50).obj, 50) def test_sort_int(self): def mycmp(x, y): return y - x - self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), + self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) def test_sort_int_str(self): @@ -528,18 +661,29 @@ class TestCmpToKey(unittest.TestCase): x, y = int(x), int(y) return (x > y) - (x < y) values = [5, '3', 7, 2, '0', '1', 4, '10', 1] - values = sorted(values, key=functools.cmp_to_key(mycmp)) + values = sorted(values, key=self.cmp_to_key(mycmp)) self.assertEqual([int(value) for value in values], [0, 1, 1, 2, 3, 4, 5, 7, 10]) def test_hash(self): def mycmp(x, y): return y - x - key = functools.cmp_to_key(mycmp) + key = self.cmp_to_key(mycmp) k = key(10) self.assertRaises(TypeError, hash, k) self.assertNotIsInstance(k, collections.Hashable) + +@unittest.skipUnless(c_functools, 'requires the C _functools module') +class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): + if c_functools: + cmp_to_key = c_functools.cmp_to_key + + +class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): + cmp_to_key = staticmethod(py_functools.cmp_to_key) + + class TestTotalOrdering(unittest.TestCase): def test_total_ordering_lt(self): @@ -557,6 +701,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(1) > A(2)) def test_total_ordering_le(self): @functools.total_ordering @@ -573,6 +718,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(1) >= A(2)) def test_total_ordering_gt(self): @functools.total_ordering @@ -589,6 +735,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(2) < A(1)) def test_total_ordering_ge(self): @functools.total_ordering @@ -605,6 +752,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(2) <= A(1)) def test_total_ordering_no_overwrite(self): # new methods should not overwrite existing @@ -624,27 +772,118 @@ class TestTotalOrdering(unittest.TestCase): class A: pass - def test_bug_10042(self): + def test_type_error_when_not_implemented(self): + # bug 10042; ensure stack overflow does not occur + # when decorated types return NotImplemented @functools.total_ordering - class TestTO: + class ImplementsLessThan: def __init__(self, value): self.value = value def __eq__(self, other): - if isinstance(other, TestTO): + if isinstance(other, ImplementsLessThan): return self.value == other.value return False def __lt__(self, other): - if isinstance(other, TestTO): + if isinstance(other, ImplementsLessThan): return self.value < other.value - raise TypeError - with self.assertRaises(TypeError): - TestTO(8) <= () + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value == other.value + return False + def __gt__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value > other.value + return NotImplemented + + @functools.total_ordering + class ImplementsLessThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value == other.value + return False + def __le__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value <= other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value == other.value + return False + def __ge__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value >= other.value + return NotImplemented + + @functools.total_ordering + class ComparatorNotImplemented: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ComparatorNotImplemented): + return self.value == other.value + return False + def __lt__(self, other): + return NotImplemented + + with self.subTest("LT < 1"), self.assertRaises(TypeError): + ImplementsLessThan(-1) < 1 + + with self.subTest("LT < LE"), self.assertRaises(TypeError): + ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) + + with self.subTest("LT < GT"), self.assertRaises(TypeError): + ImplementsLessThan(1) < ImplementsGreaterThan(1) + + with self.subTest("LE <= LT"), self.assertRaises(TypeError): + ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) + + with self.subTest("LE <= GE"), self.assertRaises(TypeError): + ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) + + with self.subTest("GT > GE"), self.assertRaises(TypeError): + ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) + + with self.subTest("GT > LT"), self.assertRaises(TypeError): + ImplementsGreaterThan(5) > ImplementsLessThan(5) + + with self.subTest("GE >= GT"), self.assertRaises(TypeError): + ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) + + with self.subTest("GE >= LE"), self.assertRaises(TypeError): + ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) + + with self.subTest("GE when equal"): + a = ComparatorNotImplemented(8) + b = ComparatorNotImplemented(8) + self.assertEqual(a, b) + with self.assertRaises(TypeError): + a >= b + + with self.subTest("LE when equal"): + a = ComparatorNotImplemented(9) + b = ComparatorNotImplemented(9) + self.assertEqual(a, b) + with self.assertRaises(TypeError): + a <= b class TestLRU(unittest.TestCase): def test_lru(self): def orig(x, y): - return 3*x+y + return 3 * x + y f = functools.lru_cache(maxsize=20)(orig) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(maxsize, 20) @@ -749,7 +988,7 @@ class TestLRU(unittest.TestCase): # Verify that user_function exceptions get passed through without # creating a hard-to-read chained exception. # http://bugs.python.org/issue13177 - for maxsize in (None, 100): + for maxsize in (None, 128): @functools.lru_cache(maxsize) def func(i): return 'abc'[i] @@ -762,7 +1001,7 @@ class TestLRU(unittest.TestCase): func(15) def test_lru_with_types(self): - for maxsize in (None, 100): + for maxsize in (None, 128): @functools.lru_cache(maxsize=maxsize, typed=True) def square(x): return x * x @@ -777,6 +1016,36 @@ class TestLRU(unittest.TestCase): self.assertEqual(square.cache_info().hits, 4) self.assertEqual(square.cache_info().misses, 4) + def test_lru_with_keyword_args(self): + @functools.lru_cache() + def fib(n): + if n < 2: + return n + return fib(n=n-1) + fib(n=n-2) + self.assertEqual( + [fib(n=number) for number in range(16)], + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] + ) + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) + fib.cache_clear() + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) + + def test_lru_with_keyword_args_maxsize_none(self): + @functools.lru_cache(maxsize=None) + def fib(n): + if n < 2: + return n + return fib(n=n-1) + fib(n=n-2) + self.assertEqual([fib(n=number) for number in range(16)], + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + fib.cache_clear() + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + def test_need_for_rlock(self): # This will deadlock on an LRU cache that uses a regular lock @@ -802,17 +1071,495 @@ class TestLRU(unittest.TestCase): DoubleEq(2)) # Verify the correct return value +class TestSingleDispatch(unittest.TestCase): + def test_simple_overloads(self): + @functools.singledispatch + def g(obj): + return "base" + def g_int(i): + return "integer" + g.register(int, g_int) + self.assertEqual(g("str"), "base") + self.assertEqual(g(1), "integer") + self.assertEqual(g([1,2,3]), "base") + + def test_mro(self): + @functools.singledispatch + def g(obj): + return "base" + class A: + pass + class C(A): + pass + class B(A): + pass + class D(C, B): + pass + def g_A(a): + return "A" + def g_B(b): + return "B" + g.register(A, g_A) + g.register(B, g_B) + self.assertEqual(g(A()), "A") + self.assertEqual(g(B()), "B") + self.assertEqual(g(C()), "A") + self.assertEqual(g(D()), "B") + + def test_register_decorator(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register(int) + def g_int(i): + return "int %s" % (i,) + self.assertEqual(g(""), "base") + self.assertEqual(g(12), "int 12") + self.assertIs(g.dispatch(int), g_int) + self.assertIs(g.dispatch(object), g.dispatch(str)) + # Note: in the assert above this is not g. + # @singledispatch returns the wrapper. + + def test_wrapping_attributes(self): + @functools.singledispatch + def g(obj): + "Simple test" + return "Test" + self.assertEqual(g.__name__, "g") + if sys.flags.optimize < 2: + self.assertEqual(g.__doc__, "Simple test") + + @unittest.skipUnless(decimal, 'requires _decimal') + @support.cpython_only + def test_c_classes(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register(decimal.DecimalException) + def _(obj): + return obj.args + subn = decimal.Subnormal("Exponent < Emin") + rnd = decimal.Rounded("Number got rounded") + self.assertEqual(g(subn), ("Exponent < Emin",)) + self.assertEqual(g(rnd), ("Number got rounded",)) + @g.register(decimal.Subnormal) + def _(obj): + return "Too small to care." + self.assertEqual(g(subn), "Too small to care.") + self.assertEqual(g(rnd), ("Number got rounded",)) + + def test_compose_mro(self): + # None of the examples in this test depend on haystack ordering. + c = collections + mro = functools._compose_mro + bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] + for haystack in permutations(bases): + m = mro(dict, haystack) + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized, + c.Iterable, c.Container, object]) + bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] + for haystack in permutations(bases): + m = mro(c.ChainMap, haystack) + self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) + + # If there's a generic function with implementations registered for + # both Sized and Container, passing a defaultdict to it results in an + # ambiguous dispatch which will cause a RuntimeError (see + # test_mro_conflicts). + bases = [c.Container, c.Sized, str] + for haystack in permutations(bases): + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, + object]) + + # MutableSequence below is registered directly on D. In other words, it + # preceeds MutableMapping which means single dispatch will always + # choose MutableSequence here. + class D(c.defaultdict): + pass + c.MutableSequence.register(D) + bases = [c.MutableSequence, c.MutableMapping] + for haystack in permutations(bases): + m = mro(D, bases) + self.assertEqual(m, [D, c.MutableSequence, c.Sequence, + c.defaultdict, dict, c.MutableMapping, + c.Mapping, c.Sized, c.Iterable, c.Container, + object]) + + # Container and Callable are registered on different base classes and + # a generic function supporting both should always pick the Callable + # implementation if a C instance is passed. + class C(c.defaultdict): + def __call__(self): + pass + bases = [c.Sized, c.Callable, c.Container, c.Mapping] + for haystack in permutations(bases): + m = mro(C, haystack) + self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) + + def test_register_abc(self): + c = collections + d = {"a": "b"} + l = [1, 2, 3] + s = {object(), None} + f = frozenset(s) + t = (1, 2, 3) + @functools.singledispatch + def g(obj): + return "base" + self.assertEqual(g(d), "base") + self.assertEqual(g(l), "base") + self.assertEqual(g(s), "base") + self.assertEqual(g(f), "base") + self.assertEqual(g(t), "base") + g.register(c.Sized, lambda obj: "sized") + self.assertEqual(g(d), "sized") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.MutableMapping, lambda obj: "mutablemapping") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.ChainMap, lambda obj: "chainmap") + self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.MutableSequence, lambda obj: "mutablesequence") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.MutableSet, lambda obj: "mutableset") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.Mapping, lambda obj: "mapping") + self.assertEqual(g(d), "mutablemapping") # not specific enough + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.Sequence, lambda obj: "sequence") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sequence") + g.register(c.Set, lambda obj: "set") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(dict, lambda obj: "dict") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(list, lambda obj: "list") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(set, lambda obj: "concrete-set") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(frozenset, lambda obj: "frozen-set") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "frozen-set") + self.assertEqual(g(t), "sequence") + g.register(tuple, lambda obj: "tuple") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "frozen-set") + self.assertEqual(g(t), "tuple") + + def test_c3_abc(self): + c = collections + mro = functools._c3_mro + class A(object): + pass + class B(A): + def __len__(self): + return 0 # implies Sized + @c.Container.register + class C(object): + pass + class D(object): + pass # unrelated + class X(D, C, B): + def __call__(self): + pass # implies Callable + expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] + for abcs in permutations([c.Sized, c.Callable, c.Container]): + self.assertEqual(mro(X, abcs=abcs), expected) + # unrelated ABCs don't appear in the resulting MRO + many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] + self.assertEqual(mro(X, abcs=many_abcs), expected) + + def test_mro_conflicts(self): + c = collections + @functools.singledispatch + def g(arg): + return "base" + class O(c.Sized): + def __len__(self): + return 0 + o = O() + self.assertEqual(g(o), "base") + g.register(c.Iterable, lambda arg: "iterable") + g.register(c.Container, lambda arg: "container") + g.register(c.Sized, lambda arg: "sized") + g.register(c.Set, lambda arg: "set") + self.assertEqual(g(o), "sized") + c.Iterable.register(O) + self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ + c.Container.register(O) + self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ + c.Set.register(O) + self.assertEqual(g(o), "set") # because c.Set is a subclass of + # c.Sized and c.Container + class P: + pass + p = P() + self.assertEqual(g(p), "base") + c.Iterable.register(P) + self.assertEqual(g(p), "iterable") + c.Container.register(P) + with self.assertRaises(RuntimeError) as re_one: + g(p) + self.assertIn( + str(re_one.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Iterable'>"), + ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " + "or <class 'collections.abc.Container'>")), + ) + class Q(c.Sized): + def __len__(self): + return 0 + q = Q() + self.assertEqual(g(q), "sized") + c.Iterable.register(Q) + self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ + c.Set.register(Q) + self.assertEqual(g(q), "set") # because c.Set is a subclass of + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + # Even though Sized and Container are explicit bases of MutableMapping, + # this ABC is implicitly registered on defaultdict which makes all of + # MutableMapping's bases implicit as well from defaultdict's + # perspective. + with self.assertRaises(RuntimeError) as re_two: + h(c.defaultdict(lambda: 0)) + self.assertIn( + str(re_two.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Sized'>"), + ("Ambiguous dispatch: <class 'collections.abc.Sized'> " + "or <class 'collections.abc.Container'>")), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") + class S: + pass + class T(S, c.Sized): + def __len__(self): + return 0 + t = T() + self.assertEqual(h(t), "sized") + c.Container.register(T) + self.assertEqual(h(t), "sized") # because it's explicitly in the MRO + class U: + def __len__(self): + return 0 + u = U() + self.assertEqual(h(u), "sized") # implicit Sized subclass inferred + # from the existence of __len__() + c.Container.register(U) + # There is no preference for registered versus inferred ABCs. + with self.assertRaises(RuntimeError) as re_three: + h(u) + self.assertIn( + str(re_three.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Sized'>"), + ("Ambiguous dispatch: <class 'collections.abc.Sized'> " + "or <class 'collections.abc.Container'>")), + ) + class V(c.Sized, S): + def __len__(self): + return 0 + @functools.singledispatch + def j(arg): + return "base" + @j.register(S) + def _(arg): + return "s" + @j.register(c.Container) + def _(arg): + return "container" + v = V() + self.assertEqual(j(v), "s") + c.Container.register(V) + self.assertEqual(j(v), "container") # because it ends up right after + # Sized in the MRO + + def test_cache_invalidation(self): + from collections import UserDict + class TracingDict(UserDict): + def __init__(self, *args, **kwargs): + super(TracingDict, self).__init__(*args, **kwargs) + self.set_ops = [] + self.get_ops = [] + def __getitem__(self, key): + result = self.data[key] + self.get_ops.append(key) + return result + def __setitem__(self, key, value): + self.set_ops.append(key) + self.data[key] = value + def clear(self): + self.data.clear() + _orig_wkd = functools.WeakKeyDictionary + td = TracingDict() + functools.WeakKeyDictionary = lambda: td + c = collections + @functools.singledispatch + def g(arg): + return "base" + d = {} + l = [] + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "base") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, []) + self.assertEqual(td.set_ops, [dict]) + self.assertEqual(td.data[dict], g.registry[object]) + self.assertEqual(g(l), "base") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, []) + self.assertEqual(td.set_ops, [dict, list]) + self.assertEqual(td.data[dict], g.registry[object]) + self.assertEqual(td.data[list], g.registry[object]) + self.assertEqual(td.data[dict], td.data[list]) + self.assertEqual(g(l), "base") + self.assertEqual(g(d), "base") + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list]) + g.register(list, lambda arg: "list") + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "base") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict]) + self.assertEqual(td.data[dict], + functools._find_impl(dict, g.registry)) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list]) + self.assertEqual(td.data[list], + functools._find_impl(list, g.registry)) + class X: + pass + c.MutableMapping.register(X) # Will not invalidate the cache, + # not using ABCs yet. + self.assertEqual(g(d), "base") + self.assertEqual(g(l), "list") + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list]) + g.register(c.Sized, lambda arg: "sized") + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "sized") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + self.assertEqual(g(l), "list") + self.assertEqual(g(d), "sized") + self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + g.dispatch(list) + g.dispatch(dict) + self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, + list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + c.MutableSet.register(X) # Will invalidate the cache. + self.assertEqual(len(td), 2) # Stale cache. + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 1) + g.register(c.MutableMapping, lambda arg: "mutablemapping") + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(len(td), 1) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + g.register(dict, lambda arg: "dict") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + g._clear_cache() + self.assertEqual(len(td), 0) + functools.WeakKeyDictionary = _orig_wkd + + def test_main(verbose=None): test_classes = ( - TestPartial, - TestPartialSubclass, - TestPythonPartial, + TestPartialC, + TestPartialPy, + TestPartialCSubclass, + TestPartialMethod, TestUpdateWrapper, TestTotalOrdering, - TestCmpToKey, + TestCmpToKeyC, + TestCmpToKeyPy, TestWraps, TestReduce, TestLRU, + TestSingleDispatch, ) support.run_unittest(*test_classes) diff --git a/Lib/test/test_future.py b/Lib/test/test_future.py index a0c156f5f7..beac993e4d 100644 --- a/Lib/test/test_future.py +++ b/Lib/test/test_future.py @@ -82,6 +82,14 @@ class FutureTest(unittest.TestCase): else: self.fail("expected exception didn't occur") + def test_badfuture10(self): + try: + from test import badsyntax_future10 + except SyntaxError as msg: + self.assertEqual(get_error_location(msg), ("badsyntax_future10", '3')) + else: + self.fail("expected exception didn't occur") + def test_parserhack(self): # test that the parser.c::future_hack function works as expected # Note: although this test must pass, it's not testing the original diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py index c59b72eacf..7eb104a6b5 100644 --- a/Lib/test/test_gc.py +++ b/Lib/test/test_gc.py @@ -1,6 +1,8 @@ import unittest from test.support import (verbose, refcount_test, run_unittest, - strip_python_stderr) + strip_python_stderr, cpython_only) +from test.script_helper import assert_python_ok, make_script, temp_dir + import sys import time import gc @@ -11,6 +13,15 @@ try: except ImportError: threading = None +try: + from _testcapi import with_tp_del +except ImportError: + def with_tp_del(cls): + class C(object): + def __new__(cls, *args, **kwargs): + raise TypeError('requires _testcapi.with_tp_del') + return C + ### Support code ############################################################################### @@ -38,6 +49,7 @@ class GC_Detector(object): # gc collects it. self.wr = weakref.ref(C1055820(666), it_happened) +@with_tp_del class Uncollectable(object): """Create a reference cycle with multiple __del__ methods. @@ -50,7 +62,7 @@ class Uncollectable(object): self.partner = Uncollectable(partner=self) else: self.partner = partner - def __del__(self): + def __tp_del__(self): pass ### Tests @@ -139,11 +151,13 @@ class GCTests(unittest.TestCase): del a self.assertNotEqual(gc.collect(), 0) - def test_finalizer(self): + @cpython_only + def test_legacy_finalizer(self): # A() is uncollectable if it is part of a cycle, make sure it shows up # in gc.garbage. + @with_tp_del class A: - def __del__(self): pass + def __tp_del__(self): pass class B: pass a = A() @@ -163,11 +177,13 @@ class GCTests(unittest.TestCase): self.fail("didn't find obj in garbage (finalizer)") gc.garbage.remove(obj) - def test_finalizer_newclass(self): + @cpython_only + def test_legacy_finalizer_newclass(self): # A() is uncollectable if it is part of a cycle, make sure it shows up # in gc.garbage. + @with_tp_del class A(object): - def __del__(self): pass + def __tp_del__(self): pass class B(object): pass a = A() @@ -564,16 +580,19 @@ class GCTests(unittest.TestCase): # would be damaged, with an empty __dict__. self.assertEqual(x, None) + @cpython_only def test_garbage_at_shutdown(self): import subprocess code = """if 1: import gc + import _testcapi + @_testcapi.with_tp_del class X: def __init__(self, name): self.name = name def __repr__(self): return "<X %%r>" %% self.name - def __del__(self): + def __tp_del__(self): pass x = X('first') @@ -610,6 +629,66 @@ class GCTests(unittest.TestCase): stderr = run_command(code % "gc.DEBUG_SAVEALL") self.assertNotIn(b"uncollectable objects at shutdown", stderr) + def test_gc_main_module_at_shutdown(self): + # Create a reference cycle through the __main__ module and check + # it gets collected at interpreter shutdown. + code = """if 1: + import weakref + class C: + def __del__(self): + print('__del__ called') + l = [C()] + l.append(l) + """ + rc, out, err = assert_python_ok('-c', code) + self.assertEqual(out.strip(), b'__del__ called') + + def test_gc_ordinary_module_at_shutdown(self): + # Same as above, but with a non-__main__ module. + with temp_dir() as script_dir: + module = """if 1: + import weakref + class C: + def __del__(self): + print('__del__ called') + l = [C()] + l.append(l) + """ + code = """if 1: + import sys + sys.path.insert(0, %r) + import gctest + """ % (script_dir,) + make_script(script_dir, 'gctest', module) + rc, out, err = assert_python_ok('-c', code) + self.assertEqual(out.strip(), b'__del__ called') + + def test_get_stats(self): + stats = gc.get_stats() + self.assertEqual(len(stats), 3) + for st in stats: + self.assertIsInstance(st, dict) + self.assertEqual(set(st), + {"collected", "collections", "uncollectable"}) + self.assertGreaterEqual(st["collected"], 0) + self.assertGreaterEqual(st["collections"], 0) + self.assertGreaterEqual(st["uncollectable"], 0) + # Check that collection counts are incremented correctly + if gc.isenabled(): + self.addCleanup(gc.enable) + gc.disable() + old = gc.get_stats() + gc.collect(0) + new = gc.get_stats() + self.assertEqual(new[0]["collections"], old[0]["collections"] + 1) + self.assertEqual(new[1]["collections"], old[1]["collections"]) + self.assertEqual(new[2]["collections"], old[2]["collections"]) + gc.collect(2) + new = gc.get_stats() + self.assertEqual(new[0]["collections"], old[0]["collections"] + 1) + self.assertEqual(new[1]["collections"], old[1]["collections"]) + self.assertEqual(new[2]["collections"], old[2]["collections"] + 1) + class GCCallbackTests(unittest.TestCase): def setUp(self): @@ -696,6 +775,7 @@ class GCCallbackTests(unittest.TestCase): info = v[2] self.assertEqual(info["generation"], 2) + @cpython_only def test_collect_garbage(self): self.preclean() # Each of these cause four objects to be garbage: Two diff --git a/Lib/test/test_gdb.py b/Lib/test/test_gdb.py index 846422b837..0e254a2487 100644 --- a/Lib/test/test_gdb.py +++ b/Lib/test/test_gdb.py @@ -5,6 +5,7 @@ import os import re +import pprint import subprocess import sys import sysconfig @@ -17,6 +18,7 @@ try: except ImportError: _thread = None +from test import support from test.support import run_unittest, findfile, python_is_optimized try: @@ -40,6 +42,8 @@ if not sysconfig.is_python_build(): checkout_hook_path = os.path.join(os.path.dirname(sys.executable), 'python-gdb.py') +PYTHONHASHSEED = '123' + def run_gdb(*args, **env_vars): """Runs gdb in --batch mode with the additional arguments given by *args. @@ -144,7 +148,7 @@ class DebuggerTests(unittest.TestCase): # print (' '.join(args)) # Use "args" to invoke gdb, capturing stdout, stderr: - out, err = run_gdb(*args, PYTHONHASHSEED='0') + out, err = run_gdb(*args, PYTHONHASHSEED=PYTHONHASHSEED) errlines = err.splitlines() unexpected_errlines = [] @@ -165,6 +169,10 @@ class DebuggerTests(unittest.TestCase): 'linux-gate.so', 'Do you need "set solib-search-path" or ' '"set sysroot"?', + 'warning: Source file is more recent than executable.', + # Issue #19753: missing symbols on System Z + 'Missing separate debuginfo for ', + 'Try: zypper install -C ', ) for line in errlines: if not line.startswith(ignore_patterns): @@ -248,9 +256,8 @@ class PrettyPrintTests(DebuggerTests): def test_dicts(self): 'Verify the pretty-printing of dictionaries' self.assertGdbRepr({}) - self.assertGdbRepr({'foo': 'bar'}) - self.assertGdbRepr({'foo': 'bar', 'douglas': 42}, - "{'foo': 'bar', 'douglas': 42}") + self.assertGdbRepr({'foo': 'bar'}, "{'foo': 'bar'}") + self.assertGdbRepr({'foo': 'bar', 'douglas': 42}, "{'douglas': 42, 'foo': 'bar'}") def test_lists(self): 'Verify the pretty-printing of lists' @@ -305,26 +312,30 @@ class PrettyPrintTests(DebuggerTests): def test_tuples(self): 'Verify the pretty-printing of tuples' - self.assertGdbRepr(tuple()) + self.assertGdbRepr(tuple(), '()') self.assertGdbRepr((1,), '(1,)') self.assertGdbRepr(('foo', 'bar', 'baz')) def test_sets(self): 'Verify the pretty-printing of sets' - self.assertGdbRepr(set()) + if (gdb_major_version, gdb_minor_version) < (7, 3): + self.skipTest("pretty-printing of sets needs gdb 7.3 or later") + self.assertGdbRepr(set(), 'set()') self.assertGdbRepr(set(['a', 'b']), "{'a', 'b'}") self.assertGdbRepr(set([4, 5, 6]), "{4, 5, 6}") # Ensure that we handle sets containing the "dummy" key value, # which happens on deletion: gdb_repr, gdb_output = self.get_gdb_repr('''s = set(['a','b']) -s.pop() +s.remove('a') id(s)''') self.assertEqual(gdb_repr, "{'b'}") def test_frozensets(self): 'Verify the pretty-printing of frozensets' - self.assertGdbRepr(frozenset()) + if (gdb_major_version, gdb_minor_version) < (7, 3): + self.skipTest("pretty-printing of frozensets needs gdb 7.3 or later") + self.assertGdbRepr(frozenset(), 'frozenset()') self.assertGdbRepr(frozenset(['a', 'b']), "frozenset({'a', 'b'})") self.assertGdbRepr(frozenset([4, 5, 6]), "frozenset({4, 5, 6})") @@ -841,6 +852,10 @@ class PyLocalsTests(DebuggerTests): r".*\na = 1\nb = 2\nc = 3\n.*") def test_main(): + if support.verbose: + print("GDB version:") + for line in os.fsdecode(gdb_version).splitlines(): + print(" " * 4 + line) run_unittest(PrettyPrintTests, PyListTests, StackNavigationTests, diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py index 958054aef5..91afe47799 100644 --- a/Lib/test/test_generators.py +++ b/Lib/test/test_generators.py @@ -1,3 +1,55 @@ +import gc +import sys +import unittest +import weakref + +from test import support + + +class FinalizationTest(unittest.TestCase): + + def test_frame_resurrect(self): + # A generator frame can be resurrected by a generator's finalization. + def gen(): + nonlocal frame + try: + yield + finally: + frame = sys._getframe() + + g = gen() + wr = weakref.ref(g) + next(g) + del g + support.gc_collect() + self.assertIs(wr(), None) + self.assertTrue(frame) + del frame + support.gc_collect() + + def test_refcycle(self): + # A generator caught in a refcycle gets finalized anyway. + old_garbage = gc.garbage[:] + finalized = False + def gen(): + nonlocal finalized + try: + g = yield + yield 1 + finally: + finalized = True + + g = gen() + next(g) + g.send(g) + self.assertGreater(sys.getrefcount(g), 2) + self.assertFalse(finalized) + del g + support.gc_collect() + self.assertTrue(finalized) + self.assertEqual(gc.garbage, old_garbage) + + tutorial_tests = """ Let's try a simple generator: @@ -384,8 +436,8 @@ From the Iterators list, about the types of these things. >>> [s for s in dir(i) if not s.startswith('_')] ['close', 'gi_code', 'gi_frame', 'gi_running', 'send', 'throw'] >>> from test.support import HAVE_DOCSTRINGS ->>> print(i.__next__.__doc__ if HAVE_DOCSTRINGS else 'x.__next__() <==> next(x)') -x.__next__() <==> next(x) +>>> print(i.__next__.__doc__ if HAVE_DOCSTRINGS else 'Implement next(self).') +Implement next(self). >>> iter(i) is i True >>> import types @@ -1729,9 +1781,7 @@ Our ill-behaved code should be invoked during GC: >>> g = f() >>> next(g) >>> del g ->>> sys.stderr.getvalue().startswith( -... "Exception RuntimeError: 'generator ignored GeneratorExit' in " -... ) +>>> "RuntimeError: generator ignored GeneratorExit" in sys.stderr.getvalue() True >>> sys.stderr = old @@ -1841,22 +1891,23 @@ to test. ... sys.stderr = io.StringIO() ... class Leaker: ... def __del__(self): -... raise RuntimeError +... def invoke(message): +... raise RuntimeError(message) +... invoke("test") ... ... l = Leaker() ... del l ... err = sys.stderr.getvalue().strip() -... err.startswith( -... "Exception RuntimeError: RuntimeError() in <" -... ) -... err.endswith("> ignored") -... len(err.splitlines()) +... "Exception ignored in" in err +... "RuntimeError: test" in err +... "Traceback" in err +... "in invoke" in err ... finally: ... sys.stderr = old True True -1 - +True +True These refleak tests should perhaps be in a testfile of their own, @@ -1881,6 +1932,7 @@ __test__ = {"tut": tutorial_tests, # so this works as expected in both ways of running regrtest. def test_main(verbose=None): from test import support, test_generators + support.run_unittest(__name__) support.run_doctest(test_generators, verbose) # This part isn't needed for regrtest, but for running the test directly. diff --git a/Lib/test/test_genericpath.py b/Lib/test/test_genericpath.py index b5068e4c98..e59ed4d21c 100644 --- a/Lib/test/test_genericpath.py +++ b/Lib/test/test_genericpath.py @@ -188,12 +188,93 @@ class GenericTest: support.unlink(support.TESTFN) safe_rmdir(support.TESTFN) + @staticmethod + def _create_file(filename): + with open(filename, 'wb') as f: + f.write(b'foo') + + def test_samefile(self): + try: + test_fn = support.TESTFN + "1" + self._create_file(test_fn) + self.assertTrue(self.pathmodule.samefile(test_fn, test_fn)) + self.assertRaises(TypeError, self.pathmodule.samefile) + finally: + os.remove(test_fn) + + @support.skip_unless_symlink + def test_samefile_on_symlink(self): + self._test_samefile_on_link_func(os.symlink) + + def test_samefile_on_link(self): + self._test_samefile_on_link_func(os.link) + + def _test_samefile_on_link_func(self, func): + try: + test_fn1 = support.TESTFN + "1" + test_fn2 = support.TESTFN + "2" + self._create_file(test_fn1) + + func(test_fn1, test_fn2) + self.assertTrue(self.pathmodule.samefile(test_fn1, test_fn2)) + os.remove(test_fn2) + + self._create_file(test_fn2) + self.assertFalse(self.pathmodule.samefile(test_fn1, test_fn2)) + finally: + os.remove(test_fn1) + os.remove(test_fn2) + + def test_samestat(self): + try: + test_fn = support.TESTFN + "1" + self._create_file(test_fn) + test_fns = [test_fn]*2 + stats = map(os.stat, test_fns) + self.assertTrue(self.pathmodule.samestat(*stats)) + finally: + os.remove(test_fn) + + @support.skip_unless_symlink + def test_samestat_on_symlink(self): + self._test_samestat_on_link_func(os.symlink) + + def test_samestat_on_link(self): + self._test_samestat_on_link_func(os.link) + + def _test_samestat_on_link_func(self, func): + try: + test_fn1 = support.TESTFN + "1" + test_fn2 = support.TESTFN + "2" + self._create_file(test_fn1) + test_fns = (test_fn1, test_fn2) + func(*test_fns) + stats = map(os.stat, test_fns) + self.assertTrue(self.pathmodule.samestat(*stats)) + os.remove(test_fn2) + + self._create_file(test_fn2) + stats = map(os.stat, test_fns) + self.assertFalse(self.pathmodule.samestat(*stats)) + + self.assertRaises(TypeError, self.pathmodule.samestat) + finally: + os.remove(test_fn1) + os.remove(test_fn2) + + def test_sameopenfile(self): + fname = support.TESTFN + "1" + with open(fname, "wb") as a, open(fname, "wb") as b: + self.assertTrue(self.pathmodule.sameopenfile( + a.fileno(), b.fileno())) + class TestGenericTest(GenericTest, unittest.TestCase): # Issue 16852: GenericTest can't inherit from unittest.TestCase # for test discovery purposes; CommonTest inherits from GenericTest # and is only meant to be inherited by others. pathmodule = genericpath + # Following TestCase is not supposed to be run from test_genericpath. # It is inherited by other test modules (macpath, ntpath, posixpath). @@ -348,7 +429,6 @@ class CommonTest(GenericTest): else: self.skipTest("need support.TESTFN_NONASCII") - # Test non-ASCII, non-UTF8 bytes in the path. with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) with support.temp_cwd(name): diff --git a/Lib/test/test_genexps.py b/Lib/test/test_genexps.py index 203b336fde..fb531d6d47 100644 --- a/Lib/test/test_genexps.py +++ b/Lib/test/test_genexps.py @@ -222,8 +222,8 @@ Check that generator attributes are present True >>> from test.support import HAVE_DOCSTRINGS - >>> print(g.__next__.__doc__ if HAVE_DOCSTRINGS else 'x.__next__() <==> next(x)') - x.__next__() <==> next(x) + >>> print(g.__next__.__doc__ if HAVE_DOCSTRINGS else 'Implement next(self).') + Implement next(self). >>> import types >>> isinstance(g, types.GeneratorType) True diff --git a/Lib/test/test_getargs2.py b/Lib/test/test_getargs2.py index fe64ee1669..1853a2dbed 100644 --- a/Lib/test/test_getargs2.py +++ b/Lib/test/test_getargs2.py @@ -3,38 +3,40 @@ from test import support # Skip this test if the _testcapi module isn't available. support.import_module('_testcapi') from _testcapi import getargs_keywords, getargs_keyword_only - -""" -> How about the following counterproposal. This also changes some of -> the other format codes to be a little more regular. -> -> Code C type Range check -> -> b unsigned char 0..UCHAR_MAX -> h signed short SHRT_MIN..SHRT_MAX -> B unsigned char none ** -> H unsigned short none ** -> k * unsigned long none -> I * unsigned int 0..UINT_MAX - - -> i int INT_MIN..INT_MAX -> l long LONG_MIN..LONG_MAX - -> K * unsigned long long none -> L long long LLONG_MIN..LLONG_MAX - -> Notes: -> -> * New format codes. -> -> ** Changed from previous "range-and-a-half" to "none"; the -> range-and-a-half checking wasn't particularly useful. - -Plus a C API or two, e.g. PyInt_AsLongMask() -> -unsigned long and PyInt_AsLongLongMask() -> unsigned -long long (if that exists). -""" +try: + from _testcapi import getargs_L, getargs_K +except ImportError: + getargs_L = None # PY_LONG_LONG not available + +# > How about the following counterproposal. This also changes some of +# > the other format codes to be a little more regular. +# > +# > Code C type Range check +# > +# > b unsigned char 0..UCHAR_MAX +# > h signed short SHRT_MIN..SHRT_MAX +# > B unsigned char none ** +# > H unsigned short none ** +# > k * unsigned long none +# > I * unsigned int 0..UINT_MAX +# +# +# > i int INT_MIN..INT_MAX +# > l long LONG_MIN..LONG_MAX +# +# > K * unsigned long long none +# > L long long LLONG_MIN..LLONG_MAX +# +# > Notes: +# > +# > * New format codes. +# > +# > ** Changed from previous "range-and-a-half" to "none"; the +# > range-and-a-half checking wasn't particularly useful. +# +# Plus a C API or two, e.g. PyInt_AsLongMask() -> +# unsigned long and PyInt_AsLongLongMask() -> unsigned +# long long (if that exists). LARGE = 0x7FFFFFFF VERY_LARGE = 0xFF0000121212121212121242 @@ -77,7 +79,8 @@ class Unsigned_TestCase(unittest.TestCase): self.assertEqual(99, getargs_b(Int())) self.assertEqual(0, getargs_b(IntSubclass())) self.assertRaises(TypeError, getargs_b, BadInt()) - self.assertEqual(1, getargs_b(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_b(BadInt2())) self.assertEqual(0, getargs_b(BadInt3())) self.assertRaises(OverflowError, getargs_b, -1) @@ -95,7 +98,8 @@ class Unsigned_TestCase(unittest.TestCase): self.assertEqual(99, getargs_B(Int())) self.assertEqual(0, getargs_B(IntSubclass())) self.assertRaises(TypeError, getargs_B, BadInt()) - self.assertEqual(1, getargs_B(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_B(BadInt2())) self.assertEqual(0, getargs_B(BadInt3())) self.assertEqual(UCHAR_MAX, getargs_B(-1)) @@ -113,7 +117,8 @@ class Unsigned_TestCase(unittest.TestCase): self.assertEqual(99, getargs_H(Int())) self.assertEqual(0, getargs_H(IntSubclass())) self.assertRaises(TypeError, getargs_H, BadInt()) - self.assertEqual(1, getargs_H(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_H(BadInt2())) self.assertEqual(0, getargs_H(BadInt3())) self.assertEqual(USHRT_MAX, getargs_H(-1)) @@ -132,7 +137,8 @@ class Unsigned_TestCase(unittest.TestCase): self.assertEqual(99, getargs_I(Int())) self.assertEqual(0, getargs_I(IntSubclass())) self.assertRaises(TypeError, getargs_I, BadInt()) - self.assertEqual(1, getargs_I(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_I(BadInt2())) self.assertEqual(0, getargs_I(BadInt3())) self.assertEqual(UINT_MAX, getargs_I(-1)) @@ -172,7 +178,8 @@ class Signed_TestCase(unittest.TestCase): self.assertEqual(99, getargs_h(Int())) self.assertEqual(0, getargs_h(IntSubclass())) self.assertRaises(TypeError, getargs_h, BadInt()) - self.assertEqual(1, getargs_h(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_h(BadInt2())) self.assertEqual(0, getargs_h(BadInt3())) self.assertRaises(OverflowError, getargs_h, SHRT_MIN-1) @@ -190,7 +197,8 @@ class Signed_TestCase(unittest.TestCase): self.assertEqual(99, getargs_i(Int())) self.assertEqual(0, getargs_i(IntSubclass())) self.assertRaises(TypeError, getargs_i, BadInt()) - self.assertEqual(1, getargs_i(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_i(BadInt2())) self.assertEqual(0, getargs_i(BadInt3())) self.assertRaises(OverflowError, getargs_i, INT_MIN-1) @@ -208,7 +216,8 @@ class Signed_TestCase(unittest.TestCase): self.assertEqual(99, getargs_l(Int())) self.assertEqual(0, getargs_l(IntSubclass())) self.assertRaises(TypeError, getargs_l, BadInt()) - self.assertEqual(1, getargs_l(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_l(BadInt2())) self.assertEqual(0, getargs_l(BadInt3())) self.assertRaises(OverflowError, getargs_l, LONG_MIN-1) @@ -239,6 +248,7 @@ class Signed_TestCase(unittest.TestCase): self.assertRaises(OverflowError, getargs_n, VERY_LARGE) +@unittest.skipIf(getargs_L is None, 'PY_LONG_LONG is not available') class LongLong_TestCase(unittest.TestCase): def test_L(self): from _testcapi import getargs_L @@ -249,7 +259,8 @@ class LongLong_TestCase(unittest.TestCase): self.assertEqual(99, getargs_L(Int())) self.assertEqual(0, getargs_L(IntSubclass())) self.assertRaises(TypeError, getargs_L, BadInt()) - self.assertEqual(1, getargs_L(BadInt2())) + with self.assertWarns(DeprecationWarning): + self.assertEqual(1, getargs_L(BadInt2())) self.assertEqual(0, getargs_L(BadInt3())) self.assertRaises(OverflowError, getargs_L, LLONG_MIN-1) @@ -600,24 +611,5 @@ class Unicode_TestCase(unittest.TestCase): self.assertIsNone(getargs_Z_hash(None)) -def test_main(): - tests = [ - Signed_TestCase, - Unsigned_TestCase, - Boolean_TestCase, - Tuple_TestCase, - Keywords_TestCase, - KeywordOnly_TestCase, - Bytes_TestCase, - Unicode_TestCase, - ] - try: - from _testcapi import getargs_L, getargs_K - except ImportError: - pass # PY_LONG_LONG not available - else: - tests.append(LongLong_TestCase) - support.run_unittest(*tests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py index eb9aeb5776..a5ab8d6c3e 100644 --- a/Lib/test/test_glob.py +++ b/Lib/test/test_glob.py @@ -169,6 +169,28 @@ class GlobTests(unittest.TestCase): eq(glob.glob('\\\\*\\*\\'), []) eq(glob.glob(b'\\\\*\\*\\'), []) + def check_escape(self, arg, expected): + self.assertEqual(glob.escape(arg), expected) + self.assertEqual(glob.escape(os.fsencode(arg)), os.fsencode(expected)) + + def test_escape(self): + check = self.check_escape + check('abc', 'abc') + check('[', '[[]') + check('?', '[?]') + check('*', '[*]') + check('[[_/*?*/_]]', '[[][[]_/[*][?][*]/_]]') + check('/[[_/*?*/_]]/', '/[[][[]_/[*][?][*]/_]]/') + + @unittest.skipUnless(sys.platform == "win32", "Win32 specific test") + def test_escape_windows(self): + check = self.check_escape + check('?:?', '?:[?]') + check('*:*', '*:[*]') + check(r'\\?\c:\?', r'\\?\c:\[?]') + check(r'\\*\*\*', r'\\*\*\[*]') + check('//?/c:/?', '//?/c:/[?]') + check('//*/*/*', '//*/*/[*]') def test_main(): run_unittest(GlobTests) diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py index 6b326bdaa1..9da6338099 100644 --- a/Lib/test/test_grammar.py +++ b/Lib/test/test_grammar.py @@ -314,6 +314,13 @@ class GrammarTests(unittest.TestCase): self.assertEqual(f.__annotations__, {'b': 1, 'c': 2, 'e': 3, 'g': 6, 'h': 7, 'j': 9, 'k': 11, 'return': 12}) + # Check for issue #20625 -- annotations mangling + class Spam: + def f(self, *, __kw:1): + pass + class Ham(Spam): pass + self.assertEquals(Spam.f.__annotations__, {'_Spam__kw': 1}) + self.assertEquals(Ham.f.__annotations__, {'_Spam__kw': 1}) # Check for SF Bug #1697248 - mixing decorators and a return annotation def null(x): return x @null diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py index c94d54e02e..52894077c9 100644 --- a/Lib/test/test_gzip.py +++ b/Lib/test/test_gzip.py @@ -130,6 +130,14 @@ class TestGzip(BaseTest): if not ztxt: break self.assertEqual(contents, b'a'*201) + def test_exclusive_write(self): + with gzip.GzipFile(self.filename, 'xb') as f: + f.write(data1 * 50) + with gzip.GzipFile(self.filename, 'rb') as f: + self.assertEqual(f.read(), data1 * 50) + with self.assertRaises(FileExistsError): + gzip.GzipFile(self.filename, 'xb') + def test_buffered_reader(self): # Issue #7471: a GzipFile can be wrapped in a BufferedReader for # performance. @@ -205,6 +213,9 @@ class TestGzip(BaseTest): self.test_write() with gzip.GzipFile(self.filename, 'r') as f: self.assertEqual(f.myfileobj.mode, 'rb') + support.unlink(self.filename) + with gzip.GzipFile(self.filename, 'x') as f: + self.assertEqual(f.myfileobj.mode, 'xb') def test_1647484(self): for mode in ('wb', 'rb'): @@ -388,6 +399,20 @@ class TestGzip(BaseTest): datac = gzip.compress(data) self.assertEqual(gzip.decompress(datac), data) + def test_read_truncated(self): + data = data1*50 + # Drop the CRC (4 bytes) and file size (4 bytes). + truncated = gzip.compress(data)[:-8] + with gzip.GzipFile(fileobj=io.BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + with gzip.GzipFile(fileobj=io.BytesIO(truncated)) as f: + self.assertEqual(f.read(len(data)), data) + self.assertRaises(EOFError, f.read, 1) + # Incomplete 10-byte header. + for i in range(2, 10): + with gzip.GzipFile(fileobj=io.BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + def test_read_with_extra(self): # Gzip data with an extra field gzdata = (b'\x1f\x8b\x08\x04\xb2\x17cQ\x02\xff' @@ -399,35 +424,59 @@ class TestGzip(BaseTest): class TestOpen(BaseTest): def test_binary_modes(self): uncompressed = data1 * 50 + with gzip.open(self.filename, "wb") as f: f.write(uncompressed) with open(self.filename, "rb") as f: file_data = gzip.decompress(f.read()) self.assertEqual(file_data, uncompressed) + with gzip.open(self.filename, "rb") as f: self.assertEqual(f.read(), uncompressed) + with gzip.open(self.filename, "ab") as f: f.write(uncompressed) with open(self.filename, "rb") as f: file_data = gzip.decompress(f.read()) self.assertEqual(file_data, uncompressed * 2) + with self.assertRaises(FileExistsError): + gzip.open(self.filename, "xb") + support.unlink(self.filename) + with gzip.open(self.filename, "xb") as f: + f.write(uncompressed) + with open(self.filename, "rb") as f: + file_data = gzip.decompress(f.read()) + self.assertEqual(file_data, uncompressed) + def test_implicit_binary_modes(self): # Test implicit binary modes (no "b" or "t" in mode string). uncompressed = data1 * 50 + with gzip.open(self.filename, "w") as f: f.write(uncompressed) with open(self.filename, "rb") as f: file_data = gzip.decompress(f.read()) self.assertEqual(file_data, uncompressed) + with gzip.open(self.filename, "r") as f: self.assertEqual(f.read(), uncompressed) + with gzip.open(self.filename, "a") as f: f.write(uncompressed) with open(self.filename, "rb") as f: file_data = gzip.decompress(f.read()) self.assertEqual(file_data, uncompressed * 2) + with self.assertRaises(FileExistsError): + gzip.open(self.filename, "x") + support.unlink(self.filename) + with gzip.open(self.filename, "x") as f: + f.write(uncompressed) + with open(self.filename, "rb") as f: + file_data = gzip.decompress(f.read()) + self.assertEqual(file_data, uncompressed) + def test_text_modes(self): uncompressed = data1.decode("ascii") * 50 uncompressed_raw = uncompressed.replace("\n", os.linesep) @@ -462,6 +511,8 @@ class TestOpen(BaseTest): with self.assertRaises(ValueError): gzip.open(self.filename, "wbt") with self.assertRaises(ValueError): + gzip.open(self.filename, "xbt") + with self.assertRaises(ValueError): gzip.open(self.filename, "rb", encoding="utf-8") with self.assertRaises(ValueError): gzip.open(self.filename, "rb", errors="ignore") diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py index 22ebbccb22..f647c6f7d2 100644 --- a/Lib/test/test_hash.py +++ b/Lib/test/test_hash.py @@ -12,6 +12,40 @@ from collections import Hashable IS_64BIT = sys.maxsize > 2**32 +def lcg(x, length=16): + """Linear congruential generator""" + if x == 0: + return bytes(length) + out = bytearray(length) + for i in range(length): + x = (214013 * x + 2531011) & 0x7fffffff + out[i] = (x >> 16) & 0xff + return bytes(out) + +def pysiphash(uint64): + """Convert SipHash24 output to Py_hash_t + """ + assert 0 <= uint64 < (1 << 64) + # simple unsigned to signed int64 + if uint64 > (1 << 63) - 1: + int64 = uint64 - (1 << 64) + else: + int64 = uint64 + # mangle uint64 to uint32 + uint32 = (uint64 ^ uint64 >> 32) & 0xffffffff + # simple unsigned to signed int32 + if uint32 > (1 << 31) - 1: + int32 = uint32 - (1 << 32) + else: + int32 = uint32 + return int32, int64 + +def skip_unless_internalhash(test): + """Skip decorator for tests that depend on SipHash24 or FNV""" + ok = sys.hash_info.algorithm in {"fnv", "siphash24"} + msg = "Requires SipHash24 or FNV" + return test if ok else unittest.skip(msg)(test) + class HashEqualityTestCase(unittest.TestCase): @@ -161,12 +195,64 @@ class HashRandomizationTests: self.assertNotEqual(run1, run2) class StringlikeHashRandomizationTests(HashRandomizationTests): + repr_ = None + repr_long = None + + # 32bit little, 64bit little, 32bit big, 64bit big + known_hashes = { + 'djba33x': [ # only used for small strings + # seed 0, 'abc' + [193485960, 193485960, 193485960, 193485960], + # seed 42, 'abc' + [-678966196, 573763426263223372, -820489388, -4282905804826039665], + ], + 'siphash24': [ + # NOTE: PyUCS2 layout depends on endianess + # seed 0, 'abc' + [1198583518, 4596069200710135518, 1198583518, 4596069200710135518], + # seed 42, 'abc' + [273876886, -4501618152524544106, 273876886, -4501618152524544106], + # seed 42, 'abcdefghijk' + [-1745215313, 4436719588892876975, -1745215313, 4436719588892876975], + # seed 0, 'äú∑ℇ' + [493570806, 5749986484189612790, -1006381564, -5915111450199468540], + # seed 42, 'äú∑ℇ' + [-1677110816, -2947981342227738144, -1860207793, -4296699217652516017], + ], + 'fnv': [ + # seed 0, 'abc' + [-1600925533, 1453079729188098211, -1600925533, + 1453079729188098211], + # seed 42, 'abc' + [-206076799, -4410911502303878509, -1024014457, + -3570150969479994130], + # seed 42, 'abcdefghijk' + [811136751, -5046230049376118746, -77208053 , + -4779029615281019666], + # seed 0, 'äú∑ℇ' + [44402817, 8998297579845987431, -1956240331, + -782697888614047887], + # seed 42, 'äú∑ℇ' + [-283066365, -4576729883824601543, -271871407, + -3927695501187247084], + ] + } + + def get_expected_hash(self, position, length): + if length < sys.hash_info.cutoff: + algorithm = "djba33x" + else: + algorithm = sys.hash_info.algorithm + if sys.byteorder == 'little': + platform = 1 if IS_64BIT else 0 + else: + assert(sys.byteorder == 'big') + platform = 3 if IS_64BIT else 2 + return self.known_hashes[algorithm][position][platform] + def test_null_hash(self): # PYTHONHASHSEED=0 disables the randomized hash - if IS_64BIT: - known_hash_of_obj = 1453079729188098211 - else: - known_hash_of_obj = -1600925533 + known_hash_of_obj = self.get_expected_hash(0, 3) # Randomization is enabled by default: self.assertNotEqual(self.get_hash(self.repr_), known_hash_of_obj) @@ -174,39 +260,53 @@ class StringlikeHashRandomizationTests(HashRandomizationTests): # It can also be disabled by setting the seed to 0: self.assertEqual(self.get_hash(self.repr_, seed=0), known_hash_of_obj) + @skip_unless_internalhash def test_fixed_hash(self): # test a fixed seed for the randomized hash # Note that all types share the same values: - if IS_64BIT: - if sys.byteorder == 'little': - h = -4410911502303878509 - else: - h = -3570150969479994130 - else: - if sys.byteorder == 'little': - h = -206076799 - else: - h = -1024014457 + h = self.get_expected_hash(1, 3) self.assertEqual(self.get_hash(self.repr_, seed=42), h) + @skip_unless_internalhash + def test_long_fixed_hash(self): + if self.repr_long is None: + return + h = self.get_expected_hash(2, 11) + self.assertEqual(self.get_hash(self.repr_long, seed=42), h) + + class StrHashRandomizationTests(StringlikeHashRandomizationTests, unittest.TestCase): repr_ = repr('abc') + repr_long = repr('abcdefghijk') + repr_ucs2 = repr('äú∑ℇ') + @skip_unless_internalhash def test_empty_string(self): self.assertEqual(hash(""), 0) + @skip_unless_internalhash + def test_ucs2_string(self): + h = self.get_expected_hash(3, 6) + self.assertEqual(self.get_hash(self.repr_ucs2, seed=0), h) + h = self.get_expected_hash(4, 6) + self.assertEqual(self.get_hash(self.repr_ucs2, seed=42), h) + class BytesHashRandomizationTests(StringlikeHashRandomizationTests, unittest.TestCase): repr_ = repr(b'abc') + repr_long = repr(b'abcdefghijk') + @skip_unless_internalhash def test_empty_string(self): self.assertEqual(hash(b""), 0) class MemoryviewHashRandomizationTests(StringlikeHashRandomizationTests, unittest.TestCase): repr_ = "memoryview(b'abc')" + repr_long = "memoryview(b'abcdefghijk')" + @skip_unless_internalhash def test_empty_string(self): self.assertEqual(hash(memoryview(b"")), 0) @@ -224,5 +324,23 @@ class DatetimeTimeTests(DatetimeTests, unittest.TestCase): repr_ = repr(datetime.time(0)) +class HashDistributionTestCase(unittest.TestCase): + + def test_hash_distribution(self): + # check for hash collision + base = "abcdefghabcdefg" + for i in range(1, len(base)): + prefix = base[:i] + with self.subTest(prefix=prefix): + s15 = set() + s255 = set() + for c in range(256): + h = hash(prefix + chr(c)) + s15.add(h & 0xf) + s255.add(h & 0xff) + # SipHash24 distribution depends on key, usually > 60% + self.assertGreater(len(s15), 8, prefix) + self.assertGreater(len(s255), 128, prefix) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py index c0470d6018..36a98cb027 100644 --- a/Lib/test/test_hashlib.py +++ b/Lib/test/test_hashlib.py @@ -18,11 +18,13 @@ except ImportError: import unittest import warnings from test import support -from test.support import _4G, bigmemtest +from test.support import _4G, bigmemtest, import_fresh_module # Were we compiled --with-pydebug or with #define Py_DEBUG? COMPILED_WITH_PYDEBUG = hasattr(sys, 'gettotalrefcount') +c_hashlib = import_fresh_module('hashlib', fresh=['_hashlib']) +py_hashlib = import_fresh_module('hashlib', blocked=['_hashlib']) def hexstr(s): assert isinstance(s, bytes), repr(s) @@ -36,7 +38,7 @@ def hexstr(s): class HashLibTestCase(unittest.TestCase): supported_hash_names = ( 'md5', 'MD5', 'sha1', 'SHA1', 'sha224', 'SHA224', 'sha256', 'SHA256', - 'sha384', 'SHA384', 'sha512', 'SHA512' ) + 'sha384', 'SHA384', 'sha512', 'SHA512') # Issue #14693: fallback modules are always compiled under POSIX _warn_on_extension_import = os.name == 'posix' or COMPILED_WITH_PYDEBUG @@ -72,27 +74,31 @@ class HashLibTestCase(unittest.TestCase): if _hashlib: # These two algorithms should always be present when this module # is compiled. If not, something was compiled wrong. - assert hasattr(_hashlib, 'openssl_md5') - assert hasattr(_hashlib, 'openssl_sha1') + self.assertTrue(hasattr(_hashlib, 'openssl_md5')) + self.assertTrue(hasattr(_hashlib, 'openssl_sha1')) for algorithm, constructors in self.constructors_to_test.items(): constructor = getattr(_hashlib, 'openssl_'+algorithm, None) if constructor: constructors.add(constructor) + def add_builtin_constructor(name): + constructor = getattr(hashlib, "__get_builtin_constructor")(name) + self.constructors_to_test[name].add(constructor) + _md5 = self._conditional_import_module('_md5') if _md5: - self.constructors_to_test['md5'].add(_md5.md5) + add_builtin_constructor('md5') _sha1 = self._conditional_import_module('_sha1') if _sha1: - self.constructors_to_test['sha1'].add(_sha1.sha1) + add_builtin_constructor('sha1') _sha256 = self._conditional_import_module('_sha256') if _sha256: - self.constructors_to_test['sha224'].add(_sha256.sha224) - self.constructors_to_test['sha256'].add(_sha256.sha256) + add_builtin_constructor('sha224') + add_builtin_constructor('sha256') _sha512 = self._conditional_import_module('_sha512') if _sha512: - self.constructors_to_test['sha384'].add(_sha512.sha384) - self.constructors_to_test['sha512'].add(_sha512.sha512) + add_builtin_constructor('sha384') + add_builtin_constructor('sha512') super(HashLibTestCase, self).__init__(*args, **kwargs) @@ -121,8 +127,10 @@ class HashLibTestCase(unittest.TestCase): self.assertRaises(TypeError, hashlib.new, 1) def test_get_builtin_constructor(self): - get_builtin_constructor = hashlib.__dict__[ - '__get_builtin_constructor'] + get_builtin_constructor = getattr(hashlib, + '__get_builtin_constructor') + builtin_constructor_cache = getattr(hashlib, + '__builtin_constructor_cache') self.assertRaises(ValueError, get_builtin_constructor, 'test') try: import _md5 @@ -130,6 +138,8 @@ class HashLibTestCase(unittest.TestCase): pass # This forces an ImportError for "import _md5" statements sys.modules['_md5'] = None + # clear the cache + builtin_constructor_cache.clear() try: self.assertRaises(ValueError, get_builtin_constructor, 'md5') finally: @@ -138,13 +148,23 @@ class HashLibTestCase(unittest.TestCase): else: del sys.modules['_md5'] self.assertRaises(TypeError, get_builtin_constructor, 3) + constructor = get_builtin_constructor('md5') + self.assertIs(constructor, _md5.md5) + self.assertEqual(sorted(builtin_constructor_cache), ['MD5', 'md5']) def test_hexdigest(self): for cons in self.hash_constructors: h = cons() - assert isinstance(h.digest(), bytes), name + self.assertIsInstance(h.digest(), bytes) self.assertEqual(hexstr(h.digest()), h.hexdigest()) + def test_name_attribute(self): + for cons in self.hash_constructors: + h = cons() + self.assertIsInstance(h.name, str) + self.assertIn(h.name, self.supported_hash_names) + self.assertEqual(h.name, hashlib.new(h.name).name) + def test_large_update(self): aas = b'a' * 128 bees = b'b' * 127 @@ -213,8 +233,9 @@ class HashLibTestCase(unittest.TestCase): self.assertEqual(m.block_size, block_size) self.assertEqual(m.digest_size, digest_size) self.assertEqual(len(m.digest()), digest_size) - self.assertEqual(m.name.lower(), name.lower()) - self.assertIn(name.split("_")[0], repr(m).lower()) + self.assertEqual(m.name, name) + # split for sha3_512 / _sha3.sha3 object + self.assertIn(name.split("_")[0], repr(m)) def test_blocksize_name(self): self.check_blocksize_name('md5', 64, 16) @@ -352,6 +373,12 @@ class HashLibTestCase(unittest.TestCase): "e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973eb"+ "de0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b") + @unittest.skipIf(sys.maxsize < _4G + 5, 'test cannot run on 32-bit systems') + @bigmemtest(size=_4G + 5, memuse=1, dry_run=False) + def test_case_sha3_224_huge(self, size): + self.check('sha3_224', b'A'*size, + '58ef60057c9dddb6a87477e9ace5a26f0d9db01881cf9b10a9f8c224') + def test_gil(self): # Check things work fine with an input larger than the size required # for multithreaded operation (which is hardwired to 2048). @@ -400,8 +427,8 @@ class HashLibTestCase(unittest.TestCase): events = [] for threadnum in range(num_threads): chunk_size = len(data) // (10**threadnum) - assert chunk_size > 0 - assert chunk_size % len(smallest_data) == 0 + self.assertGreater(chunk_size, 0) + self.assertEqual(chunk_size % len(smallest_data), 0) event = threading.Event() events.append(event) threading.Thread(target=hash_in_chunks, @@ -412,8 +439,95 @@ class HashLibTestCase(unittest.TestCase): self.assertEqual(expected_hash, hasher.hexdigest()) -def test_main(): - support.run_unittest(HashLibTestCase) + +class KDFTests(unittest.TestCase): + + pbkdf2_test_vectors = [ + (b'password', b'salt', 1, None), + (b'password', b'salt', 2, None), + (b'password', b'salt', 4096, None), + # too slow, it takes over a minute on a fast CPU. + #(b'password', b'salt', 16777216, None), + (b'passwordPASSWORDpassword', b'saltSALTsaltSALTsaltSALTsaltSALTsalt', + 4096, -1), + (b'pass\0word', b'sa\0lt', 4096, 16), + ] + + pbkdf2_results = { + "sha1": [ + # offical test vectors from RFC 6070 + (bytes.fromhex('0c60c80f961f0e71f3a9b524af6012062fe037a6'), None), + (bytes.fromhex('ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957'), None), + (bytes.fromhex('4b007901b765489abead49d926f721d065a429c1'), None), + #(bytes.fromhex('eefe3d61cd4da4e4e9945b3d6ba2158c2634e984'), None), + (bytes.fromhex('3d2eec4fe41c849b80c8d83662c0e44a8b291a964c' + 'f2f07038'), 25), + (bytes.fromhex('56fa6aa75548099dcc37d7f03425e0c3'), None),], + "sha256": [ + (bytes.fromhex('120fb6cffcf8b32c43e7225256c4f837' + 'a86548c92ccc35480805987cb70be17b'), None), + (bytes.fromhex('ae4d0c95af6b46d32d0adff928f06dd0' + '2a303f8ef3c251dfd6e2d85a95474c43'), None), + (bytes.fromhex('c5e478d59288c841aa530db6845c4c8d' + '962893a001ce4e11a4963873aa98134a'), None), + #(bytes.fromhex('cf81c66fe8cfc04d1f31ecb65dab4089' + # 'f7f179e89b3b0bcb17ad10e3ac6eba46'), None), + (bytes.fromhex('348c89dbcbd32b2f32d814b8116e84cf2b17' + '347ebc1800181c4e2a1fb8dd53e1c635518c7dac47e9'), 40), + (bytes.fromhex('89b69d0516f829893c696226650a8687'), None),], + "sha512": [ + (bytes.fromhex('867f70cf1ade02cff3752599a3a53dc4af34c7a669815ae5' + 'd513554e1c8cf252c02d470a285a0501bad999bfe943c08f' + '050235d7d68b1da55e63f73b60a57fce'), None), + (bytes.fromhex('e1d9c16aa681708a45f5c7c4e215ceb66e011a2e9f004071' + '3f18aefdb866d53cf76cab2868a39b9f7840edce4fef5a82' + 'be67335c77a6068e04112754f27ccf4e'), None), + (bytes.fromhex('d197b1b33db0143e018b12f3d1d1479e6cdebdcc97c5c0f8' + '7f6902e072f457b5143f30602641b3d55cd335988cb36b84' + '376060ecd532e039b742a239434af2d5'), None), + (bytes.fromhex('8c0511f4c6e597c6ac6315d8f0362e225f3c501495ba23b8' + '68c005174dc4ee71115b59f9e60cd9532fa33e0f75aefe30' + '225c583a186cd82bd4daea9724a3d3b8'), 64), + (bytes.fromhex('9d9e9c4cd21fe4be24d5b8244c759665'), None),], + } + + def _test_pbkdf2_hmac(self, pbkdf2): + for digest_name, results in self.pbkdf2_results.items(): + for i, vector in enumerate(self.pbkdf2_test_vectors): + password, salt, rounds, dklen = vector + expected, overwrite_dklen = results[i] + if overwrite_dklen: + dklen = overwrite_dklen + out = pbkdf2(digest_name, password, salt, rounds, dklen) + self.assertEqual(out, expected, + (digest_name, password, salt, rounds, dklen)) + out = pbkdf2(digest_name, memoryview(password), + memoryview(salt), rounds, dklen) + out = pbkdf2(digest_name, bytearray(password), + bytearray(salt), rounds, dklen) + self.assertEqual(out, expected) + if dklen is None: + out = pbkdf2(digest_name, password, salt, rounds) + self.assertEqual(out, expected, + (digest_name, password, salt, rounds)) + + self.assertRaises(TypeError, pbkdf2, b'sha1', b'pass', b'salt', 1) + self.assertRaises(TypeError, pbkdf2, 'sha1', 'pass', 'salt', 1) + self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', 0) + self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', -1) + self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', 1, 0) + self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', 1, -1) + with self.assertRaisesRegex(ValueError, 'unsupported hash type'): + pbkdf2('unknown', b'pass', b'salt', 1) + + def test_pbkdf2_hmac_py(self): + self._test_pbkdf2_hmac(py_hashlib.pbkdf2_hmac) + + @unittest.skipUnless(hasattr(c_hashlib, 'pbkdf2_hmac'), + ' test requires OpenSSL > 1.0') + def test_pbkdf2_hmac_c(self): + self._test_pbkdf2_hmac(c_hashlib.pbkdf2_hmac) + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index 4ca7cec44c..cde56fd76f 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -1,17 +1,39 @@ +import functools import hmac import hashlib import unittest import warnings from test import support + +def ignore_warning(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=PendingDeprecationWarning) + return func(*args, **kwargs) + return wrapper + + class TestVectorsTestCase(unittest.TestCase): def test_md5_vectors(self): # Test the HMAC module against test vectors from the RFC. def md5test(key, data, digest): - h = hmac.HMAC(key, data) + h = hmac.HMAC(key, data, digestmod=hashlib.md5) self.assertEqual(h.hexdigest().upper(), digest.upper()) + self.assertEqual(h.name, "hmac-md5") + self.assertEqual(h.digest_size, 16) + self.assertEqual(h.block_size, 64) + + h = hmac.HMAC(key, data, digestmod='md5') + self.assertEqual(h.hexdigest().upper(), digest.upper()) + self.assertEqual(h.name, "hmac-md5") + self.assertEqual(h.digest_size, 16) + self.assertEqual(h.block_size, 64) + md5test(b"\x0b" * 16, b"Hi There", @@ -46,6 +68,16 @@ class TestVectorsTestCase(unittest.TestCase): def shatest(key, data, digest): h = hmac.HMAC(key, data, digestmod=hashlib.sha1) self.assertEqual(h.hexdigest().upper(), digest.upper()) + self.assertEqual(h.name, "hmac-sha1") + self.assertEqual(h.digest_size, 20) + self.assertEqual(h.block_size, 64) + + h = hmac.HMAC(key, data, digestmod='sha1') + self.assertEqual(h.hexdigest().upper(), digest.upper()) + self.assertEqual(h.name, "hmac-sha1") + self.assertEqual(h.digest_size, 20) + self.assertEqual(h.block_size, 64) + shatest(b"\x0b" * 20, b"Hi There", @@ -76,10 +108,21 @@ class TestVectorsTestCase(unittest.TestCase): b"and Larger Than One Block-Size Data"), "e8e99d0f45237d786d6bbaa7965c7808bbff1a91") - def _rfc4231_test_cases(self, hashfunc): + def _rfc4231_test_cases(self, hashfunc, hash_name, digest_size, block_size): def hmactest(key, data, hexdigests): + hmac_name = "hmac-" + hash_name h = hmac.HMAC(key, data, digestmod=hashfunc) self.assertEqual(h.hexdigest().lower(), hexdigests[hashfunc]) + self.assertEqual(h.name, hmac_name) + self.assertEqual(h.digest_size, digest_size) + self.assertEqual(h.block_size, block_size) + + h = hmac.HMAC(key, data, digestmod=hash_name) + self.assertEqual(h.hexdigest().lower(), hexdigests[hashfunc]) + self.assertEqual(h.name, hmac_name) + self.assertEqual(h.digest_size, digest_size) + self.assertEqual(h.block_size, block_size) + # 4.2. Test Case 1 hmactest(key = b'\x0b'*20, @@ -189,16 +232,16 @@ class TestVectorsTestCase(unittest.TestCase): }) def test_sha224_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha224) + self._rfc4231_test_cases(hashlib.sha224, 'sha224', 28, 64) def test_sha256_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha256) + self._rfc4231_test_cases(hashlib.sha256, 'sha256', 32, 64) def test_sha384_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha384) + self._rfc4231_test_cases(hashlib.sha384, 'sha384', 48, 128) def test_sha512_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha512) + self._rfc4231_test_cases(hashlib.sha512, 'sha512', 64, 128) def test_legacy_block_size_warnings(self): class MockCrazyHash(object): @@ -222,46 +265,74 @@ class TestVectorsTestCase(unittest.TestCase): hmac.HMAC(b'a', b'b', digestmod=MockCrazyHash) self.fail('Expected warning about small block_size') + def test_with_digestmod_warning(self): + with self.assertWarns(PendingDeprecationWarning): + key = b"\x0b" * 16 + data = b"Hi There" + digest = "9294727A3638BB1C13F48EF8158BFC9D" + h = hmac.HMAC(key, data) + self.assertEqual(h.hexdigest().upper(), digest) class ConstructorTestCase(unittest.TestCase): + @ignore_warning def test_normal(self): # Standard constructor call. failed = 0 try: h = hmac.HMAC(b"key") - except: + except Exception: self.fail("Standard constructor call raised exception.") + @ignore_warning def test_with_str_key(self): # Pass a key of type str, which is an error, because it expects a key # of type bytes with self.assertRaises(TypeError): h = hmac.HMAC("key") + @ignore_warning def test_dot_new_with_str_key(self): # Pass a key of type str, which is an error, because it expects a key # of type bytes with self.assertRaises(TypeError): h = hmac.new("key") + @ignore_warning def test_withtext(self): # Constructor call with text. try: h = hmac.HMAC(b"key", b"hash this!") - except: + except Exception: self.fail("Constructor call with text argument raised exception.") + self.assertEqual(h.hexdigest(), '34325b639da4cfd95735b381e28cb864') + + def test_with_bytearray(self): + try: + h = hmac.HMAC(bytearray(b"key"), bytearray(b"hash this!"), + digestmod="md5") + except Exception: + self.fail("Constructor call with bytearray arguments raised exception.") + self.assertEqual(h.hexdigest(), '34325b639da4cfd95735b381e28cb864') + + def test_with_memoryview_msg(self): + try: + h = hmac.HMAC(b"key", memoryview(b"hash this!"), digestmod="md5") + except Exception: + self.fail("Constructor call with memoryview msg raised exception.") + self.assertEqual(h.hexdigest(), '34325b639da4cfd95735b381e28cb864') def test_withmodule(self): # Constructor call with text and digest module. try: h = hmac.HMAC(b"key", b"", hashlib.sha1) - except: + except Exception: self.fail("Constructor call with hashlib.sha1 raised exception.") class SanityTestCase(unittest.TestCase): + @ignore_warning def test_default_is_md5(self): # Testing if HMAC defaults to MD5 algorithm. # NOTE: this whitebox test depends on the hmac class internals @@ -272,19 +343,19 @@ class SanityTestCase(unittest.TestCase): # Exercising all methods once. # This must not raise any exceptions try: - h = hmac.HMAC(b"my secret key") + h = hmac.HMAC(b"my secret key", digestmod="md5") h.update(b"compute the hash of this text!") dig = h.digest() dig = h.hexdigest() h2 = h.copy() - except: + except Exception: self.fail("Exception raised during normal usage of HMAC class.") class CopyTestCase(unittest.TestCase): def test_attributes(self): # Testing if attributes are of same type. - h1 = hmac.HMAC(b"key") + h1 = hmac.HMAC(b"key", digestmod="md5") h2 = h1.copy() self.assertTrue(h1.digest_cons == h2.digest_cons, "digest constructors don't match.") @@ -295,7 +366,7 @@ class CopyTestCase(unittest.TestCase): def test_realcopy(self): # Testing if the copy method created a real copy. - h1 = hmac.HMAC(b"key") + h1 = hmac.HMAC(b"key", digestmod="md5") h2 = h1.copy() # Using id() in case somebody has overridden __eq__/__ne__. self.assertTrue(id(h1) != id(h2), "No real copy of the HMAC instance.") @@ -306,7 +377,7 @@ class CopyTestCase(unittest.TestCase): def test_equality(self): # Testing if the copy has the same digests. - h1 = hmac.HMAC(b"key") + h1 = hmac.HMAC(b"key", digestmod="md5") h1.update(b"some random text") h2 = h1.copy() self.assertEqual(h1.digest(), h2.digest(), diff --git a/Lib/test/test_html.py b/Lib/test/test_html.py index 30dac58e6d..5e9f382a2d 100644 --- a/Lib/test/test_html.py +++ b/Lib/test/test_html.py @@ -16,9 +16,89 @@ class HtmlTests(unittest.TestCase): html.escape('\'<script>"&foo;"</script>\'', False), '\'<script>"&foo;"</script>\'') + def test_unescape(self): + numeric_formats = ['&#%d', '&#%d;', '&#x%x', '&#x%x;'] + errmsg = 'unescape(%r) should have returned %r' + def check(text, expected): + self.assertEqual(html.unescape(text), expected, + msg=errmsg % (text, expected)) + def check_num(num, expected): + for format in numeric_formats: + text = format % num + self.assertEqual(html.unescape(text), expected, + msg=errmsg % (text, expected)) + # check text with no character references + check('no character references', 'no character references') + # check & followed by invalid chars + check('&\n&\t& &&', '&\n&\t& &&') + # check & followed by numbers and letters + check('&0 &9 &a &0; &9; &a;', '&0 &9 &a &0; &9; &a;') + # check incomplete entities at the end of the string + for x in ['&', '&#', '&#x', '&#X', '&#y', '&#xy', '&#Xy']: + check(x, x) + check(x+';', x+';') + # check several combinations of numeric character references, + # possibly followed by different characters + formats = ['&#%d', '&#%07d', '&#%d;', '&#%07d;', + '&#x%x', '&#x%06x', '&#x%x;', '&#x%06x;', + '&#x%X', '&#x%06X', '&#X%x;', '&#X%06x;'] + for num, char in zip([65, 97, 34, 38, 0x2603, 0x101234], + ['A', 'a', '"', '&', '\u2603', '\U00101234']): + for s in formats: + check(s % num, char) + for end in [' ', 'X']: + check((s+end) % num, char+end) + # check invalid codepoints + for cp in [0xD800, 0xDB00, 0xDC00, 0xDFFF, 0x110000]: + check_num(cp, '\uFFFD') + # check more invalid codepoints + for cp in [0x1, 0xb, 0xe, 0x7f, 0xfffe, 0xffff, 0x10fffe, 0x10ffff]: + check_num(cp, '') + # check invalid numbers + for num, ch in zip([0x0d, 0x80, 0x95, 0x9d], '\r\u20ac\u2022\x9d'): + check_num(num, ch) + # check small numbers + check_num(0, '\uFFFD') + check_num(9, '\t') + # check a big number + check_num(1000000000000000000, '\uFFFD') + # check that multiple trailing semicolons are handled correctly + for e in ['";', '";', '";', '";']: + check(e, '";') + # check that semicolons in the middle don't create problems + for e in ['"quot;', '"quot;', '"quot;', '"quot;']: + check(e, '"quot;') + # check triple adjacent charrefs + for e in ['"', '"', '"', '"']: + check(e*3, '"""') + check((e+';')*3, '"""') + # check that the case is respected + for e in ['&', '&', '&', '&']: + check(e, '&') + for e in ['&Amp', '&Amp;']: + check(e, e) + # check that non-existent named entities are returned unchanged + check('&svadilfari;', '&svadilfari;') + # the following examples are in the html5 specs + check('¬it', '¬it') + check('¬it;', '¬it;') + check('¬in', '¬in') + check('∉', '∉') + # a similar example with a long name + check('¬ReallyAnExistingNamedCharacterReference;', + '¬ReallyAnExistingNamedCharacterReference;') + # longest valid name + check('∳', '∳') + # check a charref that maps to two unicode chars + check('∾̳', '\u223E\u0333') + check('&acE', '&acE') + # see #12888 + check('{ ' * 1050, '{ ' * 1050) + # see #15156 + check('ÉricÉric&alphacentauriαcentauri', + 'ÉricÉric&alphacentauriαcentauri') + check('&co;', '&co;') -def test_main(): - run_unittest(HtmlTests) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py index 11d9c9ce8b..2d771a2a97 100644 --- a/Lib/test/test_htmlparser.py +++ b/Lib/test/test_htmlparser.py @@ -70,6 +70,18 @@ class EventCollectorExtra(EventCollector): self.append(("starttag_text", self.get_starttag_text())) +class EventCollectorCharrefs(EventCollector): + + def get_events(self): + return self.events + + def handle_charref(self, data): + self.fail('This should never be called with convert_charrefs=True') + + def handle_entityref(self, data): + self.fail('This should never be called with convert_charrefs=True') + + class TestCaseBase(unittest.TestCase): def get_collector(self): @@ -84,26 +96,30 @@ class TestCaseBase(unittest.TestCase): parser.close() events = parser.get_events() if events != expected_events: - self.fail("received events did not match expected events\n" - "Expected:\n" + pprint.pformat(expected_events) + + self.fail("received events did not match expected events" + + "\nSource:\n" + repr(source) + + "\nExpected:\n" + pprint.pformat(expected_events) + "\nReceived:\n" + pprint.pformat(events)) def _run_check_extra(self, source, events): - self._run_check(source, events, EventCollectorExtra()) + self._run_check(source, events, + EventCollectorExtra(convert_charrefs=False)) def _parse_error(self, source): def parse(source=source): parser = self.get_collector() parser.feed(source) parser.close() - self.assertRaises(html.parser.HTMLParseError, parse) + with self.assertRaises(html.parser.HTMLParseError): + with self.assertWarns(DeprecationWarning): + parse() class HTMLParserStrictTestCase(TestCaseBase): def get_collector(self): with support.check_warnings(("", DeprecationWarning), quite=False): - return EventCollector(strict=True) + return EventCollector(strict=True, convert_charrefs=False) def test_processing_instruction_only(self): self._run_check("<?processing instruction>", [ @@ -339,7 +355,7 @@ text self._run_check(s, [("starttag", element_lower, []), ("data", content), ("endtag", element_lower)], - collector=Collector()) + collector=Collector(convert_charrefs=False)) def test_comments(self): html = ("<!-- I'm a valid comment -->" @@ -367,11 +383,60 @@ text ('comment', '[if lte IE 7]>pretty?<![endif]')] self._run_check(html, expected) + def test_convert_charrefs(self): + collector = lambda: EventCollectorCharrefs(convert_charrefs=True) + self.assertTrue(collector().convert_charrefs) + charrefs = ['"', '"', '"', '"', '"', '"'] + # check charrefs in the middle of the text/attributes + expected = [('starttag', 'a', [('href', 'foo"zar')]), + ('data', 'a"z'), ('endtag', 'a')] + for charref in charrefs: + self._run_check('<a href="foo{0}zar">a{0}z</a>'.format(charref), + expected, collector=collector()) + # check charrefs at the beginning/end of the text/attributes + expected = [('data', '"'), + ('starttag', 'a', [('x', '"'), ('y', '"X'), ('z', 'X"')]), + ('data', '"'), ('endtag', 'a'), ('data', '"')] + for charref in charrefs: + self._run_check('{0}<a x="{0}" y="{0}X" z="X{0}">' + '{0}</a>{0}'.format(charref), + expected, collector=collector()) + # check charrefs in <script>/<style> elements + for charref in charrefs: + text = 'X'.join([charref]*3) + expected = [('data', '"'), + ('starttag', 'script', []), ('data', text), + ('endtag', 'script'), ('data', '"'), + ('starttag', 'style', []), ('data', text), + ('endtag', 'style'), ('data', '"')] + self._run_check('{1}<script>{0}</script>{1}' + '<style>{0}</style>{1}'.format(text, charref), + expected, collector=collector()) + # check truncated charrefs at the end of the file + html = '&quo &# &#x' + for x in range(1, len(html)): + self._run_check(html[:x], [('data', html[:x])], + collector=collector()) + # check a string with no charrefs + self._run_check('no charrefs here', [('data', 'no charrefs here')], + collector=collector()) + class HTMLParserTolerantTestCase(HTMLParserStrictTestCase): def get_collector(self): - return EventCollector(strict=False) + return EventCollector(convert_charrefs=False) + + def test_deprecation_warnings(self): + with self.assertWarns(DeprecationWarning): + EventCollector() # convert_charrefs not passed explicitly + with self.assertWarns(DeprecationWarning): + EventCollector(strict=True) + with self.assertWarns(DeprecationWarning): + EventCollector(strict=False) + with self.assertRaises(html.parser.HTMLParseError): + with self.assertWarns(DeprecationWarning): + EventCollector().error('test') def test_tolerant_parsing(self): self._run_check('<html <html>te>>xt&a<<bc</a></html>\n' @@ -564,17 +629,12 @@ class HTMLParserTolerantTestCase(HTMLParserStrictTestCase): for html, expected in data: self._run_check(html, expected) - def test_unescape_function(self): + def test_unescape_method(self): + from html import unescape p = self.get_collector() - self.assertEqual(p.unescape('&#bad;'),'&#bad;') - self.assertEqual(p.unescape('&'),'&') - # see #12888 - self.assertEqual(p.unescape('{ ' * 1050), '{ ' * 1050) - # see #15156 - self.assertEqual(p.unescape('ÉricÉric' - '&alphacentauriαcentauri'), - 'ÉricÉric&alphacentauriαcentauri') - self.assertEqual(p.unescape('&co;'), '&co;') + with self.assertWarns(DeprecationWarning): + s = '""""""&#bad;' + self.assertEqual(p.unescape(s), unescape(s)) def test_broken_comments(self): html = ('<! not really a comment >' @@ -630,7 +690,7 @@ class AttributesStrictTestCase(TestCaseBase): def get_collector(self): with support.check_warnings(("", DeprecationWarning), quite=False): - return EventCollector(strict=True) + return EventCollector(strict=True, convert_charrefs=False) def test_attr_syntax(self): output = [ @@ -691,7 +751,7 @@ class AttributesStrictTestCase(TestCaseBase): class AttributesTolerantTestCase(AttributesStrictTestCase): def get_collector(self): - return EventCollector(strict=False) + return EventCollector(convert_charrefs=False) def test_attr_funky_names2(self): self._run_check( diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index c8ded92478..30b6c0cfcb 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -27,8 +27,10 @@ class FakeSocket: self.text = text self.fileclass = fileclass self.data = b'' + self.sendall_calls = 0 def sendall(self, data): + self.sendall_calls += 1 self.data += data def makefile(self, mode, bufsize=None): @@ -45,7 +47,7 @@ class EPipeSocket(FakeSocket): def sendall(self, data): if self.pipe_trigger in data: - raise socket.error(errno.EPIPE, "gotcha") + raise OSError(errno.EPIPE, "gotcha") self.data += data def close(self): @@ -600,7 +602,7 @@ class BasicTest(TestCase): b"Content-Length") conn = client.HTTPConnection("example.com") conn.sock = sock - self.assertRaises(socket.error, + self.assertRaises(OSError, lambda: conn.request("PUT", "/url", "body")) resp = conn.getresponse() self.assertEqual(401, resp.status) @@ -646,6 +648,28 @@ class BasicTest(TestCase): resp.close() self.assertTrue(resp.closed) + def test_delayed_ack_opt(self): + # Test that Nagle/delayed_ack optimistaion works correctly. + + # For small payloads, it should coalesce the body with + # headers, resulting in a single sendall() call + conn = client.HTTPConnection('example.com') + sock = FakeSocket(None) + conn.sock = sock + body = b'x' * (conn.mss - 1) + conn.request('POST', '/', body) + self.assertEqual(sock.sendall_calls, 1) + + # For large payloads, it should send the headers and + # then the body, resulting in more than one sendall() + # call + conn = client.HTTPConnection('example.com') + sock = FakeSocket(None) + conn.sock = sock + body = b'x' * conn.mss + conn.request('POST', '/', body) + self.assertGreater(sock.sendall_calls, 1) + class OfflineTest(TestCase): def test_responses(self): self.assertEqual(client.responses[client.NOT_FOUND], "Not Found") @@ -736,7 +760,7 @@ class HTTPSTest(TestCase): def make_server(self, certfile): from test.ssl_servers import make_https_server - return make_https_server(self, certfile) + return make_https_server(self, certfile=certfile) def test_attributes(self): # simple test to check it's storing the timeout diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 4e62963725..15694afd1d 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -92,6 +92,13 @@ class BaseHTTPServerTestCase(BaseTestCase): def do_KEYERROR(self): self.send_error(999) + def do_NOTFOUND(self): + self.send_error(404) + + def do_EXPLAINERROR(self): + self.send_error(999, "Short Message", + "This is a long \n explaination") + def do_CUSTOM(self): self.send_response(999) self.send_header('Content-Type', 'text/html') @@ -203,6 +210,12 @@ class BaseHTTPServerTestCase(BaseTestCase): res = self.con.getresponse() self.assertEqual(res.status, 999) + def test_return_explain_error(self): + self.con.request('EXPLAINERROR', '/') + res = self.con.getresponse() + self.assertEqual(res.status, 999) + self.assertTrue(int(res.getheader('Content-Length'))) + def test_latin1_header(self): self.con.request('LATINONEHEADER', '/', headers={ 'X-Special-Incoming': 'Ärger mit Unicode' @@ -211,6 +224,14 @@ class BaseHTTPServerTestCase(BaseTestCase): self.assertEqual(res.getheader('X-Special'), 'Dängerous Mind') self.assertEqual(res.read(), 'Ärger mit Unicode'.encode('utf-8')) + def test_error_content_length(self): + # Issue #16088: standard error responses should have a content-length + self.con.request('NOTFOUND', '/') + res = self.con.getresponse() + self.assertEqual(res.status, 404) + data = res.read() + self.assertEqual(int(res.getheader('Content-Length')), len(data)) + class SimpleHTTPServerTestCase(BaseTestCase): class request_handler(NoLogRequestHandler, SimpleHTTPRequestHandler): diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py index 7c9afd9f39..5e0eaeab2b 100644 --- a/Lib/test/test_imaplib.py +++ b/Lib/test/test_imaplib.py @@ -18,8 +18,12 @@ try: import ssl except ImportError: ssl = None + HAS_SNI = False +else: + from ssl import HAS_SNI CERTFILE = None +CAFILE = None class TestImaplib(unittest.TestCase): @@ -125,7 +129,7 @@ class SimpleIMAPHandler(socketserver.StreamRequestHandler): # Naked sockets return empty strings.. return line += part - except IOError: + except OSError: # ..but SSLSockets raise exceptions. return if line.endswith(b'\r\n'): @@ -347,6 +351,26 @@ class ThreadedNetworkedTestsSSL(BaseThreadedNetworkedTests): server_class = SecureTCPServer imap_class = IMAP4_SSL + @reap_threads + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_ssl_verified(self): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_verify_locations(CAFILE) + + with self.assertRaisesRegex(ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + with self.reaped_server(SimpleIMAPHandler) as server: + client = self.imap_class(*server.server_address, + ssl_context=ssl_context) + client.shutdown() + + with self.reaped_server(SimpleIMAPHandler) as server: + client = self.imap_class("localhost", server.server_address[1], + ssl_context=ssl_context) + client.shutdown() + class RemoteIMAPTest(unittest.TestCase): host = 'cyrus.andrew.cmu.edu' @@ -459,11 +483,15 @@ def load_tests(*args): if support.is_resource_enabled('network'): if ssl: - global CERTFILE + global CERTFILE, CAFILE CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, - "keycert.pem") + "keycert3.pem") if not os.path.exists(CERTFILE): raise support.TestFailed("Can't read certificate files!") + CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "pycacert.pem") + if not os.path.exists(CAFILE): + raise support.TestFailed("Can't read CA file!") tests.extend([ ThreadedNetworkedTests, ThreadedNetworkedTestsSSL, RemoteIMAPTest, RemoteIMAP_SSLTest, RemoteIMAP_STARTTLSTest, diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py index 71c220f03e..024f43894b 100644 --- a/Lib/test/test_imp.py +++ b/Lib/test/test_imp.py @@ -1,14 +1,29 @@ -import imp +try: + import _thread +except ImportError: + _thread = None import importlib import os import os.path import shutil import sys from test import support -from test.test_importlib import util import unittest import warnings +with warnings.catch_warnings(): + warnings.simplefilter('ignore', PendingDeprecationWarning) + import imp + +def requires_load_dynamic(meth): + """Decorator to skip a test if not running under CPython or lacking + imp.load_dynamic().""" + meth = support.cpython_only(meth) + return unittest.skipIf(not hasattr(imp, 'load_dynamic'), + 'imp.load_dynamic() required')(meth) + + +@unittest.skipIf(_thread is None, '_thread module is required') class LockTests(unittest.TestCase): """Very basic test of import lock functions.""" @@ -150,7 +165,7 @@ class ImportTests(unittest.TestCase): self.assertIsNotNone(file) self.assertTrue(filename[:-3].endswith(temp_mod_name)) self.assertEqual(info[0], '.py') - self.assertEqual(info[1], 'U') + self.assertEqual(info[1], 'r') self.assertEqual(info[2], imp.PY_SOURCE) mod = imp.load_module(temp_mod_name, file, filename, info) @@ -208,9 +223,7 @@ class ImportTests(unittest.TestCase): self.assertIs(orig_path, new_os.path) self.assertIsNot(orig_getenv, new_os.getenv) - @support.cpython_only - @unittest.skipIf(not hasattr(imp, 'load_dynamic'), - 'imp.load_dynamic() required') + @requires_load_dynamic def test_issue15828_load_extensions(self): # Issue 15828 picked up that the adapter between the old imp API # and importlib couldn't handle C extensions @@ -222,6 +235,22 @@ class ImportTests(unittest.TestCase): mod = imp.load_module(example, *x) self.assertEqual(mod.__name__, example) + @requires_load_dynamic + def test_issue16421_multiple_modules_in_one_dll(self): + # Issue 16421: loading several modules from the same compiled file fails + m = '_testimportmultiple' + fileobj, pathname, description = imp.find_module(m) + fileobj.close() + mod0 = imp.load_dynamic(m, pathname) + mod1 = imp.load_dynamic('_testimportmultiple_foo', pathname) + mod2 = imp.load_dynamic('_testimportmultiple_bar', pathname) + self.assertEqual(mod0.__name__, m) + self.assertEqual(mod1.__name__, '_testimportmultiple_foo') + self.assertEqual(mod2.__name__, '_testimportmultiple_bar') + with self.assertRaises(ImportError): + imp.load_dynamic('nonexistent', pathname) + + @requires_load_dynamic def test_load_dynamic_ImportError_path(self): # Issue #1559549 added `name` and `path` attributes to ImportError # in order to provide better detail. Issue #10854 implemented those @@ -233,14 +262,12 @@ class ImportTests(unittest.TestCase): self.assertIn(path, err.exception.path) self.assertEqual(name, err.exception.name) - @support.cpython_only - @unittest.skipIf(not hasattr(imp, 'load_dynamic'), - 'imp.load_dynamic() required') + @requires_load_dynamic def test_load_module_extension_file_is_None(self): # When loading an extension module and the file is None, open one # on the behalf of imp.load_dynamic(). # Issue #15902 - name = '_heapq' + name = '_testimportmultiple' found = imp.find_module(name) if found[0] is not None: found[0].close() @@ -248,6 +275,15 @@ class ImportTests(unittest.TestCase): self.skipTest("found module doesn't appear to be a C extension") imp.load_module(name, None, *found[1:]) + @unittest.skipIf(sys.dont_write_bytecode, + "test meaningful only when writing bytecode") + def test_bug7732(self): + with support.temp_cwd(): + source = support.TESTFN + '.py' + os.mkdir(source) + self.assertRaisesRegex(ImportError, '^No module', + imp.find_module, support.TESTFN, ["."]) + def test_multiple_calls_to_get_data(self): # Issue #18755: make sure multiple calls to get_data() can succeed. loader = imp._LoadSourceCompatibility('imp', imp.__file__, @@ -293,22 +329,6 @@ class ReloadTests(unittest.TestCase): with self.assertRaisesRegex(ImportError, 'html'): imp.reload(parser) - def test_module_replaced(self): - # see #18698 - def code(): - module = type(sys)('top_level') - module.spam = 3 - sys.modules['top_level'] = module - mock = util.mock_modules('top_level', - module_code={'top_level': code}) - with mock: - with util.import_state(meta_path=[mock]): - module = importlib.import_module('top_level') - reloaded = imp.reload(module) - actual = sys.modules['top_level'] - self.assertEqual(actual.spam, 3) - self.assertEqual(reloaded.spam, 3) - class PEP3147Tests(unittest.TestCase): """Tests of PEP 3147.""" @@ -461,20 +481,5 @@ class NullImporterTests(unittest.TestCase): os.rmdir(name) -def test_main(): - tests = [ - ImportTests, - PEP3147Tests, - ReloadTests, - NullImporterTests, - ] - try: - import _thread - except ImportError: - pass - else: - tests.append(LockTests) - support.run_unittest(*tests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py index df08e6a86f..bda1541fdb 100644 --- a/Lib/test/test_import.py +++ b/Lib/test/test_import.py @@ -1,9 +1,8 @@ # We import importlib *ASAP* in order to test #15386 import importlib +import importlib.util from importlib._bootstrap import _get_sourcefile import builtins -import imp -from test.test_importlib.import_ import util as importlib_util import marshal import os import platform @@ -22,7 +21,7 @@ import test.support from test.support import ( EnvironmentVarGuard, TESTFN, check_warnings, forget, is_jython, make_legacy_pyc, rmtree, run_unittest, swap_attr, swap_item, temp_umask, - unlink, unload, create_empty_file, cpython_only) + unlink, unload, create_empty_file, cpython_only, TESTFN_UNENCODABLE) from test import script_helper @@ -70,8 +69,6 @@ class ImportTests(unittest.TestCase): def tearDown(self): unload(TESTFN) - setUp = tearDown - def test_case_sensitivity(self): # Brief digression to test that import is case-sensitive: if we got # this far, we know for sure that "random" exists. @@ -129,16 +126,6 @@ class ImportTests(unittest.TestCase): finally: del sys.path[0] - @skip_if_dont_write_bytecode - def test_bug7732(self): - source = TESTFN + '.py' - os.mkdir(source) - try: - self.assertRaisesRegex(ImportError, '^No module', - imp.find_module, TESTFN, ["."]) - finally: - os.rmdir(source) - def test_module_with_large_stack(self, module='longlist'): # Regression test for http://bugs.python.org/issue561858. filename = module + '.py' @@ -161,16 +148,24 @@ class ImportTests(unittest.TestCase): sys.path.append('') importlib.invalidate_caches() + namespace = {} try: make_legacy_pyc(filename) # This used to crash. - exec('import ' + module) + exec('import ' + module, None, namespace) finally: # Cleanup. del sys.path[-1] unlink(filename + 'c') unlink(filename + 'o') + # Remove references to the module (unload the module) + namespace.clear() + try: + del sys.modules[module] + except KeyError: + pass + def test_failing_import_sticks(self): source = TESTFN + ".py" with open(source, "w") as f: @@ -225,7 +220,7 @@ class ImportTests(unittest.TestCase): with open(source, "w") as f: f.write("a = 10\nb=20//0\n") - self.assertRaises(ZeroDivisionError, imp.reload, mod) + self.assertRaises(ZeroDivisionError, importlib.reload, mod) # But we still expect the module to be in sys.modules. mod = sys.modules.get(TESTFN) self.assertIsNot(mod, None, "expected module to be in sys.modules") @@ -280,7 +275,7 @@ class ImportTests(unittest.TestCase): import sys class C: def __del__(self): - import imp + import importlib sys.argv.insert(0, C()) """)) script_helper.assert_python_ok(testfn) @@ -291,7 +286,7 @@ class ImportTests(unittest.TestCase): sys.path.insert(0, os.curdir) try: source = TESTFN + ".py" - compiled = imp.cache_from_source(source) + compiled = importlib.util.cache_from_source(source) with open(source, 'w') as f: pass try: @@ -322,6 +317,14 @@ class ImportTests(unittest.TestCase): stdout, stderr = popen.communicate() self.assertIn(b"ImportError", stdout) + def test_from_import_message_for_nonexistent_module(self): + with self.assertRaisesRegex(ImportError, "^No module named 'bogus'"): + from bogus import foo + + def test_from_import_message_for_existing_module(self): + with self.assertRaisesRegex(ImportError, "^cannot import name 'bogus'"): + from re import bogus + @skip_if_dont_write_bytecode class FilePermissionTests(unittest.TestCase): @@ -332,7 +335,7 @@ class FilePermissionTests(unittest.TestCase): def test_creation_mode(self): mask = 0o022 with temp_umask(mask), _ready_to_import() as (name, path): - cached_path = imp.cache_from_source(path) + cached_path = importlib.util.cache_from_source(path) module = __import__(name) if not os.path.exists(cached_path): self.fail("__import__ did not result in creation of " @@ -350,7 +353,7 @@ class FilePermissionTests(unittest.TestCase): # permissions of .pyc should match those of .py, regardless of mask mode = 0o600 with temp_umask(0o022), _ready_to_import() as (name, path): - cached_path = imp.cache_from_source(path) + cached_path = importlib.util.cache_from_source(path) os.chmod(path, mode) __import__(name) if not os.path.exists(cached_path): @@ -365,7 +368,7 @@ class FilePermissionTests(unittest.TestCase): def test_cached_readonly(self): mode = 0o400 with temp_umask(0o022), _ready_to_import() as (name, path): - cached_path = imp.cache_from_source(path) + cached_path = importlib.util.cache_from_source(path) os.chmod(path, mode) __import__(name) if not os.path.exists(cached_path): @@ -405,7 +408,7 @@ class FilePermissionTests(unittest.TestCase): bytecode_only = path + "c" else: bytecode_only = path + "o" - os.rename(imp.cache_from_source(path), bytecode_only) + os.rename(importlib.util.cache_from_source(path), bytecode_only) m = __import__(name) self.assertEqual(m.x, 'rewritten') @@ -427,7 +430,7 @@ func_filename = func.__code__.co_filename """ dir_name = os.path.abspath(TESTFN) file_name = os.path.join(dir_name, module_name) + os.extsep + "py" - compiled_name = imp.cache_from_source(file_name) + compiled_name = importlib.util.cache_from_source(file_name) def setUp(self): self.sys_path = sys.path[:] @@ -488,7 +491,7 @@ func_filename = func.__code__.co_filename header = f.read(12) code = marshal.load(f) constants = list(code.co_consts) - foreign_code = test_main.__code__ + foreign_code = importlib.import_module.__code__ pos = constants.index(1) constants[pos] = foreign_code code = type(code)(code.co_argcount, code.co_kwonlyargcount, @@ -630,7 +633,7 @@ class OverridingImportBuiltinTests(unittest.TestCase): class PycacheTests(unittest.TestCase): # Test the various PEP 3147 related behaviors. - tag = imp.get_tag() + tag = sys.implementation.cache_tag def _clean(self): forget(TESTFN) @@ -678,10 +681,11 @@ class PycacheTests(unittest.TestCase): # With PEP 3147 cache layout, removing the source but leaving the pyc # file does not satisfy the import. __import__(TESTFN) - pyc_file = imp.cache_from_source(self.source) + pyc_file = importlib.util.cache_from_source(self.source) self.assertTrue(os.path.exists(pyc_file)) os.remove(self.source) forget(TESTFN) + importlib.invalidate_caches() self.assertRaises(ImportError, __import__, TESTFN) @skip_if_dont_write_bytecode @@ -703,7 +707,7 @@ class PycacheTests(unittest.TestCase): def test___cached__(self): # Modules now also have an __cached__ that points to the pyc file. m = __import__(TESTFN) - pyc_file = imp.cache_from_source(TESTFN + '.py') + pyc_file = importlib.util.cache_from_source(TESTFN + '.py') self.assertEqual(m.__cached__, os.path.join(os.curdir, pyc_file)) @skip_if_dont_write_bytecode @@ -738,10 +742,10 @@ class PycacheTests(unittest.TestCase): pass importlib.invalidate_caches() m = __import__('pep3147.foo') - init_pyc = imp.cache_from_source( + init_pyc = importlib.util.cache_from_source( os.path.join('pep3147', '__init__.py')) self.assertEqual(m.__cached__, os.path.join(os.curdir, init_pyc)) - foo_pyc = imp.cache_from_source(os.path.join('pep3147', 'foo.py')) + foo_pyc = importlib.util.cache_from_source(os.path.join('pep3147', 'foo.py')) self.assertEqual(sys.modules['pep3147.foo'].__cached__, os.path.join(os.curdir, foo_pyc)) @@ -765,10 +769,10 @@ class PycacheTests(unittest.TestCase): unload('pep3147') importlib.invalidate_caches() m = __import__('pep3147.foo') - init_pyc = imp.cache_from_source( + init_pyc = importlib.util.cache_from_source( os.path.join('pep3147', '__init__.py')) self.assertEqual(m.__cached__, os.path.join(os.curdir, init_pyc)) - foo_pyc = imp.cache_from_source(os.path.join('pep3147', 'foo.py')) + foo_pyc = importlib.util.cache_from_source(os.path.join('pep3147', 'foo.py')) self.assertEqual(sys.modules['pep3147.foo'].__cached__, os.path.join(os.curdir, foo_pyc)) @@ -852,7 +856,6 @@ class ImportlibBootstrapTests(unittest.TestCase): from importlib import machinery mod = sys.modules['_frozen_importlib'] self.assertIs(machinery.FileFinder, mod.FileFinder) - self.assertIs(imp.new_module, mod.new_module) @cpython_only @@ -1033,11 +1036,14 @@ class ImportTracebackTests(unittest.TestCase): # away from the traceback. self.create_module("foo", "") importlib = sys.modules['_frozen_importlib'] - old_load_module = importlib.SourceLoader.load_module + if 'load_module' in vars(importlib.SourceLoader): + old_exec_module = importlib.SourceLoader.exec_module + else: + old_exec_module = None try: - def load_module(*args): + def exec_module(*args): 1/0 - importlib.SourceLoader.load_module = load_module + importlib.SourceLoader.exec_module = exec_module try: import foo except ZeroDivisionError as e: @@ -1046,19 +1052,21 @@ class ImportTracebackTests(unittest.TestCase): self.fail("ZeroDivisionError should have been raised") self.assert_traceback(tb, [__file__, '<frozen importlib', __file__]) finally: - importlib.SourceLoader.load_module = old_load_module - + if old_exec_module is None: + del importlib.SourceLoader.exec_module + else: + importlib.SourceLoader.exec_module = old_exec_module -def test_main(verbose=None): - run_unittest(ImportTests, PycacheTests, FilePermissionTests, - PycRewritingTests, PathsTests, RelativeImportTests, - OverridingImportBuiltinTests, - ImportlibBootstrapTests, GetSourcefileTests, - TestSymbolicallyLinkedPackage, - ImportTracebackTests) + @unittest.skipUnless(TESTFN_UNENCODABLE, 'need TESTFN_UNENCODABLE') + def test_unencodable_filename(self): + # Issue #11619: The Python parser and the import machinery must not + # encode filenames, especially on Windows + pyname = script_helper.make_script('', TESTFN_UNENCODABLE, 'pass') + name = pyname[:-3] + script_helper.assert_python_ok("-c", "mod = __import__(%a)" % name, + __isolated=False) if __name__ == '__main__': # Test needs to be a package, so we can do relative imports. - from test.test_import import test_main - test_main() + unittest.main() diff --git a/Lib/test/test_importhooks.py b/Lib/test/test_importhooks.py deleted file mode 100644 index 2a22d1a186..0000000000 --- a/Lib/test/test_importhooks.py +++ /dev/null @@ -1,250 +0,0 @@ -import sys -import imp -import os -import unittest -from test import support - - -test_src = """\ -def get_name(): - return __name__ -def get_file(): - return __file__ -""" - -absimp = "import sub\n" -relimp = "from . import sub\n" -deeprelimp = "from .... import sub\n" -futimp = "from __future__ import absolute_import\n" - -reload_src = test_src+"""\ -reloaded = True -""" - -test_co = compile(test_src, "<???>", "exec") -reload_co = compile(reload_src, "<???>", "exec") - -test2_oldabs_co = compile(absimp + test_src, "<???>", "exec") -test2_newabs_co = compile(futimp + absimp + test_src, "<???>", "exec") -test2_newrel_co = compile(relimp + test_src, "<???>", "exec") -test2_deeprel_co = compile(deeprelimp + test_src, "<???>", "exec") -test2_futrel_co = compile(futimp + relimp + test_src, "<???>", "exec") - -test_path = "!!!_test_!!!" - - -class TestImporter: - - modules = { - "hooktestmodule": (False, test_co), - "hooktestpackage": (True, test_co), - "hooktestpackage.sub": (True, test_co), - "hooktestpackage.sub.subber": (True, test_co), - "hooktestpackage.oldabs": (False, test2_oldabs_co), - "hooktestpackage.newabs": (False, test2_newabs_co), - "hooktestpackage.newrel": (False, test2_newrel_co), - "hooktestpackage.sub.subber.subest": (True, test2_deeprel_co), - "hooktestpackage.futrel": (False, test2_futrel_co), - "sub": (False, test_co), - "reloadmodule": (False, test_co), - } - - def __init__(self, path=test_path): - if path != test_path: - # if our class is on sys.path_hooks, we must raise - # ImportError for any path item that we can't handle. - raise ImportError - self.path = path - - def _get__path__(self): - raise NotImplementedError - - def find_module(self, fullname, path=None): - if fullname in self.modules: - return self - else: - return None - - def load_module(self, fullname): - ispkg, code = self.modules[fullname] - mod = sys.modules.setdefault(fullname,imp.new_module(fullname)) - mod.__file__ = "<%s>" % self.__class__.__name__ - mod.__loader__ = self - if ispkg: - mod.__path__ = self._get__path__() - exec(code, mod.__dict__) - return mod - - -class MetaImporter(TestImporter): - def _get__path__(self): - return [] - -class PathImporter(TestImporter): - def _get__path__(self): - return [self.path] - - -class ImportBlocker: - """Place an ImportBlocker instance on sys.meta_path and you - can be sure the modules you specified can't be imported, even - if it's a builtin.""" - def __init__(self, *namestoblock): - self.namestoblock = dict.fromkeys(namestoblock) - def find_module(self, fullname, path=None): - if fullname in self.namestoblock: - return self - return None - def load_module(self, fullname): - raise ImportError("I dare you") - - -class ImpWrapper: - - def __init__(self, path=None): - if path is not None and not os.path.isdir(path): - raise ImportError - self.path = path - - def find_module(self, fullname, path=None): - subname = fullname.split(".")[-1] - if subname != fullname and self.path is None: - return None - if self.path is None: - path = None - else: - path = [self.path] - try: - file, filename, stuff = imp.find_module(subname, path) - except ImportError: - return None - return ImpLoader(file, filename, stuff) - - -class ImpLoader: - - def __init__(self, file, filename, stuff): - self.file = file - self.filename = filename - self.stuff = stuff - - def load_module(self, fullname): - mod = imp.load_module(fullname, self.file, self.filename, self.stuff) - if self.file: - self.file.close() - mod.__loader__ = self # for introspection - return mod - - -class ImportHooksBaseTestCase(unittest.TestCase): - - def setUp(self): - self.path = sys.path[:] - self.meta_path = sys.meta_path[:] - self.path_hooks = sys.path_hooks[:] - sys.path_importer_cache.clear() - self.modules_before = support.modules_setup() - - def tearDown(self): - sys.path[:] = self.path - sys.meta_path[:] = self.meta_path - sys.path_hooks[:] = self.path_hooks - sys.path_importer_cache.clear() - support.modules_cleanup(*self.modules_before) - - -class ImportHooksTestCase(ImportHooksBaseTestCase): - - def doTestImports(self, importer=None): - import hooktestmodule - import hooktestpackage - import hooktestpackage.sub - import hooktestpackage.sub.subber - self.assertEqual(hooktestmodule.get_name(), - "hooktestmodule") - self.assertEqual(hooktestpackage.get_name(), - "hooktestpackage") - self.assertEqual(hooktestpackage.sub.get_name(), - "hooktestpackage.sub") - self.assertEqual(hooktestpackage.sub.subber.get_name(), - "hooktestpackage.sub.subber") - if importer: - self.assertEqual(hooktestmodule.__loader__, importer) - self.assertEqual(hooktestpackage.__loader__, importer) - self.assertEqual(hooktestpackage.sub.__loader__, importer) - self.assertEqual(hooktestpackage.sub.subber.__loader__, importer) - - TestImporter.modules['reloadmodule'] = (False, test_co) - import reloadmodule - self.assertFalse(hasattr(reloadmodule,'reloaded')) - - import hooktestpackage.newrel - self.assertEqual(hooktestpackage.newrel.get_name(), - "hooktestpackage.newrel") - self.assertEqual(hooktestpackage.newrel.sub, - hooktestpackage.sub) - - import hooktestpackage.sub.subber.subest as subest - self.assertEqual(subest.get_name(), - "hooktestpackage.sub.subber.subest") - self.assertEqual(subest.sub, - hooktestpackage.sub) - - import hooktestpackage.futrel - self.assertEqual(hooktestpackage.futrel.get_name(), - "hooktestpackage.futrel") - self.assertEqual(hooktestpackage.futrel.sub, - hooktestpackage.sub) - - import sub - self.assertEqual(sub.get_name(), "sub") - - import hooktestpackage.oldabs - self.assertEqual(hooktestpackage.oldabs.get_name(), - "hooktestpackage.oldabs") - self.assertEqual(hooktestpackage.oldabs.sub, sub) - - import hooktestpackage.newabs - self.assertEqual(hooktestpackage.newabs.get_name(), - "hooktestpackage.newabs") - self.assertEqual(hooktestpackage.newabs.sub, sub) - - def testMetaPath(self): - i = MetaImporter() - sys.meta_path.append(i) - self.doTestImports(i) - - def testPathHook(self): - sys.path_hooks.insert(0, PathImporter) - sys.path.append(test_path) - self.doTestImports() - - def testBlocker(self): - mname = "exceptions" # an arbitrary harmless builtin module - support.unload(mname) - sys.meta_path.append(ImportBlocker(mname)) - self.assertRaises(ImportError, __import__, mname) - - def testImpWrapper(self): - i = ImpWrapper() - sys.meta_path.append(i) - sys.path_hooks.insert(0, ImpWrapper) - mnames = ( - "colorsys", "urllib.parse", "distutils.core", "sys", - ) - for mname in mnames: - parent = mname.split(".")[0] - for n in list(sys.modules): - if n.startswith(parent): - del sys.modules[n] - for mname in mnames: - m = __import__(mname, globals(), locals(), ["__dummy__"]) - # to make sure we actually handled the import - self.assertTrue(hasattr(m, "__loader__")) - - -def test_main(): - support.run_unittest(ImportHooksTestCase) - -if __name__ == "__main__": - test_main() diff --git a/Lib/test/test_importlib/__main__.py b/Lib/test/test_importlib/__main__.py index c39712871f..14bd5bc017 100644 --- a/Lib/test/test_importlib/__main__.py +++ b/Lib/test/test_importlib/__main__.py @@ -4,17 +4,6 @@ Specifying the ``--builtin`` flag will run tests, where applicable, with builtins.__import__ instead of importlib.__import__. """ -from . import test_main - - if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser(description='Execute the importlib test ' - 'suite') - parser.add_argument('-b', '--builtin', action='store_true', default=False, - help='use builtins.__import__() instead of importlib') - args = parser.parse_args() - if args.builtin: - util.using___import__ = True + from . import test_main test_main() diff --git a/Lib/test/test_importlib/abc.py b/Lib/test/test_importlib/abc.py index 2c17ac329b..2070dadc23 100644 --- a/Lib/test/test_importlib/abc.py +++ b/Lib/test/test_importlib/abc.py @@ -2,7 +2,7 @@ import abc import unittest -class FinderTests(unittest.TestCase, metaclass=abc.ABCMeta): +class FinderTests(metaclass=abc.ABCMeta): """Basic tests for a finder to pass.""" @@ -39,7 +39,7 @@ class FinderTests(unittest.TestCase, metaclass=abc.ABCMeta): pass -class LoaderTests(unittest.TestCase, metaclass=abc.ABCMeta): +class LoaderTests(metaclass=abc.ABCMeta): @abc.abstractmethod def test_module(self): @@ -81,11 +81,6 @@ class LoaderTests(unittest.TestCase, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def test_module_reuse(self): - """If a module is already in sys.modules, it should be reused.""" - pass - - @abc.abstractmethod def test_state_after_failure(self): """If a module is already in sys.modules and a reload fails (e.g. a SyntaxError), the module should be in the state it was before diff --git a/Lib/test/test_importlib/builtin/test_finder.py b/Lib/test/test_importlib/builtin/test_finder.py index da0ef28589..934562feb6 100644 --- a/Lib/test/test_importlib/builtin/test_finder.py +++ b/Lib/test/test_importlib/builtin/test_finder.py @@ -1,11 +1,53 @@ -from importlib import machinery from .. import abc from .. import util from . import util as builtin_util +frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') + import sys import unittest + +class FindSpecTests(abc.FinderTests): + + """Test find_spec() for built-in modules.""" + + def test_module(self): + # Common case. + with util.uncache(builtin_util.NAME): + found = self.machinery.BuiltinImporter.find_spec(builtin_util.NAME) + self.assertTrue(found) + self.assertEqual(found.origin, 'built-in') + + # Built-in modules cannot be a package. + test_package = None + + # Built-in modules cannobt be in a package. + test_module_in_package = None + + # Built-in modules cannot be a package. + test_package_in_package = None + + # Built-in modules cannot be a package. + test_package_over_module = None + + def test_failure(self): + name = 'importlib' + assert name not in sys.builtin_module_names + spec = self.machinery.BuiltinImporter.find_spec(name) + self.assertIsNone(spec) + + def test_ignore_path(self): + # The value for 'path' should always trigger a failed import. + with util.uncache(builtin_util.NAME): + spec = self.machinery.BuiltinImporter.find_spec(builtin_util.NAME, + ['pkg']) + self.assertIsNone(spec) + +Frozen_FindSpecTests, Source_FindSpecTests = util.test_both(FindSpecTests, + machinery=[frozen_machinery, source_machinery]) + + class FinderTests(abc.FinderTests): """Test find_module() for built-in modules.""" @@ -13,8 +55,9 @@ class FinderTests(abc.FinderTests): def test_module(self): # Common case. with util.uncache(builtin_util.NAME): - found = machinery.BuiltinImporter.find_module(builtin_util.NAME) + found = self.machinery.BuiltinImporter.find_module(builtin_util.NAME) self.assertTrue(found) + self.assertTrue(hasattr(found, 'load_module')) # Built-in modules cannot be a package. test_package = test_package_in_package = test_package_over_module = None @@ -24,22 +67,19 @@ class FinderTests(abc.FinderTests): def test_failure(self): assert 'importlib' not in sys.builtin_module_names - loader = machinery.BuiltinImporter.find_module('importlib') + loader = self.machinery.BuiltinImporter.find_module('importlib') self.assertIsNone(loader) def test_ignore_path(self): # The value for 'path' should always trigger a failed import. with util.uncache(builtin_util.NAME): - loader = machinery.BuiltinImporter.find_module(builtin_util.NAME, + loader = self.machinery.BuiltinImporter.find_module(builtin_util.NAME, ['pkg']) self.assertIsNone(loader) - - -def test_main(): - from test.support import run_unittest - run_unittest(FinderTests) +Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests, + machinery=[frozen_machinery, source_machinery]) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/builtin/test_loader.py b/Lib/test/test_importlib/builtin/test_loader.py index 12a79c6ba9..a636f778fa 100644 --- a/Lib/test/test_importlib/builtin/test_loader.py +++ b/Lib/test/test_importlib/builtin/test_loader.py @@ -1,9 +1,9 @@ -import importlib -from importlib import machinery from .. import abc from .. import util from . import util as builtin_util +frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') + import sys import types import unittest @@ -13,8 +13,9 @@ class LoaderTests(abc.LoaderTests): """Test load_module() for built-in modules.""" - verification = {'__name__': 'errno', '__package__': '', - '__loader__': machinery.BuiltinImporter} + def setUp(self): + self.verification = {'__name__': 'errno', '__package__': '', + '__loader__': self.machinery.BuiltinImporter} def verify(self, module): """Verify that the module matches against what it should have.""" @@ -23,8 +24,8 @@ class LoaderTests(abc.LoaderTests): self.assertEqual(getattr(module, attr), value) self.assertIn(module.__name__, sys.modules) - load_module = staticmethod(lambda name: - machinery.BuiltinImporter.load_module(name)) + def load_module(self, name): + return self.machinery.BuiltinImporter.load_module(name) def test_module(self): # Common case. @@ -55,45 +56,51 @@ class LoaderTests(abc.LoaderTests): def test_already_imported(self): # Using the name of a module already imported but not a built-in should # still fail. - assert hasattr(importlib, '__file__') # Not a built-in. + module_name = 'builtin_reload_test' + assert module_name not in sys.builtin_module_names + with util.uncache(module_name): + module = types.ModuleType(module_name) + sys.modules[module_name] = module with self.assertRaises(ImportError) as cm: - self.load_module('importlib') - self.assertEqual(cm.exception.name, 'importlib') + self.load_module(module_name) + self.assertEqual(cm.exception.name, module_name) + + +Frozen_LoaderTests, Source_LoaderTests = util.test_both(LoaderTests, + machinery=[frozen_machinery, source_machinery]) -class InspectLoaderTests(unittest.TestCase): +class InspectLoaderTests: """Tests for InspectLoader methods for BuiltinImporter.""" def test_get_code(self): # There is no code object. - result = machinery.BuiltinImporter.get_code(builtin_util.NAME) + result = self.machinery.BuiltinImporter.get_code(builtin_util.NAME) self.assertIsNone(result) def test_get_source(self): # There is no source. - result = machinery.BuiltinImporter.get_source(builtin_util.NAME) + result = self.machinery.BuiltinImporter.get_source(builtin_util.NAME) self.assertIsNone(result) def test_is_package(self): # Cannot be a package. - result = machinery.BuiltinImporter.is_package(builtin_util.NAME) + result = self.machinery.BuiltinImporter.is_package(builtin_util.NAME) self.assertTrue(not result) def test_not_builtin(self): # Modules not built-in should raise ImportError. for meth_name in ('get_code', 'get_source', 'is_package'): - method = getattr(machinery.BuiltinImporter, meth_name) + method = getattr(self.machinery.BuiltinImporter, meth_name) with self.assertRaises(ImportError) as cm: method(builtin_util.BAD_NAME) self.assertRaises(builtin_util.BAD_NAME) - - -def test_main(): - from test.support import run_unittest - run_unittest(LoaderTests, InspectLoaderTests) +Frozen_InspectLoaderTests, Source_InspectLoaderTests = util.test_both( + InspectLoaderTests, + machinery=[frozen_machinery, source_machinery]) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/extension/test_case_sensitivity.py b/Lib/test/test_importlib/extension/test_case_sensitivity.py index 76c53e4fd6..bb2528e626 100644 --- a/Lib/test/test_importlib/extension/test_case_sensitivity.py +++ b/Lib/test/test_importlib/extension/test_case_sensitivity.py @@ -1,22 +1,27 @@ -import imp +from importlib import _bootstrap import sys from test import support import unittest -from importlib import _bootstrap + from .. import util from . import util as ext_util +frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') + +# XXX find_spec tests + +@unittest.skipIf(ext_util.FILENAME is None, '_testcapi not available') @util.case_insensitive_tests -class ExtensionModuleCaseSensitivityTest(unittest.TestCase): +class ExtensionModuleCaseSensitivityTest: def find_module(self): good_name = ext_util.NAME bad_name = good_name.upper() assert good_name != bad_name - finder = _bootstrap.FileFinder(ext_util.PATH, - (_bootstrap.ExtensionFileLoader, - _bootstrap.EXTENSION_SUFFIXES)) + finder = self.machinery.FileFinder(ext_util.PATH, + (self.machinery.ExtensionFileLoader, + self.machinery.EXTENSION_SUFFIXES)) return finder.find_module(bad_name) def test_case_sensitive(self): @@ -37,14 +42,10 @@ class ExtensionModuleCaseSensitivityTest(unittest.TestCase): loader = self.find_module() self.assertTrue(hasattr(loader, 'load_module')) - - - -def test_main(): - if ext_util.FILENAME is None: - return - support.run_unittest(ExtensionModuleCaseSensitivityTest) +Frozen_ExtensionCaseSensitivity, Source_ExtensionCaseSensitivity = util.test_both( + ExtensionModuleCaseSensitivityTest, + machinery=[frozen_machinery, source_machinery]) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/extension/test_finder.py b/Lib/test/test_importlib/extension/test_finder.py index 37f67727b0..990f29c4e5 100644 --- a/Lib/test/test_importlib/extension/test_finder.py +++ b/Lib/test/test_importlib/extension/test_finder.py @@ -1,18 +1,25 @@ -from importlib import machinery from .. import abc +from .. import util as test_util from . import util +machinery = test_util.import_importlib('importlib.machinery') + import unittest +import warnings + +# XXX find_spec tests class FinderTests(abc.FinderTests): """Test the finder for extension modules.""" def find_module(self, fullname): - importer = machinery.FileFinder(util.PATH, - (machinery.ExtensionFileLoader, - machinery.EXTENSION_SUFFIXES)) - return importer.find_module(fullname) + importer = self.machinery.FileFinder(util.PATH, + (self.machinery.ExtensionFileLoader, + self.machinery.EXTENSION_SUFFIXES)) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return importer.find_module(fullname) def test_module(self): self.assertTrue(self.find_module(util.NAME)) @@ -29,11 +36,9 @@ class FinderTests(abc.FinderTests): def test_failure(self): self.assertIsNone(self.find_module('asdfjkl;')) - -def test_main(): - from test.support import run_unittest - run_unittest(FinderTests) +Frozen_FinderTests, Source_FinderTests = test_util.test_both( + FinderTests, machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/extension/test_loader.py b/Lib/test/test_importlib/extension/test_loader.py index e790675f58..fd9abf269b 100644 --- a/Lib/test/test_importlib/extension/test_loader.py +++ b/Lib/test/test_importlib/extension/test_loader.py @@ -1,10 +1,12 @@ -from importlib import machinery from . import util as ext_util from .. import abc from .. import util +machinery = util.import_importlib('importlib.machinery') + import os.path import sys +import types import unittest @@ -13,8 +15,8 @@ class LoaderTests(abc.LoaderTests): """Test load_module() for extension modules.""" def setUp(self): - self.loader = machinery.ExtensionFileLoader(ext_util.NAME, - ext_util.FILEPATH) + self.loader = self.machinery.ExtensionFileLoader(ext_util.NAME, + ext_util.FILEPATH) def load_module(self, fullname): return self.loader.load_module(fullname) @@ -26,6 +28,15 @@ class LoaderTests(abc.LoaderTests): with self.assertRaises(ImportError): self.load_module('XXX') + def test_equality(self): + other = self.machinery.ExtensionFileLoader(ext_util.NAME, + ext_util.FILEPATH) + self.assertEqual(self.loader, other) + + def test_inequality(self): + other = self.machinery.ExtensionFileLoader('_' + ext_util.NAME, + ext_util.FILEPATH) + self.assertNotEqual(self.loader, other) def test_module(self): with util.uncache(ext_util.NAME): @@ -36,7 +47,7 @@ class LoaderTests(abc.LoaderTests): self.assertEqual(getattr(module, attr), value) self.assertIn(ext_util.NAME, sys.modules) self.assertIsInstance(module.__loader__, - machinery.ExtensionFileLoader) + self.machinery.ExtensionFileLoader) # No extension module as __init__ available for testing. test_package = None @@ -61,16 +72,15 @@ class LoaderTests(abc.LoaderTests): def test_is_package(self): self.assertFalse(self.loader.is_package(ext_util.NAME)) - for suffix in machinery.EXTENSION_SUFFIXES: + for suffix in self.machinery.EXTENSION_SUFFIXES: path = os.path.join('some', 'path', 'pkg', '__init__' + suffix) - loader = machinery.ExtensionFileLoader('pkg', path) + loader = self.machinery.ExtensionFileLoader('pkg', path) self.assertTrue(loader.is_package('pkg')) +Frozen_LoaderTests, Source_LoaderTests = util.test_both( + LoaderTests, machinery=machinery) -def test_main(): - from test.support import run_unittest - run_unittest(LoaderTests) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/extension/test_path_hook.py b/Lib/test/test_importlib/extension/test_path_hook.py index 1d969a1b28..49d6734711 100644 --- a/Lib/test/test_importlib/extension/test_path_hook.py +++ b/Lib/test/test_importlib/extension/test_path_hook.py @@ -1,32 +1,32 @@ -from importlib import machinery +from .. import util as test_util from . import util +machinery = test_util.import_importlib('importlib.machinery') + import collections -import imp import sys import unittest -class PathHookTests(unittest.TestCase): +class PathHookTests: """Test the path hook for extension modules.""" # XXX Should it only succeed for pre-existing directories? # XXX Should it only work for directories containing an extension module? def hook(self, entry): - return machinery.FileFinder.path_hook((machinery.ExtensionFileLoader, - machinery.EXTENSION_SUFFIXES))(entry) + return self.machinery.FileFinder.path_hook( + (self.machinery.ExtensionFileLoader, + self.machinery.EXTENSION_SUFFIXES))(entry) def test_success(self): # Path hook should handle a directory where a known extension module # exists. self.assertTrue(hasattr(self.hook(util.PATH), 'find_module')) - -def test_main(): - from test.support import run_unittest - run_unittest(PathHookTests) +Frozen_PathHooksTests, Source_PathHooksTests = test_util.test_both( + PathHookTests, machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/extension/util.py b/Lib/test/test_importlib/extension/util.py index a266dd98c8..8d089f0d7c 100644 --- a/Lib/test/test_importlib/extension/util.py +++ b/Lib/test/test_importlib/extension/util.py @@ -1,4 +1,3 @@ -import imp from importlib import machinery import os import sys diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py index f0abe0e989..f9f97f3851 100644 --- a/Lib/test/test_importlib/frozen/test_finder.py +++ b/Lib/test/test_importlib/frozen/test_finder.py @@ -1,15 +1,52 @@ -from importlib import machinery from .. import abc +from .. import util + +machinery = util.import_importlib('importlib.machinery') import unittest +class FindSpecTests(abc.FinderTests): + + """Test finding frozen modules.""" + + def find(self, name, path=None): + finder = self.machinery.FrozenImporter + return finder.find_spec(name, path) + + def test_module(self): + name = '__hello__' + spec = self.find(name) + self.assertEqual(spec.origin, 'frozen') + + def test_package(self): + spec = self.find('__phello__') + self.assertIsNotNone(spec) + + def test_module_in_package(self): + spec = self.find('__phello__.spam', ['__phello__']) + self.assertIsNotNone(spec) + + # No frozen package within another package to test with. + test_package_in_package = None + + # No easy way to test. + test_package_over_module = None + + def test_failure(self): + spec = self.find('<not real>') + self.assertIsNone(spec) + +Frozen_FindSpecTests, Source_FindSpecTests = util.test_both(FindSpecTests, + machinery=machinery) + + class FinderTests(abc.FinderTests): """Test finding frozen modules.""" def find(self, name, path=None): - finder = machinery.FrozenImporter + finder = self.machinery.FrozenImporter return finder.find_module(name, path) def test_module(self): @@ -35,11 +72,9 @@ class FinderTests(abc.FinderTests): loader = self.find('<not real>') self.assertIsNone(loader) - -def test_main(): - from test.support import run_unittest - run_unittest(FinderTests) +Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests, + machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py index f3a8bf67c7..7c0146456d 100644 --- a/Lib/test/test_importlib/frozen/test_loader.py +++ b/Lib/test/test_importlib/frozen/test_loader.py @@ -1,18 +1,104 @@ -from importlib import machinery -import imp -import unittest from .. import abc from .. import util + +machinery = util.import_importlib('importlib.machinery') + + +import sys from test.support import captured_stdout +import types +import unittest +import warnings + + +class ExecModuleTests(abc.LoaderTests): + + def exec_module(self, name): + with util.uncache(name), captured_stdout() as stdout: + spec = self.machinery.ModuleSpec( + name, self.machinery.FrozenImporter, origin='frozen', + is_package=self.machinery.FrozenImporter.is_package(name)) + module = types.ModuleType(name) + module.__spec__ = spec + assert not hasattr(module, 'initialized') + self.machinery.FrozenImporter.exec_module(module) + self.assertTrue(module.initialized) + self.assertTrue(hasattr(module, '__spec__')) + self.assertEqual(module.__spec__.origin, 'frozen') + return module, stdout.getvalue() + + def test_module(self): + name = '__hello__' + module, output = self.exec_module(name) + check = {'__name__': name} + for attr, value in check.items(): + self.assertEqual(getattr(module, attr), value) + self.assertEqual(output, 'Hello world!\n') + self.assertTrue(hasattr(module, '__spec__')) + + def test_package(self): + name = '__phello__' + module, output = self.exec_module(name) + check = {'__name__': name} + for attr, value in check.items(): + attr_value = getattr(module, attr) + self.assertEqual(attr_value, value, + 'for {name}.{attr}, {given!r} != {expected!r}'.format( + name=name, attr=attr, given=attr_value, + expected=value)) + self.assertEqual(output, 'Hello world!\n') + + def test_lacking_parent(self): + name = '__phello__.spam' + with util.uncache('__phello__'): + module, output = self.exec_module(name) + check = {'__name__': name} + for attr, value in check.items(): + attr_value = getattr(module, attr) + self.assertEqual(attr_value, value, + 'for {name}.{attr}, {given} != {expected!r}'.format( + name=name, attr=attr, given=attr_value, + expected=value)) + self.assertEqual(output, 'Hello world!\n') + + def test_module_repr(self): + name = '__hello__' + module, output = self.exec_module(name) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + repr_str = self.machinery.FrozenImporter.module_repr(module) + self.assertEqual(repr_str, + "<module '__hello__' (frozen)>") + + def test_module_repr_indirect(self): + name = '__hello__' + module, output = self.exec_module(name) + self.assertEqual(repr(module), + "<module '__hello__' (frozen)>") + + # No way to trigger an error in a frozen module. + test_state_after_failure = None + + def test_unloadable(self): + assert self.machinery.FrozenImporter.find_module('_not_real') is None + with self.assertRaises(ImportError) as cm: + self.exec_module('_not_real') + self.assertEqual(cm.exception.name, '_not_real') + +Frozen_ExecModuleTests, Source_ExecModuleTests = util.test_both(ExecModuleTests, + machinery=machinery) + class LoaderTests(abc.LoaderTests): def test_module(self): with util.uncache('__hello__'), captured_stdout() as stdout: - module = machinery.FrozenImporter.load_module('__hello__') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = self.machinery.FrozenImporter.load_module('__hello__') check = {'__name__': '__hello__', '__package__': '', - '__loader__': machinery.FrozenImporter, + '__loader__': self.machinery.FrozenImporter, } for attr, value in check.items(): self.assertEqual(getattr(module, attr), value) @@ -21,11 +107,13 @@ class LoaderTests(abc.LoaderTests): def test_package(self): with util.uncache('__phello__'), captured_stdout() as stdout: - module = machinery.FrozenImporter.load_module('__phello__') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = self.machinery.FrozenImporter.load_module('__phello__') check = {'__name__': '__phello__', '__package__': '__phello__', - '__path__': ['__phello__'], - '__loader__': machinery.FrozenImporter, + '__path__': [], + '__loader__': self.machinery.FrozenImporter, } for attr, value in check.items(): attr_value = getattr(module, attr) @@ -38,10 +126,12 @@ class LoaderTests(abc.LoaderTests): def test_lacking_parent(self): with util.uncache('__phello__', '__phello__.spam'), \ captured_stdout() as stdout: - module = machinery.FrozenImporter.load_module('__phello__.spam') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = self.machinery.FrozenImporter.load_module('__phello__.spam') check = {'__name__': '__phello__.spam', '__package__': '__phello__', - '__loader__': machinery.FrozenImporter, + '__loader__': self.machinery.FrozenImporter, } for attr, value in check.items(): attr_value = getattr(module, attr) @@ -53,29 +143,43 @@ class LoaderTests(abc.LoaderTests): def test_module_reuse(self): with util.uncache('__hello__'), captured_stdout() as stdout: - module1 = machinery.FrozenImporter.load_module('__hello__') - module2 = machinery.FrozenImporter.load_module('__hello__') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module1 = self.machinery.FrozenImporter.load_module('__hello__') + module2 = self.machinery.FrozenImporter.load_module('__hello__') self.assertIs(module1, module2) self.assertEqual(stdout.getvalue(), 'Hello world!\nHello world!\n') def test_module_repr(self): with util.uncache('__hello__'), captured_stdout(): - module = machinery.FrozenImporter.load_module('__hello__') - self.assertEqual(repr(module), + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = self.machinery.FrozenImporter.load_module('__hello__') + repr_str = self.machinery.FrozenImporter.module_repr(module) + self.assertEqual(repr_str, "<module '__hello__' (frozen)>") + def test_module_repr_indirect(self): + with util.uncache('__hello__'), captured_stdout(): + module = self.machinery.FrozenImporter.load_module('__hello__') + self.assertEqual(repr(module), + "<module '__hello__' (frozen)>") + # No way to trigger an error in a frozen module. test_state_after_failure = None def test_unloadable(self): - assert machinery.FrozenImporter.find_module('_not_real') is None + assert self.machinery.FrozenImporter.find_module('_not_real') is None with self.assertRaises(ImportError) as cm: - machinery.FrozenImporter.load_module('_not_real') + self.machinery.FrozenImporter.load_module('_not_real') self.assertEqual(cm.exception.name, '_not_real') +Frozen_LoaderTests, Source_LoaderTests = util.test_both(LoaderTests, + machinery=machinery) -class InspectLoaderTests(unittest.TestCase): + +class InspectLoaderTests: """Tests for the InspectLoader methods for FrozenImporter.""" @@ -83,15 +187,15 @@ class InspectLoaderTests(unittest.TestCase): # Make sure that the code object is good. name = '__hello__' with captured_stdout() as stdout: - code = machinery.FrozenImporter.get_code(name) - mod = imp.new_module(name) + code = self.machinery.FrozenImporter.get_code(name) + mod = types.ModuleType(name) exec(code, mod.__dict__) self.assertTrue(hasattr(mod, 'initialized')) self.assertEqual(stdout.getvalue(), 'Hello world!\n') def test_get_source(self): # Should always return None. - result = machinery.FrozenImporter.get_source('__hello__') + result = self.machinery.FrozenImporter.get_source('__hello__') self.assertIsNone(result) def test_is_package(self): @@ -99,22 +203,20 @@ class InspectLoaderTests(unittest.TestCase): test_for = (('__hello__', False), ('__phello__', True), ('__phello__.spam', False)) for name, is_package in test_for: - result = machinery.FrozenImporter.is_package(name) + result = self.machinery.FrozenImporter.is_package(name) self.assertEqual(bool(result), is_package) def test_failure(self): # Raise ImportError for modules that are not frozen. for meth_name in ('get_code', 'get_source', 'is_package'): - method = getattr(machinery.FrozenImporter, meth_name) + method = getattr(self.machinery.FrozenImporter, meth_name) with self.assertRaises(ImportError) as cm: method('importlib') self.assertEqual(cm.exception.name, 'importlib') - -def test_main(): - from test.support import run_unittest - run_unittest(LoaderTests, InspectLoaderTests) +Frozen_ILTests, Source_ILTests = util.test_both(InspectLoaderTests, + machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test___loader__.py b/Lib/test/test_importlib/import_/test___loader__.py new file mode 100644 index 0000000000..6df80101fa --- /dev/null +++ b/Lib/test/test_importlib/import_/test___loader__.py @@ -0,0 +1,70 @@ +from importlib import machinery +import sys +import types +import unittest + +from .. import util +from . import util as import_util + + +class SpecLoaderMock: + + def find_spec(self, fullname, path=None, target=None): + return machinery.ModuleSpec(fullname, self) + + def exec_module(self, module): + pass + + +class SpecLoaderAttributeTests: + + def test___loader__(self): + loader = SpecLoaderMock() + with util.uncache('blah'), util.import_state(meta_path=[loader]): + module = self.__import__('blah') + self.assertEqual(loader, module.__loader__) + +Frozen_SpecTests, Source_SpecTests = util.test_both( + SpecLoaderAttributeTests, __import__=import_util.__import__) + + +class LoaderMock: + + def find_module(self, fullname, path=None): + return self + + def load_module(self, fullname): + sys.modules[fullname] = self.module + return self.module + + +class LoaderAttributeTests: + + def test___loader___missing(self): + module = types.ModuleType('blah') + try: + del module.__loader__ + except AttributeError: + pass + loader = LoaderMock() + loader.module = module + with util.uncache('blah'), util.import_state(meta_path=[loader]): + module = self.__import__('blah') + self.assertEqual(loader, module.__loader__) + + def test___loader___is_None(self): + module = types.ModuleType('blah') + module.__loader__ = None + loader = LoaderMock() + loader.module = module + with util.uncache('blah'), util.import_state(meta_path=[loader]): + returned_module = self.__import__('blah') + self.assertEqual(loader, module.__loader__) + + +Frozen_Tests, Source_Tests = util.test_both(LoaderAttributeTests, + __import__=import_util.__import__) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_importlib/import_/test___package__.py b/Lib/test/test_importlib/import_/test___package__.py index 783cde1729..2e19725680 100644 --- a/Lib/test/test_importlib/import_/test___package__.py +++ b/Lib/test/test_importlib/import_/test___package__.py @@ -9,7 +9,7 @@ from .. import util from . import util as import_util -class Using__package__(unittest.TestCase): +class Using__package__: """Use of __package__ supercedes the use of __name__/__path__ to calculate what package a module belongs to. The basic algorithm is [__package__]:: @@ -36,10 +36,10 @@ class Using__package__(unittest.TestCase): def test_using___package__(self): # [__package__] - with util.mock_modules('pkg.__init__', 'pkg.fake') as importer: + with self.mock_modules('pkg.__init__', 'pkg.fake') as importer: with util.import_state(meta_path=[importer]): - import_util.import_('pkg.fake') - module = import_util.import_('', + self.__import__('pkg.fake') + module = self.__import__('', globals={'__package__': 'pkg.fake'}, fromlist=['attr'], level=2) self.assertEqual(module.__name__, 'pkg') @@ -49,10 +49,10 @@ class Using__package__(unittest.TestCase): globals_ = {'__name__': 'pkg.fake', '__path__': []} if package_as_None: globals_['__package__'] = None - with util.mock_modules('pkg.__init__', 'pkg.fake') as importer: + with self.mock_modules('pkg.__init__', 'pkg.fake') as importer: with util.import_state(meta_path=[importer]): - import_util.import_('pkg.fake') - module = import_util.import_('', globals= globals_, + self.__import__('pkg.fake') + module = self.__import__('', globals= globals_, fromlist=['attr'], level=2) self.assertEqual(module.__name__, 'pkg') @@ -63,16 +63,27 @@ class Using__package__(unittest.TestCase): def test_bad__package__(self): globals = {'__package__': '<not real>'} with self.assertRaises(SystemError): - import_util.import_('', globals, {}, ['relimport'], 1) + self.__import__('', globals, {}, ['relimport'], 1) def test_bunk__package__(self): globals = {'__package__': 42} with self.assertRaises(TypeError): - import_util.import_('', globals, {}, ['relimport'], 1) + self.__import__('', globals, {}, ['relimport'], 1) +class Using__package__PEP302(Using__package__): + mock_modules = util.mock_modules -@import_util.importlib_only -class Setting__package__(unittest.TestCase): +Frozen_UsingPackagePEP302, Source_UsingPackagePEP302 = util.test_both( + Using__package__PEP302, __import__=import_util.__import__) + +class Using__package__PEP302(Using__package__): + mock_modules = util.mock_spec + +Frozen_UsingPackagePEP451, Source_UsingPackagePEP451 = util.test_both( + Using__package__PEP302, __import__=import_util.__import__) + + +class Setting__package__: """Because __package__ is a new feature, it is not always set by a loader. Import will set it as needed to help with the transition to relying on @@ -84,36 +95,39 @@ class Setting__package__(unittest.TestCase): """ + __import__ = import_util.__import__[1] + # [top-level] def test_top_level(self): - with util.mock_modules('top_level') as mock: + with self.mock_modules('top_level') as mock: with util.import_state(meta_path=[mock]): del mock['top_level'].__package__ - module = import_util.import_('top_level') + module = self.__import__('top_level') self.assertEqual(module.__package__, '') # [package] def test_package(self): - with util.mock_modules('pkg.__init__') as mock: + with self.mock_modules('pkg.__init__') as mock: with util.import_state(meta_path=[mock]): del mock['pkg'].__package__ - module = import_util.import_('pkg') + module = self.__import__('pkg') self.assertEqual(module.__package__, 'pkg') # [submodule] def test_submodule(self): - with util.mock_modules('pkg.__init__', 'pkg.mod') as mock: + with self.mock_modules('pkg.__init__', 'pkg.mod') as mock: with util.import_state(meta_path=[mock]): del mock['pkg.mod'].__package__ - pkg = import_util.import_('pkg.mod') + pkg = self.__import__('pkg.mod') module = getattr(pkg, 'mod') self.assertEqual(module.__package__, 'pkg') +class Setting__package__PEP302(Setting__package__, unittest.TestCase): + mock_modules = util.mock_modules -def test_main(): - from test.support import run_unittest - run_unittest(Using__package__, Setting__package__) +class Setting__package__PEP451(Setting__package__, unittest.TestCase): + mock_modules = util.mock_spec if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_api.py b/Lib/test/test_importlib/import_/test_api.py index 3d4cd94889..439c105e9b 100644 --- a/Lib/test/test_importlib/import_/test_api.py +++ b/Lib/test/test_importlib/import_/test_api.py @@ -1,23 +1,41 @@ -from .. import util as importlib_test_util -from . import util -import imp +from .. import util +from . import util as import_util + +from importlib import machinery import sys +import types import unittest +PKG_NAME = 'fine' +SUBMOD_NAME = 'fine.bogus' + + +class BadSpecFinderLoader: + @classmethod + def find_spec(cls, fullname, path=None, target=None): + if fullname == SUBMOD_NAME: + spec = machinery.ModuleSpec(fullname, cls) + return spec + + @staticmethod + def exec_module(module): + if module.__name__ == SUBMOD_NAME: + raise ImportError('I cannot be loaded!') + class BadLoaderFinder: - bad = 'fine.bogus' @classmethod def find_module(cls, fullname, path): - if fullname == cls.bad: + if fullname == SUBMOD_NAME: return cls + @classmethod def load_module(cls, fullname): - if fullname == cls.bad: + if fullname == SUBMOD_NAME: raise ImportError('I cannot be loaded!') -class APITest(unittest.TestCase): +class APITest: """Test API-specific details for __import__ (e.g. raising the right exception when passing in an int for the module name).""" @@ -25,43 +43,52 @@ class APITest(unittest.TestCase): def test_name_requires_rparition(self): # Raise TypeError if a non-string is passed in for the module name. with self.assertRaises(TypeError): - util.import_(42) + self.__import__(42) def test_negative_level(self): # Raise ValueError when a negative level is specified. # PEP 328 did away with sys.module None entries and the ambiguity of # absolute/relative imports. with self.assertRaises(ValueError): - util.import_('os', globals(), level=-1) + self.__import__('os', globals(), level=-1) def test_nonexistent_fromlist_entry(self): # If something in fromlist doesn't exist, that's okay. # issue15715 - mod = imp.new_module('fine') + mod = types.ModuleType(PKG_NAME) mod.__path__ = ['XXX'] - with importlib_test_util.import_state(meta_path=[BadLoaderFinder]): - with importlib_test_util.uncache('fine'): - sys.modules['fine'] = mod - util.import_('fine', fromlist=['not here']) + with util.import_state(meta_path=[self.bad_finder_loader]): + with util.uncache(PKG_NAME): + sys.modules[PKG_NAME] = mod + self.__import__(PKG_NAME, fromlist=['not here']) def test_fromlist_load_error_propagates(self): # If something in fromlist triggers an exception not related to not # existing, let that exception propagate. # issue15316 - mod = imp.new_module('fine') + mod = types.ModuleType(PKG_NAME) mod.__path__ = ['XXX'] - with importlib_test_util.import_state(meta_path=[BadLoaderFinder]): - with importlib_test_util.uncache('fine'): - sys.modules['fine'] = mod + with util.import_state(meta_path=[self.bad_finder_loader]): + with util.uncache(PKG_NAME): + sys.modules[PKG_NAME] = mod with self.assertRaises(ImportError): - util.import_('fine', fromlist=['bogus']) + self.__import__(PKG_NAME, + fromlist=[SUBMOD_NAME.rpartition('.')[-1]]) + + +class OldAPITests(APITest): + bad_finder_loader = BadLoaderFinder + +Frozen_OldAPITests, Source_OldAPITests = util.test_both( + OldAPITests, __import__=import_util.__import__) +class SpecAPITests(APITest): + bad_finder_loader = BadSpecFinderLoader -def test_main(): - from test.support import run_unittest - run_unittest(APITest) +Frozen_SpecAPITests, Source_SpecAPITests = util.test_both( + SpecAPITests, __import__=import_util.__import__) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_caching.py b/Lib/test/test_importlib/import_/test_caching.py index 207378a6fb..c292ee4ac2 100644 --- a/Lib/test/test_importlib/import_/test_caching.py +++ b/Lib/test/test_importlib/import_/test_caching.py @@ -6,7 +6,7 @@ from types import MethodType import unittest -class UseCache(unittest.TestCase): +class UseCache: """When it comes to sys.modules, import prefers it over anything else. @@ -21,12 +21,13 @@ class UseCache(unittest.TestCase): ImportError is raised [None in cache]. """ + def test_using_cache(self): # [use cache] module_to_use = "some module found!" with util.uncache('some_module'): sys.modules['some_module'] = module_to_use - module = import_util.import_('some_module') + module = self.__import__('some_module') self.assertEqual(id(module_to_use), id(module)) def test_None_in_cache(self): @@ -35,9 +36,19 @@ class UseCache(unittest.TestCase): with util.uncache(name): sys.modules[name] = None with self.assertRaises(ImportError) as cm: - import_util.import_(name) + self.__import__(name) self.assertEqual(cm.exception.name, name) +Frozen_UseCache, Source_UseCache = util.test_both( + UseCache, __import__=import_util.__import__) + + +class ImportlibUseCache(UseCache, unittest.TestCase): + + # Pertinent only to PEP 302; exec_module() doesn't return a module. + + __import__ = import_util.__import__[1] + def create_mock(self, *names, return_=None): mock = util.mock_modules(*names) original_load = mock.load_module @@ -49,40 +60,33 @@ class UseCache(unittest.TestCase): # __import__ inconsistent between loaders and built-in import when it comes # to when to use the module in sys.modules and when not to. - @import_util.importlib_only def test_using_cache_after_loader(self): # [from cache on return] with self.create_mock('module') as mock: with util.import_state(meta_path=[mock]): - module = import_util.import_('module') + module = self.__import__('module') self.assertEqual(id(module), id(sys.modules['module'])) # See test_using_cache_after_loader() for reasoning. - @import_util.importlib_only def test_using_cache_for_assigning_to_attribute(self): # [from cache to attribute] with self.create_mock('pkg.__init__', 'pkg.module') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('pkg.module') + module = self.__import__('pkg.module') self.assertTrue(hasattr(module, 'module')) self.assertEqual(id(module.module), id(sys.modules['pkg.module'])) # See test_using_cache_after_loader() for reasoning. - @import_util.importlib_only def test_using_cache_for_fromlist(self): # [from cache for fromlist] with self.create_mock('pkg.__init__', 'pkg.module') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('pkg', fromlist=['module']) + module = self.__import__('pkg', fromlist=['module']) self.assertTrue(hasattr(module, 'module')) self.assertEqual(id(module.module), id(sys.modules['pkg.module'])) -def test_main(): - from test.support import run_unittest - run_unittest(UseCache) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_fromlist.py b/Lib/test/test_importlib/import_/test_fromlist.py index c16c33710f..58f244bddc 100644 --- a/Lib/test/test_importlib/import_/test_fromlist.py +++ b/Lib/test/test_importlib/import_/test_fromlist.py @@ -1,10 +1,10 @@ """Test that the semantics relating to the 'fromlist' argument are correct.""" from .. import util from . import util as import_util -import imp import unittest -class ReturnValue(unittest.TestCase): + +class ReturnValue: """The use of fromlist influences what import returns. @@ -17,20 +17,23 @@ class ReturnValue(unittest.TestCase): def test_return_from_import(self): # [import return] - with util.mock_modules('pkg.__init__', 'pkg.module') as importer: + with util.mock_spec('pkg.__init__', 'pkg.module') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('pkg.module') + module = self.__import__('pkg.module') self.assertEqual(module.__name__, 'pkg') def test_return_from_from_import(self): # [from return] with util.mock_modules('pkg.__init__', 'pkg.module')as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('pkg.module', fromlist=['attr']) + module = self.__import__('pkg.module', fromlist=['attr']) self.assertEqual(module.__name__, 'pkg.module') +Frozen_ReturnValue, Source_ReturnValue = util.test_both( + ReturnValue, __import__=import_util.__import__) + -class HandlingFromlist(unittest.TestCase): +class HandlingFromlist: """Using fromlist triggers different actions based on what is being asked of it. @@ -49,14 +52,14 @@ class HandlingFromlist(unittest.TestCase): # [object case] with util.mock_modules('module') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('module', fromlist=['attr']) + module = self.__import__('module', fromlist=['attr']) self.assertEqual(module.__name__, 'module') def test_nonexistent_object(self): # [bad object] with util.mock_modules('module') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('module', fromlist=['non_existent']) + module = self.__import__('module', fromlist=['non_existent']) self.assertEqual(module.__name__, 'module') self.assertTrue(not hasattr(module, 'non_existent')) @@ -64,7 +67,7 @@ class HandlingFromlist(unittest.TestCase): # [module] with util.mock_modules('pkg.__init__', 'pkg.module') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('pkg', fromlist=['module']) + module = self.__import__('pkg', fromlist=['module']) self.assertEqual(module.__name__, 'pkg') self.assertTrue(hasattr(module, 'module')) self.assertEqual(module.module.__name__, 'pkg.module') @@ -79,13 +82,13 @@ class HandlingFromlist(unittest.TestCase): module_code={'pkg.mod': module_code}) as importer: with util.import_state(meta_path=[importer]): with self.assertRaises(ImportError) as exc: - import_util.import_('pkg', fromlist=['mod']) + self.__import__('pkg', fromlist=['mod']) self.assertEqual('i_do_not_exist', exc.exception.name) def test_empty_string(self): with util.mock_modules('pkg.__init__', 'pkg.mod') as importer: with util.import_state(meta_path=[importer]): - module = import_util.import_('pkg.mod', fromlist=['']) + module = self.__import__('pkg.mod', fromlist=['']) self.assertEqual(module.__name__, 'pkg.mod') def basic_star_test(self, fromlist=['*']): @@ -93,7 +96,7 @@ class HandlingFromlist(unittest.TestCase): with util.mock_modules('pkg.__init__', 'pkg.module') as mock: with util.import_state(meta_path=[mock]): mock['pkg'].__all__ = ['module'] - module = import_util.import_('pkg', fromlist=fromlist) + module = self.__import__('pkg', fromlist=fromlist) self.assertEqual(module.__name__, 'pkg') self.assertTrue(hasattr(module, 'module')) self.assertEqual(module.module.__name__, 'pkg.module') @@ -111,17 +114,16 @@ class HandlingFromlist(unittest.TestCase): with context as mock: with util.import_state(meta_path=[mock]): mock['pkg'].__all__ = ['module1'] - module = import_util.import_('pkg', fromlist=['module2', '*']) + module = self.__import__('pkg', fromlist=['module2', '*']) self.assertEqual(module.__name__, 'pkg') self.assertTrue(hasattr(module, 'module1')) self.assertTrue(hasattr(module, 'module2')) self.assertEqual(module.module1.__name__, 'pkg.module1') self.assertEqual(module.module2.__name__, 'pkg.module2') +Frozen_FromList, Source_FromList = util.test_both( + HandlingFromlist, __import__=import_util.__import__) -def test_main(): - from test.support import run_unittest - run_unittest(ReturnValue, HandlingFromlist) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_meta_path.py b/Lib/test/test_importlib/import_/test_meta_path.py index 4d85f80e67..dc45420621 100644 --- a/Lib/test/test_importlib/import_/test_meta_path.py +++ b/Lib/test/test_importlib/import_/test_meta_path.py @@ -7,7 +7,7 @@ import unittest import warnings -class CallingOrder(unittest.TestCase): +class CallingOrder: """Calls to the importers on sys.meta_path happen in order that they are specified in the sequence, starting with the first importer @@ -18,23 +18,18 @@ class CallingOrder(unittest.TestCase): def test_first_called(self): # [first called] mod = 'top_level' - first = util.mock_modules(mod) - second = util.mock_modules(mod) - with util.mock_modules(mod) as first, util.mock_modules(mod) as second: - first.modules[mod] = 42 - second.modules[mod] = -13 + with util.mock_spec(mod) as first, util.mock_spec(mod) as second: with util.import_state(meta_path=[first, second]): - self.assertEqual(import_util.import_(mod), 42) + self.assertIs(self.__import__(mod), first.modules[mod]) def test_continuing(self): # [continuing] mod_name = 'for_real' - with util.mock_modules('nonexistent') as first, \ - util.mock_modules(mod_name) as second: - first.find_module = lambda self, fullname, path=None: None - second.modules[mod_name] = 42 + with util.mock_spec('nonexistent') as first, \ + util.mock_spec(mod_name) as second: + first.find_spec = lambda self, fullname, path=None, parent=None: None with util.import_state(meta_path=[first, second]): - self.assertEqual(import_util.import_(mod_name), 42) + self.assertIs(self.__import__(mod_name), second.modules[mod_name]) def test_empty(self): # Raise an ImportWarning if sys.meta_path is empty. @@ -46,41 +41,42 @@ class CallingOrder(unittest.TestCase): with util.import_state(meta_path=[]): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - self.assertIsNone(importlib._bootstrap._find_module('nothing', - None)) + self.assertIsNone(importlib._bootstrap._find_spec('nothing', + None)) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, ImportWarning)) +Frozen_CallingOrder, Source_CallingOrder = util.test_both( + CallingOrder, __import__=import_util.__import__) -class CallSignature(unittest.TestCase): + +class CallSignature: """If there is no __path__ entry on the parent module, then 'path' is None [no path]. Otherwise, the value for __path__ is passed in for the 'path' argument [path set].""" - def log(self, fxn): + def log_finder(self, importer): + fxn = getattr(importer, self.finder_name) log = [] def wrapper(self, *args, **kwargs): log.append([args, kwargs]) return fxn(*args, **kwargs) return log, wrapper - def test_no_path(self): # [no path] mod_name = 'top_level' assert '.' not in mod_name - with util.mock_modules(mod_name) as importer: - log, wrapped_call = self.log(importer.find_module) - importer.find_module = MethodType(wrapped_call, importer) + with self.mock_modules(mod_name) as importer: + log, wrapped_call = self.log_finder(importer) + setattr(importer, self.finder_name, MethodType(wrapped_call, importer)) with util.import_state(meta_path=[importer]): - import_util.import_(mod_name) + self.__import__(mod_name) assert len(log) == 1 args = log[0][0] kwargs = log[0][1] # Assuming all arguments are positional. - self.assertEqual(len(args), 2) - self.assertEqual(len(kwargs), 0) self.assertEqual(args[0], mod_name) self.assertIsNone(args[1]) @@ -90,12 +86,12 @@ class CallSignature(unittest.TestCase): mod_name = pkg_name + '.module' path = [42] assert '.' in mod_name - with util.mock_modules(pkg_name+'.__init__', mod_name) as importer: + with self.mock_modules(pkg_name+'.__init__', mod_name) as importer: importer.modules[pkg_name].__path__ = path - log, wrapped_call = self.log(importer.find_module) - importer.find_module = MethodType(wrapped_call, importer) + log, wrapped_call = self.log_finder(importer) + setattr(importer, self.finder_name, MethodType(wrapped_call, importer)) with util.import_state(meta_path=[importer]): - import_util.import_(mod_name) + self.__import__(mod_name) assert len(log) == 2 args = log[1][0] kwargs = log[1][1] @@ -104,12 +100,20 @@ class CallSignature(unittest.TestCase): self.assertEqual(args[0], mod_name) self.assertIs(args[1], path) +class CallSignaturePEP302(CallSignature): + mock_modules = util.mock_modules + finder_name = 'find_module' + +Frozen_CallSignaturePEP302, Source_CallSignaturePEP302 = util.test_both( + CallSignaturePEP302, __import__=import_util.__import__) +class CallSignaturePEP451(CallSignature): + mock_modules = util.mock_spec + finder_name = 'find_spec' -def test_main(): - from test.support import run_unittest - run_unittest(CallingOrder, CallSignature) +Frozen_CallSignaturePEP451, Source_CallSignaturePEP451 = util.test_both( + CallSignaturePEP451, __import__=import_util.__import__) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_packages.py b/Lib/test/test_importlib/import_/test_packages.py index bfa18dc217..55a5d14985 100644 --- a/Lib/test/test_importlib/import_/test_packages.py +++ b/Lib/test/test_importlib/import_/test_packages.py @@ -6,37 +6,37 @@ import importlib from test import support -class ParentModuleTests(unittest.TestCase): +class ParentModuleTests: """Importing a submodule should import the parent modules.""" def test_import_parent(self): - with util.mock_modules('pkg.__init__', 'pkg.module') as mock: + with util.mock_spec('pkg.__init__', 'pkg.module') as mock: with util.import_state(meta_path=[mock]): - module = import_util.import_('pkg.module') + module = self.__import__('pkg.module') self.assertIn('pkg', sys.modules) def test_bad_parent(self): - with util.mock_modules('pkg.module') as mock: + with util.mock_spec('pkg.module') as mock: with util.import_state(meta_path=[mock]): with self.assertRaises(ImportError) as cm: - import_util.import_('pkg.module') + self.__import__('pkg.module') self.assertEqual(cm.exception.name, 'pkg') def test_raising_parent_after_importing_child(self): def __init__(): import pkg.module 1/0 - mock = util.mock_modules('pkg.__init__', 'pkg.module', + mock = util.mock_spec('pkg.__init__', 'pkg.module', module_code={'pkg': __init__}) with mock: with util.import_state(meta_path=[mock]): with self.assertRaises(ZeroDivisionError): - import_util.import_('pkg') + self.__import__('pkg') self.assertNotIn('pkg', sys.modules) self.assertIn('pkg.module', sys.modules) with self.assertRaises(ZeroDivisionError): - import_util.import_('pkg.module') + self.__import__('pkg.module') self.assertNotIn('pkg', sys.modules) self.assertIn('pkg.module', sys.modules) @@ -44,17 +44,17 @@ class ParentModuleTests(unittest.TestCase): def __init__(): from . import module 1/0 - mock = util.mock_modules('pkg.__init__', 'pkg.module', + mock = util.mock_spec('pkg.__init__', 'pkg.module', module_code={'pkg': __init__}) with mock: with util.import_state(meta_path=[mock]): with self.assertRaises((ZeroDivisionError, ImportError)): # This raises ImportError on the "from . import module" # line, not sure why. - import_util.import_('pkg') + self.__import__('pkg') self.assertNotIn('pkg', sys.modules) with self.assertRaises((ZeroDivisionError, ImportError)): - import_util.import_('pkg.module') + self.__import__('pkg.module') self.assertNotIn('pkg', sys.modules) # XXX False #self.assertIn('pkg.module', sys.modules) @@ -63,7 +63,7 @@ class ParentModuleTests(unittest.TestCase): def __init__(): from ..subpkg import module 1/0 - mock = util.mock_modules('pkg.__init__', 'pkg.subpkg.__init__', + mock = util.mock_spec('pkg.__init__', 'pkg.subpkg.__init__', 'pkg.subpkg.module', module_code={'pkg.subpkg': __init__}) with mock: @@ -71,10 +71,10 @@ class ParentModuleTests(unittest.TestCase): with self.assertRaises((ZeroDivisionError, ImportError)): # This raises ImportError on the "from ..subpkg import module" # line, not sure why. - import_util.import_('pkg.subpkg') + self.__import__('pkg.subpkg') self.assertNotIn('pkg.subpkg', sys.modules) with self.assertRaises((ZeroDivisionError, ImportError)): - import_util.import_('pkg.subpkg.module') + self.__import__('pkg.subpkg.module') self.assertNotIn('pkg.subpkg', sys.modules) # XXX False #self.assertIn('pkg.subpkg.module', sys.modules) @@ -83,7 +83,7 @@ class ParentModuleTests(unittest.TestCase): # Try to import a submodule from a non-package should raise ImportError. assert not hasattr(sys, '__path__') with self.assertRaises(ImportError) as cm: - import_util.import_('sys.no_submodules_here') + self.__import__('sys.no_submodules_here') self.assertEqual(cm.exception.name, 'sys.no_submodules_here') def test_module_not_package_but_side_effects(self): @@ -93,20 +93,18 @@ class ParentModuleTests(unittest.TestCase): subname = name + '.b' def module_injection(): sys.modules[subname] = 'total bunk' - mock_modules = util.mock_modules('mod', + mock_spec = util.mock_spec('mod', module_code={'mod': module_injection}) - with mock_modules as mock: + with mock_spec as mock: with util.import_state(meta_path=[mock]): try: - submodule = import_util.import_(subname) + submodule = self.__import__(subname) finally: support.unload(subname) - -def test_main(): - from test.support import run_unittest - run_unittest(ParentModuleTests) +Frozen_ParentTests, Source_ParentTests = util.test_both( + ParentModuleTests, __import__=import_util.__import__) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_path.py b/Lib/test/test_importlib/import_/test_path.py index d82b7f6b0c..713e75447a 100644 --- a/Lib/test/test_importlib/import_/test_path.py +++ b/Lib/test/test_importlib/import_/test_path.py @@ -1,8 +1,9 @@ -from importlib import _bootstrap -from importlib import machinery -from importlib import import_module from .. import util from . import util as import_util + +importlib = util.import_importlib('importlib') +machinery = util.import_importlib('importlib.machinery') + import os import sys from types import ModuleType @@ -11,25 +12,25 @@ import warnings import zipimport -class FinderTests(unittest.TestCase): +class FinderTests: """Tests for PathFinder.""" def test_failure(self): - # Test None returned upon not finding a suitable finder. + # Test None returned upon not finding a suitable loader. module = '<test module>' with util.import_state(): - self.assertIsNone(machinery.PathFinder.find_module(module)) + self.assertIsNone(self.machinery.PathFinder.find_module(module)) def test_sys_path(self): # Test that sys.path is used when 'path' is None. # Implicitly tests that sys.path_importer_cache is used. module = '<test module>' path = '<test path>' - importer = util.mock_modules(module) + importer = util.mock_spec(module) with util.import_state(path_importer_cache={path: importer}, path=[path]): - loader = machinery.PathFinder.find_module(module) + loader = self.machinery.PathFinder.find_module(module) self.assertIs(loader, importer) def test_path(self): @@ -37,29 +38,29 @@ class FinderTests(unittest.TestCase): # Implicitly tests that sys.path_importer_cache is used. module = '<test module>' path = '<test path>' - importer = util.mock_modules(module) + importer = util.mock_spec(module) with util.import_state(path_importer_cache={path: importer}): - loader = machinery.PathFinder.find_module(module, [path]) + loader = self.machinery.PathFinder.find_module(module, [path]) self.assertIs(loader, importer) def test_empty_list(self): # An empty list should not count as asking for sys.path. module = 'module' path = '<test path>' - importer = util.mock_modules(module) + importer = util.mock_spec(module) with util.import_state(path_importer_cache={path: importer}, path=[path]): - self.assertIsNone(machinery.PathFinder.find_module('module', [])) + self.assertIsNone(self.machinery.PathFinder.find_module('module', [])) def test_path_hooks(self): # Test that sys.path_hooks is used. # Test that sys.path_importer_cache is set. module = '<test module>' path = '<test path>' - importer = util.mock_modules(module) + importer = util.mock_spec(module) hook = import_util.mock_path_hook(path, importer=importer) with util.import_state(path_hooks=[hook]): - loader = machinery.PathFinder.find_module(module, [path]) + loader = self.machinery.PathFinder.find_module(module, [path]) self.assertIs(loader, importer) self.assertIn(path, sys.path_importer_cache) self.assertIs(sys.path_importer_cache[path], importer) @@ -72,7 +73,7 @@ class FinderTests(unittest.TestCase): path=[path_entry]): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - self.assertIsNone(machinery.PathFinder.find_module('os')) + self.assertIsNone(self.machinery.PathFinder.find_module('os')) self.assertIsNone(sys.path_importer_cache[path_entry]) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, ImportWarning)) @@ -81,12 +82,12 @@ class FinderTests(unittest.TestCase): # The empty string should create a finder using the cwd. path = '' module = '<test module>' - importer = util.mock_modules(module) - hook = import_util.mock_path_hook(os.curdir, importer=importer) + importer = util.mock_spec(module) + hook = import_util.mock_path_hook(os.getcwd(), importer=importer) with util.import_state(path=[path], path_hooks=[hook]): - loader = machinery.PathFinder.find_module(module) + loader = self.machinery.PathFinder.find_module(module) self.assertIs(loader, importer) - self.assertIn(os.curdir, sys.path_importer_cache) + self.assertIn(os.getcwd(), sys.path_importer_cache) def test_None_on_sys_path(self): # Putting None in sys.path[0] caused an import regression from Python @@ -96,8 +97,8 @@ class FinderTests(unittest.TestCase): new_path_importer_cache = sys.path_importer_cache.copy() new_path_importer_cache.pop(None, None) new_path_hooks = [zipimport.zipimporter, - _bootstrap.FileFinder.path_hook( - *_bootstrap._get_supported_file_loaders())] + self.machinery.FileFinder.path_hook( + *self.importlib._bootstrap._get_supported_file_loaders())] missing = object() email = sys.modules.pop('email', missing) try: @@ -105,16 +106,15 @@ class FinderTests(unittest.TestCase): path=new_path, path_importer_cache=new_path_importer_cache, path_hooks=new_path_hooks): - module = import_module('email') + module = self.importlib.import_module('email') self.assertIsInstance(module, ModuleType) finally: if email is not missing: sys.modules['email'] = email +Frozen_FinderTests, Source_FinderTests = util.test_both( + FinderTests, importlib=importlib, machinery=machinery) -def test_main(): - from test.support import run_unittest - run_unittest(FinderTests) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_relative_imports.py b/Lib/test/test_importlib/import_/test_relative_imports.py index 4569c26424..b216e9cc13 100644 --- a/Lib/test/test_importlib/import_/test_relative_imports.py +++ b/Lib/test/test_importlib/import_/test_relative_imports.py @@ -4,7 +4,7 @@ from . import util as import_util import sys import unittest -class RelativeImports(unittest.TestCase): +class RelativeImports: """PEP 328 introduced relative imports. This allows for imports to occur from within a package without having to specify the actual package name. @@ -64,7 +64,7 @@ class RelativeImports(unittest.TestCase): uncache_names.append(name) else: uncache_names.append(name[:-len('.__init__')]) - with util.mock_modules(*create) as importer: + with util.mock_spec(*create) as importer: with util.import_state(meta_path=[importer]): for global_ in globals_: with util.uncache(*uncache_names): @@ -76,8 +76,8 @@ class RelativeImports(unittest.TestCase): create = 'pkg.__init__', 'pkg.mod2' globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.mod1'} def callback(global_): - import_util.import_('pkg') # For __import__(). - module = import_util.import_('', global_, fromlist=['mod2'], level=1) + self.__import__('pkg') # For __import__(). + module = self.__import__('', global_, fromlist=['mod2'], level=1) self.assertEqual(module.__name__, 'pkg') self.assertTrue(hasattr(module, 'mod2')) self.assertEqual(module.mod2.attr, 'pkg.mod2') @@ -88,8 +88,8 @@ class RelativeImports(unittest.TestCase): create = 'pkg.__init__', 'pkg.mod2' globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.mod1'} def callback(global_): - import_util.import_('pkg') # For __import__(). - module = import_util.import_('mod2', global_, fromlist=['attr'], + self.__import__('pkg') # For __import__(). + module = self.__import__('mod2', global_, fromlist=['attr'], level=1) self.assertEqual(module.__name__, 'pkg.mod2') self.assertEqual(module.attr, 'pkg.mod2') @@ -101,8 +101,8 @@ class RelativeImports(unittest.TestCase): globals_ = ({'__package__': 'pkg'}, {'__name__': 'pkg', '__path__': ['blah']}) def callback(global_): - import_util.import_('pkg') # For __import__(). - module = import_util.import_('', global_, fromlist=['module'], + self.__import__('pkg') # For __import__(). + module = self.__import__('', global_, fromlist=['module'], level=1) self.assertEqual(module.__name__, 'pkg') self.assertTrue(hasattr(module, 'module')) @@ -114,8 +114,8 @@ class RelativeImports(unittest.TestCase): create = 'pkg.__init__', 'pkg.module' globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.module'} def callback(global_): - import_util.import_('pkg') # For __import__(). - module = import_util.import_('', global_, fromlist=['attr'], level=1) + self.__import__('pkg') # For __import__(). + module = self.__import__('', global_, fromlist=['attr'], level=1) self.assertEqual(module.__name__, 'pkg') self.relative_import_test(create, globals_, callback) @@ -126,7 +126,7 @@ class RelativeImports(unittest.TestCase): globals_ = ({'__package__': 'pkg.subpkg1'}, {'__name__': 'pkg.subpkg1', '__path__': ['blah']}) def callback(global_): - module = import_util.import_('', global_, fromlist=['subpkg2'], + module = self.__import__('', global_, fromlist=['subpkg2'], level=2) self.assertEqual(module.__name__, 'pkg') self.assertTrue(hasattr(module, 'subpkg2')) @@ -142,8 +142,8 @@ class RelativeImports(unittest.TestCase): {'__name__': 'pkg.pkg1.pkg2.pkg3.pkg4.pkg5', '__path__': ['blah']}) def callback(global_): - import_util.import_(globals_[0]['__package__']) - module = import_util.import_('', global_, fromlist=['attr'], level=6) + self.__import__(globals_[0]['__package__']) + module = self.__import__('', global_, fromlist=['attr'], level=6) self.assertEqual(module.__name__, 'pkg') self.relative_import_test(create, globals_, callback) @@ -153,9 +153,9 @@ class RelativeImports(unittest.TestCase): globals_ = ({'__package__': 'pkg'}, {'__name__': 'pkg', '__path__': ['blah']}) def callback(global_): - import_util.import_('pkg') + self.__import__('pkg') with self.assertRaises(ValueError): - import_util.import_('', global_, fromlist=['top_level'], + self.__import__('', global_, fromlist=['top_level'], level=2) self.relative_import_test(create, globals_, callback) @@ -164,16 +164,16 @@ class RelativeImports(unittest.TestCase): create = ['top_level', 'pkg.__init__', 'pkg.module'] globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.module'} def callback(global_): - import_util.import_('pkg') + self.__import__('pkg') with self.assertRaises(ValueError): - import_util.import_('', global_, fromlist=['top_level'], + self.__import__('', global_, fromlist=['top_level'], level=2) self.relative_import_test(create, globals_, callback) def test_empty_name_w_level_0(self): # [empty name] with self.assertRaises(ValueError): - import_util.import_('') + self.__import__('') def test_import_from_different_package(self): # Test importing from a different package than the caller. @@ -186,8 +186,8 @@ class RelativeImports(unittest.TestCase): '__runpy_pkg__.uncle.cousin.nephew'] globals_ = {'__package__': '__runpy_pkg__.__runpy_pkg__'} def callback(global_): - import_util.import_('__runpy_pkg__.__runpy_pkg__') - module = import_util.import_('uncle.cousin', globals_, {}, + self.__import__('__runpy_pkg__.__runpy_pkg__') + module = self.__import__('uncle.cousin', globals_, {}, fromlist=['nephew'], level=2) self.assertEqual(module.__name__, '__runpy_pkg__.uncle.cousin') @@ -198,20 +198,19 @@ class RelativeImports(unittest.TestCase): create = ['crash.__init__', 'crash.mod'] globals_ = [{'__package__': 'crash', '__name__': 'crash'}] def callback(global_): - import_util.import_('crash') - mod = import_util.import_('mod', global_, {}, [], 1) + self.__import__('crash') + mod = self.__import__('mod', global_, {}, [], 1) self.assertEqual(mod.__name__, 'crash.mod') self.relative_import_test(create, globals_, callback) def test_relative_import_no_globals(self): # No globals for a relative import is an error. with self.assertRaises(KeyError): - import_util.import_('sys', level=1) + self.__import__('sys', level=1) +Frozen_RelativeImports, Source_RelativeImports = util.test_both( + RelativeImports, __import__=import_util.__import__) -def test_main(): - from test.support import run_unittest - run_unittest(RelativeImports) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/import_/util.py b/Lib/test/test_importlib/import_/util.py index 86ac065e64..dcb490f967 100644 --- a/Lib/test/test_importlib/import_/util.py +++ b/Lib/test/test_importlib/import_/util.py @@ -1,22 +1,14 @@ +from .. import util + +frozen_importlib, source_importlib = util.import_importlib('importlib') + +import builtins import functools import importlib import unittest -using___import__ = False - - -def import_(*args, **kwargs): - """Delegate to allow for injecting different implementations of import.""" - if using___import__: - return __import__(*args, **kwargs) - else: - return importlib.__import__(*args, **kwargs) - - -def importlib_only(fxn): - """Decorator to skip a test if using __builtins__.__import__.""" - return unittest.skipIf(using___import__, "importlib-specific test")(fxn) +__import__ = staticmethod(builtins.__import__), staticmethod(source_importlib.__import__) def mock_path_hook(*entries, importer): diff --git a/Lib/test/test_importlib/source/test_abc_loader.py b/Lib/test/test_importlib/source/test_abc_loader.py deleted file mode 100644 index 0d912b6469..0000000000 --- a/Lib/test/test_importlib/source/test_abc_loader.py +++ /dev/null @@ -1,906 +0,0 @@ -import importlib -from importlib import abc - -from .. import abc as testing_abc -from .. import util -from . import util as source_util - -import imp -import inspect -import io -import marshal -import os -import sys -import types -import unittest -import warnings - - -class SourceOnlyLoaderMock(abc.SourceLoader): - - # Globals that should be defined for all modules. - source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, " - b"repr(__loader__)])") - - def __init__(self, path): - self.path = path - - def get_data(self, path): - assert self.path == path - return self.source - - def get_filename(self, fullname): - return self.path - - def module_repr(self, module): - return '<module>' - - -class SourceLoaderMock(SourceOnlyLoaderMock): - - source_mtime = 1 - - def __init__(self, path, magic=imp.get_magic()): - super().__init__(path) - self.bytecode_path = imp.cache_from_source(self.path) - self.source_size = len(self.source) - data = bytearray(magic) - data.extend(importlib._w_long(self.source_mtime)) - data.extend(importlib._w_long(self.source_size)) - code_object = compile(self.source, self.path, 'exec', - dont_inherit=True) - data.extend(marshal.dumps(code_object)) - self.bytecode = bytes(data) - self.written = {} - - def get_data(self, path): - if path == self.path: - return super().get_data(path) - elif path == self.bytecode_path: - return self.bytecode - else: - raise IOError - - def path_stats(self, path): - assert path == self.path - return {'mtime': self.source_mtime, 'size': self.source_size} - - def set_data(self, path, data): - self.written[path] = bytes(data) - return path == self.bytecode_path - - -class PyLoaderMock(abc.PyLoader): - - # Globals that should be defined for all modules. - source = (b"_ = '::'.join([__name__, __file__, __package__, " - b"repr(__loader__)])") - - def __init__(self, data): - """Take a dict of 'module_name: path' pairings. - - Paths should have no file extension, allowing packages to be denoted by - ending in '__init__'. - - """ - self.module_paths = data - self.path_to_module = {val:key for key,val in data.items()} - - def get_data(self, path): - if path not in self.path_to_module: - raise IOError - return self.source - - def is_package(self, name): - filename = os.path.basename(self.get_filename(name)) - return os.path.splitext(filename)[0] == '__init__' - - def source_path(self, name): - try: - return self.module_paths[name] - except KeyError: - raise ImportError - - def get_filename(self, name): - """Silence deprecation warning.""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - path = super().get_filename(name) - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - return path - - def module_repr(self): - return '<module>' - - -class PyLoaderCompatMock(PyLoaderMock): - - """Mock that matches what is suggested to have a loader that is compatible - from Python 3.1 onwards.""" - - def get_filename(self, fullname): - try: - return self.module_paths[fullname] - except KeyError: - raise ImportError - - def source_path(self, fullname): - try: - return self.get_filename(fullname) - except ImportError: - return None - - -class PyPycLoaderMock(abc.PyPycLoader, PyLoaderMock): - - default_mtime = 1 - - def __init__(self, source, bc={}): - """Initialize mock. - - 'bc' is a dict keyed on a module's name. The value is dict with - possible keys of 'path', 'mtime', 'magic', and 'bc'. Except for 'path', - each of those keys control if any part of created bytecode is to - deviate from default values. - - """ - super().__init__(source) - self.module_bytecode = {} - self.path_to_bytecode = {} - self.bytecode_to_path = {} - for name, data in bc.items(): - self.path_to_bytecode[data['path']] = name - self.bytecode_to_path[name] = data['path'] - magic = data.get('magic', imp.get_magic()) - mtime = importlib._w_long(data.get('mtime', self.default_mtime)) - source_size = importlib._w_long(len(self.source) & 0xFFFFFFFF) - if 'bc' in data: - bc = data['bc'] - else: - bc = self.compile_bc(name) - self.module_bytecode[name] = magic + mtime + source_size + bc - - def compile_bc(self, name): - source_path = self.module_paths.get(name, '<test>') or '<test>' - code = compile(self.source, source_path, 'exec') - return marshal.dumps(code) - - def source_mtime(self, name): - if name in self.module_paths: - return self.default_mtime - elif name in self.module_bytecode: - return None - else: - raise ImportError - - def bytecode_path(self, name): - try: - return self.bytecode_to_path[name] - except KeyError: - if name in self.module_paths: - return None - else: - raise ImportError - - def write_bytecode(self, name, bytecode): - self.module_bytecode[name] = bytecode - return True - - def get_data(self, path): - if path in self.path_to_module: - return super().get_data(path) - elif path in self.path_to_bytecode: - name = self.path_to_bytecode[path] - return self.module_bytecode[name] - else: - raise IOError - - def is_package(self, name): - try: - return super().is_package(name) - except TypeError: - return '__init__' in self.bytecode_to_path[name] - - def get_code(self, name): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - code_object = super().get_code(name) - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - return code_object - -class PyLoaderTests(testing_abc.LoaderTests): - - """Tests for importlib.abc.PyLoader.""" - - mocker = PyLoaderMock - - def eq_attrs(self, ob, **kwargs): - for attr, val in kwargs.items(): - found = getattr(ob, attr) - self.assertEqual(found, val, - "{} attribute: {} != {}".format(attr, found, val)) - - def test_module(self): - name = '<module>' - path = os.path.join('', 'path', 'to', 'module') - mock = self.mocker({name: path}) - with util.uncache(name): - module = mock.load_module(name) - self.assertIn(name, sys.modules) - self.eq_attrs(module, __name__=name, __file__=path, __package__='', - __loader__=mock) - self.assertTrue(not hasattr(module, '__path__')) - return mock, name - - def test_package(self): - name = '<pkg>' - path = os.path.join('path', 'to', name, '__init__') - mock = self.mocker({name: path}) - with util.uncache(name): - module = mock.load_module(name) - self.assertIn(name, sys.modules) - self.eq_attrs(module, __name__=name, __file__=path, - __path__=[os.path.dirname(path)], __package__=name, - __loader__=mock) - return mock, name - - def test_lacking_parent(self): - name = 'pkg.mod' - path = os.path.join('path', 'to', 'pkg', 'mod') - mock = self.mocker({name: path}) - with util.uncache(name): - module = mock.load_module(name) - self.assertIn(name, sys.modules) - self.eq_attrs(module, __name__=name, __file__=path, __package__='pkg', - __loader__=mock) - self.assertFalse(hasattr(module, '__path__')) - return mock, name - - def test_module_reuse(self): - name = 'mod' - path = os.path.join('path', 'to', 'mod') - module = imp.new_module(name) - mock = self.mocker({name: path}) - with util.uncache(name): - sys.modules[name] = module - loaded_module = mock.load_module(name) - self.assertIs(loaded_module, module) - self.assertIs(sys.modules[name], module) - return mock, name - - def test_state_after_failure(self): - name = "mod" - module = imp.new_module(name) - module.blah = None - mock = self.mocker({name: os.path.join('path', 'to', 'mod')}) - mock.source = b"1/0" - with util.uncache(name): - sys.modules[name] = module - with self.assertRaises(ZeroDivisionError): - mock.load_module(name) - self.assertIs(sys.modules[name], module) - self.assertTrue(hasattr(module, 'blah')) - return mock - - def test_unloadable(self): - name = "mod" - mock = self.mocker({name: os.path.join('path', 'to', 'mod')}) - mock.source = b"1/0" - with util.uncache(name): - with self.assertRaises(ZeroDivisionError): - mock.load_module(name) - self.assertNotIn(name, sys.modules) - return mock - - -class PyLoaderCompatTests(PyLoaderTests): - - """Test that the suggested code to make a loader that is compatible from - Python 3.1 forward works.""" - - mocker = PyLoaderCompatMock - - -class PyLoaderInterfaceTests(unittest.TestCase): - - """Tests for importlib.abc.PyLoader to make sure that when source_path() - doesn't return a path everything works as expected.""" - - def test_no_source_path(self): - # No source path should lead to ImportError. - name = 'mod' - mock = PyLoaderMock({}) - with util.uncache(name), self.assertRaises(ImportError): - mock.load_module(name) - - def test_source_path_is_None(self): - name = 'mod' - mock = PyLoaderMock({name: None}) - with util.uncache(name), self.assertRaises(ImportError): - mock.load_module(name) - - def test_get_filename_with_source_path(self): - # get_filename() should return what source_path() returns. - name = 'mod' - path = os.path.join('path', 'to', 'source') - mock = PyLoaderMock({name: path}) - with util.uncache(name): - self.assertEqual(mock.get_filename(name), path) - - def test_get_filename_no_source_path(self): - # get_filename() should raise ImportError if source_path returns None. - name = 'mod' - mock = PyLoaderMock({name: None}) - with util.uncache(name), self.assertRaises(ImportError): - mock.get_filename(name) - - -class PyPycLoaderTests(PyLoaderTests): - - """Tests for importlib.abc.PyPycLoader.""" - - mocker = PyPycLoaderMock - - @source_util.writes_bytecode_files - def verify_bytecode(self, mock, name): - assert name in mock.module_paths - self.assertIn(name, mock.module_bytecode) - magic = mock.module_bytecode[name][:4] - self.assertEqual(magic, imp.get_magic()) - mtime = importlib._r_long(mock.module_bytecode[name][4:8]) - self.assertEqual(mtime, 1) - source_size = mock.module_bytecode[name][8:12] - self.assertEqual(len(mock.source) & 0xFFFFFFFF, - importlib._r_long(source_size)) - bc = mock.module_bytecode[name][12:] - self.assertEqual(bc, mock.compile_bc(name)) - - def test_module(self): - mock, name = super().test_module() - self.verify_bytecode(mock, name) - - def test_package(self): - mock, name = super().test_package() - self.verify_bytecode(mock, name) - - def test_lacking_parent(self): - mock, name = super().test_lacking_parent() - self.verify_bytecode(mock, name) - - def test_module_reuse(self): - mock, name = super().test_module_reuse() - self.verify_bytecode(mock, name) - - def test_state_after_failure(self): - super().test_state_after_failure() - - def test_unloadable(self): - super().test_unloadable() - - -class PyPycLoaderInterfaceTests(unittest.TestCase): - - """Test for the interface of importlib.abc.PyPycLoader.""" - - def get_filename_check(self, src_path, bc_path, expect): - name = 'mod' - mock = PyPycLoaderMock({name: src_path}, {name: {'path': bc_path}}) - with util.uncache(name): - assert mock.source_path(name) == src_path - assert mock.bytecode_path(name) == bc_path - self.assertEqual(mock.get_filename(name), expect) - - def test_filename_with_source_bc(self): - # When source and bytecode paths present, return the source path. - self.get_filename_check('source_path', 'bc_path', 'source_path') - - def test_filename_with_source_no_bc(self): - # With source but no bc, return source path. - self.get_filename_check('source_path', None, 'source_path') - - def test_filename_with_no_source_bc(self): - # With not source but bc, return the bc path. - self.get_filename_check(None, 'bc_path', 'bc_path') - - def test_filename_with_no_source_or_bc(self): - # With no source or bc, raise ImportError. - name = 'mod' - mock = PyPycLoaderMock({name: None}, {name: {'path': None}}) - with util.uncache(name), self.assertRaises(ImportError): - mock.get_filename(name) - - -class SkipWritingBytecodeTests(unittest.TestCase): - - """Test that bytecode is properly handled based on - sys.dont_write_bytecode.""" - - @source_util.writes_bytecode_files - def run_test(self, dont_write_bytecode): - name = 'mod' - mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')}) - sys.dont_write_bytecode = dont_write_bytecode - with util.uncache(name): - mock.load_module(name) - self.assertIsNot(name in mock.module_bytecode, dont_write_bytecode) - - def test_no_bytecode_written(self): - self.run_test(True) - - def test_bytecode_written(self): - self.run_test(False) - - -class RegeneratedBytecodeTests(unittest.TestCase): - - """Test that bytecode is regenerated as expected.""" - - @source_util.writes_bytecode_files - def test_different_magic(self): - # A different magic number should lead to new bytecode. - name = 'mod' - bad_magic = b'\x00\x00\x00\x00' - assert bad_magic != imp.get_magic() - mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')}, - {name: {'path': os.path.join('path', 'to', - 'mod.bytecode'), - 'magic': bad_magic}}) - with util.uncache(name): - mock.load_module(name) - self.assertIn(name, mock.module_bytecode) - magic = mock.module_bytecode[name][:4] - self.assertEqual(magic, imp.get_magic()) - - @source_util.writes_bytecode_files - def test_old_mtime(self): - # Bytecode with an older mtime should be regenerated. - name = 'mod' - old_mtime = PyPycLoaderMock.default_mtime - 1 - mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')}, - {name: {'path': 'path/to/mod.bytecode', 'mtime': old_mtime}}) - with util.uncache(name): - mock.load_module(name) - self.assertIn(name, mock.module_bytecode) - mtime = importlib._r_long(mock.module_bytecode[name][4:8]) - self.assertEqual(mtime, PyPycLoaderMock.default_mtime) - - -class BadBytecodeFailureTests(unittest.TestCase): - - """Test import failures when there is no source and parts of the bytecode - is bad.""" - - def test_bad_magic(self): - # A bad magic number should lead to an ImportError. - name = 'mod' - bad_magic = b'\x00\x00\x00\x00' - bc = {name: - {'path': os.path.join('path', 'to', 'mod'), - 'magic': bad_magic}} - mock = PyPycLoaderMock({name: None}, bc) - with util.uncache(name), self.assertRaises(ImportError) as cm: - mock.load_module(name) - self.assertEqual(cm.exception.name, name) - - def test_no_bytecode(self): - # Missing code object bytecode should lead to an EOFError. - name = 'mod' - bc = {name: {'path': os.path.join('path', 'to', 'mod'), 'bc': b''}} - mock = PyPycLoaderMock({name: None}, bc) - with util.uncache(name), self.assertRaises(EOFError): - mock.load_module(name) - - def test_bad_bytecode(self): - # Malformed code object bytecode should lead to a ValueError. - name = 'mod' - bc = {name: {'path': os.path.join('path', 'to', 'mod'), 'bc': b'1234'}} - mock = PyPycLoaderMock({name: None}, bc) - with util.uncache(name), self.assertRaises(ValueError): - mock.load_module(name) - - -def raise_ImportError(*args, **kwargs): - raise ImportError - -class MissingPathsTests(unittest.TestCase): - - """Test what happens when a source or bytecode path does not exist (either - from *_path returning None or raising ImportError).""" - - def test_source_path_None(self): - # Bytecode should be used when source_path returns None, along with - # __file__ being set to the bytecode path. - name = 'mod' - bytecode_path = 'path/to/mod' - mock = PyPycLoaderMock({name: None}, {name: {'path': bytecode_path}}) - with util.uncache(name): - module = mock.load_module(name) - self.assertEqual(module.__file__, bytecode_path) - - # Testing for bytecode_path returning None handled by all tests where no - # bytecode initially exists. - - def test_all_paths_None(self): - # If all *_path methods return None, raise ImportError. - name = 'mod' - mock = PyPycLoaderMock({name: None}) - with util.uncache(name), self.assertRaises(ImportError) as cm: - mock.load_module(name) - self.assertEqual(cm.exception.name, name) - - def test_source_path_ImportError(self): - # An ImportError from source_path should trigger an ImportError. - name = 'mod' - mock = PyPycLoaderMock({}, {name: {'path': os.path.join('path', 'to', - 'mod')}}) - with util.uncache(name), self.assertRaises(ImportError): - mock.load_module(name) - - def test_bytecode_path_ImportError(self): - # An ImportError from bytecode_path should trigger an ImportError. - name = 'mod' - mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')}) - bad_meth = types.MethodType(raise_ImportError, mock) - mock.bytecode_path = bad_meth - with util.uncache(name), self.assertRaises(ImportError) as cm: - mock.load_module(name) - - -class SourceLoaderTestHarness(unittest.TestCase): - - def setUp(self, *, is_package=True, **kwargs): - self.package = 'pkg' - if is_package: - self.path = os.path.join(self.package, '__init__.py') - self.name = self.package - else: - module_name = 'mod' - self.path = os.path.join(self.package, '.'.join(['mod', 'py'])) - self.name = '.'.join([self.package, module_name]) - self.cached = imp.cache_from_source(self.path) - self.loader = self.loader_mock(self.path, **kwargs) - - def verify_module(self, module): - self.assertEqual(module.__name__, self.name) - self.assertEqual(module.__file__, self.path) - self.assertEqual(module.__cached__, self.cached) - self.assertEqual(module.__package__, self.package) - self.assertEqual(module.__loader__, self.loader) - values = module._.split('::') - self.assertEqual(values[0], self.name) - self.assertEqual(values[1], self.path) - self.assertEqual(values[2], self.cached) - self.assertEqual(values[3], self.package) - self.assertEqual(values[4], repr(self.loader)) - - def verify_code(self, code_object): - module = imp.new_module(self.name) - module.__file__ = self.path - module.__cached__ = self.cached - module.__package__ = self.package - module.__loader__ = self.loader - module.__path__ = [] - exec(code_object, module.__dict__) - self.verify_module(module) - - -class SourceOnlyLoaderTests(SourceLoaderTestHarness): - - """Test importlib.abc.SourceLoader for source-only loading. - - Reload testing is subsumed by the tests for - importlib.util.module_for_loader. - - """ - - loader_mock = SourceOnlyLoaderMock - - def test_get_source(self): - # Verify the source code is returned as a string. - # If an IOError is raised by get_data then raise ImportError. - expected_source = self.loader.source.decode('utf-8') - self.assertEqual(self.loader.get_source(self.name), expected_source) - def raise_IOError(path): - raise IOError - self.loader.get_data = raise_IOError - with self.assertRaises(ImportError) as cm: - self.loader.get_source(self.name) - self.assertEqual(cm.exception.name, self.name) - - def test_is_package(self): - # Properly detect when loading a package. - self.setUp(is_package=False) - self.assertFalse(self.loader.is_package(self.name)) - self.setUp(is_package=True) - self.assertTrue(self.loader.is_package(self.name)) - self.assertFalse(self.loader.is_package(self.name + '.__init__')) - - def test_get_code(self): - # Verify the code object is created. - code_object = self.loader.get_code(self.name) - self.verify_code(code_object) - - def test_load_module(self): - # Loading a module should set __name__, __loader__, __package__, - # __path__ (for packages), __file__, and __cached__. - # The module should also be put into sys.modules. - with util.uncache(self.name): - module = self.loader.load_module(self.name) - self.verify_module(module) - self.assertEqual(module.__path__, [os.path.dirname(self.path)]) - self.assertIn(self.name, sys.modules) - - def test_package_settings(self): - # __package__ needs to be set, while __path__ is set on if the module - # is a package. - # Testing the values for a package are covered by test_load_module. - self.setUp(is_package=False) - with util.uncache(self.name): - module = self.loader.load_module(self.name) - self.verify_module(module) - self.assertTrue(not hasattr(module, '__path__')) - - def test_get_source_encoding(self): - # Source is considered encoded in UTF-8 by default unless otherwise - # specified by an encoding line. - source = "_ = 'ü'" - self.loader.source = source.encode('utf-8') - returned_source = self.loader.get_source(self.name) - self.assertEqual(returned_source, source) - source = "# coding: latin-1\n_ = ü" - self.loader.source = source.encode('latin-1') - returned_source = self.loader.get_source(self.name) - self.assertEqual(returned_source, source) - - -@unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true") -class SourceLoaderBytecodeTests(SourceLoaderTestHarness): - - """Test importlib.abc.SourceLoader's use of bytecode. - - Source-only testing handled by SourceOnlyLoaderTests. - - """ - - loader_mock = SourceLoaderMock - - def verify_code(self, code_object, *, bytecode_written=False): - super().verify_code(code_object) - if bytecode_written: - self.assertIn(self.cached, self.loader.written) - data = bytearray(imp.get_magic()) - data.extend(importlib._w_long(self.loader.source_mtime)) - data.extend(importlib._w_long(self.loader.source_size)) - data.extend(marshal.dumps(code_object)) - self.assertEqual(self.loader.written[self.cached], bytes(data)) - - def test_code_with_everything(self): - # When everything should work. - code_object = self.loader.get_code(self.name) - self.verify_code(code_object) - - def test_no_bytecode(self): - # If no bytecode exists then move on to the source. - self.loader.bytecode_path = "<does not exist>" - # Sanity check - with self.assertRaises(IOError): - bytecode_path = imp.cache_from_source(self.path) - self.loader.get_data(bytecode_path) - code_object = self.loader.get_code(self.name) - self.verify_code(code_object, bytecode_written=True) - - def test_code_bad_timestamp(self): - # Bytecode is only used when the timestamp matches the source EXACTLY. - for source_mtime in (0, 2): - assert source_mtime != self.loader.source_mtime - original = self.loader.source_mtime - self.loader.source_mtime = source_mtime - # If bytecode is used then EOFError would be raised by marshal. - self.loader.bytecode = self.loader.bytecode[8:] - code_object = self.loader.get_code(self.name) - self.verify_code(code_object, bytecode_written=True) - self.loader.source_mtime = original - - def test_code_bad_magic(self): - # Skip over bytecode with a bad magic number. - self.setUp(magic=b'0000') - # If bytecode is used then EOFError would be raised by marshal. - self.loader.bytecode = self.loader.bytecode[8:] - code_object = self.loader.get_code(self.name) - self.verify_code(code_object, bytecode_written=True) - - def test_dont_write_bytecode(self): - # Bytecode is not written if sys.dont_write_bytecode is true. - # Can assume it is false already thanks to the skipIf class decorator. - try: - sys.dont_write_bytecode = True - self.loader.bytecode_path = "<does not exist>" - code_object = self.loader.get_code(self.name) - self.assertNotIn(self.cached, self.loader.written) - finally: - sys.dont_write_bytecode = False - - def test_no_set_data(self): - # If set_data is not defined, one can still read bytecode. - self.setUp(magic=b'0000') - original_set_data = self.loader.__class__.set_data - try: - del self.loader.__class__.set_data - code_object = self.loader.get_code(self.name) - self.verify_code(code_object) - finally: - self.loader.__class__.set_data = original_set_data - - def test_set_data_raises_exceptions(self): - # Raising NotImplementedError or IOError is okay for set_data. - def raise_exception(exc): - def closure(*args, **kwargs): - raise exc - return closure - - self.setUp(magic=b'0000') - self.loader.set_data = raise_exception(NotImplementedError) - code_object = self.loader.get_code(self.name) - self.verify_code(code_object) - - -class SourceLoaderGetSourceTests(unittest.TestCase): - - """Tests for importlib.abc.SourceLoader.get_source().""" - - def test_default_encoding(self): - # Should have no problems with UTF-8 text. - name = 'mod' - mock = SourceOnlyLoaderMock('mod.file') - source = 'x = "ü"' - mock.source = source.encode('utf-8') - returned_source = mock.get_source(name) - self.assertEqual(returned_source, source) - - def test_decoded_source(self): - # Decoding should work. - name = 'mod' - mock = SourceOnlyLoaderMock("mod.file") - source = "# coding: Latin-1\nx='ü'" - assert source.encode('latin-1') != source.encode('utf-8') - mock.source = source.encode('latin-1') - returned_source = mock.get_source(name) - self.assertEqual(returned_source, source) - - def test_universal_newlines(self): - # PEP 302 says universal newlines should be used. - name = 'mod' - mock = SourceOnlyLoaderMock('mod.file') - source = "x = 42\r\ny = -13\r\n" - mock.source = source.encode('utf-8') - expect = io.IncrementalNewlineDecoder(None, True).decode(source) - self.assertEqual(mock.get_source(name), expect) - - -class AbstractMethodImplTests(unittest.TestCase): - - """Test the concrete abstractmethod implementations.""" - - class MetaPathFinder(abc.MetaPathFinder): - def find_module(self, fullname, path): - super().find_module(fullname, path) - - class PathEntryFinder(abc.PathEntryFinder): - def find_module(self, _): - super().find_module(_) - - def find_loader(self, _): - super().find_loader(_) - - class Finder(abc.Finder): - def find_module(self, fullname, path): - super().find_module(fullname, path) - - class Loader(abc.Loader): - def load_module(self, fullname): - super().load_module(fullname) - def module_repr(self, module): - super().module_repr(module) - - class ResourceLoader(Loader, abc.ResourceLoader): - def get_data(self, _): - super().get_data(_) - - class InspectLoader(Loader, abc.InspectLoader): - def is_package(self, _): - super().is_package(_) - - def get_code(self, _): - super().get_code(_) - - def get_source(self, _): - super().get_source(_) - - class ExecutionLoader(InspectLoader, abc.ExecutionLoader): - def get_filename(self, _): - super().get_filename(_) - - class SourceLoader(ResourceLoader, ExecutionLoader, abc.SourceLoader): - pass - - class PyLoader(ResourceLoader, InspectLoader, abc.PyLoader): - def source_path(self, _): - super().source_path(_) - - class PyPycLoader(PyLoader, abc.PyPycLoader): - def bytecode_path(self, _): - super().bytecode_path(_) - - def source_mtime(self, _): - super().source_mtime(_) - - def write_bytecode(self, _, _2): - super().write_bytecode(_, _2) - - def raises_NotImplementedError(self, ins, *args): - for method_name in args: - method = getattr(ins, method_name) - arg_count = len(inspect.getfullargspec(method)[0]) - 1 - args = [''] * arg_count - try: - method(*args) - except NotImplementedError: - pass - else: - msg = "{}.{} did not raise NotImplementedError" - self.fail(msg.format(ins.__class__.__name__, method_name)) - - def test_Loader(self): - self.raises_NotImplementedError(self.Loader(), 'load_module') - - # XXX misplaced; should be somewhere else - def test_Finder(self): - self.raises_NotImplementedError(self.Finder(), 'find_module') - - def test_ResourceLoader(self): - self.raises_NotImplementedError(self.ResourceLoader(), 'load_module', - 'get_data') - - def test_InspectLoader(self): - self.raises_NotImplementedError(self.InspectLoader(), 'load_module', - 'is_package', 'get_code', 'get_source') - - def test_ExecutionLoader(self): - self.raises_NotImplementedError(self.ExecutionLoader(), 'load_module', - 'is_package', 'get_code', 'get_source', - 'get_filename') - - def test_SourceLoader(self): - ins = self.SourceLoader() - # Required abstractmethods. - self.raises_NotImplementedError(ins, 'get_filename', 'get_data') - # Optional abstractmethods. - self.raises_NotImplementedError(ins,'path_stats', 'set_data') - - def test_PyLoader(self): - self.raises_NotImplementedError(self.PyLoader(), 'source_path', - 'get_data', 'is_package') - - def test_PyPycLoader(self): - self.raises_NotImplementedError(self.PyPycLoader(), 'source_path', - 'source_mtime', 'bytecode_path', - 'write_bytecode') - - -def test_main(): - from test.support import run_unittest - run_unittest(PyLoaderTests, PyLoaderCompatTests, - PyLoaderInterfaceTests, - PyPycLoaderTests, PyPycLoaderInterfaceTests, - SkipWritingBytecodeTests, RegeneratedBytecodeTests, - BadBytecodeFailureTests, MissingPathsTests, - SourceOnlyLoaderTests, - SourceLoaderBytecodeTests, - SourceLoaderGetSourceTests, - AbstractMethodImplTests) - - -if __name__ == '__main__': - test_main() diff --git a/Lib/test/test_importlib/source/test_case_sensitivity.py b/Lib/test/test_importlib/source/test_case_sensitivity.py index 241173fb44..efd3146d90 100644 --- a/Lib/test/test_importlib/source/test_case_sensitivity.py +++ b/Lib/test/test_importlib/source/test_case_sensitivity.py @@ -1,9 +1,10 @@ """Test case-sensitivity (PEP 235).""" -from importlib import _bootstrap -from importlib import machinery from .. import util from . import util as source_util -import imp + +importlib = util.import_importlib('importlib') +machinery = util.import_importlib('importlib.machinery') + import os import sys from test import support as test_support @@ -11,7 +12,7 @@ import unittest @util.case_insensitive_tests -class CaseSensitivityTest(unittest.TestCase): +class CaseSensitivityTest: """PEP 235 dictates that on case-preserving, case-insensitive file systems that imports are case-sensitive unless the PYTHONCASEOK environment @@ -20,13 +21,12 @@ class CaseSensitivityTest(unittest.TestCase): name = 'MoDuLe' assert name != name.lower() - def find(self, path): - finder = machinery.FileFinder(path, - (machinery.SourceFileLoader, - machinery.SOURCE_SUFFIXES), - (machinery.SourcelessFileLoader, - machinery.BYTECODE_SUFFIXES)) - return finder.find_module(self.name) + def finder(self, path): + return self.machinery.FileFinder(path, + (self.machinery.SourceFileLoader, + self.machinery.SOURCE_SUFFIXES), + (self.machinery.SourcelessFileLoader, + self.machinery.BYTECODE_SUFFIXES)) def sensitivity_test(self): """Look for a module with matching and non-matching sensitivity.""" @@ -36,35 +36,48 @@ class CaseSensitivityTest(unittest.TestCase): with context as mapping: sensitive_path = os.path.join(mapping['.root'], 'sensitive') insensitive_path = os.path.join(mapping['.root'], 'insensitive') - return self.find(sensitive_path), self.find(insensitive_path) + sensitive_finder = self.finder(sensitive_path) + insensitive_finder = self.finder(insensitive_path) + return self.find(sensitive_finder), self.find(insensitive_finder) def test_sensitive(self): with test_support.EnvironmentVarGuard() as env: env.unset('PYTHONCASEOK') - if b'PYTHONCASEOK' in _bootstrap._os.environ: + if b'PYTHONCASEOK' in self.importlib._bootstrap._os.environ: self.skipTest('os.environ changes not reflected in ' '_os.environ') sensitive, insensitive = self.sensitivity_test() - self.assertTrue(hasattr(sensitive, 'load_module')) + self.assertIsNotNone(sensitive) self.assertIn(self.name, sensitive.get_filename(self.name)) self.assertIsNone(insensitive) def test_insensitive(self): with test_support.EnvironmentVarGuard() as env: env.set('PYTHONCASEOK', '1') - if b'PYTHONCASEOK' not in _bootstrap._os.environ: + if b'PYTHONCASEOK' not in self.importlib._bootstrap._os.environ: self.skipTest('os.environ changes not reflected in ' '_os.environ') sensitive, insensitive = self.sensitivity_test() - self.assertTrue(hasattr(sensitive, 'load_module')) + self.assertIsNotNone(sensitive) self.assertIn(self.name, sensitive.get_filename(self.name)) - self.assertTrue(hasattr(insensitive, 'load_module')) + self.assertIsNotNone(insensitive) self.assertIn(self.name, insensitive.get_filename(self.name)) +class CaseSensitivityTestPEP302(CaseSensitivityTest): + def find(self, finder): + return finder.find_module(self.name) + +Frozen_CaseSensitivityTestPEP302, Source_CaseSensitivityTestPEP302 = util.test_both( + CaseSensitivityTestPEP302, importlib=importlib, machinery=machinery) + +class CaseSensitivityTestPEP451(CaseSensitivityTest): + def find(self, finder): + found = finder.find_spec(self.name) + return found.loader if found is not None else found -def test_main(): - test_support.run_unittest(CaseSensitivityTest) +Frozen_CaseSensitivityTestPEP451, Source_CaseSensitivityTestPEP451 = util.test_both( + CaseSensitivityTestPEP451, importlib=importlib, machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py index d59969a8ce..25a3dae672 100644 --- a/Lib/test/test_importlib/source/test_file_loader.py +++ b/Lib/test/test_importlib/source/test_file_loader.py @@ -1,52 +1,52 @@ -from importlib import machinery -import importlib -import importlib.abc from .. import abc from .. import util from . import util as source_util +importlib = util.import_importlib('importlib') +importlib_abc = util.import_importlib('importlib.abc') +machinery = util.import_importlib('importlib.machinery') +importlib_util = util.import_importlib('importlib.util') + import errno -import imp import marshal import os import py_compile import shutil import stat import sys +import types import unittest +import warnings -from test.support import make_legacy_pyc +from test.support import make_legacy_pyc, unload -class SimpleTest(unittest.TestCase): +class SimpleTest(abc.LoaderTests): """Should have no issue importing a source module [basic]. And if there is a syntax error, it should raise a SyntaxError [syntax error]. """ - def test_load_module_API(self): - # If fullname is not specified that assume self.name is desired. - class TesterMixin(importlib.abc.Loader): - def load_module(self, fullname): return fullname - def module_repr(self, module): return '<module>' + def setUp(self): + self.name = 'spam' + self.filepath = os.path.join('ham', self.name + '.py') + self.loader = self.machinery.SourceFileLoader(self.name, self.filepath) - class Tester(importlib.abc.FileLoader, TesterMixin): - def get_code(self, _): pass - def get_source(self, _): pass - def is_package(self, _): pass + def test_load_module_API(self): + class Tester(self.abc.FileLoader): + def get_source(self, _): return 'attr = 42' + def is_package(self, _): return False - name = 'mod_name' - loader = Tester(name, 'some_path') - self.assertEqual(name, loader.load_module()) - self.assertEqual(name, loader.load_module(None)) - self.assertEqual(name, loader.load_module(name)) - with self.assertRaises(ImportError): - loader.load_module(loader.name + 'XXX') + loader = Tester('blah', 'blah.py') + self.addCleanup(unload, 'blah') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module() # Should not raise an exception. def test_get_filename_API(self): # If fullname is not set then assume self.path is desired. - class Tester(importlib.abc.FileLoader): + class Tester(self.abc.FileLoader): def get_code(self, _): pass def get_source(self, _): pass def is_package(self, _): pass @@ -61,11 +61,21 @@ class SimpleTest(unittest.TestCase): with self.assertRaises(ImportError): loader.get_filename(name + 'XXX') + def test_equality(self): + other = self.machinery.SourceFileLoader(self.name, self.filepath) + self.assertEqual(self.loader, other) + + def test_inequality(self): + other = self.machinery.SourceFileLoader('_' + self.name, self.filepath) + self.assertNotEqual(self.loader, other) + # [basic] def test_module(self): with source_util.create_modules('_temp') as mapping: - loader = machinery.SourceFileLoader('_temp', mapping['_temp']) - module = loader.load_module('_temp') + loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module('_temp') self.assertIn('_temp', sys.modules) check = {'__name__': '_temp', '__file__': mapping['_temp'], '__package__': ''} @@ -74,9 +84,11 @@ class SimpleTest(unittest.TestCase): def test_package(self): with source_util.create_modules('_pkg.__init__') as mapping: - loader = machinery.SourceFileLoader('_pkg', + loader = self.machinery.SourceFileLoader('_pkg', mapping['_pkg.__init__']) - module = loader.load_module('_pkg') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module('_pkg') self.assertIn('_pkg', sys.modules) check = {'__name__': '_pkg', '__file__': mapping['_pkg.__init__'], '__path__': [os.path.dirname(mapping['_pkg.__init__'])], @@ -87,9 +99,11 @@ class SimpleTest(unittest.TestCase): def test_lacking_parent(self): with source_util.create_modules('_pkg.__init__', '_pkg.mod')as mapping: - loader = machinery.SourceFileLoader('_pkg.mod', + loader = self.machinery.SourceFileLoader('_pkg.mod', mapping['_pkg.mod']) - module = loader.load_module('_pkg.mod') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module('_pkg.mod') self.assertIn('_pkg.mod', sys.modules) check = {'__name__': '_pkg.mod', '__file__': mapping['_pkg.mod'], '__package__': '_pkg'} @@ -102,13 +116,17 @@ class SimpleTest(unittest.TestCase): def test_module_reuse(self): with source_util.create_modules('_temp') as mapping: - loader = machinery.SourceFileLoader('_temp', mapping['_temp']) - module = loader.load_module('_temp') + loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module('_temp') module_id = id(module) module_dict_id = id(module.__dict__) with open(mapping['_temp'], 'w') as file: file.write("testing_var = 42\n") - module = loader.load_module('_temp') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module('_temp') self.assertIn('testing_var', module.__dict__, "'testing_var' not in " "{0}".format(list(module.__dict__.keys()))) @@ -122,14 +140,20 @@ class SimpleTest(unittest.TestCase): value = '<test>' name = '_temp' with source_util.create_modules(name) as mapping: - orig_module = imp.new_module(name) + orig_module = types.ModuleType(name) for attr in attributes: setattr(orig_module, attr, value) with open(mapping[name], 'w') as file: file.write('+++ bad syntax +++') - loader = machinery.SourceFileLoader('_temp', mapping['_temp']) + loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) + with self.assertRaises(SyntaxError): + loader.exec_module(orig_module) + for attr in attributes: + self.assertEqual(getattr(orig_module, attr), value) with self.assertRaises(SyntaxError): - loader.load_module(name) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + loader.load_module(name) for attr in attributes: self.assertEqual(getattr(orig_module, attr), value) @@ -138,9 +162,11 @@ class SimpleTest(unittest.TestCase): with source_util.create_modules('_temp') as mapping: with open(mapping['_temp'], 'w') as file: file.write('=') - loader = machinery.SourceFileLoader('_temp', mapping['_temp']) + loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) with self.assertRaises(SyntaxError): - loader.load_module('_temp') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + loader.load_module('_temp') self.assertNotIn('_temp', sys.modules) def test_file_from_empty_string_dir(self): @@ -151,14 +177,16 @@ class SimpleTest(unittest.TestCase): file.write("# test file for importlib") try: with util.uncache('_temp'): - loader = machinery.SourceFileLoader('_temp', file_path) - mod = loader.load_module('_temp') + loader = self.machinery.SourceFileLoader('_temp', file_path) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + mod = loader.load_module('_temp') self.assertEqual(file_path, mod.__file__) - self.assertEqual(imp.cache_from_source(file_path), + self.assertEqual(self.util.cache_from_source(file_path), mod.__cached__) finally: os.unlink(file_path) - pycache = os.path.dirname(imp.cache_from_source(file_path)) + pycache = os.path.dirname(self.util.cache_from_source(file_path)) if os.path.exists(pycache): shutil.rmtree(pycache) @@ -167,7 +195,7 @@ class SimpleTest(unittest.TestCase): # truncated rather than raise an OverflowError. with source_util.create_modules('_temp') as mapping: source = mapping['_temp'] - compiled = imp.cache_from_source(source) + compiled = self.util.cache_from_source(source) with open(source, 'w') as f: f.write("x = 5") try: @@ -178,20 +206,48 @@ class SimpleTest(unittest.TestCase): if e.errno != getattr(errno, 'EOVERFLOW', None): raise self.skipTest("cannot set modification time to large integer ({})".format(e)) - loader = machinery.SourceFileLoader('_temp', mapping['_temp']) - mod = loader.load_module('_temp') + loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) + # PEP 451 + module = types.ModuleType('_temp') + module.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(module) + self.assertEqual(module.x, 5) + self.assertTrue(os.path.exists(compiled)) + os.unlink(compiled) + # PEP 302 + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + mod = loader.load_module('_temp') # XXX # Sanity checks. self.assertEqual(mod.__cached__, compiled) self.assertEqual(mod.x, 5) # The pyc file was created. - os.stat(compiled) + self.assertTrue(os.path.exists(compiled)) + + def test_unloadable(self): + loader = self.machinery.SourceFileLoader('good name', {}) + module = types.ModuleType('bad name') + module.__spec__ = self.machinery.ModuleSpec('bad name', loader) + with self.assertRaises(ImportError): + loader.exec_module(module) + with self.assertRaises(ImportError): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + loader.load_module('bad name') + +Frozen_SimpleTest, Source_SimpleTest = util.test_both( + SimpleTest, importlib=importlib, machinery=machinery, abc=importlib_abc, + util=importlib_util) -class BadBytecodeTest(unittest.TestCase): +class BadBytecodeTest: def import_(self, file, module_name): loader = self.loader(module_name, file) - module = loader.load_module(module_name) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + # XXX Change to use exec_module(). + module = loader.load_module(module_name) self.assertIn(module_name, sys.modules) def manipulate_bytecode(self, name, mapping, manipulator, *, @@ -204,7 +260,7 @@ class BadBytecodeTest(unittest.TestCase): pass py_compile.compile(mapping[name]) if not del_source: - bytecode_path = imp.cache_from_source(mapping[name]) + bytecode_path = self.util.cache_from_source(mapping[name]) else: os.unlink(mapping[name]) bytecode_path = make_legacy_pyc(mapping[name]) @@ -290,10 +346,29 @@ class BadBytecodeTest(unittest.TestCase): lambda bc: b'\x00\x00\x00\x00' + bc[4:]) test('_temp', mapping, bc_path) +class BadBytecodeTestPEP451(BadBytecodeTest): + + def import_(self, file, module_name): + loader = self.loader(module_name, file) + module = types.ModuleType(module_name) + module.__spec__ = self.util.spec_from_loader(module_name, loader) + loader.exec_module(module) + +class BadBytecodeTestPEP302(BadBytecodeTest): + + def import_(self, file, module_name): + loader = self.loader(module_name, file) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = loader.load_module(module_name) + self.assertIn(module_name, sys.modules) -class SourceLoaderBadBytecodeTest(BadBytecodeTest): - loader = machinery.SourceFileLoader +class SourceLoaderBadBytecodeTest: + + @classmethod + def setUpClass(cls): + cls.loader = cls.machinery.SourceFileLoader @source_util.writes_bytecode_files def test_empty_file(self): @@ -332,7 +407,8 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest): def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as bytecode_file: - self.assertEqual(bytecode_file.read(4), imp.get_magic()) + self.assertEqual(bytecode_file.read(4), + self.util.MAGIC_NUMBER) self._test_bad_magic(test) @@ -382,13 +458,13 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest): zeros = b'\x00\x00\x00\x00' with source_util.create_modules('_temp') as mapping: py_compile.compile(mapping['_temp']) - bytecode_path = imp.cache_from_source(mapping['_temp']) + bytecode_path = self.util.cache_from_source(mapping['_temp']) with open(bytecode_path, 'r+b') as bytecode_file: bytecode_file.seek(4) bytecode_file.write(zeros) self.import_(mapping['_temp'], '_temp') source_mtime = os.path.getmtime(mapping['_temp']) - source_timestamp = importlib._w_long(source_mtime) + source_timestamp = self.importlib._w_long(source_mtime) with open(bytecode_path, 'rb') as bytecode_file: bytecode_file.seek(4) self.assertEqual(bytecode_file.read(4), source_timestamp) @@ -400,7 +476,7 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest): with source_util.create_modules('_temp') as mapping: # Create bytecode that will need to be re-created. py_compile.compile(mapping['_temp']) - bytecode_path = imp.cache_from_source(mapping['_temp']) + bytecode_path = self.util.cache_from_source(mapping['_temp']) with open(bytecode_path, 'r+b') as bytecode_file: bytecode_file.seek(0) bytecode_file.write(b'\x00\x00\x00\x00') @@ -408,16 +484,34 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest): os.chmod(bytecode_path, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) try: - # Should not raise IOError! + # Should not raise OSError! self.import_(mapping['_temp'], '_temp') finally: # Make writable for eventual clean-up. os.chmod(bytecode_path, stat.S_IWUSR) +class SourceLoaderBadBytecodeTestPEP451( + SourceLoaderBadBytecodeTest, BadBytecodeTestPEP451): + pass + +Frozen_SourceBadBytecodePEP451, Source_SourceBadBytecodePEP451 = util.test_both( + SourceLoaderBadBytecodeTestPEP451, importlib=importlib, machinery=machinery, + abc=importlib_abc, util=importlib_util) + +class SourceLoaderBadBytecodeTestPEP302( + SourceLoaderBadBytecodeTest, BadBytecodeTestPEP302): + pass + +Frozen_SourceBadBytecodePEP302, Source_SourceBadBytecodePEP302 = util.test_both( + SourceLoaderBadBytecodeTestPEP302, importlib=importlib, machinery=machinery, + abc=importlib_abc, util=importlib_util) -class SourcelessLoaderBadBytecodeTest(BadBytecodeTest): - loader = machinery.SourcelessFileLoader +class SourcelessLoaderBadBytecodeTest: + + @classmethod + def setUpClass(cls): + cls.loader = cls.machinery.SourcelessFileLoader def test_empty_file(self): def test(name, mapping, bytecode_path): @@ -472,14 +566,22 @@ class SourcelessLoaderBadBytecodeTest(BadBytecodeTest): def test_non_code_marshal(self): self._test_non_code_marshal(del_source=True) +class SourcelessLoaderBadBytecodeTestPEP451(SourcelessLoaderBadBytecodeTest, + BadBytecodeTestPEP451): + pass + +Frozen_SourcelessBadBytecodePEP451, Source_SourcelessBadBytecodePEP451 = util.test_both( + SourcelessLoaderBadBytecodeTestPEP451, importlib=importlib, + machinery=machinery, abc=importlib_abc, util=importlib_util) + +class SourcelessLoaderBadBytecodeTestPEP302(SourcelessLoaderBadBytecodeTest, + BadBytecodeTestPEP302): + pass -def test_main(): - from test.support import run_unittest - run_unittest(SimpleTest, - SourceLoaderBadBytecodeTest, - SourcelessLoaderBadBytecodeTest - ) +Frozen_SourcelessBadBytecodePEP302, Source_SourcelessBadBytecodePEP302 = util.test_both( + SourcelessLoaderBadBytecodeTestPEP302, importlib=importlib, + machinery=machinery, abc=importlib_abc, util=importlib_util) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py index 8e4986835d..473297bdc8 100644 --- a/Lib/test/test_importlib/source/test_finder.py +++ b/Lib/test/test_importlib/source/test_finder.py @@ -1,9 +1,10 @@ from .. import abc +from .. import util from . import util as source_util -from importlib import machinery +machinery = util.import_importlib('importlib.machinery') + import errno -import imp import os import py_compile import stat @@ -39,14 +40,15 @@ class FinderTests(abc.FinderTests): """ def get_finder(self, root): - loader_details = [(machinery.SourceFileLoader, - machinery.SOURCE_SUFFIXES), - (machinery.SourcelessFileLoader, - machinery.BYTECODE_SUFFIXES)] - return machinery.FileFinder(root, *loader_details) + loader_details = [(self.machinery.SourceFileLoader, + self.machinery.SOURCE_SUFFIXES), + (self.machinery.SourcelessFileLoader, + self.machinery.BYTECODE_SUFFIXES)] + return self.machinery.FileFinder(root, *loader_details) def import_(self, root, module): - return self.get_finder(root).find_module(module) + finder = self.get_finder(root) + return self._find(finder, module, loader_only=True) def run_test(self, test, create=None, *, compile_=None, unlink=None): """Test the finding of 'test' with the creation of modules listed in @@ -124,20 +126,20 @@ class FinderTests(abc.FinderTests): def test_empty_string_for_dir(self): # The empty string from sys.path means to search in the cwd. - finder = machinery.FileFinder('', (machinery.SourceFileLoader, - machinery.SOURCE_SUFFIXES)) + finder = self.machinery.FileFinder('', (self.machinery.SourceFileLoader, + self.machinery.SOURCE_SUFFIXES)) with open('mod.py', 'w') as file: file.write("# test file for importlib") try: - loader = finder.find_module('mod') + loader = self._find(finder, 'mod', loader_only=True) self.assertTrue(hasattr(loader, 'load_module')) finally: os.unlink('mod.py') def test_invalidate_caches(self): # invalidate_caches() should reset the mtime. - finder = machinery.FileFinder('', (machinery.SourceFileLoader, - machinery.SOURCE_SUFFIXES)) + finder = self.machinery.FileFinder('', (self.machinery.SourceFileLoader, + self.machinery.SOURCE_SUFFIXES)) finder._path_mtime = 42 finder.invalidate_caches() self.assertEqual(finder._path_mtime, -1) @@ -147,8 +149,10 @@ class FinderTests(abc.FinderTests): mod = 'mod' with source_util.create_modules(mod) as mapping: finder = self.get_finder(mapping['.root']) - self.assertIsNotNone(finder.find_module(mod)) - self.assertIsNone(finder.find_module(mod)) + found = self._find(finder, 'mod', loader_only=True) + self.assertIsNotNone(found) + found = self._find(finder, 'mod', loader_only=True) + self.assertIsNone(found) @unittest.skipUnless(sys.platform != 'win32', 'os.chmod() does not support the needed arguments under Windows') @@ -172,20 +176,57 @@ class FinderTests(abc.FinderTests): self.addCleanup(cleanup, tempdir) os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR) finder = self.get_finder(tempdir.name) - self.assertEqual((None, []), finder.find_loader('doesnotexist')) + found = self._find(finder, 'doesnotexist') + self.assertEqual(found, self.NOT_FOUND) def test_ignore_file(self): # If a directory got changed to a file from underneath us, then don't # worry about looking for submodules. with tempfile.NamedTemporaryFile() as file_obj: finder = self.get_finder(file_obj.name) - self.assertEqual((None, []), finder.find_loader('doesnotexist')) + found = self._find(finder, 'doesnotexist') + self.assertEqual(found, self.NOT_FOUND) + + +class FinderTestsPEP451(FinderTests): + + NOT_FOUND = None + + def _find(self, finder, name, loader_only=False): + spec = finder.find_spec(name) + return spec.loader if spec is not None else spec + +Frozen_FinderTestsPEP451, Source_FinderTestsPEP451 = util.test_both( + FinderTestsPEP451, machinery=machinery) + + +class FinderTestsPEP420(FinderTests): + + NOT_FOUND = (None, []) + + def _find(self, finder, name, loader_only=False): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + loader_portions = finder.find_loader(name) + return loader_portions[0] if loader_only else loader_portions + +Frozen_FinderTestsPEP420, Source_FinderTestsPEP420 = util.test_both( + FinderTestsPEP420, machinery=machinery) + + +class FinderTestsPEP302(FinderTests): + + NOT_FOUND = None + + def _find(self, finder, name, loader_only=False): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return finder.find_module(name) +Frozen_FinderTestsPEP302, Source_FinderTestsPEP302 = util.test_both( + FinderTestsPEP302, machinery=machinery) -def test_main(): - from test.support import run_unittest - run_unittest(FinderTests) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/source/test_path_hook.py b/Lib/test/test_importlib/source/test_path_hook.py index 6a78792f07..92da77265b 100644 --- a/Lib/test/test_importlib/source/test_path_hook.py +++ b/Lib/test/test_importlib/source/test_path_hook.py @@ -1,17 +1,18 @@ +from .. import util from . import util as source_util -from importlib import machinery -import imp +machinery = util.import_importlib('importlib.machinery') + import unittest -class PathHookTest(unittest.TestCase): +class PathHookTest: """Test the path hook for source.""" def path_hook(self): - return machinery.FileFinder.path_hook((machinery.SourceFileLoader, - machinery.SOURCE_SUFFIXES)) + return self.machinery.FileFinder.path_hook((self.machinery.SourceFileLoader, + self.machinery.SOURCE_SUFFIXES)) def test_success(self): with source_util.create_modules('dummy') as mapping: @@ -22,11 +23,8 @@ class PathHookTest(unittest.TestCase): # The empty string represents the cwd. self.assertTrue(hasattr(self.path_hook()(''), 'find_module')) - -def test_main(): - from test.support import run_unittest - run_unittest(PathHookTest) +Frozen_PathHookTest, Source_PathHooktest = util.test_both(PathHookTest, machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/source/test_source_encoding.py b/Lib/test/test_importlib/source/test_source_encoding.py index ba02b44274..c62dfa108d 100644 --- a/Lib/test/test_importlib/source/test_source_encoding.py +++ b/Lib/test/test_importlib/source/test_source_encoding.py @@ -1,19 +1,24 @@ +from .. import util from . import util as source_util -from importlib import _bootstrap +machinery = util.import_importlib('importlib.machinery') + import codecs +import importlib.util import re import sys +import types # Because sys.path gets essentially blanked, need to have unicodedata already # imported for the parser to use. import unicodedata import unittest +import warnings CODING_RE = re.compile(r'^[ \t\f]*#.*coding[:=][ \t]*([-\w.]+)', re.ASCII) -class EncodingTest(unittest.TestCase): +class EncodingTest: """PEP 3120 makes UTF-8 the default encoding for source code [default encoding]. @@ -35,9 +40,9 @@ class EncodingTest(unittest.TestCase): with source_util.create_modules(self.module_name) as mapping: with open(mapping[self.module_name], 'wb') as file: file.write(source) - loader = _bootstrap.SourceFileLoader(self.module_name, + loader = self.machinery.SourceFileLoader(self.module_name, mapping[self.module_name]) - return loader.load_module(self.module_name) + return self.load(loader) def create_source(self, encoding): encoding_line = "# coding={0}".format(encoding) @@ -84,8 +89,29 @@ class EncodingTest(unittest.TestCase): with self.assertRaises(SyntaxError): self.run_test(source) +class EncodingTestPEP451(EncodingTest): + + def load(self, loader): + module = types.ModuleType(self.module_name) + module.__spec__ = importlib.util.spec_from_loader(self.module_name, loader) + loader.exec_module(module) + return module + +Frozen_EncodingTestPEP451, Source_EncodingTestPEP451 = util.test_both( + EncodingTestPEP451, machinery=machinery) + +class EncodingTestPEP302(EncodingTest): + + def load(self, loader): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return loader.load_module(self.module_name) + +Frozen_EncodingTestPEP302, Source_EncodingTestPEP302 = util.test_both( + EncodingTestPEP302, machinery=machinery) -class LineEndingTest(unittest.TestCase): + +class LineEndingTest: r"""Source written with the three types of line endings (\n, \r\n, \r) need to be readable [cr][crlf][lf].""" @@ -97,9 +123,9 @@ class LineEndingTest(unittest.TestCase): with source_util.create_modules(module_name) as mapping: with open(mapping[module_name], 'wb') as file: file.write(source) - loader = _bootstrap.SourceFileLoader(module_name, - mapping[module_name]) - return loader.load_module(module_name) + loader = self.machinery.SourceFileLoader(module_name, + mapping[module_name]) + return self.load(loader, module_name) # [cr] def test_cr(self): @@ -113,11 +139,27 @@ class LineEndingTest(unittest.TestCase): def test_lf(self): self.run_test(b'\n') +class LineEndingTestPEP451(LineEndingTest): + + def load(self, loader, module_name): + module = types.ModuleType(module_name) + module.__spec__ = importlib.util.spec_from_loader(module_name, loader) + loader.exec_module(module) + return module + +Frozen_LineEndingTestPEP451, Source_LineEndingTestPEP451 = util.test_both( + LineEndingTestPEP451, machinery=machinery) + +class LineEndingTestPEP302(LineEndingTest): + + def load(self, loader, module_name): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return loader.load_module(module_name) -def test_main(): - from test.support import run_unittest - run_unittest(EncodingTest, LineEndingTest) +Frozen_LineEndingTestPEP302, Source_LineEndingTestPEP302 = util.test_both( + LineEndingTestPEP302, machinery=machinery) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/source/util.py b/Lib/test/test_importlib/source/util.py index ae65663a67..63cd25adc4 100644 --- a/Lib/test/test_importlib/source/util.py +++ b/Lib/test/test_importlib/source/util.py @@ -2,7 +2,6 @@ from .. import util import contextlib import errno import functools -import imp import os import os.path import sys diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py index c620c3771b..09b72949ca 100644 --- a/Lib/test/test_importlib/test_abc.py +++ b/Lib/test/test_importlib/test_abc.py @@ -1,9 +1,23 @@ -from importlib import abc -from importlib import machinery +import contextlib import inspect +import io +import marshal +import os +import sys +from test import support +import types import unittest +from unittest import mock +import warnings +from . import util +frozen_init, source_init = util.import_importlib('importlib') +frozen_abc, source_abc = util.import_importlib('importlib.abc') +machinery = util.import_importlib('importlib.machinery') +frozen_util, source_util = util.import_importlib('importlib.util') + +##### Inheritance ############################################################## class InheritanceTests: """Test that the specified class is a subclass/superclass of the expected @@ -14,8 +28,19 @@ class InheritanceTests: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.superclasses = [getattr(self.abc, class_name) + for class_name in self.superclass_names] + if hasattr(self, 'subclass_names'): + # Because test.support.import_fresh_module() creates a new + # importlib._bootstrap per module, inheritance checks fail when + # checking across module boundaries (i.e. the _bootstrap in abc is + # not the same as the one in machinery). That means stealing one of + # the modules from the other to make sure the same instance is used. + self.subclasses = [getattr(self.abc.machinery, class_name) + for class_name in self.subclass_names] assert self.subclasses or self.superclasses, self.__class__ - self.__test = getattr(abc, self.__class__.__name__) + testing = self.__class__.__name__.partition('_')[2] + self.__test = getattr(self.abc, testing) def test_subclasses(self): # Test that the expected subclasses inherit. @@ -29,75 +54,898 @@ class InheritanceTests: self.assertTrue(issubclass(self.__test, superclass), "{0} is not a superclass of {1}".format(superclass, self.__test)) +def create_inheritance_tests(base_class): + def set_frozen(ns): + ns['abc'] = frozen_abc + def set_source(ns): + ns['abc'] = source_abc -class MetaPathFinder(InheritanceTests, unittest.TestCase): + classes = [] + for prefix, ns_set in [('Frozen', set_frozen), ('Source', set_source)]: + classes.append(types.new_class('_'.join([prefix, base_class.__name__]), + (base_class, unittest.TestCase), + exec_body=ns_set)) + return classes - superclasses = [abc.Finder] - subclasses = [machinery.BuiltinImporter, machinery.FrozenImporter, - machinery.PathFinder, machinery.WindowsRegistryFinder] +class MetaPathFinder(InheritanceTests): + superclass_names = ['Finder'] + subclass_names = ['BuiltinImporter', 'FrozenImporter', 'PathFinder', + 'WindowsRegistryFinder'] -class PathEntryFinder(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(MetaPathFinder) +Frozen_MetaPathFinderInheritanceTests, Source_MetaPathFinderInheritanceTests = tests - superclasses = [abc.Finder] - subclasses = [machinery.FileFinder] +class PathEntryFinder(InheritanceTests): + superclass_names = ['Finder'] + subclass_names = ['FileFinder'] -class Loader(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(PathEntryFinder) +Frozen_PathEntryFinderInheritanceTests, Source_PathEntryFinderInheritanceTests = tests - subclasses = [abc.PyLoader] +class ResourceLoader(InheritanceTests): + superclass_names = ['Loader'] -class ResourceLoader(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(ResourceLoader) +Frozen_ResourceLoaderInheritanceTests, Source_ResourceLoaderInheritanceTests = tests - superclasses = [abc.Loader] +class InspectLoader(InheritanceTests): + superclass_names = ['Loader'] + subclass_names = ['BuiltinImporter', 'FrozenImporter', 'ExtensionFileLoader'] -class InspectLoader(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(InspectLoader) +Frozen_InspectLoaderInheritanceTests, Source_InspectLoaderInheritanceTests = tests - superclasses = [abc.Loader] - subclasses = [abc.PyLoader, machinery.BuiltinImporter, - machinery.FrozenImporter, machinery.ExtensionFileLoader] +class ExecutionLoader(InheritanceTests): + superclass_names = ['InspectLoader'] + subclass_names = ['ExtensionFileLoader'] -class ExecutionLoader(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(ExecutionLoader) +Frozen_ExecutionLoaderInheritanceTests, Source_ExecutionLoaderInheritanceTests = tests - superclasses = [abc.InspectLoader] - subclasses = [abc.PyLoader] +class FileLoader(InheritanceTests): + superclass_names = ['ResourceLoader', 'ExecutionLoader'] + subclass_names = ['SourceFileLoader', 'SourcelessFileLoader'] -class FileLoader(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(FileLoader) +Frozen_FileLoaderInheritanceTests, Source_FileLoaderInheritanceTests = tests - superclasses = [abc.ResourceLoader, abc.ExecutionLoader] - subclasses = [machinery.SourceFileLoader, machinery.SourcelessFileLoader] +class SourceLoader(InheritanceTests): + superclass_names = ['ResourceLoader', 'ExecutionLoader'] + subclass_names = ['SourceFileLoader'] -class SourceLoader(InheritanceTests, unittest.TestCase): +tests = create_inheritance_tests(SourceLoader) +Frozen_SourceLoaderInheritanceTests, Source_SourceLoaderInheritanceTests = tests - superclasses = [abc.ResourceLoader, abc.ExecutionLoader] - subclasses = [machinery.SourceFileLoader] +##### Default return values #################################################### +def make_abc_subclasses(base_class): + classes = [] + for kind, abc in [('Frozen', frozen_abc), ('Source', source_abc)]: + name = '_'.join([kind, base_class.__name__]) + base_classes = base_class, getattr(abc, base_class.__name__) + classes.append(types.new_class(name, base_classes)) + return classes +def make_return_value_tests(base_class, test_class): + frozen_class, source_class = make_abc_subclasses(base_class) + tests = [] + for prefix, class_in_test in [('Frozen', frozen_class), ('Source', source_class)]: + def set_ns(ns): + ns['ins'] = class_in_test() + tests.append(types.new_class('_'.join([prefix, test_class.__name__]), + (test_class, unittest.TestCase), + exec_body=set_ns)) + return tests -class PyLoader(InheritanceTests, unittest.TestCase): - superclasses = [abc.Loader, abc.ResourceLoader, abc.ExecutionLoader] +class MetaPathFinder: + def find_module(self, fullname, path): + return super().find_module(fullname, path) -class PyPycLoader(InheritanceTests, unittest.TestCase): +Frozen_MPF, Source_MPF = make_abc_subclasses(MetaPathFinder) - superclasses = [abc.PyLoader] +class MetaPathFinderDefaultsTests: -def test_main(): - from test.support import run_unittest - classes = [] - for class_ in globals().values(): - if (inspect.isclass(class_) and - issubclass(class_, unittest.TestCase) and - issubclass(class_, InheritanceTests)): - classes.append(class_) - run_unittest(*classes) + def test_find_module(self): + # Default should return None. + self.assertIsNone(self.ins.find_module('something', None)) + + def test_invalidate_caches(self): + # Calling the method is a no-op. + self.ins.invalidate_caches() + + +tests = make_return_value_tests(MetaPathFinder, MetaPathFinderDefaultsTests) +Frozen_MPFDefaultTests, Source_MPFDefaultTests = tests + + +class PathEntryFinder: + + def find_loader(self, fullname): + return super().find_loader(fullname) + +Frozen_PEF, Source_PEF = make_abc_subclasses(PathEntryFinder) + + +class PathEntryFinderDefaultsTests: + + def test_find_loader(self): + self.assertEqual((None, []), self.ins.find_loader('something')) + + def find_module(self): + self.assertEqual(None, self.ins.find_module('something')) + + def test_invalidate_caches(self): + # Should be a no-op. + self.ins.invalidate_caches() + + +tests = make_return_value_tests(PathEntryFinder, PathEntryFinderDefaultsTests) +Frozen_PEFDefaultTests, Source_PEFDefaultTests = tests + + +class Loader: + + def load_module(self, fullname): + return super().load_module(fullname) + + +Frozen_L, Source_L = make_abc_subclasses(Loader) + + +class LoaderDefaultsTests: + + def test_load_module(self): + with self.assertRaises(ImportError): + self.ins.load_module('something') + + def test_module_repr(self): + mod = types.ModuleType('blah') + with self.assertRaises(NotImplementedError): + self.ins.module_repr(mod) + original_repr = repr(mod) + mod.__loader__ = self.ins + # Should still return a proper repr. + self.assertTrue(repr(mod)) + + +tests = make_return_value_tests(Loader, LoaderDefaultsTests) +Frozen_LDefaultTests, SourceLDefaultTests = tests + + +class ResourceLoader(Loader): + + def get_data(self, path): + return super().get_data(path) + + +Frozen_RL, Source_RL = make_abc_subclasses(ResourceLoader) + + +class ResourceLoaderDefaultsTests: + + def test_get_data(self): + with self.assertRaises(IOError): + self.ins.get_data('/some/path') + + +tests = make_return_value_tests(ResourceLoader, ResourceLoaderDefaultsTests) +Frozen_RLDefaultTests, Source_RLDefaultTests = tests + + +class InspectLoader(Loader): + + def is_package(self, fullname): + return super().is_package(fullname) + + def get_source(self, fullname): + return super().get_source(fullname) + + +Frozen_IL, Source_IL = make_abc_subclasses(InspectLoader) + + +class InspectLoaderDefaultsTests: + + def test_is_package(self): + with self.assertRaises(ImportError): + self.ins.is_package('blah') + + def test_get_source(self): + with self.assertRaises(ImportError): + self.ins.get_source('blah') + + +tests = make_return_value_tests(InspectLoader, InspectLoaderDefaultsTests) +Frozen_ILDefaultTests, Source_ILDefaultTests = tests + + +class ExecutionLoader(InspectLoader): + + def get_filename(self, fullname): + return super().get_filename(fullname) + +Frozen_EL, Source_EL = make_abc_subclasses(ExecutionLoader) + + +class ExecutionLoaderDefaultsTests: + + def test_get_filename(self): + with self.assertRaises(ImportError): + self.ins.get_filename('blah') + + +tests = make_return_value_tests(ExecutionLoader, InspectLoaderDefaultsTests) +Frozen_ELDefaultTests, Source_ELDefaultsTests = tests + +##### MetaPathFinder concrete methods ########################################## + +class MetaPathFinderFindModuleTests: + + @classmethod + def finder(cls, spec): + class MetaPathSpecFinder(cls.abc.MetaPathFinder): + + def find_spec(self, fullname, path, target=None): + self.called_for = fullname, path + return spec + + return MetaPathSpecFinder() + + def test_no_spec(self): + finder = self.finder(None) + path = ['a', 'b', 'c'] + name = 'blah' + found = finder.find_module(name, path) + self.assertIsNone(found) + self.assertEqual(name, finder.called_for[0]) + self.assertEqual(path, finder.called_for[1]) + + def test_spec(self): + loader = object() + spec = self.util.spec_from_loader('blah', loader) + finder = self.finder(spec) + found = finder.find_module('blah', None) + self.assertIs(found, spec.loader) + + +Frozen_MPFFindModuleTests, Source_MPFFindModuleTests = util.test_both( + MetaPathFinderFindModuleTests, + abc=(frozen_abc, source_abc), + util=(frozen_util, source_util)) + +##### PathEntryFinder concrete methods ######################################### + +class PathEntryFinderFindLoaderTests: + + @classmethod + def finder(cls, spec): + class PathEntrySpecFinder(cls.abc.PathEntryFinder): + + def find_spec(self, fullname, target=None): + self.called_for = fullname + return spec + + return PathEntrySpecFinder() + + def test_no_spec(self): + finder = self.finder(None) + name = 'blah' + found = finder.find_loader(name) + self.assertIsNone(found[0]) + self.assertEqual([], found[1]) + self.assertEqual(name, finder.called_for) + + def test_spec_with_loader(self): + loader = object() + spec = self.util.spec_from_loader('blah', loader) + finder = self.finder(spec) + found = finder.find_loader('blah') + self.assertIs(found[0], spec.loader) + + def test_spec_with_portions(self): + spec = self.machinery.ModuleSpec('blah', None) + paths = ['a', 'b', 'c'] + spec.submodule_search_locations = paths + finder = self.finder(spec) + found = finder.find_loader('blah') + self.assertIsNone(found[0]) + self.assertEqual(paths, found[1]) + + +Frozen_PEFFindLoaderTests, Source_PEFFindLoaderTests = util.test_both( + PathEntryFinderFindLoaderTests, + abc=(frozen_abc, source_abc), + machinery=machinery, + util=(frozen_util, source_util)) + + +##### Loader concrete methods ################################################## +class LoaderLoadModuleTests: + + def loader(self): + class SpecLoader(self.abc.Loader): + found = None + def exec_module(self, module): + self.found = module + + def is_package(self, fullname): + """Force some non-default module state to be set.""" + return True + + return SpecLoader() + + def test_fresh(self): + loader = self.loader() + name = 'blah' + with util.uncache(name): + loader.load_module(name) + module = loader.found + self.assertIs(sys.modules[name], module) + self.assertEqual(loader, module.__loader__) + self.assertEqual(loader, module.__spec__.loader) + self.assertEqual(name, module.__name__) + self.assertEqual(name, module.__spec__.name) + self.assertIsNotNone(module.__path__) + self.assertIsNotNone(module.__path__, + module.__spec__.submodule_search_locations) + + def test_reload(self): + name = 'blah' + loader = self.loader() + module = types.ModuleType(name) + module.__spec__ = self.util.spec_from_loader(name, loader) + module.__loader__ = loader + with util.uncache(name): + sys.modules[name] = module + loader.load_module(name) + found = loader.found + self.assertIs(found, sys.modules[name]) + self.assertIs(module, sys.modules[name]) + + +Frozen_LoaderLoadModuleTests, Source_LoaderLoadModuleTests = util.test_both( + LoaderLoadModuleTests, + abc=(frozen_abc, source_abc), + util=(frozen_util, source_util)) + + +##### InspectLoader concrete methods ########################################### +class InspectLoaderSourceToCodeTests: + + def source_to_module(self, data, path=None): + """Help with source_to_code() tests.""" + module = types.ModuleType('blah') + loader = self.InspectLoaderSubclass() + if path is None: + code = loader.source_to_code(data) + else: + code = loader.source_to_code(data, path) + exec(code, module.__dict__) + return module + + def test_source_to_code_source(self): + # Since compile() can handle strings, so should source_to_code(). + source = 'attr = 42' + module = self.source_to_module(source) + self.assertTrue(hasattr(module, 'attr')) + self.assertEqual(module.attr, 42) + + def test_source_to_code_bytes(self): + # Since compile() can handle bytes, so should source_to_code(). + source = b'attr = 42' + module = self.source_to_module(source) + self.assertTrue(hasattr(module, 'attr')) + self.assertEqual(module.attr, 42) + + def test_source_to_code_path(self): + # Specifying a path should set it for the code object. + path = 'path/to/somewhere' + loader = self.InspectLoaderSubclass() + code = loader.source_to_code('', path) + self.assertEqual(code.co_filename, path) + + def test_source_to_code_no_path(self): + # Not setting a path should still work and be set to <string> since that + # is a pre-existing practice as a default to compile(). + loader = self.InspectLoaderSubclass() + code = loader.source_to_code('') + self.assertEqual(code.co_filename, '<string>') + + +class Frozen_ILSourceToCodeTests(InspectLoaderSourceToCodeTests, unittest.TestCase): + InspectLoaderSubclass = Frozen_IL + +class Source_ILSourceToCodeTests(InspectLoaderSourceToCodeTests, unittest.TestCase): + InspectLoaderSubclass = Source_IL + + +class InspectLoaderGetCodeTests: + + def test_get_code(self): + # Test success. + module = types.ModuleType('blah') + with mock.patch.object(self.InspectLoaderSubclass, 'get_source') as mocked: + mocked.return_value = 'attr = 42' + loader = self.InspectLoaderSubclass() + code = loader.get_code('blah') + exec(code, module.__dict__) + self.assertEqual(module.attr, 42) + + def test_get_code_source_is_None(self): + # If get_source() is None then this should be None. + with mock.patch.object(self.InspectLoaderSubclass, 'get_source') as mocked: + mocked.return_value = None + loader = self.InspectLoaderSubclass() + code = loader.get_code('blah') + self.assertIsNone(code) + + def test_get_code_source_not_found(self): + # If there is no source then there is no code object. + loader = self.InspectLoaderSubclass() + with self.assertRaises(ImportError): + loader.get_code('blah') + + +class Frozen_ILGetCodeTests(InspectLoaderGetCodeTests, unittest.TestCase): + InspectLoaderSubclass = Frozen_IL + +class Source_ILGetCodeTests(InspectLoaderGetCodeTests, unittest.TestCase): + InspectLoaderSubclass = Source_IL + + +class InspectLoaderLoadModuleTests: + + """Test InspectLoader.load_module().""" + + module_name = 'blah' + + def setUp(self): + support.unload(self.module_name) + self.addCleanup(support.unload, self.module_name) + + def mock_get_code(self): + return mock.patch.object(self.InspectLoaderSubclass, 'get_code') + + def test_get_code_ImportError(self): + # If get_code() raises ImportError, it should propagate. + with self.mock_get_code() as mocked_get_code: + mocked_get_code.side_effect = ImportError + with self.assertRaises(ImportError): + loader = self.InspectLoaderSubclass() + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + loader.load_module(self.module_name) + + def test_get_code_None(self): + # If get_code() returns None, raise ImportError. + with self.mock_get_code() as mocked_get_code: + mocked_get_code.return_value = None + with self.assertRaises(ImportError): + loader = self.InspectLoaderSubclass() + loader.load_module(self.module_name) + + def test_module_returned(self): + # The loaded module should be returned. + code = compile('attr = 42', '<string>', 'exec') + with self.mock_get_code() as mocked_get_code: + mocked_get_code.return_value = code + loader = self.InspectLoaderSubclass() + module = loader.load_module(self.module_name) + self.assertEqual(module, sys.modules[self.module_name]) + + +class Frozen_ILLoadModuleTests(InspectLoaderLoadModuleTests, unittest.TestCase): + InspectLoaderSubclass = Frozen_IL + +class Source_ILLoadModuleTests(InspectLoaderLoadModuleTests, unittest.TestCase): + InspectLoaderSubclass = Source_IL + + +##### ExecutionLoader concrete methods ######################################### +class ExecutionLoaderGetCodeTests: + + def mock_methods(self, *, get_source=False, get_filename=False): + source_mock_context, filename_mock_context = None, None + if get_source: + source_mock_context = mock.patch.object(self.ExecutionLoaderSubclass, + 'get_source') + if get_filename: + filename_mock_context = mock.patch.object(self.ExecutionLoaderSubclass, + 'get_filename') + return source_mock_context, filename_mock_context + + def test_get_code(self): + path = 'blah.py' + source_mock_context, filename_mock_context = self.mock_methods( + get_source=True, get_filename=True) + with source_mock_context as source_mock, filename_mock_context as name_mock: + source_mock.return_value = 'attr = 42' + name_mock.return_value = path + loader = self.ExecutionLoaderSubclass() + code = loader.get_code('blah') + self.assertEqual(code.co_filename, path) + module = types.ModuleType('blah') + exec(code, module.__dict__) + self.assertEqual(module.attr, 42) + + def test_get_code_source_is_None(self): + # If get_source() is None then this should be None. + source_mock_context, _ = self.mock_methods(get_source=True) + with source_mock_context as mocked: + mocked.return_value = None + loader = self.ExecutionLoaderSubclass() + code = loader.get_code('blah') + self.assertIsNone(code) + + def test_get_code_source_not_found(self): + # If there is no source then there is no code object. + loader = self.ExecutionLoaderSubclass() + with self.assertRaises(ImportError): + loader.get_code('blah') + + def test_get_code_no_path(self): + # If get_filename() raises ImportError then simply skip setting the path + # on the code object. + source_mock_context, filename_mock_context = self.mock_methods( + get_source=True, get_filename=True) + with source_mock_context as source_mock, filename_mock_context as name_mock: + source_mock.return_value = 'attr = 42' + name_mock.side_effect = ImportError + loader = self.ExecutionLoaderSubclass() + code = loader.get_code('blah') + self.assertEqual(code.co_filename, '<string>') + module = types.ModuleType('blah') + exec(code, module.__dict__) + self.assertEqual(module.attr, 42) + + +class Frozen_ELGetCodeTests(ExecutionLoaderGetCodeTests, unittest.TestCase): + ExecutionLoaderSubclass = Frozen_EL + +class Source_ELGetCodeTests(ExecutionLoaderGetCodeTests, unittest.TestCase): + ExecutionLoaderSubclass = Source_EL + + +##### SourceLoader concrete methods ############################################ +class SourceLoader: + + # Globals that should be defined for all modules. + source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, " + b"repr(__loader__)])") + + def __init__(self, path): + self.path = path + + def get_data(self, path): + if path != self.path: + raise IOError + return self.source + + def get_filename(self, fullname): + return self.path + + def module_repr(self, module): + return '<module>' + + +Frozen_SourceOnlyL, Source_SourceOnlyL = make_abc_subclasses(SourceLoader) + + +class SourceLoader(SourceLoader): + + source_mtime = 1 + + def __init__(self, path, magic=None): + super().__init__(path) + self.bytecode_path = self.util.cache_from_source(self.path) + self.source_size = len(self.source) + if magic is None: + magic = self.util.MAGIC_NUMBER + data = bytearray(magic) + data.extend(self.init._w_long(self.source_mtime)) + data.extend(self.init._w_long(self.source_size)) + code_object = compile(self.source, self.path, 'exec', + dont_inherit=True) + data.extend(marshal.dumps(code_object)) + self.bytecode = bytes(data) + self.written = {} + + def get_data(self, path): + if path == self.path: + return super().get_data(path) + elif path == self.bytecode_path: + return self.bytecode + else: + raise OSError + + def path_stats(self, path): + if path != self.path: + raise IOError + return {'mtime': self.source_mtime, 'size': self.source_size} + + def set_data(self, path, data): + self.written[path] = bytes(data) + return path == self.bytecode_path + + +Frozen_SL, Source_SL = make_abc_subclasses(SourceLoader) +Frozen_SL.util = frozen_util +Source_SL.util = source_util +Frozen_SL.init = frozen_init +Source_SL.init = source_init + + +class SourceLoaderTestHarness: + + def setUp(self, *, is_package=True, **kwargs): + self.package = 'pkg' + if is_package: + self.path = os.path.join(self.package, '__init__.py') + self.name = self.package + else: + module_name = 'mod' + self.path = os.path.join(self.package, '.'.join(['mod', 'py'])) + self.name = '.'.join([self.package, module_name]) + self.cached = self.util.cache_from_source(self.path) + self.loader = self.loader_mock(self.path, **kwargs) + + def verify_module(self, module): + self.assertEqual(module.__name__, self.name) + self.assertEqual(module.__file__, self.path) + self.assertEqual(module.__cached__, self.cached) + self.assertEqual(module.__package__, self.package) + self.assertEqual(module.__loader__, self.loader) + values = module._.split('::') + self.assertEqual(values[0], self.name) + self.assertEqual(values[1], self.path) + self.assertEqual(values[2], self.cached) + self.assertEqual(values[3], self.package) + self.assertEqual(values[4], repr(self.loader)) + + def verify_code(self, code_object): + module = types.ModuleType(self.name) + module.__file__ = self.path + module.__cached__ = self.cached + module.__package__ = self.package + module.__loader__ = self.loader + module.__path__ = [] + exec(code_object, module.__dict__) + self.verify_module(module) + + +class SourceOnlyLoaderTests(SourceLoaderTestHarness): + + """Test importlib.abc.SourceLoader for source-only loading. + + Reload testing is subsumed by the tests for + importlib.util.module_for_loader. + + """ + + def test_get_source(self): + # Verify the source code is returned as a string. + # If an OSError is raised by get_data then raise ImportError. + expected_source = self.loader.source.decode('utf-8') + self.assertEqual(self.loader.get_source(self.name), expected_source) + def raise_OSError(path): + raise OSError + self.loader.get_data = raise_OSError + with self.assertRaises(ImportError) as cm: + self.loader.get_source(self.name) + self.assertEqual(cm.exception.name, self.name) + + def test_is_package(self): + # Properly detect when loading a package. + self.setUp(is_package=False) + self.assertFalse(self.loader.is_package(self.name)) + self.setUp(is_package=True) + self.assertTrue(self.loader.is_package(self.name)) + self.assertFalse(self.loader.is_package(self.name + '.__init__')) + + def test_get_code(self): + # Verify the code object is created. + code_object = self.loader.get_code(self.name) + self.verify_code(code_object) + + def test_source_to_code(self): + # Verify the compiled code object. + code = self.loader.source_to_code(self.loader.source, self.path) + self.verify_code(code) + + def test_load_module(self): + # Loading a module should set __name__, __loader__, __package__, + # __path__ (for packages), __file__, and __cached__. + # The module should also be put into sys.modules. + with util.uncache(self.name): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = self.loader.load_module(self.name) + self.verify_module(module) + self.assertEqual(module.__path__, [os.path.dirname(self.path)]) + self.assertIn(self.name, sys.modules) + + def test_package_settings(self): + # __package__ needs to be set, while __path__ is set on if the module + # is a package. + # Testing the values for a package are covered by test_load_module. + self.setUp(is_package=False) + with util.uncache(self.name): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + module = self.loader.load_module(self.name) + self.verify_module(module) + self.assertTrue(not hasattr(module, '__path__')) + + def test_get_source_encoding(self): + # Source is considered encoded in UTF-8 by default unless otherwise + # specified by an encoding line. + source = "_ = 'ü'" + self.loader.source = source.encode('utf-8') + returned_source = self.loader.get_source(self.name) + self.assertEqual(returned_source, source) + source = "# coding: latin-1\n_ = ü" + self.loader.source = source.encode('latin-1') + returned_source = self.loader.get_source(self.name) + self.assertEqual(returned_source, source) + + +class Frozen_SourceOnlyLTests(SourceOnlyLoaderTests, unittest.TestCase): + loader_mock = Frozen_SourceOnlyL + util = frozen_util + +class Source_SourceOnlyLTests(SourceOnlyLoaderTests, unittest.TestCase): + loader_mock = Source_SourceOnlyL + util = source_util + + +@unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true") +class SourceLoaderBytecodeTests(SourceLoaderTestHarness): + + """Test importlib.abc.SourceLoader's use of bytecode. + + Source-only testing handled by SourceOnlyLoaderTests. + + """ + + def verify_code(self, code_object, *, bytecode_written=False): + super().verify_code(code_object) + if bytecode_written: + self.assertIn(self.cached, self.loader.written) + data = bytearray(self.util.MAGIC_NUMBER) + data.extend(self.init._w_long(self.loader.source_mtime)) + data.extend(self.init._w_long(self.loader.source_size)) + data.extend(marshal.dumps(code_object)) + self.assertEqual(self.loader.written[self.cached], bytes(data)) + + def test_code_with_everything(self): + # When everything should work. + code_object = self.loader.get_code(self.name) + self.verify_code(code_object) + + def test_no_bytecode(self): + # If no bytecode exists then move on to the source. + self.loader.bytecode_path = "<does not exist>" + # Sanity check + with self.assertRaises(OSError): + bytecode_path = self.util.cache_from_source(self.path) + self.loader.get_data(bytecode_path) + code_object = self.loader.get_code(self.name) + self.verify_code(code_object, bytecode_written=True) + + def test_code_bad_timestamp(self): + # Bytecode is only used when the timestamp matches the source EXACTLY. + for source_mtime in (0, 2): + assert source_mtime != self.loader.source_mtime + original = self.loader.source_mtime + self.loader.source_mtime = source_mtime + # If bytecode is used then EOFError would be raised by marshal. + self.loader.bytecode = self.loader.bytecode[8:] + code_object = self.loader.get_code(self.name) + self.verify_code(code_object, bytecode_written=True) + self.loader.source_mtime = original + + def test_code_bad_magic(self): + # Skip over bytecode with a bad magic number. + self.setUp(magic=b'0000') + # If bytecode is used then EOFError would be raised by marshal. + self.loader.bytecode = self.loader.bytecode[8:] + code_object = self.loader.get_code(self.name) + self.verify_code(code_object, bytecode_written=True) + + def test_dont_write_bytecode(self): + # Bytecode is not written if sys.dont_write_bytecode is true. + # Can assume it is false already thanks to the skipIf class decorator. + try: + sys.dont_write_bytecode = True + self.loader.bytecode_path = "<does not exist>" + code_object = self.loader.get_code(self.name) + self.assertNotIn(self.cached, self.loader.written) + finally: + sys.dont_write_bytecode = False + + def test_no_set_data(self): + # If set_data is not defined, one can still read bytecode. + self.setUp(magic=b'0000') + original_set_data = self.loader.__class__.mro()[1].set_data + try: + del self.loader.__class__.mro()[1].set_data + code_object = self.loader.get_code(self.name) + self.verify_code(code_object) + finally: + self.loader.__class__.mro()[1].set_data = original_set_data + + def test_set_data_raises_exceptions(self): + # Raising NotImplementedError or OSError is okay for set_data. + def raise_exception(exc): + def closure(*args, **kwargs): + raise exc + return closure + + self.setUp(magic=b'0000') + self.loader.set_data = raise_exception(NotImplementedError) + code_object = self.loader.get_code(self.name) + self.verify_code(code_object) + + +class Frozen_SLBytecodeTests(SourceLoaderBytecodeTests, unittest.TestCase): + loader_mock = Frozen_SL + init = frozen_init + util = frozen_util + +class SourceSLBytecodeTests(SourceLoaderBytecodeTests, unittest.TestCase): + loader_mock = Source_SL + init = source_init + util = source_util + + +class SourceLoaderGetSourceTests: + + """Tests for importlib.abc.SourceLoader.get_source().""" + + def test_default_encoding(self): + # Should have no problems with UTF-8 text. + name = 'mod' + mock = self.SourceOnlyLoaderMock('mod.file') + source = 'x = "ü"' + mock.source = source.encode('utf-8') + returned_source = mock.get_source(name) + self.assertEqual(returned_source, source) + + def test_decoded_source(self): + # Decoding should work. + name = 'mod' + mock = self.SourceOnlyLoaderMock("mod.file") + source = "# coding: Latin-1\nx='ü'" + assert source.encode('latin-1') != source.encode('utf-8') + mock.source = source.encode('latin-1') + returned_source = mock.get_source(name) + self.assertEqual(returned_source, source) + + def test_universal_newlines(self): + # PEP 302 says universal newlines should be used. + name = 'mod' + mock = self.SourceOnlyLoaderMock('mod.file') + source = "x = 42\r\ny = -13\r\n" + mock.source = source.encode('utf-8') + expect = io.IncrementalNewlineDecoder(None, True).decode(source) + self.assertEqual(mock.get_source(name), expect) + + +class Frozen_SourceOnlyLGetSourceTests(SourceLoaderGetSourceTests, unittest.TestCase): + SourceOnlyLoaderMock = Frozen_SourceOnlyL + +class Source_SourceOnlyLGetSourceTests(SourceLoaderGetSourceTests, unittest.TestCase): + SourceOnlyLoaderMock = Source_SourceOnlyL if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py index b1a58944f3..744001b4c0 100644 --- a/Lib/test/test_importlib/test_api.py +++ b/Lib/test/test_importlib/test_api.py @@ -1,14 +1,18 @@ from . import util -import imp -import importlib -from importlib import machinery + +frozen_init, source_init = util.import_importlib('importlib') +frozen_util, source_util = util.import_importlib('importlib.util') +frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') + +import os.path import sys from test import support import types import unittest +import warnings -class ImportModuleTests(unittest.TestCase): +class ImportModuleTests: """Test importlib.import_module.""" @@ -16,7 +20,7 @@ class ImportModuleTests(unittest.TestCase): # Test importing a top-level module. with util.mock_modules('top_level') as mock: with util.import_state(meta_path=[mock]): - module = importlib.import_module('top_level') + module = self.init.import_module('top_level') self.assertEqual(module.__name__, 'top_level') def test_absolute_package_import(self): @@ -26,7 +30,7 @@ class ImportModuleTests(unittest.TestCase): name = '{0}.mod'.format(pkg_name) with util.mock_modules(pkg_long_name, name) as mock: with util.import_state(meta_path=[mock]): - module = importlib.import_module(name) + module = self.init.import_module(name) self.assertEqual(module.__name__, name) def test_shallow_relative_package_import(self): @@ -38,17 +42,17 @@ class ImportModuleTests(unittest.TestCase): relative_name = '.{0}'.format(module_name) with util.mock_modules(pkg_long_name, absolute_name) as mock: with util.import_state(meta_path=[mock]): - importlib.import_module(pkg_name) - module = importlib.import_module(relative_name, pkg_name) + self.init.import_module(pkg_name) + module = self.init.import_module(relative_name, pkg_name) self.assertEqual(module.__name__, absolute_name) def test_deep_relative_package_import(self): modules = ['a.__init__', 'a.b.__init__', 'a.c'] with util.mock_modules(*modules) as mock: with util.import_state(meta_path=[mock]): - importlib.import_module('a') - importlib.import_module('a.b') - module = importlib.import_module('..c', 'a.b') + self.init.import_module('a') + self.init.import_module('a.b') + module = self.init.import_module('..c', 'a.b') self.assertEqual(module.__name__, 'a.c') def test_absolute_import_with_package(self): @@ -59,15 +63,15 @@ class ImportModuleTests(unittest.TestCase): name = '{0}.mod'.format(pkg_name) with util.mock_modules(pkg_long_name, name) as mock: with util.import_state(meta_path=[mock]): - importlib.import_module(pkg_name) - module = importlib.import_module(name, pkg_name) + self.init.import_module(pkg_name) + module = self.init.import_module(name, pkg_name) self.assertEqual(module.__name__, name) def test_relative_import_wo_package(self): # Relative imports cannot happen without the 'package' argument being # set. with self.assertRaises(TypeError): - importlib.import_module('.support') + self.init.import_module('.support') def test_loaded_once(self): @@ -76,7 +80,7 @@ class ImportModuleTests(unittest.TestCase): # module currently being imported. b_load_count = 0 def load_a(): - importlib.import_module('a.b') + self.init.import_module('a.b') def load_b(): nonlocal b_load_count b_load_count += 1 @@ -84,11 +88,17 @@ class ImportModuleTests(unittest.TestCase): modules = ['a.__init__', 'a.b'] with util.mock_modules(*modules, module_code=code) as mock: with util.import_state(meta_path=[mock]): - importlib.import_module('a.b') + self.init.import_module('a.b') self.assertEqual(b_load_count, 1) +class Frozen_ImportModuleTests(ImportModuleTests, unittest.TestCase): + init = frozen_init + +class Source_ImportModuleTests(ImportModuleTests, unittest.TestCase): + init = source_init -class FindLoaderTests(unittest.TestCase): + +class FindLoaderTests: class FakeMetaFinder: @staticmethod @@ -98,29 +108,51 @@ class FindLoaderTests(unittest.TestCase): # If a module with __loader__ is in sys.modules, then return it. name = 'some_mod' with util.uncache(name): - module = imp.new_module(name) + module = types.ModuleType(name) loader = 'a loader!' module.__loader__ = loader sys.modules[name] = module - found = importlib.find_loader(name) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + found = self.init.find_loader(name) self.assertEqual(loader, found) def test_sys_modules_loader_is_None(self): # If sys.modules[name].__loader__ is None, raise ValueError. name = 'some_mod' with util.uncache(name): - module = imp.new_module(name) + module = types.ModuleType(name) module.__loader__ = None sys.modules[name] = module with self.assertRaises(ValueError): - importlib.find_loader(name) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.init.find_loader(name) + + def test_sys_modules_loader_is_not_set(self): + # Should raise ValueError + # Issue #17099 + name = 'some_mod' + with util.uncache(name): + module = types.ModuleType(name) + try: + del module.__loader__ + except AttributeError: + pass + sys.modules[name] = module + with self.assertRaises(ValueError): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.init.find_loader(name) def test_success(self): # Return the loader found on sys.meta_path. name = 'some_mod' with util.uncache(name): with util.import_state(meta_path=[self.FakeMetaFinder]): - self.assertEqual((name, None), importlib.find_loader(name)) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertEqual((name, None), self.init.find_loader(name)) def test_success_path(self): # Searching on a path should work. @@ -128,15 +160,200 @@ class FindLoaderTests(unittest.TestCase): path = 'path to some place' with util.uncache(name): with util.import_state(meta_path=[self.FakeMetaFinder]): - self.assertEqual((name, path), - importlib.find_loader(name, path)) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertEqual((name, path), + self.init.find_loader(name, path)) def test_nothing(self): # None is returned upon failure to find a loader. - self.assertIsNone(importlib.find_loader('nevergoingtofindthismodule')) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertIsNone(self.init.find_loader('nevergoingtofindthismodule')) + +class Frozen_FindLoaderTests(FindLoaderTests, unittest.TestCase): + init = frozen_init + +class Source_FindLoaderTests(FindLoaderTests, unittest.TestCase): + init = source_init -class InvalidateCacheTests(unittest.TestCase): +class ReloadTests: + + """Test module reloading for builtin and extension modules.""" + + def test_reload_modules(self): + for mod in ('tokenize', 'time', 'marshal'): + with self.subTest(module=mod): + with support.CleanImport(mod): + module = self.init.import_module(mod) + self.init.reload(module) + + def test_module_replaced(self): + def code(): + import sys + module = type(sys)('top_level') + module.spam = 3 + sys.modules['top_level'] = module + mock = util.mock_modules('top_level', + module_code={'top_level': code}) + with mock: + with util.import_state(meta_path=[mock]): + module = self.init.import_module('top_level') + reloaded = self.init.reload(module) + actual = sys.modules['top_level'] + self.assertEqual(actual.spam, 3) + self.assertEqual(reloaded.spam, 3) + + def test_reload_missing_loader(self): + with support.CleanImport('types'): + import types + loader = types.__loader__ + del types.__loader__ + reloaded = self.init.reload(types) + + self.assertIs(reloaded, types) + self.assertIs(sys.modules['types'], types) + self.assertEqual(reloaded.__loader__.path, loader.path) + + def test_reload_loader_replaced(self): + with support.CleanImport('types'): + import types + types.__loader__ = None + self.init.invalidate_caches() + reloaded = self.init.reload(types) + + self.assertIsNot(reloaded.__loader__, None) + self.assertIs(reloaded, types) + self.assertIs(sys.modules['types'], types) + + def test_reload_location_changed(self): + name = 'spam' + with support.temp_cwd(None) as cwd: + with util.uncache('spam'): + with support.DirsOnSysPath(cwd): + # Start as a plain module. + self.init.invalidate_caches() + path = os.path.join(cwd, name + '.py') + cached = self.util.cache_from_source(path) + expected = {'__name__': name, + '__package__': '', + '__file__': path, + '__cached__': cached, + '__doc__': None, + '__builtins__': __builtins__, + } + support.create_empty_file(path) + module = self.init.import_module(name) + ns = vars(module) + loader = ns.pop('__loader__') + spec = ns.pop('__spec__') + self.assertEqual(spec.name, name) + self.assertEqual(spec.loader, loader) + self.assertEqual(loader.path, path) + self.assertEqual(ns, expected) + + # Change to a package. + self.init.invalidate_caches() + init_path = os.path.join(cwd, name, '__init__.py') + cached = self.util.cache_from_source(init_path) + expected = {'__name__': name, + '__package__': name, + '__file__': init_path, + '__cached__': cached, + '__path__': [os.path.dirname(init_path)], + '__doc__': None, + '__builtins__': __builtins__, + } + os.mkdir(name) + os.rename(path, init_path) + reloaded = self.init.reload(module) + ns = vars(reloaded) + loader = ns.pop('__loader__') + spec = ns.pop('__spec__') + self.assertEqual(spec.name, name) + self.assertEqual(spec.loader, loader) + self.assertIs(reloaded, module) + self.assertEqual(loader.path, init_path) + self.maxDiff = None + self.assertEqual(ns, expected) + + def test_reload_namespace_changed(self): + name = 'spam' + with support.temp_cwd(None) as cwd: + with util.uncache('spam'): + with support.DirsOnSysPath(cwd): + # Start as a namespace package. + self.init.invalidate_caches() + bad_path = os.path.join(cwd, name, '__init.py') + cached = self.util.cache_from_source(bad_path) + expected = {'__name__': name, + '__package__': name, + '__doc__': None, + } + os.mkdir(name) + with open(bad_path, 'w') as init_file: + init_file.write('eggs = None') + module = self.init.import_module(name) + ns = vars(module) + loader = ns.pop('__loader__') + path = ns.pop('__path__') + spec = ns.pop('__spec__') + self.assertEqual(spec.name, name) + self.assertIs(spec.loader, None) + self.assertIsNot(loader, None) + self.assertEqual(set(path), + set([os.path.dirname(bad_path)])) + with self.assertRaises(AttributeError): + # a NamespaceLoader + loader.path + self.assertEqual(ns, expected) + + # Change to a regular package. + self.init.invalidate_caches() + init_path = os.path.join(cwd, name, '__init__.py') + cached = self.util.cache_from_source(init_path) + expected = {'__name__': name, + '__package__': name, + '__file__': init_path, + '__cached__': cached, + '__path__': [os.path.dirname(init_path)], + '__doc__': None, + '__builtins__': __builtins__, + 'eggs': None, + } + os.rename(bad_path, init_path) + reloaded = self.init.reload(module) + ns = vars(reloaded) + loader = ns.pop('__loader__') + spec = ns.pop('__spec__') + self.assertEqual(spec.name, name) + self.assertEqual(spec.loader, loader) + self.assertIs(reloaded, module) + self.assertEqual(loader.path, init_path) + self.assertEqual(ns, expected) + + def test_reload_submodule(self): + # See #19851. + name = 'spam' + subname = 'ham' + with util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = util.submodule(name, subname, pkg_dir) + ham = self.init.import_module(fullname) + reloaded = self.init.reload(ham) + self.assertIs(reloaded, ham) + + +class Frozen_ReloadTests(ReloadTests, unittest.TestCase): + init = frozen_init + util = frozen_util + +class Source_ReloadTests(ReloadTests, unittest.TestCase): + init = source_init + util = source_util + + +class InvalidateCacheTests: def test_method_called(self): # If defined the method should be called. @@ -155,48 +372,65 @@ class InvalidateCacheTests(unittest.TestCase): self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key)) sys.path_importer_cache[key] = path_ins self.addCleanup(lambda: sys.meta_path.remove(meta_ins)) - importlib.invalidate_caches() + self.init.invalidate_caches() self.assertTrue(meta_ins.called) self.assertTrue(path_ins.called) def test_method_lacking(self): # There should be no issues if the method is not defined. key = 'gobbledeegook' - sys.path_importer_cache[key] = imp.NullImporter('abc') + sys.path_importer_cache[key] = None self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key)) - importlib.invalidate_caches() # Shouldn't trigger an exception. + self.init.invalidate_caches() # Shouldn't trigger an exception. + +class Frozen_InvalidateCacheTests(InvalidateCacheTests, unittest.TestCase): + init = frozen_init + +class Source_InvalidateCacheTests(InvalidateCacheTests, unittest.TestCase): + init = source_init class FrozenImportlibTests(unittest.TestCase): def test_no_frozen_importlib(self): # Should be able to import w/o _frozen_importlib being defined. - module = support.import_fresh_module('importlib', blocked=['_frozen_importlib']) - self.assertFalse(isinstance(module.__loader__, - machinery.FrozenImporter)) + # Can't do an isinstance() check since separate copies of importlib + # may have been used for import, so just check the name is not for the + # frozen loader. + self.assertNotEqual(source_init.__loader__.__class__.__name__, + 'FrozenImporter') -class StartupTests(unittest.TestCase): +class StartupTests: def test_everyone_has___loader__(self): # Issue #17098: all modules should have __loader__ defined. for name, module in sys.modules.items(): if isinstance(module, types.ModuleType): - if name in sys.builtin_module_names: - self.assertEqual(importlib.machinery.BuiltinImporter, - module.__loader__) - elif imp.is_frozen(name): - self.assertEqual(importlib.machinery.FrozenImporter, - module.__loader__) - -def test_main(): - from test.support import run_unittest - run_unittest(ImportModuleTests, - FindLoaderTests, - InvalidateCacheTests, - FrozenImportlibTests, - StartupTests) + with self.subTest(name=name): + self.assertTrue(hasattr(module, '__loader__'), + '{!r} lacks a __loader__ attribute'.format(name)) + if self.machinery.BuiltinImporter.find_module(name): + self.assertIsNot(module.__loader__, None) + elif self.machinery.FrozenImporter.find_module(name): + self.assertIsNot(module.__loader__, None) + + def test_everyone_has___spec__(self): + for name, module in sys.modules.items(): + if isinstance(module, types.ModuleType): + with self.subTest(name=name): + self.assertTrue(hasattr(module, '__spec__')) + if self.machinery.BuiltinImporter.find_module(name): + self.assertIsNot(module.__spec__, None) + elif self.machinery.FrozenImporter.find_module(name): + self.assertIsNot(module.__spec__, None) + +class Frozen_StartupTests(StartupTests, unittest.TestCase): + machinery = frozen_machinery + +class Source_StartupTests(StartupTests, unittest.TestCase): + machinery = source_machinery if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py index c373b11256..dc97ba1567 100644 --- a/Lib/test/test_importlib/test_locks.py +++ b/Lib/test/test_importlib/test_locks.py @@ -1,4 +1,8 @@ -from importlib import _bootstrap +from . import util +frozen_init, source_init = util.import_importlib('importlib') +frozen_bootstrap = frozen_init._bootstrap +source_bootstrap = source_init._bootstrap + import sys import time import unittest @@ -13,14 +17,9 @@ except ImportError: else: from test import lock_tests - -LockType = _bootstrap._ModuleLock -DeadlockError = _bootstrap._DeadlockError - - if threading is not None: - class ModuleLockAsRLockTests(lock_tests.RLockTests): - locktype = staticmethod(lambda: LockType("some_lock")) + class ModuleLockAsRLockTests: + locktype = classmethod(lambda cls: cls.LockType("some_lock")) # _is_owned() unsupported test__is_owned = None @@ -34,13 +33,21 @@ if threading is not None: # _release_save() unsupported test_release_save_unacquired = None + class Frozen_ModuleLockAsRLockTests(ModuleLockAsRLockTests, lock_tests.RLockTests): + LockType = frozen_bootstrap._ModuleLock + + class Source_ModuleLockAsRLockTests(ModuleLockAsRLockTests, lock_tests.RLockTests): + LockType = source_bootstrap._ModuleLock + else: - class ModuleLockAsRLockTests(unittest.TestCase): + class Frozen_ModuleLockAsRLockTests(unittest.TestCase): pass + class Source_ModuleLockAsRLockTests(unittest.TestCase): + pass -@unittest.skipUnless(threading, "threads needed for this test") -class DeadlockAvoidanceTests(unittest.TestCase): + +class DeadlockAvoidanceTests: def setUp(self): try: @@ -55,7 +62,7 @@ class DeadlockAvoidanceTests(unittest.TestCase): def run_deadlock_avoidance_test(self, create_deadlock): NLOCKS = 10 - locks = [LockType(str(i)) for i in range(NLOCKS)] + locks = [self.LockType(str(i)) for i in range(NLOCKS)] pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)] if create_deadlock: NTHREADS = NLOCKS @@ -67,7 +74,7 @@ class DeadlockAvoidanceTests(unittest.TestCase): """Try to acquire the lock. Return True on success, False on deadlock.""" try: lock.acquire() - except DeadlockError: + except self.DeadlockError: return False else: return True @@ -99,30 +106,50 @@ class DeadlockAvoidanceTests(unittest.TestCase): self.assertEqual(results.count((True, False)), 0) self.assertEqual(results.count((True, True)), len(results)) +@unittest.skipUnless(threading, "threads needed for this test") +class Frozen_DeadlockAvoidanceTests(DeadlockAvoidanceTests, unittest.TestCase): + LockType = frozen_bootstrap._ModuleLock + DeadlockError = frozen_bootstrap._DeadlockError + +@unittest.skipUnless(threading, "threads needed for this test") +class Source_DeadlockAvoidanceTests(DeadlockAvoidanceTests, unittest.TestCase): + LockType = source_bootstrap._ModuleLock + DeadlockError = source_bootstrap._DeadlockError -class LifetimeTests(unittest.TestCase): + +class LifetimeTests: def test_lock_lifetime(self): name = "xyzzy" - self.assertNotIn(name, _bootstrap._module_locks) - lock = _bootstrap._get_module_lock(name) - self.assertIn(name, _bootstrap._module_locks) + self.assertNotIn(name, self.bootstrap._module_locks) + lock = self.bootstrap._get_module_lock(name) + self.assertIn(name, self.bootstrap._module_locks) wr = weakref.ref(lock) del lock support.gc_collect() - self.assertNotIn(name, _bootstrap._module_locks) + self.assertNotIn(name, self.bootstrap._module_locks) self.assertIsNone(wr()) def test_all_locks(self): support.gc_collect() - self.assertEqual(0, len(_bootstrap._module_locks), _bootstrap._module_locks) + self.assertEqual(0, len(self.bootstrap._module_locks), + self.bootstrap._module_locks) + +class Frozen_LifetimeTests(LifetimeTests, unittest.TestCase): + bootstrap = frozen_bootstrap + +class Source_LifetimeTests(LifetimeTests, unittest.TestCase): + bootstrap = source_bootstrap @support.reap_threads def test_main(): - support.run_unittest(ModuleLockAsRLockTests, - DeadlockAvoidanceTests, - LifetimeTests) + support.run_unittest(Frozen_ModuleLockAsRLockTests, + Source_ModuleLockAsRLockTests, + Frozen_DeadlockAvoidanceTests, + Source_DeadlockAvoidanceTests, + Frozen_LifetimeTests, + Source_LifetimeTests) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/test_spec.py b/Lib/test/test_importlib/test_spec.py new file mode 100644 index 0000000000..71541f692e --- /dev/null +++ b/Lib/test/test_importlib/test_spec.py @@ -0,0 +1,957 @@ +from . import util + +frozen_init, source_init = util.import_importlib('importlib') +frozen_bootstrap = frozen_init._bootstrap +source_bootstrap = source_init._bootstrap +frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +frozen_util, source_util = util.import_importlib('importlib.util') + +import os.path +from test.support import CleanImport +import unittest +import sys +import warnings + + + +class TestLoader: + + def __init__(self, path=None, is_package=None): + self.path = path + self.package = is_package + + def __repr__(self): + return '<TestLoader object>' + + def __getattr__(self, name): + if name == 'get_filename' and self.path is not None: + return self._get_filename + if name == 'is_package': + return self._is_package + raise AttributeError(name) + + def _get_filename(self, name): + return self.path + + def _is_package(self, name): + return self.package + + +class NewLoader(TestLoader): + + EGGS = 1 + + def exec_module(self, module): + module.eggs = self.EGGS + + +class LegacyLoader(TestLoader): + + HAM = -1 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @frozen_util.module_for_loader + def load_module(self, module): + module.ham = self.HAM + return module + + +class ModuleSpecTests: + + def setUp(self): + self.name = 'spam' + self.path = 'spam.py' + self.cached = self.util.cache_from_source(self.path) + self.loader = TestLoader() + self.spec = self.machinery.ModuleSpec(self.name, self.loader) + self.loc_spec = self.machinery.ModuleSpec(self.name, self.loader, + origin=self.path) + self.loc_spec._set_fileattr = True + + def test_default(self): + spec = self.machinery.ModuleSpec(self.name, self.loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_default_no_loader(self): + spec = self.machinery.ModuleSpec(self.name, None) + + self.assertEqual(spec.name, self.name) + self.assertIs(spec.loader, None) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_default_is_package_false(self): + spec = self.machinery.ModuleSpec(self.name, self.loader, + is_package=False) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_default_is_package_true(self): + spec = self.machinery.ModuleSpec(self.name, self.loader, + is_package=True) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, []) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_has_location_setter(self): + spec = self.machinery.ModuleSpec(self.name, self.loader, + origin='somewhere') + self.assertFalse(spec.has_location) + spec.has_location = True + self.assertTrue(spec.has_location) + + def test_equality(self): + other = type(sys.implementation)(name=self.name, + loader=self.loader, + origin=None, + submodule_search_locations=None, + has_location=False, + cached=None, + ) + + self.assertTrue(self.spec == other) + + def test_equality_location(self): + other = type(sys.implementation)(name=self.name, + loader=self.loader, + origin=self.path, + submodule_search_locations=None, + has_location=True, + cached=self.cached, + ) + + self.assertEqual(self.loc_spec, other) + + def test_inequality(self): + other = type(sys.implementation)(name='ham', + loader=self.loader, + origin=None, + submodule_search_locations=None, + has_location=False, + cached=None, + ) + + self.assertNotEqual(self.spec, other) + + def test_inequality_incomplete(self): + other = type(sys.implementation)(name=self.name, + loader=self.loader, + ) + + self.assertNotEqual(self.spec, other) + + def test_package(self): + spec = self.machinery.ModuleSpec('spam.eggs', self.loader) + + self.assertEqual(spec.parent, 'spam') + + def test_package_is_package(self): + spec = self.machinery.ModuleSpec('spam.eggs', self.loader, + is_package=True) + + self.assertEqual(spec.parent, 'spam.eggs') + + # cached + + def test_cached_set(self): + before = self.spec.cached + self.spec.cached = 'there' + after = self.spec.cached + + self.assertIs(before, None) + self.assertEqual(after, 'there') + + def test_cached_no_origin(self): + spec = self.machinery.ModuleSpec(self.name, self.loader) + + self.assertIs(spec.cached, None) + + def test_cached_with_origin_not_location(self): + spec = self.machinery.ModuleSpec(self.name, self.loader, + origin=self.path) + + self.assertIs(spec.cached, None) + + def test_cached_source(self): + expected = self.util.cache_from_source(self.path) + + self.assertEqual(self.loc_spec.cached, expected) + + def test_cached_source_unknown_suffix(self): + self.loc_spec.origin = 'spam.spamspamspam' + + self.assertIs(self.loc_spec.cached, None) + + def test_cached_source_missing_cache_tag(self): + original = sys.implementation.cache_tag + sys.implementation.cache_tag = None + try: + cached = self.loc_spec.cached + finally: + sys.implementation.cache_tag = original + + self.assertIs(cached, None) + + def test_cached_sourceless(self): + self.loc_spec.origin = 'spam.pyc' + + self.assertEqual(self.loc_spec.cached, 'spam.pyc') + + +class Frozen_ModuleSpecTests(ModuleSpecTests, unittest.TestCase): + util = frozen_util + machinery = frozen_machinery + + +class Source_ModuleSpecTests(ModuleSpecTests, unittest.TestCase): + util = source_util + machinery = source_machinery + + +class ModuleSpecMethodsTests: + + def setUp(self): + self.name = 'spam' + self.path = 'spam.py' + self.cached = self.util.cache_from_source(self.path) + self.loader = TestLoader() + self.spec = self.machinery.ModuleSpec(self.name, self.loader) + self.loc_spec = self.machinery.ModuleSpec(self.name, self.loader, + origin=self.path) + self.loc_spec._set_fileattr = True + + # init_module_attrs + + def test_init_module_attrs(self): + module = type(sys)(self.name) + spec = self.machinery.ModuleSpec(self.name, self.loader) + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertEqual(module.__name__, spec.name) + self.assertIs(module.__loader__, spec.loader) + self.assertEqual(module.__package__, spec.parent) + self.assertIs(module.__spec__, spec) + self.assertFalse(hasattr(module, '__path__')) + self.assertFalse(hasattr(module, '__file__')) + self.assertFalse(hasattr(module, '__cached__')) + + def test_init_module_attrs_package(self): + module = type(sys)(self.name) + spec = self.machinery.ModuleSpec(self.name, self.loader) + spec.submodule_search_locations = ['spam', 'ham'] + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertEqual(module.__name__, spec.name) + self.assertIs(module.__loader__, spec.loader) + self.assertEqual(module.__package__, spec.parent) + self.assertIs(module.__spec__, spec) + self.assertIs(module.__path__, spec.submodule_search_locations) + self.assertFalse(hasattr(module, '__file__')) + self.assertFalse(hasattr(module, '__cached__')) + + def test_init_module_attrs_location(self): + module = type(sys)(self.name) + spec = self.loc_spec + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertEqual(module.__name__, spec.name) + self.assertIs(module.__loader__, spec.loader) + self.assertEqual(module.__package__, spec.parent) + self.assertIs(module.__spec__, spec) + self.assertFalse(hasattr(module, '__path__')) + self.assertEqual(module.__file__, spec.origin) + self.assertEqual(module.__cached__, + self.util.cache_from_source(spec.origin)) + + def test_init_module_attrs_different_name(self): + module = type(sys)('eggs') + spec = self.machinery.ModuleSpec(self.name, self.loader) + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertEqual(module.__name__, spec.name) + + def test_init_module_attrs_different_spec(self): + module = type(sys)(self.name) + module.__spec__ = self.machinery.ModuleSpec('eggs', object()) + spec = self.machinery.ModuleSpec(self.name, self.loader) + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertEqual(module.__name__, spec.name) + self.assertIs(module.__loader__, spec.loader) + self.assertEqual(module.__package__, spec.parent) + self.assertIs(module.__spec__, spec) + + def test_init_module_attrs_already_set(self): + module = type(sys)('ham.eggs') + module.__loader__ = object() + module.__package__ = 'ham' + module.__path__ = ['eggs'] + module.__file__ = 'ham/eggs/__init__.py' + module.__cached__ = self.util.cache_from_source(module.__file__) + original = vars(module).copy() + spec = self.loc_spec + spec.submodule_search_locations = [''] + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertIs(module.__loader__, original['__loader__']) + self.assertEqual(module.__package__, original['__package__']) + self.assertIs(module.__path__, original['__path__']) + self.assertEqual(module.__file__, original['__file__']) + self.assertEqual(module.__cached__, original['__cached__']) + + def test_init_module_attrs_immutable(self): + module = object() + spec = self.loc_spec + spec.submodule_search_locations = [''] + self.bootstrap._SpecMethods(spec).init_module_attrs(module) + + self.assertFalse(hasattr(module, '__name__')) + self.assertFalse(hasattr(module, '__loader__')) + self.assertFalse(hasattr(module, '__package__')) + self.assertFalse(hasattr(module, '__spec__')) + self.assertFalse(hasattr(module, '__path__')) + self.assertFalse(hasattr(module, '__file__')) + self.assertFalse(hasattr(module, '__cached__')) + + # create() + + def test_create(self): + created = self.bootstrap._SpecMethods(self.spec).create() + + self.assertEqual(created.__name__, self.spec.name) + self.assertIs(created.__loader__, self.spec.loader) + self.assertEqual(created.__package__, self.spec.parent) + self.assertIs(created.__spec__, self.spec) + self.assertFalse(hasattr(created, '__path__')) + self.assertFalse(hasattr(created, '__file__')) + self.assertFalse(hasattr(created, '__cached__')) + + def test_create_from_loader(self): + module = type(sys.implementation)() + class CreatingLoader(TestLoader): + def create_module(self, spec): + return module + self.spec.loader = CreatingLoader() + created = self.bootstrap._SpecMethods(self.spec).create() + + self.assertIs(created, module) + self.assertEqual(created.__name__, self.spec.name) + self.assertIs(created.__loader__, self.spec.loader) + self.assertEqual(created.__package__, self.spec.parent) + self.assertIs(created.__spec__, self.spec) + self.assertFalse(hasattr(created, '__path__')) + self.assertFalse(hasattr(created, '__file__')) + self.assertFalse(hasattr(created, '__cached__')) + + def test_create_from_loader_not_handled(self): + class CreatingLoader(TestLoader): + def create_module(self, spec): + return None + self.spec.loader = CreatingLoader() + created = self.bootstrap._SpecMethods(self.spec).create() + + self.assertEqual(created.__name__, self.spec.name) + self.assertIs(created.__loader__, self.spec.loader) + self.assertEqual(created.__package__, self.spec.parent) + self.assertIs(created.__spec__, self.spec) + self.assertFalse(hasattr(created, '__path__')) + self.assertFalse(hasattr(created, '__file__')) + self.assertFalse(hasattr(created, '__cached__')) + + # exec() + + def test_exec(self): + self.spec.loader = NewLoader() + module = self.bootstrap._SpecMethods(self.spec).create() + sys.modules[self.name] = module + self.assertFalse(hasattr(module, 'eggs')) + self.bootstrap._SpecMethods(self.spec).exec(module) + + self.assertEqual(module.eggs, 1) + + # load() + + def test_load(self): + self.spec.loader = NewLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + installed = sys.modules[self.spec.name] + + self.assertEqual(loaded.eggs, 1) + self.assertIs(loaded, installed) + + def test_load_replaced(self): + replacement = object() + class ReplacingLoader(TestLoader): + def exec_module(self, module): + sys.modules[module.__name__] = replacement + self.spec.loader = ReplacingLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + installed = sys.modules[self.spec.name] + + self.assertIs(loaded, replacement) + self.assertIs(installed, replacement) + + def test_load_failed(self): + class FailedLoader(TestLoader): + def exec_module(self, module): + raise RuntimeError + self.spec.loader = FailedLoader() + with CleanImport(self.spec.name): + with self.assertRaises(RuntimeError): + loaded = self.bootstrap._SpecMethods(self.spec).load() + self.assertNotIn(self.spec.name, sys.modules) + + def test_load_failed_removed(self): + class FailedLoader(TestLoader): + def exec_module(self, module): + del sys.modules[module.__name__] + raise RuntimeError + self.spec.loader = FailedLoader() + with CleanImport(self.spec.name): + with self.assertRaises(RuntimeError): + loaded = self.bootstrap._SpecMethods(self.spec).load() + self.assertNotIn(self.spec.name, sys.modules) + + def test_load_legacy(self): + self.spec.loader = LegacyLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + + self.assertEqual(loaded.ham, -1) + + def test_load_legacy_attributes(self): + self.spec.loader = LegacyLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + + self.assertIs(loaded.__loader__, self.spec.loader) + self.assertEqual(loaded.__package__, self.spec.parent) + self.assertIs(loaded.__spec__, self.spec) + + def test_load_legacy_attributes_immutable(self): + module = object() + class ImmutableLoader(TestLoader): + def load_module(self, name): + sys.modules[name] = module + return module + self.spec.loader = ImmutableLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + + self.assertIs(sys.modules[self.spec.name], module) + + # reload() + + def test_reload(self): + self.spec.loader = NewLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + installed = sys.modules[self.spec.name] + + self.assertEqual(loaded.eggs, 1) + self.assertIs(reloaded, loaded) + self.assertIs(installed, loaded) + + def test_reload_modified(self): + self.spec.loader = NewLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded.eggs = 2 + reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + + self.assertEqual(loaded.eggs, 1) + self.assertIs(reloaded, loaded) + + def test_reload_extra_attributes(self): + self.spec.loader = NewLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded.available = False + reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + + self.assertFalse(loaded.available) + self.assertIs(reloaded, loaded) + + def test_reload_init_module_attrs(self): + self.spec.loader = NewLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded.__name__ = 'ham' + del loaded.__loader__ + del loaded.__package__ + del loaded.__spec__ + self.bootstrap._SpecMethods(self.spec).exec(loaded) + + self.assertEqual(loaded.__name__, self.spec.name) + self.assertIs(loaded.__loader__, self.spec.loader) + self.assertEqual(loaded.__package__, self.spec.parent) + self.assertIs(loaded.__spec__, self.spec) + self.assertFalse(hasattr(loaded, '__path__')) + self.assertFalse(hasattr(loaded, '__file__')) + self.assertFalse(hasattr(loaded, '__cached__')) + + def test_reload_legacy(self): + self.spec.loader = LegacyLoader() + with CleanImport(self.spec.name): + loaded = self.bootstrap._SpecMethods(self.spec).load() + reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + installed = sys.modules[self.spec.name] + + self.assertEqual(loaded.ham, -1) + self.assertIs(reloaded, loaded) + self.assertIs(installed, loaded) + + +class Frozen_ModuleSpecMethodsTests(ModuleSpecMethodsTests, unittest.TestCase): + bootstrap = frozen_bootstrap + machinery = frozen_machinery + util = frozen_util + + +class Source_ModuleSpecMethodsTests(ModuleSpecMethodsTests, unittest.TestCase): + bootstrap = source_bootstrap + machinery = source_machinery + util = source_util + + +class ModuleReprTests: + + def setUp(self): + self.module = type(os)('spam') + self.spec = self.machinery.ModuleSpec('spam', TestLoader()) + + def test_module___loader___module_repr(self): + class Loader: + def module_repr(self, module): + return '<delicious {}>'.format(module.__name__) + self.module.__loader__ = Loader() + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, '<delicious spam>') + + def test_module___loader___module_repr_bad(self): + class Loader(TestLoader): + def module_repr(self, module): + raise Exception + self.module.__loader__ = Loader() + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, + '<module {!r} (<TestLoader object>)>'.format('spam')) + + def test_module___spec__(self): + origin = 'in a hole, in the ground' + self.spec.origin = origin + self.module.__spec__ = self.spec + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, '<module {!r} ({})>'.format('spam', origin)) + + def test_module___spec___location(self): + location = 'in_a_galaxy_far_far_away.py' + self.spec.origin = location + self.spec._set_fileattr = True + self.module.__spec__ = self.spec + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, + '<module {!r} from {!r}>'.format('spam', location)) + + def test_module___spec___no_origin(self): + self.spec.loader = TestLoader() + self.module.__spec__ = self.spec + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, + '<module {!r} (<TestLoader object>)>'.format('spam')) + + def test_module___spec___no_origin_no_loader(self): + self.spec.loader = None + self.module.__spec__ = self.spec + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, '<module {!r}>'.format('spam')) + + def test_module_no_name(self): + del self.module.__name__ + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, '<module {!r}>'.format('?')) + + def test_module_with_file(self): + filename = 'e/i/e/i/o/spam.py' + self.module.__file__ = filename + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, + '<module {!r} from {!r}>'.format('spam', filename)) + + def test_module_no_file(self): + self.module.__loader__ = TestLoader() + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, + '<module {!r} (<TestLoader object>)>'.format('spam')) + + def test_module_no_file_no_loader(self): + modrepr = self.bootstrap._module_repr(self.module) + + self.assertEqual(modrepr, '<module {!r}>'.format('spam')) + + +class Frozen_ModuleReprTests(ModuleReprTests, unittest.TestCase): + bootstrap = frozen_bootstrap + machinery = frozen_machinery + util = frozen_util + + +class Source_ModuleReprTests(ModuleReprTests, unittest.TestCase): + bootstrap = source_bootstrap + machinery = source_machinery + util = source_util + + +class FactoryTests: + + def setUp(self): + self.name = 'spam' + self.path = 'spam.py' + self.cached = self.util.cache_from_source(self.path) + self.loader = TestLoader() + self.fileloader = TestLoader(self.path) + self.pkgloader = TestLoader(self.path, True) + + # spec_from_loader() + + def test_spec_from_loader_default(self): + spec = self.util.spec_from_loader(self.name, self.loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_default_with_bad_is_package(self): + class Loader: + def is_package(self, name): + raise ImportError + loader = Loader() + spec = self.util.spec_from_loader(self.name, loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_origin(self): + origin = 'somewhere over the rainbow' + spec = self.util.spec_from_loader(self.name, self.loader, + origin=origin) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, origin) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_is_package_false(self): + spec = self.util.spec_from_loader(self.name, self.loader, + is_package=False) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_is_package_true(self): + spec = self.util.spec_from_loader(self.name, self.loader, + is_package=True) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, []) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_origin_and_is_package(self): + origin = 'where the streets have no name' + spec = self.util.spec_from_loader(self.name, self.loader, + origin=origin, is_package=True) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertIs(spec.origin, origin) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, []) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_is_package_with_loader_false(self): + loader = TestLoader(is_package=False) + spec = self.util.spec_from_loader(self.name, loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_is_package_with_loader_true(self): + loader = TestLoader(is_package=True) + spec = self.util.spec_from_loader(self.name, loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, loader) + self.assertIs(spec.origin, None) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, []) + self.assertIs(spec.cached, None) + self.assertFalse(spec.has_location) + + def test_spec_from_loader_default_with_file_loader(self): + spec = self.util.spec_from_loader(self.name, self.fileloader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_loader_is_package_false_with_fileloader(self): + spec = self.util.spec_from_loader(self.name, self.fileloader, + is_package=False) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_loader_is_package_true_with_fileloader(self): + spec = self.util.spec_from_loader(self.name, self.fileloader, + is_package=True) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, ['']) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + # spec_from_file_location() + + def test_spec_from_file_location_default(self): + if self.machinery is source_machinery: + raise unittest.SkipTest('not sure why this is breaking...') + spec = self.util.spec_from_file_location(self.name, self.path) + + self.assertEqual(spec.name, self.name) + self.assertIsInstance(spec.loader, + self.machinery.SourceFileLoader) + self.assertEqual(spec.loader.name, self.name) + self.assertEqual(spec.loader.path, self.path) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_default_without_location(self): + spec = self.util.spec_from_file_location(self.name) + + self.assertIs(spec, None) + + def test_spec_from_file_location_default_bad_suffix(self): + spec = self.util.spec_from_file_location(self.name, 'spam.eggs') + + self.assertIs(spec, None) + + def test_spec_from_file_location_loader_no_location(self): + spec = self.util.spec_from_file_location(self.name, + loader=self.fileloader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_loader_no_location_no_get_filename(self): + spec = self.util.spec_from_file_location(self.name, + loader=self.loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.loader) + self.assertEqual(spec.origin, '<unknown>') + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_loader_no_location_bad_get_filename(self): + class Loader: + def get_filename(self, name): + raise ImportError + loader = Loader() + spec = self.util.spec_from_file_location(self.name, loader=loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, loader) + self.assertEqual(spec.origin, '<unknown>') + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertIs(spec.cached, None) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_none(self): + spec = self.util.spec_from_file_location(self.name, self.path, + loader=self.fileloader, + submodule_search_locations=None) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_empty(self): + spec = self.util.spec_from_file_location(self.name, self.path, + loader=self.fileloader, + submodule_search_locations=[]) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, ['']) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_not_empty(self): + spec = self.util.spec_from_file_location(self.name, self.path, + loader=self.fileloader, + submodule_search_locations=['eggs']) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, ['eggs']) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_default(self): + spec = self.util.spec_from_file_location(self.name, self.path, + loader=self.pkgloader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.pkgloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertEqual(spec.submodule_search_locations, ['']) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_default_not_package(self): + class Loader: + def is_package(self, name): + return False + loader = Loader() + spec = self.util.spec_from_file_location(self.name, self.path, + loader=loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, loader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_default_no_is_package(self): + spec = self.util.spec_from_file_location(self.name, self.path, + loader=self.fileloader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, self.fileloader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + def test_spec_from_file_location_smsl_default_bad_is_package(self): + class Loader: + def is_package(self, name): + raise ImportError + loader = Loader() + spec = self.util.spec_from_file_location(self.name, self.path, + loader=loader) + + self.assertEqual(spec.name, self.name) + self.assertEqual(spec.loader, loader) + self.assertEqual(spec.origin, self.path) + self.assertIs(spec.loader_state, None) + self.assertIs(spec.submodule_search_locations, None) + self.assertEqual(spec.cached, self.cached) + self.assertTrue(spec.has_location) + + +class Frozen_FactoryTests(FactoryTests, unittest.TestCase): + util = frozen_util + machinery = frozen_machinery + + +class Source_FactoryTests(FactoryTests, unittest.TestCase): + util = source_util + machinery = source_machinery diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py index efc8977fb4..b2823c6a78 100644 --- a/Lib/test/test_importlib/test_util.py +++ b/Lib/test/test_importlib/test_util.py @@ -1,23 +1,66 @@ from importlib import util from . import util as test_util -import imp +frozen_init, source_init = test_util.import_importlib('importlib') +frozen_machinery, source_machinery = test_util.import_importlib('importlib.machinery') +frozen_util, source_util = test_util.import_importlib('importlib.util') + +import os import sys +from test import support import types import unittest +import warnings + + +class DecodeSourceBytesTests: + + source = "string ='ü'" + + def test_ut8_default(self): + source_bytes = self.source.encode('utf-8') + self.assertEqual(self.util.decode_source(source_bytes), self.source) + + def test_specified_encoding(self): + source = '# coding=latin-1\n' + self.source + source_bytes = source.encode('latin-1') + assert source_bytes != source.encode('utf-8') + self.assertEqual(self.util.decode_source(source_bytes), source) + def test_universal_newlines(self): + source = '\r\n'.join([self.source, self.source]) + source_bytes = source.encode('utf-8') + self.assertEqual(self.util.decode_source(source_bytes), + '\n'.join([self.source, self.source])) -class ModuleForLoaderTests(unittest.TestCase): +Frozen_DecodeSourceBytesTests, Source_DecodeSourceBytesTests = test_util.test_both( + DecodeSourceBytesTests, util=[frozen_util, source_util]) + + +class ModuleForLoaderTests: """Tests for importlib.util.module_for_loader.""" + @classmethod + def module_for_loader(cls, func): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return cls.util.module_for_loader(func) + + def test_warning(self): + # Should raise a PendingDeprecationWarning when used. + with warnings.catch_warnings(): + warnings.simplefilter('error', DeprecationWarning) + with self.assertRaises(DeprecationWarning): + func = self.util.module_for_loader(lambda x: x) + def return_module(self, name): - fxn = util.module_for_loader(lambda self, module: module) + fxn = self.module_for_loader(lambda self, module: module) return fxn(self, name) def raise_exception(self, name): def to_wrap(self, module): raise ImportError - fxn = util.module_for_loader(to_wrap) + fxn = self.module_for_loader(to_wrap) try: fxn(self, name) except ImportError: @@ -35,12 +78,23 @@ class ModuleForLoaderTests(unittest.TestCase): def test_reload(self): # Test that a module is reused if already in sys.modules. + class FakeLoader: + def is_package(self, name): + return True + @self.module_for_loader + def load_module(self, module): + return module name = 'a.b.c' - module = imp.new_module('a.b.c') + module = types.ModuleType('a.b.c') + module.__loader__ = 42 + module.__package__ = 42 with test_util.uncache(name): sys.modules[name] = module - returned_module = self.return_module(name) + loader = FakeLoader() + returned_module = loader.load_module(name) self.assertIs(returned_module, sys.modules[name]) + self.assertEqual(module.__loader__, loader) + self.assertEqual(module.__package__, name) def test_new_module_failure(self): # Test that a module is removed from sys.modules if added but an @@ -53,7 +107,7 @@ class ModuleForLoaderTests(unittest.TestCase): def test_reload_failure(self): # Test that a failure on reload leaves the module in-place. name = 'a.b.c' - module = imp.new_module(name) + module = types.ModuleType(name) with test_util.uncache(name): sys.modules[name] = module self.raise_exception(name) @@ -61,7 +115,7 @@ class ModuleForLoaderTests(unittest.TestCase): def test_decorator_attrs(self): def fxn(self, module): pass - wrapped = util.module_for_loader(fxn) + wrapped = self.module_for_loader(fxn) self.assertEqual(wrapped.__name__, fxn.__name__) self.assertEqual(wrapped.__qualname__, fxn.__qualname__) @@ -87,7 +141,7 @@ class ModuleForLoaderTests(unittest.TestCase): self._pkg = is_package def is_package(self, name): return self._pkg - @util.module_for_loader + @self.module_for_loader def load_module(self, module): return module @@ -107,8 +161,11 @@ class ModuleForLoaderTests(unittest.TestCase): self.assertIs(module.__loader__, loader) self.assertEqual(module.__package__, name) +Frozen_ModuleForLoaderTests, Source_ModuleForLoaderTests = test_util.test_both( + ModuleForLoaderTests, util=[frozen_util, source_util]) + -class SetPackageTests(unittest.TestCase): +class SetPackageTests: """Tests for importlib.util.set_package.""" @@ -116,34 +173,36 @@ class SetPackageTests(unittest.TestCase): """Verify the module has the expected value for __package__ after passing through set_package.""" fxn = lambda: module - wrapped = util.set_package(fxn) - wrapped() + wrapped = self.util.set_package(fxn) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + wrapped() self.assertTrue(hasattr(module, '__package__')) self.assertEqual(expect, module.__package__) def test_top_level(self): # __package__ should be set to the empty string if a top-level module. # Implicitly tests when package is set to None. - module = imp.new_module('module') + module = types.ModuleType('module') module.__package__ = None self.verify(module, '') def test_package(self): # Test setting __package__ for a package. - module = imp.new_module('pkg') + module = types.ModuleType('pkg') module.__path__ = ['<path>'] module.__package__ = None self.verify(module, 'pkg') def test_submodule(self): # Test __package__ for a module in a package. - module = imp.new_module('pkg.mod') + module = types.ModuleType('pkg.mod') module.__package__ = None self.verify(module, 'pkg') def test_setting_if_missing(self): # __package__ should be set if it is missing. - module = imp.new_module('mod') + module = types.ModuleType('mod') if hasattr(module, '__package__'): delattr(module, '__package__') self.verify(module, '') @@ -151,58 +210,383 @@ class SetPackageTests(unittest.TestCase): def test_leaving_alone(self): # If __package__ is set and not None then leave it alone. for value in (True, False): - module = imp.new_module('mod') + module = types.ModuleType('mod') module.__package__ = value self.verify(module, value) def test_decorator_attrs(self): def fxn(module): pass - wrapped = util.set_package(fxn) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + wrapped = self.util.set_package(fxn) self.assertEqual(wrapped.__name__, fxn.__name__) self.assertEqual(wrapped.__qualname__, fxn.__qualname__) +Frozen_SetPackageTests, Source_SetPackageTests = test_util.test_both( + SetPackageTests, util=[frozen_util, source_util]) -class ResolveNameTests(unittest.TestCase): + +class SetLoaderTests: + + """Tests importlib.util.set_loader().""" + + class DummyLoader: + @util.set_loader + def load_module(self, module): + return self.module + + def test_no_attribute(self): + loader = self.DummyLoader() + loader.module = types.ModuleType('blah') + try: + del loader.module.__loader__ + except AttributeError: + pass + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertEqual(loader, loader.load_module('blah').__loader__) + + def test_attribute_is_None(self): + loader = self.DummyLoader() + loader.module = types.ModuleType('blah') + loader.module.__loader__ = None + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertEqual(loader, loader.load_module('blah').__loader__) + + def test_not_reset(self): + loader = self.DummyLoader() + loader.module = types.ModuleType('blah') + loader.module.__loader__ = 42 + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertEqual(42, loader.load_module('blah').__loader__) + +class Frozen_SetLoaderTests(SetLoaderTests, unittest.TestCase): + class DummyLoader: + @frozen_util.set_loader + def load_module(self, module): + return self.module + +class Source_SetLoaderTests(SetLoaderTests, unittest.TestCase): + class DummyLoader: + @source_util.set_loader + def load_module(self, module): + return self.module + + +class ResolveNameTests: """Tests importlib.util.resolve_name().""" def test_absolute(self): # bacon - self.assertEqual('bacon', util.resolve_name('bacon', None)) + self.assertEqual('bacon', self.util.resolve_name('bacon', None)) def test_aboslute_within_package(self): # bacon in spam - self.assertEqual('bacon', util.resolve_name('bacon', 'spam')) + self.assertEqual('bacon', self.util.resolve_name('bacon', 'spam')) def test_no_package(self): # .bacon in '' with self.assertRaises(ValueError): - util.resolve_name('.bacon', '') + self.util.resolve_name('.bacon', '') def test_in_package(self): # .bacon in spam self.assertEqual('spam.eggs.bacon', - util.resolve_name('.bacon', 'spam.eggs')) + self.util.resolve_name('.bacon', 'spam.eggs')) def test_other_package(self): # ..bacon in spam.bacon self.assertEqual('spam.bacon', - util.resolve_name('..bacon', 'spam.eggs')) + self.util.resolve_name('..bacon', 'spam.eggs')) def test_escape(self): # ..bacon in spam with self.assertRaises(ValueError): - util.resolve_name('..bacon', 'spam') + self.util.resolve_name('..bacon', 'spam') + +Frozen_ResolveNameTests, Source_ResolveNameTests = test_util.test_both( + ResolveNameTests, + util=[frozen_util, source_util]) + + +class FindSpecTests: + + class FakeMetaFinder: + @staticmethod + def find_spec(name, path=None, target=None): return name, path, target + + def test_sys_modules(self): + name = 'some_mod' + with test_util.uncache(name): + module = types.ModuleType(name) + loader = 'a loader!' + spec = self.machinery.ModuleSpec(name, loader) + module.__loader__ = loader + module.__spec__ = spec + sys.modules[name] = module + found = self.util.find_spec(name) + self.assertEqual(found, spec) + + def test_sys_modules_without___loader__(self): + name = 'some_mod' + with test_util.uncache(name): + module = types.ModuleType(name) + del module.__loader__ + loader = 'a loader!' + spec = self.machinery.ModuleSpec(name, loader) + module.__spec__ = spec + sys.modules[name] = module + found = self.util.find_spec(name) + self.assertEqual(found, spec) + + def test_sys_modules_spec_is_None(self): + name = 'some_mod' + with test_util.uncache(name): + module = types.ModuleType(name) + module.__spec__ = None + sys.modules[name] = module + with self.assertRaises(ValueError): + self.util.find_spec(name) + + def test_sys_modules_loader_is_None(self): + name = 'some_mod' + with test_util.uncache(name): + module = types.ModuleType(name) + spec = self.machinery.ModuleSpec(name, None) + module.__spec__ = spec + sys.modules[name] = module + found = self.util.find_spec(name) + self.assertEqual(found, spec) + def test_sys_modules_spec_is_not_set(self): + name = 'some_mod' + with test_util.uncache(name): + module = types.ModuleType(name) + try: + del module.__spec__ + except AttributeError: + pass + sys.modules[name] = module + with self.assertRaises(ValueError): + self.util.find_spec(name) -def test_main(): - from test import support - support.run_unittest( - ModuleForLoaderTests, - SetPackageTests, - ResolveNameTests - ) + def test_success(self): + name = 'some_mod' + with test_util.uncache(name): + with test_util.import_state(meta_path=[self.FakeMetaFinder]): + self.assertEqual((name, None, None), + self.util.find_spec(name)) + +# def test_success_path(self): +# # Searching on a path should work. +# name = 'some_mod' +# path = 'path to some place' +# with test_util.uncache(name): +# with test_util.import_state(meta_path=[self.FakeMetaFinder]): +# self.assertEqual((name, path, None), +# self.util.find_spec(name, path)) + + def test_nothing(self): + # None is returned upon failure to find a loader. + self.assertIsNone(self.util.find_spec('nevergoingtofindthismodule')) + + def test_find_submodule(self): + name = 'spam' + subname = 'ham' + with test_util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = test_util.submodule(name, subname, pkg_dir) + spec = self.util.find_spec(fullname) + self.assertIsNot(spec, None) + self.assertIn(name, sorted(sys.modules)) + self.assertNotIn(fullname, sorted(sys.modules)) + # Ensure successive calls behave the same. + spec_again = self.util.find_spec(fullname) + self.assertEqual(spec_again, spec) + + def test_find_submodule_parent_already_imported(self): + name = 'spam' + subname = 'ham' + with test_util.temp_module(name, pkg=True) as pkg_dir: + self.init.import_module(name) + fullname, _ = test_util.submodule(name, subname, pkg_dir) + spec = self.util.find_spec(fullname) + self.assertIsNot(spec, None) + self.assertIn(name, sorted(sys.modules)) + self.assertNotIn(fullname, sorted(sys.modules)) + # Ensure successive calls behave the same. + spec_again = self.util.find_spec(fullname) + self.assertEqual(spec_again, spec) + + def test_find_relative_module(self): + name = 'spam' + subname = 'ham' + with test_util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = test_util.submodule(name, subname, pkg_dir) + relname = '.' + subname + spec = self.util.find_spec(relname, name) + self.assertIsNot(spec, None) + self.assertIn(name, sorted(sys.modules)) + self.assertNotIn(fullname, sorted(sys.modules)) + # Ensure successive calls behave the same. + spec_again = self.util.find_spec(fullname) + self.assertEqual(spec_again, spec) + + def test_find_relative_module_missing_package(self): + name = 'spam' + subname = 'ham' + with test_util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = test_util.submodule(name, subname, pkg_dir) + relname = '.' + subname + with self.assertRaises(ValueError): + self.util.find_spec(relname) + self.assertNotIn(name, sorted(sys.modules)) + self.assertNotIn(fullname, sorted(sys.modules)) + + +class Frozen_FindSpecTests(FindSpecTests, unittest.TestCase): + init = frozen_init + machinery = frozen_machinery + util = frozen_util + +class Source_FindSpecTests(FindSpecTests, unittest.TestCase): + init = source_init + machinery = source_machinery + util = source_util + + +class MagicNumberTests: + + def test_length(self): + # Should be 4 bytes. + self.assertEqual(len(self.util.MAGIC_NUMBER), 4) + + def test_incorporates_rn(self): + # The magic number uses \r\n to come out wrong when splitting on lines. + self.assertTrue(self.util.MAGIC_NUMBER.endswith(b'\r\n')) + +Frozen_MagicNumberTests, Source_MagicNumberTests = test_util.test_both( + MagicNumberTests, util=[frozen_util, source_util]) + + +class PEP3147Tests: + + """Tests of PEP 3147-related functions: cache_from_source and source_from_cache.""" + + tag = sys.implementation.cache_tag + + @unittest.skipUnless(sys.implementation.cache_tag is not None, + 'requires sys.implementation.cache_tag not be None') + def test_cache_from_source(self): + # Given the path to a .py file, return the path to its PEP 3147 + # defined .pyc file (i.e. under __pycache__). + path = os.path.join('foo', 'bar', 'baz', 'qux.py') + expect = os.path.join('foo', 'bar', 'baz', '__pycache__', + 'qux.{}.pyc'.format(self.tag)) + self.assertEqual(self.util.cache_from_source(path, True), expect) + + def test_cache_from_source_no_cache_tag(self): + # No cache tag means NotImplementedError. + with support.swap_attr(sys.implementation, 'cache_tag', None): + with self.assertRaises(NotImplementedError): + self.util.cache_from_source('whatever.py') + + def test_cache_from_source_no_dot(self): + # Directory with a dot, filename without dot. + path = os.path.join('foo.bar', 'file') + expect = os.path.join('foo.bar', '__pycache__', + 'file{}.pyc'.format(self.tag)) + self.assertEqual(self.util.cache_from_source(path, True), expect) + + def test_cache_from_source_optimized(self): + # Given the path to a .py file, return the path to its PEP 3147 + # defined .pyo file (i.e. under __pycache__). + path = os.path.join('foo', 'bar', 'baz', 'qux.py') + expect = os.path.join('foo', 'bar', 'baz', '__pycache__', + 'qux.{}.pyo'.format(self.tag)) + self.assertEqual(self.util.cache_from_source(path, False), expect) + + def test_cache_from_source_cwd(self): + path = 'foo.py' + expect = os.path.join('__pycache__', 'foo.{}.pyc'.format(self.tag)) + self.assertEqual(self.util.cache_from_source(path, True), expect) + + def test_cache_from_source_override(self): + # When debug_override is not None, it can be any true-ish or false-ish + # value. + path = os.path.join('foo', 'bar', 'baz.py') + partial_expect = os.path.join('foo', 'bar', '__pycache__', + 'baz.{}.py'.format(self.tag)) + self.assertEqual(self.util.cache_from_source(path, []), partial_expect + 'o') + self.assertEqual(self.util.cache_from_source(path, [17]), + partial_expect + 'c') + # However if the bool-ishness can't be determined, the exception + # propagates. + class Bearish: + def __bool__(self): raise RuntimeError + with self.assertRaises(RuntimeError): + self.util.cache_from_source('/foo/bar/baz.py', Bearish()) + + @unittest.skipUnless(os.sep == '\\' and os.altsep == '/', + 'test meaningful only where os.altsep is defined') + def test_sep_altsep_and_sep_cache_from_source(self): + # Windows path and PEP 3147 where sep is right of altsep. + self.assertEqual( + self.util.cache_from_source('\\foo\\bar\\baz/qux.py', True), + '\\foo\\bar\\baz\\__pycache__\\qux.{}.pyc'.format(self.tag)) + + @unittest.skipUnless(sys.implementation.cache_tag is not None, + 'requires sys.implementation.cache_tag to not be ' + 'None') + def test_source_from_cache(self): + # Given the path to a PEP 3147 defined .pyc file, return the path to + # its source. This tests the good path. + path = os.path.join('foo', 'bar', 'baz', '__pycache__', + 'qux.{}.pyc'.format(self.tag)) + expect = os.path.join('foo', 'bar', 'baz', 'qux.py') + self.assertEqual(self.util.source_from_cache(path), expect) + + def test_source_from_cache_no_cache_tag(self): + # If sys.implementation.cache_tag is None, raise NotImplementedError. + path = os.path.join('blah', '__pycache__', 'whatever.pyc') + with support.swap_attr(sys.implementation, 'cache_tag', None): + with self.assertRaises(NotImplementedError): + self.util.source_from_cache(path) + + def test_source_from_cache_bad_path(self): + # When the path to a pyc file is not in PEP 3147 format, a ValueError + # is raised. + self.assertRaises( + ValueError, self.util.source_from_cache, '/foo/bar/bazqux.pyc') + + def test_source_from_cache_no_slash(self): + # No slashes at all in path -> ValueError + self.assertRaises( + ValueError, self.util.source_from_cache, 'foo.cpython-32.pyc') + + def test_source_from_cache_too_few_dots(self): + # Too few dots in final path component -> ValueError + self.assertRaises( + ValueError, self.util.source_from_cache, '__pycache__/foo.pyc') + + def test_source_from_cache_too_many_dots(self): + # Too many dots in final path component -> ValueError + self.assertRaises( + ValueError, self.util.source_from_cache, + '__pycache__/foo.cpython-32.foo.pyc') + + def test_source_from_cache_no__pycache__(self): + # Another problem with the path -> ValueError + self.assertRaises( + ValueError, self.util.source_from_cache, + '/foo/bar/foo.cpython-32.foo.pyc') + +Frozen_PEP3147Tests, Source_PEP3147Tests = test_util.test_both( + PEP3147Tests, + util=[frozen_util, source_util]) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_importlib/test_windows.py b/Lib/test/test_importlib/test_windows.py new file mode 100644 index 0000000000..96b4adcd05 --- /dev/null +++ b/Lib/test/test_importlib/test_windows.py @@ -0,0 +1,29 @@ +from . import util +frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') + +import sys +import unittest + + +@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows') +class WindowsRegistryFinderTests: + + # XXX Need a test that finds the spec via the registry. + + def test_find_spec_missing(self): + spec = self.machinery.WindowsRegistryFinder.find_spec('spam') + self.assertIs(spec, None) + + def test_find_module_missing(self): + loader = self.machinery.WindowsRegistryFinder.find_module('spam') + self.assertIs(loader, None) + + +class Frozen_WindowsRegistryFinderTests(WindowsRegistryFinderTests, + unittest.TestCase): + machinery = frozen_machinery + + +class Source_WindowsRegistryFinderTests(WindowsRegistryFinderTests, + unittest.TestCase): + machinery = source_machinery diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py index ef32f7d690..885cec3b29 100644 --- a/Lib/test/test_importlib/util.py +++ b/Lib/test/test_importlib/util.py @@ -1,9 +1,31 @@ from contextlib import contextmanager -import imp +from importlib import util, invalidate_caches import os.path from test import support import unittest import sys +import types + + +def import_importlib(module_name): + """Import a module from importlib both w/ and w/o _frozen_importlib.""" + fresh = ('importlib',) if '.' in module_name else () + frozen = support.import_fresh_module(module_name) + source = support.import_fresh_module(module_name, fresh=fresh, + blocked=('_frozen_importlib',)) + return frozen, source + + +def test_both(test_class, **kwargs): + frozen_tests = types.new_class('Frozen_'+test_class.__name__, + (test_class, unittest.TestCase)) + source_tests = types.new_class('Source_'+test_class.__name__, + (test_class, unittest.TestCase)) + frozen_tests.__module__ = source_tests.__module__ = test_class.__module__ + for attr, (frozen_value, source_value) in kwargs.items(): + setattr(frozen_tests, attr, frozen_value) + setattr(source_tests, attr, source_value) + return frozen_tests, source_tests CASE_INSENSITIVE_FS = True @@ -24,6 +46,13 @@ def case_insensitive_tests(test): "requires a case-insensitive filesystem")(test) +def submodule(parent, name, pkg_dir, content=''): + path = os.path.join(pkg_dir, name + '.py') + with open(path, 'w') as subfile: + subfile.write(content) + return '{}.{}'.format(parent, name), path + + @contextmanager def uncache(*names): """Uncache a module from sys.modules. @@ -49,6 +78,31 @@ def uncache(*names): except KeyError: pass + +@contextmanager +def temp_module(name, content='', *, pkg=False): + conflicts = [n for n in sys.modules if n.partition('.')[0] == name] + with support.temp_cwd(None) as cwd: + with uncache(name, *conflicts): + with support.DirsOnSysPath(cwd): + invalidate_caches() + + location = os.path.join(cwd, name) + if pkg: + modpath = os.path.join(location, '__init__.py') + os.mkdir(name) + else: + modpath = location + '.py' + if content is None: + # Make sure the module file gets created. + content = '' + if content is not None: + # not a namespace package + with open(modpath, 'w') as modfile: + modfile.write(content) + yield location + + @contextmanager def import_state(**kwargs): """Context manager to manage the various importers and stored state in the @@ -80,9 +134,9 @@ def import_state(**kwargs): setattr(sys, attr, value) -class mock_modules: +class _ImporterMock: - """A mock importer/loader.""" + """Base class to help with creating importer mocks.""" def __init__(self, *names, module_code={}): self.modules = {} @@ -98,7 +152,7 @@ class mock_modules: package = name.rsplit('.', 1)[0] else: package = import_name - module = imp.new_module(import_name) + module = types.ModuleType(import_name) module.__loader__ = self module.__file__ = '<mock __file__>' module.__package__ = package @@ -112,6 +166,19 @@ class mock_modules: def __getitem__(self, name): return self.modules[name] + def __enter__(self): + self._uncache = uncache(*self.modules.keys()) + self._uncache.__enter__() + return self + + def __exit__(self, *exc_info): + self._uncache.__exit__(None, None, None) + + +class mock_modules(_ImporterMock): + + """Importer mock using PEP 302 APIs.""" + def find_module(self, fullname, path=None): if fullname not in self.modules: return None @@ -131,10 +198,28 @@ class mock_modules: raise return self.modules[fullname] - def __enter__(self): - self._uncache = uncache(*self.modules.keys()) - self._uncache.__enter__() - return self +class mock_spec(_ImporterMock): - def __exit__(self, *exc_info): - self._uncache.__exit__(None, None, None) + """Importer mock using PEP 451 APIs.""" + + def find_spec(self, fullname, path=None, parent=None): + try: + module = self.modules[fullname] + except KeyError: + return None + is_package = hasattr(module, '__path__') + spec = util.spec_from_file_location( + fullname, module.__file__, loader=self, + submodule_search_locations=getattr(module, '__path__', None)) + return spec + + def create_module(self, spec): + if spec.name not in self.modules: + raise ImportError + return self.modules[spec.name] + + def exec_module(self, module): + try: + self.module_code[module.__spec__.name]() + except KeyError: + pass diff --git a/Lib/test/test_index.py b/Lib/test/test_index.py index 87857119e1..a2ac32132e 100644 --- a/Lib/test/test_index.py +++ b/Lib/test/test_index.py @@ -81,7 +81,8 @@ class BaseTestCase(unittest.TestCase): return True bad_int = BadInt() - n = operator.index(bad_int) + with self.assertWarns(DeprecationWarning): + n = operator.index(bad_int) self.assertEqual(n, 1) bad_int = BadInt2() diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py index 5cbec9bb17..0dc74512a2 100644 --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -1,22 +1,32 @@ -import re -import sys -import types -import unittest +import collections +import datetime +import functools +import importlib import inspect +import io import linecache -import datetime -import collections import os -import shutil from os.path import normcase +import _pickle +import re +import shutil +import sys +import types +import unicodedata +import unittest +import unittest.mock -from test.support import run_unittest, TESTFN, DirsOnSysPath +try: + from concurrent.futures import ThreadPoolExecutor +except ImportError: + ThreadPoolExecutor = None +from test.support import run_unittest, TESTFN, DirsOnSysPath, cpython_only +from test.support import MISSING_C_DOCSTRINGS +from test.script_helper import assert_python_ok, assert_python_failure from test import inspect_fodder as mod from test import inspect_fodder2 as mod2 -# C module for test_findsource_binary -import unicodedata # Functions tested in this suite: # ismodule, isclass, ismethod, isfunction, istraceback, isframe, iscode, @@ -120,7 +130,6 @@ class TestPredicates(IsTestBase): def test_get_slot_members(self): class C(object): __slots__ = ("a", "b") - x = C() x.a = 42 members = dict(inspect.getmembers(x)) @@ -310,6 +319,16 @@ class TestRetrievingSourceCode(GetSourceBase): def test_getfile(self): self.assertEqual(inspect.getfile(mod.StupidGit), mod.__file__) + def test_getfile_class_without_module(self): + class CM(type): + @property + def __module__(cls): + raise AttributeError + class C(metaclass=CM): + pass + with self.assertRaises(TypeError): + inspect.getfile(C) + def test_getmodule_recursion(self): from types import ModuleType name = '__inspect_dummy' @@ -418,14 +437,14 @@ class TestBuggyCases(GetSourceBase): unicodedata.__file__[-4:] in (".pyc", ".pyo"), "unicodedata is not an external binary module") def test_findsource_binary(self): - self.assertRaises(IOError, inspect.getsource, unicodedata) - self.assertRaises(IOError, inspect.findsource, unicodedata) + self.assertRaises(OSError, inspect.getsource, unicodedata) + self.assertRaises(OSError, inspect.findsource, unicodedata) def test_findsource_code_in_linecache(self): lines = ["x=1"] co = compile(lines[0], "_dynamically_created_file", "exec") - self.assertRaises(IOError, inspect.findsource, co) - self.assertRaises(IOError, inspect.getsource, co) + self.assertRaises(OSError, inspect.findsource, co) + self.assertRaises(OSError, inspect.getsource, co) linecache.cache[co.co_filename] = (1, None, lines, co.co_filename) try: self.assertEqual(inspect.findsource(co), (lines,0)) @@ -463,13 +482,13 @@ class _BrokenDataDescriptor(object): A broken data descriptor. See bug #1785. """ def __get__(*args): - raise AssertionError("should not __get__ data descriptors") + raise AttributeError("broken data descriptor") def __set__(*args): raise RuntimeError def __getattr__(*args): - raise AssertionError("should not __getattr__ data descriptors") + raise AttributeError("broken data descriptor") class _BrokenMethodDescriptor(object): @@ -477,10 +496,10 @@ class _BrokenMethodDescriptor(object): A broken method descriptor. See bug #1785. """ def __get__(*args): - raise AssertionError("should not __get__ method descriptors") + raise AttributeError("broken method descriptor") def __getattr__(*args): - raise AssertionError("should not __getattr__ method descriptors") + raise AttributeError("broken method descriptor") # Helper for testing classify_class_attrs. @@ -559,6 +578,96 @@ class TestClassesAndFunctions(unittest.TestCase): kwonlyargs_e=['arg'], formatted='(*, arg)') + def test_argspec_api_ignores_wrapped(self): + # Issue 20684: low level introspection API must ignore __wrapped__ + @functools.wraps(mod.spam) + def ham(x, y): + pass + # Basic check + self.assertArgSpecEquals(ham, ['x', 'y'], formatted='(x, y)') + self.assertFullArgSpecEquals(ham, ['x', 'y'], formatted='(x, y)') + self.assertFullArgSpecEquals(functools.partial(ham), + ['x', 'y'], formatted='(x, y)') + # Other variants + def check_method(f): + self.assertArgSpecEquals(f, ['self', 'x', 'y'], + formatted='(self, x, y)') + class C: + @functools.wraps(mod.spam) + def ham(self, x, y): + pass + pham = functools.partialmethod(ham) + @functools.wraps(mod.spam) + def __call__(self, x, y): + pass + check_method(C()) + check_method(C.ham) + check_method(C().ham) + check_method(C.pham) + check_method(C().pham) + + class C_new: + @functools.wraps(mod.spam) + def __new__(self, x, y): + pass + check_method(C_new) + + class C_init: + @functools.wraps(mod.spam) + def __init__(self, x, y): + pass + check_method(C_init) + + def test_getfullargspec_signature_attr(self): + def test(): + pass + spam_param = inspect.Parameter('spam', inspect.Parameter.POSITIONAL_ONLY) + test.__signature__ = inspect.Signature(parameters=(spam_param,)) + + self.assertFullArgSpecEquals(test, args_e=['spam'], formatted='(spam)') + + def test_getfullargspec_signature_annos(self): + def test(a:'spam') -> 'ham': pass + spec = inspect.getfullargspec(test) + self.assertEqual(test.__annotations__, spec.annotations) + + def test(): pass + spec = inspect.getfullargspec(test) + self.assertEqual(test.__annotations__, spec.annotations) + + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_getfullargspec_builtin_methods(self): + self.assertFullArgSpecEquals(_pickle.Pickler.dump, + args_e=['self', 'obj'], formatted='(self, obj)') + + self.assertFullArgSpecEquals(_pickle.Pickler(io.BytesIO()).dump, + args_e=['self', 'obj'], formatted='(self, obj)') + + self.assertFullArgSpecEquals( + os.stat, + args_e=['path'], + kwonlyargs_e=['dir_fd', 'follow_symlinks'], + kwonlydefaults_e={'dir_fd': None, 'follow_symlinks': True}, + formatted='(path, *, dir_fd=None, follow_symlinks=True)') + + @cpython_only + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_getfullagrspec_builtin_func(self): + import _testcapi + builtin = _testcapi.docstring_with_signature_with_defaults + spec = inspect.getfullargspec(builtin) + self.assertEqual(spec.defaults[0], 'avocado') + + @cpython_only + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_getfullagrspec_builtin_func_no_signature(self): + import _testcapi + builtin = _testcapi.docstring_no_signature + with self.assertRaises(TypeError): + inspect.getfullargspec(builtin) def test_getargspec_method(self): class A(object): @@ -588,6 +697,10 @@ class TestClassesAndFunctions(unittest.TestCase): md = _BrokenMethodDescriptor() attrs = attrs_wo_objs(A) + + self.assertIn(('__new__', 'method', object), attrs, 'missing __new__') + self.assertIn(('__init__', 'method', object), attrs, 'missing __init__') + self.assertIn(('s', 'static method', A), attrs, 'missing static method') self.assertIn(('c', 'class method', A), attrs, 'missing class method') self.assertIn(('p', 'property', A), attrs, 'missing property') @@ -650,6 +763,88 @@ class TestClassesAndFunctions(unittest.TestCase): if isinstance(builtin, type): inspect.classify_class_attrs(builtin) + def test_classify_DynamicClassAttribute(self): + class Meta(type): + def __getattr__(self, name): + if name == 'ham': + return 'spam' + return super().__getattr__(name) + class VA(metaclass=Meta): + @types.DynamicClassAttribute + def ham(self): + return 'eggs' + should_find_dca = inspect.Attribute('ham', 'data', VA, VA.__dict__['ham']) + self.assertIn(should_find_dca, inspect.classify_class_attrs(VA)) + should_find_ga = inspect.Attribute('ham', 'data', Meta, 'spam') + self.assertIn(should_find_ga, inspect.classify_class_attrs(VA)) + + def test_classify_metaclass_class_attribute(self): + class Meta(type): + fish = 'slap' + def __dir__(self): + return ['__class__', '__modules__', '__name__', 'fish'] + class Class(metaclass=Meta): + pass + should_find = inspect.Attribute('fish', 'data', Meta, 'slap') + self.assertIn(should_find, inspect.classify_class_attrs(Class)) + + def test_classify_VirtualAttribute(self): + class Meta(type): + def __dir__(cls): + return ['__class__', '__module__', '__name__', 'BOOM'] + def __getattr__(self, name): + if name =='BOOM': + return 42 + return super().__getattr(name) + class Class(metaclass=Meta): + pass + should_find = inspect.Attribute('BOOM', 'data', Meta, 42) + self.assertIn(should_find, inspect.classify_class_attrs(Class)) + + def test_classify_VirtualAttribute_multi_classes(self): + class Meta1(type): + def __dir__(cls): + return ['__class__', '__module__', '__name__', 'one'] + def __getattr__(self, name): + if name =='one': + return 1 + return super().__getattr__(name) + class Meta2(type): + def __dir__(cls): + return ['__class__', '__module__', '__name__', 'two'] + def __getattr__(self, name): + if name =='two': + return 2 + return super().__getattr__(name) + class Meta3(Meta1, Meta2): + def __dir__(cls): + return list(sorted(set(['__class__', '__module__', '__name__', 'three'] + + Meta1.__dir__(cls) + Meta2.__dir__(cls)))) + def __getattr__(self, name): + if name =='three': + return 3 + return super().__getattr__(name) + class Class1(metaclass=Meta1): + pass + class Class2(Class1, metaclass=Meta3): + pass + + should_find1 = inspect.Attribute('one', 'data', Meta1, 1) + should_find2 = inspect.Attribute('two', 'data', Meta2, 2) + should_find3 = inspect.Attribute('three', 'data', Meta3, 3) + cca = inspect.classify_class_attrs(Class2) + for sf in (should_find1, should_find2, should_find3): + self.assertIn(sf, cca) + + def test_classify_class_attrs_with_buggy_dir(self): + class M(type): + def __dir__(cls): + return ['__class__', '__name__', 'missing'] + class C(metaclass=M): + pass + attrs = [a[0] for a in inspect.classify_class_attrs(C)] + self.assertNotIn('missing', attrs) + def test_getmembers_descriptors(self): class A(object): dd = _BrokenDataDescriptor() @@ -693,6 +888,28 @@ class TestClassesAndFunctions(unittest.TestCase): self.assertIn(('f', b.f), inspect.getmembers(b)) self.assertIn(('f', b.f), inspect.getmembers(b, inspect.ismethod)) + def test_getmembers_VirtualAttribute(self): + class M(type): + def __getattr__(cls, name): + if name == 'eggs': + return 'scrambled' + return super().__getattr__(name) + class A(metaclass=M): + @types.DynamicClassAttribute + def eggs(self): + return 'spam' + self.assertIn(('eggs', 'scrambled'), inspect.getmembers(A)) + self.assertIn(('eggs', 'spam'), inspect.getmembers(A())) + + def test_getmembers_with_buggy_dir(self): + class M(type): + def __dir__(cls): + return ['__class__', '__name__', 'missing'] + class C(metaclass=M): + pass + attrs = [a[0] for a in inspect.getmembers(C)] + self.assertNotIn('missing', attrs) + _global_ref = object() class TestGetClosureVars(unittest.TestCase): @@ -1080,6 +1297,15 @@ class TestGetattrStatic(unittest.TestCase): self.assertEqual(inspect.getattr_static(Thing, 'x'), Thing.x) + def test_classVirtualAttribute(self): + class Thing(object): + @types.DynamicClassAttribute + def x(self): + return self._x + _x = object() + + self.assertEqual(inspect.getattr_static(Thing, 'x'), Thing.__dict__['x']) + def test_inherited_classattribute(self): class Thing(object): x = object() @@ -1390,11 +1616,13 @@ class TestSignatureObject(unittest.TestCase): self.assertEqual(str(S()), '()') - def test(po, pk, *args, ko, **kwargs): + def test(po, pk, pod=42, pkd=100, *args, ko, **kwargs): pass sig = inspect.signature(test) po = sig.parameters['po'].replace(kind=P.POSITIONAL_ONLY) + pod = sig.parameters['pod'].replace(kind=P.POSITIONAL_ONLY) pk = sig.parameters['pk'] + pkd = sig.parameters['pkd'] args = sig.parameters['args'] ko = sig.parameters['ko'] kwargs = sig.parameters['kwargs'] @@ -1417,6 +1645,15 @@ class TestSignatureObject(unittest.TestCase): with self.assertRaisesRegex(ValueError, 'duplicate parameter name'): S((po, pk, args, kwargs2, ko)) + with self.assertRaisesRegex(ValueError, 'follows default argument'): + S((pod, po)) + + with self.assertRaisesRegex(ValueError, 'follows default argument'): + S((po, pkd, pk)) + + with self.assertRaisesRegex(ValueError, 'follows default argument'): + S((pkd, pk)) + def test_signature_immutability(self): def test(a): pass @@ -1461,19 +1698,95 @@ class TestSignatureObject(unittest.TestCase): ('kwargs', ..., int, "var_keyword")), ...)) - def test_signature_on_builtin_function(self): - with self.assertRaisesRegex(ValueError, 'not supported by signature'): - inspect.signature(type) - with self.assertRaisesRegex(ValueError, 'not supported by signature'): - # support for 'wrapper_descriptor' - inspect.signature(type.__call__) - with self.assertRaisesRegex(ValueError, 'not supported by signature'): - # support for 'method-wrapper' - inspect.signature(min.__call__) - with self.assertRaisesRegex(ValueError, - 'no signature found for builtin function'): - # support for 'method-wrapper' - inspect.signature(min) + @cpython_only + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_signature_on_builtins(self): + import _testcapi + + def test_unbound_method(o): + """Use this to test unbound methods (things that should have a self)""" + signature = inspect.signature(o) + self.assertTrue(isinstance(signature, inspect.Signature)) + self.assertEqual(list(signature.parameters.values())[0].name, 'self') + return signature + + def test_callable(o): + """Use this to test bound methods or normal callables (things that don't expect self)""" + signature = inspect.signature(o) + self.assertTrue(isinstance(signature, inspect.Signature)) + if signature.parameters: + self.assertNotEqual(list(signature.parameters.values())[0].name, 'self') + return signature + + signature = test_callable(_testcapi.docstring_with_signature_with_defaults) + def p(name): return signature.parameters[name].default + self.assertEqual(p('s'), 'avocado') + self.assertEqual(p('b'), b'bytes') + self.assertEqual(p('d'), 3.14) + self.assertEqual(p('i'), 35) + self.assertEqual(p('n'), None) + self.assertEqual(p('t'), True) + self.assertEqual(p('f'), False) + self.assertEqual(p('local'), 3) + self.assertEqual(p('sys'), sys.maxsize) + self.assertEqual(p('exp'), sys.maxsize - 1) + + test_callable(object) + + # normal method + # (PyMethodDescr_Type, "method_descriptor") + test_unbound_method(_pickle.Pickler.dump) + d = _pickle.Pickler(io.StringIO()) + test_callable(d.dump) + + # static method + test_callable(str.maketrans) + test_callable('abc'.maketrans) + + # class method + test_callable(dict.fromkeys) + test_callable({}.fromkeys) + + # wrapper around slot (PyWrapperDescr_Type, "wrapper_descriptor") + test_unbound_method(type.__call__) + test_unbound_method(int.__add__) + test_callable((3).__add__) + + # _PyMethodWrapper_Type + # support for 'method-wrapper' + test_callable(min.__call__) + + # This doesn't work now. + # (We don't have a valid signature for "type" in 3.4) + with self.assertRaisesRegex(ValueError, "no signature found"): + class ThisWorksNow: + __call__ = type + test_callable(ThisWorksNow()) + + @cpython_only + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_signature_on_decorated_builtins(self): + import _testcapi + func = _testcapi.docstring_with_signature_with_defaults + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs) -> int: + return func(*args, **kwargs) + return wrapper + + decorated_func = decorator(func) + + self.assertEqual(inspect.signature(func), + inspect.signature(decorated_func)) + + @cpython_only + def test_signature_on_builtins_no_signature(self): + import _testcapi + with self.assertRaisesRegex(ValueError, 'no signature found for builtin'): + inspect.signature(_testcapi.docstring_no_signature) def test_signature_on_non_function(self): with self.assertRaisesRegex(TypeError, 'is not a callable object'): @@ -1482,18 +1795,112 @@ class TestSignatureObject(unittest.TestCase): with self.assertRaisesRegex(TypeError, 'is not a Python function'): inspect.Signature.from_function(42) + def test_signature_from_builtin_errors(self): + with self.assertRaisesRegex(TypeError, 'is not a Python builtin'): + inspect.Signature.from_builtin(42) + + def test_signature_from_functionlike_object(self): + def func(a,b, *args, kwonly=True, kwonlyreq, **kwargs): + pass + + class funclike: + # Has to be callable, and have correct + # __code__, __annotations__, __defaults__, __name__, + # and __kwdefaults__ attributes + + def __init__(self, func): + self.__name__ = func.__name__ + self.__code__ = func.__code__ + self.__annotations__ = func.__annotations__ + self.__defaults__ = func.__defaults__ + self.__kwdefaults__ = func.__kwdefaults__ + self.func = func + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + sig_func = inspect.Signature.from_function(func) + + sig_funclike = inspect.Signature.from_function(funclike(func)) + self.assertEqual(sig_funclike, sig_func) + + sig_funclike = inspect.signature(funclike(func)) + self.assertEqual(sig_funclike, sig_func) + + # If object is not a duck type of function, then + # signature will try to get a signature for its '__call__' + # method + fl = funclike(func) + del fl.__defaults__ + self.assertEqual(self.signature(fl), + ((('args', ..., ..., "var_positional"), + ('kwargs', ..., ..., "var_keyword")), + ...)) + + # Test with cython-like builtins: + _orig_isdesc = inspect.ismethoddescriptor + def _isdesc(obj): + if hasattr(obj, '_builtinmock'): + return True + return _orig_isdesc(obj) + + with unittest.mock.patch('inspect.ismethoddescriptor', _isdesc): + builtin_func = funclike(func) + # Make sure that our mock setup is working + self.assertFalse(inspect.ismethoddescriptor(builtin_func)) + builtin_func._builtinmock = True + self.assertTrue(inspect.ismethoddescriptor(builtin_func)) + self.assertEqual(inspect.signature(builtin_func), sig_func) + + def test_signature_functionlike_class(self): + # We only want to duck type function-like objects, + # not classes. + + def func(a,b, *args, kwonly=True, kwonlyreq, **kwargs): + pass + + class funclike: + def __init__(self, marker): + pass + + __name__ = func.__name__ + __code__ = func.__code__ + __annotations__ = func.__annotations__ + __defaults__ = func.__defaults__ + __kwdefaults__ = func.__kwdefaults__ + + with self.assertRaisesRegex(TypeError, 'is not a Python function'): + inspect.Signature.from_function(funclike) + + self.assertEqual(str(inspect.signature(funclike)), '(marker)') + def test_signature_on_method(self): class Test: - def foo(self, arg1, arg2=1) -> int: + def __init__(*args): + pass + def m1(self, arg1, arg2=1) -> int: + pass + def m2(*args): + pass + def __call__(*, a): pass - meth = Test().foo - - self.assertEqual(self.signature(meth), + self.assertEqual(self.signature(Test().m1), ((('arg1', ..., ..., "positional_or_keyword"), ('arg2', 1, ..., "positional_or_keyword")), int)) + self.assertEqual(self.signature(Test().m2), + ((('args', ..., ..., "var_positional"),), + ...)) + + self.assertEqual(self.signature(Test), + ((('args', ..., ..., "var_positional"),), + ...)) + + with self.assertRaisesRegex(ValueError, 'invalid method signature'): + self.signature(Test()) + def test_signature_on_classmethod(self): class Test: @classmethod @@ -1687,6 +2094,38 @@ class TestSignatureObject(unittest.TestCase): ba = inspect.signature(_foo).bind(12, 14) self.assertEqual(_foo(*ba.args, **ba.kwargs), (12, 14, 13)) + def test_signature_on_partialmethod(self): + from functools import partialmethod + + class Spam: + def test(): + pass + ham = partialmethod(test) + + with self.assertRaisesRegex(ValueError, "has incorrect arguments"): + inspect.signature(Spam.ham) + + class Spam: + def test(it, a, *, c) -> 'spam': + pass + ham = partialmethod(test, c=1) + + self.assertEqual(self.signature(Spam.ham), + ((('it', ..., ..., 'positional_or_keyword'), + ('a', ..., ..., 'positional_or_keyword'), + ('c', 1, ..., 'keyword_only')), + 'spam')) + + self.assertEqual(self.signature(Spam().ham), + ((('a', ..., ..., 'positional_or_keyword'), + ('c', 1, ..., 'keyword_only')), + 'spam')) + + def test_signature_on_fake_partialmethod(self): + def foo(a): pass + foo._partialmethod = 'spam' + self.assertEqual(str(inspect.signature(foo)), '(a)') + def test_signature_on_decorated(self): import functools @@ -1736,6 +2175,17 @@ class TestSignatureObject(unittest.TestCase): ((('b', ..., ..., "positional_or_keyword"),), ...)) + # Test we handle __signature__ partway down the wrapper stack + def wrapped_foo_call(): + pass + wrapped_foo_call.__wrapped__ = Foo.__call__ + + self.assertEqual(self.signature(wrapped_foo_call), + ((('a', ..., ..., "positional_or_keyword"), + ('b', ..., ..., "positional_or_keyword")), + ...)) + + def test_signature_on_class(self): class C: def __init__(self, a): @@ -1817,6 +2267,49 @@ class TestSignatureObject(unittest.TestCase): ('bar', 2, ..., "keyword_only")), ...)) + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_signature_on_class_without_init(self): + # Test classes without user-defined __init__ or __new__ + class C: pass + self.assertEqual(str(inspect.signature(C)), '()') + class D(C): pass + self.assertEqual(str(inspect.signature(D)), '()') + + # Test meta-classes without user-defined __init__ or __new__ + class C(type): pass + class D(C): pass + with self.assertRaisesRegex(ValueError, "callable.*is not supported"): + self.assertEqual(inspect.signature(C), None) + with self.assertRaisesRegex(ValueError, "callable.*is not supported"): + self.assertEqual(inspect.signature(D), None) + + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_signature_on_builtin_class(self): + self.assertEqual(str(inspect.signature(_pickle.Pickler)), + '(file, protocol=None, fix_imports=True)') + + class P(_pickle.Pickler): pass + class EmptyTrait: pass + class P2(EmptyTrait, P): pass + self.assertEqual(str(inspect.signature(P)), + '(file, protocol=None, fix_imports=True)') + self.assertEqual(str(inspect.signature(P2)), + '(file, protocol=None, fix_imports=True)') + + class P3(P2): + def __init__(self, spam): + pass + self.assertEqual(str(inspect.signature(P3)), '(spam)') + + class MetaP(type): + def __call__(cls, foo, bar): + pass + class P4(P2, metaclass=MetaP): + pass + self.assertEqual(str(inspect.signature(P4)), '(foo, bar)') + def test_signature_on_callable_objects(self): class Foo: def __call__(self, a): @@ -1838,18 +2331,16 @@ class TestSignatureObject(unittest.TestCase): ((('a', ..., ..., "positional_or_keyword"),), ...)) - class ToFail: - __call__ = type - with self.assertRaisesRegex(ValueError, "not supported by signature"): - inspect.signature(ToFail()) - - class Wrapped: pass Wrapped.__wrapped__ = lambda a: None self.assertEqual(self.signature(Wrapped), ((('a', ..., ..., "positional_or_keyword"),), ...)) + # wrapper loop: + Wrapped.__wrapped__ = Wrapped + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + self.signature(Wrapped) def test_signature_on_lambdas(self): self.assertEqual(self.signature((lambda a=10: a)), @@ -1923,6 +2414,7 @@ class TestSignatureObject(unittest.TestCase): def test_signature_str_positional_only(self): P = inspect.Parameter + S = inspect.Signature def test(a_po, *, b, **kwargs): return a_po, kwargs @@ -1933,14 +2425,20 @@ class TestSignatureObject(unittest.TestCase): test.__signature__ = sig.replace(parameters=new_params) self.assertEqual(str(inspect.signature(test)), - '(<a_po>, *, b, **kwargs)') + '(a_po, /, *, b, **kwargs)') - sig = inspect.signature(test) - new_params = list(sig.parameters.values()) - new_params[0] = new_params[0].replace(name=None) - test.__signature__ = sig.replace(parameters=new_params) - self.assertEqual(str(inspect.signature(test)), - '(<0>, *, b, **kwargs)') + self.assertEqual(str(S(parameters=[P('foo', P.POSITIONAL_ONLY)])), + '(foo, /)') + + self.assertEqual(str(S(parameters=[ + P('foo', P.POSITIONAL_ONLY), + P('bar', P.VAR_KEYWORD)])), + '(foo, /, **bar)') + + self.assertEqual(str(S(parameters=[ + P('foo', P.POSITIONAL_ONLY), + P('bar', P.VAR_POSITIONAL)])), + '(foo, /, *bar)') def test_signature_replace_anno(self): def test() -> 42: @@ -1955,6 +2453,22 @@ class TestSignatureObject(unittest.TestCase): self.assertEqual(sig.return_annotation, 42) self.assertEqual(sig, inspect.signature(test)) + def test_signature_on_mangled_parameters(self): + class Spam: + def foo(self, __p1:1=2, *, __p2:2=3): + pass + class Ham(Spam): + pass + + self.assertEqual(self.signature(Spam.foo), + ((('self', ..., ..., "positional_or_keyword"), + ('_Spam__p1', 2, 1, "positional_or_keyword"), + ('_Spam__p2', 3, 2, "keyword_only")), + ...)) + + self.assertEqual(self.signature(Spam.foo), + self.signature(Ham.foo)) + class TestParameterObject(unittest.TestCase): def test_signature_parameter_kinds(self): @@ -1979,10 +2493,13 @@ class TestParameterObject(unittest.TestCase): with self.assertRaisesRegex(ValueError, 'not a valid parameter name'): inspect.Parameter('1', kind=inspect.Parameter.VAR_KEYWORD) - with self.assertRaisesRegex(ValueError, - 'non-positional-only parameter'): + with self.assertRaisesRegex(TypeError, 'name must be a str'): inspect.Parameter(None, kind=inspect.Parameter.VAR_KEYWORD) + with self.assertRaisesRegex(ValueError, + 'is not a valid parameter name'): + inspect.Parameter('$', kind=inspect.Parameter.VAR_KEYWORD) + with self.assertRaisesRegex(ValueError, 'cannot have default values'): inspect.Parameter('a', default=42, kind=inspect.Parameter.VAR_KEYWORD) @@ -2031,7 +2548,8 @@ class TestParameterObject(unittest.TestCase): self.assertEqual(p2.name, 'bar') self.assertNotEqual(p2, p) - with self.assertRaisesRegex(ValueError, 'not a valid parameter name'): + with self.assertRaisesRegex(ValueError, + 'name is a required attribute'): p2 = p2.replace(name=p2.empty) p2 = p2.replace(name='foo', default=None) @@ -2053,14 +2571,11 @@ class TestParameterObject(unittest.TestCase): self.assertEqual(p2, p) def test_signature_parameter_positional_only(self): - p = inspect.Parameter(None, kind=inspect.Parameter.POSITIONAL_ONLY) - self.assertEqual(str(p), '<>') - - p = p.replace(name='1') - self.assertEqual(str(p), '<1>') + with self.assertRaisesRegex(TypeError, 'name must be a str'): + inspect.Parameter(None, kind=inspect.Parameter.POSITIONAL_ONLY) def test_signature_parameter_immutability(self): - p = inspect.Parameter(None, kind=inspect.Parameter.POSITIONAL_ONLY) + p = inspect.Parameter('spam', kind=inspect.Parameter.KEYWORD_ONLY) with self.assertRaises(AttributeError): p.foo = 'bar' @@ -2258,6 +2773,15 @@ class TestSignatureBind(unittest.TestCase): self.assertEqual(self.call(test, 1, 2, 4, 5, bar=6), (1, 2, 4, 5, 6, {})) + self.assertEqual(self.call(test, 1, 2), + (1, 2, 3, 42, 50, {})) + + self.assertEqual(self.call(test, 1, 2, foo=4, bar=5), + (1, 2, 3, 4, 5, {})) + + with self.assertRaisesRegex(TypeError, "but was passed as a keyword"): + self.call(test, 1, 2, foo=4, bar=5, c_po=10) + with self.assertRaisesRegex(TypeError, "parameter is positional only"): self.call(test, 1, 2, c_po=4) @@ -2274,6 +2798,22 @@ class TestSignatureBind(unittest.TestCase): ba = sig.bind(1, self=2, b=3) self.assertEqual(ba.args, (1, 2, 3)) + def test_signature_bind_vararg_name(self): + def test(a, *args): + return a, args + sig = inspect.signature(test) + + with self.assertRaisesRegex(TypeError, "too many keyword arguments"): + sig.bind(a=0, args=1) + + def test(*args, **kwargs): + return args, kwargs + self.assertEqual(self.call(test, args=1), ((), {'args': 1})) + + sig = inspect.signature(test) + ba = sig.bind(args=1) + self.assertEqual(ba.arguments, {'kwargs': {'args': 1}}) + class TestBoundArguments(unittest.TestCase): def test_signature_bound_arguments_unhashable(self): @@ -2301,6 +2841,168 @@ class TestBoundArguments(unittest.TestCase): self.assertNotEqual(ba, ba4) +class TestSignaturePrivateHelpers(unittest.TestCase): + def test_signature_get_bound_param(self): + getter = inspect._signature_get_bound_param + + self.assertEqual(getter('($self)'), 'self') + self.assertEqual(getter('($self, obj)'), 'self') + self.assertEqual(getter('($cls, /, obj)'), 'cls') + + def _strip_non_python_syntax(self, input, + clean_signature, self_parameter, last_positional_only): + computed_clean_signature, \ + computed_self_parameter, \ + computed_last_positional_only = \ + inspect._signature_strip_non_python_syntax(input) + self.assertEqual(computed_clean_signature, clean_signature) + self.assertEqual(computed_self_parameter, self_parameter) + self.assertEqual(computed_last_positional_only, last_positional_only) + + def test_signature_strip_non_python_syntax(self): + self._strip_non_python_syntax( + "($module, /, path, mode, *, dir_fd=None, " + + "effective_ids=False,\n follow_symlinks=True)", + "(module, path, mode, *, dir_fd=None, " + + "effective_ids=False, follow_symlinks=True)", + 0, + 0) + + self._strip_non_python_syntax( + "($module, word, salt, /)", + "(module, word, salt)", + 0, + 2) + + self._strip_non_python_syntax( + "(x, y=None, z=None, /)", + "(x, y=None, z=None)", + None, + 2) + + self._strip_non_python_syntax( + "(x, y=None, z=None)", + "(x, y=None, z=None)", + None, + None) + + self._strip_non_python_syntax( + "(x,\n y=None,\n z = None )", + "(x, y=None, z=None)", + None, + None) + + self._strip_non_python_syntax( + "", + "", + None, + None) + + self._strip_non_python_syntax( + None, + None, + None, + None) + + +class TestUnwrap(unittest.TestCase): + + def test_unwrap_one(self): + def func(a, b): + return a + b + wrapper = functools.lru_cache(maxsize=20)(func) + self.assertIs(inspect.unwrap(wrapper), func) + + def test_unwrap_several(self): + def func(a, b): + return a + b + wrapper = func + for __ in range(10): + @functools.wraps(wrapper) + def wrapper(): + pass + self.assertIsNot(wrapper.__wrapped__, func) + self.assertIs(inspect.unwrap(wrapper), func) + + def test_stop(self): + def func1(a, b): + return a + b + @functools.wraps(func1) + def func2(): + pass + @functools.wraps(func2) + def wrapper(): + pass + func2.stop_here = 1 + unwrapped = inspect.unwrap(wrapper, + stop=(lambda f: hasattr(f, "stop_here"))) + self.assertIs(unwrapped, func2) + + def test_cycle(self): + def func1(): pass + func1.__wrapped__ = func1 + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + inspect.unwrap(func1) + + def func2(): pass + func2.__wrapped__ = func1 + func1.__wrapped__ = func2 + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + inspect.unwrap(func1) + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + inspect.unwrap(func2) + + def test_unhashable(self): + def func(): pass + func.__wrapped__ = None + class C: + __hash__ = None + __wrapped__ = func + self.assertIsNone(inspect.unwrap(C())) + +class TestMain(unittest.TestCase): + def test_only_source(self): + module = importlib.import_module('unittest') + rc, out, err = assert_python_ok('-m', 'inspect', + 'unittest') + lines = out.decode().splitlines() + # ignore the final newline + self.assertEqual(lines[:-1], inspect.getsource(module).splitlines()) + self.assertEqual(err, b'') + + @unittest.skipIf(ThreadPoolExecutor is None, + 'threads required to test __qualname__ for source files') + def test_qualname_source(self): + rc, out, err = assert_python_ok('-m', 'inspect', + 'concurrent.futures:ThreadPoolExecutor') + lines = out.decode().splitlines() + # ignore the final newline + self.assertEqual(lines[:-1], + inspect.getsource(ThreadPoolExecutor).splitlines()) + self.assertEqual(err, b'') + + def test_builtins(self): + module = importlib.import_module('unittest') + _, out, err = assert_python_failure('-m', 'inspect', + 'sys') + lines = err.decode().splitlines() + self.assertEqual(lines, ["Can't get info for builtin modules."]) + + def test_details(self): + module = importlib.import_module('unittest') + rc, out, err = assert_python_ok('-m', 'inspect', + 'unittest', '--details') + output = out.decode() + # Just a quick sanity check on the output + self.assertIn(module.__name__, output) + self.assertIn(module.__file__, output) + if not sys.flags.optimize: + self.assertIn(module.__cached__, output) + self.assertEqual(err, b'') + + + + def test_main(): run_unittest( TestDecorators, TestRetrievingSourceCode, TestOneliners, TestBuggyCases, @@ -2308,7 +3010,8 @@ def test_main(): TestGetcallargsFunctions, TestGetcallargsMethods, TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState, TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject, - TestBoundArguments, TestGetClosureVars + TestBoundArguments, TestSignaturePrivateHelpers, TestGetClosureVars, + TestUnwrap, TestMain ) if __name__ == "__main__": diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index c489510537..e94602e88d 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -228,6 +228,47 @@ class IntTestCases(unittest.TestCase): self.assertRaises(TypeError, int, base=10) self.assertRaises(TypeError, int, base=0) + def test_int_base_limits(self): + """Testing the supported limits of the int() base parameter.""" + self.assertEqual(int('0', 5), 0) + with self.assertRaises(ValueError): + int('0', 1) + with self.assertRaises(ValueError): + int('0', 37) + with self.assertRaises(ValueError): + int('0', -909) # An old magic value base from Python 2. + with self.assertRaises(ValueError): + int('0', base=0-(2**234)) + with self.assertRaises(ValueError): + int('0', base=2**234) + # Bases 2 through 36 are supported. + for base in range(2,37): + self.assertEqual(int('0', base=base), 0) + + def test_int_base_bad_types(self): + """Not integer types are not valid bases; issue16772.""" + with self.assertRaises(TypeError): + int('0', 5.5) + with self.assertRaises(TypeError): + int('0', 5.0) + + def test_int_base_indexable(self): + class MyIndexable(object): + def __init__(self, value): + self.value = value + def __index__(self): + return self.value + + # Check out of range bases. + for base in 2**100, -2**100, 1, 37: + with self.assertRaises(ValueError): + int('43', base) + + # Check in-range bases. + self.assertEqual(int('101', base=MyIndexable(2)), 5) + self.assertEqual(int('101', base=MyIndexable(10)), 101) + self.assertEqual(int('101', base=MyIndexable(36)), 1 + 36**2) + def test_non_numeric_input_types(self): # Test possible non-numeric types for the argument x, including # subclasses of the explicitly documented accepted types. @@ -359,15 +400,18 @@ class IntTestCases(unittest.TestCase): return True bad_int = BadInt() - n = int(bad_int) + with self.assertWarns(DeprecationWarning): + n = int(bad_int) self.assertEqual(n, 1) bad_int = BadInt2() - n = int(bad_int) + with self.assertWarns(DeprecationWarning): + n = int(bad_int) self.assertEqual(n, 1) bad_int = TruncReturnsBadInt() - n = int(bad_int) + with self.assertWarns(DeprecationWarning): + n = int(bad_int) self.assertEqual(n, 1) good_int = TruncReturnsIntSubclass() diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index ac6d478e32..267537fdeb 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -35,6 +35,7 @@ import weakref from collections import deque, UserList from itertools import cycle, count from test import support +from test.script_helper import assert_python_ok import codecs import io # C implementation of io @@ -164,7 +165,7 @@ class CloseFailureIO(MockRawIO): def close(self): if not self.closed: self.closed = 1 - raise IOError + raise OSError class CCloseFailureIO(CloseFailureIO, io.RawIOBase): pass @@ -595,9 +596,9 @@ class IOTest(unittest.TestCase): def test_flush_error_on_close(self): f = self.open(support.TESTFN, "wb", buffering=0) def bad_flush(): - raise IOError() + raise OSError() f.flush = bad_flush - self.assertRaises(IOError, f.close) # exception not swallowed + self.assertRaises(OSError, f.close) # exception not swallowed self.assertTrue(f.closed) def test_multi_close(self): @@ -757,7 +758,7 @@ class CommonBufferedTests: if s: # The destructor *may* have printed an unraisable error, check it self.assertEqual(len(s.splitlines()), 1) - self.assertTrue(s.startswith("Exception IOError: "), s) + self.assertTrue(s.startswith("Exception OSError: "), s) self.assertTrue(s.endswith(" ignored"), s) def test_repr(self): @@ -773,22 +774,22 @@ class CommonBufferedTests: def test_flush_error_on_close(self): raw = self.MockRawIO() def bad_flush(): - raise IOError() + raise OSError() raw.flush = bad_flush b = self.tp(raw) - self.assertRaises(IOError, b.close) # exception not swallowed + self.assertRaises(OSError, b.close) # exception not swallowed self.assertTrue(b.closed) def test_close_error_on_close(self): raw = self.MockRawIO() def bad_flush(): - raise IOError('flush') + raise OSError('flush') def bad_close(): - raise IOError('close') + raise OSError('close') raw.close = bad_close b = self.tp(raw) b.flush = bad_flush - with self.assertRaises(IOError) as err: # exception not swallowed + with self.assertRaises(OSError) as err: # exception not swallowed b.close() self.assertEqual(err.exception.args, ('close',)) self.assertEqual(err.exception.__context__.args, ('flush',)) @@ -828,6 +829,14 @@ class SizeofTest: bufio = self.tp(rawio, buffer_size=bufsize2) self.assertEqual(sys.getsizeof(bufio), size + bufsize2) + @support.cpython_only + def test_buffer_freeing(self) : + bufsize = 4096 + rawio = self.MockRawIO() + bufio = self.tp(rawio, buffer_size=bufsize) + size = sys.getsizeof(bufio) - bufsize + bufio.close() + self.assertEqual(sys.getsizeof(bufio), size) class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): read_mode = "rb" @@ -1012,8 +1021,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): def test_misbehaved_io(self): rawio = self.MisbehavedRawIO((b"abc", b"d", b"efg")) bufio = self.tp(rawio) - self.assertRaises(IOError, bufio.seek, 0) - self.assertRaises(IOError, bufio.tell) + self.assertRaises(OSError, bufio.seek, 0) + self.assertRaises(OSError, bufio.tell) def test_no_extraneous_read(self): # Issue #9550; when the raw IO object has satisfied the read request, @@ -1064,17 +1073,18 @@ class CBufferedReaderTest(BufferedReaderTest, SizeofTest): bufio = self.tp(rawio) # _pyio.BufferedReader seems to implement reading different, so that # checking this is not so easy. - self.assertRaises(IOError, bufio.read, 10) + self.assertRaises(OSError, bufio.read, 10) def test_garbage_collection(self): # C BufferedReader objects are collected. # The Python version has __del__, so it ends into gc.garbage instead - rawio = self.FileIO(support.TESTFN, "w+b") - f = self.tp(rawio) - f.f = f - wr = weakref.ref(f) - del f - support.gc_collect() + with support.check_warnings(('', ResourceWarning)): + rawio = self.FileIO(support.TESTFN, "w+b") + f = self.tp(rawio) + f.f = f + wr = weakref.ref(f) + del f + support.gc_collect() self.assertTrue(wr() is None, wr) def test_args_error(self): @@ -1327,9 +1337,9 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests): def test_misbehaved_io(self): rawio = self.MisbehavedRawIO() bufio = self.tp(rawio, 5) - self.assertRaises(IOError, bufio.seek, 0) - self.assertRaises(IOError, bufio.tell) - self.assertRaises(IOError, bufio.write, b"abcdef") + self.assertRaises(OSError, bufio.seek, 0) + self.assertRaises(OSError, bufio.tell) + self.assertRaises(OSError, bufio.write, b"abcdef") def test_max_buffer_size_removal(self): with self.assertRaises(TypeError): @@ -1338,11 +1348,11 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests): def test_write_error_on_close(self): raw = self.MockRawIO() def bad_write(b): - raise IOError() + raise OSError() raw.write = bad_write b = self.tp(raw) b.write(b'spam') - self.assertRaises(IOError, b.close) # exception not swallowed + self.assertRaises(OSError, b.close) # exception not swallowed self.assertTrue(b.closed) @@ -1373,13 +1383,14 @@ class CBufferedWriterTest(BufferedWriterTest, SizeofTest): # C BufferedWriter objects are collected, and collecting them flushes # all data to disk. # The Python version has __del__, so it ends into gc.garbage instead - rawio = self.FileIO(support.TESTFN, "w+b") - f = self.tp(rawio) - f.write(b"123xxx") - f.x = f - wr = weakref.ref(f) - del f - support.gc_collect() + with support.check_warnings(('', ResourceWarning)): + rawio = self.FileIO(support.TESTFN, "w+b") + f = self.tp(rawio) + f.write(b"123xxx") + f.x = f + wr = weakref.ref(f) + del f + support.gc_collect() self.assertTrue(wr() is None, wr) with self.open(support.TESTFN, "rb") as f: self.assertEqual(f.read(), b"123xxx") @@ -1426,14 +1437,14 @@ class BufferedRWPairTest(unittest.TestCase): def readable(self): return False - self.assertRaises(IOError, self.tp, NotReadable(), self.MockRawIO()) + self.assertRaises(OSError, self.tp, NotReadable(), self.MockRawIO()) def test_constructor_with_not_writeable(self): class NotWriteable(MockRawIO): def writable(self): return False - self.assertRaises(IOError, self.tp, self.MockRawIO(), NotWriteable()) + self.assertRaises(OSError, self.tp, self.MockRawIO(), NotWriteable()) def test_read(self): pair = self.tp(self.BytesIO(b"abcdef"), self.MockRawIO()) @@ -1955,6 +1966,15 @@ class TextIOWrapperTest(unittest.TestCase): self.assertRaises(TypeError, t.__init__, b, newline=42) self.assertRaises(ValueError, t.__init__, b, newline='xyzzy') + def test_non_text_encoding_codecs_are_rejected(self): + # Ensure the constructor complains if passed a codec that isn't + # marked as a text encoding + # http://bugs.python.org/issue20404 + r = self.BytesIO() + b = self.BufferedWriter(r) + with self.assertRaisesRegex(LookupError, "is not a text encoding"): + self.TextIOWrapper(b, encoding="hex") + def test_detach(self): r = self.BytesIO() b = self.BufferedWriter(r) @@ -2200,7 +2220,7 @@ class TextIOWrapperTest(unittest.TestCase): if s: # The destructor *may* have printed an unraisable error, check it self.assertEqual(len(s.splitlines()), 1) - self.assertTrue(s.startswith("Exception IOError: "), s) + self.assertTrue(s.startswith("Exception OSError: "), s) self.assertTrue(s.endswith(" ignored"), s) # Systematic tests of the text I/O API @@ -2272,7 +2292,7 @@ class TextIOWrapperTest(unittest.TestCase): f.seek(0) for line in f: self.assertEqual(line, "\xff\n") - self.assertRaises(IOError, f.tell) + self.assertRaises(OSError, f.tell) self.assertEqual(f.tell(), p2) f.close() @@ -2376,7 +2396,7 @@ class TextIOWrapperTest(unittest.TestCase): def readable(self): return False txt = self.TextIOWrapper(UnReadable()) - self.assertRaises(IOError, txt.read) + self.assertRaises(OSError, txt.read) def test_read_one_by_one(self): txt = self.TextIOWrapper(self.BytesIO(b"AA\r\nBB")) @@ -2551,9 +2571,9 @@ class TextIOWrapperTest(unittest.TestCase): def test_flush_error_on_close(self): txt = self.TextIOWrapper(self.BytesIO(self.testdata), encoding="ascii") def bad_flush(): - raise IOError() + raise OSError() txt.flush = bad_flush - self.assertRaises(IOError, txt.close) # exception not swallowed + self.assertRaises(OSError, txt.close) # exception not swallowed self.assertTrue(txt.closed) def test_multi_close(self): @@ -2607,19 +2627,64 @@ class TextIOWrapperTest(unittest.TestCase): def test_illegal_decoder(self): # Issue #17106 + # Bypass the early encoding check added in issue 20404 + def _make_illegal_wrapper(): + quopri = codecs.lookup("quopri") + quopri._is_text_encoding = True + try: + t = self.TextIOWrapper(self.BytesIO(b'aaaaaa'), + newline='\n', encoding="quopri") + finally: + quopri._is_text_encoding = False + return t # Crash when decoder returns non-string - t = self.TextIOWrapper(self.BytesIO(b'aaaaaa'), newline='\n', - encoding='quopri_codec') + t = _make_illegal_wrapper() self.assertRaises(TypeError, t.read, 1) - t = self.TextIOWrapper(self.BytesIO(b'aaaaaa'), newline='\n', - encoding='quopri_codec') + t = _make_illegal_wrapper() self.assertRaises(TypeError, t.readline) - t = self.TextIOWrapper(self.BytesIO(b'aaaaaa'), newline='\n', - encoding='quopri_codec') + t = _make_illegal_wrapper() self.assertRaises(TypeError, t.read) + def _check_create_at_shutdown(self, **kwargs): + # Issue #20037: creating a TextIOWrapper at shutdown + # shouldn't crash the interpreter. + iomod = self.io.__name__ + code = """if 1: + import codecs + import {iomod} as io + + # Avoid looking up codecs at shutdown + codecs.lookup('utf-8') + + class C: + def __init__(self): + self.buf = io.BytesIO() + def __del__(self): + io.TextIOWrapper(self.buf, **{kwargs}) + print("ok") + c = C() + """.format(iomod=iomod, kwargs=kwargs) + return assert_python_ok("-c", code) + + def test_create_at_shutdown_without_encoding(self): + rc, out, err = self._check_create_at_shutdown() + if err: + # Can error out with a RuntimeError if the module state + # isn't found. + self.assertIn(self.shutdown_error, err.decode()) + else: + self.assertEqual("ok", out.decode().strip()) + + def test_create_at_shutdown_with_encoding(self): + rc, out, err = self._check_create_at_shutdown(encoding='utf-8', + errors='strict') + self.assertFalse(err) + self.assertEqual("ok", out.decode().strip()) + class CTextIOWrapperTest(TextIOWrapperTest): + io = io + shutdown_error = "RuntimeError: could not find io module state" def test_initialization(self): r = self.BytesIO(b"\xc3\xa9\n\n") @@ -2634,14 +2699,15 @@ class CTextIOWrapperTest(TextIOWrapperTest): # C TextIOWrapper objects are collected, and collecting them flushes # all data to disk. # The Python version has __del__, so it ends in gc.garbage instead. - rawio = io.FileIO(support.TESTFN, "wb") - b = self.BufferedWriter(rawio) - t = self.TextIOWrapper(b, encoding="ascii") - t.write("456def") - t.x = t - wr = weakref.ref(t) - del t - support.gc_collect() + with support.check_warnings(('', ResourceWarning)): + rawio = io.FileIO(support.TESTFN, "wb") + b = self.BufferedWriter(rawio) + t = self.TextIOWrapper(b, encoding="ascii") + t.write("456def") + t.x = t + wr = weakref.ref(t) + del t + support.gc_collect() self.assertTrue(wr() is None, wr) with self.open(support.TESTFN, "rb") as f: self.assertEqual(f.read(), b"456def") @@ -2662,7 +2728,9 @@ class CTextIOWrapperTest(TextIOWrapperTest): class PyTextIOWrapperTest(TextIOWrapperTest): - pass + io = pyio + #shutdown_error = "LookupError: unknown encoding: ascii" + shutdown_error = "TypeError: 'NoneType' object is not iterable" class IncrementalNewlineDecoderTest(unittest.TestCase): @@ -2801,7 +2869,8 @@ class MiscIOTest(unittest.TestCase): self.assertEqual(f.mode, "wb") f.close() - f = self.open(support.TESTFN, "U") + with support.check_warnings(('', DeprecationWarning)): + f = self.open(support.TESTFN, "U") self.assertEqual(f.name, support.TESTFN) self.assertEqual(f.buffer.name, support.TESTFN) self.assertEqual(f.buffer.raw.name, support.TESTFN) @@ -2931,7 +3000,7 @@ class MiscIOTest(unittest.TestCase): for fd in fds: try: os.close(fd) - except EnvironmentError as e: + except OSError as e: if e.errno != errno.EBADF: raise self.addCleanup(cleanup_fds) @@ -3097,15 +3166,18 @@ class SignalsTest(unittest.TestCase): try: wio = self.io.open(w, **fdopen_kwargs) t.start() - signal.alarm(1) # Fill the pipe enough that the write will be blocking. # It will be interrupted by the timer armed above. Since the # other thread has read one byte, the low-level write will # return with a successful (partial) result rather than an EINTR. # The buffered IO layer must check for pending signal # handlers, which in this case will invoke alarm_interrupt(). - self.assertRaises(ZeroDivisionError, - wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1)) + signal.alarm(1) + try: + self.assertRaises(ZeroDivisionError, + wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1)) + finally: + signal.alarm(0) t.join() # We got one byte, get another one and check that it isn't a # repeat of the first one. @@ -3119,7 +3191,7 @@ class SignalsTest(unittest.TestCase): # buffer, and block again. try: wio.close() - except IOError as e: + except OSError as e: if e.errno != errno.EBADF: raise @@ -3247,7 +3319,7 @@ class SignalsTest(unittest.TestCase): # buffer, and could block (in case of failure). try: wio.close() - except IOError as e: + except OSError as e: if e.errno != errno.EBADF: raise diff --git a/Lib/test/test_ioctl.py b/Lib/test/test_ioctl.py index 531c9afbb5..efe9f516cc 100644 --- a/Lib/test/test_ioctl.py +++ b/Lib/test/test_ioctl.py @@ -8,7 +8,7 @@ get_attribute(termios, 'TIOCGPGRP') #Can't run tests without this feature try: tty = open("/dev/tty", "rb") -except IOError: +except OSError: raise unittest.SkipTest("Unable to open /dev/tty") else: # Skip if another process is in foreground @@ -86,8 +86,6 @@ class IoctlTests(unittest.TestCase): os.close(mfd) os.close(sfd) -def test_main(): - run_unittest(IoctlTests) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index 5bfcf41429..f2947b9c8a 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -1358,6 +1358,14 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(True, ipaddress.ip_network( '127.42.0.0/16').is_loopback) self.assertEqual(False, ipaddress.ip_network('128.0.0.0').is_loopback) + self.assertEqual(False, + ipaddress.ip_network('100.64.0.0/10').is_private) + self.assertEqual(False, ipaddress.ip_network('100.64.0.0/10').is_global) + + self.assertEqual(True, + ipaddress.ip_network('192.0.2.128/25').is_private) + self.assertEqual(True, + ipaddress.ip_network('192.0.3.0/24').is_global) # test addresses self.assertEqual(True, ipaddress.ip_address('0.0.0.0').is_unspecified) @@ -1423,6 +1431,10 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(False, ipaddress.ip_network('::1').is_unspecified) self.assertEqual(False, ipaddress.ip_network('::/127').is_unspecified) + self.assertEqual(True, + ipaddress.ip_network('2001::1/128').is_private) + self.assertEqual(True, + ipaddress.ip_network('200::1/128').is_global) # test addresses self.assertEqual(True, ipaddress.ip_address('ffff::').is_multicast) self.assertEqual(True, ipaddress.ip_address(2**128 - 1).is_multicast) diff --git a/Lib/test/test_iterlen.py b/Lib/test/test_iterlen.py index 9101f6c884..152f5fc0cb 100644 --- a/Lib/test/test_iterlen.py +++ b/Lib/test/test_iterlen.py @@ -45,31 +45,21 @@ import unittest from test import support from itertools import repeat from collections import deque -from builtins import len as _len +from operator import length_hint n = 10 -def len(obj): - try: - return _len(obj) - except TypeError: - try: - # note: this is an internal undocumented API, - # don't rely on it in your own programs - return obj.__length_hint__() - except AttributeError: - raise TypeError class TestInvariantWithoutMutations: def test_invariant(self): it = self.it for i in reversed(range(1, n+1)): - self.assertEqual(len(it), i) + self.assertEqual(length_hint(it), i) next(it) - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) self.assertRaises(StopIteration, next, it) - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) class TestTemporarilyImmutable(TestInvariantWithoutMutations): @@ -78,12 +68,12 @@ class TestTemporarilyImmutable(TestInvariantWithoutMutations): # length immutability during iteration it = self.it - self.assertEqual(len(it), n) + self.assertEqual(length_hint(it), n) next(it) - self.assertEqual(len(it), n-1) + self.assertEqual(length_hint(it), n-1) self.mutate() self.assertRaises(RuntimeError, next, it) - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) ## ------- Concrete Type Tests ------- @@ -92,10 +82,6 @@ class TestRepeat(TestInvariantWithoutMutations, unittest.TestCase): def setUp(self): self.it = repeat(None, n) - def test_no_len_for_infinite_repeat(self): - # The repeat() object can also be infinite - self.assertRaises(TypeError, len, repeat(None)) - class TestXrange(TestInvariantWithoutMutations, unittest.TestCase): def setUp(self): @@ -167,14 +153,15 @@ class TestList(TestInvariantWithoutMutations, unittest.TestCase): it = iter(d) next(it) next(it) - self.assertEqual(len(it), n-2) + self.assertEqual(length_hint(it), n - 2) d.append(n) - self.assertEqual(len(it), n-1) # grow with append + self.assertEqual(length_hint(it), n - 1) # grow with append d[1:] = [] - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) self.assertEqual(list(it), []) d.extend(range(20)) - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) + class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase): @@ -186,32 +173,41 @@ class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase): it = reversed(d) next(it) next(it) - self.assertEqual(len(it), n-2) + self.assertEqual(length_hint(it), n - 2) d.append(n) - self.assertEqual(len(it), n-2) # ignore append + self.assertEqual(length_hint(it), n - 2) # ignore append d[1:] = [] - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) self.assertEqual(list(it), []) # confirm invariant d.extend(range(20)) - self.assertEqual(len(it), 0) + self.assertEqual(length_hint(it), 0) ## -- Check to make sure exceptions are not suppressed by __length_hint__() class BadLen(object): - def __iter__(self): return iter(range(10)) + def __iter__(self): + return iter(range(10)) + def __len__(self): raise RuntimeError('hello') + class BadLengthHint(object): - def __iter__(self): return iter(range(10)) + def __iter__(self): + return iter(range(10)) + def __length_hint__(self): raise RuntimeError('hello') + class NoneLengthHint(object): - def __iter__(self): return iter(range(10)) + def __iter__(self): + return iter(range(10)) + def __length_hint__(self): - return None + return NotImplemented + class TestLengthHintExceptions(unittest.TestCase): diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 514a6b7a20..21f1bdeb82 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -10,6 +10,8 @@ import random import copy import pickle from functools import reduce +import sys +import struct maxsize = support.MAX_Py_ssize_t minsize = -maxsize-1 @@ -1729,9 +1731,8 @@ class TestVariousIteratorArgs(unittest.TestCase): class LengthTransparency(unittest.TestCase): def test_repeat(self): - from test.test_iterlen import len - self.assertEqual(len(repeat(None, 50)), 50) - self.assertRaises(TypeError, len, repeat(None)) + self.assertEqual(operator.length_hint(repeat(None, 50)), 50) + self.assertEqual(operator.length_hint(repeat(None), 12), 12) class RegressionTests(unittest.TestCase): @@ -1808,6 +1809,44 @@ class SubclassWithKwargsTest(unittest.TestCase): # we expect type errors because of wrong argument count self.assertNotIn("does not take keyword arguments", err.args[0]) +@support.cpython_only +class SizeofTest(unittest.TestCase): + def setUp(self): + self.ssize_t = struct.calcsize('n') + + check_sizeof = support.check_sizeof + + def test_product_sizeof(self): + basesize = support.calcobjsize('3Pi') + check = self.check_sizeof + check(product('ab', '12'), basesize + 2 * self.ssize_t) + check(product(*(('abc',) * 10)), basesize + 10 * self.ssize_t) + + def test_combinations_sizeof(self): + basesize = support.calcobjsize('3Pni') + check = self.check_sizeof + check(combinations('abcd', 3), basesize + 3 * self.ssize_t) + check(combinations(range(10), 4), basesize + 4 * self.ssize_t) + + def test_combinations_with_replacement_sizeof(self): + cwr = combinations_with_replacement + basesize = support.calcobjsize('3Pni') + check = self.check_sizeof + check(cwr('abcd', 3), basesize + 3 * self.ssize_t) + check(cwr(range(10), 4), basesize + 4 * self.ssize_t) + + def test_permutations_sizeof(self): + basesize = support.calcobjsize('4Pni') + check = self.check_sizeof + check(permutations('abcd'), + basesize + 4 * self.ssize_t + 4 * self.ssize_t) + check(permutations('abcd', 3), + basesize + 4 * self.ssize_t + 3 * self.ssize_t) + check(permutations('abcde', 3), + basesize + 5 * self.ssize_t + 3 * self.ssize_t) + check(permutations(range(10), 4), + basesize + 10 * self.ssize_t + 4 * self.ssize_t) + libreftest = """ Doctest for examples in the library reference: libitertools.tex @@ -2043,7 +2082,8 @@ __test__ = {'libreftest' : libreftest} def test_main(verbose=None): test_classes = (TestBasicOps, TestVariousIteratorArgs, TestGC, RegressionTests, LengthTransparency, - SubclassWithKwargsTest, TestExamples) + SubclassWithKwargsTest, TestExamples, + SizeofTest) support.run_unittest(*test_classes) # verify reference counting diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py index d23e285909..35c02de88c 100644 --- a/Lib/test/test_json/test_decode.py +++ b/Lib/test/test_json/test_decode.py @@ -1,5 +1,5 @@ import decimal -from io import StringIO +from io import StringIO, BytesIO from collections import OrderedDict from test.test_json import PyTest, CTest @@ -70,5 +70,26 @@ class TestDecode: msg = 'escape' self.assertRaisesRegex(ValueError, msg, self.loads, s) + def test_invalid_input_type(self): + msg = 'the JSON object must be str' + for value in [1, 3.14, b'bytes', b'\xff\x00', [], {}, None]: + self.assertRaisesRegex(TypeError, msg, self.loads, value) + with self.assertRaisesRegex(TypeError, msg): + self.json.load(BytesIO(b'[1,2,3]')) + + def test_string_with_utf8_bom(self): + # see #18958 + bom_json = "[1,2,3]".encode('utf-8-sig').decode('utf-8') + with self.assertRaises(ValueError) as cm: + self.loads(bom_json) + self.assertIn('BOM', str(cm.exception)) + with self.assertRaises(ValueError) as cm: + self.json.load(StringIO(bom_json)) + self.assertIn('BOM', str(cm.exception)) + # make sure that the BOM is not detected in the middle of a string + bom_in_str = '"{}"'.format(''.encode('utf-8-sig').decode('utf-8')) + self.assertEqual(self.loads(bom_in_str), '\ufeff') + self.assertEqual(self.json.load(StringIO(bom_in_str)), '\ufeff') + class TestPyDecode(TestDecode, PyTest): pass class TestCDecode(TestDecode, CTest): pass diff --git a/Lib/test/test_json/test_enum.py b/Lib/test/test_json/test_enum.py new file mode 100644 index 0000000000..10f414898b --- /dev/null +++ b/Lib/test/test_json/test_enum.py @@ -0,0 +1,120 @@ +from enum import Enum, IntEnum +from math import isnan +from test.test_json import PyTest, CTest + +SMALL = 1 +BIG = 1<<32 +HUGE = 1<<64 +REALLY_HUGE = 1<<96 + +class BigNum(IntEnum): + small = SMALL + big = BIG + huge = HUGE + really_huge = REALLY_HUGE + +E = 2.718281 +PI = 3.141593 +TAU = 2 * PI + +class FloatNum(float, Enum): + e = E + pi = PI + tau = TAU + +INF = float('inf') +NEG_INF = float('-inf') +NAN = float('nan') + +class WierdNum(float, Enum): + inf = INF + neg_inf = NEG_INF + nan = NAN + +class TestEnum: + + def test_floats(self): + for enum in FloatNum: + self.assertEqual(self.dumps(enum), repr(enum.value)) + self.assertEqual(float(self.dumps(enum)), enum) + self.assertEqual(self.loads(self.dumps(enum)), enum) + + def test_weird_floats(self): + for enum, expected in zip(WierdNum, ('Infinity', '-Infinity', 'NaN')): + self.assertEqual(self.dumps(enum), expected) + if not isnan(enum): + self.assertEqual(float(self.dumps(enum)), enum) + self.assertEqual(self.loads(self.dumps(enum)), enum) + else: + self.assertTrue(isnan(float(self.dumps(enum)))) + self.assertTrue(isnan(self.loads(self.dumps(enum)))) + + def test_ints(self): + for enum in BigNum: + self.assertEqual(self.dumps(enum), str(enum.value)) + self.assertEqual(int(self.dumps(enum)), enum) + self.assertEqual(self.loads(self.dumps(enum)), enum) + + def test_list(self): + self.assertEqual(self.dumps(list(BigNum)), + str([SMALL, BIG, HUGE, REALLY_HUGE])) + self.assertEqual(self.loads(self.dumps(list(BigNum))), + list(BigNum)) + self.assertEqual(self.dumps(list(FloatNum)), + str([E, PI, TAU])) + self.assertEqual(self.loads(self.dumps(list(FloatNum))), + list(FloatNum)) + self.assertEqual(self.dumps(list(WierdNum)), + '[Infinity, -Infinity, NaN]') + self.assertEqual(self.loads(self.dumps(list(WierdNum)))[:2], + list(WierdNum)[:2]) + self.assertTrue(isnan(self.loads(self.dumps(list(WierdNum)))[2])) + + def test_dict_keys(self): + s, b, h, r = BigNum + e, p, t = FloatNum + i, j, n = WierdNum + d = { + s:'tiny', b:'large', h:'larger', r:'largest', + e:"Euler's number", p:'pi', t:'tau', + i:'Infinity', j:'-Infinity', n:'NaN', + } + nd = self.loads(self.dumps(d)) + self.assertEqual(nd[str(SMALL)], 'tiny') + self.assertEqual(nd[str(BIG)], 'large') + self.assertEqual(nd[str(HUGE)], 'larger') + self.assertEqual(nd[str(REALLY_HUGE)], 'largest') + self.assertEqual(nd[repr(E)], "Euler's number") + self.assertEqual(nd[repr(PI)], 'pi') + self.assertEqual(nd[repr(TAU)], 'tau') + self.assertEqual(nd['Infinity'], 'Infinity') + self.assertEqual(nd['-Infinity'], '-Infinity') + self.assertEqual(nd['NaN'], 'NaN') + + def test_dict_values(self): + d = dict( + tiny=BigNum.small, + large=BigNum.big, + larger=BigNum.huge, + largest=BigNum.really_huge, + e=FloatNum.e, + pi=FloatNum.pi, + tau=FloatNum.tau, + i=WierdNum.inf, + j=WierdNum.neg_inf, + n=WierdNum.nan, + ) + nd = self.loads(self.dumps(d)) + self.assertEqual(nd['tiny'], SMALL) + self.assertEqual(nd['large'], BIG) + self.assertEqual(nd['larger'], HUGE) + self.assertEqual(nd['largest'], REALLY_HUGE) + self.assertEqual(nd['e'], E) + self.assertEqual(nd['pi'], PI) + self.assertEqual(nd['tau'], TAU) + self.assertEqual(nd['i'], INF) + self.assertEqual(nd['j'], NEG_INF) + self.assertTrue(isnan(nd['n'])) + +class TestPyEnum(TestEnum, PyTest): pass +class TestCEnum(TestEnum, CTest): pass diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py index 3dd877afdb..7caafdbddd 100644 --- a/Lib/test/test_json/test_fail.py +++ b/Lib/test/test_json/test_fail.py @@ -1,4 +1,5 @@ from test.test_json import PyTest, CTest +import re # 2007-10-05 JSONDOCS = [ @@ -100,6 +101,94 @@ class TestFail: #This is for python encoder self.assertRaises(TypeError, self.dumps, data, indent=True) + def test_truncated_input(self): + test_cases = [ + ('', 'Expecting value', 0), + ('[', 'Expecting value', 1), + ('[42', "Expecting ',' delimiter", 3), + ('[42,', 'Expecting value', 4), + ('["', 'Unterminated string starting at', 1), + ('["spam', 'Unterminated string starting at', 1), + ('["spam"', "Expecting ',' delimiter", 7), + ('["spam",', 'Expecting value', 8), + ('{', 'Expecting property name enclosed in double quotes', 1), + ('{"', 'Unterminated string starting at', 1), + ('{"spam', 'Unterminated string starting at', 1), + ('{"spam"', "Expecting ':' delimiter", 7), + ('{"spam":', 'Expecting value', 8), + ('{"spam":42', "Expecting ',' delimiter", 10), + ('{"spam":42,', 'Expecting property name enclosed in double quotes', 11), + ] + test_cases += [ + ('"', 'Unterminated string starting at', 0), + ('"spam', 'Unterminated string starting at', 0), + ] + for data, msg, idx in test_cases: + self.assertRaisesRegex(ValueError, + r'^{0}: line 1 column {1} \(char {2}\)'.format( + re.escape(msg), idx + 1, idx), + self.loads, data) + + def test_unexpected_data(self): + test_cases = [ + ('[,', 'Expecting value', 1), + ('{"spam":[}', 'Expecting value', 9), + ('[42:', "Expecting ',' delimiter", 3), + ('[42 "spam"', "Expecting ',' delimiter", 4), + ('[42,]', 'Expecting value', 4), + ('{"spam":[42}', "Expecting ',' delimiter", 11), + ('["]', 'Unterminated string starting at', 1), + ('["spam":', "Expecting ',' delimiter", 7), + ('["spam",]', 'Expecting value', 8), + ('{:', 'Expecting property name enclosed in double quotes', 1), + ('{,', 'Expecting property name enclosed in double quotes', 1), + ('{42', 'Expecting property name enclosed in double quotes', 1), + ('[{]', 'Expecting property name enclosed in double quotes', 2), + ('{"spam",', "Expecting ':' delimiter", 7), + ('{"spam"}', "Expecting ':' delimiter", 7), + ('[{"spam"]', "Expecting ':' delimiter", 8), + ('{"spam":}', 'Expecting value', 8), + ('[{"spam":]', 'Expecting value', 9), + ('{"spam":42 "ham"', "Expecting ',' delimiter", 11), + ('[{"spam":42]', "Expecting ',' delimiter", 11), + ('{"spam":42,}', 'Expecting property name enclosed in double quotes', 11), + ] + for data, msg, idx in test_cases: + self.assertRaisesRegex(ValueError, + r'^{0}: line 1 column {1} \(char {2}\)'.format( + re.escape(msg), idx + 1, idx), + self.loads, data) + + def test_extra_data(self): + test_cases = [ + ('[]]', 'Extra data', 2), + ('{}}', 'Extra data', 2), + ('[],[]', 'Extra data', 2), + ('{},{}', 'Extra data', 2), + ] + test_cases += [ + ('42,"spam"', 'Extra data', 2), + ('"spam",42', 'Extra data', 6), + ] + for data, msg, idx in test_cases: + self.assertRaisesRegex(ValueError, + r'^{0}: line 1 column {1} - line 1 column {2}' + r' \(char {3} - {4}\)'.format( + re.escape(msg), idx + 1, len(data) + 1, idx, len(data)), + self.loads, data) + + def test_linecol(self): + test_cases = [ + ('!', 1, 1, 0), + (' !', 1, 2, 1), + ('\n!', 2, 1, 1), + ('\n \n\n !', 4, 6, 10), + ] + for data, line, col, idx in test_cases: + self.assertRaisesRegex(ValueError, + r'^Expecting value: line {0} column {1}' + r' \(char {2}\)$'.format(line, col, idx), + self.loads, data) class TestPyFail(TestFail, PyTest): pass class TestCFail(TestFail, CTest): pass diff --git a/Lib/test/test_json/test_indent.py b/Lib/test/test_json/test_indent.py index a4d4d20142..e07856f33c 100644 --- a/Lib/test/test_json/test_indent.py +++ b/Lib/test/test_json/test_indent.py @@ -32,6 +32,8 @@ class TestIndent: d1 = self.dumps(h) d2 = self.dumps(h, indent=2, sort_keys=True, separators=(',', ': ')) d3 = self.dumps(h, indent='\t', sort_keys=True, separators=(',', ': ')) + d4 = self.dumps(h, indent=2, sort_keys=True) + d5 = self.dumps(h, indent='\t', sort_keys=True) h1 = self.loads(d1) h2 = self.loads(d2) @@ -42,6 +44,8 @@ class TestIndent: self.assertEqual(h3, h) self.assertEqual(d2, expect.expandtabs(2)) self.assertEqual(d3, expect) + self.assertEqual(d4, d2) + self.assertEqual(d5, d3) def test_indent0(self): h = {3: 1} diff --git a/Lib/test/test_keyword.py b/Lib/test/test_keyword.py new file mode 100644 index 0000000000..af99f52c63 --- /dev/null +++ b/Lib/test/test_keyword.py @@ -0,0 +1,138 @@ +import keyword +import unittest +from test import support +import filecmp +import os +import sys +import subprocess +import shutil +import textwrap + +KEYWORD_FILE = support.findfile('keyword.py') +GRAMMAR_FILE = os.path.join(os.path.split(__file__)[0], + '..', '..', 'Python', 'graminit.c') +TEST_PY_FILE = 'keyword_test.py' +GRAMMAR_TEST_FILE = 'graminit_test.c' +PY_FILE_WITHOUT_KEYWORDS = 'minimal_keyword.py' +NONEXISTENT_FILE = 'not_here.txt' + + +class Test_iskeyword(unittest.TestCase): + def test_true_is_a_keyword(self): + self.assertTrue(keyword.iskeyword('True')) + + def test_uppercase_true_is_not_a_keyword(self): + self.assertFalse(keyword.iskeyword('TRUE')) + + def test_none_value_is_not_a_keyword(self): + self.assertFalse(keyword.iskeyword(None)) + + # This is probably an accident of the current implementation, but should be + # preserved for backward compatibility. + def test_changing_the_kwlist_does_not_affect_iskeyword(self): + oldlist = keyword.kwlist + self.addCleanup(setattr, keyword, 'kwlist', oldlist) + keyword.kwlist = ['its', 'all', 'eggs', 'beans', 'and', 'a', 'slice'] + self.assertFalse(keyword.iskeyword('eggs')) + + +class TestKeywordGeneration(unittest.TestCase): + + def _copy_file_without_generated_keywords(self, source_file, dest_file): + with open(source_file, 'rb') as fp: + lines = fp.readlines() + nl = lines[0][len(lines[0].strip()):] + with open(dest_file, 'wb') as fp: + fp.writelines(lines[:lines.index(b"#--start keywords--" + nl) + 1]) + fp.writelines(lines[lines.index(b"#--end keywords--" + nl):]) + + def _generate_keywords(self, grammar_file, target_keyword_py_file): + proc = subprocess.Popen([sys.executable, + KEYWORD_FILE, + grammar_file, + target_keyword_py_file], stderr=subprocess.PIPE) + stderr = proc.communicate()[1] + return proc.returncode, stderr + + @unittest.skipIf(not os.path.exists(GRAMMAR_FILE), + 'test only works from source build directory') + def test_real_grammar_and_keyword_file(self): + self._copy_file_without_generated_keywords(KEYWORD_FILE, TEST_PY_FILE) + self.addCleanup(support.unlink, TEST_PY_FILE) + self.assertFalse(filecmp.cmp(KEYWORD_FILE, TEST_PY_FILE)) + self.assertEqual((0, b''), self._generate_keywords(GRAMMAR_FILE, + TEST_PY_FILE)) + self.assertTrue(filecmp.cmp(KEYWORD_FILE, TEST_PY_FILE)) + + def test_grammar(self): + self._copy_file_without_generated_keywords(KEYWORD_FILE, TEST_PY_FILE) + self.addCleanup(support.unlink, TEST_PY_FILE) + with open(GRAMMAR_TEST_FILE, 'w') as fp: + # Some of these are probably implementation accidents. + fp.writelines(textwrap.dedent("""\ + {2, 1}, + {11, "encoding_decl", 0, 2, states_79, + "\000\000\040\000\000\000\000\000\000\000\000\000" + "\000\000\000\000\000\000\000\000\000"}, + {1, "jello"}, + {326, 0}, + {1, "turnip"}, + \t{1, "This one is tab indented" + {278, 0}, + {1, "crazy but legal" + "also legal" {1, " + {1, "continue"}, + {1, "lemon"}, + {1, "tomato"}, + {1, "wigii"}, + {1, 'no good'} + {283, 0}, + {1, "too many spaces"}""")) + self.addCleanup(support.unlink, GRAMMAR_TEST_FILE) + self._generate_keywords(GRAMMAR_TEST_FILE, TEST_PY_FILE) + expected = [ + " 'This one is tab indented',", + " 'also legal',", + " 'continue',", + " 'crazy but legal',", + " 'jello',", + " 'lemon',", + " 'tomato',", + " 'turnip',", + " 'wigii',", + ] + with open(TEST_PY_FILE) as fp: + lines = fp.read().splitlines() + start = lines.index("#--start keywords--") + 1 + end = lines.index("#--end keywords--") + actual = lines[start:end] + self.assertEqual(actual, expected) + + def test_empty_grammar_results_in_no_keywords(self): + self._copy_file_without_generated_keywords(KEYWORD_FILE, + PY_FILE_WITHOUT_KEYWORDS) + self.addCleanup(support.unlink, PY_FILE_WITHOUT_KEYWORDS) + shutil.copyfile(KEYWORD_FILE, TEST_PY_FILE) + self.addCleanup(support.unlink, TEST_PY_FILE) + self.assertEqual((0, b''), self._generate_keywords(os.devnull, + TEST_PY_FILE)) + self.assertTrue(filecmp.cmp(TEST_PY_FILE, PY_FILE_WITHOUT_KEYWORDS)) + + def test_keywords_py_without_markers_produces_error(self): + rc, stderr = self._generate_keywords(os.devnull, os.devnull) + self.assertNotEqual(rc, 0) + self.assertRegex(stderr, b'does not contain format markers') + + def test_missing_grammar_file_produces_error(self): + rc, stderr = self._generate_keywords(NONEXISTENT_FILE, KEYWORD_FILE) + self.assertNotEqual(rc, 0) + self.assertRegex(stderr, b'(?ms)' + NONEXISTENT_FILE.encode()) + + def test_missing_keywords_py_file_produces_error(self): + rc, stderr = self._generate_keywords(os.devnull, NONEXISTENT_FILE) + self.assertNotEqual(rc, 0) + self.assertRegex(stderr, b'(?ms)' + NONEXISTENT_FILE.encode()) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_keywordonlyarg.py b/Lib/test/test_keywordonlyarg.py index 6c2ff00362..7f315d4cb4 100644 --- a/Lib/test/test_keywordonlyarg.py +++ b/Lib/test/test_keywordonlyarg.py @@ -174,6 +174,18 @@ class KeywordOnlyArgTestCase(unittest.TestCase): return __a self.assertEqual(X().f(), 42) + def test_default_evaluation_order(self): + # See issue 16967 + a = 42 + with self.assertRaises(NameError) as err: + def f(v=a, x=b, *, y=c, z=d): + pass + self.assertEqual(str(err.exception), "name 'b' is not defined") + with self.assertRaises(NameError) as err: + f = lambda v=a, x=b, *, y=c, z=d: None + self.assertEqual(str(err.exception), "name 'b' is not defined") + + def test_main(): run_unittest(KeywordOnlyArgTestCase) diff --git a/Lib/test/test_kqueue.py b/Lib/test/test_kqueue.py index e5e6058334..bafdeba6ff 100644 --- a/Lib/test/test_kqueue.py +++ b/Lib/test/test_kqueue.py @@ -94,7 +94,7 @@ class TestKQueue(unittest.TestCase): client.setblocking(False) try: client.connect(('127.0.0.1', serverSocket.getsockname()[1])) - except socket.error as e: + except OSError as e: self.assertEqual(e.args[0], errno.EINPROGRESS) else: #raise AssertionError("Connect should have raised EINPROGRESS") @@ -185,6 +185,33 @@ class TestKQueue(unittest.TestCase): b.close() kq.close() + def test_close(self): + open_file = open(__file__, "rb") + self.addCleanup(open_file.close) + fd = open_file.fileno() + kqueue = select.kqueue() + + # test fileno() method and closed attribute + self.assertIsInstance(kqueue.fileno(), int) + self.assertFalse(kqueue.closed) + + # test close() + kqueue.close() + self.assertTrue(kqueue.closed) + self.assertRaises(ValueError, kqueue.fileno) + + # close() can be called more than once + kqueue.close() + + # operations must fail with ValueError("I/O operation on closed ...") + self.assertRaises(ValueError, kqueue.control, None, 4) + + def test_fd_non_inheritable(self): + kqueue = select.kqueue() + self.addCleanup(kqueue.close) + self.assertEqual(os.get_inheritable(kqueue.fileno()), False) + + def test_main(): support.run_unittest(TestKQueue) diff --git a/Lib/test/test_largefile.py b/Lib/test/test_largefile.py index 63ee697250..5b276e76ff 100644 --- a/Lib/test/test_largefile.py +++ b/Lib/test/test_largefile.py @@ -159,7 +159,7 @@ def setUpModule(): # Seeking is not enough of a test: you must write and flush, too! f.write(b'x') f.flush() - except (IOError, OverflowError): + except (OSError, OverflowError): raise unittest.SkipTest("filesystem does not have " "largefile support") finally: diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py index 76347b7746..ab990ae63f 100644 --- a/Lib/test/test_locale.py +++ b/Lib/test/test_locale.py @@ -371,7 +371,8 @@ class NormalizeTest(unittest.TestCase): def test_locale_alias(self): for localename, alias in locale.locale_alias.items(): - self.check(localename, alias) + with self.subTest(locale=(localename, alias)): + self.check(localename, alias) def test_empty(self): self.check('', '') @@ -383,6 +384,7 @@ class NormalizeTest(unittest.TestCase): def test_english(self): self.check('en', 'en_US.ISO8859-1') self.check('EN', 'en_US.ISO8859-1') + self.check('en.iso88591', 'en_US.ISO8859-1') self.check('en_US', 'en_US.ISO8859-1') self.check('en_us', 'en_US.ISO8859-1') self.check('en_GB', 'en_GB.ISO8859-1') @@ -391,7 +393,10 @@ class NormalizeTest(unittest.TestCase): self.check('en_US:UTF-8', 'en_US.UTF-8') self.check('en_US.ISO8859-1', 'en_US.ISO8859-1') self.check('en_US.US-ASCII', 'en_US.ISO8859-1') + self.check('en_US.88591', 'en_US.ISO8859-1') + self.check('en_US.885915', 'en_US.ISO8859-15') self.check('english', 'en_EN.ISO8859-1') + self.check('english_uk.ascii', 'en_GB.ISO8859-1') def test_hyphenated_encoding(self): self.check('az_AZ.iso88599e', 'az_AZ.ISO8859-9E') @@ -411,10 +416,12 @@ class NormalizeTest(unittest.TestCase): def test_euro_modifier(self): self.check('de_DE@euro', 'de_DE.ISO8859-15') self.check('en_US.ISO8859-15@euro', 'en_US.ISO8859-15') + self.check('de_DE.utf8@euro', 'de_DE.UTF-8') def test_latin_modifier(self): self.check('be_BY.UTF-8@latin', 'be_BY.UTF-8@latin') self.check('sr_RS.UTF-8@latin', 'sr_RS.UTF-8@latin') + self.check('sr_RS.UTF-8@latn', 'sr_RS.UTF-8@latin') def test_valencia_modifier(self): self.check('ca_ES.UTF-8@valencia', 'ca_ES.UTF-8@valencia') @@ -435,6 +442,39 @@ class NormalizeTest(unittest.TestCase): self.check('sd_IN', 'sd_IN.UTF-8') self.check('sd', 'sd_IN.UTF-8') + def test_euc_encoding(self): + self.check('ja_jp.euc', 'ja_JP.eucJP') + self.check('ja_jp.eucjp', 'ja_JP.eucJP') + self.check('ko_kr.euc', 'ko_KR.eucKR') + self.check('ko_kr.euckr', 'ko_KR.eucKR') + self.check('zh_cn.euc', 'zh_CN.eucCN') + self.check('zh_tw.euc', 'zh_TW.eucTW') + self.check('zh_tw.euctw', 'zh_TW.eucTW') + + def test_japanese(self): + self.check('ja', 'ja_JP.eucJP') + self.check('ja.jis', 'ja_JP.JIS7') + self.check('ja.sjis', 'ja_JP.SJIS') + self.check('ja_jp', 'ja_JP.eucJP') + self.check('ja_jp.ajec', 'ja_JP.eucJP') + self.check('ja_jp.euc', 'ja_JP.eucJP') + self.check('ja_jp.eucjp', 'ja_JP.eucJP') + self.check('ja_jp.iso-2022-jp', 'ja_JP.JIS7') + self.check('ja_jp.iso2022jp', 'ja_JP.JIS7') + self.check('ja_jp.jis', 'ja_JP.JIS7') + self.check('ja_jp.jis7', 'ja_JP.JIS7') + self.check('ja_jp.mscode', 'ja_JP.SJIS') + self.check('ja_jp.pck', 'ja_JP.SJIS') + self.check('ja_jp.sjis', 'ja_JP.SJIS') + self.check('ja_jp.ujis', 'ja_JP.eucJP') + self.check('ja_jp.utf8', 'ja_JP.UTF-8') + self.check('japan', 'ja_JP.eucJP') + self.check('japanese', 'ja_JP.eucJP') + self.check('japanese-euc', 'ja_JP.eucJP') + self.check('japanese.euc', 'ja_JP.eucJP') + self.check('japanese.sjis', 'ja_JP.SJIS') + self.check('jp_jp', 'ja_JP.eucJP') + class TestMiscellaneous(unittest.TestCase): def test_getpreferredencoding(self): diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index f34172a371..e72b036e10 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -24,6 +24,7 @@ import logging.handlers import logging.config import codecs +import configparser import datetime import pickle import io @@ -38,6 +39,7 @@ import socket import struct import sys import tempfile +from test.script_helper import assert_python_ok from test.support import (captured_stdout, run_with_locale, run_unittest, patch, requires_zlib, TestHandler, Matcher) import textwrap @@ -56,7 +58,9 @@ try: import smtpd from urllib.parse import urlparse, parse_qs from socketserver import (ThreadingUDPServer, DatagramRequestHandler, - ThreadingTCPServer, StreamRequestHandler) + ThreadingTCPServer, StreamRequestHandler, + ThreadingUnixStreamServer, + ThreadingUnixDatagramServer) except ImportError: threading = None try: @@ -73,13 +77,12 @@ try: except ImportError: pass - class BaseTest(unittest.TestCase): """Base class for logging tests.""" log_format = "%(name)s -> %(levelname)s: %(message)s" - expected_log_pat = r"^([\w.]+) -> ([\w]+): ([\d]+)$" + expected_log_pat = r"^([\w.]+) -> (\w+): (\d+)$" message_num = 0 def setUp(self): @@ -91,7 +94,8 @@ class BaseTest(unittest.TestCase): self.saved_handlers = logging._handlers.copy() self.saved_handler_list = logging._handlerList[:] self.saved_loggers = saved_loggers = logger_dict.copy() - self.saved_level_names = logging._levelNames.copy() + self.saved_name_to_level = logging._nameToLevel.copy() + self.saved_level_to_name = logging._levelToName.copy() self.logger_states = logger_states = {} for name in saved_loggers: logger_states[name] = getattr(saved_loggers[name], @@ -133,8 +137,10 @@ class BaseTest(unittest.TestCase): self.root_logger.setLevel(self.original_logging_level) logging._acquireLock() try: - logging._levelNames.clear() - logging._levelNames.update(self.saved_level_names) + logging._levelToName.clear() + logging._levelToName.update(self.saved_level_to_name) + logging._nameToLevel.clear() + logging._nameToLevel.update(self.saved_name_to_level) logging._handlers.clear() logging._handlers.update(self.saved_handlers) logging._handlerList[:] = self.saved_handler_list @@ -148,12 +154,12 @@ class BaseTest(unittest.TestCase): finally: logging._releaseLock() - def assert_log_lines(self, expected_values, stream=None): + def assert_log_lines(self, expected_values, stream=None, pat=None): """Match the collected log lines against the regular expression self.expected_log_pat, and compare the extracted group values to the expected_values list of tuples.""" stream = stream or self.stream - pat = re.compile(self.expected_log_pat) + pat = re.compile(pat or self.expected_log_pat) actual_lines = stream.getvalue().splitlines() self.assertEqual(len(actual_lines), len(expected_values)) for actual, expected in zip(actual_lines, expected_values): @@ -428,7 +434,7 @@ class CustomLevelsAndFiltersTest(BaseTest): """Test various filtering possibilities with custom logging levels.""" # Skip the logger name group. - expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$" + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" def setUp(self): BaseTest.setUp(self) @@ -558,7 +564,7 @@ class HandlerTest(BaseTest): self.assertEqual(h.facility, h.LOG_USER) self.assertTrue(h.unixsocket) h.close() - except socket.error: # syslogd might not be available + except OSError: # syslogd might not be available pass for method in ('GET', 'POST', 'PUT'): if method == 'PUT': @@ -583,6 +589,7 @@ class HandlerTest(BaseTest): for _ in range(tries): try: os.unlink(fname) + self.deletion_time = time.time() except OSError: pass time.sleep(0.004 * random.randint(0, 4)) @@ -590,6 +597,9 @@ class HandlerTest(BaseTest): del_count = 500 log_count = 500 + self.handle_time = None + self.deletion_time = None + for delay in (False, True): fd, fn = tempfile.mkstemp('.log', 'test_logging-3-') os.close(fd) @@ -603,7 +613,14 @@ class HandlerTest(BaseTest): for _ in range(log_count): time.sleep(0.005) r = logging.makeLogRecord({'msg': 'testing' }) - h.handle(r) + try: + self.handle_time = time.time() + h.handle(r) + except Exception: + print('Deleted at %s, ' + 'opened at %s' % (self.deletion_time, + self.handle_time)) + raise finally: remover.join() h.close() @@ -645,41 +662,6 @@ class StreamHandlerTest(BaseTest): # -- if it proves to be of wider utility than just test_logging if threading: - class TestSMTPChannel(smtpd.SMTPChannel): - """ - This derived class has had to be created because smtpd does not - support use of custom channel maps, although they are allowed by - asyncore's design. Issue #11959 has been raised to address this, - and if resolved satisfactorily, some of this code can be removed. - """ - def __init__(self, server, conn, addr, sockmap): - asynchat.async_chat.__init__(self, conn, sockmap) - self.smtp_server = server - self.conn = conn - self.addr = addr - self.data_size_limit = None - self.received_lines = [] - self.smtp_state = self.COMMAND - self.seen_greeting = '' - self.mailfrom = None - self.rcpttos = [] - self.received_data = '' - self.fqdn = socket.getfqdn() - self.num_bytes = 0 - try: - self.peer = conn.getpeername() - except socket.error as err: - # a race condition may occur if the other end is closing - # before we can get the peername - self.close() - if err.args[0] != errno.ENOTCONN: - raise - return - self.push('220 %s %s' % (self.fqdn, smtpd.__version__)) - self.set_terminator(b'\r\n') - self.extended_smtp = False - - class TestSMTPServer(smtpd.SMTPServer): """ This class implements a test SMTP server. @@ -700,37 +682,14 @@ if threading: :func:`asyncore.loop`. This avoids changing the :mod:`asyncore` module's global state. """ - channel_class = TestSMTPChannel def __init__(self, addr, handler, poll_interval, sockmap): - self._localaddr = addr - self._remoteaddr = None - self.data_size_limit = None - self.sockmap = sockmap - asyncore.dispatcher.__init__(self, map=sockmap) - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setblocking(0) - self.set_socket(sock, map=sockmap) - # try to re-use a server port if possible - self.set_reuse_addr() - self.bind(addr) - self.port = sock.getsockname()[1] - self.listen(5) - except: - self.close() - raise + smtpd.SMTPServer.__init__(self, addr, None, map=sockmap) + self.port = self.socket.getsockname()[1] self._handler = handler self._thread = None self.poll_interval = poll_interval - def handle_accepted(self, conn, addr): - """ - Redefined only because the base class does not pass in a - map, forcing use of a global in :mod:`asyncore`. - """ - channel = self.channel_class(self, conn, addr, self.sockmap) - def process_message(self, peer, mailfrom, rcpttos, data): """ Delegates to the handler passed in to the server's constructor. @@ -761,8 +720,8 @@ if threading: :func:`asyncore.loop`. """ try: - asyncore.loop(poll_interval, map=self.sockmap) - except select.error: + asyncore.loop(poll_interval, map=self._map) + except OSError: # On FreeBSD 8, closing the server repeatably # raises this error. We swallow it if the # server has been closed. @@ -869,7 +828,7 @@ if threading: sock, addr = self.socket.accept() if self.sslctx: sock = self.sslctx.wrap_socket(sock, server_side=True) - except socket.error as e: + except OSError as e: # socket errors are silenced by the caller, print them here sys.stderr.write("Got an error:\n%s\n" % e) raise @@ -906,6 +865,9 @@ if threading: super(TestTCPServer, self).server_bind() self.port = self.socket.getsockname()[1] + class TestUnixStreamServer(TestTCPServer): + address_family = socket.AF_UNIX + class TestUDPServer(ControlMixin, ThreadingUDPServer): """ A UDP server which is controllable using :class:`ControlMixin`. @@ -935,7 +897,7 @@ if threading: if data: try: super(DelegatingUDPRequestHandler, self).finish() - except socket.error: + except OSError: if not self.server._closed: raise @@ -953,6 +915,9 @@ if threading: super(TestUDPServer, self).server_close() self._closed = True + class TestUnixDatagramServer(TestUDPServer): + address_family = socket.AF_UNIX + # - end of server_helper section @unittest.skipUnless(threading, 'Threading required for this test.') @@ -991,7 +956,7 @@ class MemoryHandlerTest(BaseTest): """Tests for the MemoryHandler.""" # Do not bother with a logger name group. - expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$" + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" def setUp(self): BaseTest.setUp(self) @@ -1044,7 +1009,7 @@ class ConfigFileTest(BaseTest): """Reading logging config from a .ini-style config file.""" - expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$" + expected_log_pat = r"^(\w+) \+\+ (\w+)$" # config0 is a standard configuration. config0 = """ @@ -1292,6 +1257,24 @@ class ConfigFileTest(BaseTest): # Original logger output is empty. self.assert_log_lines([]) + def test_config0_using_cp_ok(self): + # A simple config file which overrides the default settings. + with captured_stdout() as output: + file = io.StringIO(textwrap.dedent(self.config0)) + cp = configparser.ConfigParser() + cp.read_file(file) + logging.config.fileConfig(cp) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + def test_config1_ok(self, config=config1): # A config file defining a sub-parser as well. with captured_stdout() as output: @@ -1381,7 +1364,7 @@ class ConfigFileTest(BaseTest): def test_logger_disabling(self): self.apply_config(self.disable_test) - logger = logging.getLogger('foo') + logger = logging.getLogger('some_pristine_logger') self.assertFalse(logger.disabled) self.apply_config(self.disable_test) self.assertTrue(logger.disabled) @@ -1394,17 +1377,23 @@ class SocketHandlerTest(BaseTest): """Test for SocketHandler objects.""" + if threading: + server_class = TestTCPServer + address = ('localhost', 0) + def setUp(self): """Set up a TCP server to receive log messages, and a SocketHandler pointing to that server's address and port.""" BaseTest.setUp(self) - addr = ('localhost', 0) - self.server = server = TestTCPServer(addr, self.handle_socket, - 0.01) + self.server = server = self.server_class(self.address, + self.handle_socket, 0.01) server.start() server.ready.wait() - self.sock_hdlr = logging.handlers.SocketHandler('localhost', - server.port) + hcls = logging.handlers.SocketHandler + if isinstance(server.server_address, tuple): + self.sock_hdlr = hcls('localhost', server.port) + else: + self.sock_hdlr = hcls(server.server_address, None) self.log_output = '' self.root_logger.removeHandler(self.root_logger.handlers[0]) self.root_logger.addHandler(self.sock_hdlr) @@ -1444,35 +1433,69 @@ class SocketHandlerTest(BaseTest): self.assertEqual(self.log_output, "spam\neggs\n") def test_noserver(self): + # Avoid timing-related failures due to SocketHandler's own hard-wired + # one-second timeout on socket.create_connection() (issue #16264). + self.sock_hdlr.retryStart = 2.5 # Kill the server self.server.stop(2.0) - #The logging call should try to connect, which should fail + # The logging call should try to connect, which should fail try: raise RuntimeError('Deliberate mistake') except RuntimeError: self.root_logger.exception('Never sent') self.root_logger.error('Never sent, either') now = time.time() - self.assertTrue(self.sock_hdlr.retryTime > now) + self.assertGreater(self.sock_hdlr.retryTime, now) time.sleep(self.sock_hdlr.retryTime - now + 0.001) self.root_logger.error('Nor this') +def _get_temp_domain_socket(): + fd, fn = tempfile.mkstemp(prefix='test_logging_', suffix='.sock') + os.close(fd) + # just need a name - file can't be present, or we'll get an + # 'address already in use' error. + os.remove(fn) + return fn + +@unittest.skipUnless(threading, 'Threading required for this test.') +class UnixSocketHandlerTest(SocketHandlerTest): + + """Test for SocketHandler with unix sockets.""" + + if threading: + server_class = TestUnixStreamServer + + def setUp(self): + # override the definition in the base class + self.address = _get_temp_domain_socket() + SocketHandlerTest.setUp(self) + + def tearDown(self): + SocketHandlerTest.tearDown(self) + os.remove(self.address) @unittest.skipUnless(threading, 'Threading required for this test.') class DatagramHandlerTest(BaseTest): """Test for DatagramHandler.""" + if threading: + server_class = TestUDPServer + address = ('localhost', 0) + def setUp(self): """Set up a UDP server to receive log messages, and a DatagramHandler pointing to that server's address and port.""" BaseTest.setUp(self) - addr = ('localhost', 0) - self.server = server = TestUDPServer(addr, self.handle_datagram, 0.01) + self.server = server = self.server_class(self.address, + self.handle_datagram, 0.01) server.start() server.ready.wait() - self.sock_hdlr = logging.handlers.DatagramHandler('localhost', - server.port) + hcls = logging.handlers.DatagramHandler + if isinstance(server.server_address, tuple): + self.sock_hdlr = hcls('localhost', server.port) + else: + self.sock_hdlr = hcls(server.server_address, None) self.log_output = '' self.root_logger.removeHandler(self.root_logger.handlers[0]) self.root_logger.addHandler(self.sock_hdlr) @@ -1507,21 +1530,44 @@ class DatagramHandlerTest(BaseTest): @unittest.skipUnless(threading, 'Threading required for this test.') +class UnixDatagramHandlerTest(DatagramHandlerTest): + + """Test for DatagramHandler using Unix sockets.""" + + if threading: + server_class = TestUnixDatagramServer + + def setUp(self): + # override the definition in the base class + self.address = _get_temp_domain_socket() + DatagramHandlerTest.setUp(self) + + def tearDown(self): + DatagramHandlerTest.tearDown(self) + os.remove(self.address) + +@unittest.skipUnless(threading, 'Threading required for this test.') class SysLogHandlerTest(BaseTest): """Test for SysLogHandler using UDP.""" + if threading: + server_class = TestUDPServer + address = ('localhost', 0) + def setUp(self): """Set up a UDP server to receive log messages, and a SysLogHandler pointing to that server's address and port.""" BaseTest.setUp(self) - addr = ('localhost', 0) - self.server = server = TestUDPServer(addr, self.handle_datagram, - 0.01) + self.server = server = self.server_class(self.address, + self.handle_datagram, 0.01) server.start() server.ready.wait() - self.sl_hdlr = logging.handlers.SysLogHandler(('localhost', - server.port)) + hcls = logging.handlers.SysLogHandler + if isinstance(server.server_address, tuple): + self.sl_hdlr = hcls(('localhost', server.port)) + else: + self.sl_hdlr = hcls(server.server_address) self.log_output = '' self.root_logger.removeHandler(self.root_logger.handlers[0]) self.root_logger.addHandler(self.sl_hdlr) @@ -1559,6 +1605,23 @@ class SysLogHandlerTest(BaseTest): @unittest.skipUnless(threading, 'Threading required for this test.') +class UnixSysLogHandlerTest(SysLogHandlerTest): + + """Test for SysLogHandler with Unix sockets.""" + + if threading: + server_class = TestUnixDatagramServer + + def setUp(self): + # override the definition in the base class + self.address = _get_temp_domain_socket() + SysLogHandlerTest.setUp(self) + + def tearDown(self): + SysLogHandlerTest.tearDown(self) + os.remove(self.address) + +@unittest.skipUnless(threading, 'Threading required for this test.') class HTTPHandlerTest(BaseTest): """Test for HTTPHandler.""" @@ -1779,7 +1842,7 @@ class WarningsTest(BaseTest): logger.removeHandler(h) s = stream.getvalue() h.close() - self.assertTrue(s.find("UserWarning: I'm warning you...\n") > 0) + self.assertGreater(s.find("UserWarning: I'm warning you...\n"), 0) #See if an explicit file uses the original implementation a_file = io.StringIO() @@ -1817,7 +1880,7 @@ class ConfigDictTest(BaseTest): """Reading logging config from a dictionary.""" - expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$" + expected_log_pat = r"^(\w+) \+\+ (\w+)$" # config0 is a standard configuration. config0 = { @@ -2389,6 +2452,32 @@ class ConfigDictTest(BaseTest): }, } + # As config0, but with properties + config14 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + '.': { + 'foo': 'bar', + 'terminator': '!\n', + } + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + out_of_order = { "version": 1, "formatters": { @@ -2656,11 +2745,20 @@ class ConfigDictTest(BaseTest): def test_config13_failure(self): self.assertRaises(Exception, self.apply_config, self.config13) + def test_config14_ok(self): + with captured_stdout() as output: + self.apply_config(self.config14) + h = logging._handlers['hand1'] + self.assertEqual(h.foo, 'bar') + self.assertEqual(h.terminator, '!\n') + logging.warning('Exclamation') + self.assertTrue(output.getvalue().endswith('Exclamation!\n')) + @unittest.skipUnless(threading, 'listen() needs threading to work') - def setup_via_listener(self, text): + def setup_via_listener(self, text, verify=None): text = text.encode("utf-8") # Ask for a randomly assigned port (by using port 0) - t = logging.config.listen(0) + t = logging.config.listen(0, verify) t.start() t.ready.wait() # Now get the port allocated @@ -2720,6 +2818,69 @@ class ConfigDictTest(BaseTest): # Original logger output is empty. self.assert_log_lines([]) + @unittest.skipUnless(threading, 'Threading required for this test.') + def test_listen_verify(self): + + def verify_fail(stuff): + return None + + def verify_reverse(stuff): + return stuff[::-1] + + logger = logging.getLogger("compiler.parser") + to_send = textwrap.dedent(ConfigFileTest.config1) + # First, specify a verification function that will fail. + # We expect to see no output, since our configuration + # never took effect. + with captured_stdout() as output: + self.setup_via_listener(to_send, verify_fail) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([], stream=output) + # Original logger output has the stuff we logged. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + # Now, perform no verification. Our configuration + # should take effect. + + with captured_stdout() as output: + self.setup_via_listener(to_send) # no verify callable specified + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ], stream=output) + # Original logger output still has the stuff we logged before. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + # Now, perform verification which transforms the bytes. + + with captured_stdout() as output: + self.setup_via_listener(to_send[::-1], verify_reverse) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '5'), + ('ERROR', '6'), + ], stream=output) + # Original logger output still has the stuff we logged before. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + def test_out_of_order(self): self.apply_config(self.out_of_order) handler = logging.getLogger('mymodule').handlers[0] @@ -2779,14 +2940,14 @@ class ChildLoggerTest(BaseTest): l2 = logging.getLogger('def.ghi') c1 = r.getChild('xyz') c2 = r.getChild('uvw.xyz') - self.assertTrue(c1 is logging.getLogger('xyz')) - self.assertTrue(c2 is logging.getLogger('uvw.xyz')) + self.assertIs(c1, logging.getLogger('xyz')) + self.assertIs(c2, logging.getLogger('uvw.xyz')) c1 = l1.getChild('def') c2 = c1.getChild('ghi') c3 = l1.getChild('def.ghi') - self.assertTrue(c1 is logging.getLogger('abc.def')) - self.assertTrue(c2 is logging.getLogger('abc.def.ghi')) - self.assertTrue(c2 is c3) + self.assertIs(c1, logging.getLogger('abc.def')) + self.assertIs(c2, logging.getLogger('abc.def.ghi')) + self.assertIs(c2, c3) class DerivedLogRecord(logging.LogRecord): @@ -2829,7 +2990,7 @@ class LogRecordFactoryTest(BaseTest): class QueueHandlerTest(BaseTest): # Do not bother with a logger name group. - expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$" + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" def setUp(self): BaseTest.setUp(self) @@ -3123,13 +3284,13 @@ class ShutdownTest(BaseTest): self.assertEqual('0 - release', self.called[-1]) def test_with_ioerror_in_acquire(self): - self._test_with_failure_in_method('acquire', IOError) + self._test_with_failure_in_method('acquire', OSError) def test_with_ioerror_in_flush(self): - self._test_with_failure_in_method('flush', IOError) + self._test_with_failure_in_method('flush', OSError) def test_with_ioerror_in_close(self): - self._test_with_failure_in_method('close', IOError) + self._test_with_failure_in_method('close', OSError) def test_with_valueerror_in_acquire(self): self._test_with_failure_in_method('acquire', ValueError) @@ -3235,6 +3396,25 @@ class ModuleLevelMiscTest(BaseTest): logging.setLoggerClass(logging.Logger) self.assertEqual(logging.getLoggerClass(), logging.Logger) + def test_logging_at_shutdown(self): + # Issue #20037 + code = """if 1: + import logging + + class A: + def __del__(self): + try: + raise ValueError("some error") + except Exception: + logging.exception("exception in __del__") + + a = A()""" + rc, out, err = assert_python_ok("-c", code) + err = err.decode() + self.assertIn("exception in __del__", err) + self.assertIn("ValueError: some error", err) + + class LogRecordTest(BaseTest): def test_str_rep(self): r = logging.makeLogRecord({}) @@ -3352,6 +3532,12 @@ class BasicConfigTest(unittest.TestCase): "ERROR:root:Log an error") def test_filename(self): + + def cleanup(h1, h2, fn): + h1.close() + h2.close() + os.remove(fn) + logging.basicConfig(filename='test.log') self.assertEqual(len(logging.root.handlers), 1) @@ -3359,17 +3545,23 @@ class BasicConfigTest(unittest.TestCase): self.assertIsInstance(handler, logging.FileHandler) expected = logging.FileHandler('test.log', 'a') - self.addCleanup(expected.close) self.assertEqual(handler.stream.mode, expected.stream.mode) self.assertEqual(handler.stream.name, expected.stream.name) + self.addCleanup(cleanup, handler, expected, 'test.log') def test_filemode(self): + + def cleanup(h1, h2, fn): + h1.close() + h2.close() + os.remove(fn) + logging.basicConfig(filename='test.log', filemode='wb') handler = logging.root.handlers[0] expected = logging.FileHandler('test.log', 'wb') - self.addCleanup(expected.close) self.assertEqual(handler.stream.mode, expected.stream.mode) + self.addCleanup(cleanup, handler, expected, 'test.log') def test_stream(self): stream = io.StringIO() @@ -3825,6 +4017,63 @@ class TimedRotatingFileHandlerTest(BaseFileTest): assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, self.fn, 'W7', delay=True) + def test_compute_rollover_daily_attime(self): + currentTime = 0 + atTime = datetime.time(12, 0, 0) + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, when='MIDNIGHT', interval=1, backupCount=0, utc=True, + atTime=atTime) + try: + actual = rh.computeRollover(currentTime) + self.assertEqual(actual, currentTime + 12 * 60 * 60) + + actual = rh.computeRollover(currentTime + 13 * 60 * 60) + self.assertEqual(actual, currentTime + 36 * 60 * 60) + finally: + rh.close() + + #@unittest.skipIf(True, 'Temporarily skipped while failures investigated.') + def test_compute_rollover_weekly_attime(self): + currentTime = int(time.time()) + today = currentTime - currentTime % 86400 + + atTime = datetime.time(12, 0, 0) + + wday = time.gmtime(today).tm_wday + for day in range(7): + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, when='W%d' % day, interval=1, backupCount=0, utc=True, + atTime=atTime) + try: + if wday > day: + # The rollover day has already passed this week, so we + # go over into next week + expected = (7 - wday + day) + else: + expected = (day - wday) + # At this point expected is in days from now, convert to seconds + expected *= 24 * 60 * 60 + # Add in the rollover time + expected += 12 * 60 * 60 + # Add in adjustment for today + expected += today + actual = rh.computeRollover(today) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + if day == wday: + # goes into following week + expected += 7 * 24 * 60 * 60 + actual = rh.computeRollover(today + 13 * 60 * 60) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + finally: + rh.close() + + def secs(**kw): return datetime.timedelta(**kw) // datetime.timedelta(seconds=1) @@ -3883,7 +4132,7 @@ class NTEventLogHandlerTest(BaseTest): h.handle(r) h.close() # Now see if the event is recorded - self.assertTrue(num_recs < win32evtlog.GetNumberOfEventLogRecords(elh)) + self.assertLess(num_recs, win32evtlog.GetNumberOfEventLogRecords(elh)) flags = win32evtlog.EVENTLOG_BACKWARDS_READ | \ win32evtlog.EVENTLOG_SEQUENTIAL_READ found = False @@ -3916,7 +4165,8 @@ def test_main(): SMTPHandlerTest, FileHandlerTest, RotatingFileHandlerTest, LastResortTest, LogRecordTest, ExceptionTest, SysLogHandlerTest, HTTPHandlerTest, NTEventLogHandlerTest, - TimedRotatingFileHandlerTest + TimedRotatingFileHandlerTest, UnixSocketHandlerTest, + UnixDatagramHandlerTest, UnixSysLogHandlerTest ) if __name__ == "__main__": diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py index 3f9c1e295b..13152ecf12 100644 --- a/Lib/test/test_long.py +++ b/Lib/test/test_long.py @@ -1079,7 +1079,7 @@ class LongTest(unittest.TestCase): self.assertRaises(OverflowError, (256).to_bytes, 1, 'big', signed=True) self.assertRaises(OverflowError, (256).to_bytes, 1, 'little', signed=False) self.assertRaises(OverflowError, (256).to_bytes, 1, 'little', signed=True) - self.assertRaises(OverflowError, (-1).to_bytes, 2, 'big', signed=False), + self.assertRaises(OverflowError, (-1).to_bytes, 2, 'big', signed=False) self.assertRaises(OverflowError, (-1).to_bytes, 2, 'little', signed=False) self.assertEqual((0).to_bytes(0, 'big'), b'') self.assertEqual((1).to_bytes(5, 'big'), b'\x00\x00\x00\x00\x01') diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py index 160b0e8b0f..3e73d7f1cf 100644 --- a/Lib/test/test_lzma.py +++ b/Lib/test/test_lzma.py @@ -383,6 +383,8 @@ class FileTestCase(unittest.TestCase): pass with LZMAFile(BytesIO(), "w") as f: pass + with LZMAFile(BytesIO(), "x") as f: + pass with LZMAFile(BytesIO(), "a") as f: pass @@ -410,13 +412,29 @@ class FileTestCase(unittest.TestCase): with LZMAFile(TESTFN, "ab"): pass + def test_init_with_x_mode(self): + self.addCleanup(unlink, TESTFN) + for mode in ("x", "xb"): + unlink(TESTFN) + with LZMAFile(TESTFN, mode): + pass + with self.assertRaises(FileExistsError): + with LZMAFile(TESTFN, mode): + pass + def test_init_bad_mode(self): with self.assertRaises(ValueError): LZMAFile(BytesIO(COMPRESSED_XZ), (3, "x")) with self.assertRaises(ValueError): LZMAFile(BytesIO(COMPRESSED_XZ), "") with self.assertRaises(ValueError): - LZMAFile(BytesIO(COMPRESSED_XZ), "x") + LZMAFile(BytesIO(COMPRESSED_XZ), "xt") + with self.assertRaises(ValueError): + LZMAFile(BytesIO(COMPRESSED_XZ), "x+") + with self.assertRaises(ValueError): + LZMAFile(BytesIO(COMPRESSED_XZ), "rx") + with self.assertRaises(ValueError): + LZMAFile(BytesIO(COMPRESSED_XZ), "wx") with self.assertRaises(ValueError): LZMAFile(BytesIO(COMPRESSED_XZ), "rt") with self.assertRaises(ValueError): @@ -698,6 +716,20 @@ class FileTestCase(unittest.TestCase): with LZMAFile(BytesIO(COMPRESSED_XZ[:128])) as f: self.assertRaises(EOFError, f.read) + def test_read_truncated(self): + # Drop stream footer: CRC (4 bytes), index size (4 bytes), + # flags (2 bytes) and magic number (2 bytes). + truncated = COMPRESSED_XZ[:-12] + with LZMAFile(BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + with LZMAFile(BytesIO(truncated)) as f: + self.assertEqual(f.read(len(INPUT)), INPUT) + self.assertRaises(EOFError, f.read, 1) + # Incomplete 12-byte header. + for i in range(12): + with LZMAFile(BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + def test_read_bad_args(self): f = LZMAFile(BytesIO(COMPRESSED_XZ)) f.close() @@ -1041,8 +1073,6 @@ class OpenTestCase(unittest.TestCase): with self.assertRaises(ValueError): lzma.open(TESTFN, "") with self.assertRaises(ValueError): - lzma.open(TESTFN, "x") - with self.assertRaises(ValueError): lzma.open(TESTFN, "rbt") with self.assertRaises(ValueError): lzma.open(TESTFN, "rb", encoding="utf-8") @@ -1091,6 +1121,16 @@ class OpenTestCase(unittest.TestCase): with lzma.open(bio, "rt", newline="\r") as f: self.assertEqual(f.readlines(), [text]) + def test_x_mode(self): + self.addCleanup(unlink, TESTFN) + for mode in ("x", "xb", "xt"): + unlink(TESTFN) + with lzma.open(TESTFN, mode): + pass + with self.assertRaises(FileExistsError): + with lzma.open(TESTFN, mode): + pass + class MiscellaneousTestCase(unittest.TestCase): diff --git a/Lib/test/test_mailbox.py b/Lib/test/test_mailbox.py index 5e9e30acf1..d40509368a 100644 --- a/Lib/test/test_mailbox.py +++ b/Lib/test/test_mailbox.py @@ -596,7 +596,7 @@ class TestMaildir(TestMailbox, unittest.TestCase): def setUp(self): TestMailbox.setUp(self) - if os.name in ('nt', 'os2') or sys.platform == 'cygwin': + if (os.name == 'nt') or (sys.platform == 'cygwin'): self._box.colon = '!' def assertMailboxEmpty(self): diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py index 0b1413715d..068c4713e1 100644 --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -22,37 +22,13 @@ class HelperMixin: class IntTestCase(unittest.TestCase, HelperMixin): def test_ints(self): - # Test the full range of Python ints. - n = sys.maxsize + # Test a range of Python ints larger than the machine word size. + n = sys.maxsize ** 2 while n: for expected in (-n, n): self.helper(expected) n = n >> 1 - def test_int64(self): - # Simulate int marshaling on a 64-bit box. This is most interesting if - # we're running the test on a 32-bit box, of course. - - def to_little_endian_string(value, nbytes): - b = bytearray() - for i in range(nbytes): - b.append(value & 0xff) - value >>= 8 - return b - - maxint64 = (1 << 63) - 1 - minint64 = -maxint64-1 - - for base in maxint64, minint64, -maxint64, -(minint64 >> 1): - while base: - s = b'I' + to_little_endian_string(base, 8) - got = marshal.loads(s) - self.assertEqual(base, got) - if base == -1: # a fixed-point for shifting right 1 - base = 0 - else: - base >>= 1 - def test_bool(self): for b in (True, False): self.helper(b) @@ -199,10 +175,14 @@ class BugsTestCase(unittest.TestCase): except Exception: pass - def test_loads_recursion(self): + def test_loads_2x_code(self): s = b'c' + (b'X' * 4*4) + b'{' * 2**20 self.assertRaises(ValueError, marshal.loads, s) + def test_loads_recursion(self): + s = b'c' + (b'X' * 4*5) + b'{' * 2**20 + self.assertRaises(ValueError, marshal.loads, s) + def test_recursion_limit(self): # Create a deeply nested structure. head = last = [] @@ -280,15 +260,20 @@ class BugsTestCase(unittest.TestCase): def test_bad_reader(self): class BadReader(io.BytesIO): - def read(self, n=-1): - b = super().read(n) + def readinto(self, buf): + n = super().readinto(buf) if n is not None and n > 4: - b += b' ' * 10**6 - return b + n += 10**6 + return n for value in (1.0, 1j, b'0123456789', '0123456789'): self.assertRaises(ValueError, marshal.load, BadReader(marshal.dumps(value))) + def _test_eof(self): + data = marshal.dumps(("hello", "dolly", None)) + for i in range(len(data)): + self.assertRaises(EOFError, marshal.loads, data[0: i]) + LARGE_SIZE = 2**31 pointer_size = 8 if sys.maxsize > 0xFFFFFFFF else 4 @@ -333,6 +318,122 @@ class LargeValuesTestCase(unittest.TestCase): def test_bytearray(self, size): self.check_unmarshallable(bytearray(size)) +def CollectObjectIDs(ids, obj): + """Collect object ids seen in a structure""" + if id(obj) in ids: + return + ids.add(id(obj)) + if isinstance(obj, (list, tuple, set, frozenset)): + for e in obj: + CollectObjectIDs(ids, e) + elif isinstance(obj, dict): + for k, v in obj.items(): + CollectObjectIDs(ids, k) + CollectObjectIDs(ids, v) + return len(ids) + +class InstancingTestCase(unittest.TestCase, HelperMixin): + intobj = 123321 + floatobj = 1.2345 + strobj = "abcde"*3 + dictobj = {"hello":floatobj, "goodbye":floatobj, floatobj:"hello"} + + def helper3(self, rsample, recursive=False, simple=False): + #we have two instances + sample = (rsample, rsample) + + n0 = CollectObjectIDs(set(), sample) + + s3 = marshal.dumps(sample, 3) + n3 = CollectObjectIDs(set(), marshal.loads(s3)) + + #same number of instances generated + self.assertEqual(n3, n0) + + if not recursive: + #can compare with version 2 + s2 = marshal.dumps(sample, 2) + n2 = CollectObjectIDs(set(), marshal.loads(s2)) + #old format generated more instances + self.assertGreater(n2, n0) + + #if complex objects are in there, old format is larger + if not simple: + self.assertGreater(len(s2), len(s3)) + else: + self.assertGreaterEqual(len(s2), len(s3)) + + def testInt(self): + self.helper(self.intobj) + self.helper3(self.intobj, simple=True) + + def testFloat(self): + self.helper(self.floatobj) + self.helper3(self.floatobj) + + def testStr(self): + self.helper(self.strobj) + self.helper3(self.strobj) + + def testDict(self): + self.helper(self.dictobj) + self.helper3(self.dictobj) + + def testModule(self): + with open(__file__, "rb") as f: + code = f.read() + if __file__.endswith(".py"): + code = compile(code, __file__, "exec") + self.helper(code) + self.helper3(code) + + def testRecursion(self): + d = dict(self.dictobj) + d["self"] = d + self.helper3(d, recursive=True) + l = [self.dictobj] + l.append(l) + self.helper3(l, recursive=True) + +class CompatibilityTestCase(unittest.TestCase): + def _test(self, version): + with open(__file__, "rb") as f: + code = f.read() + if __file__.endswith(".py"): + code = compile(code, __file__, "exec") + data = marshal.dumps(code, version) + marshal.loads(data) + + def test0To3(self): + self._test(0) + + def test1To3(self): + self._test(1) + + def test2To3(self): + self._test(2) + + def test3To3(self): + self._test(3) + +class InterningTestCase(unittest.TestCase, HelperMixin): + strobj = "this is an interned string" + strobj = sys.intern(strobj) + + def testIntern(self): + s = marshal.loads(marshal.dumps(self.strobj)) + self.assertEqual(s, self.strobj) + self.assertEqual(id(s), id(self.strobj)) + s2 = sys.intern(s) + self.assertEqual(id(s2), id(s)) + + def testNoIntern(self): + s = marshal.loads(marshal.dumps(self.strobj, 2)) + self.assertEqual(s, self.strobj) + self.assertNotEqual(id(s), id(self.strobj)) + s2 = sys.intern(s) + self.assertNotEqual(id(s2), id(s)) + def test_main(): support.run_unittest(IntTestCase, diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py index 4fa9a195f9..9ef293e708 100644 --- a/Lib/test/test_memoryio.py +++ b/Lib/test/test_memoryio.py @@ -520,12 +520,12 @@ class TextIOTestMixin: def test_relative_seek(self): memio = self.ioclass() - self.assertRaises(IOError, memio.seek, -1, 1) - self.assertRaises(IOError, memio.seek, 3, 1) - self.assertRaises(IOError, memio.seek, -3, 1) - self.assertRaises(IOError, memio.seek, -1, 2) - self.assertRaises(IOError, memio.seek, 1, 1) - self.assertRaises(IOError, memio.seek, 1, 2) + self.assertRaises(OSError, memio.seek, -1, 1) + self.assertRaises(OSError, memio.seek, 3, 1) + self.assertRaises(OSError, memio.seek, -3, 1) + self.assertRaises(OSError, memio.seek, -1, 2) + self.assertRaises(OSError, memio.seek, 1, 1) + self.assertRaises(OSError, memio.seek, 1, 2) def test_textio_properties(self): memio = self.ioclass() diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py index bf0eaad2ac..e7df8a762c 100644 --- a/Lib/test/test_memoryview.py +++ b/Lib/test/test_memoryview.py @@ -352,6 +352,15 @@ class AbstractMemoryTests: self.assertIs(wr(), None) self.assertIs(L[0], b) + def test_reversed(self): + for tp in self._types: + b = tp(self._source) + m = self._view(b) + aslist = list(reversed(m.tolist())) + self.assertEqual(list(reversed(m)), aslist) + self.assertEqual(list(reversed(m)), list(m[::-1])) + + # Variations on source objects for the buffer: bytes-like objects, then arrays # with itemsize > 1. # NOTE: support for multi-dimensional objects is unimplemented. diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py index 113c7fd71d..a1516c242d 100644 --- a/Lib/test/test_minidom.py +++ b/Lib/test/test_minidom.py @@ -360,6 +360,15 @@ class MinidomTest(unittest.TestCase): and el.getAttribute("spam2") == "bam2") dom.unlink() + def testGetAttrList(self): + pass + + def testGetAttrValues(self): + pass + + def testGetAttrLength(self): + pass + def testGetAttribute(self): dom = Document() child = dom.appendChild( @@ -380,6 +389,8 @@ class MinidomTest(unittest.TestCase): self.assertEqual(child2.getAttributeNS("http://www.python.org", "missing"), '') + def testGetAttributeNode(self): pass + def testGetElementsByTagNameNS(self): d="""<foo xmlns:minidom='http://pyxml.sf.net/minidom'> <minidom:myelem/> @@ -450,6 +461,8 @@ class MinidomTest(unittest.TestCase): self.confirm(str(node) == repr(node)) dom.unlink() + def testTextNodeRepr(self): pass + def testWriteXML(self): str = '<?xml version="1.0" ?><a b="c"/>' dom = parseString(str) @@ -513,6 +526,14 @@ class MinidomTest(unittest.TestCase): and pi.localName is None and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE) + def testProcessingInstructionRepr(self): pass + + def testTextRepr(self): pass + + def testWriteText(self): pass + + def testDocumentElement(self): pass + def testTooManyDocumentElements(self): doc = parseString("<doc/>") elem = doc.createElement("extra") @@ -521,6 +542,26 @@ class MinidomTest(unittest.TestCase): elem.unlink() doc.unlink() + def testCreateElementNS(self): pass + + def testCreateAttributeNS(self): pass + + def testParse(self): pass + + def testParseString(self): pass + + def testComment(self): pass + + def testAttrListItem(self): pass + + def testAttrListItems(self): pass + + def testAttrListItemNS(self): pass + + def testAttrListKeys(self): pass + + def testAttrListKeysNS(self): pass + def testRemoveNamedItem(self): doc = parseString("<doc a=''/>") e = doc.documentElement @@ -540,6 +581,30 @@ class MinidomTest(unittest.TestCase): self.assertRaises(xml.dom.NotFoundErr, attrs.removeNamedItemNS, "http://xml.python.org/", "b") + def testAttrListValues(self): pass + + def testAttrListLength(self): pass + + def testAttrList__getitem__(self): pass + + def testAttrList__setitem__(self): pass + + def testSetAttrValueandNodeValue(self): pass + + def testParseElement(self): pass + + def testParseAttributes(self): pass + + def testParseElementNamespaces(self): pass + + def testParseAttributeNamespaces(self): pass + + def testParseProcessingInstructions(self): pass + + def testChildNodes(self): pass + + def testFirstChild(self): pass + def testHasChildNodes(self): dom = parseString("<doc><foo/></doc>") doc = dom.documentElement diff --git a/Lib/test/test_mmap.py b/Lib/test/test_mmap.py index 899df8d818..6ca5e1b730 100644 --- a/Lib/test/test_mmap.py +++ b/Lib/test/test_mmap.py @@ -1,11 +1,12 @@ from test.support import (TESTFN, run_unittest, import_module, unlink, - requires, _2G, _4G) + requires, _2G, _4G, gc_collect) import unittest import os import re import itertools import socket import sys +import weakref # Skip test if we can't import mmap. mmap = import_module('mmap') @@ -245,7 +246,7 @@ class MmapTests(unittest.TestCase): def test_bad_file_desc(self): # Try opening a bad file descriptor... - self.assertRaises(mmap.error, mmap.mmap, -2, 4096) + self.assertRaises(OSError, mmap.mmap, -2, 4096) def test_tougher_find(self): # Do a tougher .find() test. SF bug 515943 pointed out that, in 2.2, @@ -655,7 +656,7 @@ class MmapTests(unittest.TestCase): m = mmap.mmap(f.fileno(), 0) f.close() try: - m.resize(0) # will raise WindowsError + m.resize(0) # will raise OSError except: pass try: @@ -671,7 +672,7 @@ class MmapTests(unittest.TestCase): # parameters to _get_osfhandle. s = socket.socket() try: - with self.assertRaises(mmap.error): + with self.assertRaises(OSError): m = mmap.mmap(s.fileno(), 10) finally: s.close() @@ -682,14 +683,23 @@ class MmapTests(unittest.TestCase): self.assertTrue(m.closed) def test_context_manager_exception(self): - # Test that the IOError gets passed through + # Test that the OSError gets passed through with self.assertRaises(Exception) as exc: with mmap.mmap(-1, 10) as m: - raise IOError - self.assertIsInstance(exc.exception, IOError, + raise OSError + self.assertIsInstance(exc.exception, OSError, "wrong exception raised in context manager") self.assertTrue(m.closed, "context manager failed") + def test_weakref(self): + # Check mmap objects are weakrefable + mm = mmap.mmap(-1, 16) + wr = weakref.ref(mm) + self.assertIs(wr(), mm) + del mm + gc_collect() + self.assertIs(wr(), None) + class LargeMmapTests(unittest.TestCase): def setUp(self): @@ -707,7 +717,7 @@ class LargeMmapTests(unittest.TestCase): f.seek(num_zeroes) f.write(tail) f.flush() - except (IOError, OverflowError): + except (OSError, OverflowError): f.close() raise unittest.SkipTest("filesystem does not have largefile support") return f diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py index e5a2525d18..1230293670 100644 --- a/Lib/test/test_module.py +++ b/Lib/test/test_module.py @@ -1,6 +1,8 @@ # Test the module type import unittest +import weakref from test.support import run_unittest, gc_collect +from test.script_helper import assert_python_ok import sys ModuleType = type(sys) @@ -33,7 +35,12 @@ class ModuleTests(unittest.TestCase): foo = ModuleType("foo") self.assertEqual(foo.__name__, "foo") self.assertEqual(foo.__doc__, None) - self.assertEqual(foo.__dict__, {"__name__": "foo", "__doc__": None}) + self.assertIs(foo.__loader__, None) + self.assertIs(foo.__package__, None) + self.assertIs(foo.__spec__, None) + self.assertEqual(foo.__dict__, {"__name__": "foo", "__doc__": None, + "__loader__": None, "__package__": None, + "__spec__": None}) def test_ascii_docstring(self): # ASCII docstring @@ -41,7 +48,9 @@ class ModuleTests(unittest.TestCase): self.assertEqual(foo.__name__, "foo") self.assertEqual(foo.__doc__, "foodoc") self.assertEqual(foo.__dict__, - {"__name__": "foo", "__doc__": "foodoc"}) + {"__name__": "foo", "__doc__": "foodoc", + "__loader__": None, "__package__": None, + "__spec__": None}) def test_unicode_docstring(self): # Unicode docstring @@ -49,7 +58,9 @@ class ModuleTests(unittest.TestCase): self.assertEqual(foo.__name__, "foo") self.assertEqual(foo.__doc__, "foodoc\u1234") self.assertEqual(foo.__dict__, - {"__name__": "foo", "__doc__": "foodoc\u1234"}) + {"__name__": "foo", "__doc__": "foodoc\u1234", + "__loader__": None, "__package__": None, + "__spec__": None}) def test_reinit(self): # Reinitialization should not replace the __dict__ @@ -61,10 +72,10 @@ class ModuleTests(unittest.TestCase): self.assertEqual(foo.__doc__, "foodoc") self.assertEqual(foo.bar, 42) self.assertEqual(foo.__dict__, - {"__name__": "foo", "__doc__": "foodoc", "bar": 42}) + {"__name__": "foo", "__doc__": "foodoc", "bar": 42, + "__loader__": None, "__package__": None, "__spec__": None}) self.assertTrue(foo.__dict__ is d) - @unittest.expectedFailure def test_dont_clear_dict(self): # See issue 7140. def f(): @@ -89,6 +100,14 @@ a = A(destroyed)""" gc_collect() self.assertEqual(destroyed, [1]) + def test_weakref(self): + m = ModuleType("foo") + wr = weakref.ref(m) + self.assertIs(wr(), m) + del m + gc_collect() + self.assertIs(wr(), None) + def test_module_repr_minimal(self): # reprs when modules have no __file__, __name__, or __loader__ m = ModuleType('foo') @@ -110,13 +129,19 @@ a = A(destroyed)""" m.__file__ = '/tmp/foo.py' self.assertEqual(repr(m), "<module '?' from '/tmp/foo.py'>") + def test_module_repr_with_loader_as_None(self): + m = ModuleType('foo') + assert m.__loader__ is None + self.assertEqual(repr(m), "<module 'foo'>") + def test_module_repr_with_bare_loader_but_no_name(self): m = ModuleType('foo') del m.__name__ # Yes, a class not an instance. m.__loader__ = BareLoader + loader_repr = repr(BareLoader) self.assertEqual( - repr(m), "<module '?' (<class 'test.test_module.BareLoader'>)>") + repr(m), "<module '?' ({})>".format(loader_repr)) def test_module_repr_with_full_loader_but_no_name(self): # m.__loader__.module_repr() will fail because the module has no @@ -126,15 +151,17 @@ a = A(destroyed)""" del m.__name__ # Yes, a class not an instance. m.__loader__ = FullLoader + loader_repr = repr(FullLoader) self.assertEqual( - repr(m), "<module '?' (<class 'test.test_module.FullLoader'>)>") + repr(m), "<module '?' ({})>".format(loader_repr)) def test_module_repr_with_bare_loader(self): m = ModuleType('foo') # Yes, a class not an instance. m.__loader__ = BareLoader + module_repr = repr(BareLoader) self.assertEqual( - repr(m), "<module 'foo' (<class 'test.test_module.BareLoader'>)>") + repr(m), "<module 'foo' ({})>".format(module_repr)) def test_module_repr_with_full_loader(self): m = ModuleType('foo') @@ -164,8 +191,25 @@ a = A(destroyed)""" def test_module_repr_source(self): r = repr(unittest) - self.assertEqual(r[:25], "<module 'unittest' from '") - self.assertEqual(r[-13:], "__init__.py'>") + starts_with = "<module 'unittest' from '" + ends_with = "__init__.py'>" + self.assertEqual(r[:len(starts_with)], starts_with, + '{!r} does not start with {!r}'.format(r, starts_with)) + self.assertEqual(r[-len(ends_with):], ends_with, + '{!r} does not end with {!r}'.format(r, ends_with)) + + def test_module_finalization_at_shutdown(self): + # Module globals and builtins should still be available during shutdown + rc, out, err = assert_python_ok("-c", "from test import final_a") + self.assertFalse(err) + lines = out.splitlines() + self.assertEqual(set(lines), { + b"x = a", + b"x = b", + b"final_a.x = a", + b"final_b.x = b", + b"len = len", + b"shutil.rmtree = rmtree"}) # frozen and namespace module reprs are tested in importlib. diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py index 6889184cce..ea592c7501 100644 --- a/Lib/test/test_multibytecodec.py +++ b/Lib/test/test_multibytecodec.py @@ -175,57 +175,28 @@ class Test_StreamReader(unittest.TestCase): support.unlink(TESTFN) class Test_StreamWriter(unittest.TestCase): - if len('\U00012345') == 2: # UCS2 - def test_gb18030(self): - s= io.BytesIO() - c = codecs.getwriter('gb18030')(s) - c.write('123') - self.assertEqual(s.getvalue(), b'123') - c.write('\U00012345') - self.assertEqual(s.getvalue(), b'123\x907\x959') - c.write('\U00012345'[0]) - self.assertEqual(s.getvalue(), b'123\x907\x959') - c.write('\U00012345'[1] + '\U00012345' + '\uac00\u00ac') - self.assertEqual(s.getvalue(), - b'123\x907\x959\x907\x959\x907\x959\x827\xcf5\x810\x851') - c.write('\U00012345'[0]) - self.assertEqual(s.getvalue(), - b'123\x907\x959\x907\x959\x907\x959\x827\xcf5\x810\x851') - self.assertRaises(UnicodeError, c.reset) - self.assertEqual(s.getvalue(), - b'123\x907\x959\x907\x959\x907\x959\x827\xcf5\x810\x851') - - def test_utf_8(self): - s= io.BytesIO() - c = codecs.getwriter('utf-8')(s) - c.write('123') - self.assertEqual(s.getvalue(), b'123') - c.write('\U00012345') - self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85') - - # Python utf-8 codec can't buffer surrogate pairs yet. - if 0: - c.write('\U00012345'[0]) - self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85') - c.write('\U00012345'[1] + '\U00012345' + '\uac00\u00ac') - self.assertEqual(s.getvalue(), - b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85' - b'\xea\xb0\x80\xc2\xac') - c.write('\U00012345'[0]) - self.assertEqual(s.getvalue(), - b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85' - b'\xea\xb0\x80\xc2\xac') - c.reset() - self.assertEqual(s.getvalue(), - b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85' - b'\xea\xb0\x80\xc2\xac\xed\xa0\x88') - c.write('\U00012345'[1]) - self.assertEqual(s.getvalue(), - b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85' - b'\xea\xb0\x80\xc2\xac\xed\xa0\x88\xed\xbd\x85') - - else: # UCS4 - pass + def test_gb18030(self): + s= io.BytesIO() + c = codecs.getwriter('gb18030')(s) + c.write('123') + self.assertEqual(s.getvalue(), b'123') + c.write('\U00012345') + self.assertEqual(s.getvalue(), b'123\x907\x959') + c.write('\uac00\u00ac') + self.assertEqual(s.getvalue(), + b'123\x907\x959\x827\xcf5\x810\x851') + + def test_utf_8(self): + s= io.BytesIO() + c = codecs.getwriter('utf-8')(s) + c.write('123') + self.assertEqual(s.getvalue(), b'123') + c.write('\U00012345') + self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85') + c.write('\uac00\u00ac') + self.assertEqual(s.getvalue(), + b'123\xf0\x92\x8d\x85' + b'\xea\xb0\x80\xc2\xac') def test_streamwriter_strwrite(self): s = io.BytesIO() diff --git a/Lib/test/test_multiprocessing_fork.py b/Lib/test/test_multiprocessing_fork.py new file mode 100644 index 0000000000..2bf4e75644 --- /dev/null +++ b/Lib/test/test_multiprocessing_fork.py @@ -0,0 +1,7 @@ +import unittest +import test._test_multiprocessing + +test._test_multiprocessing.install_tests_in_module_dict(globals(), 'fork') + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_forkserver.py b/Lib/test/test_multiprocessing_forkserver.py new file mode 100644 index 0000000000..193a04a5fc --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver.py @@ -0,0 +1,7 @@ +import unittest +import test._test_multiprocessing + +test._test_multiprocessing.install_tests_in_module_dict(globals(), 'forkserver') + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_main_handling.py b/Lib/test/test_multiprocessing_main_handling.py new file mode 100644 index 0000000000..7ea8e7806c --- /dev/null +++ b/Lib/test/test_multiprocessing_main_handling.py @@ -0,0 +1,285 @@ +# tests __main__ module handling in multiprocessing + +import importlib +import importlib.machinery +import zipimport +import unittest +import sys +import os +import os.path +import py_compile + +from test import support +from test.script_helper import ( + make_pkg, make_script, make_zip_pkg, make_zip_script, + assert_python_ok, assert_python_failure, temp_dir, + spawn_python, kill_python) + +# Skip tests if _multiprocessing wasn't built. +_multiprocessing = support.import_module('_multiprocessing') +# Look up which start methods are available to test +import multiprocessing +AVAILABLE_START_METHODS = set(multiprocessing.get_all_start_methods()) + +verbose = support.verbose + +test_source = """\ +# multiprocessing includes all sorts of shenanigans to make __main__ +# attributes accessible in the subprocess in a pickle compatible way. + +# We run the "doesn't work in the interactive interpreter" example from +# the docs to make sure it *does* work from an executed __main__, +# regardless of the invocation mechanism + +import sys +import time +from multiprocessing import Pool, set_start_method + +# We use this __main__ defined function in the map call below in order to +# check that multiprocessing in correctly running the unguarded +# code in child processes and then making it available as __main__ +def f(x): + return x*x + +# Check explicit relative imports +if "check_sibling" in __file__: + # We're inside a package and not in a __main__.py file + # so make sure explicit relative imports work correctly + from . import sibling + +if __name__ == '__main__': + start_method = sys.argv[1] + set_start_method(start_method) + p = Pool(5) + results = [] + p.map_async(f, [1, 2, 3], callback=results.extend) + deadline = time.time() + 10 # up to 10 s to report the results + while not results: + time.sleep(0.05) + if time.time() > deadline: + raise RuntimeError("Timed out waiting for results") + results.sort() + print(start_method, "->", results) +""" + +test_source_main_skipped_in_children = """\ +# __main__.py files have an implied "if __name__ == '__main__'" so +# multiprocessing should always skip running them in child processes + +# This means we can't use __main__ defined functions in child processes, +# so we just use "int" as a passthrough operation below + +if __name__ != "__main__": + raise RuntimeError("Should only be called as __main__!") + +import sys +import time +from multiprocessing import Pool, set_start_method + +start_method = sys.argv[1] +set_start_method(start_method) +p = Pool(5) +results = [] +p.map_async(int, [1, 4, 9], callback=results.extend) +deadline = time.time() + 10 # up to 10 s to report the results +while not results: + time.sleep(0.05) + if time.time() > deadline: + raise RuntimeError("Timed out waiting for results") +results.sort() +print(start_method, "->", results) +""" + +# These helpers were copied from test_cmd_line_script & tweaked a bit... + +def _make_test_script(script_dir, script_basename, + source=test_source, omit_suffix=False): + to_return = make_script(script_dir, script_basename, + source, omit_suffix) + # Hack to check explicit relative imports + if script_basename == "check_sibling": + make_script(script_dir, "sibling", "") + importlib.invalidate_caches() + return to_return + +def _make_test_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename, + source=test_source, depth=1): + to_return = make_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename, + source, depth) + importlib.invalidate_caches() + return to_return + +# There's no easy way to pass the script directory in to get +# -m to work (avoiding that is the whole point of making +# directories and zipfiles executable!) +# So we fake it for testing purposes with a custom launch script +launch_source = """\ +import sys, os.path, runpy +sys.path.insert(0, %s) +runpy._run_module_as_main(%r) +""" + +def _make_launch_script(script_dir, script_basename, module_name, path=None): + if path is None: + path = "os.path.dirname(__file__)" + else: + path = repr(path) + source = launch_source % (path, module_name) + to_return = make_script(script_dir, script_basename, source) + importlib.invalidate_caches() + return to_return + +class MultiProcessingCmdLineMixin(): + maxDiff = None # Show full tracebacks on subprocess failure + + def setUp(self): + if self.start_method not in AVAILABLE_START_METHODS: + self.skipTest("%r start method not available" % self.start_method) + + def _check_output(self, script_name, exit_code, out, err): + if verbose > 1: + print("Output from test script %r:" % script_name) + print(out) + self.assertEqual(exit_code, 0) + self.assertEqual(err.decode('utf-8'), '') + expected_results = "%s -> [1, 4, 9]" % self.start_method + self.assertEqual(out.decode('utf-8').strip(), expected_results) + + def _check_script(self, script_name, *cmd_line_switches): + if not __debug__: + cmd_line_switches += ('-' + 'O' * sys.flags.optimize,) + run_args = cmd_line_switches + (script_name, self.start_method) + rc, out, err = assert_python_ok(*run_args, __isolated=False) + self._check_output(script_name, rc, out, err) + + def test_basic_script(self): + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script') + self._check_script(script_name) + + def test_basic_script_no_suffix(self): + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script', + omit_suffix=True) + self._check_script(script_name) + + def test_ipython_workaround(self): + # Some versions of the IPython launch script are missing the + # __name__ = "__main__" guard, and multiprocessing has long had + # a workaround for that case + # See https://github.com/ipython/ipython/issues/4698 + source = test_source_main_skipped_in_children + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'ipython', + source=source) + self._check_script(script_name) + script_no_suffix = _make_test_script(script_dir, 'ipython', + source=source, + omit_suffix=True) + self._check_script(script_no_suffix) + + def test_script_compiled(self): + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script') + py_compile.compile(script_name, doraise=True) + os.remove(script_name) + pyc_file = support.make_legacy_pyc(script_name) + self._check_script(pyc_file) + + def test_directory(self): + source = self.main_in_children_source + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + self._check_script(script_dir) + + def test_directory_compiled(self): + source = self.main_in_children_source + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + py_compile.compile(script_name, doraise=True) + os.remove(script_name) + pyc_file = support.make_legacy_pyc(script_name) + self._check_script(script_dir) + + def test_zipfile(self): + source = self.main_in_children_source + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + zip_name, run_name = make_zip_script(script_dir, 'test_zip', script_name) + self._check_script(zip_name) + + def test_zipfile_compiled(self): + source = self.main_in_children_source + with temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + compiled_name = py_compile.compile(script_name, doraise=True) + zip_name, run_name = make_zip_script(script_dir, 'test_zip', compiled_name) + self._check_script(zip_name) + + def test_module_in_package(self): + with temp_dir() as script_dir: + pkg_dir = os.path.join(script_dir, 'test_pkg') + make_pkg(pkg_dir) + script_name = _make_test_script(pkg_dir, 'check_sibling') + launch_name = _make_launch_script(script_dir, 'launch', + 'test_pkg.check_sibling') + self._check_script(launch_name) + + def test_module_in_package_in_zipfile(self): + with temp_dir() as script_dir: + zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script') + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.script', zip_name) + self._check_script(launch_name) + + def test_module_in_subpackage_in_zipfile(self): + with temp_dir() as script_dir: + zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script', depth=2) + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.test_pkg.script', zip_name) + self._check_script(launch_name) + + def test_package(self): + source = self.main_in_children_source + with temp_dir() as script_dir: + pkg_dir = os.path.join(script_dir, 'test_pkg') + make_pkg(pkg_dir) + script_name = _make_test_script(pkg_dir, '__main__', + source=source) + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg') + self._check_script(launch_name) + + def test_package_compiled(self): + source = self.main_in_children_source + with temp_dir() as script_dir: + pkg_dir = os.path.join(script_dir, 'test_pkg') + make_pkg(pkg_dir) + script_name = _make_test_script(pkg_dir, '__main__', + source=source) + compiled_name = py_compile.compile(script_name, doraise=True) + os.remove(script_name) + pyc_file = support.make_legacy_pyc(script_name) + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg') + self._check_script(launch_name) + +# Test all supported start methods (setupClass skips as appropriate) + +class SpawnCmdLineTest(MultiProcessingCmdLineMixin, unittest.TestCase): + start_method = 'spawn' + main_in_children_source = test_source_main_skipped_in_children + +class ForkCmdLineTest(MultiProcessingCmdLineMixin, unittest.TestCase): + start_method = 'fork' + main_in_children_source = test_source + +class ForkServerCmdLineTest(MultiProcessingCmdLineMixin, unittest.TestCase): + start_method = 'forkserver' + main_in_children_source = test_source_main_skipped_in_children + +def tearDownModule(): + support.reap_children() + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_spawn.py b/Lib/test/test_multiprocessing_spawn.py new file mode 100644 index 0000000000..334ae9e8f7 --- /dev/null +++ b/Lib/test/test_multiprocessing_spawn.py @@ -0,0 +1,7 @@ +import unittest +import test._test_multiprocessing + +test._test_multiprocessing.install_tests_in_module_dict(globals(), 'spawn') + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_namespace_pkgs.py b/Lib/test/test_namespace_pkgs.py index 7067b12e8f..6639612631 100644 --- a/Lib/test/test_namespace_pkgs.py +++ b/Lib/test/test_namespace_pkgs.py @@ -1,7 +1,10 @@ -import sys import contextlib -import unittest +import importlib.abc +import importlib.machinery import os +import sys +import types +import unittest from test.test_importlib import util from test.support import run_unittest @@ -286,9 +289,5 @@ class ModuleAndNamespacePackageInSameDir(NamespacePackageTest): self.assertEqual(a_test.attr, 'in module') -def test_main(): - run_unittest(*NamespacePackageTest.__subclasses__()) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index 1d52713f07..71a4ec022b 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -266,7 +266,7 @@ class NetworkedNNTPTestsMixin: return False try: server.help() - except (socket.error, EOFError): + except (OSError, EOFError): return False return True diff --git a/Lib/test/test_normalization.py b/Lib/test/test_normalization.py index 28ede34cd7..ab2eeb77da 100644 --- a/Lib/test/test_normalization.py +++ b/Lib/test/test_normalization.py @@ -43,7 +43,7 @@ class NormalizationTest(unittest.TestCase): try: testdata = open_urlresource(TESTDATAURL, encoding="utf-8", check=check_version) - except (IOError, HTTPException): + except (OSError, HTTPException): self.skipTest("Could not retrieve " + TESTDATAURL) self.addCleanup(testdata.close) for line in testdata: diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py index e9e1d715ef..000fc75c20 100644 --- a/Lib/test/test_ntpath.py +++ b/Lib/test/test_ntpath.py @@ -306,6 +306,40 @@ class TestNtpath(unittest.TestCase): # dialogs (#4804) ntpath.sameopenfile(-1, -1) + def test_ismount(self): + self.assertTrue(ntpath.ismount("c:\\")) + self.assertTrue(ntpath.ismount("C:\\")) + self.assertTrue(ntpath.ismount("c:/")) + self.assertTrue(ntpath.ismount("C:/")) + self.assertTrue(ntpath.ismount("\\\\.\\c:\\")) + self.assertTrue(ntpath.ismount("\\\\.\\C:\\")) + + self.assertTrue(ntpath.ismount(b"c:\\")) + self.assertTrue(ntpath.ismount(b"C:\\")) + self.assertTrue(ntpath.ismount(b"c:/")) + self.assertTrue(ntpath.ismount(b"C:/")) + self.assertTrue(ntpath.ismount(b"\\\\.\\c:\\")) + self.assertTrue(ntpath.ismount(b"\\\\.\\C:\\")) + + with support.temp_dir() as d: + self.assertFalse(ntpath.ismount(d)) + + if sys.platform == "win32": + # + # Make sure the current folder isn't the root folder + # (or any other volume root). The drive-relative + # locations below cannot then refer to mount points + # + drive, path = ntpath.splitdrive(sys.executable) + with support.change_cwd(os.path.dirname(sys.executable)): + self.assertFalse(ntpath.ismount(drive.lower())) + self.assertFalse(ntpath.ismount(drive.upper())) + + self.assertTrue(ntpath.ismount("\\\\localhost\\c$")) + self.assertTrue(ntpath.ismount("\\\\localhost\\c$\\")) + + self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$")) + self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$\\")) class NtCommonTest(test_genericpath.CommonTest, unittest.TestCase): pathmodule = ntpath diff --git a/Lib/test/test_openpty.py b/Lib/test/test_openpty.py index 63843705b1..47851072c1 100644 --- a/Lib/test/test_openpty.py +++ b/Lib/test/test_openpty.py @@ -4,7 +4,7 @@ import os, unittest from test.support import run_unittest if not hasattr(os, "openpty"): - raise unittest.SkipTest("No openpty() available.") + raise unittest.SkipTest("os.openpty() not available.") class OpenptyTest(unittest.TestCase): diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py index fa608b9a52..ab58a98365 100644 --- a/Lib/test/test_operator.py +++ b/Lib/test/test_operator.py @@ -1,8 +1,10 @@ -import operator import unittest from test import support +py_operator = support.import_fresh_module('operator', blocked=['_operator']) +c_operator = support.import_fresh_module('operator', fresh=['_operator']) + class Seq1: def __init__(self, lst): self.lst = lst @@ -32,8 +34,9 @@ class Seq2(object): return other * self.lst -class OperatorTestCase(unittest.TestCase): +class OperatorTestCase: def test_lt(self): + operator = self.module self.assertRaises(TypeError, operator.lt) self.assertRaises(TypeError, operator.lt, 1j, 2j) self.assertFalse(operator.lt(1, 0)) @@ -44,6 +47,7 @@ class OperatorTestCase(unittest.TestCase): self.assertTrue(operator.lt(1, 2.0)) def test_le(self): + operator = self.module self.assertRaises(TypeError, operator.le) self.assertRaises(TypeError, operator.le, 1j, 2j) self.assertFalse(operator.le(1, 0)) @@ -54,6 +58,7 @@ class OperatorTestCase(unittest.TestCase): self.assertTrue(operator.le(1, 2.0)) def test_eq(self): + operator = self.module class C(object): def __eq__(self, other): raise SyntaxError @@ -67,6 +72,7 @@ class OperatorTestCase(unittest.TestCase): self.assertFalse(operator.eq(1, 2.0)) def test_ne(self): + operator = self.module class C(object): def __ne__(self, other): raise SyntaxError @@ -80,6 +86,7 @@ class OperatorTestCase(unittest.TestCase): self.assertTrue(operator.ne(1, 2.0)) def test_ge(self): + operator = self.module self.assertRaises(TypeError, operator.ge) self.assertRaises(TypeError, operator.ge, 1j, 2j) self.assertTrue(operator.ge(1, 0)) @@ -90,6 +97,7 @@ class OperatorTestCase(unittest.TestCase): self.assertFalse(operator.ge(1, 2.0)) def test_gt(self): + operator = self.module self.assertRaises(TypeError, operator.gt) self.assertRaises(TypeError, operator.gt, 1j, 2j) self.assertTrue(operator.gt(1, 0)) @@ -100,22 +108,26 @@ class OperatorTestCase(unittest.TestCase): self.assertFalse(operator.gt(1, 2.0)) def test_abs(self): + operator = self.module self.assertRaises(TypeError, operator.abs) self.assertRaises(TypeError, operator.abs, None) self.assertEqual(operator.abs(-1), 1) self.assertEqual(operator.abs(1), 1) def test_add(self): + operator = self.module self.assertRaises(TypeError, operator.add) self.assertRaises(TypeError, operator.add, None, None) self.assertTrue(operator.add(3, 4) == 7) def test_bitwise_and(self): + operator = self.module self.assertRaises(TypeError, operator.and_) self.assertRaises(TypeError, operator.and_, None, None) self.assertTrue(operator.and_(0xf, 0xa) == 0xa) def test_concat(self): + operator = self.module self.assertRaises(TypeError, operator.concat) self.assertRaises(TypeError, operator.concat, None, None) self.assertTrue(operator.concat('py', 'thon') == 'python') @@ -125,12 +137,14 @@ class OperatorTestCase(unittest.TestCase): self.assertRaises(TypeError, operator.concat, 13, 29) def test_countOf(self): + operator = self.module self.assertRaises(TypeError, operator.countOf) self.assertRaises(TypeError, operator.countOf, None, None) self.assertTrue(operator.countOf([1, 2, 1, 3, 1, 4], 3) == 1) self.assertTrue(operator.countOf([1, 2, 1, 3, 1, 4], 5) == 0) def test_delitem(self): + operator = self.module a = [4, 3, 2, 1] self.assertRaises(TypeError, operator.delitem, a) self.assertRaises(TypeError, operator.delitem, a, None) @@ -138,33 +152,39 @@ class OperatorTestCase(unittest.TestCase): self.assertTrue(a == [4, 2, 1]) def test_floordiv(self): + operator = self.module self.assertRaises(TypeError, operator.floordiv, 5) self.assertRaises(TypeError, operator.floordiv, None, None) self.assertTrue(operator.floordiv(5, 2) == 2) def test_truediv(self): + operator = self.module self.assertRaises(TypeError, operator.truediv, 5) self.assertRaises(TypeError, operator.truediv, None, None) self.assertTrue(operator.truediv(5, 2) == 2.5) def test_getitem(self): + operator = self.module a = range(10) self.assertRaises(TypeError, operator.getitem) self.assertRaises(TypeError, operator.getitem, a, None) self.assertTrue(operator.getitem(a, 2) == 2) def test_indexOf(self): + operator = self.module self.assertRaises(TypeError, operator.indexOf) self.assertRaises(TypeError, operator.indexOf, None, None) self.assertTrue(operator.indexOf([4, 3, 2, 1], 3) == 1) self.assertRaises(ValueError, operator.indexOf, [4, 3, 2, 1], 0) def test_invert(self): + operator = self.module self.assertRaises(TypeError, operator.invert) self.assertRaises(TypeError, operator.invert, None) self.assertEqual(operator.inv(4), -5) def test_lshift(self): + operator = self.module self.assertRaises(TypeError, operator.lshift) self.assertRaises(TypeError, operator.lshift, None, 42) self.assertTrue(operator.lshift(5, 1) == 10) @@ -172,16 +192,19 @@ class OperatorTestCase(unittest.TestCase): self.assertRaises(ValueError, operator.lshift, 2, -1) def test_mod(self): + operator = self.module self.assertRaises(TypeError, operator.mod) self.assertRaises(TypeError, operator.mod, None, 42) self.assertTrue(operator.mod(5, 2) == 1) def test_mul(self): + operator = self.module self.assertRaises(TypeError, operator.mul) self.assertRaises(TypeError, operator.mul, None, None) self.assertTrue(operator.mul(5, 2) == 10) def test_neg(self): + operator = self.module self.assertRaises(TypeError, operator.neg) self.assertRaises(TypeError, operator.neg, None) self.assertEqual(operator.neg(5), -5) @@ -190,11 +213,13 @@ class OperatorTestCase(unittest.TestCase): self.assertEqual(operator.neg(-0), 0) def test_bitwise_or(self): + operator = self.module self.assertRaises(TypeError, operator.or_) self.assertRaises(TypeError, operator.or_, None, None) self.assertTrue(operator.or_(0xa, 0x5) == 0xf) def test_pos(self): + operator = self.module self.assertRaises(TypeError, operator.pos) self.assertRaises(TypeError, operator.pos, None) self.assertEqual(operator.pos(5), 5) @@ -203,14 +228,15 @@ class OperatorTestCase(unittest.TestCase): self.assertEqual(operator.pos(-0), 0) def test_pow(self): + operator = self.module self.assertRaises(TypeError, operator.pow) self.assertRaises(TypeError, operator.pow, None, None) self.assertEqual(operator.pow(3,5), 3**5) - self.assertEqual(operator.__pow__(3,5), 3**5) self.assertRaises(TypeError, operator.pow, 1) self.assertRaises(TypeError, operator.pow, 1, 2, 3) def test_rshift(self): + operator = self.module self.assertRaises(TypeError, operator.rshift) self.assertRaises(TypeError, operator.rshift, None, 42) self.assertTrue(operator.rshift(5, 1) == 2) @@ -218,12 +244,14 @@ class OperatorTestCase(unittest.TestCase): self.assertRaises(ValueError, operator.rshift, 2, -1) def test_contains(self): + operator = self.module self.assertRaises(TypeError, operator.contains) self.assertRaises(TypeError, operator.contains, None, None) self.assertTrue(operator.contains(range(4), 2)) self.assertFalse(operator.contains(range(4), 5)) def test_setitem(self): + operator = self.module a = list(range(3)) self.assertRaises(TypeError, operator.setitem, a) self.assertRaises(TypeError, operator.setitem, a, None, None) @@ -232,11 +260,13 @@ class OperatorTestCase(unittest.TestCase): self.assertRaises(IndexError, operator.setitem, a, 4, 2) def test_sub(self): + operator = self.module self.assertRaises(TypeError, operator.sub) self.assertRaises(TypeError, operator.sub, None, None) self.assertTrue(operator.sub(5, 2) == 3) def test_truth(self): + operator = self.module class C(object): def __bool__(self): raise SyntaxError @@ -248,11 +278,13 @@ class OperatorTestCase(unittest.TestCase): self.assertFalse(operator.truth([])) def test_bitwise_xor(self): + operator = self.module self.assertRaises(TypeError, operator.xor) self.assertRaises(TypeError, operator.xor, None, None) self.assertTrue(operator.xor(0xb, 0xc) == 0x7) def test_is(self): + operator = self.module a = b = 'xyzpdq' c = a[:3] + b[3:] self.assertRaises(TypeError, operator.is_) @@ -260,6 +292,7 @@ class OperatorTestCase(unittest.TestCase): self.assertFalse(operator.is_(a,c)) def test_is_not(self): + operator = self.module a = b = 'xyzpdq' c = a[:3] + b[3:] self.assertRaises(TypeError, operator.is_not) @@ -267,6 +300,7 @@ class OperatorTestCase(unittest.TestCase): self.assertTrue(operator.is_not(a,c)) def test_attrgetter(self): + operator = self.module class A: pass a = A() @@ -316,6 +350,7 @@ class OperatorTestCase(unittest.TestCase): self.assertEqual(f(a), ('arthur', 'thomas', 'johnson')) def test_itemgetter(self): + operator = self.module a = 'ABCDE' f = operator.itemgetter(2) self.assertEqual(f(a), 'C') @@ -350,12 +385,15 @@ class OperatorTestCase(unittest.TestCase): self.assertRaises(TypeError, operator.itemgetter(2, 'x', 5), data) def test_methodcaller(self): + operator = self.module self.assertRaises(TypeError, operator.methodcaller) class A: def foo(self, *args, **kwds): return args[0] + args[1] def bar(self, f=42): return f + def baz(*args, **kwds): + return kwds['name'], kwds['self'] a = A() f = operator.methodcaller('foo') self.assertRaises(IndexError, f, a) @@ -366,8 +404,11 @@ class OperatorTestCase(unittest.TestCase): self.assertRaises(TypeError, f, a, a) f = operator.methodcaller('bar', f=5) self.assertEqual(f(a), 5) + f = operator.methodcaller('baz', name='spam', self='eggs') + self.assertEqual(f(a), ('spam', 'eggs')) def test_inplace(self): + operator = self.module class C(object): def __iadd__ (self, other): return "iadd" def __iand__ (self, other): return "iand" @@ -396,37 +437,48 @@ class OperatorTestCase(unittest.TestCase): self.assertEqual(operator.itruediv (c, 5), "itruediv") self.assertEqual(operator.ixor (c, 5), "ixor") self.assertEqual(operator.iconcat (c, c), "iadd") - self.assertEqual(operator.__iadd__ (c, 5), "iadd") - self.assertEqual(operator.__iand__ (c, 5), "iand") - self.assertEqual(operator.__ifloordiv__(c, 5), "ifloordiv") - self.assertEqual(operator.__ilshift__ (c, 5), "ilshift") - self.assertEqual(operator.__imod__ (c, 5), "imod") - self.assertEqual(operator.__imul__ (c, 5), "imul") - self.assertEqual(operator.__ior__ (c, 5), "ior") - self.assertEqual(operator.__ipow__ (c, 5), "ipow") - self.assertEqual(operator.__irshift__ (c, 5), "irshift") - self.assertEqual(operator.__isub__ (c, 5), "isub") - self.assertEqual(operator.__itruediv__ (c, 5), "itruediv") - self.assertEqual(operator.__ixor__ (c, 5), "ixor") - self.assertEqual(operator.__iconcat__ (c, c), "iadd") - -def test_main(verbose=None): - import sys - test_classes = ( - OperatorTestCase, - ) - - support.run_unittest(*test_classes) - - # verify reference counting - if verbose and hasattr(sys, "gettotalrefcount"): - import gc - counts = [None] * 5 - for i in range(len(counts)): - support.run_unittest(*test_classes) - gc.collect() - counts[i] = sys.gettotalrefcount() - print(counts) + + def test_length_hint(self): + operator = self.module + class X(object): + def __init__(self, value): + self.value = value + + def __length_hint__(self): + if type(self.value) is type: + raise self.value + else: + return self.value + + self.assertEqual(operator.length_hint([], 2), 0) + self.assertEqual(operator.length_hint(iter([1, 2, 3])), 3) + + self.assertEqual(operator.length_hint(X(2)), 2) + self.assertEqual(operator.length_hint(X(NotImplemented), 4), 4) + self.assertEqual(operator.length_hint(X(TypeError), 12), 12) + with self.assertRaises(TypeError): + operator.length_hint(X("abc")) + with self.assertRaises(ValueError): + operator.length_hint(X(-2)) + with self.assertRaises(LookupError): + operator.length_hint(X(LookupError)) + + def test_dunder_is_original(self): + operator = self.module + + names = [name for name in dir(operator) if not name.startswith('_')] + for name in names: + orig = getattr(operator, name) + dunder = getattr(operator, '__' + name.strip('_') + '__', None) + if dunder: + self.assertIs(dunder, orig) + +class PyOperatorTestCase(OperatorTestCase, unittest.TestCase): + module = py_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class COperatorTestCase(OperatorTestCase, unittest.TestCase): + module = c_operator if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py index eaee504ac8..3c8c6124a2 100644 --- a/Lib/test/test_optparse.py +++ b/Lib/test/test_optparse.py @@ -730,7 +730,7 @@ class TestStandard(BaseTest): def test_short_and_long_option_split(self): self.assertParseOK(["-a", "xyz", "--foo", "bar"], {'a': 'xyz', 'boo': None, 'foo': ["bar"]}, - []), + []) def test_short_option_split_long_option_append(self): self.assertParseOK(["--foo=bar", "-b", "123", "--foo", "baz"], @@ -740,15 +740,15 @@ class TestStandard(BaseTest): def test_short_option_split_one_positional_arg(self): self.assertParseOK(["-a", "foo", "bar"], {'a': "foo", 'boo': None, 'foo': None}, - ["bar"]), + ["bar"]) def test_short_option_consumes_separator(self): self.assertParseOK(["-a", "--", "foo", "bar"], {'a': "--", 'boo': None, 'foo': None}, - ["foo", "bar"]), + ["foo", "bar"]) self.assertParseOK(["-a", "--", "--foo", "bar"], {'a': "--", 'boo': None, 'foo': ["bar"]}, - []), + []) def test_short_option_joined_and_separator(self): self.assertParseOK(["-ab", "--", "--foo", "bar"], diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index d70a0aeaf8..50240939d8 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -24,6 +24,9 @@ import itertools import stat import locale import codecs +import decimal +import fractions +import pickle try: import threading except ImportError: @@ -32,6 +35,10 @@ try: import resource except ImportError: resource = None +try: + import fcntl +except ImportError: + fcntl = None from test.script_helper import assert_python_ok @@ -256,6 +263,13 @@ class StatAttributeTests(unittest.TestCase): warnings.simplefilter("ignore", DeprecationWarning) self.check_stat_attributes(fname) + def test_stat_result_pickle(self): + result = os.stat(self.fname) + p = pickle.dumps(result) + self.assertIn(b'\x03cos\nstat_result\n', p) + unpickled = pickle.loads(p) + self.assertEqual(result, unpickled) + @unittest.skipUnless(hasattr(os, 'statvfs'), 'test needs os.statvfs()') def test_statvfs_attributes(self): try: @@ -263,7 +277,7 @@ class StatAttributeTests(unittest.TestCase): except OSError as e: # On AtheOS, glibc always returns ENOSYS if e.errno == errno.ENOSYS: - self.skipTest('glibc always returns ENOSYS on AtheOS') + self.skipTest('os.statvfs() failed with ENOSYS') # Make sure direct access works self.assertEqual(result.f_bfree, result[3]) @@ -300,6 +314,21 @@ class StatAttributeTests(unittest.TestCase): except TypeError: pass + @unittest.skipUnless(hasattr(os, 'statvfs'), + "need os.statvfs()") + def test_statvfs_result_pickle(self): + try: + result = os.statvfs(self.fname) + except OSError as e: + # On AtheOS, glibc always returns ENOSYS + if e.errno == errno.ENOSYS: + self.skipTest('os.statvfs() failed with ENOSYS') + + p = pickle.dumps(result) + self.assertIn(b'\x03cos\nstatvfs_result\n', p) + unpickled = pickle.loads(p) + self.assertEqual(result, unpickled) + def test_utime_dir(self): delta = 1000000 st = os.stat(support.TESTFN) @@ -478,9 +507,9 @@ class StatAttributeTests(unittest.TestCase): # Verify that an open file can be stat'ed try: os.stat(r"c:\pagefile.sys") - except WindowsError as e: - if e.errno == 2: # file does not exist; cannot run test - self.skipTest(r'c:\pagefile.sys does not exist') + except FileNotFoundError: + self.skipTest(r'c:\pagefile.sys does not exist') + except OSError as e: self.fail("Could not stat pagefile.sys") @unittest.skipUnless(sys.platform == "win32", "Win32 specific tests") @@ -876,6 +905,17 @@ class MakedirTests(unittest.TestCase): os.makedirs(path, mode=mode, exist_ok=True) os.umask(old_mask) + @unittest.skipUnless(hasattr(os, 'chown'), 'test needs os.chown') + def test_chown_uid_gid_arguments_must_be_index(self): + stat = os.stat(support.TESTFN) + uid = stat.st_uid + gid = stat.st_gid + for value in (-1.0, -1j, decimal.Decimal(-1), fractions.Fraction(-2, 2)): + self.assertRaises(TypeError, os.chown, support.TESTFN, value, gid) + self.assertRaises(TypeError, os.chown, support.TESTFN, uid, value) + self.assertIsNone(os.chown(support.TESTFN, uid, gid)) + self.assertIsNone(os.chown(support.TESTFN, -1, -1)) + def test_exist_ok_s_isgid_directory(self): path = os.path.join(support.TESTFN, 'dir1') S_ISGID = stat.S_ISGID @@ -1133,27 +1173,27 @@ class ExecTests(unittest.TestCase): @unittest.skipUnless(sys.platform == "win32", "Win32 specific tests") class Win32ErrorTests(unittest.TestCase): def test_rename(self): - self.assertRaises(WindowsError, os.rename, support.TESTFN, support.TESTFN+".bak") + self.assertRaises(OSError, os.rename, support.TESTFN, support.TESTFN+".bak") def test_remove(self): - self.assertRaises(WindowsError, os.remove, support.TESTFN) + self.assertRaises(OSError, os.remove, support.TESTFN) def test_chdir(self): - self.assertRaises(WindowsError, os.chdir, support.TESTFN) + self.assertRaises(OSError, os.chdir, support.TESTFN) def test_mkdir(self): f = open(support.TESTFN, "w") try: - self.assertRaises(WindowsError, os.mkdir, support.TESTFN) + self.assertRaises(OSError, os.mkdir, support.TESTFN) finally: f.close() os.unlink(support.TESTFN) def test_utime(self): - self.assertRaises(WindowsError, os.utime, support.TESTFN, None) + self.assertRaises(OSError, os.utime, support.TESTFN, None) def test_chmod(self): - self.assertRaises(WindowsError, os.chmod, support.TESTFN, 0) + self.assertRaises(OSError, os.chmod, support.TESTFN, 0) class TestInvalidFD(unittest.TestCase): singles = ["fchdir", "dup", "fdopen", "fdatasync", "fstat", @@ -1287,41 +1327,57 @@ class PosixUidGidTests(unittest.TestCase): @unittest.skipUnless(hasattr(os, 'setuid'), 'test needs os.setuid()') def test_setuid(self): if os.getuid() != 0: - self.assertRaises(os.error, os.setuid, 0) + self.assertRaises(OSError, os.setuid, 0) self.assertRaises(OverflowError, os.setuid, 1<<32) @unittest.skipUnless(hasattr(os, 'setgid'), 'test needs os.setgid()') def test_setgid(self): if os.getuid() != 0 and not HAVE_WHEEL_GROUP: - self.assertRaises(os.error, os.setgid, 0) + self.assertRaises(OSError, os.setgid, 0) self.assertRaises(OverflowError, os.setgid, 1<<32) @unittest.skipUnless(hasattr(os, 'seteuid'), 'test needs os.seteuid()') def test_seteuid(self): if os.getuid() != 0: - self.assertRaises(os.error, os.seteuid, 0) + self.assertRaises(OSError, os.seteuid, 0) self.assertRaises(OverflowError, os.seteuid, 1<<32) @unittest.skipUnless(hasattr(os, 'setegid'), 'test needs os.setegid()') def test_setegid(self): if os.getuid() != 0 and not HAVE_WHEEL_GROUP: - self.assertRaises(os.error, os.setegid, 0) + self.assertRaises(OSError, os.setegid, 0) self.assertRaises(OverflowError, os.setegid, 1<<32) @unittest.skipUnless(hasattr(os, 'setreuid'), 'test needs os.setreuid()') def test_setreuid(self): if os.getuid() != 0: - self.assertRaises(os.error, os.setreuid, 0, 0) + self.assertRaises(OSError, os.setreuid, 0, 0) self.assertRaises(OverflowError, os.setreuid, 1<<32, 0) self.assertRaises(OverflowError, os.setreuid, 0, 1<<32) + @unittest.skipUnless(hasattr(os, 'setreuid'), 'test needs os.setreuid()') + def test_setreuid_neg1(self): + # Needs to accept -1. We run this in a subprocess to avoid + # altering the test runner's process state (issue8045). + subprocess.check_call([ + sys.executable, '-c', + 'import os,sys;os.setreuid(-1,-1);sys.exit(0)']) + @unittest.skipUnless(hasattr(os, 'setregid'), 'test needs os.setregid()') def test_setregid(self): if os.getuid() != 0 and not HAVE_WHEEL_GROUP: - self.assertRaises(os.error, os.setregid, 0, 0) + self.assertRaises(OSError, os.setregid, 0, 0) self.assertRaises(OverflowError, os.setregid, 1<<32, 0) self.assertRaises(OverflowError, os.setregid, 0, 1<<32) + @unittest.skipUnless(hasattr(os, 'setregid'), 'test needs os.setregid()') + def test_setregid_neg1(self): + # Needs to accept -1. We run this in a subprocess to avoid + # altering the test runner's process state (issue8045). + subprocess.check_call([ + sys.executable, '-c', + 'import os,sys;os.setregid(-1,-1);sys.exit(0)']) + @unittest.skipIf(sys.platform == "win32", "Posix specific tests") class Pep383Tests(unittest.TestCase): def setUp(self): @@ -1511,6 +1567,52 @@ class Win32KillTests(unittest.TestCase): @unittest.skipUnless(sys.platform == "win32", "Win32 specific tests") +class Win32ListdirTests(unittest.TestCase): + """Test listdir on Windows.""" + + def setUp(self): + self.created_paths = [] + for i in range(2): + dir_name = 'SUB%d' % i + dir_path = os.path.join(support.TESTFN, dir_name) + file_name = 'FILE%d' % i + file_path = os.path.join(support.TESTFN, file_name) + os.makedirs(dir_path) + with open(file_path, 'w') as f: + f.write("I'm %s and proud of it. Blame test_os.\n" % file_path) + self.created_paths.extend([dir_name, file_name]) + self.created_paths.sort() + + def tearDown(self): + shutil.rmtree(support.TESTFN) + + def test_listdir_no_extended_path(self): + """Test when the path is not an "extended" path.""" + # unicode + self.assertEqual( + sorted(os.listdir(support.TESTFN)), + self.created_paths) + # bytes + self.assertEqual( + sorted(os.listdir(os.fsencode(support.TESTFN))), + [os.fsencode(path) for path in self.created_paths]) + + def test_listdir_extended_path(self): + """Test when the path starts with '\\\\?\\'.""" + # See: http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247(v=vs.85).aspx#maxpath + # unicode + path = '\\\\?\\' + os.path.abspath(support.TESTFN) + self.assertEqual( + sorted(os.listdir(path)), + self.created_paths) + # bytes + path = b'\\\\?\\' + os.fsencode(os.path.abspath(support.TESTFN)) + self.assertEqual( + sorted(os.listdir(path)), + [os.fsencode(path) for path in self.created_paths]) + + +@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests") @support.skip_unless_symlink class Win32SymlinkTests(unittest.TestCase): filelink = 'filelinktest' @@ -2175,6 +2277,197 @@ class TermsizeTests(unittest.TestCase): self.assertEqual(expected, actual) +class OSErrorTests(unittest.TestCase): + def setUp(self): + class Str(str): + pass + + self.bytes_filenames = [] + self.unicode_filenames = [] + if support.TESTFN_UNENCODABLE is not None: + decoded = support.TESTFN_UNENCODABLE + else: + decoded = support.TESTFN + self.unicode_filenames.append(decoded) + self.unicode_filenames.append(Str(decoded)) + if support.TESTFN_UNDECODABLE is not None: + encoded = support.TESTFN_UNDECODABLE + else: + encoded = os.fsencode(support.TESTFN) + self.bytes_filenames.append(encoded) + self.bytes_filenames.append(memoryview(encoded)) + + self.filenames = self.bytes_filenames + self.unicode_filenames + + def test_oserror_filename(self): + funcs = [ + (self.filenames, os.chdir,), + (self.filenames, os.chmod, 0o777), + (self.filenames, os.lstat,), + (self.filenames, os.open, os.O_RDONLY), + (self.filenames, os.rmdir,), + (self.filenames, os.stat,), + (self.filenames, os.unlink,), + ] + if sys.platform == "win32": + funcs.extend(( + (self.bytes_filenames, os.rename, b"dst"), + (self.bytes_filenames, os.replace, b"dst"), + (self.unicode_filenames, os.rename, "dst"), + (self.unicode_filenames, os.replace, "dst"), + # Issue #16414: Don't test undecodable names with listdir() + # because of a Windows bug. + # + # With the ANSI code page 932, os.listdir(b'\xe7') return an + # empty list (instead of failing), whereas os.listdir(b'\xff') + # raises a FileNotFoundError. It looks like a Windows bug: + # b'\xe7' directory does not exist, FindFirstFileA(b'\xe7') + # fails with ERROR_FILE_NOT_FOUND (2), instead of + # ERROR_PATH_NOT_FOUND (3). + (self.unicode_filenames, os.listdir,), + )) + else: + funcs.extend(( + (self.filenames, os.listdir,), + (self.filenames, os.rename, "dst"), + (self.filenames, os.replace, "dst"), + )) + if hasattr(os, "chown"): + funcs.append((self.filenames, os.chown, 0, 0)) + if hasattr(os, "lchown"): + funcs.append((self.filenames, os.lchown, 0, 0)) + if hasattr(os, "truncate"): + funcs.append((self.filenames, os.truncate, 0)) + if hasattr(os, "chflags"): + funcs.append((self.filenames, os.chflags, 0)) + if hasattr(os, "lchflags"): + funcs.append((self.filenames, os.lchflags, 0)) + if hasattr(os, "chroot"): + funcs.append((self.filenames, os.chroot,)) + if hasattr(os, "link"): + if sys.platform == "win32": + funcs.append((self.bytes_filenames, os.link, b"dst")) + funcs.append((self.unicode_filenames, os.link, "dst")) + else: + funcs.append((self.filenames, os.link, "dst")) + if hasattr(os, "listxattr"): + funcs.extend(( + (self.filenames, os.listxattr,), + (self.filenames, os.getxattr, "user.test"), + (self.filenames, os.setxattr, "user.test", b'user'), + (self.filenames, os.removexattr, "user.test"), + )) + if hasattr(os, "lchmod"): + funcs.append((self.filenames, os.lchmod, 0o777)) + if hasattr(os, "readlink"): + if sys.platform == "win32": + funcs.append((self.unicode_filenames, os.readlink,)) + else: + funcs.append((self.filenames, os.readlink,)) + + for filenames, func, *func_args in funcs: + for name in filenames: + try: + func(name, *func_args) + except OSError as err: + self.assertIs(err.filename, name) + else: + self.fail("No exception thrown by {}".format(func)) + +class CPUCountTests(unittest.TestCase): + def test_cpu_count(self): + cpus = os.cpu_count() + if cpus is not None: + self.assertIsInstance(cpus, int) + self.assertGreater(cpus, 0) + else: + self.skipTest("Could not determine the number of CPUs") + + +class FDInheritanceTests(unittest.TestCase): + def test_get_set_inheritable(self): + fd = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd) + self.assertEqual(os.get_inheritable(fd), False) + + os.set_inheritable(fd, True) + self.assertEqual(os.get_inheritable(fd), True) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_get_inheritable_cloexec(self): + fd = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd) + self.assertEqual(os.get_inheritable(fd), False) + + # clear FD_CLOEXEC flag + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + flags &= ~fcntl.FD_CLOEXEC + fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + self.assertEqual(os.get_inheritable(fd), True) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_set_inheritable_cloexec(self): + fd = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd) + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + fcntl.FD_CLOEXEC) + + os.set_inheritable(fd, True) + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + 0) + + def test_open(self): + fd = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd) + self.assertEqual(os.get_inheritable(fd), False) + + @unittest.skipUnless(hasattr(os, 'pipe'), "need os.pipe()") + def test_pipe(self): + rfd, wfd = os.pipe() + self.addCleanup(os.close, rfd) + self.addCleanup(os.close, wfd) + self.assertEqual(os.get_inheritable(rfd), False) + self.assertEqual(os.get_inheritable(wfd), False) + + def test_dup(self): + fd1 = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd1) + + fd2 = os.dup(fd1) + self.addCleanup(os.close, fd2) + self.assertEqual(os.get_inheritable(fd2), False) + + @unittest.skipUnless(hasattr(os, 'dup2'), "need os.dup2()") + def test_dup2(self): + fd = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd) + + # inheritable by default + fd2 = os.open(__file__, os.O_RDONLY) + try: + os.dup2(fd, fd2) + self.assertEqual(os.get_inheritable(fd2), True) + finally: + os.close(fd2) + + # force non-inheritable + fd3 = os.open(__file__, os.O_RDONLY) + try: + os.dup2(fd, fd3, inheritable=False) + self.assertEqual(os.get_inheritable(fd3), False) + finally: + os.close(fd3) + + @unittest.skipUnless(hasattr(os, 'openpty'), "need os.openpty()") + def test_openpty(self): + master_fd, slave_fd = os.openpty() + self.addCleanup(os.close, master_fd) + self.addCleanup(os.close, slave_fd) + self.assertEqual(os.get_inheritable(master_fd), False) + self.assertEqual(os.get_inheritable(slave_fd), False) + + @support.reap_threads def test_main(): support.run_unittest( @@ -2192,6 +2485,7 @@ def test_main(): PosixUidGidTests, Pep383Tests, Win32KillTests, + Win32ListdirTests, Win32SymlinkTests, NonLocalSymlinkTests, FSEncodingTests, @@ -2204,7 +2498,10 @@ def test_main(): ExtendedAttributeTests, Win32DeprecatedBytesAPI, TermsizeTests, + OSErrorTests, RemoveDirsTests, + CPUCountTests, + FDInheritanceTests, ) if __name__ == "__main__": diff --git a/Lib/test/test_ossaudiodev.py b/Lib/test/test_ossaudiodev.py index 3908a0506e..c9e2a24767 100644 --- a/Lib/test/test_ossaudiodev.py +++ b/Lib/test/test_ossaudiodev.py @@ -44,7 +44,7 @@ class OSSAudioDevTests(unittest.TestCase): def play_sound_file(self, data, rate, ssize, nchannels): try: dsp = ossaudiodev.open('w') - except IOError as msg: + except OSError as msg: if msg.args[0] in (errno.EACCES, errno.ENOENT, errno.ENODEV, errno.EBUSY): raise unittest.SkipTest(msg) @@ -190,7 +190,7 @@ class OSSAudioDevTests(unittest.TestCase): def test_main(): try: dsp = ossaudiodev.open('w') - except (ossaudiodev.error, IOError) as msg: + except (ossaudiodev.error, OSError) as msg: if msg.args[0] in (errno.EACCES, errno.ENOENT, errno.ENODEV, errno.EBUSY): raise unittest.SkipTest(msg) diff --git a/Lib/test/test_pathlib.py b/Lib/test/test_pathlib.py new file mode 100644 index 0000000000..2268f94a72 --- /dev/null +++ b/Lib/test/test_pathlib.py @@ -0,0 +1,1840 @@ +import collections +import io +import os +import errno +import pathlib +import pickle +import shutil +import socket +import stat +import sys +import tempfile +import unittest +from contextlib import contextmanager + +from test import support +TESTFN = support.TESTFN + +try: + import grp, pwd +except ImportError: + grp = pwd = None + + +class _BaseFlavourTest(object): + + def _check_parse_parts(self, arg, expected): + f = self.flavour.parse_parts + sep = self.flavour.sep + altsep = self.flavour.altsep + actual = f([x.replace('/', sep) for x in arg]) + self.assertEqual(actual, expected) + if altsep: + actual = f([x.replace('/', altsep) for x in arg]) + self.assertEqual(actual, expected) + + def test_parse_parts_common(self): + check = self._check_parse_parts + sep = self.flavour.sep + # Unanchored parts + check([], ('', '', [])) + check(['a'], ('', '', ['a'])) + check(['a/'], ('', '', ['a'])) + check(['a', 'b'], ('', '', ['a', 'b'])) + # Expansion + check(['a/b'], ('', '', ['a', 'b'])) + check(['a/b/'], ('', '', ['a', 'b'])) + check(['a', 'b/c', 'd'], ('', '', ['a', 'b', 'c', 'd'])) + # Collapsing and stripping excess slashes + check(['a', 'b//c', 'd'], ('', '', ['a', 'b', 'c', 'd'])) + check(['a', 'b/c/', 'd'], ('', '', ['a', 'b', 'c', 'd'])) + # Eliminating standalone dots + check(['.'], ('', '', [])) + check(['.', '.', 'b'], ('', '', ['b'])) + check(['a', '.', 'b'], ('', '', ['a', 'b'])) + check(['a', '.', '.'], ('', '', ['a'])) + # The first part is anchored + check(['/a/b'], ('', sep, [sep, 'a', 'b'])) + check(['/a', 'b'], ('', sep, [sep, 'a', 'b'])) + check(['/a/', 'b'], ('', sep, [sep, 'a', 'b'])) + # Ignoring parts before an anchored part + check(['a', '/b', 'c'], ('', sep, [sep, 'b', 'c'])) + check(['a', '/b', '/c'], ('', sep, [sep, 'c'])) + + +class PosixFlavourTest(_BaseFlavourTest, unittest.TestCase): + flavour = pathlib._posix_flavour + + def test_parse_parts(self): + check = self._check_parse_parts + # Collapsing of excess leading slashes, except for the double-slash + # special case. + check(['//a', 'b'], ('', '//', ['//', 'a', 'b'])) + check(['///a', 'b'], ('', '/', ['/', 'a', 'b'])) + check(['////a', 'b'], ('', '/', ['/', 'a', 'b'])) + # Paths which look like NT paths aren't treated specially + check(['c:a'], ('', '', ['c:a'])) + check(['c:\\a'], ('', '', ['c:\\a'])) + check(['\\a'], ('', '', ['\\a'])) + + def test_splitroot(self): + f = self.flavour.splitroot + self.assertEqual(f(''), ('', '', '')) + self.assertEqual(f('a'), ('', '', 'a')) + self.assertEqual(f('a/b'), ('', '', 'a/b')) + self.assertEqual(f('a/b/'), ('', '', 'a/b/')) + self.assertEqual(f('/a'), ('', '/', 'a')) + self.assertEqual(f('/a/b'), ('', '/', 'a/b')) + self.assertEqual(f('/a/b/'), ('', '/', 'a/b/')) + # The root is collapsed when there are redundant slashes + # except when there are exactly two leading slashes, which + # is a special case in POSIX. + self.assertEqual(f('//a'), ('', '//', 'a')) + self.assertEqual(f('///a'), ('', '/', 'a')) + self.assertEqual(f('///a/b'), ('', '/', 'a/b')) + # Paths which look like NT paths aren't treated specially + self.assertEqual(f('c:/a/b'), ('', '', 'c:/a/b')) + self.assertEqual(f('\\/a/b'), ('', '', '\\/a/b')) + self.assertEqual(f('\\a\\b'), ('', '', '\\a\\b')) + + +class NTFlavourTest(_BaseFlavourTest, unittest.TestCase): + flavour = pathlib._windows_flavour + + def test_parse_parts(self): + check = self._check_parse_parts + # First part is anchored + check(['c:'], ('c:', '', ['c:'])) + check(['c:\\'], ('c:', '\\', ['c:\\'])) + check(['\\'], ('', '\\', ['\\'])) + check(['c:a'], ('c:', '', ['c:', 'a'])) + check(['c:\\a'], ('c:', '\\', ['c:\\', 'a'])) + check(['\\a'], ('', '\\', ['\\', 'a'])) + # UNC paths + check(['\\\\a\\b'], ('\\\\a\\b', '\\', ['\\\\a\\b\\'])) + check(['\\\\a\\b\\'], ('\\\\a\\b', '\\', ['\\\\a\\b\\'])) + check(['\\\\a\\b\\c'], ('\\\\a\\b', '\\', ['\\\\a\\b\\', 'c'])) + # Second part is anchored, so that the first part is ignored + check(['a', 'Z:b', 'c'], ('Z:', '', ['Z:', 'b', 'c'])) + check(['a', 'Z:\\b', 'c'], ('Z:', '\\', ['Z:\\', 'b', 'c'])) + check(['a', '\\b', 'c'], ('', '\\', ['\\', 'b', 'c'])) + # UNC paths + check(['a', '\\\\b\\c', 'd'], ('\\\\b\\c', '\\', ['\\\\b\\c\\', 'd'])) + # Collapsing and stripping excess slashes + check(['a', 'Z:\\\\b\\\\c\\', 'd\\'], ('Z:', '\\', ['Z:\\', 'b', 'c', 'd'])) + # UNC paths + check(['a', '\\\\b\\c\\\\', 'd'], ('\\\\b\\c', '\\', ['\\\\b\\c\\', 'd'])) + # Extended paths + check(['\\\\?\\c:\\'], ('\\\\?\\c:', '\\', ['\\\\?\\c:\\'])) + check(['\\\\?\\c:\\a'], ('\\\\?\\c:', '\\', ['\\\\?\\c:\\', 'a'])) + # Extended UNC paths (format is "\\?\UNC\server\share") + check(['\\\\?\\UNC\\b\\c'], ('\\\\?\\UNC\\b\\c', '\\', ['\\\\?\\UNC\\b\\c\\'])) + check(['\\\\?\\UNC\\b\\c\\d'], ('\\\\?\\UNC\\b\\c', '\\', ['\\\\?\\UNC\\b\\c\\', 'd'])) + + def test_splitroot(self): + f = self.flavour.splitroot + self.assertEqual(f(''), ('', '', '')) + self.assertEqual(f('a'), ('', '', 'a')) + self.assertEqual(f('a\\b'), ('', '', 'a\\b')) + self.assertEqual(f('\\a'), ('', '\\', 'a')) + self.assertEqual(f('\\a\\b'), ('', '\\', 'a\\b')) + self.assertEqual(f('c:a\\b'), ('c:', '', 'a\\b')) + self.assertEqual(f('c:\\a\\b'), ('c:', '\\', 'a\\b')) + # Redundant slashes in the root are collapsed + self.assertEqual(f('\\\\a'), ('', '\\', 'a')) + self.assertEqual(f('\\\\\\a/b'), ('', '\\', 'a/b')) + self.assertEqual(f('c:\\\\a'), ('c:', '\\', 'a')) + self.assertEqual(f('c:\\\\\\a/b'), ('c:', '\\', 'a/b')) + # Valid UNC paths + self.assertEqual(f('\\\\a\\b'), ('\\\\a\\b', '\\', '')) + self.assertEqual(f('\\\\a\\b\\'), ('\\\\a\\b', '\\', '')) + self.assertEqual(f('\\\\a\\b\\c\\d'), ('\\\\a\\b', '\\', 'c\\d')) + # These are non-UNC paths (according to ntpath.py and test_ntpath) + # However, command.com says such paths are invalid, so it's + # difficult to know what the right semantics are + self.assertEqual(f('\\\\\\a\\b'), ('', '\\', 'a\\b')) + self.assertEqual(f('\\\\a'), ('', '\\', 'a')) + + +# +# Tests for the pure classes +# + +class _BasePurePathTest(object): + + # keys are canonical paths, values are list of tuples of arguments + # supposed to produce equal paths + equivalences = { + 'a/b': [ + ('a', 'b'), ('a/', 'b'), ('a', 'b/'), ('a/', 'b/'), + ('a/b/',), ('a//b',), ('a//b//',), + # empty components get removed + ('', 'a', 'b'), ('a', '', 'b'), ('a', 'b', ''), + ], + '/b/c/d': [ + ('a', '/b/c', 'd'), ('a', '///b//c', 'd/'), + ('/a', '/b/c', 'd'), + # empty components get removed + ('/', 'b', '', 'c/d'), ('/', '', 'b/c/d'), ('', '/b/c/d'), + ], + } + + def setUp(self): + p = self.cls('a') + self.flavour = p._flavour + self.sep = self.flavour.sep + self.altsep = self.flavour.altsep + + def test_constructor_common(self): + P = self.cls + p = P('a') + self.assertIsInstance(p, P) + P('a', 'b', 'c') + P('/a', 'b', 'c') + P('a/b/c') + P('/a/b/c') + self.assertEqual(P(P('a')), P('a')) + self.assertEqual(P(P('a'), 'b'), P('a/b')) + self.assertEqual(P(P('a'), P('b')), P('a/b')) + + def test_join_common(self): + P = self.cls + p = P('a/b') + pp = p.joinpath('c') + self.assertEqual(pp, P('a/b/c')) + self.assertIs(type(pp), type(p)) + pp = p.joinpath('c', 'd') + self.assertEqual(pp, P('a/b/c/d')) + pp = p.joinpath(P('c')) + self.assertEqual(pp, P('a/b/c')) + pp = p.joinpath('/c') + self.assertEqual(pp, P('/c')) + + def test_div_common(self): + # Basically the same as joinpath() + P = self.cls + p = P('a/b') + pp = p / 'c' + self.assertEqual(pp, P('a/b/c')) + self.assertIs(type(pp), type(p)) + pp = p / 'c/d' + self.assertEqual(pp, P('a/b/c/d')) + pp = p / 'c' / 'd' + self.assertEqual(pp, P('a/b/c/d')) + pp = 'c' / p / 'd' + self.assertEqual(pp, P('c/a/b/d')) + pp = p / P('c') + self.assertEqual(pp, P('a/b/c')) + pp = p/ '/c' + self.assertEqual(pp, P('/c')) + + def _check_str(self, expected, args): + p = self.cls(*args) + self.assertEqual(str(p), expected.replace('/', self.sep)) + + def test_str_common(self): + # Canonicalized paths roundtrip + for pathstr in ('a', 'a/b', 'a/b/c', '/', '/a/b', '/a/b/c'): + self._check_str(pathstr, (pathstr,)) + # Special case for the empty path + self._check_str('.', ('',)) + # Other tests for str() are in test_equivalences() + + def test_as_posix_common(self): + P = self.cls + for pathstr in ('a', 'a/b', 'a/b/c', '/', '/a/b', '/a/b/c'): + self.assertEqual(P(pathstr).as_posix(), pathstr) + # Other tests for as_posix() are in test_equivalences() + + def test_as_bytes_common(self): + sep = os.fsencode(self.sep) + P = self.cls + self.assertEqual(bytes(P('a/b')), b'a' + sep + b'b') + + def test_as_uri_common(self): + P = self.cls + with self.assertRaises(ValueError): + P('a').as_uri() + with self.assertRaises(ValueError): + P().as_uri() + + def test_repr_common(self): + for pathstr in ('a', 'a/b', 'a/b/c', '/', '/a/b', '/a/b/c'): + p = self.cls(pathstr) + clsname = p.__class__.__name__ + r = repr(p) + # The repr() is in the form ClassName("forward-slashes path") + self.assertTrue(r.startswith(clsname + '('), r) + self.assertTrue(r.endswith(')'), r) + inner = r[len(clsname) + 1 : -1] + self.assertEqual(eval(inner), p.as_posix()) + # The repr() roundtrips + q = eval(r, pathlib.__dict__) + self.assertIs(q.__class__, p.__class__) + self.assertEqual(q, p) + self.assertEqual(repr(q), r) + + def test_eq_common(self): + P = self.cls + self.assertEqual(P('a/b'), P('a/b')) + self.assertEqual(P('a/b'), P('a', 'b')) + self.assertNotEqual(P('a/b'), P('a')) + self.assertNotEqual(P('a/b'), P('/a/b')) + self.assertNotEqual(P('a/b'), P()) + self.assertNotEqual(P('/a/b'), P('/')) + self.assertNotEqual(P(), P('/')) + self.assertNotEqual(P(), "") + self.assertNotEqual(P(), {}) + self.assertNotEqual(P(), int) + + def test_match_common(self): + P = self.cls + self.assertRaises(ValueError, P('a').match, '') + self.assertRaises(ValueError, P('a').match, '.') + # Simple relative pattern + self.assertTrue(P('b.py').match('b.py')) + self.assertTrue(P('a/b.py').match('b.py')) + self.assertTrue(P('/a/b.py').match('b.py')) + self.assertFalse(P('a.py').match('b.py')) + self.assertFalse(P('b/py').match('b.py')) + self.assertFalse(P('/a.py').match('b.py')) + self.assertFalse(P('b.py/c').match('b.py')) + # Wilcard relative pattern + self.assertTrue(P('b.py').match('*.py')) + self.assertTrue(P('a/b.py').match('*.py')) + self.assertTrue(P('/a/b.py').match('*.py')) + self.assertFalse(P('b.pyc').match('*.py')) + self.assertFalse(P('b./py').match('*.py')) + self.assertFalse(P('b.py/c').match('*.py')) + # Multi-part relative pattern + self.assertTrue(P('ab/c.py').match('a*/*.py')) + self.assertTrue(P('/d/ab/c.py').match('a*/*.py')) + self.assertFalse(P('a.py').match('a*/*.py')) + self.assertFalse(P('/dab/c.py').match('a*/*.py')) + self.assertFalse(P('ab/c.py/d').match('a*/*.py')) + # Absolute pattern + self.assertTrue(P('/b.py').match('/*.py')) + self.assertFalse(P('b.py').match('/*.py')) + self.assertFalse(P('a/b.py').match('/*.py')) + self.assertFalse(P('/a/b.py').match('/*.py')) + # Multi-part absolute pattern + self.assertTrue(P('/a/b.py').match('/a/*.py')) + self.assertFalse(P('/ab.py').match('/a/*.py')) + self.assertFalse(P('/a/b/c.py').match('/a/*.py')) + + def test_ordering_common(self): + # Ordering is tuple-alike + def assertLess(a, b): + self.assertLess(a, b) + self.assertGreater(b, a) + P = self.cls + a = P('a') + b = P('a/b') + c = P('abc') + d = P('b') + assertLess(a, b) + assertLess(a, c) + assertLess(a, d) + assertLess(b, c) + assertLess(c, d) + P = self.cls + a = P('/a') + b = P('/a/b') + c = P('/abc') + d = P('/b') + assertLess(a, b) + assertLess(a, c) + assertLess(a, d) + assertLess(b, c) + assertLess(c, d) + with self.assertRaises(TypeError): + P() < {} + + def test_parts_common(self): + # `parts` returns a tuple + sep = self.sep + P = self.cls + p = P('a/b') + parts = p.parts + self.assertEqual(parts, ('a', 'b')) + # The object gets reused + self.assertIs(parts, p.parts) + # When the path is absolute, the anchor is a separate part + p = P('/a/b') + parts = p.parts + self.assertEqual(parts, (sep, 'a', 'b')) + + def test_equivalences(self): + for k, tuples in self.equivalences.items(): + canon = k.replace('/', self.sep) + posix = k.replace(self.sep, '/') + if canon != posix: + tuples = tuples + [ + tuple(part.replace('/', self.sep) for part in t) + for t in tuples + ] + tuples.append((posix, )) + pcanon = self.cls(canon) + for t in tuples: + p = self.cls(*t) + self.assertEqual(p, pcanon, "failed with args {}".format(t)) + self.assertEqual(hash(p), hash(pcanon)) + self.assertEqual(str(p), canon) + self.assertEqual(p.as_posix(), posix) + + def test_parent_common(self): + # Relative + P = self.cls + p = P('a/b/c') + self.assertEqual(p.parent, P('a/b')) + self.assertEqual(p.parent.parent, P('a')) + self.assertEqual(p.parent.parent.parent, P()) + self.assertEqual(p.parent.parent.parent.parent, P()) + # Anchored + p = P('/a/b/c') + self.assertEqual(p.parent, P('/a/b')) + self.assertEqual(p.parent.parent, P('/a')) + self.assertEqual(p.parent.parent.parent, P('/')) + self.assertEqual(p.parent.parent.parent.parent, P('/')) + + def test_parents_common(self): + # Relative + P = self.cls + p = P('a/b/c') + par = p.parents + self.assertEqual(len(par), 3) + self.assertEqual(par[0], P('a/b')) + self.assertEqual(par[1], P('a')) + self.assertEqual(par[2], P('.')) + self.assertEqual(list(par), [P('a/b'), P('a'), P('.')]) + with self.assertRaises(IndexError): + par[-1] + with self.assertRaises(IndexError): + par[3] + with self.assertRaises(TypeError): + par[0] = p + # Anchored + p = P('/a/b/c') + par = p.parents + self.assertEqual(len(par), 3) + self.assertEqual(par[0], P('/a/b')) + self.assertEqual(par[1], P('/a')) + self.assertEqual(par[2], P('/')) + self.assertEqual(list(par), [P('/a/b'), P('/a'), P('/')]) + with self.assertRaises(IndexError): + par[3] + + def test_drive_common(self): + P = self.cls + self.assertEqual(P('a/b').drive, '') + self.assertEqual(P('/a/b').drive, '') + self.assertEqual(P('').drive, '') + + def test_root_common(self): + P = self.cls + sep = self.sep + self.assertEqual(P('').root, '') + self.assertEqual(P('a/b').root, '') + self.assertEqual(P('/').root, sep) + self.assertEqual(P('/a/b').root, sep) + + def test_anchor_common(self): + P = self.cls + sep = self.sep + self.assertEqual(P('').anchor, '') + self.assertEqual(P('a/b').anchor, '') + self.assertEqual(P('/').anchor, sep) + self.assertEqual(P('/a/b').anchor, sep) + + def test_name_common(self): + P = self.cls + self.assertEqual(P('').name, '') + self.assertEqual(P('.').name, '') + self.assertEqual(P('/').name, '') + self.assertEqual(P('a/b').name, 'b') + self.assertEqual(P('/a/b').name, 'b') + self.assertEqual(P('/a/b/.').name, 'b') + self.assertEqual(P('a/b.py').name, 'b.py') + self.assertEqual(P('/a/b.py').name, 'b.py') + + def test_suffix_common(self): + P = self.cls + self.assertEqual(P('').suffix, '') + self.assertEqual(P('.').suffix, '') + self.assertEqual(P('..').suffix, '') + self.assertEqual(P('/').suffix, '') + self.assertEqual(P('a/b').suffix, '') + self.assertEqual(P('/a/b').suffix, '') + self.assertEqual(P('/a/b/.').suffix, '') + self.assertEqual(P('a/b.py').suffix, '.py') + self.assertEqual(P('/a/b.py').suffix, '.py') + self.assertEqual(P('a/.hgrc').suffix, '') + self.assertEqual(P('/a/.hgrc').suffix, '') + self.assertEqual(P('a/.hg.rc').suffix, '.rc') + self.assertEqual(P('/a/.hg.rc').suffix, '.rc') + self.assertEqual(P('a/b.tar.gz').suffix, '.gz') + self.assertEqual(P('/a/b.tar.gz').suffix, '.gz') + self.assertEqual(P('a/Some name. Ending with a dot.').suffix, '') + self.assertEqual(P('/a/Some name. Ending with a dot.').suffix, '') + + def test_suffixes_common(self): + P = self.cls + self.assertEqual(P('').suffixes, []) + self.assertEqual(P('.').suffixes, []) + self.assertEqual(P('/').suffixes, []) + self.assertEqual(P('a/b').suffixes, []) + self.assertEqual(P('/a/b').suffixes, []) + self.assertEqual(P('/a/b/.').suffixes, []) + self.assertEqual(P('a/b.py').suffixes, ['.py']) + self.assertEqual(P('/a/b.py').suffixes, ['.py']) + self.assertEqual(P('a/.hgrc').suffixes, []) + self.assertEqual(P('/a/.hgrc').suffixes, []) + self.assertEqual(P('a/.hg.rc').suffixes, ['.rc']) + self.assertEqual(P('/a/.hg.rc').suffixes, ['.rc']) + self.assertEqual(P('a/b.tar.gz').suffixes, ['.tar', '.gz']) + self.assertEqual(P('/a/b.tar.gz').suffixes, ['.tar', '.gz']) + self.assertEqual(P('a/Some name. Ending with a dot.').suffixes, []) + self.assertEqual(P('/a/Some name. Ending with a dot.').suffixes, []) + + def test_stem_common(self): + P = self.cls + self.assertEqual(P('').stem, '') + self.assertEqual(P('.').stem, '') + self.assertEqual(P('..').stem, '..') + self.assertEqual(P('/').stem, '') + self.assertEqual(P('a/b').stem, 'b') + self.assertEqual(P('a/b.py').stem, 'b') + self.assertEqual(P('a/.hgrc').stem, '.hgrc') + self.assertEqual(P('a/.hg.rc').stem, '.hg') + self.assertEqual(P('a/b.tar.gz').stem, 'b.tar') + self.assertEqual(P('a/Some name. Ending with a dot.').stem, + 'Some name. Ending with a dot.') + + def test_with_name_common(self): + P = self.cls + self.assertEqual(P('a/b').with_name('d.xml'), P('a/d.xml')) + self.assertEqual(P('/a/b').with_name('d.xml'), P('/a/d.xml')) + self.assertEqual(P('a/b.py').with_name('d.xml'), P('a/d.xml')) + self.assertEqual(P('/a/b.py').with_name('d.xml'), P('/a/d.xml')) + self.assertEqual(P('a/Dot ending.').with_name('d.xml'), P('a/d.xml')) + self.assertEqual(P('/a/Dot ending.').with_name('d.xml'), P('/a/d.xml')) + self.assertRaises(ValueError, P('').with_name, 'd.xml') + self.assertRaises(ValueError, P('.').with_name, 'd.xml') + self.assertRaises(ValueError, P('/').with_name, 'd.xml') + + def test_with_suffix_common(self): + P = self.cls + self.assertEqual(P('a/b').with_suffix('.gz'), P('a/b.gz')) + self.assertEqual(P('/a/b').with_suffix('.gz'), P('/a/b.gz')) + self.assertEqual(P('a/b.py').with_suffix('.gz'), P('a/b.gz')) + self.assertEqual(P('/a/b.py').with_suffix('.gz'), P('/a/b.gz')) + # Path doesn't have a "filename" component + self.assertRaises(ValueError, P('').with_suffix, '.gz') + self.assertRaises(ValueError, P('.').with_suffix, '.gz') + self.assertRaises(ValueError, P('/').with_suffix, '.gz') + # Invalid suffix + self.assertRaises(ValueError, P('a/b').with_suffix, 'gz') + self.assertRaises(ValueError, P('a/b').with_suffix, '/') + self.assertRaises(ValueError, P('a/b').with_suffix, '/.gz') + self.assertRaises(ValueError, P('a/b').with_suffix, 'c/d') + self.assertRaises(ValueError, P('a/b').with_suffix, '.c/.d') + + def test_relative_to_common(self): + P = self.cls + p = P('a/b') + self.assertRaises(TypeError, p.relative_to) + self.assertRaises(TypeError, p.relative_to, b'a') + self.assertEqual(p.relative_to(P()), P('a/b')) + self.assertEqual(p.relative_to(''), P('a/b')) + self.assertEqual(p.relative_to(P('a')), P('b')) + self.assertEqual(p.relative_to('a'), P('b')) + self.assertEqual(p.relative_to('a/'), P('b')) + self.assertEqual(p.relative_to(P('a/b')), P()) + self.assertEqual(p.relative_to('a/b'), P()) + # With several args + self.assertEqual(p.relative_to('a', 'b'), P()) + # Unrelated paths + self.assertRaises(ValueError, p.relative_to, P('c')) + self.assertRaises(ValueError, p.relative_to, P('a/b/c')) + self.assertRaises(ValueError, p.relative_to, P('a/c')) + self.assertRaises(ValueError, p.relative_to, P('/a')) + p = P('/a/b') + self.assertEqual(p.relative_to(P('/')), P('a/b')) + self.assertEqual(p.relative_to('/'), P('a/b')) + self.assertEqual(p.relative_to(P('/a')), P('b')) + self.assertEqual(p.relative_to('/a'), P('b')) + self.assertEqual(p.relative_to('/a/'), P('b')) + self.assertEqual(p.relative_to(P('/a/b')), P()) + self.assertEqual(p.relative_to('/a/b'), P()) + # Unrelated paths + self.assertRaises(ValueError, p.relative_to, P('/c')) + self.assertRaises(ValueError, p.relative_to, P('/a/b/c')) + self.assertRaises(ValueError, p.relative_to, P('/a/c')) + self.assertRaises(ValueError, p.relative_to, P()) + self.assertRaises(ValueError, p.relative_to, '') + self.assertRaises(ValueError, p.relative_to, P('a')) + + def test_pickling_common(self): + P = self.cls + p = P('/a/b') + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + dumped = pickle.dumps(p, proto) + pp = pickle.loads(dumped) + self.assertIs(pp.__class__, p.__class__) + self.assertEqual(pp, p) + self.assertEqual(hash(pp), hash(p)) + self.assertEqual(str(pp), str(p)) + + +class PurePosixPathTest(_BasePurePathTest, unittest.TestCase): + cls = pathlib.PurePosixPath + + def test_root(self): + P = self.cls + self.assertEqual(P('/a/b').root, '/') + self.assertEqual(P('///a/b').root, '/') + # POSIX special case for two leading slashes + self.assertEqual(P('//a/b').root, '//') + + def test_eq(self): + P = self.cls + self.assertNotEqual(P('a/b'), P('A/b')) + self.assertEqual(P('/a'), P('///a')) + self.assertNotEqual(P('/a'), P('//a')) + + def test_as_uri(self): + P = self.cls + self.assertEqual(P('/').as_uri(), 'file:///') + self.assertEqual(P('/a/b.c').as_uri(), 'file:///a/b.c') + self.assertEqual(P('/a/b%#c').as_uri(), 'file:///a/b%25%23c') + + def test_as_uri_non_ascii(self): + from urllib.parse import quote_from_bytes + P = self.cls + try: + os.fsencode('\xe9') + except UnicodeEncodeError: + self.skipTest("\\xe9 cannot be encoded to the filesystem encoding") + self.assertEqual(P('/a/b\xe9').as_uri(), + 'file:///a/b' + quote_from_bytes(os.fsencode('\xe9'))) + + def test_match(self): + P = self.cls + self.assertFalse(P('A.py').match('a.PY')) + + def test_is_absolute(self): + P = self.cls + self.assertFalse(P().is_absolute()) + self.assertFalse(P('a').is_absolute()) + self.assertFalse(P('a/b/').is_absolute()) + self.assertTrue(P('/').is_absolute()) + self.assertTrue(P('/a').is_absolute()) + self.assertTrue(P('/a/b/').is_absolute()) + self.assertTrue(P('//a').is_absolute()) + self.assertTrue(P('//a/b').is_absolute()) + + def test_is_reserved(self): + P = self.cls + self.assertIs(False, P('').is_reserved()) + self.assertIs(False, P('/').is_reserved()) + self.assertIs(False, P('/foo/bar').is_reserved()) + self.assertIs(False, P('/dev/con/PRN/NUL').is_reserved()) + + def test_join(self): + P = self.cls + p = P('//a') + pp = p.joinpath('b') + self.assertEqual(pp, P('//a/b')) + pp = P('/a').joinpath('//c') + self.assertEqual(pp, P('//c')) + pp = P('//a').joinpath('/c') + self.assertEqual(pp, P('/c')) + + def test_div(self): + # Basically the same as joinpath() + P = self.cls + p = P('//a') + pp = p / 'b' + self.assertEqual(pp, P('//a/b')) + pp = P('/a') / '//c' + self.assertEqual(pp, P('//c')) + pp = P('//a') / '/c' + self.assertEqual(pp, P('/c')) + + +class PureWindowsPathTest(_BasePurePathTest, unittest.TestCase): + cls = pathlib.PureWindowsPath + + equivalences = _BasePurePathTest.equivalences.copy() + equivalences.update({ + 'c:a': [ ('c:', 'a'), ('c:', 'a/'), ('/', 'c:', 'a') ], + 'c:/a': [ + ('c:/', 'a'), ('c:', '/', 'a'), ('c:', '/a'), + ('/z', 'c:/', 'a'), ('//x/y', 'c:/', 'a'), + ], + '//a/b/': [ ('//a/b',) ], + '//a/b/c': [ + ('//a/b', 'c'), ('//a/b/', 'c'), + ], + }) + + def test_str(self): + p = self.cls('a/b/c') + self.assertEqual(str(p), 'a\\b\\c') + p = self.cls('c:/a/b/c') + self.assertEqual(str(p), 'c:\\a\\b\\c') + p = self.cls('//a/b') + self.assertEqual(str(p), '\\\\a\\b\\') + p = self.cls('//a/b/c') + self.assertEqual(str(p), '\\\\a\\b\\c') + p = self.cls('//a/b/c/d') + self.assertEqual(str(p), '\\\\a\\b\\c\\d') + + def test_eq(self): + P = self.cls + self.assertEqual(P('c:a/b'), P('c:a/b')) + self.assertEqual(P('c:a/b'), P('c:', 'a', 'b')) + self.assertNotEqual(P('c:a/b'), P('d:a/b')) + self.assertNotEqual(P('c:a/b'), P('c:/a/b')) + self.assertNotEqual(P('/a/b'), P('c:/a/b')) + # Case-insensitivity + self.assertEqual(P('a/B'), P('A/b')) + self.assertEqual(P('C:a/B'), P('c:A/b')) + self.assertEqual(P('//Some/SHARE/a/B'), P('//somE/share/A/b')) + + def test_as_uri(self): + from urllib.parse import quote_from_bytes + P = self.cls + with self.assertRaises(ValueError): + P('/a/b').as_uri() + with self.assertRaises(ValueError): + P('c:a/b').as_uri() + self.assertEqual(P('c:/').as_uri(), 'file:///c:/') + self.assertEqual(P('c:/a/b.c').as_uri(), 'file:///c:/a/b.c') + self.assertEqual(P('c:/a/b%#c').as_uri(), 'file:///c:/a/b%25%23c') + self.assertEqual(P('c:/a/b\xe9').as_uri(), 'file:///c:/a/b%C3%A9') + self.assertEqual(P('//some/share/').as_uri(), 'file://some/share/') + self.assertEqual(P('//some/share/a/b.c').as_uri(), + 'file://some/share/a/b.c') + self.assertEqual(P('//some/share/a/b%#c\xe9').as_uri(), + 'file://some/share/a/b%25%23c%C3%A9') + + def test_match_common(self): + P = self.cls + # Absolute patterns + self.assertTrue(P('c:/b.py').match('/*.py')) + self.assertTrue(P('c:/b.py').match('c:*.py')) + self.assertTrue(P('c:/b.py').match('c:/*.py')) + self.assertFalse(P('d:/b.py').match('c:/*.py')) # wrong drive + self.assertFalse(P('b.py').match('/*.py')) + self.assertFalse(P('b.py').match('c:*.py')) + self.assertFalse(P('b.py').match('c:/*.py')) + self.assertFalse(P('c:b.py').match('/*.py')) + self.assertFalse(P('c:b.py').match('c:/*.py')) + self.assertFalse(P('/b.py').match('c:*.py')) + self.assertFalse(P('/b.py').match('c:/*.py')) + # UNC patterns + self.assertTrue(P('//some/share/a.py').match('/*.py')) + self.assertTrue(P('//some/share/a.py').match('//some/share/*.py')) + self.assertFalse(P('//other/share/a.py').match('//some/share/*.py')) + self.assertFalse(P('//some/share/a/b.py').match('//some/share/*.py')) + # Case-insensitivity + self.assertTrue(P('B.py').match('b.PY')) + self.assertTrue(P('c:/a/B.Py').match('C:/A/*.pY')) + self.assertTrue(P('//Some/Share/B.Py').match('//somE/sharE/*.pY')) + + def test_ordering_common(self): + # Case-insensitivity + def assertOrderedEqual(a, b): + self.assertLessEqual(a, b) + self.assertGreaterEqual(b, a) + P = self.cls + p = P('c:A/b') + q = P('C:a/B') + assertOrderedEqual(p, q) + self.assertFalse(p < q) + self.assertFalse(p > q) + p = P('//some/Share/A/b') + q = P('//Some/SHARE/a/B') + assertOrderedEqual(p, q) + self.assertFalse(p < q) + self.assertFalse(p > q) + + def test_parts(self): + P = self.cls + p = P('c:a/b') + parts = p.parts + self.assertEqual(parts, ('c:', 'a', 'b')) + p = P('c:/a/b') + parts = p.parts + self.assertEqual(parts, ('c:\\', 'a', 'b')) + p = P('//a/b/c/d') + parts = p.parts + self.assertEqual(parts, ('\\\\a\\b\\', 'c', 'd')) + + def test_parent(self): + # Anchored + P = self.cls + p = P('z:a/b/c') + self.assertEqual(p.parent, P('z:a/b')) + self.assertEqual(p.parent.parent, P('z:a')) + self.assertEqual(p.parent.parent.parent, P('z:')) + self.assertEqual(p.parent.parent.parent.parent, P('z:')) + p = P('z:/a/b/c') + self.assertEqual(p.parent, P('z:/a/b')) + self.assertEqual(p.parent.parent, P('z:/a')) + self.assertEqual(p.parent.parent.parent, P('z:/')) + self.assertEqual(p.parent.parent.parent.parent, P('z:/')) + p = P('//a/b/c/d') + self.assertEqual(p.parent, P('//a/b/c')) + self.assertEqual(p.parent.parent, P('//a/b')) + self.assertEqual(p.parent.parent.parent, P('//a/b')) + + def test_parents(self): + # Anchored + P = self.cls + p = P('z:a/b/') + par = p.parents + self.assertEqual(len(par), 2) + self.assertEqual(par[0], P('z:a')) + self.assertEqual(par[1], P('z:')) + self.assertEqual(list(par), [P('z:a'), P('z:')]) + with self.assertRaises(IndexError): + par[2] + p = P('z:/a/b/') + par = p.parents + self.assertEqual(len(par), 2) + self.assertEqual(par[0], P('z:/a')) + self.assertEqual(par[1], P('z:/')) + self.assertEqual(list(par), [P('z:/a'), P('z:/')]) + with self.assertRaises(IndexError): + par[2] + p = P('//a/b/c/d') + par = p.parents + self.assertEqual(len(par), 2) + self.assertEqual(par[0], P('//a/b/c')) + self.assertEqual(par[1], P('//a/b')) + self.assertEqual(list(par), [P('//a/b/c'), P('//a/b')]) + with self.assertRaises(IndexError): + par[2] + + def test_drive(self): + P = self.cls + self.assertEqual(P('c:').drive, 'c:') + self.assertEqual(P('c:a/b').drive, 'c:') + self.assertEqual(P('c:/').drive, 'c:') + self.assertEqual(P('c:/a/b/').drive, 'c:') + self.assertEqual(P('//a/b').drive, '\\\\a\\b') + self.assertEqual(P('//a/b/').drive, '\\\\a\\b') + self.assertEqual(P('//a/b/c/d').drive, '\\\\a\\b') + + def test_root(self): + P = self.cls + self.assertEqual(P('c:').root, '') + self.assertEqual(P('c:a/b').root, '') + self.assertEqual(P('c:/').root, '\\') + self.assertEqual(P('c:/a/b/').root, '\\') + self.assertEqual(P('//a/b').root, '\\') + self.assertEqual(P('//a/b/').root, '\\') + self.assertEqual(P('//a/b/c/d').root, '\\') + + def test_anchor(self): + P = self.cls + self.assertEqual(P('c:').anchor, 'c:') + self.assertEqual(P('c:a/b').anchor, 'c:') + self.assertEqual(P('c:/').anchor, 'c:\\') + self.assertEqual(P('c:/a/b/').anchor, 'c:\\') + self.assertEqual(P('//a/b').anchor, '\\\\a\\b\\') + self.assertEqual(P('//a/b/').anchor, '\\\\a\\b\\') + self.assertEqual(P('//a/b/c/d').anchor, '\\\\a\\b\\') + + def test_name(self): + P = self.cls + self.assertEqual(P('c:').name, '') + self.assertEqual(P('c:/').name, '') + self.assertEqual(P('c:a/b').name, 'b') + self.assertEqual(P('c:/a/b').name, 'b') + self.assertEqual(P('c:a/b.py').name, 'b.py') + self.assertEqual(P('c:/a/b.py').name, 'b.py') + self.assertEqual(P('//My.py/Share.php').name, '') + self.assertEqual(P('//My.py/Share.php/a/b').name, 'b') + + def test_suffix(self): + P = self.cls + self.assertEqual(P('c:').suffix, '') + self.assertEqual(P('c:/').suffix, '') + self.assertEqual(P('c:a/b').suffix, '') + self.assertEqual(P('c:/a/b').suffix, '') + self.assertEqual(P('c:a/b.py').suffix, '.py') + self.assertEqual(P('c:/a/b.py').suffix, '.py') + self.assertEqual(P('c:a/.hgrc').suffix, '') + self.assertEqual(P('c:/a/.hgrc').suffix, '') + self.assertEqual(P('c:a/.hg.rc').suffix, '.rc') + self.assertEqual(P('c:/a/.hg.rc').suffix, '.rc') + self.assertEqual(P('c:a/b.tar.gz').suffix, '.gz') + self.assertEqual(P('c:/a/b.tar.gz').suffix, '.gz') + self.assertEqual(P('c:a/Some name. Ending with a dot.').suffix, '') + self.assertEqual(P('c:/a/Some name. Ending with a dot.').suffix, '') + self.assertEqual(P('//My.py/Share.php').suffix, '') + self.assertEqual(P('//My.py/Share.php/a/b').suffix, '') + + def test_suffixes(self): + P = self.cls + self.assertEqual(P('c:').suffixes, []) + self.assertEqual(P('c:/').suffixes, []) + self.assertEqual(P('c:a/b').suffixes, []) + self.assertEqual(P('c:/a/b').suffixes, []) + self.assertEqual(P('c:a/b.py').suffixes, ['.py']) + self.assertEqual(P('c:/a/b.py').suffixes, ['.py']) + self.assertEqual(P('c:a/.hgrc').suffixes, []) + self.assertEqual(P('c:/a/.hgrc').suffixes, []) + self.assertEqual(P('c:a/.hg.rc').suffixes, ['.rc']) + self.assertEqual(P('c:/a/.hg.rc').suffixes, ['.rc']) + self.assertEqual(P('c:a/b.tar.gz').suffixes, ['.tar', '.gz']) + self.assertEqual(P('c:/a/b.tar.gz').suffixes, ['.tar', '.gz']) + self.assertEqual(P('//My.py/Share.php').suffixes, []) + self.assertEqual(P('//My.py/Share.php/a/b').suffixes, []) + self.assertEqual(P('c:a/Some name. Ending with a dot.').suffixes, []) + self.assertEqual(P('c:/a/Some name. Ending with a dot.').suffixes, []) + + def test_stem(self): + P = self.cls + self.assertEqual(P('c:').stem, '') + self.assertEqual(P('c:.').stem, '') + self.assertEqual(P('c:..').stem, '..') + self.assertEqual(P('c:/').stem, '') + self.assertEqual(P('c:a/b').stem, 'b') + self.assertEqual(P('c:a/b.py').stem, 'b') + self.assertEqual(P('c:a/.hgrc').stem, '.hgrc') + self.assertEqual(P('c:a/.hg.rc').stem, '.hg') + self.assertEqual(P('c:a/b.tar.gz').stem, 'b.tar') + self.assertEqual(P('c:a/Some name. Ending with a dot.').stem, + 'Some name. Ending with a dot.') + + def test_with_name(self): + P = self.cls + self.assertEqual(P('c:a/b').with_name('d.xml'), P('c:a/d.xml')) + self.assertEqual(P('c:/a/b').with_name('d.xml'), P('c:/a/d.xml')) + self.assertEqual(P('c:a/Dot ending.').with_name('d.xml'), P('c:a/d.xml')) + self.assertEqual(P('c:/a/Dot ending.').with_name('d.xml'), P('c:/a/d.xml')) + self.assertRaises(ValueError, P('c:').with_name, 'd.xml') + self.assertRaises(ValueError, P('c:/').with_name, 'd.xml') + self.assertRaises(ValueError, P('//My/Share').with_name, 'd.xml') + + def test_with_suffix(self): + P = self.cls + self.assertEqual(P('c:a/b').with_suffix('.gz'), P('c:a/b.gz')) + self.assertEqual(P('c:/a/b').with_suffix('.gz'), P('c:/a/b.gz')) + self.assertEqual(P('c:a/b.py').with_suffix('.gz'), P('c:a/b.gz')) + self.assertEqual(P('c:/a/b.py').with_suffix('.gz'), P('c:/a/b.gz')) + # Path doesn't have a "filename" component + self.assertRaises(ValueError, P('').with_suffix, '.gz') + self.assertRaises(ValueError, P('.').with_suffix, '.gz') + self.assertRaises(ValueError, P('/').with_suffix, '.gz') + self.assertRaises(ValueError, P('//My/Share').with_suffix, '.gz') + # Invalid suffix + self.assertRaises(ValueError, P('c:a/b').with_suffix, 'gz') + self.assertRaises(ValueError, P('c:a/b').with_suffix, '/') + self.assertRaises(ValueError, P('c:a/b').with_suffix, '\\') + self.assertRaises(ValueError, P('c:a/b').with_suffix, 'c:') + self.assertRaises(ValueError, P('c:a/b').with_suffix, '/.gz') + self.assertRaises(ValueError, P('c:a/b').with_suffix, '\\.gz') + self.assertRaises(ValueError, P('c:a/b').with_suffix, 'c:.gz') + self.assertRaises(ValueError, P('c:a/b').with_suffix, 'c/d') + self.assertRaises(ValueError, P('c:a/b').with_suffix, 'c\\d') + self.assertRaises(ValueError, P('c:a/b').with_suffix, '.c/d') + self.assertRaises(ValueError, P('c:a/b').with_suffix, '.c\\d') + + def test_relative_to(self): + P = self.cls + p = P('C:Foo/Bar') + self.assertEqual(p.relative_to(P('c:')), P('Foo/Bar')) + self.assertEqual(p.relative_to('c:'), P('Foo/Bar')) + self.assertEqual(p.relative_to(P('c:foO')), P('Bar')) + self.assertEqual(p.relative_to('c:foO'), P('Bar')) + self.assertEqual(p.relative_to('c:foO/'), P('Bar')) + self.assertEqual(p.relative_to(P('c:foO/baR')), P()) + self.assertEqual(p.relative_to('c:foO/baR'), P()) + # Unrelated paths + self.assertRaises(ValueError, p.relative_to, P()) + self.assertRaises(ValueError, p.relative_to, '') + self.assertRaises(ValueError, p.relative_to, P('d:')) + self.assertRaises(ValueError, p.relative_to, P('/')) + self.assertRaises(ValueError, p.relative_to, P('Foo')) + self.assertRaises(ValueError, p.relative_to, P('/Foo')) + self.assertRaises(ValueError, p.relative_to, P('C:/Foo')) + self.assertRaises(ValueError, p.relative_to, P('C:Foo/Bar/Baz')) + self.assertRaises(ValueError, p.relative_to, P('C:Foo/Baz')) + p = P('C:/Foo/Bar') + self.assertEqual(p.relative_to(P('c:')), P('/Foo/Bar')) + self.assertEqual(p.relative_to('c:'), P('/Foo/Bar')) + self.assertEqual(str(p.relative_to(P('c:'))), '\\Foo\\Bar') + self.assertEqual(str(p.relative_to('c:')), '\\Foo\\Bar') + self.assertEqual(p.relative_to(P('c:/')), P('Foo/Bar')) + self.assertEqual(p.relative_to('c:/'), P('Foo/Bar')) + self.assertEqual(p.relative_to(P('c:/foO')), P('Bar')) + self.assertEqual(p.relative_to('c:/foO'), P('Bar')) + self.assertEqual(p.relative_to('c:/foO/'), P('Bar')) + self.assertEqual(p.relative_to(P('c:/foO/baR')), P()) + self.assertEqual(p.relative_to('c:/foO/baR'), P()) + # Unrelated paths + self.assertRaises(ValueError, p.relative_to, P('C:/Baz')) + self.assertRaises(ValueError, p.relative_to, P('C:/Foo/Bar/Baz')) + self.assertRaises(ValueError, p.relative_to, P('C:/Foo/Baz')) + self.assertRaises(ValueError, p.relative_to, P('C:Foo')) + self.assertRaises(ValueError, p.relative_to, P('d:')) + self.assertRaises(ValueError, p.relative_to, P('d:/')) + self.assertRaises(ValueError, p.relative_to, P('/')) + self.assertRaises(ValueError, p.relative_to, P('/Foo')) + self.assertRaises(ValueError, p.relative_to, P('//C/Foo')) + # UNC paths + p = P('//Server/Share/Foo/Bar') + self.assertEqual(p.relative_to(P('//sErver/sHare')), P('Foo/Bar')) + self.assertEqual(p.relative_to('//sErver/sHare'), P('Foo/Bar')) + self.assertEqual(p.relative_to('//sErver/sHare/'), P('Foo/Bar')) + self.assertEqual(p.relative_to(P('//sErver/sHare/Foo')), P('Bar')) + self.assertEqual(p.relative_to('//sErver/sHare/Foo'), P('Bar')) + self.assertEqual(p.relative_to('//sErver/sHare/Foo/'), P('Bar')) + self.assertEqual(p.relative_to(P('//sErver/sHare/Foo/Bar')), P()) + self.assertEqual(p.relative_to('//sErver/sHare/Foo/Bar'), P()) + # Unrelated paths + self.assertRaises(ValueError, p.relative_to, P('/Server/Share/Foo')) + self.assertRaises(ValueError, p.relative_to, P('c:/Server/Share/Foo')) + self.assertRaises(ValueError, p.relative_to, P('//z/Share/Foo')) + self.assertRaises(ValueError, p.relative_to, P('//Server/z/Foo')) + + def test_is_absolute(self): + P = self.cls + # Under NT, only paths with both a drive and a root are absolute + self.assertFalse(P().is_absolute()) + self.assertFalse(P('a').is_absolute()) + self.assertFalse(P('a/b/').is_absolute()) + self.assertFalse(P('/').is_absolute()) + self.assertFalse(P('/a').is_absolute()) + self.assertFalse(P('/a/b/').is_absolute()) + self.assertFalse(P('c:').is_absolute()) + self.assertFalse(P('c:a').is_absolute()) + self.assertFalse(P('c:a/b/').is_absolute()) + self.assertTrue(P('c:/').is_absolute()) + self.assertTrue(P('c:/a').is_absolute()) + self.assertTrue(P('c:/a/b/').is_absolute()) + # UNC paths are absolute by definition + self.assertTrue(P('//a/b').is_absolute()) + self.assertTrue(P('//a/b/').is_absolute()) + self.assertTrue(P('//a/b/c').is_absolute()) + self.assertTrue(P('//a/b/c/d').is_absolute()) + + def test_join(self): + P = self.cls + p = P('C:/a/b') + pp = p.joinpath('x/y') + self.assertEqual(pp, P('C:/a/b/x/y')) + pp = p.joinpath('/x/y') + self.assertEqual(pp, P('C:/x/y')) + # Joining with a different drive => the first path is ignored, even + # if the second path is relative. + pp = p.joinpath('D:x/y') + self.assertEqual(pp, P('D:x/y')) + pp = p.joinpath('D:/x/y') + self.assertEqual(pp, P('D:/x/y')) + pp = p.joinpath('//host/share/x/y') + self.assertEqual(pp, P('//host/share/x/y')) + # Joining with the same drive => the first path is appended to if + # the second path is relative. + pp = p.joinpath('c:x/y') + self.assertEqual(pp, P('C:/a/b/x/y')) + pp = p.joinpath('c:/x/y') + self.assertEqual(pp, P('C:/x/y')) + + def test_div(self): + # Basically the same as joinpath() + P = self.cls + p = P('C:/a/b') + self.assertEqual(p / 'x/y', P('C:/a/b/x/y')) + self.assertEqual(p / 'x' / 'y', P('C:/a/b/x/y')) + self.assertEqual(p / '/x/y', P('C:/x/y')) + self.assertEqual(p / '/x' / 'y', P('C:/x/y')) + # Joining with a different drive => the first path is ignored, even + # if the second path is relative. + self.assertEqual(p / 'D:x/y', P('D:x/y')) + self.assertEqual(p / 'D:' / 'x/y', P('D:x/y')) + self.assertEqual(p / 'D:/x/y', P('D:/x/y')) + self.assertEqual(p / 'D:' / '/x/y', P('D:/x/y')) + self.assertEqual(p / '//host/share/x/y', P('//host/share/x/y')) + # Joining with the same drive => the first path is appended to if + # the second path is relative. + self.assertEqual(p / 'c:x/y', P('C:/a/b/x/y')) + self.assertEqual(p / 'c:/x/y', P('C:/x/y')) + + def test_is_reserved(self): + P = self.cls + self.assertIs(False, P('').is_reserved()) + self.assertIs(False, P('/').is_reserved()) + self.assertIs(False, P('/foo/bar').is_reserved()) + self.assertIs(True, P('con').is_reserved()) + self.assertIs(True, P('NUL').is_reserved()) + self.assertIs(True, P('NUL.txt').is_reserved()) + self.assertIs(True, P('com1').is_reserved()) + self.assertIs(True, P('com9.bar').is_reserved()) + self.assertIs(False, P('bar.com9').is_reserved()) + self.assertIs(True, P('lpt1').is_reserved()) + self.assertIs(True, P('lpt9.bar').is_reserved()) + self.assertIs(False, P('bar.lpt9').is_reserved()) + # Only the last component matters + self.assertIs(False, P('c:/NUL/con/baz').is_reserved()) + # UNC paths are never reserved + self.assertIs(False, P('//my/share/nul/con/aux').is_reserved()) + + +class PurePathTest(_BasePurePathTest, unittest.TestCase): + cls = pathlib.PurePath + + def test_concrete_class(self): + p = self.cls('a') + self.assertIs(type(p), + pathlib.PureWindowsPath if os.name == 'nt' else pathlib.PurePosixPath) + + def test_different_flavours_unequal(self): + p = pathlib.PurePosixPath('a') + q = pathlib.PureWindowsPath('a') + self.assertNotEqual(p, q) + + def test_different_flavours_unordered(self): + p = pathlib.PurePosixPath('a') + q = pathlib.PureWindowsPath('a') + with self.assertRaises(TypeError): + p < q + with self.assertRaises(TypeError): + p <= q + with self.assertRaises(TypeError): + p > q + with self.assertRaises(TypeError): + p >= q + + +# +# Tests for the concrete classes +# + +# Make sure any symbolic links in the base test path are resolved +BASE = os.path.realpath(TESTFN) +join = lambda *x: os.path.join(BASE, *x) +rel_join = lambda *x: os.path.join(TESTFN, *x) + +def symlink_skip_reason(): + if not pathlib.supports_symlinks: + return "no system support for symlinks" + try: + os.symlink(__file__, BASE) + except OSError as e: + return str(e) + else: + support.unlink(BASE) + return None + +symlink_skip_reason = symlink_skip_reason() + +only_nt = unittest.skipIf(os.name != 'nt', + 'test requires a Windows-compatible system') +only_posix = unittest.skipIf(os.name == 'nt', + 'test requires a POSIX-compatible system') +with_symlinks = unittest.skipIf(symlink_skip_reason, symlink_skip_reason) + + +@only_posix +class PosixPathAsPureTest(PurePosixPathTest): + cls = pathlib.PosixPath + +@only_nt +class WindowsPathAsPureTest(PureWindowsPathTest): + cls = pathlib.WindowsPath + + +class _BasePathTest(object): + """Tests for the FS-accessing functionalities of the Path classes.""" + + # (BASE) + # | + # |-- dirA/ + # |-- linkC -> "../dirB" + # |-- dirB/ + # | |-- fileB + # |-- linkD -> "../dirB" + # |-- dirC/ + # | |-- fileC + # | |-- fileD + # |-- fileA + # |-- linkA -> "fileA" + # |-- linkB -> "dirB" + # + + def setUp(self): + os.mkdir(BASE) + self.addCleanup(shutil.rmtree, BASE) + os.mkdir(join('dirA')) + os.mkdir(join('dirB')) + os.mkdir(join('dirC')) + os.mkdir(join('dirC', 'dirD')) + with open(join('fileA'), 'wb') as f: + f.write(b"this is file A\n") + with open(join('dirB', 'fileB'), 'wb') as f: + f.write(b"this is file B\n") + with open(join('dirC', 'fileC'), 'wb') as f: + f.write(b"this is file C\n") + with open(join('dirC', 'dirD', 'fileD'), 'wb') as f: + f.write(b"this is file D\n") + if not symlink_skip_reason: + # Relative symlinks + os.symlink('fileA', join('linkA')) + os.symlink('non-existing', join('brokenLink')) + self.dirlink('dirB', join('linkB')) + self.dirlink(os.path.join('..', 'dirB'), join('dirA', 'linkC')) + # This one goes upwards but doesn't create a loop + self.dirlink(os.path.join('..', 'dirB'), join('dirB', 'linkD')) + + if os.name == 'nt': + # Workaround for http://bugs.python.org/issue13772 + def dirlink(self, src, dest): + os.symlink(src, dest, target_is_directory=True) + else: + def dirlink(self, src, dest): + os.symlink(src, dest) + + def assertSame(self, path_a, path_b): + self.assertTrue(os.path.samefile(str(path_a), str(path_b)), + "%r and %r don't point to the same file" % + (path_a, path_b)) + + def assertFileNotFound(self, func, *args, **kwargs): + with self.assertRaises(FileNotFoundError) as cm: + func(*args, **kwargs) + self.assertEqual(cm.exception.errno, errno.ENOENT) + + def _test_cwd(self, p): + q = self.cls(os.getcwd()) + self.assertEqual(p, q) + self.assertEqual(str(p), str(q)) + self.assertIs(type(p), type(q)) + self.assertTrue(p.is_absolute()) + + def test_cwd(self): + p = self.cls.cwd() + self._test_cwd(p) + + def test_empty_path(self): + # The empty path points to '.' + p = self.cls('') + self.assertEqual(p.stat(), os.stat('.')) + + def test_exists(self): + P = self.cls + p = P(BASE) + self.assertIs(True, p.exists()) + self.assertIs(True, (p / 'dirA').exists()) + self.assertIs(True, (p / 'fileA').exists()) + if not symlink_skip_reason: + self.assertIs(True, (p / 'linkA').exists()) + self.assertIs(True, (p / 'linkB').exists()) + self.assertIs(False, (p / 'foo').exists()) + self.assertIs(False, P('/xyzzy').exists()) + + def test_open_common(self): + p = self.cls(BASE) + with (p / 'fileA').open('r') as f: + self.assertIsInstance(f, io.TextIOBase) + self.assertEqual(f.read(), "this is file A\n") + with (p / 'fileA').open('rb') as f: + self.assertIsInstance(f, io.BufferedIOBase) + self.assertEqual(f.read().strip(), b"this is file A") + with (p / 'fileA').open('rb', buffering=0) as f: + self.assertIsInstance(f, io.RawIOBase) + self.assertEqual(f.read().strip(), b"this is file A") + + def test_iterdir(self): + P = self.cls + p = P(BASE) + it = p.iterdir() + paths = set(it) + expected = ['dirA', 'dirB', 'dirC', 'fileA'] + if not symlink_skip_reason: + expected += ['linkA', 'linkB', 'brokenLink'] + self.assertEqual(paths, { P(BASE, q) for q in expected }) + + @with_symlinks + def test_iterdir_symlink(self): + # __iter__ on a symlink to a directory + P = self.cls + p = P(BASE, 'linkB') + paths = set(p.iterdir()) + expected = { P(BASE, 'linkB', q) for q in ['fileB', 'linkD'] } + self.assertEqual(paths, expected) + + def test_iterdir_nodir(self): + # __iter__ on something that is not a directory + p = self.cls(BASE, 'fileA') + with self.assertRaises(OSError) as cm: + next(p.iterdir()) + # ENOENT or EINVAL under Windows, ENOTDIR otherwise + # (see issue #12802) + self.assertIn(cm.exception.errno, (errno.ENOTDIR, + errno.ENOENT, errno.EINVAL)) + + def test_glob_common(self): + def _check(glob, expected): + self.assertEqual(set(glob), { P(BASE, q) for q in expected }) + P = self.cls + p = P(BASE) + it = p.glob("fileA") + self.assertIsInstance(it, collections.Iterator) + _check(it, ["fileA"]) + _check(p.glob("fileB"), []) + _check(p.glob("dir*/file*"), ["dirB/fileB", "dirC/fileC"]) + if symlink_skip_reason: + _check(p.glob("*A"), ['dirA', 'fileA']) + else: + _check(p.glob("*A"), ['dirA', 'fileA', 'linkA']) + if symlink_skip_reason: + _check(p.glob("*B/*"), ['dirB/fileB']) + else: + _check(p.glob("*B/*"), ['dirB/fileB', 'dirB/linkD', + 'linkB/fileB', 'linkB/linkD']) + if symlink_skip_reason: + _check(p.glob("*/fileB"), ['dirB/fileB']) + else: + _check(p.glob("*/fileB"), ['dirB/fileB', 'linkB/fileB']) + + def test_rglob_common(self): + def _check(glob, expected): + self.assertEqual(set(glob), { P(BASE, q) for q in expected }) + P = self.cls + p = P(BASE) + it = p.rglob("fileA") + self.assertIsInstance(it, collections.Iterator) + # XXX cannot test because of symlink loops in the test setup + #_check(it, ["fileA"]) + #_check(p.rglob("fileB"), ["dirB/fileB"]) + #_check(p.rglob("*/fileA"), [""]) + #_check(p.rglob("*/fileB"), ["dirB/fileB"]) + #_check(p.rglob("file*"), ["fileA", "dirB/fileB"]) + # No symlink loops here + p = P(BASE, "dirC") + _check(p.rglob("file*"), ["dirC/fileC", "dirC/dirD/fileD"]) + _check(p.rglob("*/*"), ["dirC/dirD/fileD"]) + + def test_glob_dotdot(self): + # ".." is not special in globs + P = self.cls + p = P(BASE) + self.assertEqual(set(p.glob("..")), { P(BASE, "..") }) + self.assertEqual(set(p.glob("dirA/../file*")), { P(BASE, "dirA/../fileA") }) + self.assertEqual(set(p.glob("../xyzzy")), set()) + + def _check_resolve_relative(self, p, expected): + q = p.resolve() + self.assertEqual(q, expected) + + def _check_resolve_absolute(self, p, expected): + q = p.resolve() + self.assertEqual(q, expected) + + @with_symlinks + def test_resolve_common(self): + P = self.cls + p = P(BASE, 'foo') + with self.assertRaises(OSError) as cm: + p.resolve() + self.assertEqual(cm.exception.errno, errno.ENOENT) + # These are all relative symlinks + p = P(BASE, 'dirB', 'fileB') + self._check_resolve_relative(p, p) + p = P(BASE, 'linkA') + self._check_resolve_relative(p, P(BASE, 'fileA')) + p = P(BASE, 'dirA', 'linkC', 'fileB') + self._check_resolve_relative(p, P(BASE, 'dirB', 'fileB')) + p = P(BASE, 'dirB', 'linkD', 'fileB') + self._check_resolve_relative(p, P(BASE, 'dirB', 'fileB')) + # Now create absolute symlinks + d = tempfile.mkdtemp(suffix='-dirD') + self.addCleanup(shutil.rmtree, d) + os.symlink(os.path.join(d), join('dirA', 'linkX')) + os.symlink(join('dirB'), os.path.join(d, 'linkY')) + p = P(BASE, 'dirA', 'linkX', 'linkY', 'fileB') + self._check_resolve_absolute(p, P(BASE, 'dirB', 'fileB')) + + @with_symlinks + def test_resolve_dot(self): + # See https://bitbucket.org/pitrou/pathlib/issue/9/pathresolve-fails-on-complex-symlinks + p = self.cls(BASE) + self.dirlink('.', join('0')) + self.dirlink(os.path.join('0', '0'), join('1')) + self.dirlink(os.path.join('1', '1'), join('2')) + q = p / '2' + self.assertEqual(q.resolve(), p) + + def test_with(self): + p = self.cls(BASE) + it = p.iterdir() + it2 = p.iterdir() + next(it2) + with p: + pass + # I/O operation on closed path + self.assertRaises(ValueError, next, it) + self.assertRaises(ValueError, next, it2) + self.assertRaises(ValueError, p.open) + self.assertRaises(ValueError, p.resolve) + self.assertRaises(ValueError, p.absolute) + self.assertRaises(ValueError, p.__enter__) + + def test_chmod(self): + p = self.cls(BASE) / 'fileA' + mode = p.stat().st_mode + # Clear writable bit + new_mode = mode & ~0o222 + p.chmod(new_mode) + self.assertEqual(p.stat().st_mode, new_mode) + # Set writable bit + new_mode = mode | 0o222 + p.chmod(new_mode) + self.assertEqual(p.stat().st_mode, new_mode) + + # XXX also need a test for lchmod + + def test_stat(self): + p = self.cls(BASE) / 'fileA' + st = p.stat() + self.assertEqual(p.stat(), st) + # Change file mode by flipping write bit + p.chmod(st.st_mode ^ 0o222) + self.addCleanup(p.chmod, st.st_mode) + self.assertNotEqual(p.stat(), st) + + @with_symlinks + def test_lstat(self): + p = self.cls(BASE)/ 'linkA' + st = p.stat() + self.assertNotEqual(st, p.lstat()) + + def test_lstat_nosymlink(self): + p = self.cls(BASE) / 'fileA' + st = p.stat() + self.assertEqual(st, p.lstat()) + + @unittest.skipUnless(pwd, "the pwd module is needed for this test") + def test_owner(self): + p = self.cls(BASE) / 'fileA' + uid = p.stat().st_uid + try: + name = pwd.getpwuid(uid).pw_name + except KeyError: + self.skipTest( + "user %d doesn't have an entry in the system database" % uid) + self.assertEqual(name, p.owner()) + + @unittest.skipUnless(grp, "the grp module is needed for this test") + def test_group(self): + p = self.cls(BASE) / 'fileA' + gid = p.stat().st_gid + try: + name = grp.getgrgid(gid).gr_name + except KeyError: + self.skipTest( + "group %d doesn't have an entry in the system database" % gid) + self.assertEqual(name, p.group()) + + def test_unlink(self): + p = self.cls(BASE) / 'fileA' + p.unlink() + self.assertFileNotFound(p.stat) + self.assertFileNotFound(p.unlink) + + def test_rmdir(self): + p = self.cls(BASE) / 'dirA' + for q in p.iterdir(): + q.unlink() + p.rmdir() + self.assertFileNotFound(p.stat) + self.assertFileNotFound(p.unlink) + + def test_rename(self): + P = self.cls(BASE) + p = P / 'fileA' + size = p.stat().st_size + # Renaming to another path + q = P / 'dirA' / 'fileAA' + p.rename(q) + self.assertEqual(q.stat().st_size, size) + self.assertFileNotFound(p.stat) + # Renaming to a str of a relative path + r = rel_join('fileAAA') + q.rename(r) + self.assertEqual(os.stat(r).st_size, size) + self.assertFileNotFound(q.stat) + + def test_replace(self): + P = self.cls(BASE) + p = P / 'fileA' + size = p.stat().st_size + # Replacing a non-existing path + q = P / 'dirA' / 'fileAA' + p.replace(q) + self.assertEqual(q.stat().st_size, size) + self.assertFileNotFound(p.stat) + # Replacing another (existing) path + r = rel_join('dirB', 'fileB') + q.replace(r) + self.assertEqual(os.stat(r).st_size, size) + self.assertFileNotFound(q.stat) + + def test_touch_common(self): + P = self.cls(BASE) + p = P / 'newfileA' + self.assertFalse(p.exists()) + p.touch() + self.assertTrue(p.exists()) + st = p.stat() + old_mtime = st.st_mtime + old_mtime_ns = st.st_mtime_ns + # Rewind the mtime sufficiently far in the past to work around + # filesystem-specific timestamp granularity. + os.utime(str(p), (old_mtime - 10, old_mtime - 10)) + # The file mtime should be refreshed by calling touch() again + p.touch() + st = p.stat() + self.assertGreaterEqual(st.st_mtime_ns, old_mtime_ns) + self.assertGreaterEqual(st.st_mtime, old_mtime) + # Now with exist_ok=False + p = P / 'newfileB' + self.assertFalse(p.exists()) + p.touch(mode=0o700, exist_ok=False) + self.assertTrue(p.exists()) + self.assertRaises(OSError, p.touch, exist_ok=False) + + def test_touch_nochange(self): + P = self.cls(BASE) + p = P / 'fileA' + p.touch() + with p.open('rb') as f: + self.assertEqual(f.read().strip(), b"this is file A") + + def test_mkdir(self): + P = self.cls(BASE) + p = P / 'newdirA' + self.assertFalse(p.exists()) + p.mkdir() + self.assertTrue(p.exists()) + self.assertTrue(p.is_dir()) + with self.assertRaises(OSError) as cm: + p.mkdir() + self.assertEqual(cm.exception.errno, errno.EEXIST) + + def test_mkdir_parents(self): + # Creating a chain of directories + p = self.cls(BASE, 'newdirB', 'newdirC') + self.assertFalse(p.exists()) + with self.assertRaises(OSError) as cm: + p.mkdir() + self.assertEqual(cm.exception.errno, errno.ENOENT) + p.mkdir(parents=True) + self.assertTrue(p.exists()) + self.assertTrue(p.is_dir()) + with self.assertRaises(OSError) as cm: + p.mkdir(parents=True) + self.assertEqual(cm.exception.errno, errno.EEXIST) + # test `mode` arg + mode = stat.S_IMODE(p.stat().st_mode) # default mode + p = self.cls(BASE, 'newdirD', 'newdirE') + p.mkdir(0o555, parents=True) + self.assertTrue(p.exists()) + self.assertTrue(p.is_dir()) + if os.name != 'nt': + # the directory's permissions follow the mode argument + self.assertEqual(stat.S_IMODE(p.stat().st_mode), 0o7555 & mode) + # the parent's permissions follow the default process settings + self.assertEqual(stat.S_IMODE(p.parent.stat().st_mode), mode) + + @with_symlinks + def test_symlink_to(self): + P = self.cls(BASE) + target = P / 'fileA' + # Symlinking a path target + link = P / 'dirA' / 'linkAA' + link.symlink_to(target) + self.assertEqual(link.stat(), target.stat()) + self.assertNotEqual(link.lstat(), target.stat()) + # Symlinking a str target + link = P / 'dirA' / 'linkAAA' + link.symlink_to(str(target)) + self.assertEqual(link.stat(), target.stat()) + self.assertNotEqual(link.lstat(), target.stat()) + self.assertFalse(link.is_dir()) + # Symlinking to a directory + target = P / 'dirB' + link = P / 'dirA' / 'linkAAAA' + link.symlink_to(target, target_is_directory=True) + self.assertEqual(link.stat(), target.stat()) + self.assertNotEqual(link.lstat(), target.stat()) + self.assertTrue(link.is_dir()) + self.assertTrue(list(link.iterdir())) + + def test_is_dir(self): + P = self.cls(BASE) + self.assertTrue((P / 'dirA').is_dir()) + self.assertFalse((P / 'fileA').is_dir()) + self.assertFalse((P / 'non-existing').is_dir()) + if not symlink_skip_reason: + self.assertFalse((P / 'linkA').is_dir()) + self.assertTrue((P / 'linkB').is_dir()) + self.assertFalse((P/ 'brokenLink').is_dir()) + + def test_is_file(self): + P = self.cls(BASE) + self.assertTrue((P / 'fileA').is_file()) + self.assertFalse((P / 'dirA').is_file()) + self.assertFalse((P / 'non-existing').is_file()) + if not symlink_skip_reason: + self.assertTrue((P / 'linkA').is_file()) + self.assertFalse((P / 'linkB').is_file()) + self.assertFalse((P/ 'brokenLink').is_file()) + + def test_is_symlink(self): + P = self.cls(BASE) + self.assertFalse((P / 'fileA').is_symlink()) + self.assertFalse((P / 'dirA').is_symlink()) + self.assertFalse((P / 'non-existing').is_symlink()) + if not symlink_skip_reason: + self.assertTrue((P / 'linkA').is_symlink()) + self.assertTrue((P / 'linkB').is_symlink()) + self.assertTrue((P/ 'brokenLink').is_symlink()) + + def test_is_fifo_false(self): + P = self.cls(BASE) + self.assertFalse((P / 'fileA').is_fifo()) + self.assertFalse((P / 'dirA').is_fifo()) + self.assertFalse((P / 'non-existing').is_fifo()) + + @unittest.skipUnless(hasattr(os, "mkfifo"), "os.mkfifo() required") + def test_is_fifo_true(self): + P = self.cls(BASE, 'myfifo') + os.mkfifo(str(P)) + self.assertTrue(P.is_fifo()) + self.assertFalse(P.is_socket()) + self.assertFalse(P.is_file()) + + def test_is_socket_false(self): + P = self.cls(BASE) + self.assertFalse((P / 'fileA').is_socket()) + self.assertFalse((P / 'dirA').is_socket()) + self.assertFalse((P / 'non-existing').is_socket()) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") + def test_is_socket_true(self): + P = self.cls(BASE, 'mysock') + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + try: + sock.bind(str(P)) + except OSError as e: + if "AF_UNIX path too long" in str(e): + self.skipTest("cannot bind Unix socket: " + str(e)) + self.assertTrue(P.is_socket()) + self.assertFalse(P.is_fifo()) + self.assertFalse(P.is_file()) + + def test_is_block_device_false(self): + P = self.cls(BASE) + self.assertFalse((P / 'fileA').is_block_device()) + self.assertFalse((P / 'dirA').is_block_device()) + self.assertFalse((P / 'non-existing').is_block_device()) + + def test_is_char_device_false(self): + P = self.cls(BASE) + self.assertFalse((P / 'fileA').is_char_device()) + self.assertFalse((P / 'dirA').is_char_device()) + self.assertFalse((P / 'non-existing').is_char_device()) + + def test_is_char_device_true(self): + # Under Unix, /dev/null should generally be a char device + P = self.cls('/dev/null') + if not P.exists(): + self.skipTest("/dev/null required") + self.assertTrue(P.is_char_device()) + self.assertFalse(P.is_block_device()) + self.assertFalse(P.is_file()) + + def test_pickling_common(self): + p = self.cls(BASE, 'fileA') + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + dumped = pickle.dumps(p, proto) + pp = pickle.loads(dumped) + self.assertEqual(pp.stat(), p.stat()) + + def test_parts_interning(self): + P = self.cls + p = P('/usr/bin/foo') + q = P('/usr/local/bin') + # 'usr' + self.assertIs(p.parts[1], q.parts[1]) + # 'bin' + self.assertIs(p.parts[2], q.parts[3]) + + def _check_complex_symlinks(self, link0_target): + # Test solving a non-looping chain of symlinks (issue #19887) + P = self.cls(BASE) + self.dirlink(os.path.join('link0', 'link0'), join('link1')) + self.dirlink(os.path.join('link1', 'link1'), join('link2')) + self.dirlink(os.path.join('link2', 'link2'), join('link3')) + self.dirlink(link0_target, join('link0')) + + # Resolve absolute paths + p = (P / 'link0').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + p = (P / 'link1').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + p = (P / 'link2').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + p = (P / 'link3').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + + # Resolve relative paths + old_path = os.getcwd() + os.chdir(BASE) + try: + p = self.cls('link0').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + p = self.cls('link1').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + p = self.cls('link2').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + p = self.cls('link3').resolve() + self.assertEqual(p, P) + self.assertEqual(str(p), BASE) + finally: + os.chdir(old_path) + + @with_symlinks + def test_complex_symlinks_absolute(self): + self._check_complex_symlinks(BASE) + + @with_symlinks + def test_complex_symlinks_relative(self): + self._check_complex_symlinks('.') + + @with_symlinks + def test_complex_symlinks_relative_dot_dot(self): + self._check_complex_symlinks(os.path.join('dirA', '..')) + + +class PathTest(_BasePathTest, unittest.TestCase): + cls = pathlib.Path + + def test_concrete_class(self): + p = self.cls('a') + self.assertIs(type(p), + pathlib.WindowsPath if os.name == 'nt' else pathlib.PosixPath) + + def test_unsupported_flavour(self): + if os.name == 'nt': + self.assertRaises(NotImplementedError, pathlib.PosixPath) + else: + self.assertRaises(NotImplementedError, pathlib.WindowsPath) + + +@only_posix +class PosixPathTest(_BasePathTest, unittest.TestCase): + cls = pathlib.PosixPath + + def _check_symlink_loop(self, *args): + path = self.cls(*args) + with self.assertRaises(RuntimeError): + print(path.resolve()) + + def test_open_mode(self): + old_mask = os.umask(0) + self.addCleanup(os.umask, old_mask) + p = self.cls(BASE) + with (p / 'new_file').open('wb'): + pass + st = os.stat(join('new_file')) + self.assertEqual(stat.S_IMODE(st.st_mode), 0o666) + os.umask(0o022) + with (p / 'other_new_file').open('wb'): + pass + st = os.stat(join('other_new_file')) + self.assertEqual(stat.S_IMODE(st.st_mode), 0o644) + + def test_touch_mode(self): + old_mask = os.umask(0) + self.addCleanup(os.umask, old_mask) + p = self.cls(BASE) + (p / 'new_file').touch() + st = os.stat(join('new_file')) + self.assertEqual(stat.S_IMODE(st.st_mode), 0o666) + os.umask(0o022) + (p / 'other_new_file').touch() + st = os.stat(join('other_new_file')) + self.assertEqual(stat.S_IMODE(st.st_mode), 0o644) + (p / 'masked_new_file').touch(mode=0o750) + st = os.stat(join('masked_new_file')) + self.assertEqual(stat.S_IMODE(st.st_mode), 0o750) + + @with_symlinks + def test_resolve_loop(self): + # Loop detection for broken symlinks under POSIX + P = self.cls + # Loops with relative symlinks + os.symlink('linkX/inside', join('linkX')) + self._check_symlink_loop(BASE, 'linkX') + os.symlink('linkY', join('linkY')) + self._check_symlink_loop(BASE, 'linkY') + os.symlink('linkZ/../linkZ', join('linkZ')) + self._check_symlink_loop(BASE, 'linkZ') + # Loops with absolute symlinks + os.symlink(join('linkU/inside'), join('linkU')) + self._check_symlink_loop(BASE, 'linkU') + os.symlink(join('linkV'), join('linkV')) + self._check_symlink_loop(BASE, 'linkV') + os.symlink(join('linkW/../linkW'), join('linkW')) + self._check_symlink_loop(BASE, 'linkW') + + def test_glob(self): + P = self.cls + p = P(BASE) + given = set(p.glob("FILEa")) + expect = set() if not support.fs_is_case_insensitive(BASE) else given + self.assertEqual(given, expect) + self.assertEqual(set(p.glob("FILEa*")), set()) + + def test_rglob(self): + P = self.cls + p = P(BASE, "dirC") + given = set(p.rglob("FILEd")) + expect = set() if not support.fs_is_case_insensitive(BASE) else given + self.assertEqual(given, expect) + self.assertEqual(set(p.rglob("FILEd*")), set()) + + +@only_nt +class WindowsPathTest(_BasePathTest, unittest.TestCase): + cls = pathlib.WindowsPath + + def test_glob(self): + P = self.cls + p = P(BASE) + self.assertEqual(set(p.glob("FILEa")), { P(BASE, "fileA") }) + + def test_rglob(self): + P = self.cls + p = P(BASE, "dirC") + self.assertEqual(set(p.rglob("FILEd")), { P(BASE, "dirC/dirD/fileD") }) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py index 03084e4d12..74253b338f 100644 --- a/Lib/test/test_pdb.py +++ b/Lib/test/test_pdb.py @@ -1,9 +1,9 @@ # A test suite for pdb; not very comprehensive at the moment. import doctest -import imp import pdb import sys +import types import unittest import subprocess import textwrap @@ -205,7 +205,8 @@ def test_pdb_breakpoint_commands(): ... 'enable 1', ... 'clear 1', ... 'commands 2', - ... 'print 42', + ... 'p "42"', + ... 'print("42", 7*6)', # Issue 18764 (not about breakpoints) ... 'end', ... 'continue', # will stop at breakpoint 2 (line 4) ... 'clear', # clear all! @@ -252,11 +253,13 @@ def test_pdb_breakpoint_commands(): (Pdb) clear 1 Deleted breakpoint 1 at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:3 (Pdb) commands 2 - (com) print 42 + (com) p "42" + (com) print("42", 7*6) (com) end (Pdb) continue 1 - 42 + '42' + 42 42 > <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>(4)test_function() -> print(2) (Pdb) clear @@ -464,7 +467,7 @@ def test_pdb_skip_modules(): # Module for testing skipping of module that makes a callback -mod = imp.new_module('module_to_skip') +mod = types.ModuleType('module_to_skip') exec('def foo_pony(callback): x = 1; callback(); return None', mod.__dict__) @@ -597,6 +600,311 @@ def test_pdb_run_with_code_object(): (Pdb) continue """ +def test_next_until_return_at_return_event(): + """Test that pdb stops after a next/until/return issued at a return debug event. + + >>> def test_function_2(): + ... x = 1 + ... x = 2 + + >>> def test_function(): + ... import pdb; pdb.Pdb(nosigint=True).set_trace() + ... test_function_2() + ... test_function_2() + ... test_function_2() + ... end = 1 + + >>> with PdbTestInput(['break test_function_2', + ... 'continue', + ... 'return', + ... 'next', + ... 'continue', + ... 'return', + ... 'until', + ... 'continue', + ... 'return', + ... 'return', + ... 'continue']): + ... test_function() + > <doctest test.test_pdb.test_next_until_return_at_return_event[1]>(3)test_function() + -> test_function_2() + (Pdb) break test_function_2 + Breakpoint 1 at <doctest test.test_pdb.test_next_until_return_at_return_event[0]>:1 + (Pdb) continue + > <doctest test.test_pdb.test_next_until_return_at_return_event[0]>(2)test_function_2() + -> x = 1 + (Pdb) return + --Return-- + > <doctest test.test_pdb.test_next_until_return_at_return_event[0]>(3)test_function_2()->None + -> x = 2 + (Pdb) next + > <doctest test.test_pdb.test_next_until_return_at_return_event[1]>(4)test_function() + -> test_function_2() + (Pdb) continue + > <doctest test.test_pdb.test_next_until_return_at_return_event[0]>(2)test_function_2() + -> x = 1 + (Pdb) return + --Return-- + > <doctest test.test_pdb.test_next_until_return_at_return_event[0]>(3)test_function_2()->None + -> x = 2 + (Pdb) until + > <doctest test.test_pdb.test_next_until_return_at_return_event[1]>(5)test_function() + -> test_function_2() + (Pdb) continue + > <doctest test.test_pdb.test_next_until_return_at_return_event[0]>(2)test_function_2() + -> x = 1 + (Pdb) return + --Return-- + > <doctest test.test_pdb.test_next_until_return_at_return_event[0]>(3)test_function_2()->None + -> x = 2 + (Pdb) return + > <doctest test.test_pdb.test_next_until_return_at_return_event[1]>(6)test_function() + -> end = 1 + (Pdb) continue + """ + +def test_pdb_next_command_for_generator(): + """Testing skip unwindng stack on yield for generators for "next" command + + >>> def test_gen(): + ... yield 0 + ... return 1 + ... yield 2 + + >>> def test_function(): + ... import pdb; pdb.Pdb(nosigint=True).set_trace() + ... it = test_gen() + ... try: + ... assert next(it) == 0 + ... next(it) + ... except StopIteration as ex: + ... assert ex.value == 1 + ... print("finished") + + >>> with PdbTestInput(['step', + ... 'step', + ... 'step', + ... 'next', + ... 'next', + ... 'step', + ... 'step', + ... 'continue']): + ... test_function() + > <doctest test.test_pdb.test_pdb_next_command_for_generator[1]>(3)test_function() + -> it = test_gen() + (Pdb) step + > <doctest test.test_pdb.test_pdb_next_command_for_generator[1]>(4)test_function() + -> try: + (Pdb) step + > <doctest test.test_pdb.test_pdb_next_command_for_generator[1]>(5)test_function() + -> assert next(it) == 0 + (Pdb) step + --Call-- + > <doctest test.test_pdb.test_pdb_next_command_for_generator[0]>(1)test_gen() + -> def test_gen(): + (Pdb) next + > <doctest test.test_pdb.test_pdb_next_command_for_generator[0]>(2)test_gen() + -> yield 0 + (Pdb) next + > <doctest test.test_pdb.test_pdb_next_command_for_generator[0]>(3)test_gen() + -> return 1 + (Pdb) step + --Return-- + > <doctest test.test_pdb.test_pdb_next_command_for_generator[0]>(3)test_gen()->1 + -> return 1 + (Pdb) step + StopIteration: 1 + > <doctest test.test_pdb.test_pdb_next_command_for_generator[1]>(6)test_function() + -> next(it) + (Pdb) continue + finished + """ + +def test_pdb_return_command_for_generator(): + """Testing no unwindng stack on yield for generators + for "return" command + + >>> def test_gen(): + ... yield 0 + ... return 1 + ... yield 2 + + >>> def test_function(): + ... import pdb; pdb.Pdb(nosigint=True).set_trace() + ... it = test_gen() + ... try: + ... assert next(it) == 0 + ... next(it) + ... except StopIteration as ex: + ... assert ex.value == 1 + ... print("finished") + + >>> with PdbTestInput(['step', + ... 'step', + ... 'step', + ... 'return', + ... 'step', + ... 'step', + ... 'continue']): + ... test_function() + > <doctest test.test_pdb.test_pdb_return_command_for_generator[1]>(3)test_function() + -> it = test_gen() + (Pdb) step + > <doctest test.test_pdb.test_pdb_return_command_for_generator[1]>(4)test_function() + -> try: + (Pdb) step + > <doctest test.test_pdb.test_pdb_return_command_for_generator[1]>(5)test_function() + -> assert next(it) == 0 + (Pdb) step + --Call-- + > <doctest test.test_pdb.test_pdb_return_command_for_generator[0]>(1)test_gen() + -> def test_gen(): + (Pdb) return + StopIteration: 1 + > <doctest test.test_pdb.test_pdb_return_command_for_generator[1]>(6)test_function() + -> next(it) + (Pdb) step + > <doctest test.test_pdb.test_pdb_return_command_for_generator[1]>(7)test_function() + -> except StopIteration as ex: + (Pdb) step + > <doctest test.test_pdb.test_pdb_return_command_for_generator[1]>(8)test_function() + -> assert ex.value == 1 + (Pdb) continue + finished + """ + +def test_pdb_until_command_for_generator(): + """Testing no unwindng stack on yield for generators + for "until" command if target breakpoing is not reached + + >>> def test_gen(): + ... yield 0 + ... yield 1 + ... yield 2 + + >>> def test_function(): + ... import pdb; pdb.Pdb(nosigint=True).set_trace() + ... for i in test_gen(): + ... print(i) + ... print("finished") + + >>> with PdbTestInput(['step', + ... 'until 4', + ... 'step', + ... 'step', + ... 'continue']): + ... test_function() + > <doctest test.test_pdb.test_pdb_until_command_for_generator[1]>(3)test_function() + -> for i in test_gen(): + (Pdb) step + --Call-- + > <doctest test.test_pdb.test_pdb_until_command_for_generator[0]>(1)test_gen() + -> def test_gen(): + (Pdb) until 4 + 0 + 1 + > <doctest test.test_pdb.test_pdb_until_command_for_generator[0]>(4)test_gen() + -> yield 2 + (Pdb) step + --Return-- + > <doctest test.test_pdb.test_pdb_until_command_for_generator[0]>(4)test_gen()->2 + -> yield 2 + (Pdb) step + > <doctest test.test_pdb.test_pdb_until_command_for_generator[1]>(4)test_function() + -> print(i) + (Pdb) continue + 2 + finished + """ + +def test_pdb_next_command_in_generator_for_loop(): + """The next command on returning from a generator controled by a for loop. + + >>> def test_gen(): + ... yield 0 + ... return 1 + + >>> def test_function(): + ... import pdb; pdb.Pdb(nosigint=True).set_trace() + ... for i in test_gen(): + ... print('value', i) + ... x = 123 + + >>> with PdbTestInput(['break test_gen', + ... 'continue', + ... 'next', + ... 'next', + ... 'next', + ... 'continue']): + ... test_function() + > <doctest test.test_pdb.test_pdb_next_command_in_generator_for_loop[1]>(3)test_function() + -> for i in test_gen(): + (Pdb) break test_gen + Breakpoint 6 at <doctest test.test_pdb.test_pdb_next_command_in_generator_for_loop[0]>:1 + (Pdb) continue + > <doctest test.test_pdb.test_pdb_next_command_in_generator_for_loop[0]>(2)test_gen() + -> yield 0 + (Pdb) next + value 0 + > <doctest test.test_pdb.test_pdb_next_command_in_generator_for_loop[0]>(3)test_gen() + -> return 1 + (Pdb) next + Internal StopIteration: 1 + > <doctest test.test_pdb.test_pdb_next_command_in_generator_for_loop[1]>(3)test_function() + -> for i in test_gen(): + (Pdb) next + > <doctest test.test_pdb.test_pdb_next_command_in_generator_for_loop[1]>(5)test_function() + -> x = 123 + (Pdb) continue + """ + +def test_pdb_next_command_subiterator(): + """The next command in a generator with a subiterator. + + >>> def test_subgenerator(): + ... yield 0 + ... return 1 + + >>> def test_gen(): + ... x = yield from test_subgenerator() + ... return x + + >>> def test_function(): + ... import pdb; pdb.Pdb(nosigint=True).set_trace() + ... for i in test_gen(): + ... print('value', i) + ... x = 123 + + >>> with PdbTestInput(['step', + ... 'step', + ... 'next', + ... 'next', + ... 'next', + ... 'continue']): + ... test_function() + > <doctest test.test_pdb.test_pdb_next_command_subiterator[2]>(3)test_function() + -> for i in test_gen(): + (Pdb) step + --Call-- + > <doctest test.test_pdb.test_pdb_next_command_subiterator[1]>(1)test_gen() + -> def test_gen(): + (Pdb) step + > <doctest test.test_pdb.test_pdb_next_command_subiterator[1]>(2)test_gen() + -> x = yield from test_subgenerator() + (Pdb) next + value 0 + > <doctest test.test_pdb.test_pdb_next_command_subiterator[1]>(3)test_gen() + -> return x + (Pdb) next + Internal StopIteration: 1 + > <doctest test.test_pdb.test_pdb_next_command_subiterator[2]>(3)test_function() + -> for i in test_gen(): + (Pdb) next + > <doctest test.test_pdb.test_pdb_next_command_subiterator[2]>(5)test_function() + -> x = 123 + (Pdb) continue + """ + class PdbTestCase(unittest.TestCase): @@ -617,6 +925,36 @@ class PdbTestCase(unittest.TestCase): stderr = stderr and bytes.decode(stderr) return stdout, stderr + def _assert_find_function(self, file_content, func_name, expected): + file_content = textwrap.dedent(file_content) + + with open(support.TESTFN, 'w') as f: + f.write(file_content) + + expected = None if not expected else ( + expected[0], support.TESTFN, expected[1]) + self.assertEqual( + expected, pdb.find_function(func_name, support.TESTFN)) + + def test_find_function_empty_file(self): + self._assert_find_function('', 'foo', None) + + def test_find_function_found(self): + self._assert_find_function( + """\ + def foo(): + pass + + def bar(): + pass + + def quux(): + pass + """, + 'bar', + ('bar', 4), + ) + def test_issue7964(self): # open the file as binary so we can force \r\n newline with open(support.TESTFN, 'wb') as f: diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py index 1cacdea569..50257923fe 100644 --- a/Lib/test/test_peepholer.py +++ b/Lib/test/test_peepholer.py @@ -5,44 +5,28 @@ from io import StringIO import unittest from math import copysign -def disassemble(func): - f = StringIO() - tmp = sys.stdout - sys.stdout = f - try: - dis.dis(func) - finally: - sys.stdout = tmp - result = f.getvalue() - f.close() - return result +from test.bytecode_helper import BytecodeTestCase -def dis_single(line): - return disassemble(compile(line, '', 'single')) - - -class TestTranforms(unittest.TestCase): +class TestTranforms(BytecodeTestCase): def test_unot(self): # UNARY_NOT POP_JUMP_IF_FALSE --> POP_JUMP_IF_TRUE' def unot(x): if not x == 2: del x - asm = disassemble(unot) - for elem in ('UNARY_NOT', 'POP_JUMP_IF_FALSE'): - self.assertNotIn(elem, asm) - for elem in ('POP_JUMP_IF_TRUE',): - self.assertIn(elem, asm) + self.assertNotInBytecode(unot, 'UNARY_NOT') + self.assertNotInBytecode(unot, 'POP_JUMP_IF_FALSE') + self.assertInBytecode(unot, 'POP_JUMP_IF_TRUE') def test_elim_inversion_of_is_or_in(self): - for line, elem in ( - ('not a is b', '(is not)',), - ('not a in b', '(not in)',), - ('not a is not b', '(is)',), - ('not a not in b', '(in)',), + for line, cmp_op in ( + ('not a is b', 'is not',), + ('not a in b', 'not in',), + ('not a is not b', 'is',), + ('not a not in b', 'in',), ): - asm = dis_single(line) - self.assertIn(elem, asm) + code = compile(line, '', 'single') + self.assertInBytecode(code, 'COMPARE_OP', cmp_op) def test_global_as_constant(self): # LOAD_GLOBAL None/True/False --> LOAD_CONST None/True/False @@ -56,17 +40,14 @@ class TestTranforms(unittest.TestCase): def h(x): False return x - for func, name in ((f, 'None'), (g, 'True'), (h, 'False')): - asm = disassemble(func) - for elem in ('LOAD_GLOBAL',): - self.assertNotIn(elem, asm) - for elem in ('LOAD_CONST', '('+name+')'): - self.assertIn(elem, asm) + for func, elem in ((f, None), (g, True), (h, False)): + self.assertNotInBytecode(func, 'LOAD_GLOBAL') + self.assertInBytecode(func, 'LOAD_CONST', elem) def f(): 'Adding a docstring made this test fail in Py2.5.0' return None - self.assertIn('LOAD_CONST', disassemble(f)) - self.assertNotIn('LOAD_GLOBAL', disassemble(f)) + self.assertNotInBytecode(f, 'LOAD_GLOBAL') + self.assertInBytecode(f, 'LOAD_CONST', None) def test_while_one(self): # Skip over: LOAD_CONST trueconst POP_JUMP_IF_FALSE xx @@ -74,11 +55,10 @@ class TestTranforms(unittest.TestCase): while 1: pass return list - asm = disassemble(f) for elem in ('LOAD_CONST', 'POP_JUMP_IF_FALSE'): - self.assertNotIn(elem, asm) + self.assertNotInBytecode(f, elem) for elem in ('JUMP_ABSOLUTE',): - self.assertIn(elem, asm) + self.assertInBytecode(f, elem) def test_pack_unpack(self): for line, elem in ( @@ -86,28 +66,30 @@ class TestTranforms(unittest.TestCase): ('a, b = a, b', 'ROT_TWO',), ('a, b, c = a, b, c', 'ROT_THREE',), ): - asm = dis_single(line) - self.assertIn(elem, asm) - self.assertNotIn('BUILD_TUPLE', asm) - self.assertNotIn('UNPACK_TUPLE', asm) + code = compile(line,'','single') + self.assertInBytecode(code, elem) + self.assertNotInBytecode(code, 'BUILD_TUPLE') + self.assertNotInBytecode(code, 'UNPACK_TUPLE') def test_folding_of_tuples_of_constants(self): for line, elem in ( - ('a = 1,2,3', '((1, 2, 3))'), - ('("a","b","c")', "(('a', 'b', 'c'))"), - ('a,b,c = 1,2,3', '((1, 2, 3))'), - ('(None, 1, None)', '((None, 1, None))'), - ('((1, 2), 3, 4)', '(((1, 2), 3, 4))'), + ('a = 1,2,3', (1, 2, 3)), + ('("a","b","c")', ('a', 'b', 'c')), + ('a,b,c = 1,2,3', (1, 2, 3)), + ('(None, 1, None)', (None, 1, None)), + ('((1, 2), 3, 4)', ((1, 2), 3, 4)), ): - asm = dis_single(line) - self.assertIn(elem, asm) - self.assertNotIn('BUILD_TUPLE', asm) + code = compile(line,'','single') + self.assertInBytecode(code, 'LOAD_CONST', elem) + self.assertNotInBytecode(code, 'BUILD_TUPLE') # Long tuples should be folded too. - asm = dis_single(repr(tuple(range(10000)))) + code = compile(repr(tuple(range(10000))),'','single') + self.assertNotInBytecode(code, 'BUILD_TUPLE') # One LOAD_CONST for the tuple, one for the None return value - self.assertEqual(asm.count('LOAD_CONST'), 2) - self.assertNotIn('BUILD_TUPLE', asm) + load_consts = [instr for instr in dis.get_instructions(code) + if instr.opname == 'LOAD_CONST'] + self.assertEqual(len(load_consts), 2) # Bug 1053819: Tuple of constants misidentified when presented with: # . . . opcode_with_arg 100 unary_opcode BUILD_TUPLE 1 . . . @@ -129,14 +111,14 @@ class TestTranforms(unittest.TestCase): def test_folding_of_lists_of_constants(self): for line, elem in ( # in/not in constants with BUILD_LIST should be folded to a tuple: - ('a in [1,2,3]', '(1, 2, 3)'), - ('a not in ["a","b","c"]', "(('a', 'b', 'c'))"), - ('a in [None, 1, None]', '((None, 1, None))'), - ('a not in [(1, 2), 3, 4]', '(((1, 2), 3, 4))'), + ('a in [1,2,3]', (1, 2, 3)), + ('a not in ["a","b","c"]', ('a', 'b', 'c')), + ('a in [None, 1, None]', (None, 1, None)), + ('a not in [(1, 2), 3, 4]', ((1, 2), 3, 4)), ): - asm = dis_single(line) - self.assertIn(elem, asm) - self.assertNotIn('BUILD_LIST', asm) + code = compile(line, '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', elem) + self.assertNotInBytecode(code, 'BUILD_LIST') def test_folding_of_sets_of_constants(self): for line, elem in ( @@ -147,18 +129,9 @@ class TestTranforms(unittest.TestCase): ('a not in {(1, 2), 3, 4}', frozenset({(1, 2), 3, 4})), ('a in {1, 2, 3, 3, 2, 1}', frozenset({1, 2, 3})), ): - asm = dis_single(line) - self.assertNotIn('BUILD_SET', asm) - - # Verify that the frozenset 'elem' is in the disassembly - # The ordering of the elements in repr( frozenset ) isn't - # guaranteed, so we jump through some hoops to ensure that we have - # the frozenset we expect: - self.assertIn('frozenset', asm) - # Extract the frozenset literal from the disassembly: - m = re.match(r'.*(frozenset\({.*}\)).*', asm, re.DOTALL) - self.assertTrue(m) - self.assertEqual(eval(m.group(1)), elem) + code = compile(line, '', 'single') + self.assertNotInBytecode(code, 'BUILD_SET') + self.assertInBytecode(code, 'LOAD_CONST', elem) # Ensure that the resulting code actually works: def f(a): @@ -176,98 +149,103 @@ class TestTranforms(unittest.TestCase): def test_folding_of_binops_on_constants(self): for line, elem in ( - ('a = 2+3+4', '(9)'), # chained fold - ('"@"*4', "('@@@@')"), # check string ops - ('a="abc" + "def"', "('abcdef')"), # check string ops - ('a = 3**4', '(81)'), # binary power - ('a = 3*4', '(12)'), # binary multiply - ('a = 13//4', '(3)'), # binary floor divide - ('a = 14%4', '(2)'), # binary modulo - ('a = 2+3', '(5)'), # binary add - ('a = 13-4', '(9)'), # binary subtract - ('a = (12,13)[1]', '(13)'), # binary subscr - ('a = 13 << 2', '(52)'), # binary lshift - ('a = 13 >> 2', '(3)'), # binary rshift - ('a = 13 & 7', '(5)'), # binary and - ('a = 13 ^ 7', '(10)'), # binary xor - ('a = 13 | 7', '(15)'), # binary or + ('a = 2+3+4', 9), # chained fold + ('"@"*4', '@@@@'), # check string ops + ('a="abc" + "def"', 'abcdef'), # check string ops + ('a = 3**4', 81), # binary power + ('a = 3*4', 12), # binary multiply + ('a = 13//4', 3), # binary floor divide + ('a = 14%4', 2), # binary modulo + ('a = 2+3', 5), # binary add + ('a = 13-4', 9), # binary subtract + ('a = (12,13)[1]', 13), # binary subscr + ('a = 13 << 2', 52), # binary lshift + ('a = 13 >> 2', 3), # binary rshift + ('a = 13 & 7', 5), # binary and + ('a = 13 ^ 7', 10), # binary xor + ('a = 13 | 7', 15), # binary or ): - asm = dis_single(line) - self.assertIn(elem, asm, asm) - self.assertNotIn('BINARY_', asm) + code = compile(line, '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', elem) + for instr in dis.get_instructions(code): + self.assertFalse(instr.opname.startswith('BINARY_')) # Verify that unfoldables are skipped - asm = dis_single('a=2+"b"') - self.assertIn('(2)', asm) - self.assertIn("('b')", asm) + code = compile('a=2+"b"', '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', 2) + self.assertInBytecode(code, 'LOAD_CONST', 'b') # Verify that large sequences do not result from folding - asm = dis_single('a="x"*1000') - self.assertIn('(1000)', asm) + code = compile('a="x"*1000', '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', 1000) def test_binary_subscr_on_unicode(self): # valid code get optimized - asm = dis_single('"foo"[0]') - self.assertIn("('f')", asm) - self.assertNotIn('BINARY_SUBSCR', asm) - asm = dis_single('"\u0061\uffff"[1]') - self.assertIn("('\\uffff')", asm) - self.assertNotIn('BINARY_SUBSCR', asm) - asm = dis_single('"\U00012345abcdef"[3]') - self.assertIn("('c')", asm) - self.assertNotIn('BINARY_SUBSCR', asm) + code = compile('"foo"[0]', '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', 'f') + self.assertNotInBytecode(code, 'BINARY_SUBSCR') + code = compile('"\u0061\uffff"[1]', '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', '\uffff') + self.assertNotInBytecode(code,'BINARY_SUBSCR') + + # With PEP 393, non-BMP char get optimized + code = compile('"\U00012345"[0]', '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', '\U00012345') + self.assertNotInBytecode(code, 'BINARY_SUBSCR') # invalid code doesn't get optimized # out of range - asm = dis_single('"fuu"[10]') - self.assertIn('BINARY_SUBSCR', asm) + code = compile('"fuu"[10]', '', 'single') + self.assertInBytecode(code, 'BINARY_SUBSCR') def test_folding_of_unaryops_on_constants(self): for line, elem in ( - ('-0.5', '(-0.5)'), # unary negative - ('-0.0', '(-0.0)'), # -0.0 - ('-(1.0-1.0)','(-0.0)'), # -0.0 after folding - ('-0', '(0)'), # -0 - ('~-2', '(1)'), # unary invert - ('+1', '(1)'), # unary positive + ('-0.5', -0.5), # unary negative + ('-0.0', -0.0), # -0.0 + ('-(1.0-1.0)', -0.0), # -0.0 after folding + ('-0', 0), # -0 + ('~-2', 1), # unary invert + ('+1', 1), # unary positive ): - asm = dis_single(line) - self.assertIn(elem, asm, asm) - self.assertNotIn('UNARY_', asm) + code = compile(line, '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', elem) + for instr in dis.get_instructions(code): + self.assertFalse(instr.opname.startswith('UNARY_')) # Check that -0.0 works after marshaling def negzero(): return -(1.0-1.0) - self.assertNotIn('UNARY_', disassemble(negzero)) - self.assertTrue(copysign(1.0, negzero()) < 0) + for instr in dis.get_instructions(code): + self.assertFalse(instr.opname.startswith('UNARY_')) # Verify that unfoldables are skipped - for line, elem in ( - ('-"abc"', "('abc')"), # unary negative - ('~"abc"', "('abc')"), # unary invert + for line, elem, opname in ( + ('-"abc"', 'abc', 'UNARY_NEGATIVE'), + ('~"abc"', 'abc', 'UNARY_INVERT'), ): - asm = dis_single(line) - self.assertIn(elem, asm, asm) - self.assertIn('UNARY_', asm) + code = compile(line, '', 'single') + self.assertInBytecode(code, 'LOAD_CONST', elem) + self.assertInBytecode(code, opname) def test_elim_extra_return(self): # RETURN LOAD_CONST None RETURN --> RETURN def f(x): return x - asm = disassemble(f) - self.assertNotIn('LOAD_CONST', asm) - self.assertNotIn('(None)', asm) - self.assertEqual(asm.split().count('RETURN_VALUE'), 1) + self.assertNotInBytecode(f, 'LOAD_CONST', None) + returns = [instr for instr in dis.get_instructions(f) + if instr.opname == 'RETURN_VALUE'] + self.assertEqual(len(returns), 1) def test_elim_jump_to_return(self): # JUMP_FORWARD to RETURN --> RETURN def f(cond, true_value, false_value): return true_value if cond else false_value - asm = disassemble(f) - self.assertNotIn('JUMP_FORWARD', asm) - self.assertNotIn('JUMP_ABSOLUTE', asm) - self.assertEqual(asm.split().count('RETURN_VALUE'), 2) + self.assertNotInBytecode(f, 'JUMP_FORWARD') + self.assertNotInBytecode(f, 'JUMP_ABSOLUTE') + returns = [instr for instr in dis.get_instructions(f) + if instr.opname == 'RETURN_VALUE'] + self.assertEqual(len(returns), 2) def test_elim_jump_after_return1(self): # Eliminate dead code: jumps immediately after returns can't be reached @@ -280,48 +258,53 @@ class TestTranforms(unittest.TestCase): if cond1: return 4 return 5 return 6 - asm = disassemble(f) - self.assertNotIn('JUMP_FORWARD', asm) - self.assertNotIn('JUMP_ABSOLUTE', asm) - self.assertEqual(asm.split().count('RETURN_VALUE'), 6) + self.assertNotInBytecode(f, 'JUMP_FORWARD') + self.assertNotInBytecode(f, 'JUMP_ABSOLUTE') + returns = [instr for instr in dis.get_instructions(f) + if instr.opname == 'RETURN_VALUE'] + self.assertEqual(len(returns), 6) def test_elim_jump_after_return2(self): # Eliminate dead code: jumps immediately after returns can't be reached def f(cond1, cond2): while 1: if cond1: return 4 - asm = disassemble(f) - self.assertNotIn('JUMP_FORWARD', asm) + self.assertNotInBytecode(f, 'JUMP_FORWARD') # There should be one jump for the while loop. - self.assertEqual(asm.split().count('JUMP_ABSOLUTE'), 1) - self.assertEqual(asm.split().count('RETURN_VALUE'), 2) + returns = [instr for instr in dis.get_instructions(f) + if instr.opname == 'JUMP_ABSOLUTE'] + self.assertEqual(len(returns), 1) + returns = [instr for instr in dis.get_instructions(f) + if instr.opname == 'RETURN_VALUE'] + self.assertEqual(len(returns), 2) def test_make_function_doesnt_bail(self): def f(): def g()->1+1: pass return g - asm = disassemble(f) - self.assertNotIn('BINARY_ADD', asm) + self.assertNotInBytecode(f, 'BINARY_ADD') def test_constant_folding(self): # Issue #11244: aggressive constant folding. exprs = [ - "3 * -5", - "-3 * 5", - "2 * (3 * 4)", - "(2 * 3) * 4", - "(-1, 2, 3)", - "(1, -2, 3)", - "(1, 2, -3)", - "(1, 2, -3) * 6", - "lambda x: x in {(3 * -5) + (-1 - 6), (1, -2, 3) * 2, None}", + '3 * -5', + '-3 * 5', + '2 * (3 * 4)', + '(2 * 3) * 4', + '(-1, 2, 3)', + '(1, -2, 3)', + '(1, 2, -3)', + '(1, 2, -3) * 6', + 'lambda x: x in {(3 * -5) + (-1 - 6), (1, -2, 3) * 2, None}', ] for e in exprs: - asm = dis_single(e) - self.assertNotIn('UNARY_', asm, e) - self.assertNotIn('BINARY_', asm, e) - self.assertNotIn('BUILD_', asm, e) + code = compile(e, '', 'single') + for instr in dis.get_instructions(code): + self.assertFalse(instr.opname.startswith('UNARY_')) + self.assertFalse(instr.opname.startswith('BINARY_')) + self.assertFalse(instr.opname.startswith('BUILD_')) + class TestBuglets(unittest.TestCase): @@ -343,7 +326,7 @@ def test_main(verbose=None): support.run_unittest(*test_classes) # verify reference counting - if verbose and hasattr(sys, "gettotalrefcount"): + if verbose and hasattr(sys, 'gettotalrefcount'): import gc counts = [None] * 5 for i in range(len(counts)): diff --git a/Lib/test/test_pep247.py b/Lib/test/test_pep247.py index 7f10472856..b85a26ad65 100644 --- a/Lib/test/test_pep247.py +++ b/Lib/test/test_pep247.py @@ -15,12 +15,14 @@ class Pep247Test(unittest.TestCase): self.assertTrue(module.digest_size is None or module.digest_size > 0) self.check_object(module.new, module.digest_size, key) - def check_object(self, cls, digest_size, key): + def check_object(self, cls, digest_size, key, digestmod=None): if key is not None: - obj1 = cls(key) - obj2 = cls(key, b'string') - h1 = cls(key, b'string').digest() - obj3 = cls(key) + if digestmod is None: + digestmod = md5 + obj1 = cls(key, digestmod=digestmod) + obj2 = cls(key, b'string', digestmod=digestmod) + h1 = cls(key, b'string', digestmod=digestmod).digest() + obj3 = cls(key, digestmod=digestmod) obj3.update(b'string') h2 = obj3.digest() else: diff --git a/Lib/test/test_pep277.py b/Lib/test/test_pep277.py index 4b16cbb383..9bae6dcad7 100644 --- a/Lib/test/test_pep277.py +++ b/Lib/test/test_pep277.py @@ -99,10 +99,6 @@ class UnicodeFileTests(unittest.TestCase): with self.assertRaises(expected_exception) as c: fn(filename) exc_filename = c.exception.filename - # listdir may append a wildcard to the filename - if fn is os.listdir and sys.platform == 'win32': - exc_filename, _, wildcard = exc_filename.rpartition(os.sep) - self.assertEqual(wildcard, '*.*') if check_filename: self.assertEqual(exc_filename, filename, "Function '%s(%a) failed " "with bad filename in the exception: %a" % diff --git a/Lib/test/test_pep3151.py b/Lib/test/test_pep3151.py index 2792c10f98..7d4a5d8446 100644 --- a/Lib/test/test_pep3151.py +++ b/Lib/test/test_pep3151.py @@ -157,8 +157,6 @@ class AttributesTest(unittest.TestCase): e.characters_written = 5 self.assertEqual(e.characters_written, 5) - # XXX VMSError not tested - class ExplicitSubclassingTest(unittest.TestCase): diff --git a/Lib/test/test_pep352.py b/Lib/test/test_pep352.py index 558cdb56d2..7c98c460b9 100644 --- a/Lib/test/test_pep352.py +++ b/Lib/test/test_pep352.py @@ -1,7 +1,6 @@ import unittest import builtins import warnings -from test.support import run_unittest import os from platform import system as platform_system @@ -180,8 +179,6 @@ class UsageTests(unittest.TestCase): # Catching a string is bad. self.catch_fails("spam") -def test_main(): - run_unittest(ExceptionClassTests, UsageTests) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index fbe96ac2a8..0b2fe1ef2a 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -83,13 +83,17 @@ class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): class PyDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle._Pickler + def get_dispatch_table(self): return pickle.dispatch_table.copy() class PyChainDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle._Pickler + def get_dispatch_table(self): return collections.ChainMap({}, pickle.dispatch_table) diff --git a/Lib/test/test_pkg.py b/Lib/test/test_pkg.py index 355efe7566..9883000945 100644 --- a/Lib/test/test_pkg.py +++ b/Lib/test/test_pkg.py @@ -199,14 +199,14 @@ class TestPkg(unittest.TestCase): import t5 self.assertEqual(fixdir(dir(t5)), ['__cached__', '__doc__', '__file__', '__loader__', - '__name__', '__package__', '__path__', 'foo', - 'string', 't5']) + '__name__', '__package__', '__path__', '__spec__', + 'foo', 'string', 't5']) self.assertEqual(fixdir(dir(t5.foo)), ['__cached__', '__doc__', '__file__', '__loader__', - '__name__', '__package__', 'string']) + '__name__', '__package__', '__spec__', 'string']) self.assertEqual(fixdir(dir(t5.string)), ['__cached__', '__doc__', '__file__', '__loader__', - '__name__', '__package__', 'spam']) + '__name__', '__package__', '__spec__', 'spam']) def test_6(self): hier = [ @@ -222,14 +222,15 @@ class TestPkg(unittest.TestCase): import t6 self.assertEqual(fixdir(dir(t6)), ['__all__', '__cached__', '__doc__', '__file__', - '__loader__', '__name__', '__package__', '__path__']) + '__loader__', '__name__', '__package__', '__path__', + '__spec__']) s = """ import t6 from t6 import * self.assertEqual(fixdir(dir(t6)), ['__all__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', - '__path__', 'eggs', 'ham', 'spam']) + '__path__', '__spec__', 'eggs', 'ham', 'spam']) self.assertEqual(dir(), ['eggs', 'ham', 'self', 'spam', 't6']) """ self.run_code(s) @@ -256,18 +257,19 @@ class TestPkg(unittest.TestCase): import t7 as tas self.assertEqual(fixdir(dir(tas)), ['__cached__', '__doc__', '__file__', '__loader__', - '__name__', '__package__', '__path__']) + '__name__', '__package__', '__path__', '__spec__']) self.assertFalse(t7) from t7 import sub as subpar self.assertEqual(fixdir(dir(subpar)), ['__cached__', '__doc__', '__file__', '__loader__', - '__name__', '__package__', '__path__']) + '__name__', '__package__', '__path__', '__spec__']) self.assertFalse(t7) self.assertFalse(sub) from t7.sub import subsub as subsubsub self.assertEqual(fixdir(dir(subsubsub)), ['__cached__', '__doc__', '__file__', '__loader__', - '__name__', '__package__', '__path__', 'spam']) + '__name__', '__package__', '__path__', '__spec__', + 'spam']) self.assertFalse(t7) self.assertFalse(sub) self.assertFalse(subsub) diff --git a/Lib/test/test_pkgimport.py b/Lib/test/test_pkgimport.py index a8426b5c9a..370b2aae28 100644 --- a/Lib/test/test_pkgimport.py +++ b/Lib/test/test_pkgimport.py @@ -6,7 +6,7 @@ import random import tempfile import unittest -from imp import cache_from_source +from importlib.util import cache_from_source from test.support import run_unittest, create_empty_file class TestImport(unittest.TestCase): diff --git a/Lib/test/test_pkgutil.py b/Lib/test/test_pkgutil.py index fd0661450a..aaa9d8df6b 100644 --- a/Lib/test/test_pkgutil.py +++ b/Lib/test/test_pkgutil.py @@ -1,12 +1,13 @@ from test.support import run_unittest, unload, check_warnings import unittest import sys -import imp import importlib +from importlib.util import spec_from_file_location import pkgutil import os import os.path import tempfile +import types import shutil import zipfile @@ -103,23 +104,20 @@ class PkgutilTests(unittest.TestCase): class PkgutilPEP302Tests(unittest.TestCase): class MyTestLoader(object): - def load_module(self, fullname): - # Create an empty module - mod = sys.modules.setdefault(fullname, imp.new_module(fullname)) - mod.__file__ = "<%s>" % self.__class__.__name__ - mod.__loader__ = self - # Make it a package - mod.__path__ = [] + def exec_module(self, mod): # Count how many times the module is reloaded - mod.__dict__['loads'] = mod.__dict__.get('loads',0) + 1 - return mod + mod.__dict__['loads'] = mod.__dict__.get('loads', 0) + 1 def get_data(self, path): return "Hello, world!" class MyTestImporter(object): - def find_module(self, fullname, path=None): - return PkgutilPEP302Tests.MyTestLoader() + def find_spec(self, fullname, path=None, target=None): + loader = PkgutilPEP302Tests.MyTestLoader() + return spec_from_file_location(fullname, + '<%s>' % loader.__class__.__name__, + loader=loader, + submodule_search_locations=[]) def setUp(self): sys.meta_path.insert(0, self.MyTestImporter()) @@ -200,6 +198,8 @@ class ExtendPathTests(unittest.TestCase): dirname = self.create_init(pkgname) pathitem = os.path.join(dirname, pkgname) fullname = '{}.{}'.format(pkgname, modname) + sys.modules.pop(fullname, None) + sys.modules.pop(pkgname, None) try: self.create_submodule(dirname, pkgname, modname, 0) @@ -208,11 +208,19 @@ class ExtendPathTests(unittest.TestCase): importers = list(iter_importers(fullname)) expected_importer = get_importer(pathitem) for finder in importers: + spec = pkgutil._get_spec(finder, fullname) + loader = spec.loader + try: + loader = loader.loader + except AttributeError: + # For now we still allow raw loaders from + # find_module(). + pass self.assertIsInstance(finder, importlib.machinery.FileFinder) self.assertEqual(finder, expected_importer) - self.assertIsInstance(finder.find_module(fullname), + self.assertIsInstance(loader, importlib.machinery.SourceFileLoader) - self.assertIsNone(finder.find_module(pkgname)) + self.assertIsNone(pkgutil._get_spec(finder, pkgname)) with self.assertRaises(ImportError): list(iter_importers('invalid.module')) @@ -222,8 +230,11 @@ class ExtendPathTests(unittest.TestCase): finally: shutil.rmtree(dirname) del sys.path[0] - del sys.modules['spam'] - del sys.modules['spam.eggs'] + try: + del sys.modules['spam'] + del sys.modules['spam.eggs'] + except KeyError: + pass def test_mixed_namespace(self): diff --git a/Lib/test/test_plistlib.py b/Lib/test/test_plistlib.py index 9e86b3de64..bb7cc159b4 100644 --- a/Lib/test/test_plistlib.py +++ b/Lib/test/test_plistlib.py @@ -1,94 +1,95 @@ -# Copyright (C) 2003 Python Software Foundation +# Copyright (C) 2003-2013 Python Software Foundation import unittest import plistlib import os import datetime +import codecs +import binascii +import collections +import struct from test import support +from io import BytesIO +ALL_FORMATS=(plistlib.FMT_XML, plistlib.FMT_BINARY) -# This test data was generated through Cocoa's NSDictionary class -TESTDATA = b"""<?xml version="1.0" encoding="UTF-8"?> -<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" \ -"http://www.apple.com/DTDs/PropertyList-1.0.dtd"> -<plist version="1.0"> -<dict> - <key>aDate</key> - <date>2004-10-26T10:33:33Z</date> - <key>aDict</key> - <dict> - <key>aFalseValue</key> - <false/> - <key>aTrueValue</key> - <true/> - <key>aUnicodeValue</key> - <string>M\xc3\xa4ssig, Ma\xc3\x9f</string> - <key>anotherString</key> - <string><hello & 'hi' there!></string> - <key>deeperDict</key> - <dict> - <key>a</key> - <integer>17</integer> - <key>b</key> - <real>32.5</real> - <key>c</key> - <array> - <integer>1</integer> - <integer>2</integer> - <string>text</string> - </array> - </dict> - </dict> - <key>aFloat</key> - <real>0.5</real> - <key>aList</key> - <array> - <string>A</string> - <string>B</string> - <integer>12</integer> - <real>32.5</real> - <array> - <integer>1</integer> - <integer>2</integer> - <integer>3</integer> - </array> - </array> - <key>aString</key> - <string>Doodah</string> - <key>anEmptyDict</key> - <dict/> - <key>anEmptyList</key> - <array/> - <key>anInt</key> - <integer>728</integer> - <key>nestedData</key> - <array> - <data> - PGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAzxsb3RzIG9mIGJpbmFyeSBndW5r - PgABAgM8bG90cyBvZiBiaW5hcnkgZ3Vuaz4AAQIDPGxvdHMgb2YgYmluYXJ5 - IGd1bms+AAECAzxsb3RzIG9mIGJpbmFyeSBndW5rPgABAgM8bG90cyBvZiBi - aW5hcnkgZ3Vuaz4AAQIDPGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAzxsb3Rz - IG9mIGJpbmFyeSBndW5rPgABAgM8bG90cyBvZiBiaW5hcnkgZ3Vuaz4AAQID - PGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAw== - </data> - </array> - <key>someData</key> - <data> - PGJpbmFyeSBndW5rPg== - </data> - <key>someMoreData</key> - <data> - PGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAzxsb3RzIG9mIGJpbmFyeSBndW5rPgABAgM8 - bG90cyBvZiBiaW5hcnkgZ3Vuaz4AAQIDPGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAzxs - b3RzIG9mIGJpbmFyeSBndW5rPgABAgM8bG90cyBvZiBiaW5hcnkgZ3Vuaz4AAQIDPGxv - dHMgb2YgYmluYXJ5IGd1bms+AAECAzxsb3RzIG9mIGJpbmFyeSBndW5rPgABAgM8bG90 - cyBvZiBiaW5hcnkgZ3Vuaz4AAQIDPGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAw== - </data> - <key>\xc3\x85benraa</key> - <string>That was a unicode key.</string> -</dict> -</plist> -""".replace(b" " * 8, b"\t") # Apple as well as plistlib.py output hard tabs +# The testdata is generated using Mac/Tools/plistlib_generate_testdata.py +# (which using PyObjC to control the Cocoa classes for generating plists) +TESTDATA={ + plistlib.FMT_XML: binascii.a2b_base64(b''' + PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4KPCFET0NU + WVBFIHBsaXN0IFBVQkxJQyAiLS8vQXBwbGUvL0RURCBQTElTVCAxLjAvL0VO + IiAiaHR0cDovL3d3dy5hcHBsZS5jb20vRFREcy9Qcm9wZXJ0eUxpc3QtMS4w + LmR0ZCI+CjxwbGlzdCB2ZXJzaW9uPSIxLjAiPgo8ZGljdD4KCTxrZXk+YUJp + Z0ludDwva2V5PgoJPGludGVnZXI+OTIyMzM3MjAzNjg1NDc3NTc2NDwvaW50 + ZWdlcj4KCTxrZXk+YUJpZ0ludDI8L2tleT4KCTxpbnRlZ2VyPjkyMjMzNzIw + MzY4NTQ3NzU4NTI8L2ludGVnZXI+Cgk8a2V5PmFEYXRlPC9rZXk+Cgk8ZGF0 + ZT4yMDA0LTEwLTI2VDEwOjMzOjMzWjwvZGF0ZT4KCTxrZXk+YURpY3Q8L2tl + eT4KCTxkaWN0PgoJCTxrZXk+YUZhbHNlVmFsdWU8L2tleT4KCQk8ZmFsc2Uv + PgoJCTxrZXk+YVRydWVWYWx1ZTwva2V5PgoJCTx0cnVlLz4KCQk8a2V5PmFV + bmljb2RlVmFsdWU8L2tleT4KCQk8c3RyaW5nPk3DpHNzaWcsIE1hw588L3N0 + cmluZz4KCQk8a2V5PmFub3RoZXJTdHJpbmc8L2tleT4KCQk8c3RyaW5nPiZs + dDtoZWxsbyAmYW1wOyAnaGknIHRoZXJlISZndDs8L3N0cmluZz4KCQk8a2V5 + PmRlZXBlckRpY3Q8L2tleT4KCQk8ZGljdD4KCQkJPGtleT5hPC9rZXk+CgkJ + CTxpbnRlZ2VyPjE3PC9pbnRlZ2VyPgoJCQk8a2V5PmI8L2tleT4KCQkJPHJl + YWw+MzIuNTwvcmVhbD4KCQkJPGtleT5jPC9rZXk+CgkJCTxhcnJheT4KCQkJ + CTxpbnRlZ2VyPjE8L2ludGVnZXI+CgkJCQk8aW50ZWdlcj4yPC9pbnRlZ2Vy + PgoJCQkJPHN0cmluZz50ZXh0PC9zdHJpbmc+CgkJCTwvYXJyYXk+CgkJPC9k + aWN0PgoJPC9kaWN0PgoJPGtleT5hRmxvYXQ8L2tleT4KCTxyZWFsPjAuNTwv + cmVhbD4KCTxrZXk+YUxpc3Q8L2tleT4KCTxhcnJheT4KCQk8c3RyaW5nPkE8 + L3N0cmluZz4KCQk8c3RyaW5nPkI8L3N0cmluZz4KCQk8aW50ZWdlcj4xMjwv + aW50ZWdlcj4KCQk8cmVhbD4zMi41PC9yZWFsPgoJCTxhcnJheT4KCQkJPGlu + dGVnZXI+MTwvaW50ZWdlcj4KCQkJPGludGVnZXI+MjwvaW50ZWdlcj4KCQkJ + PGludGVnZXI+MzwvaW50ZWdlcj4KCQk8L2FycmF5PgoJPC9hcnJheT4KCTxr + ZXk+YU5lZ2F0aXZlQmlnSW50PC9rZXk+Cgk8aW50ZWdlcj4tODAwMDAwMDAw + MDA8L2ludGVnZXI+Cgk8a2V5PmFOZWdhdGl2ZUludDwva2V5PgoJPGludGVn + ZXI+LTU8L2ludGVnZXI+Cgk8a2V5PmFTdHJpbmc8L2tleT4KCTxzdHJpbmc+ + RG9vZGFoPC9zdHJpbmc+Cgk8a2V5PmFuRW1wdHlEaWN0PC9rZXk+Cgk8ZGlj + dC8+Cgk8a2V5PmFuRW1wdHlMaXN0PC9rZXk+Cgk8YXJyYXkvPgoJPGtleT5h + bkludDwva2V5PgoJPGludGVnZXI+NzI4PC9pbnRlZ2VyPgoJPGtleT5uZXN0 + ZWREYXRhPC9rZXk+Cgk8YXJyYXk+CgkJPGRhdGE+CgkJUEd4dmRITWdiMlln + WW1sdVlYSjVJR2QxYm1zK0FBRUNBenhzYjNSeklHOW1JR0pwYm1GeWVTQm5k + VzVyCgkJUGdBQkFnTThiRzkwY3lCdlppQmlhVzVoY25rZ1ozVnVhejRBQVFJ + RFBHeHZkSE1nYjJZZ1ltbHVZWEo1CgkJSUdkMWJtcytBQUVDQXp4c2IzUnpJ + RzltSUdKcGJtRnllU0JuZFc1clBnQUJBZ004Ykc5MGN5QnZaaUJpCgkJYVc1 + aGNua2daM1Z1YXo0QUFRSURQR3h2ZEhNZ2IyWWdZbWx1WVhKNUlHZDFibXMr + QUFFQ0F6eHNiM1J6CgkJSUc5bUlHSnBibUZ5ZVNCbmRXNXJQZ0FCQWdNOGJH + OTBjeUJ2WmlCaWFXNWhjbmtnWjNWdWF6NEFBUUlECgkJUEd4dmRITWdiMlln + WW1sdVlYSjVJR2QxYm1zK0FBRUNBdz09CgkJPC9kYXRhPgoJPC9hcnJheT4K + CTxrZXk+c29tZURhdGE8L2tleT4KCTxkYXRhPgoJUEdKcGJtRnllU0JuZFc1 + clBnPT0KCTwvZGF0YT4KCTxrZXk+c29tZU1vcmVEYXRhPC9rZXk+Cgk8ZGF0 + YT4KCVBHeHZkSE1nYjJZZ1ltbHVZWEo1SUdkMWJtcytBQUVDQXp4c2IzUnpJ + RzltSUdKcGJtRnllU0JuZFc1clBnQUJBZ004CgliRzkwY3lCdlppQmlhVzVo + Y25rZ1ozVnVhejRBQVFJRFBHeHZkSE1nYjJZZ1ltbHVZWEo1SUdkMWJtcytB + QUVDQXp4cwoJYjNSeklHOW1JR0pwYm1GeWVTQm5kVzVyUGdBQkFnTThiRzkw + Y3lCdlppQmlhVzVoY25rZ1ozVnVhejRBQVFJRFBHeHYKCWRITWdiMllnWW1s + dVlYSjVJR2QxYm1zK0FBRUNBenhzYjNSeklHOW1JR0pwYm1GeWVTQm5kVzVy + UGdBQkFnTThiRzkwCgljeUJ2WmlCaWFXNWhjbmtnWjNWdWF6NEFBUUlEUEd4 + dmRITWdiMllnWW1sdVlYSjVJR2QxYm1zK0FBRUNBdz09Cgk8L2RhdGE+Cgk8 + a2V5PsOFYmVucmFhPC9rZXk+Cgk8c3RyaW5nPlRoYXQgd2FzIGEgdW5pY29k + ZSBrZXkuPC9zdHJpbmc+CjwvZGljdD4KPC9wbGlzdD4K'''), + plistlib.FMT_BINARY: binascii.a2b_base64(b''' + YnBsaXN0MDDfEBABAgMEBQYHCAkKCwwNDg8QERITFCgpLzAxMjM0NTc2OFdh + QmlnSW50WGFCaWdJbnQyVWFEYXRlVWFEaWN0VmFGbG9hdFVhTGlzdF8QD2FO + ZWdhdGl2ZUJpZ0ludFxhTmVnYXRpdmVJbnRXYVN0cmluZ1thbkVtcHR5RGlj + dFthbkVtcHR5TGlzdFVhbkludFpuZXN0ZWREYXRhWHNvbWVEYXRhXHNvbWVN + b3JlRGF0YWcAxQBiAGUAbgByAGEAYRN/////////1BQAAAAAAAAAAIAAAAAA + AAAsM0GcuX30AAAA1RUWFxgZGhscHR5bYUZhbHNlVmFsdWVaYVRydWVWYWx1 + ZV1hVW5pY29kZVZhbHVlXWFub3RoZXJTdHJpbmdaZGVlcGVyRGljdAgJawBN + AOQAcwBzAGkAZwAsACAATQBhAN9fEBU8aGVsbG8gJiAnaGknIHRoZXJlIT7T + HyAhIiMkUWFRYlFjEBEjQEBAAAAAAACjJSYnEAEQAlR0ZXh0Iz/gAAAAAAAA + pSorLCMtUUFRQhAMoyUmLhADE////+1foOAAE//////////7VkRvb2RhaNCg + EQLYoTZPEPo8bG90cyBvZiBiaW5hcnkgZ3Vuaz4AAQIDPGxvdHMgb2YgYmlu + YXJ5IGd1bms+AAECAzxsb3RzIG9mIGJpbmFyeSBndW5rPgABAgM8bG90cyBv + ZiBiaW5hcnkgZ3Vuaz4AAQIDPGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAzxs + b3RzIG9mIGJpbmFyeSBndW5rPgABAgM8bG90cyBvZiBiaW5hcnkgZ3Vuaz4A + AQIDPGxvdHMgb2YgYmluYXJ5IGd1bms+AAECAzxsb3RzIG9mIGJpbmFyeSBn + dW5rPgABAgM8bG90cyBvZiBiaW5hcnkgZ3Vuaz4AAQIDTTxiaW5hcnkgZ3Vu + az5fEBdUaGF0IHdhcyBhIHVuaWNvZGUga2V5LgAIACsAMwA8AEIASABPAFUA + ZwB0AHwAiACUAJoApQCuALsAygDTAOQA7QD4AQQBDwEdASsBNgE3ATgBTwFn + AW4BcAFyAXQBdgF/AYMBhQGHAYwBlQGbAZ0BnwGhAaUBpwGwAbkBwAHBAcIB + xQHHAsQC0gAAAAAAAAIBAAAAAAAAADkAAAAAAAAAAAAAAAAAAALs'''), +} class TestPlistlib(unittest.TestCase): @@ -99,12 +100,16 @@ class TestPlistlib(unittest.TestCase): except: pass - def _create(self): + def _create(self, fmt=None): pl = dict( aString="Doodah", aList=["A", "B", 12, 32.5, [1, 2, 3]], aFloat = 0.5, anInt = 728, + aBigInt = 2 ** 63 - 44, + aBigInt2 = 2 ** 63 + 44, + aNegativeInt = -5, + aNegativeBigInt = -80000000000, aDict=dict( anotherString="<hello & 'hi' there!>", aUnicodeValue='M\xe4ssig, Ma\xdf', @@ -112,9 +117,9 @@ class TestPlistlib(unittest.TestCase): aFalseValue=False, deeperDict=dict(a=17, b=32.5, c=[1, 2, "text"]), ), - someData = plistlib.Data(b"<binary gunk>"), - someMoreData = plistlib.Data(b"<lots of binary gunk>\0\1\2\3" * 10), - nestedData = [plistlib.Data(b"<lots of binary gunk>\0\1\2\3" * 10)], + someData = b"<binary gunk>", + someMoreData = b"<lots of binary gunk>\0\1\2\3" * 10, + nestedData = [b"<lots of binary gunk>\0\1\2\3" * 10], aDate = datetime.datetime(2004, 10, 26, 10, 33, 33), anEmptyDict = dict(), anEmptyList = list() @@ -129,49 +134,215 @@ class TestPlistlib(unittest.TestCase): def test_io(self): pl = self._create() - plistlib.writePlist(pl, support.TESTFN) - pl2 = plistlib.readPlist(support.TESTFN) + with open(support.TESTFN, 'wb') as fp: + plistlib.dump(pl, fp) + + with open(support.TESTFN, 'rb') as fp: + pl2 = plistlib.load(fp) + self.assertEqual(dict(pl), dict(pl2)) + self.assertRaises(AttributeError, plistlib.dump, pl, 'filename') + self.assertRaises(AttributeError, plistlib.load, 'filename') + + def test_invalid_type(self): + pl = [ object() ] + + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + self.assertRaises(TypeError, plistlib.dumps, pl, fmt=fmt) + + def test_int(self): + for pl in [0, 2**8-1, 2**8, 2**16-1, 2**16, 2**32-1, 2**32, + 2**63-1, 2**64-1, 1, -2**63]: + for fmt in ALL_FORMATS: + with self.subTest(pl=pl, fmt=fmt): + data = plistlib.dumps(pl, fmt=fmt) + pl2 = plistlib.loads(data) + self.assertIsInstance(pl2, int) + self.assertEqual(pl, pl2) + data2 = plistlib.dumps(pl2, fmt=fmt) + self.assertEqual(data, data2) + + for fmt in ALL_FORMATS: + for pl in (2 ** 64 + 1, 2 ** 127-1, -2**64, -2 ** 127): + with self.subTest(pl=pl, fmt=fmt): + self.assertRaises(OverflowError, plistlib.dumps, + pl, fmt=fmt) + def test_bytes(self): pl = self._create() - data = plistlib.writePlistToBytes(pl) - pl2 = plistlib.readPlistFromBytes(data) + data = plistlib.dumps(pl) + pl2 = plistlib.loads(data) + self.assertNotIsInstance(pl, plistlib._InternalDict) self.assertEqual(dict(pl), dict(pl2)) - data2 = plistlib.writePlistToBytes(pl2) + data2 = plistlib.dumps(pl2) self.assertEqual(data, data2) def test_indentation_array(self): - data = [[[[[[[[{'test': plistlib.Data(b'aaaaaa')}]]]]]]]] - self.assertEqual(plistlib.readPlistFromBytes(plistlib.writePlistToBytes(data)), data) + data = [[[[[[[[{'test': b'aaaaaa'}]]]]]]]] + self.assertEqual(plistlib.loads(plistlib.dumps(data)), data) def test_indentation_dict(self): - data = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': {'9': plistlib.Data(b'aaaaaa')}}}}}}}}} - self.assertEqual(plistlib.readPlistFromBytes(plistlib.writePlistToBytes(data)), data) + data = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': {'9': b'aaaaaa'}}}}}}}}} + self.assertEqual(plistlib.loads(plistlib.dumps(data)), data) def test_indentation_dict_mix(self): - data = {'1': {'2': [{'3': [[[[[{'test': plistlib.Data(b'aaaaaa')}]]]]]}]}} - self.assertEqual(plistlib.readPlistFromBytes(plistlib.writePlistToBytes(data)), data) + data = {'1': {'2': [{'3': [[[[[{'test': b'aaaaaa'}]]]]]}]}} + self.assertEqual(plistlib.loads(plistlib.dumps(data)), data) def test_appleformatting(self): - pl = plistlib.readPlistFromBytes(TESTDATA) - data = plistlib.writePlistToBytes(pl) - self.assertEqual(data, TESTDATA, - "generated data was not identical to Apple's output") + for use_builtin_types in (True, False): + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt, use_builtin_types=use_builtin_types): + pl = plistlib.loads(TESTDATA[fmt], + use_builtin_types=use_builtin_types) + data = plistlib.dumps(pl, fmt=fmt) + self.assertEqual(data, TESTDATA[fmt], + "generated data was not identical to Apple's output") + def test_appleformattingfromliteral(self): - pl = self._create() - pl2 = plistlib.readPlistFromBytes(TESTDATA) - self.assertEqual(dict(pl), dict(pl2), - "generated data was not identical to Apple's output") + self.maxDiff = None + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + pl = self._create(fmt=fmt) + pl2 = plistlib.loads(TESTDATA[fmt]) + self.assertEqual(dict(pl), dict(pl2), + "generated data was not identical to Apple's output") def test_bytesio(self): - from io import BytesIO - b = BytesIO() - pl = self._create() - plistlib.writePlist(pl, b) - pl2 = plistlib.readPlist(BytesIO(b.getvalue())) - self.assertEqual(dict(pl), dict(pl2)) + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + b = BytesIO() + pl = self._create(fmt=fmt) + plistlib.dump(pl, b, fmt=fmt) + pl2 = plistlib.load(BytesIO(b.getvalue())) + self.assertEqual(dict(pl), dict(pl2)) + + def test_keysort_bytesio(self): + pl = collections.OrderedDict() + pl['b'] = 1 + pl['a'] = 2 + pl['c'] = 3 + + for fmt in ALL_FORMATS: + for sort_keys in (False, True): + with self.subTest(fmt=fmt, sort_keys=sort_keys): + b = BytesIO() + + plistlib.dump(pl, b, fmt=fmt, sort_keys=sort_keys) + pl2 = plistlib.load(BytesIO(b.getvalue()), + dict_type=collections.OrderedDict) + + self.assertEqual(dict(pl), dict(pl2)) + if sort_keys: + self.assertEqual(list(pl2.keys()), ['a', 'b', 'c']) + else: + self.assertEqual(list(pl2.keys()), ['b', 'a', 'c']) + + def test_keysort(self): + pl = collections.OrderedDict() + pl['b'] = 1 + pl['a'] = 2 + pl['c'] = 3 + + for fmt in ALL_FORMATS: + for sort_keys in (False, True): + with self.subTest(fmt=fmt, sort_keys=sort_keys): + data = plistlib.dumps(pl, fmt=fmt, sort_keys=sort_keys) + pl2 = plistlib.loads(data, dict_type=collections.OrderedDict) + + self.assertEqual(dict(pl), dict(pl2)) + if sort_keys: + self.assertEqual(list(pl2.keys()), ['a', 'b', 'c']) + else: + self.assertEqual(list(pl2.keys()), ['b', 'a', 'c']) + + def test_keys_no_string(self): + pl = { 42: 'aNumber' } + + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + self.assertRaises(TypeError, plistlib.dumps, pl, fmt=fmt) + + b = BytesIO() + self.assertRaises(TypeError, plistlib.dump, pl, b, fmt=fmt) + + def test_skipkeys(self): + pl = { + 42: 'aNumber', + 'snake': 'aWord', + } + + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + data = plistlib.dumps( + pl, fmt=fmt, skipkeys=True, sort_keys=False) + + pl2 = plistlib.loads(data) + self.assertEqual(pl2, {'snake': 'aWord'}) + + fp = BytesIO() + plistlib.dump( + pl, fp, fmt=fmt, skipkeys=True, sort_keys=False) + data = fp.getvalue() + pl2 = plistlib.loads(fp.getvalue()) + self.assertEqual(pl2, {'snake': 'aWord'}) + + def test_tuple_members(self): + pl = { + 'first': (1, 2), + 'second': (1, 2), + 'third': (3, 4), + } + + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + data = plistlib.dumps(pl, fmt=fmt) + pl2 = plistlib.loads(data) + self.assertEqual(pl2, { + 'first': [1, 2], + 'second': [1, 2], + 'third': [3, 4], + }) + self.assertIsNot(pl2['first'], pl2['second']) + + def test_list_members(self): + pl = { + 'first': [1, 2], + 'second': [1, 2], + 'third': [3, 4], + } + + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + data = plistlib.dumps(pl, fmt=fmt) + pl2 = plistlib.loads(data) + self.assertEqual(pl2, { + 'first': [1, 2], + 'second': [1, 2], + 'third': [3, 4], + }) + self.assertIsNot(pl2['first'], pl2['second']) + + def test_dict_members(self): + pl = { + 'first': {'a': 1}, + 'second': {'a': 1}, + 'third': {'b': 2 }, + } + + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + data = plistlib.dumps(pl, fmt=fmt) + pl2 = plistlib.loads(data) + self.assertEqual(pl2, { + 'first': {'a': 1}, + 'second': {'a': 1}, + 'third': {'b': 2 }, + }) + self.assertIsNot(pl2['first'], pl2['second']) def test_controlcharacters(self): for i in range(128): @@ -179,25 +350,27 @@ class TestPlistlib(unittest.TestCase): testString = "string containing %s" % c if i >= 32 or c in "\r\n\t": # \r, \n and \t are the only legal control chars in XML - plistlib.writePlistToBytes(testString) + plistlib.dumps(testString, fmt=plistlib.FMT_XML) else: self.assertRaises(ValueError, - plistlib.writePlistToBytes, + plistlib.dumps, testString) def test_nondictroot(self): - test1 = "abc" - test2 = [1, 2, 3, "abc"] - result1 = plistlib.readPlistFromBytes(plistlib.writePlistToBytes(test1)) - result2 = plistlib.readPlistFromBytes(plistlib.writePlistToBytes(test2)) - self.assertEqual(test1, result1) - self.assertEqual(test2, result2) + for fmt in ALL_FORMATS: + with self.subTest(fmt=fmt): + test1 = "abc" + test2 = [1, 2, 3, "abc"] + result1 = plistlib.loads(plistlib.dumps(test1, fmt=fmt)) + result2 = plistlib.loads(plistlib.dumps(test2, fmt=fmt)) + self.assertEqual(test1, result1) + self.assertEqual(test2, result2) def test_invalidarray(self): for i in ["<key>key inside an array</key>", "<key>key inside an array2</key><real>3</real>", "<true/><key>key inside an array3</key>"]: - self.assertRaises(ValueError, plistlib.readPlistFromBytes, + self.assertRaises(ValueError, plistlib.loads, ("<plist><array>%s</array></plist>"%i).encode()) def test_invaliddict(self): @@ -206,22 +379,130 @@ class TestPlistlib(unittest.TestCase): "<string>missing key</string>", "<key>k1</key><string>v1</string><real>5.3</real>" "<key>k1</key><key>k2</key><string>double key</string>"]: - self.assertRaises(ValueError, plistlib.readPlistFromBytes, + self.assertRaises(ValueError, plistlib.loads, ("<plist><dict>%s</dict></plist>"%i).encode()) - self.assertRaises(ValueError, plistlib.readPlistFromBytes, + self.assertRaises(ValueError, plistlib.loads, ("<plist><array><dict>%s</dict></array></plist>"%i).encode()) def test_invalidinteger(self): - self.assertRaises(ValueError, plistlib.readPlistFromBytes, + self.assertRaises(ValueError, plistlib.loads, b"<plist><integer>not integer</integer></plist>") def test_invalidreal(self): - self.assertRaises(ValueError, plistlib.readPlistFromBytes, + self.assertRaises(ValueError, plistlib.loads, b"<plist><integer>not real</integer></plist>") + def test_xml_encodings(self): + base = TESTDATA[plistlib.FMT_XML] + + for xml_encoding, encoding, bom in [ + (b'utf-8', 'utf-8', codecs.BOM_UTF8), + (b'utf-16', 'utf-16-le', codecs.BOM_UTF16_LE), + (b'utf-16', 'utf-16-be', codecs.BOM_UTF16_BE), + # Expat does not support UTF-32 + #(b'utf-32', 'utf-32-le', codecs.BOM_UTF32_LE), + #(b'utf-32', 'utf-32-be', codecs.BOM_UTF32_BE), + ]: + + pl = self._create(fmt=plistlib.FMT_XML) + with self.subTest(encoding=encoding): + data = base.replace(b'UTF-8', xml_encoding) + data = bom + data.decode('utf-8').encode(encoding) + pl2 = plistlib.loads(data) + self.assertEqual(dict(pl), dict(pl2)) + + +class TestPlistlibDeprecated(unittest.TestCase): + def test_io_deprecated(self): + pl_in = { + 'key': 42, + 'sub': { + 'key': 9, + 'alt': 'value', + 'data': b'buffer', + } + } + pl_out = plistlib._InternalDict({ + 'key': 42, + 'sub': plistlib._InternalDict({ + 'key': 9, + 'alt': 'value', + 'data': plistlib.Data(b'buffer'), + }) + }) + + self.addCleanup(support.unlink, support.TESTFN) + with self.assertWarns(DeprecationWarning): + plistlib.writePlist(pl_in, support.TESTFN) + + with self.assertWarns(DeprecationWarning): + pl2 = plistlib.readPlist(support.TESTFN) + + self.assertEqual(pl_out, pl2) + + os.unlink(support.TESTFN) + + with open(support.TESTFN, 'wb') as fp: + with self.assertWarns(DeprecationWarning): + plistlib.writePlist(pl_in, fp) + + with open(support.TESTFN, 'rb') as fp: + with self.assertWarns(DeprecationWarning): + pl2 = plistlib.readPlist(fp) + + self.assertEqual(pl_out, pl2) + + def test_bytes_deprecated(self): + pl = { + 'key': 42, + 'sub': { + 'key': 9, + 'alt': 'value', + 'data': b'buffer', + } + } + with self.assertWarns(DeprecationWarning): + data = plistlib.writePlistToBytes(pl) + + with self.assertWarns(DeprecationWarning): + pl2 = plistlib.readPlistFromBytes(data) + + self.assertIsInstance(pl2, plistlib._InternalDict) + self.assertEqual(pl2, plistlib._InternalDict( + key=42, + sub=plistlib._InternalDict( + key=9, + alt='value', + data=plistlib.Data(b'buffer'), + ) + )) + + with self.assertWarns(DeprecationWarning): + data2 = plistlib.writePlistToBytes(pl2) + self.assertEqual(data, data2) + + def test_dataobject_deprecated(self): + in_data = { 'key': plistlib.Data(b'hello') } + out_data = { 'key': b'hello' } + + buf = plistlib.dumps(in_data) + + cur = plistlib.loads(buf) + self.assertEqual(cur, out_data) + self.assertNotEqual(cur, in_data) + + cur = plistlib.loads(buf, use_builtin_types=False) + self.assertNotEqual(cur, out_data) + self.assertEqual(cur, in_data) + + with self.assertWarns(DeprecationWarning): + cur = plistlib.readPlistFromBytes(buf) + self.assertNotEqual(cur, out_data) + self.assertEqual(cur, in_data) + def test_main(): - support.run_unittest(TestPlistlib) + support.run_unittest(TestPlistlib, TestPlistlibDeprecated) if __name__ == '__main__': diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py index a07a7199c1..a0a332bdbc 100644 --- a/Lib/test/test_poll.py +++ b/Lib/test/test_poll.py @@ -1,6 +1,7 @@ # Test case for the os.poll() function import os +import subprocess import random import select try: @@ -14,7 +15,7 @@ from test.support import TESTFN, run_unittest, reap_threads, cpython_only try: select.poll except AttributeError: - raise unittest.SkipTest("select.poll not defined -- skipping test_poll") + raise unittest.SkipTest("select.poll not defined") def find_ready_matching(ready, flag): @@ -75,13 +76,11 @@ class PollTests(unittest.TestCase): self.assertEqual(bufs, [MSG] * NUM_PIPES) - def poll_unit_tests(self): + def test_poll_unit_tests(self): # returns NVAL for invalid file descriptor - FD = 42 - try: - os.close(FD) - except OSError: - pass + FD, w = os.pipe() + os.close(FD) + os.close(w) p = select.poll() p.register(FD) r = p.poll() @@ -124,7 +123,9 @@ class PollTests(unittest.TestCase): def test_poll2(self): cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' - p = os.popen(cmd, 'r') + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, + bufsize=0) + p = proc.stdout pollster = select.poll() pollster.register( p, select.POLLIN ) for tout in (0, 1000, 2000, 4000, 8000, 16000) + (-1,)*10: @@ -134,7 +135,7 @@ class PollTests(unittest.TestCase): fd, flags = fdlist[0] if flags & select.POLLHUP: line = p.readline() - if line != "": + if line != b"": self.fail('error: pipe seems to be closed, but still returns data') continue @@ -142,6 +143,7 @@ class PollTests(unittest.TestCase): line = p.readline() if not line: break + self.assertEqual(line, b'testing...\n') continue else: self.fail('Unexpected return value from select.poll: %s' % fdlist) diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py index 5809be66bb..d076fc1b9c 100644 --- a/Lib/test/test_poplib.py +++ b/Lib/test/test_poplib.py @@ -18,6 +18,19 @@ threading = test_support.import_module('threading') HOST = test_support.HOST PORT = 0 +SUPPORTS_SSL = False +if hasattr(poplib, 'POP3_SSL'): + import ssl + from ssl import HAS_SNI + + SUPPORTS_SSL = True + CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "keycert3.pem") + CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "pycacert.pem") +else: + HAS_SNI = False + +requires_ssl = skipUnless(SUPPORTS_SSL, 'SSL not supported') + # the dummy data returned by server when LIST and RETR commands are issued LIST_RESP = b'1 1\r\n2 2\r\n3 3\r\n4 4\r\n5 5\r\n.\r\n' RETR_RESP = b"""From: postmaster@python.org\ @@ -33,11 +46,15 @@ line3\r\n\ class DummyPOP3Handler(asynchat.async_chat): + CAPAS = {'UIDL': [], 'IMPLEMENTATION': ['python-testlib-pop-server']} + def __init__(self, conn): asynchat.async_chat.__init__(self, conn) self.set_terminator(b"\r\n") self.in_buffer = [] self.push('+OK dummy pop3 server ready. <timestamp>') + self.tls_active = False + self.tls_starting = False def collect_incoming_data(self, data): self.in_buffer.append(data) @@ -112,6 +129,65 @@ class DummyPOP3Handler(asynchat.async_chat): self.push('+OK closing.') self.close_when_done() + def _get_capas(self): + _capas = dict(self.CAPAS) + if not self.tls_active and SUPPORTS_SSL: + _capas['STLS'] = [] + return _capas + + def cmd_capa(self, arg): + self.push('+OK Capability list follows') + if self._get_capas(): + for cap, params in self._get_capas().items(): + _ln = [cap] + if params: + _ln.extend(params) + self.push(' '.join(_ln)) + self.push('.') + + if SUPPORTS_SSL: + + def cmd_stls(self, arg): + if self.tls_active is False: + self.push('+OK Begin TLS negotiation') + tls_sock = ssl.wrap_socket(self.socket, certfile=CERTFILE, + server_side=True, + do_handshake_on_connect=False, + suppress_ragged_eofs=False) + self.del_channel() + self.set_socket(tls_sock) + self.tls_active = True + self.tls_starting = True + self.in_buffer = [] + self._do_tls_handshake() + else: + self.push('-ERR Command not permitted when TLS active') + + def _do_tls_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + raise + except OSError as err: + if err.args[0] == errno.ECONNABORTED: + return self.handle_close() + else: + self.tls_active = True + self.tls_starting = False + + def handle_read(self): + if self.tls_starting: + self._do_tls_handshake() + else: + try: + asynchat.async_chat.handle_read(self) + except ssl.SSLEOFError: + self.handle_close() class DummyPOP3Server(asyncore.dispatcher, threading.Thread): @@ -236,19 +312,43 @@ class TestPOP3Class(TestCase): self.client.uidl() self.client.uidl('foo') + def test_capa(self): + capa = self.client.capa() + self.assertTrue('IMPLEMENTATION' in capa.keys()) + def test_quit(self): resp = self.client.quit() self.assertTrue(resp) self.assertIsNone(self.client.sock) self.assertIsNone(self.client.file) + @requires_ssl + def test_stls_capa(self): + capa = self.client.capa() + self.assertTrue('STLS' in capa.keys()) + + @requires_ssl + def test_stls(self): + expected = b'+OK Begin TLS negotiation' + resp = self.client.stls() + self.assertEqual(resp, expected) + + @requires_ssl + @skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_stls_context(self): + expected = b'+OK Begin TLS negotiation' + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_verify_locations(CAFILE) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.check_hostname = True + with self.assertRaises(ssl.CertificateError): + resp = self.client.stls(context=ctx) + self.client = poplib.POP3("localhost", self.server.port, timeout=3) + resp = self.client.stls(context=ctx) + self.assertEqual(resp, expected) -SUPPORTS_SSL = False -if hasattr(poplib, 'POP3_SSL'): - import ssl - SUPPORTS_SSL = True - CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "keycert.pem") +if SUPPORTS_SSL: class DummyPOP3_SSLHandler(DummyPOP3Handler): @@ -260,35 +360,13 @@ if hasattr(poplib, 'POP3_SSL'): self.del_channel() self.set_socket(ssl_socket) # Must try handshake before calling push() - self._ssl_accepting = True - self._do_ssl_handshake() + self.tls_active = True + self.tls_starting = True + self._do_tls_handshake() self.set_terminator(b"\r\n") self.in_buffer = [] self.push('+OK dummy pop3 server ready. <timestamp>') - def _do_ssl_handshake(self): - try: - self.socket.do_handshake() - except ssl.SSLError as err: - if err.args[0] in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE): - return - elif err.args[0] == ssl.SSL_ERROR_EOF: - return self.handle_close() - raise - except socket.error as err: - if err.args[0] == errno.ECONNABORTED: - return self.handle_close() - else: - self._ssl_accepting = False - - def handle_read(self): - if self._ssl_accepting: - self._do_ssl_handshake() - else: - DummyPOP3Handler.handle_read(self) - -requires_ssl = skipUnless(SUPPORTS_SSL, 'SSL not supported') @requires_ssl class TestPOP3_SSLClass(TestPOP3Class): @@ -306,20 +384,60 @@ class TestPOP3_SSLClass(TestPOP3Class): def test_context(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host, - self.server.port, keyfile=CERTFILE, context=ctx) + self.server.port, keyfile=CERTFILE, context=ctx) self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host, - self.server.port, certfile=CERTFILE, context=ctx) + self.server.port, certfile=CERTFILE, context=ctx) self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host, - self.server.port, keyfile=CERTFILE, - certfile=CERTFILE, context=ctx) + self.server.port, keyfile=CERTFILE, + certfile=CERTFILE, context=ctx) self.client.quit() self.client = poplib.POP3_SSL(self.server.host, self.server.port, - context=ctx) + context=ctx) self.assertIsInstance(self.client.sock, ssl.SSLSocket) self.assertIs(self.client.sock.context, ctx) self.assertTrue(self.client.noop().startswith(b'+OK')) + def test_stls(self): + self.assertRaises(poplib.error_proto, self.client.stls) + + test_stls_context = test_stls + + def test_stls_capa(self): + capa = self.client.capa() + self.assertFalse('STLS' in capa.keys()) + + +@requires_ssl +class TestPOP3_TLSClass(TestPOP3Class): + # repeat previous tests by using poplib.POP3.stls() + + def setUp(self): + self.server = DummyPOP3Server((HOST, PORT)) + self.server.start() + self.client = poplib.POP3(self.server.host, self.server.port, timeout=3) + self.client.stls() + + def tearDown(self): + if self.client.file is not None and self.client.sock is not None: + try: + self.client.quit() + except poplib.error_proto: + # happens in the test_too_long_lines case; the overlong + # response will be treated as response to QUIT and raise + # this exception + self.client.close() + self.server.stop() + + def test_stls(self): + self.assertRaises(poplib.error_proto, self.client.stls) + + test_stls_context = test_stls + + def test_stls_capa(self): + capa = self.client.capa() + self.assertFalse(b'STLS' in capa.keys()) + class TestTimeouts(TestCase): @@ -377,7 +495,7 @@ class TestTimeouts(TestCase): def test_main(): tests = [TestPOP3Class, TestTimeouts, - TestPOP3_SSLClass] + TestPOP3_SSLClass, TestPOP3_TLSClass] thread_info = test_support.threading_setup() try: test_support.run_unittest(*tests) diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py index 9408e5c1b5..009f29da32 100644 --- a/Lib/test/test_posix.py +++ b/Lib/test/test_posix.py @@ -603,8 +603,10 @@ class PosixTester(unittest.TestCase): r, w = os.pipe2(os.O_CLOEXEC|os.O_NONBLOCK) self.addCleanup(os.close, r) self.addCleanup(os.close, w) - self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) - self.assertTrue(fcntl.fcntl(w, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) + self.assertFalse(os.get_inheritable(r)) + self.assertFalse(os.get_inheritable(w)) + self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK) + self.assertTrue(fcntl.fcntl(w, fcntl.F_GETFL) & os.O_NONBLOCK) # try reading from an empty pipe: this should fail, not block self.assertRaises(OSError, os.read, r, 1) # try a write big enough to fill-up the pipe: this should either @@ -652,7 +654,7 @@ class PosixTester(unittest.TestCase): self.assertEqual(st.st_flags | stat.UF_IMMUTABLE, new_st.st_flags) try: fd = open(target_file, 'w+') - except IOError as e: + except OSError as e: self.assertEqual(e.errno, errno.EPERM) finally: posix.chflags(target_file, st.st_flags) @@ -710,41 +712,39 @@ class PosixTester(unittest.TestCase): @unittest.skipUnless(hasattr(posix, 'getcwd'), 'test needs posix.getcwd()') def test_getcwd_long_pathnames(self): - if hasattr(posix, 'getcwd'): - dirname = 'getcwd-test-directory-0123456789abcdef-01234567890abcdef' - curdir = os.getcwd() - base_path = os.path.abspath(support.TESTFN) + '.getcwd' + dirname = 'getcwd-test-directory-0123456789abcdef-01234567890abcdef' + curdir = os.getcwd() + base_path = os.path.abspath(support.TESTFN) + '.getcwd' - try: - os.mkdir(base_path) - os.chdir(base_path) - except: -# Just returning nothing instead of the SkipTest exception, -# because the test results in Error in that case. -# Is that ok? -# raise unittest.SkipTest("cannot create directory for testing") - return - - def _create_and_do_getcwd(dirname, current_path_length = 0): - try: - os.mkdir(dirname) - except: - raise unittest.SkipTest("mkdir cannot create directory sufficiently deep for getcwd test") + try: + os.mkdir(base_path) + os.chdir(base_path) + except: + # Just returning nothing instead of the SkipTest exception, because + # the test results in Error in that case. Is that ok? + # raise unittest.SkipTest("cannot create directory for testing") + return - os.chdir(dirname) - try: - os.getcwd() - if current_path_length < 1027: - _create_and_do_getcwd(dirname, current_path_length + len(dirname) + 1) - finally: - os.chdir('..') - os.rmdir(dirname) + def _create_and_do_getcwd(dirname, current_path_length = 0): + try: + os.mkdir(dirname) + except: + raise unittest.SkipTest("mkdir cannot create directory sufficiently deep for getcwd test") - _create_and_do_getcwd(dirname) + os.chdir(dirname) + try: + os.getcwd() + if current_path_length < 1027: + _create_and_do_getcwd(dirname, current_path_length + len(dirname) + 1) + finally: + os.chdir('..') + os.rmdir(dirname) - finally: - os.chdir(curdir) - support.rmtree(base_path) + _create_and_do_getcwd(dirname) + + finally: + os.chdir(curdir) + support.rmtree(base_path) @unittest.skipUnless(hasattr(posix, 'getgrouplist'), "test needs posix.getgrouplist()") @unittest.skipUnless(hasattr(pwd, 'getpwuid'), "test needs pwd.getpwuid()") @@ -1121,6 +1121,23 @@ class PosixTester(unittest.TestCase): # http://lists.freebsd.org/pipermail/freebsd-amd64/2012-January/014332.html raise unittest.SkipTest("OSError raised!") + def test_path_error2(self): + """ + Test functions that call path_error2(), providing two filenames in their exceptions. + """ + for name in ("rename", "replace", "link", "symlink"): + function = getattr(os, name, None) + + if function: + for dst in ("noodly2", support.TESTFN): + try: + function('doesnotexistfilename', dst) + except OSError as e: + self.assertIn("'doesnotexistfilename' -> '{}'".format(dst), str(e)) + break + else: + self.fail("No valid path_error2() test for os." + name) + class PosixGroupsTester(unittest.TestCase): def setUp(self): diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py index 0e7d866485..412849cff3 100644 --- a/Lib/test/test_posixpath.py +++ b/Lib/test/test_posixpath.py @@ -186,63 +186,6 @@ class PosixPathTest(unittest.TestCase): if not f.close(): f.close() - @staticmethod - def _create_file(filename): - with open(filename, 'wb') as f: - f.write(b'foo') - - def test_samefile(self): - test_fn = support.TESTFN + "1" - self._create_file(test_fn) - self.assertTrue(posixpath.samefile(test_fn, test_fn)) - self.assertRaises(TypeError, posixpath.samefile) - - @unittest.skipIf( - sys.platform.startswith('win'), - "posixpath.samefile does not work on links in Windows") - @unittest.skipUnless(hasattr(os, "symlink"), - "Missing symlink implementation") - def test_samefile_on_links(self): - test_fn1 = support.TESTFN + "1" - test_fn2 = support.TESTFN + "2" - self._create_file(test_fn1) - - os.symlink(test_fn1, test_fn2) - self.assertTrue(posixpath.samefile(test_fn1, test_fn2)) - os.remove(test_fn2) - - self._create_file(test_fn2) - self.assertFalse(posixpath.samefile(test_fn1, test_fn2)) - - - def test_samestat(self): - test_fn = support.TESTFN + "1" - self._create_file(test_fn) - test_fns = [test_fn]*2 - stats = map(os.stat, test_fns) - self.assertTrue(posixpath.samestat(*stats)) - - @unittest.skipIf( - sys.platform.startswith('win'), - "posixpath.samestat does not work on links in Windows") - @unittest.skipUnless(hasattr(os, "symlink"), - "Missing symlink implementation") - def test_samestat_on_links(self): - test_fn1 = support.TESTFN + "1" - test_fn2 = support.TESTFN + "2" - self._create_file(test_fn1) - test_fns = (test_fn1, test_fn2) - os.symlink(*test_fns) - stats = map(os.stat, test_fns) - self.assertTrue(posixpath.samestat(*stats)) - os.remove(test_fn2) - - self._create_file(test_fn2) - stats = map(os.stat, test_fns) - self.assertFalse(posixpath.samestat(*stats)) - - self.assertRaises(TypeError, posixpath.samestat) - def test_ismount(self): self.assertIs(posixpath.ismount("/"), True) with warnings.catch_warnings(): @@ -595,11 +538,6 @@ class PosixPathTest(unittest.TestCase): finally: os.getcwdb = real_getcwdb - def test_sameopenfile(self): - fname = support.TESTFN + "1" - with open(fname, "wb") as a, open(fname, "wb") as b: - self.assertTrue(posixpath.sameopenfile(a.fileno(), b.fileno())) - class PosixCommonTest(test_genericpath.CommonTest, unittest.TestCase): pathmodule = posixpath diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py index a85298e26f..3d364c4595 100644 --- a/Lib/test/test_pprint.py +++ b/Lib/test/test_pprint.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import pprint import test.support import unittest @@ -530,6 +532,54 @@ frozenset2({0, self.assertEqual(pprint.pformat(dict.fromkeys(keys, 0)), '{%r: 0, %r: 0}' % tuple(sorted(keys, key=id))) + def test_str_wrap(self): + # pprint tries to wrap strings intelligently + fox = 'the quick brown fox jumped over a lazy dog' + self.assertEqual(pprint.pformat(fox, width=20), """\ +'the quick brown ' +'fox jumped over ' +'a lazy dog'""") + self.assertEqual(pprint.pformat({'a': 1, 'b': fox, 'c': 2}, + width=26), """\ +{'a': 1, + 'b': 'the quick brown ' + 'fox jumped over ' + 'a lazy dog', + 'c': 2}""") + # With some special characters + # - \n always triggers a new line in the pprint + # - \t and \n are escaped + # - non-ASCII is allowed + # - an apostrophe doesn't disrupt the pprint + special = "Portons dix bons \"whiskys\"\nà l'avocat goujat\t qui fumait au zoo" + self.assertEqual(pprint.pformat(special, width=20), """\ +'Portons dix bons ' +'"whiskys"\\n' +"à l'avocat " +'goujat\\t qui ' +'fumait au zoo'""") + # An unwrappable string is formatted as its repr + unwrappable = "x" * 100 + self.assertEqual(pprint.pformat(unwrappable, width=80), repr(unwrappable)) + self.assertEqual(pprint.pformat(''), "''") + # Check that the pprint is a usable repr + special *= 10 + for width in range(3, 40): + formatted = pprint.pformat(special, width=width) + self.assertEqual(eval("(" + formatted + ")"), special) + + def test_compact(self): + o = ([list(range(i * i)) for i in range(5)] + + [list(range(i)) for i in range(6)]) + expected = """\ +[[], [0], [0, 1, 2, 3], + [0, 1, 2, 3, 4, 5, 6, 7, 8], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15], + [], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3], + [0, 1, 2, 3, 4]]""" + self.assertEqual(pprint.pformat(o, width=48, compact=True), expected) + class DottedPrettyPrinter(pprint.PrettyPrinter): diff --git a/Lib/test/test_print.py b/Lib/test/test_print.py index 9d6dbea46b..7eea349115 100644 --- a/Lib/test/test_print.py +++ b/Lib/test/test_print.py @@ -1,62 +1,55 @@ -"""Test correct operation of the print function. -""" - -# In 2.6, this gives us the behavior we want. In 3.0, it has -# no function, but it still must parse correctly. -from __future__ import print_function - import unittest -from test import support +from io import StringIO -try: - # 3.x - from io import StringIO -except ImportError: - # 2.x - from StringIO import StringIO +from test import support NotDefined = object() # A dispatch table all 8 combinations of providing -# sep, end, and file +# sep, end, and file. # I use this machinery so that I'm not just passing default -# values to print, I'm either passing or not passing in the -# arguments +# values to print, I'm either passing or not passing in the +# arguments. dispatch = { (False, False, False): - lambda args, sep, end, file: print(*args), + lambda args, sep, end, file: print(*args), (False, False, True): - lambda args, sep, end, file: print(file=file, *args), + lambda args, sep, end, file: print(file=file, *args), (False, True, False): - lambda args, sep, end, file: print(end=end, *args), + lambda args, sep, end, file: print(end=end, *args), (False, True, True): - lambda args, sep, end, file: print(end=end, file=file, *args), + lambda args, sep, end, file: print(end=end, file=file, *args), (True, False, False): - lambda args, sep, end, file: print(sep=sep, *args), + lambda args, sep, end, file: print(sep=sep, *args), (True, False, True): - lambda args, sep, end, file: print(sep=sep, file=file, *args), + lambda args, sep, end, file: print(sep=sep, file=file, *args), (True, True, False): - lambda args, sep, end, file: print(sep=sep, end=end, *args), + lambda args, sep, end, file: print(sep=sep, end=end, *args), (True, True, True): - lambda args, sep, end, file: print(sep=sep, end=end, file=file, *args), - } + lambda args, sep, end, file: print(sep=sep, end=end, file=file, *args), +} + # Class used to test __str__ and print class ClassWith__str__: def __init__(self, x): self.x = x + def __str__(self): return self.x + class TestPrint(unittest.TestCase): + """Test correct operation of the print function.""" + def check(self, expected, args, - sep=NotDefined, end=NotDefined, file=NotDefined): + sep=NotDefined, end=NotDefined, file=NotDefined): # Capture sys.stdout in a StringIO. Call print with args, - # and with sep, end, and file, if they're defined. Result - # must match expected. + # and with sep, end, and file, if they're defined. Result + # must match expected. - # Look up the actual function to call, based on if sep, end, and file - # are defined + # Look up the actual function to call, based on if sep, end, + # and file are defined. fn = dispatch[(sep is not NotDefined, end is not NotDefined, file is not NotDefined)] @@ -69,7 +62,7 @@ class TestPrint(unittest.TestCase): def test_print(self): def x(expected, args, sep=NotDefined, end=NotDefined): # Run the test 2 ways: not using file, and using - # file directed to a StringIO + # file directed to a StringIO. self.check(expected, args, sep=sep, end=end) @@ -101,11 +94,6 @@ class TestPrint(unittest.TestCase): x('*\n', (ClassWith__str__('*'),)) x('abc 1\n', (ClassWith__str__('abc'), 1)) -# # 2.x unicode tests -# x(u'1 2\n', ('1', u'2')) -# x(u'u\1234\n', (u'u\1234',)) -# x(u' abc 1\n', (' ', ClassWith__str__(u'abc'), 1)) - # errors self.assertRaises(TypeError, print, '', sep=3) self.assertRaises(TypeError, print, '', end=3) @@ -113,12 +101,14 @@ class TestPrint(unittest.TestCase): def test_print_flush(self): # operation of the flush flag - class filelike(): + class filelike: def __init__(self): self.written = '' self.flushed = 0 + def write(self, str): self.written += str + def flush(self): self.flushed += 1 @@ -130,15 +120,13 @@ class TestPrint(unittest.TestCase): self.assertEqual(f.flushed, 2) # ensure exceptions from flush are passed through - class noflush(): + class noflush: def write(self, str): pass + def flush(self): raise RuntimeError self.assertRaises(RuntimeError, print, 1, file=noflush(), flush=True) -def test_main(): - support.run_unittest(TestPrint) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_profile.py b/Lib/test/test_profile.py index cd7ec58e23..1fc3c42669 100644 --- a/Lib/test/test_profile.py +++ b/Lib/test/test_profile.py @@ -3,9 +3,11 @@ import sys import pstats import unittest +import os from difflib import unified_diff from io import StringIO -from test.support import run_unittest +from test.support import TESTFN, run_unittest, unlink +from contextlib import contextmanager import profile from test.profilee import testfunc, timer @@ -14,9 +16,13 @@ from test.profilee import testfunc, timer class ProfileTest(unittest.TestCase): profilerclass = profile.Profile + profilermodule = profile methodnames = ['print_stats', 'print_callers', 'print_callees'] expected_max_output = ':0(max)' + def tearDown(self): + unlink(TESTFN) + def get_expected_output(self): return _ProfileOutput @@ -74,6 +80,19 @@ class ProfileTest(unittest.TestCase): self.assertIn(self.expected_max_output, res, "Profiling {0!r} didn't report max:\n{1}".format(stmt, res)) + def test_run(self): + with silent(): + self.profilermodule.run("int('1')") + self.profilermodule.run("int('1')", filename=TESTFN) + self.assertTrue(os.path.exists(TESTFN)) + + def test_runctx(self): + with silent(): + self.profilermodule.runctx("testfunc()", globals(), locals()) + self.profilermodule.runctx("testfunc()", globals(), locals(), + filename=TESTFN) + self.assertTrue(os.path.exists(TESTFN)) + def regenerate_expected_output(filename, cls): filename = filename.rstrip('co') @@ -95,6 +114,14 @@ def regenerate_expected_output(filename, cls): method, results[i+1])) f.write('\nif __name__ == "__main__":\n main()\n') +@contextmanager +def silent(): + stdout = sys.stdout + try: + sys.stdout = StringIO() + yield + finally: + sys.stdout = stdout def test_main(): run_unittest(ProfileTest) diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py index 29297f8841..8916861f5b 100644 --- a/Lib/test/test_pty.py +++ b/Lib/test/test_pty.py @@ -187,7 +187,7 @@ class PtyTest(unittest.TestCase): ##debug("Reading from master_fd now that the child has exited") ##try: ## s1 = os.read(master_fd, 1024) - ##except os.error: + ##except OSError: ## pass ##else: ## raise TestFailed("Read from master_fd did not raise exception") diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py index f3c1a6a44b..154c08a6a2 100644 --- a/Lib/test/test_py_compile.py +++ b/Lib/test/test_py_compile.py @@ -1,7 +1,9 @@ -import imp +import importlib.util import os import py_compile import shutil +import stat +import sys import tempfile import unittest @@ -13,7 +15,7 @@ class PyCompileTests(unittest.TestCase): self.directory = tempfile.mkdtemp() self.source_path = os.path.join(self.directory, '_test.py') self.pyc_path = self.source_path + 'c' - self.cache_path = imp.cache_from_source(self.source_path) + self.cache_path = importlib.util.cache_from_source(self.source_path) self.cwd_drive = os.path.splitdrive(os.getcwd())[0] # In these tests we compute relative paths. When using Windows, the # current working directory path and the 'self.source_path' might be @@ -35,6 +37,26 @@ class PyCompileTests(unittest.TestCase): self.assertTrue(os.path.exists(self.pyc_path)) self.assertFalse(os.path.exists(self.cache_path)) + def test_do_not_overwrite_symlinks(self): + # In the face of a cfile argument being a symlink, bail out. + # Issue #17222 + try: + os.symlink(self.pyc_path + '.actual', self.pyc_path) + except (NotImplementedError, OSError): + self.skipTest('need to be able to create a symlink for a file') + else: + assert os.path.islink(self.pyc_path) + with self.assertRaises(FileExistsError): + py_compile.compile(self.source_path, self.pyc_path) + + @unittest.skipIf(not os.path.exists(os.devnull) or os.path.isfile(os.devnull), + 'requires os.devnull and for it to be a non-regular file') + def test_do_not_overwrite_nonregular_files(self): + # In the face of a cfile argument being a non-regular file, bail out. + # Issue #17222 + with self.assertRaises(FileExistsError): + py_compile.compile(self.source_path, os.devnull) + def test_cache_path(self): py_compile.compile(self.source_path) self.assertTrue(os.path.exists(self.cache_path)) @@ -54,8 +76,22 @@ class PyCompileTests(unittest.TestCase): self.assertTrue(os.path.exists(self.pyc_path)) self.assertFalse(os.path.exists(self.cache_path)) -def test_main(): - support.run_unittest(PyCompileTests) + @unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0, + 'non-root user required') + @unittest.skipIf(os.name == 'nt', + 'cannot control directory permissions on Windows') + def test_exceptions_propagate(self): + # Make sure that exceptions raised thanks to issues with writing + # bytecode. + # http://bugs.python.org/issue17244 + mode = os.stat(self.directory) + os.chmod(self.directory, stat.S_IREAD) + try: + with self.assertRaises(IOError): + py_compile.compile(self.source_path, self.pyc_path) + finally: + os.chmod(self.directory, mode.st_mode) + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pyclbr.py b/Lib/test/test_pyclbr.py index e83989e2d8..88aff898d4 100644 --- a/Lib/test/test_pyclbr.py +++ b/Lib/test/test_pyclbr.py @@ -142,7 +142,7 @@ class PyclbrTest(TestCase): self.checkModule('pyclbr') self.checkModule('ast') self.checkModule('doctest', ignore=("TestResults", "_SpoofOut", - "DocTestCase")) + "DocTestCase", '_DocTestSuite')) self.checkModule('difflib', ignore=("Match",)) def test_decorators(self): @@ -158,7 +158,7 @@ class PyclbrTest(TestCase): cm('random', ignore=('Random',)) # from _random import Random as CoreGenerator cm('cgi', ignore=('log',)) # set with = in module cm('pickle') - cm('aifc', ignore=('openfp',)) # set with = in module + cm('aifc', ignore=('openfp', '_aifc_params')) # set with = in module cm('sre_parse', ignore=('dump',)) # from sre_constants import * cm('pdb') cm('pydoc') diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py index 43f4163dae..f26fb15f53 100644 --- a/Lib/test/test_pydoc.py +++ b/Lib/test/test_pydoc.py @@ -6,11 +6,13 @@ import difflib import inspect import pydoc import keyword +import _pickle import pkgutil import re import string import test.support import time +import types import unittest import xml.etree import textwrap @@ -20,7 +22,7 @@ from test.script_helper import assert_python_ok from test.support import ( TESTFN, rmtree, reap_children, reap_threads, captured_output, captured_stdout, - captured_stderr, unlink + captured_stderr, unlink, requires_docstrings ) from test import pydoc_mod @@ -29,10 +31,6 @@ try: except ImportError: threading = None -# Just in case sys.modules["test"] has the optional attribute __loader__. -if hasattr(pydoc_mod, "__loader__"): - del pydoc_mod.__loader__ - if test.support.HAVE_DOCSTRINGS: expected_data_docstrings = ( 'dictionary for instance variables (if defined)', @@ -212,6 +210,75 @@ missing_pattern = "no Python documentation found for '%s'" # output pattern for module with bad imports badimport_pattern = "problem in %s - ImportError: No module named %r" +expected_dynamicattribute_pattern = """ +Help on class DA in module %s: + +class DA(builtins.object) + | Data descriptors defined here: + |\x20\x20 + | __dict__%s + |\x20\x20 + | __weakref__%s + |\x20\x20 + | ham + |\x20\x20 + | ---------------------------------------------------------------------- + | Data and other attributes inherited from Meta: + |\x20\x20 + | ham = 'spam' +""".strip() + +expected_virtualattribute_pattern1 = """ +Help on class Class in module %s: + +class Class(builtins.object) + | Data and other attributes inherited from Meta: + |\x20\x20 + | LIFE = 42 +""".strip() + +expected_virtualattribute_pattern2 = """ +Help on class Class1 in module %s: + +class Class1(builtins.object) + | Data and other attributes inherited from Meta1: + |\x20\x20 + | one = 1 +""".strip() + +expected_virtualattribute_pattern3 = """ +Help on class Class2 in module %s: + +class Class2(Class1) + | Method resolution order: + | Class2 + | Class1 + | builtins.object + |\x20\x20 + | Data and other attributes inherited from Meta1: + |\x20\x20 + | one = 1 + |\x20\x20 + | ---------------------------------------------------------------------- + | Data and other attributes inherited from Meta3: + |\x20\x20 + | three = 3 + |\x20\x20 + | ---------------------------------------------------------------------- + | Data and other attributes inherited from Meta2: + |\x20\x20 + | two = 2 +""".strip() + +expected_missingattribute_pattern = """ +Help on class C in module %s: + +class C(builtins.object) + | Data and other attributes defined here: + |\x20\x20 + | here = 'present!' +""".strip() + def run_pydoc(module_name, *args, **env): """ Runs pydoc on the specified module. Returns the stripped @@ -320,6 +387,16 @@ class PydocDocTest(unittest.TestCase): print_diffs(expected_text, result) self.fail("outputs are not equal, see diff above") + def test_text_enum_member_with_value_zero(self): + # Test issue #20654 to ensure enum member with value 0 can be + # displayed. It used to throw KeyError: 'zero'. + import enum + class BinaryInteger(enum.IntEnum): + zero = 0 + one = 1 + doc = pydoc.render_doc(BinaryInteger) + self.assertIn('<BinaryInteger.zero: 0>', doc) + def test_issue8225(self): # Test issue8225 to ensure no doc link appears for xml.etree result, doc_loc = get_pydoc_text(xml.etree) @@ -421,6 +498,38 @@ class PydocDocTest(unittest.TestCase): synopsis = pydoc.synopsis(TESTFN, {}) self.assertEqual(synopsis, 'line 1: h\xe9') + def test_synopsis_sourceless(self): + expected = os.__doc__.splitlines()[0] + filename = os.__cached__ + synopsis = pydoc.synopsis(filename) + + self.assertEqual(synopsis, expected) + + def test_splitdoc_with_description(self): + example_string = "I Am A Doc\n\n\nHere is my description" + self.assertEqual(pydoc.splitdoc(example_string), + ('I Am A Doc', '\nHere is my description')) + + def test_is_object_or_method(self): + doc = pydoc.Doc() + # Bound Method + self.assertTrue(pydoc._is_some_method(doc.fail)) + # Method Descriptor + self.assertTrue(pydoc._is_some_method(int.__add__)) + # String + self.assertFalse(pydoc._is_some_method("I am not a method")) + + def test_is_package_when_not_package(self): + with test.support.temp_cwd() as test_dir: + self.assertFalse(pydoc.ispackage(test_dir)) + + def test_is_package_when_is_package(self): + with test.support.temp_cwd() as test_dir: + init_path = os.path.join(test_dir, '__init__.py') + open(init_path, 'w').close() + self.assertTrue(pydoc.ispackage(test_dir)) + os.remove(init_path) + def test_allmethods(self): # issue 17476: allmethods was no longer returning unbound methods. # This test is a bit fragile in the face of changes to object and type, @@ -509,6 +618,55 @@ class PydocImportTest(PydocBaseTest): self.assertEqual(out.getvalue(), '') self.assertEqual(err.getvalue(), '') + @unittest.skip('causes undesireable side-effects (#20128)') + def test_modules(self): + # See Helper.listmodules(). + num_header_lines = 2 + num_module_lines_min = 5 # Playing it safe. + num_footer_lines = 3 + expected = num_header_lines + num_module_lines_min + num_footer_lines + + output = StringIO() + helper = pydoc.Helper(output=output) + helper('modules') + result = output.getvalue().strip() + num_lines = len(result.splitlines()) + + self.assertGreaterEqual(num_lines, expected) + + @unittest.skip('causes undesireable side-effects (#20128)') + def test_modules_search(self): + # See Helper.listmodules(). + expected = 'pydoc - ' + + output = StringIO() + helper = pydoc.Helper(output=output) + with captured_stdout() as help_io: + helper('modules pydoc') + result = help_io.getvalue() + + self.assertIn(expected, result) + + @unittest.skip('some buildbots are not cooperating (#20128)') + def test_modules_search_builtin(self): + expected = 'gc - ' + + output = StringIO() + helper = pydoc.Helper(output=output) + with captured_stdout() as help_io: + helper('modules garbage') + result = help_io.getvalue() + + self.assertTrue(result.startswith(expected)) + + def test_importfile(self): + loaded_pydoc = pydoc.importfile(pydoc.__file__) + + self.assertIsNot(loaded_pydoc, pydoc) + self.assertEqual(loaded_pydoc.__name__, 'pydoc') + self.assertEqual(loaded_pydoc.__file__, pydoc.__file__) + self.assertEqual(loaded_pydoc.__spec__, pydoc.__spec__) + class TestDescriptions(unittest.TestCase): @@ -544,6 +702,42 @@ class TestDescriptions(unittest.TestCase): self.assertIsNone(pydoc.locate(name)) self.assertRaises(ImportError, pydoc.render_doc, name) + @staticmethod + def _get_summary_line(o): + text = pydoc.plain(pydoc.render_doc(o)) + lines = text.split('\n') + assert len(lines) >= 2 + return lines[2] + + # these should include "self" + def test_unbound_python_method(self): + self.assertEqual(self._get_summary_line(textwrap.TextWrapper.wrap), + "wrap(self, text)") + + @requires_docstrings + def test_unbound_builtin_method(self): + self.assertEqual(self._get_summary_line(_pickle.Pickler.dump), + "dump(self, obj, /)") + + # these no longer include "self" + def test_bound_python_method(self): + t = textwrap.TextWrapper() + self.assertEqual(self._get_summary_line(t.wrap), + "wrap(text) method of textwrap.TextWrapper instance") + + @requires_docstrings + def test_bound_builtin_method(self): + s = StringIO() + p = _pickle.Pickler(s) + self.assertEqual(self._get_summary_line(p.dump), + "dump(obj, /) method of _pickle.Pickler instance") + + # this should *never* include self! + @requires_docstrings + def test_module_level_callable(self): + self.assertEqual(self._get_summary_line(os.stat), + "stat(path, *, dir_fd=None, follow_symlinks=True)") + @unittest.skipUnless(threading, 'Threading required for this test.') class PydocServerTest(unittest.TestCase): @@ -615,6 +809,128 @@ class TestHelper(unittest.TestCase): self.assertEqual(sorted(pydoc.Helper.keywords), sorted(keyword.kwlist)) +class PydocWithMetaClasses(unittest.TestCase): + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), + 'trace function introduces __locals__ unexpectedly') + def test_DynamicClassAttribute(self): + class Meta(type): + def __getattr__(self, name): + if name == 'ham': + return 'spam' + return super().__getattr__(name) + class DA(metaclass=Meta): + @types.DynamicClassAttribute + def ham(self): + return 'eggs' + expected_text_data_docstrings = tuple('\n | ' + s if s else '' + for s in expected_data_docstrings) + output = StringIO() + helper = pydoc.Helper(output=output) + helper(DA) + expected_text = expected_dynamicattribute_pattern % ( + (__name__,) + expected_text_data_docstrings[:2]) + result = output.getvalue().strip() + if result != expected_text: + print_diffs(expected_text, result) + self.fail("outputs are not equal, see diff above") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), + 'trace function introduces __locals__ unexpectedly') + def test_virtualClassAttributeWithOneMeta(self): + class Meta(type): + def __dir__(cls): + return ['__class__', '__module__', '__name__', 'LIFE'] + def __getattr__(self, name): + if name =='LIFE': + return 42 + return super().__getattr(name) + class Class(metaclass=Meta): + pass + output = StringIO() + helper = pydoc.Helper(output=output) + helper(Class) + expected_text = expected_virtualattribute_pattern1 % __name__ + result = output.getvalue().strip() + if result != expected_text: + print_diffs(expected_text, result) + self.fail("outputs are not equal, see diff above") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), + 'trace function introduces __locals__ unexpectedly') + def test_virtualClassAttributeWithTwoMeta(self): + class Meta1(type): + def __dir__(cls): + return ['__class__', '__module__', '__name__', 'one'] + def __getattr__(self, name): + if name =='one': + return 1 + return super().__getattr__(name) + class Meta2(type): + def __dir__(cls): + return ['__class__', '__module__', '__name__', 'two'] + def __getattr__(self, name): + if name =='two': + return 2 + return super().__getattr__(name) + class Meta3(Meta1, Meta2): + def __dir__(cls): + return list(sorted(set( + ['__class__', '__module__', '__name__', 'three'] + + Meta1.__dir__(cls) + Meta2.__dir__(cls)))) + def __getattr__(self, name): + if name =='three': + return 3 + return super().__getattr__(name) + class Class1(metaclass=Meta1): + pass + class Class2(Class1, metaclass=Meta3): + pass + fail1 = fail2 = False + output = StringIO() + helper = pydoc.Helper(output=output) + helper(Class1) + expected_text1 = expected_virtualattribute_pattern2 % __name__ + result1 = output.getvalue().strip() + if result1 != expected_text1: + print_diffs(expected_text1, result1) + fail1 = True + output = StringIO() + helper = pydoc.Helper(output=output) + helper(Class2) + expected_text2 = expected_virtualattribute_pattern3 % __name__ + result2 = output.getvalue().strip() + if result2 != expected_text2: + print_diffs(expected_text2, result2) + fail2 = True + if fail1 or fail2: + self.fail("outputs are not equal, see diff above") + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), + 'trace function introduces __locals__ unexpectedly') + def test_buggy_dir(self): + class M(type): + def __dir__(cls): + return ['__class__', '__name__', 'missing', 'here'] + class C(metaclass=M): + here = 'present!' + output = StringIO() + helper = pydoc.Helper(output=output) + helper(C) + expected_text = expected_missingattribute_pattern % __name__ + result = output.getvalue().strip() + if result != expected_text: + print_diffs(expected_text, result) + self.fail("outputs are not equal, see diff above") + + @reap_threads def test_main(): try: @@ -624,6 +940,7 @@ def test_main(): PydocServerTest, PydocUrlHandlerTest, TestHelper, + PydocWithMetaClasses, ) finally: reap_children() diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index a475207fbd..37a222d622 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -1,8 +1,10 @@ import unittest +import unittest.mock import random import time import pickle import warnings +from functools import partial from math import log, exp, pi, fsum, sin from test import support @@ -44,6 +46,48 @@ class TestBasicOps: self.assertRaises(TypeError, self.gen.seed, 1, 2, 3, 4) self.assertRaises(TypeError, type(self.gen), []) + @unittest.mock.patch('random._urandom') # os.urandom + def test_seed_when_randomness_source_not_found(self, urandom_mock): + # Random.seed() uses time.time() when an operating system specific + # randomness source is not found. To test this on machines were it + # exists, run the above test, test_seedargs(), again after mocking + # os.urandom() so that it raises the exception expected when the + # randomness source is not available. + urandom_mock.side_effect = NotImplementedError + self.test_seedargs() + + def test_shuffle(self): + shuffle = self.gen.shuffle + lst = [] + shuffle(lst) + self.assertEqual(lst, []) + lst = [37] + shuffle(lst) + self.assertEqual(lst, [37]) + seqs = [list(range(n)) for n in range(10)] + shuffled_seqs = [list(range(n)) for n in range(10)] + for shuffled_seq in shuffled_seqs: + shuffle(shuffled_seq) + for (seq, shuffled_seq) in zip(seqs, shuffled_seqs): + self.assertEqual(len(seq), len(shuffled_seq)) + self.assertEqual(set(seq), set(shuffled_seq)) + # The above tests all would pass if the shuffle was a + # no-op. The following non-deterministic test covers that. It + # asserts that the shuffled sequence of 1000 distinct elements + # must be different from the original one. Although there is + # mathematically a non-zero probability that this could + # actually happen in a genuinely random shuffle, it is + # completely negligible, given that the number of possible + # permutations of 1000 objects is 1000! (factorial of 1000), + # which is considerably larger than the number of atoms in the + # universe... + lst = list(range(1000)) + shuffled_lst = list(range(1000)) + shuffle(shuffled_lst) + self.assertTrue(lst != shuffled_lst) + shuffle(lst) + self.assertTrue(lst != shuffled_lst) + def test_choice(self): choice = self.gen.choice with self.assertRaises(IndexError): @@ -63,6 +107,8 @@ class TestBasicOps: self.assertEqual(len(uniq), k) self.assertTrue(uniq <= set(population)) self.assertEqual(self.gen.sample([], 0), []) # test edge case N==k==0 + # Exception raised if size of sample exceeds that of population + self.assertRaises(ValueError, self.gen.sample, population, N+1) def test_sample_distribution(self): # For the entire allowable range of 0 <= k <= N, validate that @@ -203,6 +249,25 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase): self.assertEqual(set(range(start,stop)), set([self.gen.randrange(start,stop) for i in range(100)])) + def test_randrange_nonunit_step(self): + rint = self.gen.randrange(0, 10, 2) + self.assertIn(rint, (0, 2, 4, 6, 8)) + rint = self.gen.randrange(0, 2, 2) + self.assertEqual(rint, 0) + + def test_randrange_errors(self): + raises = partial(self.assertRaises, ValueError, self.gen.randrange) + # Empty range + raises(3, 3) + raises(-721) + raises(0, 100, -12) + # Non-integer start/stop + raises(3.14159) + raises(0, 2.71828) + # Zero and non-integer step + raises(0, 42, 0) + raises(0, 42, 3.14159) + def test_genrandbits(self): # Verify ranges for k in range(1, 1000): @@ -272,6 +337,16 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase): # Last element s/b an int also self.assertRaises(TypeError, self.gen.setstate, (2, (0,)*624+('a',), None)) + # Little trick to make "tuple(x % (2**32) for x in internalstate)" + # raise ValueError. I cannot think of a simple way to achieve this, so + # I am opting for using a generator as the middle argument of setstate + # which attempts to cast a NaN to integer. + state_values = self.gen.getstate()[1] + state_values = list(state_values) + state_values[-1] = float('nan') + state = (int(x) for x in state_values) + self.assertRaises(TypeError, self.gen.setstate, (2, state, None)) + def test_referenceImplementation(self): # Compare the python implementation with results from the original # code. Create 2000 53-bit precision random floats. Compare only @@ -411,6 +486,38 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase): self.assertEqual(k, numbits) # note the stronger assertion self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion + @unittest.mock.patch('random.Random.random') + def test_randbelow_overriden_random(self, random_mock): + # Random._randbelow() can only use random() when the built-in one + # has been overridden but no new getrandbits() method was supplied. + random_mock.side_effect = random.SystemRandom().random + maxsize = 1<<random.BPF + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + # Population range too large (n >= maxsize) + self.gen._randbelow(maxsize+1, maxsize = maxsize) + self.gen._randbelow(5640, maxsize = maxsize) + + # This might be going too far to test a single line, but because of our + # noble aim of achieving 100% test coverage we need to write a case in + # which the following line in Random._randbelow() gets executed: + # + # rem = maxsize % n + # limit = (maxsize - rem) / maxsize + # r = random() + # while r >= limit: + # r = random() # <== *This line* <==< + # + # Therefore, to guarantee that the while loop is executed at least + # once, we need to mock random() so that it returns a number greater + # than 'limit' the first time it gets called. + + n = 42 + epsilon = 0.01 + limit = (maxsize - (maxsize % n)) / maxsize + random_mock.side_effect = [limit + epsilon, limit - epsilon] + self.gen._randbelow(n, maxsize = maxsize) + def test_randrange_bug_1590891(self): start = 1000000000000 stop = -100000000000000000000 @@ -528,6 +635,106 @@ class TestDistributions(unittest.TestCase): random.vonmisesvariate(0, 1e15) random.vonmisesvariate(0, 1e100) + def test_gammavariate_errors(self): + # Both alpha and beta must be > 0.0 + self.assertRaises(ValueError, random.gammavariate, -1, 3) + self.assertRaises(ValueError, random.gammavariate, 0, 2) + self.assertRaises(ValueError, random.gammavariate, 2, 0) + self.assertRaises(ValueError, random.gammavariate, 1, -3) + + @unittest.mock.patch('random.Random.random') + def test_gammavariate_full_code_coverage(self, random_mock): + # There are three different possibilities in the current implementation + # of random.gammavariate(), depending on the value of 'alpha'. What we + # are going to do here is to fix the values returned by random() to + # generate test cases that provide 100% line coverage of the method. + + # #1: alpha > 1.0: we want the first random number to be outside the + # [1e-7, .9999999] range, so that the continue statement executes + # once. The values of u1 and u2 will be 0.5 and 0.3, respectively. + random_mock.side_effect = [1e-8, 0.5, 0.3] + returned_value = random.gammavariate(1.1, 2.3) + self.assertAlmostEqual(returned_value, 2.53) + + # #2: alpha == 1: first random number less than 1e-7 to that the body + # of the while loop executes once. Then random.random() returns 0.45, + # which causes while to stop looping and the algorithm to terminate. + random_mock.side_effect = [1e-8, 0.45] + returned_value = random.gammavariate(1.0, 3.14) + self.assertAlmostEqual(returned_value, 2.507314166123803) + + # #3: 0 < alpha < 1. This is the most complex region of code to cover, + # as there are multiple if-else statements. Let's take a look at the + # source code, and determine the values that we need accordingly: + # + # while 1: + # u = random() + # b = (_e + alpha)/_e + # p = b*u + # if p <= 1.0: # <=== (A) + # x = p ** (1.0/alpha) + # else: # <=== (B) + # x = -_log((b-p)/alpha) + # u1 = random() + # if p > 1.0: # <=== (C) + # if u1 <= x ** (alpha - 1.0): # <=== (D) + # break + # elif u1 <= _exp(-x): # <=== (E) + # break + # return x * beta + # + # First, we want (A) to be True. For that we need that: + # b*random() <= 1.0 + # r1 = random() <= 1.0 / b + # + # We now get to the second if-else branch, and here, since p <= 1.0, + # (C) is False and we take the elif branch, (E). For it to be True, + # so that the break is executed, we need that: + # r2 = random() <= _exp(-x) + # r2 <= _exp(-(p ** (1.0/alpha))) + # r2 <= _exp(-((b*r1) ** (1.0/alpha))) + + _e = random._e + _exp = random._exp + _log = random._log + alpha = 0.35 + beta = 1.45 + b = (_e + alpha)/_e + epsilon = 0.01 + + r1 = 0.8859296441566 # 1.0 / b + r2 = 0.3678794411714 # _exp(-((b*r1) ** (1.0/alpha))) + + # These four "random" values result in the following trace: + # (A) True, (E) False --> [next iteration of while] + # (A) True, (E) True --> [while loop breaks] + random_mock.side_effect = [r1, r2 + epsilon, r1, r2] + returned_value = random.gammavariate(alpha, beta) + self.assertAlmostEqual(returned_value, 1.4499999999997544) + + # Let's now make (A) be False. If this is the case, when we get to the + # second if-else 'p' is greater than 1, so (C) evaluates to True. We + # now encounter a second if statement, (D), which in order to execute + # must satisfy the following condition: + # r2 <= x ** (alpha - 1.0) + # r2 <= (-_log((b-p)/alpha)) ** (alpha - 1.0) + # r2 <= (-_log((b-(b*r1))/alpha)) ** (alpha - 1.0) + r1 = 0.8959296441566 # (1.0 / b) + epsilon -- so that (A) is False + r2 = 0.9445400408898141 + + # And these four values result in the following trace: + # (B) and (C) True, (D) False --> [next iteration of while] + # (B) and (C) True, (D) True [while loop breaks] + random_mock.side_effect = [r1, r2 + epsilon, r1, r2] + returned_value = random.gammavariate(alpha, beta) + self.assertAlmostEqual(returned_value, 1.5830349561760781) + + @unittest.mock.patch('random.Random.gammavariate') + def test_betavariate_return_zero(self, gammavariate_mock): + # betavariate() returns zero when the Gamma distribution + # that it uses internally returns this same value. + gammavariate_mock.return_value = 0.0 + self.assertEqual(0.0, random.betavariate(2.71828, 3.14159)) class TestModule(unittest.TestCase): def testMagicConstants(self): diff --git a/Lib/test/test_range.py b/Lib/test/test_range.py index 2a13bfeabd..d9d4cf88a2 100644 --- a/Lib/test/test_range.py +++ b/Lib/test/test_range.py @@ -313,7 +313,7 @@ class RangeTest(unittest.TestCase): self.assertRaises(TypeError, range, IN()) # Test use of user-defined classes in slice indices. - self.assertEqual(list(range(10)[:I(5)]), list(range(5))) + self.assertEqual(range(10)[:I(5)], range(5)) with self.assertRaises(RuntimeError): range(0, 10)[:IX()] @@ -353,9 +353,10 @@ class RangeTest(unittest.TestCase): (13, 21, 3), (-2, 2, 2), (2**65, 2**65+2)] for proto in range(pickle.HIGHEST_PROTOCOL + 1): for t in testcases: - r = range(*t) - self.assertEqual(list(pickle.loads(pickle.dumps(r, proto))), - list(r)) + with self.subTest(proto=proto, test=t): + r = range(*t) + self.assertEqual(list(pickle.loads(pickle.dumps(r, proto))), + list(r)) def test_iterator_pickling(self): testcases = [(13,), (0, 11), (-22, 10), (20, 3, -1), diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 1c6f45dfa6..a229e235ca 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -3,10 +3,12 @@ from test.support import verbose, run_unittest, gc_collect, bigmemtest, _2G, \ import io import re from re import Scanner +import sre_compile import sre_constants import sys import string import traceback +import unittest from weakref import proxy # Misc tests from Tim Peters' re.doc @@ -15,10 +17,26 @@ from weakref import proxy # what you're doing. Some of these tests were carefully modeled to # cover most of the code. -import unittest +class S(str): + def __getitem__(self, index): + return S(super().__getitem__(index)) + +class B(bytes): + def __getitem__(self, index): + return B(super().__getitem__(index)) class ReTests(unittest.TestCase): + def assertTypedEqual(self, actual, expect, msg=None): + self.assertEqual(actual, expect, msg) + def recurse(actual, expect): + if isinstance(expect, (tuple, list)): + for x, y in zip(actual, expect): + recurse(x, y) + else: + self.assertIs(type(actual), type(expect), msg) + recurse(actual, expect) + def test_keep_buffer(self): # See bug 14212 b = bytearray(b'x') @@ -53,6 +71,15 @@ class ReTests(unittest.TestCase): return str(int_value + 1) def test_basic_re_sub(self): + self.assertTypedEqual(re.sub('y', 'a', 'xyz'), 'xaz') + self.assertTypedEqual(re.sub('y', S('a'), S('xyz')), 'xaz') + self.assertTypedEqual(re.sub(b'y', b'a', b'xyz'), b'xaz') + self.assertTypedEqual(re.sub(b'y', B(b'a'), B(b'xyz')), b'xaz') + self.assertTypedEqual(re.sub(b'y', bytearray(b'a'), bytearray(b'xyz')), b'xaz') + self.assertTypedEqual(re.sub(b'y', memoryview(b'a'), memoryview(b'xyz')), b'xaz') + for y in ("\xe0", "\u0430", "\U0001d49c"): + self.assertEqual(re.sub(y, 'a', 'x%sz' % y), 'xaz') + self.assertEqual(re.sub("(?i)b+", "x", "bbbb BBBB"), 'x x') self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y'), '9.3 -3 24x100y') @@ -210,10 +237,29 @@ class ReTests(unittest.TestCase): self.assertEqual(re.subn("b*", "x", "xyz", 2), ('xxxyz', 2)) def test_re_split(self): - self.assertEqual(re.split(":", ":a:b::c"), ['', 'a', 'b', '', 'c']) - self.assertEqual(re.split(":*", ":a:b::c"), ['', 'a', 'b', 'c']) - self.assertEqual(re.split("(:*)", ":a:b::c"), - ['', ':', 'a', ':', 'b', '::', 'c']) + for string in ":a:b::c", S(":a:b::c"): + self.assertTypedEqual(re.split(":", string), + ['', 'a', 'b', '', 'c']) + self.assertTypedEqual(re.split(":*", string), + ['', 'a', 'b', 'c']) + self.assertTypedEqual(re.split("(:*)", string), + ['', ':', 'a', ':', 'b', '::', 'c']) + for string in (b":a:b::c", B(b":a:b::c"), bytearray(b":a:b::c"), + memoryview(b":a:b::c")): + self.assertTypedEqual(re.split(b":", string), + [b'', b'a', b'b', b'', b'c']) + self.assertTypedEqual(re.split(b":*", string), + [b'', b'a', b'b', b'c']) + self.assertTypedEqual(re.split(b"(:*)", string), + [b'', b':', b'a', b':', b'b', b'::', b'c']) + for a, b, c in ("\xe0\xdf\xe7", "\u0430\u0431\u0432", + "\U0001d49c\U0001d49e\U0001d4b5"): + string = ":%s:%s::%s" % (a, b, c) + self.assertEqual(re.split(":", string), ['', a, b, '', c]) + self.assertEqual(re.split(":*", string), ['', a, b, c]) + self.assertEqual(re.split("(:*)", string), + ['', ':', a, ':', b, '::', c]) + self.assertEqual(re.split("(?::*)", ":a:b::c"), ['', 'a', 'b', 'c']) self.assertEqual(re.split("(:)*", ":a:b::c"), ['', ':', 'a', ':', 'b', ':', 'c']) @@ -235,22 +281,53 @@ class ReTests(unittest.TestCase): def test_re_findall(self): self.assertEqual(re.findall(":+", "abc"), []) - self.assertEqual(re.findall(":+", "a:b::c:::d"), [":", "::", ":::"]) - self.assertEqual(re.findall("(:+)", "a:b::c:::d"), [":", "::", ":::"]) - self.assertEqual(re.findall("(:)(:*)", "a:b::c:::d"), [(":", ""), - (":", ":"), - (":", "::")]) + for string in "a:b::c:::d", S("a:b::c:::d"): + self.assertTypedEqual(re.findall(":+", string), + [":", "::", ":::"]) + self.assertTypedEqual(re.findall("(:+)", string), + [":", "::", ":::"]) + self.assertTypedEqual(re.findall("(:)(:*)", string), + [(":", ""), (":", ":"), (":", "::")]) + for string in (b"a:b::c:::d", B(b"a:b::c:::d"), bytearray(b"a:b::c:::d"), + memoryview(b"a:b::c:::d")): + self.assertTypedEqual(re.findall(b":+", string), + [b":", b"::", b":::"]) + self.assertTypedEqual(re.findall(b"(:+)", string), + [b":", b"::", b":::"]) + self.assertTypedEqual(re.findall(b"(:)(:*)", string), + [(b":", b""), (b":", b":"), (b":", b"::")]) + for x in ("\xe0", "\u0430", "\U0001d49c"): + xx = x * 2 + xxx = x * 3 + string = "a%sb%sc%sd" % (x, xx, xxx) + self.assertEqual(re.findall("%s+" % x, string), [x, xx, xxx]) + self.assertEqual(re.findall("(%s+)" % x, string), [x, xx, xxx]) + self.assertEqual(re.findall("(%s)(%s*)" % (x, x), string), + [(x, ""), (x, x), (x, xx)]) def test_bug_117612(self): self.assertEqual(re.findall(r"(a|(b))", "aba"), [("a", ""),("b", "b"),("a", "")]) def test_re_match(self): - self.assertEqual(re.match('a', 'a').groups(), ()) - self.assertEqual(re.match('(a)', 'a').groups(), ('a',)) - self.assertEqual(re.match(r'(a)', 'a').group(0), 'a') - self.assertEqual(re.match(r'(a)', 'a').group(1), 'a') - self.assertEqual(re.match(r'(a)', 'a').group(1, 1), ('a', 'a')) + for string in 'a', S('a'): + self.assertEqual(re.match('a', string).groups(), ()) + self.assertEqual(re.match('(a)', string).groups(), ('a',)) + self.assertEqual(re.match('(a)', string).group(0), 'a') + self.assertEqual(re.match('(a)', string).group(1), 'a') + self.assertEqual(re.match('(a)', string).group(1, 1), ('a', 'a')) + for string in b'a', B(b'a'), bytearray(b'a'), memoryview(b'a'): + self.assertEqual(re.match(b'a', string).groups(), ()) + self.assertEqual(re.match(b'(a)', string).groups(), (b'a',)) + self.assertEqual(re.match(b'(a)', string).group(0), b'a') + self.assertEqual(re.match(b'(a)', string).group(1), b'a') + self.assertEqual(re.match(b'(a)', string).group(1, 1), (b'a', b'a')) + for a in ("\xe0", "\u0430", "\U0001d49c"): + self.assertEqual(re.match(a, a).groups(), ()) + self.assertEqual(re.match('(%s)' % a, a).groups(), (a,)) + self.assertEqual(re.match('(%s)' % a, a).group(0), a) + self.assertEqual(re.match('(%s)' % a, a).group(1), a) + self.assertEqual(re.match('(%s)' % a, a).group(1, 1), (a, a)) pat = re.compile('((a)|(b))(c)?') self.assertEqual(pat.match('a').groups(), ('a', 'a', None, None)) @@ -272,6 +349,36 @@ class ReTests(unittest.TestCase): (None, 'b', None)) self.assertEqual(pat.match('ac').group(1, 'b2', 3), ('a', None, 'c')) + def test_re_fullmatch(self): + # Issue 16203: Proposal: add re.fullmatch() method. + self.assertEqual(re.fullmatch(r"a", "a").span(), (0, 1)) + for string in "ab", S("ab"): + self.assertEqual(re.fullmatch(r"a|ab", string).span(), (0, 2)) + for string in b"ab", B(b"ab"), bytearray(b"ab"), memoryview(b"ab"): + self.assertEqual(re.fullmatch(br"a|ab", string).span(), (0, 2)) + for a, b in "\xe0\xdf", "\u0430\u0431", "\U0001d49c\U0001d49e": + r = r"%s|%s" % (a, a + b) + self.assertEqual(re.fullmatch(r, a + b).span(), (0, 2)) + self.assertEqual(re.fullmatch(r".*?$", "abc").span(), (0, 3)) + self.assertEqual(re.fullmatch(r".*?", "abc").span(), (0, 3)) + self.assertEqual(re.fullmatch(r"a.*?b", "ab").span(), (0, 2)) + self.assertEqual(re.fullmatch(r"a.*?b", "abb").span(), (0, 3)) + self.assertEqual(re.fullmatch(r"a.*?b", "axxb").span(), (0, 4)) + self.assertIsNone(re.fullmatch(r"a+", "ab")) + self.assertIsNone(re.fullmatch(r"abc$", "abc\n")) + self.assertIsNone(re.fullmatch(r"abc\Z", "abc\n")) + self.assertIsNone(re.fullmatch(r"(?m)abc$", "abc\n")) + self.assertEqual(re.fullmatch(r"ab(?=c)cd", "abcd").span(), (0, 4)) + self.assertEqual(re.fullmatch(r"ab(?<=b)cd", "abcd").span(), (0, 4)) + self.assertEqual(re.fullmatch(r"(?=a|ab)ab", "ab").span(), (0, 2)) + + self.assertEqual( + re.compile(r"bc").fullmatch("abcd", pos=1, endpos=3).span(), (1, 3)) + self.assertEqual( + re.compile(r".*?$").fullmatch("abcd", pos=1, endpos=3).span(), (1, 3)) + self.assertEqual( + re.compile(r".*?").fullmatch("abcd", pos=1, endpos=3).span(), (1, 3)) + def test_re_groupref_exists(self): self.assertEqual(re.match('^(\()?([^()]+)(?(1)\))$', '(a)').groups(), ('(', 'a')) @@ -1053,6 +1160,28 @@ class ReTests(unittest.TestCase): self.assertEqual(re.compile(pattern, re.S).findall(b'xyz'), [b'xyz'], msg=pattern) + def test_match_repr(self): + for string in '[abracadabra]', S('[abracadabra]'): + m = re.search(r'(.+)(.*?)\1', string) + self.assertEqual(repr(m), "<%s.%s object; " + "span=(1, 12), match='abracadabra'>" % + (type(m).__module__, type(m).__qualname__)) + for string in (b'[abracadabra]', B(b'[abracadabra]'), + bytearray(b'[abracadabra]'), + memoryview(b'[abracadabra]')): + m = re.search(rb'(.+)(.*?)\1', string) + self.assertEqual(repr(m), "<%s.%s object; " + "span=(1, 12), match=b'abracadabra'>" % + (type(m).__module__, type(m).__qualname__)) + + first, second = list(re.finditer("(aa)|(bb)", "aa bb")) + self.assertEqual(repr(first), "<%s.%s object; " + "span=(0, 2), match='aa'>" % + (type(second).__module__, type(first).__qualname__)) + self.assertEqual(repr(second), "<%s.%s object; " + "span=(3, 5), match='bb'>" % + (type(second).__module__, type(second).__qualname__)) + def test_bug_2537(self): # issue 2537: empty submatches @@ -1077,6 +1206,83 @@ class ReTests(unittest.TestCase): ['literal 102 ', 'literal 111 ', 'literal 111 ']) +class PatternReprTests(unittest.TestCase): + def check(self, pattern, expected): + self.assertEqual(repr(re.compile(pattern)), expected) + + def check_flags(self, pattern, flags, expected): + self.assertEqual(repr(re.compile(pattern, flags)), expected) + + def test_without_flags(self): + self.check('random pattern', + "re.compile('random pattern')") + + def test_single_flag(self): + self.check_flags('random pattern', re.IGNORECASE, + "re.compile('random pattern', re.IGNORECASE)") + + def test_multiple_flags(self): + self.check_flags('random pattern', re.I|re.S|re.X, + "re.compile('random pattern', " + "re.IGNORECASE|re.DOTALL|re.VERBOSE)") + + def test_unicode_flag(self): + self.check_flags('random pattern', re.U, + "re.compile('random pattern')") + self.check_flags('random pattern', re.I|re.S|re.U, + "re.compile('random pattern', " + "re.IGNORECASE|re.DOTALL)") + + def test_inline_flags(self): + self.check('(?i)pattern', + "re.compile('(?i)pattern', re.IGNORECASE)") + + def test_unknown_flags(self): + self.check_flags('random pattern', 0x123000, + "re.compile('random pattern', 0x123000)") + self.check_flags('random pattern', 0x123000|re.I, + "re.compile('random pattern', re.IGNORECASE|0x123000)") + + def test_bytes(self): + self.check(b'bytes pattern', + "re.compile(b'bytes pattern')") + self.check_flags(b'bytes pattern', re.A, + "re.compile(b'bytes pattern', re.ASCII)") + + def test_quotes(self): + self.check('random "double quoted" pattern', + '''re.compile('random "double quoted" pattern')''') + self.check("random 'single quoted' pattern", + '''re.compile("random 'single quoted' pattern")''') + self.check('''both 'single' and "double" quotes''', + '''re.compile('both \\'single\\' and "double" quotes')''') + + def test_long_pattern(self): + pattern = 'Very %spattern' % ('long ' * 1000) + r = repr(re.compile(pattern)) + self.assertLess(len(r), 300) + self.assertEqual(r[:30], "re.compile('Very long long lon") + r = repr(re.compile(pattern, re.I)) + self.assertLess(len(r), 300) + self.assertEqual(r[:30], "re.compile('Very long long lon") + self.assertEqual(r[-16:], ", re.IGNORECASE)") + + +class ImplementationTest(unittest.TestCase): + """ + Test implementation details of the re module. + """ + + def test_overlap_table(self): + f = sre_compile._generate_overlap_table + self.assertEqual(f(""), []) + self.assertEqual(f("a"), [0]) + self.assertEqual(f("abcd"), [0, 0, 0, 0]) + self.assertEqual(f("aaaa"), [0, 1, 2, 3]) + self.assertEqual(f("ababba"), [0, 0, 1, 2, 0, 1]) + self.assertEqual(f("abcabdac"), [0, 0, 0, 1, 2, 0, 1, 0]) + + def run_re_tests(): from test.re_tests import tests, SUCCEED, FAIL, SYNTAX_ERROR if verbose: @@ -1206,7 +1412,7 @@ def run_re_tests(): def test_main(): - run_unittest(ReTests) + run_unittest(__name__) run_re_tests() if __name__ == "__main__": diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py new file mode 100644 index 0000000000..a398a4f836 --- /dev/null +++ b/Lib/test/test_regrtest.py @@ -0,0 +1,275 @@ +""" +Tests of regrtest.py. +""" + +import argparse +import faulthandler +import getopt +import os.path +import unittest +from test import regrtest, support + +class ParseArgsTestCase(unittest.TestCase): + + """Test regrtest's argument parsing.""" + + def checkError(self, args, msg): + with support.captured_stderr() as err, self.assertRaises(SystemExit): + regrtest._parse_args(args) + self.assertIn(msg, err.getvalue()) + + def test_help(self): + for opt in '-h', '--help': + with self.subTest(opt=opt): + with support.captured_stdout() as out, \ + self.assertRaises(SystemExit): + regrtest._parse_args([opt]) + self.assertIn('Run Python regression tests.', out.getvalue()) + + @unittest.skipUnless(hasattr(faulthandler, 'dump_traceback_later'), + "faulthandler.dump_traceback_later() required") + def test_timeout(self): + ns = regrtest._parse_args(['--timeout', '4.2']) + self.assertEqual(ns.timeout, 4.2) + self.checkError(['--timeout'], 'expected one argument') + self.checkError(['--timeout', 'foo'], 'invalid float value') + + def test_wait(self): + ns = regrtest._parse_args(['--wait']) + self.assertTrue(ns.wait) + + def test_slaveargs(self): + ns = regrtest._parse_args(['--slaveargs', '[[], {}]']) + self.assertEqual(ns.slaveargs, '[[], {}]') + self.checkError(['--slaveargs'], 'expected one argument') + + def test_start(self): + for opt in '-S', '--start': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, 'foo']) + self.assertEqual(ns.start, 'foo') + self.checkError([opt], 'expected one argument') + + def test_verbose(self): + ns = regrtest._parse_args(['-v']) + self.assertEqual(ns.verbose, 1) + ns = regrtest._parse_args(['-vvv']) + self.assertEqual(ns.verbose, 3) + ns = regrtest._parse_args(['--verbose']) + self.assertEqual(ns.verbose, 1) + ns = regrtest._parse_args(['--verbose'] * 3) + self.assertEqual(ns.verbose, 3) + ns = regrtest._parse_args([]) + self.assertEqual(ns.verbose, 0) + + def test_verbose2(self): + for opt in '-w', '--verbose2': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.verbose2) + + def test_verbose3(self): + for opt in '-W', '--verbose3': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.verbose3) + + def test_quiet(self): + for opt in '-q', '--quiet': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.quiet) + self.assertEqual(ns.verbose, 0) + + def test_slow(self): + for opt in '-o', '--slow': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.print_slow) + + def test_header(self): + ns = regrtest._parse_args(['--header']) + self.assertTrue(ns.header) + + def test_randomize(self): + for opt in '-r', '--randomize': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.randomize) + + def test_randseed(self): + ns = regrtest._parse_args(['--randseed', '12345']) + self.assertEqual(ns.random_seed, 12345) + self.assertTrue(ns.randomize) + self.checkError(['--randseed'], 'expected one argument') + self.checkError(['--randseed', 'foo'], 'invalid int value') + + def test_fromfile(self): + for opt in '-f', '--fromfile': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, 'foo']) + self.assertEqual(ns.fromfile, 'foo') + self.checkError([opt], 'expected one argument') + self.checkError([opt, 'foo', '-s'], "don't go together") + + def test_exclude(self): + for opt in '-x', '--exclude': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.exclude) + + def test_single(self): + for opt in '-s', '--single': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.single) + self.checkError([opt, '-f', 'foo'], "don't go together") + + def test_match(self): + for opt in '-m', '--match': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, 'pattern']) + self.assertEqual(ns.match_tests, 'pattern') + self.checkError([opt], 'expected one argument') + + def test_failfast(self): + for opt in '-G', '--failfast': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, '-v']) + self.assertTrue(ns.failfast) + ns = regrtest._parse_args([opt, '-W']) + self.assertTrue(ns.failfast) + self.checkError([opt], '-G/--failfast needs either -v or -W') + + def test_use(self): + for opt in '-u', '--use': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, 'gui,network']) + self.assertEqual(ns.use_resources, ['gui', 'network']) + ns = regrtest._parse_args([opt, 'gui,none,network']) + self.assertEqual(ns.use_resources, ['network']) + expected = list(regrtest.RESOURCE_NAMES) + expected.remove('gui') + ns = regrtest._parse_args([opt, 'all,-gui']) + self.assertEqual(ns.use_resources, expected) + self.checkError([opt], 'expected one argument') + self.checkError([opt, 'foo'], 'invalid resource') + + def test_memlimit(self): + for opt in '-M', '--memlimit': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, '4G']) + self.assertEqual(ns.memlimit, '4G') + self.checkError([opt], 'expected one argument') + + def test_testdir(self): + ns = regrtest._parse_args(['--testdir', 'foo']) + self.assertEqual(ns.testdir, os.path.join(support.SAVEDCWD, 'foo')) + self.checkError(['--testdir'], 'expected one argument') + + def test_runleaks(self): + for opt in '-L', '--runleaks': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.runleaks) + + def test_huntrleaks(self): + for opt in '-R', '--huntrleaks': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, ':']) + self.assertEqual(ns.huntrleaks, (5, 4, 'reflog.txt')) + ns = regrtest._parse_args([opt, '6:']) + self.assertEqual(ns.huntrleaks, (6, 4, 'reflog.txt')) + ns = regrtest._parse_args([opt, ':3']) + self.assertEqual(ns.huntrleaks, (5, 3, 'reflog.txt')) + ns = regrtest._parse_args([opt, '6:3:leaks.log']) + self.assertEqual(ns.huntrleaks, (6, 3, 'leaks.log')) + self.checkError([opt], 'expected one argument') + self.checkError([opt, '6'], + 'needs 2 or 3 colon-separated arguments') + self.checkError([opt, 'foo:'], 'invalid huntrleaks value') + self.checkError([opt, '6:foo'], 'invalid huntrleaks value') + + def test_multiprocess(self): + for opt in '-j', '--multiprocess': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, '2']) + self.assertEqual(ns.use_mp, 2) + self.checkError([opt], 'expected one argument') + self.checkError([opt, 'foo'], 'invalid int value') + self.checkError([opt, '2', '-T'], "don't go together") + self.checkError([opt, '2', '-l'], "don't go together") + self.checkError([opt, '2', '-M', '4G'], "don't go together") + + def test_coverage(self): + for opt in '-T', '--coverage': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.trace) + + def test_coverdir(self): + for opt in '-D', '--coverdir': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, 'foo']) + self.assertEqual(ns.coverdir, + os.path.join(support.SAVEDCWD, 'foo')) + self.checkError([opt], 'expected one argument') + + def test_nocoverdir(self): + for opt in '-N', '--nocoverdir': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertIsNone(ns.coverdir) + + def test_threshold(self): + for opt in '-t', '--threshold': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt, '1000']) + self.assertEqual(ns.threshold, 1000) + self.checkError([opt], 'expected one argument') + self.checkError([opt, 'foo'], 'invalid int value') + + def test_nowindows(self): + for opt in '-n', '--nowindows': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.nowindows) + + def test_forever(self): + for opt in '-F', '--forever': + with self.subTest(opt=opt): + ns = regrtest._parse_args([opt]) + self.assertTrue(ns.forever) + + + def test_unrecognized_argument(self): + self.checkError(['--xxx'], 'usage:') + + def test_long_option__partial(self): + ns = regrtest._parse_args(['--qui']) + self.assertTrue(ns.quiet) + self.assertEqual(ns.verbose, 0) + + def test_two_options(self): + ns = regrtest._parse_args(['--quiet', '--exclude']) + self.assertTrue(ns.quiet) + self.assertEqual(ns.verbose, 0) + self.assertTrue(ns.exclude) + + def test_option_with_empty_string_value(self): + ns = regrtest._parse_args(['--start', '']) + self.assertEqual(ns.start, '') + + def test_arg(self): + ns = regrtest._parse_args(['foo']) + self.assertEqual(ns.args, ['foo']) + + def test_option_and_arg(self): + ns = regrtest._parse_args(['--quiet', 'foo']) + self.assertTrue(ns.quiet) + self.assertEqual(ns.verbose, 0) + self.assertEqual(ns.args, ['foo']) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py index c233b1d67b..ae67f06bc4 100644 --- a/Lib/test/test_reprlib.py +++ b/Lib/test/test_reprlib.py @@ -3,11 +3,11 @@ Nick Mathewson """ -import imp import sys import os import shutil import importlib +import importlib.util import unittest from test.support import run_unittest, create_empty_file, verbose @@ -248,7 +248,8 @@ class LongReprTest(unittest.TestCase): source_path_len += 2 * (len(self.longname) + 1) # a path separator + `module_name` + ".py" source_path_len += len(module_name) + 1 + len(".py") - cached_path_len = source_path_len + len(imp.cache_from_source("x.py")) - len("x.py") + cached_path_len = (source_path_len + + len(importlib.util.cache_from_source("x.py")) - len("x.py")) if os.name == 'nt' and cached_path_len >= 258: # Under Windows, the max path len is 260 including C's terminating # NUL character. @@ -259,6 +260,7 @@ class LongReprTest(unittest.TestCase): print("cached_path_len =", cached_path_len) def test_module(self): + self.maxDiff = None self._check_path_limitations(self.pkgname) create_empty_file(os.path.join(self.subpkgname, self.pkgname + '.py')) importlib.invalidate_caches() diff --git a/Lib/test/test_resource.py b/Lib/test/test_resource.py index 0cf61cb17a..2ecae0fc45 100644 --- a/Lib/test/test_resource.py +++ b/Lib/test/test_resource.py @@ -1,3 +1,6 @@ +import contextlib +import sys +import os import unittest from test import support import time @@ -61,7 +64,7 @@ class ResourceTest(unittest.TestCase): for i in range(5): time.sleep(.1) f.flush() - except IOError: + except OSError: if not limit_set: raise if limit_set: @@ -129,6 +132,33 @@ class ResourceTest(unittest.TestCase): self.assertIsInstance(pagesize, int) self.assertGreaterEqual(pagesize, 0) + @unittest.skipUnless(sys.platform == 'linux', 'test requires Linux') + def test_linux_constants(self): + for attr in ['MSGQUEUE', 'NICE', 'RTPRIO', 'RTTIME', 'SIGPENDING']: + with contextlib.suppress(AttributeError): + self.assertIsInstance(getattr(resource, 'RLIMIT_' + attr), int) + + @support.requires_freebsd_version(9) + def test_freebsd_contants(self): + for attr in ['SWAP', 'SBSIZE', 'NPTS']: + with contextlib.suppress(AttributeError): + self.assertIsInstance(getattr(resource, 'RLIMIT_' + attr), int) + + @unittest.skipUnless(hasattr(resource, 'prlimit'), 'no prlimit') + @support.requires_linux_version(2, 6, 36) + def test_prlimit(self): + self.assertRaises(TypeError, resource.prlimit) + if os.geteuid() != 0: + self.assertRaises(PermissionError, resource.prlimit, + 1, resource.RLIMIT_AS) + self.assertRaises(ProcessLookupError, resource.prlimit, + -1, resource.RLIMIT_AS) + limit = resource.getrlimit(resource.RLIMIT_AS) + self.assertEqual(resource.prlimit(0, resource.RLIMIT_AS), limit) + self.assertEqual(resource.prlimit(0, resource.RLIMIT_AS, limit), + limit) + + def test_main(verbose=None): support.run_unittest(ResourceTest) diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py index 2ddba3413a..786b813b32 100644 --- a/Lib/test/test_runpy.py +++ b/Lib/test/test_runpy.py @@ -5,7 +5,7 @@ import os.path import sys import re import tempfile -import importlib +import importlib, importlib.machinery, importlib.util import py_compile from test.support import ( forget, make_legacy_pyc, run_unittest, unload, verbose, no_tracing, @@ -47,6 +47,7 @@ implicit_namespace = { "__cached__": None, "__package__": None, "__doc__": None, + "__spec__": None } example_namespace = { "sys": sys, @@ -67,11 +68,19 @@ class CodeExecutionMixin: # testing occurs at those upper layers as well, not just at the utility # layer + # Figuring out the loader details in advance is hard to do, so we skip + # checking the full details of loader and loader_state + CHECKED_SPEC_ATTRIBUTES = ["name", "parent", "origin", "cached", + "has_location", "submodule_search_locations"] + def assertNamespaceMatches(self, result_ns, expected_ns): """Check two namespaces match. Ignores any unspecified interpreter created names """ + # Avoid side effects + result_ns = result_ns.copy() + expected_ns = expected_ns.copy() # Impls are permitted to add extra names, so filter them out for k in list(result_ns): if k.startswith("__") and k.endswith("__"): @@ -79,7 +88,25 @@ class CodeExecutionMixin: result_ns.pop(k) if k not in expected_ns["nested"]: result_ns["nested"].pop(k) - # Don't use direct dict comparison - the diffs are too hard to debug + # Spec equality includes the loader, so we take the spec out of the + # result namespace and check that separately + result_spec = result_ns.pop("__spec__") + expected_spec = expected_ns.pop("__spec__") + if expected_spec is None: + self.assertIsNone(result_spec) + else: + # If an expected loader is set, we just check we got the right + # type, rather than checking for full equality + if expected_spec.loader is not None: + self.assertEqual(type(result_spec.loader), + type(expected_spec.loader)) + for attr in self.CHECKED_SPEC_ATTRIBUTES: + k = "__spec__." + attr + actual = (k, getattr(result_spec, attr)) + expected = (k, getattr(expected_spec, attr)) + self.assertEqual(actual, expected) + # For the rest, we still don't use direct dict comparison on the + # namespace, as the diffs are too hard to debug if anything breaks self.assertEqual(set(result_ns), set(expected_ns)) for k in result_ns: actual = (k, result_ns[k]) @@ -130,12 +157,16 @@ class ExecutionLayerTestCase(unittest.TestCase, CodeExecutionMixin): mod_fname = "Some other nonsense" mod_loader = "Now you're just being silly" mod_package = '' # Treat as a top level module + mod_spec = importlib.machinery.ModuleSpec(mod_name, + origin=mod_fname, + loader=mod_loader) expected_ns = example_namespace.copy() expected_ns.update({ "__name__": mod_name, "__file__": mod_fname, "__loader__": mod_loader, "__package__": mod_package, + "__spec__": mod_spec, "run_argv0": mod_fname, "run_name_in_sys_modules": True, "module_in_sys_modules": True, @@ -144,9 +175,7 @@ class ExecutionLayerTestCase(unittest.TestCase, CodeExecutionMixin): return _run_module_code(example_source, init_globals, mod_name, - mod_fname, - mod_loader, - mod_package) + mod_spec) self.check_code_execution(create_ns, expected_ns) # TODO: Use self.addCleanup to get rid of a lot of try-finally blocks @@ -176,31 +205,43 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin): def test_library_module(self): self.assertEqual(run_module("runpy")["__name__"], "runpy") - def _add_pkg_dir(self, pkg_dir): + def _add_pkg_dir(self, pkg_dir, namespace=False): os.mkdir(pkg_dir) + if namespace: + return None pkg_fname = os.path.join(pkg_dir, "__init__.py") create_empty_file(pkg_fname) return pkg_fname - def _make_pkg(self, source, depth, mod_base="runpy_test"): + def _make_pkg(self, source, depth, mod_base="runpy_test", + *, namespace=False, parent_namespaces=False): + # Enforce a couple of internal sanity checks on test cases + if (namespace or parent_namespaces) and not depth: + raise RuntimeError("Can't mark top level module as a " + "namespace package") pkg_name = "__runpy_pkg__" test_fname = mod_base+os.extsep+"py" pkg_dir = sub_dir = os.path.realpath(tempfile.mkdtemp()) if verbose > 1: print(" Package tree in:", sub_dir) sys.path.insert(0, pkg_dir) if verbose > 1: print(" Updated sys.path:", sys.path[0]) - for i in range(depth): - sub_dir = os.path.join(sub_dir, pkg_name) - pkg_fname = self._add_pkg_dir(sub_dir) - if verbose > 1: print(" Next level in:", sub_dir) - if verbose > 1: print(" Created:", pkg_fname) + if depth: + namespace_flags = [parent_namespaces] * depth + namespace_flags[-1] = namespace + for namespace_flag in namespace_flags: + sub_dir = os.path.join(sub_dir, pkg_name) + pkg_fname = self._add_pkg_dir(sub_dir, namespace_flag) + if verbose > 1: print(" Next level in:", sub_dir) + if verbose > 1: print(" Created:", pkg_fname) mod_fname = os.path.join(sub_dir, test_fname) mod_file = open(mod_fname, "w") mod_file.write(source) mod_file.close() if verbose > 1: print(" Created:", mod_fname) mod_name = (pkg_name+".")*depth + mod_base - return pkg_dir, mod_fname, mod_name + mod_spec = importlib.util.spec_from_file_location(mod_name, + mod_fname) + return pkg_dir, mod_fname, mod_name, mod_spec def _del_pkg(self, top, depth, mod_name): for entry in list(sys.modules): @@ -230,19 +271,29 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin): def _fix_ns_for_legacy_pyc(self, ns, alter_sys): char_to_add = "c" if __debug__ else "o" ns["__file__"] += char_to_add + ns["__cached__"] = ns["__file__"] + spec = ns["__spec__"] + new_spec = importlib.util.spec_from_file_location(spec.name, + ns["__file__"]) + ns["__spec__"] = new_spec if alter_sys: ns["run_argv0"] += char_to_add - def _check_module(self, depth, alter_sys=False): - pkg_dir, mod_fname, mod_name = ( - self._make_pkg(example_source, depth)) + def _check_module(self, depth, alter_sys=False, + *, namespace=False, parent_namespaces=False): + pkg_dir, mod_fname, mod_name, mod_spec = ( + self._make_pkg(example_source, depth, + namespace=namespace, + parent_namespaces=parent_namespaces)) forget(mod_name) expected_ns = example_namespace.copy() expected_ns.update({ "__name__": mod_name, "__file__": mod_fname, + "__cached__": mod_spec.cached, "__package__": mod_name.rpartition(".")[0], + "__spec__": mod_spec, }) if alter_sys: expected_ns.update({ @@ -269,16 +320,21 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin): self._del_pkg(pkg_dir, depth, mod_name) if verbose > 1: print("Module executed successfully") - def _check_package(self, depth, alter_sys=False): - pkg_dir, mod_fname, mod_name = ( - self._make_pkg(example_source, depth, "__main__")) + def _check_package(self, depth, alter_sys=False, + *, namespace=False, parent_namespaces=False): + pkg_dir, mod_fname, mod_name, mod_spec = ( + self._make_pkg(example_source, depth, "__main__", + namespace=namespace, + parent_namespaces=parent_namespaces)) pkg_name = mod_name.rpartition(".")[0] forget(mod_name) expected_ns = example_namespace.copy() expected_ns.update({ "__name__": mod_name, "__file__": mod_fname, + "__cached__": importlib.util.cache_from_source(mod_fname), "__package__": pkg_name, + "__spec__": mod_spec, }) if alter_sys: expected_ns.update({ @@ -334,7 +390,7 @@ from __future__ import absolute_import from . import sibling from ..uncle.cousin import nephew """ - pkg_dir, mod_fname, mod_name = ( + pkg_dir, mod_fname, mod_name, mod_spec = ( self._make_pkg(contents, depth)) if run_name is None: expected_name = mod_name @@ -373,11 +429,31 @@ from ..uncle.cousin import nephew if verbose > 1: print("Testing package depth:", depth) self._check_module(depth) + def test_run_module_in_namespace_package(self): + for depth in range(1, 4): + if verbose > 1: print("Testing package depth:", depth) + self._check_module(depth, namespace=True, parent_namespaces=True) + def test_run_package(self): for depth in range(1, 4): if verbose > 1: print("Testing package depth:", depth) self._check_package(depth) + def test_run_package_in_namespace_package(self): + for depth in range(1, 4): + if verbose > 1: print("Testing package depth:", depth) + self._check_package(depth, parent_namespaces=True) + + def test_run_namespace_package(self): + for depth in range(1, 4): + if verbose > 1: print("Testing package depth:", depth) + self._check_package(depth, namespace=True) + + def test_run_namespace_package_in_namespace_package(self): + for depth in range(1, 4): + if verbose > 1: print("Testing package depth:", depth) + self._check_package(depth, namespace=True, parent_namespaces=True) + def test_run_module_alter_sys(self): for depth in range(4): if verbose > 1: print("Testing package depth:", depth) @@ -401,14 +477,16 @@ from ..uncle.cousin import nephew def test_run_name(self): depth = 1 run_name = "And now for something completely different" - pkg_dir, mod_fname, mod_name = ( + pkg_dir, mod_fname, mod_name, mod_spec = ( self._make_pkg(example_source, depth)) forget(mod_name) expected_ns = example_namespace.copy() expected_ns.update({ "__name__": run_name, "__file__": mod_fname, + "__cached__": importlib.util.cache_from_source(mod_fname), "__package__": mod_name.rpartition(".")[0], + "__spec__": mod_spec, }) def create_ns(init_globals): return run_module(mod_name, init_globals, run_name) @@ -437,7 +515,7 @@ from ..uncle.cousin import nephew pkg_name = ".".join([base_name] * max_depth) expected_packages.add(pkg_name) expected_modules.add(pkg_name + ".runpy_test") - pkg_dir, mod_fname, mod_name = ( + pkg_dir, mod_fname, mod_name, mod_spec = ( self._make_pkg("", max_depth)) self.addCleanup(self._del_pkg, pkg_dir, max_depth, mod_name) for depth in range(2, max_depth+1): @@ -455,21 +533,39 @@ from ..uncle.cousin import nephew class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): """Unit tests for runpy.run_path""" - def _make_test_script(self, script_dir, script_basename, source=None): + def _make_test_script(self, script_dir, script_basename, + source=None, omit_suffix=False): if source is None: source = example_source - return make_script(script_dir, script_basename, source) + return make_script(script_dir, script_basename, + source, omit_suffix) def _check_script(self, script_name, expected_name, expected_file, - expected_argv0): + expected_argv0, mod_name=None, + expect_spec=True, check_loader=True): # First check is without run_name def create_ns(init_globals): return run_path(script_name, init_globals) expected_ns = example_namespace.copy() + if mod_name is None: + spec_name = expected_name + else: + spec_name = mod_name + if expect_spec: + mod_spec = importlib.util.spec_from_file_location(spec_name, + expected_file) + mod_cached = mod_spec.cached + if not check_loader: + mod_spec.loader = None + else: + mod_spec = mod_cached = None + expected_ns.update({ "__name__": expected_name, "__file__": expected_file, + "__cached__": mod_cached, "__package__": "", + "__spec__": mod_spec, "run_argv0": expected_argv0, "run_name_in_sys_modules": True, "module_in_sys_modules": True, @@ -479,6 +575,12 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): run_name = "prove.issue15230.is.fixed" def create_ns(init_globals): return run_path(script_name, init_globals, run_name) + if expect_spec and mod_name is None: + mod_spec = importlib.util.spec_from_file_location(run_name, + expected_file) + if not check_loader: + mod_spec.loader = None + expected_ns["__spec__"] = mod_spec expected_ns["__name__"] = run_name expected_ns["__package__"] = run_name.rpartition(".")[0] self.check_code_execution(create_ns, expected_ns) @@ -492,7 +594,15 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): mod_name = 'script' script_name = self._make_test_script(script_dir, mod_name) self._check_script(script_name, "<run_path>", script_name, - script_name) + script_name, expect_spec=False) + + def test_basic_script_no_suffix(self): + with temp_dir() as script_dir: + mod_name = 'script' + script_name = self._make_test_script(script_dir, mod_name, + omit_suffix=True) + self._check_script(script_name, "<run_path>", script_name, + script_name, expect_spec=False) def test_script_compiled(self): with temp_dir() as script_dir: @@ -501,14 +611,14 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): compiled_name = py_compile.compile(script_name, doraise=True) os.remove(script_name) self._check_script(compiled_name, "<run_path>", compiled_name, - compiled_name) + compiled_name, expect_spec=False) def test_directory(self): with temp_dir() as script_dir: mod_name = '__main__' script_name = self._make_test_script(script_dir, mod_name) self._check_script(script_dir, "<run_path>", script_name, - script_dir) + script_dir, mod_name=mod_name) def test_directory_compiled(self): with temp_dir() as script_dir: @@ -519,7 +629,7 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): if not sys.dont_write_bytecode: legacy_pyc = make_legacy_pyc(script_name) self._check_script(script_dir, "<run_path>", legacy_pyc, - script_dir) + script_dir, mod_name=mod_name) def test_directory_error(self): with temp_dir() as script_dir: @@ -533,7 +643,8 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): mod_name = '__main__' script_name = self._make_test_script(script_dir, mod_name) zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name) - self._check_script(zip_name, "<run_path>", fname, zip_name) + self._check_script(zip_name, "<run_path>", fname, zip_name, + mod_name=mod_name, check_loader=False) def test_zipfile_compiled(self): with temp_dir() as script_dir: @@ -542,7 +653,8 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): compiled_name = py_compile.compile(script_name, doraise=True) zip_name, fname = make_zip_script(script_dir, 'test_zip', compiled_name) - self._check_script(zip_name, "<run_path>", fname, zip_name) + self._check_script(zip_name, "<run_path>", fname, zip_name, + mod_name=mod_name, check_loader=False) def test_zipfile_error(self): with temp_dir() as script_dir: @@ -575,12 +687,5 @@ s = "non-ASCII: h\xe9" self.assertEqual(result['s'], "non-ASCII: h\xe9") -def test_main(): - run_unittest( - ExecutionLayerTestCase, - RunModuleTestCase, - RunPathTestCase - ) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py index 79fa7d3e3d..fe8e785092 100644 --- a/Lib/test/test_sched.py +++ b/Lib/test/test_sched.py @@ -195,8 +195,5 @@ class TestCase(unittest.TestCase): self.assertEqual(l, []) -def test_main(): - support.run_unittest(TestCase) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py index 26ce0424f8..b325545f32 100644 --- a/Lib/test/test_scope.py +++ b/Lib/test/test_scope.py @@ -715,6 +715,19 @@ class ScopeTests(unittest.TestCase): def b(): global a + def testClassNamespaceOverridesClosure(self): + # See #17853. + x = 42 + class X: + locals()["x"] = 43 + y = x + self.assertEqual(X.y, 43) + class X: + locals()["x"] = 43 + del x + self.assertFalse(hasattr(X, "x")) + self.assertEqual(x, 42) + @cpython_only def testCellLeak(self): # Issue 17927. @@ -743,10 +756,6 @@ class ScopeTests(unittest.TestCase): del tester self.assertIsNone(ref()) - def test__Class__Global(self): - s = "class X:\n global __class__\n def f(self): super()" - self.assertRaises(SyntaxError, exec, s) - def test_main(): run_unittest(ScopeTests) diff --git a/Lib/test/test_select.py b/Lib/test/test_select.py index ddb9a0f67e..8f9a1c9d88 100644 --- a/Lib/test/test_select.py +++ b/Lib/test/test_select.py @@ -5,7 +5,7 @@ import sys import unittest from test import support -@unittest.skipIf(sys.platform[:3] in ('win', 'os2', 'riscos'), +@unittest.skipIf((sys.platform[:3]=='win'), "can't easily test on this system") class SelectTestCase(unittest.TestCase): @@ -32,7 +32,7 @@ class SelectTestCase(unittest.TestCase): fp.close() try: select.select([fd], [], [], 0) - except select.error as err: + except OSError as err: self.assertEqual(err.errno, errno.EBADF) else: self.fail("exception not raised") diff --git a/Lib/test/test_selectors.py b/Lib/test/test_selectors.py new file mode 100644 index 0000000000..34edd7630d --- /dev/null +++ b/Lib/test/test_selectors.py @@ -0,0 +1,453 @@ +import errno +import os +import random +import selectors +import signal +import socket +from test import support +from time import sleep +import unittest +import unittest.mock +try: + from time import monotonic as time +except ImportError: + from time import time as time +try: + import resource +except ImportError: + resource = None + + +if hasattr(socket, 'socketpair'): + socketpair = socket.socketpair +else: + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + with socket.socket(family, type, proto) as l: + l.bind((support.HOST, 0)) + l.listen(3) + c = socket.socket(family, type, proto) + try: + c.connect(l.getsockname()) + caddr = c.getsockname() + while True: + a, addr = l.accept() + # check that we've got the correct client + if addr == caddr: + return c, a + a.close() + except OSError: + c.close() + raise + + +def find_ready_matching(ready, flag): + match = [] + for key, events in ready: + if events & flag: + match.append(key.fileobj) + return match + + +class BaseSelectorTestCase(unittest.TestCase): + + def make_socketpair(self): + rd, wr = socketpair() + self.addCleanup(rd.close) + self.addCleanup(wr.close) + return rd, wr + + def test_register(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + key = s.register(rd, selectors.EVENT_READ, "data") + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fileobj, rd) + self.assertEqual(key.fd, rd.fileno()) + self.assertEqual(key.events, selectors.EVENT_READ) + self.assertEqual(key.data, "data") + + # register an unknown event + self.assertRaises(ValueError, s.register, 0, 999999) + + # register an invalid FD + self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ) + + # register twice + self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ) + + # register the same FD, but with a different object + self.assertRaises(KeyError, s.register, rd.fileno(), + selectors.EVENT_READ) + + def test_unregister(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + s.register(rd, selectors.EVENT_READ) + s.unregister(rd) + + # unregister an unknown file obj + self.assertRaises(KeyError, s.unregister, 999999) + + # unregister twice + self.assertRaises(KeyError, s.unregister, rd) + + def test_unregister_after_fd_close(self): + s = self.SELECTOR() + self.addCleanup(s.close) + rd, wr = self.make_socketpair() + r, w = rd.fileno(), wr.fileno() + s.register(r, selectors.EVENT_READ) + s.register(w, selectors.EVENT_WRITE) + rd.close() + wr.close() + s.unregister(r) + s.unregister(w) + + @unittest.skipUnless(os.name == 'posix', "requires posix") + def test_unregister_after_fd_close_and_reuse(self): + s = self.SELECTOR() + self.addCleanup(s.close) + rd, wr = self.make_socketpair() + r, w = rd.fileno(), wr.fileno() + s.register(r, selectors.EVENT_READ) + s.register(w, selectors.EVENT_WRITE) + rd2, wr2 = self.make_socketpair() + rd.close() + wr.close() + os.dup2(rd2.fileno(), r) + os.dup2(wr2.fileno(), w) + self.addCleanup(os.close, r) + self.addCleanup(os.close, w) + s.unregister(r) + s.unregister(w) + + def test_unregister_after_socket_close(self): + s = self.SELECTOR() + self.addCleanup(s.close) + rd, wr = self.make_socketpair() + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + rd.close() + wr.close() + s.unregister(rd) + s.unregister(wr) + + def test_modify(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + key = s.register(rd, selectors.EVENT_READ) + + # modify events + key2 = s.modify(rd, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual(key2, s.get_key(rd)) + + s.unregister(rd) + + # modify data + d1 = object() + d2 = object() + + key = s.register(rd, selectors.EVENT_READ, d1) + key2 = s.modify(rd, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual(key2, s.get_key(rd)) + self.assertEqual(key2.data, d2) + + # modify unknown file obj + self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ) + + # modify use a shortcut + d3 = object() + s.register = unittest.mock.Mock() + s.unregister = unittest.mock.Mock() + + s.modify(rd, selectors.EVENT_READ, d3) + self.assertFalse(s.register.called) + self.assertFalse(s.unregister.called) + + def test_close(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + + s.close() + self.assertRaises(KeyError, s.get_key, rd) + self.assertRaises(KeyError, s.get_key, wr) + + def test_get_key(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + key = s.register(rd, selectors.EVENT_READ, "data") + self.assertEqual(key, s.get_key(rd)) + + # unknown file obj + self.assertRaises(KeyError, s.get_key, 999999) + + def test_get_map(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + keys = s.get_map() + self.assertFalse(keys) + self.assertEqual(len(keys), 0) + self.assertEqual(list(keys), []) + key = s.register(rd, selectors.EVENT_READ, "data") + self.assertIn(rd, keys) + self.assertEqual(key, keys[rd]) + self.assertEqual(len(keys), 1) + self.assertEqual(list(keys), [rd.fileno()]) + self.assertEqual(list(keys.values()), [key]) + + # unknown file obj + with self.assertRaises(KeyError): + keys[999999] + + # Read-only mapping + with self.assertRaises(TypeError): + del keys[rd] + + def test_select(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + s.register(rd, selectors.EVENT_READ) + wr_key = s.register(wr, selectors.EVENT_WRITE) + + result = s.select() + for key, events in result: + self.assertTrue(isinstance(key, selectors.SelectorKey)) + self.assertTrue(events) + self.assertFalse(events & ~(selectors.EVENT_READ | + selectors.EVENT_WRITE)) + + self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result) + + def test_context_manager(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + with s as sel: + sel.register(rd, selectors.EVENT_READ) + sel.register(wr, selectors.EVENT_WRITE) + + self.assertRaises(KeyError, s.get_key, rd) + self.assertRaises(KeyError, s.get_key, wr) + + def test_fileno(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + if hasattr(s, 'fileno'): + fd = s.fileno() + self.assertTrue(isinstance(fd, int)) + self.assertGreaterEqual(fd, 0) + + def test_selector(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + NUM_SOCKETS = 12 + MSG = b" This is a test." + MSG_LEN = len(MSG) + readers = [] + writers = [] + r2w = {} + w2r = {} + + for i in range(NUM_SOCKETS): + rd, wr = self.make_socketpair() + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + readers.append(rd) + writers.append(wr) + r2w[rd] = wr + w2r[wr] = rd + + bufs = [] + + while writers: + ready = s.select() + ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE) + if not ready_writers: + self.fail("no sockets ready for writing") + wr = random.choice(ready_writers) + wr.send(MSG) + + for i in range(10): + ready = s.select() + ready_readers = find_ready_matching(ready, + selectors.EVENT_READ) + if ready_readers: + break + # there might be a delay between the write to the write end and + # the read end is reported ready + sleep(0.1) + else: + self.fail("no sockets ready for reading") + self.assertEqual([w2r[wr]], ready_readers) + rd = ready_readers[0] + buf = rd.recv(MSG_LEN) + self.assertEqual(len(buf), MSG_LEN) + bufs.append(buf) + s.unregister(r2w[rd]) + s.unregister(rd) + writers.remove(r2w[rd]) + + self.assertEqual(bufs, [MSG] * NUM_SOCKETS) + + def test_timeout(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + s.register(wr, selectors.EVENT_WRITE) + t = time() + self.assertEqual(1, len(s.select(0))) + self.assertEqual(1, len(s.select(-1))) + self.assertLess(time() - t, 0.5) + + s.unregister(wr) + s.register(rd, selectors.EVENT_READ) + t = time() + self.assertFalse(s.select(0)) + self.assertFalse(s.select(-1)) + self.assertLess(time() - t, 0.5) + + t0 = time() + self.assertFalse(s.select(1)) + t1 = time() + dt = t1 - t0 + self.assertTrue(0.8 <= dt <= 1.6, dt) + + @unittest.skipUnless(hasattr(signal, "alarm"), + "signal.alarm() required for this test") + def test_select_interrupt(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None) + self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) + self.addCleanup(signal.alarm, 0) + + signal.alarm(1) + + s.register(rd, selectors.EVENT_READ) + t = time() + self.assertFalse(s.select(2)) + self.assertLess(time() - t, 2.5) + + +class ScalableSelectorMixIn: + + # see issue #18963 for why it's skipped on older OS X versions + @support.requires_mac_ver(10, 5) + @unittest.skipUnless(resource, "Test needs resource module") + def test_above_fd_setsize(self): + # A scalable implementation should have no problem with more than + # FD_SETSIZE file descriptors. Since we don't know the value, we just + # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling. + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) + self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE, + (soft, hard)) + NUM_FDS = hard + except (OSError, ValueError): + NUM_FDS = soft + + # guard for already allocated FDs (stdin, stdout...) + NUM_FDS -= 32 + + s = self.SELECTOR() + self.addCleanup(s.close) + + for i in range(NUM_FDS // 2): + try: + rd, wr = self.make_socketpair() + except OSError: + # too many FDs, skip - note that we should only catch EMFILE + # here, but apparently *BSD and Solaris can fail upon connect() + # or bind() with EADDRNOTAVAIL, so let's be safe + self.skipTest("FD limit reached") + + try: + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + except OSError as e: + if e.errno == errno.ENOSPC: + # this can be raised by epoll if we go over + # fs.epoll.max_user_watches sysctl + self.skipTest("FD limit reached") + raise + + self.assertEqual(NUM_FDS // 2, len(s.select())) + + +class DefaultSelectorTestCase(BaseSelectorTestCase): + + SELECTOR = selectors.DefaultSelector + + +class SelectSelectorTestCase(BaseSelectorTestCase): + + SELECTOR = selectors.SelectSelector + + +@unittest.skipUnless(hasattr(selectors, 'PollSelector'), + "Test needs selectors.PollSelector") +class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'PollSelector', None) + + +@unittest.skipUnless(hasattr(selectors, 'EpollSelector'), + "Test needs selectors.EpollSelector") +class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'EpollSelector', None) + + +@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'), + "Test needs selectors.KqueueSelector)") +class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'KqueueSelector', None) + + +def test_main(): + tests = [DefaultSelectorTestCase, SelectSelectorTestCase, + PollSelectorTestCase, EpollSelectorTestCase, + KqueueSelectorTestCase] + support.run_unittest(*tests) + support.reap_children() + + +if __name__ == "__main__": + test_main() diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index c0bca2fb88..bfef621093 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -848,8 +848,6 @@ class TestBasicOps: for v in self.set: self.assertIn(v, self.values) setiter = iter(self.set) - # note: __length_hint__ is an internal undocumented API, - # don't rely on it in your own programs self.assertEqual(setiter.__length_hint__(), len(self.set)) def test_pickling(self): diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 13c126566d..bd51d868fe 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -148,6 +148,19 @@ class TestCase(unittest.TestCase): p2 = d[encodedkey] self.assertNotEqual(p1, p2) # Write creates new object in store + def test_with(self): + d1 = {} + with shelve.Shelf(d1, protocol=2, writeback=False) as s: + s['key1'] = [1,2,3,4] + self.assertEqual(s['key1'], [1,2,3,4]) + self.assertEqual(len(s), 1) + self.assertRaises(ValueError, len, s) + try: + s['key1'] + except ValueError: + pass + else: + self.fail('Closed shelf should not find a key') from test import mapping_tests diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index c3fb8a7d2c..a483fe127c 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -18,7 +18,8 @@ from shutil import (_make_tarball, _make_zipfile, make_archive, register_archive_format, unregister_archive_format, get_archive_formats, Error, unpack_archive, register_unpack_format, RegistryError, - unregister_unpack_format, get_unpack_formats) + unregister_unpack_format, get_unpack_formats, + SameFileError) import tarfile import warnings @@ -746,13 +747,13 @@ class TestShutil(unittest.TestCase): shutil.copytree(src_dir, dst_dir) self.assertEqual(os.stat(src_dir).st_mode, os.stat(dst_dir).st_mode) self.assertEqual(os.stat(os.path.join(src_dir, 'permissive.txt')).st_mode, - os.stat(os.path.join(dst_dir, 'permissive.txt')).st_mode) + os.stat(os.path.join(dst_dir, 'permissive.txt')).st_mode) self.assertEqual(os.stat(os.path.join(src_dir, 'restrictive.txt')).st_mode, - os.stat(os.path.join(dst_dir, 'restrictive.txt')).st_mode) + os.stat(os.path.join(dst_dir, 'restrictive.txt')).st_mode) restrictive_subdir_dst = os.path.join(dst_dir, os.path.split(restrictive_subdir)[1]) self.assertEqual(os.stat(restrictive_subdir).st_mode, - os.stat(restrictive_subdir_dst).st_mode) + os.stat(restrictive_subdir_dst).st_mode) @unittest.skipIf(os.name == 'nt', 'temporarily disabled on Windows') @unittest.skipUnless(hasattr(os, 'link'), 'requires os.link') @@ -765,7 +766,7 @@ class TestShutil(unittest.TestCase): with open(src, 'w') as f: f.write('cheddar') os.link(src, dst) - self.assertRaises(shutil.Error, shutil.copyfile, src, dst) + self.assertRaises(shutil.SameFileError, shutil.copyfile, src, dst) with open(src, 'r') as f: self.assertEqual(f.read(), 'cheddar') os.remove(dst) @@ -785,7 +786,7 @@ class TestShutil(unittest.TestCase): # to TESTFN/TESTFN/cheese, while it should point at # TESTFN/cheese. os.symlink('cheese', dst) - self.assertRaises(shutil.Error, shutil.copyfile, src, dst) + self.assertRaises(shutil.SameFileError, shutil.copyfile, src, dst) with open(src, 'r') as f: self.assertEqual(f.read(), 'cheddar') os.remove(dst) @@ -1293,6 +1294,16 @@ class TestShutil(unittest.TestCase): self.assertTrue(os.path.exists(rv)) self.assertEqual(read_file(src_file), read_file(dst_file)) + def test_copyfile_same_file(self): + # copyfile() should raise SameFileError if the source and destination + # are the same. + src_dir = self.mkdtemp() + src_file = os.path.join(src_dir, 'foo') + write_file(src_file, 'foo') + self.assertRaises(SameFileError, shutil.copyfile, src_file, src_file) + # But Error should work too, to stay backward compatible. + self.assertRaises(Error, shutil.copyfile, src_file, src_file) + def test_copytree_return_value(self): # copytree returns its destination path. src_dir = self.mkdtemp() @@ -1601,7 +1612,7 @@ class TestCopyFile(unittest.TestCase): self._exited_with = exc_type, exc_val, exc_tb if self._raise_in_exit: self._raised = True - raise IOError("Cannot close") + raise OSError("Cannot close") return self._suppress_at_exit def tearDown(self): @@ -1615,12 +1626,12 @@ class TestCopyFile(unittest.TestCase): def test_w_source_open_fails(self): def _open(filename, mode='r'): if filename == 'srcfile': - raise IOError('Cannot open "srcfile"') + raise OSError('Cannot open "srcfile"') assert 0 # shouldn't reach here. self._set_shutil_open(_open) - self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile') + self.assertRaises(OSError, shutil.copyfile, 'srcfile', 'destfile') def test_w_dest_open_fails(self): @@ -1630,14 +1641,14 @@ class TestCopyFile(unittest.TestCase): if filename == 'srcfile': return srcfile if filename == 'destfile': - raise IOError('Cannot open "destfile"') + raise OSError('Cannot open "destfile"') assert 0 # shouldn't reach here. self._set_shutil_open(_open) shutil.copyfile('srcfile', 'destfile') self.assertTrue(srcfile._entered) - self.assertTrue(srcfile._exited_with[0] is IOError) + self.assertTrue(srcfile._exited_with[0] is OSError) self.assertEqual(srcfile._exited_with[1].args, ('Cannot open "destfile"',)) @@ -1659,7 +1670,7 @@ class TestCopyFile(unittest.TestCase): self.assertTrue(srcfile._entered) self.assertTrue(destfile._entered) self.assertTrue(destfile._raised) - self.assertTrue(srcfile._exited_with[0] is IOError) + self.assertTrue(srcfile._exited_with[0] is OSError) self.assertEqual(srcfile._exited_with[1].args, ('Cannot close',)) @@ -1677,7 +1688,7 @@ class TestCopyFile(unittest.TestCase): self._set_shutil_open(_open) - self.assertRaises(IOError, + self.assertRaises(OSError, shutil.copyfile, 'srcfile', 'destfile') self.assertTrue(srcfile._entered) self.assertTrue(destfile._entered) @@ -1748,9 +1759,5 @@ class TermsizeTests(unittest.TestCase): self.assertEqual(expected, actual) -def test_main(): - support.run_unittest(TestShutil, TestMove, TestCopyFile, - TermsizeTests, TestWhich) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_signal.py b/Lib/test/test_signal.py index 191d4cf691..a6f2c64857 100644 --- a/Lib/test/test_signal.py +++ b/Lib/test/test_signal.py @@ -15,9 +15,6 @@ try: except ImportError: threading = None -if sys.platform in ('os2', 'riscos'): - raise unittest.SkipTest("Can't test signal on %s" % sys.platform) - class HandlerBCalled(Exception): pass @@ -36,7 +33,7 @@ def exit_subprocess(): def ignoring_eintr(__func, *args, **kwargs): try: return __func(*args, **kwargs) - except EnvironmentError as e: + except OSError as e: if e.errno != errno.EINTR: raise return None @@ -278,6 +275,57 @@ class WakeupSignalTests(unittest.TestCase): assert_python_ok('-c', code) + def test_wakeup_write_error(self): + # Issue #16105: write() errors in the C signal handler should not + # pass silently. + # Use a subprocess to have only one thread. + code = """if 1: + import errno + import fcntl + import os + import signal + import sys + import time + from test.support import captured_stderr + + def handler(signum, frame): + 1/0 + + signal.signal(signal.SIGALRM, handler) + r, w = os.pipe() + flags = fcntl.fcntl(r, fcntl.F_GETFL, 0) + fcntl.fcntl(r, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + # Set wakeup_fd a read-only file descriptor to trigger the error + signal.set_wakeup_fd(r) + try: + with captured_stderr() as err: + signal.alarm(1) + time.sleep(5.0) + except ZeroDivisionError: + # An ignored exception should have been printed out on stderr + err = err.getvalue() + if ('Exception ignored when trying to write to the signal wakeup fd' + not in err): + raise AssertionError(err) + if ('OSError: [Errno %d]' % errno.EBADF) not in err: + raise AssertionError(err) + else: + raise AssertionError("ZeroDivisionError not raised") + """ + r, w = os.pipe() + try: + os.write(r, b'x') + except OSError: + pass + else: + self.skipTest("OS doesn't report write() error on the read end of a pipe") + finally: + os.close(r) + os.close(w) + + assert_python_ok('-c', code) + def test_wakeup_fd_early(self): self.check_wakeup("""def test(): import select @@ -315,10 +363,10 @@ class WakeupSignalTests(unittest.TestCase): # We attempt to get a signal during the select call try: select.select([read], [], [], TIMEOUT_FULL) - except select.error: + except OSError: pass else: - raise Exception("select.error not raised") + raise Exception("OSError not raised") after_time = time.time() dt = after_time - before_time if dt >= TIMEOUT_HALF: diff --git a/Lib/test/test_site.py b/Lib/test/test_site.py index f3bd168168..2d214f4baf 100644 --- a/Lib/test/test_site.py +++ b/Lib/test/test_site.py @@ -230,11 +230,7 @@ class HelperFunctionsTests(unittest.TestCase): site.PREFIXES = ['xoxo'] dirs = site.getsitepackages() - if sys.platform in ('os2emx', 'riscos'): - self.assertEqual(len(dirs), 1) - wanted = os.path.join('xoxo', 'Lib', 'site-packages') - self.assertEqual(dirs[0], wanted) - elif (sys.platform == "darwin" and + if (sys.platform == "darwin" and sysconfig.get_config_var("PYTHONFRAMEWORK")): # OS X framework builds site.PREFIXES = ['Python.framework'] @@ -432,5 +428,38 @@ class ImportSideEffectTests(unittest.TestCase): self.assertEqual(code, 200, msg="Can't find " + url) +class StartupImportTests(unittest.TestCase): + + def test_startup_imports(self): + # This tests checks which modules are loaded by Python when it + # initially starts upon startup. + popen = subprocess.Popen([sys.executable, '-I', '-v', '-c', + 'import sys; print(set(sys.modules))'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout, stderr = popen.communicate() + stdout = stdout.decode('utf-8') + stderr = stderr.decode('utf-8') + modules = eval(stdout) + + self.assertIn('site', modules) + + # http://bugs.python.org/issue19205 + re_mods = {'re', '_sre', 'sre_compile', 'sre_constants', 'sre_parse'} + # _osx_support uses the re module in many placs + if sys.platform != 'darwin': + self.assertFalse(modules.intersection(re_mods), stderr) + # http://bugs.python.org/issue9548 + self.assertNotIn('locale', modules, stderr) + if sys.platform != 'darwin': + # http://bugs.python.org/issue19209 + self.assertNotIn('copyreg', modules, stderr) + # http://bugs.python.org/issue19218> + collection_mods = {'_collections', 'collections', 'functools', + 'heapq', 'itertools', 'keyword', 'operator', + 'reprlib', 'types', 'weakref'} + self.assertFalse(modules.intersection(collection_mods), stderr) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py index 2df9271da7..9203d5eb19 100644 --- a/Lib/test/test_slice.py +++ b/Lib/test/test_slice.py @@ -4,8 +4,70 @@ import unittest from test import support from pickle import loads, dumps +import itertools +import operator import sys + +def evaluate_slice_index(arg): + """ + Helper function to convert a slice argument to an integer, and raise + TypeError with a suitable message on failure. + + """ + if hasattr(arg, '__index__'): + return operator.index(arg) + else: + raise TypeError( + "slice indices must be integers or " + "None or have an __index__ method") + +def slice_indices(slice, length): + """ + Reference implementation for the slice.indices method. + + """ + # Compute step and length as integers. + length = operator.index(length) + step = 1 if slice.step is None else evaluate_slice_index(slice.step) + + # Raise ValueError for negative length or zero step. + if length < 0: + raise ValueError("length should not be negative") + if step == 0: + raise ValueError("slice step cannot be zero") + + # Find lower and upper bounds for start and stop. + lower = -1 if step < 0 else 0 + upper = length - 1 if step < 0 else length + + # Compute start. + if slice.start is None: + start = upper if step < 0 else lower + else: + start = evaluate_slice_index(slice.start) + start = max(start + length, lower) if start < 0 else min(start, upper) + + # Compute stop. + if slice.stop is None: + stop = lower if step < 0 else upper + else: + stop = evaluate_slice_index(slice.stop) + stop = max(stop + length, lower) if stop < 0 else min(stop, upper) + + return start, stop, step + + +# Class providing an __index__ method. Used for testing slice.indices. + +class MyIndexable(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + + class SliceTest(unittest.TestCase): def test_constructor(self): @@ -75,6 +137,22 @@ class SliceTest(unittest.TestCase): s = slice(obj) self.assertTrue(s.stop is obj) + def check_indices(self, slice, length): + try: + actual = slice.indices(length) + except ValueError: + actual = "valueerror" + try: + expected = slice_indices(slice, length) + except ValueError: + expected = "valueerror" + self.assertEqual(actual, expected) + + if length >= 0 and slice.step != 0: + actual = range(*slice.indices(length)) + expected = range(length)[slice] + self.assertEqual(actual, expected) + def test_indices(self): self.assertEqual(slice(None ).indices(10), (0, 10, 1)) self.assertEqual(slice(None, None, 2).indices(10), (0, 10, 2)) @@ -108,7 +186,41 @@ class SliceTest(unittest.TestCase): self.assertEqual(list(range(10))[::sys.maxsize - 1], [0]) - self.assertRaises(OverflowError, slice(None).indices, 1<<100) + # Check a variety of start, stop, step and length values, including + # values exceeding sys.maxsize (see issue #14794). + vals = [None, -2**100, -2**30, -53, -7, -1, 0, 1, 7, 53, 2**30, 2**100] + lengths = [0, 1, 7, 53, 2**30, 2**100] + for slice_args in itertools.product(vals, repeat=3): + s = slice(*slice_args) + for length in lengths: + self.check_indices(s, length) + self.check_indices(slice(0, 10, 1), -3) + + # Negative length should raise ValueError + with self.assertRaises(ValueError): + slice(None).indices(-1) + + # Zero step should raise ValueError + with self.assertRaises(ValueError): + slice(0, 10, 0).indices(5) + + # Using a start, stop or step or length that can't be interpreted as an + # integer should give a TypeError ... + with self.assertRaises(TypeError): + slice(0.0, 10, 1).indices(5) + with self.assertRaises(TypeError): + slice(0, 10.0, 1).indices(5) + with self.assertRaises(TypeError): + slice(0, 10, 1.0).indices(5) + with self.assertRaises(TypeError): + slice(0, 10, 1).indices(5.0) + + # ... but it should be fine to use a custom class that provides index. + self.assertEqual(slice(0, 10, 1).indices(5), (0, 5, 1)) + self.assertEqual(slice(MyIndexable(0), 10, 1).indices(5), (0, 5, 1)) + self.assertEqual(slice(0, MyIndexable(10), 1).indices(5), (0, 5, 1)) + self.assertEqual(slice(0, 10, MyIndexable(1)).indices(5), (0, 5, 1)) + self.assertEqual(slice(0, 10, 1).indices(MyIndexable(5)), (0, 5, 1)) def test_setslice_without_getslice(self): tmp = [] diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index 3bbc6b7e53..3f7b9b4175 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -222,7 +222,7 @@ class DebuggingServerTests(unittest.TestCase): self.assertEqual(smtp.source_address, ('127.0.0.1', port)) self.assertEqual(smtp.local_hostname, 'localhost') smtp.quit() - except IOError as e: + except OSError as e: if e.errno == errno.EADDRINUSE: self.skipTest("couldn't bind to port %d" % port) raise @@ -524,12 +524,6 @@ class DebuggingServerTests(unittest.TestCase): class NonConnectingTests(unittest.TestCase): - def setUp(self): - smtplib.socket = mock_socket - - def tearDown(self): - smtplib.socket = socket - def testNotConnected(self): # Test various operations on an unconnected SMTP object that # should raise exceptions (at present the attempt in SMTP.send @@ -541,10 +535,10 @@ class NonConnectingTests(unittest.TestCase): smtp.send, 'test msg') def testNonnumericPort(self): - # check that non-numeric port raises socket.error - self.assertRaises(mock_socket.error, smtplib.SMTP, + # check that non-numeric port raises OSError + self.assertRaises(OSError, smtplib.SMTP, "localhost", "bogus") - self.assertRaises(mock_socket.error, smtplib.SMTP, + self.assertRaises(OSError, smtplib.SMTP, "localhost:bogus") @@ -852,6 +846,15 @@ class SMTPSimTests(unittest.TestCase): self.assertIn(sim_auth_credentials['cram-md5'], str(err)) smtp.close() + def testAUTH_multiple(self): + # Test that multiple authentication methods are tried. + self.serv.add_feature("AUTH BOGUS PLAIN LOGIN CRAM-MD5") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15) + try: smtp.login(sim_auth[0], sim_auth[1]) + except smtplib.SMTPAuthenticationError as err: + self.assertIn(sim_auth_login_password, str(err)) + smtp.close() + def test_with_statement(self): with smtplib.SMTP(HOST, self.port) as smtp: code, message = smtp.noop() diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py index e97cf363bb..7f9db2ddf0 100644 --- a/Lib/test/test_smtpnet.py +++ b/Lib/test/test_smtpnet.py @@ -1,23 +1,35 @@ import unittest from test import support import smtplib +import socket ssl = support.import_module("ssl") support.requires("network") +def check_ssl_verifiy(host, port): + context = ssl.create_default_context() + with socket.create_connection((host, port)) as sock: + try: + sock = context.wrap_socket(sock, server_hostname=host) + except Exception: + return False + else: + sock.close() + return True + class SmtpTest(unittest.TestCase): testServer = 'smtp.gmail.com' remotePort = 25 - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) def test_connect_starttls(self): support.get_attribute(smtplib, 'SMTP_SSL') + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) with support.transient_internet(self.testServer): server = smtplib.SMTP(self.testServer, self.remotePort) try: - server.starttls(context=self.context) + server.starttls(context=context) except smtplib.SMTPException as e: if e.args[0] == 'STARTTLS extension not supported by server.': unittest.skip(e.args[0]) @@ -30,7 +42,7 @@ class SmtpTest(unittest.TestCase): class SmtpSSLTest(unittest.TestCase): testServer = 'smtp.gmail.com' remotePort = 465 - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + can_verify = check_ssl_verifiy(testServer, remotePort) def test_connect(self): support.get_attribute(smtplib, 'SMTP_SSL') @@ -47,9 +59,19 @@ class SmtpSSLTest(unittest.TestCase): server.quit() def test_connect_using_sslcontext(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + support.get_attribute(smtplib, 'SMTP_SSL') + with support.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) + server.ehlo() + server.quit() + + @unittest.skipUnless(can_verify, "SSL certificate can't be verified") + def test_connect_using_sslcontext_verified(self): support.get_attribute(smtplib, 'SMTP_SSL') + context = ssl.create_default_context() with support.transient_internet(self.testServer): - server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=self.context) + server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) server.ehlo() server.quit() diff --git a/Lib/test/test_sndhdr.py b/Lib/test/test_sndhdr.py index 10046887d7..5e0abe0b36 100644 --- a/Lib/test/test_sndhdr.py +++ b/Lib/test/test_sndhdr.py @@ -12,7 +12,7 @@ class TestFormats(unittest.TestCase): ('sndhdr.hcom', ('hcom', 22050.0, 1, -1, 8)), ('sndhdr.sndt', ('sndt', 44100, 1, 5, 8)), ('sndhdr.voc', ('voc', 0, 1, -1, 8)), - ('sndhdr.wav', ('wav', 44100, 2, -1, 16)), + ('sndhdr.wav', ('wav', 44100, 2, 5, 16)), ): filename = findfile(filename, subdir="sndhdrdata") what = sndhdr.what(filename) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 450aee1d43..e94f5396c5 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1,6 +1,5 @@ import unittest from test import support -from unittest.case import _ExpectedFailure import errno import io @@ -21,13 +20,13 @@ import math import pickle import struct try: - import fcntl -except ImportError: - fcntl = False -try: import multiprocessing except ImportError: multiprocessing = False +try: + import fcntl +except ImportError: + fcntl = None HOST = support.HOST MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return @@ -43,7 +42,7 @@ def _have_socket_can(): """Check whether CAN sockets are supported on this host.""" try: s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) - except (AttributeError, socket.error, OSError): + except (AttributeError, OSError): return False else: s.close() @@ -118,12 +117,42 @@ class SocketCANTest(unittest.TestCase): interface = 'vcan0' bufsize = 128 + """The CAN frame structure is defined in <linux/can.h>: + + struct can_frame { + canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */ + __u8 can_dlc; /* data length code: 0 .. 8 */ + __u8 data[8] __attribute__((aligned(8))); + }; + """ + can_frame_fmt = "=IB3x8s" + can_frame_size = struct.calcsize(can_frame_fmt) + + """The Broadcast Management Command frame structure is defined + in <linux/can/bcm.h>: + + struct bcm_msg_head { + __u32 opcode; + __u32 flags; + __u32 count; + struct timeval ival1, ival2; + canid_t can_id; + __u32 nframes; + struct can_frame frames[0]; + } + + `bcm_msg_head` must be 8 bytes aligned because of the `frames` member (see + `struct can_frame` definition). Must use native not standard types for packing. + """ + bcm_cmd_msg_fmt = "@3I4l2I" + bcm_cmd_msg_fmt += "x" * (struct.calcsize(bcm_cmd_msg_fmt) % 8) + def setUp(self): self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) self.addCleanup(self.s.close) try: self.s.bind((self.interface,)) - except socket.error: + except OSError: self.skipTest('network interface `%s` does not exist' % self.interface) @@ -239,9 +268,6 @@ class ThreadableTest: raise TypeError("test_func must be a callable function") try: test_func() - except _ExpectedFailure: - # We deliberately ignore expected failures - pass except BaseException as e: self.queue.put(e) finally: @@ -292,7 +318,7 @@ class ThreadedCANSocketTest(SocketCANTest, ThreadableTest): self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) try: self.cli.bind((self.interface,)) - except socket.error: + except OSError: # skipTest should not be called here, and will be called in the # server instead pass @@ -540,11 +566,7 @@ class SCTPStreamBase(InetTestBase): class Inet6TestBase(InetTestBase): """Base class for IPv6 socket tests.""" - # Don't use "localhost" here - it may not have an IPv6 address - # assigned to it by default (e.g. in /etc/hosts), and if someone - # has assigned it an IPv4-mapped address, then it's unlikely to - # work with the full IPv6 API. - host = "::1" + host = support.HOSTv6 class UDP6TestBase(Inet6TestBase): """Base class for UDP-over-IPv6 tests.""" @@ -605,7 +627,7 @@ def requireSocket(*args): for obj in args] try: s = socket.socket(*callargs) - except socket.error as e: + except OSError as e: # XXX: check errno? err = str(e) else: @@ -623,8 +645,17 @@ class GeneralModuleTests(unittest.TestCase): def test_repr(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.addCleanup(s.close) - self.assertTrue(repr(s).startswith("<socket.socket object")) + with s: + self.assertIn('fd=%i' % s.fileno(), repr(s)) + self.assertIn('family=%s' % socket.AF_INET, repr(s)) + self.assertIn('type=%s' % socket.SOCK_STREAM, repr(s)) + self.assertIn('proto=0', repr(s)) + self.assertNotIn('raddr', repr(s)) + s.bind(('127.0.0.1', 0)) + self.assertIn('laddr', repr(s)) + self.assertIn(str(s.getsockname()), repr(s)) + self.assertIn('[closed]', repr(s)) + self.assertNotIn('laddr', repr(s)) def test_weakref(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -642,11 +673,11 @@ class GeneralModuleTests(unittest.TestCase): def testSocketError(self): # Testing socket module exceptions msg = "Error raising socket exception (%s)." - with self.assertRaises(socket.error, msg=msg % 'socket.error'): - raise socket.error - with self.assertRaises(socket.error, msg=msg % 'socket.herror'): + with self.assertRaises(OSError, msg=msg % 'OSError'): + raise OSError + with self.assertRaises(OSError, msg=msg % 'socket.herror'): raise socket.herror - with self.assertRaises(socket.error, msg=msg % 'socket.gaierror'): + with self.assertRaises(OSError, msg=msg % 'socket.gaierror'): raise socket.gaierror def testSendtoErrors(self): @@ -709,13 +740,13 @@ class GeneralModuleTests(unittest.TestCase): hostname = socket.gethostname() try: ip = socket.gethostbyname(hostname) - except socket.error: + except OSError: # Probably name lookup wasn't set up right; skip this test self.skipTest('name lookup failure') self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.") try: hname, aliases, ipaddrs = socket.gethostbyaddr(ip) - except socket.error: + except OSError: # Probably a similar problem as above; skip this test self.skipTest('name lookup failure') all_host_names = [hostname, hname] + aliases @@ -723,13 +754,27 @@ class GeneralModuleTests(unittest.TestCase): if not fqhn in all_host_names: self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names))) + def test_host_resolution(self): + for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2', + '1:1:1:1:1:1:1:1:1']: + self.assertRaises(OSError, socket.gethostbyname, addr) + self.assertRaises(OSError, socket.gethostbyaddr, addr) + + for addr in [support.HOST, '10.0.0.1', '255.255.255.255']: + self.assertEqual(socket.gethostbyname(addr), addr) + + # we don't test support.HOSTv6 because there's a chance it doesn't have + # a matching name entry (e.g. 'ip6-localhost') + for host in [support.HOST]: + self.assertIn(host, socket.gethostbyaddr(host)[2]) + @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()") @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()") def test_sethostname(self): oldhn = socket.gethostname() try: socket.sethostname('new') - except socket.error as e: + except OSError as e: if e.errno == errno.EPERM: self.skipTest("test should be run as root") else: @@ -763,8 +808,8 @@ class GeneralModuleTests(unittest.TestCase): 'socket.if_nameindex() not available.') def testInvalidInterfaceNameIndex(self): # test nonexistent interface index/name - self.assertRaises(socket.error, socket.if_indextoname, 0) - self.assertRaises(socket.error, socket.if_nametoindex, '_DEADBEEF') + self.assertRaises(OSError, socket.if_indextoname, 0) + self.assertRaises(OSError, socket.if_nametoindex, '_DEADBEEF') # test with invalid values self.assertRaises(TypeError, socket.if_nametoindex, 0) self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF') @@ -786,7 +831,7 @@ class GeneralModuleTests(unittest.TestCase): try: # On some versions, this crashes the interpreter. socket.getnameinfo(('x', 0, 0, 0), 0) - except socket.error: + except OSError: pass def testNtoH(self): @@ -833,17 +878,17 @@ class GeneralModuleTests(unittest.TestCase): try: port = socket.getservbyname(service, 'tcp') break - except socket.error: + except OSError: pass else: - raise socket.error + raise OSError # Try same call with optional protocol omitted port2 = socket.getservbyname(service) eq(port, port2) # Try udp, but don't barf if it doesn't exist try: udpport = socket.getservbyname(service, 'udp') - except socket.error: + except OSError: udpport = None else: eq(udpport, port) @@ -899,7 +944,7 @@ class GeneralModuleTests(unittest.TestCase): g = lambda a: inet_pton(AF_INET, a) assertInvalid = lambda func,a: self.assertRaises( - (socket.error, ValueError), func, a + (OSError, ValueError), func, a ) self.assertEqual(b'\x00\x00\x00\x00', f('0.0.0.0')) @@ -932,9 +977,17 @@ class GeneralModuleTests(unittest.TestCase): self.skipTest('IPv6 not available') except ImportError: self.skipTest('could not import needed symbols from socket') + + if sys.platform == "win32": + try: + inet_pton(AF_INET6, '::') + except OSError as e: + if e.winerror == 10022: + self.skipTest('IPv6 might not be supported') + f = lambda a: inet_pton(AF_INET6, a) assertInvalid = lambda a: self.assertRaises( - (socket.error, ValueError), f, a + (OSError, ValueError), f, a ) self.assertEqual(b'\x00' * 16, f('::')) @@ -983,7 +1036,7 @@ class GeneralModuleTests(unittest.TestCase): from socket import inet_ntoa as f, inet_ntop, AF_INET g = lambda a: inet_ntop(AF_INET, a) assertInvalid = lambda func,a: self.assertRaises( - (socket.error, ValueError), func, a + (OSError, ValueError), func, a ) self.assertEqual('1.0.1.0', f(b'\x01\x00\x01\x00')) @@ -1010,9 +1063,17 @@ class GeneralModuleTests(unittest.TestCase): self.skipTest('IPv6 not available') except ImportError: self.skipTest('could not import needed symbols from socket') + + if sys.platform == "win32": + try: + inet_ntop(AF_INET6, b'\x00' * 16) + except OSError as e: + if e.winerror == 10022: + self.skipTest('IPv6 might not be supported') + f = lambda a: inet_ntop(AF_INET6, a) assertInvalid = lambda a: self.assertRaises( - (socket.error, ValueError), f, a + (OSError, ValueError), f, a ) self.assertEqual('::', f(b'\x00' * 16)) @@ -1040,7 +1101,7 @@ class GeneralModuleTests(unittest.TestCase): # At least for eCos. This is required for the S/390 to pass. try: my_ip_addr = socket.gethostbyname(socket.gethostname()) - except socket.error: + except OSError: # Probably name lookup wasn't set up right; skip this test self.skipTest('name lookup failure') self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0]) @@ -1067,13 +1128,19 @@ class GeneralModuleTests(unittest.TestCase): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) sock.close() - self.assertRaises(socket.error, sock.send, b"spam") + self.assertRaises(OSError, sock.send, b"spam") def testNewAttributes(self): # testing .family, .type and .protocol + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.assertEqual(sock.family, socket.AF_INET) - self.assertEqual(sock.type, socket.SOCK_STREAM) + if hasattr(socket, 'SOCK_CLOEXEC'): + self.assertIn(sock.type, + (socket.SOCK_STREAM | socket.SOCK_CLOEXEC, + socket.SOCK_STREAM)) + else: + self.assertEqual(sock.type, socket.SOCK_STREAM) self.assertEqual(sock.proto, 0) sock.close() @@ -1126,9 +1193,12 @@ class GeneralModuleTests(unittest.TestCase): socket.getaddrinfo(HOST, 80) socket.getaddrinfo(HOST, None) # test family and socktype filters - infos = socket.getaddrinfo(HOST, None, socket.AF_INET) - for family, _, _, _, _ in infos: + infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM) + for family, type, _, _, _ in infos: self.assertEqual(family, socket.AF_INET) + self.assertEqual(str(family), 'AddressFamily.AF_INET') + self.assertEqual(type, socket.SOCK_STREAM) + self.assertEqual(str(type), 'SocketType.SOCK_STREAM') infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM) for _, socktype, _, _, _ in infos: self.assertEqual(socktype, socket.SOCK_STREAM) @@ -1176,7 +1246,7 @@ class GeneralModuleTests(unittest.TestCase): def test_getnameinfo(self): # only IP addresses are allowed - self.assertRaises(socket.error, socket.getnameinfo, ('mail.python.org',0), 0) + self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0) @unittest.skipUnless(support.is_resource_enabled('network'), 'network is not enabled') @@ -1291,10 +1361,31 @@ class GeneralModuleTests(unittest.TestCase): @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') def test_flowinfo(self): self.assertRaises(OverflowError, socket.getnameinfo, - ('::1',0, 0xffffffff), 0) + (support.HOSTv6, 0, 0xffffffff), 0) with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - self.assertRaises(OverflowError, s.bind, ('::1', 0, -10)) - + self.assertRaises(OverflowError, s.bind, (support.HOSTv6, 0, -10)) + + def test_str_for_enums(self): + # Make sure that the AF_* and SOCK_* constants have enum-like string + # reprs. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + self.assertEqual(str(s.family), 'AddressFamily.AF_INET') + self.assertEqual(str(s.type), 'SocketType.SOCK_STREAM') + + @unittest.skipIf(os.name == 'nt', 'Will not work on Windows') + def test_uknown_socket_family_repr(self): + # Test that when created with a family that's not one of the known + # AF_*/SOCK_* constants, socket.family just returns the number. + # + # To do this we fool socket.socket into believing it already has an + # open fd because on this path it doesn't actually verify the family and + # type and populates the socket object. + # + # On Windows this trick won't work, so the test is skipped. + fd, _ = tempfile.mkstemp() + with socket.socket(family=42424, type=13331, fileno=fd) as s: + self.assertEqual(s.family, 42424) + self.assertEqual(s.type, 13331) @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') class BasicCANTest(unittest.TestCase): @@ -1304,10 +1395,35 @@ class BasicCANTest(unittest.TestCase): socket.PF_CAN socket.CAN_RAW + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def testBCMConstants(self): + socket.CAN_BCM + + # opcodes + socket.CAN_BCM_TX_SETUP # create (cyclic) transmission task + socket.CAN_BCM_TX_DELETE # remove (cyclic) transmission task + socket.CAN_BCM_TX_READ # read properties of (cyclic) transmission task + socket.CAN_BCM_TX_SEND # send one CAN frame + socket.CAN_BCM_RX_SETUP # create RX content filter subscription + socket.CAN_BCM_RX_DELETE # remove RX content filter subscription + socket.CAN_BCM_RX_READ # read properties of RX content filter subscription + socket.CAN_BCM_TX_STATUS # reply to TX_READ request + socket.CAN_BCM_TX_EXPIRED # notification on performed transmissions (count=0) + socket.CAN_BCM_RX_STATUS # reply to RX_READ request + socket.CAN_BCM_RX_TIMEOUT # cyclic message is absent + socket.CAN_BCM_RX_CHANGED # updated CAN frame (detected content change) + def testCreateSocket(self): with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: pass + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def testCreateBCMSocket(self): + with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) as s: + pass + def testBindAny(self): with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: s.bind(('', )) @@ -1315,7 +1431,7 @@ class BasicCANTest(unittest.TestCase): def testTooLongInterfaceName(self): # most systems limit IFNAMSIZ to 16, take 1024 to be sure with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: - self.assertRaisesRegex(socket.error, 'interface name too long', + self.assertRaisesRegex(OSError, 'interface name too long', s.bind, ('x' * 1024,)) @unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"), @@ -1340,19 +1456,8 @@ class BasicCANTest(unittest.TestCase): @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') -@unittest.skipUnless(thread, 'Threading required for this test.') class CANTest(ThreadedCANSocketTest): - """The CAN frame structure is defined in <linux/can.h>: - - struct can_frame { - canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */ - __u8 can_dlc; /* data length code: 0 .. 8 */ - __u8 data[8] __attribute__((aligned(8))); - }; - """ - can_frame_fmt = "=IB3x8s" - def __init__(self, methodName='runTest'): ThreadedCANSocketTest.__init__(self, methodName=methodName) @@ -1401,6 +1506,46 @@ class CANTest(ThreadedCANSocketTest): self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33') self.cli.send(self.cf2) + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def _testBCM(self): + cf, addr = self.cli.recvfrom(self.bufsize) + self.assertEqual(self.cf, cf) + can_id, can_dlc, data = self.dissect_can_frame(cf) + self.assertEqual(self.can_id, can_id) + self.assertEqual(self.data, data) + + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def testBCM(self): + bcm = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) + self.addCleanup(bcm.close) + bcm.connect((self.interface,)) + self.can_id = 0x123 + self.data = bytes([0xc0, 0xff, 0xee]) + self.cf = self.build_can_frame(self.can_id, self.data) + opcode = socket.CAN_BCM_TX_SEND + flags = 0 + count = 0 + ival1_seconds = ival1_usec = ival2_seconds = ival2_usec = 0 + bcm_can_id = 0x0222 + nframes = 1 + assert len(self.cf) == 16 + header = struct.pack(self.bcm_cmd_msg_fmt, + opcode, + flags, + count, + ival1_seconds, + ival1_usec, + ival2_seconds, + ival2_usec, + bcm_can_id, + nframes, + ) + header_plus_frame = header + self.cf + bytes_sent = bcm.send(header_plus_frame) + self.assertEqual(bytes_sent, len(header_plus_frame)) + @unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.') class BasicRDSTest(unittest.TestCase): @@ -1623,7 +1768,7 @@ class BasicTCPTest(SocketConnectedTest): self.assertEqual(f, fileno) # cli_conn cannot be used anymore... self.assertTrue(self.cli_conn._closed) - self.assertRaises(socket.error, self.cli_conn.recv, 1024) + self.assertRaises(OSError, self.cli_conn.recv, 1024) self.cli_conn.close() # ...but we can create another socket using the (still open) # file descriptor @@ -1992,7 +2137,7 @@ class SendmsgTests(SendrecvmsgServerTimeoutBase): def _testSendmsgExcessCmsgReject(self): if not hasattr(socket, "CMSG_SPACE"): # Can only send one item - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")]) self.assertIsNone(cm.exception.errno) self.sendToServer(b"done") @@ -2003,7 +2148,7 @@ class SendmsgTests(SendrecvmsgServerTimeoutBase): def _testSendmsgAfterClose(self): self.cli_sock.close() - self.assertRaises(socket.error, self.sendmsgToServer, [MSG]) + self.assertRaises(OSError, self.sendmsgToServer, [MSG]) class SendmsgStreamTests(SendmsgTests): @@ -2047,7 +2192,7 @@ class SendmsgStreamTests(SendmsgTests): @testSendmsgDontWait.client_skip def _testSendmsgDontWait(self): try: - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: while True: self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT) self.assertIn(cm.exception.errno, @@ -2067,9 +2212,9 @@ class SendmsgConnectionlessTests(SendmsgTests): pass def _testSendmsgNoDestAddr(self): - self.assertRaises(socket.error, self.cli_sock.sendmsg, + self.assertRaises(OSError, self.cli_sock.sendmsg, [MSG]) - self.assertRaises(socket.error, self.cli_sock.sendmsg, + self.assertRaises(OSError, self.cli_sock.sendmsg, [MSG], [], 0, None) @@ -2155,7 +2300,7 @@ class RecvmsgGenericTests(SendrecvmsgBase): def testRecvmsgAfterClose(self): # Check that recvmsg[_into]() fails on a closed socket. self.serv_sock.close() - self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024) + self.assertRaises(OSError, self.doRecvmsg, self.serv_sock, 1024) def _testRecvmsgAfterClose(self): pass @@ -2606,7 +2751,7 @@ class SCMRightsTest(SendrecvmsgServerTimeoutBase): # call fails, just send msg with no ancillary data. try: nbytes = self.sendmsgToServer([msg], ancdata) - except socket.error as e: + except OSError as e: # Check that it was the system call that failed self.assertIsInstance(e.errno, int) nbytes = self.sendmsgToServer([msg]) @@ -2984,7 +3129,7 @@ class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase): array.array("i", [self.traffic_class]).tobytes() + b"\x00"), (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT, array.array("i", [self.hop_limit]))]) - except socket.error as e: + except OSError as e: self.assertIsInstance(e.errno, int) nbytes = self.sendmsgToServer( [MSG], @@ -3442,10 +3587,10 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): self.serv.settimeout(self.timeout) def checkInterruptedRecv(self, func, *args, **kwargs): - # Check that func(*args, **kwargs) raises socket.error with an + # Check that func(*args, **kwargs) raises OSError with an # errno of EINTR when interrupted by a signal. self.setAlarm(self.alarm_time) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: func(*args, **kwargs) self.assertNotIsInstance(cm.exception, socket.timeout) self.assertEqual(cm.exception.errno, errno.EINTR) @@ -3502,9 +3647,9 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase, def checkInterruptedSend(self, func, *args, **kwargs): # Check that func(*args, **kwargs), run in a loop, raises - # socket.error with an errno of EINTR when interrupted by a + # OSError with an errno of EINTR when interrupted by a # signal. - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: while True: self.setAlarm(self.alarm_time) func(*args, **kwargs) @@ -3603,7 +3748,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): start = time.time() try: self.serv.accept() - except socket.error: + except OSError: pass end = time.time() self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.") @@ -3638,7 +3783,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): start = time.time() try: self.serv.accept() - except socket.error: + except OSError: pass end = time.time() self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.") @@ -3668,7 +3813,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.serv.setblocking(0) try: conn, addr = self.serv.accept() - except socket.error: + except OSError: pass else: self.fail("Error trying to do non-blocking accept.") @@ -3698,7 +3843,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): conn.setblocking(0) try: msg = conn.recv(len(MSG)) - except socket.error: + except OSError: pass else: self.fail("Error trying to do non-blocking recv.") @@ -3781,7 +3926,7 @@ class FileObjectClassTestCase(SocketConnectedTest): # First read raises a timeout self.assertRaises(socket.timeout, self.read_file.read, 1) # Second read is disallowed - with self.assertRaises(IOError) as ctx: + with self.assertRaises(OSError) as ctx: self.read_file.read(1) self.assertIn("cannot read from timed out object", str(ctx.exception)) @@ -3873,7 +4018,7 @@ class FileObjectClassTestCase(SocketConnectedTest): self.read_file.close() self.assertRaises(ValueError, self.read_file.fileno) self.cli_conn.close() - self.assertRaises(socket.error, self.cli_conn.getsockname) + self.assertRaises(OSError, self.cli_conn.getsockname) def _testRealClose(self): pass @@ -3910,7 +4055,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase): @staticmethod def _raise_eintr(): - raise socket.error(errno.EINTR, "interrupted") + raise OSError(errno.EINTR, "interrupted") def _textiowrap_mock_socket(self, mock, buffering=-1): raw = socket.SocketIO(mock, "r") @@ -4022,7 +4167,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): self.assertEqual(msg, self.read_msg) # ...until the file is itself closed self.read_file.close() - self.assertRaises(socket.error, self.cli_conn.recv, 1024) + self.assertRaises(OSError, self.cli_conn.recv, 1024) def _testMakefileClose(self): self.write_file.write(self.write_msg) @@ -4171,7 +4316,7 @@ class NetworkConnectionNoServer(unittest.TestCase): port = support.find_unused_port() cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(cli.close) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: cli.connect((HOST, port)) self.assertEqual(cm.exception.errno, errno.ECONNREFUSED) @@ -4179,7 +4324,7 @@ class NetworkConnectionNoServer(unittest.TestCase): # Issue #9792: errors raised by create_connection() should have # a proper errno attribute. port = support.find_unused_port() - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: socket.create_connection((HOST, port)) # Issue #16257: create_connection() calls getaddrinfo() against @@ -4327,7 +4472,7 @@ class TCPTimeoutTest(SocketTCPTest): foo = self.serv.accept() except socket.timeout: self.fail("caught timeout instead of error (TCP)") - except socket.error: + except OSError: ok = True except: self.fail("caught unexpected exception (TCP)") @@ -4384,7 +4529,7 @@ class UDPTimeoutTest(SocketUDPTest): foo = self.serv.recv(1024) except socket.timeout: self.fail("caught timeout instead of error (UDP)") - except socket.error: + except OSError: ok = True except: self.fail("caught unexpected exception (UDP)") @@ -4394,10 +4539,10 @@ class UDPTimeoutTest(SocketUDPTest): class TestExceptions(unittest.TestCase): def testExceptionTree(self): - self.assertTrue(issubclass(socket.error, Exception)) - self.assertTrue(issubclass(socket.herror, socket.error)) - self.assertTrue(issubclass(socket.gaierror, socket.error)) - self.assertTrue(issubclass(socket.timeout, socket.error)) + self.assertTrue(issubclass(OSError, Exception)) + self.assertTrue(issubclass(socket.herror, OSError)) + self.assertTrue(issubclass(socket.gaierror, OSError)) + self.assertTrue(issubclass(socket.timeout, OSError)) @unittest.skipUnless(sys.platform == 'linux', 'Linux specific test') class TestLinuxAbstractNamespace(unittest.TestCase): @@ -4424,7 +4569,7 @@ class TestLinuxAbstractNamespace(unittest.TestCase): def testNameOverflow(self): address = "\x00" + "h" * self.UNIX_PATH_MAX with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: - self.assertRaises(socket.error, s.bind, address) + self.assertRaises(OSError, s.bind, address) def testStrName(self): # Check that an abstract name can be passed as a string. @@ -4681,7 +4826,7 @@ class ContextManagersTest(ThreadedTCPSocketTest): self.assertTrue(sock._closed) # exception inside with block with socket.socket() as sock: - self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertRaises(OSError, sock.sendall, b'foo') self.assertTrue(sock._closed) def testCreateConnectionBase(self): @@ -4709,19 +4854,76 @@ class ContextManagersTest(ThreadedTCPSocketTest): with socket.create_connection(address) as sock: sock.close() self.assertTrue(sock._closed) - self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertRaises(OSError, sock.sendall, b'foo') -@unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), - "SOCK_CLOEXEC not defined") -@unittest.skipUnless(fcntl, "module fcntl not available") -class CloexecConstantTest(unittest.TestCase): +class InheritanceTest(unittest.TestCase): + @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), + "SOCK_CLOEXEC not defined") @support.requires_linux_version(2, 6, 28) def test_SOCK_CLOEXEC(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: self.assertTrue(s.type & socket.SOCK_CLOEXEC) - self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) + self.assertFalse(s.get_inheritable()) + + def test_default_inheritable(self): + sock = socket.socket() + with sock: + self.assertEqual(sock.get_inheritable(), False) + + def test_dup(self): + sock = socket.socket() + with sock: + newsock = sock.dup() + sock.close() + with newsock: + self.assertEqual(newsock.get_inheritable(), False) + + def test_set_inheritable(self): + sock = socket.socket() + with sock: + sock.set_inheritable(True) + self.assertEqual(sock.get_inheritable(), True) + + sock.set_inheritable(False) + self.assertEqual(sock.get_inheritable(), False) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_get_inheritable_cloexec(self): + sock = socket.socket() + with sock: + fd = sock.fileno() + self.assertEqual(sock.get_inheritable(), False) + + # clear FD_CLOEXEC flag + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + flags &= ~fcntl.FD_CLOEXEC + fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + self.assertEqual(sock.get_inheritable(), True) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_set_inheritable_cloexec(self): + sock = socket.socket() + with sock: + fd = sock.fileno() + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + fcntl.FD_CLOEXEC) + + sock.set_inheritable(True) + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + 0) + + + @unittest.skipUnless(hasattr(socket, "socketpair"), + "need socket.socketpair()") + def test_socketpair(self): + s1, s2 = socket.socketpair() + self.addCleanup(s1.close) + self.addCleanup(s2.close) + self.assertEqual(s1.get_inheritable(), False) + self.assertEqual(s2.get_inheritable(), False) @unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"), @@ -4890,7 +5092,7 @@ def test_main(): NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, ContextManagersTest, - CloexecConstantTest, + InheritanceTest, NonblockConstantTest ]) tests.append(BasicSocketPairTest) diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 59d8e5dcb7..0617b30a69 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -2,8 +2,8 @@ Test suite for socketserver. """ +import _imp as imp import contextlib -import imp import os import select import signal @@ -29,7 +29,7 @@ HOST = test.support.HOST HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 'requires Unix sockets') -HAVE_FORKING = hasattr(os, "fork") and os.name != "os2" +HAVE_FORKING = hasattr(os, "fork") requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') def signal_alarm(n): @@ -85,7 +85,7 @@ class SocketServerTest(unittest.TestCase): for fn in self.test_files: try: os.remove(fn) - except os.error: + except OSError: pass self.test_files[:] = [] @@ -96,21 +96,7 @@ class SocketServerTest(unittest.TestCase): # XXX: We need a way to tell AF_UNIX to pick its own name # like AF_INET provides port==0. dir = None - if os.name == 'os2': - dir = '\socket' fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) - if os.name == 'os2': - # AF_UNIX socket names on OS/2 require a specific prefix - # which can't include a drive letter and must also use - # backslashes as directory separators - if fn[1] == ':': - fn = fn[2:] - if fn[0] in (os.sep, os.altsep): - fn = fn[1:] - if os.sep == '/': - fn = fn.replace(os.sep, os.altsep) - else: - fn = fn.replace(os.altsep, os.sep) self.test_files.append(fn) return fn diff --git a/Lib/test/test_pep263.py b/Lib/test/test_source_encoding.py index 324ae38619..cd9d2b374c 100644 --- a/Lib/test/test_pep263.py +++ b/Lib/test/test_source_encoding.py @@ -1,9 +1,12 @@ # -*- coding: koi8-r -*- import unittest -from test import support +from test.support import TESTFN, unlink, unload +import importlib +import os +import sys -class PEP263Test(unittest.TestCase): +class SourceEncodingTest(unittest.TestCase): def test_pep263(self): self.assertEqual( @@ -72,9 +75,61 @@ class PEP263Test(unittest.TestCase): with self.assertRaisesRegex(SyntaxError, 'BOM'): compile(b'\xef\xbb\xbf# -*- coding: fake -*-\n', 'dummy', 'exec') + def test_bad_coding(self): + module_name = 'bad_coding' + self.verify_bad_module(module_name) -def test_main(): - support.run_unittest(PEP263Test) + def test_bad_coding2(self): + module_name = 'bad_coding2' + self.verify_bad_module(module_name) -if __name__=="__main__": - test_main() + def verify_bad_module(self, module_name): + self.assertRaises(SyntaxError, __import__, 'test.' + module_name) + + path = os.path.dirname(__file__) + filename = os.path.join(path, module_name + '.py') + with open(filename, "rb") as fp: + bytes = fp.read() + self.assertRaises(SyntaxError, compile, bytes, filename, 'exec') + + def test_exec_valid_coding(self): + d = {} + exec(b'# coding: cp949\na = "\xaa\xa7"\n', d) + self.assertEqual(d['a'], '\u3047') + + def test_file_parse(self): + # issue1134: all encodings outside latin-1 and utf-8 fail on + # multiline strings and long lines (>512 columns) + unload(TESTFN) + filename = TESTFN + ".py" + f = open(filename, "w", encoding="cp1252") + sys.path.insert(0, os.curdir) + try: + with f: + f.write("# -*- coding: cp1252 -*-\n") + f.write("'''A short string\n") + f.write("'''\n") + f.write("'A very long string %s'\n" % ("X" * 1000)) + + importlib.invalidate_caches() + __import__(TESTFN) + finally: + del sys.path[0] + unlink(filename) + unlink(filename + "c") + unlink(filename + "o") + unload(TESTFN) + + def test_error_from_string(self): + # See http://bugs.python.org/issue6289 + input = "# coding: ascii\n\N{SNOWMAN}".encode('utf-8') + with self.assertRaises(SyntaxError) as c: + compile(input, "<string>", "exec") + expected = "'ascii' codec can't decode byte 0xe2 in position 16: " \ + "ordinal not in range(128)" + self.assertTrue(c.exception.args[0].startswith(expected), + msg=c.exception.args[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 9fc6027afb..0dc04c08ea 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -6,6 +6,7 @@ from test import support import socket import select import time +import datetime import gc import os import errno @@ -17,19 +18,15 @@ import asyncore import weakref import platform import functools +from unittest import mock ssl = support.import_module("ssl") -PROTOCOLS = [ - ssl.PROTOCOL_SSLv3, - ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1 -] -if hasattr(ssl, 'PROTOCOL_SSLv2'): - PROTOCOLS.append(ssl.PROTOCOL_SSLv2) - +PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST -data_file = lambda name: os.path.join(os.path.dirname(__file__), name) +def data_file(*name): + return os.path.join(os.path.dirname(__file__), *name) # The custom key and certificate files used in test_ssl are generated # using Lib/test/make_ssl_certs.py. @@ -47,6 +44,17 @@ ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem") KEY_PASSWORD = "somepass" CAPATH = data_file("capath") BYTES_CAPATH = os.fsencode(CAPATH) +CAFILE_NEURONIO = data_file("capath", "4e1295a3.0") +CAFILE_CACERT = data_file("capath", "5ed36f99.0") + + +# empty CRL +CRLFILE = data_file("revocation.crl") + +# Two keys and certs signed by the same CA (for SNI tests) +SIGNED_CERTFILE = data_file("keycert3.pem") +SIGNED_CERTFILE2 = data_file("keycert4.pem") +SIGNING_CA = data_file("pycacert.pem") SVN_PYTHON_ORG_ROOT_CERT = data_file("https_svn_python_org_root.pem") @@ -60,6 +68,7 @@ NULLBYTECERT = data_file("nullbytecert.pem") DHFILE = data_file("dh512.pem") BYTES_DHFILE = os.fsencode(DHFILE) + def handle_error(prefix): exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) if support.verbose: @@ -73,6 +82,23 @@ def no_sslv2_implies_sslv3_hello(): # 0.9.7h or higher return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15) +def have_verify_flags(): + # 0.9.8 or higher + return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15) + +def asn1time(cert_time): + # Some versions of OpenSSL ignore seconds, see #18207 + # 0.9.8.i + if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15): + fmt = "%b %d %H:%M:%S %Y GMT" + dt = datetime.datetime.strptime(cert_time, fmt) + dt = dt.replace(second=0) + cert_time = dt.strftime(fmt) + # %d adds leading zero but ASN1_TIME_print() uses leading space + if cert_time[4] == "0": + cert_time = cert_time[:4] + " " + cert_time[5:] + + return cert_time # Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2 def skip_if_broken_ubuntu_ssl(func): @@ -90,14 +116,12 @@ def skip_if_broken_ubuntu_ssl(func): else: return func +needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test") + class BasicSocketTests(unittest.TestCase): def test_constants(self): - #ssl.PROTOCOL_SSLv2 - ssl.PROTOCOL_SSLv23 - ssl.PROTOCOL_SSLv3 - ssl.PROTOCOL_TLSv1 ssl.CERT_NONE ssl.CERT_OPTIONAL ssl.CERT_REQUIRED @@ -179,8 +203,9 @@ class BasicSocketTests(unittest.TestCase): (('organizationName', 'Python Software Foundation'),), (('commonName', 'localhost'),)) ) - self.assertEqual(p['notAfter'], 'Oct 5 23:01:56 2020 GMT') - self.assertEqual(p['notBefore'], 'Oct 8 23:01:56 2010 GMT') + # Note the next three asserts will fail if the keys are regenerated + self.assertEqual(p['notAfter'], asn1time('Oct 5 23:01:56 2020 GMT')) + self.assertEqual(p['notBefore'], asn1time('Oct 8 23:01:56 2010 GMT')) self.assertEqual(p['serialNumber'], 'D7C7381919AFC24E') self.assertEqual(p['subject'], ((('countryName', 'XY'),), @@ -198,6 +223,12 @@ class BasicSocketTests(unittest.TestCase): (('DNS', 'projects.developer.nokia.com'), ('DNS', 'projects.forum.nokia.com')) ) + # extra OCSP and AIA fields + self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',)) + self.assertEqual(p['caIssuers'], + ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',)) + self.assertEqual(p['crlDistributionPoints'], + ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',)) def test_parse_cert_CVE_2013_4238(self): p = ssl._ssl._test_decode_cert(NULLBYTECERT) @@ -280,15 +311,15 @@ class BasicSocketTests(unittest.TestCase): def test_wrapped_unconnected(self): # Methods on an unconnected SSLSocket propagate the original - # socket.error raise by the underlying socket object. + # OSError raise by the underlying socket object. s = socket.socket(socket.AF_INET) with ssl.wrap_socket(s) as ss: - self.assertRaises(socket.error, ss.recv, 1) - self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) - self.assertRaises(socket.error, ss.recvfrom, 1) - self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) - self.assertRaises(socket.error, ss.send, b'x') - self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + self.assertRaises(OSError, ss.recv, 1) + self.assertRaises(OSError, ss.recv_into, bytearray(b'x')) + self.assertRaises(OSError, ss.recvfrom, 1) + self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(OSError, ss.send, b'x') + self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0)) def test_timeout(self): # Issue #8524: when creating an SSL socket, the timeout of the @@ -313,15 +344,15 @@ class BasicSocketTests(unittest.TestCase): with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s: self.assertRaisesRegex(ValueError, "can't connect in server-side mode", s.connect, (HOST, 8080)) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=WRONGCERT, keyfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) @@ -493,6 +524,114 @@ class BasicSocketTests(unittest.TestCase): support.gc_collect() self.assertIn(r, str(cm.warning.args[0])) + def test_get_default_verify_paths(self): + paths = ssl.get_default_verify_paths() + self.assertEqual(len(paths), 6) + self.assertIsInstance(paths, ssl.DefaultVerifyPaths) + + with support.EnvironmentVarGuard() as env: + env["SSL_CERT_DIR"] = CAPATH + env["SSL_CERT_FILE"] = CERTFILE + paths = ssl.get_default_verify_paths() + self.assertEqual(paths.cafile, CERTFILE) + self.assertEqual(paths.capath, CAPATH) + + @unittest.skipUnless(sys.platform == "win32", "Windows specific") + def test_enum_certificates(self): + self.assertTrue(ssl.enum_certificates("CA")) + self.assertTrue(ssl.enum_certificates("ROOT")) + + self.assertRaises(TypeError, ssl.enum_certificates) + self.assertRaises(WindowsError, ssl.enum_certificates, "") + + trust_oids = set() + for storename in ("CA", "ROOT"): + store = ssl.enum_certificates(storename) + self.assertIsInstance(store, list) + for element in store: + self.assertIsInstance(element, tuple) + self.assertEqual(len(element), 3) + cert, enc, trust = element + self.assertIsInstance(cert, bytes) + self.assertIn(enc, {"x509_asn", "pkcs_7_asn"}) + self.assertIsInstance(trust, (set, bool)) + if isinstance(trust, set): + trust_oids.update(trust) + + serverAuth = "1.3.6.1.5.5.7.3.1" + self.assertIn(serverAuth, trust_oids) + + @unittest.skipUnless(sys.platform == "win32", "Windows specific") + def test_enum_crls(self): + self.assertTrue(ssl.enum_crls("CA")) + self.assertRaises(TypeError, ssl.enum_crls) + self.assertRaises(WindowsError, ssl.enum_crls, "") + + crls = ssl.enum_crls("CA") + self.assertIsInstance(crls, list) + for element in crls: + self.assertIsInstance(element, tuple) + self.assertEqual(len(element), 2) + self.assertIsInstance(element[0], bytes) + self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"}) + + + def test_asn1object(self): + expected = (129, 'serverAuth', 'TLS Web Server Authentication', + '1.3.6.1.5.5.7.3.1') + + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1') + self.assertEqual(val, expected) + self.assertEqual(val.nid, 129) + self.assertEqual(val.shortname, 'serverAuth') + self.assertEqual(val.longname, 'TLS Web Server Authentication') + self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1') + self.assertIsInstance(val, ssl._ASN1Object) + self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth') + + val = ssl._ASN1Object.fromnid(129) + self.assertEqual(val, expected) + self.assertIsInstance(val, ssl._ASN1Object) + self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1) + with self.assertRaisesRegex(ValueError, "unknown NID 100000"): + ssl._ASN1Object.fromnid(100000) + for i in range(1000): + try: + obj = ssl._ASN1Object.fromnid(i) + except ValueError: + pass + else: + self.assertIsInstance(obj.nid, int) + self.assertIsInstance(obj.shortname, str) + self.assertIsInstance(obj.longname, str) + self.assertIsInstance(obj.oid, (str, type(None))) + + val = ssl._ASN1Object.fromname('TLS Web Server Authentication') + self.assertEqual(val, expected) + self.assertIsInstance(val, ssl._ASN1Object) + self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected) + self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'), + expected) + with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"): + ssl._ASN1Object.fromname('serverauth') + + def test_purpose_enum(self): + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1') + self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object) + self.assertEqual(ssl.Purpose.SERVER_AUTH, val) + self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129) + self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth') + self.assertEqual(ssl.Purpose.SERVER_AUTH.oid, + '1.3.6.1.5.5.7.3.1') + + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2') + self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object) + self.assertEqual(ssl.Purpose.CLIENT_AUTH, val) + self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130) + self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth') + self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid, + '1.3.6.1.5.5.7.3.2') + def test_unsupported_dtls(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.addCleanup(s.close) @@ -509,11 +648,8 @@ class ContextTests(unittest.TestCase): @skip_if_broken_ubuntu_ssl def test_constructor(self): - if hasattr(ssl, 'PROTOCOL_SSLv2'): - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv2) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv3) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + for protocol in PROTOCOLS: + ssl.SSLContext(protocol) self.assertRaises(TypeError, ssl.SSLContext) self.assertRaises(ValueError, ssl.SSLContext, -1) self.assertRaises(ValueError, ssl.SSLContext, 42) @@ -550,7 +686,7 @@ class ContextTests(unittest.TestCase): with self.assertRaises(ValueError): ctx.options = 0 - def test_verify(self): + def test_verify_mode(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) # Default value self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) @@ -565,13 +701,32 @@ class ContextTests(unittest.TestCase): with self.assertRaises(ValueError): ctx.verify_mode = 42 + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_verify_flags(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + # default value by OpenSSL + self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT) + ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF + self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF) + ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN + self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN) + ctx.verify_flags = ssl.VERIFY_DEFAULT + self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT) + # supports any value + ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT + self.assertEqual(ctx.verify_flags, + ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT) + with self.assertRaises(TypeError): + ctx.verify_flags = None + def test_load_cert_chain(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) # Combined key and cert in a single file ctx.load_cert_chain(CERTFILE) ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: ctx.load_cert_chain(WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): @@ -655,8 +810,8 @@ class ContextTests(unittest.TestCase): ctx.load_verify_locations(BYTES_CERTFILE) ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None) self.assertRaises(TypeError, ctx.load_verify_locations) - self.assertRaises(TypeError, ctx.load_verify_locations, None, None) - with self.assertRaises(IOError) as cm: + self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None) + with self.assertRaises(OSError) as cm: ctx.load_verify_locations(WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): @@ -667,6 +822,64 @@ class ContextTests(unittest.TestCase): # Issue #10989: crash if the second argument type is invalid self.assertRaises(TypeError, ctx.load_verify_locations, None, True) + def test_load_verify_cadata(self): + # test cadata + with open(CAFILE_CACERT) as f: + cacert_pem = f.read() + cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem) + with open(CAFILE_NEURONIO) as f: + neuronio_pem = f.read() + neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem) + + # test PEM + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0) + ctx.load_verify_locations(cadata=cacert_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1) + ctx.load_verify_locations(cadata=neuronio_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + # cert already in hash table + ctx.load_verify_locations(cadata=neuronio_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # combined + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + combined = "\n".join((cacert_pem, neuronio_pem)) + ctx.load_verify_locations(cadata=combined) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # with junk around the certs + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + combined = ["head", cacert_pem, "other", neuronio_pem, "again", + neuronio_pem, "tail"] + ctx.load_verify_locations(cadata="\n".join(combined)) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # test DER + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_verify_locations(cadata=cacert_der) + ctx.load_verify_locations(cadata=neuronio_der) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + # cert already in hash table + ctx.load_verify_locations(cadata=cacert_der) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # combined + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + combined = b"".join((cacert_der, neuronio_der)) + ctx.load_verify_locations(cadata=combined) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # error cases + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object) + + with self.assertRaisesRegex(ssl.SSLError, "no start line"): + ctx.load_verify_locations(cadata="broken") + with self.assertRaisesRegex(ssl.SSLError, "not enough data"): + ctx.load_verify_locations(cadata=b"broken") + + def test_load_dh_params(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx.load_dh_params(DHFILE) @@ -714,6 +927,158 @@ class ContextTests(unittest.TestCase): self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo") self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo") + @needs_sni + def test_sni_callback(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + + # set_servername_callback expects a callable, or None + self.assertRaises(TypeError, ctx.set_servername_callback) + self.assertRaises(TypeError, ctx.set_servername_callback, 4) + self.assertRaises(TypeError, ctx.set_servername_callback, "") + self.assertRaises(TypeError, ctx.set_servername_callback, ctx) + + def dummycallback(sock, servername, ctx): + pass + ctx.set_servername_callback(None) + ctx.set_servername_callback(dummycallback) + + @needs_sni + def test_sni_callback_refcycle(self): + # Reference cycles through the servername callback are detected + # and cleared. + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + def dummycallback(sock, servername, ctx, cycle=ctx): + pass + ctx.set_servername_callback(dummycallback) + wr = weakref.ref(ctx) + del ctx, dummycallback + gc.collect() + self.assertIs(wr(), None) + + def test_cert_store_stats(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_cert_chain(CERTFILE) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_verify_locations(CERTFILE) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 1}) + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 1, 'crl': 0, 'x509': 2}) + + def test_get_ca_certs(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.get_ca_certs(), []) + # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE + ctx.load_verify_locations(CERTFILE) + self.assertEqual(ctx.get_ca_certs(), []) + # but SVN_PYTHON_ORG_ROOT_CERT is a CA cert + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + self.assertEqual(ctx.get_ca_certs(), + [{'issuer': ((('organizationName', 'Root CA'),), + (('organizationalUnitName', 'http://www.cacert.org'),), + (('commonName', 'CA Cert Signing Authority'),), + (('emailAddress', 'support@cacert.org'),)), + 'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'), + 'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'), + 'serialNumber': '00', + 'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',), + 'subject': ((('organizationName', 'Root CA'),), + (('organizationalUnitName', 'http://www.cacert.org'),), + (('commonName', 'CA Cert Signing Authority'),), + (('emailAddress', 'support@cacert.org'),)), + 'version': 3}]) + + with open(SVN_PYTHON_ORG_ROOT_CERT) as f: + pem = f.read() + der = ssl.PEM_cert_to_DER_cert(pem) + self.assertEqual(ctx.get_ca_certs(True), [der]) + + def test_load_default_certs(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) + ctx.load_default_certs() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(TypeError, ctx.load_default_certs, None) + self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') + + def test_create_default_context(self): + ctx = ssl.create_default_context() + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + with open(SIGNING_CA) as f: + cadata = f.read() + ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH, + cadata=cadata) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + def test__create_stdlib_context(self): + ctx = ssl._create_stdlib_context() + self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertFalse(ctx.check_hostname) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1, + cert_reqs=ssl.CERT_REQUIRED, + check_hostname=True) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + def test_check_hostname(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertFalse(ctx.check_hostname) + + # Requires CERT_REQUIRED or CERT_OPTIONAL + with self.assertRaises(ValueError): + ctx.check_hostname = True + ctx.verify_mode = ssl.CERT_REQUIRED + self.assertFalse(ctx.check_hostname) + ctx.check_hostname = True + self.assertTrue(ctx.check_hostname) + + ctx.verify_mode = ssl.CERT_OPTIONAL + ctx.check_hostname = True + self.assertTrue(ctx.check_hostname) + + # Cannot set CERT_NONE with check_hostname enabled + with self.assertRaises(ValueError): + ctx.verify_mode = ssl.CERT_NONE + ctx.check_hostname = False + self.assertFalse(ctx.check_hostname) + class SSLErrorTests(unittest.TestCase): @@ -919,6 +1284,28 @@ class NetworkedTests(unittest.TestCase): finally: s.close() + def test_connect_cadata(self): + with open(CAFILE_CACERT) as f: + pem = f.read() + der = ssl.PEM_cert_to_DER_cert(pem) + with support.transient_internet("svn.python.org"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(cadata=pem) + with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + s.connect(("svn.python.org", 443)) + cert = s.getpeercert() + self.assertTrue(cert) + + # same with DER + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(cadata=der) + with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + s.connect(("svn.python.org", 443)) + cert = s.getpeercert() + self.assertTrue(cert) + @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows") def test_makefile_close(self): # Issue #5238: creating a file-like object with makefile() shouldn't @@ -1031,6 +1418,36 @@ class NetworkedTests(unittest.TestCase): finally: s.close() + def test_get_ca_certs_capath(self): + # capath certs are loaded on request + with support.transient_internet("svn.python.org"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(capath=CAPATH) + self.assertEqual(ctx.get_ca_certs(), []) + s = ctx.wrap_socket(socket.socket(socket.AF_INET)) + s.connect(("svn.python.org", 443)) + try: + cert = s.getpeercert() + self.assertTrue(cert) + finally: + s.close() + self.assertEqual(len(ctx.get_ca_certs()), 1) + + @needs_sni + def test_context_setget(self): + # Check that the context of a connected socket can be replaced. + with support.transient_internet("svn.python.org"): + ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + s = socket.socket(socket.AF_INET) + with ctx1.wrap_socket(s) as ss: + ss.connect(("svn.python.org", 443)) + self.assertIs(ss.context, ctx1) + self.assertIs(ss._sslobj.context, ctx1) + ss.context = ctx2 + self.assertIs(ss.context, ctx2) + self.assertIs(ss._sslobj.context, ctx2) try: import threading @@ -1158,7 +1575,7 @@ else: sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" % (msg, ctype, msg.lower(), ctype)) self.write(msg.lower()) - except socket.error: + except OSError: if self.server.chatty: handle_error("Test server failure:\n") self.close() @@ -1268,7 +1685,7 @@ else: return self.handle_close() except ssl.SSLError: raise - except socket.error as err: + except OSError as err: if err.args[0] == errno.ECONNABORTED: return self.handle_close() else: @@ -1372,19 +1789,19 @@ else: except ssl.SSLError as x: if support.verbose: sys.stdout.write("\nSSLError is %s\n" % x.args[1]) - except socket.error as x: + except OSError as x: if support.verbose: - sys.stdout.write("\nsocket.error is %s\n" % x.args[1]) - except IOError as x: + sys.stdout.write("\nOSError is %s\n" % x.args[1]) + except OSError as x: if x.errno != errno.ENOENT: raise if support.verbose: - sys.stdout.write("\IOError is %s\n" % str(x)) + sys.stdout.write("\OSError is %s\n" % str(x)) else: raise AssertionError("Use of invalid cert should have failed!") def server_params_test(client_context, server_context, indata=b"FOO\n", - chatty=True, connectionchatty=False): + chatty=True, connectionchatty=False, sni_name=None): """ Launch a server, connect a client to it and try various reads and writes. @@ -1394,7 +1811,8 @@ else: chatty=chatty, connectionchatty=False) with server: - with client_context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=sni_name) as s: s.connect((HOST, server.port)) for arg in [indata, bytearray(indata), memoryview(indata)]: if connectionchatty: @@ -1418,6 +1836,7 @@ else: stats.update({ 'compression': s.compression(), 'cipher': s.cipher(), + 'peercert': s.getpeercert(), 'client_npn_protocol': s.selected_npn_protocol() }) s.close() @@ -1443,12 +1862,15 @@ else: client_context.options |= client_options server_context = ssl.SSLContext(server_protocol) server_context.options |= server_options + + # NOTE: we must enable "ALL" ciphers on the client, otherwise an + # SSLv23 client will send an SSLv3 hello (rather than SSLv2) + # starting from OpenSSL 1.0.0 (see issue #8322). + if client_context.protocol == ssl.PROTOCOL_SSLv23: + client_context.set_ciphers("ALL") + for ctx in (client_context, server_context): ctx.verify_mode = certsreqs - # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client - # will send an SSLv3 hello (rather than SSLv2) starting from - # OpenSSL 1.0.0 (see issue #8322). - ctx.set_ciphers("ALL") ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: @@ -1459,7 +1881,7 @@ else: except ssl.SSLError: if expect_success: raise - except socket.error as e: + except OSError as e: if expect_success or e.errno != errno.ECONNRESET: raise else: @@ -1478,10 +1900,11 @@ else: if support.verbose: sys.stdout.write("\n") for protocol in PROTOCOLS: - context = ssl.SSLContext(protocol) - context.load_cert_chain(CERTFILE) - server_params_test(context, context, - chatty=True, connectionchatty=True) + with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): + context = ssl.SSLContext(protocol) + context.load_cert_chain(CERTFILE) + server_params_test(context, context, + chatty=True, connectionchatty=True) def test_getpeercert(self): if support.verbose: @@ -1492,8 +1915,14 @@ else: context.load_cert_chain(CERTFILE) server = ThreadedEchoServer(context=context, chatty=False) with server: - s = context.wrap_socket(socket.socket()) + s = context.wrap_socket(socket.socket(), + do_handshake_on_connect=False) s.connect((HOST, server.port)) + # getpeercert() raise ValueError while the handshake isn't + # done. + with self.assertRaises(ValueError): + s.getpeercert() + s.do_handshake() cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") cipher = s.cipher() @@ -1515,6 +1944,87 @@ else: self.assertLess(before, after) s.close() + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_crl_check(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(SIGNING_CA) + self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT) + + # VERIFY_DEFAULT should pass + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails + context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + with self.assertRaisesRegex(ssl.SSLError, + "certificate verify failed"): + s.connect((HOST, server.port)) + + # now load a CRL file. The CRL file is signed by the CA. + context.load_verify_locations(CRLFILE) + + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + @needs_sni + def test_check_hostname(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + context.load_verify_locations(SIGNING_CA) + + # correct hostname should verify + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="localhost") as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # incorrect hostname should raise an exception + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: + with self.assertRaisesRegex(ssl.CertificateError, + "hostname 'invalid' doesn't match 'localhost'"): + s.connect((HOST, server.port)) + + # missing server_hostname arg should cause an exception, too + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with socket.socket() as s: + with self.assertRaisesRegex(ValueError, + "check_hostname requires server_hostname"): + context.wrap_socket(s) + def test_empty_cert(self): """Connecting with an empty cert file""" bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, @@ -1533,7 +2043,7 @@ else: "badkey.pem")) def test_rude_shutdown(self): - """A brutal shutdown of an SSL server should raise an IOError + """A brutal shutdown of an SSL server should raise an OSError in the client when attempting handshake. """ listener_ready = threading.Event() @@ -1561,7 +2071,7 @@ else: listener_gone.wait() try: ssl_sock = ssl.wrap_socket(c) - except IOError: + except OSError: pass else: self.fail('connecting to closed SSL socket should have failed') @@ -1604,7 +2114,7 @@ else: if hasattr(ssl, 'PROTOCOL_SSLv2'): try: try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) - except (ssl.SSLError, socket.error) as x: + except OSError as x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if support.verbose: sys.stdout.write( @@ -1664,6 +2174,49 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1) + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), + "TLS version 1.1 not supported.") + def test_protocol_tlsv1_1(self): + """Connecting to a TLSv1.1 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_1) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) + + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), + "TLS version 1.2 not supported.") + def test_protocol_tlsv1_2(self): + """Connecting to a TLSv1.2 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, + client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_2) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + def test_starttls(self): """Switching from clear text to encrypted and back again.""" msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") @@ -1724,7 +2277,7 @@ else: def test_socketserver(self): """Using a SocketServer to create and manage SSL connections.""" - server = make_https_server(self, CERTFILE) + server = make_https_server(self, certfile=CERTFILE) # try to connect if support.verbose: sys.stdout.write('\n') @@ -1980,6 +2533,20 @@ else: self.assertIsInstance(remote, ssl.SSLSocket) self.assertEqual(peer, client_addr) + def test_getpeercert_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.getpeercert() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_do_handshake_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.do_handshake() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + def test_default_ciphers(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) try: @@ -1991,7 +2558,7 @@ else: ssl_version=ssl.PROTOCOL_SSLv23, chatty=False) as server: with context.wrap_socket(socket.socket()) as s: - with self.assertRaises((OSError, ssl.SSLError)): + with self.assertRaises(OSError): s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) @@ -2124,6 +2691,124 @@ else: if len(stats['server_npn_protocols']) else 'nothing' self.assertEqual(server_result, expected, msg % (server_result, "server")) + def sni_contexts(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context.load_cert_chain(SIGNED_CERTFILE2) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + return server_context, other_context, client_context + + def check_common_name(self, stats, name): + cert = stats['peercert'] + self.assertIn((('commonName', name),), cert['subject']) + + @needs_sni + def test_sni_callback(self): + calls = [] + server_context, other_context, client_context = self.sni_contexts() + + def servername_cb(ssl_sock, server_name, initial_context): + calls.append((server_name, initial_context)) + if server_name is not None: + ssl_sock.context = other_context + server_context.set_servername_callback(servername_cb) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='supermessage') + # The hostname was fetched properly, and the certificate was + # changed for the connection. + self.assertEqual(calls, [("supermessage", server_context)]) + # CERTFILE4 was selected + self.check_common_name(stats, 'fakehostname') + + calls = [] + # The callback is called with server_name=None + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name=None) + self.assertEqual(calls, [(None, server_context)]) + self.check_common_name(stats, 'localhost') + + # Check disabling the callback + calls = [] + server_context.set_servername_callback(None) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='notfunny') + # Certificate didn't change + self.check_common_name(stats, 'localhost') + self.assertEqual(calls, []) + + @needs_sni + def test_sni_callback_alert(self): + # Returning a TLS alert is reflected to the connecting client + server_context, other_context, client_context = self.sni_contexts() + + def cb_returning_alert(ssl_sock, server_name, initial_context): + return ssl.ALERT_DESCRIPTION_ACCESS_DENIED + server_context.set_servername_callback(cb_returning_alert) + + with self.assertRaises(ssl.SSLError) as cm: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') + + @needs_sni + def test_sni_callback_raising(self): + # Raising fails the connection with a TLS handshake failure alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_raising(ssl_sock, server_name, initial_context): + 1/0 + server_context.set_servername_callback(cb_raising) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') + self.assertIn("ZeroDivisionError", stderr.getvalue()) + + @needs_sni + def test_sni_callback_wrong_return_type(self): + # Returning the wrong return type terminates the TLS connection + # with an internal error alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_wrong_return_type(ssl_sock, server_name, initial_context): + return "foo" + server_context.set_servername_callback(cb_wrong_return_type) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') + self.assertIn("TypeError", stderr.getvalue()) + + def test_read_write_after_close_raises_valuerror(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + + with server: + s = context.wrap_socket(socket.socket()) + s.connect((HOST, server.port)) + s.close() + + self.assertRaises(ValueError, s.read, 1024) + self.assertRaises(ValueError, s.write, b'hello') + def test_main(verbose=False): if support.verbose: @@ -2143,10 +2828,16 @@ def test_main(verbose=False): (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO)) print(" under %s" % plat) print(" HAS_SNI = %r" % ssl.HAS_SNI) + print(" OP_ALL = 0x%8x" % ssl.OP_ALL) + try: + print(" OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1) + except AttributeError: + pass for filename in [ CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, BYTES_CERTFILE, ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY, + SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA, BADCERT, BADKEY, EMPTYCERT]: if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py index eb3f07a63d..af6ced4204 100644 --- a/Lib/test/test_stat.py +++ b/Lib/test/test_stat.py @@ -1,9 +1,13 @@ import unittest import os -from test.support import TESTFN, run_unittest, import_fresh_module -import stat +from test.support import TESTFN, import_fresh_module + +c_stat = import_fresh_module('stat', fresh=['_stat']) +py_stat = import_fresh_module('stat', blocked=['_stat']) + +class TestFilemode: + statmod = None -class TestFilemode(unittest.TestCase): file_flags = {'SF_APPEND', 'SF_ARCHIVED', 'SF_IMMUTABLE', 'SF_NOUNLINK', 'SF_SNAPSHOT', 'UF_APPEND', 'UF_COMPRESSED', 'UF_HIDDEN', 'UF_IMMUTABLE', 'UF_NODUMP', 'UF_NOUNLINK', 'UF_OPAQUE'} @@ -63,17 +67,17 @@ class TestFilemode(unittest.TestCase): st_mode = os.lstat(fname).st_mode else: st_mode = os.stat(fname).st_mode - modestr = stat.filemode(st_mode) + modestr = self.statmod.filemode(st_mode) return st_mode, modestr def assertS_IS(self, name, mode): # test format, lstrip is for S_IFIFO - fmt = getattr(stat, "S_IF" + name.lstrip("F")) - self.assertEqual(stat.S_IFMT(mode), fmt) + fmt = getattr(self.statmod, "S_IF" + name.lstrip("F")) + self.assertEqual(self.statmod.S_IFMT(mode), fmt) # test that just one function returns true testname = "S_IS" + name for funcname in self.format_funcs: - func = getattr(stat, funcname, None) + func = getattr(self.statmod, funcname, None) if func is None: if funcname == testname: raise ValueError(funcname) @@ -91,35 +95,35 @@ class TestFilemode(unittest.TestCase): st_mode, modestr = self.get_mode() self.assertEqual(modestr, '-rwx------') self.assertS_IS("REG", st_mode) - self.assertEqual(stat.S_IMODE(st_mode), - stat.S_IRWXU) + self.assertEqual(self.statmod.S_IMODE(st_mode), + self.statmod.S_IRWXU) os.chmod(TESTFN, 0o070) st_mode, modestr = self.get_mode() self.assertEqual(modestr, '----rwx---') self.assertS_IS("REG", st_mode) - self.assertEqual(stat.S_IMODE(st_mode), - stat.S_IRWXG) + self.assertEqual(self.statmod.S_IMODE(st_mode), + self.statmod.S_IRWXG) os.chmod(TESTFN, 0o007) st_mode, modestr = self.get_mode() self.assertEqual(modestr, '-------rwx') self.assertS_IS("REG", st_mode) - self.assertEqual(stat.S_IMODE(st_mode), - stat.S_IRWXO) + self.assertEqual(self.statmod.S_IMODE(st_mode), + self.statmod.S_IRWXO) os.chmod(TESTFN, 0o444) st_mode, modestr = self.get_mode() self.assertS_IS("REG", st_mode) self.assertEqual(modestr, '-r--r--r--') - self.assertEqual(stat.S_IMODE(st_mode), 0o444) + self.assertEqual(self.statmod.S_IMODE(st_mode), 0o444) else: os.chmod(TESTFN, 0o700) st_mode, modestr = self.get_mode() self.assertEqual(modestr[:3], '-rw') self.assertS_IS("REG", st_mode) - self.assertEqual(stat.S_IFMT(st_mode), - stat.S_IFREG) + self.assertEqual(self.statmod.S_IFMT(st_mode), + self.statmod.S_IFREG) def test_directory(self): os.mkdir(TESTFN) @@ -165,25 +169,34 @@ class TestFilemode(unittest.TestCase): def test_module_attributes(self): for key, value in self.stat_struct.items(): - modvalue = getattr(stat, key) + modvalue = getattr(self.statmod, key) self.assertEqual(value, modvalue, key) for key, value in self.permission_bits.items(): - modvalue = getattr(stat, key) + modvalue = getattr(self.statmod, key) self.assertEqual(value, modvalue, key) for key in self.file_flags: - modvalue = getattr(stat, key) + modvalue = getattr(self.statmod, key) self.assertIsInstance(modvalue, int) for key in self.formats: - modvalue = getattr(stat, key) + modvalue = getattr(self.statmod, key) self.assertIsInstance(modvalue, int) for key in self.format_funcs: - func = getattr(stat, key) + func = getattr(self.statmod, key) self.assertTrue(callable(func)) self.assertEqual(func(0), 0) -def test_main(): - run_unittest(TestFilemode) +class TestFilemodeCStat(TestFilemode, unittest.TestCase): + statmod = c_stat + + formats = TestFilemode.formats | {'S_IFDOOR', 'S_IFPORT', 'S_IFWHT'} + format_funcs = TestFilemode.format_funcs | {'S_ISDOOR', 'S_ISPORT', + 'S_ISWHT'} + + +class TestFilemodePyStat(TestFilemode, unittest.TestCase): + statmod = py_stat + if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py new file mode 100644 index 0000000000..6f5c9a649d --- /dev/null +++ b/Lib/test/test_statistics.py @@ -0,0 +1,1574 @@ +"""Test suite for statistics module, including helper NumericTestCase and +approx_equal function. + +""" + +import collections +import decimal +import doctest +import math +import random +import sys +import types +import unittest + +from decimal import Decimal +from fractions import Fraction + + +# Module to be tested. +import statistics + + +# === Helper functions and class === + +def _calc_errors(actual, expected): + """Return the absolute and relative errors between two numbers. + + >>> _calc_errors(100, 75) + (25, 0.25) + >>> _calc_errors(100, 100) + (0, 0.0) + + Returns the (absolute error, relative error) between the two arguments. + """ + base = max(abs(actual), abs(expected)) + abs_err = abs(actual - expected) + rel_err = abs_err/base if base else float('inf') + return (abs_err, rel_err) + + +def approx_equal(x, y, tol=1e-12, rel=1e-7): + """approx_equal(x, y [, tol [, rel]]) => True|False + + Return True if numbers x and y are approximately equal, to within some + margin of error, otherwise return False. Numbers which compare equal + will also compare approximately equal. + + x is approximately equal to y if the difference between them is less than + an absolute error tol or a relative error rel, whichever is bigger. + + If given, both tol and rel must be finite, non-negative numbers. If not + given, default values are tol=1e-12 and rel=1e-7. + + >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0) + True + >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0) + False + + Absolute error is defined as abs(x-y); if that is less than or equal to + tol, x and y are considered approximately equal. + + Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is + smaller, provided x or y are not zero. If that figure is less than or + equal to rel, x and y are considered approximately equal. + + Complex numbers are not directly supported. If you wish to compare to + complex numbers, extract their real and imaginary parts and compare them + individually. + + NANs always compare unequal, even with themselves. Infinities compare + approximately equal if they have the same sign (both positive or both + negative). Infinities with different signs compare unequal; so do + comparisons of infinities with finite numbers. + """ + if tol < 0 or rel < 0: + raise ValueError('error tolerances must be non-negative') + # NANs are never equal to anything, approximately or otherwise. + if math.isnan(x) or math.isnan(y): + return False + # Numbers which compare equal also compare approximately equal. + if x == y: + # This includes the case of two infinities with the same sign. + return True + if math.isinf(x) or math.isinf(y): + # This includes the case of two infinities of opposite sign, or + # one infinity and one finite number. + return False + # Two finite numbers. + actual_error = abs(x - y) + allowed_error = max(tol, rel*max(abs(x), abs(y))) + return actual_error <= allowed_error + + +# This class exists only as somewhere to stick a docstring containing +# doctests. The following docstring and tests were originally in a separate +# module. Now that it has been merged in here, I need somewhere to hang the. +# docstring. Ultimately, this class will die, and the information below will +# either become redundant, or be moved into more appropriate places. +class _DoNothing: + """ + When doing numeric work, especially with floats, exact equality is often + not what you want. Due to round-off error, it is often a bad idea to try + to compare floats with equality. Instead the usual procedure is to test + them with some (hopefully small!) allowance for error. + + The ``approx_equal`` function allows you to specify either an absolute + error tolerance, or a relative error, or both. + + Absolute error tolerances are simple, but you need to know the magnitude + of the quantities being compared: + + >>> approx_equal(12.345, 12.346, tol=1e-3) + True + >>> approx_equal(12.345e6, 12.346e6, tol=1e-3) # tol is too small. + False + + Relative errors are more suitable when the values you are comparing can + vary in magnitude: + + >>> approx_equal(12.345, 12.346, rel=1e-4) + True + >>> approx_equal(12.345e6, 12.346e6, rel=1e-4) + True + + but a naive implementation of relative error testing can run into trouble + around zero. + + If you supply both an absolute tolerance and a relative error, the + comparison succeeds if either individual test succeeds: + + >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4) + True + + """ + pass + + + +# We prefer this for testing numeric values that may not be exactly equal, +# and avoid using TestCase.assertAlmostEqual, because it sucks :-) + +class NumericTestCase(unittest.TestCase): + """Unit test class for numeric work. + + This subclasses TestCase. In addition to the standard method + ``TestCase.assertAlmostEqual``, ``assertApproxEqual`` is provided. + """ + # By default, we expect exact equality, unless overridden. + tol = rel = 0 + + def assertApproxEqual( + self, first, second, tol=None, rel=None, msg=None + ): + """Test passes if ``first`` and ``second`` are approximately equal. + + This test passes if ``first`` and ``second`` are equal to + within ``tol``, an absolute error, or ``rel``, a relative error. + + If either ``tol`` or ``rel`` are None or not given, they default to + test attributes of the same name (by default, 0). + + The objects may be either numbers, or sequences of numbers. Sequences + are tested element-by-element. + + >>> class MyTest(NumericTestCase): + ... def test_number(self): + ... x = 1.0/6 + ... y = sum([x]*6) + ... self.assertApproxEqual(y, 1.0, tol=1e-15) + ... def test_sequence(self): + ... a = [1.001, 1.001e-10, 1.001e10] + ... b = [1.0, 1e-10, 1e10] + ... self.assertApproxEqual(a, b, rel=1e-3) + ... + >>> import unittest + >>> from io import StringIO # Suppress test runner output. + >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest) + >>> unittest.TextTestRunner(stream=StringIO()).run(suite) + <unittest.runner.TextTestResult run=2 errors=0 failures=0> + + """ + if tol is None: + tol = self.tol + if rel is None: + rel = self.rel + if ( + isinstance(first, collections.Sequence) and + isinstance(second, collections.Sequence) + ): + check = self._check_approx_seq + else: + check = self._check_approx_num + check(first, second, tol, rel, msg) + + def _check_approx_seq(self, first, second, tol, rel, msg): + if len(first) != len(second): + standardMsg = ( + "sequences differ in length: %d items != %d items" + % (len(first), len(second)) + ) + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) + for i, (a,e) in enumerate(zip(first, second)): + self._check_approx_num(a, e, tol, rel, msg, i) + + def _check_approx_num(self, first, second, tol, rel, msg, idx=None): + if approx_equal(first, second, tol, rel): + # Test passes. Return early, we are done. + return None + # Otherwise we failed. + standardMsg = self._make_std_err_msg(first, second, tol, rel, idx) + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) + + @staticmethod + def _make_std_err_msg(first, second, tol, rel, idx): + # Create the standard error message for approx_equal failures. + assert first != second + template = ( + ' %r != %r\n' + ' values differ by more than tol=%r and rel=%r\n' + ' -> absolute error = %r\n' + ' -> relative error = %r' + ) + if idx is not None: + header = 'numeric sequences first differ at index %d.\n' % idx + template = header + template + # Calculate actual errors: + abs_err, rel_err = _calc_errors(first, second) + return template % (first, second, tol, rel, abs_err, rel_err) + + +# ======================== +# === Test the helpers === +# ======================== + + +# --- Tests for approx_equal --- + +class ApproxEqualSymmetryTest(unittest.TestCase): + # Test symmetry of approx_equal. + + def test_relative_symmetry(self): + # Check that approx_equal treats relative error symmetrically. + # (a-b)/a is usually not equal to (a-b)/b. Ensure that this + # doesn't matter. + # + # Note: the reason for this test is that an early version + # of approx_equal was not symmetric. A relative error test + # would pass, or fail, depending on which value was passed + # as the first argument. + # + args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)] + args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)] + assert len(args1) == len(args2) + for a, b in zip(args1, args2): + self.do_relative_symmetry(a, b) + + def do_relative_symmetry(self, a, b): + a, b = min(a, b), max(a, b) + assert a < b + delta = b - a # The absolute difference between the values. + rel_err1, rel_err2 = abs(delta/a), abs(delta/b) + # Choose an error margin halfway between the two. + rel = (rel_err1 + rel_err2)/2 + # Now see that values a and b compare approx equal regardless of + # which is given first. + self.assertTrue(approx_equal(a, b, tol=0, rel=rel)) + self.assertTrue(approx_equal(b, a, tol=0, rel=rel)) + + def test_symmetry(self): + # Test that approx_equal(a, b) == approx_equal(b, a) + args = [-23, -2, 5, 107, 93568] + delta = 2 + for a in args: + for type_ in (int, float, Decimal, Fraction): + x = type_(a)*100 + y = x + delta + r = abs(delta/max(x, y)) + # There are five cases to check: + # 1) actual error <= tol, <= rel + self.do_symmetry_test(x, y, tol=delta, rel=r) + self.do_symmetry_test(x, y, tol=delta+1, rel=2*r) + # 2) actual error > tol, > rel + self.do_symmetry_test(x, y, tol=delta-1, rel=r/2) + # 3) actual error <= tol, > rel + self.do_symmetry_test(x, y, tol=delta, rel=r/2) + # 4) actual error > tol, <= rel + self.do_symmetry_test(x, y, tol=delta-1, rel=r) + self.do_symmetry_test(x, y, tol=delta-1, rel=2*r) + # 5) exact equality test + self.do_symmetry_test(x, x, tol=0, rel=0) + self.do_symmetry_test(x, y, tol=0, rel=0) + + def do_symmetry_test(self, a, b, tol, rel): + template = "approx_equal comparisons don't match for %r" + flag1 = approx_equal(a, b, tol, rel) + flag2 = approx_equal(b, a, tol, rel) + self.assertEqual(flag1, flag2, template.format((a, b, tol, rel))) + + +class ApproxEqualExactTest(unittest.TestCase): + # Test the approx_equal function with exactly equal values. + # Equal values should compare as approximately equal. + # Test cases for exactly equal values, which should compare approx + # equal regardless of the error tolerances given. + + def do_exactly_equal_test(self, x, tol, rel): + result = approx_equal(x, x, tol=tol, rel=rel) + self.assertTrue(result, 'equality failure for x=%r' % x) + result = approx_equal(-x, -x, tol=tol, rel=rel) + self.assertTrue(result, 'equality failure for x=%r' % -x) + + def test_exactly_equal_ints(self): + # Test that equal int values are exactly equal. + for n in [42, 19740, 14974, 230, 1795, 700245, 36587]: + self.do_exactly_equal_test(n, 0, 0) + + def test_exactly_equal_floats(self): + # Test that equal float values are exactly equal. + for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]: + self.do_exactly_equal_test(x, 0, 0) + + def test_exactly_equal_fractions(self): + # Test that equal Fraction values are exactly equal. + F = Fraction + for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]: + self.do_exactly_equal_test(f, 0, 0) + + def test_exactly_equal_decimals(self): + # Test that equal Decimal values are exactly equal. + D = Decimal + for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()): + self.do_exactly_equal_test(d, 0, 0) + + def test_exactly_equal_absolute(self): + # Test that equal values are exactly equal with an absolute error. + for n in [16, 1013, 1372, 1198, 971, 4]: + # Test as ints. + self.do_exactly_equal_test(n, 0.01, 0) + # Test as floats. + self.do_exactly_equal_test(n/10, 0.01, 0) + # Test as Fractions. + f = Fraction(n, 1234) + self.do_exactly_equal_test(f, 0.01, 0) + + def test_exactly_equal_absolute_decimals(self): + # Test equal Decimal values are exactly equal with an absolute error. + self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0) + self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0) + + def test_exactly_equal_relative(self): + # Test that equal values are exactly equal with a relative error. + for x in [8347, 101.3, -7910.28, Fraction(5, 21)]: + self.do_exactly_equal_test(x, 0, 0.01) + self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01")) + + def test_exactly_equal_both(self): + # Test that equal values are equal when both tol and rel are given. + for x in [41017, 16.742, -813.02, Fraction(3, 8)]: + self.do_exactly_equal_test(x, 0.1, 0.01) + D = Decimal + self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01")) + + +class ApproxEqualUnequalTest(unittest.TestCase): + # Unequal values should compare unequal with zero error tolerances. + # Test cases for unequal values, with exact equality test. + + def do_exactly_unequal_test(self, x): + for a in (x, -x): + result = approx_equal(a, a+1, tol=0, rel=0) + self.assertFalse(result, 'inequality failure for x=%r' % a) + + def test_exactly_unequal_ints(self): + # Test unequal int values are unequal with zero error tolerance. + for n in [951, 572305, 478, 917, 17240]: + self.do_exactly_unequal_test(n) + + def test_exactly_unequal_floats(self): + # Test unequal float values are unequal with zero error tolerance. + for x in [9.51, 5723.05, 47.8, 9.17, 17.24]: + self.do_exactly_unequal_test(x) + + def test_exactly_unequal_fractions(self): + # Test that unequal Fractions are unequal with zero error tolerance. + F = Fraction + for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]: + self.do_exactly_unequal_test(f) + + def test_exactly_unequal_decimals(self): + # Test that unequal Decimals are unequal with zero error tolerance. + for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()): + self.do_exactly_unequal_test(d) + + +class ApproxEqualInexactTest(unittest.TestCase): + # Inexact test cases for approx_error. + # Test cases when comparing two values that are not exactly equal. + + # === Absolute error tests === + + def do_approx_equal_abs_test(self, x, delta): + template = "Test failure for x={!r}, y={!r}" + for y in (x + delta, x - delta): + msg = template.format(x, y) + self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg) + self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg) + + def test_approx_equal_absolute_ints(self): + # Test approximate equality of ints with an absolute error. + for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]: + self.do_approx_equal_abs_test(n, 10) + self.do_approx_equal_abs_test(n, 2) + + def test_approx_equal_absolute_floats(self): + # Test approximate equality of floats with an absolute error. + for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]: + self.do_approx_equal_abs_test(x, 1.5) + self.do_approx_equal_abs_test(x, 0.01) + self.do_approx_equal_abs_test(x, 0.0001) + + def test_approx_equal_absolute_fractions(self): + # Test approximate equality of Fractions with an absolute error. + delta = Fraction(1, 29) + numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71] + for f in (Fraction(n, 29) for n in numerators): + self.do_approx_equal_abs_test(f, delta) + self.do_approx_equal_abs_test(f, float(delta)) + + def test_approx_equal_absolute_decimals(self): + # Test approximate equality of Decimals with an absolute error. + delta = Decimal("0.01") + for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()): + self.do_approx_equal_abs_test(d, delta) + self.do_approx_equal_abs_test(-d, delta) + + def test_cross_zero(self): + # Test for the case of the two values having opposite signs. + self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0)) + + # === Relative error tests === + + def do_approx_equal_rel_test(self, x, delta): + template = "Test failure for x={!r}, y={!r}" + for y in (x*(1+delta), x*(1-delta)): + msg = template.format(x, y) + self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg) + self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg) + + def test_approx_equal_relative_ints(self): + # Test approximate equality of ints with a relative error. + self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36)) + self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37)) + # --- + self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125)) + self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125)) + self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125)) + + def test_approx_equal_relative_floats(self): + # Test approximate equality of floats with a relative error. + for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]: + self.do_approx_equal_rel_test(x, 0.02) + self.do_approx_equal_rel_test(x, 0.0001) + + def test_approx_equal_relative_fractions(self): + # Test approximate equality of Fractions with a relative error. + F = Fraction + delta = Fraction(3, 8) + for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]: + for d in (delta, float(delta)): + self.do_approx_equal_rel_test(f, d) + self.do_approx_equal_rel_test(-f, d) + + def test_approx_equal_relative_decimals(self): + # Test approximate equality of Decimals with a relative error. + for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()): + self.do_approx_equal_rel_test(d, Decimal("0.001")) + self.do_approx_equal_rel_test(-d, Decimal("0.05")) + + # === Both absolute and relative error tests === + + # There are four cases to consider: + # 1) actual error <= both absolute and relative error + # 2) actual error <= absolute error but > relative error + # 3) actual error <= relative error but > absolute error + # 4) actual error > both absolute and relative error + + def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag): + check = self.assertTrue if tol_flag else self.assertFalse + check(approx_equal(a, b, tol=tol, rel=0)) + check = self.assertTrue if rel_flag else self.assertFalse + check(approx_equal(a, b, tol=0, rel=rel)) + check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse + check(approx_equal(a, b, tol=tol, rel=rel)) + + def test_approx_equal_both1(self): + # Test actual error <= both absolute and relative error. + self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True) + self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True) + + def test_approx_equal_both2(self): + # Test actual error <= absolute error but > relative error. + self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False) + + def test_approx_equal_both3(self): + # Test actual error <= relative error but > absolute error. + self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True) + + def test_approx_equal_both4(self): + # Test actual error > both absolute and relative error. + self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False) + self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False) + + +class ApproxEqualSpecialsTest(unittest.TestCase): + # Test approx_equal with NANs and INFs and zeroes. + + def test_inf(self): + for type_ in (float, Decimal): + inf = type_('inf') + self.assertTrue(approx_equal(inf, inf)) + self.assertTrue(approx_equal(inf, inf, 0, 0)) + self.assertTrue(approx_equal(inf, inf, 1, 0.01)) + self.assertTrue(approx_equal(-inf, -inf)) + self.assertFalse(approx_equal(inf, -inf)) + self.assertFalse(approx_equal(inf, 1000)) + + def test_nan(self): + for type_ in (float, Decimal): + nan = type_('nan') + for other in (nan, type_('inf'), 1000): + self.assertFalse(approx_equal(nan, other)) + + def test_float_zeroes(self): + nzero = math.copysign(0.0, -1) + self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1)) + + def test_decimal_zeroes(self): + nzero = Decimal("-0.0") + self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1)) + + +class TestApproxEqualErrors(unittest.TestCase): + # Test error conditions of approx_equal. + + def test_bad_tol(self): + # Test negative tol raises. + self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1) + + def test_bad_rel(self): + # Test negative rel raises. + self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1) + + +# --- Tests for NumericTestCase --- + +# The formatting routine that generates the error messages is complex enough +# that it too needs testing. + +class TestNumericTestCase(unittest.TestCase): + # The exact wording of NumericTestCase error messages is *not* guaranteed, + # but we need to give them some sort of test to ensure that they are + # generated correctly. As a compromise, we look for specific substrings + # that are expected to be found even if the overall error message changes. + + def do_test(self, args): + actual_msg = NumericTestCase._make_std_err_msg(*args) + expected = self.generate_substrings(*args) + for substring in expected: + self.assertIn(substring, actual_msg) + + def test_numerictestcase_is_testcase(self): + # Ensure that NumericTestCase actually is a TestCase. + self.assertTrue(issubclass(NumericTestCase, unittest.TestCase)) + + def test_error_msg_numeric(self): + # Test the error message generated for numeric comparisons. + args = (2.5, 4.0, 0.5, 0.25, None) + self.do_test(args) + + def test_error_msg_sequence(self): + # Test the error message generated for sequence comparisons. + args = (3.75, 8.25, 1.25, 0.5, 7) + self.do_test(args) + + def generate_substrings(self, first, second, tol, rel, idx): + """Return substrings we expect to see in error messages.""" + abs_err, rel_err = _calc_errors(first, second) + substrings = [ + 'tol=%r' % tol, + 'rel=%r' % rel, + 'absolute error = %r' % abs_err, + 'relative error = %r' % rel_err, + ] + if idx is not None: + substrings.append('differ at index %d' % idx) + return substrings + + +# ======================================= +# === Tests for the statistics module === +# ======================================= + + +class GlobalsTest(unittest.TestCase): + module = statistics + expected_metadata = ["__doc__", "__all__"] + + def test_meta(self): + # Test for the existence of metadata. + for meta in self.expected_metadata: + self.assertTrue(hasattr(self.module, meta), + "%s not present" % meta) + + def test_check_all(self): + # Check everything in __all__ exists and is public. + module = self.module + for name in module.__all__: + # No private names in __all__: + self.assertFalse(name.startswith("_"), + 'private name "%s" in __all__' % name) + # And anything in __all__ must exist: + self.assertTrue(hasattr(module, name), + 'missing name "%s" in __all__' % name) + + +class DocTests(unittest.TestCase): + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -OO and above") + def test_doc_tests(self): + failed, tried = doctest.testmod(statistics) + self.assertGreater(tried, 0) + self.assertEqual(failed, 0) + +class StatisticsErrorTest(unittest.TestCase): + def test_has_exception(self): + errmsg = ( + "Expected StatisticsError to be a ValueError, but got a" + " subclass of %r instead." + ) + self.assertTrue(hasattr(statistics, 'StatisticsError')) + self.assertTrue( + issubclass(statistics.StatisticsError, ValueError), + errmsg % statistics.StatisticsError.__base__ + ) + + +# === Tests for private utility functions === + +class ExactRatioTest(unittest.TestCase): + # Test _exact_ratio utility. + + def test_int(self): + for i in (-20, -3, 0, 5, 99, 10**20): + self.assertEqual(statistics._exact_ratio(i), (i, 1)) + + def test_fraction(self): + numerators = (-5, 1, 12, 38) + for n in numerators: + f = Fraction(n, 37) + self.assertEqual(statistics._exact_ratio(f), (n, 37)) + + def test_float(self): + self.assertEqual(statistics._exact_ratio(0.125), (1, 8)) + self.assertEqual(statistics._exact_ratio(1.125), (9, 8)) + data = [random.uniform(-100, 100) for _ in range(100)] + for x in data: + num, den = statistics._exact_ratio(x) + self.assertEqual(x, num/den) + + def test_decimal(self): + D = Decimal + _exact_ratio = statistics._exact_ratio + self.assertEqual(_exact_ratio(D("0.125")), (125, 1000)) + self.assertEqual(_exact_ratio(D("12.345")), (12345, 1000)) + self.assertEqual(_exact_ratio(D("-1.98")), (-198, 100)) + + +class DecimalToRatioTest(unittest.TestCase): + # Test _decimal_to_ratio private function. + + def testSpecialsRaise(self): + # Test that NANs and INFs raise ValueError. + # Non-special values are covered by _exact_ratio above. + for d in (Decimal('NAN'), Decimal('sNAN'), Decimal('INF')): + self.assertRaises(ValueError, statistics._decimal_to_ratio, d) + + def test_sign(self): + # Test sign is calculated correctly. + numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")] + for d in numbers: + # First test positive decimals. + assert d > 0 + num, den = statistics._decimal_to_ratio(d) + self.assertGreaterEqual(num, 0) + self.assertGreater(den, 0) + # Then test negative decimals. + num, den = statistics._decimal_to_ratio(-d) + self.assertLessEqual(num, 0) + self.assertGreater(den, 0) + + def test_negative_exponent(self): + # Test result when the exponent is negative. + t = statistics._decimal_to_ratio(Decimal("0.1234")) + self.assertEqual(t, (1234, 10000)) + + def test_positive_exponent(self): + # Test results when the exponent is positive. + t = statistics._decimal_to_ratio(Decimal("1.234e7")) + self.assertEqual(t, (12340000, 1)) + + def test_regression_20536(self): + # Regression test for issue 20536. + # See http://bugs.python.org/issue20536 + t = statistics._decimal_to_ratio(Decimal("1e2")) + self.assertEqual(t, (100, 1)) + t = statistics._decimal_to_ratio(Decimal("1.47e5")) + self.assertEqual(t, (147000, 1)) + + +class CheckTypeTest(unittest.TestCase): + # Test _check_type private function. + + def test_allowed(self): + # Test that a type which should be allowed is allowed. + allowed = set([int, float]) + statistics._check_type(int, allowed) + statistics._check_type(float, allowed) + + def test_not_allowed(self): + # Test that a type which should not be allowed raises. + allowed = set([int, float]) + self.assertRaises(TypeError, statistics._check_type, Decimal, allowed) + + def test_add_to_allowed(self): + # Test that a second type will be added to the allowed set. + allowed = set([int]) + statistics._check_type(float, allowed) + self.assertEqual(allowed, set([int, float])) + + +# === Tests for public functions === + +class UnivariateCommonMixin: + # Common tests for most univariate functions that take a data argument. + + def test_no_args(self): + # Fail if given no arguments. + self.assertRaises(TypeError, self.func) + + def test_empty_data(self): + # Fail when the data argument (first argument) is empty. + for empty in ([], (), iter([])): + self.assertRaises(statistics.StatisticsError, self.func, empty) + + def prepare_data(self): + """Return int data for various tests.""" + data = list(range(10)) + while data == sorted(data): + random.shuffle(data) + return data + + def test_no_inplace_modifications(self): + # Test that the function does not modify its input data. + data = self.prepare_data() + assert len(data) != 1 # Necessary to avoid infinite loop. + assert data != sorted(data) + saved = data[:] + assert data is not saved + _ = self.func(data) + self.assertListEqual(data, saved, "data has been modified") + + def test_order_doesnt_matter(self): + # Test that the order of data points doesn't change the result. + + # CAUTION: due to floating point rounding errors, the result actually + # may depend on the order. Consider this test representing an ideal. + # To avoid this test failing, only test with exact values such as ints + # or Fractions. + data = [1, 2, 3, 3, 3, 4, 5, 6]*100 + expected = self.func(data) + random.shuffle(data) + actual = self.func(data) + self.assertEqual(expected, actual) + + def test_type_of_data_collection(self): + # Test that the type of iterable data doesn't effect the result. + class MyList(list): + pass + class MyTuple(tuple): + pass + def generator(data): + return (obj for obj in data) + data = self.prepare_data() + expected = self.func(data) + for kind in (list, tuple, iter, MyList, MyTuple, generator): + result = self.func(kind(data)) + self.assertEqual(result, expected) + + def test_range_data(self): + # Test that functions work with range objects. + data = range(20, 50, 3) + expected = self.func(list(data)) + self.assertEqual(self.func(data), expected) + + def test_bad_arg_types(self): + # Test that function raises when given data of the wrong type. + + # Don't roll the following into a loop like this: + # for bad in list_of_bad: + # self.check_for_type_error(bad) + # + # Since assertRaises doesn't show the arguments that caused the test + # failure, it is very difficult to debug these test failures when the + # following are in a loop. + self.check_for_type_error(None) + self.check_for_type_error(23) + self.check_for_type_error(42.0) + self.check_for_type_error(object()) + + def check_for_type_error(self, *args): + self.assertRaises(TypeError, self.func, *args) + + def test_type_of_data_element(self): + # Check the type of data elements doesn't affect the numeric result. + # This is a weaker test than UnivariateTypeMixin.testTypesConserved, + # because it checks the numeric result by equality, but not by type. + class MyFloat(float): + def __truediv__(self, other): + return type(self)(super().__truediv__(other)) + def __add__(self, other): + return type(self)(super().__add__(other)) + __radd__ = __add__ + + raw = self.prepare_data() + expected = self.func(raw) + for kind in (float, MyFloat, Decimal, Fraction): + data = [kind(x) for x in raw] + result = type(expected)(self.func(data)) + self.assertEqual(result, expected) + + +class UnivariateTypeMixin: + """Mixin class for type-conserving functions. + + This mixin class holds test(s) for functions which conserve the type of + individual data points. E.g. the mean of a list of Fractions should itself + be a Fraction. + + Not all tests to do with types need go in this class. Only those that + rely on the function returning the same type as its input data. + """ + def test_types_conserved(self): + # Test that functions keeps the same type as their data points. + # (Excludes mixed data types.) This only tests the type of the return + # result, not the value. + class MyFloat(float): + def __truediv__(self, other): + return type(self)(super().__truediv__(other)) + def __sub__(self, other): + return type(self)(super().__sub__(other)) + def __rsub__(self, other): + return type(self)(super().__rsub__(other)) + def __pow__(self, other): + return type(self)(super().__pow__(other)) + def __add__(self, other): + return type(self)(super().__add__(other)) + __radd__ = __add__ + + data = self.prepare_data() + for kind in (float, Decimal, Fraction, MyFloat): + d = [kind(x) for x in data] + result = self.func(d) + self.assertIs(type(result), kind) + + +class TestSum(NumericTestCase, UnivariateCommonMixin, UnivariateTypeMixin): + # Test cases for statistics._sum() function. + + def setUp(self): + self.func = statistics._sum + + def test_empty_data(self): + # Override test for empty data. + for data in ([], (), iter([])): + self.assertEqual(self.func(data), 0) + self.assertEqual(self.func(data, 23), 23) + self.assertEqual(self.func(data, 2.3), 2.3) + + def test_ints(self): + self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]), 60) + self.assertEqual(self.func([4, 2, 3, -8, 7], 1000), 1008) + + def test_floats(self): + self.assertEqual(self.func([0.25]*20), 5.0) + self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5), 3.125) + + def test_fractions(self): + F = Fraction + self.assertEqual(self.func([Fraction(1, 1000)]*500), Fraction(1, 2)) + + def test_decimals(self): + D = Decimal + data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"), + D("3.974"), D("2.328"), D("4.617"), D("2.843"), + ] + self.assertEqual(self.func(data), Decimal("20.686")) + + def test_compare_with_math_fsum(self): + # Compare with the math.fsum function. + # Ideally we ought to get the exact same result, but sometimes + # we differ by a very slight amount :-( + data = [random.uniform(-100, 1000) for _ in range(1000)] + self.assertApproxEqual(self.func(data), math.fsum(data), rel=2e-16) + + def test_start_argument(self): + # Test that the optional start argument works correctly. + data = [random.uniform(1, 1000) for _ in range(100)] + t = self.func(data) + self.assertEqual(t+42, self.func(data, 42)) + self.assertEqual(t-23, self.func(data, -23)) + self.assertEqual(t+1e20, self.func(data, 1e20)) + + def test_strings_fail(self): + # Sum of strings should fail. + self.assertRaises(TypeError, self.func, [1, 2, 3], '999') + self.assertRaises(TypeError, self.func, [1, 2, 3, '999']) + + def test_bytes_fail(self): + # Sum of bytes should fail. + self.assertRaises(TypeError, self.func, [1, 2, 3], b'999') + self.assertRaises(TypeError, self.func, [1, 2, 3, b'999']) + + def test_mixed_sum(self): + # Mixed input types are not (currently) allowed. + # Check that mixed data types fail. + self.assertRaises(TypeError, self.func, [1, 2.0, Fraction(1, 2)]) + # And so does mixed start argument. + self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1)) + + +class SumTortureTest(NumericTestCase): + def test_torture(self): + # Tim Peters' torture test for sum, and variants of same. + self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000), 20000.0) + self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000), 20000.0) + self.assertApproxEqual( + statistics._sum([1e-100, 1, 1e-100, -1]*10000), 2.0e-96, rel=5e-16 + ) + + +class SumSpecialValues(NumericTestCase): + # Test that sum works correctly with IEEE-754 special values. + + def test_nan(self): + for type_ in (float, Decimal): + nan = type_('nan') + result = statistics._sum([1, nan, 2]) + self.assertIs(type(result), type_) + self.assertTrue(math.isnan(result)) + + def check_infinity(self, x, inf): + """Check x is an infinity of the same type and sign as inf.""" + self.assertTrue(math.isinf(x)) + self.assertIs(type(x), type(inf)) + self.assertEqual(x > 0, inf > 0) + assert x == inf + + def do_test_inf(self, inf): + # Adding a single infinity gives infinity. + result = statistics._sum([1, 2, inf, 3]) + self.check_infinity(result, inf) + # Adding two infinities of the same sign also gives infinity. + result = statistics._sum([1, 2, inf, 3, inf, 4]) + self.check_infinity(result, inf) + + def test_float_inf(self): + inf = float('inf') + for sign in (+1, -1): + self.do_test_inf(sign*inf) + + def test_decimal_inf(self): + inf = Decimal('inf') + for sign in (+1, -1): + self.do_test_inf(sign*inf) + + def test_float_mismatched_infs(self): + # Test that adding two infinities of opposite sign gives a NAN. + inf = float('inf') + result = statistics._sum([1, 2, inf, 3, -inf, 4]) + self.assertTrue(math.isnan(result)) + + def test_decimal_mismatched_infs_to_nan(self): + # Test adding Decimal INFs with opposite sign returns NAN. + inf = Decimal('inf') + data = [1, 2, inf, 3, -inf, 4] + with decimal.localcontext(decimal.ExtendedContext): + self.assertTrue(math.isnan(statistics._sum(data))) + + def test_decimal_mismatched_infs_to_nan(self): + # Test adding Decimal INFs with opposite sign raises InvalidOperation. + inf = Decimal('inf') + data = [1, 2, inf, 3, -inf, 4] + with decimal.localcontext(decimal.BasicContext): + self.assertRaises(decimal.InvalidOperation, statistics._sum, data) + + def test_decimal_snan_raises(self): + # Adding sNAN should raise InvalidOperation. + sNAN = Decimal('sNAN') + data = [1, sNAN, 2] + self.assertRaises(decimal.InvalidOperation, statistics._sum, data) + + +# === Tests for averages === + +class AverageMixin(UnivariateCommonMixin): + # Mixin class holding common tests for averages. + + def test_single_value(self): + # Average of a single value is the value itself. + for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): + self.assertEqual(self.func([x]), x) + + def test_repeated_single_value(self): + # The average of a single repeated value is the value itself. + for x in (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712')): + for count in (2, 5, 10, 20): + data = [x]*count + self.assertEqual(self.func(data), x) + + +class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): + def setUp(self): + self.func = statistics.mean + + def test_torture_pep(self): + # "Torture Test" from PEP-450. + self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1) + + def test_ints(self): + # Test mean with ints. + data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9] + random.shuffle(data) + self.assertEqual(self.func(data), 4.8125) + + def test_floats(self): + # Test mean with floats. + data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5] + random.shuffle(data) + self.assertEqual(self.func(data), 22.015625) + + def test_decimals(self): + # Test mean with ints. + D = Decimal + data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")] + random.shuffle(data) + self.assertEqual(self.func(data), D("3.5896")) + + def test_fractions(self): + # Test mean with Fractions. + F = Fraction + data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] + random.shuffle(data) + self.assertEqual(self.func(data), F(1479, 1960)) + + def test_inf(self): + # Test mean with infinities. + raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. + for kind in (float, Decimal): + for sign in (1, -1): + inf = kind("inf")*sign + data = raw + [inf] + result = self.func(data) + self.assertTrue(math.isinf(result)) + self.assertEqual(result, inf) + + def test_mismatched_infs(self): + # Test mean with infinities of opposite sign. + data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')] + result = self.func(data) + self.assertTrue(math.isnan(result)) + + def test_nan(self): + # Test mean with NANs. + raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. + for kind in (float, Decimal): + inf = kind("nan") + data = raw + [inf] + result = self.func(data) + self.assertTrue(math.isnan(result)) + + def test_big_data(self): + # Test adding a large constant to every data point. + c = 1e9 + data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] + expected = self.func(data) + c + assert expected != c + result = self.func([x+c for x in data]) + self.assertEqual(result, expected) + + def test_doubled_data(self): + # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z]. + data = [random.uniform(-3, 5) for _ in range(1000)] + expected = self.func(data) + actual = self.func(data*2) + self.assertApproxEqual(actual, expected) + + def test_regression_20561(self): + # Regression test for issue 20561. + # See http://bugs.python.org/issue20561 + d = Decimal('1e4') + self.assertEqual(statistics.mean([d]), d) + + +class TestMedian(NumericTestCase, AverageMixin): + # Common tests for median and all median.* functions. + def setUp(self): + self.func = statistics.median + + def prepare_data(self): + """Overload method from UnivariateCommonMixin.""" + data = super().prepare_data() + if len(data)%2 != 1: + data.append(2) + return data + + def test_even_ints(self): + # Test median with an even number of int data points. + data = [1, 2, 3, 4, 5, 6] + assert len(data)%2 == 0 + self.assertEqual(self.func(data), 3.5) + + def test_odd_ints(self): + # Test median with an odd number of int data points. + data = [1, 2, 3, 4, 5, 6, 9] + assert len(data)%2 == 1 + self.assertEqual(self.func(data), 4) + + def test_odd_fractions(self): + # Test median works with an odd number of Fractions. + F = Fraction + data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)] + assert len(data)%2 == 1 + random.shuffle(data) + self.assertEqual(self.func(data), F(3, 7)) + + def test_even_fractions(self): + # Test median works with an even number of Fractions. + F = Fraction + data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), F(1, 2)) + + def test_odd_decimals(self): + # Test median works with an odd number of Decimals. + D = Decimal + data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] + assert len(data)%2 == 1 + random.shuffle(data) + self.assertEqual(self.func(data), D('4.2')) + + def test_even_decimals(self): + # Test median works with an even number of Decimals. + D = Decimal + data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), D('3.65')) + + +class TestMedianDataType(NumericTestCase, UnivariateTypeMixin): + # Test conservation of data element type for median. + def setUp(self): + self.func = statistics.median + + def prepare_data(self): + data = list(range(15)) + assert len(data)%2 == 1 + while data == sorted(data): + random.shuffle(data) + return data + + +class TestMedianLow(TestMedian, UnivariateTypeMixin): + def setUp(self): + self.func = statistics.median_low + + def test_even_ints(self): + # Test median_low with an even number of ints. + data = [1, 2, 3, 4, 5, 6] + assert len(data)%2 == 0 + self.assertEqual(self.func(data), 3) + + def test_even_fractions(self): + # Test median_low works with an even number of Fractions. + F = Fraction + data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), F(3, 7)) + + def test_even_decimals(self): + # Test median_low works with an even number of Decimals. + D = Decimal + data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), D('3.3')) + + +class TestMedianHigh(TestMedian, UnivariateTypeMixin): + def setUp(self): + self.func = statistics.median_high + + def test_even_ints(self): + # Test median_high with an even number of ints. + data = [1, 2, 3, 4, 5, 6] + assert len(data)%2 == 0 + self.assertEqual(self.func(data), 4) + + def test_even_fractions(self): + # Test median_high works with an even number of Fractions. + F = Fraction + data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), F(4, 7)) + + def test_even_decimals(self): + # Test median_high works with an even number of Decimals. + D = Decimal + data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), D('4.4')) + + +class TestMedianGrouped(TestMedian): + # Test median_grouped. + # Doesn't conserve data element types, so don't use TestMedianType. + def setUp(self): + self.func = statistics.median_grouped + + def test_odd_number_repeated(self): + # Test median.grouped with repeated median values. + data = [12, 13, 14, 14, 14, 15, 15] + assert len(data)%2 == 1 + self.assertEqual(self.func(data), 14) + #--- + data = [12, 13, 14, 14, 14, 14, 15] + assert len(data)%2 == 1 + self.assertEqual(self.func(data), 13.875) + #--- + data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30] + assert len(data)%2 == 1 + self.assertEqual(self.func(data, 5), 19.375) + #--- + data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28] + assert len(data)%2 == 1 + self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8) + + def test_even_number_repeated(self): + # Test median.grouped with repeated median values. + data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30] + assert len(data)%2 == 0 + self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8) + #--- + data = [2, 3, 4, 4, 4, 5] + assert len(data)%2 == 0 + self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8) + #--- + data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6] + assert len(data)%2 == 0 + self.assertEqual(self.func(data), 4.5) + #--- + data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6] + assert len(data)%2 == 0 + self.assertEqual(self.func(data), 4.75) + + def test_repeated_single_value(self): + # Override method from AverageMixin. + # Yet again, failure of median_grouped to conserve the data type + # causes me headaches :-( + for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')): + for count in (2, 5, 10, 20): + data = [x]*count + self.assertEqual(self.func(data), float(x)) + + def test_odd_fractions(self): + # Test median_grouped works with an odd number of Fractions. + F = Fraction + data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)] + assert len(data)%2 == 1 + random.shuffle(data) + self.assertEqual(self.func(data), 3.0) + + def test_even_fractions(self): + # Test median_grouped works with an even number of Fractions. + F = Fraction + data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), 3.25) + + def test_odd_decimals(self): + # Test median_grouped works with an odd number of Decimals. + D = Decimal + data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] + assert len(data)%2 == 1 + random.shuffle(data) + self.assertEqual(self.func(data), 6.75) + + def test_even_decimals(self): + # Test median_grouped works with an even number of Decimals. + D = Decimal + data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), 6.5) + #--- + data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')] + assert len(data)%2 == 0 + random.shuffle(data) + self.assertEqual(self.func(data), 7.0) + + def test_interval(self): + # Test median_grouped with interval argument. + data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] + self.assertEqual(self.func(data, 0.25), 2.875) + data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] + self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8) + data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340] + self.assertEqual(self.func(data, 20), 265.0) + + +class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin): + # Test cases for the discrete version of mode. + def setUp(self): + self.func = statistics.mode + + def prepare_data(self): + """Overload method from UnivariateCommonMixin.""" + # Make sure test data has exactly one mode. + return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2] + + def test_range_data(self): + # Override test from UnivariateCommonMixin. + data = range(20, 50, 3) + self.assertRaises(statistics.StatisticsError, self.func, data) + + def test_nominal_data(self): + # Test mode with nominal data. + data = 'abcbdb' + self.assertEqual(self.func(data), 'b') + data = 'fe fi fo fum fi fi'.split() + self.assertEqual(self.func(data), 'fi') + + def test_discrete_data(self): + # Test mode with discrete numeric data. + data = list(range(10)) + for i in range(10): + d = data + [i] + random.shuffle(d) + self.assertEqual(self.func(d), i) + + def test_bimodal_data(self): + # Test mode with bimodal data. + data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9] + assert data.count(2) == data.count(6) == 4 + # Check for an exception. + self.assertRaises(statistics.StatisticsError, self.func, data) + + def test_unique_data_failure(self): + # Test mode exception when data points are all unique. + data = list(range(10)) + self.assertRaises(statistics.StatisticsError, self.func, data) + + def test_none_data(self): + # Test that mode raises TypeError if given None as data. + + # This test is necessary because the implementation of mode uses + # collections.Counter, which accepts None and returns an empty dict. + self.assertRaises(TypeError, self.func, None) + + def test_counter_data(self): + # Test that a Counter is treated like any other iterable. + data = collections.Counter([1, 1, 1, 2]) + # Since the keys of the counter are treated as data points, not the + # counts, this should raise. + self.assertRaises(statistics.StatisticsError, self.func, data) + + + +# === Tests for variances and standard deviations === + +class VarianceStdevMixin(UnivariateCommonMixin): + # Mixin class holding common tests for variance and std dev. + + # Subclasses should inherit from this before NumericTestClass, in order + # to see the rel attribute below. See testShiftData for an explanation. + + rel = 1e-12 + + def test_single_value(self): + # Deviation of a single value is zero. + for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')): + self.assertEqual(self.func([x]), 0) + + def test_repeated_single_value(self): + # The deviation of a single repeated value is zero. + for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')): + for count in (2, 3, 5, 15): + data = [x]*count + self.assertEqual(self.func(data), 0) + + def test_domain_error_regression(self): + # Regression test for a domain error exception. + # (Thanks to Geremy Condra.) + data = [0.123456789012345]*10000 + # All the items are identical, so variance should be exactly zero. + # We allow some small round-off error, but not much. + result = self.func(data) + self.assertApproxEqual(result, 0.0, tol=5e-17) + self.assertGreaterEqual(result, 0) # A negative result must fail. + + def test_shift_data(self): + # Test that shifting the data by a constant amount does not affect + # the variance or stdev. Or at least not much. + + # Due to rounding, this test should be considered an ideal. We allow + # some tolerance away from "no change at all" by setting tol and/or rel + # attributes. Subclasses may set tighter or looser error tolerances. + raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78] + expected = self.func(raw) + # Don't set shift too high, the bigger it is, the more rounding error. + shift = 1e5 + data = [x + shift for x in raw] + self.assertApproxEqual(self.func(data), expected) + + def test_shift_data_exact(self): + # Like test_shift_data, but result is always exact. + raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16] + assert all(x==int(x) for x in raw) + expected = self.func(raw) + shift = 10**9 + data = [x + shift for x in raw] + self.assertEqual(self.func(data), expected) + + def test_iter_list_same(self): + # Test that iter data and list data give the same result. + + # This is an explicit test that iterators and lists are treated the + # same; justification for this test over and above the similar test + # in UnivariateCommonMixin is that an earlier design had variance and + # friends swap between one- and two-pass algorithms, which would + # sometimes give different results. + data = [random.uniform(-3, 8) for _ in range(1000)] + expected = self.func(data) + self.assertEqual(self.func(iter(data)), expected) + + +class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): + # Tests for population variance. + def setUp(self): + self.func = statistics.pvariance + + def test_exact_uniform(self): + # Test the variance against an exact result for uniform data. + data = list(range(10000)) + random.shuffle(data) + expected = (10000**2 - 1)/12 # Exact value. + self.assertEqual(self.func(data), expected) + + def test_ints(self): + # Test population variance with int data. + data = [4, 7, 13, 16] + exact = 22.5 + self.assertEqual(self.func(data), exact) + + def test_fractions(self): + # Test population variance with Fraction data. + F = Fraction + data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] + exact = F(3, 8) + result = self.func(data) + self.assertEqual(result, exact) + self.assertIsInstance(result, Fraction) + + def test_decimals(self): + # Test population variance with Decimal data. + D = Decimal + data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")] + exact = D('0.096875') + result = self.func(data) + self.assertEqual(result, exact) + self.assertIsInstance(result, Decimal) + + +class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): + # Tests for sample variance. + def setUp(self): + self.func = statistics.variance + + def test_single_value(self): + # Override method from VarianceStdevMixin. + for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')): + self.assertRaises(statistics.StatisticsError, self.func, [x]) + + def test_ints(self): + # Test sample variance with int data. + data = [4, 7, 13, 16] + exact = 30 + self.assertEqual(self.func(data), exact) + + def test_fractions(self): + # Test sample variance with Fraction data. + F = Fraction + data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] + exact = F(1, 2) + result = self.func(data) + self.assertEqual(result, exact) + self.assertIsInstance(result, Fraction) + + def test_decimals(self): + # Test sample variance with Decimal data. + D = Decimal + data = [D(2), D(2), D(7), D(9)] + exact = 4*D('9.5')/D(3) + result = self.func(data) + self.assertEqual(result, exact) + self.assertIsInstance(result, Decimal) + + +class TestPStdev(VarianceStdevMixin, NumericTestCase): + # Tests for population standard deviation. + def setUp(self): + self.func = statistics.pstdev + + def test_compare_to_variance(self): + # Test that stdev is, in fact, the square root of variance. + data = [random.uniform(-17, 24) for _ in range(1000)] + expected = math.sqrt(statistics.pvariance(data)) + self.assertEqual(self.func(data), expected) + + +class TestStdev(VarianceStdevMixin, NumericTestCase): + # Tests for sample standard deviation. + def setUp(self): + self.func = statistics.stdev + + def test_single_value(self): + # Override method from VarianceStdevMixin. + for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')): + self.assertRaises(statistics.StatisticsError, self.func, [x]) + + def test_compare_to_variance(self): + # Test that stdev is, in fact, the square root of variance. + data = [random.uniform(-2, 9) for _ in range(1000)] + expected = math.sqrt(statistics.variance(data)) + self.assertEqual(self.func(data), expected) + + +# === Run tests === + +def load_tests(loader, tests, ignore): + """Used for doctest/unittest integration.""" + tests.addTests(doctest.DocTestSuite()) + return tests + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_strftime.py b/Lib/test/test_strftime.py index 4f39cc79ac..772cd0688c 100644 --- a/Lib/test/test_strftime.py +++ b/Lib/test/test_strftime.py @@ -182,15 +182,15 @@ class Y1900Tests(unittest.TestCase): a date before 1900 is passed with a format string containing "%y" """ - @unittest.skipUnless(sys.platform == "win32", "Only applies to Windows") - def test_y_before_1900_win(self): - with self.assertRaises(ValueError): - time.strftime("%y", (1899, 1, 1, 0, 0, 0, 0, 0, 0)) - - @unittest.skipIf(sys.platform == "win32", "Doesn't apply on Windows") - def test_y_before_1900_nonwin(self): - self.assertEqual( - time.strftime("%y", (1899, 1, 1, 0, 0, 0, 0, 0, 0)), "99") + def test_y_before_1900(self): + # Issue #13674, #19634 + t = (1899, 1, 1, 0, 0, 0, 0, 0, 0) + if (sys.platform == "win32" + or sys.platform.startswith(("aix", "sunos", "solaris"))): + with self.assertRaises(ValueError): + time.strftime("%y", t) + else: + self.assertEqual(time.strftime("%y", t), "99") def test_y_1900(self): self.assertEqual( @@ -200,8 +200,5 @@ class Y1900Tests(unittest.TestCase): self.assertEqual( time.strftime("%y", (2013, 1, 1, 0, 0, 0, 0, 0, 0)), "13") -def test_main(): - support.run_unittest(StrftimeTest, Y1900Tests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_strtod.py b/Lib/test/test_strtod.py index c5979fc989..41b6e5f833 100644 --- a/Lib/test/test_strtod.py +++ b/Lib/test/test_strtod.py @@ -248,7 +248,7 @@ class StrtodTests(unittest.TestCase): else: assert False, "expected ValueError" - @test.support.bigmemtest(size=test.support._2G+10, memuse=4, dry_run=False) + @test.support.bigmemtest(size=test.support._2G+10, memuse=3, dry_run=False) def test_oversized_digit_strings(self, maxsize): # Input string whose length doesn't fit in an INT. s = "1." + "1" * maxsize diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py index 50cae052f3..0107eebca6 100644 --- a/Lib/test/test_struct.py +++ b/Lib/test/test_struct.py @@ -1,4 +1,6 @@ +from collections import abc import array +import operator import unittest import struct import sys @@ -6,7 +8,6 @@ import sys from test import support ISBIGENDIAN = sys.byteorder == "big" -IS32BIT = sys.maxsize == 0x7fffffff integer_codes = 'b', 'B', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q', 'n', 'N' byteorders = '', '@', '=', '<', '>', '!' @@ -489,7 +490,7 @@ class StructTest(unittest.TestCase): def test_bool(self): class ExplodingBool(object): def __bool__(self): - raise IOError + raise OSError for prefix in tuple("<>!=")+('',): false = (), [], [], '', 0 true = [1], 'test', 5, -1, 0xffffffff+1, 0xffffffff/2 @@ -520,10 +521,10 @@ class StructTest(unittest.TestCase): try: struct.pack(prefix + '?', ExplodingBool()) - except IOError: + except OSError: pass else: - self.fail("Expected IOError: struct.pack(%r, " + self.fail("Expected OSError: struct.pack(%r, " "ExplodingBool())" % (prefix + '?')) for c in [b'\x01', b'\x7f', b'\xff', b'\x0f', b'\xf0']: @@ -536,10 +537,6 @@ class StructTest(unittest.TestCase): hugecount2 = '{}b{}H'.format(sys.maxsize//2, sys.maxsize//2) self.assertRaises(struct.error, struct.calcsize, hugecount2) - if IS32BIT: - def test_crasher(self): - self.assertRaises(MemoryError, struct.pack, "357913941b", "a") - def test_trailing_counter(self): store = array.array('b', b' '*100) @@ -576,7 +573,7 @@ class StructTest(unittest.TestCase): # The size of 'PyStructObject' totalsize = support.calcobjsize('2n3P') # The size taken up by the 'formatcode' dynamic array - totalsize += struct.calcsize('P2n0P') * (number_of_codes + 1) + totalsize += struct.calcsize('P3n0P') * (number_of_codes + 1) support.check_sizeof(self, struct.Struct(format_str), totalsize) @support.cpython_only @@ -587,14 +584,84 @@ class StructTest(unittest.TestCase): self.check_sizeof('B' * 1234, 1234) self.check_sizeof('fd', 2) self.check_sizeof('xxxxxxxxxxxxxx', 0) - self.check_sizeof('100H', 100) + self.check_sizeof('100H', 1) self.check_sizeof('187s', 1) self.check_sizeof('20p', 1) self.check_sizeof('0s', 1) self.check_sizeof('0c', 0) + +class UnpackIteratorTest(unittest.TestCase): + """ + Tests for iterative unpacking (struct.Struct.iter_unpack). + """ + + def test_construct(self): + def _check_iterator(it): + self.assertIsInstance(it, abc.Iterator) + self.assertIsInstance(it, abc.Iterable) + s = struct.Struct('>ibcp') + it = s.iter_unpack(b"") + _check_iterator(it) + it = s.iter_unpack(b"1234567") + _check_iterator(it) + # Wrong bytes length + with self.assertRaises(struct.error): + s.iter_unpack(b"123456") + with self.assertRaises(struct.error): + s.iter_unpack(b"12345678") + # Zero-length struct + s = struct.Struct('>') + with self.assertRaises(struct.error): + s.iter_unpack(b"") + with self.assertRaises(struct.error): + s.iter_unpack(b"12") + + def test_iterate(self): + s = struct.Struct('>IB') + b = bytes(range(1, 16)) + it = s.iter_unpack(b) + self.assertEqual(next(it), (0x01020304, 5)) + self.assertEqual(next(it), (0x06070809, 10)) + self.assertEqual(next(it), (0x0b0c0d0e, 15)) + self.assertRaises(StopIteration, next, it) + self.assertRaises(StopIteration, next, it) + + def test_arbitrary_buffer(self): + s = struct.Struct('>IB') + b = bytes(range(1, 11)) + it = s.iter_unpack(memoryview(b)) + self.assertEqual(next(it), (0x01020304, 5)) + self.assertEqual(next(it), (0x06070809, 10)) + self.assertRaises(StopIteration, next, it) + self.assertRaises(StopIteration, next, it) + + def test_length_hint(self): + lh = operator.length_hint + s = struct.Struct('>IB') + b = bytes(range(1, 16)) + it = s.iter_unpack(b) + self.assertEqual(lh(it), 3) + next(it) + self.assertEqual(lh(it), 2) + next(it) + self.assertEqual(lh(it), 1) + next(it) + self.assertEqual(lh(it), 0) + self.assertRaises(StopIteration, next, it) + self.assertEqual(lh(it), 0) + + def test_module_func(self): + # Sanity check for the global struct.iter_unpack() + it = struct.iter_unpack('>IB', bytes(range(1, 11))) + self.assertEqual(next(it), (0x01020304, 5)) + self.assertEqual(next(it), (0x06070809, 10)) + self.assertRaises(StopIteration, next, it) + self.assertRaises(StopIteration, next, it) + + def test_main(): - support.run_unittest(StructTest) + support.run_unittest(__name__) if __name__ == '__main__': test_main() diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index a89e9556c1..353d0eadcc 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -38,7 +38,7 @@ class StructSeqTest(unittest.TestCase): # os.stat() gives a complicated struct sequence. st = os.stat(__file__) rep = repr(st) - self.assertTrue(rep.startswith(os.name + ".stat_result")) + self.assertTrue(rep.startswith("os.stat_result")) self.assertIn("st_mode=", rep) self.assertIn("st_ino=", rep) self.assertIn("st_dev=", rep) diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index f5a70a3873..46e012d142 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -11,6 +11,7 @@ import errno import tempfile import time import re +import selectors import sysconfig import warnings import select @@ -19,10 +20,6 @@ import gc import textwrap try: - import resource -except ImportError: - resource = None -try: import threading except ImportError: threading = None @@ -162,8 +159,28 @@ class ProcessTestCase(BaseTestCase): stderr=subprocess.STDOUT) self.assertIn(b'BDFL', output) + def test_check_output_stdin_arg(self): + # check_output() can be called with stdin set to a file + tf = tempfile.TemporaryFile() + self.addCleanup(tf.close) + tf.write(b'pear') + tf.seek(0) + output = subprocess.check_output( + [sys.executable, "-c", + "import sys; sys.stdout.write(sys.stdin.read().upper())"], + stdin=tf) + self.assertIn(b'PEAR', output) + + def test_check_output_input_arg(self): + # check_output() can be called with input set to a string + output = subprocess.check_output( + [sys.executable, "-c", + "import sys; sys.stdout.write(sys.stdin.read().upper())"], + input=b'pear') + self.assertIn(b'PEAR', output) + def test_check_output_stdout_arg(self): - # check_output() function stderr redirected to stdout + # check_output() refuses to accept 'stdout' argument with self.assertRaises(ValueError) as c: output = subprocess.check_output( [sys.executable, "-c", "print('will not be run')"], @@ -171,6 +188,20 @@ class ProcessTestCase(BaseTestCase): self.fail("Expected ValueError when stdout arg supplied.") self.assertIn('stdout', c.exception.args[0]) + def test_check_output_stdin_with_input_arg(self): + # check_output() refuses to accept 'stdin' with 'input' + tf = tempfile.TemporaryFile() + self.addCleanup(tf.close) + tf.write(b'pear') + tf.seek(0) + with self.assertRaises(ValueError) as c: + output = subprocess.check_output( + [sys.executable, "-c", "print('will not be run')"], + stdin=tf, input=b'hare') + self.fail("Expected ValueError when stdin and input args supplied.") + self.assertIn('stdin', c.exception.args[0]) + self.assertIn('input', c.exception.args[0]) + def test_check_output_timeout(self): # check_output() function with timeout arg with self.assertRaises(subprocess.TimeoutExpired) as c: @@ -853,8 +884,9 @@ class ProcessTestCase(BaseTestCase): # # UTF-16 and UTF-32-BE are sufficient to check both with BOM and # without, and UTF-16 and UTF-32. + import _bootlocale for encoding in ['utf-16', 'utf-32-be']: - old_getpreferredencoding = locale.getpreferredencoding + old_getpreferredencoding = _bootlocale.getpreferredencoding # Indirectly via io.TextIOWrapper, Popen() defaults to # locale.getpreferredencoding(False) and earlier in Python 3.2 to # locale.getpreferredencoding(). @@ -865,7 +897,7 @@ class ProcessTestCase(BaseTestCase): encoding) args = [sys.executable, '-c', code] try: - locale.getpreferredencoding = getpreferredencoding + _bootlocale.getpreferredencoding = getpreferredencoding # We set stdin to be non-None because, as of this writing, # a different code path is used when the number of pipes is # zero or one. @@ -874,7 +906,7 @@ class ProcessTestCase(BaseTestCase): stdout=subprocess.PIPE) stdout, stderr = popen.communicate(input='') finally: - locale.getpreferredencoding = old_getpreferredencoding + _bootlocale.getpreferredencoding = old_getpreferredencoding self.assertEqual(stdout, '1\n2\n3\n4') def test_no_leaking(self): @@ -982,8 +1014,7 @@ class ProcessTestCase(BaseTestCase): # value for that limit, but Windows has 2048, so we loop # 1024 times (each call leaked two fds). for i in range(1024): - # Windows raises IOError. Others raise OSError. - with self.assertRaises(EnvironmentError) as c: + with self.assertRaises(OSError) as c: subprocess.Popen(['nonexisting_i_hope'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -1114,47 +1145,6 @@ class ProcessTestCase(BaseTestCase): fds_after_exception = os.listdir(fd_directory) self.assertEqual(fds_before_popen, fds_after_exception) - -# context manager -class _SuppressCoreFiles(object): - """Try to prevent core files from being created.""" - old_limit = None - - def __enter__(self): - """Try to save previous ulimit, then set it to (0, 0).""" - if resource is not None: - try: - self.old_limit = resource.getrlimit(resource.RLIMIT_CORE) - resource.setrlimit(resource.RLIMIT_CORE, (0, 0)) - except (ValueError, resource.error): - pass - - if sys.platform == 'darwin': - # Check if the 'Crash Reporter' on OSX was configured - # in 'Developer' mode and warn that it will get triggered - # when it is. - # - # This assumes that this context manager is used in tests - # that might trigger the next manager. - value = subprocess.Popen(['/usr/bin/defaults', 'read', - 'com.apple.CrashReporter', 'DialogType'], - stdout=subprocess.PIPE).communicate()[0] - if value.strip() == b'developer': - print("this tests triggers the Crash Reporter, " - "that is intentional", end='') - sys.stdout.flush() - - def __exit__(self, *args): - """Return core file behavior to default.""" - if self.old_limit is None: - return - if resource is not None: - try: - resource.setrlimit(resource.RLIMIT_CORE, self.old_limit) - except (ValueError, resource.error): - pass - - @unittest.skipIf(mswindows, "POSIX specific tests") class POSIXProcessTestCase(BaseTestCase): @@ -1243,7 +1233,7 @@ class POSIXProcessTestCase(BaseTestCase): def test_run_abort(self): # returncode handles signal termination - with _SuppressCoreFiles(): + with support.SuppressCrashReport(): p = subprocess.Popen([sys.executable, "-c", 'import os; os.abort()']) p.wait() @@ -1266,7 +1256,7 @@ class POSIXProcessTestCase(BaseTestCase): try: p = subprocess.Popen([sys.executable, "-c", ""], preexec_fn=raise_it) - except RuntimeError as e: + except subprocess.SubprocessError as e: self.assertTrue( subprocess._posixsubprocess, "Expected a ValueError from the preexec_fn") @@ -1306,9 +1296,10 @@ class POSIXProcessTestCase(BaseTestCase): """Issue16140: Don't double close pipes on preexec error.""" def raise_it(): - raise RuntimeError("force the _execute_child() errpipe_data path.") + raise subprocess.SubprocessError( + "force the _execute_child() errpipe_data path.") - with self.assertRaises(RuntimeError): + with self.assertRaises(subprocess.SubprocessError): self._TestExecuteChildPopen( self, [sys.executable, "-c", "pass"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, @@ -1507,16 +1498,28 @@ class POSIXProcessTestCase(BaseTestCase): # Terminating a dead process self._kill_dead_process('terminate') + def _save_fds(self, save_fds): + fds = [] + for fd in save_fds: + inheritable = os.get_inheritable(fd) + saved = os.dup(fd) + fds.append((fd, saved, inheritable)) + return fds + + def _restore_fds(self, fds): + for fd, saved, inheritable in fds: + os.dup2(saved, fd, inheritable=inheritable) + os.close(saved) + def check_close_std_fds(self, fds): # Issue #9905: test that subprocess pipes still work properly with # some standard fds closed stdin = 0 - newfds = [] - for a in fds: - b = os.dup(a) - newfds.append(b) - if a == 0: - stdin = b + saved_fds = self._save_fds(fds) + for fd, saved, inheritable in saved_fds: + if fd == 0: + stdin = saved + break try: for fd in fds: os.close(fd) @@ -1531,10 +1534,7 @@ class POSIXProcessTestCase(BaseTestCase): err = support.strip_python_stderr(err) self.assertEqual((out, err), (b'apple', b'orange')) finally: - for b, a in zip(newfds, fds): - os.dup2(b, a) - for b in newfds: - os.close(b) + self._restore_fds(saved_fds) def test_close_fd_0(self): self.check_close_std_fds([0]) @@ -1595,7 +1595,7 @@ class POSIXProcessTestCase(BaseTestCase): os.lseek(temp_fds[1], 0, 0) # move the standard file descriptors out of the way - saved_fds = [os.dup(fd) for fd in range(3)] + saved_fds = self._save_fds(range(3)) try: # duplicate the file objects over the standard fd's for fd, temp_fd in enumerate(temp_fds): @@ -1611,10 +1611,7 @@ class POSIXProcessTestCase(BaseTestCase): stderr=temp_fds[0]) p.wait() finally: - # restore the original fd's underneath sys.stdin, etc. - for std, saved in enumerate(saved_fds): - os.dup2(saved, std) - os.close(saved) + self._restore_fds(saved_fds) for fd in temp_fds: os.lseek(fd, 0, 0) @@ -1638,7 +1635,7 @@ class POSIXProcessTestCase(BaseTestCase): os.unlink(fname) # save a copy of the standard file descriptors - saved_fds = [os.dup(fd) for fd in range(3)] + saved_fds = self._save_fds(range(3)) try: # duplicate the temp files over the standard fd's 0, 1, 2 for fd, temp_fd in enumerate(temp_fds): @@ -1664,9 +1661,7 @@ class POSIXProcessTestCase(BaseTestCase): out = os.read(stdout_no, 1024) err = support.strip_python_stderr(os.read(stderr_no, 1024)) finally: - for std, saved in enumerate(saved_fds): - os.dup2(saved, std) - os.close(saved) + self._restore_fds(saved_fds) self.assertEqual(out, b"got STDIN") self.assertEqual(err, b"err") @@ -1698,40 +1693,47 @@ class POSIXProcessTestCase(BaseTestCase): # Pure Python implementations keeps the message self.assertIsNone(subprocess._posixsubprocess) self.assertEqual(str(err), "surrogate:\uDCff") - except RuntimeError as err: + except subprocess.SubprocessError as err: # _posixsubprocess uses a default message self.assertIsNotNone(subprocess._posixsubprocess) self.assertEqual(str(err), "Exception occurred in preexec_fn.") else: - self.fail("Expected ValueError or RuntimeError") + self.fail("Expected ValueError or subprocess.SubprocessError") def test_undecodable_env(self): for key, value in (('test', 'abc\uDCFF'), ('test\uDCFF', '42')): + encoded_value = value.encode("ascii", "surrogateescape") + # test str with surrogates script = "import os; print(ascii(os.getenv(%s)))" % repr(key) env = os.environ.copy() env[key] = value - # Use C locale to get ascii for the locale encoding to force + # Use C locale to get ASCII for the locale encoding to force # surrogate-escaping of \xFF in the child process; otherwise it can # be decoded as-is if the default locale is latin-1. env['LC_ALL'] = 'C' + if sys.platform.startswith("aix"): + # On AIX, the C locale uses the Latin1 encoding + decoded_value = encoded_value.decode("latin1", "surrogateescape") + else: + # On other UNIXes, the C locale uses the ASCII encoding + decoded_value = value stdout = subprocess.check_output( [sys.executable, "-c", script], env=env) stdout = stdout.rstrip(b'\n\r') - self.assertEqual(stdout.decode('ascii'), ascii(value)) + self.assertEqual(stdout.decode('ascii'), ascii(decoded_value)) # test bytes key = key.encode("ascii", "surrogateescape") - value = value.encode("ascii", "surrogateescape") script = "import os; print(ascii(os.getenvb(%s)))" % repr(key) env = os.environ.copy() - env[key] = value + env[key] = encoded_value stdout = subprocess.check_output( [sys.executable, "-c", script], env=env) stdout = stdout.rstrip(b'\n\r') - self.assertEqual(stdout.decode('ascii'), ascii(value)) + self.assertEqual(stdout.decode('ascii'), ascii(encoded_value)) def test_bytes_program(self): abs_program = os.fsencode(sys.executable) @@ -1837,6 +1839,9 @@ class POSIXProcessTestCase(BaseTestCase): self.addCleanup(os.close, fd) open_fds.add(fd) + for fd in open_fds: + os.set_inheritable(fd, True) + p = subprocess.Popen([sys.executable, fd_status], stdout=subprocess.PIPE, close_fds=False) output, ignored = p.communicate() @@ -1881,6 +1886,8 @@ class POSIXProcessTestCase(BaseTestCase): fds = os.pipe() self.addCleanup(os.close, fds[0]) self.addCleanup(os.close, fds[1]) + os.set_inheritable(fds[0], True) + os.set_inheritable(fds[1], True) open_fds.update(fds) for fd in open_fds: @@ -1903,6 +1910,32 @@ class POSIXProcessTestCase(BaseTestCase): close_fds=False, pass_fds=(fd, ))) self.assertIn('overriding close_fds', str(context.warning)) + def test_pass_fds_inheritable(self): + script = support.findfile("fd_status.py", subdir="subprocessdata") + + inheritable, non_inheritable = os.pipe() + self.addCleanup(os.close, inheritable) + self.addCleanup(os.close, non_inheritable) + os.set_inheritable(inheritable, True) + os.set_inheritable(non_inheritable, False) + pass_fds = (inheritable, non_inheritable) + args = [sys.executable, script] + args += list(map(str, pass_fds)) + + p = subprocess.Popen(args, + stdout=subprocess.PIPE, close_fds=True, + pass_fds=pass_fds) + output, ignored = p.communicate() + fds = set(map(int, output.split(b','))) + + # the inheritable file descriptor must be inherited, so its inheritable + # flag must be set in the child process after fork() and before exec() + self.assertEqual(fds, set(pass_fds), "output=%a" % output) + + # inheritable flag must not be changed in the parent process + self.assertEqual(os.get_inheritable(inheritable), True) + self.assertEqual(os.get_inheritable(non_inheritable), False) + def test_stdout_stdin_are_single_inout_fd(self): with io.open(os.devnull, "r+") as inout: p = subprocess.Popen([sys.executable, "-c", "import sys; sys.exit(0)"], @@ -1990,7 +2023,7 @@ class POSIXProcessTestCase(BaseTestCase): # let some time for the process to exit, and create a new Popen: this # should trigger the wait() of p time.sleep(0.2) - with self.assertRaises(EnvironmentError) as c: + with self.assertRaises(OSError) as c: with subprocess.Popen(['nonexisting_i_hope'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: @@ -2175,15 +2208,16 @@ class CommandTests(unittest.TestCase): os.rmdir(dir) -@unittest.skipUnless(getattr(subprocess, '_has_poll', False), - "poll system call not supported") +@unittest.skipUnless(hasattr(selectors, 'PollSelector'), + "Test needs selectors.PollSelector") class ProcessTestCaseNoPoll(ProcessTestCase): def setUp(self): - subprocess._has_poll = False + self.orig_selector = subprocess._PopenSelector + subprocess._PopenSelector = selectors.SelectSelector ProcessTestCase.setUp(self) def tearDown(self): - subprocess._has_poll = True + subprocess._PopenSelector = self.orig_selector ProcessTestCase.tearDown(self) diff --git a/Lib/test/test_sunau.py b/Lib/test/test_sunau.py index 1655317c07..0f4134e982 100644 --- a/Lib/test/test_sunau.py +++ b/Lib/test/test_sunau.py @@ -1,6 +1,7 @@ from test.support import TESTFN import unittest from test import audiotests +from audioop import byteswap import sys import sunau @@ -46,6 +47,31 @@ class SunauPCM16Test(SunauTest, unittest.TestCase): """) +class SunauPCM24Test(SunauTest, unittest.TestCase): + sndfilename = 'pluck-pcm24.au' + sndfilenframes = 3307 + nchannels = 2 + sampwidth = 3 + framerate = 11025 + nframes = 48 + comptype = 'NONE' + compname = 'not compressed' + frames = bytes.fromhex("""\ + 022D65FFEB9D 4B5A0F00FA54 3113C304EE2B 80DCD6084303 \ + CBDEC006B261 48A99803F2F8 BFE82401B07D 036BFBFE7B5D \ + B85756FA3EC9 B4B055F3502B 299830EBCB62 1A5CA7E6D99A \ + EDFA3EE491BD C625EBE27884 0E05A9E0B6CF EF2929E02922 \ + 5758D8E27067 FB3557E83E16 1377BFEF8402 D82C5BF7272A \ + 978F16FB7745 F5F865FC1013 086635FB9C4E DF30FCFB40EE \ + 117FE0FA3438 3EE6B8FB5AC3 BC77A3FCB2F4 66D6DAFF5F32 \ + CF13B9041275 431D69097A8C C1BB600EC74E 5120B912A2BA \ + EEDF641754C0 8207001664B7 7FFFFF14453F 8000001294E6 \ + 499C1B0EB3B2 52B73E0DBCA0 EFB2B20F5FD8 CE3CDB0FBE12 \ + E4B49C0CEA2D 6344A80A5A7C 08C8FE0A1FFE 2BB9860B0A0E \ + 51486F0E44E1 8BCC64113B05 B6F4EC0EEB36 4413170A5B48 \ + """) + + class SunauPCM32Test(SunauTest, unittest.TestCase): sndfilename = 'pluck-pcm32.au' sndfilenframes = 3307 @@ -89,7 +115,7 @@ class SunauULAWTest(SunauTest, unittest.TestCase): E5040CBC 617C0A3C 08BC0A3C 2C7C0B3C 517C0E3C 8A8410FC B6840EBC 457C0A3C \ """) if sys.byteorder != 'big': - frames = audiotests.byteswap2(frames) + frames = byteswap(frames, 2) if __name__ == "__main__": diff --git a/Lib/test/test_sundry.py b/Lib/test/test_sundry.py index 8808c47edd..e99ca9ed37 100644 --- a/Lib/test/test_sundry.py +++ b/Lib/test/test_sundry.py @@ -1,19 +1,26 @@ """Do a minimal test of all the modules that aren't otherwise tested.""" - -from test import support +import importlib import sys +from test import support import unittest class TestUntestedModules(unittest.TestCase): - def test_at_least_import_untested_modules(self): + def test_untested_modules_can_be_imported(self): + untested = ('bdb', 'encodings', 'formatter', + 'nturl2path', 'tabnanny') with support.check_warnings(quiet=True): - import bdb - import cgitb + for name in untested: + try: + support.import_module('test.test_{}'.format(name)) + except unittest.SkipTest: + importlib.import_module(name) + else: + self.fail('{} has tests even though test_sundry claims ' + 'otherwise'.format(name)) import distutils.bcppcompiler import distutils.ccompiler import distutils.cygwinccompiler - import distutils.emxccompiler import distutils.filelist if sys.platform.startswith('win'): import distutils.msvccompiler @@ -39,26 +46,14 @@ class TestUntestedModules(unittest.TestCase): import distutils.command.sdist import distutils.command.upload - import encodings - import formatter import html.entities - import keyword - import mailcap - import nturl2path - import os2emxpath - import pstats - import py_compile - import sndhdr - import tabnanny + try: - import tty # not available on Windows + import tty # Not available on Windows except ImportError: if support.verbose: print("skipping tty") -def test_main(): - support.run_unittest(TestUntestedModules) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py index 1e272ee2cf..37fc2d9134 100644 --- a/Lib/test/test_super.py +++ b/Lib/test/test_super.py @@ -44,6 +44,11 @@ class G(A): class TestSuper(unittest.TestCase): + def tearDown(self): + # This fixes the damage that test_various___class___pathologies does. + nonlocal __class__ + __class__ = TestSuper + def test_basics_working(self): self.assertEqual(D().f(), 'ABCD') @@ -81,8 +86,7 @@ class TestSuper(unittest.TestCase): self.assertEqual(E().f(), 'AE') - @unittest.expectedFailure - def test___class___set(self): + def test_various___class___pathologies(self): # See issue #12370 class X(A): def f(self): @@ -91,6 +95,31 @@ class TestSuper(unittest.TestCase): x = X() self.assertEqual(x.f(), 'A') self.assertEqual(x.__class__, 413) + class X: + x = __class__ + def f(): + __class__ + self.assertIs(X.x, type(self)) + with self.assertRaises(NameError) as e: + exec("""class X: + __class__ + def f(): + __class__""", globals(), {}) + self.assertIs(type(e.exception), NameError) # Not UnboundLocalError + class X: + global __class__ + __class__ = 42 + def f(): + __class__ + self.assertEqual(globals()["__class__"], 42) + del globals()["__class__"] + self.assertNotIn("__class__", X.__dict__) + class X: + nonlocal __class__ + __class__ = 42 + def f(): + __class__ + self.assertEqual(__class__, 42) def test___class___instancemethod(self): # See issue #14857 diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 5913044877..70e63968f2 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -304,6 +304,7 @@ class TestSupport(unittest.TestCase): # args_from_interpreter_flags # can_symlink # skip_unless_symlink + # SuppressCrashReport def test_main(): diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py index 5926b69c93..a9d3628819 100644 --- a/Lib/test/test_syntax.py +++ b/Lib/test/test_syntax.py @@ -33,7 +33,7 @@ SyntaxError: invalid syntax >>> None = 1 Traceback (most recent call last): -SyntaxError: assignment to keyword +SyntaxError: can't assign to keyword It's a syntax error to assign to the empty tuple. Why isn't it an error to assign to the empty list? It will always raise some error at @@ -233,7 +233,7 @@ Traceback (most recent call last): SyntaxError: can't assign to generator expression >>> None += 1 Traceback (most recent call last): -SyntaxError: assignment to keyword +SyntaxError: can't assign to keyword >>> f() += 1 Traceback (most recent call last): SyntaxError: can't assign to function call diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index f768e9b7c3..5a9699ff20 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -7,6 +7,9 @@ import textwrap import warnings import operator import codecs +import gc +import sysconfig +import platform # count the number of test runs, used to create unique # strings to intern in test_intern() @@ -227,7 +230,7 @@ class SysModuleTest(unittest.TestCase): sys.setrecursionlimit(%d) f()""") - with test.support.suppress_crash_popup(): + with test.support.SuppressCrashReport(): for i in (50, 1000): sub = subprocess.Popen([sys.executable, '-c', code % i], stderr=subprocess.PIPE) @@ -408,7 +411,7 @@ class SysModuleTest(unittest.TestCase): self.assertEqual(type(sys.int_info.sizeof_digit), int) self.assertIsInstance(sys.hexversion, int) - self.assertEqual(len(sys.hash_info), 5) + self.assertEqual(len(sys.hash_info), 9) self.assertLess(sys.hash_info.modulus, 2**sys.hash_info.width) # sys.hash_info.modulus should be a prime; we do a quick # probable primality test (doesn't exclude the possibility of @@ -423,6 +426,22 @@ class SysModuleTest(unittest.TestCase): self.assertIsInstance(sys.hash_info.inf, int) self.assertIsInstance(sys.hash_info.nan, int) self.assertIsInstance(sys.hash_info.imag, int) + algo = sysconfig.get_config_var("Py_HASH_ALGORITHM") + if sys.hash_info.algorithm in {"fnv", "siphash24"}: + self.assertIn(sys.hash_info.hash_bits, {32, 64}) + self.assertIn(sys.hash_info.seed_bits, {32, 64, 128}) + + if algo == 1: + self.assertEqual(sys.hash_info.algorithm, "siphash24") + elif algo == 2: + self.assertEqual(sys.hash_info.algorithm, "fnv") + else: + self.assertIn(sys.hash_info.algorithm, {"fnv", "siphash24"}) + else: + # PY_HASH_EXTERNAL + self.assertEqual(algo, 0) + self.assertGreaterEqual(sys.hash_info.cutoff, 0) + self.assertLess(sys.hash_info.cutoff, 8) self.assertIsInstance(sys.maxsize, int) self.assertIsInstance(sys.maxunicode, int) @@ -460,7 +479,7 @@ class SysModuleTest(unittest.TestCase): def test_thread_info(self): info = sys.thread_info self.assertEqual(len(info), 3) - self.assertIn(info.name, ('nt', 'os2', 'pthread', 'solaris', None)) + self.assertIn(info.name, ('nt', 'pthread', 'solaris', None)) self.assertIn(info.lock, ('semaphore', 'mutex+cond', None)) def test_43581(self): @@ -493,7 +512,7 @@ class SysModuleTest(unittest.TestCase): attrs = ("debug", "inspect", "interactive", "optimize", "dont_write_bytecode", "no_user_site", "no_site", "ignore_environment", "verbose", - "bytes_warning", "quiet", "hash_randomization") + "bytes_warning", "quiet", "hash_randomization", "isolated") for attr in attrs: self.assertTrue(hasattr(sys.flags, attr), attr) self.assertEqual(type(getattr(sys.flags, attr)), int, attr) @@ -522,6 +541,42 @@ class SysModuleTest(unittest.TestCase): out = p.communicate()[0].strip() self.assertEqual(out, b'?') + env["PYTHONIOENCODING"] = "ascii" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + out, err = p.communicate() + self.assertEqual(out, b'') + self.assertIn(b'UnicodeEncodeError:', err) + self.assertIn(rb"'\xa2'", err) + + env["PYTHONIOENCODING"] = "ascii:" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + out, err = p.communicate() + self.assertEqual(out, b'') + self.assertIn(b'UnicodeEncodeError:', err) + self.assertIn(rb"'\xa2'", err) + + env["PYTHONIOENCODING"] = ":surrogateescape" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xdcbd))'], + stdout=subprocess.PIPE, env=env) + out = p.communicate()[0].strip() + self.assertEqual(out, b'\xbd') + + @unittest.skipUnless(test.support.FS_NONASCII, + 'requires OS support of non-ASCII encodings') + def test_ioencoding_nonascii(self): + env = dict(os.environ) + + env["PYTHONIOENCODING"] = "" + p = subprocess.Popen([sys.executable, "-c", + 'print(%a)' % test.support.FS_NONASCII], + stdout=subprocess.PIPE, env=env) + out = p.communicate()[0].strip() + self.assertEqual(out, os.fsencode(test.support.FS_NONASCII)) + @unittest.skipIf(sys.base_prefix != sys.prefix, 'Test is not venv-compatible') def test_executable(self): @@ -589,6 +644,39 @@ class SysModuleTest(unittest.TestCase): ret, out, err = assert_python_ok(*args) self.assertIn(b"free PyDictObjects", err) + # The function has no parameter + self.assertRaises(TypeError, sys._debugmallocstats, True) + + @unittest.skipUnless(hasattr(sys, "getallocatedblocks"), + "sys.getallocatedblocks unavailable on this build") + def test_getallocatedblocks(self): + # Some sanity checks + with_pymalloc = sysconfig.get_config_var('WITH_PYMALLOC') + a = sys.getallocatedblocks() + self.assertIs(type(a), int) + if with_pymalloc: + self.assertGreater(a, 0) + else: + # When WITH_PYMALLOC isn't available, we don't know anything + # about the underlying implementation: the function might + # return 0 or something greater. + self.assertGreaterEqual(a, 0) + try: + # While we could imagine a Python session where the number of + # multiple buffer objects would exceed the sharing of references, + # it is unlikely to happen in a normal test run. + self.assertLess(a, sys.gettotalrefcount()) + except AttributeError: + # gettotalrefcount() not available + pass + gc.collect() + b = sys.getallocatedblocks() + self.assertLessEqual(b, a) + gc.collect() + c = sys.getallocatedblocks() + self.assertIn(c, range(b - 50, b + 50)) + + @test.support.cpython_only class SizeofTest(unittest.TestCase): @@ -634,7 +722,7 @@ class SizeofTest(unittest.TestCase): samples = [b'', b'u'*100000] for sample in samples: x = bytearray(sample) - check(x, vsize('inP') + x.__alloc__()) + check(x, vsize('n2Pi') + x.__alloc__()) # bytearray_iterator check(iter(bytearray()), size('nP')) # cell @@ -713,7 +801,7 @@ class SizeofTest(unittest.TestCase): nfrees = len(x.f_code.co_freevars) extras = x.f_code.co_stacksize + x.f_code.co_nlocals +\ ncells + nfrees - 1 - check(x, vsize('12P3i' + CO_MAXBLOCKS*'3i' + 'P' + extras*'P')) + check(x, vsize('12P3ic' + CO_MAXBLOCKS*'3i' + 'P' + extras*'P')) # function def func(): pass check(func, size('12P')) @@ -759,7 +847,7 @@ class SizeofTest(unittest.TestCase): # memoryview check(memoryview(b''), size('Pnin 2P2n2i5P 3cPn')) # module - check(unittest, size('PnP')) + check(unittest, size('PnPPP')) # None check(None, size('')) # NotImplementedType @@ -813,11 +901,11 @@ class SizeofTest(unittest.TestCase): check((1,2,3), vsize('') + 3*self.P) # type # static type: PyTypeObject - s = vsize('P2n15Pl4Pn9Pn11PI') + s = vsize('P2n15Pl4Pn9Pn11PIP') check(int, s) # (PyTypeObject + PyNumberMethods + PyMappingMethods + # PySequenceMethods + PyBufferProcs + 4P) - s = vsize('P2n15Pl4Pn9Pn11PI') + struct.calcsize('34P 3P 10P 2P 4P') + s = vsize('P2n15Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P') # Separate block for PyDictKeysObject with 4 entries s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P") # class diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py index 03f67fd476..804c446716 100644 --- a/Lib/test/test_sysconfig.py +++ b/Lib/test/test_sysconfig.py @@ -5,7 +5,7 @@ import subprocess import shutil from copy import copy -from test.support import (run_unittest, TESTFN, unlink, +from test.support import (run_unittest, TESTFN, unlink, check_warnings, captured_stdout, skip_unless_symlink) import sysconfig @@ -234,7 +234,7 @@ class TestSysConfig(unittest.TestCase): self.assertTrue(os.path.isfile(config_h), config_h) def test_get_scheme_names(self): - wanted = ('nt', 'nt_user', 'os2', 'os2_home', 'osx_framework_user', + wanted = ('nt', 'nt_user', 'osx_framework_user', 'posix_home', 'posix_prefix', 'posix_user') self.assertEqual(get_scheme_names(), wanted) @@ -305,14 +305,13 @@ class TestSysConfig(unittest.TestCase): if 'MACOSX_DEPLOYMENT_TARGET' in env: del env['MACOSX_DEPLOYMENT_TARGET'] - with open('/dev/null', 'w') as devnull_fp: - p = subprocess.Popen([ - sys.executable, '-c', - 'import sysconfig; print(sysconfig.get_platform())', - ], - stdout=subprocess.PIPE, - stderr=devnull_fp, - env=env) + p = subprocess.Popen([ + sys.executable, '-c', + 'import sysconfig; print(sysconfig.get_platform())', + ], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env=env) test_platform = p.communicate()[0].strip() test_platform = test_platform.decode('utf-8') status = p.wait() @@ -325,20 +324,19 @@ class TestSysConfig(unittest.TestCase): env = os.environ.copy() env['MACOSX_DEPLOYMENT_TARGET'] = '10.1' - with open('/dev/null') as dev_null: - p = subprocess.Popen([ - sys.executable, '-c', - 'import sysconfig; print(sysconfig.get_platform())', - ], - stdout=subprocess.PIPE, - stderr=dev_null, - env=env) - test_platform = p.communicate()[0].strip() - test_platform = test_platform.decode('utf-8') - status = p.wait() - - self.assertEqual(status, 0) - self.assertEqual(my_platform, test_platform) + p = subprocess.Popen([ + sys.executable, '-c', + 'import sysconfig; print(sysconfig.get_platform())', + ], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env=env) + test_platform = p.communicate()[0].strip() + test_platform = test_platform.decode('utf-8') + status = p.wait() + + self.assertEqual(status, 0) + self.assertEqual(my_platform, test_platform) def test_srcdir(self): # See Issues #15322, #15364. @@ -371,6 +369,26 @@ class TestSysConfig(unittest.TestCase): os.chdir(cwd) self.assertEqual(srcdir, srcdir2) + @unittest.skipIf(sysconfig.get_config_var('EXT_SUFFIX') is None, + 'EXT_SUFFIX required for this test') + def test_SO_deprecation(self): + self.assertWarns(DeprecationWarning, + sysconfig.get_config_var, 'SO') + + @unittest.skipIf(sysconfig.get_config_var('EXT_SUFFIX') is None, + 'EXT_SUFFIX required for this test') + def test_SO_value(self): + with check_warnings(('', DeprecationWarning)): + self.assertEqual(sysconfig.get_config_var('SO'), + sysconfig.get_config_var('EXT_SUFFIX')) + + @unittest.skipIf(sysconfig.get_config_var('EXT_SUFFIX') is None, + 'EXT_SUFFIX required for this test') + def test_SO_in_vars(self): + vars = sysconfig.get_config_vars() + self.assertIsNotNone(vars['SO']) + self.assertEqual(vars['SO'], vars['EXT_SUFFIX']) + class MakefileTests(unittest.TestCase): diff --git a/Lib/test/test_syslog.py b/Lib/test/test_syslog.py index 4e7621e5ec..b7fd2bdabb 100644 --- a/Lib/test/test_syslog.py +++ b/Lib/test/test_syslog.py @@ -32,6 +32,10 @@ class Test(unittest.TestCase): def test_log_upto(self): syslog.LOG_UPTO(syslog.LOG_INFO) + def test_openlog_noargs(self): + syslog.openlog() + syslog.syslog('test message from python test_syslog') + def test_main(): support.run_unittest(__name__) diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index 68fe60801d..92f8bfe920 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -7,7 +7,7 @@ from hashlib import md5 import unittest import tarfile -from test import support +from test import support, script_helper # Check for our compression modules. try: @@ -27,11 +27,13 @@ def md5sum(data): return md5(data).hexdigest() TEMPDIR = os.path.abspath(support.TESTFN) + "-tardir" +tarextdir = TEMPDIR + '-extract-test' tarname = support.findfile("testtar.tar") gzipname = os.path.join(TEMPDIR, "testtar.tar.gz") bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2") xzname = os.path.join(TEMPDIR, "testtar.tar.xz") tmpname = os.path.join(TEMPDIR, "tmp.tar") +dotlessname = os.path.join(TEMPDIR, "testtar") md5_regtype = "65f477c818ad9e15f7feab0c6d37742f" md5_sparse = "a54fbc4ca4f4399a90e1b27164012fc6" @@ -271,7 +273,7 @@ class ListTest(ReadTest, unittest.TestCase): # ?rw-r--r-- tarfile/tarfile 7011 2003-01-06 07:19:43 ustar/conttype # ?rw-r--r-- tarfile/tarfile 7011 2003-01-06 07:19:43 ustar/regtype # ... - self.assertRegex(out, (br'-rw-r--r-- tarfile/tarfile\s+7011 ' + self.assertRegex(out, (br'\?rw-r--r-- tarfile/tarfile\s+7011 ' br'\d{4}-\d\d-\d\d\s+\d\d:\d\d:\d\d ' br'ustar/\w+type ?\r?\n') * 2) # Make sure it prints the source of link with verbose flag @@ -319,12 +321,7 @@ class CommonReadTest(ReadTest): def test_non_existent_tarfile(self): # Test for issue11513: prevent non-existent gzipped tarfiles raising # multiple exceptions. - test = 'xxx' - if sys.platform == 'win32' and '|' in self.mode: - # Issue #20384: On Windows os.open() error message doesn't - # contain file name. - test = '' - with self.assertRaisesRegex(FileNotFoundError, test): + with self.assertRaisesRegex(FileNotFoundError, "xxx"): tarfile.open("xxx", self.mode) def test_null_tarfile(self): @@ -1847,6 +1844,171 @@ class MiscTest(unittest.TestCase): tarfile.itn(0x10000000000, 6, tarfile.GNU_FORMAT) +class CommandLineTest(unittest.TestCase): + + def tarfilecmd(self, *args, **kwargs): + rc, out, err = script_helper.assert_python_ok('-m', 'tarfile', *args, + **kwargs) + return out.replace(os.linesep.encode(), b'\n') + + def tarfilecmd_failure(self, *args): + return script_helper.assert_python_failure('-m', 'tarfile', *args) + + def make_simple_tarfile(self, tar_name): + files = [support.findfile('tokenize_tests.txt'), + support.findfile('tokenize_tests-no-coding-cookie-' + 'and-utf8-bom-sig-only.txt')] + self.addCleanup(support.unlink, tar_name) + with tarfile.open(tar_name, 'w') as tf: + for tardata in files: + tf.add(tardata, arcname=os.path.basename(tardata)) + + def test_test_command(self): + for tar_name in testtarnames: + for opt in '-t', '--test': + out = self.tarfilecmd(opt, tar_name) + self.assertEqual(out, b'') + + def test_test_command_verbose(self): + for tar_name in testtarnames: + for opt in '-v', '--verbose': + out = self.tarfilecmd(opt, '-t', tar_name) + self.assertIn(b'is a tar archive.\n', out) + + def test_test_command_invalid_file(self): + zipname = support.findfile('zipdir.zip') + rc, out, err = self.tarfilecmd_failure('-t', zipname) + self.assertIn(b' is not a tar archive.', err) + self.assertEqual(out, b'') + self.assertEqual(rc, 1) + + for tar_name in testtarnames: + with self.subTest(tar_name=tar_name): + with open(tar_name, 'rb') as f: + data = f.read() + try: + with open(tmpname, 'wb') as f: + f.write(data[:511]) + rc, out, err = self.tarfilecmd_failure('-t', tmpname) + self.assertEqual(out, b'') + self.assertEqual(rc, 1) + finally: + support.unlink(tmpname) + + def test_list_command(self): + for tar_name in testtarnames: + with support.captured_stdout() as t: + with tarfile.open(tar_name, 'r') as tf: + tf.list(verbose=False) + expected = t.getvalue().encode('ascii', 'backslashreplace') + for opt in '-l', '--list': + out = self.tarfilecmd(opt, tar_name, + PYTHONIOENCODING='ascii') + self.assertEqual(out, expected) + + def test_list_command_verbose(self): + for tar_name in testtarnames: + with support.captured_stdout() as t: + with tarfile.open(tar_name, 'r') as tf: + tf.list(verbose=True) + expected = t.getvalue().encode('ascii', 'backslashreplace') + for opt in '-v', '--verbose': + out = self.tarfilecmd(opt, '-l', tar_name, + PYTHONIOENCODING='ascii') + self.assertEqual(out, expected) + + def test_list_command_invalid_file(self): + zipname = support.findfile('zipdir.zip') + rc, out, err = self.tarfilecmd_failure('-l', zipname) + self.assertIn(b' is not a tar archive.', err) + self.assertEqual(out, b'') + self.assertEqual(rc, 1) + + def test_create_command(self): + files = [support.findfile('tokenize_tests.txt'), + support.findfile('tokenize_tests-no-coding-cookie-' + 'and-utf8-bom-sig-only.txt')] + for opt in '-c', '--create': + try: + out = self.tarfilecmd(opt, tmpname, *files) + self.assertEqual(out, b'') + with tarfile.open(tmpname) as tar: + tar.getmembers() + finally: + support.unlink(tmpname) + + def test_create_command_verbose(self): + files = [support.findfile('tokenize_tests.txt'), + support.findfile('tokenize_tests-no-coding-cookie-' + 'and-utf8-bom-sig-only.txt')] + for opt in '-v', '--verbose': + try: + out = self.tarfilecmd(opt, '-c', tmpname, *files) + self.assertIn(b' file created.', out) + with tarfile.open(tmpname) as tar: + tar.getmembers() + finally: + support.unlink(tmpname) + + def test_create_command_dotless_filename(self): + files = [support.findfile('tokenize_tests.txt')] + try: + out = self.tarfilecmd('-c', dotlessname, *files) + self.assertEqual(out, b'') + with tarfile.open(dotlessname) as tar: + tar.getmembers() + finally: + support.unlink(dotlessname) + + def test_create_command_dot_started_filename(self): + tar_name = os.path.join(TEMPDIR, ".testtar") + files = [support.findfile('tokenize_tests.txt')] + try: + out = self.tarfilecmd('-c', tar_name, *files) + self.assertEqual(out, b'') + with tarfile.open(tar_name) as tar: + tar.getmembers() + finally: + support.unlink(tar_name) + + def test_extract_command(self): + self.make_simple_tarfile(tmpname) + for opt in '-e', '--extract': + try: + with support.temp_cwd(tarextdir): + out = self.tarfilecmd(opt, tmpname) + self.assertEqual(out, b'') + finally: + support.rmtree(tarextdir) + + def test_extract_command_verbose(self): + self.make_simple_tarfile(tmpname) + for opt in '-v', '--verbose': + try: + with support.temp_cwd(tarextdir): + out = self.tarfilecmd(opt, '-e', tmpname) + self.assertIn(b' file is extracted.', out) + finally: + support.rmtree(tarextdir) + + def test_extract_command_different_directory(self): + self.make_simple_tarfile(tmpname) + try: + with support.temp_cwd(tarextdir): + out = self.tarfilecmd('-e', tmpname, 'spamdir') + self.assertEqual(out, b'') + finally: + support.rmtree(tarextdir) + + def test_extract_command_invalid_file(self): + zipname = support.findfile('zipdir.zip') + with support.temp_cwd(tarextdir): + rc, out, err = self.tarfilecmd_failure('-e', zipname) + self.assertIn(b' is not a tar archive.', err) + self.assertEqual(out, b'') + self.assertEqual(rc, 1) + + class ContextManagerTest(unittest.TestCase): def test_basic(self): @@ -1855,20 +2017,20 @@ class ContextManagerTest(unittest.TestCase): self.assertTrue(tar.closed, "context manager failed") def test_closed(self): - # The __enter__() method is supposed to raise IOError + # The __enter__() method is supposed to raise OSError # if the TarFile object is already closed. tar = tarfile.open(tarname) tar.close() - with self.assertRaises(IOError): + with self.assertRaises(OSError): with tar: pass def test_exception(self): - # Test if the IOError exception is passed through properly. + # Test if the OSError exception is passed through properly. with self.assertRaises(Exception) as exc: with tarfile.open(tarname) as tar: - raise IOError - self.assertIsInstance(exc.exception, IOError, + raise OSError + self.assertIsInstance(exc.exception, OSError, "wrong exception raised in context manager") self.assertTrue(tar.closed, "context manager failed") @@ -1974,6 +2136,8 @@ def setUpModule(): support.unlink(TEMPDIR) os.makedirs(TEMPDIR) + global testtarnames + testtarnames = [tarname] with open(tarname, "rb") as fobj: data = fobj.read() @@ -1981,6 +2145,7 @@ def setUpModule(): for c in GzipTest, Bz2Test, LzmaTest: if c.open: support.unlink(c.tarname) + testtarnames.append(c.tarname) with c.open(c.tarname, "wb") as tar: tar.write(data) diff --git a/Lib/test/test_tcl.py b/Lib/test/test_tcl.py index 5226ee676e..d12fb22484 100644 --- a/Lib/test/test_tcl.py +++ b/Lib/test/test_tcl.py @@ -504,40 +504,6 @@ class TclTest(unittest.TestCase): for arg, res in testcases: self.assertEqual(split(arg), res, msg=arg) - def test_merge(self): - with support.check_warnings(('merge is deprecated', - DeprecationWarning)): - merge = self.interp.tk.merge - call = self.interp.tk.call - testcases = [ - ((), ''), - (('a',), 'a'), - ((2,), '2'), - (('',), '{}'), - ('{', '\\{'), - (('a', 'b', 'c'), 'a b c'), - ((' ', '\t', '\r', '\n'), '{ } {\t} {\r} {\n}'), - (('a', ' ', 'c'), 'a { } c'), - (('a', '€'), 'a €'), - (('a', '\U000104a2'), 'a \U000104a2'), - (('a', b'\xe2\x82\xac'), 'a €'), - (('a', ('b', 'c')), 'a {b c}'), - (('a', 2), 'a 2'), - (('a', 3.4), 'a 3.4'), - (('a', (2, 3.4)), 'a {2 3.4}'), - ((), ''), - ((call('list', 1, '2', (3.4,)),), '{1 2 3.4}'), - ] - if tcl_version >= (8, 5): - testcases += [ - ((call('dict', 'create', 12, '\u20ac', b'\xe2\x82\xac', (3.4,)),), - '{12 € € 3.4}'), - ] - for args, res in testcases: - self.assertEqual(merge(*args), res, msg=args) - self.assertRaises(UnicodeDecodeError, merge, b'\x80') - self.assertRaises(UnicodeEncodeError, merge, '\udc80') - class BigmemTclTest(unittest.TestCase): diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py index c9f2ccb462..ee1c35705b 100644 --- a/Lib/test/test_telnetlib.py +++ b/Lib/test/test_telnetlib.py @@ -1,10 +1,9 @@ import socket -import select +import selectors import telnetlib import time import contextlib -import unittest from unittest import TestCase from test import support threading = support.import_module('threading') @@ -112,40 +111,37 @@ class TelnetAlike(telnetlib.Telnet): self._messages += out.getvalue() return -def mock_select(*s_args): - block = False - for l in s_args: - for fob in l: - if isinstance(fob, TelnetAlike): - block = fob.sock.block - if block: - return [[], [], []] - else: - return s_args - -class MockPoller(object): - test_case = None # Set during TestCase setUp. +class MockSelector(selectors.BaseSelector): def __init__(self): - self._file_objs = [] + self.keys = {} + + @property + def resolution(self): + return 1e-3 + + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key - def register(self, fd, eventmask): - self.test_case.assertTrue(hasattr(fd, 'fileno'), fd) - self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI) - self._file_objs.append(fd) + def unregister(self, fileobj): + return self.keys.pop(fileobj) - def poll(self, timeout=None): + def select(self, timeout=None): block = False - for fob in self._file_objs: - if isinstance(fob, TelnetAlike): - block = fob.sock.block + for fileobj in self.keys: + if isinstance(fileobj, TelnetAlike): + block = fileobj.sock.block + break if block: return [] else: - return zip(self._file_objs, [select.POLLIN]*len(self._file_objs)) + return [(key, key.events) for key in self.keys.values()] + + def get_map(self): + return self.keys - def unregister(self, fd): - self._file_objs.remove(fd) @contextlib.contextmanager def test_socket(reads): @@ -159,7 +155,7 @@ def test_socket(reads): socket.create_connection = old_conn return -def test_telnet(reads=(), cls=TelnetAlike, use_poll=None): +def test_telnet(reads=(), cls=TelnetAlike): ''' return a telnetlib.Telnet object that uses a SocketStub with reads queued up to be read ''' for x in reads: @@ -167,29 +163,14 @@ def test_telnet(reads=(), cls=TelnetAlike, use_poll=None): with test_socket(reads): telnet = cls('dummy', 0) telnet._messages = '' # debuglevel output - if use_poll is not None: - if use_poll and not telnet._has_poll: - raise unittest.SkipTest('select.poll() required.') - telnet._has_poll = use_poll return telnet - class ExpectAndReadTestCase(TestCase): def setUp(self): - self.old_select = select.select - select.select = mock_select - self.old_poll = False - if hasattr(select, 'poll'): - self.old_poll = select.poll - select.poll = MockPoller - MockPoller.test_case = self - + self.old_selector = telnetlib._TelnetSelector + telnetlib._TelnetSelector = MockSelector def tearDown(self): - if self.old_poll: - MockPoller.test_case = None - select.poll = self.old_poll - select.select = self.old_select - + telnetlib._TelnetSelector = self.old_selector class ReadTests(ExpectAndReadTestCase): def test_read_until(self): @@ -208,22 +189,6 @@ class ReadTests(ExpectAndReadTestCase): data = telnet.read_until(b'match') self.assertEqual(data, expect) - def test_read_until_with_poll(self): - """Use select.poll() to implement telnet.read_until().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=True) - select.select = lambda *_: self.fail('unexpected select() call.') - data = telnet.read_until(b'match') - self.assertEqual(data, b''.join(want[:-1])) - - def test_read_until_with_select(self): - """Use select.select() to implement telnet.read_until().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=False) - if self.old_poll: - select.poll = lambda *_: self.fail('unexpected poll() call.') - data = telnet.read_until(b'match') - self.assertEqual(data, b''.join(want[:-1])) def test_read_all(self): """ @@ -427,23 +392,6 @@ class ExpectTests(ExpectAndReadTestCase): (_,_,data) = telnet.expect([b'match']) self.assertEqual(data, b''.join(want[:-1])) - def test_expect_with_poll(self): - """Use select.poll() to implement telnet.expect().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=True) - select.select = lambda *_: self.fail('unexpected select() call.') - (_,_,data) = telnet.expect([b'match']) - self.assertEqual(data, b''.join(want[:-1])) - - def test_expect_with_select(self): - """Use select.select() to implement telnet.expect().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=False) - if self.old_poll: - select.poll = lambda *_: self.fail('unexpected poll() call.') - (_,_,data) = telnet.expect([b'match']) - self.assertEqual(data, b''.join(want[:-1])) - def test_main(verbose=None): support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests, diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py index 5a970355e6..cf2ae080be 100644 --- a/Lib/test/test_tempfile.py +++ b/Lib/test/test_tempfile.py @@ -37,7 +37,7 @@ else: # Common functionality. class BaseTestCase(unittest.TestCase): - str_check = re.compile(r"[a-zA-Z0-9_-]{6}$") + str_check = re.compile(r"^[a-z0-9_-]{8}$") def setUp(self): self._warnings_manager = support.check_warnings() @@ -64,7 +64,7 @@ class BaseTestCase(unittest.TestCase): nbase = nbase[len(pre):len(nbase)-len(suf)] self.assertTrue(self.str_check.match(nbase), - "random string '%s' does not match /^[a-zA-Z0-9_-]{6}$/" + "random string '%s' does not match ^[a-z0-9_-]{8}$" % nbase) @@ -153,7 +153,7 @@ class TestRandomNameSequence(BaseTestCase): # via any bugs above try: os.kill(pid, signal.SIGKILL) - except EnvironmentError: + except OSError: pass os.close(read_fd) os.close(write_fd) @@ -192,7 +192,7 @@ class TestCandidateTempdirList(BaseTestCase): try: dirname = os.getcwd() - except (AttributeError, os.error): + except (AttributeError, OSError): dirname = os.curdir self.assertIn(dirname, cand) @@ -332,7 +332,7 @@ class TestMkstempInner(BaseTestCase): file = self.do_create() mode = stat.S_IMODE(os.stat(file.name).st_mode) expected = 0o600 - if sys.platform in ('win32', 'os2emx'): + if sys.platform == 'win32': # There's no distinction among 'user', 'group' and 'world'; # replicate the 'user' bits. user = expected >> 6 @@ -349,6 +349,7 @@ class TestMkstempInner(BaseTestCase): v="q" file = self.do_create() + self.assertEqual(os.get_inheritable(file.fd), False) fd = "%d" % file.fd try: @@ -365,7 +366,7 @@ class TestMkstempInner(BaseTestCase): # On Windows a spawn* /path/ with embedded spaces shouldn't be quoted, # but an arg with embedded spaces should be decorated with double # quotes on each end - if sys.platform in ('win32',): + if sys.platform == 'win32': decorated = '"%s"' % sys.executable tester = '"%s"' % tester else: @@ -475,6 +476,20 @@ class TestGetTempDir(BaseTestCase): self.assertTrue(a is b) + def test_case_sensitive(self): + # gettempdir should not flatten its case + # even on a case-insensitive file system + case_sensitive_tempdir = tempfile.mkdtemp("-Temp") + _tempdir, tempfile.tempdir = tempfile.tempdir, None + try: + with support.EnvironmentVarGuard() as env: + # Fake the first env var which is checked as a candidate + env["TMPDIR"] = case_sensitive_tempdir + self.assertEqual(tempfile.gettempdir(), case_sensitive_tempdir) + finally: + tempfile.tempdir = _tempdir + support.rmdir(case_sensitive_tempdir) + class TestMkstemp(BaseTestCase): """Test mkstemp().""" @@ -563,7 +578,7 @@ class TestMkdtemp(BaseTestCase): mode = stat.S_IMODE(os.stat(dir).st_mode) mode &= 0o777 # Mask off sticky bits inherited from /tmp expected = 0o700 - if sys.platform in ('win32', 'os2emx'): + if sys.platform == 'win32': # There's no distinction among 'user', 'group' and 'world'; # replicate the 'user' bits. user = expected >> 6 @@ -1139,8 +1154,9 @@ class TestTemporaryDirectory(BaseTestCase): def test_del_on_shutdown(self): # A TemporaryDirectory may be cleaned up during shutdown with self.do_create() as dir: - for mod in ('os', 'shutil', 'sys', 'tempfile', 'warnings'): + for mod in ('builtins', 'os', 'shutil', 'sys', 'tempfile', 'warnings'): code = """if True: + import builtins import os import shutil import sys @@ -1165,6 +1181,7 @@ class TestTemporaryDirectory(BaseTestCase): "TemporaryDirectory %s exists after cleanup" % tmp_name) err = err.decode('utf-8', 'backslashreplace') self.assertNotIn("Exception ", err) + self.assertIn("ResourceWarning: Implicitly cleaning up", err) def test_warnings_on_cleanup(self): # ResourceWarning will be triggered by __del__ diff --git a/Lib/test/test_textwrap.py b/Lib/test/test_textwrap.py index c86f5cfae8..1bba77eb9e 100644 --- a/Lib/test/test_textwrap.py +++ b/Lib/test/test_textwrap.py @@ -9,9 +9,8 @@ # import unittest -from test import support -from textwrap import TextWrapper, wrap, fill, dedent, indent +from textwrap import TextWrapper, wrap, fill, dedent, indent, shorten class BaseTestCase(unittest.TestCase): @@ -430,6 +429,90 @@ What a mess! self.check_wrap(text, 7, ["aa \xe4\xe4-", "\xe4\xe4"]) +class MaxLinesTestCase(BaseTestCase): + text = "Hello there, how are you this fine day? I'm glad to hear it!" + + def test_simple(self): + self.check_wrap(self.text, 12, + ["Hello [...]"], + max_lines=0) + self.check_wrap(self.text, 12, + ["Hello [...]"], + max_lines=1) + self.check_wrap(self.text, 12, + ["Hello there,", + "how [...]"], + max_lines=2) + self.check_wrap(self.text, 13, + ["Hello there,", + "how are [...]"], + max_lines=2) + self.check_wrap(self.text, 80, [self.text], max_lines=1) + self.check_wrap(self.text, 12, + ["Hello there,", + "how are you", + "this fine", + "day? I'm", + "glad to hear", + "it!"], + max_lines=6) + + def test_spaces(self): + # strip spaces before placeholder + self.check_wrap(self.text, 12, + ["Hello there,", + "how are you", + "this fine", + "day? [...]"], + max_lines=4) + # placeholder at the start of line + self.check_wrap(self.text, 6, + ["Hello", + "[...]"], + max_lines=2) + # final spaces + self.check_wrap(self.text + ' ' * 10, 12, + ["Hello there,", + "how are you", + "this fine", + "day? I'm", + "glad to hear", + "it!"], + max_lines=6) + + def test_placeholder(self): + self.check_wrap(self.text, 12, + ["Hello..."], + max_lines=1, + placeholder='...') + self.check_wrap(self.text, 12, + ["Hello there,", + "how are..."], + max_lines=2, + placeholder='...') + # long placeholder and indentation + with self.assertRaises(ValueError): + wrap(self.text, 16, initial_indent=' ', + max_lines=1, placeholder=' [truncated]...') + with self.assertRaises(ValueError): + wrap(self.text, 16, subsequent_indent=' ', + max_lines=2, placeholder=' [truncated]...') + self.check_wrap(self.text, 16, + [" Hello there,", + " [truncated]..."], + max_lines=2, + initial_indent=' ', + subsequent_indent=' ', + placeholder=' [truncated]...') + self.check_wrap(self.text, 16, + [" [truncated]..."], + max_lines=1, + initial_indent=' ', + subsequent_indent=' ', + placeholder=' [truncated]...') + self.check_wrap(self.text, 80, [self.text], placeholder='.' * 1000) + + class LongWordTestCase (BaseTestCase): def setUp(self): self.wrapper = TextWrapper() @@ -490,6 +573,14 @@ How *do* you spell that odd word, anyways? result = wrap(self.text, width=30, break_long_words=0) self.check(result, expect) + def test_max_lines_long(self): + self.check_wrap(self.text, 12, + ['Did you say ', + '"supercalifr', + 'agilisticexp', + '[...]'], + max_lines=4) + class IndentTestCases(BaseTestCase): @@ -777,12 +868,62 @@ class IndentTestCase(unittest.TestCase): self.assertEqual(indent(text, prefix, predicate), expect) -def test_main(): - support.run_unittest(WrapTestCase, - LongWordTestCase, - IndentTestCases, - DedentTestCase, - IndentTestCase) +class ShortenTestCase(BaseTestCase): + + def check_shorten(self, text, width, expect, **kwargs): + result = shorten(text, width, **kwargs) + self.check(result, expect) + + def test_simple(self): + # Simple case: just words, spaces, and a bit of punctuation + text = "Hello there, how are you this fine day? I'm glad to hear it!" + + self.check_shorten(text, 18, "Hello there, [...]") + self.check_shorten(text, len(text), text) + self.check_shorten(text, len(text) - 1, + "Hello there, how are you this fine day? " + "I'm glad to [...]") + + def test_placeholder(self): + text = "Hello there, how are you this fine day? I'm glad to hear it!" + + self.check_shorten(text, 17, "Hello there,$$", placeholder='$$') + self.check_shorten(text, 18, "Hello there, how$$", placeholder='$$') + self.check_shorten(text, 18, "Hello there, $$", placeholder=' $$') + self.check_shorten(text, len(text), text, placeholder='$$') + self.check_shorten(text, len(text) - 1, + "Hello there, how are you this fine day? " + "I'm glad to hear$$", placeholder='$$') + + def test_empty_string(self): + self.check_shorten("", 6, "") + + def test_whitespace(self): + # Whitespace collapsing + text = """ + This is a paragraph that already has + line breaks and \t tabs too.""" + self.check_shorten(text, 62, + "This is a paragraph that already has line " + "breaks and tabs too.") + self.check_shorten(text, 61, + "This is a paragraph that already has line " + "breaks and [...]") + + self.check_shorten("hello world! ", 12, "hello world!") + self.check_shorten("hello world! ", 11, "hello [...]") + # The leading space is trimmed from the placeholder + # (it would be ugly otherwise). + self.check_shorten("hello world! ", 10, "[...]") + + def test_width_too_small_for_placeholder(self): + shorten("x" * 20, width=8, placeholder="(......)") + with self.assertRaises(ValueError): + shorten("x" * 20, width=8, placeholder="(.......)") + + def test_first_word_too_long_but_placeholder_fits(self): + self.check_shorten("Helloo", 5, "[...]") + if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py index 54372813b9..614490199a 100644 --- a/Lib/test/test_thread.py +++ b/Lib/test/test_thread.py @@ -68,7 +68,7 @@ class ThreadRunningTests(BasicThreadTest): thread.stack_size(0) self.assertEqual(thread.stack_size(), 0, "stack_size not reset to default") - @unittest.skipIf(os.name not in ("nt", "os2", "posix"), 'test meant for nt, os2, and posix') + @unittest.skipIf(os.name not in ("nt", "posix"), 'test meant for nt and posix') def test_nt_and_posix_stack_size(self): try: thread.stack_size(4096) diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py index 6c2965ba6e..c7bfea07a1 100644 --- a/Lib/test/test_threaded_import.py +++ b/Lib/test/test_threaded_import.py @@ -5,8 +5,8 @@ # complains several times about module random having no attribute # randrange, and then Python hangs. +import _imp as imp import os -import imp import importlib import sys import time @@ -57,7 +57,7 @@ circular_imports_modules = { } class Finder: - """A dummy finder to detect concurrent access to its find_module() + """A dummy finder to detect concurrent access to its find_spec() method.""" def __init__(self): @@ -65,8 +65,8 @@ class Finder: self.x = 0 self.lock = threading.Lock() - def find_module(self, name, path=None): - # Simulate some thread-unsafe behaviour. If calls to find_module() + def find_spec(self, name, path=None, target=None): + # Simulate some thread-unsafe behaviour. If calls to find_spec() # are properly serialized, `x` will end up the same as `numcalls`. # Otherwise not. assert imp.lock_held() @@ -80,7 +80,7 @@ class FlushingFinder: """A dummy finder which flushes sys.path_importer_cache when it gets called.""" - def find_module(self, name, path=None): + def find_spec(self, name, path=None, target=None): sys.path_importer_cache.clear() @@ -145,13 +145,13 @@ class ThreadedImportTests(unittest.TestCase): # dedicated meta_path entry. flushing_finder = FlushingFinder() def path_hook(path): - finder.find_module('') + finder.find_spec('') raise ImportError sys.path_hooks.insert(0, path_hook) sys.meta_path.append(flushing_finder) try: # Flush the cache a first time - flushing_finder.find_module('') + flushing_finder.find_spec('') numtests = self.check_parallel_module_init() self.assertGreater(finder.numcalls, 0) self.assertEqual(finder.x, finder.numcalls) diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 11c8979517..7170f60f1c 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -17,13 +17,18 @@ import weakref import os from test.script_helper import assert_python_ok, assert_python_failure import subprocess -try: - import _testcapi -except ImportError: - _testcapi = None from test import lock_tests + +# Between fork() and exec(), only async-safe functions are allowed (issues +# #12316 and #11870), and fork() from a worker thread is known to trigger +# problems with some operating systems (issue #3863): skip problematic tests +# on platforms known to behave badly. +platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5', + 'hp-ux11') + + # A trivial mutable counter. class Counter(object): def __init__(self): @@ -103,7 +108,7 @@ class ThreadTests(BaseTestCase): if verbose: print('waiting for all tasks to complete') for t in threads: - t.join(NUMTASKS) + t.join() self.assertTrue(not t.is_alive()) self.assertNotEqual(t.ident, 0) self.assertFalse(t.ident is None) @@ -471,6 +476,127 @@ class ThreadTests(BaseTestCase): pid, status = os.waitpid(pid, 0) self.assertEqual(0, status) + def test_main_thread(self): + main = threading.main_thread() + self.assertEqual(main.name, 'MainThread') + self.assertEqual(main.ident, threading.current_thread().ident) + self.assertEqual(main.ident, threading.get_ident()) + + def f(): + self.assertNotEqual(threading.main_thread().ident, + threading.current_thread().ident) + th = threading.Thread(target=f) + th.start() + th.join() + + @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") + @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()") + def test_main_thread_after_fork(self): + code = """if 1: + import os, threading + + pid = os.fork() + if pid == 0: + main = threading.main_thread() + print(main.name) + print(main.ident == threading.current_thread().ident) + print(main.ident == threading.get_ident()) + else: + os.waitpid(pid, 0) + """ + _, out, err = assert_python_ok("-c", code) + data = out.decode().replace('\r', '') + self.assertEqual(err, b"") + self.assertEqual(data, "MainThread\nTrue\nTrue\n") + + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") + @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()") + def test_main_thread_after_fork_from_nonmain_thread(self): + code = """if 1: + import os, threading, sys + + def f(): + pid = os.fork() + if pid == 0: + main = threading.main_thread() + print(main.name) + print(main.ident == threading.current_thread().ident) + print(main.ident == threading.get_ident()) + # stdout is fully buffered because not a tty, + # we have to flush before exit. + sys.stdout.flush() + else: + os.waitpid(pid, 0) + + th = threading.Thread(target=f) + th.start() + th.join() + """ + _, out, err = assert_python_ok("-c", code) + data = out.decode().replace('\r', '') + self.assertEqual(err, b"") + self.assertEqual(data, "Thread-1\nTrue\nTrue\n") + + def test_tstate_lock(self): + # Test an implementation detail of Thread objects. + started = _thread.allocate_lock() + finish = _thread.allocate_lock() + started.acquire() + finish.acquire() + def f(): + started.release() + finish.acquire() + time.sleep(0.01) + # The tstate lock is None until the thread is started + t = threading.Thread(target=f) + self.assertIs(t._tstate_lock, None) + t.start() + started.acquire() + self.assertTrue(t.is_alive()) + # The tstate lock can't be acquired when the thread is running + # (or suspended). + tstate_lock = t._tstate_lock + self.assertFalse(tstate_lock.acquire(timeout=0), False) + finish.release() + # When the thread ends, the state_lock can be successfully + # acquired. + self.assertTrue(tstate_lock.acquire(timeout=5), False) + # But is_alive() is still True: we hold _tstate_lock now, which + # prevents is_alive() from knowing the thread's end-of-life C code + # is done. + self.assertTrue(t.is_alive()) + # Let is_alive() find out the C code is done. + tstate_lock.release() + self.assertFalse(t.is_alive()) + # And verify the thread disposed of _tstate_lock. + self.assertTrue(t._tstate_lock is None) + + def test_repr_stopped(self): + # Verify that "stopped" shows up in repr(Thread) appropriately. + started = _thread.allocate_lock() + finish = _thread.allocate_lock() + started.acquire() + finish.acquire() + def f(): + started.release() + finish.acquire() + t = threading.Thread(target=f) + t.start() + started.acquire() + self.assertIn("started", repr(t)) + finish.release() + # "stopped" should appear in the repr in a reasonable amount of time. + # Implementation detail: as of this writing, that's trivially true + # if .join() is called, and almost trivially true if .is_alive() is + # called. The detail we're testing here is that "stopped" shows up + # "all on its own". + LOOKING_FOR = "stopped" + for i in range(500): + if LOOKING_FOR in repr(t): + break + time.sleep(0.01) + self.assertIn(LOOKING_FOR, repr(t)) # we waited at least 5 seconds def test_BoundedSemaphore_limit(self): # BoundedSemaphore should raise ValueError if released too often. @@ -490,14 +616,48 @@ class ThreadTests(BaseTestCase): t.join() self.assertRaises(ValueError, bs.release) -class ThreadJoinOnShutdown(BaseTestCase): + @cpython_only + def test_frame_tstate_tracing(self): + # Issue #14432: Crash when a generator is created in a C thread that is + # destroyed while the generator is still used. The issue was that a + # generator contains a frame, and the frame kept a reference to the + # Python state of the destroyed C thread. The crash occurs when a trace + # function is setup. + + def noop_trace(frame, event, arg): + # no operation + return noop_trace + + def generator(): + while 1: + yield "genereator" + + def callback(): + if callback.gen is None: + callback.gen = generator() + return next(callback.gen) + callback.gen = None + + old_trace = sys.gettrace() + sys.settrace(noop_trace) + try: + # Install a trace function + threading.settrace(noop_trace) - # Between fork() and exec(), only async-safe functions are allowed (issues - # #12316 and #11870), and fork() from a worker thread is known to trigger - # problems with some operating systems (issue #3863): skip problematic tests - # on platforms known to behave badly. - platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5', - 'os2emx', 'hp-ux11') + # Create a generator in a C thread which exits after the call + import _testcapi + _testcapi.call_in_temporary_c_thread(callback) + + # Call the generator in a different Python thread, check that the + # generator didn't keep a reference to the destroyed thread state + for test in range(3): + # The trace function is still called here + callback() + finally: + sys.settrace(old_trace) + + +class ThreadJoinOnShutdown(BaseTestCase): def _run_and_join(self, script): script = """if 1: @@ -570,144 +730,8 @@ class ThreadJoinOnShutdown(BaseTestCase): """ self._run_and_join(script) - def assertScriptHasOutput(self, script, expected_output): - rc, out, err = assert_python_ok("-c", script) - data = out.decode().replace('\r', '') - self.assertEqual(data, expected_output) - - @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") - @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") - def test_4_joining_across_fork_in_worker_thread(self): - # There used to be a possible deadlock when forking from a child - # thread. See http://bugs.python.org/issue6643. - - # The script takes the following steps: - # - The main thread in the parent process starts a new thread and then - # tries to join it. - # - The join operation acquires the Lock inside the thread's _block - # Condition. (See threading.py:Thread.join().) - # - We stub out the acquire method on the condition to force it to wait - # until the child thread forks. (See LOCK ACQUIRED HERE) - # - The child thread forks. (See LOCK HELD and WORKER THREAD FORKS - # HERE) - # - The main thread of the parent process enters Condition.wait(), - # which releases the lock on the child thread. - # - The child process returns. Without the necessary fix, when the - # main thread of the child process (which used to be the child thread - # in the parent process) attempts to exit, it will try to acquire the - # lock in the Thread._block Condition object and hang, because the - # lock was held across the fork. - - script = """if 1: - import os, time, threading - - finish_join = False - start_fork = False - - def worker(): - # Wait until this thread's lock is acquired before forking to - # create the deadlock. - global finish_join - while not start_fork: - time.sleep(0.01) - # LOCK HELD: Main thread holds lock across this call. - childpid = os.fork() - finish_join = True - if childpid != 0: - # Parent process just waits for child. - os.waitpid(childpid, 0) - # Child process should just return. - - w = threading.Thread(target=worker) - - # Stub out the private condition variable's lock acquire method. - # This acquires the lock and then waits until the child has forked - # before returning, which will release the lock soon after. If - # someone else tries to fix this test case by acquiring this lock - # before forking instead of resetting it, the test case will - # deadlock when it shouldn't. - condition = w._block - orig_acquire = condition.acquire - call_count_lock = threading.Lock() - call_count = 0 - def my_acquire(): - global call_count - global start_fork - orig_acquire() # LOCK ACQUIRED HERE - start_fork = True - if call_count == 0: - while not finish_join: - time.sleep(0.01) # WORKER THREAD FORKS HERE - with call_count_lock: - call_count += 1 - condition.acquire = my_acquire - - w.start() - w.join() - print('end of main') - """ - self.assertScriptHasOutput(script, "end of main\n") - - @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") - def test_5_clear_waiter_locks_to_avoid_crash(self): - # Check that a spawned thread that forks doesn't segfault on certain - # platforms, namely OS X. This used to happen if there was a waiter - # lock in the thread's condition variable's waiters list. Even though - # we know the lock will be held across the fork, it is not safe to - # release locks held across forks on all platforms, so releasing the - # waiter lock caused a segfault on OS X. Furthermore, since locks on - # OS X are (as of this writing) implemented with a mutex + condition - # variable instead of a semaphore, while we know that the Python-level - # lock will be acquired, we can't know if the internal mutex will be - # acquired at the time of the fork. - - script = """if True: - import os, time, threading - - start_fork = False - - def worker(): - # Wait until the main thread has attempted to join this thread - # before continuing. - while not start_fork: - time.sleep(0.01) - childpid = os.fork() - if childpid != 0: - # Parent process just waits for child. - (cpid, rc) = os.waitpid(childpid, 0) - assert cpid == childpid - assert rc == 0 - print('end of worker thread') - else: - # Child process should just return. - pass - - w = threading.Thread(target=worker) - - # Stub out the private condition variable's _release_save method. - # This releases the condition's lock and flips the global that - # causes the worker to fork. At this point, the problematic waiter - # lock has been acquired once by the waiter and has been put onto - # the waiters list. - condition = w._block - orig_release_save = condition._release_save - def my_release_save(): - global start_fork - orig_release_save() - # Waiter lock held here, condition lock released. - start_fork = True - condition._release_save = my_release_save - - w.start() - w.join() - print('end of main thread') - """ - output = "end of worker thread\nend of main thread\n" - self.assertScriptHasOutput(script, output) - - @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") - def test_6_daemon_threads(self): + def test_4_daemon_threads(self): # Check that a daemon thread cannot crash the interpreter on shutdown # by manipulating internal structures that are being disposed of in # the main thread. @@ -773,45 +797,111 @@ class ThreadJoinOnShutdown(BaseTestCase): for t in threads: t.join() - @cpython_only - @unittest.skipIf(_testcapi is None, "need _testcapi module") - def test_frame_tstate_tracing(self): - # Issue #14432: Crash when a generator is created in a C thread that is - # destroyed while the generator is still used. The issue was that a - # generator contains a frame, and the frame kept a reference to the - # Python state of the destroyed C thread. The crash occurs when a trace - # function is setup. + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + def test_clear_threads_states_after_fork(self): + # Issue #17094: check that threads states are cleared after fork() - def noop_trace(frame, event, arg): - # no operation - return noop_trace + # start a bunch of threads + threads = [] + for i in range(16): + t = threading.Thread(target=lambda : time.sleep(0.3)) + threads.append(t) + t.start() - def generator(): - while 1: - yield "genereator" + pid = os.fork() + if pid == 0: + # check that threads states have been cleared + if len(sys._current_frames()) == 1: + os._exit(0) + else: + os._exit(1) + else: + _, status = os.waitpid(pid, 0) + self.assertEqual(0, status) - def callback(): - if callback.gen is None: - callback.gen = generator() - return next(callback.gen) - callback.gen = None + for t in threads: + t.join() - old_trace = sys.gettrace() - sys.settrace(noop_trace) - try: - # Install a trace function - threading.settrace(noop_trace) - # Create a generator in a C thread which exits after the call - _testcapi.call_in_temporary_c_thread(callback) +class SubinterpThreadingTests(BaseTestCase): - # Call the generator in a different Python thread, check that the - # generator didn't keep a reference to the destroyed thread state - for test in range(3): - # The trace function is still called here - callback() - finally: - sys.settrace(old_trace) + def test_threads_join(self): + # Non-daemon threads should be joined at subinterpreter shutdown + # (issue #18808) + r, w = os.pipe() + self.addCleanup(os.close, r) + self.addCleanup(os.close, w) + code = r"""if 1: + import os + import threading + import time + + def f(): + # Sleep a bit so that the thread is still running when + # Py_EndInterpreter is called. + time.sleep(0.05) + os.write(%d, b"x") + threading.Thread(target=f).start() + """ % (w,) + ret = test.support.run_in_subinterp(code) + self.assertEqual(ret, 0) + # The thread was joined properly. + self.assertEqual(os.read(r, 1), b"x") + + def test_threads_join_2(self): + # Same as above, but a delay gets introduced after the thread's + # Python code returned but before the thread state is deleted. + # To achieve this, we register a thread-local object which sleeps + # a bit when deallocated. + r, w = os.pipe() + self.addCleanup(os.close, r) + self.addCleanup(os.close, w) + code = r"""if 1: + import os + import threading + import time + + class Sleeper: + def __del__(self): + time.sleep(0.05) + + tls = threading.local() + + def f(): + # Sleep a bit so that the thread is still running when + # Py_EndInterpreter is called. + time.sleep(0.05) + tls.x = Sleeper() + os.write(%d, b"x") + threading.Thread(target=f).start() + """ % (w,) + ret = test.support.run_in_subinterp(code) + self.assertEqual(ret, 0) + # The thread was joined properly. + self.assertEqual(os.read(r, 1), b"x") + + @cpython_only + def test_daemon_threads_fatal_error(self): + subinterp_code = r"""if 1: + import os + import threading + import time + + def f(): + # Make sure the daemon thread is still running when + # Py_EndInterpreter is called. + time.sleep(10) + threading.Thread(target=f, daemon=True).start() + """ + script = r"""if 1: + import _testcapi + + _testcapi.run_in_subinterp(%r) + """ % (subinterp_code,) + with test.support.SuppressCrashReport(): + rc, out, err = assert_python_failure("-c", script) + self.assertIn("Fatal Python error: Py_EndInterpreter: " + "not the last thread", err.decode()) class ThreadingExceptionTests(BaseTestCase): diff --git a/Lib/test/test_threadsignals.py b/Lib/test/test_threadsignals.py index f975a75e85..9d92742375 100644 --- a/Lib/test/test_threadsignals.py +++ b/Lib/test/test_threadsignals.py @@ -8,7 +8,7 @@ from test.support import run_unittest, import_module thread = import_module('_thread') import time -if sys.platform[:3] in ('win', 'os2') or sys.platform=='riscos': +if (sys.platform[:3] == 'win'): raise unittest.SkipTest("Can't test signal on %s" % sys.platform) process_pid = os.getpid() @@ -74,6 +74,9 @@ class ThreadSignals(unittest.TestCase): @unittest.skipIf(USING_PTHREAD_COND, 'POSIX condition variables cannot be interrupted') + # Issue #20564: sem_timedwait() cannot be interrupted on OpenBSD + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'lock cannot be interrupted on OpenBSD') def test_lock_acquire_interruption(self): # Mimic receiving a SIGINT (KeyboardInterrupt) with SIGALRM while stuck # in a deadlock. @@ -97,6 +100,9 @@ class ThreadSignals(unittest.TestCase): @unittest.skipIf(USING_PTHREAD_COND, 'POSIX condition variables cannot be interrupted') + # Issue #20564: sem_timedwait() cannot be interrupted on OpenBSD + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'lock cannot be interrupted on OpenBSD') def test_rlock_acquire_interruption(self): # Mimic receiving a SIGINT (KeyboardInterrupt) with SIGALRM while stuck # in a deadlock. diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py index 1a4d873a57..be7ddcc34d 100644 --- a/Lib/test/test_time.py +++ b/Lib/test/test_time.py @@ -14,6 +14,8 @@ except ImportError: SIZEOF_INT = sysconfig.get_config_var('SIZEOF_INT') or 4 TIME_MAXYEAR = (1 << 8 * SIZEOF_INT - 1) - 1 TIME_MINYEAR = -TIME_MAXYEAR - 1 +_PyTime_ROUND_DOWN = 0 +_PyTime_ROUND_UP = 1 class TimeTestCase(unittest.TestCase): @@ -226,7 +228,7 @@ class TimeTestCase(unittest.TestCase): self.assertEqual(time.ctime(t), 'Sun Sep 16 01:03:52 1973') t = time.mktime((2000, 1, 1, 0, 0, 0, 0, 0, -1)) self.assertEqual(time.ctime(t), 'Sat Jan 1 00:00:00 2000') - for year in [-100, 100, 1000, 2000, 10000]: + for year in [-100, 100, 1000, 2000, 2050, 10000]: try: testval = time.mktime((year, 1, 10) + (0,)*6) except (ValueError, OverflowError): @@ -344,6 +346,13 @@ class TimeTestCase(unittest.TestCase): def test_mktime(self): # Issue #1726687 for t in (-2, -1, 0, 1): + if sys.platform.startswith('aix') and t == -1: + # Issue #11188, #19748: mktime() returns -1 on error. On Linux, + # the tm_wday field is used as a sentinel () to detect if -1 is + # really an error or a valid timestamp. On AIX, tm_wday is + # unchanged even on success and so cannot be used as a + # sentinel. + continue try: tt = time.localtime(t) except (OverflowError, OSError): @@ -371,6 +380,14 @@ class TimeTestCase(unittest.TestCase): @unittest.skipUnless(hasattr(time, 'monotonic'), 'need time.monotonic') def test_monotonic(self): + # monotonic() should not go backward + times = [time.monotonic() for n in range(100)] + t1 = times[0] + for t2 in times[1:]: + self.assertGreaterEqual(t2, t1, "times=%s" % times) + t1 = t2 + + # monotonic() includes time elapsed during a sleep t1 = time.monotonic() time.sleep(0.5) t2 = time.monotonic() @@ -379,6 +396,7 @@ class TimeTestCase(unittest.TestCase): # Issue #20101: On some Windows machines, dt may be slightly low self.assertTrue(0.45 <= dt <= 1.0, dt) + # monotonic() is a monotonic but non adjustable clock info = time.get_clock_info('monotonic') self.assertTrue(info.monotonic) self.assertFalse(info.adjustable) @@ -576,58 +594,116 @@ class TestPytime(unittest.TestCase): @support.cpython_only def test_time_t(self): from _testcapi import pytime_object_to_time_t - for obj, time_t in ( - (0, 0), - (-1, -1), - (-1.0, -1), - (-1.9, -1), - (1.0, 1), - (1.9, 1), + for obj, time_t, rnd in ( + # Round towards zero + (0, 0, _PyTime_ROUND_DOWN), + (-1, -1, _PyTime_ROUND_DOWN), + (-1.0, -1, _PyTime_ROUND_DOWN), + (-1.9, -1, _PyTime_ROUND_DOWN), + (1.0, 1, _PyTime_ROUND_DOWN), + (1.9, 1, _PyTime_ROUND_DOWN), + # Round away from zero + (0, 0, _PyTime_ROUND_UP), + (-1, -1, _PyTime_ROUND_UP), + (-1.0, -1, _PyTime_ROUND_UP), + (-1.9, -2, _PyTime_ROUND_UP), + (1.0, 1, _PyTime_ROUND_UP), + (1.9, 2, _PyTime_ROUND_UP), ): - self.assertEqual(pytime_object_to_time_t(obj), time_t) + self.assertEqual(pytime_object_to_time_t(obj, rnd), time_t) + rnd = _PyTime_ROUND_DOWN for invalid in self.invalid_values: - self.assertRaises(OverflowError, pytime_object_to_time_t, invalid) + self.assertRaises(OverflowError, + pytime_object_to_time_t, invalid, rnd) @support.cpython_only def test_timeval(self): from _testcapi import pytime_object_to_timeval - for obj, timeval in ( - (0, (0, 0)), - (-1, (-1, 0)), - (-1.0, (-1, 0)), - (1e-6, (0, 1)), - (-1e-6, (-1, 999999)), - (-1.2, (-2, 800000)), - (1.1234560, (1, 123456)), - (1.1234569, (1, 123456)), - (-1.1234560, (-2, 876544)), - (-1.1234561, (-2, 876543)), + for obj, timeval, rnd in ( + # Round towards zero + (0, (0, 0), _PyTime_ROUND_DOWN), + (-1, (-1, 0), _PyTime_ROUND_DOWN), + (-1.0, (-1, 0), _PyTime_ROUND_DOWN), + (1e-6, (0, 1), _PyTime_ROUND_DOWN), + (1e-7, (0, 0), _PyTime_ROUND_DOWN), + (-1e-6, (-1, 999999), _PyTime_ROUND_DOWN), + (-1e-7, (-1, 999999), _PyTime_ROUND_DOWN), + (-1.2, (-2, 800000), _PyTime_ROUND_DOWN), + (0.9999999, (0, 999999), _PyTime_ROUND_DOWN), + (0.0000041, (0, 4), _PyTime_ROUND_DOWN), + (1.1234560, (1, 123456), _PyTime_ROUND_DOWN), + (1.1234569, (1, 123456), _PyTime_ROUND_DOWN), + (-0.0000040, (-1, 999996), _PyTime_ROUND_DOWN), + (-0.0000041, (-1, 999995), _PyTime_ROUND_DOWN), + (-1.1234560, (-2, 876544), _PyTime_ROUND_DOWN), + (-1.1234561, (-2, 876543), _PyTime_ROUND_DOWN), + # Round away from zero + (0, (0, 0), _PyTime_ROUND_UP), + (-1, (-1, 0), _PyTime_ROUND_UP), + (-1.0, (-1, 0), _PyTime_ROUND_UP), + (1e-6, (0, 1), _PyTime_ROUND_UP), + (1e-7, (0, 1), _PyTime_ROUND_UP), + (-1e-6, (-1, 999999), _PyTime_ROUND_UP), + (-1e-7, (-1, 999999), _PyTime_ROUND_UP), + (-1.2, (-2, 800000), _PyTime_ROUND_UP), + (0.9999999, (1, 0), _PyTime_ROUND_UP), + (0.0000041, (0, 5), _PyTime_ROUND_UP), + (1.1234560, (1, 123457), _PyTime_ROUND_UP), + (1.1234569, (1, 123457), _PyTime_ROUND_UP), + (-0.0000040, (-1, 999996), _PyTime_ROUND_UP), + (-0.0000041, (-1, 999995), _PyTime_ROUND_UP), + (-1.1234560, (-2, 876544), _PyTime_ROUND_UP), + (-1.1234561, (-2, 876543), _PyTime_ROUND_UP), ): - self.assertEqual(pytime_object_to_timeval(obj), timeval) + with self.subTest(obj=obj, round=rnd, timeval=timeval): + self.assertEqual(pytime_object_to_timeval(obj, rnd), timeval) + rnd = _PyTime_ROUND_DOWN for invalid in self.invalid_values: - self.assertRaises(OverflowError, pytime_object_to_timeval, invalid) + self.assertRaises(OverflowError, + pytime_object_to_timeval, invalid, rnd) @support.cpython_only def test_timespec(self): from _testcapi import pytime_object_to_timespec - for obj, timespec in ( - (0, (0, 0)), - (-1, (-1, 0)), - (-1.0, (-1, 0)), - (1e-9, (0, 1)), - (-1e-9, (-1, 999999999)), - (-1.2, (-2, 800000000)), - (1.1234567890, (1, 123456789)), - (1.1234567899, (1, 123456789)), - (-1.1234567890, (-2, 876543211)), - (-1.1234567891, (-2, 876543210)), + for obj, timespec, rnd in ( + # Round towards zero + (0, (0, 0), _PyTime_ROUND_DOWN), + (-1, (-1, 0), _PyTime_ROUND_DOWN), + (-1.0, (-1, 0), _PyTime_ROUND_DOWN), + (1e-9, (0, 1), _PyTime_ROUND_DOWN), + (1e-10, (0, 0), _PyTime_ROUND_DOWN), + (-1e-9, (-1, 999999999), _PyTime_ROUND_DOWN), + (-1e-10, (-1, 999999999), _PyTime_ROUND_DOWN), + (-1.2, (-2, 800000000), _PyTime_ROUND_DOWN), + (0.9999999999, (0, 999999999), _PyTime_ROUND_DOWN), + (1.1234567890, (1, 123456789), _PyTime_ROUND_DOWN), + (1.1234567899, (1, 123456789), _PyTime_ROUND_DOWN), + (-1.1234567890, (-2, 876543211), _PyTime_ROUND_DOWN), + (-1.1234567891, (-2, 876543210), _PyTime_ROUND_DOWN), + # Round away from zero + (0, (0, 0), _PyTime_ROUND_UP), + (-1, (-1, 0), _PyTime_ROUND_UP), + (-1.0, (-1, 0), _PyTime_ROUND_UP), + (1e-9, (0, 1), _PyTime_ROUND_UP), + (1e-10, (0, 1), _PyTime_ROUND_UP), + (-1e-9, (-1, 999999999), _PyTime_ROUND_UP), + (-1e-10, (-1, 999999999), _PyTime_ROUND_UP), + (-1.2, (-2, 800000000), _PyTime_ROUND_UP), + (0.9999999999, (1, 0), _PyTime_ROUND_UP), + (1.1234567890, (1, 123456790), _PyTime_ROUND_UP), + (1.1234567899, (1, 123456790), _PyTime_ROUND_UP), + (-1.1234567890, (-2, 876543211), _PyTime_ROUND_UP), + (-1.1234567891, (-2, 876543210), _PyTime_ROUND_UP), ): - self.assertEqual(pytime_object_to_timespec(obj), timespec) + with self.subTest(obj=obj, round=rnd, timespec=timespec): + self.assertEqual(pytime_object_to_timespec(obj, rnd), timespec) + rnd = _PyTime_ROUND_DOWN for invalid in self.invalid_values: - self.assertRaises(OverflowError, pytime_object_to_timespec, invalid) + self.assertRaises(OverflowError, + pytime_object_to_timespec, invalid, rnd) @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support") def test_localtime_timezone(self): diff --git a/Lib/test/test_timeout.py b/Lib/test/test_timeout.py index dcf201b507..703c43ab2b 100644 --- a/Lib/test/test_timeout.py +++ b/Lib/test/test_timeout.py @@ -207,7 +207,7 @@ class TCPTimeoutTestCase(TimeoutTestCase): sock.connect((whitehole)) except socket.timeout: pass - except IOError as err: + except OSError as err: if err.errno == errno.ECONNREFUSED: skip = False finally: diff --git a/Lib/test/test_tools.py b/Lib/test/test_tools.py index f971515899..1bf7d54c1e 100644 --- a/Lib/test/test_tools.py +++ b/Lib/test/test_tools.py @@ -6,6 +6,7 @@ Tools directory of a Python checkout or tarball, such as reindent.py. import os import sys +import importlib._bootstrap import importlib.machinery import unittest from unittest import mock @@ -405,8 +406,8 @@ class PdepsTests(unittest.TestCase): @classmethod def setUpClass(self): path = os.path.join(scriptsdir, 'pdeps.py') - loader = importlib.machinery.SourceFileLoader('pdeps', path) - self.pdeps = loader.load_module() + spec = importlib.util.spec_from_file_location('pdeps', path) + self.pdeps = importlib._bootstrap._SpecMethods(spec).load() @classmethod def tearDownClass(self): @@ -430,8 +431,8 @@ class Gprof2htmlTests(unittest.TestCase): def setUp(self): path = os.path.join(scriptsdir, 'gprof2html.py') - loader = importlib.machinery.SourceFileLoader('gprof2html', path) - self.gprof = loader.load_module() + spec = importlib.util.spec_from_file_location('gprof2html', path) + self.gprof = importlib._bootstrap._SpecMethods(spec).load() oldargv = sys.argv def fixup(): sys.argv = oldargv diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py index c38c65bb31..c29556354e 100644 --- a/Lib/test/test_traceback.py +++ b/Lib/test/test_traceback.py @@ -172,32 +172,76 @@ class SyntaxTracebackCases(unittest.TestCase): class TracebackFormatTests(unittest.TestCase): + def some_exception(self): + raise KeyError('blah') + @cpython_only - def test_traceback_format(self): + def check_traceback_format(self, cleanup_func=None): from _testcapi import traceback_print try: - raise KeyError('blah') + self.some_exception() except KeyError: type_, value, tb = sys.exc_info() + if cleanup_func is not None: + # Clear the inner frames, not this one + cleanup_func(tb.tb_next) traceback_fmt = 'Traceback (most recent call last):\n' + \ ''.join(traceback.format_tb(tb)) file_ = StringIO() traceback_print(tb, file_) python_fmt = file_.getvalue() + # Call all _tb and _exc functions + with captured_output("stderr") as tbstderr: + traceback.print_tb(tb) + tbfile = StringIO() + traceback.print_tb(tb, file=tbfile) + with captured_output("stderr") as excstderr: + traceback.print_exc() + excfmt = traceback.format_exc() + excfile = StringIO() + traceback.print_exc(file=excfile) else: raise Error("unable to create test traceback string") # Make sure that Python and the traceback module format the same thing self.assertEqual(traceback_fmt, python_fmt) + # Now verify the _tb func output + self.assertEqual(tbstderr.getvalue(), tbfile.getvalue()) + # Now verify the _exc func output + self.assertEqual(excstderr.getvalue(), excfile.getvalue()) + self.assertEqual(excfmt, excfile.getvalue()) # Make sure that the traceback is properly indented. tb_lines = python_fmt.splitlines() - self.assertEqual(len(tb_lines), 3) - banner, location, source_line = tb_lines + self.assertEqual(len(tb_lines), 5) + banner = tb_lines[0] + location, source_line = tb_lines[-2:] self.assertTrue(banner.startswith('Traceback')) self.assertTrue(location.startswith(' File')) self.assertTrue(source_line.startswith(' raise')) + def test_traceback_format(self): + self.check_traceback_format() + + def test_traceback_format_with_cleared_frames(self): + # Check that traceback formatting also works with a clear()ed frame + def cleanup_tb(tb): + tb.tb_frame.clear() + self.check_traceback_format(cleanup_tb) + + def test_stack_format(self): + # Verify _stack functions. Note we have to use _getframe(1) to + # compare them without this frame appearing in the output + with captured_output("stderr") as ststderr: + traceback.print_stack(sys._getframe(1)) + stfile = StringIO() + traceback.print_stack(sys._getframe(1), file=stfile) + self.assertEqual(ststderr.getvalue(), stfile.getvalue()) + + stfmt = traceback.format_stack(sys._getframe(1)) + + self.assertEqual(ststderr.getvalue(), "".join(stfmt)) + cause_message = ( "\nThe above exception was the direct cause " @@ -370,6 +414,36 @@ class CExcReportingTests(BaseExceptionReportingTests, unittest.TestCase): return s.getvalue() +class MiscTracebackCases(unittest.TestCase): + # + # Check non-printing functions in traceback module + # + + def test_clear(self): + def outer(): + middle() + def middle(): + inner() + def inner(): + i = 1 + 1/0 + + try: + outer() + except: + type_, value, tb = sys.exc_info() + + # Initial assertion: there's one local in the inner frame. + inner_frame = tb.tb_next.tb_next.tb_next.tb_frame + self.assertEqual(len(inner_frame.f_locals), 1) + + # Clear traceback frames + traceback.clear_frames(tb) + + # Local variable dict should now be empty. + self.assertEqual(len(inner_frame.f_locals), 0) + + def test_main(): run_unittest(__name__) diff --git a/Lib/test/test_tracemalloc.py b/Lib/test/test_tracemalloc.py new file mode 100644 index 0000000000..d1e5aef5c8 --- /dev/null +++ b/Lib/test/test_tracemalloc.py @@ -0,0 +1,818 @@ +import contextlib +import os +import sys +import tracemalloc +import unittest +from unittest.mock import patch +from test.script_helper import assert_python_ok, assert_python_failure +from test import support +try: + import threading +except ImportError: + threading = None + +EMPTY_STRING_SIZE = sys.getsizeof(b'') + +def get_frames(nframe, lineno_delta): + frames = [] + frame = sys._getframe(1) + for index in range(nframe): + code = frame.f_code + lineno = frame.f_lineno + lineno_delta + frames.append((code.co_filename, lineno)) + lineno_delta = 0 + frame = frame.f_back + if frame is None: + break + return tuple(frames) + +def allocate_bytes(size): + nframe = tracemalloc.get_traceback_limit() + bytes_len = (size - EMPTY_STRING_SIZE) + frames = get_frames(nframe, 1) + data = b'x' * bytes_len + return data, tracemalloc.Traceback(frames) + +def create_snapshots(): + traceback_limit = 2 + + raw_traces = [ + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + + (2, (('a.py', 5), ('b.py', 4))), + + (66, (('b.py', 1),)), + + (7, (('<unknown>', 0),)), + ] + snapshot = tracemalloc.Snapshot(raw_traces, traceback_limit) + + raw_traces2 = [ + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + + (2, (('a.py', 5), ('b.py', 4))), + (5000, (('a.py', 5), ('b.py', 4))), + + (400, (('c.py', 578),)), + ] + snapshot2 = tracemalloc.Snapshot(raw_traces2, traceback_limit) + + return (snapshot, snapshot2) + +def frame(filename, lineno): + return tracemalloc._Frame((filename, lineno)) + +def traceback(*frames): + return tracemalloc.Traceback(frames) + +def traceback_lineno(filename, lineno): + return traceback((filename, lineno)) + +def traceback_filename(filename): + return traceback_lineno(filename, 0) + + +class TestTracemallocEnabled(unittest.TestCase): + def setUp(self): + if tracemalloc.is_tracing(): + self.skipTest("tracemalloc must be stopped before the test") + + tracemalloc.start(1) + + def tearDown(self): + tracemalloc.stop() + + def test_get_tracemalloc_memory(self): + data = [allocate_bytes(123) for count in range(1000)] + size = tracemalloc.get_tracemalloc_memory() + self.assertGreaterEqual(size, 0) + + tracemalloc.clear_traces() + size2 = tracemalloc.get_tracemalloc_memory() + self.assertGreaterEqual(size2, 0) + self.assertLessEqual(size2, size) + + def test_get_object_traceback(self): + tracemalloc.clear_traces() + obj_size = 12345 + obj, obj_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj) + self.assertEqual(traceback, obj_traceback) + + def test_set_traceback_limit(self): + obj_size = 10 + + tracemalloc.stop() + self.assertRaises(ValueError, tracemalloc.start, -1) + + tracemalloc.stop() + tracemalloc.start(10) + obj2, obj2_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj2) + self.assertEqual(len(traceback), 10) + self.assertEqual(traceback, obj2_traceback) + + tracemalloc.stop() + tracemalloc.start(1) + obj, obj_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj) + self.assertEqual(len(traceback), 1) + self.assertEqual(traceback, obj_traceback) + + def find_trace(self, traces, traceback): + for trace in traces: + if trace[1] == traceback._frames: + return trace + + self.fail("trace not found") + + def test_get_traces(self): + tracemalloc.clear_traces() + obj_size = 12345 + obj, obj_traceback = allocate_bytes(obj_size) + + traces = tracemalloc._get_traces() + trace = self.find_trace(traces, obj_traceback) + + self.assertIsInstance(trace, tuple) + size, traceback = trace + self.assertEqual(size, obj_size) + self.assertEqual(traceback, obj_traceback._frames) + + tracemalloc.stop() + self.assertEqual(tracemalloc._get_traces(), []) + + def test_get_traces_intern_traceback(self): + # dummy wrappers to get more useful and identical frames in the traceback + def allocate_bytes2(size): + return allocate_bytes(size) + def allocate_bytes3(size): + return allocate_bytes2(size) + def allocate_bytes4(size): + return allocate_bytes3(size) + + # Ensure that two identical tracebacks are not duplicated + tracemalloc.stop() + tracemalloc.start(4) + obj_size = 123 + obj1, obj1_traceback = allocate_bytes4(obj_size) + obj2, obj2_traceback = allocate_bytes4(obj_size) + + traces = tracemalloc._get_traces() + + trace1 = self.find_trace(traces, obj1_traceback) + trace2 = self.find_trace(traces, obj2_traceback) + size1, traceback1 = trace1 + size2, traceback2 = trace2 + self.assertEqual(traceback2, traceback1) + self.assertIs(traceback2, traceback1) + + def test_get_traced_memory(self): + # Python allocates some internals objects, so the test must tolerate + # a small difference between the expected size and the real usage + max_error = 2048 + + # allocate one object + obj_size = 1024 * 1024 + tracemalloc.clear_traces() + obj, obj_traceback = allocate_bytes(obj_size) + size, peak_size = tracemalloc.get_traced_memory() + self.assertGreaterEqual(size, obj_size) + self.assertGreaterEqual(peak_size, size) + + self.assertLessEqual(size - obj_size, max_error) + self.assertLessEqual(peak_size - size, max_error) + + # destroy the object + obj = None + size2, peak_size2 = tracemalloc.get_traced_memory() + self.assertLess(size2, size) + self.assertGreaterEqual(size - size2, obj_size - max_error) + self.assertGreaterEqual(peak_size2, peak_size) + + # clear_traces() must reset traced memory counters + tracemalloc.clear_traces() + self.assertEqual(tracemalloc.get_traced_memory(), (0, 0)) + + # allocate another object + obj, obj_traceback = allocate_bytes(obj_size) + size, peak_size = tracemalloc.get_traced_memory() + self.assertGreaterEqual(size, obj_size) + + # stop() also resets traced memory counters + tracemalloc.stop() + self.assertEqual(tracemalloc.get_traced_memory(), (0, 0)) + + def test_clear_traces(self): + obj, obj_traceback = allocate_bytes(123) + traceback = tracemalloc.get_object_traceback(obj) + self.assertIsNotNone(traceback) + + tracemalloc.clear_traces() + traceback2 = tracemalloc.get_object_traceback(obj) + self.assertIsNone(traceback2) + + def test_is_tracing(self): + tracemalloc.stop() + self.assertFalse(tracemalloc.is_tracing()) + + tracemalloc.start() + self.assertTrue(tracemalloc.is_tracing()) + + def test_snapshot(self): + obj, source = allocate_bytes(123) + + # take a snapshot + snapshot = tracemalloc.take_snapshot() + + # write on disk + snapshot.dump(support.TESTFN) + self.addCleanup(support.unlink, support.TESTFN) + + # load from disk + snapshot2 = tracemalloc.Snapshot.load(support.TESTFN) + self.assertEqual(snapshot2.traces, snapshot.traces) + + # tracemalloc must be tracing memory allocations to take a snapshot + tracemalloc.stop() + with self.assertRaises(RuntimeError) as cm: + tracemalloc.take_snapshot() + self.assertEqual(str(cm.exception), + "the tracemalloc module must be tracing memory " + "allocations to take a snapshot") + + def test_snapshot_save_attr(self): + # take a snapshot with a new attribute + snapshot = tracemalloc.take_snapshot() + snapshot.test_attr = "new" + snapshot.dump(support.TESTFN) + self.addCleanup(support.unlink, support.TESTFN) + + # load() should recreates the attribute + snapshot2 = tracemalloc.Snapshot.load(support.TESTFN) + self.assertEqual(snapshot2.test_attr, "new") + + def fork_child(self): + if not tracemalloc.is_tracing(): + return 2 + + obj_size = 12345 + obj, obj_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj) + if traceback is None: + return 3 + + # everything is fine + return 0 + + @unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork()') + def test_fork(self): + # check that tracemalloc is still working after fork + pid = os.fork() + if not pid: + # child + exitcode = 1 + try: + exitcode = self.fork_child() + finally: + os._exit(exitcode) + else: + pid2, status = os.waitpid(pid, 0) + self.assertTrue(os.WIFEXITED(status)) + exitcode = os.WEXITSTATUS(status) + self.assertEqual(exitcode, 0) + + +class TestSnapshot(unittest.TestCase): + maxDiff = 4000 + + def test_create_snapshot(self): + raw_traces = [(5, (('a.py', 2),))] + + with contextlib.ExitStack() as stack: + stack.enter_context(patch.object(tracemalloc, 'is_tracing', + return_value=True)) + stack.enter_context(patch.object(tracemalloc, 'get_traceback_limit', + return_value=5)) + stack.enter_context(patch.object(tracemalloc, '_get_traces', + return_value=raw_traces)) + + snapshot = tracemalloc.take_snapshot() + self.assertEqual(snapshot.traceback_limit, 5) + self.assertEqual(len(snapshot.traces), 1) + trace = snapshot.traces[0] + self.assertEqual(trace.size, 5) + self.assertEqual(len(trace.traceback), 1) + self.assertEqual(trace.traceback[0].filename, 'a.py') + self.assertEqual(trace.traceback[0].lineno, 2) + + def test_filter_traces(self): + snapshot, snapshot2 = create_snapshots() + filter1 = tracemalloc.Filter(False, "b.py") + filter2 = tracemalloc.Filter(True, "a.py", 2) + filter3 = tracemalloc.Filter(True, "a.py", 5) + + original_traces = list(snapshot.traces._traces) + + # exclude b.py + snapshot3 = snapshot.filter_traces((filter1,)) + self.assertEqual(snapshot3.traces._traces, [ + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + (2, (('a.py', 5), ('b.py', 4))), + (7, (('<unknown>', 0),)), + ]) + + # filter_traces() must not touch the original snapshot + self.assertEqual(snapshot.traces._traces, original_traces) + + # only include two lines of a.py + snapshot4 = snapshot3.filter_traces((filter2, filter3)) + self.assertEqual(snapshot4.traces._traces, [ + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + (10, (('a.py', 2), ('b.py', 4))), + (2, (('a.py', 5), ('b.py', 4))), + ]) + + # No filter: just duplicate the snapshot + snapshot5 = snapshot.filter_traces(()) + self.assertIsNot(snapshot5, snapshot) + self.assertIsNot(snapshot5.traces, snapshot.traces) + self.assertEqual(snapshot5.traces, snapshot.traces) + + def test_snapshot_group_by_line(self): + snapshot, snapshot2 = create_snapshots() + tb_0 = traceback_lineno('<unknown>', 0) + tb_a_2 = traceback_lineno('a.py', 2) + tb_a_5 = traceback_lineno('a.py', 5) + tb_b_1 = traceback_lineno('b.py', 1) + tb_c_578 = traceback_lineno('c.py', 578) + + # stats per file and line + stats1 = snapshot.statistics('lineno') + self.assertEqual(stats1, [ + tracemalloc.Statistic(tb_b_1, 66, 1), + tracemalloc.Statistic(tb_a_2, 30, 3), + tracemalloc.Statistic(tb_0, 7, 1), + tracemalloc.Statistic(tb_a_5, 2, 1), + ]) + + # stats per file and line (2) + stats2 = snapshot2.statistics('lineno') + self.assertEqual(stats2, [ + tracemalloc.Statistic(tb_a_5, 5002, 2), + tracemalloc.Statistic(tb_c_578, 400, 1), + tracemalloc.Statistic(tb_a_2, 30, 3), + ]) + + # stats diff per file and line + statistics = snapshot2.compare_to(snapshot, 'lineno') + self.assertEqual(statistics, [ + tracemalloc.StatisticDiff(tb_a_5, 5002, 5000, 2, 1), + tracemalloc.StatisticDiff(tb_c_578, 400, 400, 1, 1), + tracemalloc.StatisticDiff(tb_b_1, 0, -66, 0, -1), + tracemalloc.StatisticDiff(tb_0, 0, -7, 0, -1), + tracemalloc.StatisticDiff(tb_a_2, 30, 0, 3, 0), + ]) + + def test_snapshot_group_by_file(self): + snapshot, snapshot2 = create_snapshots() + tb_0 = traceback_filename('<unknown>') + tb_a = traceback_filename('a.py') + tb_b = traceback_filename('b.py') + tb_c = traceback_filename('c.py') + + # stats per file + stats1 = snapshot.statistics('filename') + self.assertEqual(stats1, [ + tracemalloc.Statistic(tb_b, 66, 1), + tracemalloc.Statistic(tb_a, 32, 4), + tracemalloc.Statistic(tb_0, 7, 1), + ]) + + # stats per file (2) + stats2 = snapshot2.statistics('filename') + self.assertEqual(stats2, [ + tracemalloc.Statistic(tb_a, 5032, 5), + tracemalloc.Statistic(tb_c, 400, 1), + ]) + + # stats diff per file + diff = snapshot2.compare_to(snapshot, 'filename') + self.assertEqual(diff, [ + tracemalloc.StatisticDiff(tb_a, 5032, 5000, 5, 1), + tracemalloc.StatisticDiff(tb_c, 400, 400, 1, 1), + tracemalloc.StatisticDiff(tb_b, 0, -66, 0, -1), + tracemalloc.StatisticDiff(tb_0, 0, -7, 0, -1), + ]) + + def test_snapshot_group_by_traceback(self): + snapshot, snapshot2 = create_snapshots() + + # stats per file + tb1 = traceback(('a.py', 2), ('b.py', 4)) + tb2 = traceback(('a.py', 5), ('b.py', 4)) + tb3 = traceback(('b.py', 1)) + tb4 = traceback(('<unknown>', 0)) + stats1 = snapshot.statistics('traceback') + self.assertEqual(stats1, [ + tracemalloc.Statistic(tb3, 66, 1), + tracemalloc.Statistic(tb1, 30, 3), + tracemalloc.Statistic(tb4, 7, 1), + tracemalloc.Statistic(tb2, 2, 1), + ]) + + # stats per file (2) + tb5 = traceback(('c.py', 578)) + stats2 = snapshot2.statistics('traceback') + self.assertEqual(stats2, [ + tracemalloc.Statistic(tb2, 5002, 2), + tracemalloc.Statistic(tb5, 400, 1), + tracemalloc.Statistic(tb1, 30, 3), + ]) + + # stats diff per file + diff = snapshot2.compare_to(snapshot, 'traceback') + self.assertEqual(diff, [ + tracemalloc.StatisticDiff(tb2, 5002, 5000, 2, 1), + tracemalloc.StatisticDiff(tb5, 400, 400, 1, 1), + tracemalloc.StatisticDiff(tb3, 0, -66, 0, -1), + tracemalloc.StatisticDiff(tb4, 0, -7, 0, -1), + tracemalloc.StatisticDiff(tb1, 30, 0, 3, 0), + ]) + + self.assertRaises(ValueError, + snapshot.statistics, 'traceback', cumulative=True) + + def test_snapshot_group_by_cumulative(self): + snapshot, snapshot2 = create_snapshots() + tb_0 = traceback_filename('<unknown>') + tb_a = traceback_filename('a.py') + tb_b = traceback_filename('b.py') + tb_a_2 = traceback_lineno('a.py', 2) + tb_a_5 = traceback_lineno('a.py', 5) + tb_b_1 = traceback_lineno('b.py', 1) + tb_b_4 = traceback_lineno('b.py', 4) + + # per file + stats = snapshot.statistics('filename', True) + self.assertEqual(stats, [ + tracemalloc.Statistic(tb_b, 98, 5), + tracemalloc.Statistic(tb_a, 32, 4), + tracemalloc.Statistic(tb_0, 7, 1), + ]) + + # per line + stats = snapshot.statistics('lineno', True) + self.assertEqual(stats, [ + tracemalloc.Statistic(tb_b_1, 66, 1), + tracemalloc.Statistic(tb_b_4, 32, 4), + tracemalloc.Statistic(tb_a_2, 30, 3), + tracemalloc.Statistic(tb_0, 7, 1), + tracemalloc.Statistic(tb_a_5, 2, 1), + ]) + + def test_trace_format(self): + snapshot, snapshot2 = create_snapshots() + trace = snapshot.traces[0] + self.assertEqual(str(trace), 'a.py:2: 10 B') + traceback = trace.traceback + self.assertEqual(str(traceback), 'a.py:2') + frame = traceback[0] + self.assertEqual(str(frame), 'a.py:2') + + def test_statistic_format(self): + snapshot, snapshot2 = create_snapshots() + stats = snapshot.statistics('lineno') + stat = stats[0] + self.assertEqual(str(stat), + 'b.py:1: size=66 B, count=1, average=66 B') + + def test_statistic_diff_format(self): + snapshot, snapshot2 = create_snapshots() + stats = snapshot2.compare_to(snapshot, 'lineno') + stat = stats[0] + self.assertEqual(str(stat), + 'a.py:5: size=5002 B (+5000 B), count=2 (+1), average=2501 B') + + def test_slices(self): + snapshot, snapshot2 = create_snapshots() + self.assertEqual(snapshot.traces[:2], + (snapshot.traces[0], snapshot.traces[1])) + + traceback = snapshot.traces[0].traceback + self.assertEqual(traceback[:2], + (traceback[0], traceback[1])) + + def test_format_traceback(self): + snapshot, snapshot2 = create_snapshots() + def getline(filename, lineno): + return ' <%s, %s>' % (filename, lineno) + with unittest.mock.patch('tracemalloc.linecache.getline', + side_effect=getline): + tb = snapshot.traces[0].traceback + self.assertEqual(tb.format(), + [' File "a.py", line 2', + ' <a.py, 2>', + ' File "b.py", line 4', + ' <b.py, 4>']) + + self.assertEqual(tb.format(limit=1), + [' File "a.py", line 2', + ' <a.py, 2>']) + + self.assertEqual(tb.format(limit=-1), + []) + + +class TestFilters(unittest.TestCase): + maxDiff = 2048 + + def test_filter_attributes(self): + # test default values + f = tracemalloc.Filter(True, "abc") + self.assertEqual(f.inclusive, True) + self.assertEqual(f.filename_pattern, "abc") + self.assertIsNone(f.lineno) + self.assertEqual(f.all_frames, False) + + # test custom values + f = tracemalloc.Filter(False, "test.py", 123, True) + self.assertEqual(f.inclusive, False) + self.assertEqual(f.filename_pattern, "test.py") + self.assertEqual(f.lineno, 123) + self.assertEqual(f.all_frames, True) + + # parameters passed by keyword + f = tracemalloc.Filter(inclusive=False, filename_pattern="test.py", lineno=123, all_frames=True) + self.assertEqual(f.inclusive, False) + self.assertEqual(f.filename_pattern, "test.py") + self.assertEqual(f.lineno, 123) + self.assertEqual(f.all_frames, True) + + # read-only attribute + self.assertRaises(AttributeError, setattr, f, "filename_pattern", "abc") + + def test_filter_match(self): + # filter without line number + f = tracemalloc.Filter(True, "abc") + self.assertTrue(f._match_frame("abc", 0)) + self.assertTrue(f._match_frame("abc", 5)) + self.assertTrue(f._match_frame("abc", 10)) + self.assertFalse(f._match_frame("12356", 0)) + self.assertFalse(f._match_frame("12356", 5)) + self.assertFalse(f._match_frame("12356", 10)) + + f = tracemalloc.Filter(False, "abc") + self.assertFalse(f._match_frame("abc", 0)) + self.assertFalse(f._match_frame("abc", 5)) + self.assertFalse(f._match_frame("abc", 10)) + self.assertTrue(f._match_frame("12356", 0)) + self.assertTrue(f._match_frame("12356", 5)) + self.assertTrue(f._match_frame("12356", 10)) + + # filter with line number > 0 + f = tracemalloc.Filter(True, "abc", 5) + self.assertFalse(f._match_frame("abc", 0)) + self.assertTrue(f._match_frame("abc", 5)) + self.assertFalse(f._match_frame("abc", 10)) + self.assertFalse(f._match_frame("12356", 0)) + self.assertFalse(f._match_frame("12356", 5)) + self.assertFalse(f._match_frame("12356", 10)) + + f = tracemalloc.Filter(False, "abc", 5) + self.assertTrue(f._match_frame("abc", 0)) + self.assertFalse(f._match_frame("abc", 5)) + self.assertTrue(f._match_frame("abc", 10)) + self.assertTrue(f._match_frame("12356", 0)) + self.assertTrue(f._match_frame("12356", 5)) + self.assertTrue(f._match_frame("12356", 10)) + + # filter with line number 0 + f = tracemalloc.Filter(True, "abc", 0) + self.assertTrue(f._match_frame("abc", 0)) + self.assertFalse(f._match_frame("abc", 5)) + self.assertFalse(f._match_frame("abc", 10)) + self.assertFalse(f._match_frame("12356", 0)) + self.assertFalse(f._match_frame("12356", 5)) + self.assertFalse(f._match_frame("12356", 10)) + + f = tracemalloc.Filter(False, "abc", 0) + self.assertFalse(f._match_frame("abc", 0)) + self.assertTrue(f._match_frame("abc", 5)) + self.assertTrue(f._match_frame("abc", 10)) + self.assertTrue(f._match_frame("12356", 0)) + self.assertTrue(f._match_frame("12356", 5)) + self.assertTrue(f._match_frame("12356", 10)) + + def test_filter_match_filename(self): + def fnmatch(inclusive, filename, pattern): + f = tracemalloc.Filter(inclusive, pattern) + return f._match_frame(filename, 0) + + self.assertTrue(fnmatch(True, "abc", "abc")) + self.assertFalse(fnmatch(True, "12356", "abc")) + self.assertFalse(fnmatch(True, "<unknown>", "abc")) + + self.assertFalse(fnmatch(False, "abc", "abc")) + self.assertTrue(fnmatch(False, "12356", "abc")) + self.assertTrue(fnmatch(False, "<unknown>", "abc")) + + def test_filter_match_filename_joker(self): + def fnmatch(filename, pattern): + filter = tracemalloc.Filter(True, pattern) + return filter._match_frame(filename, 0) + + # empty string + self.assertFalse(fnmatch('abc', '')) + self.assertFalse(fnmatch('', 'abc')) + self.assertTrue(fnmatch('', '')) + self.assertTrue(fnmatch('', '*')) + + # no * + self.assertTrue(fnmatch('abc', 'abc')) + self.assertFalse(fnmatch('abc', 'abcd')) + self.assertFalse(fnmatch('abc', 'def')) + + # a* + self.assertTrue(fnmatch('abc', 'a*')) + self.assertTrue(fnmatch('abc', 'abc*')) + self.assertFalse(fnmatch('abc', 'b*')) + self.assertFalse(fnmatch('abc', 'abcd*')) + + # a*b + self.assertTrue(fnmatch('abc', 'a*c')) + self.assertTrue(fnmatch('abcdcx', 'a*cx')) + self.assertFalse(fnmatch('abb', 'a*c')) + self.assertFalse(fnmatch('abcdce', 'a*cx')) + + # a*b*c + self.assertTrue(fnmatch('abcde', 'a*c*e')) + self.assertTrue(fnmatch('abcbdefeg', 'a*bd*eg')) + self.assertFalse(fnmatch('abcdd', 'a*c*e')) + self.assertFalse(fnmatch('abcbdefef', 'a*bd*eg')) + + # replace .pyc and .pyo suffix with .py + self.assertTrue(fnmatch('a.pyc', 'a.py')) + self.assertTrue(fnmatch('a.pyo', 'a.py')) + self.assertTrue(fnmatch('a.py', 'a.pyc')) + self.assertTrue(fnmatch('a.py', 'a.pyo')) + + if os.name == 'nt': + # case insensitive + self.assertTrue(fnmatch('aBC', 'ABc')) + self.assertTrue(fnmatch('aBcDe', 'Ab*dE')) + + self.assertTrue(fnmatch('a.pyc', 'a.PY')) + self.assertTrue(fnmatch('a.PYO', 'a.py')) + self.assertTrue(fnmatch('a.py', 'a.PYC')) + self.assertTrue(fnmatch('a.PY', 'a.pyo')) + else: + # case sensitive + self.assertFalse(fnmatch('aBC', 'ABc')) + self.assertFalse(fnmatch('aBcDe', 'Ab*dE')) + + self.assertFalse(fnmatch('a.pyc', 'a.PY')) + self.assertFalse(fnmatch('a.PYO', 'a.py')) + self.assertFalse(fnmatch('a.py', 'a.PYC')) + self.assertFalse(fnmatch('a.PY', 'a.pyo')) + + if os.name == 'nt': + # normalize alternate separator "/" to the standard separator "\" + self.assertTrue(fnmatch(r'a/b', r'a\b')) + self.assertTrue(fnmatch(r'a\b', r'a/b')) + self.assertTrue(fnmatch(r'a/b\c', r'a\b/c')) + self.assertTrue(fnmatch(r'a/b/c', r'a\b\c')) + else: + # there is no alternate separator + self.assertFalse(fnmatch(r'a/b', r'a\b')) + self.assertFalse(fnmatch(r'a\b', r'a/b')) + self.assertFalse(fnmatch(r'a/b\c', r'a\b/c')) + self.assertFalse(fnmatch(r'a/b/c', r'a\b\c')) + + def test_filter_match_trace(self): + t1 = (("a.py", 2), ("b.py", 3)) + t2 = (("b.py", 4), ("b.py", 5)) + t3 = (("c.py", 5), ('<unknown>', 0)) + unknown = (('<unknown>', 0),) + + f = tracemalloc.Filter(True, "b.py", all_frames=True) + self.assertTrue(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertFalse(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + f = tracemalloc.Filter(True, "b.py", all_frames=False) + self.assertFalse(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertFalse(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "b.py", all_frames=True) + self.assertFalse(f._match_traceback(t1)) + self.assertFalse(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertTrue(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "b.py", all_frames=False) + self.assertTrue(f._match_traceback(t1)) + self.assertFalse(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertTrue(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "<unknown>", all_frames=False) + self.assertTrue(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + f = tracemalloc.Filter(True, "<unknown>", all_frames=True) + self.assertFalse(f._match_traceback(t1)) + self.assertFalse(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertTrue(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "<unknown>", all_frames=True) + self.assertTrue(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertFalse(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + +class TestCommandLine(unittest.TestCase): + def test_env_var(self): + # not tracing by default + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-c', code) + stdout = stdout.rstrip() + self.assertEqual(stdout, b'False') + + # PYTHON* environment variables must be ignored when -E option is + # present + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-E', '-c', code, PYTHONTRACEMALLOC='1') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'False') + + # tracing at startup + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='1') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'True') + + # start and set the number of frames + code = 'import tracemalloc; print(tracemalloc.get_traceback_limit())' + ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='10') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'10') + + def test_env_var_invalid(self): + for nframe in (-1, 0, 2**30): + with self.subTest(nframe=nframe): + with support.SuppressCrashReport(): + ok, stdout, stderr = assert_python_failure( + '-c', 'pass', + PYTHONTRACEMALLOC=str(nframe)) + self.assertIn(b'PYTHONTRACEMALLOC: invalid ' + b'number of frames', + stderr) + + def test_sys_xoptions(self): + for xoptions, nframe in ( + ('tracemalloc', 1), + ('tracemalloc=1', 1), + ('tracemalloc=15', 15), + ): + with self.subTest(xoptions=xoptions, nframe=nframe): + code = 'import tracemalloc; print(tracemalloc.get_traceback_limit())' + ok, stdout, stderr = assert_python_ok('-X', xoptions, '-c', code) + stdout = stdout.rstrip() + self.assertEqual(stdout, str(nframe).encode('ascii')) + + def test_sys_xoptions_invalid(self): + for nframe in (-1, 0, 2**30): + with self.subTest(nframe=nframe): + with support.SuppressCrashReport(): + args = ('-X', 'tracemalloc=%s' % nframe, '-c', 'pass') + ok, stdout, stderr = assert_python_failure(*args) + self.assertIn(b'-X tracemalloc=NFRAME: invalid ' + b'number of frames', + stderr) + + +def test_main(): + support.run_unittest( + TestTracemallocEnabled, + TestSnapshot, + TestFilters, + TestCommandLine, + ) + +if __name__ == "__main__": + test_main() diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 3ee4c6bc5f..18e6b0a89c 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -1,7 +1,8 @@ # Python test set -- part 6, built-in types -from test.support import run_unittest, run_with_locale +from test.support import run_unittest, run_with_locale, cpython_only import collections +import pickle import locale import sys import types @@ -1077,9 +1078,19 @@ class SimpleNamespaceTests(unittest.TestCase): ns2 = types.SimpleNamespace() ns2.x = "spam" ns2._y = 5 + name = "namespace" - self.assertEqual(repr(ns1), "namespace(w=3, x=1, y=2)") - self.assertEqual(repr(ns2), "namespace(_y=5, x='spam')") + self.assertEqual(repr(ns1), "{name}(w=3, x=1, y=2)".format(name=name)) + self.assertEqual(repr(ns2), "{name}(_y=5, x='spam')".format(name=name)) + + def test_equal(self): + ns1 = types.SimpleNamespace(x=1) + ns2 = types.SimpleNamespace() + ns2.x = 1 + + self.assertEqual(types.SimpleNamespace(), types.SimpleNamespace()) + self.assertEqual(ns1, ns2) + self.assertNotEqual(ns2, types.SimpleNamespace()) def test_nested(self): ns1 = types.SimpleNamespace(a=1, b=2) @@ -1117,11 +1128,12 @@ class SimpleNamespaceTests(unittest.TestCase): ns1.spam = ns1 ns2.spam = ns3 ns3.spam = ns2 + name = "namespace" + repr1 = "{name}(c='cookie', spam={name}(...))".format(name=name) + repr2 = "{name}(spam={name}(spam={name}(...), x=1))".format(name=name) - self.assertEqual(repr(ns1), - "namespace(c='cookie', spam=namespace(...))") - self.assertEqual(repr(ns2), - "namespace(spam=namespace(spam=namespace(...), x=1))") + self.assertEqual(repr(ns1), repr1) + self.assertEqual(repr(ns2), repr2) def test_as_dict(self): ns = types.SimpleNamespace(spam='spamspamspam') @@ -1144,10 +1156,45 @@ class SimpleNamespaceTests(unittest.TestCase): self.assertIs(type(spam), Spam) self.assertEqual(vars(spam), {'ham': 8, 'eggs': 9}) + def test_pickle(self): + ns = types.SimpleNamespace(breakfast="spam", lunch="spam") + + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + pname = "protocol {}".format(protocol) + try: + ns_pickled = pickle.dumps(ns, protocol) + except TypeError as e: + raise TypeError(pname) from e + ns_roundtrip = pickle.loads(ns_pickled) + + self.assertEqual(ns, ns_roundtrip, pname) + + +class SharedKeyTests(unittest.TestCase): + + @cpython_only + def test_subclasses(self): + # Verify that subclasses can share keys (per PEP 412) + class A: + pass + class B(A): + pass + + a, b = A(), B() + self.assertEqual(sys.getsizeof(vars(a)), sys.getsizeof(vars(b))) + self.assertLess(sys.getsizeof(vars(a)), sys.getsizeof({})) + a.x, a.y, a.z, a.w = range(4) + self.assertNotEqual(sys.getsizeof(vars(a)), sys.getsizeof(vars(b))) + a2 = A() + self.assertEqual(sys.getsizeof(vars(a)), sys.getsizeof(vars(a2))) + self.assertLess(sys.getsizeof(vars(a)), sys.getsizeof({})) + b.u, b.v, b.w, b.t = range(4) + self.assertLess(sys.getsizeof(vars(b)), sys.getsizeof({})) + def test_main(): run_unittest(TypesTests, MappingProxyTests, ClassCreationTests, - SimpleNamespaceTests) + SimpleNamespaceTests, SharedKeyTests) if __name__ == '__main__': test_main() diff --git a/Lib/test/test_ucn.py b/Lib/test/test_ucn.py index de71680c63..be7e9cdcd7 100644 --- a/Lib/test/test_ucn.py +++ b/Lib/test/test_ucn.py @@ -177,7 +177,7 @@ class UnicodeNamesTest(unittest.TestCase): try: testdata = support.open_urlresource(url, encoding="utf-8", check=check_version) - except (IOError, HTTPException): + except (OSError, HTTPException): self.skipTest("Could not retrieve " + url) self.addCleanup(testdata.close) for line in testdata: diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index c2ede07a83..7e7091872f 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -7,6 +7,7 @@ Written by Marc-Andre Lemburg (mal@lemburg.com). """#" import _string import codecs +import itertools import struct import sys import unittest @@ -31,6 +32,16 @@ def search_function(encoding): return None codecs.register(search_function) +def duplicate_string(text): + """ + Try to get a fresh clone of the specified text: + new object with a reference count of 1. + + This is a best-effort: latin1 single letters and the empty + string ('') are singletons and cannot be cloned. + """ + return text.encode().decode() + class UnicodeTest(string_tests.CommonTest, string_tests.MixinStrUnicodeUserStringTest, string_tests.MixinStrUnicodeTest, @@ -863,11 +874,9 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual('{0:d}'.format(G('data')), 'G(data)') self.assertEqual('{0!s}'.format(G('data')), 'string is data') - msg = 'object.__format__ with a non-empty format string is deprecated' - with support.check_warnings((msg, DeprecationWarning)): - self.assertEqual('{0:^10}'.format(E('data')), ' E(data) ') - self.assertEqual('{0:^10s}'.format(E('data')), ' E(data) ') - self.assertEqual('{0:>15s}'.format(G('data')), ' string is data') + self.assertRaises(TypeError, '{0:^10}'.format, E('data')) + self.assertRaises(TypeError, '{0:^10s}'.format, E('data')) + self.assertRaises(TypeError, '{0:>15s}'.format, G('data')) self.assertEqual("{0:date: %Y-%m-%d}".format(I(year=2007, month=8, @@ -903,7 +912,7 @@ class UnicodeTest(string_tests.CommonTest, self.assertRaises(ValueError, "{0".format) self.assertRaises(IndexError, "{0.}".format) self.assertRaises(ValueError, "{0.}".format, 0) - self.assertRaises(IndexError, "{0[}".format) + self.assertRaises(ValueError, "{0[}".format) self.assertRaises(ValueError, "{0[}".format, []) self.assertRaises(KeyError, "{0]}".format) self.assertRaises(ValueError, "{0.[]}".format, 0) @@ -955,6 +964,15 @@ class UnicodeTest(string_tests.CommonTest, '') self.assertEqual("{[{}]}".format({"{}": 5}), "5") + self.assertEqual("{[{}]}".format({"{}" : "a"}), "a") + self.assertEqual("{[{]}".format({"{" : "a"}), "a") + self.assertEqual("{[}]}".format({"}" : "a"}), "a") + self.assertEqual("{[[]}".format({"[" : "a"}), "a") + self.assertEqual("{[!]}".format({"!" : "a"}), "a") + self.assertRaises(ValueError, "{a{}b}".format, 42) + self.assertRaises(ValueError, "{a{b}".format, 42) + self.assertRaises(ValueError, "{[}".format, 42) + self.assertEqual("0x{:0{:d}X}".format(0x0,16), "0x0000000000000000") def test_format_map(self): @@ -1108,6 +1126,67 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual('%.1s' % "a\xe9\u20ac", 'a') self.assertEqual('%.2s' % "a\xe9\u20ac", 'a\xe9') + #issue 19995 + class PsuedoInt: + def __init__(self, value): + self.value = int(value) + def __int__(self): + return self.value + def __index__(self): + return self.value + class PsuedoFloat: + def __init__(self, value): + self.value = float(value) + def __int__(self): + return int(self.value) + pi = PsuedoFloat(3.1415) + letter_m = PsuedoInt(109) + self.assertEqual('%x' % 42, '2a') + self.assertEqual('%X' % 15, 'F') + self.assertEqual('%o' % 9, '11') + self.assertEqual('%c' % 109, 'm') + self.assertEqual('%x' % letter_m, '6d') + self.assertEqual('%X' % letter_m, '6D') + self.assertEqual('%o' % letter_m, '155') + self.assertEqual('%c' % letter_m, 'm') + self.assertWarns(DeprecationWarning, '%x'.__mod__, pi), + self.assertWarns(DeprecationWarning, '%x'.__mod__, 3.14), + self.assertWarns(DeprecationWarning, '%X'.__mod__, 2.11), + self.assertWarns(DeprecationWarning, '%o'.__mod__, 1.79), + self.assertWarns(DeprecationWarning, '%c'.__mod__, pi), + + def test_formatting_with_enum(self): + # issue18780 + import enum + class Float(float, enum.Enum): + PI = 3.1415926 + class Int(enum.IntEnum): + IDES = 15 + class Str(str, enum.Enum): + ABC = 'abc' + # Testing Unicode formatting strings... + self.assertEqual("%s, %s" % (Str.ABC, Str.ABC), + 'Str.ABC, Str.ABC') + self.assertEqual("%s, %s, %d, %i, %u, %f, %5.2f" % + (Str.ABC, Str.ABC, + Int.IDES, Int.IDES, Int.IDES, + Float.PI, Float.PI), + 'Str.ABC, Str.ABC, 15, 15, 15, 3.141593, 3.14') + + # formatting jobs delegated from the string implementation: + self.assertEqual('...%(foo)s...' % {'foo':Str.ABC}, + '...Str.ABC...') + self.assertEqual('...%(foo)s...' % {'foo':Int.IDES}, + '...Int.IDES...') + self.assertEqual('...%(foo)i...' % {'foo':Int.IDES}, + '...15...') + self.assertEqual('...%(foo)d...' % {'foo':Int.IDES}, + '...15...') + self.assertEqual('...%(foo)u...' % {'foo':Int.IDES, 'def':Float.PI}, + '...15...') + self.assertEqual('...%(foo)f...' % {'foo':Float.PI,'def':123}, + '...3.141593...') + def test_formatting_huge_precision(self): format_string = "%.{}f".format(sys.maxsize + 1) with self.assertRaises(ValueError): @@ -1788,10 +1867,10 @@ class UnicodeTest(string_tests.CommonTest, # 0-127 s = bytes(range(128)) for encoding in ( - 'cp037', 'cp1026', + 'cp037', 'cp1026', 'cp273', 'cp437', 'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp858', 'cp860', 'cp861', 'cp862', - 'cp863', 'cp865', 'cp866', + 'cp863', 'cp865', 'cp866', 'cp1125', 'iso8859_10', 'iso8859_13', 'iso8859_14', 'iso8859_15', 'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', 'iso8859_9', 'koi8_r', 'latin_1', @@ -1816,10 +1895,10 @@ class UnicodeTest(string_tests.CommonTest, # 128-255 s = bytes(range(128, 256)) for encoding in ( - 'cp037', 'cp1026', + 'cp037', 'cp1026', 'cp273', 'cp437', 'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp858', 'cp860', 'cp861', 'cp862', - 'cp863', 'cp865', 'cp866', + 'cp863', 'cp865', 'cp866', 'cp1125', 'iso8859_10', 'iso8859_13', 'iso8859_14', 'iso8859_15', 'iso8859_2', 'iso8859_4', 'iso8859_5', 'iso8859_9', 'koi8_r', 'latin_1', @@ -2010,9 +2089,10 @@ class UnicodeTest(string_tests.CommonTest, # Test PyUnicode_FromFormat() def test_from_format(self): support.import_module('ctypes') - from ctypes import (pythonapi, py_object, + from ctypes import ( + pythonapi, py_object, sizeof, c_int, c_long, c_longlong, c_ssize_t, - c_uint, c_ulong, c_ulonglong, c_size_t) + c_uint, c_ulong, c_ulonglong, c_size_t, c_void_p) name = "PyUnicode_FromFormat" _PyUnicode_FromFormat = getattr(pythonapi, name) _PyUnicode_FromFormat.restype = py_object @@ -2023,9 +2103,13 @@ class UnicodeTest(string_tests.CommonTest, for arg in args) return _PyUnicode_FromFormat(format, *cargs) + def check_format(expected, format, *args): + text = PyUnicode_FromFormat(format, *args) + self.assertEqual(expected, text) + # ascii format, non-ascii argument - text = PyUnicode_FromFormat(b'ascii\x7f=%U', 'unicode\xe9') - self.assertEqual(text, 'ascii\x7f=unicode\xe9') + check_format('ascii\x7f=unicode\xe9', + b'ascii\x7f=%U', 'unicode\xe9') # non-ascii format, ascii argument: ensure that PyUnicode_FromFormatV() # raises an error @@ -2035,64 +2119,205 @@ class UnicodeTest(string_tests.CommonTest, PyUnicode_FromFormat, b'unicode\xe9=%s', 'ascii') # test "%c" - self.assertEqual(PyUnicode_FromFormat(b'%c', c_int(0xabcd)), '\uabcd') - self.assertEqual(PyUnicode_FromFormat(b'%c', c_int(0x10ffff)), '\U0010ffff') + check_format('\uabcd', + b'%c', c_int(0xabcd)) + check_format('\U0010ffff', + b'%c', c_int(0x10ffff)) with self.assertRaises(OverflowError): PyUnicode_FromFormat(b'%c', c_int(0x110000)) # Issue #18183 - self.assertEqual( - PyUnicode_FromFormat(b'%c%c', c_int(0x10000), c_int(0x100000)), - '\U00010000\U00100000') + check_format('\U00010000\U00100000', + b'%c%c', c_int(0x10000), c_int(0x100000)) # test "%" - self.assertEqual(PyUnicode_FromFormat(b'%'), '%') - self.assertEqual(PyUnicode_FromFormat(b'%%'), '%') - self.assertEqual(PyUnicode_FromFormat(b'%%s'), '%s') - self.assertEqual(PyUnicode_FromFormat(b'[%%]'), '[%]') - self.assertEqual(PyUnicode_FromFormat(b'%%%s', b'abc'), '%abc') + check_format('%', + b'%') + check_format('%', + b'%%') + check_format('%s', + b'%%s') + check_format('[%]', + b'[%%]') + check_format('%abc', + b'%%%s', b'abc') + + # truncated string + check_format('abc', + b'%.3s', b'abcdef') + check_format('abc[\ufffd', + b'%.5s', 'abc[\u20ac]'.encode('utf8')) + check_format("'\\u20acABC'", + b'%A', '\u20acABC') + check_format("'\\u20", + b'%.5A', '\u20acABCDEF') + check_format("'\u20acABC'", + b'%R', '\u20acABC') + check_format("'\u20acA", + b'%.3R', '\u20acABCDEF') + check_format('\u20acAB', + b'%.3S', '\u20acABCDEF') + check_format('\u20acAB', + b'%.3U', '\u20acABCDEF') + check_format('\u20acAB', + b'%.3V', '\u20acABCDEF', None) + check_format('abc[\ufffd', + b'%.5V', None, 'abc[\u20ac]'.encode('utf8')) + + # following tests comes from #7330 + # test width modifier and precision modifier with %S + check_format("repr= abc", + b'repr=%5S', 'abc') + check_format("repr=ab", + b'repr=%.2S', 'abc') + check_format("repr= ab", + b'repr=%5.2S', 'abc') + + # test width modifier and precision modifier with %R + check_format("repr= 'abc'", + b'repr=%8R', 'abc') + check_format("repr='ab", + b'repr=%.3R', 'abc') + check_format("repr= 'ab", + b'repr=%5.3R', 'abc') + + # test width modifier and precision modifier with %A + check_format("repr= 'abc'", + b'repr=%8A', 'abc') + check_format("repr='ab", + b'repr=%.3A', 'abc') + check_format("repr= 'ab", + b'repr=%5.3A', 'abc') + + # test width modifier and precision modifier with %s + check_format("repr= abc", + b'repr=%5s', b'abc') + check_format("repr=ab", + b'repr=%.2s', b'abc') + check_format("repr= ab", + b'repr=%5.2s', b'abc') + + # test width modifier and precision modifier with %U + check_format("repr= abc", + b'repr=%5U', 'abc') + check_format("repr=ab", + b'repr=%.2U', 'abc') + check_format("repr= ab", + b'repr=%5.2U', 'abc') + + # test width modifier and precision modifier with %V + check_format("repr= abc", + b'repr=%5V', 'abc', b'123') + check_format("repr=ab", + b'repr=%.2V', 'abc', b'123') + check_format("repr= ab", + b'repr=%5.2V', 'abc', b'123') + check_format("repr= 123", + b'repr=%5V', None, b'123') + check_format("repr=12", + b'repr=%.2V', None, b'123') + check_format("repr= 12", + b'repr=%5.2V', None, b'123') # test integer formats (%i, %d, %u) - self.assertEqual(PyUnicode_FromFormat(b'%03i', c_int(10)), '010') - self.assertEqual(PyUnicode_FromFormat(b'%0.4i', c_int(10)), '0010') - self.assertEqual(PyUnicode_FromFormat(b'%i', c_int(-123)), '-123') - self.assertEqual(PyUnicode_FromFormat(b'%li', c_long(-123)), '-123') - self.assertEqual(PyUnicode_FromFormat(b'%lli', c_longlong(-123)), '-123') - self.assertEqual(PyUnicode_FromFormat(b'%zi', c_ssize_t(-123)), '-123') - - self.assertEqual(PyUnicode_FromFormat(b'%d', c_int(-123)), '-123') - self.assertEqual(PyUnicode_FromFormat(b'%ld', c_long(-123)), '-123') - self.assertEqual(PyUnicode_FromFormat(b'%lld', c_longlong(-123)), '-123') - self.assertEqual(PyUnicode_FromFormat(b'%zd', c_ssize_t(-123)), '-123') - - self.assertEqual(PyUnicode_FromFormat(b'%u', c_uint(123)), '123') - self.assertEqual(PyUnicode_FromFormat(b'%lu', c_ulong(123)), '123') - self.assertEqual(PyUnicode_FromFormat(b'%llu', c_ulonglong(123)), '123') - self.assertEqual(PyUnicode_FromFormat(b'%zu', c_size_t(123)), '123') + check_format('010', + b'%03i', c_int(10)) + check_format('0010', + b'%0.4i', c_int(10)) + check_format('-123', + b'%i', c_int(-123)) + check_format('-123', + b'%li', c_long(-123)) + check_format('-123', + b'%lli', c_longlong(-123)) + check_format('-123', + b'%zi', c_ssize_t(-123)) + + check_format('-123', + b'%d', c_int(-123)) + check_format('-123', + b'%ld', c_long(-123)) + check_format('-123', + b'%lld', c_longlong(-123)) + check_format('-123', + b'%zd', c_ssize_t(-123)) + + check_format('123', + b'%u', c_uint(123)) + check_format('123', + b'%lu', c_ulong(123)) + check_format('123', + b'%llu', c_ulonglong(123)) + check_format('123', + b'%zu', c_size_t(123)) + + # test long output + min_longlong = -(2 ** (8 * sizeof(c_longlong) - 1)) + max_longlong = -min_longlong - 1 + check_format(str(min_longlong), + b'%lld', c_longlong(min_longlong)) + check_format(str(max_longlong), + b'%lld', c_longlong(max_longlong)) + max_ulonglong = 2 ** (8 * sizeof(c_ulonglong)) - 1 + check_format(str(max_ulonglong), + b'%llu', c_ulonglong(max_ulonglong)) + PyUnicode_FromFormat(b'%p', c_void_p(-1)) + + # test padding (width and/or precision) + check_format('123'.rjust(10, '0'), + b'%010i', c_int(123)) + check_format('123'.rjust(100), + b'%100i', c_int(123)) + check_format('123'.rjust(100, '0'), + b'%.100i', c_int(123)) + check_format('123'.rjust(80, '0').rjust(100), + b'%100.80i', c_int(123)) + + check_format('123'.rjust(10, '0'), + b'%010u', c_uint(123)) + check_format('123'.rjust(100), + b'%100u', c_uint(123)) + check_format('123'.rjust(100, '0'), + b'%.100u', c_uint(123)) + check_format('123'.rjust(80, '0').rjust(100), + b'%100.80u', c_uint(123)) + + check_format('123'.rjust(10, '0'), + b'%010x', c_int(0x123)) + check_format('123'.rjust(100), + b'%100x', c_int(0x123)) + check_format('123'.rjust(100, '0'), + b'%.100x', c_int(0x123)) + check_format('123'.rjust(80, '0').rjust(100), + b'%100.80x', c_int(0x123)) # test %A - text = PyUnicode_FromFormat(b'%%A:%A', 'abc\xe9\uabcd\U0010ffff') - self.assertEqual(text, r"%A:'abc\xe9\uabcd\U0010ffff'") + check_format(r"%A:'abc\xe9\uabcd\U0010ffff'", + b'%%A:%A', 'abc\xe9\uabcd\U0010ffff') # test %V - text = PyUnicode_FromFormat(b'repr=%V', 'abc', b'xyz') - self.assertEqual(text, 'repr=abc') + check_format('repr=abc', + b'repr=%V', 'abc', b'xyz') # Test string decode from parameter of %s using utf-8. # b'\xe4\xba\xba\xe6\xb0\x91' is utf-8 encoded byte sequence of # '\u4eba\u6c11' - text = PyUnicode_FromFormat(b'repr=%V', None, b'\xe4\xba\xba\xe6\xb0\x91') - self.assertEqual(text, 'repr=\u4eba\u6c11') + check_format('repr=\u4eba\u6c11', + b'repr=%V', None, b'\xe4\xba\xba\xe6\xb0\x91') #Test replace error handler. - text = PyUnicode_FromFormat(b'repr=%V', None, b'abc\xff') - self.assertEqual(text, 'repr=abc\ufffd') + check_format('repr=abc\ufffd', + b'repr=%V', None, b'abc\xff') # not supported: copy the raw format string. these tests are just here # to check for crashs and should not be considered as specifications - self.assertEqual(PyUnicode_FromFormat(b'%1%s', b'abc'), '%s') - self.assertEqual(PyUnicode_FromFormat(b'%1abc'), '%1abc') - self.assertEqual(PyUnicode_FromFormat(b'%+i', c_int(10)), '%+i') - self.assertEqual(PyUnicode_FromFormat(b'%.%s', b'abc'), '%.%s') + check_format('%s', + b'%1%s', b'abc') + check_format('%1abc', + b'%1abc') + check_format('%+i', + b'%+i', c_int(10)) + check_format('%.%s', + b'%.%s', b'abc') # Test PyUnicode_AsWideChar() @support.cpython_only @@ -2220,6 +2445,80 @@ class UnicodeTest(string_tests.CommonTest, self.assertNotEqual(abc, abcdef) self.assertEqual(abcdef.decode('unicode_internal'), text) + def test_compare(self): + # Issue #17615 + N = 10 + ascii = 'a' * N + ascii2 = 'z' * N + latin = '\x80' * N + latin2 = '\xff' * N + bmp = '\u0100' * N + bmp2 = '\uffff' * N + astral = '\U00100000' * N + astral2 = '\U0010ffff' * N + strings = ( + ascii, ascii2, + latin, latin2, + bmp, bmp2, + astral, astral2) + for text1, text2 in itertools.combinations(strings, 2): + equal = (text1 is text2) + self.assertEqual(text1 == text2, equal) + self.assertEqual(text1 != text2, not equal) + + if equal: + self.assertTrue(text1 <= text2) + self.assertTrue(text1 >= text2) + + # text1 is text2: duplicate strings to skip the "str1 == str2" + # optimization in unicode_compare_eq() and really compare + # character per character + copy1 = duplicate_string(text1) + copy2 = duplicate_string(text2) + self.assertIsNot(copy1, copy2) + + self.assertTrue(copy1 == copy2) + self.assertFalse(copy1 != copy2) + + self.assertTrue(copy1 <= copy2) + self.assertTrue(copy2 >= copy2) + + self.assertTrue(ascii < ascii2) + self.assertTrue(ascii < latin) + self.assertTrue(ascii < bmp) + self.assertTrue(ascii < astral) + self.assertFalse(ascii >= ascii2) + self.assertFalse(ascii >= latin) + self.assertFalse(ascii >= bmp) + self.assertFalse(ascii >= astral) + + self.assertFalse(latin < ascii) + self.assertTrue(latin < latin2) + self.assertTrue(latin < bmp) + self.assertTrue(latin < astral) + self.assertTrue(latin >= ascii) + self.assertFalse(latin >= latin2) + self.assertFalse(latin >= bmp) + self.assertFalse(latin >= astral) + + self.assertFalse(bmp < ascii) + self.assertFalse(bmp < latin) + self.assertTrue(bmp < bmp2) + self.assertTrue(bmp < astral) + self.assertTrue(bmp >= ascii) + self.assertTrue(bmp >= latin) + self.assertFalse(bmp >= bmp2) + self.assertFalse(bmp >= astral) + + self.assertFalse(astral < ascii) + self.assertFalse(astral < latin) + self.assertFalse(astral < bmp2) + self.assertTrue(astral < astral2) + self.assertTrue(astral >= ascii) + self.assertTrue(astral >= latin) + self.assertTrue(astral >= bmp2) + self.assertFalse(astral >= astral2) + class StringModuleTest(unittest.TestCase): def test_formatter_parser(self): diff --git a/Lib/test/test_unicodedata.py b/Lib/test/test_unicodedata.py index 99aa0033cd..707b30e50c 100644 --- a/Lib/test/test_unicodedata.py +++ b/Lib/test/test_unicodedata.py @@ -21,7 +21,7 @@ errors = 'surrogatepass' class UnicodeMethodsTest(unittest.TestCase): # update this, if the database changes - expectedchecksum = 'bf7a78f1a532421b5033600102e23a92044dbba9' + expectedchecksum = 'e74e878de71b6e780ffac271785c3cb58f6251f3' def test_method_checksum(self): h = hashlib.sha1() @@ -80,7 +80,7 @@ class UnicodeDatabaseTest(unittest.TestCase): class UnicodeFunctionsTest(UnicodeDatabaseTest): # update this, if the database changes - expectedchecksum = '17fe2f12b788e4fff5479b469c4404bb6ecf841f' + expectedchecksum = 'f0b74d26776331cc7bdc3a4698f037d73f2cee2b' def test_function_checksum(self): data = [] h = hashlib.sha1() diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py index 2dec4e9f88..1a40566620 100644 --- a/Lib/test/test_urllib.py +++ b/Lib/test/test_urllib.py @@ -16,6 +16,7 @@ from nturl2path import url2pathname, pathname2url from base64 import b64encode import collections + def hexescape(char): """Escape char as RFC 2396 specifies""" hex_repr = hex(ord(char))[2:].upper() @@ -238,7 +239,7 @@ class urlopen_HttpTests(unittest.TestCase, FakeHTTPMixin): self.check_read(b"1.1") def test_read_bogus(self): - # urlopen() should raise IOError for many error codes. + # urlopen() should raise OSError for many error codes. self.fakehttp(b'''HTTP/1.1 401 Authentication Required Date: Wed, 02 Jan 2008 03:03:54 GMT Server: Apache/1.3.33 (Debian GNU/Linux) mod_ssl/2.8.22 OpenSSL/0.9.7e @@ -246,12 +247,12 @@ Connection: close Content-Type: text/html; charset=iso-8859-1 ''') try: - self.assertRaises(IOError, urlopen, "http://python.org/") + self.assertRaises(OSError, urlopen, "http://python.org/") finally: self.unfakehttp() def test_invalid_redirect(self): - # urlopen() should raise IOError for many error codes. + # urlopen() should raise OSError for many error codes. self.fakehttp(b'''HTTP/1.1 302 Found Date: Wed, 02 Jan 2008 03:03:54 GMT Server: Apache/1.3.33 (Debian GNU/Linux) mod_ssl/2.8.22 OpenSSL/0.9.7e @@ -266,19 +267,20 @@ Content-Type: text/html; charset=iso-8859-1 self.unfakehttp() def test_empty_socket(self): - # urlopen() raises IOError if the underlying socket does not send any + # urlopen() raises OSError if the underlying socket does not send any # data. (#1680230) self.fakehttp(b'') try: - self.assertRaises(IOError, urlopen, "http://something") + self.assertRaises(OSError, urlopen, "http://something") finally: self.unfakehttp() def test_missing_localfile(self): # Test for #10836 - # 3.3 - URLError is not captured, explicit IOError is raised. - with self.assertRaises(IOError): + with self.assertRaises(urllib.error.URLError) as e: urlopen('file://localhost/a/file/which/doesnot/exists.py') + self.assertTrue(e.exception.filename) + self.assertTrue(e.exception.reason) def test_file_notexists(self): fd, tmp_file = tempfile.mkstemp() @@ -291,20 +293,21 @@ Content-Type: text/html; charset=iso-8859-1 os.close(fd) os.unlink(tmp_file) self.assertFalse(os.path.exists(tmp_file)) - # 3.3 - IOError instead of URLError - with self.assertRaises(IOError): + with self.assertRaises(urllib.error.URLError): urlopen(tmp_fileurl) def test_ftp_nohost(self): test_ftp_url = 'ftp:///path' - # 3.3 - IOError instead of URLError - with self.assertRaises(IOError): + with self.assertRaises(urllib.error.URLError) as e: urlopen(test_ftp_url) + self.assertFalse(e.exception.filename) + self.assertTrue(e.exception.reason) def test_ftp_nonexisting(self): - # 3.3 - IOError instead of URLError - with self.assertRaises(IOError): + with self.assertRaises(urllib.error.URLError) as e: urlopen('ftp://localhost/a/file/which/doesnot/exists.py') + self.assertFalse(e.exception.filename) + self.assertTrue(e.exception.reason) def test_userpass_inurl(self): @@ -341,6 +344,79 @@ Content-Type: text/html; charset=iso-8859-1 with support.check_warnings(('',DeprecationWarning)): urllib.request.URLopener() +class urlopen_DataTests(unittest.TestCase): + """Test urlopen() opening a data URL.""" + + def setUp(self): + # text containing URL special- and unicode-characters + self.text = "test data URLs :;,%=& \u00f6 \u00c4 " + # 2x1 pixel RGB PNG image with one black and one white pixel + self.image = ( + b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x02\x00\x00\x00' + b'\x01\x08\x02\x00\x00\x00{@\xe8\xdd\x00\x00\x00\x01sRGB\x00\xae' + b'\xce\x1c\xe9\x00\x00\x00\x0fIDAT\x08\xd7c```\xf8\xff\xff?\x00' + b'\x06\x01\x02\xfe\no/\x1e\x00\x00\x00\x00IEND\xaeB`\x82') + + self.text_url = ( + "data:text/plain;charset=UTF-8,test%20data%20URLs%20%3A%3B%2C%25%3" + "D%26%20%C3%B6%20%C3%84%20") + self.text_url_base64 = ( + "data:text/plain;charset=ISO-8859-1;base64,dGVzdCBkYXRhIFVSTHMgOjs" + "sJT0mIPYgxCA%3D") + # base64 encoded data URL that contains ignorable spaces, + # such as "\n", " ", "%0A", and "%20". + self.image_url = ( + "\n" + "QOjdAAAAAXNSR0IArs4c6QAAAA9JREFUCNdj%0AYGBg%2BP//PwAGAQL%2BCm8 " + "vHgAAAABJRU5ErkJggg%3D%3D%0A%20") + + self.text_url_resp = urllib.request.urlopen(self.text_url) + self.text_url_base64_resp = urllib.request.urlopen( + self.text_url_base64) + self.image_url_resp = urllib.request.urlopen(self.image_url) + + def test_interface(self): + # Make sure object returned by urlopen() has the specified methods + for attr in ("read", "readline", "readlines", + "close", "info", "geturl", "getcode", "__iter__"): + self.assertTrue(hasattr(self.text_url_resp, attr), + "object returned by urlopen() lacks %s attribute" % + attr) + + def test_info(self): + self.assertIsInstance(self.text_url_resp.info(), email.message.Message) + self.assertEqual(self.text_url_base64_resp.info().get_params(), + [('text/plain', ''), ('charset', 'ISO-8859-1')]) + self.assertEqual(self.image_url_resp.info()['content-length'], + str(len(self.image))) + self.assertEqual(urllib.request.urlopen("data:,").info().get_params(), + [('text/plain', ''), ('charset', 'US-ASCII')]) + + def test_geturl(self): + self.assertEqual(self.text_url_resp.geturl(), self.text_url) + self.assertEqual(self.text_url_base64_resp.geturl(), + self.text_url_base64) + self.assertEqual(self.image_url_resp.geturl(), self.image_url) + + def test_read_text(self): + self.assertEqual(self.text_url_resp.read().decode( + dict(self.text_url_resp.info().get_params())['charset']), self.text) + + def test_read_text_base64(self): + self.assertEqual(self.text_url_base64_resp.read().decode( + dict(self.text_url_base64_resp.info().get_params())['charset']), + self.text) + + def test_read_image(self): + self.assertEqual(self.image_url_resp.read(), self.image) + + def test_missing_comma(self): + self.assertRaises(ValueError,urllib.request.urlopen,'data:text/plain') + + def test_invalid_base64_data(self): + # missing padding character + self.assertRaises(ValueError,urllib.request.urlopen,'data:;base64,Cg=') + class urlretrieve_FileTests(unittest.TestCase): """Test urllib.urlretrieve() on local files""" @@ -1291,6 +1367,7 @@ class URLopener_Tests(unittest.TestCase): # self.assertEqual(ftp.ftp.sock.gettimeout(), 30) # ftp.close() + class RequestTests(unittest.TestCase): """Unit tests for urllib.request.Request.""" diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py index 3559b1b789..b32b4090e8 100644 --- a/Lib/test/test_urllib2.py +++ b/Lib/test/test_urllib2.py @@ -11,6 +11,7 @@ import urllib.request # The proxy bypass method imported below has logic specific to the OSX # proxy config data structure but is testable on all platforms. from urllib.request import Request, OpenerDirector, _proxy_bypass_macosx_sysconf +from urllib.parse import urlparse import urllib.error # XXX @@ -118,6 +119,15 @@ class RequestHdrsTests(unittest.TestCase): self.assertIsNone(req.get_header("Not-there")) self.assertEqual(req.get_header("Not-there", "default"), "default") + req.remove_header("Spam-eggs") + self.assertFalse(req.has_header("Spam-eggs")) + + req.add_unredirected_header("Unredirected-spam", "Eggs") + self.assertTrue(req.has_header("Unredirected-spam")) + + req.remove_header("Unredirected-spam") + self.assertFalse(req.has_header("Unredirected-spam")) + def test_password_manager(self): mgr = urllib.request.HTTPPasswordMgr() @@ -311,8 +321,7 @@ class MockHTTPClass: if body: self.data = body if self.raise_on_endheaders: - import socket - raise socket.error() + raise OSError() def getresponse(self): return MockHTTPResponse(MockFile(), {}, 200, "OK") @@ -597,27 +606,6 @@ class OpenerDirectorTests(unittest.TestCase): if args[1] is not None: self.assertIsInstance(args[1], MockResponse) - def test_method_deprecations(self): - req = Request("http://www.example.com") - - with self.assertWarns(DeprecationWarning): - req.add_data("data") - with self.assertWarns(DeprecationWarning): - req.get_data() - with self.assertWarns(DeprecationWarning): - req.has_data() - with self.assertWarns(DeprecationWarning): - req.get_host() - with self.assertWarns(DeprecationWarning): - req.get_selector() - with self.assertWarns(DeprecationWarning): - req.is_unverifiable() - with self.assertWarns(DeprecationWarning): - req.get_origin_req_host() - with self.assertWarns(DeprecationWarning): - req.get_type() - - def sanepathname2url(path): try: path.encode("utf-8") @@ -812,7 +800,7 @@ class HandlerTests(unittest.TestCase): ("Foo", "bar"), ("Spam", "eggs")]) self.assertEqual(http.data, data) - # check socket.error converted to URLError + # check OSError converted to URLError http.raise_on_endheaders = True self.assertRaises(urllib.error.URLError, h.do_open, http, req) @@ -917,6 +905,36 @@ class HandlerTests(unittest.TestCase): p_ds_req = h.do_request_(ds_req) self.assertEqual(p_ds_req.unredirected_hdrs["Host"],"example.com") + def test_full_url_setter(self): + # Checks to ensure that components are set correctly after setting the + # full_url of a Request object + + urls = [ + 'http://example.com?foo=bar#baz', + 'http://example.com?foo=bar&spam=eggs#bash', + 'http://example.com', + ] + + # testing a reusable request instance, but the url parameter is + # required, so just use a dummy one to instantiate + r = Request('http://example.com') + for url in urls: + r.full_url = url + parsed = urlparse(url) + + self.assertEqual(r.get_full_url(), url) + # full_url setter uses splittag to split into components. + # splittag sets the fragment as None while urlparse sets it to '' + self.assertEqual(r.fragment or '', parsed.fragment) + self.assertEqual(urlparse(r.get_full_url()).query, parsed.query) + + def test_full_url_deleter(self): + r = Request('http://www.example.com') + del r.full_url + self.assertIsNone(r.full_url) + self.assertIsNone(r.fragment) + self.assertEqual(r.selector, '') + def test_fixpath_in_weirdurls(self): # Issue4493: urllib2 to supply '/' when to urls where path does not # start with'/' @@ -1417,6 +1435,21 @@ class MiscTests(unittest.TestCase): self.opener_has_handler(o, MyHTTPHandler) self.opener_has_handler(o, MyOtherHTTPHandler) + @unittest.skipUnless(support.is_resource_enabled('network'), + 'test requires network access') + def test_issue16464(self): + opener = urllib.request.build_opener() + request = urllib.request.Request("http://www.python.org/~jeremy/") + self.assertEqual(None, request.data) + + opener.open(request, "1".encode("us-ascii")) + self.assertEqual(b"1", request.data) + self.assertEqual("1", request.get_header("Content-length")) + + opener.open(request, "1234567890".encode("us-ascii")) + self.assertEqual(b"1234567890", request.data) + self.assertEqual("10", request.get_header("Content-length")) + def test_HTTPError_interface(self): """ Issue 13211 reveals that HTTPError didn't implement the URLError @@ -1428,23 +1461,31 @@ class MiscTests(unittest.TestCase): err = urllib.error.HTTPError(url, code, msg, hdrs, fp) self.assertTrue(hasattr(err, 'reason')) self.assertEqual(err.reason, 'something bad happened') - self.assertTrue(hasattr(err, 'hdrs')) - self.assertEqual(err.hdrs, 'Content-Length: 42') + self.assertTrue(hasattr(err, 'headers')) + self.assertEqual(err.headers, 'Content-Length: 42') expected_errmsg = 'HTTP Error %s: %s' % (err.code, err.msg) self.assertEqual(str(err), expected_errmsg) - class RequestTests(unittest.TestCase): + class PutRequest(Request): + method='PUT' def setUp(self): self.get = Request("http://www.python.org/~jeremy/") self.post = Request("http://www.python.org/~jeremy/", "data", headers={"X-Test": "test"}) + self.head = Request("http://www.python.org/~jeremy/", method='HEAD') + self.put = self.PutRequest("http://www.python.org/~jeremy/") + self.force_post = self.PutRequest("http://www.python.org/~jeremy/", + method="POST") def test_method(self): self.assertEqual("POST", self.post.get_method()) self.assertEqual("GET", self.get.get_method()) + self.assertEqual("HEAD", self.head.get_method()) + self.assertEqual("PUT", self.put.get_method()) + self.assertEqual("POST", self.force_post.get_method()) def test_data(self): self.assertFalse(self.get.data) @@ -1453,6 +1494,25 @@ class RequestTests(unittest.TestCase): self.assertTrue(self.get.data) self.assertEqual("POST", self.get.get_method()) + # issue 16464 + # if we change data we need to remove content-length header + # (cause it's most probably calculated for previous value) + def test_setting_data_should_remove_content_length(self): + self.assertNotIn("Content-length", self.get.unredirected_hdrs) + self.get.add_unredirected_header("Content-length", 42) + self.assertEqual(42, self.get.unredirected_hdrs["Content-length"]) + self.get.data = "spam" + self.assertNotIn("Content-length", self.get.unredirected_hdrs) + + # issue 17485 same for deleting data. + def test_deleting_data_should_remove_content_length(self): + self.assertNotIn("Content-length", self.get.unredirected_hdrs) + self.get.data = 'foo' + self.get.add_unredirected_header("Content-length", 3) + self.assertEqual(3, self.get.unredirected_hdrs["Content-length"]) + del self.get.data + self.assertNotIn("Content-length", self.get.unredirected_hdrs) + def test_get_full_url(self): self.assertEqual("http://www.python.org/~jeremy/", self.get.get_full_url()) @@ -1494,21 +1554,13 @@ class RequestTests(unittest.TestCase): req = Request(url) self.assertEqual(req.get_full_url(), url) - def test_HTTPError_interface_call(self): - """ - Issue 15701 - HTTPError interface has info method available from URLError - """ - err = urllib.request.HTTPError(msg="something bad happened", url=None, - code=None, hdrs='Content-Length:42', fp=None) - self.assertTrue(hasattr(err, 'reason')) - assert hasattr(err, 'reason') - assert hasattr(err, 'info') - assert callable(err.info) - try: - err.info() - except AttributeError: - self.fail('err.info call failed.') - self.assertEqual(err.info(), "Content-Length:42") + def test_url_fullurl_get_full_url(self): + urls = ['http://docs.python.org', + 'http://docs.python.org/library/urllib2.html#OK', + 'http://www.python.org/?qs=query#fragment=true' ] + for url in urls: + req = Request(url) + self.assertEqual(req.get_full_url(), req.full_url) def test_main(verbose=None): from test import test_urllib2 diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py index 31c638fc4e..83626f4fd3 100644 --- a/Lib/test/test_urllib2_localnet.py +++ b/Lib/test/test_urllib2_localnet.py @@ -7,7 +7,10 @@ import unittest import hashlib from test import support threading = support.import_module('threading') - +try: + import ssl +except ImportError: + ssl = None here = os.path.dirname(__file__) # Self-signed cert file for 'localhost' @@ -15,6 +18,7 @@ CERT_localhost = os.path.join(here, 'keycert.pem') # Self-signed cert file for 'fakehostname' CERT_fakehostname = os.path.join(here, 'keycert2.pem') + # Loopback http server infrastructure class LoopbackHttpServer(http.server.HTTPServer): @@ -387,14 +391,14 @@ class TestUrlopen(unittest.TestCase): handler.port = port return handler - def start_https_server(self, responses=None, certfile=CERT_localhost): + def start_https_server(self, responses=None, **kwargs): if not hasattr(urllib.request, 'HTTPSHandler'): self.skipTest('ssl support required') from test.ssl_servers import make_https_server if responses is None: responses = [(200, [], b"we care a bit")] handler = GetRequestHandler(responses) - server = make_https_server(self, certfile=certfile, handler_class=handler) + server = make_https_server(self, handler_class=handler, **kwargs) handler.port = server.port return handler @@ -484,6 +488,21 @@ class TestUrlopen(unittest.TestCase): self.urlopen("https://localhost:%s/bizarre" % handler.port, cadefault=True) + def test_https_sni(self): + if ssl is None: + self.skipTest("ssl module required") + if not ssl.HAS_SNI: + self.skipTest("SNI support required in OpenSSL") + sni_name = None + def cb_sni(ssl_sock, server_name, initial_context): + nonlocal sni_name + sni_name = server_name + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.set_servername_callback(cb_sni) + handler = self.start_https_server(context=context, certfile=CERT_localhost) + self.urlopen("https://localhost:%s" % handler.port) + self.assertEqual(sni_name, "localhost") + def test_sending_headers(self): handler = self.start_server() req = urllib.request.Request("http://localhost:%s/" % handler.port, @@ -530,7 +549,7 @@ class TestUrlopen(unittest.TestCase): # so we run the test only when -unetwork/-uall is specified to # mitigate the problem a bit (see #17564) support.requires('network') - self.assertRaises(IOError, + self.assertRaises(OSError, # Given that both VeriSign and various ISPs have in # the past or are presently hijacking various invalid # domain name requests in an attempt to boost traffic diff --git a/Lib/test/test_urllib2net.py b/Lib/test/test_urllib2net.py index 4981c2410b..78891fd797 100644 --- a/Lib/test/test_urllib2net.py +++ b/Lib/test/test_urllib2net.py @@ -12,6 +12,8 @@ try: except ImportError: ssl = None +support.requires("network") + TIMEOUT = 60 # seconds @@ -162,6 +164,14 @@ class OtherNetworkTests(unittest.TestCase): self.assertEqual(res.geturl(), "http://docs.python.org/2/glossary.html#glossary") + def test_redirect_url_withfrag(self): + redirect_url_with_frag = "http://bitly.com/urllibredirecttest" + with support.transient_internet(redirect_url_with_frag): + req = urllib.request.Request(redirect_url_with_frag) + res = urllib.request.urlopen(req) + self.assertEqual(res.geturl(), + "http://docs.python.org/3.4/glossary.html#term-global-interpreter-lock") + def test_custom_headers(self): url = "http://www.example.com" with support.transient_internet(url): @@ -214,7 +224,7 @@ class OtherNetworkTests(unittest.TestCase): debug(url) try: f = urlopen(url, req, TIMEOUT) - except EnvironmentError as err: + except OSError as err: debug(err) if expected_err: msg = ("Didn't get expected error(s) %s for %s %s, got %s: %s" % @@ -328,35 +338,5 @@ class TimeoutTest(unittest.TestCase): self.assertEqual(u.fp.fp.raw._sock.gettimeout(), 60) -@unittest.skipUnless(ssl, "requires SSL support") -class HTTPSTests(unittest.TestCase): - - def test_sni(self): - self.skipTest("test disabled - test server needed") - # Checks that Server Name Indication works, if supported by the - # OpenSSL linked to. - # The ssl module itself doesn't have server-side support for SNI, - # so we rely on a third-party test site. - expect_sni = ssl.HAS_SNI - with support.transient_internet("XXX"): - u = urllib.request.urlopen("XXX") - contents = u.readall() - if expect_sni: - self.assertIn(b"Great", contents) - self.assertNotIn(b"Unfortunately", contents) - else: - self.assertNotIn(b"Great", contents) - self.assertIn(b"Unfortunately", contents) - - -def test_main(): - support.requires("network") - support.run_unittest(AuthTests, - HTTPSTests, - OtherNetworkTests, - CloseSocketTest, - TimeoutTest, - ) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_urllibnet.py b/Lib/test/test_urllibnet.py index 38afb69250..4da89fcaaa 100644 --- a/Lib/test/test_urllibnet.py +++ b/Lib/test/test_urllibnet.py @@ -121,16 +121,15 @@ class urlopenNetworkTests(unittest.TestCase): else: # This happens with some overzealous DNS providers such as OpenDNS self.skipTest("%r should not resolve for test to work" % bogus_domain) - self.assertRaises(IOError, - # SF patch 809915: In Sep 2003, VeriSign started - # highjacking invalid .com and .net addresses to - # boost traffic to their own site. This test - # started failing then. One hopes the .invalid - # domain will be spared to serve its defined - # purpose. - # urllib.urlopen, "http://www.sadflkjsasadf.com/") - urllib.request.urlopen, - "http://sadflkjsasf.i.nvali.d/") + failure_explanation = ('opening an invalid URL did not raise OSError; ' + 'can be caused by a broken DNS server ' + '(e.g. returns 404 or hijacks page)') + with self.assertRaises(OSError, msg=failure_explanation): + # SF patch 809915: In Sep 2003, VeriSign started highjacking + # invalid .com and .net addresses to boost traffic to their own + # site. This test started failing then. One hopes the .invalid + # domain will be spared to serve its defined purpose. + urllib.request.urlopen("http://sadflkjsasf.i.nvali.d/") class urlretrieveNetworkTests(unittest.TestCase): diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py index d67cf258b0..393481148d 100644 --- a/Lib/test/test_urlparse.py +++ b/Lib/test/test_urlparse.py @@ -863,6 +863,14 @@ class UrlParseTestCase(unittest.TestCase): self.assertEqual(p1.path, '863-1234') self.assertEqual(p1.params, 'phone-context=+1-914-555') + def test_unwrap(self): + url = urllib.parse.unwrap('<URL:type://host/path>') + self.assertEqual(url, 'type://host/path') + + def test_Quoter_repr(self): + quoter = urllib.parse.Quoter(urllib.parse._ALWAYS_SAFE) + self.assertIn('Quoter', repr(quoter)) + def test_main(): support.run_unittest(UrlParseTestCase) diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py index 137c445eae..2ca99298d4 100644 --- a/Lib/test/test_userdict.py +++ b/Lib/test/test_userdict.py @@ -45,7 +45,8 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): # Test __repr__ self.assertEqual(str(u0), str(d0)) self.assertEqual(repr(u1), repr(d1)) - self.assertEqual(repr(u2), repr(d2)) + self.assertIn(repr(u2), ("{'one': 1, 'two': 2}", + "{'two': 2, 'one': 1}")) # Test rich comparison and __len__ all = [d0, d1, d2, u, u0, u1, u2, uu, uu0, uu1, uu2] @@ -89,9 +90,9 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): self.assertNotEqual(m2a, m2) # Test keys, items, values - self.assertEqual(u2.keys(), d2.keys()) - self.assertEqual(u2.items(), d2.items()) - self.assertEqual(list(u2.values()), list(d2.values())) + self.assertEqual(sorted(u2.keys()), sorted(d2.keys())) + self.assertEqual(sorted(u2.items()), sorted(d2.items())) + self.assertEqual(sorted(u2.values()), sorted(d2.values())) # Test "in". for i in u2.keys(): diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py index 0fae88bcf9..1084a99a5e 100644 --- a/Lib/test/test_venv.py +++ b/Lib/test/test_venv.py @@ -5,6 +5,7 @@ Copyright (C) 2011-2012 Vinay Sajip. Licensed to the PSF under a contributor agreement. """ +import ensurepip import os import os.path import shutil @@ -12,10 +13,28 @@ import subprocess import sys import tempfile from test.support import (captured_stdout, captured_stderr, run_unittest, - can_symlink) + can_symlink, EnvironmentVarGuard) +import textwrap import unittest import venv +# pip currently requires ssl support, so we ensure we handle +# it being missing (http://bugs.python.org/issue19744) +try: + import ssl +except ImportError: + ssl = None + +skipInVenv = unittest.skipIf(sys.prefix != sys.base_prefix, + 'Test not appropriate in a venv') + +# os.path.exists('nul') is False: http://bugs.python.org/issue20541 +if os.devnull.lower() == 'nul': + failsOnWindows = unittest.expectedFailure +else: + def failsOnWindows(f): + return f + class BaseTest(unittest.TestCase): """Base class for venv tests.""" @@ -83,8 +102,7 @@ class BasicTest(BaseTest): print(' %r' % os.listdir(bd)) self.assertTrue(os.path.exists(fn), 'File %r should exist.' % fn) - @unittest.skipIf(sys.prefix != sys.base_prefix, 'Test not appropriate ' - 'in a venv') + @skipInVenv def test_prefixes(self): """ Test that the prefix values are as expected. @@ -109,13 +127,68 @@ class BasicTest(BaseTest): out, err = p.communicate() self.assertEqual(out.strip(), expected.encode()) + if sys.platform == 'win32': + ENV_SUBDIRS = ( + ('Scripts',), + ('Include',), + ('Lib',), + ('Lib', 'site-packages'), + ) + else: + ENV_SUBDIRS = ( + ('bin',), + ('include',), + ('lib',), + ('lib', 'python%d.%d' % sys.version_info[:2]), + ('lib', 'python%d.%d' % sys.version_info[:2], 'site-packages'), + ) + + def create_contents(self, paths, filename): + """ + Create some files in the environment which are unrelated + to the virtual environment. + """ + for subdirs in paths: + d = os.path.join(self.env_dir, *subdirs) + os.mkdir(d) + fn = os.path.join(d, filename) + with open(fn, 'wb') as f: + f.write(b'Still here?') + def test_overwrite_existing(self): """ - Test control of overwriting an existing environment directory. + Test creating environment in an existing directory. """ - self.assertRaises(ValueError, venv.create, self.env_dir) + self.create_contents(self.ENV_SUBDIRS, 'foo') + venv.create(self.env_dir) + for subdirs in self.ENV_SUBDIRS: + fn = os.path.join(self.env_dir, *(subdirs + ('foo',))) + self.assertTrue(os.path.exists(fn)) + with open(fn, 'rb') as f: + self.assertEqual(f.read(), b'Still here?') + builder = venv.EnvBuilder(clear=True) builder.create(self.env_dir) + for subdirs in self.ENV_SUBDIRS: + fn = os.path.join(self.env_dir, *(subdirs + ('foo',))) + self.assertFalse(os.path.exists(fn)) + + def clear_directory(self, path): + for fn in os.listdir(path): + fn = os.path.join(path, fn) + if os.path.islink(fn) or os.path.isfile(fn): + os.remove(fn) + elif os.path.isdir(fn): + shutil.rmtree(fn) + + def test_unoverwritable_fails(self): + #create a file clashing with directories in the env dir + for paths in self.ENV_SUBDIRS[:3]: + fn = os.path.join(self.env_dir, *paths) + with open(fn, 'wb') as f: + f.write(b'') + self.assertRaises((ValueError, OSError), venv.create, self.env_dir) + self.clear_directory(self.env_dir) def test_upgrade(self): """ @@ -162,8 +235,7 @@ class BasicTest(BaseTest): # run the test, the pyvenv.cfg in the venv created in the test will # point to the venv being used to run the test, and we lose the link # to the source build - so Python can't initialise properly. - @unittest.skipIf(sys.prefix != sys.base_prefix, 'Test not appropriate ' - 'in a venv') + @skipInVenv def test_executable(self): """ Test that the sys.executable value is as expected. @@ -192,8 +264,129 @@ class BasicTest(BaseTest): out, err = p.communicate() self.assertEqual(out.strip(), envpy.encode()) + +@skipInVenv +class EnsurePipTest(BaseTest): + """Test venv module installation of pip.""" + def assert_pip_not_installed(self): + envpy = os.path.join(os.path.realpath(self.env_dir), + self.bindir, self.exe) + try_import = 'try:\n import pip\nexcept ImportError:\n print("OK")' + cmd = [envpy, '-c', try_import] + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = p.communicate() + # We force everything to text, so unittest gives the detailed diff + # if we get unexpected results + err = err.decode("latin-1") # Force to text, prevent decoding errors + self.assertEqual(err, "") + out = out.decode("latin-1") # Force to text, prevent decoding errors + self.assertEqual(out.strip(), "OK") + + + def test_no_pip_by_default(self): + shutil.rmtree(self.env_dir) + self.run_with_capture(venv.create, self.env_dir) + self.assert_pip_not_installed() + + def test_explicit_no_pip(self): + shutil.rmtree(self.env_dir) + self.run_with_capture(venv.create, self.env_dir, with_pip=False) + self.assert_pip_not_installed() + + @failsOnWindows + def test_devnull_exists_and_is_empty(self): + # Fix for issue #20053 uses os.devnull to force a config file to + # appear empty. However http://bugs.python.org/issue20541 means + # that doesn't currently work properly on Windows. Once that is + # fixed, the "win_location" part of test_with_pip should be restored + self.assertTrue(os.path.exists(os.devnull)) + with open(os.devnull, "rb") as f: + self.assertEqual(f.read(), b"") + + # Requesting pip fails without SSL (http://bugs.python.org/issue19744) + @unittest.skipIf(ssl is None, ensurepip._MISSING_SSL_MESSAGE) + def test_with_pip(self): + shutil.rmtree(self.env_dir) + with EnvironmentVarGuard() as envvars: + # pip's cross-version compatibility may trigger deprecation + # warnings in current versions of Python. Ensure related + # environment settings don't cause venv to fail. + envvars["PYTHONWARNINGS"] = "e" + # ensurepip is different enough from a normal pip invocation + # that we want to ensure it ignores the normal pip environment + # variable settings. We set PIP_NO_INSTALL here specifically + # to check that ensurepip (and hence venv) ignores it. + # See http://bugs.python.org/issue19734 + envvars["PIP_NO_INSTALL"] = "1" + # Also check that we ignore the pip configuration file + # See http://bugs.python.org/issue20053 + with tempfile.TemporaryDirectory() as home_dir: + envvars["HOME"] = home_dir + bad_config = "[global]\nno-install=1" + # Write to both config file names on all platforms to reduce + # cross-platform variation in test code behaviour + win_location = ("pip", "pip.ini") + posix_location = (".pip", "pip.conf") + # Skips win_location due to http://bugs.python.org/issue20541 + for dirname, fname in (posix_location,): + dirpath = os.path.join(home_dir, dirname) + os.mkdir(dirpath) + fpath = os.path.join(dirpath, fname) + with open(fpath, 'w') as f: + f.write(bad_config) + + # Actually run the create command with all that unhelpful + # config in place to ensure we ignore it + try: + self.run_with_capture(venv.create, self.env_dir, + with_pip=True) + except subprocess.CalledProcessError as exc: + # The output this produces can be a little hard to read, + # but at least it has all the details + details = exc.output.decode(errors="replace") + msg = "{}\n\n**Subprocess Output**\n{}" + self.fail(msg.format(exc, details)) + # Ensure pip is available in the virtual environment + envpy = os.path.join(os.path.realpath(self.env_dir), self.bindir, self.exe) + cmd = [envpy, '-Im', 'pip', '--version'] + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = p.communicate() + # We force everything to text, so unittest gives the detailed diff + # if we get unexpected results + err = err.decode("latin-1") # Force to text, prevent decoding errors + self.assertEqual(err, "") + out = out.decode("latin-1") # Force to text, prevent decoding errors + expected_version = "pip {}".format(ensurepip.version()) + self.assertEqual(out[:len(expected_version)], expected_version) + env_dir = os.fsencode(self.env_dir).decode("latin-1") + self.assertIn(env_dir, out) + + # http://bugs.python.org/issue19728 + # Check the private uninstall command provided for the Windows + # installers works (at least in a virtual environment) + cmd = [envpy, '-Im', 'ensurepip._uninstall'] + with EnvironmentVarGuard() as envvars: + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = p.communicate() + # We force everything to text, so unittest gives the detailed diff + # if we get unexpected results + err = err.decode("latin-1") # Force to text, prevent decoding errors + self.assertEqual(err, "") + # Being fairly specific regarding the expected behaviour for the + # initial bundling phase in Python 3.4. If the output changes in + # future pip versions, this test can likely be relaxed further. + out = out.decode("latin-1") # Force to text, prevent decoding errors + self.assertIn("Successfully uninstalled pip", out) + self.assertIn("Successfully uninstalled setuptools", out) + # Check pip is now gone from the virtual environment + self.assert_pip_not_installed() + + def test_main(): - run_unittest(BasicTest) + run_unittest(BasicTest, EnsurePipTest) if __name__ == "__main__": test_main() diff --git a/Lib/test/test_wait3.py b/Lib/test/test_wait3.py index bd06c8d8bd..f6a065d850 100644 --- a/Lib/test/test_wait3.py +++ b/Lib/test/test_wait3.py @@ -7,15 +7,11 @@ import unittest from test.fork_wait import ForkWait from test.support import run_unittest, reap_children -try: - os.fork -except AttributeError: - raise unittest.SkipTest("os.fork not defined -- skipping test_wait3") +if not hasattr(os, 'fork'): + raise unittest.SkipTest("os.fork not defined") -try: - os.wait3 -except AttributeError: - raise unittest.SkipTest("os.wait3 not defined -- skipping test_wait3") +if not hasattr(os, 'wait3'): + raise unittest.SkipTest("os.wait3 not defined") class Wait3Test(ForkWait): def wait_impl(self, cpid): diff --git a/Lib/test/test_warnings.py b/Lib/test/test_warnings.py index 10076af3aa..eec2c24218 100644 --- a/Lib/test/test_warnings.py +++ b/Lib/test/test_warnings.py @@ -329,6 +329,19 @@ class WarnTests(BaseTest): warning_tests.__name__ = module_name sys.argv = argv + def test_warn_explicit_non_ascii_filename(self): + with original_warnings.catch_warnings(record=True, + module=self.module) as w: + self.module.resetwarnings() + self.module.filterwarnings("always", category=UserWarning) + for filename in ("nonascii\xe9\u20ac", "surrogate\udc80"): + try: + os.fsencode(filename) + except UnicodeEncodeError: + continue + self.module.warn_explicit("text", UserWarning, filename, 1) + self.assertEqual(w[-1].filename, filename) + def test_warn_explicit_type_errors(self): # warn_explicit() should error out gracefully if it is given objects # of the wrong types. @@ -766,6 +779,25 @@ class BootstrapTest(unittest.TestCase): # Use -W to load warnings module at startup assert_python_ok('-c', 'pass', '-W', 'always', PYTHONPATH=cwd) +class FinalizationTest(unittest.TestCase): + def test_finalization(self): + # Issue #19421: warnings.warn() should not crash + # during Python finalization + code = """ +import warnings +warn = warnings.warn + +class A: + def __del__(self): + warn("test") + +a=A() + """ + rc, out, err = assert_python_ok("-c", code) + # note: "__main__" filename is not correct, it should be the name + # of the script + self.assertEqual(err, b'__main__:7: UserWarning: test') + def setUpModule(): py_warnings.onceregistry.clear() diff --git a/Lib/test/test_wave.py b/Lib/test/test_wave.py index f64201f9fc..3eff773bca 100644 --- a/Lib/test/test_wave.py +++ b/Lib/test/test_wave.py @@ -1,6 +1,7 @@ from test.support import TESTFN import unittest from test import audiotests +from audioop import byteswap import sys import wave @@ -8,9 +9,6 @@ import wave class WaveTest(audiotests.AudioWriteTests, audiotests.AudioTestsWithSourceFile): module = wave - test_unseekable_write = None - test_unseekable_overflowed_write = None - test_unseekable_incompleted_write = None class WavePCM8Test(WaveTest, unittest.TestCase): @@ -48,13 +46,7 @@ class WavePCM16Test(WaveTest, unittest.TestCase): E4B50CEB 63440A5A 08CA0A1F 2BBA0B0B 51460E47 8BCB113C B6F50EEA 44150A59 \ """) if sys.byteorder != 'big': - frames = audiotests.byteswap2(frames) - - if sys.byteorder == 'big': - @unittest.expectedFailure - def test_unseekable_incompleted_write(self): - super().test_unseekable_incompleted_write() - + frames = byteswap(frames, 2) class WavePCM24Test(WaveTest, unittest.TestCase): @@ -81,7 +73,7 @@ class WavePCM24Test(WaveTest, unittest.TestCase): 51486F0E44E1 8BCC64113B05 B6F4EC0EEB36 4413170A5B48 \ """) if sys.byteorder != 'big': - frames = audiotests.byteswap3(frames) + frames = byteswap(frames, 3) class WavePCM32Test(WaveTest, unittest.TestCase): @@ -108,12 +100,7 @@ class WavePCM32Test(WaveTest, unittest.TestCase): 51486F800E44E190 8BCC6480113B0580 B6F4EC000EEB3630 441317800A5B48A0 \ """) if sys.byteorder != 'big': - frames = audiotests.byteswap4(frames) - - if sys.byteorder == 'big': - @unittest.expectedFailure - def test_unseekable_incompleted_write(self): - super().test_unseekable_incompleted_write() + frames = byteswap(frames, 4) if __name__ == '__main__': diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index 85ab717c69..cccb515252 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -7,11 +7,15 @@ import operator import contextlib import copy -from test import support +from test import support, script_helper # Used in ReferencesTestCase.test_ref_created_during_del() . ref_from_del = None +# Used by FinalizeTestCase as a global that may be replaced by None +# when the interpreter shuts down. +_global_var = 'foobar' + class C: def method(self): pass @@ -47,6 +51,11 @@ class Object: return NotImplemented def __hash__(self): return hash(self.arg) + def some_method(self): + return 4 + def other_method(self): + return 5 + class RefCycle: def __init__(self): @@ -794,6 +803,30 @@ class ReferencesTestCase(TestBase): del root gc.collect() + def test_callback_attribute(self): + x = Object(1) + callback = lambda ref: None + ref1 = weakref.ref(x, callback) + self.assertIs(ref1.__callback__, callback) + + ref2 = weakref.ref(x) + self.assertIsNone(ref2.__callback__) + + def test_callback_attribute_after_deletion(self): + x = Object(1) + ref = weakref.ref(x, self.callback) + self.assertIsNotNone(ref.__callback__) + del x + support.gc_collect() + self.assertIsNone(ref.__callback__) + + def test_set_callback_attribute(self): + x = Object(1) + callback = lambda ref: None + ref1 = weakref.ref(x, callback) + with self.assertRaises(AttributeError): + ref1.__callback__ = lambda ref: None + class SubclassableWeakrefTestCase(TestBase): @@ -898,6 +931,140 @@ class SubclassableWeakrefTestCase(TestBase): self.assertEqual(self.cbcalled, 0) +class WeakMethodTestCase(unittest.TestCase): + + def _subclass(self): + """Return a Object subclass overriding `some_method`.""" + class C(Object): + def some_method(self): + return 6 + return C + + def test_alive(self): + o = Object(1) + r = weakref.WeakMethod(o.some_method) + self.assertIsInstance(r, weakref.ReferenceType) + self.assertIsInstance(r(), type(o.some_method)) + self.assertIs(r().__self__, o) + self.assertIs(r().__func__, o.some_method.__func__) + self.assertEqual(r()(), 4) + + def test_object_dead(self): + o = Object(1) + r = weakref.WeakMethod(o.some_method) + del o + gc.collect() + self.assertIs(r(), None) + + def test_method_dead(self): + C = self._subclass() + o = C(1) + r = weakref.WeakMethod(o.some_method) + del C.some_method + gc.collect() + self.assertIs(r(), None) + + def test_callback_when_object_dead(self): + # Test callback behaviour when object dies first. + C = self._subclass() + calls = [] + def cb(arg): + calls.append(arg) + o = C(1) + r = weakref.WeakMethod(o.some_method, cb) + del o + gc.collect() + self.assertEqual(calls, [r]) + # Callback is only called once. + C.some_method = Object.some_method + gc.collect() + self.assertEqual(calls, [r]) + + def test_callback_when_method_dead(self): + # Test callback behaviour when method dies first. + C = self._subclass() + calls = [] + def cb(arg): + calls.append(arg) + o = C(1) + r = weakref.WeakMethod(o.some_method, cb) + del C.some_method + gc.collect() + self.assertEqual(calls, [r]) + # Callback is only called once. + del o + gc.collect() + self.assertEqual(calls, [r]) + + @support.cpython_only + def test_no_cycles(self): + # A WeakMethod doesn't create any reference cycle to itself. + o = Object(1) + def cb(_): + pass + r = weakref.WeakMethod(o.some_method, cb) + wr = weakref.ref(r) + del r + self.assertIs(wr(), None) + + def test_equality(self): + def _eq(a, b): + self.assertTrue(a == b) + self.assertFalse(a != b) + def _ne(a, b): + self.assertTrue(a != b) + self.assertFalse(a == b) + x = Object(1) + y = Object(1) + a = weakref.WeakMethod(x.some_method) + b = weakref.WeakMethod(y.some_method) + c = weakref.WeakMethod(x.other_method) + d = weakref.WeakMethod(y.other_method) + # Objects equal, same method + _eq(a, b) + _eq(c, d) + # Objects equal, different method + _ne(a, c) + _ne(a, d) + _ne(b, c) + _ne(b, d) + # Objects unequal, same or different method + z = Object(2) + e = weakref.WeakMethod(z.some_method) + f = weakref.WeakMethod(z.other_method) + _ne(a, e) + _ne(a, f) + _ne(b, e) + _ne(b, f) + del x, y, z + gc.collect() + # Dead WeakMethods compare by identity + refs = a, b, c, d, e, f + for q in refs: + for r in refs: + self.assertEqual(q == r, q is r) + self.assertEqual(q != r, q is not r) + + def test_hashing(self): + # Alive WeakMethods are hashable if the underlying object is + # hashable. + x = Object(1) + y = Object(1) + a = weakref.WeakMethod(x.some_method) + b = weakref.WeakMethod(y.some_method) + c = weakref.WeakMethod(y.other_method) + # Since WeakMethod objects are equal, the hashes should be equal. + self.assertEqual(hash(a), hash(b)) + ha = hash(a) + # Dead WeakMethods retain their old hash value + del x, y + gc.collect() + self.assertEqual(hash(a), ha) + self.assertEqual(hash(b), ha) + # If it wasn't hashed when alive, a dead WeakMethod cannot be hashed. + self.assertRaises(TypeError, hash, c) + + class MappingTestCase(TestBase): COUNT = 10 @@ -1385,6 +1552,151 @@ class WeakKeyDictionaryTestCase(mapping_tests.BasicTestMappingProtocol): def _reference(self): return self.__ref.copy() + +class FinalizeTestCase(unittest.TestCase): + + class A: + pass + + def _collect_if_necessary(self): + # we create no ref-cycles so in CPython no gc should be needed + if sys.implementation.name != 'cpython': + support.gc_collect() + + def test_finalize(self): + def add(x,y,z): + res.append(x + y + z) + return x + y + z + + a = self.A() + + res = [] + f = weakref.finalize(a, add, 67, 43, z=89) + self.assertEqual(f.alive, True) + self.assertEqual(f.peek(), (a, add, (67,43), {'z':89})) + self.assertEqual(f(), 199) + self.assertEqual(f(), None) + self.assertEqual(f(), None) + self.assertEqual(f.peek(), None) + self.assertEqual(f.detach(), None) + self.assertEqual(f.alive, False) + self.assertEqual(res, [199]) + + res = [] + f = weakref.finalize(a, add, 67, 43, 89) + self.assertEqual(f.peek(), (a, add, (67,43,89), {})) + self.assertEqual(f.detach(), (a, add, (67,43,89), {})) + self.assertEqual(f(), None) + self.assertEqual(f(), None) + self.assertEqual(f.peek(), None) + self.assertEqual(f.detach(), None) + self.assertEqual(f.alive, False) + self.assertEqual(res, []) + + res = [] + f = weakref.finalize(a, add, x=67, y=43, z=89) + del a + self._collect_if_necessary() + self.assertEqual(f(), None) + self.assertEqual(f(), None) + self.assertEqual(f.peek(), None) + self.assertEqual(f.detach(), None) + self.assertEqual(f.alive, False) + self.assertEqual(res, [199]) + + def test_order(self): + a = self.A() + res = [] + + f1 = weakref.finalize(a, res.append, 'f1') + f2 = weakref.finalize(a, res.append, 'f2') + f3 = weakref.finalize(a, res.append, 'f3') + f4 = weakref.finalize(a, res.append, 'f4') + f5 = weakref.finalize(a, res.append, 'f5') + + # make sure finalizers can keep themselves alive + del f1, f4 + + self.assertTrue(f2.alive) + self.assertTrue(f3.alive) + self.assertTrue(f5.alive) + + self.assertTrue(f5.detach()) + self.assertFalse(f5.alive) + + f5() # nothing because previously unregistered + res.append('A') + f3() # => res.append('f3') + self.assertFalse(f3.alive) + res.append('B') + f3() # nothing because previously called + res.append('C') + del a + self._collect_if_necessary() + # => res.append('f4') + # => res.append('f2') + # => res.append('f1') + self.assertFalse(f2.alive) + res.append('D') + f2() # nothing because previously called by gc + + expected = ['A', 'f3', 'B', 'C', 'f4', 'f2', 'f1', 'D'] + self.assertEqual(res, expected) + + def test_all_freed(self): + # we want a weakrefable subclass of weakref.finalize + class MyFinalizer(weakref.finalize): + pass + + a = self.A() + res = [] + def callback(): + res.append(123) + f = MyFinalizer(a, callback) + + wr_callback = weakref.ref(callback) + wr_f = weakref.ref(f) + del callback, f + + self.assertIsNotNone(wr_callback()) + self.assertIsNotNone(wr_f()) + + del a + self._collect_if_necessary() + + self.assertIsNone(wr_callback()) + self.assertIsNone(wr_f()) + self.assertEqual(res, [123]) + + @classmethod + def run_in_child(cls): + def error(): + # Create an atexit finalizer from inside a finalizer called + # at exit. This should be the next to be run. + g1 = weakref.finalize(cls, print, 'g1') + print('f3 error') + 1/0 + + # cls should stay alive till atexit callbacks run + f1 = weakref.finalize(cls, print, 'f1', _global_var) + f2 = weakref.finalize(cls, print, 'f2', _global_var) + f3 = weakref.finalize(cls, error) + f4 = weakref.finalize(cls, print, 'f4', _global_var) + + assert f1.atexit == True + f2.atexit = False + assert f3.atexit == True + assert f4.atexit == True + + def test_atexit(self): + prog = ('from test.test_weakref import FinalizeTestCase;'+ + 'FinalizeTestCase.run_in_child()') + rc, out, err = script_helper.assert_python_ok('-c', prog) + out = out.decode('ascii').splitlines() + self.assertEqual(out, ['f4 foobar', 'f3 error', 'g1', 'f1 foobar']) + self.assertTrue(b'ZeroDivisionError' in err) + + libreftest = """ Doctest for examples in the library reference: weakref.rst >>> import weakref @@ -1473,10 +1785,12 @@ __test__ = {'libreftest' : libreftest} def test_main(): support.run_unittest( ReferencesTestCase, + WeakMethodTestCase, MappingTestCase, WeakValueDictionaryTestCase, WeakKeyDictionaryTestCase, SubclassableWeakrefTestCase, + FinalizeTestCase, ) support.run_doctest(sys.modules[__name__]) diff --git a/Lib/test/test_winreg.py b/Lib/test/test_winreg.py index cb4cde9bc6..ef4ce552f1 100644 --- a/Lib/test/test_winreg.py +++ b/Lib/test/test_winreg.py @@ -8,7 +8,7 @@ threading = support.import_module("threading") from platform import machine # Do this first so test will be skipped if module doesn't exist -support.import_module('winreg') +support.import_module('winreg', required_on=['win']) # Now import everything from winreg import * @@ -57,13 +57,13 @@ class BaseWinregTests(unittest.TestCase): def delete_tree(self, root, subkey): try: hkey = OpenKey(root, subkey, KEY_ALL_ACCESS) - except WindowsError: + except OSError: # subkey does not exist return while True: try: subsubkey = EnumKey(hkey, 0) - except WindowsError: + except OSError: # no more subkeys break self.delete_tree(hkey, subsubkey) @@ -100,7 +100,7 @@ class BaseWinregTests(unittest.TestCase): QueryInfoKey(int_sub_key) self.fail("It appears the CloseKey() function does " "not close the actual key!") - except EnvironmentError: + except OSError: pass # ... and close that key that way :-) int_key = int(key) @@ -109,7 +109,7 @@ class BaseWinregTests(unittest.TestCase): QueryInfoKey(int_key) self.fail("It appears the key.Close() function " "does not close the actual key!") - except EnvironmentError: + except OSError: pass def _read_test_data(self, root_key, subkeystr="sub_key", OpenKey=OpenKey): @@ -126,7 +126,7 @@ class BaseWinregTests(unittest.TestCase): while 1: try: data = EnumValue(sub_key, index) - except EnvironmentError: + except OSError: break self.assertEqual(data in test_data, True, "Didn't read back the correct test data") @@ -147,7 +147,7 @@ class BaseWinregTests(unittest.TestCase): try: EnumKey(key, 1) self.fail("Was able to get a second key when I only have one!") - except EnvironmentError: + except OSError: pass key.Close() @@ -171,7 +171,7 @@ class BaseWinregTests(unittest.TestCase): # Shouldnt be able to delete it twice! DeleteKey(key, subkeystr) self.fail("Deleting the key twice succeeded") - except EnvironmentError: + except OSError: pass key.Close() DeleteKey(root_key, test_key_name) @@ -179,7 +179,7 @@ class BaseWinregTests(unittest.TestCase): try: key = OpenKey(root_key, test_key_name) self.fail("Could open the non-existent key") - except WindowsError: # Use this error name this time + except OSError: # Use this error name this time pass def _test_all(self, root_key, subkeystr="sub_key"): @@ -230,7 +230,7 @@ class LocalWinregTests(BaseWinregTests): def test_inexistant_remote_registry(self): connect = lambda: ConnectRegistry("abcdefghijkl", HKEY_CURRENT_USER) - self.assertRaises(WindowsError, connect) + self.assertRaises(OSError, connect) def testExpandEnvironmentStrings(self): r = ExpandEnvironmentStrings("%windir%\\test") @@ -242,8 +242,8 @@ class LocalWinregTests(BaseWinregTests): try: with ConnectRegistry(None, HKEY_LOCAL_MACHINE) as h: self.assertNotEqual(h.handle, 0) - raise WindowsError - except WindowsError: + raise OSError + except OSError: self.assertEqual(h.handle, 0) def test_changing_value(self): @@ -407,7 +407,7 @@ class Win64WinregTests(BaseWinregTests): open_fail = lambda: OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, KEY_READ | KEY_WOW64_64KEY) - self.assertRaises(WindowsError, open_fail) + self.assertRaises(OSError, open_fail) # Now explicitly open the 64-bit version of the key with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, @@ -447,7 +447,7 @@ class Win64WinregTests(BaseWinregTests): open_fail = lambda: OpenKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, KEY_READ | KEY_WOW64_64KEY) - self.assertRaises(WindowsError, open_fail) + self.assertRaises(OSError, open_fail) # Make sure the 32-bit key is actually there with OpenKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, diff --git a/Lib/test/test_winsound.py b/Lib/test/test_winsound.py index 069adc3981..83618b6a44 100644 --- a/Lib/test/test_winsound.py +++ b/Lib/test/test_winsound.py @@ -22,7 +22,7 @@ def has_sound(sound): key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER, "AppEvents\Schemes\Apps\.Default\{0}\.Default".format(sound)) return winreg.EnumValue(key, 0)[1] != "" - except WindowsError: + except OSError: return False class BeepTest(unittest.TestCase): diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 7bd8a2c7ee..89971f16ca 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -10,6 +10,7 @@ import io import operator import pickle import sys +import types import unittest import weakref @@ -240,7 +241,6 @@ class ElementTreeTest(unittest.TestCase): self.assertEqual(ET.XML, ET.fromstring) self.assertEqual(ET.PI, ET.ProcessingInstruction) - self.assertEqual(ET.XMLParser, ET.XMLTreeBuilder) def test_simpleops(self): # Basic method sanity checks. @@ -433,15 +433,6 @@ class ElementTreeTest(unittest.TestCase): ' <empty-element />\n' '</root>') - parser = ET.XMLTreeBuilder() # 1.2 compatibility - parser.feed(data) - self.serialize_check(parser.close(), - '<root>\n' - ' <element key="value">text</element>\n' - ' <element>text</element>tail\n' - ' <empty-element />\n' - '</root>') - target = ET.TreeBuilder() parser = ET.XMLParser(target=target) parser.feed(data) @@ -706,9 +697,9 @@ class ElementTreeTest(unittest.TestCase): 'iso8859-13', 'iso8859-14', 'iso8859-15', 'iso8859-16', 'cp437', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', 'cp858', 'cp860', 'cp861', 'cp862', - 'cp863', 'cp865', 'cp866', 'cp869', 'cp874', 'cp1006', 'cp1250', - 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', - 'cp1257', 'cp1258', + 'cp863', 'cp865', 'cp866', 'cp869', 'cp874', 'cp1006', 'cp1125', + 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', + 'cp1256', 'cp1257', 'cp1258', 'mac-cyrillic', 'mac-greek', 'mac-iceland', 'mac-latin2', 'mac-roman', 'mac-turkish', 'iso2022-jp', 'iso2022-jp-1', 'iso2022-jp-2', 'iso2022-jp-2004', @@ -964,6 +955,160 @@ class ElementTreeTest(unittest.TestCase): self.assertEqual(serialized, expected) +class XMLPullParserTest(unittest.TestCase): + + def _feed(self, parser, data, chunk_size=None): + if chunk_size is None: + parser.feed(data) + else: + for i in range(0, len(data), chunk_size): + parser.feed(data[i:i+chunk_size]) + + def assert_event_tags(self, parser, expected): + events = parser.read_events() + self.assertEqual([(action, elem.tag) for action, elem in events], + expected) + + def test_simple_xml(self): + for chunk_size in (None, 1, 5): + with self.subTest(chunk_size=chunk_size): + parser = ET.XMLPullParser() + self.assert_event_tags(parser, []) + self._feed(parser, "<!-- comment -->\n", chunk_size) + self.assert_event_tags(parser, []) + self._feed(parser, + "<root>\n <element key='value'>text</element", + chunk_size) + self.assert_event_tags(parser, []) + self._feed(parser, ">\n", chunk_size) + self.assert_event_tags(parser, [('end', 'element')]) + self._feed(parser, "<element>text</element>tail\n", chunk_size) + self._feed(parser, "<empty-element/>\n", chunk_size) + self.assert_event_tags(parser, [ + ('end', 'element'), + ('end', 'empty-element'), + ]) + self._feed(parser, "</root>\n", chunk_size) + self.assert_event_tags(parser, [('end', 'root')]) + self.assertIsNone(parser.close()) + + def test_feed_while_iterating(self): + parser = ET.XMLPullParser() + it = parser.read_events() + self._feed(parser, "<root>\n <element key='value'>text</element>\n") + action, elem = next(it) + self.assertEqual((action, elem.tag), ('end', 'element')) + self._feed(parser, "</root>\n") + action, elem = next(it) + self.assertEqual((action, elem.tag), ('end', 'root')) + with self.assertRaises(StopIteration): + next(it) + + def test_simple_xml_with_ns(self): + parser = ET.XMLPullParser() + self.assert_event_tags(parser, []) + self._feed(parser, "<!-- comment -->\n") + self.assert_event_tags(parser, []) + self._feed(parser, "<root xmlns='namespace'>\n") + self.assert_event_tags(parser, []) + self._feed(parser, "<element key='value'>text</element") + self.assert_event_tags(parser, []) + self._feed(parser, ">\n") + self.assert_event_tags(parser, [('end', '{namespace}element')]) + self._feed(parser, "<element>text</element>tail\n") + self._feed(parser, "<empty-element/>\n") + self.assert_event_tags(parser, [ + ('end', '{namespace}element'), + ('end', '{namespace}empty-element'), + ]) + self._feed(parser, "</root>\n") + self.assert_event_tags(parser, [('end', '{namespace}root')]) + self.assertIsNone(parser.close()) + + def test_ns_events(self): + parser = ET.XMLPullParser(events=('start-ns', 'end-ns')) + self._feed(parser, "<!-- comment -->\n") + self._feed(parser, "<root xmlns='namespace'>\n") + self.assertEqual( + list(parser.read_events()), + [('start-ns', ('', 'namespace'))]) + self._feed(parser, "<element key='value'>text</element") + self._feed(parser, ">\n") + self._feed(parser, "<element>text</element>tail\n") + self._feed(parser, "<empty-element/>\n") + self._feed(parser, "</root>\n") + self.assertEqual(list(parser.read_events()), [('end-ns', None)]) + self.assertIsNone(parser.close()) + + def test_events(self): + parser = ET.XMLPullParser(events=()) + self._feed(parser, "<root/>\n") + self.assert_event_tags(parser, []) + + parser = ET.XMLPullParser(events=('start', 'end')) + self._feed(parser, "<!-- comment -->\n") + self.assert_event_tags(parser, []) + self._feed(parser, "<root>\n") + self.assert_event_tags(parser, [('start', 'root')]) + self._feed(parser, "<element key='value'>text</element") + self.assert_event_tags(parser, [('start', 'element')]) + self._feed(parser, ">\n") + self.assert_event_tags(parser, [('end', 'element')]) + self._feed(parser, + "<element xmlns='foo'>text<empty-element/></element>tail\n") + self.assert_event_tags(parser, [ + ('start', '{foo}element'), + ('start', '{foo}empty-element'), + ('end', '{foo}empty-element'), + ('end', '{foo}element'), + ]) + self._feed(parser, "</root>") + self.assertIsNone(parser.close()) + self.assert_event_tags(parser, [('end', 'root')]) + + parser = ET.XMLPullParser(events=('start',)) + self._feed(parser, "<!-- comment -->\n") + self.assert_event_tags(parser, []) + self._feed(parser, "<root>\n") + self.assert_event_tags(parser, [('start', 'root')]) + self._feed(parser, "<element key='value'>text</element") + self.assert_event_tags(parser, [('start', 'element')]) + self._feed(parser, ">\n") + self.assert_event_tags(parser, []) + self._feed(parser, + "<element xmlns='foo'>text<empty-element/></element>tail\n") + self.assert_event_tags(parser, [ + ('start', '{foo}element'), + ('start', '{foo}empty-element'), + ]) + self._feed(parser, "</root>") + self.assertIsNone(parser.close()) + + def test_events_sequence(self): + # Test that events can be some sequence that's not just a tuple or list + eventset = {'end', 'start'} + parser = ET.XMLPullParser(events=eventset) + self._feed(parser, "<foo>bar</foo>") + self.assert_event_tags(parser, [('start', 'foo'), ('end', 'foo')]) + + class DummyIter: + def __init__(self): + self.events = iter(['start', 'end', 'start-ns']) + def __iter__(self): + return self + def __next__(self): + return next(self.events) + + parser = ET.XMLPullParser(events=DummyIter()) + self._feed(parser, "<foo>bar</foo>") + self.assert_event_tags(parser, [('start', 'foo'), ('end', 'foo')]) + + + def test_unknown_event(self): + with self.assertRaises(ValueError): + ET.XMLPullParser(events=('start', 'end', 'bogus')) + + # # xinclude tests (samples from appendix C of the xinclude specification) @@ -1305,7 +1450,7 @@ class BugsTest(unittest.TestCase): # Don't crash when using custom entities. ENTITIES = {'rsquo': '\u2019', 'lsquo': '\u2018'} - parser = ET.XMLTreeBuilder() + parser = ET.XMLParser() parser.entity.update(ENTITIES) parser.feed("""<?xml version="1.0" encoding="UTF-8"?> <!DOCTYPE patent-application-publication SYSTEM "pap-v15-2001-01-31.dtd" []> @@ -1643,6 +1788,11 @@ class ElementFindTest(unittest.TestCase): self.assertEqual(e.find('./tag[last()-1]').attrib['class'], 'c') self.assertEqual(e.find('./tag[last()-2]').attrib['class'], 'b') + self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[0]') + self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[-1]') + self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[last()-0]') + self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[last()+1]') + def test_findall(self): e = ET.XML(SAMPLE_XML) e[2] = ET.XML(SAMPLE_SECTION) @@ -1902,7 +2052,7 @@ class TreeBuilderTest(unittest.TestCase): # Mimick SimpleTAL's behaviour (issue #16089): both versions of # TreeBuilder should be able to cope with a subclass of the # pure Python Element class. - base = ET._Element + base = ET._Element_Py # Not from a C extension self.assertEqual(base.__module__, 'xml.etree.ElementTree') # Force some multiple inheritance with a C class to make things @@ -2261,6 +2411,18 @@ class IOTest(unittest.TestCase): ET.tostring(root, 'utf-16'), b''.join(ET.tostringlist(root, 'utf-16'))) + def test_short_empty_elements(self): + root = ET.fromstring('<tag>a<x />b<y></y>c</tag>') + self.assertEqual( + ET.tostring(root, 'unicode'), + '<tag>a<x />b<y />c</tag>') + self.assertEqual( + ET.tostring(root, 'unicode', short_empty_elements=True), + '<tag>a<x />b<y />c</tag>') + self.assertEqual( + ET.tostring(root, 'unicode', short_empty_elements=False), + '<tag>a<x></x>b<y></y>c</tag>') + class ParseErrorTest(unittest.TestCase): def test_subclass(self): @@ -2326,8 +2488,11 @@ class NoAcceleratorTest(unittest.TestCase): # Test that the C accelerator was not imported for pyET def test_correct_import_pyET(self): - self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree') - self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree') + # The type of methods defined in Python code is types.FunctionType, + # while the type of methods defined inside _elementtree is + # <class 'wrapper_descriptor'> + self.assertIsInstance(pyET.Element.__init__, types.FunctionType) + self.assertIsInstance(pyET.XMLParser.__init__, types.FunctionType) # -------------------------------------------------------------------- @@ -2397,6 +2562,7 @@ def test_main(module=None): ElementIterTest, TreeBuilderTest, XMLParserTest, + XMLPullParserTest, BugsTest, ] diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py index d04b7dc97f..816aa8645d 100644 --- a/Lib/test/test_xml_etree_c.py +++ b/Lib/test/test_xml_etree_c.py @@ -2,6 +2,7 @@ import sys, struct from test import support from test.support import import_fresh_module +import types import unittest cET = import_fresh_module('xml.etree.ElementTree', @@ -31,14 +32,22 @@ class TestAliasWorking(unittest.TestCase): @unittest.skipUnless(cET, 'requires _elementtree') +@support.cpython_only class TestAcceleratorImported(unittest.TestCase): # Test that the C accelerator was imported, as expected def test_correct_import_cET(self): + # SubElement is a function so it retains _elementtree as its module. self.assertEqual(cET.SubElement.__module__, '_elementtree') def test_correct_import_cET_alias(self): self.assertEqual(cET_alias.SubElement.__module__, '_elementtree') + def test_parser_comes_from_C(self): + # The type of methods defined in Python code is types.FunctionType, + # while the type of methods defined inside _elementtree is + # <class 'wrapper_descriptor'> + self.assertNotIsInstance(cET.Element.__init__, types.FunctionType) + @unittest.skipUnless(cET, 'requires _elementtree') @support.cpython_only diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index cc52259c82..99b3eda13e 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -15,6 +15,10 @@ import contextlib from test import support try: + import gzip +except ImportError: + gzip = None +try: import threading except ImportError: threading = None @@ -216,7 +220,7 @@ class XMLRPCTestCase(unittest.TestCase): xmlrpc.client.ServerProxy('https://localhost:9999').bad_function() except NotImplementedError: self.assertFalse(has_ssl, "xmlrpc client's error with SSL support") - except socket.error: + except OSError: self.assertTrue(has_ssl) class HelperTestCase(unittest.TestCase): @@ -376,6 +380,11 @@ def http_server(evt, numrequests, requestHandler=None): if name == 'div': return 'This is the div function' + class Fixture: + @staticmethod + def getData(): + return '42' + def my_function(): '''This is my function''' return True @@ -407,7 +416,8 @@ def http_server(evt, numrequests, requestHandler=None): serv.register_function(pow) serv.register_function(lambda x,y: x+y, 'add') serv.register_function(my_function) - serv.register_instance(TestInstanceClass()) + testInstance = TestInstanceClass() + serv.register_instance(testInstance, allow_dotted_names=True) evt.set() # handle up to 'numrequests' requests @@ -500,7 +510,7 @@ def is_unavailable_exception(e): return True exc_mess = e.headers.get('X-exception') except AttributeError: - # Ignore socket.errors here. + # Ignore OSErrors here. exc_mess = str(e) if exc_mess and 'temporarily unavailable' in exc_mess.lower(): @@ -515,7 +525,7 @@ def make_request_and_skipIf(condition, reason): def make_request_and_skip(self): try: xmlrpclib.ServerProxy(URL).my_function() - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: if not is_unavailable_exception(e): raise raise unittest.SkipTest(reason) @@ -553,7 +563,7 @@ class SimpleServerTestCase(BaseServerTestCase): try: p = xmlrpclib.ServerProxy(URL) self.assertEqual(p.pow(6,8), 6**8) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -566,7 +576,7 @@ class SimpleServerTestCase(BaseServerTestCase): p = xmlrpclib.ServerProxy(URL) self.assertEqual(p.add(start_string, end_string), start_string + end_string) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -587,12 +597,13 @@ class SimpleServerTestCase(BaseServerTestCase): def test_introspection1(self): expected_methods = set(['pow', 'div', 'my_function', 'add', 'system.listMethods', 'system.methodHelp', - 'system.methodSignature', 'system.multicall']) + 'system.methodSignature', 'system.multicall', + 'Fixture']) try: p = xmlrpclib.ServerProxy(URL) meth = p.system.listMethods() self.assertEqual(set(meth), expected_methods) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -605,7 +616,7 @@ class SimpleServerTestCase(BaseServerTestCase): p = xmlrpclib.ServerProxy(URL) divhelp = p.system.methodHelp('div') self.assertEqual(divhelp, 'This is the div function') - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -619,7 +630,7 @@ class SimpleServerTestCase(BaseServerTestCase): p = xmlrpclib.ServerProxy(URL) myfunction = p.system.methodHelp('my_function') self.assertEqual(myfunction, 'This is my function') - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -632,7 +643,7 @@ class SimpleServerTestCase(BaseServerTestCase): p = xmlrpclib.ServerProxy(URL) divsig = p.system.methodSignature('div') self.assertEqual(divsig, 'signatures not supported') - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -649,7 +660,7 @@ class SimpleServerTestCase(BaseServerTestCase): self.assertEqual(add_result, 2+3) self.assertEqual(pow_result, 6**8) self.assertEqual(div_result, 127//42) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -670,7 +681,7 @@ class SimpleServerTestCase(BaseServerTestCase): self.assertEqual(result.results[0]['faultString'], '<class \'Exception\'>:method "this_is_not_exists" ' 'is not supported') - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -686,6 +697,12 @@ class SimpleServerTestCase(BaseServerTestCase): # This avoids waiting for the socket timeout. self.test_simple1() + def test_allow_dotted_names_true(self): + # XXX also need allow_dotted_names_false test. + server = xmlrpclib.ServerProxy("http://%s:%d/RPC2" % (ADDR, PORT)) + data = server.Fixture.getData() + self.assertEqual(data, '42') + def test_unicode_host(self): server = xmlrpclib.ServerProxy("http://%s:%d/RPC2" % (ADDR, PORT)) self.assertEqual(server.add("a", "\xe9"), "a\xe9") @@ -793,6 +810,7 @@ class KeepaliveServerTestCase2(BaseKeepaliveServerTestCase): #A test case that verifies that gzip encoding works in both directions #(for a request and the response) +@unittest.skipIf(gzip is None, 'requires gzip') class GzipServerTestCase(BaseServerTestCase): #a request handler that supports keep-alive and logs requests into a #class variable @@ -923,7 +941,7 @@ class FailingServerTestCase(unittest.TestCase): try: p = xmlrpclib.ServerProxy(URL) self.assertEqual(p.pow(6,8), 6**8) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e): # protocol error; provide additional information in test output @@ -936,7 +954,7 @@ class FailingServerTestCase(unittest.TestCase): try: p = xmlrpclib.ServerProxy(URL) p.pow(6,8) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e) and hasattr(e, "headers"): # The two server-side error headers shouldn't be sent back in this case @@ -956,7 +974,7 @@ class FailingServerTestCase(unittest.TestCase): try: p = xmlrpclib.ServerProxy(URL) p.pow(6,8) - except (xmlrpclib.ProtocolError, socket.error) as e: + except (xmlrpclib.ProtocolError, OSError) as e: # ignore failures due to non-blocking socket 'unavailable' errors if not is_unavailable_exception(e) and hasattr(e, "headers"): # We should get error info in the response @@ -1084,23 +1102,13 @@ class UseBuiltinTypesTestCase(unittest.TestCase): @support.reap_threads def test_main(): - xmlrpc_tests = [XMLRPCTestCase, HelperTestCase, DateTimeTestCase, - BinaryTestCase, FaultTestCase] - xmlrpc_tests.append(UseBuiltinTypesTestCase) - xmlrpc_tests.append(SimpleServerTestCase) - xmlrpc_tests.append(KeepaliveServerTestCase1) - xmlrpc_tests.append(KeepaliveServerTestCase2) - try: - import gzip - xmlrpc_tests.append(GzipServerTestCase) - except ImportError: - pass #gzip not supported in this build - xmlrpc_tests.append(MultiPathServerTestCase) - xmlrpc_tests.append(ServerProxyTestCase) - xmlrpc_tests.append(FailingServerTestCase) - xmlrpc_tests.append(CGIHandlerTestCase) - - support.run_unittest(*xmlrpc_tests) + support.run_unittest(XMLRPCTestCase, HelperTestCase, DateTimeTestCase, + BinaryTestCase, FaultTestCase, UseBuiltinTypesTestCase, + SimpleServerTestCase, KeepaliveServerTestCase1, + KeepaliveServerTestCase2, GzipServerTestCase, + MultiPathServerTestCase, ServerProxyTestCase, FailingServerTestCase, + CGIHandlerTestCase) + if __name__ == "__main__": test_main() diff --git a/Lib/test/test_xmlrpc_net.py b/Lib/test/test_xmlrpc_net.py index 77f66ea57b..b60b82f3e2 100644 --- a/Lib/test/test_xmlrpc_net.py +++ b/Lib/test/test_xmlrpc_net.py @@ -7,31 +7,7 @@ from test import support import xmlrpc.client as xmlrpclib -class CurrentTimeTest(unittest.TestCase): - - def test_current_time(self): - # Get the current time from xmlrpc.com. This code exercises - # the minimal HTTP functionality in xmlrpclib. - self.skipTest("time.xmlrpc.com is unreliable") - server = xmlrpclib.ServerProxy("http://time.xmlrpc.com/RPC2") - try: - t0 = server.currentTime.getCurrentTime() - except socket.error as e: - self.skipTest("network error: %s" % e) - - # Perform a minimal sanity check on the result, just to be sure - # the request means what we think it means. - t1 = xmlrpclib.DateTime() - - dt0 = xmlrpclib._datetime_type(t0.value) - dt1 = xmlrpclib._datetime_type(t1.value) - if dt0 > dt1: - delta = dt0 - dt1 - else: - delta = dt1 - dt0 - # The difference between the system time here and the system - # time on the server should not be too big. - self.assertTrue(delta.days <= 1) +class PythonBuildersTest(unittest.TestCase): def test_python_builders(self): # Get the list of builders from the XMLRPC buildbot interface at @@ -39,7 +15,7 @@ class CurrentTimeTest(unittest.TestCase): server = xmlrpclib.ServerProxy("http://buildbot.python.org/all/xmlrpc/") try: builders = server.getAllBuilders() - except socket.error as e: + except OSError as e: self.skipTest("network error: %s" % e) self.addCleanup(lambda: server('close')()) @@ -51,7 +27,7 @@ class CurrentTimeTest(unittest.TestCase): def test_main(): support.requires("network") - support.run_unittest(CurrentTimeTest) + support.run_unittest(PythonBuildersTest) if __name__ == "__main__": test_main() diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py index 12a0f71f5d..1bef5750a1 100644 --- a/Lib/test/test_zipfile.py +++ b/Lib/test/test_zipfile.py @@ -1,7 +1,7 @@ import io import os import sys -import imp +import importlib.util import time import shutil import struct @@ -14,7 +14,7 @@ from random import randint, random, getrandbits from test.support import (TESTFN, findfile, unlink, requires_zlib, requires_bz2, requires_lzma, - captured_stdout) + captured_stdout, check_warnings) TESTFN2 = TESTFN + "2" TESTFNDIR = TESTFN + "d" @@ -35,6 +35,10 @@ def get_files(test): yield f test.assertFalse(f.closed) +def openU(zipfp, fn): + with check_warnings(('', DeprecationWarning)): + return zipfp.open(fn, 'rU') + class AbstractTestsWithSourceFile: @classmethod def setUpClass(cls): @@ -537,12 +541,12 @@ class StoredTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, compression = zipfile.ZIP_STORED def large_file_exception_test(self, f, compression): - with zipfile.ZipFile(f, "w", compression) as zipfp: + with zipfile.ZipFile(f, "w", compression, allowZip64=False) as zipfp: self.assertRaises(zipfile.LargeZipFile, zipfp.write, TESTFN, "another.name") def large_file_exception_test2(self, f, compression): - with zipfile.ZipFile(f, "w", compression) as zipfp: + with zipfile.ZipFile(f, "w", compression, allowZip64=False) as zipfp: self.assertRaises(zipfile.LargeZipFile, zipfp.writestr, "another.name", self.data) @@ -588,7 +592,7 @@ class PyZipFileTests(unittest.TestCase): if os.altsep is not None: path_split.extend(fn.split(os.altsep)) if '__pycache__' in path_split: - fn = imp.source_from_cache(fn) + fn = importlib.util.source_from_cache(fn) else: fn = fn[:-1] @@ -622,6 +626,34 @@ class PyZipFileTests(unittest.TestCase): self.assertCompiledIn('email/__init__.py', names) self.assertCompiledIn('email/mime/text.py', names) + def test_write_filtered_python_package(self): + import test + packagedir = os.path.dirname(test.__file__) + + with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp: + + # first make sure that the test folder gives error messages + # (on the badsyntax_... files) + with captured_stdout() as reportSIO: + zipfp.writepy(packagedir) + reportStr = reportSIO.getvalue() + self.assertTrue('SyntaxError' in reportStr) + + # then check that the filter works on the whole package + with captured_stdout() as reportSIO: + zipfp.writepy(packagedir, filterfunc=lambda whatever: False) + reportStr = reportSIO.getvalue() + self.assertTrue('SyntaxError' not in reportStr) + + # then check that the filter works on individual files + with captured_stdout() as reportSIO, self.assertWarns(UserWarning): + zipfp.writepy(packagedir, filterfunc=lambda fn: + 'bad' not in fn) + reportStr = reportSIO.getvalue() + if reportStr: + print(reportStr) + self.assertTrue('SyntaxError' not in reportStr) + def test_write_with_optimization(self): import email packagedir = os.path.dirname(email.__file__) @@ -631,7 +663,7 @@ class PyZipFileTests(unittest.TestCase): ext = '.pyo' if optlevel == 1 else '.pyc' with TemporaryFile() as t, \ - zipfile.PyZipFile(t, "w", optimize=optlevel) as zipfp: + zipfile.PyZipFile(t, "w", optimize=optlevel) as zipfp: zipfp.writepy(packagedir) names = zipfp.namelist() @@ -661,6 +693,26 @@ class PyZipFileTests(unittest.TestCase): finally: shutil.rmtree(TESTFN2) + def test_write_python_directory_filtered(self): + os.mkdir(TESTFN2) + try: + with open(os.path.join(TESTFN2, "mod1.py"), "w") as fp: + fp.write("print(42)\n") + + with open(os.path.join(TESTFN2, "mod2.py"), "w") as fp: + fp.write("print(42 * 42)\n") + + with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp: + zipfp.writepy(TESTFN2, filterfunc=lambda fn: + not fn.endswith('mod2.py')) + + names = zipfp.namelist() + self.assertCompiledIn('mod1.py', names) + self.assertNotIn('mod2.py', names) + + finally: + shutil.rmtree(TESTFN2) + def test_write_non_pyfile(self): with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp: with open(TESTFN, 'w') as f: @@ -764,25 +816,25 @@ class ExtractTests(unittest.TestCase): def test_extract_hackers_arcnames_windows_only(self): """Test combination of path fixing and windows name sanitization.""" windows_hacknames = [ - (r'..\foo\bar', 'foo/bar'), - (r'..\/foo\/bar', 'foo/bar'), - (r'foo/\..\/bar', 'foo/bar'), - (r'foo\/../\bar', 'foo/bar'), - (r'C:foo/bar', 'foo/bar'), - (r'C:/foo/bar', 'foo/bar'), - (r'C://foo/bar', 'foo/bar'), - (r'C:\foo\bar', 'foo/bar'), - (r'//conky/mountpoint/foo/bar', 'foo/bar'), - (r'\\conky\mountpoint\foo\bar', 'foo/bar'), - (r'///conky/mountpoint/foo/bar', 'conky/mountpoint/foo/bar'), - (r'\\\conky\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'), - (r'//conky//mountpoint/foo/bar', 'conky/mountpoint/foo/bar'), - (r'\\conky\\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'), - (r'//?/C:/foo/bar', 'foo/bar'), - (r'\\?\C:\foo\bar', 'foo/bar'), - (r'C:/../C:/foo/bar', 'C_/foo/bar'), - (r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'), - ('../../foo../../ba..r', 'foo/ba..r'), + (r'..\foo\bar', 'foo/bar'), + (r'..\/foo\/bar', 'foo/bar'), + (r'foo/\..\/bar', 'foo/bar'), + (r'foo\/../\bar', 'foo/bar'), + (r'C:foo/bar', 'foo/bar'), + (r'C:/foo/bar', 'foo/bar'), + (r'C://foo/bar', 'foo/bar'), + (r'C:\foo\bar', 'foo/bar'), + (r'//conky/mountpoint/foo/bar', 'foo/bar'), + (r'\\conky\mountpoint\foo\bar', 'foo/bar'), + (r'///conky/mountpoint/foo/bar', 'conky/mountpoint/foo/bar'), + (r'\\\conky\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'), + (r'//conky//mountpoint/foo/bar', 'conky/mountpoint/foo/bar'), + (r'\\conky\\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'), + (r'//?/C:/foo/bar', 'foo/bar'), + (r'\\?\C:\foo\bar', 'foo/bar'), + (r'C:/../C:/foo/bar', 'C_/foo/bar'), + (r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'), + ('../../foo../../ba..r', 'foo/ba..r'), ] self._test_extract_hackers_arcnames(windows_hacknames) @@ -860,6 +912,17 @@ class OtherTests(unittest.TestCase): data += zipfp.read(info) self.assertIn(data, {b"foobar", b"barfoo"}) + def test_universal_deprecation(self): + f = io.BytesIO() + with zipfile.ZipFile(f, "w") as zipfp: + zipfp.writestr('spam.txt', b'ababagalamaga') + + with zipfile.ZipFile(f, "r") as zipfp: + for mode in 'U', 'rU': + with self.assertWarns(DeprecationWarning): + zipopen = zipfp.open('spam.txt', mode) + zipopen.close() + def test_universal_readaheads(self): f = io.BytesIO() @@ -869,7 +932,7 @@ class OtherTests(unittest.TestCase): data2 = b'' with zipfile.ZipFile(f, 'r') as zipfp, \ - zipfp.open(TESTFN, 'rU') as zipopen: + openU(zipfp, TESTFN) as zipopen: for line in zipopen: data2 += line @@ -910,10 +973,10 @@ class OtherTests(unittest.TestCase): def test_unsupported_version(self): # File has an extract_version of 120 data = (b'PK\x03\x04x\x00\x00\x00\x00\x00!p\xa1@\x00\x00\x00\x00\x00\x00' - b'\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00xPK\x01\x02x\x03x\x00\x00\x00\x00' - b'\x00!p\xa1@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00' - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00xPK\x05\x06' - b'\x00\x00\x00\x00\x01\x00\x01\x00/\x00\x00\x00\x1f\x00\x00\x00\x00\x00') + b'\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00xPK\x01\x02x\x03x\x00\x00\x00\x00' + b'\x00!p\xa1@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00xPK\x05\x06' + b'\x00\x00\x00\x00\x01\x00\x01\x00/\x00\x00\x00\x1f\x00\x00\x00\x00\x00') self.assertRaises(NotImplementedError, zipfile.ZipFile, io.BytesIO(data), 'r') @@ -946,7 +1009,7 @@ class OtherTests(unittest.TestCase): try: with zipfile.ZipFile(TESTFN, 'a') as zf: zf.writestr(filename, content) - except IOError: + except OSError: self.fail('Could not append data to a non-existent zip file.') self.assertTrue(os.path.exists(TESTFN)) @@ -1018,7 +1081,7 @@ class OtherTests(unittest.TestCase): fp.seek(0, 0) self.assertTrue(zipfile.is_zipfile(fp)) - def test_non_existent_file_raises_IOError(self): + def test_non_existent_file_raises_OSError(self): # make sure we don't raise an AttributeError when a partially-constructed # ZipFile instance is finalized; this tests for regression on SF tracker # bug #403871. @@ -1030,7 +1093,7 @@ class OtherTests(unittest.TestCase): # it is ignored, but the user should be sufficiently annoyed by # the message on the output that regression will be noticed # quickly. - self.assertRaises(IOError, zipfile.ZipFile, TESTFN) + self.assertRaises(OSError, zipfile.ZipFile, TESTFN) def test_empty_file_raises_BadZipFile(self): f = open(TESTFN, 'w') @@ -1099,11 +1162,11 @@ class OtherTests(unittest.TestCase): def test_unsupported_compression(self): # data is declared as shrunk, but actually deflated data = (b'PK\x03\x04.\x00\x00\x00\x01\x00\xe4C\xa1@\x00\x00\x00' - b'\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00x\x03\x00PK\x01' - b'\x02.\x03.\x00\x00\x00\x01\x00\xe4C\xa1@\x00\x00\x00\x00\x02\x00\x00' - b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' - b'\x80\x01\x00\x00\x00\x00xPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x00' - b'/\x00\x00\x00!\x00\x00\x00\x00\x00') + b'\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00x\x03\x00PK\x01' + b'\x02.\x03.\x00\x00\x00\x01\x00\xe4C\xa1@\x00\x00\x00\x00\x02\x00\x00' + b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x80\x01\x00\x00\x00\x00xPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x00' + b'/\x00\x00\x00!\x00\x00\x00\x00\x00') with zipfile.ZipFile(io.BytesIO(data), 'r') as zipf: self.assertRaises(NotImplementedError, zipf.open, 'x') @@ -1218,7 +1281,7 @@ class OtherTests(unittest.TestCase): def test_open_empty_file(self): # Issue 1710703: Check that opening a file with less than 22 bytes # raises a BadZipFile exception (rather than the previously unhelpful - # IOError) + # OSError) f = open(TESTFN, 'w') f.close() self.assertRaises(zipfile.BadZipFile, zipfile.ZipFile, TESTFN, 'r') @@ -1266,57 +1329,57 @@ class AbstractBadCrcTests: class StoredBadCrcTests(AbstractBadCrcTests, unittest.TestCase): compression = zipfile.ZIP_STORED zip_with_bad_crc = ( - b'PK\003\004\024\0\0\0\0\0 \213\212;:r' - b'\253\377\f\0\0\0\f\0\0\0\005\0\0\000af' - b'ilehello,AworldP' - b'K\001\002\024\003\024\0\0\0\0\0 \213\212;:' - b'r\253\377\f\0\0\0\f\0\0\0\005\0\0\0\0' - b'\0\0\0\0\0\0\0\200\001\0\0\0\000afi' - b'lePK\005\006\0\0\0\0\001\0\001\0003\000' - b'\0\0/\0\0\0\0\0') + b'PK\003\004\024\0\0\0\0\0 \213\212;:r' + b'\253\377\f\0\0\0\f\0\0\0\005\0\0\000af' + b'ilehello,AworldP' + b'K\001\002\024\003\024\0\0\0\0\0 \213\212;:' + b'r\253\377\f\0\0\0\f\0\0\0\005\0\0\0\0' + b'\0\0\0\0\0\0\0\200\001\0\0\0\000afi' + b'lePK\005\006\0\0\0\0\001\0\001\0003\000' + b'\0\0/\0\0\0\0\0') @requires_zlib class DeflateBadCrcTests(AbstractBadCrcTests, unittest.TestCase): compression = zipfile.ZIP_DEFLATED zip_with_bad_crc = ( - b'PK\x03\x04\x14\x00\x00\x00\x08\x00n}\x0c=FA' - b'KE\x10\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af' - b'ile\xcbH\xcd\xc9\xc9W(\xcf/\xcaI\xc9\xa0' - b'=\x13\x00PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00n' - b'}\x0c=FAKE\x10\x00\x00\x00n\x00\x00\x00\x05' - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00' - b'\x00afilePK\x05\x06\x00\x00\x00\x00\x01\x00' - b'\x01\x003\x00\x00\x003\x00\x00\x00\x00\x00') + b'PK\x03\x04\x14\x00\x00\x00\x08\x00n}\x0c=FA' + b'KE\x10\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af' + b'ile\xcbH\xcd\xc9\xc9W(\xcf/\xcaI\xc9\xa0' + b'=\x13\x00PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00n' + b'}\x0c=FAKE\x10\x00\x00\x00n\x00\x00\x00\x05' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00' + b'\x00afilePK\x05\x06\x00\x00\x00\x00\x01\x00' + b'\x01\x003\x00\x00\x003\x00\x00\x00\x00\x00') @requires_bz2 class Bzip2BadCrcTests(AbstractBadCrcTests, unittest.TestCase): compression = zipfile.ZIP_BZIP2 zip_with_bad_crc = ( - b'PK\x03\x04\x14\x03\x00\x00\x0c\x00nu\x0c=FA' - b'KE8\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af' - b'ileBZh91AY&SY\xd4\xa8\xca' - b'\x7f\x00\x00\x0f\x11\x80@\x00\x06D\x90\x80 \x00 \xa5' - b'P\xd9!\x03\x03\x13\x13\x13\x89\xa9\xa9\xc2u5:\x9f' - b'\x8b\xb9"\x9c(HjTe?\x80PK\x01\x02\x14' - b'\x03\x14\x03\x00\x00\x0c\x00nu\x0c=FAKE8' - b'\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00' - b'\x00 \x80\x80\x81\x00\x00\x00\x00afilePK' - b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00[\x00' - b'\x00\x00\x00\x00') + b'PK\x03\x04\x14\x03\x00\x00\x0c\x00nu\x0c=FA' + b'KE8\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af' + b'ileBZh91AY&SY\xd4\xa8\xca' + b'\x7f\x00\x00\x0f\x11\x80@\x00\x06D\x90\x80 \x00 \xa5' + b'P\xd9!\x03\x03\x13\x13\x13\x89\xa9\xa9\xc2u5:\x9f' + b'\x8b\xb9"\x9c(HjTe?\x80PK\x01\x02\x14' + b'\x03\x14\x03\x00\x00\x0c\x00nu\x0c=FAKE8' + b'\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00 \x80\x80\x81\x00\x00\x00\x00afilePK' + b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00[\x00' + b'\x00\x00\x00\x00') @requires_lzma class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase): compression = zipfile.ZIP_LZMA zip_with_bad_crc = ( - b'PK\x03\x04\x14\x03\x00\x00\x0e\x00nu\x0c=FA' - b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af' - b'ile\t\x04\x05\x00]\x00\x00\x00\x04\x004\x19I' - b'\xee\x8d\xe9\x17\x89:3`\tq!.8\x00PK' - b'\x01\x02\x14\x03\x14\x03\x00\x00\x0e\x00nu\x0c=FA' - b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00' - b'\x00\x00\x00\x00 \x80\x80\x81\x00\x00\x00\x00afil' - b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00' - b'\x00>\x00\x00\x00\x00\x00') + b'PK\x03\x04\x14\x03\x00\x00\x0e\x00nu\x0c=FA' + b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af' + b'ile\t\x04\x05\x00]\x00\x00\x00\x04\x004\x19I' + b'\xee\x8d\xe9\x17\x89:3`\tq!.8\x00PK' + b'\x01\x02\x14\x03\x14\x03\x00\x00\x0e\x00nu\x0c=FA' + b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00 \x80\x80\x81\x00\x00\x00\x00afil' + b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00' + b'\x00>\x00\x00\x00\x00\x00') class DecryptionTests(unittest.TestCase): @@ -1325,22 +1388,22 @@ class DecryptionTests(unittest.TestCase): ZIP file.""" data = ( - b'PK\x03\x04\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00\x1a\x00' - b'\x00\x00\x08\x00\x00\x00test.txt\xfa\x10\xa0gly|\xfa-\xc5\xc0=\xf9y' - b'\x18\xe0\xa8r\xb3Z}Lg\xbc\xae\xf9|\x9b\x19\xe4\x8b\xba\xbb)\x8c\xb0\xdbl' - b'PK\x01\x02\x14\x00\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00' - b'\x1a\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x01\x00 \x00\xb6\x81' - b'\x00\x00\x00\x00test.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x006\x00' - b'\x00\x00L\x00\x00\x00\x00\x00' ) + b'PK\x03\x04\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00\x1a\x00' + b'\x00\x00\x08\x00\x00\x00test.txt\xfa\x10\xa0gly|\xfa-\xc5\xc0=\xf9y' + b'\x18\xe0\xa8r\xb3Z}Lg\xbc\xae\xf9|\x9b\x19\xe4\x8b\xba\xbb)\x8c\xb0\xdbl' + b'PK\x01\x02\x14\x00\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00' + b'\x1a\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x01\x00 \x00\xb6\x81' + b'\x00\x00\x00\x00test.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x006\x00' + b'\x00\x00L\x00\x00\x00\x00\x00' ) data2 = ( - b'PK\x03\x04\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02' - b'\x00\x00\x04\x00\x15\x00zeroUT\t\x00\x03\xd6\x8b\x92G\xda\x8b\x92GUx\x04' - b'\x00\xe8\x03\xe8\x03\xc7<M\xb5a\xceX\xa3Y&\x8b{oE\xd7\x9d\x8c\x98\x02\xc0' - b'PK\x07\x08xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00PK\x01\x02\x17\x03' - b'\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00' - b'\x04\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x81\x00\x00\x00\x00ze' - b'roUT\x05\x00\x03\xd6\x8b\x92GUx\x00\x00PK\x05\x06\x00\x00\x00\x00\x01' - b'\x00\x01\x00?\x00\x00\x00[\x00\x00\x00\x00\x00' ) + b'PK\x03\x04\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02' + b'\x00\x00\x04\x00\x15\x00zeroUT\t\x00\x03\xd6\x8b\x92G\xda\x8b\x92GUx\x04' + b'\x00\xe8\x03\xe8\x03\xc7<M\xb5a\xceX\xa3Y&\x8b{oE\xd7\x9d\x8c\x98\x02\xc0' + b'PK\x07\x08xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00PK\x01\x02\x17\x03' + b'\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00' + b'\x04\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x81\x00\x00\x00\x00ze' + b'roUT\x05\x00\x03\xd6\x8b\x92GUx\x00\x00PK\x05\x06\x00\x00\x00\x00\x01' + b'\x00\x01\x00?\x00\x00\x00[\x00\x00\x00\x00\x00' ) plain = b'zipfile.py encryption test' plain2 = b'\x00'*512 @@ -1599,7 +1662,7 @@ class AbstractUniversalNewlineTests: # Read the ZIP archive with zipfile.ZipFile(f, "r") as zipfp: for sep, fn in self.arcfiles.items(): - with zipfp.open(fn, "rU") as fp: + with openU(zipfp, fn) as fp: zipdata = fp.read() self.assertEqual(self.arcdata[sep], zipdata) @@ -1613,7 +1676,7 @@ class AbstractUniversalNewlineTests: # Read the ZIP archive with zipfile.ZipFile(f, "r") as zipfp: for sep, fn in self.arcfiles.items(): - with zipfp.open(fn, "rU") as zipopen: + with openU(zipfp, fn) as zipopen: data = b'' while True: read = zipopen.readline() @@ -1638,7 +1701,7 @@ class AbstractUniversalNewlineTests: # Read the ZIP archive with zipfile.ZipFile(f, "r") as zipfp: for sep, fn in self.arcfiles.items(): - with zipfp.open(fn, "rU") as zipopen: + with openU(zipfp, fn) as zipopen: for line in self.line_gen: linedata = zipopen.readline() self.assertEqual(linedata, line + b'\n') @@ -1653,7 +1716,7 @@ class AbstractUniversalNewlineTests: # Read the ZIP archive with zipfile.ZipFile(f, "r") as zipfp: for sep, fn in self.arcfiles.items(): - with zipfp.open(fn, "rU") as fp: + with openU(zipfp, fn) as fp: ziplines = fp.readlines() for line, zipline in zip(self.line_gen, ziplines): self.assertEqual(zipline, line + b'\n') @@ -1668,7 +1731,7 @@ class AbstractUniversalNewlineTests: # Read the ZIP archive with zipfile.ZipFile(f, "r") as zipfp: for sep, fn in self.arcfiles.items(): - with zipfp.open(fn, "rU") as fp: + with openU(zipfp, fn) as fp: for line, zipline in zip(self.line_gen, fp): self.assertEqual(zipline, line + b'\n') @@ -1702,6 +1765,5 @@ class LzmaUniversalNewlineTests(AbstractUniversalNewlineTests, unittest.TestCase): compression = zipfile.ZIP_LZMA - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_zipfile64.py b/Lib/test/test_zipfile64.py index a8fb7aba54..498d464db3 100644 --- a/Lib/test/test_zipfile64.py +++ b/Lib/test/test_zipfile64.py @@ -38,7 +38,7 @@ class TestsWithSourceFile(unittest.TestCase): def zipTest(self, f, compression): # Create the ZIP archive. - zipfp = zipfile.ZipFile(f, "w", compression, allowZip64=True) + zipfp = zipfile.ZipFile(f, "w", compression) # It will contain enough copies of self.data to reach about 6GB of # raw data to store. @@ -92,7 +92,7 @@ class OtherTests(unittest.TestCase): def testMoreThan64kFiles(self): # This test checks that more than 64k files can be added to an archive, # and that the resulting archive can be read properly by ZipFile - zipf = zipfile.ZipFile(TESTFN, mode="w") + zipf = zipfile.ZipFile(TESTFN, mode="w", allowZip64=False) zipf.debug = 100 numfiles = (1 << 16) * 3//2 for i in range(numfiles): diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py index 37603b9bcd..1e351c8c8a 100644 --- a/Lib/test/test_zipimport.py +++ b/Lib/test/test_zipimport.py @@ -1,13 +1,12 @@ import sys import os import marshal -import imp +import importlib.util import struct import time import unittest from test import support -from test.test_importhooks import ImportHooksBaseTestCase, test_src, test_co from zipfile import ZipFile, ZipInfo, ZIP_STORED, ZIP_DEFLATED @@ -17,6 +16,14 @@ import doctest import inspect import io from traceback import extract_tb, extract_stack, print_tb + +test_src = """\ +def get_name(): + return __name__ +def get_file(): + return __file__ +""" +test_co = compile(test_src, "<???>", "exec") raise_src = 'def do_raise(): raise TypeError\n' def make_pyc(co, mtime, size): @@ -27,7 +34,8 @@ def make_pyc(co, mtime, size): mtime = int(mtime) else: mtime = int(-0x100000000 + int(mtime)) - pyc = imp.get_magic() + struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data + pyc = (importlib.util.MAGIC_NUMBER + + struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data) return pyc def module_path_to_dotted_name(path): @@ -42,10 +50,27 @@ TESTPACK = "ziptestpackage" TESTPACK2 = "ziptestpackage2" TEMP_ZIP = os.path.abspath("junk95142.zip") -pyc_file = imp.cache_from_source(TESTMOD + '.py') +pyc_file = importlib.util.cache_from_source(TESTMOD + '.py') pyc_ext = ('.pyc' if __debug__ else '.pyo') +class ImportHooksBaseTestCase(unittest.TestCase): + + def setUp(self): + self.path = sys.path[:] + self.meta_path = sys.meta_path[:] + self.path_hooks = sys.path_hooks[:] + sys.path_importer_cache.clear() + self.modules_before = support.modules_setup() + + def tearDown(self): + sys.path[:] = self.path + sys.meta_path[:] = self.meta_path + sys.path_hooks[:] = self.path_hooks + sys.path_importer_cache.clear() + support.modules_cleanup(*self.modules_before) + + class UncompressedZipImportTestCase(ImportHooksBaseTestCase): compression = ZIP_STORED @@ -196,6 +221,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): for name, (mtime, data) in files.items(): zinfo = ZipInfo(name, time.localtime(mtime)) zinfo.compress_type = self.compression + zinfo.comment = b"spam" z.writestr(zinfo, data) z.close() @@ -245,6 +271,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): for name, (mtime, data) in files.items(): zinfo = ZipInfo(name, time.localtime(mtime)) zinfo.compress_type = self.compression + zinfo.comment = b"eggs" z.writestr(zinfo, data) z.close() @@ -459,7 +486,7 @@ class BadFileZipImportTestCase(unittest.TestCase): self.assertRaises(error, z.load_module, 'abc') self.assertRaises(error, z.get_code, 'abc') - self.assertRaises(IOError, z.get_data, 'abc') + self.assertRaises(OSError, z.get_data, 'abc') self.assertRaises(error, z.get_source, 'abc') self.assertRaises(error, z.is_package, 'abc') finally: diff --git a/Lib/test/test_zipimport_support.py b/Lib/test/test_zipimport_support.py index f7f3398015..84ba5e047a 100644 --- a/Lib/test/test_zipimport_support.py +++ b/Lib/test/test_zipimport_support.py @@ -227,13 +227,15 @@ class ZipSupportTests(unittest.TestCase): p = spawn_python(script_name) p.stdin.write(b'l\n') data = kill_python(p) - self.assertIn(script_name.encode('utf-8'), data) + # bdb/pdb applies normcase to its filename before displaying + self.assertIn(os.path.normcase(script_name.encode('utf-8')), data) zip_name, run_name = make_zip_script(d, "test_zip", script_name, '__main__.py') p = spawn_python(zip_name) p.stdin.write(b'l\n') data = kill_python(p) - self.assertIn(run_name.encode('utf-8'), data) + # bdb/pdb applies normcase to its filename before displaying + self.assertIn(os.path.normcase(run_name.encode('utf-8')), data) def test_main(): diff --git a/Lib/test/tf_inherit_check.py b/Lib/test/tf_inherit_check.py index 92ebd95e52..afe50d2325 100644 --- a/Lib/test/tf_inherit_check.py +++ b/Lib/test/tf_inherit_check.py @@ -11,7 +11,7 @@ try: try: os.write(fd, b"blat") - except os.error: + except OSError: # Success -- could not write to fd. sys.exit(0) else: |