diff options
Diffstat (limited to 'Lib/test/test_multiprocessing.py')
-rw-r--r-- | Lib/test/test_multiprocessing.py | 1379 |
1 files changed, 1219 insertions, 160 deletions
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 05fba7fb68..aa66db4b7e 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # # Unit tests for the multiprocessing package # @@ -8,6 +6,7 @@ import unittest import queue as pyqueue import time import io +import itertools import sys import os import gc @@ -17,6 +16,8 @@ import array import socket import random import logging +import struct +import operator import test.support import test.script_helper @@ -84,6 +85,13 @@ HAVE_GETVALUE = not getattr(_multiprocessing, WIN32 = (sys.platform == "win32") +from multiprocessing.connection import wait + +def wait_for_handle(handle, timeout): + if timeout is not None and timeout < 0.0: + timeout = None + return wait([handle], timeout) + try: MAXFD = os.sysconf("SC_OPEN_MAX") except: @@ -185,7 +193,7 @@ class _TestProcess(BaseTestCase): def test_current(self): if self.TYPE == 'threads': - return + self.skipTest('test not appropriate for {}'.format(self.TYPE)) current = self.current_process() authkey = current.authkey @@ -197,6 +205,18 @@ class _TestProcess(BaseTestCase): self.assertEqual(current.ident, os.getpid()) self.assertEqual(current.exitcode, None) + def test_daemon_argument(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + # By default uses the current process's daemon flag. + proc0 = self.Process(target=self._test) + self.assertEqual(proc0.daemon, self.current_process().daemon) + proc1 = self.Process(target=self._test, daemon=True) + self.assertTrue(proc1.daemon) + proc2 = self.Process(target=self._test, daemon=False) + self.assertFalse(proc2.daemon) + @classmethod def _test(cls, q, *args, **kwds): current = cls.current_process() @@ -248,11 +268,11 @@ class _TestProcess(BaseTestCase): @classmethod def _test_terminate(cls): - time.sleep(1000) + time.sleep(100) def test_terminate(self): if self.TYPE == 'threads': - return + self.skipTest('test not appropriate for {}'.format(self.TYPE)) p = self.Process(target=self._test_terminate) p.daemon = True @@ -262,10 +282,31 @@ class _TestProcess(BaseTestCase): self.assertIn(p, self.active_children()) self.assertEqual(p.exitcode, None) + join = TimingWrapper(p.join) + + self.assertEqual(join(0), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + + self.assertEqual(join(-1), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + p.terminate() - join = TimingWrapper(p.join) - self.assertEqual(join(), None) + if hasattr(signal, 'alarm'): + def handler(*args): + raise RuntimeError('join took too long: %s' % p) + old_handler = signal.signal(signal.SIGALRM, handler) + try: + signal.alarm(10) + self.assertEqual(join(), None) + signal.alarm(0) + finally: + signal.signal(signal.SIGALRM, old_handler) + else: + self.assertEqual(join(), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) self.assertEqual(p.is_alive(), False) @@ -329,6 +370,26 @@ class _TestProcess(BaseTestCase): ] self.assertEqual(result, expected) + @classmethod + def _test_sentinel(cls, event): + event.wait(10.0) + + def test_sentinel(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + event = self.Event() + p = self.Process(target=self._test_sentinel, args=(event,)) + with self.assertRaises(ValueError): + p.sentinel + p.start() + self.addCleanup(p.join) + sentinel = p.sentinel + self.assertIsInstance(sentinel, int) + self.assertFalse(wait_for_handle(sentinel, timeout=0.0)) + event.set() + p.join() + self.assertTrue(wait_for_handle(sentinel, timeout=DELTA)) + # # # @@ -371,7 +432,7 @@ class _TestSubclassingProcess(BaseTestCase): def test_stderr_flush(self): # sys.stderr is flushed at process shutdown (issue #13812) if self.TYPE == "threads": - return + self.skipTest('test not appropriate for {}'.format(self.TYPE)) testfn = test.support.TESTFN self.addCleanup(test.support.unlink, testfn) @@ -399,12 +460,12 @@ class _TestSubclassingProcess(BaseTestCase): def test_sys_exit(self): # See Issue 13854 if self.TYPE == 'threads': - return + self.skipTest('test not appropriate for {}'.format(self.TYPE)) testfn = test.support.TESTFN self.addCleanup(test.support.unlink, testfn) - for reason, code in (([1, 2, 3], 1), ('ignore this', 0)): + for reason, code in (([1, 2, 3], 1), ('ignore this', 1)): p = self.Process(target=self._test_sys_exit, args=(reason, testfn)) p.daemon = True p.start() @@ -608,7 +669,7 @@ class _TestQueue(BaseTestCase): try: self.assertEqual(q.qsize(), 0) except NotImplementedError: - return + self.skipTest('qsize method not implemented') q.put(1) self.assertEqual(q.qsize(), 1) q.put(5) @@ -648,6 +709,13 @@ class _TestQueue(BaseTestCase): for p in workers: p.join() + def test_timeout(self): + q = multiprocessing.Queue() + start = time.time() + self.assertRaises(pyqueue.Empty, q.get, True, 0.2) + delta = time.time() - start + self.assertGreaterEqual(delta, 0.18) + # # # @@ -709,7 +777,7 @@ class _TestSemaphore(BaseTestCase): def test_timeout(self): if self.TYPE != 'processes': - return + self.skipTest('test not appropriate for {}'.format(self.TYPE)) sem = self.Semaphore(0) acquire = TimingWrapper(sem.acquire) @@ -868,6 +936,104 @@ class _TestCondition(BaseTestCase): self.assertEqual(res, False) self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1) + @classmethod + def _test_waitfor_f(cls, cond, state): + with cond: + state.value = 0 + cond.notify() + result = cond.wait_for(lambda : state.value==4) + if not result or state.value != 4: + sys.exit(1) + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', -1) + + p = self.Process(target=self._test_waitfor_f, args=(cond, state)) + p.daemon = True + p.start() + + with cond: + result = cond.wait_for(lambda : state.value==0) + self.assertTrue(result) + self.assertEqual(state.value, 0) + + for i in range(4): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + p.join(5) + self.assertFalse(p.is_alive()) + self.assertEqual(p.exitcode, 0) + + @classmethod + def _test_waitfor_timeout_f(cls, cond, state, success, sem): + sem.release() + with cond: + expected = 0.1 + dt = time.time() + result = cond.wait_for(lambda : state.value==4, timeout=expected) + dt = time.time() - dt + # borrow logic in assertTimeout() from test/lock_tests.py + if not result and expected * 0.6 < dt < expected * 10.0: + success.value = True + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor_timeout(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', 0) + success = self.Value('i', False) + sem = self.Semaphore(0) + + p = self.Process(target=self._test_waitfor_timeout_f, + args=(cond, state, success, sem)) + p.daemon = True + p.start() + self.assertTrue(sem.acquire(timeout=10)) + + # Only increment 3 times, so state == 4 is never reached. + for i in range(3): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + p.join(5) + self.assertTrue(success.value) + + @classmethod + def _test_wait_result(cls, c, pid): + with c: + c.notify() + time.sleep(1) + if pid is not None: + os.kill(pid, signal.SIGINT) + + def test_wait_result(self): + if isinstance(self, ProcessesMixin) and sys.platform != 'win32': + pid = os.getpid() + else: + pid = None + + c = self.Condition() + with c: + self.assertFalse(c.wait(0)) + self.assertFalse(c.wait(0.1)) + + p = self.Process(target=self._test_wait_result, args=(c, pid)) + p.start() + + self.assertTrue(c.wait(10)) + if pid is not None: + self.assertRaises(KeyboardInterrupt, c.wait, 10) + + p.join() + class _TestEvent(BaseTestCase): @@ -911,6 +1077,340 @@ class _TestEvent(BaseTestCase): self.assertEqual(wait(), True) # +# Tests for Barrier - adapted from tests in test/lock_tests.py +# + +# Many of the tests for threading.Barrier use a list as an atomic +# counter: a value is appended to increment the counter, and the +# length of the list gives the value. We use the class DummyList +# for the same purpose. + +class _DummyList(object): + + def __init__(self): + wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i')) + lock = multiprocessing.Lock() + self.__setstate__((wrapper, lock)) + self._lengthbuf[0] = 0 + + def __setstate__(self, state): + (self._wrapper, self._lock) = state + self._lengthbuf = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._wrapper, self._lock) + + def append(self, _): + with self._lock: + self._lengthbuf[0] += 1 + + def __len__(self): + with self._lock: + return self._lengthbuf[0] + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, namespace, f, args, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.args = args + self.n = n + self.started = namespace.DummyList() + self.finished = namespace.DummyList() + self._can_exit = namespace.Event() + if not wait_before_exit: + self._can_exit.set() + for i in range(n): + p = namespace.Process(target=self.task) + p.daemon = True + p.start() + + def task(self): + pid = os.getpid() + self.started.append(pid) + try: + self.f(*self.args) + finally: + self.finished.append(pid) + self._can_exit.wait(30) + assert self._can_exit.is_set() + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + + def do_finish(self): + self._can_exit.set() + + +class AppendTrue(object): + def __init__(self, obj): + self.obj = obj + def __call__(self): + self.obj.append(True) + + +class _TestBarrier(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 30.0 # XXX Slow Windows buildbots need generous timeout + + def setUp(self): + self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout) + + def tearDown(self): + self.barrier.abort() + self.barrier = None + + def DummyList(self): + if self.TYPE == 'threads': + return [] + elif self.TYPE == 'manager': + return self.manager.list() + else: + return _DummyList() + + def run_threads(self, f, args): + b = Bunch(self, f, args, self.N-1) + f(*args) + b.wait_for_finished() + + @classmethod + def multipass(cls, barrier, results, n): + m = barrier.parties + assert m == cls.N + for i in range(n): + results[0].append(True) + assert len(results[1]) == i * m + barrier.wait() + results[1].append(True) + assert len(results[0]) == (i + 1) * m + barrier.wait() + try: + assert barrier.n_waiting == 0 + except NotImplementedError: + pass + assert not barrier.broken + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [self.DummyList(), self.DummyList()] + self.run_threads(self.multipass, (self.barrier, results, passes)) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + @classmethod + def _test_wait_return_f(cls, barrier, queue): + res = barrier.wait() + queue.put(res) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + queue = self.Queue() + self.run_threads(self._test_wait_return_f, (self.barrier, queue)) + results = [queue.get() for i in range(self.N)] + self.assertEqual(results.count(0), 1) + + @classmethod + def _test_action_f(cls, barrier, results): + barrier.wait() + if len(results) != 1: + raise RuntimeError + + def test_action(self): + """ + Test the 'action' callback + """ + results = self.DummyList() + barrier = self.Barrier(self.N, action=AppendTrue(results)) + self.run_threads(self._test_action_f, (barrier, results)) + self.assertEqual(len(results), 1) + + @classmethod + def _test_abort_f(cls, barrier, results1, results2): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = self.DummyList() + results2 = self.DummyList() + self.run_threads(self._test_abort_f, + (self.barrier, results1, results2)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + @classmethod + def _test_reset_f(cls, barrier, results1, results2, results3): + i = barrier.wait() + if i == cls.N//2: + # Wait until the other threads are all in the barrier. + while barrier.n_waiting < cls.N-1: + time.sleep(0.001) + barrier.reset() + else: + try: + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + barrier.wait() + results3.append(True) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + self.run_threads(self._test_reset_f, + (self.barrier, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_abort_and_reset_f(cls, barrier, barrier2, + results1, results2, results3): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == cls.N//2: + barrier.reset() + barrier2.wait() + barrier.wait() + results3.append(True) + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + barrier2 = self.Barrier(self.N) + + self.run_threads(self._test_abort_and_reset_f, + (self.barrier, barrier2, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_timeout_f(cls, barrier, results): + i = barrier.wait() + if i == cls.N//2: + # One thread is late! + time.sleep(1.0) + try: + barrier.wait(0.5) + except threading.BrokenBarrierError: + results.append(True) + + def test_timeout(self): + """ + Test wait(timeout) + """ + results = self.DummyList() + self.run_threads(self._test_timeout_f, (self.barrier, results)) + self.assertEqual(len(results), self.barrier.parties) + + @classmethod + def _test_default_timeout_f(cls, barrier, results): + i = barrier.wait(cls.defaultTimeout) + if i == cls.N//2: + # One thread is later than the default timeout + time.sleep(1.0) + try: + barrier.wait() + except threading.BrokenBarrierError: + results.append(True) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + barrier = self.Barrier(self.N, timeout=0.5) + results = self.DummyList() + self.run_threads(self._test_default_timeout_f, (barrier, results)) + self.assertEqual(len(results), barrier.parties) + + def test_single_thread(self): + b = self.Barrier(1) + b.wait() + b.wait() + + @classmethod + def _test_thousand_f(cls, barrier, passes, conn, lock): + for i in range(passes): + barrier.wait() + with lock: + conn.send(i) + + def test_thousand(self): + if self.TYPE == 'manager': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + passes = 1000 + lock = self.Lock() + conn, child_conn = self.Pipe(False) + for j in range(self.N): + p = self.Process(target=self._test_thousand_f, + args=(self.barrier, passes, child_conn, lock)) + p.start() + + for i in range(passes): + for j in range(self.N): + self.assertEqual(conn.recv(), i) + +# # # @@ -1130,8 +1630,23 @@ def sqr(x, wait=0.0): time.sleep(wait) return x*x +def mul(x, y): + return x*y + class _TestPool(BaseTestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.pool = cls.Pool(4) + + @classmethod + def tearDownClass(cls): + cls.pool.terminate() + cls.pool.join() + cls.pool = None + super().tearDownClass() + def test_apply(self): papply = self.pool.apply self.assertEqual(papply(sqr, (5,)), sqr(5)) @@ -1143,6 +1658,47 @@ class _TestPool(BaseTestCase): self.assertEqual(pmap(sqr, list(range(100)), chunksize=20), list(map(sqr, list(range(100))))) + def test_starmap(self): + psmap = self.pool.starmap + tuples = list(zip(range(10), range(9,-1, -1))) + self.assertEqual(psmap(mul, tuples), + list(itertools.starmap(mul, tuples))) + tuples = list(zip(range(100), range(99,-1, -1))) + self.assertEqual(psmap(mul, tuples, chunksize=20), + list(itertools.starmap(mul, tuples))) + + def test_starmap_async(self): + tuples = list(zip(range(100), range(99,-1, -1))) + self.assertEqual(self.pool.starmap_async(mul, tuples).get(), + list(itertools.starmap(mul, tuples))) + + def test_map_async(self): + self.assertEqual(self.pool.map_async(sqr, list(range(10))).get(), + list(map(sqr, list(range(10))))) + + def test_map_async_callbacks(self): + call_args = self.manager.list() if self.TYPE == 'manager' else [] + self.pool.map_async(int, ['1'], + callback=call_args.append, + error_callback=call_args.append).wait() + self.assertEqual(1, len(call_args)) + self.assertEqual([1], call_args[0]) + self.pool.map_async(int, ['a'], + callback=call_args.append, + error_callback=call_args.append).wait() + self.assertEqual(2, len(call_args)) + self.assertIsInstance(call_args[1], ValueError) + + def test_map_unplicklable(self): + # Issue #19425 -- failure to pickle should not cause a hang + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + class A(object): + def __reduce__(self): + raise RuntimeError('cannot pickle') + with self.assertRaises(RuntimeError): + self.pool.map(sqr, [A()]*10) + def test_map_chunksize(self): try: self.pool.map_async(sqr, [], chunksize=1).get(timeout=TIMEOUT1) @@ -1156,7 +1712,7 @@ class _TestPool(BaseTestCase): self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) def test_async_timeout(self): - res = self.pool.apply_async(sqr, (6, TIMEOUT2 + 0.2)) + res = self.pool.apply_async(sqr, (6, TIMEOUT2 + 1.0)) get = TimingWrapper(res.get) self.assertRaises(multiprocessing.TimeoutError, get, timeout=TIMEOUT2) self.assertTimingAlmostEqual(get.elapsed, TIMEOUT2) @@ -1192,15 +1748,6 @@ class _TestPool(BaseTestCase): p.join() def test_terminate(self): - if self.TYPE == 'manager': - # On Unix a forked process increfs each shared object to - # which its parent process held a reference. If the - # forked process gets terminated then there is likely to - # be a reference leak. So to prevent - # _TestZZZNumberOfObjects from failing we skip this test - # when using a manager. - return - result = self.pool.map_async( time.sleep, [0.1 for i in range(10000)], chunksize=1 ) @@ -1221,6 +1768,15 @@ class _TestPool(BaseTestCase): p.close() p.join() + def test_context(self): + if self.TYPE == 'processes': + L = list(range(10)) + expected = [sqr(i) for i in L] + with multiprocessing.Pool(2) as p: + r = p.map_async(sqr, L) + self.assertEqual(r.get(), expected) + self.assertRaises(ValueError, p.map_async, sqr, L) + def raising(): raise KeyError("key") @@ -1312,30 +1868,6 @@ class _TestPoolWorkerLifetime(BaseTestCase): for (j, res) in enumerate(results): self.assertEqual(res.get(), sqr(j)) - -# -# Test that manager has expected number of shared objects left -# - -class _TestZZZNumberOfObjects(BaseTestCase): - # Because test cases are sorted alphabetically, this one will get - # run after all the other tests for the manager. It tests that - # there have been no "reference leaks" for the manager's shared - # objects. Note the comment in _TestPool.test_terminate(). - ALLOWED_TYPES = ('manager',) - - def test_number_of_objects(self): - EXPECTED_NUMBER = 1 # the pool object is still alive - multiprocessing.active_children() # discard dead process objs - gc.collect() # do garbage collection - refs = self.manager._number_of_objects() - debug_info = self.manager._debug_info() - if refs != EXPECTED_NUMBER: - print(self.manager._debug_info()) - print(debug_info) - - self.assertEqual(refs, EXPECTED_NUMBER) - # # Test of creating a customized manager class # @@ -1376,7 +1908,27 @@ class _TestMyManager(BaseTestCase): def test_mymanager(self): manager = MyManager() manager.start() + self.common(manager) + manager.shutdown() + + # If the manager process exited cleanly then the exitcode + # will be zero. Otherwise (after a short timeout) + # terminate() is used, resulting in an exitcode of -SIGTERM. + self.assertEqual(manager._process.exitcode, 0) + + def test_mymanager_context(self): + with MyManager() as manager: + self.common(manager) + self.assertEqual(manager._process.exitcode, 0) + + def test_mymanager_context_prestarted(self): + manager = MyManager() + manager.start() + with manager: + self.common(manager) + self.assertEqual(manager._process.exitcode, 0) + def common(self, manager): foo = manager.Foo() bar = manager.Bar() baz = manager.baz() @@ -1399,7 +1951,6 @@ class _TestMyManager(BaseTestCase): self.assertEqual(list(baz), [i*i for i in range(10)]) - manager.shutdown() # # Test of connecting to a remote server and using xmlrpclib for serialization @@ -1437,7 +1988,7 @@ class _TestRemoteManager(BaseTestCase): authkey = os.urandom(32) manager = QueueManager( - address=('localhost', 0), authkey=authkey, serializer=SERIALIZER + address=(test.support.HOST, 0), authkey=authkey, serializer=SERIALIZER ) manager.start() @@ -1475,7 +2026,7 @@ class _TestManagerRestart(BaseTestCase): def test_rapid_restart(self): authkey = os.urandom(32) manager = QueueManager( - address=('localhost', 0), authkey=authkey, serializer=SERIALIZER) + address=(test.support.HOST, 0), authkey=authkey, serializer=SERIALIZER) srvr = manager.get_server() addr = srvr.address # Close the connection.Listener socket which gets opened as a part @@ -1494,7 +2045,7 @@ class _TestManagerRestart(BaseTestCase): address=addr, authkey=authkey, serializer=SERIALIZER) try: manager.start() - except IOError as e: + except OSError as e: if e.errno != errno.EADDRINUSE: raise # Retry after some time, in case the old socket was lingering @@ -1570,6 +2121,9 @@ class _TestConnection(BaseTestCase): self.assertEqual(poll(), False) self.assertTimingAlmostEqual(poll.elapsed, 0) + self.assertEqual(poll(-1), False) + self.assertTimingAlmostEqual(poll.elapsed, 0) + self.assertEqual(poll(TIMEOUT1), False) self.assertTimingAlmostEqual(poll.elapsed, TIMEOUT1) @@ -1605,9 +2159,9 @@ class _TestConnection(BaseTestCase): self.assertEqual(reader.writable, False) self.assertEqual(writer.readable, False) self.assertEqual(writer.writable, True) - self.assertRaises(IOError, reader.send, 2) - self.assertRaises(IOError, writer.recv) - self.assertRaises(IOError, writer.poll) + self.assertRaises(OSError, reader.send, 2) + self.assertRaises(OSError, writer.recv) + self.assertRaises(OSError, writer.poll) def test_spawn_close(self): # We test that a pipe connection can be closed by parent @@ -1632,7 +2186,7 @@ class _TestConnection(BaseTestCase): def test_sendbytes(self): if self.TYPE != 'processes': - return + self.skipTest('test not appropriate for {}'.format(self.TYPE)) msg = latin('abcdefghijklmnopqrstuvwxyz') a, b = self.Pipe() @@ -1756,6 +2310,43 @@ class _TestConnection(BaseTestCase): self.assertRaises(RuntimeError, reduction.recv_handle, conn) p.join() + def test_context(self): + a, b = self.Pipe() + + with a, b: + a.send(1729) + self.assertEqual(b.recv(), 1729) + if self.TYPE == 'processes': + self.assertFalse(a.closed) + self.assertFalse(b.closed) + + if self.TYPE == 'processes': + self.assertTrue(a.closed) + self.assertTrue(b.closed) + self.assertRaises(OSError, a.recv) + self.assertRaises(OSError, b.recv) + +class _TestListener(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def test_multiple_bind(self): + for family in self.connection.families: + l = self.connection.Listener(family=family) + self.addCleanup(l.close) + self.assertRaises(OSError, self.connection.Listener, + l.address, family) + + def test_context(self): + with self.connection.Listener() as l: + with self.connection.Client(l.address) as c: + with l.accept() as d: + c.send(1729) + self.assertEqual(d.recv(), 1729) + + if self.TYPE == 'processes': + self.assertRaises(OSError, l.accept) + class _TestListenerClient(BaseTestCase): ALLOWED_TYPES = ('processes', 'threads') @@ -1793,52 +2384,146 @@ class _TestListenerClient(BaseTestCase): p.join() l.close() + def test_issue16955(self): + for fam in self.connection.families: + l = self.connection.Listener(family=fam) + c = self.connection.Client(l.address) + a = l.accept() + a.send_bytes(b"hello") + self.assertTrue(c.poll(1)) + a.close() + c.close() + l.close() + +class _TestPoll(BaseTestCase): + + ALLOWED_TYPES = ('processes', 'threads') + + def test_empty_string(self): + a, b = self.Pipe() + self.assertEqual(a.poll(), False) + b.send_bytes(b'') + self.assertEqual(a.poll(), True) + self.assertEqual(a.poll(), True) + + @classmethod + def _child_strings(cls, conn, strings): + for s in strings: + time.sleep(0.1) + conn.send_bytes(s) + conn.close() + + def test_strings(self): + strings = (b'hello', b'', b'a', b'b', b'', b'bye', b'', b'lop') + a, b = self.Pipe() + p = self.Process(target=self._child_strings, args=(b, strings)) + p.start() + + for s in strings: + for i in range(200): + if a.poll(0.01): + break + x = a.recv_bytes() + self.assertEqual(s, x) + + p.join() + + @classmethod + def _child_boundaries(cls, r): + # Polling may "pull" a message in to the child process, but we + # don't want it to pull only part of a message, as that would + # corrupt the pipe for any other processes which might later + # read from it. + r.poll(5) + + def test_boundaries(self): + r, w = self.Pipe(False) + p = self.Process(target=self._child_boundaries, args=(r,)) + p.start() + time.sleep(2) + L = [b"first", b"second"] + for obj in L: + w.send_bytes(obj) + w.close() + p.join() + self.assertIn(r.recv_bytes(), L) + + @classmethod + def _child_dont_merge(cls, b): + b.send_bytes(b'a') + b.send_bytes(b'b') + b.send_bytes(b'cd') + + def test_dont_merge(self): + a, b = self.Pipe() + self.assertEqual(a.poll(0.0), False) + self.assertEqual(a.poll(0.1), False) + + p = self.Process(target=self._child_dont_merge, args=(b,)) + p.start() + + self.assertEqual(a.recv_bytes(), b'a') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.recv_bytes(), b'b') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(0.0), True) + self.assertEqual(a.recv_bytes(), b'cd') + + p.join() + # # Test of sending connection and socket objects between processes # -""" + +@unittest.skipUnless(HAS_REDUCTION, "test needs multiprocessing.reduction") class _TestPicklingConnections(BaseTestCase): ALLOWED_TYPES = ('processes',) - def _listener(self, conn, families): + @classmethod + def tearDownClass(cls): + from multiprocessing.reduction import resource_sharer + resource_sharer.stop(timeout=5) + + @classmethod + def _listener(cls, conn, families): for fam in families: - l = self.connection.Listener(family=fam) + l = cls.connection.Listener(family=fam) conn.send(l.address) new_conn = l.accept() conn.send(new_conn) + new_conn.close() + l.close() - if self.TYPE == 'processes': - l = socket.socket() - l.bind(('localhost', 0)) - conn.send(l.getsockname()) - l.listen(1) - new_conn, addr = l.accept() - conn.send(new_conn) + l = socket.socket() + l.bind((test.support.HOST, 0)) + l.listen(1) + conn.send(l.getsockname()) + new_conn, addr = l.accept() + conn.send(new_conn) + new_conn.close() + l.close() conn.recv() - def _remote(self, conn): + @classmethod + def _remote(cls, conn): for (address, msg) in iter(conn.recv, None): - client = self.connection.Client(address) + client = cls.connection.Client(address) client.send(msg.upper()) client.close() - if self.TYPE == 'processes': - address, msg = conn.recv() - client = socket.socket() - client.connect(address) - client.sendall(msg.upper()) - client.close() + address, msg = conn.recv() + client = socket.socket() + client.connect(address) + client.sendall(msg.upper()) + client.close() conn.close() def test_pickling(self): - try: - multiprocessing.allow_connection_pickling() - except ImportError: - return - families = self.connection.families lconn, lconn0 = self.Pipe() @@ -1862,16 +2547,19 @@ class _TestPicklingConnections(BaseTestCase): rconn.send(None) - if self.TYPE == 'processes': - msg = latin('This connection uses a normal socket') - address = lconn.recv() - rconn.send((address, msg)) - if hasattr(socket, 'fromfd'): - new_conn = lconn.recv() - self.assertEqual(new_conn.recv(100), msg.upper()) - else: - # XXX On Windows with Py2.6 need to backport fromfd() - discard = lconn.recv_bytes() + msg = latin('This connection uses a normal socket') + address = lconn.recv() + rconn.send((address, msg)) + new_conn = lconn.recv() + buf = [] + while True: + s = new_conn.recv(100) + if not s: + break + buf.append(s) + buf = b''.join(buf) + self.assertEqual(buf, msg.upper()) + new_conn.close() lconn.send(None) @@ -1880,7 +2568,46 @@ class _TestPicklingConnections(BaseTestCase): lp.join() rp.join() -""" + + @classmethod + def child_access(cls, conn): + w = conn.recv() + w.send('all is well') + w.close() + + r = conn.recv() + msg = r.recv() + conn.send(msg*2) + + conn.close() + + def test_access(self): + # On Windows, if we do not specify a destination pid when + # using DupHandle then we need to be careful to use the + # correct access flags for DuplicateHandle(), or else + # DupHandle.detach() will raise PermissionError. For example, + # for a read only pipe handle we should use + # access=FILE_GENERIC_READ. (Unfortunately + # DUPLICATE_SAME_ACCESS does not work.) + conn, child_conn = self.Pipe() + p = self.Process(target=self.child_access, args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + + r, w = self.Pipe(duplex=False) + conn.send(w) + w.close() + self.assertEqual(r.recv(), 'all is well') + r.close() + + r, w = self.Pipe(duplex=False) + conn.send(r) + r.close() + w.send('foobar') + w.close() + self.assertEqual(conn.recv(), 'foobar'*2) + # # # @@ -2207,37 +2934,36 @@ class TestInvalidHandle(unittest.TestCase): @unittest.skipIf(WIN32, "skipped on Windows") def test_invalid_handles(self): - conn = _multiprocessing.Connection(44977608) - self.assertRaises(IOError, conn.poll) - self.assertRaises(IOError, _multiprocessing.Connection, -1) + conn = multiprocessing.connection.Connection(44977608) + try: + self.assertRaises((ValueError, OSError), conn.poll) + finally: + # Hack private attribute _handle to avoid printing an error + # in conn.__del__ + conn._handle = None + self.assertRaises((ValueError, OSError), + multiprocessing.connection.Connection, -1) # # Functions used to create test cases from the base ones in this module # -def get_attributes(Source, names): - d = {} - for name in names: - obj = getattr(Source, name) - if type(obj) == type(get_attributes): - obj = staticmethod(obj) - d[name] = obj - return d - def create_test_cases(Mixin, type): result = {} glob = globals() Type = type.capitalize() + ALL_TYPES = {'processes', 'threads', 'manager'} for name in list(glob.keys()): if name.startswith('_Test'): base = glob[name] + assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES) if type in base.ALLOWED_TYPES: newname = 'With' + Type + name[1:] - class Temp(base, unittest.TestCase, Mixin): + class Temp(base, Mixin, unittest.TestCase): pass result[newname] = Temp - Temp.__name__ = newname + Temp.__name__ = Temp.__qualname__ = newname Temp.__module__ = Mixin.__module__ return result @@ -2248,12 +2974,24 @@ def create_test_cases(Mixin, type): class ProcessesMixin(object): TYPE = 'processes' Process = multiprocessing.Process - locals().update(get_attributes(multiprocessing, ( - 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'RawValue', - 'RawArray', 'current_process', 'active_children', 'Pipe', - 'connection', 'JoinableQueue', 'Pool' - ))) + connection = multiprocessing.connection + current_process = staticmethod(multiprocessing.current_process) + active_children = staticmethod(multiprocessing.active_children) + Pool = staticmethod(multiprocessing.Pool) + Pipe = staticmethod(multiprocessing.Pipe) + Queue = staticmethod(multiprocessing.Queue) + JoinableQueue = staticmethod(multiprocessing.JoinableQueue) + Lock = staticmethod(multiprocessing.Lock) + RLock = staticmethod(multiprocessing.RLock) + Semaphore = staticmethod(multiprocessing.Semaphore) + BoundedSemaphore = staticmethod(multiprocessing.BoundedSemaphore) + Condition = staticmethod(multiprocessing.Condition) + Event = staticmethod(multiprocessing.Event) + Barrier = staticmethod(multiprocessing.Barrier) + Value = staticmethod(multiprocessing.Value) + Array = staticmethod(multiprocessing.Array) + RawValue = staticmethod(multiprocessing.RawValue) + RawArray = staticmethod(multiprocessing.RawArray) testcases_processes = create_test_cases(ProcessesMixin, type='processes') globals().update(testcases_processes) @@ -2262,12 +3000,48 @@ globals().update(testcases_processes) class ManagerMixin(object): TYPE = 'manager' Process = multiprocessing.Process - manager = object.__new__(multiprocessing.managers.SyncManager) - locals().update(get_attributes(manager, ( - 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'list', 'dict', - 'Namespace', 'JoinableQueue', 'Pool' - ))) + Queue = property(operator.attrgetter('manager.Queue')) + JoinableQueue = property(operator.attrgetter('manager.JoinableQueue')) + Lock = property(operator.attrgetter('manager.Lock')) + RLock = property(operator.attrgetter('manager.RLock')) + Semaphore = property(operator.attrgetter('manager.Semaphore')) + BoundedSemaphore = property(operator.attrgetter('manager.BoundedSemaphore')) + Condition = property(operator.attrgetter('manager.Condition')) + Event = property(operator.attrgetter('manager.Event')) + Barrier = property(operator.attrgetter('manager.Barrier')) + Value = property(operator.attrgetter('manager.Value')) + Array = property(operator.attrgetter('manager.Array')) + list = property(operator.attrgetter('manager.list')) + dict = property(operator.attrgetter('manager.dict')) + Namespace = property(operator.attrgetter('manager.Namespace')) + + @classmethod + def Pool(cls, *args, **kwds): + return cls.manager.Pool(*args, **kwds) + + @classmethod + def setUpClass(cls): + cls.manager = multiprocessing.Manager() + + @classmethod + def tearDownClass(cls): + # only the manager process should be returned by active_children() + # but this can take a bit on slow machines, so wait a few seconds + # if there are other children too (see #17395) + t = 0.01 + while len(multiprocessing.active_children()) > 1 and t < 5: + time.sleep(t) + t *= 2 + gc.collect() # do garbage collection + if cls.manager._number_of_objects() != 0: + # This is not really an error since some tests do not + # ensure that all processes which hold a reference to a + # managed object have been joined. + print('Shared objects which still exist at manager shutdown:') + print(cls.manager._debug_info()) + cls.manager.shutdown() + cls.manager.join() + cls.manager = None testcases_manager = create_test_cases(ManagerMixin, type='manager') globals().update(testcases_manager) @@ -2276,16 +3050,27 @@ globals().update(testcases_manager) class ThreadsMixin(object): TYPE = 'threads' Process = multiprocessing.dummy.Process - locals().update(get_attributes(multiprocessing.dummy, ( - 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'current_process', - 'active_children', 'Pipe', 'connection', 'dict', 'list', - 'Namespace', 'JoinableQueue', 'Pool' - ))) + connection = multiprocessing.dummy.connection + current_process = staticmethod(multiprocessing.dummy.current_process) + active_children = staticmethod(multiprocessing.dummy.active_children) + Pool = staticmethod(multiprocessing.Pool) + Pipe = staticmethod(multiprocessing.dummy.Pipe) + Queue = staticmethod(multiprocessing.dummy.Queue) + JoinableQueue = staticmethod(multiprocessing.dummy.JoinableQueue) + Lock = staticmethod(multiprocessing.dummy.Lock) + RLock = staticmethod(multiprocessing.dummy.RLock) + Semaphore = staticmethod(multiprocessing.dummy.Semaphore) + BoundedSemaphore = staticmethod(multiprocessing.dummy.BoundedSemaphore) + Condition = staticmethod(multiprocessing.dummy.Condition) + Event = staticmethod(multiprocessing.dummy.Event) + Barrier = staticmethod(multiprocessing.dummy.Barrier) + Value = staticmethod(multiprocessing.dummy.Value) + Array = staticmethod(multiprocessing.dummy.Array) testcases_threads = create_test_cases(ThreadsMixin, type='threads') globals().update(testcases_threads) + class OtherTest(unittest.TestCase): # TODO: add more tests for deliver/answer challenge. def test_deliver_challenge_auth_failure(self): @@ -2330,6 +3115,7 @@ class TestInitializers(unittest.TestCase): def tearDown(self): self.mgr.shutdown() + self.mgr.join() def test_manager_initializer(self): m = multiprocessing.managers.SyncManager() @@ -2337,6 +3123,7 @@ class TestInitializers(unittest.TestCase): m.start(initializer, (self.ns,)) self.assertEqual(self.ns.test, 1) m.shutdown() + m.join() def test_pool_initializer(self): self.assertRaises(TypeError, multiprocessing.Pool, initializer=1) @@ -2350,15 +3137,15 @@ class TestInitializers(unittest.TestCase): # Verifies os.close(sys.stdin.fileno) vs. sys.stdin.close() behavior # -def _ThisSubProcess(q): +def _this_sub_process(q): try: item = q.get(block=False) except pyqueue.Empty: pass -def _TestProcess(q): +def _test_process(q): queue = multiprocessing.Queue() - subProc = multiprocessing.Process(target=_ThisSubProcess, args=(queue,)) + subProc = multiprocessing.Process(target=_this_sub_process, args=(queue,)) subProc.daemon = True subProc.start() subProc.join() @@ -2369,6 +3156,8 @@ def _afunc(x): def pool_in_process(): pool = multiprocessing.Pool(processes=4) x = pool.map(_afunc, [1, 2, 3, 4, 5, 6, 7]) + pool.close() + pool.join() class _file_like(object): def __init__(self, delegate): @@ -2395,7 +3184,7 @@ class TestStdinBadfiledescriptor(unittest.TestCase): def test_queue_in_process(self): queue = multiprocessing.Queue() - proc = multiprocessing.Process(target=_TestProcess, args=(queue,)) + proc = multiprocessing.Process(target=_test_process, args=(queue,)) proc.start() proc.join() @@ -2413,6 +3202,181 @@ class TestStdinBadfiledescriptor(unittest.TestCase): assert sio.getvalue() == 'foo' +class TestWait(unittest.TestCase): + + @classmethod + def _child_test_wait(cls, w, slow): + for i in range(10): + if slow: + time.sleep(random.random()*0.1) + w.send((i, os.getpid())) + w.close() + + def test_wait(self, slow=False): + from multiprocessing.connection import wait + readers = [] + procs = [] + messages = [] + + for i in range(4): + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=self._child_test_wait, args=(w, slow)) + p.daemon = True + p.start() + w.close() + readers.append(r) + procs.append(p) + self.addCleanup(p.join) + + while readers: + for r in wait(readers): + try: + msg = r.recv() + except EOFError: + readers.remove(r) + r.close() + else: + messages.append(msg) + + messages.sort() + expected = sorted((i, p.pid) for i in range(10) for p in procs) + self.assertEqual(messages, expected) + + @classmethod + def _child_test_wait_socket(cls, address, slow): + s = socket.socket() + s.connect(address) + for i in range(10): + if slow: + time.sleep(random.random()*0.1) + s.sendall(('%s\n' % i).encode('ascii')) + s.close() + + def test_wait_socket(self, slow=False): + from multiprocessing.connection import wait + l = socket.socket() + l.bind((test.support.HOST, 0)) + l.listen(4) + addr = l.getsockname() + readers = [] + procs = [] + dic = {} + + for i in range(4): + p = multiprocessing.Process(target=self._child_test_wait_socket, + args=(addr, slow)) + p.daemon = True + p.start() + procs.append(p) + self.addCleanup(p.join) + + for i in range(4): + r, _ = l.accept() + readers.append(r) + dic[r] = [] + l.close() + + while readers: + for r in wait(readers): + msg = r.recv(32) + if not msg: + readers.remove(r) + r.close() + else: + dic[r].append(msg) + + expected = ''.join('%s\n' % i for i in range(10)).encode('ascii') + for v in dic.values(): + self.assertEqual(b''.join(v), expected) + + def test_wait_slow(self): + self.test_wait(True) + + def test_wait_socket_slow(self): + self.test_wait_socket(True) + + def test_wait_timeout(self): + from multiprocessing.connection import wait + + expected = 5 + a, b = multiprocessing.Pipe() + + start = time.time() + res = wait([a, b], expected) + delta = time.time() - start + + self.assertEqual(res, []) + self.assertLess(delta, expected * 2) + self.assertGreater(delta, expected * 0.5) + + b.send(None) + + start = time.time() + res = wait([a, b], 20) + delta = time.time() - start + + self.assertEqual(res, [a]) + self.assertLess(delta, 0.4) + + @classmethod + def signal_and_sleep(cls, sem, period): + sem.release() + time.sleep(period) + + def test_wait_integer(self): + from multiprocessing.connection import wait + + expected = 3 + sorted_ = lambda l: sorted(l, key=lambda x: id(x)) + sem = multiprocessing.Semaphore(0) + a, b = multiprocessing.Pipe() + p = multiprocessing.Process(target=self.signal_and_sleep, + args=(sem, expected)) + + p.start() + self.assertIsInstance(p.sentinel, int) + self.assertTrue(sem.acquire(timeout=20)) + + start = time.time() + res = wait([a, p.sentinel, b], expected + 20) + delta = time.time() - start + + self.assertEqual(res, [p.sentinel]) + self.assertLess(delta, expected + 2) + self.assertGreater(delta, expected - 2) + + a.send(None) + + start = time.time() + res = wait([a, p.sentinel, b], 20) + delta = time.time() - start + + self.assertEqual(sorted_(res), sorted_([p.sentinel, b])) + self.assertLess(delta, 0.4) + + b.send(None) + + start = time.time() + res = wait([a, p.sentinel, b], 20) + delta = time.time() - start + + self.assertEqual(sorted_(res), sorted_([a, p.sentinel, b])) + self.assertLess(delta, 0.4) + + p.terminate() + p.join() + + def test_neg_timeout(self): + from multiprocessing.connection import wait + a, b = multiprocessing.Pipe() + t = time.time() + res = wait([a], timeout=-1) + t = time.time() - t + self.assertEqual(res, []) + self.assertLess(t, 1) + a.close() + b.close() + # # Issue 14151: Test invalid family on invalid environment # @@ -2430,6 +3394,38 @@ class TestInvalidFamily(unittest.TestCase): multiprocessing.connection.Listener('/var/test.pipe') # +# Issue 12098: check sys.flags of child matches that for parent +# + +class TestFlags(unittest.TestCase): + @classmethod + def run_in_grandchild(cls, conn): + conn.send(tuple(sys.flags)) + + @classmethod + def run_in_child(cls): + import json + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=cls.run_in_grandchild, args=(w,)) + p.start() + grandchild_flags = r.recv() + p.join() + r.close() + w.close() + flags = (tuple(sys.flags), grandchild_flags) + print(json.dumps(flags)) + + def test_flags(self): + import json, subprocess + # start child process using unusual flags + prog = ('from test.test_multiprocessing import TestFlags; ' + + 'TestFlags.run_in_child()') + data = subprocess.check_output( + [sys.executable, '-E', '-S', '-O', '-c', prog]) + child_flags, grandchild_flags = json.loads(data.decode('ascii')) + self.assertEqual(child_flags, grandchild_flags) + +# # Test interaction with socket timeouts - see Issue #6056 # @@ -2480,59 +3476,122 @@ class TestNoForkBomb(unittest.TestCase): self.assertEqual('', err.decode('ascii')) # +# Issue #17555: ForkAwareThreadLock # + +class TestForkAwareThreadLock(unittest.TestCase): + # We recurisvely start processes. Issue #17555 meant that the + # after fork registry would get duplicate entries for the same + # lock. The size of the registry at generation n was ~2**n. + + @classmethod + def child(cls, n, conn): + if n > 1: + p = multiprocessing.Process(target=cls.child, args=(n-1, conn)) + p.start() + p.join() + else: + conn.send(len(util._afterfork_registry)) + conn.close() + + def test_lock(self): + r, w = multiprocessing.Pipe(False) + l = util.ForkAwareThreadLock() + old_size = len(util._afterfork_registry) + p = multiprocessing.Process(target=self.child, args=(5, w)) + p.start() + new_size = r.recv() + p.join() + self.assertLessEqual(new_size, old_size) + +# +# Issue #17097: EINTR should be ignored by recv(), send(), accept() etc # -testcases_other = [OtherTest, TestInvalidHandle, TestInitializers, - TestStdinBadfiledescriptor, TestInvalidFamily, - TestTimeouts, TestNoForkBomb] +class TestIgnoreEINTR(unittest.TestCase): + + @classmethod + def _test_ignore(cls, conn): + def handler(signum, frame): + pass + signal.signal(signal.SIGUSR1, handler) + conn.send('ready') + x = conn.recv() + conn.send(x) + conn.send_bytes(b'x'*(1024*1024)) # sending 1 MB should block + + @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') + def test_ignore(self): + conn, child_conn = multiprocessing.Pipe() + try: + p = multiprocessing.Process(target=self._test_ignore, + args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + self.assertEqual(conn.recv(), 'ready') + time.sleep(0.1) + os.kill(p.pid, signal.SIGUSR1) + time.sleep(0.1) + conn.send(1234) + self.assertEqual(conn.recv(), 1234) + time.sleep(0.1) + os.kill(p.pid, signal.SIGUSR1) + self.assertEqual(conn.recv_bytes(), b'x'*(1024*1024)) + time.sleep(0.1) + p.join() + finally: + conn.close() + + @classmethod + def _test_ignore_listener(cls, conn): + def handler(signum, frame): + pass + signal.signal(signal.SIGUSR1, handler) + l = multiprocessing.connection.Listener() + conn.send(l.address) + a = l.accept() + a.send('welcome') + + @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') + def test_ignore_listener(self): + conn, child_conn = multiprocessing.Pipe() + try: + p = multiprocessing.Process(target=self._test_ignore_listener, + args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + address = conn.recv() + time.sleep(0.1) + os.kill(p.pid, signal.SIGUSR1) + time.sleep(0.1) + client = multiprocessing.connection.Client(address) + self.assertEqual(client.recv(), 'welcome') + p.join() + finally: + conn.close() # # # -def test_main(run=None): +def setUpModule(): if sys.platform.startswith("linux"): try: lock = multiprocessing.RLock() except OSError: - raise unittest.SkipTest("OSError raises on RLock creation, see issue 3111!") - + raise unittest.SkipTest("OSError raises on RLock creation, " + "see issue 3111!") check_enough_semaphores() - - if run is None: - from test.support import run_unittest as run - util.get_temp_dir() # creates temp directory for use by all processes - multiprocessing.get_logger().setLevel(LOG_LEVEL) - ProcessesMixin.pool = multiprocessing.Pool(4) - ThreadsMixin.pool = multiprocessing.dummy.Pool(4) - ManagerMixin.manager.__init__() - ManagerMixin.manager.start() - ManagerMixin.pool = ManagerMixin.manager.Pool(4) - - testcases = ( - sorted(testcases_processes.values(), key=lambda tc:tc.__name__) + - sorted(testcases_threads.values(), key=lambda tc:tc.__name__) + - sorted(testcases_manager.values(), key=lambda tc:tc.__name__) + - testcases_other - ) - - loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase - suite = unittest.TestSuite(loadTestsFromTestCase(tc) for tc in testcases) - run(suite) - - ThreadsMixin.pool.terminate() - ProcessesMixin.pool.terminate() - ManagerMixin.pool.terminate() - ManagerMixin.manager.shutdown() - del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool +def tearDownModule(): + # pause a bit so we don't get warning about dangling threads/processes + time.sleep(0.5) -def main(): - test_main(unittest.TextTestRunner(verbosity=2).run) if __name__ == '__main__': - main() + unittest.main() |