diff options
author | Jonathan Reams <jbreams@mongodb.com> | 2017-10-16 16:26:02 -0400 |
---|---|---|
committer | Jonathan Reams <jbreams@mongodb.com> | 2017-11-06 17:19:10 -0500 |
commit | e7837911c89af144fe012e5063f8ca88c4c66956 (patch) | |
tree | eb2c141aa289033a400ede246e3478083e5e81bf /src | |
parent | dc712619bf21f7c577f28b3f8281bf4c25362511 (diff) | |
download | mongo-e7837911c89af144fe012e5063f8ca88c4c66956.tar.gz |
SERVER-31538 Ensure the ServiceStateMachine always gets cleaned up on error/termination
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/transport/service_entry_point_impl.cpp | 10 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.cpp | 313 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.h | 85 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine_test.cpp | 276 |
4 files changed, 497 insertions, 187 deletions
diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp index df02f234f11..adbdc48a6ec 100644 --- a/src/mongo/transport/service_entry_point_impl.cpp +++ b/src/mongo/transport/service_entry_point_impl.cpp @@ -86,9 +86,9 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) { const bool quiet = serverGlobalParams.quiet.load(); size_t connectionCount; + auto transportMode = _svcCtx->getServiceExecutor()->transportMode(); - auto ssm = ServiceStateMachine::create( - _svcCtx, session, _svcCtx->getServiceExecutor()->transportMode()); + auto ssm = ServiceStateMachine::create(_svcCtx, session, transportMode); { stdx::lock_guard<decltype(_sessionsMutex)> lk(_sessionsMutex); connectionCount = _sessions.size() + 1; @@ -129,7 +129,11 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) { }); - ssm->scheduleNext(); + auto ownership = ServiceStateMachine::Ownership::kOwned; + if (transportMode == transport::Mode::kSynchronous) { + ownership = ServiceStateMachine::Ownership::kStatic; + } + ssm->start(ownership); } void ServiceEntryPointImpl::endAllSessions(transport::Session::TagMask tags) { diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp index 47764de2444..263c3c06b2d 100644 --- a/src/mongo/transport/service_state_machine.cpp +++ b/src/mongo/transport/service_state_machine.cpp @@ -32,6 +32,7 @@ #include "mongo/transport/service_state_machine.h" +#include "mongo/config.h" #include "mongo/db/client.h" #include "mongo/db/dbmessage.h" #include "mongo/db/stats/counters.h" @@ -89,91 +90,128 @@ bool setExhaustMessage(Message* m, const DbResponse& dbresponse) { } // namespace -using transport::TransportLayer; using transport::ServiceExecutor; +using transport::TransportLayer; /* * This class wraps up the logic for swapping/unswapping the Client during runNext(). + * + * In debug builds this also ensures that only one thread is working on the SSM at once. */ class ServiceStateMachine::ThreadGuard { ThreadGuard(ThreadGuard&) = delete; ThreadGuard& operator=(ThreadGuard&) = delete; public: - explicit ThreadGuard(ServiceStateMachine* ssm) - : _ssm{ssm}, - _haveTakenOwnership{!_ssm->_isOwned.test_and_set()}, - _oldThreadName{getThreadName().toString()} { - const auto currentOwningThread = _ssm->_currentOwningThread.load(); - const auto currentThreadId = stdx::this_thread::get_id(); - - // If this is true, then we are the "owner" of the Client and we should swap the - // client/thread name before doing any work. - if (_haveTakenOwnership) { - _ssm->_currentOwningThread.store(currentThreadId); + explicit ThreadGuard(ServiceStateMachine* ssm) : _ssm{ssm} { + auto owned = _ssm->_owned.compareAndSwap(Ownership::kUnowned, Ownership::kOwned); + if (owned == Ownership::kStatic) { + dassert(haveClient()); + dassert(Client::getCurrent() == _ssm->_dbClientPtr); + _haveTakenOwnership = true; + return; + } + +#ifdef MONGO_CONFIG_DEBUG_BUILD + invariant(owned == Ownership::kUnowned); + _ssm->_owningThread.store(stdx::this_thread::get_id()); +#endif - // Set up the thread name + // Set up the thread name + auto oldThreadName = getThreadName(); + if (oldThreadName != _ssm->_threadName) { + _ssm->_oldThreadName = getThreadName().toString(); setThreadName(_ssm->_threadName); + } + + // Swap the current Client so calls to cc() work as expected + Client::setCurrent(std::move(_ssm->_dbClient)); + _haveTakenOwnership = true; + } - // These are sanity checks to make sure that the Client is what we expect it to be - invariant(!haveClient()); - invariant(_ssm->_dbClient.get() == _ssm->_dbClientPtr); + // Constructing from a moved ThreadGuard invalidates the other thread guard. + ThreadGuard(ThreadGuard&& other) + : _ssm(other._ssm), _haveTakenOwnership(other._haveTakenOwnership) { + other._haveTakenOwnership = false; + } - // Swap the current Client so calls to cc() work as expected - Client::setCurrent(std::move(_ssm->_dbClient)); - } else if (currentOwningThread != currentThreadId) { - // If the currentOwningThread does not equal the currentThreadId, then another thread - // currently "owns" the Client and we should reschedule ourself. - _okayToRunNext = false; + ThreadGuard& operator=(ThreadGuard&& other) { + if (this != &other) { + _ssm = other._ssm; + _haveTakenOwnership = other._haveTakenOwnership; + other._haveTakenOwnership = false; } - } + return *this; + }; + + ThreadGuard() = delete; ~ThreadGuard() { - // If we are not the owner of the SSM, then do nothing. Something higher up the call stack - // will have to clean up. - if (!_haveTakenOwnership) - return; + if (_haveTakenOwnership) + release(); + } - // If the session has ended, then assume that it's unsafe to do anything but call the - // cleanup hook. + explicit operator bool() const { +#ifdef MONGO_CONFIG_DEBUG_BUILD + if (_haveTakenOwnership) { + invariant(_ssm->_owned.load() != Ownership::kUnowned); + invariant(_ssm->_owningThread.load() == stdx::this_thread::get_id()); + return true; + } else { + return false; + } +#else + return _haveTakenOwnership; +#endif + } + + void markStaticOwnership() { + dassert(static_cast<bool>(*this)); + _ssm->_owned.store(Ownership::kStatic); + } + + void release() { + auto owned = _ssm->_owned.load(); + +#ifdef MONGO_CONFIG_DEBUG_BUILD + dassert(_haveTakenOwnership); + dassert(owned != Ownership::kUnowned); + dassert(_ssm->_owningThread.load() == stdx::this_thread::get_id()); +#endif + if (owned != Ownership::kStatic) { + if (haveClient()) { + _ssm->_dbClient = Client::releaseCurrent(); + } + + if (!_ssm->_oldThreadName.empty()) { + setThreadName(_ssm->_oldThreadName); + } + } + + // If the session has ended, then it's unsafe to do anything but call the cleanup hook. if (_ssm->state() == State::Ended) { - // The cleanup hook may change as soon as we unlock the mutex, so move it out of the - // ssm before unlocking the lock. + // The cleanup hook gets moved out of _ssm->_cleanupHook so that it can only be called + // once. auto cleanupHook = std::move(_ssm->_cleanupHook); if (cleanupHook) cleanupHook(); + // It's very important that the Guard returns here and that the SSM's state does not + // get modified in any way after the cleanup hook is called. return; } - // Otherwise swap thread locals and thread names back into the SSM so its ready for the - // next run. - if (haveClient()) { - _ssm->_dbClient = Client::releaseCurrent(); + _haveTakenOwnership = false; + // If owned != Ownership::kOwned here then it can only equal Ownership::kStatic and we + // should just return + if (owned == Ownership::kOwned) { + _ssm->_owned.store(Ownership::kUnowned); } - setThreadName(_oldThreadName); - _ssm->_isOwned.clear(); - } - - // This bool operator reflects whether the ThreadGuard was able to take ownership of the thread - // either higher up the call chain, or in this call. If this returns false, then it is not safe - // to assume the thread has been setup correctly, or that any mutable state of the SSM is safe - // to access except for the current _state value. - explicit operator bool() const { - return _okayToRunNext; - } - - // Returns whether the thread guard is the owner of the SSM's state or not. Callers can use this - // to determine whether their callchain is recursive. - bool isOwner() const { - return _haveTakenOwnership; } private: ServiceStateMachine* _ssm; - bool _haveTakenOwnership; - const std::string _oldThreadName; - bool _okayToRunNext = true; + bool _haveTakenOwnership = false; }; std::shared_ptr<ServiceStateMachine> ServiceStateMachine::create(ServiceContext* svcContext, @@ -192,27 +230,52 @@ ServiceStateMachine::ServiceStateMachine(ServiceContext* svcContext, _sessionHandle(session), _dbClient{svcContext->makeClient("conn", std::move(session))}, _dbClientPtr{_dbClient.get()}, - _threadName{str::stream() << "conn" << _session()->id()}, - _currentOwningThread{stdx::this_thread::get_id()} {} + _threadName{str::stream() << "conn" << _session()->id()} {} const transport::SessionHandle& ServiceStateMachine::_session() const { return _sessionHandle; } +void ServiceStateMachine::_sourceMessage(ThreadGuard guard) { + invariant(_inMessage.empty()); + auto ticket = _session()->sourceMessage(&_inMessage); + + _state.store(State::SourceWait); + guard.release(); + + if (_transportMode == transport::Mode::kSynchronous) { + _sourceCallback([this](auto ticket) { + MONGO_IDLE_THREAD_BLOCK; + return _session()->getTransportLayer()->wait(std::move(ticket)); + }(std::move(ticket))); + } else if (_transportMode == transport::Mode::kAsynchronous) { + _session()->getTransportLayer()->asyncWait( + std::move(ticket), [this](Status status) { _sourceCallback(status); }); + } +} + +void ServiceStateMachine::_sinkMessage(ThreadGuard guard, Message toSink) { + // Sink our response to the client + auto ticket = _session()->sinkMessage(toSink); + + _state.store(State::SinkWait); + guard.release(); + + if (_transportMode == transport::Mode::kSynchronous) { + _sinkCallback(_session()->getTransportLayer()->wait(std::move(ticket))); + } else if (_transportMode == transport::Mode::kAsynchronous) { + _session()->getTransportLayer()->asyncWait( + std::move(ticket), [this](Status status) { _sinkCallback(status); }); + } +} + void ServiceStateMachine::_sourceCallback(Status status) { // The first thing to do is create a ThreadGuard which will take ownership of the SSM in this // thread. ThreadGuard guard(this); - // If the guard wasn't able to take ownership of the thread, then reschedule this call to - // runNext() so that this thread can do other useful work with its timeslice instead of going - // to sleep while waiting for the SSM to be released. - if (!guard) { - return _scheduleFunc([this, status] { _sourceCallback(status); }, - ServiceExecutor::kDeferredTask); - } // Make sure we just called sourceMessage(); - invariant(state() == State::SourceWait); + dassert(state() == State::SourceWait); auto remote = _session()->remote(); if (status.isOK()) { @@ -225,7 +288,7 @@ void ServiceStateMachine::_sourceCallback(Status status) { // If this callback doesn't own the ThreadGuard, then we're being called recursively, // and the executor shouldn't start a new thread to process the message - it can use this // one just after this returns. - return scheduleNext(ServiceExecutor::kMayRecurse); + return _scheduleNextWithGuard(std::move(guard), ServiceExecutor::kMayRecurse); } else if (ErrorCodes::isInterruption(status.code()) || ErrorCodes::isNetworkError(status.code())) { LOG(2) << "Session from " << remote << " encountered a network error during SourceMessage"; @@ -242,22 +305,15 @@ void ServiceStateMachine::_sourceCallback(Status status) { // There was an error receiving a message from the client and we've already printed the error // so call runNextInGuard() to clean up the session without waiting. - _runNextInGuard(guard); + _runNextInGuard(std::move(guard)); } void ServiceStateMachine::_sinkCallback(Status status) { // The first thing to do is create a ThreadGuard which will take ownership of the SSM in this // thread. ThreadGuard guard(this); - // If the guard wasn't able to take ownership of the thread, then reschedule this call to - // runNext() so that this thread can do other useful work with its timeslice instead of going - // to sleep while waiting for the SSM to be released. - if (!guard) { - return _scheduleFunc([this, status] { _sinkCallback(status); }, - ServiceExecutor::kDeferredTask); - } - invariant(state() == State::SinkWait); + dassert(state() == State::SinkWait); // If there was an error sinking the message to the client, then we should print an error and // end the session. No need to unwind the stack, so this will runNextInGuard() and return. @@ -268,22 +324,19 @@ void ServiceStateMachine::_sinkCallback(Status status) { log() << "Error sending response to client: " << status << ". Ending connection from " << _session()->remote() << " (connection id: " << _session()->id() << ")"; _state.store(State::EndSession); - return _runNextInGuard(guard); + return _runNextInGuard(std::move(guard)); } else if (_inExhaust) { _state.store(State::Process); } else { _state.store(State::Source); } - return scheduleNext(ServiceExecutor::kDeferredTask | ServiceExecutor::kMayYieldBeforeSchedule); + return _scheduleNextWithGuard(std::move(guard), + ServiceExecutor::kDeferredTask | + ServiceExecutor::kMayYieldBeforeSchedule); } -void ServiceStateMachine::_processMessage(ThreadGuard& guard) { - // This may have been called just after a failure to source a message, in which case this - // should return early so the session can be cleaned up. - if (state() != State::Process) { - return; - } +void ServiceStateMachine::_processMessage(ThreadGuard guard) { invariant(!_inMessage.empty()); auto& compressorMgr = MessageCompressorManager::forSession(_session()); @@ -332,42 +385,22 @@ void ServiceStateMachine::_processMessage(ThreadGuard& guard) { uassertStatusOK(swm.getStatus()); toSink = swm.getValue(); } + _sinkMessage(std::move(guard), std::move(toSink)); - // Sink our response to the client - auto ticket = _session()->sinkMessage(toSink); - - _state.store(State::SinkWait); - if (_transportMode == transport::Mode::kSynchronous) { - _sinkCallback(_session()->getTransportLayer()->wait(std::move(ticket))); - } else if (_transportMode == transport::Mode::kAsynchronous) { - _session()->getTransportLayer()->asyncWait( - std::move(ticket), [this](Status status) { _sinkCallback(status); }); - } else { - MONGO_UNREACHABLE; - } } else { _state.store(State::Source); _inMessage.reset(); - return scheduleNext(ServiceExecutor::kDeferredTask); + return _scheduleNextWithGuard(std::move(guard), ServiceExecutor::kDeferredTask); } } void ServiceStateMachine::runNext() { - // The first thing to do is create a ThreadGuard which will take ownership of the SSM in this - // thread. - ThreadGuard guard(this); - // If the guard wasn't able to take ownership of the thread, then reschedule this call to - // runNext() so that this thread can do other useful work with its timeslice instead of going - // to sleep while waiting for the SSM to be released. - if (!guard) { - return scheduleNext(ServiceExecutor::kDeferredTask); - } - return _runNextInGuard(guard); + return _runNextInGuard(ThreadGuard(this)); } -void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) { +void ServiceStateMachine::_runNextInGuard(ThreadGuard guard) { auto curState = state(); - invariant(curState != State::Ended); + dassert(curState != State::Ended); // If this is the first run of the SSM, then update its state to Source if (curState == State::Created) { @@ -376,29 +409,14 @@ void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) { } // Make sure the current Client got set correctly - invariant(Client::getCurrent() == _dbClientPtr); + dassert(Client::getCurrent() == _dbClientPtr); try { switch (curState) { - case State::Source: { - invariant(_inMessage.empty()); - - auto ticket = _session()->sourceMessage(&_inMessage); - _state.store(State::SourceWait); - if (_transportMode == transport::Mode::kSynchronous) { - _sourceCallback([this](auto ticket) { - MONGO_IDLE_THREAD_BLOCK; - return _session()->getTransportLayer()->wait(std::move(ticket)); - }(std::move(ticket))); - } else if (_transportMode == transport::Mode::kAsynchronous) { - _session()->getTransportLayer()->asyncWait( - std::move(ticket), [this](Status status) { _sourceCallback(status); }); - } else { - MONGO_UNREACHABLE; - } + case State::Source: + _sourceMessage(std::move(guard)); break; - } case State::Process: - _processMessage(guard); + _processMessage(std::move(guard)); break; case State::EndSession: // This will get handled below in an if statement. That way if an error occurs @@ -409,7 +427,10 @@ void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) { } if (state() == State::EndSession) { - _cleanupSession(guard); + if (!guard) { + guard = ThreadGuard(this); + } + _cleanupSession(std::move(guard)); } return; @@ -421,12 +442,40 @@ void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) { quickExit(EXIT_UNCAUGHT); } + if (!guard) { + guard = ThreadGuard(this); + } _state.store(State::EndSession); - _cleanupSession(guard); + _cleanupSession(std::move(guard)); } -void ServiceStateMachine::scheduleNext(ServiceExecutor::ScheduleFlags flags) { - _scheduleFunc([this] { runNext(); }, flags); +void ServiceStateMachine::start(Ownership ownershipModel) { + _scheduleNextWithGuard( + ThreadGuard(this), transport::ServiceExecutor::kEmptyFlags, ownershipModel); +} + +void ServiceStateMachine::_scheduleNextWithGuard(ThreadGuard guard, + transport::ServiceExecutor::ScheduleFlags flags, + Ownership ownershipModel) { + auto func = [ ssm = shared_from_this(), ownershipModel ] { + ThreadGuard guard(ssm.get()); + if (ownershipModel == Ownership::kStatic) + guard.markStaticOwnership(); + ssm->_runNextInGuard(std::move(guard)); + }; + guard.release(); + Status status = _serviceContext->getServiceExecutor()->schedule(std::move(func), flags); + if (status.isOK()) { + return; + } + + // We've had an error, reacquire the ThreadGuard and destroy the SSM + ThreadGuard terminateGuard(this); + + // The service executor failed to schedule the task. This could for example be that we failed + // to start a worker thread. Terminate this connection to leave the system in a valid state. + _terminateAndLogIfError(status); + _cleanupSession(std::move(terminateGuard)); } void ServiceStateMachine::terminate() { @@ -449,7 +498,7 @@ void ServiceStateMachine::terminateIfTagsDontMatch(transport::Session::TagMask t return; } - _session()->getTransportLayer()->end(_session()); + terminate(); } void ServiceStateMachine::setCleanupHook(stdx::function<void()> hook) { @@ -468,7 +517,7 @@ void ServiceStateMachine::_terminateAndLogIfError(Status status) { } } -void ServiceStateMachine::_cleanupSession(ThreadGuard& guard) { +void ServiceStateMachine::_cleanupSession(ThreadGuard guard) { _state.store(State::Ended); _inMessage.reset(); diff --git a/src/mongo/transport/service_state_machine.h b/src/mongo/transport/service_state_machine.h index 35812b86940..60c6981acaf 100644 --- a/src/mongo/transport/service_state_machine.h +++ b/src/mongo/transport/service_state_machine.h @@ -31,6 +31,7 @@ #include <atomic> #include "mongo/base/status.h" +#include "mongo/config.h" #include "mongo/db/service_context.h" #include "mongo/platform/atomic_word.h" #include "mongo/stdx/functional.h" @@ -73,12 +74,12 @@ public: transport::Mode transportMode); /* - * Any state may transition to EndSession in case of an error, otherwise the valid state - * transitions are: - * Source -> SourceWait -> Process -> SinkWait -> Source (standard RPC) - * Source -> SourceWait -> Process -> SinkWait -> Process -> SinkWait ... (exhaust) - * Source -> SourceWait -> Process -> Source (fire-and-forget) - */ + * Any state may transition to EndSession in case of an error, otherwise the valid state + * transitions are: + * Source -> SourceWait -> Process -> SinkWait -> Source (standard RPC) + * Source -> SourceWait -> Process -> SinkWait -> Process -> SinkWait ... (exhaust) + * Source -> SourceWait -> Process -> Source (fire-and-forget) + */ enum class State { Created, // The session has been created, but no operations have been performed yet Source, // Request a new Message from the network to handle @@ -91,6 +92,18 @@ public: }; /* + * When start() is called with Ownership::kOwned, the SSM will swap the Client/thread name + * whenever it runs a stage of the state machine, and then unswap them out when leaving the SSM. + * + * With Ownership::kStatic, it will assume that the SSM will only ever be run from one thread, + * and that thread will not be used for other SSM's. It will swap in the Client/thread name + * for the first run and leave them in place. + * + * kUnowned is used internally to mark that the SSM is inactive. + */ + enum class Ownership { kUnowned, kOwned, kStatic }; + + /* * runNext() will run the current state of the state machine. It also handles all the error * handling and state management for requests. * @@ -104,14 +117,12 @@ public: void runNext(); /* - * scheduleNext() schedules a call to runNext() in the future. This will be implemented with - * an async TransportLayer. + * start() schedules a call to runNext() in the future. * * It is guaranteed to unwind the stack, and not call runNext() recursively, but is not - * guaranteed that runNext() will run after this returns. + * guaranteed that runNext() will run after this return */ - void scheduleNext( - transport::ServiceExecutor::ScheduleFlags flags = transport::ServiceExecutor::kEmptyFlags); + void start(Ownership ownershipModel); /* * Gets the current state of connection for testing/diagnostic purposes. @@ -147,29 +158,21 @@ private: friend class ThreadGuard; /* - * Terminates the associated transport Session if status indicate error. - * - * This will not block on the session terminating cleaning itself up, it returns immediately. - */ + * Terminates the associated transport Session if status indicate error. + * + * This will not block on the session terminating cleaning itself up, it returns immediately. + */ void _terminateAndLogIfError(Status status); /* - * This and scheduleFunc() are helper functions to schedule tasks on the serviceExecutor - * while maintaining a shared_ptr copy to anchor the lifetime of the SSM while waiting for - * callbacks to run. - */ - template <typename Func> - void _scheduleFunc(Func&& func, transport::ServiceExecutor::ScheduleFlags flags) { - Status status = _serviceContext->getServiceExecutor()->schedule( - [ func = std::move(func), anchor = shared_from_this() ] { func(); }, flags); - if (!status.isOK()) { - // The service executor failed to schedule the task - // This could for example be that we failed to start - // a worker thread. Terminate this connection to - // leave the system in a valid state. - _terminateAndLogIfError(status); - } - } + * This is a helper function to schedule tasks on the serviceExecutor maintaining a shared_ptr + * copy to anchor the lifetime of the SSM while waiting for callbacks to run. + * + * If scheduling the function fails, the SSM will be terminated and cleaned up immediately + */ + void _scheduleNextWithGuard(ThreadGuard guard, + transport::ServiceExecutor::ScheduleFlags flags, + Ownership ownershipModel = Ownership::kOwned); /* * Gets the transport::Session associated with this connection @@ -182,13 +185,13 @@ private: * runNext() and already own a ThreadGuard, they should call this with that guard as the * argument. */ - void _runNextInGuard(ThreadGuard& guard); + void _runNextInGuard(ThreadGuard guard); /* * This function actually calls into the database and processes a request. It's broken out * into its own inline function for better readability. */ - inline void _processMessage(ThreadGuard& guard); + inline void _processMessage(ThreadGuard guard); /* * These get called by the TransportLayer when requested network I/O has completed. @@ -197,9 +200,16 @@ private: void _sinkCallback(Status status); /* + * Source/Sink message from the TransportLayer. These will invalidate the ThreadGuard just + * before waiting on the TL. + */ + void _sourceMessage(ThreadGuard guard); + void _sinkMessage(ThreadGuard guard, Message toSink); + + /* * Releases all the resources associated with the session and call the cleanupHook. */ - void _cleanupSession(ThreadGuard& guard); + void _cleanupSession(ThreadGuard guard); AtomicWord<State> _state{State::Created}; @@ -218,8 +228,11 @@ private: boost::optional<MessageCompressorId> _compressorId; Message _inMessage; - AtomicWord<stdx::thread::id> _currentOwningThread; - std::atomic_flag _isOwned = ATOMIC_FLAG_INIT; // NOLINT + AtomicWord<Ownership> _owned{Ownership::kUnowned}; +#if MONGO_CONFIG_DEBUG_BUILD + AtomicWord<stdx::thread::id> _owningThread; +#endif + std::string _oldThreadName; }; template <typename T> diff --git a/src/mongo/transport/service_state_machine_test.cpp b/src/mongo/transport/service_state_machine_test.cpp index 646431580f7..0e4ca8c15c1 100644 --- a/src/mongo/transport/service_state_machine_test.cpp +++ b/src/mongo/transport/service_state_machine_test.cpp @@ -39,7 +39,7 @@ #include "mongo/transport/mock_session.h" #include "mongo/transport/mock_ticket.h" #include "mongo/transport/service_entry_point.h" -#include "mongo/transport/service_executor_noop.h" +#include "mongo/transport/service_executor.h" #include "mongo/transport/service_state_machine.h" #include "mongo/transport/transport_layer_mock.h" #include "mongo/unittest/unittest.h" @@ -51,6 +51,11 @@ namespace mongo { namespace { +inline std::string stateToString(ServiceStateMachine::State state) { + std::string ret = str::stream() << state; + return ret; +} + class MockSEP : public ServiceEntryPoint { public: virtual ~MockSEP() = default; @@ -122,9 +127,9 @@ public: return TransportLayer::TicketSessionClosedStatus; } - if (_nextMessage) { - *message = *_nextMessage; - } + OpMsgBuilder builder; + builder.setBody(BSON("ping" << 1)); + *message = builder.finish(); return TransportLayerMock::sourceMessage(session, message, expiration); } @@ -154,9 +159,10 @@ public: ASSERT_EQ(_ssm->state(), _lastTicketSource ? ServiceStateMachine::State::SourceWait : ServiceStateMachine::State::SinkWait); - std::stringstream ss; - ss << _ssm->state(); - log() << "In wait. ssm state: " << ss.str(); + + log() << "In wait. ssm state: " << stateToString(_ssm->state()); + if (_waitHook) + _waitHook(); return TransportLayerMock::wait(std::move(ticket)); } @@ -164,10 +170,6 @@ public: MONGO_UNREACHABLE; } - void setNextMessage(Message&& message) { - _nextMessage = std::move(message); - } - void setSSM(ServiceStateMachine* ssm) { _ssm = ssm; } @@ -190,14 +192,18 @@ public: return _ranSource; } + void setWaitHook(stdx::function<void()> hook) { + _waitHook = std::move(hook); + } + private: bool _lastTicketSource = true; bool _ranSink = false; bool _ranSource = false; - boost::optional<Message> _nextMessage; FailureMode _nextShouldFail = Nothing; Message _lastSunk; ServiceStateMachine* _ssm; + stdx::function<void()> _waitHook; }; Message buildRequest(BSONObj input) { @@ -206,6 +212,61 @@ Message buildRequest(BSONObj input) { return builder.finish(); } +class MockServiceExecutor : public ServiceExecutor { +public: + explicit MockServiceExecutor(ServiceContext* ctx) {} + + using ScheduleHook = stdx::function<bool(Task)>; + + Status start() override { + return Status::OK(); + } + Status shutdown(Milliseconds timeout) override { + return Status::OK(); + } + Status schedule(Task task, ScheduleFlags flags) override { + if (!_scheduleHook) { + return Status::OK(); + } else { + return _scheduleHook(std::move(task)) ? Status::OK() : Status{ErrorCodes::InternalError, + "Hook returned error!"}; + } + } + + Mode transportMode() const override { + return Mode::kSynchronous; + } + + void appendStats(BSONObjBuilder* bob) const override {} + + void setScheduleHook(ScheduleHook hook) { + _scheduleHook = std::move(hook); + } + +private: + ScheduleHook _scheduleHook; +}; + +class SimpleEvent { +public: + void signal() { + stdx::unique_lock<stdx::mutex> lk(_mutex); + _signaled = true; + _cond.notify_one(); + } + + void wait() { + stdx::unique_lock<stdx::mutex> lk(_mutex); + _cond.wait(lk, [this] { return _signaled; }); + _signaled = false; + } + +private: + stdx::mutex _mutex; + stdx::condition_variable _cond; + bool _signaled = false; +}; + using State = ServiceStateMachine::State; class ServiceStateMachineFixture : public unittest::Test { @@ -223,7 +284,9 @@ protected: _sep = sep.get(); sc->setServiceEntryPoint(std::move(sep)); - sc->setServiceExecutor(stdx::make_unique<ServiceExecutorNoop>(sc)); + auto se = stdx::make_unique<MockServiceExecutor>(sc); + _sexec = se.get(); + sc->setServiceExecutor(std::move(se)); auto tl = stdx::make_unique<MockTL>(); _tl = tl.get(); @@ -236,7 +299,7 @@ protected: } void tearDown() override { - getGlobalServiceContext()->getTransportLayer()->shutdown(); + _tl->shutdown(); } void runPingTest(State first, State second); @@ -244,14 +307,13 @@ protected: MockTL* _tl; MockSEP* _sep; + MockServiceExecutor* _sexec; SessionHandle _session; std::shared_ptr<ServiceStateMachine> _ssm; bool _ranHandler; }; void ServiceStateMachineFixture::runPingTest(State first, State second) { - _tl->setNextMessage(buildRequest(BSON("ping" << 1))); - ASSERT_FALSE(haveClient()); ASSERT_EQ(_ssm->state(), State::Created); log() << "run next"; @@ -329,5 +391,187 @@ TEST_F(ServiceStateMachineFixture, TestSessionCleanupOnDestroy) { ASSERT_TRUE(hookRan); } +// This tests that SSMs that fail to schedule their first task get cleaned up correctly. +// (i.e. we couldn't create a worker thread after accept()). +TEST_F(ServiceStateMachineFixture, ScheduleFailureDuringCreateCleanup) { + _sexec->setScheduleHook([](auto) { return false; }); + // Set a cleanup hook so we know that the cleanup hook actually gets run when the session + // is destroyed + bool hookRan = false; + _ssm->setCleanupHook([&hookRan] { hookRan = true; }); + + _ssm->start(ServiceStateMachine::Ownership::kOwned); + ASSERT_EQ(State::Ended, _ssm->state()); + ASSERT_EQ(_ssm.use_count(), 1); + ASSERT_TRUE(hookRan); +} + +// This tests that calling terminate() actually ends and cleans up the SSM during all the +// states. +TEST_F(ServiceStateMachineFixture, TerminateWorksForAllStates) { + SimpleEvent hookRan, okayToContinue; + + auto cleanupHook = [&hookRan] { + log() << "Cleaning up session"; + hookRan.signal(); + }; + + // This is a shared hook between the executor/TL that lets us notify the test that the SSM + // has reached a certain state and then gets terminated during that state. + State waitFor = State::Created; + SimpleEvent atDesiredState; + auto waitForHook = [this, &waitFor, &atDesiredState, &okayToContinue]() { + log() << "Checking for wakeup at " << stateToString(_ssm->state()) << ". Expecting " + << stateToString(waitFor); + if (_ssm->state() == waitFor) { + atDesiredState.signal(); + okayToContinue.wait(); + } + }; + + // This wraps the waitForHook so that schedules always succeed. + _sexec->setScheduleHook([waitForHook](auto) { + waitForHook(); + return true; + }); + + // This just lets us intercept calls to _tl->wait() and terminate during them. + _tl->setWaitHook(waitForHook); + + // Run this same test for each state. + auto states = {State::Source, State::SourceWait, State::Process, State::SinkWait}; + for (const auto testState : states) { + log() << "Testing termination during " << stateToString(testState); + + // Reset the _ssm to a fresh SSM and reset our tracking variables. + _ssm = ServiceStateMachine::create( + getGlobalServiceContext(), _tl->createSession(), transport::Mode::kSynchronous); + _tl->setSSM(_ssm.get()); + _ssm->setCleanupHook(cleanupHook); + + waitFor = testState; + // This is a dummy thread that just advances the SSM while we track its state/kill it + stdx::thread runner([ssm = _ssm] { + while (ssm->state() != State::Ended) { + ssm->runNext(); + } + }); + + // Wait for the SSM to advance to the expected state + atDesiredState.wait(); + log() << "Terminating session at " << stateToString(_ssm->state()); + + // Terminate the SSM + _ssm->terminate(); + + // Notify the waitForHook to continue and end the session + okayToContinue.signal(); + + // Wait for the SSM to terminate and the thread to end. + hookRan.wait(); + runner.join(); + + // Verify that the SSM terminated and is in the correct state + ASSERT_EQ(State::Ended, _ssm->state()); + ASSERT_EQ(_ssm.use_count(), 1); + } +} + +// This tests that calling terminate() actually ends and cleans up the SSM during all states, and +// with schedule() returning an error for each state. +TEST_F(ServiceStateMachineFixture, TerminateWorksForAllStatesWithScheduleFailure) { + // Set a cleanup hook so we know that the cleanup hook actually gets run when the session + // is destroyed + SimpleEvent hookRan, okayToContinue; + bool scheduleFailed = false; + + auto cleanupHook = [&hookRan] { + log() << "Cleaning up session"; + hookRan.signal(); + }; + + // This is a shared hook between the executor/TL that lets us notify the test that the SSM + // has reached a certain state and then gets terminated during that state. + State waitFor = State::Created; + SimpleEvent atDesiredState; + auto waitForHook = [this, &waitFor, &scheduleFailed, &okayToContinue, &atDesiredState]() { + log() << "Checking for wakeup at " << stateToString(_ssm->state()) << ". Expecting " + << stateToString(waitFor); + if (_ssm->state() == waitFor) { + atDesiredState.signal(); + okayToContinue.wait(); + scheduleFailed = true; + return false; + } + return true; + }; + + _sexec->setScheduleHook([waitForHook](auto) { return waitForHook(); }); + // This wraps the waitForHook and discards its return status. + _tl->setWaitHook([waitForHook] { waitForHook(); }); + + auto states = {State::Source, State::SourceWait, State::Process, State::SinkWait}; + for (const auto testState : states) { + log() << "Testing termination during " << stateToString(testState); + _ssm = ServiceStateMachine::create( + getGlobalServiceContext(), _tl->createSession(), transport::Mode::kSynchronous); + _tl->setSSM(_ssm.get()); + scheduleFailed = false; + _ssm->setCleanupHook(cleanupHook); + + waitFor = testState; + // This is a dummy thread that just advances the SSM while we track its state/kill it + stdx::thread runner([ ssm = _ssm, &scheduleFailed ] { + while (ssm->state() != State::Ended && !scheduleFailed) { + ssm->runNext(); + } + }); + + // Wait for the SSM to advance to the expected state + atDesiredState.wait(); + ASSERT_EQ(_ssm->state(), testState); + log() << "Terminating session at " << stateToString(_ssm->state()); + + // Terminate the SSM + _ssm->terminate(); + + // Notify the waitForHook to continue and end the session + okayToContinue.signal(); + hookRan.wait(); + runner.join(); + + // Verify that the SSM terminated and is in the correct state + ASSERT_EQ(State::Ended, _ssm->state()); + ASSERT_EQ(_ssm.use_count(), 1); + } +} + +// This makes sure that the SSM can run recursively by forcing the ServiceExecutor to run everything +// recursively +TEST_F(ServiceStateMachineFixture, SSMRunsRecursively) { + // This lets us force the SSM to only run once. After sinking the first response, the next call + // to sourceMessage will return with an error. + _tl->setWaitHook([this] { + if (_ssm->state() == State::SinkWait) { + _tl->setNextFailure(); + } + }); + + // The scheduleHook just runs the task, effectively making this a recursive executor. + int recursionDepth = 0; + _sexec->setScheduleHook([&recursionDepth](auto task) { + log() << "running task in executor. depth: " << ++recursionDepth; + task(); + return true; + }); + + _ssm->runNext(); + // Check that the SSM actually ran, is ended, and actually ran recursively + ASSERT_EQ(recursionDepth, 2); + ASSERT_TRUE(_tl->ranSource()); + ASSERT_TRUE(_tl->ranSink()); + ASSERT_EQ(_ssm->state(), State::Ended); +} + } // namespace } // namespace mongo |