diff options
-rw-r--r-- | taskflow/engines/action_engine/process_executor.py | 44 | ||||
-rw-r--r-- | taskflow/tests/unit/test_engines.py | 42 | ||||
-rw-r--r-- | taskflow/tests/utils.py | 10 |
3 files changed, 87 insertions, 9 deletions
diff --git a/taskflow/engines/action_engine/process_executor.py b/taskflow/engines/action_engine/process_executor.py index 85d37e0..8699ed5 100644 --- a/taskflow/engines/action_engine/process_executor.py +++ b/taskflow/engines/action_engine/process_executor.py @@ -34,6 +34,7 @@ import six from taskflow.engines.action_engine import executor as base from taskflow import logging from taskflow import task as ta +from taskflow.types import notifier as nt from taskflow.utils import iter_utils from taskflow.utils import misc from taskflow.utils import schema_utils as su @@ -675,16 +676,38 @@ class ParallelProcessTaskExecutor(base.ParallelTaskExecutor): # so that when the clone runs in another process that this task # can receive the same notifications (thus making it look like the # the notifications are transparently happening in this process). - needed = set() + proxy_event_types = set() for (event_type, listeners) in task.notifier.listeners_iter(): if listeners: - needed.add(event_type) + proxy_event_types.add(event_type) if progress_callback is not None: - needed.add(ta.EVENT_UPDATE_PROGRESS) - if needed: + proxy_event_types.add(ta.EVENT_UPDATE_PROGRESS) + if nt.Notifier.ANY in proxy_event_types: + # NOTE(harlowja): If ANY is present, just have it be + # the **only** event registered, as all other events will be + # sent if ANY is registered (due to the nature of ANY sending + # all the things); if we also include the other event types + # in this set if ANY is present we will receive duplicate + # messages in this process (the one where the local + # task callbacks are being triggered). For example the + # emissions of the tasks notifier (that is running out + # of process) will for specific events send messages for + # its ANY event type **and** the specific event + # type (2 messages, when we just want one) which will + # cause > 1 notify() call on the local tasks notifier, which + # causes more local callback triggering than we want + # to actually happen. + proxy_event_types = set([nt.Notifier.ANY]) + if proxy_event_types: + # This sender acts as our forwarding proxy target, it + # will be sent pickled to the process that will execute + # the needed task and it will do the work of using the + # channel object to send back messages to this process for + # dispatch into the local task. sender = EventSender(channel) - for event_type in needed: + for event_type in proxy_event_types: clone.notifier.register(event_type, sender) + return bool(proxy_event_types) def register(): if progress_callback is not None: @@ -698,14 +721,17 @@ class ParallelProcessTaskExecutor(base.ParallelTaskExecutor): progress_callback) self._dispatcher.targets.pop(identity, None) - rebind_task() - register() + should_register = rebind_task() + if should_register: + register() try: fut = self._executor.submit(func, clone, *args, **kwargs) except RuntimeError: with excutils.save_and_reraise_exception(): - deregister() + if should_register: + deregister() fut.atom = task - fut.add_done_callback(deregister) + if should_register: + fut.add_done_callback(deregister) return fut diff --git a/taskflow/tests/unit/test_engines.py b/taskflow/tests/unit/test_engines.py index 6fd86f2..89a1a49 100644 --- a/taskflow/tests/unit/test_engines.py +++ b/taskflow/tests/unit/test_engines.py @@ -1527,6 +1527,48 @@ class ParallelEngineWithProcessTest(EngineTaskTest, max_workers=self._EXECUTOR_WORKERS, **kwargs) + def test_update_progress_notifications_proxied(self): + captured = collections.defaultdict(list) + + def notify_me(event_type, details): + captured[event_type].append(details) + + a = utils.MultiProgressingTask('a') + a.notifier.register(a.notifier.ANY, notify_me) + progress_chunks = list(x / 10.0 for x in range(1, 10)) + e = self._make_engine(a, store={'progress_chunks': progress_chunks}) + e.run() + + self.assertEqual(11, len(captured[task.EVENT_UPDATE_PROGRESS])) + + def test_custom_notifications_proxied(self): + captured = collections.defaultdict(list) + + def notify_me(event_type, details): + captured[event_type].append(details) + + a = utils.EmittingTask('a') + a.notifier.register(a.notifier.ANY, notify_me) + e = self._make_engine(a) + e.run() + + self.assertEqual(1, len(captured['hi'])) + self.assertEqual(2, len(captured[task.EVENT_UPDATE_PROGRESS])) + + def test_just_custom_notifications_proxied(self): + captured = collections.defaultdict(list) + + def notify_me(event_type, details): + captured[event_type].append(details) + + a = utils.EmittingTask('a') + a.notifier.register('hi', notify_me) + e = self._make_engine(a) + e.run() + + self.assertEqual(1, len(captured['hi'])) + self.assertEqual(0, len(captured[task.EVENT_UPDATE_PROGRESS])) + class WorkerBasedEngineTest(EngineTaskTest, EngineMultipleResultsTest, diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index 471da9b..58cd9ab 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -19,6 +19,7 @@ import string import threading import time +from oslo_utils import timeutils import redis import six @@ -104,6 +105,15 @@ class DummyTask(task.Task): pass +class EmittingTask(task.Task): + TASK_EVENTS = (task.EVENT_UPDATE_PROGRESS, 'hi') + + def execute(self, *args, **kwargs): + self.notifier.notify('hi', + details={'sent_on': timeutils.utcnow(), + 'args': args, 'kwargs': kwargs}) + + class AddOneSameProvidesRequires(task.Task): default_provides = 'value' |