From 60f76436c144a08aa6b74bfefd559ac0188202f6 Mon Sep 17 00:00:00 2001 From: Martin Krizek Date: Fri, 9 Dec 2022 13:33:13 +0100 Subject: Simplify AnsibleJ2Vars by using ChainMap for vars (#78713) Co-authored-by: Matt Martz --- changelogs/fragments/ansiblej2vars-chainmap.yml | 2 + lib/ansible/template/vars.py | 148 ++++++++---------------- test/units/template/test_vars.py | 23 ++-- 3 files changed, 58 insertions(+), 115 deletions(-) create mode 100644 changelogs/fragments/ansiblej2vars-chainmap.yml diff --git a/changelogs/fragments/ansiblej2vars-chainmap.yml b/changelogs/fragments/ansiblej2vars-chainmap.yml new file mode 100644 index 0000000000..04175e332c --- /dev/null +++ b/changelogs/fragments/ansiblej2vars-chainmap.yml @@ -0,0 +1,2 @@ +minor_changes: + - "``AnsibleJ2Vars`` class that acts as a storage for all variables for templating purposes now uses ``collections.ChainMap`` internally." diff --git a/lib/ansible/template/vars.py b/lib/ansible/template/vars.py index fd1b812458..a7a9402ce8 100644 --- a/lib/ansible/template/vars.py +++ b/lib/ansible/template/vars.py @@ -1,25 +1,7 @@ # (c) 2012, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -# Make coding more python3-ish -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type - -from collections.abc import Mapping +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from collections import ChainMap from jinja2.utils import missing @@ -30,99 +12,65 @@ from ansible.module_utils._text import to_native __all__ = ['AnsibleJ2Vars'] -class AnsibleJ2Vars(Mapping): - ''' - Helper class to template all variable content before jinja2 sees it. This is - done by hijacking the variable storage that jinja2 uses, and overriding __contains__ - and __getitem__ to look like a dict. Added bonus is avoiding duplicating the large - hashes that inject tends to be. +def _process_locals(_l): + if _l is None: + return {} + return { + k: v for k, v in _l.items() + if v is not missing + and k not in {'context', 'environment', 'template'} # NOTE is this really needed? + } - To facilitate using builtin jinja2 things like range, globals are also handled here. - ''' - def __init__(self, templar, globals, locals=None): - ''' - Initializes this object with a valid Templar() object, as - well as several dictionaries of variables representing - different scopes (in jinja2 terminology). - ''' +class AnsibleJ2Vars(ChainMap): + """Helper variable storage class that allows for nested variables templating: `foo: "{{ bar }}"`.""" + def __init__(self, templar, globals, locals=None): self._templar = templar - self._globals = globals - self._locals = dict() - if isinstance(locals, dict): - for key, val in locals.items(): - if val is not missing: - if key[:2] == 'l_': - self._locals[key[2:]] = val - elif key not in ('context', 'environment', 'template'): - self._locals[key] = val - - def __contains__(self, k): - if k in self._locals: - return True - if k in self._templar.available_variables: - return True - if k in self._globals: - return True - return False - - def __iter__(self): - keys = set() - keys.update(self._templar.available_variables, self._locals, self._globals) - return iter(keys) - - def __len__(self): - keys = set() - keys.update(self._templar.available_variables, self._locals, self._globals) - return len(keys) + super().__init__( + _process_locals(locals), # first mapping has the highest precedence + self._templar.available_variables, + globals, + ) def __getitem__(self, varname): - if varname in self._locals: - return self._locals[varname] - if varname in self._templar.available_variables: - variable = self._templar.available_variables[varname] - elif varname in self._globals: - return self._globals[varname] - else: - raise KeyError("undefined variable: %s" % varname) - - # HostVars is special, return it as-is, as is the special variable - # 'vars', which contains the vars structure + variable = super().__getitem__(varname) + from ansible.vars.hostvars import HostVars - if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars) or hasattr(variable, '__UNSAFE__'): + if (varname == "vars" and isinstance(variable, dict)) or isinstance(variable, HostVars) or hasattr(variable, '__UNSAFE__'): return variable - else: - value = None - try: - value = self._templar.template(variable) - except AnsibleUndefinedVariable as e: - # Instead of failing here prematurely, return an Undefined - # object which fails only after its first usage allowing us to - # do lazy evaluation and passing it into filters/tests that - # operate on such objects. - return self._templar.environment.undefined( - hint=f"{variable}: {e.message}", - name=varname, - exc=AnsibleUndefinedVariable, - ) - except Exception as e: - msg = getattr(e, 'message', None) or to_native(e) - raise AnsibleError("An unhandled exception occurred while templating '%s'. " - "Error was a %s, original message: %s" % (to_native(variable), type(e), msg)) - - return value + + try: + return self._templar.template(variable) + except AnsibleUndefinedVariable as e: + # Instead of failing here prematurely, return an Undefined + # object which fails only after its first usage allowing us to + # do lazy evaluation and passing it into filters/tests that + # operate on such objects. + return self._templar.environment.undefined( + hint=f"{variable}: {e.message}", + name=varname, + exc=AnsibleUndefinedVariable, + ) + except Exception as e: + msg = getattr(e, 'message', None) or to_native(e) + raise AnsibleError( + f"An unhandled exception occurred while templating '{to_native(variable)}'. " + f"Error was a {type(e)}, original message: {msg}" + ) def add_locals(self, locals): - ''' - If locals are provided, create a copy of self containing those + """If locals are provided, create a copy of self containing those locals in addition to what is already in this variable proxy. - ''' + """ if locals is None: return self + current_locals = self.maps[0] + current_globals = self.maps[2] + # prior to version 2.9, locals contained all of the vars and not just the current # local vars so this was not necessary for locals to propagate down to nested includes - new_locals = self._locals | locals + new_locals = current_locals | locals - return AnsibleJ2Vars(self._templar, self._globals, locals=new_locals) + return AnsibleJ2Vars(self._templar, current_globals, locals=new_locals) diff --git a/test/units/template/test_vars.py b/test/units/template/test_vars.py index 514104f23b..f43cfac462 100644 --- a/test/units/template/test_vars.py +++ b/test/units/template/test_vars.py @@ -19,23 +19,16 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -from units.compat import unittest -from unittest.mock import MagicMock - +from ansible.template import Templar from ansible.template.vars import AnsibleJ2Vars -class TestVars(unittest.TestCase): - def setUp(self): - self.mock_templar = MagicMock(name='mock_templar') +def test_globals_empty(): + assert isinstance(dict(AnsibleJ2Vars(Templar(None), {})), dict) - def test_globals_empty(self): - ajvars = AnsibleJ2Vars(self.mock_templar, {}) - res = dict(ajvars) - self.assertIsInstance(res, dict) - def test_globals(self): - res = dict(AnsibleJ2Vars(self.mock_templar, {'foo': 'bar', 'blip': [1, 2, 3]})) - self.assertIsInstance(res, dict) - self.assertIn('foo', res) - self.assertEqual(res['foo'], 'bar') +def test_globals(): + res = dict(AnsibleJ2Vars(Templar(None), {'foo': 'bar', 'blip': [1, 2, 3]})) + assert isinstance(res, dict) + assert 'foo' in res + assert res['foo'] == 'bar' -- cgit v1.2.1