summaryrefslogtreecommitdiff
path: root/taskflow/tests/utils.py
diff options
context:
space:
mode:
authorAnastasia Karpinska <akarpinska@griddynamics.com>2013-10-17 13:16:06 +0300
committerAnastasia Karpinska <akarpinska@griddynamics.com>2013-10-17 15:35:18 +0300
commite7508eb8e53217d889fe8af24a92cdd0dbb28068 (patch)
tree332dc2dca36610daa1a00971c802591659ad7643 /taskflow/tests/utils.py
parent176d0a920effa3c4052271de8cedc3b0228a97fc (diff)
downloadtaskflow-e7508eb8e53217d889fe8af24a92cdd0dbb28068.tar.gz
Unit tests refactoring
* duplicated tests were removed * common tasks moved to utils Change-Id: I69c91a264ec668b1333db8fd907298262af098cb
Diffstat (limited to 'taskflow/tests/utils.py')
-rw-r--r--taskflow/tests/utils.py179
1 files changed, 163 insertions, 16 deletions
diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py
index db31a2a..81107ae 100644
--- a/taskflow/tests/utils.py
+++ b/taskflow/tests/utils.py
@@ -16,8 +16,11 @@
# License for the specific language governing permissions and limitations
# under the License.
+import contextlib
import six
+import time
+from taskflow.persistence.backends import impl_memory
from taskflow import task
ARGS_KEY = '__args__'
@@ -45,6 +48,19 @@ def make_reverting_task(token, blowup=False):
name='do_apply_%s' % token)
+class DummyTask(task.Task):
+ def execute(self, context, *args, **kwargs):
+ pass
+
+
+if six.PY3:
+ RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
+ 'BaseException', 'object']
+else:
+ RUNTIME_ERROR_CLASSES = ['RuntimeError', 'StandardError', 'Exception',
+ 'BaseException', 'object']
+
+
class ProvidesRequiresTask(task.Task):
def __init__(self, name, provides, requires, return_tuple=True):
super(ProvidesRequiresTask, self).__init__(name=name,
@@ -52,28 +68,159 @@ class ProvidesRequiresTask(task.Task):
requires=requires)
self.return_tuple = isinstance(provides, (tuple, list))
- def execute(self, context, *args, **kwargs):
- if ORDER_KEY not in context:
- context[ORDER_KEY] = []
- context[ORDER_KEY].append({
- 'name': self.name,
- KWARGS_KEY: kwargs,
- ARGS_KEY: args,
- })
+ def execute(self, *args, **kwargs):
if self.return_tuple:
return tuple(range(len(self.provides)))
else:
return dict((k, k) for k in self.provides)
-class DummyTask(task.Task):
- def execute(self, context, *args, **kwargs):
+class SaveOrderTask(task.Task):
+
+ def __init__(self, values=None, name=None, sleep=None,
+ *args, **kwargs):
+ super(SaveOrderTask, self).__init__(name=name, *args, **kwargs)
+ if values is None:
+ self.values = []
+ else:
+ self.values = values
+ self._sleep = sleep
+
+ def execute(self, **kwargs):
+ self.update_progress(0.0)
+ if self._sleep:
+ time.sleep(self._sleep)
+ self.values.append(self.name)
+ self.update_progress(1.0)
+ return 5
+
+ def revert(self, **kwargs):
+ self.update_progress(0)
+ if self._sleep:
+ time.sleep(self._sleep)
+ self.values.append(self.name + ' reverted(%s)'
+ % kwargs.get('result'))
+ self.update_progress(1.0)
+
+
+class FailingTask(SaveOrderTask):
+
+ def execute(self, **kwargs):
+ self.update_progress(0)
+ if self._sleep:
+ time.sleep(self._sleep)
+ self.update_progress(0.99)
+ raise RuntimeError('Woot!')
+
+
+class NastyTask(task.Task):
+ def execute(self, **kwargs):
pass
+ def revert(self, **kwargs):
+ raise RuntimeError('Gotcha!')
-if six.PY3:
- RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
- 'BaseException', 'object']
-else:
- RUNTIME_ERROR_CLASSES = ['RuntimeError', 'StandardError', 'Exception',
- 'BaseException', 'object']
+
+class TaskNoRequiresNoReturns(task.Task):
+
+ def execute(self, **kwargs):
+ pass
+
+ def revert(self, **kwargs):
+ pass
+
+
+class TaskOneArg(task.Task):
+
+ def execute(self, x, **kwargs):
+ pass
+
+ def revert(self, x, **kwargs):
+ pass
+
+
+class TaskMultiArg(task.Task):
+
+ def execute(self, x, y, z, **kwargs):
+ pass
+
+ def revert(self, x, y, z, **kwargs):
+ pass
+
+
+class TaskOneReturn(task.Task):
+
+ def execute(self, **kwargs):
+ return 1
+
+ def revert(self, **kwargs):
+ pass
+
+
+class TaskMultiReturn(task.Task):
+
+ def execute(self, **kwargs):
+ return 1, 3, 5
+
+ def revert(self, **kwargs):
+ pass
+
+
+class TaskOneArgOneReturn(task.Task):
+
+ def execute(self, x, **kwargs):
+ return 1
+
+ def revert(self, x, **kwargs):
+ pass
+
+
+class TaskMultiArgOneReturn(task.Task):
+
+ def execute(self, x, y, z, **kwargs):
+ return x + y + z
+
+ def revert(self, x, y, z, **kwargs):
+ pass
+
+
+class TaskMultiArgMultiReturn(task.Task):
+
+ def execute(self, x, y, z, **kwargs):
+ return 1, 3, 5
+
+ def revert(self, x, y, z, **kwargs):
+ pass
+
+
+class TaskMultiDictk(task.Task):
+
+ def execute(self):
+ output = {}
+ for i, k in enumerate(sorted(self.provides)):
+ output[k] = i
+ return output
+
+
+class NeverRunningTask(task.Task):
+ def execute(self, **kwargs):
+ assert False, 'This method should not be called'
+
+ def revert(self, **kwargs):
+ assert False, 'This method should not be called'
+
+
+class EngineTestBase(object):
+ def setUp(self):
+ super(EngineTestBase, self).setUp()
+ self.values = []
+ self.backend = impl_memory.MemoryBackend(conf={})
+
+ def tearDown(self):
+ super(EngineTestBase, self).tearDown()
+ with contextlib.closing(self.backend) as be:
+ with contextlib.closing(be.get_connection()) as conn:
+ conn.clear_all()
+
+ def _make_engine(self, flow, flow_detail=None):
+ raise NotImplementedError()