summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Cammarata <jimi@sngx.net>2016-05-27 12:22:57 -0400
committerJames Cammarata <jimi@sngx.net>2016-05-29 08:35:06 -0400
commitd93f0e0d9bf2e77b71e16298b95c3180f0c8d332 (patch)
tree0eb8f315842a27ae61c27ca492e008ed2148ca5b
parentd81b9ca29e1e2098c1a12b0fc5892b66706acb20 (diff)
downloadansible_chain_map.tar.gz
Create and use AnsibleChainMap for the master dict of variablesansible_chain_map
-rw-r--r--lib/ansible/executor/task_executor.py20
-rw-r--r--lib/ansible/playbook/play_context.py4
-rw-r--r--lib/ansible/template/__init__.py3
-rw-r--r--lib/ansible/vars/__init__.py54
-rw-r--r--lib/ansible/vars/chain_map.py108
-rw-r--r--lib/ansible/vars/hostvars.py2
-rw-r--r--test/units/executor/test_task_executor.py30
7 files changed, 172 insertions, 49 deletions
diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py
index f35ba897f1..51ec0a5155 100644
--- a/lib/ansible/executor/task_executor.py
+++ b/lib/ansible/executor/task_executor.py
@@ -26,6 +26,8 @@ import sys
import time
import traceback
+from collections import MutableMapping
+
from ansible.compat.six import iteritems, string_types, binary_type
from ansible import constants as C
@@ -434,13 +436,13 @@ class TaskExecutor:
if delay < 0:
delay = 1
- # make a copy of the job vars here, in case we need to update them
- # with the registered variable value later on when testing conditions
- vars_copy = variables.copy()
-
display.debug("starting attempt loop")
result = None
for attempt in range(1, retries + 1):
+ # make a copy of the job vars here, in case we need to update them
+ # with the registered variable value later on when testing conditions
+ vars_copy = variables.copy()
+
display.debug("running the handler")
try:
result = self._handler.run(task_vars=variables)
@@ -454,7 +456,7 @@ class TaskExecutor:
# update the local copy of vars with the registered value, if specified,
# or any facts which may have been generated by the module execution
if self._task.register:
- vars_copy[self._task.register] = wrap_var(result.copy())
+ vars_copy.push({self._task.register: wrap_var(result.copy())})
if self._task.async > 0:
# the async_wrapper module returns dumped JSON via its stdout
@@ -490,7 +492,7 @@ class TaskExecutor:
return failed_when_result
if 'ansible_facts' in result:
- vars_copy.update(result['ansible_facts'])
+ vars_copy.push(result['ansible_facts'])
# set the failed property if the result has a non-zero rc. This will be
# overridden below if the failed_when property is set
@@ -525,10 +527,10 @@ class TaskExecutor:
# do the final update of the local variables here, for both registered
# values and any facts which may have been created
if self._task.register:
- variables[self._task.register] = wrap_var(result)
+ variables.push({self._task.register: wrap_var(result)})
if 'ansible_facts' in result:
- variables.update(result['ansible_facts'])
+ variables.push(result['ansible_facts'])
# save the notification target in the result, if it was specified, as
# this task may be running in a loop in which case the notification
@@ -609,7 +611,7 @@ class TaskExecutor:
# now replace the interpreter values with those that may have come
# from the delegated-to host
delegated_vars = variables.get('ansible_delegated_vars', dict()).get(self._task.delegate_to, dict())
- if isinstance(delegated_vars, dict):
+ if isinstance(delegated_vars, MutableMapping):
for i in delegated_vars:
if isinstance(i, string_types) and i.startswith("ansible_") and i.endswith("_interpreter"):
variables[i] = delegated_vars[i]
diff --git a/lib/ansible/playbook/play_context.py b/lib/ansible/playbook/play_context.py
index a79b4e1988..9783a6b574 100644
--- a/lib/ansible/playbook/play_context.py
+++ b/lib/ansible/playbook/play_context.py
@@ -26,6 +26,8 @@ import random
import re
import string
+from collections import MutableMapping
+
from ansible.compat.six import iteritems, string_types
from ansible import constants as C
from ansible.errors import AnsibleError
@@ -375,7 +377,7 @@ class PlayContext(Base):
continue
# if delegation task ONLY use delegated host vars, avoid delegated FOR host vars
if task.delegate_to is not None:
- if isinstance(delegated_vars, dict) and variable_name in delegated_vars:
+ if isinstance(delegated_vars, MutableMapping) and variable_name in delegated_vars:
setattr(new_info, attr, delegated_vars[variable_name])
attrs_considered.append(attr)
elif variable_name in variables:
diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py
index 4e1f7b7fd8..978d5fae3c 100644
--- a/lib/ansible/template/__init__.py
+++ b/lib/ansible/template/__init__.py
@@ -25,6 +25,7 @@ import os
import re
from io import StringIO
+from collections import MutableMapping
from ansible.compat.six import string_types, text_type, binary_type
from jinja2 import Environment
@@ -273,7 +274,7 @@ class Templar:
are being changed.
'''
- assert isinstance(variables, dict)
+ assert isinstance(variables, MutableMapping)
self._available_variables = variables
self._cached_result = {}
diff --git a/lib/ansible/vars/__init__.py b/lib/ansible/vars/__init__.py
index e1f569abd2..0e43230310 100644
--- a/lib/ansible/vars/__init__.py
+++ b/lib/ansible/vars/__init__.py
@@ -42,6 +42,7 @@ from ansible.plugins.cache import FactCache
from ansible.template import Templar
from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.vars import combine_vars
+from ansible.vars.chain_map import AnsibleChainMap
from ansible.vars.unsafe_proxy import wrap_var
try:
@@ -212,7 +213,7 @@ class VariableManager:
display.debug("vars are cached, returning them now")
return VARIABLE_CACHE[cache_entry]
- all_vars = dict()
+ all_vars = AnsibleChainMap()
magic_variables = self._get_magic_variables(
loader=loader,
play=play,
@@ -226,13 +227,13 @@ class VariableManager:
# first we compile any vars specified in defaults/main.yml
# for all roles within the specified play
for role in play.get_roles():
- all_vars = combine_vars(all_vars, role.get_default_vars())
+ all_vars.push(role.get_default_vars())
# if we have a task in this context, and that task has a role, make
# sure it sees its defaults above any other roles, as we previously
# (v1) made sure each task had a copy of its roles default vars
if task and task._role is not None:
- all_vars = combine_vars(all_vars, task._role.get_default_vars(dep_chain=task._block.get_dep_chain()))
+ all_vars.push(task._role.get_default_vars(dep_chain=task._block.get_dep_chain()))
if host:
# next, if a host is specified, we load any vars from group_vars
@@ -243,20 +244,20 @@ class VariableManager:
if 'all' in self._group_vars_files:
data = preprocess_vars(self._group_vars_files['all'])
for item in data:
- all_vars = combine_vars(all_vars, item)
+ all_vars.push(item)
# we merge in vars from groups specified in the inventory (INI or script)
- all_vars = combine_vars(all_vars, host.get_group_vars())
+ all_vars.push(host.get_group_vars())
for group in sorted(host.get_groups(), key=lambda g: g.depth):
if group.name in self._group_vars_files and group.name != 'all':
for data in self._group_vars_files[group.name]:
data = preprocess_vars(data)
for item in data:
- all_vars = combine_vars(all_vars, item)
+ all_vars.push(item)
# then we merge in vars from the host specified in the inventory (INI or script)
- all_vars = combine_vars(all_vars, host.get_vars())
+ all_vars.push(host.get_vars())
# then we merge in the host_vars/<hostname> file, if it exists
host_name = host.get_name()
@@ -264,17 +265,17 @@ class VariableManager:
for data in self._host_vars_files[host_name]:
data = preprocess_vars(data)
for item in data:
- all_vars = combine_vars(all_vars, item)
+ all_vars.push(item)
# finally, the facts caches for this host, if it exists
try:
host_facts = wrap_var(self._fact_cache.get(host.name, dict()))
- all_vars = combine_vars(all_vars, host_facts)
+ all_vars.push(host_facts)
except KeyError:
pass
if play:
- all_vars = combine_vars(all_vars, play.get_vars())
+ all_vars.push(play.get_vars())
for vars_file_item in play.get_vars_files():
# create a set of temporary vars here, which incorporate the extra
@@ -300,7 +301,7 @@ class VariableManager:
data = preprocess_vars(loader.load_from_file(vars_file))
if data is not None:
for item in data:
- all_vars = combine_vars(all_vars, item)
+ all_vars.push(item)
break
except AnsibleFileNotFound as e:
# we continue on loader failures
@@ -320,43 +321,43 @@ class VariableManager:
if not C.DEFAULT_PRIVATE_ROLE_VARS:
for role in play.get_roles():
- all_vars = combine_vars(all_vars, role.get_vars(include_params=False))
+ all_vars.push(role.get_vars(include_params=False))
if task:
if task._role:
- all_vars = combine_vars(all_vars, task._role.get_vars(include_params=False))
- all_vars = combine_vars(all_vars, task._role.get_role_params(task._block.get_dep_chain()))
- all_vars = combine_vars(all_vars, task.get_vars())
+ all_vars.push(task._role.get_vars(include_params=False))
+ all_vars.push(task._role.get_role_params(task._block.get_dep_chain()))
+ all_vars.push(task.get_vars())
if host:
- all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
- all_vars = combine_vars(all_vars, self._nonpersistent_fact_cache.get(host.name, dict()))
+ all_vars.push(self._vars_cache.get(host.get_name(), dict()))
+ all_vars.push(self._nonpersistent_fact_cache.get(host.name, dict()))
# special case for include tasks, where the include params
# may be specified in the vars field for the task, which should
# have higher precedence than the vars/np facts above
if task:
- all_vars = combine_vars(all_vars, task.get_include_params())
+ all_vars.push(task.get_include_params())
- all_vars = combine_vars(all_vars, self._extra_vars)
- all_vars = combine_vars(all_vars, magic_variables)
+ all_vars.push(self._extra_vars)
+ all_vars.push(magic_variables)
# special case for the 'environment' magic variable, as someone
# may have set it as a variable and we don't want to stomp on it
if task:
if 'environment' not in all_vars:
- all_vars['environment'] = task.environment
+ all_vars.push(dict(environment=task.environment))
else:
display.warning("The variable 'environment' appears to be used already, which is also used internally for environment variables set on the task/block/play. You should use a different variable name to avoid conflicts with this internal variable")
# if we have a task and we're delegating to another host, figure out the
# variables for that host now so we don't have to rely on hostvars later
if task and task.delegate_to is not None and include_delegate_to:
- all_vars['ansible_delegated_vars'] = self._get_delegated_vars(loader, play, task, all_vars)
+ all_vars.push(dict(ansible_delegated_vars=self._get_delegated_vars(loader, play, task, all_vars)))
#VARIABLE_CACHE[cache_entry] = all_vars
if task or play:
- all_vars['vars'] = all_vars.copy()
+ all_vars.push(dict(vars=all_vars.to_dict()))
display.debug("done with get_vars()")
return all_vars
@@ -442,10 +443,11 @@ class VariableManager:
for item in items:
# update the variables with the item value for templating, in case we need it
if item is not None:
- vars_copy['item'] = item
-
- templar.set_available_variables(vars_copy)
+ vars_copy.push(dict(item=item))
delegated_host_name = templar.template(task.delegate_to, fail_on_undefined=False)
+ if item is not None:
+ vars_copy.pop()
+
if delegated_host_name in delegated_host_vars:
# no need to repeat ourselves, as the delegate_to value
# does not appear to be tied to the loop item variable
diff --git a/lib/ansible/vars/chain_map.py b/lib/ansible/vars/chain_map.py
new file mode 100644
index 0000000000..9221aa4d0c
--- /dev/null
+++ b/lib/ansible/vars/chain_map.py
@@ -0,0 +1,108 @@
+# (c) 2016, Ansible, Inc. <support@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from collections import MutableMapping
+
+from ansible import constants as C
+from ansible.utils.vars import merge_hash
+
+class AnsibleChainMap(MutableMapping):
+ '''
+ A variation of the ChainMap idea, which is extended here to
+ also support merging dicts values from multiple levels.
+ '''
+ def __init__(self, *args, **kwargs):
+ self._maps = [dict()]
+
+ def __str__(self):
+ return str(self.to_dict())
+
+ def __getitem__(self, k):
+ if C.DEFAULT_HASH_BEHAVIOUR == 'merge':
+ tmp = None
+ found = False
+ for m in self._maps:
+ if k in m:
+ if isinstance(m[k], dict) and isinstance(tmp, dict):
+ tmp = merge_hash(tmp, m[k])
+ else:
+ tmp = m[k]
+ found = True
+ if found:
+ return tmp
+ else:
+ for m in reversed(self._maps):
+ if k in m:
+ return m[k]
+ raise KeyError
+
+ def __setitem__(self, k, v):
+ '''
+ This sets the key to the value specified if it is found in any
+ mapping in the list, otherwise it is set in the default dict
+ (slot 0 in the maps list)
+ '''
+ for m in reversed(self._maps):
+ if k in m:
+ m[k] = v
+ break
+ else:
+ self._maps[0][k] = v
+
+ def __delitem__(self, k):
+ '''
+ This deletes the key in ALL maps contained within the list.
+ '''
+ for m in self._maps:
+ if k in m:
+ del m[k]
+
+ def __iter__(self):
+ for k in self.keys():
+ yield k
+
+ def __len__(self):
+ return len(self.keys())
+
+ def keys(self):
+ key_set = set()
+ for m in self._maps:
+ key_set.update(m)
+
+ return list(key_set)
+
+ def update(self, m):
+ assert isinstance(m, MutableMapping)
+ self.push(m)
+
+ def push(self, m):
+ self._maps.append(m)
+
+ def pop(self):
+ return self._maps.pop()
+
+ def copy(self):
+ new_map = AnsibleChainMap()
+ new_map._maps = self._maps[:]
+ return new_map
+
+ def to_dict(self):
+ return dict((k, self[k]) for k in self.keys())
diff --git a/lib/ansible/vars/hostvars.py b/lib/ansible/vars/hostvars.py
index c4010447b7..3f3da84a78 100644
--- a/lib/ansible/vars/hostvars.py
+++ b/lib/ansible/vars/hostvars.py
@@ -69,7 +69,7 @@ class HostVars(collections.Mapping):
if host is None:
raise j2undefined
- data = self._variable_manager.get_vars(loader=self._loader, host=host, include_hostvars=False)
+ data = self._variable_manager.get_vars(loader=self._loader, host=host, include_hostvars=False).to_dict()
sha1_hash = sha1(str(data).encode('utf-8')).hexdigest()
if sha1_hash in self._cached_result:
diff --git a/test/units/executor/test_task_executor.py b/test/units/executor/test_task_executor.py
index 741c76b330..4f300ea935 100644
--- a/test/units/executor/test_task_executor.py
+++ b/test/units/executor/test_task_executor.py
@@ -26,6 +26,7 @@ from ansible.errors import AnsibleError, AnsibleParserError
from ansible.executor.task_executor import TaskExecutor
from ansible.playbook.play_context import PlayContext
from ansible.plugins import action_loader, lookup_loader
+from ansible.vars.chain_map import AnsibleChainMap
from units.mock.loader import DictDataLoader
@@ -44,7 +45,7 @@ class TestTaskExecutor(unittest.TestCase):
mock_play_context = MagicMock()
mock_shared_loader = MagicMock()
new_stdin = None
- job_vars = dict()
+ job_vars = AnsibleChainMap()
mock_queue = MagicMock()
te = TaskExecutor(
host = mock_host,
@@ -71,7 +72,7 @@ class TestTaskExecutor(unittest.TestCase):
mock_queue = MagicMock()
new_stdin = None
- job_vars = dict()
+ job_vars = AnsibleChainMap()
te = TaskExecutor(
host = mock_host,
@@ -114,7 +115,7 @@ class TestTaskExecutor(unittest.TestCase):
mock_shared_loader.lookup_loader = lookup_loader
new_stdin = None
- job_vars = dict()
+ job_vars = AnsibleChainMap()
mock_queue = MagicMock()
te = TaskExecutor(
@@ -151,7 +152,7 @@ class TestTaskExecutor(unittest.TestCase):
mock_queue = MagicMock()
new_stdin = None
- job_vars = dict()
+ job_vars = AnsibleChainMap()
te = TaskExecutor(
host = mock_host,
@@ -197,7 +198,8 @@ class TestTaskExecutor(unittest.TestCase):
mock_queue = MagicMock()
new_stdin = None
- job_vars = dict(pkg_mgr='yum')
+ job_vars = AnsibleChainMap()
+ job_vars.push(dict(pkg_mgr='yum'))
te = TaskExecutor(
host = mock_host,
@@ -246,7 +248,8 @@ class TestTaskExecutor(unittest.TestCase):
# an error later. If so, we can throw it now instead.
# Squashing in this case would not be intuitive as the user is being
# explicit in using each list entry as a key.
- job_vars = dict(pkg_mgr='yum', packages={ "a": "foo", "b": "bar", "foo": "baz", "bar": "quux" })
+ job_vars = AnsibleChainMap()
+ job_vars.push(dict(pkg_mgr='yum', packages={ "a": "foo", "b": "bar", "foo": "baz", "bar": "quux" }))
items = [['a', 'b'], ['foo', 'bar']]
mock_task.action = 'yum'
mock_task.args = {'name': '{{ packages[item] }}'}
@@ -283,7 +286,8 @@ class TestTaskExecutor(unittest.TestCase):
#
# Squashing lists
- job_vars = dict(pkg_mgr='yum')
+ job_vars = AnsibleChainMap()
+ job_vars.push(dict(pkg_mgr='yum'))
items = [['a', 'b'], ['foo', 'bar']]
mock_task.action = 'yum'
mock_task.args = {'name': '{{ item }}'}
@@ -300,7 +304,8 @@ class TestTaskExecutor(unittest.TestCase):
self.assertEqual(new_items, items)
# Another way to retrieve from a dict
- job_vars = dict(pkg_mgr='yum')
+ job_vars.pop()
+ job_vars.push(dict(pkg_mgr='yum'))
items = [{'package': 'foo'}, {'package': 'bar'}]
mock_task.action = 'yum'
mock_task.args = {'name': '{{ item["package"] }}'}
@@ -328,7 +333,8 @@ class TestTaskExecutor(unittest.TestCase):
# dict(name='c', state='absent')])
# Could do something like this to recover from bad deps in a package
- job_vars = dict(pkg_mgr='yum', packages=['a', 'b'])
+ job_vars.pop()
+ job_vars.push(dict(pkg_mgr='yum', packages=['a', 'b']))
items = [ 'absent', 'latest' ]
mock_task.action = 'yum'
mock_task.args = {'name': '{{ packages }}', 'state': '{{ item }}'}
@@ -369,7 +375,8 @@ class TestTaskExecutor(unittest.TestCase):
shared_loader = None
new_stdin = None
- job_vars = dict(omit="XXXXXXXXXXXXXXXXXXX")
+ job_vars = AnsibleChainMap()
+ job_vars.push(dict(omit="XXXXXXXXXXXXXXXXXXX"))
te = TaskExecutor(
host = mock_host,
@@ -424,7 +431,8 @@ class TestTaskExecutor(unittest.TestCase):
shared_loader.action_loader = action_loader
new_stdin = None
- job_vars = dict(omit="XXXXXXXXXXXXXXXXXXX")
+ job_vars = AnsibleChainMap()
+ job_vars.push(dict(omit="XXXXXXXXXXXXXXXXXXX"))
te = TaskExecutor(
host = mock_host,