summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Sprygada <privateip@users.noreply.github.com>2017-02-14 14:38:30 -0500
committerGitHub <noreply@github.com>2017-02-14 14:38:30 -0500
commitb0c01bbb820802bca53cd7274ef161ff64335b1f (patch)
treefaed75ff11b78210adbf2fb6813c5648f92c0c67
parent009ac075b7de7450efddc1e097af0b2f9589c595 (diff)
downloadansible-b0c01bbb820802bca53cd7274ef161ff64335b1f.tar.gz
updates network_common lib (#21306)
* removes connection functions refactored into connection * updates ComplexDict and ComplexList objects to use with AnsibleModule * updates modules to add new argument to ComplexList & ComplexDict
-rw-r--r--lib/ansible/module_utils/network_common.py198
-rw-r--r--lib/ansible/modules/network/eos/eos_command.py5
-rw-r--r--lib/ansible/modules/network/eos/eos_system.py4
-rw-r--r--lib/ansible/modules/network/ios/ios_command.py2
-rw-r--r--lib/ansible/modules/network/ios/ios_system.py6
-rw-r--r--lib/ansible/modules/network/iosxr/iosxr_command.py2
-rw-r--r--lib/ansible/modules/network/vyos/vyos_command.py3
7 files changed, 114 insertions, 106 deletions
diff --git a/lib/ansible/module_utils/network_common.py b/lib/ansible/module_utils/network_common.py
index 82c0cc9f8f..add4ce18ea 100644
--- a/lib/ansible/module_utils/network_common.py
+++ b/lib/ansible/module_utils/network_common.py
@@ -25,13 +25,8 @@
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
-import socket
-import struct
-import signal
-
-from ansible.module_utils.basic import get_exception
-from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six import iteritems
+from ansible.module_utils.basic import AnsibleFallbackNotFound
def to_list(val):
if isinstance(val, (list, tuple, set)):
@@ -41,103 +36,116 @@ def to_list(val):
else:
return list()
-class ComplexDict:
-
- def __init__(self, attrs):
+class ComplexDict(object):
+ """Transforms a dict to with an argument spec
+
+ This class will take a dict and apply an Ansible argument spec to the
+ values. The resulting dict will contain all of the keys in the param
+ with appropriate values set.
+
+ Example::
+
+ argument_spec = dict(
+ command=dict(key=True),
+ display=dict(default='text', choices=['text', 'json']),
+ validate=dict(type='bool')
+ )
+ transform = ComplexDict(argument_spec, module)
+ value = dict(command='foo')
+ result = transform(value)
+ print result
+ {'command': 'foo', 'display': 'text', 'validate': None}
+
+ Supported argument spec:
+ * key - specifies how to map a single value to a dict
+ * read_from - read and apply the argument_spec from the module
+ * required - a value is required
+ * type - type of value (uses AnsibleModule type checker)
+ * fallback - implements fallback function
+ * choices - set of valid options
+ * default - default value
+
+ """
+
+ def __init__(self, attrs, module):
self._attributes = attrs
+ self._module = module
self.attr_names = frozenset(self._attributes.keys())
+
+ self._has_key = False
for name, attr in iteritems(self._attributes):
+ if attr.get('read_from'):
+ spec = self._module.argument_spec.get(attr['read_from'])
+ if not spec:
+ raise ValueError('argument_spec %s does not exist' % attr['read_from'])
+ for key, value in iteritems(spec):
+ if key not in attr:
+ attr[key] = value
+
if attr.get('key'):
+ if self._has_key:
+ raise ValueError('only one key value can be specified')
+ self_has_key = True
attr['required'] = True
- def __call__(self, value):
- if isinstance(value, dict):
- unknown = set(value.keys()).difference(self.attr_names)
- if unknown:
- raise ValueError('invalid keys: %s' % ','.join(unknown))
- for name, attr in iteritems(self._attributes):
- if attr.get('required') and name not in value:
- raise ValueError('missing required attribute %s' % name)
- if not value.get(name):
- value[name] = attr.get('default')
- return value
- else:
- obj = {}
- for name, attr in iteritems(self._attributes):
- if attr.get('key'):
- obj[name] = value
- else:
- obj[name] = attr.get('default')
- return obj
-
-
-class ComplexList:
-
- def __init__(self, attrs):
- self._attributes = attrs
- self.attr_names = frozenset(self._attributes.keys())
+
+ def _dict(self, value):
+ obj = {}
for name, attr in iteritems(self._attributes):
if attr.get('key'):
- attr['required'] = True
+ obj[name] = value
+ else:
+ obj[name] = attr.get('default')
+ return obj
+
+ def __call__(self, value):
+ if not isinstance(value, dict):
+ value = self._dict(value)
+ unknown = set(value).difference(self.attr_names)
+ if unknown:
+ raise ValueError('invalid keys: %s' % ','.join(unknown))
+
+ for name, attr in iteritems(self._attributes):
+ if not value.get(name):
+ value[name] = attr.get('default')
+
+ if attr.get('fallback') and not value.get(name):
+ fallback = attr.get('fallback', (None,))
+ fallback_strategy = fallback[0]
+ fallback_args = []
+ fallback_kwargs = {}
+ if fallback_strategy is not None:
+ for item in fallback[1:]:
+ if isinstance(item, dict):
+ fallback_kwargs = item
+ else:
+ fallback_args = item
+ try:
+ value[name] = fallback_strategy(*fallback_args, **fallback_kwargs)
+ except AnsibleFallbackNotFound:
+ continue
+
+ if attr.get('required') and value.get(name) is None:
+ raise ValueError('missing required attribute %s' % name)
+
+ if 'choices' in attr:
+ if value[name] not in attr['choices']:
+ raise ValueError('%s must be one of %s, got %s' % \
+ (name, ', '.join(attr['choices']), value[name]))
+
+ if value[name] is not None:
+ value_type = attr.get('type', 'str')
+ type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type]
+ type_checker(value[name])
+
+ return value
+
+class ComplexList(ComplexDict):
+ """Extends ```ComplexDict``` to handle a list of dicts """
def __call__(self, values):
- objects = list()
- for value in values:
- if isinstance(value, dict):
- for name, attr in iteritems(self._attributes):
- if attr.get('required') and name not in value:
- raise ValueError('missing required attr %s' % name)
- if not value.get(name):
- value[name] = attr.get('default')
- objects.append(value)
- else:
- obj = {}
- for name, attr in iteritems(self._attributes):
- if attr.get('key'):
- obj[name] = value
- else:
- obj[name] = attr.get('default')
- objects.append(obj)
- return objects
-
-def send_data(s, data):
- packed_len = struct.pack('!Q',len(data))
- return s.sendall(packed_len + data)
-
-def recv_data(s):
- header_len = 8 # size of a packed unsigned long long
- data = to_bytes("")
- while len(data) < header_len:
- d = s.recv(header_len - len(data))
- if not d:
- return None
- data += d
- data_len = struct.unpack('!Q',data[:header_len])[0]
- data = data[header_len:]
- while len(data) < data_len:
- d = s.recv(data_len - len(data))
- if not d:
- return None
- data += d
- return data
-
-def exec_command(module, command):
- try:
- sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- sf.connect(module._socket_path)
-
- data = "EXEC: %s" % command
- send_data(sf, to_bytes(data.strip()))
-
- rc = int(recv_data(sf), 10)
- stdout = recv_data(sf)
- stderr = recv_data(sf)
- except socket.error:
- exc = get_exception()
- sf.close()
- module.fail_json(msg='unable to connect to socket', err=str(exc))
-
- sf.close()
-
- return (rc, to_native(stdout), to_native(stderr))
+ if not isinstance(values, (list, tuple)):
+ raise TypeError('value must be an ordered iterable')
+ return [(super(ComplexList, self).__call__(v)) for v in values]
+
diff --git a/lib/ansible/modules/network/eos/eos_command.py b/lib/ansible/modules/network/eos/eos_command.py
index 11b0f59eda..ecce61c8c4 100644
--- a/lib/ansible/modules/network/eos/eos_command.py
+++ b/lib/ansible/modules/network/eos/eos_command.py
@@ -143,13 +143,14 @@ def to_lines(stdout):
return lines
def parse_commands(module, warnings):
- transform = ComplexList(dict(
+ spec = dict(
command=dict(key=True),
output=dict(),
prompt=dict(),
response=dict()
- ))
+ )
+ transform = ComplexList(spec, module)
commands = transform(module.params['commands'])
for index, item in enumerate(commands):
diff --git a/lib/ansible/modules/network/eos/eos_system.py b/lib/ansible/modules/network/eos/eos_system.py
index ae83ef0fbc..bd4d48be3a 100644
--- a/lib/ansible/modules/network/eos/eos_system.py
+++ b/lib/ansible/modules/network/eos/eos_system.py
@@ -272,12 +272,12 @@ def map_params_to_obj(module):
lookup_source = ComplexList(dict(
interface=dict(key=True),
vrf=dict()
- ))
+ ), module)
name_servers = ComplexList(dict(
server=dict(key=True),
vrf=dict(default='default')
- ))
+ ), module)
for arg, cast in [('lookup_source', lookup_source), ('name_servers', name_servers)]:
if module.params[arg] is not None:
diff --git a/lib/ansible/modules/network/ios/ios_command.py b/lib/ansible/modules/network/ios/ios_command.py
index 2fa2bbb6fc..4abd06c6e4 100644
--- a/lib/ansible/modules/network/ios/ios_command.py
+++ b/lib/ansible/modules/network/ios/ios_command.py
@@ -149,7 +149,7 @@ def parse_commands(module, warnings):
command=dict(key=True),
prompt=dict(),
response=dict()
- ))
+ ), module)
commands = command(module.params['commands'])
for index, item in enumerate(commands):
if module.check_mode and not item['command'].startswith('show'):
diff --git a/lib/ansible/modules/network/ios/ios_system.py b/lib/ansible/modules/network/ios/ios_system.py
index 3e2772dfda..1bc6edb53d 100644
--- a/lib/ansible/modules/network/ios/ios_system.py
+++ b/lib/ansible/modules/network/ios/ios_system.py
@@ -311,17 +311,17 @@ def map_params_to_obj(module):
domain_name = ComplexList(dict(
name=dict(key=True),
vrf=dict()
- ))
+ ), module)
domain_search = ComplexList(dict(
name=dict(key=True),
vrf=dict()
- ))
+ ), module)
name_servers = ComplexList(dict(
server=dict(key=True),
vrf=dict()
- ))
+ ), module)
for arg, cast in [('domain_name', domain_name),
('domain_search', domain_search),
diff --git a/lib/ansible/modules/network/iosxr/iosxr_command.py b/lib/ansible/modules/network/iosxr/iosxr_command.py
index 5560321b92..a81ce18ade 100644
--- a/lib/ansible/modules/network/iosxr/iosxr_command.py
+++ b/lib/ansible/modules/network/iosxr/iosxr_command.py
@@ -163,7 +163,7 @@ def parse_commands(module, warnings):
command=dict(key=True),
prompt=dict(),
response=dict()
- ))
+ ), module)
commands = command(module.params['commands'])
for index, item in enumerate(commands):
diff --git a/lib/ansible/modules/network/vyos/vyos_command.py b/lib/ansible/modules/network/vyos/vyos_command.py
index 63218a94eb..2aefa622a2 100644
--- a/lib/ansible/modules/network/vyos/vyos_command.py
+++ b/lib/ansible/modules/network/vyos/vyos_command.py
@@ -152,8 +152,7 @@ def parse_commands(module, warnings):
command=dict(key=True),
prompt=dict(),
response=dict(),
- ))
-
+ ), module)
commands = command(module.params['commands'])
for index, cmd in enumerate(commands):