diff options
author | Jenkins <jenkins@review.openstack.org> | 2013-10-22 21:40:17 +0000 |
---|---|---|
committer | Gerrit Code Review <review@openstack.org> | 2013-10-22 21:40:17 +0000 |
commit | a80479b340a4fd6e557aa9547707c9867ff72526 (patch) | |
tree | a58329db603b6cabe76e58168df6a857a83edd55 | |
parent | 2532be0e893af523f0d09d033a44b85a17522165 (diff) | |
parent | 7a8aa34c86530e34e7ed7d6d87344af2c2db0489 (diff) | |
download | taskflow-a80479b340a4fd6e557aa9547707c9867ff72526.tar.gz |
Merge "Support for optional task arguments"0.1
-rw-r--r-- | taskflow/task.py | 31 | ||||
-rw-r--r-- | taskflow/tests/unit/test_task.py | 13 | ||||
-rw-r--r-- | taskflow/tests/unit/test_utils.py | 23 | ||||
-rw-r--r-- | taskflow/utils/reflection.py | 13 |
4 files changed, 52 insertions, 28 deletions
diff --git a/taskflow/task.py b/taskflow/task.py index 5925fde..feed142 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -78,34 +78,33 @@ def _build_rebind_dict(args, rebind_args): raise TypeError('Invalid rebind value: %s' % rebind_args) -def _check_args_mapping(task_name, rebind, args, accepts_kwargs): - args = set(args) - rebind = set(rebind.keys()) - extra_args = rebind - args - missing_args = args - rebind - if not accepts_kwargs and extra_args: - raise ValueError('Extra arguments given to task %s: %s' - % (task_name, sorted(extra_args))) - if missing_args: - raise ValueError('Missing arguments for task %s: %s' - % (task_name, sorted(missing_args))) - - def _build_arg_mapping(task_name, reqs, rebind_args, function, do_infer): """Given a function, its requirements and a rebind mapping this helper function will build the correct argument mapping for the given function as well as verify that the final argument mapping does not have missing or extra arguments (where applicable). """ - task_args = reflection.get_required_callable_args(function) - accepts_kwargs = reflection.accepts_kwargs(function) + task_args = reflection.get_callable_args(function, required_only=True) result = {} if reqs: result.update((a, a) for a in reqs) if do_infer: result.update((a, a) for a in task_args) result.update(_build_rebind_dict(task_args, rebind_args)) - _check_args_mapping(task_name, result, task_args, accepts_kwargs) + + if not reflection.accepts_kwargs(function): + all_args = reflection.get_callable_args(function, required_only=False) + extra_args = set(result) - set(all_args) + if extra_args: + extra_args_str = ', '.join(sorted(extra_args)) + raise ValueError('Extra arguments given to task %s: %s' + % (task_name, extra_args_str)) + + # NOTE(imelnikov): don't use set to preserve order in error message + missing_args = [arg for arg in task_args if arg not in result] + if missing_args: + raise ValueError('Missing arguments for task %s: %s' + % (task_name, ' ,'.join(missing_args))) return result diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index e661aa4..e7b4365 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -31,6 +31,11 @@ class KwargsTask(task.Task): pass +class DefaultArgTask(task.Task): + def execute(self, spam, eggs=()): + pass + + class DefaultProvidesTask(task.Task): default_provides = 'def' @@ -98,6 +103,14 @@ class TaskTestCase(test.TestCase): with self.assertRaisesRegexp(ValueError, '^Missing arguments'): MyTask(auto_extract=False, requires=('spam', 'eggs')) + def test_requires_ignores_optional(self): + my_task = DefaultArgTask() + self.assertEquals(my_task.requires, set(['spam'])) + + def test_requires_allows_optional(self): + my_task = DefaultArgTask(requires=('spam', 'eggs')) + self.assertEquals(my_task.requires, set(['spam', 'eggs'])) + def test_rebind_all_args(self): my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'}) self.assertEquals(my_task.rebind, { diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index ef88709..816d322 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -99,41 +99,46 @@ class GetCallableNameTest(test.TestCase): '__call__'))) -class GetRequiredCallableArgsTest(test.TestCase): +class GetCallableArgsTest(test.TestCase): def test_mere_function(self): - result = reflection.get_required_callable_args(mere_function) + result = reflection.get_callable_args(mere_function) self.assertEquals(['a', 'b'], result) def test_function_with_defaults(self): - result = reflection.get_required_callable_args(function_with_defs) + result = reflection.get_callable_args(function_with_defs) + self.assertEquals(['a', 'b', 'optional'], result) + + def test_required_only(self): + result = reflection.get_callable_args(function_with_defs, + required_only=True) self.assertEquals(['a', 'b'], result) def test_method(self): - result = reflection.get_required_callable_args(Class.method) + result = reflection.get_callable_args(Class.method) self.assertEquals(['self', 'c', 'd'], result) def test_instance_method(self): - result = reflection.get_required_callable_args(Class().method) + result = reflection.get_callable_args(Class().method) self.assertEquals(['c', 'd'], result) def test_class_method(self): - result = reflection.get_required_callable_args(Class.class_method) + result = reflection.get_callable_args(Class.class_method) self.assertEquals(['g', 'h'], result) def test_class_constructor(self): - result = reflection.get_required_callable_args(ClassWithInit) + result = reflection.get_callable_args(ClassWithInit) self.assertEquals(['k', 'l'], result) def test_class_with_call(self): - result = reflection.get_required_callable_args(CallableClass()) + result = reflection.get_callable_args(CallableClass()) self.assertEquals(['i', 'j'], result) def test_decorators_work(self): @lock_utils.locked def locked_fun(x, y): pass - result = reflection.get_required_callable_args(locked_fun) + result = reflection.get_callable_args(locked_fun) self.assertEquals(['x', 'y'], result) diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py index a630aae..6621244 100644 --- a/taskflow/utils/reflection.py +++ b/taskflow/utils/reflection.py @@ -111,11 +111,18 @@ def _get_arg_spec(function): return inspect.getargspec(function), bound -def get_required_callable_args(function): - """Get names of argument required by callable""" +def get_callable_args(function, required_only=False): + """Get names of callable arguments + + Special arguments (like *args and **kwargs) are not included into + output. + + If required_only is True, optional arguments (with default values) + are not included into output. + """ argspec, bound = _get_arg_spec(function) f_args = argspec.args - if argspec.defaults: + if required_only and argspec.defaults: f_args = f_args[:-len(argspec.defaults)] if bound: f_args = f_args[1:] |