summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorToshio Kuratomi <toshio@fedoraproject.org>2015-06-29 08:05:58 -0700
committerToshio Kuratomi <toshio@fedoraproject.org>2015-06-29 08:05:58 -0700
commitbe6db1a730270a8e89636da9630dcac8e3e093fc (patch)
tree5b43932419e15dc732ee73ae61aa1a08dee2ab19
parent881dbb6da122598029107e63dc6b1cfe51f2bc2c (diff)
downloadansible-argspec-path-and-refactor.tar.gz
Refactor the argspec type checking and add path as a typeargspec-path-and-refactor
-rw-r--r--lib/ansible/module_utils/basic.py146
1 files changed, 90 insertions, 56 deletions
diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py
index ffd159601d..e89809ff12 100644
--- a/lib/ansible/module_utils/basic.py
+++ b/lib/ansible/module_utils/basic.py
@@ -351,9 +351,9 @@ class AnsibleModule(object):
self.check_mode = False
self.no_log = no_log
self.cleanup_files = []
-
+
self.aliases = {}
-
+
if add_file_common_args:
for k, v in FILE_COMMON_ARGUMENTS.iteritems():
if k not in self.argument_spec:
@@ -366,7 +366,7 @@ class AnsibleModule(object):
self.params = self._load_params()
self._legal_inputs = ['_ansible_check_mode', '_ansible_no_log']
-
+
self.aliases = self._handle_aliases()
if check_invalid_arguments:
@@ -380,6 +380,16 @@ class AnsibleModule(object):
self._set_defaults(pre=True)
+
+ self._CHECK_ARGUMENT_TYPES_DISPATCHER = {
+ 'str': self._check_type_str,
+ 'list': self._check_type_list,
+ 'dict': self._check_type_dict,
+ 'bool': self._check_type_bool,
+ 'int': self._check_type_int,
+ 'float': self._check_type_float,
+ 'path': self._check_type_path,
+ }
if not bypass_checks:
self._check_required_arguments()
self._check_argument_values()
@@ -1021,6 +1031,76 @@ class AnsibleModule(object):
return (str, e)
return str
+ def _check_type_str(self, value):
+ if isinstance(value, basestring):
+ return value
+ # Note: This could throw a unicode error if value's __str__() method
+ # returns non-ascii. Have to port utils.to_bytes() if that happens
+ return str(value)
+
+ def _check_type_list(self, value):
+ if isinstance(value, list):
+ return value
+
+ if isinstance(value, basestring):
+ return value.split(",")
+ elif isinstance(value, int) or isinstance(value, float):
+ return [ str(value) ]
+
+ raise TypeError('%s cannot be converted to a list' % type(value))
+
+ def _check_type_dict(self, value):
+ if isinstance(value, dict):
+ return value
+
+ if isinstance(value, basestring):
+ if value.startswith("{"):
+ try:
+ return json.loads(value)
+ except:
+ (result, exc) = self.safe_eval(value, dict(), include_exceptions=True)
+ if exc is not None:
+ raise TypeError('unable to evaluate string as dictionary')
+ return result
+ elif '=' in value:
+ return dict([x.strip().split("=", 1) for x in value.split(",")])
+ else:
+ raise TypeError("dictionary requested, could not parse JSON or key=value")
+
+ raise TypeError('%s cannot be converted to a dict' % type(value))
+
+ def _check_type_bool(self, value):
+ if isinstance(value, bool):
+ return value
+
+ if isinstance(value, basestring):
+ return self.boolean(value)
+
+ raise TypeError('%s cannot be converted to a bool' % type(value))
+
+ def _check_type_int(self, value):
+ if isinstance(value, int):
+ return value
+
+ if isinstance(value, basestring):
+ return int(value)
+
+ raise TypeError('%s cannot be converted to an int' % type(value))
+
+ def _check_type_float(self, value):
+ if isinstance(value, float):
+ return value
+
+ if isinstance(value, basestring):
+ return float(value)
+
+ raise TypeError('%s cannot be converted to a float' % type(value))
+
+ def _check_type_path(self, value):
+ value = self._check_type_str(value)
+ return os.path.expanduser(os.path.expandvars(value))
+
+
def _check_argument_types(self):
''' ensure all arguments have the requested type '''
for (k, v) in self.argument_spec.iteritems():
@@ -1034,59 +1114,13 @@ class AnsibleModule(object):
is_invalid = False
try:
- if wanted == 'str':
- if not isinstance(value, basestring):
- self.params[k] = str(value)
- elif wanted == 'list':
- if not isinstance(value, list):
- if isinstance(value, basestring):
- self.params[k] = value.split(",")
- elif isinstance(value, int) or isinstance(value, float):
- self.params[k] = [ str(value) ]
- else:
- is_invalid = True
- elif wanted == 'dict':
- if not isinstance(value, dict):
- if isinstance(value, basestring):
- if value.startswith("{"):
- try:
- self.params[k] = json.loads(value)
- except:
- (result, exc) = self.safe_eval(value, dict(), include_exceptions=True)
- if exc is not None:
- self.fail_json(msg="unable to evaluate dictionary for %s" % k)
- self.params[k] = result
- elif '=' in value:
- self.params[k] = dict([x.strip().split("=", 1) for x in value.split(",")])
- else:
- self.fail_json(msg="dictionary requested, could not parse JSON or key=value")
- else:
- is_invalid = True
- elif wanted == 'bool':
- if not isinstance(value, bool):
- if isinstance(value, basestring):
- self.params[k] = self.boolean(value)
- else:
- is_invalid = True
- elif wanted == 'int':
- if not isinstance(value, int):
- if isinstance(value, basestring):
- self.params[k] = int(value)
- else:
- is_invalid = True
- elif wanted == 'float':
- if not isinstance(value, float):
- if isinstance(value, basestring):
- self.params[k] = float(value)
- else:
- is_invalid = True
- else:
- self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
-
- if is_invalid:
- self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted))
- except ValueError:
- self.fail_json(msg="value of argument %s is not of type %s and we were unable to automatically convert" % (k, wanted))
+ type_checker = self._CHECK_ARGUMENT_TYPES_DISPATCHER[wanted]
+ except KeyError:
+ self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
+ try:
+ self.params[k] = type_checker(value)
+ except (TypeError, ValueError):
+ self.fail_json(msg="argument %s is of type %s and we were unable to convert to %s" % (k, type(value), wanted))
def _set_defaults(self, pre=True):
for (k,v) in self.argument_spec.iteritems():