diff options
Diffstat (limited to 'lib/ansible/executor')
-rw-r--r-- | lib/ansible/executor/module_common.py | 8 | ||||
-rw-r--r-- | lib/ansible/executor/task_queue_manager.py | 152 |
2 files changed, 147 insertions, 13 deletions
diff --git a/lib/ansible/executor/module_common.py b/lib/ansible/executor/module_common.py index df2be06f70..1243dad932 100644 --- a/lib/ansible/executor/module_common.py +++ b/lib/ansible/executor/module_common.py @@ -37,7 +37,7 @@ from ansible.utils.unicode import to_bytes, to_unicode # Must import strategy and use write_locks from there # If we import write_locks directly then we end up binding a # variable to the object and then it never gets updated. -from ansible.plugins import strategy +from ansible.executor.task_queue_manager import action_write_locks try: from __main__ import display @@ -596,16 +596,16 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta display.debug('ANSIBALLZ: using cached module: %s' % cached_module_filename) zipdata = open(cached_module_filename, 'rb').read() else: - if module_name in strategy.action_write_locks: + if module_name in action_write_locks: display.debug('ANSIBALLZ: Using lock for %s' % module_name) - lock = strategy.action_write_locks[module_name] + lock = action_write_locks[module_name] else: # If the action plugin directly invokes the module (instead of # going through a strategy) then we don't have a cross-process # Lock specifically for this module. Use the "unexpected # module" lock instead display.debug('ANSIBALLZ: Using generic lock for %s' % module_name) - lock = strategy.action_write_locks[None] + lock = action_write_locks[None] display.debug('ANSIBALLZ: Acquiring lock') with lock: diff --git a/lib/ansible/executor/task_queue_manager.py b/lib/ansible/executor/task_queue_manager.py index c3313ae50a..622db49bdc 100644 --- a/lib/ansible/executor/task_queue_manager.py +++ b/lib/ansible/executor/task_queue_manager.py @@ -22,14 +22,20 @@ __metaclass__ = type import multiprocessing import os import tempfile +import time + +from collections import deque +from threading import Thread, Lock from ansible import constants as C from ansible.errors import AnsibleError from ansible.executor.play_iterator import PlayIterator +from ansible.executor.process.worker import WorkerProcess from ansible.executor.stats import AggregateStats +from ansible.module_utils.facts import Facts from ansible.playbook.block import Block from ansible.playbook.play_context import PlayContext -from ansible.plugins import callback_loader, strategy_loader, module_loader +from ansible.plugins import action_loader, callback_loader, connection_loader, filter_loader, lookup_loader, module_loader, strategy_loader, test_loader from ansible.template import Templar from ansible.vars.hostvars import HostVars from ansible.plugins.callback import CallbackBase @@ -46,6 +52,42 @@ except ImportError: __all__ = ['TaskQueueManager'] +if 'action_write_locks' not in globals(): + # Do not initialize this more than once because it seems to bash + # the existing one. multiprocessing must be reloading the module + # when it forks? + action_write_locks = dict() + + # Below is a Lock for use when we weren't expecting a named module. + # It gets used when an action plugin directly invokes a module instead + # of going through the strategies. Slightly less efficient as all + # processes with unexpected module names will wait on this lock + action_write_locks[None] = Lock() + + # These plugins are called directly by action plugins (not going through + # a strategy). We precreate them here as an optimization + mods = set(p['name'] for p in Facts.PKG_MGRS) + mods.update(('copy', 'file', 'setup', 'slurp', 'stat')) + for mod_name in mods: + action_write_locks[mod_name] = Lock() + +# TODO: this should probably be in the plugins/__init__.py, with +# a smarter mechanism to set all of the attributes based on +# the loaders created there +class SharedPluginLoaderObj: + ''' + A simple object to make pass the various plugin loaders to + the forked processes over the queue easier + ''' + def __init__(self): + self.action_loader = action_loader + self.connection_loader = connection_loader + self.filter_loader = filter_loader + self.test_loader = test_loader + self.lookup_loader = lookup_loader + self.module_loader = module_loader + + class TaskQueueManager: ''' @@ -98,12 +140,102 @@ class TaskQueueManager: self._failed_hosts = dict() self._unreachable_hosts = dict() + self._workers = [] + self._queue_thread = None + self._queued_tasks = deque() + self._queued_tasks_lock = Lock() + self._final_q = multiprocessing.Queue() # A temporary file (opened pre-fork) used by connection # plugins for inter-process locking. self._connection_lockfile = tempfile.TemporaryFile() + def _queue_thread_main(self): + global action_write_locks + + cur_worker = 0 + while not self._terminated: + try: + self._queued_tasks_lock.acquire() + (host, task, task_vars, play_context) = self._queued_tasks.pop() + self._queued_tasks_lock.release() + except IndexError: + self._queued_tasks_lock.release() + time.sleep(0.001) + continue + + # Add a write lock for tasks. + # Maybe this should be added somewhere further up the call stack but + # this is the earliest in the code where we have task (1) extracted + # into its own variable and (2) there's only a single code path + # leading to the module being run. This is called by three + # functions: __init__.py::_do_handler_run(), linear.py::run(), and + # free.py::run() so we'd have to add to all three to do it there. + # The next common higher level is __init__.py::run() and that has + # tasks inside of play_iterator so we'd have to extract them to do it + # there. + + if task.action not in action_write_locks: + display.debug('Creating lock for %s' % task.action) + action_write_locks[task.action] = Lock() + + # create a dummy object with plugin loaders set as an easier + # way to share them with the forked processes + shared_loader_obj = SharedPluginLoaderObj() + + try: + queued = False + starting_worker = cur_worker + while True: + try: + (worker_prc, rslt_q) = self._workers[cur_worker] + except IndexError: + cur_worker = 0 + continue + + if worker_prc is None or not worker_prc.is_alive(): + worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj) + self._workers[cur_worker][0] = worker_prc + worker_prc.start() + display.debug("worker is %d (out of %d available)" % (cur_worker+1, len(self._workers))) + queued = True + + cur_worker += 1 + if cur_worker >= len(self._workers): + cur_worker = 0 + + if queued: + break + elif cur_worker == starting_worker: + if self._terminated: + break + time.sleep(0.0001) + + # if we didn't queue it, we must have broken out inside the + # while loop meaning we're terminated, so break again + if not queued: + break + + except (EOFError, IOError, AssertionError) as e: + # most likely an abort + display.debug("got an error while queuing: %s" % e) + break + + time.sleep(0.0001) + + def queue_task(self, host, task, task_vars, play_context): + self._queued_tasks_lock.acquire() + self._queued_tasks.append((host, task, task_vars, play_context)) + self._queued_tasks_lock.release() + + def queue_multiple_tasks(self, items, play_context): + self._queued_tasks_lock.acquire() + for item in items: + (host, task, task_vars) = item + self._queued_tasks.append((host, task, task_vars, play_context)) + self._queued_tasks_lock.release() + def _initialize_processes(self, num): self._workers = [] @@ -111,6 +243,9 @@ class TaskQueueManager: rslt_q = multiprocessing.Queue() self._workers.append([None, rslt_q]) + self._queue_thread = Thread(target=self._queue_thread_main) + self._queue_thread.start() + def _initialize_notified_handlers(self, play): ''' Clears and initializes the shared notified handlers dict with entries @@ -294,14 +429,13 @@ class TaskQueueManager: self._cleanup_processes() def _cleanup_processes(self): - if hasattr(self, '_workers'): - for (worker_prc, rslt_q) in self._workers: - rslt_q.close() - if worker_prc and worker_prc.is_alive(): - try: - worker_prc.terminate() - except AttributeError: - pass + for (worker_prc, rslt_q) in self._workers: + rslt_q.close() + if worker_prc and worker_prc.is_alive(): + try: + worker_prc.terminate() + except AttributeError: + pass def clear_failed_hosts(self): self._failed_hosts = dict() |