diff options
author | Jonathan Reams <jbreams@mongodb.com> | 2017-07-21 12:51:16 -0400 |
---|---|---|
committer | Jonathan Reams <jbreams@mongodb.com> | 2017-07-28 11:00:14 -0400 |
commit | 6f5f53f79552aacbbfba8c7e61cf7b15f58f3f3f (patch) | |
tree | 983ec64121fbe98e400e7c47ea266c07da3712b3 /src/mongo/transport | |
parent | 3d313292efac3654393d30e0eb439e3df7728171 (diff) | |
download | mongo-6f5f53f79552aacbbfba8c7e61cf7b15f58f3f3f.tar.gz |
SERVER-30260 Fix race condition in endAllSessions
Diffstat (limited to 'src/mongo/transport')
-rw-r--r-- | src/mongo/transport/service_entry_point_impl.cpp | 23 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.cpp | 97 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.h | 35 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine_test.cpp | 24 |
4 files changed, 101 insertions, 78 deletions
diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp index cd2fa26ea1e..5b07f8b2c6c 100644 --- a/src/mongo/transport/service_entry_point_impl.cpp +++ b/src/mongo/transport/service_entry_point_impl.cpp @@ -100,33 +100,14 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) { } void ServiceEntryPointImpl::endAllSessions(transport::Session::TagMask tags) { - SSMList connsToEnd; - // While holding the _sesionsMutex, loop over all the current connections, and if their tags - // do not match the requested tags to skip, create a copy of their shared_ptr and place it in - // connsToEnd. - // - // This will ensure that sessions to be ended will live at least long enough for us to call - // their terminate() function, even if they've already ended because of an i/o error. + // do not match the requested tags to skip, terminate the session. { stdx::unique_lock<decltype(_sessionsMutex)> lk(_sessionsMutex); for (auto& ssm : _sessions) { - if (ssm->session()->getTags() & tags) { - log() << "Skip closing connection for connection # " << ssm->session()->id(); - } else { - connsToEnd.emplace_back(ssm); - } + ssm->terminateIfTagsDontMatch(tags); } } - - // Loop through all the connections we marked for ending and call terminate on them. They will - // then remove themselves from _sessions whenever they transition to the next state. - // - // If they've already ended, then this is a noop, and the SSM will be destroyed when connsToEnd - // goes out of scope. - for (auto& ssm : connsToEnd) { - ssm->terminate(); - } } std::size_t ServiceEntryPointImpl::getNumberOfConnections() const { diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp index d30a99dcb63..9c2d335cee7 100644 --- a/src/mongo/transport/service_state_machine.cpp +++ b/src/mongo/transport/service_state_machine.cpp @@ -182,17 +182,17 @@ ServiceStateMachine::ServiceStateMachine(ServiceContext* svcContext, _sep{svcContext->getServiceEntryPoint()}, _sync(sync), _serviceContext(svcContext), + _sessionHandle(session), _dbClient{svcContext->makeClient("conn", std::move(session))}, _dbClientPtr{_dbClient.get()}, - _threadName{str::stream() << "conn" << _dbClient->session()->id()}, + _threadName{str::stream() << "conn" << _session()->id()}, _currentOwningThread{stdx::this_thread::get_id()} {} -const transport::SessionHandle& ServiceStateMachine::session() const { - // The _dbClientPtr should always point to our Client which should always own our SessionHandle - return _dbClientPtr->session(); +const transport::SessionHandle& ServiceStateMachine::_session() const { + return _sessionHandle; } -void ServiceStateMachine::sourceCallback(Status status) { +void ServiceStateMachine::_sourceCallback(Status status) { // The first thing to do is create a ThreadGuard which will take ownership of the SSM in this // thread. ThreadGuard guard(this); @@ -200,12 +200,12 @@ void ServiceStateMachine::sourceCallback(Status status) { // runNext() so that this thread can do other useful work with its timeslice instead of going // to sleep while waiting for the SSM to be released. if (!guard) { - return scheduleFunc([this, status] { sourceCallback(status); }); + return _scheduleFunc([this, status] { _sourceCallback(status); }); } // Make sure we just called sourceMessage(); invariant(state() == State::SourceWait); - auto remote = session()->remote(); + auto remote = _session()->remote(); if (status.isOK()) { _state.store(State::Process); @@ -224,16 +224,16 @@ void ServiceStateMachine::sourceCallback(Status status) { _state.store(State::EndSession); } else { log() << "Error receiving request from client: " << status << ". Ending connection from " - << remote << " (connection id: " << session()->id() << ")"; + << remote << " (connection id: " << _session()->id() << ")"; _state.store(State::EndSession); } // There was an error receiving a message from the client and we've already printed the error // so call runNextInGuard() to clean up the session without waiting. - runNextInGuard(guard); + _runNextInGuard(guard); } -void ServiceStateMachine::sinkCallback(Status status) { +void ServiceStateMachine::_sinkCallback(Status status) { // The first thing to do is create a ThreadGuard which will take ownership of the SSM in this // thread. ThreadGuard guard(this); @@ -241,7 +241,7 @@ void ServiceStateMachine::sinkCallback(Status status) { // runNext() so that this thread can do other useful work with its timeslice instead of going // to sleep while waiting for the SSM to be released. if (!guard) { - return scheduleFunc([this, status] { sinkCallback(status); }); + return _scheduleFunc([this, status] { _sinkCallback(status); }); } invariant(state() == State::SinkWait); @@ -253,20 +253,24 @@ void ServiceStateMachine::sinkCallback(Status status) { // scheduleNext() to unwind the stack and do the next step. if (!status.isOK()) { log() << "Error sending response to client: " << status << ". Ending connection from " - << session()->remote() << " (connection id: " << session()->id() << ")"; + << _session()->remote() << " (connection id: " << _session()->id() << ")"; _state.store(State::EndSession); - return runNextInGuard(guard); + return _runNextInGuard(guard); } else if (inExhaust) { _state.store(State::Process); } else { _state.store(State::Source); } - // Call scheduleNext() to unwind the stack and run next step - scheduleNext(); + // If the session ended, then runNext to clean it up + if (state() == State::EndSession) { + _runNextInGuard(guard); + } else { // Otherwise scheduleNext to unwind the stack and run the next step later + scheduleNext(); + } } -void ServiceStateMachine::processMessage() { +void ServiceStateMachine::_processMessage(ThreadGuard& guard) { // This may have been called just after a failure to source a message, in which case this // should return early so the session can be cleaned up. if (state() != State::Process) { @@ -274,7 +278,7 @@ void ServiceStateMachine::processMessage() { } invariant(!_inMessage.empty()); - auto& compressorMgr = MessageCompressorManager::forSession(session()); + auto& compressorMgr = MessageCompressorManager::forSession(_session()); if (_inMessage.operation() == dbCompressed) { auto swm = compressorMgr.decompressMessage(_inMessage); @@ -288,7 +292,7 @@ void ServiceStateMachine::processMessage() { networkCounter.hitLogicalIn(_inMessage.size()); // Pass sourced Message to handler to generate response. - auto opCtx = cc().makeOperationContext(); + auto opCtx = Client::getCurrent()->makeOperationContext(); // The handleRequest is implemented in a subclass for mongod/mongos and actually all the // database work for this request. @@ -321,14 +325,14 @@ void ServiceStateMachine::processMessage() { } // Sink our response to the client - auto ticket = session()->sinkMessage(toSink); + auto ticket = _session()->sinkMessage(toSink); _state.store(State::SinkWait); if (_sync) { - sinkCallback(session()->getTransportLayer()->wait(std::move(ticket))); + _sinkCallback(_session()->getTransportLayer()->wait(std::move(ticket))); } else { - session()->getTransportLayer()->asyncWait( - std::move(ticket), [this](Status status) { sinkCallback(status); }); + _session()->getTransportLayer()->asyncWait( + std::move(ticket), [this](Status status) { _sinkCallback(status); }); } } else { _state.store(State::Source); @@ -347,10 +351,10 @@ void ServiceStateMachine::runNext() { if (!guard) { return scheduleNext(); } - return runNextInGuard(guard); + return _runNextInGuard(guard); } -void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) { +void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) { auto curState = state(); invariant(curState != State::Ended); @@ -367,21 +371,21 @@ void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) { case State::Source: { invariant(_inMessage.empty()); - auto ticket = session()->sourceMessage(&_inMessage); + auto ticket = _session()->sourceMessage(&_inMessage); _state.store(State::SourceWait); if (_sync) { - sourceCallback([&] { + _sourceCallback([this](auto ticket) { MONGO_IDLE_THREAD_BLOCK; - return session()->getTransportLayer()->wait(std::move(ticket)); - }()); + return _session()->getTransportLayer()->wait(std::move(ticket)); + }(std::move(ticket))); } else { - session()->getTransportLayer()->asyncWait( - std::move(ticket), [this](Status status) { sourceCallback(status); }); + _session()->getTransportLayer()->asyncWait( + std::move(ticket), [this](Status status) { _sourceCallback(status); }); break; } } case State::Process: - processMessage(); + _processMessage(guard); break; case State::EndSession: // This will get handled below in an if statement. That way if an error occurs @@ -396,7 +400,7 @@ void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) { } if (state() == State::EndSession) { - cleanupSession(); + _cleanupSession(guard); } return; @@ -413,18 +417,23 @@ void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) { } _state.store(State::EndSession); - cleanupSession(); + _cleanupSession(guard); } void ServiceStateMachine::scheduleNext() { - maybeScheduleFunc(_serviceContext->getServiceExecutor(), [this] { runNext(); }); + _maybeScheduleFunc(_serviceContext->getServiceExecutor(), [this] { runNext(); }); } -void ServiceStateMachine::terminate() { +void ServiceStateMachine::terminateIfTagsDontMatch(transport::Session::TagMask tags) { if (state() == State::Ended) return; - auto tl = session()->getTransportLayer(); - tl->end(session()); + + if (_session()->getTags() & tags) { + log() << "Skip closing connection for connection # " << _session()->id(); + return; + } + + _session()->getTransportLayer()->end(_session()); } void ServiceStateMachine::setCleanupHook(stdx::function<void()> hook) { @@ -436,18 +445,24 @@ ServiceStateMachine::State ServiceStateMachine::state() { return _state.load(); } -void ServiceStateMachine::cleanupSession() { +void ServiceStateMachine::_cleanupSession(ThreadGuard& guard) { _state.store(State::Ended); - auto tl = session()->getTransportLayer(); + auto tl = _session()->getTransportLayer(); + auto remote = _session()->remote(); _inMessage.reset(); - auto remote = session()->remote(); + // By ignoring the return value of Client::releaseCurrent() we destroy the session. + // _dbClient is now nullptr and _dbClientPtr is invalid and should never be accessed. Client::releaseCurrent(); if (!serverGlobalParams.quiet.load()) { - auto conns = tl->sessionStats().numOpenSessions; + // Get the number of open sessions minus 1 (this one will get cleaned up when + // this SSM gets destroyed) + // TODO Swich to using ServiceEntryPointImpl::getNumberOfConnections(), or move this + // into the ServiceEntryPoint + auto conns = tl->sessionStats().numOpenSessions - 1; const char* word = (conns == 1 ? " connection" : " connections"); log() << "end connection " << remote << " (" << conns << word << " now open)"; } diff --git a/src/mongo/transport/service_state_machine.h b/src/mongo/transport/service_state_machine.h index e3212ad21b0..bd56bf374b4 100644 --- a/src/mongo/transport/service_state_machine.h +++ b/src/mongo/transport/service_state_machine.h @@ -113,21 +113,17 @@ public: State state(); /* - * Terminates the associated transport Session, and requests that the next call to runNext - * should end the session. If the session has already ended, this does nothing. + * Terminates the associated transport Session if its tags don't match the supplied tags. + * + * This will not block on the session terminating cleaning itself up, it returns immediately. */ - void terminate(); + void terminateIfTagsDontMatch(transport::Session::TagMask tags); /* * Sets a function to be called after the session is ended */ void setCleanupHook(stdx::function<void()> hook); - /* - * Gets the transport::Session associated with this connection - */ - const transport::SessionHandle& session() const; - private: /* * A class that wraps up lifetime management of the _dbClient and _threadName for runNext(); @@ -141,7 +137,7 @@ private: * callbacks to run. */ template <typename Executor, typename Func> - void maybeScheduleFunc(Executor* svcExec, Func&& func) { + void _maybeScheduleFunc(Executor* svcExec, Func&& func) { if (svcExec) { uassertStatusOK(svcExec->schedule( [ func = std::move(func), anchor = shared_from_this() ] { func(); })); @@ -149,36 +145,41 @@ private: } template <typename Func> - void scheduleFunc(Func&& func) { + void _scheduleFunc(Func&& func) { auto svcExec = _serviceContext->getServiceExecutor(); invariant(svcExec); - maybeScheduleFunc(svcExec, func); + _maybeScheduleFunc(svcExec, func); } /* + * Gets the transport::Session associated with this connection + */ + const transport::SessionHandle& _session() const; + + /* * This is the actual implementation of runNext() that gets called after the ThreadGuard * has been successfully created. If any callbacks (like sourceCallback()) need to call * runNext() and already own a ThreadGuard, they should call this with that guard as the * argument. */ - void runNextInGuard(ThreadGuard& guard); + void _runNextInGuard(ThreadGuard& guard); /* * This function actually calls into the database and processes a request. It's broken out * into its own inline function for better readability. */ - inline void processMessage(); + inline void _processMessage(ThreadGuard& guard); /* * These get called by the TransportLayer when requested network I/O has completed. */ - void sourceCallback(Status status); - void sinkCallback(Status status); + void _sourceCallback(Status status); + void _sinkCallback(Status status); /* * Releases all the resources associated with the session and call the cleanupHook. */ - void cleanupSession(); + void _cleanupSession(ThreadGuard& guard); AtomicWord<State> _state{State::Created}; @@ -186,6 +187,8 @@ private: bool _sync; ServiceContext* const _serviceContext; + + transport::SessionHandle _sessionHandle; ServiceContext::UniqueClient _dbClient; const Client* _dbClientPtr; const std::string _threadName; diff --git a/src/mongo/transport/service_state_machine_test.cpp b/src/mongo/transport/service_state_machine_test.cpp index fd65cb0cf6a..0af41b6e593 100644 --- a/src/mongo/transport/service_state_machine_test.cpp +++ b/src/mongo/transport/service_state_machine_test.cpp @@ -283,4 +283,28 @@ TEST_F(ServiceStateMachineFixture, TestSinkError) { ASSERT_TRUE(_tl->ranSink()); } +// This test checks that after the SSM has been cleaned up, the SessionHandle that it passed +// into the Client doesn't have any dangling shared_ptr copies. +TEST_F(ServiceStateMachineFixture, TestSessionCleanupOnDestroy) { + // Set a cleanup hook so we know that the cleanup hook actually gets run when the session + // is destroyed + bool hookRan = false; + _ssm->setCleanupHook([&hookRan] { hookRan = true; }); + + // Do a regular ping test so that all the processMessage/sinkMessage code gets exercised + ASSERT_EQ(ServiceStateMachine::State::Source, runPingTest()); + + // Set the next run up to fail on source (like a disconnected client) and run it + _tl->setNextFailure(MockTL::Source); + _ssm->runNext(); + ASSERT_EQ(ServiceStateMachine::State::Ended, _ssm->state()); + + // Check that after the failure and the session getting cleaned up that the SessionHandle + // only has one use (our copy in _sessionHandle) + ASSERT_EQ(_ssm.use_count(), 1); + + // Make sure the cleanup hook actually ran. + ASSERT_TRUE(hookRan); +} + } // namespace mongo |