summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2013-10-22 21:40:17 +0000
committerGerrit Code Review <review@openstack.org>2013-10-22 21:40:17 +0000
commita80479b340a4fd6e557aa9547707c9867ff72526 (patch)
treea58329db603b6cabe76e58168df6a857a83edd55
parent2532be0e893af523f0d09d033a44b85a17522165 (diff)
parent7a8aa34c86530e34e7ed7d6d87344af2c2db0489 (diff)
downloadtaskflow-a80479b340a4fd6e557aa9547707c9867ff72526.tar.gz
Merge "Support for optional task arguments"0.1
-rw-r--r--taskflow/task.py31
-rw-r--r--taskflow/tests/unit/test_task.py13
-rw-r--r--taskflow/tests/unit/test_utils.py23
-rw-r--r--taskflow/utils/reflection.py13
4 files changed, 52 insertions, 28 deletions
diff --git a/taskflow/task.py b/taskflow/task.py
index 5925fde..feed142 100644
--- a/taskflow/task.py
+++ b/taskflow/task.py
@@ -78,34 +78,33 @@ def _build_rebind_dict(args, rebind_args):
raise TypeError('Invalid rebind value: %s' % rebind_args)
-def _check_args_mapping(task_name, rebind, args, accepts_kwargs):
- args = set(args)
- rebind = set(rebind.keys())
- extra_args = rebind - args
- missing_args = args - rebind
- if not accepts_kwargs and extra_args:
- raise ValueError('Extra arguments given to task %s: %s'
- % (task_name, sorted(extra_args)))
- if missing_args:
- raise ValueError('Missing arguments for task %s: %s'
- % (task_name, sorted(missing_args)))
-
-
def _build_arg_mapping(task_name, reqs, rebind_args, function, do_infer):
"""Given a function, its requirements and a rebind mapping this helper
function will build the correct argument mapping for the given function as
well as verify that the final argument mapping does not have missing or
extra arguments (where applicable).
"""
- task_args = reflection.get_required_callable_args(function)
- accepts_kwargs = reflection.accepts_kwargs(function)
+ task_args = reflection.get_callable_args(function, required_only=True)
result = {}
if reqs:
result.update((a, a) for a in reqs)
if do_infer:
result.update((a, a) for a in task_args)
result.update(_build_rebind_dict(task_args, rebind_args))
- _check_args_mapping(task_name, result, task_args, accepts_kwargs)
+
+ if not reflection.accepts_kwargs(function):
+ all_args = reflection.get_callable_args(function, required_only=False)
+ extra_args = set(result) - set(all_args)
+ if extra_args:
+ extra_args_str = ', '.join(sorted(extra_args))
+ raise ValueError('Extra arguments given to task %s: %s'
+ % (task_name, extra_args_str))
+
+ # NOTE(imelnikov): don't use set to preserve order in error message
+ missing_args = [arg for arg in task_args if arg not in result]
+ if missing_args:
+ raise ValueError('Missing arguments for task %s: %s'
+ % (task_name, ' ,'.join(missing_args)))
return result
diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py
index e661aa4..e7b4365 100644
--- a/taskflow/tests/unit/test_task.py
+++ b/taskflow/tests/unit/test_task.py
@@ -31,6 +31,11 @@ class KwargsTask(task.Task):
pass
+class DefaultArgTask(task.Task):
+ def execute(self, spam, eggs=()):
+ pass
+
+
class DefaultProvidesTask(task.Task):
default_provides = 'def'
@@ -98,6 +103,14 @@ class TaskTestCase(test.TestCase):
with self.assertRaisesRegexp(ValueError, '^Missing arguments'):
MyTask(auto_extract=False, requires=('spam', 'eggs'))
+ def test_requires_ignores_optional(self):
+ my_task = DefaultArgTask()
+ self.assertEquals(my_task.requires, set(['spam']))
+
+ def test_requires_allows_optional(self):
+ my_task = DefaultArgTask(requires=('spam', 'eggs'))
+ self.assertEquals(my_task.requires, set(['spam', 'eggs']))
+
def test_rebind_all_args(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
self.assertEquals(my_task.rebind, {
diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py
index ef88709..816d322 100644
--- a/taskflow/tests/unit/test_utils.py
+++ b/taskflow/tests/unit/test_utils.py
@@ -99,41 +99,46 @@ class GetCallableNameTest(test.TestCase):
'__call__')))
-class GetRequiredCallableArgsTest(test.TestCase):
+class GetCallableArgsTest(test.TestCase):
def test_mere_function(self):
- result = reflection.get_required_callable_args(mere_function)
+ result = reflection.get_callable_args(mere_function)
self.assertEquals(['a', 'b'], result)
def test_function_with_defaults(self):
- result = reflection.get_required_callable_args(function_with_defs)
+ result = reflection.get_callable_args(function_with_defs)
+ self.assertEquals(['a', 'b', 'optional'], result)
+
+ def test_required_only(self):
+ result = reflection.get_callable_args(function_with_defs,
+ required_only=True)
self.assertEquals(['a', 'b'], result)
def test_method(self):
- result = reflection.get_required_callable_args(Class.method)
+ result = reflection.get_callable_args(Class.method)
self.assertEquals(['self', 'c', 'd'], result)
def test_instance_method(self):
- result = reflection.get_required_callable_args(Class().method)
+ result = reflection.get_callable_args(Class().method)
self.assertEquals(['c', 'd'], result)
def test_class_method(self):
- result = reflection.get_required_callable_args(Class.class_method)
+ result = reflection.get_callable_args(Class.class_method)
self.assertEquals(['g', 'h'], result)
def test_class_constructor(self):
- result = reflection.get_required_callable_args(ClassWithInit)
+ result = reflection.get_callable_args(ClassWithInit)
self.assertEquals(['k', 'l'], result)
def test_class_with_call(self):
- result = reflection.get_required_callable_args(CallableClass())
+ result = reflection.get_callable_args(CallableClass())
self.assertEquals(['i', 'j'], result)
def test_decorators_work(self):
@lock_utils.locked
def locked_fun(x, y):
pass
- result = reflection.get_required_callable_args(locked_fun)
+ result = reflection.get_callable_args(locked_fun)
self.assertEquals(['x', 'y'], result)
diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py
index a630aae..6621244 100644
--- a/taskflow/utils/reflection.py
+++ b/taskflow/utils/reflection.py
@@ -111,11 +111,18 @@ def _get_arg_spec(function):
return inspect.getargspec(function), bound
-def get_required_callable_args(function):
- """Get names of argument required by callable"""
+def get_callable_args(function, required_only=False):
+ """Get names of callable arguments
+
+ Special arguments (like *args and **kwargs) are not included into
+ output.
+
+ If required_only is True, optional arguments (with default values)
+ are not included into output.
+ """
argspec, bound = _get_arg_spec(function)
f_args = argspec.args
- if argspec.defaults:
+ if required_only and argspec.defaults:
f_args = f_args[:-len(argspec.defaults)]
if bound:
f_args = f_args[1:]