summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Booth <mbooth@redhat.com>2015-10-19 14:11:23 +0100
committerMatthew Booth <mbooth@redhat.com>2015-10-23 16:15:06 +0100
commitd700c382791b6352bb80a0dc455589085881669f (patch)
tree254c1649008c6867e7714f992cb995328b545751
parentb93d20854393fcc660a32760fcf1af7a45e10225 (diff)
downloadoslo-messaging-d700c382791b6352bb80a0dc455589085881669f.tar.gz
Robustify locking in MessageHandlingServer
This change formalises locking in MessageHandlingServer. It allows the user to make calls in any order and it will ensure, with locking, that these will be reordered appropriately. It also adds locking for internal state when using the blocking executor, which closes a number of races. It fixes a regression introduced in change gI3cfbe1bf02d451e379b1dcc23dacb0139c03be76. If multiple threads called wait() simultaneously, only 1 of them would wait and the others would return immediately, despite message handling not having completed. With this change only 1 will call the underlying wait, but all will wait on its completion. We add a common logging mechanism when waiting too long. Specifically, we now log a single message when waiting on any lock for longer than 30 seconds. We remove DummyCondition as it no longer has any users. Change-Id: I9d516b208446963dcd80b75e2d5a2cecb1187efa
-rw-r--r--oslo_messaging/_utils.py23
-rw-r--r--oslo_messaging/server.py260
-rw-r--r--oslo_messaging/tests/rpc/test_server.py210
3 files changed, 402 insertions, 91 deletions
diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py
index 1bb20b0..cec94bb 100644
--- a/oslo_messaging/_utils.py
+++ b/oslo_messaging/_utils.py
@@ -116,29 +116,6 @@ def fetch_current_thread_functor():
return lambda: threading.current_thread()
-class DummyCondition(object):
- def acquire(self):
- pass
-
- def notify(self):
- pass
-
- def notify_all(self):
- pass
-
- def wait(self, timeout=None):
- pass
-
- def release(self):
- pass
-
- def __enter__(self):
- self.acquire()
-
- def __exit__(self, type, value, traceback):
- self.release()
-
-
class DummyLock(object):
def acquire(self):
pass
diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py
index 491ccbf..f1739ad 100644
--- a/oslo_messaging/server.py
+++ b/oslo_messaging/server.py
@@ -23,16 +23,17 @@ __all__ = [
'ServerListenError',
]
+import functools
+import inspect
import logging
import threading
+import traceback
from oslo_service import service
from oslo_utils import timeutils
from stevedore import driver
from oslo_messaging._drivers import base as driver_base
-from oslo_messaging._i18n import _LW
-from oslo_messaging import _utils
from oslo_messaging import exceptions
LOG = logging.getLogger(__name__)
@@ -62,7 +63,170 @@ class ServerListenError(MessagingServerError):
self.ex = ex
-class MessageHandlingServer(service.ServiceBase):
+class _OrderedTask(object):
+ """A task which must be executed in a particular order.
+
+ A caller may wait for this task to complete by calling
+ `wait_for_completion`.
+
+ A caller may run this task with `run_once`, which will ensure that however
+ many times the task is called it only runs once. Simultaneous callers will
+ block until the running task completes, which means that any caller can be
+ sure that the task has completed after run_once returns.
+ """
+
+ INIT = 0 # The task has not yet started
+ RUNNING = 1 # The task is running somewhere
+ COMPLETE = 2 # The task has run somewhere
+
+ # We generate a log message if we wait for a lock longer than
+ # LOG_AFTER_WAIT_SECS seconds
+ LOG_AFTER_WAIT_SECS = 30
+
+ def __init__(self, name):
+ """Create a new _OrderedTask.
+
+ :param name: The name of this task. Used in log messages.
+ """
+
+ super(_OrderedTask, self).__init__()
+
+ self._name = name
+ self._cond = threading.Condition()
+ self._state = self.INIT
+
+ def _wait(self, condition, warn_msg):
+ """Wait while condition() is true. Write a log message if condition()
+ has not become false within LOG_AFTER_WAIT_SECS.
+ """
+ with timeutils.StopWatch(duration=self.LOG_AFTER_WAIT_SECS) as w:
+ logged = False
+ while condition():
+ wait = None if logged else w.leftover()
+ self._cond.wait(wait)
+
+ if not logged and w.expired():
+ LOG.warn(warn_msg)
+ LOG.debug(''.join(traceback.format_stack()))
+ # Only log once. After than we wait indefinitely without
+ # logging.
+ logged = True
+
+ def wait_for_completion(self, caller):
+ """Wait until this task has completed.
+
+ :param caller: The name of the task which is waiting.
+ """
+ with self._cond:
+ self._wait(lambda: self._state != self.COMPLETE,
+ '%s has been waiting for %s to complete for longer '
+ 'than %i seconds'
+ % (caller, self._name, self.LOG_AFTER_WAIT_SECS))
+
+ def run_once(self, fn):
+ """Run a task exactly once. If it is currently running in another
+ thread, wait for it to complete. If it has already run, return
+ immediately without running it again.
+
+ :param fn: The task to run. It must be a callable taking no arguments.
+ It may optionally return another callable, which also takes
+ no arguments, which will be executed after completion has
+ been signaled to other threads.
+ """
+ with self._cond:
+ if self._state == self.INIT:
+ self._state = self.RUNNING
+ # Note that nothing waits on RUNNING, so no need to notify
+
+ # We need to release the condition lock before calling out to
+ # prevent deadlocks. Reacquire it immediately afterwards.
+ self._cond.release()
+ try:
+ post_fn = fn()
+ finally:
+ self._cond.acquire()
+ self._state = self.COMPLETE
+ self._cond.notify_all()
+
+ if post_fn is not None:
+ # Release the condition lock before calling out to prevent
+ # deadlocks. Reacquire it immediately afterwards.
+ self._cond.release()
+ try:
+ post_fn()
+ finally:
+ self._cond.acquire()
+ elif self._state == self.RUNNING:
+ self._wait(lambda: self._state == self.RUNNING,
+ '%s has been waiting on another thread to complete '
+ 'for longer than %i seconds'
+ % (self._name, self.LOG_AFTER_WAIT_SECS))
+
+
+class _OrderedTaskRunner(object):
+ """Mixin for a class which executes ordered tasks."""
+
+ def __init__(self, *args, **kwargs):
+ super(_OrderedTaskRunner, self).__init__(*args, **kwargs)
+
+ # Get a list of methods on this object which have the _ordered
+ # attribute
+ self._tasks = [name
+ for (name, member) in inspect.getmembers(self)
+ if inspect.ismethod(member) and
+ getattr(member, '_ordered', False)]
+ self.init_task_states()
+
+ def init_task_states(self):
+ # Note that we don't need to lock this. Once created, the _states dict
+ # is immutable. Get and set are (individually) atomic operations in
+ # Python, and we only set after the dict is fully created.
+ self._states = {task: _OrderedTask(task) for task in self._tasks}
+
+ @staticmethod
+ def decorate_ordered(fn, state, after):
+
+ @functools.wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ # Store the states we started with in case the state wraps on us
+ # while we're sleeping. We must wait and run_once in the same
+ # epoch. If the epoch ended while we were sleeping, run_once will
+ # safely do nothing.
+ states = self._states
+
+ # Wait for the given preceding state to complete
+ if after is not None:
+ states[after].wait_for_completion(state)
+
+ # Run this state
+ states[state].run_once(lambda: fn(self, *args, **kwargs))
+ return wrapper
+
+
+def ordered(after=None):
+ """A method which will be executed as an ordered task. The method will be
+ called exactly once, however many times it is called. If it is called
+ multiple times simultaneously it will only be called once, but all callers
+ will wait until execution is complete.
+
+ If `after` is given, this method will not run until `after` has completed.
+
+ :param after: Optionally, another method decorated with `ordered`. Wait for
+ the completion of `after` before executing this method.
+ """
+ if after is not None:
+ after = after.__name__
+
+ def _ordered(fn):
+ # Set an attribute on the method so we can find it later
+ setattr(fn, '_ordered', True)
+ state = fn.__name__
+
+ return _OrderedTaskRunner.decorate_ordered(fn, state, after)
+ return _ordered
+
+
+class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
"""Server for handling messages.
Connect a transport to a dispatcher that knows how to process the
@@ -94,29 +258,18 @@ class MessageHandlingServer(service.ServiceBase):
self.dispatcher = dispatcher
self.executor = executor
- # NOTE(sileht): we use a lock to protect the state change of the
- # server, we don't want to call stop until the transport driver
- # is fully started. Except for the blocking executor that have
- # start() that doesn't return
- if self.executor != "blocking":
- self._state_cond = threading.Condition()
- self._dummy_cond = False
- else:
- self._state_cond = _utils.DummyCondition()
- self._dummy_cond = True
-
try:
mgr = driver.DriverManager('oslo.messaging.executors',
self.executor)
except RuntimeError as ex:
raise ExecutorLoadFailure(self.executor, ex)
- else:
- self._executor_cls = mgr.driver
- self._executor_obj = None
- self._running = False
+
+ self._executor_cls = mgr.driver
+ self._executor_obj = None
super(MessageHandlingServer, self).__init__()
+ @ordered()
def start(self):
"""Start handling incoming messages.
@@ -131,24 +284,21 @@ class MessageHandlingServer(service.ServiceBase):
choose to dispatch messages in a new thread, coroutine or simply the
current thread.
"""
- if self._executor_obj is not None:
- return
- with self._state_cond:
- if self._executor_obj is not None:
- return
- try:
- listener = self.dispatcher._listen(self.transport)
- except driver_base.TransportDriverError as ex:
- raise ServerListenError(self.target, ex)
- self._executor_obj = self._executor_cls(self.conf, listener,
- self.dispatcher)
- self._executor_obj.start()
- self._running = True
- self._state_cond.notify_all()
+ try:
+ listener = self.dispatcher._listen(self.transport)
+ except driver_base.TransportDriverError as ex:
+ raise ServerListenError(self.target, ex)
+ executor = self._executor_cls(self.conf, listener, self.dispatcher)
+ executor.start()
+ self._executor_obj = executor
if self.executor == 'blocking':
- self._executor_obj.execute()
+ # N.B. This will be executed unlocked and unordered, so
+ # we can't rely on the value of self._executor_obj when this runs.
+ # We explicitly pass the local variable.
+ return lambda: executor.execute()
+ @ordered(after=start)
def stop(self):
"""Stop handling incoming messages.
@@ -157,12 +307,9 @@ class MessageHandlingServer(service.ServiceBase):
some messages, and underlying driver resources associated to this
server are still in use. See 'wait' for more details.
"""
- with self._state_cond:
- if self._executor_obj is not None:
- self._running = False
- self._executor_obj.stop()
- self._state_cond.notify_all()
+ self._executor_obj.stop()
+ @ordered(after=stop)
def wait(self):
"""Wait for message processing to complete.
@@ -173,37 +320,14 @@ class MessageHandlingServer(service.ServiceBase):
Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections).
"""
- with self._state_cond:
- if self._running:
- LOG.warn(_LW("wait() should be called after stop() as it "
- "waits for existing messages to finish "
- "processing"))
- w = timeutils.StopWatch()
- w.start()
- while self._running:
- # NOTE(harlowja): 1.0 seconds was mostly chosen at
- # random, but it seems like a reasonable value to
- # use to avoid spamming the logs with to much
- # information.
- self._state_cond.wait(1.0)
- if self._running and not self._dummy_cond:
- LOG.warn(
- _LW("wait() should have been called"
- " after stop() as wait() waits for existing"
- " messages to finish processing, it has"
- " been %0.2f seconds and stop() still has"
- " not been called"), w.elapsed())
- executor = self._executor_obj
+ try:
+ self._executor_obj.wait()
+ finally:
+ # Close listener connection after processing all messages
+ self._executor_obj.listener.cleanup()
self._executor_obj = None
- if executor is not None:
- # We are the lucky calling thread to wait on the executor to
- # actually finish.
- try:
- executor.wait()
- finally:
- # Close listener connection after processing all messages
- executor.listener.cleanup()
- executor = None
+
+ self.init_task_states()
def reset(self):
"""Reset service.
diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py
index 258dacb..1a2d2aa 100644
--- a/oslo_messaging/tests/rpc/test_server.py
+++ b/oslo_messaging/tests/rpc/test_server.py
@@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
+import eventlet
+import time
import threading
from oslo_config import cfg
@@ -20,6 +22,7 @@ import testscenarios
import mock
import oslo_messaging
+from oslo_messaging import server as server_module
from oslo_messaging.tests import utils as test_utils
load_tests = testscenarios.load_tests_apply_scenarios
@@ -528,3 +531,210 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
TestMultipleServers.generate_scenarios()
+
+class TestServerLocking(test_utils.BaseTestCase):
+ def setUp(self):
+ super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts())
+
+ def _logmethod(name):
+ def method(self):
+ with self._lock:
+ self._calls.append(name)
+ return method
+
+ executors = []
+ class FakeExecutor(object):
+ def __init__(self, *args, **kwargs):
+ self._lock = threading.Lock()
+ self._calls = []
+ self.listener = mock.MagicMock()
+ executors.append(self)
+
+ start = _logmethod('start')
+ stop = _logmethod('stop')
+ wait = _logmethod('wait')
+ execute = _logmethod('execute')
+ self.executors = executors
+
+ self.server = oslo_messaging.MessageHandlingServer(mock.Mock(),
+ mock.Mock())
+ self.server._executor_cls = FakeExecutor
+
+ def test_start_stop_wait(self):
+ # Test a simple execution of start, stop, wait in order
+
+ thread = eventlet.spawn(self.server.start)
+ self.server.stop()
+ self.server.wait()
+
+ self.assertEqual(len(self.executors), 1)
+ executor = self.executors[0]
+ self.assertEqual(executor._calls,
+ ['start', 'execute', 'stop', 'wait'])
+ self.assertTrue(executor.listener.cleanup.called)
+
+ def test_reversed_order(self):
+ # Test that if we call wait, stop, start, these will be correctly
+ # reordered
+
+ wait = eventlet.spawn(self.server.wait)
+ # This is non-deterministic, but there's not a great deal we can do
+ # about that
+ eventlet.sleep(0)
+
+ stop = eventlet.spawn(self.server.stop)
+ eventlet.sleep(0)
+
+ start = eventlet.spawn(self.server.start)
+
+ self.server.wait()
+
+ self.assertEqual(len(self.executors), 1)
+ executor = self.executors[0]
+ self.assertEqual(executor._calls,
+ ['start', 'execute', 'stop', 'wait'])
+
+ def test_wait_for_running_task(self):
+ # Test that if 2 threads call a method simultaneously, both will wait,
+ # but only 1 will call the underlying executor method.
+
+ start_event = threading.Event()
+ finish_event = threading.Event()
+
+ running_event = threading.Event()
+ done_event = threading.Event()
+
+ runner = [None]
+ class SteppingFakeExecutor(self.server._executor_cls):
+ def start(self):
+ # Tell the test which thread won the race
+ runner[0] = eventlet.getcurrent()
+ running_event.set()
+
+ start_event.wait()
+ super(SteppingFakeExecutor, self).start()
+ done_event.set()
+
+ finish_event.wait()
+ self.server._executor_cls = SteppingFakeExecutor
+
+ start1 = eventlet.spawn(self.server.start)
+ start2 = eventlet.spawn(self.server.start)
+
+ # Wait until one of the threads starts running
+ running_event.wait()
+ runner = runner[0]
+ waiter = start2 if runner == start1 else start2
+
+ waiter_finished = threading.Event()
+ waiter.link(lambda _: waiter_finished.set())
+
+ # At this point, runner is running start(), and waiter() is waiting for
+ # it to complete. runner has not yet logged anything.
+ self.assertEqual(1, len(self.executors))
+ executor = self.executors[0]
+
+ self.assertEqual(executor._calls, [])
+ self.assertFalse(waiter_finished.is_set())
+
+ # Let the runner log the call
+ start_event.set()
+ done_event.wait()
+
+ # We haven't signalled completion yet, so execute shouldn't have run
+ self.assertEqual(executor._calls, ['start'])
+ self.assertFalse(waiter_finished.is_set())
+
+ # Let the runner complete
+ finish_event.set()
+ waiter.wait()
+ runner.wait()
+
+ # Check that both threads have finished, start was only called once,
+ # and execute ran
+ self.assertTrue(waiter_finished.is_set())
+ self.assertEqual(executor._calls, ['start', 'execute'])
+
+ def test_state_wrapping(self):
+ # Test that we behave correctly if a thread waits, and the server state
+ # has wrapped when it it next scheduled
+
+ # Ensure that if 2 threads wait for the completion of 'start', the
+ # first will wait until complete_event is signalled, but the second
+ # will continue
+ complete_event = threading.Event()
+ complete_waiting_callback = threading.Event()
+
+ start_state = self.server._states['start']
+ old_wait_for_completion = start_state.wait_for_completion
+ waited = [False]
+ def new_wait_for_completion(*args, **kwargs):
+ if not waited[0]:
+ waited[0] = True
+ complete_waiting_callback.set()
+ complete_event.wait()
+ old_wait_for_completion(*args, **kwargs)
+ start_state.wait_for_completion = new_wait_for_completion
+
+ # thread1 will wait for start to complete until we signal it
+ thread1 = eventlet.spawn(self.server.stop)
+ thread1_finished = threading.Event()
+ thread1.link(lambda _: thread1_finished.set())
+
+ self.server.start()
+ complete_waiting_callback.wait()
+
+ # The server should have started, but stop should not have been called
+ self.assertEqual(1, len(self.executors))
+ self.assertEqual(self.executors[0]._calls, ['start', 'execute'])
+ self.assertFalse(thread1_finished.is_set())
+
+ self.server.stop()
+ self.server.wait()
+
+ # We should have gone through all the states, and thread1 should still
+ # be waiting
+ self.assertEqual(1, len(self.executors))
+ self.assertEqual(self.executors[0]._calls, ['start', 'execute',
+ 'stop', 'wait'])
+ self.assertFalse(thread1_finished.is_set())
+
+ # Start again
+ self.server.start()
+
+ # We should now record 2 executors
+ self.assertEqual(2, len(self.executors))
+ self.assertEqual(self.executors[0]._calls, ['start', 'execute',
+ 'stop', 'wait'])
+ self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
+ self.assertFalse(thread1_finished.is_set())
+
+ # Allow thread1 to complete
+ complete_event.set()
+ thread1_finished.wait()
+
+ # thread1 should now have finished, and stop should not have been
+ # called again on either the first or second executor
+ self.assertEqual(2, len(self.executors))
+ self.assertEqual(self.executors[0]._calls, ['start', 'execute',
+ 'stop', 'wait'])
+ self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
+ self.assertTrue(thread1_finished.is_set())
+
+ @mock.patch.object(server_module._OrderedTask,
+ 'LOG_AFTER_WAIT_SECS', 1)
+ @mock.patch.object(server_module, 'LOG')
+ def test_timeout_logging(self, mock_log):
+ # Test that we generate a log message if we wait longer than
+ # LOG_AFTER_WAIT_SECS
+
+ log_event = threading.Event()
+ mock_log.warn.side_effect = lambda _: log_event.set()
+
+ # Call stop without calling start. We should log a wait after 1 second
+ thread = eventlet.spawn(self.server.stop)
+ log_event.wait()
+
+ # Redundant given that we already waited, but it's nice to assert
+ self.assertTrue(mock_log.warn.called)
+ thread.kill()