summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJonathan Reams <jbreams@mongodb.com>2017-10-16 16:26:02 -0400
committerJonathan Reams <jbreams@mongodb.com>2017-11-06 17:19:10 -0500
commite7837911c89af144fe012e5063f8ca88c4c66956 (patch)
treeeb2c141aa289033a400ede246e3478083e5e81bf /src
parentdc712619bf21f7c577f28b3f8281bf4c25362511 (diff)
downloadmongo-e7837911c89af144fe012e5063f8ca88c4c66956.tar.gz
SERVER-31538 Ensure the ServiceStateMachine always gets cleaned up on error/termination
Diffstat (limited to 'src')
-rw-r--r--src/mongo/transport/service_entry_point_impl.cpp10
-rw-r--r--src/mongo/transport/service_state_machine.cpp313
-rw-r--r--src/mongo/transport/service_state_machine.h85
-rw-r--r--src/mongo/transport/service_state_machine_test.cpp276
4 files changed, 497 insertions, 187 deletions
diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp
index df02f234f11..adbdc48a6ec 100644
--- a/src/mongo/transport/service_entry_point_impl.cpp
+++ b/src/mongo/transport/service_entry_point_impl.cpp
@@ -86,9 +86,9 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) {
const bool quiet = serverGlobalParams.quiet.load();
size_t connectionCount;
+ auto transportMode = _svcCtx->getServiceExecutor()->transportMode();
- auto ssm = ServiceStateMachine::create(
- _svcCtx, session, _svcCtx->getServiceExecutor()->transportMode());
+ auto ssm = ServiceStateMachine::create(_svcCtx, session, transportMode);
{
stdx::lock_guard<decltype(_sessionsMutex)> lk(_sessionsMutex);
connectionCount = _sessions.size() + 1;
@@ -129,7 +129,11 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) {
});
- ssm->scheduleNext();
+ auto ownership = ServiceStateMachine::Ownership::kOwned;
+ if (transportMode == transport::Mode::kSynchronous) {
+ ownership = ServiceStateMachine::Ownership::kStatic;
+ }
+ ssm->start(ownership);
}
void ServiceEntryPointImpl::endAllSessions(transport::Session::TagMask tags) {
diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp
index 47764de2444..263c3c06b2d 100644
--- a/src/mongo/transport/service_state_machine.cpp
+++ b/src/mongo/transport/service_state_machine.cpp
@@ -32,6 +32,7 @@
#include "mongo/transport/service_state_machine.h"
+#include "mongo/config.h"
#include "mongo/db/client.h"
#include "mongo/db/dbmessage.h"
#include "mongo/db/stats/counters.h"
@@ -89,91 +90,128 @@ bool setExhaustMessage(Message* m, const DbResponse& dbresponse) {
} // namespace
-using transport::TransportLayer;
using transport::ServiceExecutor;
+using transport::TransportLayer;
/*
* This class wraps up the logic for swapping/unswapping the Client during runNext().
+ *
+ * In debug builds this also ensures that only one thread is working on the SSM at once.
*/
class ServiceStateMachine::ThreadGuard {
ThreadGuard(ThreadGuard&) = delete;
ThreadGuard& operator=(ThreadGuard&) = delete;
public:
- explicit ThreadGuard(ServiceStateMachine* ssm)
- : _ssm{ssm},
- _haveTakenOwnership{!_ssm->_isOwned.test_and_set()},
- _oldThreadName{getThreadName().toString()} {
- const auto currentOwningThread = _ssm->_currentOwningThread.load();
- const auto currentThreadId = stdx::this_thread::get_id();
-
- // If this is true, then we are the "owner" of the Client and we should swap the
- // client/thread name before doing any work.
- if (_haveTakenOwnership) {
- _ssm->_currentOwningThread.store(currentThreadId);
+ explicit ThreadGuard(ServiceStateMachine* ssm) : _ssm{ssm} {
+ auto owned = _ssm->_owned.compareAndSwap(Ownership::kUnowned, Ownership::kOwned);
+ if (owned == Ownership::kStatic) {
+ dassert(haveClient());
+ dassert(Client::getCurrent() == _ssm->_dbClientPtr);
+ _haveTakenOwnership = true;
+ return;
+ }
+
+#ifdef MONGO_CONFIG_DEBUG_BUILD
+ invariant(owned == Ownership::kUnowned);
+ _ssm->_owningThread.store(stdx::this_thread::get_id());
+#endif
- // Set up the thread name
+ // Set up the thread name
+ auto oldThreadName = getThreadName();
+ if (oldThreadName != _ssm->_threadName) {
+ _ssm->_oldThreadName = getThreadName().toString();
setThreadName(_ssm->_threadName);
+ }
+
+ // Swap the current Client so calls to cc() work as expected
+ Client::setCurrent(std::move(_ssm->_dbClient));
+ _haveTakenOwnership = true;
+ }
- // These are sanity checks to make sure that the Client is what we expect it to be
- invariant(!haveClient());
- invariant(_ssm->_dbClient.get() == _ssm->_dbClientPtr);
+ // Constructing from a moved ThreadGuard invalidates the other thread guard.
+ ThreadGuard(ThreadGuard&& other)
+ : _ssm(other._ssm), _haveTakenOwnership(other._haveTakenOwnership) {
+ other._haveTakenOwnership = false;
+ }
- // Swap the current Client so calls to cc() work as expected
- Client::setCurrent(std::move(_ssm->_dbClient));
- } else if (currentOwningThread != currentThreadId) {
- // If the currentOwningThread does not equal the currentThreadId, then another thread
- // currently "owns" the Client and we should reschedule ourself.
- _okayToRunNext = false;
+ ThreadGuard& operator=(ThreadGuard&& other) {
+ if (this != &other) {
+ _ssm = other._ssm;
+ _haveTakenOwnership = other._haveTakenOwnership;
+ other._haveTakenOwnership = false;
}
- }
+ return *this;
+ };
+
+ ThreadGuard() = delete;
~ThreadGuard() {
- // If we are not the owner of the SSM, then do nothing. Something higher up the call stack
- // will have to clean up.
- if (!_haveTakenOwnership)
- return;
+ if (_haveTakenOwnership)
+ release();
+ }
- // If the session has ended, then assume that it's unsafe to do anything but call the
- // cleanup hook.
+ explicit operator bool() const {
+#ifdef MONGO_CONFIG_DEBUG_BUILD
+ if (_haveTakenOwnership) {
+ invariant(_ssm->_owned.load() != Ownership::kUnowned);
+ invariant(_ssm->_owningThread.load() == stdx::this_thread::get_id());
+ return true;
+ } else {
+ return false;
+ }
+#else
+ return _haveTakenOwnership;
+#endif
+ }
+
+ void markStaticOwnership() {
+ dassert(static_cast<bool>(*this));
+ _ssm->_owned.store(Ownership::kStatic);
+ }
+
+ void release() {
+ auto owned = _ssm->_owned.load();
+
+#ifdef MONGO_CONFIG_DEBUG_BUILD
+ dassert(_haveTakenOwnership);
+ dassert(owned != Ownership::kUnowned);
+ dassert(_ssm->_owningThread.load() == stdx::this_thread::get_id());
+#endif
+ if (owned != Ownership::kStatic) {
+ if (haveClient()) {
+ _ssm->_dbClient = Client::releaseCurrent();
+ }
+
+ if (!_ssm->_oldThreadName.empty()) {
+ setThreadName(_ssm->_oldThreadName);
+ }
+ }
+
+ // If the session has ended, then it's unsafe to do anything but call the cleanup hook.
if (_ssm->state() == State::Ended) {
- // The cleanup hook may change as soon as we unlock the mutex, so move it out of the
- // ssm before unlocking the lock.
+ // The cleanup hook gets moved out of _ssm->_cleanupHook so that it can only be called
+ // once.
auto cleanupHook = std::move(_ssm->_cleanupHook);
if (cleanupHook)
cleanupHook();
+ // It's very important that the Guard returns here and that the SSM's state does not
+ // get modified in any way after the cleanup hook is called.
return;
}
- // Otherwise swap thread locals and thread names back into the SSM so its ready for the
- // next run.
- if (haveClient()) {
- _ssm->_dbClient = Client::releaseCurrent();
+ _haveTakenOwnership = false;
+ // If owned != Ownership::kOwned here then it can only equal Ownership::kStatic and we
+ // should just return
+ if (owned == Ownership::kOwned) {
+ _ssm->_owned.store(Ownership::kUnowned);
}
- setThreadName(_oldThreadName);
- _ssm->_isOwned.clear();
- }
-
- // This bool operator reflects whether the ThreadGuard was able to take ownership of the thread
- // either higher up the call chain, or in this call. If this returns false, then it is not safe
- // to assume the thread has been setup correctly, or that any mutable state of the SSM is safe
- // to access except for the current _state value.
- explicit operator bool() const {
- return _okayToRunNext;
- }
-
- // Returns whether the thread guard is the owner of the SSM's state or not. Callers can use this
- // to determine whether their callchain is recursive.
- bool isOwner() const {
- return _haveTakenOwnership;
}
private:
ServiceStateMachine* _ssm;
- bool _haveTakenOwnership;
- const std::string _oldThreadName;
- bool _okayToRunNext = true;
+ bool _haveTakenOwnership = false;
};
std::shared_ptr<ServiceStateMachine> ServiceStateMachine::create(ServiceContext* svcContext,
@@ -192,27 +230,52 @@ ServiceStateMachine::ServiceStateMachine(ServiceContext* svcContext,
_sessionHandle(session),
_dbClient{svcContext->makeClient("conn", std::move(session))},
_dbClientPtr{_dbClient.get()},
- _threadName{str::stream() << "conn" << _session()->id()},
- _currentOwningThread{stdx::this_thread::get_id()} {}
+ _threadName{str::stream() << "conn" << _session()->id()} {}
const transport::SessionHandle& ServiceStateMachine::_session() const {
return _sessionHandle;
}
+void ServiceStateMachine::_sourceMessage(ThreadGuard guard) {
+ invariant(_inMessage.empty());
+ auto ticket = _session()->sourceMessage(&_inMessage);
+
+ _state.store(State::SourceWait);
+ guard.release();
+
+ if (_transportMode == transport::Mode::kSynchronous) {
+ _sourceCallback([this](auto ticket) {
+ MONGO_IDLE_THREAD_BLOCK;
+ return _session()->getTransportLayer()->wait(std::move(ticket));
+ }(std::move(ticket)));
+ } else if (_transportMode == transport::Mode::kAsynchronous) {
+ _session()->getTransportLayer()->asyncWait(
+ std::move(ticket), [this](Status status) { _sourceCallback(status); });
+ }
+}
+
+void ServiceStateMachine::_sinkMessage(ThreadGuard guard, Message toSink) {
+ // Sink our response to the client
+ auto ticket = _session()->sinkMessage(toSink);
+
+ _state.store(State::SinkWait);
+ guard.release();
+
+ if (_transportMode == transport::Mode::kSynchronous) {
+ _sinkCallback(_session()->getTransportLayer()->wait(std::move(ticket)));
+ } else if (_transportMode == transport::Mode::kAsynchronous) {
+ _session()->getTransportLayer()->asyncWait(
+ std::move(ticket), [this](Status status) { _sinkCallback(status); });
+ }
+}
+
void ServiceStateMachine::_sourceCallback(Status status) {
// The first thing to do is create a ThreadGuard which will take ownership of the SSM in this
// thread.
ThreadGuard guard(this);
- // If the guard wasn't able to take ownership of the thread, then reschedule this call to
- // runNext() so that this thread can do other useful work with its timeslice instead of going
- // to sleep while waiting for the SSM to be released.
- if (!guard) {
- return _scheduleFunc([this, status] { _sourceCallback(status); },
- ServiceExecutor::kDeferredTask);
- }
// Make sure we just called sourceMessage();
- invariant(state() == State::SourceWait);
+ dassert(state() == State::SourceWait);
auto remote = _session()->remote();
if (status.isOK()) {
@@ -225,7 +288,7 @@ void ServiceStateMachine::_sourceCallback(Status status) {
// If this callback doesn't own the ThreadGuard, then we're being called recursively,
// and the executor shouldn't start a new thread to process the message - it can use this
// one just after this returns.
- return scheduleNext(ServiceExecutor::kMayRecurse);
+ return _scheduleNextWithGuard(std::move(guard), ServiceExecutor::kMayRecurse);
} else if (ErrorCodes::isInterruption(status.code()) ||
ErrorCodes::isNetworkError(status.code())) {
LOG(2) << "Session from " << remote << " encountered a network error during SourceMessage";
@@ -242,22 +305,15 @@ void ServiceStateMachine::_sourceCallback(Status status) {
// There was an error receiving a message from the client and we've already printed the error
// so call runNextInGuard() to clean up the session without waiting.
- _runNextInGuard(guard);
+ _runNextInGuard(std::move(guard));
}
void ServiceStateMachine::_sinkCallback(Status status) {
// The first thing to do is create a ThreadGuard which will take ownership of the SSM in this
// thread.
ThreadGuard guard(this);
- // If the guard wasn't able to take ownership of the thread, then reschedule this call to
- // runNext() so that this thread can do other useful work with its timeslice instead of going
- // to sleep while waiting for the SSM to be released.
- if (!guard) {
- return _scheduleFunc([this, status] { _sinkCallback(status); },
- ServiceExecutor::kDeferredTask);
- }
- invariant(state() == State::SinkWait);
+ dassert(state() == State::SinkWait);
// If there was an error sinking the message to the client, then we should print an error and
// end the session. No need to unwind the stack, so this will runNextInGuard() and return.
@@ -268,22 +324,19 @@ void ServiceStateMachine::_sinkCallback(Status status) {
log() << "Error sending response to client: " << status << ". Ending connection from "
<< _session()->remote() << " (connection id: " << _session()->id() << ")";
_state.store(State::EndSession);
- return _runNextInGuard(guard);
+ return _runNextInGuard(std::move(guard));
} else if (_inExhaust) {
_state.store(State::Process);
} else {
_state.store(State::Source);
}
- return scheduleNext(ServiceExecutor::kDeferredTask | ServiceExecutor::kMayYieldBeforeSchedule);
+ return _scheduleNextWithGuard(std::move(guard),
+ ServiceExecutor::kDeferredTask |
+ ServiceExecutor::kMayYieldBeforeSchedule);
}
-void ServiceStateMachine::_processMessage(ThreadGuard& guard) {
- // This may have been called just after a failure to source a message, in which case this
- // should return early so the session can be cleaned up.
- if (state() != State::Process) {
- return;
- }
+void ServiceStateMachine::_processMessage(ThreadGuard guard) {
invariant(!_inMessage.empty());
auto& compressorMgr = MessageCompressorManager::forSession(_session());
@@ -332,42 +385,22 @@ void ServiceStateMachine::_processMessage(ThreadGuard& guard) {
uassertStatusOK(swm.getStatus());
toSink = swm.getValue();
}
+ _sinkMessage(std::move(guard), std::move(toSink));
- // Sink our response to the client
- auto ticket = _session()->sinkMessage(toSink);
-
- _state.store(State::SinkWait);
- if (_transportMode == transport::Mode::kSynchronous) {
- _sinkCallback(_session()->getTransportLayer()->wait(std::move(ticket)));
- } else if (_transportMode == transport::Mode::kAsynchronous) {
- _session()->getTransportLayer()->asyncWait(
- std::move(ticket), [this](Status status) { _sinkCallback(status); });
- } else {
- MONGO_UNREACHABLE;
- }
} else {
_state.store(State::Source);
_inMessage.reset();
- return scheduleNext(ServiceExecutor::kDeferredTask);
+ return _scheduleNextWithGuard(std::move(guard), ServiceExecutor::kDeferredTask);
}
}
void ServiceStateMachine::runNext() {
- // The first thing to do is create a ThreadGuard which will take ownership of the SSM in this
- // thread.
- ThreadGuard guard(this);
- // If the guard wasn't able to take ownership of the thread, then reschedule this call to
- // runNext() so that this thread can do other useful work with its timeslice instead of going
- // to sleep while waiting for the SSM to be released.
- if (!guard) {
- return scheduleNext(ServiceExecutor::kDeferredTask);
- }
- return _runNextInGuard(guard);
+ return _runNextInGuard(ThreadGuard(this));
}
-void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) {
+void ServiceStateMachine::_runNextInGuard(ThreadGuard guard) {
auto curState = state();
- invariant(curState != State::Ended);
+ dassert(curState != State::Ended);
// If this is the first run of the SSM, then update its state to Source
if (curState == State::Created) {
@@ -376,29 +409,14 @@ void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) {
}
// Make sure the current Client got set correctly
- invariant(Client::getCurrent() == _dbClientPtr);
+ dassert(Client::getCurrent() == _dbClientPtr);
try {
switch (curState) {
- case State::Source: {
- invariant(_inMessage.empty());
-
- auto ticket = _session()->sourceMessage(&_inMessage);
- _state.store(State::SourceWait);
- if (_transportMode == transport::Mode::kSynchronous) {
- _sourceCallback([this](auto ticket) {
- MONGO_IDLE_THREAD_BLOCK;
- return _session()->getTransportLayer()->wait(std::move(ticket));
- }(std::move(ticket)));
- } else if (_transportMode == transport::Mode::kAsynchronous) {
- _session()->getTransportLayer()->asyncWait(
- std::move(ticket), [this](Status status) { _sourceCallback(status); });
- } else {
- MONGO_UNREACHABLE;
- }
+ case State::Source:
+ _sourceMessage(std::move(guard));
break;
- }
case State::Process:
- _processMessage(guard);
+ _processMessage(std::move(guard));
break;
case State::EndSession:
// This will get handled below in an if statement. That way if an error occurs
@@ -409,7 +427,10 @@ void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) {
}
if (state() == State::EndSession) {
- _cleanupSession(guard);
+ if (!guard) {
+ guard = ThreadGuard(this);
+ }
+ _cleanupSession(std::move(guard));
}
return;
@@ -421,12 +442,40 @@ void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) {
quickExit(EXIT_UNCAUGHT);
}
+ if (!guard) {
+ guard = ThreadGuard(this);
+ }
_state.store(State::EndSession);
- _cleanupSession(guard);
+ _cleanupSession(std::move(guard));
}
-void ServiceStateMachine::scheduleNext(ServiceExecutor::ScheduleFlags flags) {
- _scheduleFunc([this] { runNext(); }, flags);
+void ServiceStateMachine::start(Ownership ownershipModel) {
+ _scheduleNextWithGuard(
+ ThreadGuard(this), transport::ServiceExecutor::kEmptyFlags, ownershipModel);
+}
+
+void ServiceStateMachine::_scheduleNextWithGuard(ThreadGuard guard,
+ transport::ServiceExecutor::ScheduleFlags flags,
+ Ownership ownershipModel) {
+ auto func = [ ssm = shared_from_this(), ownershipModel ] {
+ ThreadGuard guard(ssm.get());
+ if (ownershipModel == Ownership::kStatic)
+ guard.markStaticOwnership();
+ ssm->_runNextInGuard(std::move(guard));
+ };
+ guard.release();
+ Status status = _serviceContext->getServiceExecutor()->schedule(std::move(func), flags);
+ if (status.isOK()) {
+ return;
+ }
+
+ // We've had an error, reacquire the ThreadGuard and destroy the SSM
+ ThreadGuard terminateGuard(this);
+
+ // The service executor failed to schedule the task. This could for example be that we failed
+ // to start a worker thread. Terminate this connection to leave the system in a valid state.
+ _terminateAndLogIfError(status);
+ _cleanupSession(std::move(terminateGuard));
}
void ServiceStateMachine::terminate() {
@@ -449,7 +498,7 @@ void ServiceStateMachine::terminateIfTagsDontMatch(transport::Session::TagMask t
return;
}
- _session()->getTransportLayer()->end(_session());
+ terminate();
}
void ServiceStateMachine::setCleanupHook(stdx::function<void()> hook) {
@@ -468,7 +517,7 @@ void ServiceStateMachine::_terminateAndLogIfError(Status status) {
}
}
-void ServiceStateMachine::_cleanupSession(ThreadGuard& guard) {
+void ServiceStateMachine::_cleanupSession(ThreadGuard guard) {
_state.store(State::Ended);
_inMessage.reset();
diff --git a/src/mongo/transport/service_state_machine.h b/src/mongo/transport/service_state_machine.h
index 35812b86940..60c6981acaf 100644
--- a/src/mongo/transport/service_state_machine.h
+++ b/src/mongo/transport/service_state_machine.h
@@ -31,6 +31,7 @@
#include <atomic>
#include "mongo/base/status.h"
+#include "mongo/config.h"
#include "mongo/db/service_context.h"
#include "mongo/platform/atomic_word.h"
#include "mongo/stdx/functional.h"
@@ -73,12 +74,12 @@ public:
transport::Mode transportMode);
/*
- * Any state may transition to EndSession in case of an error, otherwise the valid state
- * transitions are:
- * Source -> SourceWait -> Process -> SinkWait -> Source (standard RPC)
- * Source -> SourceWait -> Process -> SinkWait -> Process -> SinkWait ... (exhaust)
- * Source -> SourceWait -> Process -> Source (fire-and-forget)
- */
+ * Any state may transition to EndSession in case of an error, otherwise the valid state
+ * transitions are:
+ * Source -> SourceWait -> Process -> SinkWait -> Source (standard RPC)
+ * Source -> SourceWait -> Process -> SinkWait -> Process -> SinkWait ... (exhaust)
+ * Source -> SourceWait -> Process -> Source (fire-and-forget)
+ */
enum class State {
Created, // The session has been created, but no operations have been performed yet
Source, // Request a new Message from the network to handle
@@ -91,6 +92,18 @@ public:
};
/*
+ * When start() is called with Ownership::kOwned, the SSM will swap the Client/thread name
+ * whenever it runs a stage of the state machine, and then unswap them out when leaving the SSM.
+ *
+ * With Ownership::kStatic, it will assume that the SSM will only ever be run from one thread,
+ * and that thread will not be used for other SSM's. It will swap in the Client/thread name
+ * for the first run and leave them in place.
+ *
+ * kUnowned is used internally to mark that the SSM is inactive.
+ */
+ enum class Ownership { kUnowned, kOwned, kStatic };
+
+ /*
* runNext() will run the current state of the state machine. It also handles all the error
* handling and state management for requests.
*
@@ -104,14 +117,12 @@ public:
void runNext();
/*
- * scheduleNext() schedules a call to runNext() in the future. This will be implemented with
- * an async TransportLayer.
+ * start() schedules a call to runNext() in the future.
*
* It is guaranteed to unwind the stack, and not call runNext() recursively, but is not
- * guaranteed that runNext() will run after this returns.
+ * guaranteed that runNext() will run after this return
*/
- void scheduleNext(
- transport::ServiceExecutor::ScheduleFlags flags = transport::ServiceExecutor::kEmptyFlags);
+ void start(Ownership ownershipModel);
/*
* Gets the current state of connection for testing/diagnostic purposes.
@@ -147,29 +158,21 @@ private:
friend class ThreadGuard;
/*
- * Terminates the associated transport Session if status indicate error.
- *
- * This will not block on the session terminating cleaning itself up, it returns immediately.
- */
+ * Terminates the associated transport Session if status indicate error.
+ *
+ * This will not block on the session terminating cleaning itself up, it returns immediately.
+ */
void _terminateAndLogIfError(Status status);
/*
- * This and scheduleFunc() are helper functions to schedule tasks on the serviceExecutor
- * while maintaining a shared_ptr copy to anchor the lifetime of the SSM while waiting for
- * callbacks to run.
- */
- template <typename Func>
- void _scheduleFunc(Func&& func, transport::ServiceExecutor::ScheduleFlags flags) {
- Status status = _serviceContext->getServiceExecutor()->schedule(
- [ func = std::move(func), anchor = shared_from_this() ] { func(); }, flags);
- if (!status.isOK()) {
- // The service executor failed to schedule the task
- // This could for example be that we failed to start
- // a worker thread. Terminate this connection to
- // leave the system in a valid state.
- _terminateAndLogIfError(status);
- }
- }
+ * This is a helper function to schedule tasks on the serviceExecutor maintaining a shared_ptr
+ * copy to anchor the lifetime of the SSM while waiting for callbacks to run.
+ *
+ * If scheduling the function fails, the SSM will be terminated and cleaned up immediately
+ */
+ void _scheduleNextWithGuard(ThreadGuard guard,
+ transport::ServiceExecutor::ScheduleFlags flags,
+ Ownership ownershipModel = Ownership::kOwned);
/*
* Gets the transport::Session associated with this connection
@@ -182,13 +185,13 @@ private:
* runNext() and already own a ThreadGuard, they should call this with that guard as the
* argument.
*/
- void _runNextInGuard(ThreadGuard& guard);
+ void _runNextInGuard(ThreadGuard guard);
/*
* This function actually calls into the database and processes a request. It's broken out
* into its own inline function for better readability.
*/
- inline void _processMessage(ThreadGuard& guard);
+ inline void _processMessage(ThreadGuard guard);
/*
* These get called by the TransportLayer when requested network I/O has completed.
@@ -197,9 +200,16 @@ private:
void _sinkCallback(Status status);
/*
+ * Source/Sink message from the TransportLayer. These will invalidate the ThreadGuard just
+ * before waiting on the TL.
+ */
+ void _sourceMessage(ThreadGuard guard);
+ void _sinkMessage(ThreadGuard guard, Message toSink);
+
+ /*
* Releases all the resources associated with the session and call the cleanupHook.
*/
- void _cleanupSession(ThreadGuard& guard);
+ void _cleanupSession(ThreadGuard guard);
AtomicWord<State> _state{State::Created};
@@ -218,8 +228,11 @@ private:
boost::optional<MessageCompressorId> _compressorId;
Message _inMessage;
- AtomicWord<stdx::thread::id> _currentOwningThread;
- std::atomic_flag _isOwned = ATOMIC_FLAG_INIT; // NOLINT
+ AtomicWord<Ownership> _owned{Ownership::kUnowned};
+#if MONGO_CONFIG_DEBUG_BUILD
+ AtomicWord<stdx::thread::id> _owningThread;
+#endif
+ std::string _oldThreadName;
};
template <typename T>
diff --git a/src/mongo/transport/service_state_machine_test.cpp b/src/mongo/transport/service_state_machine_test.cpp
index 646431580f7..0e4ca8c15c1 100644
--- a/src/mongo/transport/service_state_machine_test.cpp
+++ b/src/mongo/transport/service_state_machine_test.cpp
@@ -39,7 +39,7 @@
#include "mongo/transport/mock_session.h"
#include "mongo/transport/mock_ticket.h"
#include "mongo/transport/service_entry_point.h"
-#include "mongo/transport/service_executor_noop.h"
+#include "mongo/transport/service_executor.h"
#include "mongo/transport/service_state_machine.h"
#include "mongo/transport/transport_layer_mock.h"
#include "mongo/unittest/unittest.h"
@@ -51,6 +51,11 @@
namespace mongo {
namespace {
+inline std::string stateToString(ServiceStateMachine::State state) {
+ std::string ret = str::stream() << state;
+ return ret;
+}
+
class MockSEP : public ServiceEntryPoint {
public:
virtual ~MockSEP() = default;
@@ -122,9 +127,9 @@ public:
return TransportLayer::TicketSessionClosedStatus;
}
- if (_nextMessage) {
- *message = *_nextMessage;
- }
+ OpMsgBuilder builder;
+ builder.setBody(BSON("ping" << 1));
+ *message = builder.finish();
return TransportLayerMock::sourceMessage(session, message, expiration);
}
@@ -154,9 +159,10 @@ public:
ASSERT_EQ(_ssm->state(),
_lastTicketSource ? ServiceStateMachine::State::SourceWait
: ServiceStateMachine::State::SinkWait);
- std::stringstream ss;
- ss << _ssm->state();
- log() << "In wait. ssm state: " << ss.str();
+
+ log() << "In wait. ssm state: " << stateToString(_ssm->state());
+ if (_waitHook)
+ _waitHook();
return TransportLayerMock::wait(std::move(ticket));
}
@@ -164,10 +170,6 @@ public:
MONGO_UNREACHABLE;
}
- void setNextMessage(Message&& message) {
- _nextMessage = std::move(message);
- }
-
void setSSM(ServiceStateMachine* ssm) {
_ssm = ssm;
}
@@ -190,14 +192,18 @@ public:
return _ranSource;
}
+ void setWaitHook(stdx::function<void()> hook) {
+ _waitHook = std::move(hook);
+ }
+
private:
bool _lastTicketSource = true;
bool _ranSink = false;
bool _ranSource = false;
- boost::optional<Message> _nextMessage;
FailureMode _nextShouldFail = Nothing;
Message _lastSunk;
ServiceStateMachine* _ssm;
+ stdx::function<void()> _waitHook;
};
Message buildRequest(BSONObj input) {
@@ -206,6 +212,61 @@ Message buildRequest(BSONObj input) {
return builder.finish();
}
+class MockServiceExecutor : public ServiceExecutor {
+public:
+ explicit MockServiceExecutor(ServiceContext* ctx) {}
+
+ using ScheduleHook = stdx::function<bool(Task)>;
+
+ Status start() override {
+ return Status::OK();
+ }
+ Status shutdown(Milliseconds timeout) override {
+ return Status::OK();
+ }
+ Status schedule(Task task, ScheduleFlags flags) override {
+ if (!_scheduleHook) {
+ return Status::OK();
+ } else {
+ return _scheduleHook(std::move(task)) ? Status::OK() : Status{ErrorCodes::InternalError,
+ "Hook returned error!"};
+ }
+ }
+
+ Mode transportMode() const override {
+ return Mode::kSynchronous;
+ }
+
+ void appendStats(BSONObjBuilder* bob) const override {}
+
+ void setScheduleHook(ScheduleHook hook) {
+ _scheduleHook = std::move(hook);
+ }
+
+private:
+ ScheduleHook _scheduleHook;
+};
+
+class SimpleEvent {
+public:
+ void signal() {
+ stdx::unique_lock<stdx::mutex> lk(_mutex);
+ _signaled = true;
+ _cond.notify_one();
+ }
+
+ void wait() {
+ stdx::unique_lock<stdx::mutex> lk(_mutex);
+ _cond.wait(lk, [this] { return _signaled; });
+ _signaled = false;
+ }
+
+private:
+ stdx::mutex _mutex;
+ stdx::condition_variable _cond;
+ bool _signaled = false;
+};
+
using State = ServiceStateMachine::State;
class ServiceStateMachineFixture : public unittest::Test {
@@ -223,7 +284,9 @@ protected:
_sep = sep.get();
sc->setServiceEntryPoint(std::move(sep));
- sc->setServiceExecutor(stdx::make_unique<ServiceExecutorNoop>(sc));
+ auto se = stdx::make_unique<MockServiceExecutor>(sc);
+ _sexec = se.get();
+ sc->setServiceExecutor(std::move(se));
auto tl = stdx::make_unique<MockTL>();
_tl = tl.get();
@@ -236,7 +299,7 @@ protected:
}
void tearDown() override {
- getGlobalServiceContext()->getTransportLayer()->shutdown();
+ _tl->shutdown();
}
void runPingTest(State first, State second);
@@ -244,14 +307,13 @@ protected:
MockTL* _tl;
MockSEP* _sep;
+ MockServiceExecutor* _sexec;
SessionHandle _session;
std::shared_ptr<ServiceStateMachine> _ssm;
bool _ranHandler;
};
void ServiceStateMachineFixture::runPingTest(State first, State second) {
- _tl->setNextMessage(buildRequest(BSON("ping" << 1)));
-
ASSERT_FALSE(haveClient());
ASSERT_EQ(_ssm->state(), State::Created);
log() << "run next";
@@ -329,5 +391,187 @@ TEST_F(ServiceStateMachineFixture, TestSessionCleanupOnDestroy) {
ASSERT_TRUE(hookRan);
}
+// This tests that SSMs that fail to schedule their first task get cleaned up correctly.
+// (i.e. we couldn't create a worker thread after accept()).
+TEST_F(ServiceStateMachineFixture, ScheduleFailureDuringCreateCleanup) {
+ _sexec->setScheduleHook([](auto) { return false; });
+ // Set a cleanup hook so we know that the cleanup hook actually gets run when the session
+ // is destroyed
+ bool hookRan = false;
+ _ssm->setCleanupHook([&hookRan] { hookRan = true; });
+
+ _ssm->start(ServiceStateMachine::Ownership::kOwned);
+ ASSERT_EQ(State::Ended, _ssm->state());
+ ASSERT_EQ(_ssm.use_count(), 1);
+ ASSERT_TRUE(hookRan);
+}
+
+// This tests that calling terminate() actually ends and cleans up the SSM during all the
+// states.
+TEST_F(ServiceStateMachineFixture, TerminateWorksForAllStates) {
+ SimpleEvent hookRan, okayToContinue;
+
+ auto cleanupHook = [&hookRan] {
+ log() << "Cleaning up session";
+ hookRan.signal();
+ };
+
+ // This is a shared hook between the executor/TL that lets us notify the test that the SSM
+ // has reached a certain state and then gets terminated during that state.
+ State waitFor = State::Created;
+ SimpleEvent atDesiredState;
+ auto waitForHook = [this, &waitFor, &atDesiredState, &okayToContinue]() {
+ log() << "Checking for wakeup at " << stateToString(_ssm->state()) << ". Expecting "
+ << stateToString(waitFor);
+ if (_ssm->state() == waitFor) {
+ atDesiredState.signal();
+ okayToContinue.wait();
+ }
+ };
+
+ // This wraps the waitForHook so that schedules always succeed.
+ _sexec->setScheduleHook([waitForHook](auto) {
+ waitForHook();
+ return true;
+ });
+
+ // This just lets us intercept calls to _tl->wait() and terminate during them.
+ _tl->setWaitHook(waitForHook);
+
+ // Run this same test for each state.
+ auto states = {State::Source, State::SourceWait, State::Process, State::SinkWait};
+ for (const auto testState : states) {
+ log() << "Testing termination during " << stateToString(testState);
+
+ // Reset the _ssm to a fresh SSM and reset our tracking variables.
+ _ssm = ServiceStateMachine::create(
+ getGlobalServiceContext(), _tl->createSession(), transport::Mode::kSynchronous);
+ _tl->setSSM(_ssm.get());
+ _ssm->setCleanupHook(cleanupHook);
+
+ waitFor = testState;
+ // This is a dummy thread that just advances the SSM while we track its state/kill it
+ stdx::thread runner([ssm = _ssm] {
+ while (ssm->state() != State::Ended) {
+ ssm->runNext();
+ }
+ });
+
+ // Wait for the SSM to advance to the expected state
+ atDesiredState.wait();
+ log() << "Terminating session at " << stateToString(_ssm->state());
+
+ // Terminate the SSM
+ _ssm->terminate();
+
+ // Notify the waitForHook to continue and end the session
+ okayToContinue.signal();
+
+ // Wait for the SSM to terminate and the thread to end.
+ hookRan.wait();
+ runner.join();
+
+ // Verify that the SSM terminated and is in the correct state
+ ASSERT_EQ(State::Ended, _ssm->state());
+ ASSERT_EQ(_ssm.use_count(), 1);
+ }
+}
+
+// This tests that calling terminate() actually ends and cleans up the SSM during all states, and
+// with schedule() returning an error for each state.
+TEST_F(ServiceStateMachineFixture, TerminateWorksForAllStatesWithScheduleFailure) {
+ // Set a cleanup hook so we know that the cleanup hook actually gets run when the session
+ // is destroyed
+ SimpleEvent hookRan, okayToContinue;
+ bool scheduleFailed = false;
+
+ auto cleanupHook = [&hookRan] {
+ log() << "Cleaning up session";
+ hookRan.signal();
+ };
+
+ // This is a shared hook between the executor/TL that lets us notify the test that the SSM
+ // has reached a certain state and then gets terminated during that state.
+ State waitFor = State::Created;
+ SimpleEvent atDesiredState;
+ auto waitForHook = [this, &waitFor, &scheduleFailed, &okayToContinue, &atDesiredState]() {
+ log() << "Checking for wakeup at " << stateToString(_ssm->state()) << ". Expecting "
+ << stateToString(waitFor);
+ if (_ssm->state() == waitFor) {
+ atDesiredState.signal();
+ okayToContinue.wait();
+ scheduleFailed = true;
+ return false;
+ }
+ return true;
+ };
+
+ _sexec->setScheduleHook([waitForHook](auto) { return waitForHook(); });
+ // This wraps the waitForHook and discards its return status.
+ _tl->setWaitHook([waitForHook] { waitForHook(); });
+
+ auto states = {State::Source, State::SourceWait, State::Process, State::SinkWait};
+ for (const auto testState : states) {
+ log() << "Testing termination during " << stateToString(testState);
+ _ssm = ServiceStateMachine::create(
+ getGlobalServiceContext(), _tl->createSession(), transport::Mode::kSynchronous);
+ _tl->setSSM(_ssm.get());
+ scheduleFailed = false;
+ _ssm->setCleanupHook(cleanupHook);
+
+ waitFor = testState;
+ // This is a dummy thread that just advances the SSM while we track its state/kill it
+ stdx::thread runner([ ssm = _ssm, &scheduleFailed ] {
+ while (ssm->state() != State::Ended && !scheduleFailed) {
+ ssm->runNext();
+ }
+ });
+
+ // Wait for the SSM to advance to the expected state
+ atDesiredState.wait();
+ ASSERT_EQ(_ssm->state(), testState);
+ log() << "Terminating session at " << stateToString(_ssm->state());
+
+ // Terminate the SSM
+ _ssm->terminate();
+
+ // Notify the waitForHook to continue and end the session
+ okayToContinue.signal();
+ hookRan.wait();
+ runner.join();
+
+ // Verify that the SSM terminated and is in the correct state
+ ASSERT_EQ(State::Ended, _ssm->state());
+ ASSERT_EQ(_ssm.use_count(), 1);
+ }
+}
+
+// This makes sure that the SSM can run recursively by forcing the ServiceExecutor to run everything
+// recursively
+TEST_F(ServiceStateMachineFixture, SSMRunsRecursively) {
+ // This lets us force the SSM to only run once. After sinking the first response, the next call
+ // to sourceMessage will return with an error.
+ _tl->setWaitHook([this] {
+ if (_ssm->state() == State::SinkWait) {
+ _tl->setNextFailure();
+ }
+ });
+
+ // The scheduleHook just runs the task, effectively making this a recursive executor.
+ int recursionDepth = 0;
+ _sexec->setScheduleHook([&recursionDepth](auto task) {
+ log() << "running task in executor. depth: " << ++recursionDepth;
+ task();
+ return true;
+ });
+
+ _ssm->runNext();
+ // Check that the SSM actually ran, is ended, and actually ran recursively
+ ASSERT_EQ(recursionDepth, 2);
+ ASSERT_TRUE(_tl->ranSource());
+ ASSERT_TRUE(_tl->ranSink());
+ ASSERT_EQ(_ssm->state(), State::Ended);
+}
+
} // namespace
} // namespace mongo