summaryrefslogtreecommitdiff
path: root/taskflow/utils/threading_utils.py
diff options
context:
space:
mode:
authorJoshua Harlow <harlowja@gmail.com>2015-01-24 00:45:36 -0800
committerJoshua Harlow <harlowja@gmail.com>2015-01-24 18:33:51 -0800
commitca82e20efe8f5c5d50b3db89be0342710ef7f73b (patch)
treee4682847164001d229d70f1ecc94739fd80958af /taskflow/utils/threading_utils.py
parent1ae7a8e67b79f1ea7533525ef27271978365afe9 (diff)
downloadtaskflow-ca82e20efe8f5c5d50b3db89be0342710ef7f73b.tar.gz
Add a thread bundle helper utility + tests
To make it easier to create a bunch of threads in a single call (and stop them in a single call) create a concept of a thread bundle (similar to a thread group) that will call into a provided set of factories to get a thread, activate callbacks to notify others that a thread is about to start or stop and then perform the start or stop of the bound threads in a orderly manner. Change-Id: I7d233cccb230b716af41243ad27220b988eec14c
Diffstat (limited to 'taskflow/utils/threading_utils.py')
-rw-r--r--taskflow/utils/threading_utils.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py
index 5048401..cea0760 100644
--- a/taskflow/utils/threading_utils.py
+++ b/taskflow/utils/threading_utils.py
@@ -14,10 +14,12 @@
# License for the specific language governing permissions and limitations
# under the License.
+import collections
import multiprocessing
import sys
import threading
+import six
from six.moves import _thread
@@ -71,3 +73,105 @@ def daemon_thread(target, *args, **kwargs):
# unless the daemon property is set to True.
thread.daemon = True
return thread
+
+
+# Container for thread creator + associated callbacks.
+_ThreadBuilder = collections.namedtuple('_ThreadBuilder',
+ ['thread_factory',
+ 'before_start', 'after_start',
+ 'before_join', 'after_join'])
+_ThreadBuilder.callables = tuple([
+ # Attribute name -> none allowed as a valid value...
+ ('thread_factory', False),
+ ('before_start', True),
+ ('after_start', True),
+ ('before_join', True),
+ ('after_join', True),
+])
+
+
+class ThreadBundle(object):
+ """A group/bundle of threads that start/stop together."""
+
+ def __init__(self):
+ self._threads = []
+ self._lock = threading.Lock()
+
+ def bind(self, thread_factory,
+ before_start=None, after_start=None,
+ before_join=None, after_join=None):
+ """Adds a thread (to-be) into this bundle (with given callbacks).
+
+ NOTE(harlowja): callbacks provided should not attempt to call
+ mutating methods (:meth:`.stop`, :meth:`.start`,
+ :meth:`.bind` ...) on this object as that will result
+ in dead-lock since the lock on this object is not
+ meant to be (and is not) reentrant...
+ """
+ builder = _ThreadBuilder(thread_factory,
+ before_start, after_start,
+ before_join, after_join)
+ for attr_name, none_allowed in builder.callables:
+ cb = getattr(builder, attr_name)
+ if cb is None and none_allowed:
+ continue
+ if not six.callable(cb):
+ raise ValueError("Provided callback for argument"
+ " '%s' must be callable" % attr_name)
+ with self._lock:
+ self._threads.append([
+ builder,
+ # The built thread.
+ None,
+ # Whether the built thread was started (and should have
+ # ran or still be running).
+ False,
+ ])
+
+ @staticmethod
+ def _trigger_callback(callback, thread):
+ if callback is not None:
+ callback(thread)
+
+ def start(self):
+ """Creates & starts all associated threads (that are not running)."""
+ count = 0
+ with self._lock:
+ for i, (builder, thread, started) in enumerate(self._threads):
+ if thread and started:
+ continue
+ if not thread:
+ self._threads[i][1] = thread = builder.thread_factory()
+ self._trigger_callback(builder.before_start, thread)
+ thread.start()
+ count += 1
+ try:
+ self._trigger_callback(builder.after_start, thread)
+ finally:
+ # Just incase the 'after_start' callback blows up make sure
+ # we always set this...
+ self._threads[i][2] = started = True
+ return count
+
+ def stop(self):
+ """Stops & joins all associated threads (that have been started)."""
+ count = 0
+ with self._lock:
+ for i, (builder, thread, started) in enumerate(self._threads):
+ if not thread or not started:
+ continue
+ self._trigger_callback(builder.before_join, thread)
+ thread.join()
+ count += 1
+ try:
+ self._trigger_callback(builder.after_join, thread)
+ finally:
+ # Just incase the 'after_join' callback blows up make sure
+ # we always set/reset these...
+ self._threads[i][1] = thread = None
+ self._threads[i][2] = started = False
+ return count
+
+ def __len__(self):
+ """Returns how many threads (to-be) are in this bundle."""
+ return len(self._threads)