summaryrefslogtreecommitdiff
path: root/src/mongo/transport
diff options
context:
space:
mode:
authorJonathan Reams <jbreams@mongodb.com>2017-07-21 12:51:16 -0400
committerJonathan Reams <jbreams@mongodb.com>2017-07-28 11:00:14 -0400
commit6f5f53f79552aacbbfba8c7e61cf7b15f58f3f3f (patch)
tree983ec64121fbe98e400e7c47ea266c07da3712b3 /src/mongo/transport
parent3d313292efac3654393d30e0eb439e3df7728171 (diff)
downloadmongo-6f5f53f79552aacbbfba8c7e61cf7b15f58f3f3f.tar.gz
SERVER-30260 Fix race condition in endAllSessions
Diffstat (limited to 'src/mongo/transport')
-rw-r--r--src/mongo/transport/service_entry_point_impl.cpp23
-rw-r--r--src/mongo/transport/service_state_machine.cpp97
-rw-r--r--src/mongo/transport/service_state_machine.h35
-rw-r--r--src/mongo/transport/service_state_machine_test.cpp24
4 files changed, 101 insertions, 78 deletions
diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp
index cd2fa26ea1e..5b07f8b2c6c 100644
--- a/src/mongo/transport/service_entry_point_impl.cpp
+++ b/src/mongo/transport/service_entry_point_impl.cpp
@@ -100,33 +100,14 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) {
}
void ServiceEntryPointImpl::endAllSessions(transport::Session::TagMask tags) {
- SSMList connsToEnd;
-
// While holding the _sesionsMutex, loop over all the current connections, and if their tags
- // do not match the requested tags to skip, create a copy of their shared_ptr and place it in
- // connsToEnd.
- //
- // This will ensure that sessions to be ended will live at least long enough for us to call
- // their terminate() function, even if they've already ended because of an i/o error.
+ // do not match the requested tags to skip, terminate the session.
{
stdx::unique_lock<decltype(_sessionsMutex)> lk(_sessionsMutex);
for (auto& ssm : _sessions) {
- if (ssm->session()->getTags() & tags) {
- log() << "Skip closing connection for connection # " << ssm->session()->id();
- } else {
- connsToEnd.emplace_back(ssm);
- }
+ ssm->terminateIfTagsDontMatch(tags);
}
}
-
- // Loop through all the connections we marked for ending and call terminate on them. They will
- // then remove themselves from _sessions whenever they transition to the next state.
- //
- // If they've already ended, then this is a noop, and the SSM will be destroyed when connsToEnd
- // goes out of scope.
- for (auto& ssm : connsToEnd) {
- ssm->terminate();
- }
}
std::size_t ServiceEntryPointImpl::getNumberOfConnections() const {
diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp
index d30a99dcb63..9c2d335cee7 100644
--- a/src/mongo/transport/service_state_machine.cpp
+++ b/src/mongo/transport/service_state_machine.cpp
@@ -182,17 +182,17 @@ ServiceStateMachine::ServiceStateMachine(ServiceContext* svcContext,
_sep{svcContext->getServiceEntryPoint()},
_sync(sync),
_serviceContext(svcContext),
+ _sessionHandle(session),
_dbClient{svcContext->makeClient("conn", std::move(session))},
_dbClientPtr{_dbClient.get()},
- _threadName{str::stream() << "conn" << _dbClient->session()->id()},
+ _threadName{str::stream() << "conn" << _session()->id()},
_currentOwningThread{stdx::this_thread::get_id()} {}
-const transport::SessionHandle& ServiceStateMachine::session() const {
- // The _dbClientPtr should always point to our Client which should always own our SessionHandle
- return _dbClientPtr->session();
+const transport::SessionHandle& ServiceStateMachine::_session() const {
+ return _sessionHandle;
}
-void ServiceStateMachine::sourceCallback(Status status) {
+void ServiceStateMachine::_sourceCallback(Status status) {
// The first thing to do is create a ThreadGuard which will take ownership of the SSM in this
// thread.
ThreadGuard guard(this);
@@ -200,12 +200,12 @@ void ServiceStateMachine::sourceCallback(Status status) {
// runNext() so that this thread can do other useful work with its timeslice instead of going
// to sleep while waiting for the SSM to be released.
if (!guard) {
- return scheduleFunc([this, status] { sourceCallback(status); });
+ return _scheduleFunc([this, status] { _sourceCallback(status); });
}
// Make sure we just called sourceMessage();
invariant(state() == State::SourceWait);
- auto remote = session()->remote();
+ auto remote = _session()->remote();
if (status.isOK()) {
_state.store(State::Process);
@@ -224,16 +224,16 @@ void ServiceStateMachine::sourceCallback(Status status) {
_state.store(State::EndSession);
} else {
log() << "Error receiving request from client: " << status << ". Ending connection from "
- << remote << " (connection id: " << session()->id() << ")";
+ << remote << " (connection id: " << _session()->id() << ")";
_state.store(State::EndSession);
}
// There was an error receiving a message from the client and we've already printed the error
// so call runNextInGuard() to clean up the session without waiting.
- runNextInGuard(guard);
+ _runNextInGuard(guard);
}
-void ServiceStateMachine::sinkCallback(Status status) {
+void ServiceStateMachine::_sinkCallback(Status status) {
// The first thing to do is create a ThreadGuard which will take ownership of the SSM in this
// thread.
ThreadGuard guard(this);
@@ -241,7 +241,7 @@ void ServiceStateMachine::sinkCallback(Status status) {
// runNext() so that this thread can do other useful work with its timeslice instead of going
// to sleep while waiting for the SSM to be released.
if (!guard) {
- return scheduleFunc([this, status] { sinkCallback(status); });
+ return _scheduleFunc([this, status] { _sinkCallback(status); });
}
invariant(state() == State::SinkWait);
@@ -253,20 +253,24 @@ void ServiceStateMachine::sinkCallback(Status status) {
// scheduleNext() to unwind the stack and do the next step.
if (!status.isOK()) {
log() << "Error sending response to client: " << status << ". Ending connection from "
- << session()->remote() << " (connection id: " << session()->id() << ")";
+ << _session()->remote() << " (connection id: " << _session()->id() << ")";
_state.store(State::EndSession);
- return runNextInGuard(guard);
+ return _runNextInGuard(guard);
} else if (inExhaust) {
_state.store(State::Process);
} else {
_state.store(State::Source);
}
- // Call scheduleNext() to unwind the stack and run next step
- scheduleNext();
+ // If the session ended, then runNext to clean it up
+ if (state() == State::EndSession) {
+ _runNextInGuard(guard);
+ } else { // Otherwise scheduleNext to unwind the stack and run the next step later
+ scheduleNext();
+ }
}
-void ServiceStateMachine::processMessage() {
+void ServiceStateMachine::_processMessage(ThreadGuard& guard) {
// This may have been called just after a failure to source a message, in which case this
// should return early so the session can be cleaned up.
if (state() != State::Process) {
@@ -274,7 +278,7 @@ void ServiceStateMachine::processMessage() {
}
invariant(!_inMessage.empty());
- auto& compressorMgr = MessageCompressorManager::forSession(session());
+ auto& compressorMgr = MessageCompressorManager::forSession(_session());
if (_inMessage.operation() == dbCompressed) {
auto swm = compressorMgr.decompressMessage(_inMessage);
@@ -288,7 +292,7 @@ void ServiceStateMachine::processMessage() {
networkCounter.hitLogicalIn(_inMessage.size());
// Pass sourced Message to handler to generate response.
- auto opCtx = cc().makeOperationContext();
+ auto opCtx = Client::getCurrent()->makeOperationContext();
// The handleRequest is implemented in a subclass for mongod/mongos and actually all the
// database work for this request.
@@ -321,14 +325,14 @@ void ServiceStateMachine::processMessage() {
}
// Sink our response to the client
- auto ticket = session()->sinkMessage(toSink);
+ auto ticket = _session()->sinkMessage(toSink);
_state.store(State::SinkWait);
if (_sync) {
- sinkCallback(session()->getTransportLayer()->wait(std::move(ticket)));
+ _sinkCallback(_session()->getTransportLayer()->wait(std::move(ticket)));
} else {
- session()->getTransportLayer()->asyncWait(
- std::move(ticket), [this](Status status) { sinkCallback(status); });
+ _session()->getTransportLayer()->asyncWait(
+ std::move(ticket), [this](Status status) { _sinkCallback(status); });
}
} else {
_state.store(State::Source);
@@ -347,10 +351,10 @@ void ServiceStateMachine::runNext() {
if (!guard) {
return scheduleNext();
}
- return runNextInGuard(guard);
+ return _runNextInGuard(guard);
}
-void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) {
+void ServiceStateMachine::_runNextInGuard(ThreadGuard& guard) {
auto curState = state();
invariant(curState != State::Ended);
@@ -367,21 +371,21 @@ void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) {
case State::Source: {
invariant(_inMessage.empty());
- auto ticket = session()->sourceMessage(&_inMessage);
+ auto ticket = _session()->sourceMessage(&_inMessage);
_state.store(State::SourceWait);
if (_sync) {
- sourceCallback([&] {
+ _sourceCallback([this](auto ticket) {
MONGO_IDLE_THREAD_BLOCK;
- return session()->getTransportLayer()->wait(std::move(ticket));
- }());
+ return _session()->getTransportLayer()->wait(std::move(ticket));
+ }(std::move(ticket)));
} else {
- session()->getTransportLayer()->asyncWait(
- std::move(ticket), [this](Status status) { sourceCallback(status); });
+ _session()->getTransportLayer()->asyncWait(
+ std::move(ticket), [this](Status status) { _sourceCallback(status); });
break;
}
}
case State::Process:
- processMessage();
+ _processMessage(guard);
break;
case State::EndSession:
// This will get handled below in an if statement. That way if an error occurs
@@ -396,7 +400,7 @@ void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) {
}
if (state() == State::EndSession) {
- cleanupSession();
+ _cleanupSession(guard);
}
return;
@@ -413,18 +417,23 @@ void ServiceStateMachine::runNextInGuard(ThreadGuard& guard) {
}
_state.store(State::EndSession);
- cleanupSession();
+ _cleanupSession(guard);
}
void ServiceStateMachine::scheduleNext() {
- maybeScheduleFunc(_serviceContext->getServiceExecutor(), [this] { runNext(); });
+ _maybeScheduleFunc(_serviceContext->getServiceExecutor(), [this] { runNext(); });
}
-void ServiceStateMachine::terminate() {
+void ServiceStateMachine::terminateIfTagsDontMatch(transport::Session::TagMask tags) {
if (state() == State::Ended)
return;
- auto tl = session()->getTransportLayer();
- tl->end(session());
+
+ if (_session()->getTags() & tags) {
+ log() << "Skip closing connection for connection # " << _session()->id();
+ return;
+ }
+
+ _session()->getTransportLayer()->end(_session());
}
void ServiceStateMachine::setCleanupHook(stdx::function<void()> hook) {
@@ -436,18 +445,24 @@ ServiceStateMachine::State ServiceStateMachine::state() {
return _state.load();
}
-void ServiceStateMachine::cleanupSession() {
+void ServiceStateMachine::_cleanupSession(ThreadGuard& guard) {
_state.store(State::Ended);
- auto tl = session()->getTransportLayer();
+ auto tl = _session()->getTransportLayer();
+ auto remote = _session()->remote();
_inMessage.reset();
- auto remote = session()->remote();
+ // By ignoring the return value of Client::releaseCurrent() we destroy the session.
+ // _dbClient is now nullptr and _dbClientPtr is invalid and should never be accessed.
Client::releaseCurrent();
if (!serverGlobalParams.quiet.load()) {
- auto conns = tl->sessionStats().numOpenSessions;
+ // Get the number of open sessions minus 1 (this one will get cleaned up when
+ // this SSM gets destroyed)
+ // TODO Swich to using ServiceEntryPointImpl::getNumberOfConnections(), or move this
+ // into the ServiceEntryPoint
+ auto conns = tl->sessionStats().numOpenSessions - 1;
const char* word = (conns == 1 ? " connection" : " connections");
log() << "end connection " << remote << " (" << conns << word << " now open)";
}
diff --git a/src/mongo/transport/service_state_machine.h b/src/mongo/transport/service_state_machine.h
index e3212ad21b0..bd56bf374b4 100644
--- a/src/mongo/transport/service_state_machine.h
+++ b/src/mongo/transport/service_state_machine.h
@@ -113,21 +113,17 @@ public:
State state();
/*
- * Terminates the associated transport Session, and requests that the next call to runNext
- * should end the session. If the session has already ended, this does nothing.
+ * Terminates the associated transport Session if its tags don't match the supplied tags.
+ *
+ * This will not block on the session terminating cleaning itself up, it returns immediately.
*/
- void terminate();
+ void terminateIfTagsDontMatch(transport::Session::TagMask tags);
/*
* Sets a function to be called after the session is ended
*/
void setCleanupHook(stdx::function<void()> hook);
- /*
- * Gets the transport::Session associated with this connection
- */
- const transport::SessionHandle& session() const;
-
private:
/*
* A class that wraps up lifetime management of the _dbClient and _threadName for runNext();
@@ -141,7 +137,7 @@ private:
* callbacks to run.
*/
template <typename Executor, typename Func>
- void maybeScheduleFunc(Executor* svcExec, Func&& func) {
+ void _maybeScheduleFunc(Executor* svcExec, Func&& func) {
if (svcExec) {
uassertStatusOK(svcExec->schedule(
[ func = std::move(func), anchor = shared_from_this() ] { func(); }));
@@ -149,36 +145,41 @@ private:
}
template <typename Func>
- void scheduleFunc(Func&& func) {
+ void _scheduleFunc(Func&& func) {
auto svcExec = _serviceContext->getServiceExecutor();
invariant(svcExec);
- maybeScheduleFunc(svcExec, func);
+ _maybeScheduleFunc(svcExec, func);
}
/*
+ * Gets the transport::Session associated with this connection
+ */
+ const transport::SessionHandle& _session() const;
+
+ /*
* This is the actual implementation of runNext() that gets called after the ThreadGuard
* has been successfully created. If any callbacks (like sourceCallback()) need to call
* runNext() and already own a ThreadGuard, they should call this with that guard as the
* argument.
*/
- void runNextInGuard(ThreadGuard& guard);
+ void _runNextInGuard(ThreadGuard& guard);
/*
* This function actually calls into the database and processes a request. It's broken out
* into its own inline function for better readability.
*/
- inline void processMessage();
+ inline void _processMessage(ThreadGuard& guard);
/*
* These get called by the TransportLayer when requested network I/O has completed.
*/
- void sourceCallback(Status status);
- void sinkCallback(Status status);
+ void _sourceCallback(Status status);
+ void _sinkCallback(Status status);
/*
* Releases all the resources associated with the session and call the cleanupHook.
*/
- void cleanupSession();
+ void _cleanupSession(ThreadGuard& guard);
AtomicWord<State> _state{State::Created};
@@ -186,6 +187,8 @@ private:
bool _sync;
ServiceContext* const _serviceContext;
+
+ transport::SessionHandle _sessionHandle;
ServiceContext::UniqueClient _dbClient;
const Client* _dbClientPtr;
const std::string _threadName;
diff --git a/src/mongo/transport/service_state_machine_test.cpp b/src/mongo/transport/service_state_machine_test.cpp
index fd65cb0cf6a..0af41b6e593 100644
--- a/src/mongo/transport/service_state_machine_test.cpp
+++ b/src/mongo/transport/service_state_machine_test.cpp
@@ -283,4 +283,28 @@ TEST_F(ServiceStateMachineFixture, TestSinkError) {
ASSERT_TRUE(_tl->ranSink());
}
+// This test checks that after the SSM has been cleaned up, the SessionHandle that it passed
+// into the Client doesn't have any dangling shared_ptr copies.
+TEST_F(ServiceStateMachineFixture, TestSessionCleanupOnDestroy) {
+ // Set a cleanup hook so we know that the cleanup hook actually gets run when the session
+ // is destroyed
+ bool hookRan = false;
+ _ssm->setCleanupHook([&hookRan] { hookRan = true; });
+
+ // Do a regular ping test so that all the processMessage/sinkMessage code gets exercised
+ ASSERT_EQ(ServiceStateMachine::State::Source, runPingTest());
+
+ // Set the next run up to fail on source (like a disconnected client) and run it
+ _tl->setNextFailure(MockTL::Source);
+ _ssm->runNext();
+ ASSERT_EQ(ServiceStateMachine::State::Ended, _ssm->state());
+
+ // Check that after the failure and the session getting cleaned up that the SessionHandle
+ // only has one use (our copy in _sessionHandle)
+ ASSERT_EQ(_ssm.use_count(), 1);
+
+ // Make sure the cleanup hook actually ran.
+ ASSERT_TRUE(hookRan);
+}
+
} // namespace mongo