diff options
author | James Cammarata <jimi@sngx.net> | 2016-05-27 12:22:57 -0400 |
---|---|---|
committer | James Cammarata <jimi@sngx.net> | 2016-05-29 08:35:06 -0400 |
commit | d93f0e0d9bf2e77b71e16298b95c3180f0c8d332 (patch) | |
tree | 0eb8f315842a27ae61c27ca492e008ed2148ca5b | |
parent | d81b9ca29e1e2098c1a12b0fc5892b66706acb20 (diff) | |
download | ansible_chain_map.tar.gz |
Create and use AnsibleChainMap for the master dict of variablesansible_chain_map
-rw-r--r-- | lib/ansible/executor/task_executor.py | 20 | ||||
-rw-r--r-- | lib/ansible/playbook/play_context.py | 4 | ||||
-rw-r--r-- | lib/ansible/template/__init__.py | 3 | ||||
-rw-r--r-- | lib/ansible/vars/__init__.py | 54 | ||||
-rw-r--r-- | lib/ansible/vars/chain_map.py | 108 | ||||
-rw-r--r-- | lib/ansible/vars/hostvars.py | 2 | ||||
-rw-r--r-- | test/units/executor/test_task_executor.py | 30 |
7 files changed, 172 insertions, 49 deletions
diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index f35ba897f1..51ec0a5155 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -26,6 +26,8 @@ import sys import time import traceback +from collections import MutableMapping + from ansible.compat.six import iteritems, string_types, binary_type from ansible import constants as C @@ -434,13 +436,13 @@ class TaskExecutor: if delay < 0: delay = 1 - # make a copy of the job vars here, in case we need to update them - # with the registered variable value later on when testing conditions - vars_copy = variables.copy() - display.debug("starting attempt loop") result = None for attempt in range(1, retries + 1): + # make a copy of the job vars here, in case we need to update them + # with the registered variable value later on when testing conditions + vars_copy = variables.copy() + display.debug("running the handler") try: result = self._handler.run(task_vars=variables) @@ -454,7 +456,7 @@ class TaskExecutor: # update the local copy of vars with the registered value, if specified, # or any facts which may have been generated by the module execution if self._task.register: - vars_copy[self._task.register] = wrap_var(result.copy()) + vars_copy.push({self._task.register: wrap_var(result.copy())}) if self._task.async > 0: # the async_wrapper module returns dumped JSON via its stdout @@ -490,7 +492,7 @@ class TaskExecutor: return failed_when_result if 'ansible_facts' in result: - vars_copy.update(result['ansible_facts']) + vars_copy.push(result['ansible_facts']) # set the failed property if the result has a non-zero rc. This will be # overridden below if the failed_when property is set @@ -525,10 +527,10 @@ class TaskExecutor: # do the final update of the local variables here, for both registered # values and any facts which may have been created if self._task.register: - variables[self._task.register] = wrap_var(result) + variables.push({self._task.register: wrap_var(result)}) if 'ansible_facts' in result: - variables.update(result['ansible_facts']) + variables.push(result['ansible_facts']) # save the notification target in the result, if it was specified, as # this task may be running in a loop in which case the notification @@ -609,7 +611,7 @@ class TaskExecutor: # now replace the interpreter values with those that may have come # from the delegated-to host delegated_vars = variables.get('ansible_delegated_vars', dict()).get(self._task.delegate_to, dict()) - if isinstance(delegated_vars, dict): + if isinstance(delegated_vars, MutableMapping): for i in delegated_vars: if isinstance(i, string_types) and i.startswith("ansible_") and i.endswith("_interpreter"): variables[i] = delegated_vars[i] diff --git a/lib/ansible/playbook/play_context.py b/lib/ansible/playbook/play_context.py index a79b4e1988..9783a6b574 100644 --- a/lib/ansible/playbook/play_context.py +++ b/lib/ansible/playbook/play_context.py @@ -26,6 +26,8 @@ import random import re import string +from collections import MutableMapping + from ansible.compat.six import iteritems, string_types from ansible import constants as C from ansible.errors import AnsibleError @@ -375,7 +377,7 @@ class PlayContext(Base): continue # if delegation task ONLY use delegated host vars, avoid delegated FOR host vars if task.delegate_to is not None: - if isinstance(delegated_vars, dict) and variable_name in delegated_vars: + if isinstance(delegated_vars, MutableMapping) and variable_name in delegated_vars: setattr(new_info, attr, delegated_vars[variable_name]) attrs_considered.append(attr) elif variable_name in variables: diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py index 4e1f7b7fd8..978d5fae3c 100644 --- a/lib/ansible/template/__init__.py +++ b/lib/ansible/template/__init__.py @@ -25,6 +25,7 @@ import os import re from io import StringIO +from collections import MutableMapping from ansible.compat.six import string_types, text_type, binary_type from jinja2 import Environment @@ -273,7 +274,7 @@ class Templar: are being changed. ''' - assert isinstance(variables, dict) + assert isinstance(variables, MutableMapping) self._available_variables = variables self._cached_result = {} diff --git a/lib/ansible/vars/__init__.py b/lib/ansible/vars/__init__.py index e1f569abd2..0e43230310 100644 --- a/lib/ansible/vars/__init__.py +++ b/lib/ansible/vars/__init__.py @@ -42,6 +42,7 @@ from ansible.plugins.cache import FactCache from ansible.template import Templar from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.vars import combine_vars +from ansible.vars.chain_map import AnsibleChainMap from ansible.vars.unsafe_proxy import wrap_var try: @@ -212,7 +213,7 @@ class VariableManager: display.debug("vars are cached, returning them now") return VARIABLE_CACHE[cache_entry] - all_vars = dict() + all_vars = AnsibleChainMap() magic_variables = self._get_magic_variables( loader=loader, play=play, @@ -226,13 +227,13 @@ class VariableManager: # first we compile any vars specified in defaults/main.yml # for all roles within the specified play for role in play.get_roles(): - all_vars = combine_vars(all_vars, role.get_default_vars()) + all_vars.push(role.get_default_vars()) # if we have a task in this context, and that task has a role, make # sure it sees its defaults above any other roles, as we previously # (v1) made sure each task had a copy of its roles default vars if task and task._role is not None: - all_vars = combine_vars(all_vars, task._role.get_default_vars(dep_chain=task._block.get_dep_chain())) + all_vars.push(task._role.get_default_vars(dep_chain=task._block.get_dep_chain())) if host: # next, if a host is specified, we load any vars from group_vars @@ -243,20 +244,20 @@ class VariableManager: if 'all' in self._group_vars_files: data = preprocess_vars(self._group_vars_files['all']) for item in data: - all_vars = combine_vars(all_vars, item) + all_vars.push(item) # we merge in vars from groups specified in the inventory (INI or script) - all_vars = combine_vars(all_vars, host.get_group_vars()) + all_vars.push(host.get_group_vars()) for group in sorted(host.get_groups(), key=lambda g: g.depth): if group.name in self._group_vars_files and group.name != 'all': for data in self._group_vars_files[group.name]: data = preprocess_vars(data) for item in data: - all_vars = combine_vars(all_vars, item) + all_vars.push(item) # then we merge in vars from the host specified in the inventory (INI or script) - all_vars = combine_vars(all_vars, host.get_vars()) + all_vars.push(host.get_vars()) # then we merge in the host_vars/<hostname> file, if it exists host_name = host.get_name() @@ -264,17 +265,17 @@ class VariableManager: for data in self._host_vars_files[host_name]: data = preprocess_vars(data) for item in data: - all_vars = combine_vars(all_vars, item) + all_vars.push(item) # finally, the facts caches for this host, if it exists try: host_facts = wrap_var(self._fact_cache.get(host.name, dict())) - all_vars = combine_vars(all_vars, host_facts) + all_vars.push(host_facts) except KeyError: pass if play: - all_vars = combine_vars(all_vars, play.get_vars()) + all_vars.push(play.get_vars()) for vars_file_item in play.get_vars_files(): # create a set of temporary vars here, which incorporate the extra @@ -300,7 +301,7 @@ class VariableManager: data = preprocess_vars(loader.load_from_file(vars_file)) if data is not None: for item in data: - all_vars = combine_vars(all_vars, item) + all_vars.push(item) break except AnsibleFileNotFound as e: # we continue on loader failures @@ -320,43 +321,43 @@ class VariableManager: if not C.DEFAULT_PRIVATE_ROLE_VARS: for role in play.get_roles(): - all_vars = combine_vars(all_vars, role.get_vars(include_params=False)) + all_vars.push(role.get_vars(include_params=False)) if task: if task._role: - all_vars = combine_vars(all_vars, task._role.get_vars(include_params=False)) - all_vars = combine_vars(all_vars, task._role.get_role_params(task._block.get_dep_chain())) - all_vars = combine_vars(all_vars, task.get_vars()) + all_vars.push(task._role.get_vars(include_params=False)) + all_vars.push(task._role.get_role_params(task._block.get_dep_chain())) + all_vars.push(task.get_vars()) if host: - all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict())) - all_vars = combine_vars(all_vars, self._nonpersistent_fact_cache.get(host.name, dict())) + all_vars.push(self._vars_cache.get(host.get_name(), dict())) + all_vars.push(self._nonpersistent_fact_cache.get(host.name, dict())) # special case for include tasks, where the include params # may be specified in the vars field for the task, which should # have higher precedence than the vars/np facts above if task: - all_vars = combine_vars(all_vars, task.get_include_params()) + all_vars.push(task.get_include_params()) - all_vars = combine_vars(all_vars, self._extra_vars) - all_vars = combine_vars(all_vars, magic_variables) + all_vars.push(self._extra_vars) + all_vars.push(magic_variables) # special case for the 'environment' magic variable, as someone # may have set it as a variable and we don't want to stomp on it if task: if 'environment' not in all_vars: - all_vars['environment'] = task.environment + all_vars.push(dict(environment=task.environment)) else: display.warning("The variable 'environment' appears to be used already, which is also used internally for environment variables set on the task/block/play. You should use a different variable name to avoid conflicts with this internal variable") # if we have a task and we're delegating to another host, figure out the # variables for that host now so we don't have to rely on hostvars later if task and task.delegate_to is not None and include_delegate_to: - all_vars['ansible_delegated_vars'] = self._get_delegated_vars(loader, play, task, all_vars) + all_vars.push(dict(ansible_delegated_vars=self._get_delegated_vars(loader, play, task, all_vars))) #VARIABLE_CACHE[cache_entry] = all_vars if task or play: - all_vars['vars'] = all_vars.copy() + all_vars.push(dict(vars=all_vars.to_dict())) display.debug("done with get_vars()") return all_vars @@ -442,10 +443,11 @@ class VariableManager: for item in items: # update the variables with the item value for templating, in case we need it if item is not None: - vars_copy['item'] = item - - templar.set_available_variables(vars_copy) + vars_copy.push(dict(item=item)) delegated_host_name = templar.template(task.delegate_to, fail_on_undefined=False) + if item is not None: + vars_copy.pop() + if delegated_host_name in delegated_host_vars: # no need to repeat ourselves, as the delegate_to value # does not appear to be tied to the loop item variable diff --git a/lib/ansible/vars/chain_map.py b/lib/ansible/vars/chain_map.py new file mode 100644 index 0000000000..9221aa4d0c --- /dev/null +++ b/lib/ansible/vars/chain_map.py @@ -0,0 +1,108 @@ +# (c) 2016, Ansible, Inc. <support@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from collections import MutableMapping + +from ansible import constants as C +from ansible.utils.vars import merge_hash + +class AnsibleChainMap(MutableMapping): + ''' + A variation of the ChainMap idea, which is extended here to + also support merging dicts values from multiple levels. + ''' + def __init__(self, *args, **kwargs): + self._maps = [dict()] + + def __str__(self): + return str(self.to_dict()) + + def __getitem__(self, k): + if C.DEFAULT_HASH_BEHAVIOUR == 'merge': + tmp = None + found = False + for m in self._maps: + if k in m: + if isinstance(m[k], dict) and isinstance(tmp, dict): + tmp = merge_hash(tmp, m[k]) + else: + tmp = m[k] + found = True + if found: + return tmp + else: + for m in reversed(self._maps): + if k in m: + return m[k] + raise KeyError + + def __setitem__(self, k, v): + ''' + This sets the key to the value specified if it is found in any + mapping in the list, otherwise it is set in the default dict + (slot 0 in the maps list) + ''' + for m in reversed(self._maps): + if k in m: + m[k] = v + break + else: + self._maps[0][k] = v + + def __delitem__(self, k): + ''' + This deletes the key in ALL maps contained within the list. + ''' + for m in self._maps: + if k in m: + del m[k] + + def __iter__(self): + for k in self.keys(): + yield k + + def __len__(self): + return len(self.keys()) + + def keys(self): + key_set = set() + for m in self._maps: + key_set.update(m) + + return list(key_set) + + def update(self, m): + assert isinstance(m, MutableMapping) + self.push(m) + + def push(self, m): + self._maps.append(m) + + def pop(self): + return self._maps.pop() + + def copy(self): + new_map = AnsibleChainMap() + new_map._maps = self._maps[:] + return new_map + + def to_dict(self): + return dict((k, self[k]) for k in self.keys()) diff --git a/lib/ansible/vars/hostvars.py b/lib/ansible/vars/hostvars.py index c4010447b7..3f3da84a78 100644 --- a/lib/ansible/vars/hostvars.py +++ b/lib/ansible/vars/hostvars.py @@ -69,7 +69,7 @@ class HostVars(collections.Mapping): if host is None: raise j2undefined - data = self._variable_manager.get_vars(loader=self._loader, host=host, include_hostvars=False) + data = self._variable_manager.get_vars(loader=self._loader, host=host, include_hostvars=False).to_dict() sha1_hash = sha1(str(data).encode('utf-8')).hexdigest() if sha1_hash in self._cached_result: diff --git a/test/units/executor/test_task_executor.py b/test/units/executor/test_task_executor.py index 741c76b330..4f300ea935 100644 --- a/test/units/executor/test_task_executor.py +++ b/test/units/executor/test_task_executor.py @@ -26,6 +26,7 @@ from ansible.errors import AnsibleError, AnsibleParserError from ansible.executor.task_executor import TaskExecutor from ansible.playbook.play_context import PlayContext from ansible.plugins import action_loader, lookup_loader +from ansible.vars.chain_map import AnsibleChainMap from units.mock.loader import DictDataLoader @@ -44,7 +45,7 @@ class TestTaskExecutor(unittest.TestCase): mock_play_context = MagicMock() mock_shared_loader = MagicMock() new_stdin = None - job_vars = dict() + job_vars = AnsibleChainMap() mock_queue = MagicMock() te = TaskExecutor( host = mock_host, @@ -71,7 +72,7 @@ class TestTaskExecutor(unittest.TestCase): mock_queue = MagicMock() new_stdin = None - job_vars = dict() + job_vars = AnsibleChainMap() te = TaskExecutor( host = mock_host, @@ -114,7 +115,7 @@ class TestTaskExecutor(unittest.TestCase): mock_shared_loader.lookup_loader = lookup_loader new_stdin = None - job_vars = dict() + job_vars = AnsibleChainMap() mock_queue = MagicMock() te = TaskExecutor( @@ -151,7 +152,7 @@ class TestTaskExecutor(unittest.TestCase): mock_queue = MagicMock() new_stdin = None - job_vars = dict() + job_vars = AnsibleChainMap() te = TaskExecutor( host = mock_host, @@ -197,7 +198,8 @@ class TestTaskExecutor(unittest.TestCase): mock_queue = MagicMock() new_stdin = None - job_vars = dict(pkg_mgr='yum') + job_vars = AnsibleChainMap() + job_vars.push(dict(pkg_mgr='yum')) te = TaskExecutor( host = mock_host, @@ -246,7 +248,8 @@ class TestTaskExecutor(unittest.TestCase): # an error later. If so, we can throw it now instead. # Squashing in this case would not be intuitive as the user is being # explicit in using each list entry as a key. - job_vars = dict(pkg_mgr='yum', packages={ "a": "foo", "b": "bar", "foo": "baz", "bar": "quux" }) + job_vars = AnsibleChainMap() + job_vars.push(dict(pkg_mgr='yum', packages={ "a": "foo", "b": "bar", "foo": "baz", "bar": "quux" })) items = [['a', 'b'], ['foo', 'bar']] mock_task.action = 'yum' mock_task.args = {'name': '{{ packages[item] }}'} @@ -283,7 +286,8 @@ class TestTaskExecutor(unittest.TestCase): # # Squashing lists - job_vars = dict(pkg_mgr='yum') + job_vars = AnsibleChainMap() + job_vars.push(dict(pkg_mgr='yum')) items = [['a', 'b'], ['foo', 'bar']] mock_task.action = 'yum' mock_task.args = {'name': '{{ item }}'} @@ -300,7 +304,8 @@ class TestTaskExecutor(unittest.TestCase): self.assertEqual(new_items, items) # Another way to retrieve from a dict - job_vars = dict(pkg_mgr='yum') + job_vars.pop() + job_vars.push(dict(pkg_mgr='yum')) items = [{'package': 'foo'}, {'package': 'bar'}] mock_task.action = 'yum' mock_task.args = {'name': '{{ item["package"] }}'} @@ -328,7 +333,8 @@ class TestTaskExecutor(unittest.TestCase): # dict(name='c', state='absent')]) # Could do something like this to recover from bad deps in a package - job_vars = dict(pkg_mgr='yum', packages=['a', 'b']) + job_vars.pop() + job_vars.push(dict(pkg_mgr='yum', packages=['a', 'b'])) items = [ 'absent', 'latest' ] mock_task.action = 'yum' mock_task.args = {'name': '{{ packages }}', 'state': '{{ item }}'} @@ -369,7 +375,8 @@ class TestTaskExecutor(unittest.TestCase): shared_loader = None new_stdin = None - job_vars = dict(omit="XXXXXXXXXXXXXXXXXXX") + job_vars = AnsibleChainMap() + job_vars.push(dict(omit="XXXXXXXXXXXXXXXXXXX")) te = TaskExecutor( host = mock_host, @@ -424,7 +431,8 @@ class TestTaskExecutor(unittest.TestCase): shared_loader.action_loader = action_loader new_stdin = None - job_vars = dict(omit="XXXXXXXXXXXXXXXXXXX") + job_vars = AnsibleChainMap() + job_vars.push(dict(omit="XXXXXXXXXXXXXXXXXXX")) te = TaskExecutor( host = mock_host, |