summaryrefslogtreecommitdiff
path: root/lib/ansible/executor
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ansible/executor')
-rw-r--r--lib/ansible/executor/module_common.py8
-rw-r--r--lib/ansible/executor/task_queue_manager.py152
2 files changed, 147 insertions, 13 deletions
diff --git a/lib/ansible/executor/module_common.py b/lib/ansible/executor/module_common.py
index df2be06f70..1243dad932 100644
--- a/lib/ansible/executor/module_common.py
+++ b/lib/ansible/executor/module_common.py
@@ -37,7 +37,7 @@ from ansible.utils.unicode import to_bytes, to_unicode
# Must import strategy and use write_locks from there
# If we import write_locks directly then we end up binding a
# variable to the object and then it never gets updated.
-from ansible.plugins import strategy
+from ansible.executor.task_queue_manager import action_write_locks
try:
from __main__ import display
@@ -596,16 +596,16 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
display.debug('ANSIBALLZ: using cached module: %s' % cached_module_filename)
zipdata = open(cached_module_filename, 'rb').read()
else:
- if module_name in strategy.action_write_locks:
+ if module_name in action_write_locks:
display.debug('ANSIBALLZ: Using lock for %s' % module_name)
- lock = strategy.action_write_locks[module_name]
+ lock = action_write_locks[module_name]
else:
# If the action plugin directly invokes the module (instead of
# going through a strategy) then we don't have a cross-process
# Lock specifically for this module. Use the "unexpected
# module" lock instead
display.debug('ANSIBALLZ: Using generic lock for %s' % module_name)
- lock = strategy.action_write_locks[None]
+ lock = action_write_locks[None]
display.debug('ANSIBALLZ: Acquiring lock')
with lock:
diff --git a/lib/ansible/executor/task_queue_manager.py b/lib/ansible/executor/task_queue_manager.py
index c3313ae50a..622db49bdc 100644
--- a/lib/ansible/executor/task_queue_manager.py
+++ b/lib/ansible/executor/task_queue_manager.py
@@ -22,14 +22,20 @@ __metaclass__ = type
import multiprocessing
import os
import tempfile
+import time
+
+from collections import deque
+from threading import Thread, Lock
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.executor.play_iterator import PlayIterator
+from ansible.executor.process.worker import WorkerProcess
from ansible.executor.stats import AggregateStats
+from ansible.module_utils.facts import Facts
from ansible.playbook.block import Block
from ansible.playbook.play_context import PlayContext
-from ansible.plugins import callback_loader, strategy_loader, module_loader
+from ansible.plugins import action_loader, callback_loader, connection_loader, filter_loader, lookup_loader, module_loader, strategy_loader, test_loader
from ansible.template import Templar
from ansible.vars.hostvars import HostVars
from ansible.plugins.callback import CallbackBase
@@ -46,6 +52,42 @@ except ImportError:
__all__ = ['TaskQueueManager']
+if 'action_write_locks' not in globals():
+ # Do not initialize this more than once because it seems to bash
+ # the existing one. multiprocessing must be reloading the module
+ # when it forks?
+ action_write_locks = dict()
+
+ # Below is a Lock for use when we weren't expecting a named module.
+ # It gets used when an action plugin directly invokes a module instead
+ # of going through the strategies. Slightly less efficient as all
+ # processes with unexpected module names will wait on this lock
+ action_write_locks[None] = Lock()
+
+ # These plugins are called directly by action plugins (not going through
+ # a strategy). We precreate them here as an optimization
+ mods = set(p['name'] for p in Facts.PKG_MGRS)
+ mods.update(('copy', 'file', 'setup', 'slurp', 'stat'))
+ for mod_name in mods:
+ action_write_locks[mod_name] = Lock()
+
+# TODO: this should probably be in the plugins/__init__.py, with
+# a smarter mechanism to set all of the attributes based on
+# the loaders created there
+class SharedPluginLoaderObj:
+ '''
+ A simple object to make pass the various plugin loaders to
+ the forked processes over the queue easier
+ '''
+ def __init__(self):
+ self.action_loader = action_loader
+ self.connection_loader = connection_loader
+ self.filter_loader = filter_loader
+ self.test_loader = test_loader
+ self.lookup_loader = lookup_loader
+ self.module_loader = module_loader
+
+
class TaskQueueManager:
'''
@@ -98,12 +140,102 @@ class TaskQueueManager:
self._failed_hosts = dict()
self._unreachable_hosts = dict()
+ self._workers = []
+ self._queue_thread = None
+ self._queued_tasks = deque()
+ self._queued_tasks_lock = Lock()
+
self._final_q = multiprocessing.Queue()
# A temporary file (opened pre-fork) used by connection
# plugins for inter-process locking.
self._connection_lockfile = tempfile.TemporaryFile()
+ def _queue_thread_main(self):
+ global action_write_locks
+
+ cur_worker = 0
+ while not self._terminated:
+ try:
+ self._queued_tasks_lock.acquire()
+ (host, task, task_vars, play_context) = self._queued_tasks.pop()
+ self._queued_tasks_lock.release()
+ except IndexError:
+ self._queued_tasks_lock.release()
+ time.sleep(0.001)
+ continue
+
+ # Add a write lock for tasks.
+ # Maybe this should be added somewhere further up the call stack but
+ # this is the earliest in the code where we have task (1) extracted
+ # into its own variable and (2) there's only a single code path
+ # leading to the module being run. This is called by three
+ # functions: __init__.py::_do_handler_run(), linear.py::run(), and
+ # free.py::run() so we'd have to add to all three to do it there.
+ # The next common higher level is __init__.py::run() and that has
+ # tasks inside of play_iterator so we'd have to extract them to do it
+ # there.
+
+ if task.action not in action_write_locks:
+ display.debug('Creating lock for %s' % task.action)
+ action_write_locks[task.action] = Lock()
+
+ # create a dummy object with plugin loaders set as an easier
+ # way to share them with the forked processes
+ shared_loader_obj = SharedPluginLoaderObj()
+
+ try:
+ queued = False
+ starting_worker = cur_worker
+ while True:
+ try:
+ (worker_prc, rslt_q) = self._workers[cur_worker]
+ except IndexError:
+ cur_worker = 0
+ continue
+
+ if worker_prc is None or not worker_prc.is_alive():
+ worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj)
+ self._workers[cur_worker][0] = worker_prc
+ worker_prc.start()
+ display.debug("worker is %d (out of %d available)" % (cur_worker+1, len(self._workers)))
+ queued = True
+
+ cur_worker += 1
+ if cur_worker >= len(self._workers):
+ cur_worker = 0
+
+ if queued:
+ break
+ elif cur_worker == starting_worker:
+ if self._terminated:
+ break
+ time.sleep(0.0001)
+
+ # if we didn't queue it, we must have broken out inside the
+ # while loop meaning we're terminated, so break again
+ if not queued:
+ break
+
+ except (EOFError, IOError, AssertionError) as e:
+ # most likely an abort
+ display.debug("got an error while queuing: %s" % e)
+ break
+
+ time.sleep(0.0001)
+
+ def queue_task(self, host, task, task_vars, play_context):
+ self._queued_tasks_lock.acquire()
+ self._queued_tasks.append((host, task, task_vars, play_context))
+ self._queued_tasks_lock.release()
+
+ def queue_multiple_tasks(self, items, play_context):
+ self._queued_tasks_lock.acquire()
+ for item in items:
+ (host, task, task_vars) = item
+ self._queued_tasks.append((host, task, task_vars, play_context))
+ self._queued_tasks_lock.release()
+
def _initialize_processes(self, num):
self._workers = []
@@ -111,6 +243,9 @@ class TaskQueueManager:
rslt_q = multiprocessing.Queue()
self._workers.append([None, rslt_q])
+ self._queue_thread = Thread(target=self._queue_thread_main)
+ self._queue_thread.start()
+
def _initialize_notified_handlers(self, play):
'''
Clears and initializes the shared notified handlers dict with entries
@@ -294,14 +429,13 @@ class TaskQueueManager:
self._cleanup_processes()
def _cleanup_processes(self):
- if hasattr(self, '_workers'):
- for (worker_prc, rslt_q) in self._workers:
- rslt_q.close()
- if worker_prc and worker_prc.is_alive():
- try:
- worker_prc.terminate()
- except AttributeError:
- pass
+ for (worker_prc, rslt_q) in self._workers:
+ rslt_q.close()
+ if worker_prc and worker_prc.is_alive():
+ try:
+ worker_prc.terminate()
+ except AttributeError:
+ pass
def clear_failed_hosts(self):
self._failed_hosts = dict()