diff options
| author | Anastasia Karpinska <akarpinska@griddynamics.com> | 2013-10-17 13:16:06 +0300 |
|---|---|---|
| committer | Anastasia Karpinska <akarpinska@griddynamics.com> | 2013-10-17 15:35:18 +0300 |
| commit | e7508eb8e53217d889fe8af24a92cdd0dbb28068 (patch) | |
| tree | 332dc2dca36610daa1a00971c802591659ad7643 /taskflow/tests/utils.py | |
| parent | 176d0a920effa3c4052271de8cedc3b0228a97fc (diff) | |
| download | taskflow-e7508eb8e53217d889fe8af24a92cdd0dbb28068.tar.gz | |
Unit tests refactoring
* duplicated tests were removed
* common tasks moved to utils
Change-Id: I69c91a264ec668b1333db8fd907298262af098cb
Diffstat (limited to 'taskflow/tests/utils.py')
| -rw-r--r-- | taskflow/tests/utils.py | 179 |
1 files changed, 163 insertions, 16 deletions
diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index db31a2a..81107ae 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -16,8 +16,11 @@ # License for the specific language governing permissions and limitations # under the License. +import contextlib import six +import time +from taskflow.persistence.backends import impl_memory from taskflow import task ARGS_KEY = '__args__' @@ -45,6 +48,19 @@ def make_reverting_task(token, blowup=False): name='do_apply_%s' % token) +class DummyTask(task.Task): + def execute(self, context, *args, **kwargs): + pass + + +if six.PY3: + RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', + 'BaseException', 'object'] +else: + RUNTIME_ERROR_CLASSES = ['RuntimeError', 'StandardError', 'Exception', + 'BaseException', 'object'] + + class ProvidesRequiresTask(task.Task): def __init__(self, name, provides, requires, return_tuple=True): super(ProvidesRequiresTask, self).__init__(name=name, @@ -52,28 +68,159 @@ class ProvidesRequiresTask(task.Task): requires=requires) self.return_tuple = isinstance(provides, (tuple, list)) - def execute(self, context, *args, **kwargs): - if ORDER_KEY not in context: - context[ORDER_KEY] = [] - context[ORDER_KEY].append({ - 'name': self.name, - KWARGS_KEY: kwargs, - ARGS_KEY: args, - }) + def execute(self, *args, **kwargs): if self.return_tuple: return tuple(range(len(self.provides))) else: return dict((k, k) for k in self.provides) -class DummyTask(task.Task): - def execute(self, context, *args, **kwargs): +class SaveOrderTask(task.Task): + + def __init__(self, values=None, name=None, sleep=None, + *args, **kwargs): + super(SaveOrderTask, self).__init__(name=name, *args, **kwargs) + if values is None: + self.values = [] + else: + self.values = values + self._sleep = sleep + + def execute(self, **kwargs): + self.update_progress(0.0) + if self._sleep: + time.sleep(self._sleep) + self.values.append(self.name) + self.update_progress(1.0) + return 5 + + def revert(self, **kwargs): + self.update_progress(0) + if self._sleep: + time.sleep(self._sleep) + self.values.append(self.name + ' reverted(%s)' + % kwargs.get('result')) + self.update_progress(1.0) + + +class FailingTask(SaveOrderTask): + + def execute(self, **kwargs): + self.update_progress(0) + if self._sleep: + time.sleep(self._sleep) + self.update_progress(0.99) + raise RuntimeError('Woot!') + + +class NastyTask(task.Task): + def execute(self, **kwargs): pass + def revert(self, **kwargs): + raise RuntimeError('Gotcha!') -if six.PY3: - RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', - 'BaseException', 'object'] -else: - RUNTIME_ERROR_CLASSES = ['RuntimeError', 'StandardError', 'Exception', - 'BaseException', 'object'] + +class TaskNoRequiresNoReturns(task.Task): + + def execute(self, **kwargs): + pass + + def revert(self, **kwargs): + pass + + +class TaskOneArg(task.Task): + + def execute(self, x, **kwargs): + pass + + def revert(self, x, **kwargs): + pass + + +class TaskMultiArg(task.Task): + + def execute(self, x, y, z, **kwargs): + pass + + def revert(self, x, y, z, **kwargs): + pass + + +class TaskOneReturn(task.Task): + + def execute(self, **kwargs): + return 1 + + def revert(self, **kwargs): + pass + + +class TaskMultiReturn(task.Task): + + def execute(self, **kwargs): + return 1, 3, 5 + + def revert(self, **kwargs): + pass + + +class TaskOneArgOneReturn(task.Task): + + def execute(self, x, **kwargs): + return 1 + + def revert(self, x, **kwargs): + pass + + +class TaskMultiArgOneReturn(task.Task): + + def execute(self, x, y, z, **kwargs): + return x + y + z + + def revert(self, x, y, z, **kwargs): + pass + + +class TaskMultiArgMultiReturn(task.Task): + + def execute(self, x, y, z, **kwargs): + return 1, 3, 5 + + def revert(self, x, y, z, **kwargs): + pass + + +class TaskMultiDictk(task.Task): + + def execute(self): + output = {} + for i, k in enumerate(sorted(self.provides)): + output[k] = i + return output + + +class NeverRunningTask(task.Task): + def execute(self, **kwargs): + assert False, 'This method should not be called' + + def revert(self, **kwargs): + assert False, 'This method should not be called' + + +class EngineTestBase(object): + def setUp(self): + super(EngineTestBase, self).setUp() + self.values = [] + self.backend = impl_memory.MemoryBackend(conf={}) + + def tearDown(self): + super(EngineTestBase, self).tearDown() + with contextlib.closing(self.backend) as be: + with contextlib.closing(be.get_connection()) as conn: + conn.clear_all() + + def _make_engine(self, flow, flow_detail=None): + raise NotImplementedError() |
