summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--taskflow/engines/action_engine/process_executor.py44
-rw-r--r--taskflow/tests/unit/test_engines.py42
-rw-r--r--taskflow/tests/utils.py10
3 files changed, 87 insertions, 9 deletions
diff --git a/taskflow/engines/action_engine/process_executor.py b/taskflow/engines/action_engine/process_executor.py
index 85d37e0..8699ed5 100644
--- a/taskflow/engines/action_engine/process_executor.py
+++ b/taskflow/engines/action_engine/process_executor.py
@@ -34,6 +34,7 @@ import six
from taskflow.engines.action_engine import executor as base
from taskflow import logging
from taskflow import task as ta
+from taskflow.types import notifier as nt
from taskflow.utils import iter_utils
from taskflow.utils import misc
from taskflow.utils import schema_utils as su
@@ -675,16 +676,38 @@ class ParallelProcessTaskExecutor(base.ParallelTaskExecutor):
# so that when the clone runs in another process that this task
# can receive the same notifications (thus making it look like the
# the notifications are transparently happening in this process).
- needed = set()
+ proxy_event_types = set()
for (event_type, listeners) in task.notifier.listeners_iter():
if listeners:
- needed.add(event_type)
+ proxy_event_types.add(event_type)
if progress_callback is not None:
- needed.add(ta.EVENT_UPDATE_PROGRESS)
- if needed:
+ proxy_event_types.add(ta.EVENT_UPDATE_PROGRESS)
+ if nt.Notifier.ANY in proxy_event_types:
+ # NOTE(harlowja): If ANY is present, just have it be
+ # the **only** event registered, as all other events will be
+ # sent if ANY is registered (due to the nature of ANY sending
+ # all the things); if we also include the other event types
+ # in this set if ANY is present we will receive duplicate
+ # messages in this process (the one where the local
+ # task callbacks are being triggered). For example the
+ # emissions of the tasks notifier (that is running out
+ # of process) will for specific events send messages for
+ # its ANY event type **and** the specific event
+ # type (2 messages, when we just want one) which will
+ # cause > 1 notify() call on the local tasks notifier, which
+ # causes more local callback triggering than we want
+ # to actually happen.
+ proxy_event_types = set([nt.Notifier.ANY])
+ if proxy_event_types:
+ # This sender acts as our forwarding proxy target, it
+ # will be sent pickled to the process that will execute
+ # the needed task and it will do the work of using the
+ # channel object to send back messages to this process for
+ # dispatch into the local task.
sender = EventSender(channel)
- for event_type in needed:
+ for event_type in proxy_event_types:
clone.notifier.register(event_type, sender)
+ return bool(proxy_event_types)
def register():
if progress_callback is not None:
@@ -698,14 +721,17 @@ class ParallelProcessTaskExecutor(base.ParallelTaskExecutor):
progress_callback)
self._dispatcher.targets.pop(identity, None)
- rebind_task()
- register()
+ should_register = rebind_task()
+ if should_register:
+ register()
try:
fut = self._executor.submit(func, clone, *args, **kwargs)
except RuntimeError:
with excutils.save_and_reraise_exception():
- deregister()
+ if should_register:
+ deregister()
fut.atom = task
- fut.add_done_callback(deregister)
+ if should_register:
+ fut.add_done_callback(deregister)
return fut
diff --git a/taskflow/tests/unit/test_engines.py b/taskflow/tests/unit/test_engines.py
index 6fd86f2..89a1a49 100644
--- a/taskflow/tests/unit/test_engines.py
+++ b/taskflow/tests/unit/test_engines.py
@@ -1527,6 +1527,48 @@ class ParallelEngineWithProcessTest(EngineTaskTest,
max_workers=self._EXECUTOR_WORKERS,
**kwargs)
+ def test_update_progress_notifications_proxied(self):
+ captured = collections.defaultdict(list)
+
+ def notify_me(event_type, details):
+ captured[event_type].append(details)
+
+ a = utils.MultiProgressingTask('a')
+ a.notifier.register(a.notifier.ANY, notify_me)
+ progress_chunks = list(x / 10.0 for x in range(1, 10))
+ e = self._make_engine(a, store={'progress_chunks': progress_chunks})
+ e.run()
+
+ self.assertEqual(11, len(captured[task.EVENT_UPDATE_PROGRESS]))
+
+ def test_custom_notifications_proxied(self):
+ captured = collections.defaultdict(list)
+
+ def notify_me(event_type, details):
+ captured[event_type].append(details)
+
+ a = utils.EmittingTask('a')
+ a.notifier.register(a.notifier.ANY, notify_me)
+ e = self._make_engine(a)
+ e.run()
+
+ self.assertEqual(1, len(captured['hi']))
+ self.assertEqual(2, len(captured[task.EVENT_UPDATE_PROGRESS]))
+
+ def test_just_custom_notifications_proxied(self):
+ captured = collections.defaultdict(list)
+
+ def notify_me(event_type, details):
+ captured[event_type].append(details)
+
+ a = utils.EmittingTask('a')
+ a.notifier.register('hi', notify_me)
+ e = self._make_engine(a)
+ e.run()
+
+ self.assertEqual(1, len(captured['hi']))
+ self.assertEqual(0, len(captured[task.EVENT_UPDATE_PROGRESS]))
+
class WorkerBasedEngineTest(EngineTaskTest,
EngineMultipleResultsTest,
diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py
index 471da9b..58cd9ab 100644
--- a/taskflow/tests/utils.py
+++ b/taskflow/tests/utils.py
@@ -19,6 +19,7 @@ import string
import threading
import time
+from oslo_utils import timeutils
import redis
import six
@@ -104,6 +105,15 @@ class DummyTask(task.Task):
pass
+class EmittingTask(task.Task):
+ TASK_EVENTS = (task.EVENT_UPDATE_PROGRESS, 'hi')
+
+ def execute(self, *args, **kwargs):
+ self.notifier.notify('hi',
+ details={'sent_on': timeutils.utcnow(),
+ 'args': args, 'kwargs': kwargs})
+
+
class AddOneSameProvidesRequires(task.Task):
default_provides = 'value'