diff options
Diffstat (limited to 'Lib/multiprocessing/pool.py')
-rw-r--r-- | Lib/multiprocessing/pool.py | 133 |
1 files changed, 81 insertions, 52 deletions
diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 0f2dab48eb..1cb2d95678 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -7,7 +7,7 @@ # Licensed to PSF under a Contributor Agreement. # -__all__ = ['Pool'] +__all__ = ['Pool', 'ThreadPool'] # # Imports @@ -17,10 +17,14 @@ import threading import queue import itertools import collections +import os import time +import traceback -from multiprocessing import Process, cpu_count, TimeoutError -from multiprocessing.util import Finalize, debug +# If threading is available then ThreadPool should be provided. Therefore +# we avoid top-level imports which are liable to fail on some systems. +from . import util +from . import get_context, cpu_count, TimeoutError # # Constants representing the state of a pool @@ -43,6 +47,29 @@ def starmapstar(args): return list(itertools.starmap(args[0], args[1])) # +# Hack to embed stringification of remote traceback in local traceback +# + +class RemoteTraceback(Exception): + def __init__(self, tb): + self.tb = tb + def __str__(self): + return self.tb + +class ExceptionWithTraceback: + def __init__(self, exc, tb): + tb = traceback.format_exception(type(exc), exc, tb) + tb = ''.join(tb) + self.exc = exc + self.tb = '\n"""\n%s"""' % tb + def __reduce__(self): + return rebuild_exc, (self.exc, self.tb) + +def rebuild_exc(exc, tb): + exc.__cause__ = RemoteTraceback(tb) + return exc + +# # Code run by worker processes # @@ -78,28 +105,29 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): while maxtasks is None or (maxtasks and completed < maxtasks): try: task = get() - except (EOFError, IOError): - debug('worker got EOFError or IOError -- exiting') + except (EOFError, OSError): + util.debug('worker got EOFError or OSError -- exiting') break if task is None: - debug('worker got sentinel -- exiting') + util.debug('worker got sentinel -- exiting') break job, i, func, args, kwds = task try: result = (True, func(*args, **kwds)) except Exception as e: + e = ExceptionWithTraceback(e, e.__traceback__) result = (False, e) try: put((job, i, result)) except Exception as e: wrapped = MaybeEncodingError(e, result[1]) - debug("Possible encoding error while sending result: %s" % ( + util.debug("Possible encoding error while sending result: %s" % ( wrapped)) put((job, i, (False, wrapped))) completed += 1 - debug('worker exiting after %d tasks' % completed) + util.debug('worker exiting after %d tasks' % completed) # # Class representing a process pool @@ -109,10 +137,12 @@ class Pool(object): ''' Class which supports an async version of applying functions to arguments. ''' - Process = Process + def Process(self, *args, **kwds): + return self._ctx.Process(*args, **kwds) def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None): + maxtasksperchild=None, context=None): + self._ctx = context or get_context() self._setup_queues() self._taskqueue = queue.Queue() self._cache = {} @@ -122,10 +152,7 @@ class Pool(object): self._initargs = initargs if processes is None: - try: - processes = cpu_count() - except NotImplementedError: - processes = 1 + processes = os.cpu_count() or 1 if processes < 1: raise ValueError("Number of processes must be at least 1") @@ -162,7 +189,7 @@ class Pool(object): self._result_handler._state = RUN self._result_handler.start() - self._terminate = Finalize( + self._terminate = util.Finalize( self, self._terminate_pool, args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, self._worker_handler, self._task_handler, @@ -179,7 +206,7 @@ class Pool(object): worker = self._pool[i] if worker.exitcode is not None: # worker exited - debug('cleaning up worker %d' % i) + util.debug('cleaning up worker %d' % i) worker.join() cleaned = True del self._pool[i] @@ -199,7 +226,7 @@ class Pool(object): w.name = w.name.replace('Process', 'PoolWorker') w.daemon = True w.start() - debug('added worker') + util.debug('added worker') def _maintain_pool(self): """Clean up any exited workers and start replacements for them. @@ -208,9 +235,8 @@ class Pool(object): self._repopulate_pool() def _setup_queues(self): - from .queues import SimpleQueue - self._inqueue = SimpleQueue() - self._outqueue = SimpleQueue() + self._inqueue = self._ctx.SimpleQueue() + self._outqueue = self._ctx.SimpleQueue() self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv @@ -336,7 +362,7 @@ class Pool(object): time.sleep(0.1) # send sentinel to stop workers pool._taskqueue.put(None) - debug('worker handler exiting') + util.debug('worker handler exiting') @staticmethod def _handle_tasks(taskqueue, put, outqueue, pool, cache): @@ -346,7 +372,7 @@ class Pool(object): i = -1 for i, task in enumerate(taskseq): if thread._state: - debug('task handler found thread._state != RUN') + util.debug('task handler found thread._state != RUN') break try: put(task) @@ -358,27 +384,27 @@ class Pool(object): pass else: if set_length: - debug('doing set_length()') + util.debug('doing set_length()') set_length(i+1) continue break else: - debug('task handler got sentinel') + util.debug('task handler got sentinel') try: # tell result handler to finish when cache is empty - debug('task handler sending sentinel to result handler') + util.debug('task handler sending sentinel to result handler') outqueue.put(None) # tell workers there is no more work - debug('task handler sending sentinel to workers') + util.debug('task handler sending sentinel to workers') for p in pool: put(None) - except IOError: - debug('task handler got IOError when sending sentinels') + except OSError: + util.debug('task handler got OSError when sending sentinels') - debug('task handler exiting') + util.debug('task handler exiting') @staticmethod def _handle_results(outqueue, get, cache): @@ -387,17 +413,17 @@ class Pool(object): while 1: try: task = get() - except (IOError, EOFError): - debug('result handler got EOFError/IOError -- exiting') + except (OSError, EOFError): + util.debug('result handler got EOFError/OSError -- exiting') return if thread._state: assert thread._state == TERMINATE - debug('result handler found thread._state=TERMINATE') + util.debug('result handler found thread._state=TERMINATE') break if task is None: - debug('result handler got sentinel') + util.debug('result handler got sentinel') break job, i, obj = task @@ -409,12 +435,12 @@ class Pool(object): while cache and thread._state != TERMINATE: try: task = get() - except (IOError, EOFError): - debug('result handler got EOFError/IOError -- exiting') + except (OSError, EOFError): + util.debug('result handler got EOFError/OSError -- exiting') return if task is None: - debug('result handler ignoring extra sentinel') + util.debug('result handler ignoring extra sentinel') continue job, i, obj = task try: @@ -423,7 +449,7 @@ class Pool(object): pass if hasattr(outqueue, '_reader'): - debug('ensuring that outqueue is not full') + util.debug('ensuring that outqueue is not full') # If we don't make room available in outqueue then # attempts to add the sentinel (None) to outqueue may # block. There is guaranteed to be no more than 2 sentinels. @@ -432,10 +458,10 @@ class Pool(object): if not outqueue._reader.poll(): break get() - except (IOError, EOFError): + except (OSError, EOFError): pass - debug('result handler exiting: len(cache)=%s, thread._state=%s', + util.debug('result handler exiting: len(cache)=%s, thread._state=%s', len(cache), thread._state) @staticmethod @@ -453,19 +479,19 @@ class Pool(object): ) def close(self): - debug('closing pool') + util.debug('closing pool') if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE def terminate(self): - debug('terminating pool') + util.debug('terminating pool') self._state = TERMINATE self._worker_handler._state = TERMINATE self._terminate() def join(self): - debug('joining pool') + util.debug('joining pool') assert self._state in (CLOSE, TERMINATE) self._worker_handler.join() self._task_handler.join() @@ -476,7 +502,7 @@ class Pool(object): @staticmethod def _help_stuff_finish(inqueue, task_handler, size): # task_handler may be blocked trying to put items on inqueue - debug('removing tasks from inqueue until task handler finished') + util.debug('removing tasks from inqueue until task handler finished') inqueue._rlock.acquire() while task_handler.is_alive() and inqueue._reader.poll(): inqueue._reader.recv() @@ -486,12 +512,12 @@ class Pool(object): def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache): # this is guaranteed to only be called once - debug('finalizing pool') + util.debug('finalizing pool') worker_handler._state = TERMINATE task_handler._state = TERMINATE - debug('helping task handler/workers to finish') + util.debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) assert result_handler.is_alive() or len(cache) == 0 @@ -501,31 +527,31 @@ class Pool(object): # We must wait for the worker handler to exit before terminating # workers because we don't want workers to be restarted behind our back. - debug('joining worker handler') + util.debug('joining worker handler') if threading.current_thread() is not worker_handler: worker_handler.join() # Terminate workers which haven't already finished. if pool and hasattr(pool[0], 'terminate'): - debug('terminating workers') + util.debug('terminating workers') for p in pool: if p.exitcode is None: p.terminate() - debug('joining task handler') + util.debug('joining task handler') if threading.current_thread() is not task_handler: task_handler.join() - debug('joining result handler') + util.debug('joining result handler') if threading.current_thread() is not result_handler: result_handler.join() if pool and hasattr(pool[0], 'terminate'): - debug('joining pool workers') + util.debug('joining pool workers') for p in pool: if p.is_alive(): # worker has not yet exited - debug('cleaning up worker %d' % p.pid) + util.debug('cleaning up worker %d' % p.pid) p.join() def __enter__(self): @@ -711,7 +737,10 @@ class IMapUnorderedIterator(IMapIterator): class ThreadPool(Pool): - from .dummy import Process + @staticmethod + def Process(*args, **kwds): + from .dummy import Process + return Process(*args, **kwds) def __init__(self, processes=None, initializer=None, initargs=()): Pool.__init__(self, processes, initializer, initargs) |