summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.md1
-rw-r--r--README.md1
-rw-r--r--docsite/rst/playbooks_intro.rst26
-rw-r--r--lib/ansible/executor/play_iterator.py3
-rw-r--r--lib/ansible/executor/task_queue_manager.py25
-rw-r--r--lib/ansible/module_utils/eos.py508
-rw-r--r--lib/ansible/module_utils/netcfg.py279
-rw-r--r--lib/ansible/module_utils/netcmd.py202
-rw-r--r--lib/ansible/module_utils/network.py282
-rw-r--r--lib/ansible/module_utils/shell.py5
-rw-r--r--lib/ansible/playbook/handler.py4
-rw-r--r--lib/ansible/playbook/task.py2
-rw-r--r--lib/ansible/plugins/strategy/__init__.py89
-rw-r--r--test/units/plugins/strategies/test_strategy_base.py35
14 files changed, 1072 insertions, 390 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6a5c627cfd..11dc06c1c4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,7 @@ Ansible Changes By Release
###Major Changes:
+* Added the `listen` feature for modules. This feature allows tasks to more easily notify multiple handlers, as well as making it easier for handlers from decoupled roles to be notified.
* Added support for binary modules
* The service module has been changed to use system specific modules if they exist and fallback to the old service module if they cannot be found or detected.
diff --git a/README.md b/README.md
index 544fd3694a..4eb8bebbe6 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,4 @@
[![PyPI version](https://img.shields.io/pypi/v/ansible.svg)](https://pypi.python.org/pypi/ansible)
-[![PyPI downloads](https://img.shields.io/pypi/dm/ansible.svg)](https://pypi.python.org/pypi/ansible)
[![Build Status](https://api.shippable.com/projects/573f79d02a8192902e20e34b/badge?branch=devel)](https://app.shippable.com/projects/573f79d02a8192902e20e34b)
diff --git a/docsite/rst/playbooks_intro.rst b/docsite/rst/playbooks_intro.rst
index 1f24250d07..f135ee7742 100644
--- a/docsite/rst/playbooks_intro.rst
+++ b/docsite/rst/playbooks_intro.rst
@@ -378,15 +378,31 @@ Here's an example handlers section::
- name: restart apache
service: name=apache state=restarted
-Handlers are best used to restart services and trigger reboots. You probably
-won't need them for much else.
+As of Ansible 2.2, handlers can also "listen" to generic topics, and tasks can notify those topics as follows::
+
+ handlers:
+ - name: restart memcached
+ service: name=memcached state=restarted
+ listen: "restart web services"
+ - name: restart apache
+ service: name=apache state=restarted
+ listen: "restart web services"
+
+ tasks:
+ - name: restart everything
+ command: echo "this task will restart the web services"
+ notify: "restart web services"
+
+This use makes it much easier to trigger multiple handlers. It also decouples handlers from their names,
+making it easier to share handlers among playbooks and roles (especially when using 3rd party roles from
+a shared source like Galaxy).
.. note::
- * Notify handlers are always run in the same order they are defined, `not` in the order listed in the notify-statement.
- * Handler names live in a global namespace.
+ * Notify handlers are always run in the same order they are defined, `not` in the order listed in the notify-statement. This is also the case for handlers using `listen`.
+ * Handler names and `listen` topics live in a global namespace.
* If two handler tasks have the same name, only one will run.
`* <https://github.com/ansible/ansible/issues/4943>`_
- * You cannot notify a handler that is defined inside of an include
+ * You cannot notify a handler that is defined inside of an include. As of Ansible 2.1, this does work, however the include must be `static`.
Roles are described later on, but it's worthwhile to point out that:
diff --git a/lib/ansible/executor/play_iterator.py b/lib/ansible/executor/play_iterator.py
index 72742c1a34..4ccc709e66 100644
--- a/lib/ansible/executor/play_iterator.py
+++ b/lib/ansible/executor/play_iterator.py
@@ -212,9 +212,6 @@ class PlayIterator:
# plays won't try to advance)
play_context.start_at_task = None
- # Extend the play handlers list to include the handlers defined in roles
- self._play.handlers.extend(play.compile_roles_handlers())
-
def get_host_state(self, host):
# Since we're using the PlayIterator to carry forward failed hosts,
# in the event that a previous host was not in the current inventory
diff --git a/lib/ansible/executor/task_queue_manager.py b/lib/ansible/executor/task_queue_manager.py
index 8fe7404328..10f3b50d4b 100644
--- a/lib/ansible/executor/task_queue_manager.py
+++ b/lib/ansible/executor/task_queue_manager.py
@@ -93,6 +93,7 @@ class TaskQueueManager:
# this dictionary is used to keep track of notified handlers
self._notified_handlers = dict()
+ self._listening_handlers = dict()
# dictionaries to keep track of failed/unreachable hosts
self._failed_hosts = dict()
@@ -114,17 +115,21 @@ class TaskQueueManager:
self._result_prc = ResultProcess(self._final_q, self._workers)
self._result_prc.start()
- def _initialize_notified_handlers(self, handlers):
+ def _initialize_notified_handlers(self, play):
'''
Clears and initializes the shared notified handlers dict with entries
for each handler in the play, which is an empty array that will contain
inventory hostnames for those hosts triggering the handler.
'''
+ handlers = play.handlers
+ for role in play.roles:
+ handlers.extend(role._handler_blocks)
+
# Zero the dictionary first by removing any entries there.
# Proxied dicts don't support iteritems, so we have to use keys()
- for key in self._notified_handlers.keys():
- del self._notified_handlers[key]
+ self._notified_handlers.clear()
+ self._listening_handlers.clear()
def _process_block(b):
temp_list = []
@@ -139,9 +144,14 @@ class TaskQueueManager:
for handler_block in handlers:
handler_list.extend(_process_block(handler_block))
- # then initialize it with the handler names from the handler list
+ # then initialize it with the given handler list
for handler in handler_list:
- self._notified_handlers[handler.get_name()] = []
+ if handler not in self._notified_handlers:
+ self._notified_handlers[handler] = []
+ if handler.listen:
+ if handler.listen not in self._listening_handlers:
+ self._listening_handlers[handler.listen] = []
+ self._listening_handlers[handler.listen].append(handler.get_name())
def load_callbacks(self):
'''
@@ -226,7 +236,7 @@ class TaskQueueManager:
self.send_callback('v2_playbook_on_play_start', new_play)
# initialize the shared dictionary containing the notified handlers
- self._initialize_notified_handlers(new_play.handlers)
+ self._initialize_notified_handlers(new_play)
# load the specified strategy (or the default linear one)
strategy = strategy_loader.get(new_play.strategy, self)
@@ -299,9 +309,6 @@ class TaskQueueManager:
def get_loader(self):
return self._loader
- def get_notified_handlers(self):
- return self._notified_handlers
-
def get_workers(self):
return self._workers[:]
diff --git a/lib/ansible/module_utils/eos.py b/lib/ansible/module_utils/eos.py
index b89ad26179..430e521ed7 100644
--- a/lib/ansible/module_utils/eos.py
+++ b/lib/ansible/module_utils/eos.py
@@ -17,78 +17,224 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
+import collections
import re
-from ansible.module_utils.basic import AnsibleModule, env_fallback, get_exception
-from ansible.module_utils.shell import Shell, ShellError, Command, HAS_PARAMIKO
-from ansible.module_utils.netcfg import parse
-from ansible.module_utils.urls import fetch_url
+from ansible.module_utils.basic import json
+from ansible.module_utils.network import NetCli, NetworkError, get_module, Command
+from ansible.module_utils.network import add_argument, register_transport, to_list
+from ansible.module_utils.netcfg import NetworkConfig
+from ansible.module_utils.urls import fetch_url, url_argument_spec
NET_PASSWD_RE = re.compile(r"[\r\n]?password: $", re.I)
-NET_COMMON_ARGS = dict(
- host=dict(required=True),
- port=dict(type='int'),
- username=dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])),
- password=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD'])),
- ssh_keyfile=dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'),
- authorize=dict(default=False, fallback=(env_fallback, ['ANSIBLE_NET_AUTHORIZE']), type='bool'),
- auth_pass=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_AUTH_PASS'])),
- transport=dict(default='cli', choices=['cli', 'eapi']),
- use_ssl=dict(default=True, type='bool'),
- provider=dict(type='dict')
-)
-
-CLI_PROMPTS_RE = [
- re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
- re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
-]
-
-CLI_ERRORS_RE = [
- re.compile(r"% ?Error"),
- re.compile(r"^% \w+", re.M),
- re.compile(r"% ?Bad secret"),
- re.compile(r"invalid input", re.I),
- re.compile(r"(?:incomplete|ambiguous) command", re.I),
- re.compile(r"connection timed out", re.I),
- re.compile(r"[^\r\n]+ not found", re.I),
- re.compile(r"'[^']' +returned error code: ?\d+"),
- re.compile(r"[^\r\n]\/bin\/(?:ba)?sh")
-]
-
-
-def to_list(val):
- if isinstance(val, (list, tuple)):
- return list(val)
- elif val is not None:
- return [val]
+EAPI_FORMATS = ['json', 'text']
+
+add_argument('use_ssl', dict(default=True, type='bool'))
+add_argument('validate_certs', dict(default=True, type='bool'))
+
+ModuleStub = collections.namedtuple('ModuleStub', 'params fail_json')
+
+def argument_spec():
+ return dict(
+ # config options
+ running_config=dict(aliases=['config']),
+ config_session=dict(default='ansible_session'),
+ save_config=dict(default=False, aliases=['save']),
+ force=dict(type='bool', default=False)
+ )
+eos_argument_spec = argument_spec()
+
+def get_config(module):
+ config = module.params['running_config']
+ if not config:
+ config = module.config.get_config(include_defaults=False)
+ return NetworkConfig(indent=3, contents=config)
+
+def load_config(module, candidate):
+
+ if not module.params['force']:
+ config = get_config(module)
+ commands = candidate.difference(config)
else:
- return list()
+ commands = str(candidate)
+ commands = [str(c).strip() for c in commands]
-class Eapi(object):
+ session = module.params['config_session']
+ save_config = module.params['save_config']
- def __init__(self, module):
- self.module = module
+ result = dict(changed=False)
- # sets the module_utils/urls.py req parameters
- self.module.params['url_username'] = module.params['username']
- self.module.params['url_password'] = module.params['password']
+ if commands:
+ if module._diff:
+ diff = module.config.load_config(commands, session_name=session)
+ if diff:
+ result['diff'] = dict(prepared=diff)
+
+ if not module.check_mode:
+ module.config.commit_config(session)
+ if save_config:
+ module.config.save_config()
+ else:
+ module.config.abort_config(session_name=session)
+
+ if not module.check_mode:
+ module.config(commands)
+ if save_config:
+ module.config.save_config()
+
+ result['changed'] = True
+ result['updates'] = commands
+
+ return result
+
+def expand_intf_range(interfaces):
+ match = re.match(r'([a-zA-Z]+)(.+)', interfaces)
+ if not match:
+ raise ValueError('could not parse interface range')
+
+ name = match.group(1)
+ values = match.group(2).split(',')
+
+ indicies = list()
+
+ for val in values:
+ tokens = val.split('-')
+
+ # single index value to handle
+ if len(tokens) == 1:
+ indicies.append(tokens[0])
+
+ elif len(tokens) == 2:
+ pairs = list()
+ mod = 0
+
+ for token in tokens:
+ parts = token.split('/')
+
+ if len(parts) == 1:
+ port = parts[0]
+ if port == '$':
+ port = last_port
+ pairs.append((mod, int(port)))
+
+ elif len(parts) == 2:
+ mod = int(parts[0])
+ port = parts[1]
+ if port == '$':
+ port = last_port
+ pairs.append((mod, int(port)))
+
+ else:
+ raise ValueError('unable to parse interface')
+
+ if pairs[0][0] == pairs[1][0]:
+ # same module
+ mod = pairs[0][0]
+ start = pairs[0][1]
+ end = pairs[1][1] + 1
+
+ for i in range(start, end):
+ if mod == 0:
+ indicies.append(i)
+ else:
+ indicies.append('%s/%s' % (mod, i))
+ else:
+ # span modules
+ start_mod, start_port = pairs[0]
+ end_mod, end_port = pairs[1]
+ end_port += 1
+
+ for i in range(start_port, last_port+1):
+ indicies.append('%s/%s' % (start_mod, i))
+ for i in range(first_port, end_port):
+ indicies.append('%s/%s' % (end_mod, i))
+
+ return ['%s%s' % (name, index) for index in indicies]
+
+class EosConfigMixin(object):
+
+ def configure(self, commands, **kwargs):
+ commands = prepare_config(commands)
+ responses = self.execute(commands)
+ responses.pop(0)
+ return responses
+
+ def get_config(self, **kwargs):
+ cmd = 'show running-config'
+ if kwargs.get('include_defaults') is True:
+ cmd += ' all'
+ return self.execute([cmd])[0]
+
+ def load_config(self, commands, session_name='ansible_temp_session', **kwargs):
+ commands = to_list(commands)
+ commands.insert(0, 'configure session %s' % session_name)
+ commands.append('show session-config diffs')
+ commands.append('end')
+ responses = self.execute(commands)
+ return responses[-2]
+
+ def replace_config(self, contents, params, **kwargs):
+ remote_user = params['username']
+ remote_path = '/home/%s/ansible-config' % remote_user
+
+ commands = [
+ 'bash echo "%s" > %s' % (contents, remote_path),
+ 'diff running-config file:/%s' % remote_path,
+ 'config replace file:/%s' % remote_path,
+ ]
+
+ responses = self.run_commands(commands)
+ return responses[-2]
+
+ def commit_config(self, session_name):
+ session = 'configure session %s' % session_name
+ commands = [session, 'commit', 'no %s' % session]
+ self.execute(commands)
+
+ def abort_config(self, session_name):
+ command = 'no configure session %s' % session_name
+ self.execute([command])
+
+ def save_config(self):
+ self.execute(['copy running-config startup-config'])
+
+class Eapi(EosConfigMixin):
+
+ def __init__(self):
self.url = None
+ self.url_args = ModuleStub(url_argument_spec(), self._error)
self.enable = None
+ self.default_output = 'json'
+ self._connected = False
- def _get_body(self, commands, encoding, reqid=None):
+ def _error(self, msg):
+ raise NetworkError(msg, url=self.url)
+
+ def _get_body(self, commands, format, reqid=None):
"""Create a valid eAPI JSON-RPC request message
"""
- params = dict(version=1, cmds=commands, format=encoding)
+
+ if format not in EAPI_FORMATS:
+ msg = 'invalid format, received %s, expected one of %s' % \
+ (format, ','.join(EAPI_FORMATS))
+ self._error(msg=msg)
+
+ params = dict(version=1, cmds=commands, format=format)
return dict(jsonrpc='2.0', id=reqid, method='runCmds', params=params)
- def connect(self):
- host = self.module.params['host']
- port = self.module.params['port']
+ def connect(self, params, **kwargs):
+ host = params['host']
+ port = params['port']
- if self.module.params['use_ssl']:
+ # sets the module_utils/urls.py req parameters
+ self.url_args.params['url_username'] = params['username']
+ self.url_args.params['url_password'] = params['password']
+ self.url_args.params['validate_certs'] = params['validate_certs']
+
+ if params['use_ssl']:
proto = 'https'
if not port:
port = 443
@@ -98,176 +244,146 @@ class Eapi(object):
port = 80
self.url = '%s://%s:%s/command-api' % (proto, host, port)
+ self._connected = True
+
+ def disconnect(self, **kwargs):
+ self.url = None
+ self._connected = False
- def authorize(self):
- if self.module.params['auth_pass']:
- passwd = self.module.params['auth_pass']
+ def authorize(self, params, **kwargs):
+ if params.get('auth_pass'):
+ passwd = params['auth_pass']
self.enable = dict(cmd='enable', input=passwd)
else:
self.enable = 'enable'
- def send(self, commands, encoding='json'):
- """Send commands to the device.
- """
- clist = to_list(commands)
-
- if self.enable is not None:
- clist.insert(0, self.enable)
- data = self._get_body(clist, encoding)
- data = self.module.jsonify(data)
+ ### implementation of network.Cli ###
- headers = {'Content-Type': 'application/json-rpc'}
+ def run_commands(self, commands):
+ output = None
+ cmds = list()
+ responses = list()
- response, headers = fetch_url(self.module, self.url, data=data,
- headers=headers, method='POST')
+ for cmd in commands:
+ if output and output != cmd.output:
+ responses.extend(self.execute(cmds, format=output))
+ cmds = list()
- if headers['status'] != 200:
- self.module.fail_json(**headers)
+ output = cmd.output
+ cmds.append(str(cmd))
- response = self.module.from_json(response.read())
- if 'error' in response:
- err = response['error']
- self.module.fail_json(msg='json-rpc error', commands=commands, **err)
+ if cmds:
+ responses.extend(self.execute(cmds, format=output))
- if self.enable:
- response['result'].pop(0)
+ for index, cmd in enumerate(commands):
+ if cmd.output == 'text':
+ responses[index] = responses[index].get('output')
- return response['result']
+ return responses
+ def execute(self, commands, format='json', **kwargs):
+ """Send commands to the device.
+ """
+ if self.url is None:
+ raise NetworkError('Not connected to endpoint.')
+ if self.enable is not None:
+ commands.insert(0, self.enable)
-class Cli(object):
+ data = self._get_body(commands, format)
+ data = json.dumps(data)
- def __init__(self, module):
- self.module = module
- self.shell = None
+ headers = {'Content-Type': 'application/json-rpc'}
- def connect(self, **kwargs):
- host = self.module.params['host']
- port = self.module.params['port'] or 22
+ response, headers = fetch_url(
+ self.url_args, self.url, data=data, headers=headers,
+ method='POST'
+ )
- username = self.module.params['username']
- password = self.module.params['password']
- key_filename = self.module.params['ssh_keyfile']
+ if headers['status'] != 200:
+ raise NetworkError(**headers)
try:
- self.shell = Shell(prompts_re=CLI_PROMPTS_RE, errors_re=CLI_ERRORS_RE)
- self.shell.open(host, port=port, username=username, password=password, key_filename=key_filename)
- except ShellError:
- e = get_exception()
- msg = 'failed to connect to %s:%s - %s' % (host, port, str(e))
- self.module.fail_json(msg=msg)
-
- def authorize(self):
- passwd = self.module.params['auth_pass']
- self.send(Command('enable', prompt=NET_PASSWD_RE, response=passwd))
-
- def send(self, commands):
- try:
- return self.shell.send(commands)
- except ShellError:
- e = get_exception()
- self.module.fail_json(msg=e.message, commands=commands)
+ response = json.loads(response.read())
+ except ValueError:
+ raise NetworkError('unable to load response from device')
+ if 'error' in response:
+ err = response['error']
+ raise NetworkError(
+ msg=err['message'], code=err['code'], data=err['data'],
+ commands=commands
+ )
-class NetworkModule(AnsibleModule):
-
- def __init__(self, *args, **kwargs):
- super(NetworkModule, self).__init__(*args, **kwargs)
- self.connection = None
- self._config = None
- self._connected = False
-
- @property
- def connected(self):
- return self._connected
-
- @property
- def config(self):
- if not self._config:
- self._config = self.get_config()
- return self._config
-
- def _load_params(self):
- super(NetworkModule, self)._load_params()
- provider = self.params.get('provider') or dict()
- for key, value in provider.items():
- if key in NET_COMMON_ARGS:
- if self.params.get(key) is None and value is not None:
- self.params[key] = value
-
- def connect(self):
- cls = globals().get(str(self.params['transport']).capitalize())
- try:
- self.connection = cls(self)
- except TypeError:
- e = get_exception()
- self.fail_json(msg=e.message)
-
- self.connection.connect()
- self.connection.send('terminal length 0')
+ if self.enable:
+ response['result'].pop(0)
- if self.params['authorize']:
- self.connection.authorize()
+ return response['result']
+ def get_config(self, **kwargs):
+ return self.run_commands(['show running-config'], format='text')[0]
+Eapi = register_transport('eapi')(Eapi)
+
+
+class Cli(NetCli, EosConfigMixin):
+ CLI_PROMPTS_RE = [
+ re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
+ re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
+ ]
+
+ CLI_ERRORS_RE = [
+ re.compile(r"% ?Error"),
+ re.compile(r"^% \w+", re.M),
+ re.compile(r"% ?Bad secret"),
+ re.compile(r"invalid input", re.I),
+ re.compile(r"(?:incomplete|ambiguous) command", re.I),
+ re.compile(r"connection timed out", re.I),
+ re.compile(r"[^\r\n]+ not found", re.I),
+ re.compile(r"'[^']' +returned error code: ?\d+"),
+ re.compile(r"[^\r\n]\/bin\/(?:ba)?sh")
+ ]
+
+ def __init__(self):
+ super(Cli, self).__init__()
+
+ def connect(self, params, **kwargs):
+ super(Cli, self).connect(params, kickstart=True, **kwargs)
+ self.shell.send('terminal length 0')
self._connected = True
- def configure(self, commands, replace=False):
- if replace:
- responses = self.config_replace(commands)
- else:
- responses = self.config_terminal(commands)
- return responses
-
- def config_terminal(self, commands):
- commands = to_list(commands)
- commands.insert(0, 'configure terminal')
- responses = self.execute(commands)
- responses.pop(0)
+ def authorize(self, params, **kwargs):
+ passwd = params['auth_pass']
+ self.execute(Command('enable', prompt=NET_PASSWD_RE, response=passwd))
+
+ ### implementation of network.Cli ###
+
+ def run_commands(self, commands):
+ cmds = list(prepare_commands(commands))
+ responses = self.execute(cmds)
+ for index, cmd in enumerate(commands):
+ if cmd.output == 'json':
+ try:
+ responses[index] = json.loads(responses[index])
+ except ValueError:
+ raise NetworkError(
+ msg='unable to load response from device',
+ response=responses[index]
+ )
return responses
+Cli = register_transport('cli', default=True)(Cli)
- def config_replace(self, commands):
- if self.params['transport'] == 'cli':
- self.fail_json(msg='config replace only supported over eapi')
- cmd = 'configure replace terminal:'
- commands = '\n'.join(to_list(commands))
- command = dict(cmd=cmd, input=commands)
- return self.execute(command)
-
- def execute(self, commands, **kwargs):
- if not self.connected:
- self.connect()
- return self.connection.send(commands, **kwargs)
-
- def disconnect(self):
- self.connection.close()
- self._connected = False
+def prepare_config(commands):
+ commands = to_list(commands)
+ commands.insert(0, 'configure terminal')
+ commands.append('end')
+ return commands
- def parse_config(self, cfg):
- return parse(cfg, indent=3)
- def get_config(self):
- cmd = 'show running-config'
- if self.params.get('include_defaults'):
- cmd += ' all'
- if self.params['transport'] == 'cli':
- return self.execute(cmd)[0]
+def prepare_commands(commands):
+ jsonify = lambda x: '%s | json' % x
+ for cmd in to_list(commands):
+ if cmd.output == 'json':
+ cmd = jsonify(cmd)
else:
- resp = self.execute(cmd, encoding='text')
- return resp[0]['output']
-
-
-def get_module(**kwargs):
- """Return instance of NetworkModule
- """
- argument_spec = NET_COMMON_ARGS.copy()
- if kwargs.get('argument_spec'):
- argument_spec.update(kwargs['argument_spec'])
- kwargs['argument_spec'] = argument_spec
-
- module = NetworkModule(**kwargs)
-
- if module.params['transport'] == 'cli' and not HAS_PARAMIKO:
- module.fail_json(msg='paramiko is required but does not appear to be installed')
-
- return module
+ cmd = str(cmd)
+ yield cmd
diff --git a/lib/ansible/module_utils/netcfg.py b/lib/ansible/module_utils/netcfg.py
index 6f6bbee6e1..71cb57ea65 100644
--- a/lib/ansible/module_utils/netcfg.py
+++ b/lib/ansible/module_utils/netcfg.py
@@ -18,11 +18,14 @@
#
import re
+import time
import collections
import itertools
import shlex
+import itertools
from ansible.module_utils.basic import BOOLEANS_TRUE, BOOLEANS_FALSE
+from ansible.module_utils.network import to_list
DEFAULT_COMMENT_TOKENS = ['#', '!']
@@ -34,6 +37,13 @@ class ConfigLine(object):
self.parents = list()
self.raw = None
+ @property
+ def line(self):
+ line = ['set']
+ line.extend([p.text for p in self.parents])
+ line.append(self.text)
+ return ' '.join(line)
+
def __str__(self):
return self.raw
@@ -49,16 +59,20 @@ def ignore_line(text, tokens=None):
if text.startswith(item):
return True
+def get_next(iterable):
+ item, next_item = itertools.tee(iterable, 2)
+ next_item = itertools.islice(next_item, 1, None)
+ return itertools.izip_longest(item, next_item)
+
def parse(lines, indent, comment_tokens=None):
toplevel = re.compile(r'\S')
childline = re.compile(r'^\s*(.+)$')
- repl = r'([{|}|;])'
ancestors = list()
config = list()
for line in str(lines).split('\n'):
- text = str(re.sub(repl, '', line)).strip()
+ text = str(re.sub(r'([{};])', '', line)).strip()
cfg = ConfigLine(text)
cfg.raw = line
@@ -108,11 +122,23 @@ class NetworkConfig(object):
def items(self):
return self._config
+ @property
+ def lines(self):
+ lines = list()
+ for item, next_item in get_next(self.items):
+ if next_item is None:
+ lines.append(item.line)
+ elif not next_item.line.startswith(item.line):
+ lines.append(item.line)
+ return lines
+
def __str__(self):
- config = collections.OrderedDict()
- for item in self._config:
- self.expand(item, config)
- return '\n'.join(self.flatten(config))
+ text = ''
+ for item in self.items:
+ if not item.parents:
+ expand = self.get_section(item.text)
+ text += '%s\n' % self.get_section(item.text)
+ return str(text).strip()
def load(self, contents):
self._config = parse(contents, indent=self.indent)
@@ -167,6 +193,45 @@ class NetworkConfig(object):
if c.raw not in current_level:
current_level[c.raw] = collections.OrderedDict()
+ def to_lines(self, section):
+ lines = list()
+ for entry in section[1:]:
+ line = ['set']
+ line.extend([p.text for p in entry.parents])
+ line.append(entry.text)
+ lines.append(' '.join(line))
+ return lines
+
+ def to_block(self, section):
+ return '\n'.join([item.raw for item in section])
+
+ def get_section(self, path):
+ try:
+ section = self.get_section_objects(path)
+ if self._device_os == 'junos':
+ return self.to_lines(section)
+ return self.to_block(section)
+ except ValueError:
+ return list()
+
+ def get_section_objects(self, path):
+ if not isinstance(path, list):
+ path = [path]
+ obj = self.get_object(path)
+ if not obj:
+ raise ValueError('path does not exist in config')
+ return self.expand_section(obj)
+
+ def expand_section(self, configobj, S=None):
+ if S is None:
+ S = list()
+ S.append(configobj)
+ for child in configobj.children:
+ if child in S:
+ continue
+ self.expand_section(child, S)
+ return S
+
def flatten(self, data, obj=None):
if obj is None:
obj = list()
@@ -237,155 +302,83 @@ class NetworkConfig(object):
return self.flatten(diffs)
- def _build_children(self, children, parents=None, offset=0):
- for item in children:
- line = ConfigLine(item)
- line.raw = item.rjust(len(item) + offset)
- if parents:
- line.parents = parents
- parents[-1].children.append(line)
- yield line
-
- def add(self, lines, parents=None):
- offset = 0
+ def replace(self, replace, text=None, regex=None, parents=None,
+ add_if_missing=False, ignore_whitespace=False):
+ match = None
- config = list()
- parent = None
parents = parents or list()
+ if text is None and regex is None:
+ raise ValueError('missing required arguments')
- for item in parents:
- line = ConfigLine(item)
- line.raw = item.rjust(len(item) + offset)
- config.append(line)
- if parent:
- parent.children.append(line)
- if parent.parents:
- line.parents.append(*parent.parents)
- line.parents.append(parent)
- parent = line
- offset += self.indent
-
- self._config.extend(config)
- self._config.extend(list(self._build_children(lines, config, offset)))
-
-
-
-class Conditional(object):
- """Used in command modules to evaluate waitfor conditions
- """
-
- OPERATORS = {
- 'eq': ['eq', '=='],
- 'neq': ['neq', 'ne', '!='],
- 'gt': ['gt', '>'],
- 'ge': ['ge', '>='],
- 'lt': ['lt', '<'],
- 'le': ['le', '<='],
- 'contains': ['contains']
- }
-
- def __init__(self, conditional, encoding='json'):
- self.raw = conditional
- self.encoding = encoding
-
- key, op, val = shlex.split(conditional)
- self.key = key
- self.func = self.func(op)
- self.value = self._cast_value(val)
-
- def __call__(self, data):
- value = self.get_value(dict(result=data))
- return self.func(value)
-
- def _cast_value(self, value):
- if value in BOOLEANS_TRUE:
- return True
- elif value in BOOLEANS_FALSE:
- return False
- elif re.match(r'^\d+\.d+$', value):
- return float(value)
- elif re.match(r'^\d+$', value):
- return int(value)
- else:
- return unicode(value)
+ if not regex:
+ regex = ['^%s$' % text]
- def func(self, oper):
- for func, operators in self.OPERATORS.items():
- if oper in operators:
- return getattr(self, func)
- raise AttributeError('unknown operator: %s' % oper)
+ patterns = [re.compile(r, re.I) for r in to_list(regex)]
- def get_value(self, result):
- if self.encoding in ['json', 'text']:
- return self.get_json(result)
- elif self.encoding == 'xml':
- return self.get_xml(result.get('result'))
-
- def get_xml(self, result):
- parts = self.key.split('.')
+ for item in self.items:
+ for regexp in patterns:
+ string = ignore_whitespace is True and item.text or item.raw
+ if regexp.search(item.text):
+ if item.text != replace:
+ if parents == [p.text for p in item.parents]:
+ match = item
+ break
- value_index = None
- match = re.match(r'^\S+(\[)(\d+)\]', parts[-1])
if match:
- start, end = match.regs[1]
- parts[-1] = parts[-1][0:start]
- value_index = int(match.group(2))
-
- path = '/'.join(parts[1:])
- path = '/%s' % path
- path += '/text()'
-
- index = int(re.match(r'result\[(\d+)\]', parts[0]).group(1))
- values = result[index].xpath(path)
-
- if value_index is not None:
- return values[value_index].strip()
- return [v.strip() for v in values]
-
- def get_json(self, result):
- parts = re.split(r'\.(?=[^\]]*(?:\[|$))', self.key)
- for part in parts:
- match = re.findall(r'\[(\S+?)\]', part)
- if match:
- key = part[:part.find('[')]
- result = result[key]
- for m in match:
- try:
- m = int(m)
- except ValueError:
- m = str(m)
- result = result[m]
- else:
- result = result.get(part)
- return result
-
- def number(self, value):
- if '.' in str(value):
- return float(value)
- else:
- return int(value)
-
- def eq(self, value):
- return value == self.value
+ match.text = replace
+ indent = len(match.raw) - len(match.raw.lstrip())
+ match.raw = replace.rjust(len(replace) + indent)
- def neq(self, value):
- return value != self.value
+ elif add_if_missing:
+ self.add(replace, parents=parents)
- def gt(self, value):
- return self.number(value) > self.value
- def ge(self, value):
- return self.number(value) >= self.value
-
- def lt(self, value):
- return self.number(value) < self.value
-
- def le(self, value):
- return self.number(value) <= self.value
+ def add(self, lines, parents=None):
+ """Adds one or lines of configuration
+ """
- def contains(self, value):
- return str(self.value) in value
+ ancestors = list()
+ offset = 0
+ obj = None
+ ## global config command
+ if not parents:
+ for line in to_list(lines):
+ item = ConfigLine(line)
+ item.raw = line
+ if item not in self.items:
+ self.items.append(item)
+ else:
+ for index, p in enumerate(parents):
+ try:
+ i = index + 1
+ obj = self.get_section_objects(parents[:i])[0]
+ ancestors.append(obj)
+
+ except ValueError:
+ # add parent to config
+ offset = index * self.indent
+ obj = ConfigLine(p)
+ obj.raw = p.rjust(len(p) + offset)
+ if ancestors:
+ obj.parents = list(ancestors)
+ ancestors[-1].children.append(obj)
+ self.items.append(obj)
+ ancestors.append(obj)
+
+ # add child objects
+ for line in to_list(lines):
+ # check if child already exists
+ for child in ancestors[-1].children:
+ if child.text == line:
+ break
+ else:
+ offset = len(parents) * self.indent
+ item = ConfigLine(line)
+ item.raw = line.rjust(len(line) + offset)
+ item.parents = ancestors
+ ancestors[-1].children.append(item)
+ self.items.append(item)
diff --git a/lib/ansible/module_utils/netcmd.py b/lib/ansible/module_utils/netcmd.py
new file mode 100644
index 0000000000..11254b78c9
--- /dev/null
+++ b/lib/ansible/module_utils/netcmd.py
@@ -0,0 +1,202 @@
+#
+# (c) 2015 Peter Sprygada, <psprygada@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+import re
+import time
+import collections
+import itertools
+import shlex
+
+from ansible.module_utils.basic import BOOLEANS_TRUE, BOOLEANS_FALSE
+
+class Conditional(object):
+ """Used in command modules to evaluate waitfor conditions
+ """
+
+ OPERATORS = {
+ 'eq': ['eq', '=='],
+ 'neq': ['neq', 'ne', '!='],
+ 'gt': ['gt', '>'],
+ 'ge': ['ge', '>='],
+ 'lt': ['lt', '<'],
+ 'le': ['le', '<='],
+ 'contains': ['contains'],
+ 'matches': ['matches']
+ }
+
+ def __init__(self, conditional, encoding='json'):
+ self.raw = conditional
+ self.encoding = encoding
+
+ key, op, val = shlex.split(conditional)
+ self.key = key
+ self.func = self.func(op)
+ self.value = self._cast_value(val)
+
+ def __call__(self, data):
+ value = self.get_value(dict(result=data))
+ return self.func(value)
+
+ def _cast_value(self, value):
+ if value in BOOLEANS_TRUE:
+ return True
+ elif value in BOOLEANS_FALSE:
+ return False
+ elif re.match(r'^\d+\.d+$', value):
+ return float(value)
+ elif re.match(r'^\d+$', value):
+ return int(value)
+ else:
+ return unicode(value)
+
+ def func(self, oper):
+ for func, operators in self.OPERATORS.items():
+ if oper in operators:
+ return getattr(self, func)
+ raise AttributeError('unknown operator: %s' % oper)
+
+ def get_value(self, result):
+ if self.encoding in ['json', 'text']:
+ return self.get_json(result)
+ elif self.encoding == 'xml':
+ return self.get_xml(result.get('result'))
+
+ def get_xml(self, result):
+ parts = self.key.split('.')
+
+ value_index = None
+ match = re.match(r'^\S+(\[)(\d+)\]', parts[-1])
+ if match:
+ start, end = match.regs[1]
+ parts[-1] = parts[-1][0:start]
+ value_index = int(match.group(2))
+
+ path = '/'.join(parts[1:])
+ path = '/%s' % path
+ path += '/text()'
+
+ index = int(re.match(r'result\[(\d+)\]', parts[0]).group(1))
+ values = result[index].xpath(path)
+
+ if value_index is not None:
+ return values[value_index].strip()
+ return [v.strip() for v in values]
+
+ def get_json(self, result):
+ parts = re.split(r'\.(?=[^\]]*(?:\[|$))', self.key)
+ for part in parts:
+ match = re.findall(r'\[(\S+?)\]', part)
+ if match:
+ key = part[:part.find('[')]
+ result = result[key]
+ for m in match:
+ try:
+ m = int(m)
+ except ValueError:
+ m = str(m)
+ result = result[m]
+ else:
+ result = result.get(part)
+ return result
+
+ def number(self, value):
+ if '.' in str(value):
+ return float(value)
+ else:
+ return int(value)
+
+ def eq(self, value):
+ return value == self.value
+
+ def neq(self, value):
+ return value != self.value
+
+ def gt(self, value):
+ return self.number(value) > self.value
+
+ def ge(self, value):
+ return self.number(value) >= self.value
+
+ def lt(self, value):
+ return self.number(value) < self.value
+
+ def le(self, value):
+ return self.number(value) <= self.value
+
+ def contains(self, value):
+ return str(self.value) in value
+
+ def matches(self, value):
+ match = re.search(value, self.value, re.M)
+ return match is not None
+
+
+class FailedConditionsError(Exception):
+
+ def __init__(self, msg, failed_conditions):
+ super(FailedConditionsError, self).__init__(msg)
+ self.failed_conditions = failed_conditions
+
+class CommandRunner(collections.Mapping):
+
+ def __init__(self, module):
+ self.module = module
+
+ self.items = dict()
+ self.conditionals = set()
+
+ self.retries = 10
+ self.interval = 1
+
+ def __getitem__(self, key):
+ return self.items[key]
+
+ def __len__(self):
+ return len(self.items)
+
+ def __iter__(self):
+ return iter(self.items)
+
+ def add_command(self, command, output=None):
+ self.module.cli.add_commands(command, output=output)
+
+ def add_conditional(self, condition):
+ self.conditionals.add(Conditional(condition))
+
+ def run_commands(self):
+ responses = self.module.cli.run_commands()
+ for cmd, resp in itertools.izip(self.module.cli.commands, responses):
+ self.items[str(cmd)] = resp
+
+ def run(self):
+ while self.retries > 0:
+ self.run_commands()
+ for item in list(self.conditionals):
+ if item(self.items.values()):
+ self.conditionals.remove(item)
+
+ if not self.conditionals:
+ break
+
+ time.sleep(self.interval)
+ self.retries -= 1
+ else:
+ failed_conditions = [item.raw for item in self.conditionals]
+ raise FailedConditionsError('timeout waiting for value', failed_conditions)
+
diff --git a/lib/ansible/module_utils/network.py b/lib/ansible/module_utils/network.py
new file mode 100644
index 0000000000..19c0c2c532
--- /dev/null
+++ b/lib/ansible/module_utils/network.py
@@ -0,0 +1,282 @@
+#
+# (c) 2015 Peter Sprygada, <psprygada@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.basic import env_fallback, get_exception
+from ansible.module_utils.shell import Shell, ShellError, HAS_PARAMIKO
+
+NET_TRANSPORT_ARGS = dict(
+ host=dict(required=True),
+ port=dict(type='int'),
+ username=dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])),
+ password=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD'])),
+ ssh_keyfile=dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'),
+ authorize=dict(default=False, fallback=(env_fallback, ['ANSIBLE_NET_AUTHORIZE']), type='bool'),
+ auth_pass=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_AUTH_PASS'])),
+ provider=dict(type='dict'),
+ transport=dict(choices=list()),
+ timeout=dict(default=10, type='int')
+)
+
+NET_CONNECTION_ARGS = dict()
+
+NET_CONNECTIONS = dict()
+
+
+def to_list(val):
+ if isinstance(val, (list, tuple)):
+ return list(val)
+ elif val is not None:
+ return [val]
+ else:
+ return list()
+
+def connect(module):
+ try:
+ if not module.connected:
+ module.connection.connect(module.params)
+ if module.params['authorize']:
+ module.connection.authorize(module.params)
+ except NetworkError:
+ exc = get_exception()
+ module.fail_json(msg=exc.message)
+
+def disconnect(module):
+ try:
+ if module.connected:
+ module.connection.disconnect()
+ except NetworkError:
+ exc = get_exception()
+ module.fail_json(msg=exc.message)
+
+
+class Command(object):
+
+ def __init__(self, command, output=None, prompt=None, response=None):
+ self.command = command
+ self.output = output
+ self.prompt = prompt
+ self.response = response
+ self.conditions = set()
+
+ def __str__(self):
+ return self.command
+
+class Cli(object):
+
+ def __init__(self, connection):
+ self.connection = connection
+ self.default_output = connection.default_output or 'text'
+ self.commands = list()
+
+ def __call__(self, commands, output=None):
+ commands = self.to_command(commands, output)
+ return self.connection.run_commands(commands)
+
+ def to_command(self, commands, output=None):
+ output = output or self.default_output
+ objects = list()
+ for cmd in to_list(commands):
+ if not isinstance(cmd, Command):
+ cmd = Command(cmd, output)
+ objects.append(cmd)
+ return objects
+
+ def add_commands(self, commands, output=None):
+ commands = self.to_command(commands, output)
+ self.commands.extend(commands)
+
+ def run_commands(self):
+ return self.connection.run_commands(self.commands)
+
+class Config(object):
+
+ def __init__(self, connection):
+ self.connection = connection
+
+ def invoke(self, method, *args, **kwargs):
+ try:
+ return method(*args, **kwargs)
+ except AttributeError:
+ exc = get_exception()
+ raise NetworkError('undefined method "%s"' % method.__name__, exc=str(exc))
+ except NetworkError:
+ if raise_exc:
+ raise
+ exc = get_exception()
+ self.fail_json(msg=exc.message, **exc.kwargs)
+ except NotImplementedError:
+ raise NetworkError('method not supported "%s"' % method.__name__)
+
+ def __call__(self, commands):
+ lines = to_list(commands)
+ return self.invoke(self.connection.configure, commands)
+
+ def load_config(self, commands, **kwargs):
+ commands = to_list(commands)
+ return self.invoke(self.connection.load_config, commands, **kwargs)
+
+ def get_config(self, **kwargs):
+ return self.invoke(self.connection.get_config, **kwargs)
+
+ def commit_config(self, **kwargs):
+ return self.invoke(self.connection.commit_config, **kwargs)
+
+ def abort_config(self, **kwargs):
+ return self.invoke(self.connection.abort_config, **kwargs)
+
+ def save_config(self):
+ return self.invoke(self.connection.save_config)
+
+
+class NetworkError(Exception):
+
+ def __init__(self, msg, **kwargs):
+ super(NetworkError, self).__init__(msg)
+ self.kwargs = kwargs
+
+
+class NetworkModule(AnsibleModule):
+
+ def __init__(self, *args, **kwargs):
+ super(NetworkModule, self).__init__(*args, **kwargs)
+ self.connection = None
+ self._cli = None
+ self._config = None
+
+ @property
+ def cli(self):
+ if not self.connected:
+ connect(self)
+ if self._cli:
+ return self._cli
+ self._cli = Cli(self.connection)
+ return self._cli
+
+ @property
+ def config(self):
+ if not self.connected:
+ connect(self)
+ if self._config:
+ return self._config
+ self._config = Config(self.connection)
+ return self._config
+
+ @property
+ def connected(self):
+ return self.connection._connected
+
+ def _load_params(self):
+ super(NetworkModule, self)._load_params()
+ provider = self.params.get('provider') or dict()
+ for key, value in provider.items():
+ for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]:
+ if key in args:
+ if self.params.get(key) is None and value is not None:
+ self.params[key] = value
+
+
+class NetCli(object):
+ """Basic paramiko-based ssh transport any NetworkModule can use."""
+
+ def __init__(self):
+ if not HAS_PARAMIKO:
+ raise NetworkError(
+ msg='paramiko is required but does not appear to be installed. '
+ 'It can be installed using `pip install paramiko`'
+ )
+
+ self.shell = None
+ self._connected = False
+ self.default_output = 'text'
+
+ def connect(self, params, kickstart, **kwargs):
+ host = params['host']
+ port = params.get('port') or 22
+
+ username = params['username']
+ password = params.get('password')
+ key_file = params.get('ssh_keyfile')
+ timeout = params['timeout']
+
+ try:
+ self.shell = Shell(
+ kickstart=kickstart,
+ prompts_re=self.CLI_PROMPTS_RE,
+ errors_re=self.CLI_ERRORS_RE,
+ )
+ self.shell.open(
+ host, port=port, username=username, password=password,
+ key_filename=key_file, timeout=timeout,
+ )
+ except ShellError:
+ exc = get_exception()
+ raise NetworkError(
+ msg='failed to connect to %s:%s' % (host, port), exc=str(exc)
+ )
+
+ def disconnect(self, **kwargs):
+ self._connected = False
+ self.shell.close()
+
+ def execute(self, commands, **kwargs):
+ try:
+ return self.shell.send(commands)
+ except ShellError:
+ exc = get_exception()
+ raise NetworkError(exc.message, commands=commands)
+
+
+def get_module(connect_on_load=True, **kwargs):
+ argument_spec = NET_TRANSPORT_ARGS.copy()
+ argument_spec['transport']['choices'] = NET_CONNECTIONS.keys()
+ argument_spec.update(NET_CONNECTION_ARGS.copy())
+
+ if kwargs.get('argument_spec'):
+ argument_spec.update(kwargs['argument_spec'])
+ kwargs['argument_spec'] = argument_spec
+
+ module = NetworkModule(**kwargs)
+
+ try:
+ transport = module.params['transport'] or '__default__'
+ cls = NET_CONNECTIONS[transport]
+ module.connection = cls()
+ except KeyError:
+ module.fail_json(msg='Unknown transport or no default transport specified')
+ except (TypeError, NetworkError):
+ exc = get_exception()
+ module.fail_json(msg=exc.message)
+
+ if connect_on_load:
+ connect(module)
+
+ return module
+
+def register_transport(transport, default=False):
+ def register(cls):
+ NET_CONNECTIONS[transport] = cls
+ if default:
+ NET_CONNECTIONS['__default__'] = cls
+ return cls
+ return register
+
+def add_argument(key, value):
+ NET_CONNECTION_ARGS[key] = value
+
diff --git a/lib/ansible/module_utils/shell.py b/lib/ansible/module_utils/shell.py
index 641f6927ab..5e17df2573 100644
--- a/lib/ansible/module_utils/shell.py
+++ b/lib/ansible/module_utils/shell.py
@@ -19,6 +19,8 @@
import re
import socket
+from ansible.module_utils.basic import get_exception
+
# py2 vs py3; replace with six via ziploader
try:
from StringIO import StringIO
@@ -156,6 +158,9 @@ class Shell(object):
responses.append(self.receive(command))
except socket.timeout:
raise ShellError("timeout trying to send command", cmd)
+ except socket.error:
+ exc = get_exception()
+ raise ShellError("problem sending command to host: %s" % exc.message)
return responses
def close(self):
diff --git a/lib/ansible/playbook/handler.py b/lib/ansible/playbook/handler.py
index c8c1572e48..a611b72259 100644
--- a/lib/ansible/playbook/handler.py
+++ b/lib/ansible/playbook/handler.py
@@ -20,11 +20,13 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.errors import AnsibleError
-#from ansible.inventory.host import Host
+from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.task import Task
class Handler(Task):
+ _listen = FieldAttribute(isa='list')
+
def __init__(self, block=None, role=None, task_include=None):
self._flagged_hosts = []
diff --git a/lib/ansible/playbook/task.py b/lib/ansible/playbook/task.py
index d6eabf366c..309579c85a 100644
--- a/lib/ansible/playbook/task.py
+++ b/lib/ansible/playbook/task.py
@@ -105,7 +105,7 @@ class Task(Base, Conditional, Taggable, Become):
def get_name(self):
''' return the name of the task '''
- if self._role and self.name:
+ if self._role and self.name and ("%s : " % self._role._role_name) not in self.name:
return "%s : %s" % (self._role.get_name(), self.name)
elif self.name:
return self.name
diff --git a/lib/ansible/plugins/strategy/__init__.py b/lib/ansible/plugins/strategy/__init__.py
index 2881bf419e..8864d57736 100644
--- a/lib/ansible/plugins/strategy/__init__.py
+++ b/lib/ansible/plugins/strategy/__init__.py
@@ -100,7 +100,8 @@ class StrategyBase:
self._tqm = tqm
self._inventory = tqm.get_inventory()
self._workers = tqm.get_workers()
- self._notified_handlers = tqm.get_notified_handlers()
+ self._notified_handlers = tqm._notified_handlers
+ self._listening_handlers = tqm._listening_handlers
self._variable_manager = tqm.get_variable_manager()
self._loader = tqm.get_loader()
self._final_q = tqm._final_q
@@ -318,12 +319,63 @@ class StrategyBase:
original_host = get_original_host(task_result._host)
original_task = iterator.get_original_task(original_host, task_result._task)
- if handler_name not in self._notified_handlers:
- self._notified_handlers[handler_name] = []
- if original_host not in self._notified_handlers[handler_name]:
- self._notified_handlers[handler_name].append(original_host)
- display.vv("NOTIFIED HANDLER %s" % (handler_name,))
+ def search_handler_blocks(handler_name, handler_blocks):
+ for handler_block in handler_blocks:
+ for handler_task in handler_block.block:
+ handler_vars = self._variable_manager.get_vars(loader=self._loader, play=iterator._play, task=handler_task)
+ templar = Templar(loader=self._loader, variables=handler_vars)
+ try:
+ # first we check with the full result of get_name(), which may
+ # include the role name (if the handler is from a role). If that
+ # is not found, we resort to the simple name field, which doesn't
+ # have anything extra added to it.
+ target_handler_name = templar.template(handler_task.name)
+ if target_handler_name == handler_name:
+ return handler_task
+ else:
+ target_handler_name = templar.template(handler_task.get_name())
+ if target_handler_name == handler_name:
+ return handler_task
+ except (UndefinedError, AnsibleUndefinedVariable):
+ # We skip this handler due to the fact that it may be using
+ # a variable in the name that was conditionally included via
+ # set_fact or some other method, and we don't want to error
+ # out unnecessarily
+ continue
+ return None
+
+ # Find the handler using the above helper. First we look up the
+ # dependency chain of the current task (if it's from a role), otherwise
+ # we just look through the list of handlers in the current play/all
+ # roles and use the first one that matches the notify name
+ target_handler = None
+ if original_task._role:
+ target_handler = search_handler_blocks(handler_name, original_task._role.get_handler_blocks())
+ if target_handler is None:
+ target_handler = search_handler_blocks(handler_name, iterator._play.handlers)
+ if target_handler is None:
+ if handler_name in self._listening_handlers:
+ for listening_handler_name in self._listening_handlers[handler_name]:
+ listening_handler = None
+ if original_task._role:
+ listening_handler = search_handler_blocks(listening_handler_name, original_task._role.get_handler_blocks())
+ if listening_handler is None:
+ listening_handler = search_handler_blocks(listening_handler_name, iterator._play.handlers)
+ if listening_handler is None:
+ raise AnsibleError("The requested handler listener '%s' was not found in any of the known handlers" % listening_handler_name)
+
+ if original_host not in self._notified_handlers[listening_handler]:
+ self._notified_handlers[listening_handler].append(original_host)
+ display.vv("NOTIFIED HANDLER %s" % (listening_handler_name,))
+ else:
+ raise AnsibleError("The requested handler '%s' was found in neither the main handlers list nor the listening handlers list" % handler_name)
+ else:
+ if target_handler in self._notified_handlers:
+ if original_host not in self._notified_handlers[target_handler]:
+ self._notified_handlers[target_handler].append(original_host)
+ # FIXME: should this be a callback?
+ display.vv("NOTIFIED HANDLER %s" % (handler_name,))
elif result[0] == 'register_host_var':
# essentially the same as 'set_host_var' below, however we
@@ -572,25 +624,8 @@ class StrategyBase:
# but this may take some work in the iterator and gets tricky when
# we consider the ability of meta tasks to flush handlers
for handler in handler_block.block:
- handler_vars = self._variable_manager.get_vars(loader=self._loader, play=iterator._play, task=handler)
- templar = Templar(loader=self._loader, variables=handler_vars)
- try:
- # first we check with the full result of get_name(), which may
- # include the role name (if the handler is from a role). If that
- # is not found, we resort to the simple name field, which doesn't
- # have anything extra added to it.
- handler_name = templar.template(handler.name)
- if handler_name not in self._notified_handlers:
- handler_name = templar.template(handler.get_name())
- except (UndefinedError, AnsibleUndefinedVariable):
- # We skip this handler due to the fact that it may be using
- # a variable in the name that was conditionally included via
- # set_fact or some other method, and we don't want to error
- # out unnecessarily
- continue
-
- if handler_name in self._notified_handlers and len(self._notified_handlers[handler_name]):
- result = self._do_handler_run(handler, handler_name, iterator=iterator, play_context=play_context)
+ if handler in self._notified_handlers and len(self._notified_handlers[handler]):
+ result = self._do_handler_run(handler, handler.get_name(), iterator=iterator, play_context=play_context)
if not result:
break
return result
@@ -608,7 +643,7 @@ class StrategyBase:
handler.name = saved_name
if notified_hosts is None:
- notified_hosts = self._notified_handlers[handler_name]
+ notified_hosts = self._notified_handlers[handler]
run_once = False
try:
@@ -671,7 +706,7 @@ class StrategyBase:
continue
# wipe the notification list
- self._notified_handlers[handler_name] = []
+ self._notified_handlers[handler] = []
display.debug("done running handlers, result is: %s" % result)
return result
diff --git a/test/units/plugins/strategies/test_strategy_base.py b/test/units/plugins/strategies/test_strategy_base.py
index 338bcc1fd6..e079fa8d48 100644
--- a/test/units/plugins/strategies/test_strategy_base.py
+++ b/test/units/plugins/strategies/test_strategy_base.py
@@ -45,12 +45,16 @@ class TestStrategyBase(unittest.TestCase):
mock_tqm = MagicMock(TaskQueueManager)
mock_tqm._final_q = MagicMock()
mock_tqm._options = MagicMock()
+ mock_tqm._notified_handlers = {}
+ mock_tqm._listening_handlers = {}
strategy_base = StrategyBase(tqm=mock_tqm)
def test_strategy_base_run(self):
mock_tqm = MagicMock(TaskQueueManager)
mock_tqm._final_q = MagicMock()
mock_tqm._stats = MagicMock()
+ mock_tqm._notified_handlers = {}
+ mock_tqm._listening_handlers = {}
mock_tqm.send_callback.return_value = None
mock_iterator = MagicMock()
@@ -62,6 +66,8 @@ class TestStrategyBase(unittest.TestCase):
mock_tqm._failed_hosts = dict()
mock_tqm._unreachable_hosts = dict()
mock_tqm._options = MagicMock()
+ mock_tqm._notified_handlers = {}
+ mock_tqm._listening_handlers = {}
strategy_base = StrategyBase(tqm=mock_tqm)
mock_host = MagicMock()
@@ -89,6 +95,8 @@ class TestStrategyBase(unittest.TestCase):
mock_tqm = MagicMock()
mock_tqm._final_q = MagicMock()
+ mock_tqm._notified_handlers = {}
+ mock_tqm._listening_handlers = {}
mock_tqm.get_inventory.return_value = mock_inventory
mock_play = MagicMock()
@@ -153,6 +161,8 @@ class TestStrategyBase(unittest.TestCase):
mock_tqm._failed_hosts = dict()
mock_tqm._unreachable_hosts = dict()
mock_tqm.send_callback.return_value = None
+ mock_tqm._notified_handlers = {}
+ mock_tqm._listening_handlers = {}
queue_items = []
def _queue_empty(*args, **kwargs):
@@ -171,7 +181,10 @@ class TestStrategyBase(unittest.TestCase):
mock_tqm._stats = MagicMock()
mock_tqm._stats.increment.return_value = None
+ mock_play = MagicMock()
+
mock_iterator = MagicMock()
+ mock_iterator._play = mock_play
mock_iterator.mark_host_failed.return_value = None
mock_iterator.get_next_task_for_host.return_value = (None, None)
@@ -184,6 +197,18 @@ class TestStrategyBase(unittest.TestCase):
mock_task._role = None
mock_task.ignore_errors = False
+ mock_handler_task = MagicMock(Handler)
+ mock_handler_task.action = 'foo'
+ mock_handler_task.get_name.return_value = "test handler"
+ mock_handler_task.has_triggered.return_value = False
+
+ mock_handler_block = MagicMock()
+ mock_handler_block.block = [mock_handler_task]
+ mock_play.handlers = [mock_handler_block]
+
+ mock_tqm._notified_handlers = {mock_handler_task: []}
+ mock_tqm._listening_handlers = {}
+
mock_group = MagicMock()
mock_group.add_host.return_value = None
@@ -211,7 +236,6 @@ class TestStrategyBase(unittest.TestCase):
strategy_base._inventory = mock_inventory
strategy_base._variable_manager = mock_var_mgr
strategy_base._blocked_hosts = dict()
- strategy_base._notified_handlers = dict()
results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
self.assertEqual(len(results), 0)
@@ -281,8 +305,8 @@ class TestStrategyBase(unittest.TestCase):
self.assertEqual(len(results), 0)
self.assertEqual(strategy_base._pending_results, 1)
self.assertIn('test01', strategy_base._blocked_hosts)
- self.assertIn('test handler', strategy_base._notified_handlers)
- self.assertIn(mock_host, strategy_base._notified_handlers['test handler'])
+ self.assertIn(mock_handler_task, strategy_base._notified_handlers)
+ self.assertIn(mock_host, strategy_base._notified_handlers[mock_handler_task])
queue_items.append(('set_host_var', mock_host, mock_task, None, 'foo', 'bar'))
results = strategy_base._process_pending_results(iterator=mock_iterator)
@@ -308,6 +332,8 @@ class TestStrategyBase(unittest.TestCase):
mock_tqm = MagicMock()
mock_tqm._final_q = MagicMock()
+ mock_tqm._notified_handlers = {}
+ mock_tqm._listening_handlers = {}
strategy_base = StrategyBase(tqm=mock_tqm)
strategy_base._loader = fake_loader
@@ -379,13 +405,14 @@ class TestStrategyBase(unittest.TestCase):
passwords=None,
)
tqm._initialize_processes(3)
+ tqm._initialize_notified_handlers(mock_play)
tqm.hostvars = dict()
try:
strategy_base = StrategyBase(tqm=tqm)
strategy_base._inventory = mock_inventory
- strategy_base._notified_handlers = {"test handler": [mock_host]}
+ strategy_base._notified_handlers = {mock_handler_task: [mock_host]}
task_result = TaskResult(Host('host01'), Handler(), dict(changed=False))
tqm._final_q.put(('host_task_ok', task_result))