diff options
-rw-r--r-- | src/mongo/SConscript | 4 | ||||
-rw-r--r-- | src/mongo/db/db.cpp | 2 | ||||
-rw-r--r-- | src/mongo/s/server.cpp | 3 | ||||
-rw-r--r-- | src/mongo/transport/SConscript | 25 | ||||
-rw-r--r-- | src/mongo/transport/mock_ticket.h | 9 | ||||
-rw-r--r-- | src/mongo/transport/service_entry_point_impl.cpp | 136 | ||||
-rw-r--r-- | src/mongo/transport/service_entry_point_impl.h | 12 | ||||
-rw-r--r-- | src/mongo/transport/service_entry_point_utils.cpp | 63 | ||||
-rw-r--r-- | src/mongo/transport/service_entry_point_utils.h | 3 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.cpp | 371 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.h | 152 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine_test.cpp | 295 | ||||
-rw-r--r-- | src/mongo/transport/ticket.cpp | 12 | ||||
-rw-r--r-- | src/mongo/transport/ticket.h | 5 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_mock.cpp | 2 |
15 files changed, 900 insertions, 194 deletions
diff --git a/src/mongo/SConscript b/src/mongo/SConscript index 803080e4ab6..1687378dda5 100644 --- a/src/mongo/SConscript +++ b/src/mongo/SConscript @@ -306,7 +306,7 @@ mongod = env.Program( 'executor/network_interface_factory', 'rpc/rpc', 's/commands/shared_cluster_commands', - 'transport/service_entry_point_utils', + 'transport/service_entry_point', 'transport/transport_layer_legacy', 'util/clock_sources', 'util/fail_point', @@ -362,7 +362,7 @@ env.Install( 's/is_mongos', 's/sharding_egress_metadata_hook_for_mongos', 's/sharding_initialization', - 'transport/service_entry_point_utils', + 'transport/service_entry_point', 'transport/transport_layer_legacy', 'util/clock_sources', 'util/fail_point', diff --git a/src/mongo/db/db.cpp b/src/mongo/db/db.cpp index 3803808efab..d635c7927fd 100644 --- a/src/mongo/db/db.cpp +++ b/src/mongo/db/db.cpp @@ -489,7 +489,7 @@ ExitCode _initAndListen(int listenPort) { options.ipList = serverGlobalParams.bind_ip; globalServiceContext->setServiceEntryPoint( - stdx::make_unique<ServiceEntryPointMongod>(globalServiceContext->getTransportLayer())); + stdx::make_unique<ServiceEntryPointMongod>(globalServiceContext)); // Create, start, and attach the TL auto transportLayer = stdx::make_unique<transport::TransportLayerLegacy>( diff --git a/src/mongo/s/server.cpp b/src/mongo/s/server.cpp index 791dbae5e91..74fe11548c5 100644 --- a/src/mongo/s/server.cpp +++ b/src/mongo/s/server.cpp @@ -254,8 +254,7 @@ static ExitCode runMongosServer() { opts.port = serverGlobalParams.port; opts.ipList = serverGlobalParams.bind_ip; - auto sep = - stdx::make_unique<ServiceEntryPointMongos>(getGlobalServiceContext()->getTransportLayer()); + auto sep = stdx::make_unique<ServiceEntryPointMongos>(getGlobalServiceContext()); auto sepPtr = sep.get(); getGlobalServiceContext()->setServiceEntryPoint(std::move(sep)); diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 4d66ab981e1..e623b0a88af 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -17,6 +17,7 @@ env.CppUnitTest( env.Library( target='transport_layer_common', source=[ + 'service_entry_point_utils.cpp', 'session.cpp', 'ticket.cpp', 'transport_layer.cpp', @@ -71,13 +72,14 @@ env.Library( ) env.Library( - target='service_entry_point_utils', + target='service_entry_point', source=[ - 'service_entry_point_utils.cpp', 'service_entry_point_impl.cpp', + 'service_state_machine.cpp', ], LIBDEPS=[ "$BUILD_DIR/mongo/db/service_context", + "$BUILD_DIR/mongo/util/processinfo", 'transport_layer_common', ], ) @@ -94,6 +96,25 @@ env.CppUnitTest( ) env.CppUnitTest( + target='service_state_machine_test', + source=[ + 'service_state_machine_test.cpp', + ], + LIBDEPS=[ + 'service_entry_point', + 'transport_layer_common', + 'transport_layer_mock', + '$BUILD_DIR/mongo/db/dbmessage', + '$BUILD_DIR/mongo/db/service_context', + '$BUILD_DIR/mongo/rpc/command_reply', + '$BUILD_DIR/mongo/rpc/command_request', + '$BUILD_DIR/mongo/unittest/unittest', + '$BUILD_DIR/mongo/util/clock_source_mock', + '$BUILD_DIR/mongo/util/decorable', + ], +) + +env.CppUnitTest( target='transport_layer_mock_test', source=[ 'transport_layer_mock_test.cpp', diff --git a/src/mongo/transport/mock_ticket.h b/src/mongo/transport/mock_ticket.h index 437daae23f6..bbf4eaa9df4 100644 --- a/src/mongo/transport/mock_ticket.h +++ b/src/mongo/transport/mock_ticket.h @@ -50,11 +50,11 @@ public: MockTicket(const SessionHandle& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) - : _id(session->id()), _message(message), _expiration(expiration) {} + : _session(session), _id(session->id()), _message(message), _expiration(expiration) {} // Sink constructor MockTicket(const SessionHandle& session, Date_t expiration = Ticket::kNoExpirationDate) - : _id(session->id()), _expiration(expiration) {} + : _session(session), _id(session->id()), _expiration(expiration) {} SessionId sessionId() const override { return _id; @@ -68,7 +68,12 @@ public: return _message; } + SessionHandle session() const { + return _session.lock(); + } + private: + std::weak_ptr<Session> _session; Session::Id _id; boost::optional<Message*> _message; Date_t _expiration; diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp index e6db7f875cd..c802e0a8471 100644 --- a/src/mongo/transport/service_entry_point_impl.cpp +++ b/src/mongo/transport/service_entry_point_impl.cpp @@ -34,136 +34,36 @@ #include <vector> -#include "mongo/db/assemble_response.h" -#include "mongo/db/client.h" -#include "mongo/db/dbmessage.h" -#include "mongo/stdx/thread.h" #include "mongo/transport/service_entry_point_utils.h" +#include "mongo/transport/service_state_machine.h" #include "mongo/transport/session.h" -#include "mongo/transport/ticket.h" -#include "mongo/transport/transport_layer.h" -#include "mongo/util/concurrency/idle_thread_block.h" -#include "mongo/util/exit.h" -#include "mongo/util/log.h" -#include "mongo/util/net/message.h" -#include "mongo/util/net/socket_exception.h" -#include "mongo/util/net/thread_idle_callback.h" -#include "mongo/util/quick_exit.h" +#include "mongo/util/processinfo.h" #include "mongo/util/scopeguard.h" namespace mongo { -namespace { - -// Set up proper headers for formatting an exhaust request, if we need to -bool setExhaustMessage(Message* m, const DbResponse& dbresponse) { - MsgData::View header = dbresponse.response.header(); - QueryResult::View qr = header.view2ptr(); - long long cursorid = qr.getCursorId(); - - if (!cursorid) { - return false; - } - - verify(dbresponse.exhaustNS.size() && dbresponse.exhaustNS[0]); - - auto ns = dbresponse.exhaustNS; // reset() will free this - - m->reset(); - - BufBuilder b(512); - b.appendNum(static_cast<int>(0) /* size set later in appendData() */); - b.appendNum(header.getId()); - b.appendNum(header.getResponseToMsgId()); - b.appendNum(static_cast<int>(dbGetMore)); - b.appendNum(static_cast<int>(0)); - b.appendStr(ns); - b.appendNum(static_cast<int>(0)); // ntoreturn - b.appendNum(cursorid); - - MsgData::View(b.buf()).setLen(b.len()); - m->setData(b.release()); - - return true; -} - -} // namespace - -using transport::Session; -using transport::TransportLayer; void ServiceEntryPointImpl::startSession(transport::SessionHandle session) { // Pass ownership of the transport::SessionHandle into our worker thread. When this // thread exits, the session will end. - launchWrappedServiceEntryWorkerThread( - std::move(session), [this](const transport::SessionHandle& session) { - _nWorkers.fetchAndAdd(1); - auto guard = MakeGuard([&] { _nWorkers.fetchAndSubtract(1); }); - - _sessionLoop(session); - }); -} - -void ServiceEntryPointImpl::_sessionLoop(const transport::SessionHandle& session) { - Message inMessage; - bool inExhaust = false; - int64_t counter = 0; - - while (true) { - // 1. Source a Message from the client (unless we are exhausting) - if (!inExhaust) { - inMessage.reset(); - auto status = [&] { - MONGO_IDLE_THREAD_BLOCK; - return session->sourceMessage(&inMessage).wait(); - }(); - - if (ErrorCodes::isInterruption(status.code()) || - ErrorCodes::isNetworkError(status.code())) { - break; - } - - // Our session may have been closed internally. - if (status == TransportLayer::TicketSessionClosedStatus) { - break; - } - - uassertStatusOK(status); - } - - // 2. Pass sourced Message to handler to generate response. - auto opCtx = cc().makeOperationContext(); - - // The handleRequest is implemented in a subclass for mongod/mongos and actually all the - // database work for this request. - DbResponse dbresponse = this->handleRequest(opCtx.get(), inMessage, session->remote()); - - // opCtx must be destroyed here so that the operation cannot show - // up in currentOp results after the response reaches the client - opCtx.reset(); - - // 3. Format our response, if we have one - Message& toSink = dbresponse.response; - if (!toSink.empty()) { - toSink.header().setId(nextMessageId()); - toSink.header().setResponseToMsgId(inMessage.header().getId()); - - // If this is an exhaust cursor, don't source more Messages - if (dbresponse.exhaustNS.size() > 0 && setExhaustMessage(&inMessage, dbresponse)) { - inExhaust = true; - } else { - inExhaust = false; + launchServiceWorkerThread([ this, session = std::move(session) ]() mutable { + _nWorkers.addAndFetch(1); + const auto guard = MakeGuard([this] { _nWorkers.subtractAndFetch(1); }); + + ServiceStateMachine ssm(_svcCtx, std::move(session), true); + const auto numCores = [] { + ProcessInfo p; + if (auto availCores = p.getNumAvailableCores()) { + return static_cast<unsigned>(*availCores); } + return static_cast<unsigned>(p.getNumCores()); + }(); - // 4. Sink our response to the client - uassertStatusOK(session->sinkMessage(toSink).wait()); - } else { - inExhaust = false; - } - - if ((counter++ & 0xf) == 0) { - markThreadIdle(); + while (ssm.state() != ServiceStateMachine::State::Ended) { + ssm.runNext(); + if (_nWorkers.load() > numCores) + stdx::this_thread::yield(); } - } + }); } } // namespace mongo diff --git a/src/mongo/transport/service_entry_point_impl.h b/src/mongo/transport/service_entry_point_impl.h index aeb5ce5016e..6dd6047d9d2 100644 --- a/src/mongo/transport/service_entry_point_impl.h +++ b/src/mongo/transport/service_entry_point_impl.h @@ -32,16 +32,14 @@ #include "mongo/base/disallow_copying.h" #include "mongo/platform/atomic_word.h" +#include "mongo/stdx/mutex.h" #include "mongo/transport/service_entry_point.h" namespace mongo { - -struct DbResponse; -class OperationContext; +class ServiceContext; namespace transport { class Session; -class TransportLayer; } // namespace transport /** @@ -55,7 +53,7 @@ class ServiceEntryPointImpl : public ServiceEntryPoint { MONGO_DISALLOW_COPYING(ServiceEntryPointImpl); public: - explicit ServiceEntryPointImpl(transport::TransportLayer* tl) : _tl(tl) {} + explicit ServiceEntryPointImpl(ServiceContext* svcCtx) : _svcCtx(svcCtx) {} void startSession(transport::SessionHandle session) final; @@ -64,9 +62,7 @@ public: } private: - void _sessionLoop(const transport::SessionHandle& session); - - transport::TransportLayer* _tl; + ServiceContext* _svcCtx; AtomicWord<std::size_t> _nWorkers; }; diff --git a/src/mongo/transport/service_entry_point_utils.cpp b/src/mongo/transport/service_entry_point_utils.cpp index 61987c37972..1ce9230d104 100644 --- a/src/mongo/transport/service_entry_point_utils.cpp +++ b/src/mongo/transport/service_entry_point_utils.cpp @@ -32,16 +32,12 @@ #include "mongo/transport/service_entry_point_utils.h" -#include "mongo/db/client.h" -#include "mongo/db/server_options.h" +#include "mongo/stdx/functional.h" #include "mongo/stdx/memory.h" -#include "mongo/transport/session.h" -#include "mongo/transport/transport_layer.h" +#include "mongo/stdx/thread.h" #include "mongo/util/assert_util.h" #include "mongo/util/debug_util.h" #include "mongo/util/log.h" -#include "mongo/util/net/socket_exception.h" -#include "mongo/util/quick_exit.h" #ifdef __linux__ // TODO: consider making this ifndef _WIN32 #include <sys/resource.h> @@ -54,59 +50,16 @@ namespace mongo { namespace { - -/** - * This object takes ownership of transport::SessionHandle. - */ -struct Context { - Context(transport::SessionHandle session, - stdx::function<void(const transport::SessionHandle&)> task) - : session(std::move(session)), task(std::move(task)) {} - - transport::SessionHandle session; - stdx::function<void(const transport::SessionHandle&)> task; -}; - -void* runFunc(void* ptr) { - std::unique_ptr<Context> ctx(static_cast<Context*>(ptr)); - - auto client = getGlobalServiceContext()->makeClient("conn", ctx->session); - setThreadName(str::stream() << "conn" << ctx->session->id()); - - Client::setCurrent(std::move(client)); - - auto tl = ctx->session->getTransportLayer(); - - try { - ctx->task(ctx->session); - } catch (const AssertionException& e) { - log() << "AssertionException handling request, closing client connection: " << e; - } catch (const SocketException& e) { - log() << "SocketException handling request, closing client connection: " << e; - } catch (const DBException& e) { - // must be right above std::exception to avoid catching subclasses - log() << "DBException handling request, closing client connection: " << e; - } catch (const std::exception& e) { - error() << "Uncaught std::exception: " << e.what() << ", terminating"; - quickExit(EXIT_UNCAUGHT); - } - - tl->end(ctx->session); - - if (!serverGlobalParams.quiet.load()) { - auto conns = tl->sessionStats().numOpenSessions; - const char* word = (conns == 1 ? " connection" : " connections"); - log() << "end connection " << ctx->session->remote() << " (" << conns << word - << " now open)"; - } +void* runFunc(void* ctx) { + std::unique_ptr<stdx::function<void()>> taskPtr(static_cast<stdx::function<void()>*>(ctx)); + (*taskPtr)(); return nullptr; } } // namespace -void launchWrappedServiceEntryWorkerThread( - transport::SessionHandle session, stdx::function<void(const transport::SessionHandle&)> task) { - auto ctx = stdx::make_unique<Context>(std::move(session), std::move(task)); +void launchServiceWorkerThread(stdx::function<void()> task) { + auto ctx = stdx::make_unique<stdx::function<void()>>(std::move(task)); try { #ifndef __linux__ // TODO: consider making this ifdef _WIN32 @@ -152,7 +105,7 @@ void launchWrappedServiceEntryWorkerThread( #endif // __linux__ } catch (...) { - log() << "failed to create service entry worker thread for " << ctx->session->remote(); + log() << "failed to create service entry worker thread"; } } diff --git a/src/mongo/transport/service_entry_point_utils.h b/src/mongo/transport/service_entry_point_utils.h index 1c1634af6d5..79f6dbd0171 100644 --- a/src/mongo/transport/service_entry_point_utils.h +++ b/src/mongo/transport/service_entry_point_utils.h @@ -33,7 +33,6 @@ namespace mongo { -void launchWrappedServiceEntryWorkerThread( - transport::SessionHandle session, stdx::function<void(const transport::SessionHandle&)> task); +void launchServiceWorkerThread(stdx::function<void()> task); } // namespace mongo diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp new file mode 100644 index 00000000000..7af9171426f --- /dev/null +++ b/src/mongo/transport/service_state_machine.cpp @@ -0,0 +1,371 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork + +#include "mongo/platform/basic.h" + +#include "mongo/transport/service_state_machine.h" + +#include "mongo/db/assemble_response.h" +#include "mongo/db/client.h" +#include "mongo/db/dbmessage.h" +#include "mongo/stdx/memory.h" +#include "mongo/transport/service_entry_point.h" +#include "mongo/transport/session.h" +#include "mongo/transport/ticket.h" +#include "mongo/transport/transport_layer.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/concurrency/idle_thread_block.h" +#include "mongo/util/concurrency/thread_name.h" +#include "mongo/util/debug_util.h" +#include "mongo/util/exit.h" +#include "mongo/util/log.h" +#include "mongo/util/net/message.h" +#include "mongo/util/net/socket_exception.h" +#include "mongo/util/net/thread_idle_callback.h" +#include "mongo/util/quick_exit.h" +#include "mongo/util/scopeguard.h" + +namespace mongo { +namespace { +// Set up proper headers for formatting an exhaust request, if we need to +bool setExhaustMessage(Message* m, const DbResponse& dbresponse) { + MsgData::View header = dbresponse.response.header(); + QueryResult::View qr = header.view2ptr(); + long long cursorid = qr.getCursorId(); + + if (!cursorid) { + return false; + } + + invariant(dbresponse.exhaustNS.size() && dbresponse.exhaustNS[0]); + + auto ns = dbresponse.exhaustNS; // m->reset() will free this so we must cache a copy + + m->reset(); + + // Rebuild out the response. + BufBuilder b(512); + b.appendNum(static_cast<int>(0) /* size set later in setLen() */); + b.appendNum(header.getId()); // message id + b.appendNum(header.getResponseToMsgId()); // in response to + b.appendNum(static_cast<int>(dbGetMore)); // opCode is OP_GET_MORE + b.appendNum(static_cast<int>(0)); // Must be ZERO (reserved) + b.appendStr(ns); // Namespace + b.appendNum(static_cast<int>(0)); // ntoreturn + b.appendNum(cursorid); // cursor id from the OP_REPLY + + MsgData::View(b.buf()).setLen(b.len()); + m->setData(b.release()); + + return true; +} + +} // namespace + +using transport::TransportLayer; +ServiceStateMachine::ServiceStateMachine(ServiceContext* svcContext, + transport::SessionHandle session, + bool sync) + : _state{State::Source}, + _sep{svcContext->getServiceEntryPoint()}, + _sync(sync), + _dbClient{svcContext->makeClient("conn", std::move(session))}, + _dbClientPtr{_dbClient.get()}, + _threadName{str::stream() << "conn" << _dbClient->session()->id()}, + _currentOwningThread{stdx::this_thread::get_id()} {} + +const transport::SessionHandle& ServiceStateMachine::session() const { + // The _dbClientPtr should always point to our Client which should always own our SessionHandle + return _dbClientPtr->session(); +} + +void ServiceStateMachine::sourceCallback(Status status) { + // Make sure we just called sourceMessage(); + invariant(_state == State::SourceWait); + auto remote = session()->remote(); + + if (status.isOK()) { + _state = State::Process; + } else if (ErrorCodes::isInterruption(status.code()) || + ErrorCodes::isNetworkError(status.code())) { + LOG(2) << "Session from " << remote << " encountered a network error during SourceMessage"; + _state = State::EndSession; + } else if (status == TransportLayer::TicketSessionClosedStatus) { + // Our session may have been closed internally. + LOG(2) << "Session from " << remote << " was closed internally during SourceMessage"; + _state = State::EndSession; + } else { + log() << "Error receiving request from client: " << status << ". Ending connection from " + << remote << " (connection id: " << session()->id() << ")"; + _state = State::EndSession; + } + + // In asyncronous mode this is the entrypoint back into the database from the network layer + // after a message has been received, so we want to call runNext() to process the message. + // + // In synchronous mode, runNext() will fall through to call processMessage() so we avoid + // the recursive call. + if (!_sync) + return runNext(); +} + +void ServiceStateMachine::sinkCallback(Status status) { + invariant(_state == State::SinkWait); + + if (!status.isOK()) { + log() << "Error sending response to client: " << status << ". Ending connection from " + << session()->remote() << " (connection id: " << session()->id() << ")"; + _state = State::EndSession; + } else if (inExhaust) { + _state = State::Process; + } else { + _state = State::Source; + } + + return scheduleNext(); +} + +void ServiceStateMachine::processMessage() { + // This may have been called just after a failure to source a message, in which case this + // should return early so the session can be cleaned up. + if (_state != State::Process) { + return; + } + invariant(!_inMessage.empty()); + + // 2. Pass sourced Message to handler to generate response. + auto opCtx = cc().makeOperationContext(); + + // The handleRequest is implemented in a subclass for mongod/mongos and actually all the + // database work for this request. + DbResponse dbresponse = _sep->handleRequest(opCtx.get(), _inMessage, session()->remote()); + + // opCtx must be destroyed here so that the operation cannot show + // up in currentOp results after the response reaches the client + opCtx.reset(); + + // 3. Format our response, if we have one + Message& toSink = dbresponse.response; + if (!toSink.empty()) { + toSink.header().setId(nextMessageId()); + toSink.header().setResponseToMsgId(_inMessage.header().getId()); + + // If this is an exhaust cursor, don't source more Messages + if (dbresponse.exhaustNS.size() > 0 && setExhaustMessage(&_inMessage, dbresponse)) { + inExhaust = true; + } else { + inExhaust = false; + _inMessage.reset(); + } + + // 4. Sink our response to the client + auto ticket = session()->sinkMessage(toSink); + _state = State::SinkWait; + if (_sync) { + sinkCallback(session()->getTransportLayer()->wait(std::move(ticket))); + } else { + session()->getTransportLayer()->asyncWait( + std::move(ticket), [this](Status status) { sinkCallback(status); }); + } + } else { + _state = State::Source; + _inMessage.reset(); + return scheduleNext(); + } +} + +/* + * This class wraps up the logic for swapping/unswapping the Client during runNext(). + */ +class ServiceStateMachine::ThreadGuard { + ThreadGuard(ThreadGuard&) = delete; + ThreadGuard& operator=(ThreadGuard&) = delete; + +public: + explicit ThreadGuard(ServiceStateMachine* ssm) + : _ssm{ssm}, + _haveTakenOwnership{!_ssm->_isOwned.test_and_set()}, + _oldThreadName{getThreadName().toString()} { + const auto currentOwningThread = _ssm->_currentOwningThread.load(); + const auto currentThreadId = stdx::this_thread::get_id(); + + // If this is true, then we are the "owner" of the Client and we should swap the + // client/thread name before doing any work. + if (_haveTakenOwnership) { + _ssm->_currentOwningThread.store(currentThreadId); + + // Set up the thread name + setThreadName(_ssm->_threadName); + + // These are sanity checks to make sure that the Client is what we expect it to be + invariant(!haveClient()); + invariant(_ssm->_dbClient.get() == _ssm->_dbClientPtr); + + // Swap the current Client so calls to cc() work as expected + Client::setCurrent(std::move(_ssm->_dbClient)); + } else if (currentOwningThread != currentThreadId) { + // If the currentOwningThread does not equal the currentThreadId, then another thread + // currently "owns" the Client and we should reschedule ourself. + _okayToRunNext = false; + } + } + + ~ThreadGuard() { + if (!_haveTakenOwnership) + return; + + if (haveClient()) { + _ssm->_dbClient = Client::releaseCurrent(); + } + setThreadName(_oldThreadName); + _ssm->_isOwned.clear(); + } + + void dismiss() { + _haveTakenOwnership = false; + } + + explicit operator bool() const { + return _okayToRunNext; + } + +private: + ServiceStateMachine* _ssm; + bool _haveTakenOwnership; + const std::string _oldThreadName; + bool _okayToRunNext = true; +}; + +void ServiceStateMachine::runNext() { + ThreadGuard guard(this); + if (!guard) + return scheduleNext(); + + // Make sure the current Client got set correctly + invariant(Client::getCurrent() == _dbClientPtr); + try { + switch (_state) { + case State::Source: { + invariant(_inMessage.empty()); + + auto ticket = session()->sourceMessage(&_inMessage); + _state = State::SourceWait; + if (_sync) { + MONGO_IDLE_THREAD_BLOCK; + sourceCallback(session()->getTransportLayer()->wait(std::move(ticket))); + } else { + session()->getTransportLayer()->asyncWait( + std::move(ticket), [this](Status status) { sourceCallback(status); }); + break; + } + } + case State::Process: + processMessage(); + break; + case State::EndSession: + // This will get handled below in an if statement. That way if an error occurs + // you don't have to call runNext() again to clean up the session. + break; + default: + MONGO_UNREACHABLE; + } + + if (_state == State::EndSession) { + guard.dismiss(); + endSession(); + } + + if ((_counter++ & 0xf) == 0) { + markThreadIdle(); + }; + return; + } catch (const AssertionException& e) { + log() << "AssertionException handling request, closing client connection: " << e; + } catch (const SocketException& e) { + log() << "SocketException handling request, closing client connection: " << e; + } catch (const DBException& e) { + // must be right above std::exception to avoid catching subclasses + log() << "DBException handling request, closing client connection: " << e; + } catch (const std::exception& e) { + error() << "Uncaught std::exception: " << e.what() << ", terminating"; + quickExit(EXIT_UNCAUGHT); + } + + _state = State::EndSession; + guard.dismiss(); + endSession(); +} + +// TODO: Right now this is a noop because we only run in synchronous mode. When an async +// TransportLayer is written, this will call the serviceexecutor to schedule calls to runNext(). +void ServiceStateMachine::scheduleNext() {} + +void ServiceStateMachine::endSession() { + auto tl = session()->getTransportLayer(); + + _inMessage.reset(); + auto remote = session()->remote(); + + Client::releaseCurrent(); + + if (!serverGlobalParams.quiet.load()) { + auto conns = tl->sessionStats().numOpenSessions; + const char* word = (conns == 1 ? " connection" : " connections"); + log() << "end connection " << remote << " (" << conns << word << " now open)"; + } + + _state = State::Ended; +} + +std::ostream& operator<<(std::ostream& stream, const ServiceStateMachine::State& state) { + switch (state) { + case ServiceStateMachine::State::Source: + stream << "source"; + break; + case ServiceStateMachine::State::SourceWait: + stream << "sourceWait"; + break; + case ServiceStateMachine::State::Process: + stream << "process"; + break; + case ServiceStateMachine::State::SinkWait: + stream << "sinkWait"; + break; + case ServiceStateMachine::State::EndSession: + stream << "endSession"; + break; + case ServiceStateMachine::State::Ended: + stream << "ended"; + break; + } + return stream; +} + +} // namespace mongo diff --git a/src/mongo/transport/service_state_machine.h b/src/mongo/transport/service_state_machine.h new file mode 100644 index 00000000000..31ef1ee4028 --- /dev/null +++ b/src/mongo/transport/service_state_machine.h @@ -0,0 +1,152 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <atomic> + +#include "mongo/base/status.h" +#include "mongo/db/service_context.h" +#include "mongo/platform/atomic_word.h" +#include "mongo/stdx/mutex.h" +#include "mongo/stdx/thread.h" +#include "mongo/transport/session.h" + +namespace mongo { +class ServiceEntryPoint; + +namespace transport { +class ServiceExecutorBase; +} // namespace transport + +/* + * The ServiceStateMachine holds the state of a single client connection and represents the + * lifecycle of each user request as a state machine. It is the glue between the stateless + * ServiceEntryPoint and TransportLayer that ties network and database logic together for a + * user. + */ +class ServiceStateMachine { + ServiceStateMachine(ServiceStateMachine&) = delete; + ServiceStateMachine& operator=(ServiceStateMachine&) = delete; + +public: + ServiceStateMachine() = default; + ServiceStateMachine(ServiceStateMachine&&) = default; + ServiceStateMachine& operator=(ServiceStateMachine&&) = default; + + ServiceStateMachine(ServiceContext* svcContext, transport::SessionHandle session, bool sync); + + /* + * Any state may transition to EndSession in case of an error, otherwise the valid state + * transitions are: + * Source -> SourceWait -> Process -> SinkWait -> Source (standard RPC) + * Source -> SourceWait -> Process -> SinkWait -> Process -> SinkWait ... (exhaust) + * Source -> SourceWait -> Process -> Source (fire-and-forget) + */ + enum class State { + Source, // Request a new Message from the network to handle + SourceWait, // Wait for the new Message to arrive from the network + Process, // Run the Message through the database + SinkWait, // Wait for the database result to be sent by the network + EndSession, // End the session - the ServiceStateMachine will be invalid after this + Ended // The session has ended. It is illegal to call any method besides + // state() if this is the current state. + }; + + /* + * runNext() will run the current state of the state machine. It also handles all the error + * handling and state management for requests. + * + * Each state function (processMessage(), sinkCallback(), etc) should always unwind the stack + * if they have just completed a database operation to make sure that this doesn't infinitely + * recurse. + */ + void runNext(); + + /* + * scheduleNext() schedules a call to runNext() in the future. This will be implemented with + * an async TransportLayer. + * + * It is guaranteed to unwind the stack, and not call runNext() recursively, but is not + * guaranteed that runNext() will run after this returns. + */ + void scheduleNext(); + + /* + * Gets the current state of connection for testing/diagnostic purposes. + */ + State state() const { + return _state; + } + + /* + * Explicitly ends the session. + */ + void endSession(); + +private: + /* + * This function actually calls into the database and processes a request. It's broken out + * into its own inline function for better readability. + */ + inline void processMessage(); + + /* + * These get called by the TransportLayer when requested network I/O has completed. + */ + void sourceCallback(Status status); + void sinkCallback(Status status); + + /* + * A class that wraps up lifetime management of the _dbClient and _threadName for runNext(); + */ + class ThreadGuard; + friend class ThreadGuard; + + const transport::SessionHandle& session() const; + + State _state{State::Source}; + + ServiceEntryPoint* _sep; + bool _sync; + + ServiceContext::UniqueClient _dbClient; + const Client* _dbClientPtr; + const std::string _threadName; + + bool inExhaust = false; + Message _inMessage; + int64_t _counter = 0; + + AtomicWord<stdx::thread::id> _currentOwningThread; + std::atomic_flag _isOwned = ATOMIC_FLAG_INIT; // NOLINT +}; + +std::ostream& operator<<(std::ostream& stream, const ServiceStateMachine::State& state); + +} // namespace mongo diff --git a/src/mongo/transport/service_state_machine_test.cpp b/src/mongo/transport/service_state_machine_test.cpp new file mode 100644 index 00000000000..338a111d5e0 --- /dev/null +++ b/src/mongo/transport/service_state_machine_test.cpp @@ -0,0 +1,295 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork + +#include "mongo/platform/basic.h" + +#include "mongo/base/checked_cast.h" +#include "mongo/bson/bsonobj.h" +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/dbmessage.h" +#include "mongo/db/service_context_noop.h" +#include "mongo/rpc/command_reply.h" +#include "mongo/rpc/command_reply_builder.h" +#include "mongo/rpc/command_request.h" +#include "mongo/rpc/command_request_builder.h" +#include "mongo/stdx/memory.h" +#include "mongo/transport/mock_session.h" +#include "mongo/transport/mock_ticket.h" +#include "mongo/transport/service_entry_point.h" +#include "mongo/transport/service_state_machine.h" +#include "mongo/transport/transport_layer_mock.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/clock_source_mock.h" +#include "mongo/util/log.h" +#include "mongo/util/tick_source_mock.h" + +namespace mongo { +namespace { +class MockSEP : public ServiceEntryPoint { +public: + virtual ~MockSEP() = default; + + void startSession(transport::SessionHandle session) override {} + + DbResponse handleRequest(OperationContext* opCtx, + const Message& request, + const HostAndPort& client) override { + log() << "In handleRequest"; + _ranHandler = true; + ASSERT_TRUE(haveClient()); + + rpc::CommandRequest req(&request); + ASSERT_BSONOBJ_EQ(BSON("ping" << 1), req.getCommandArgs()); + + // Build out a dummy reply + rpc::CommandReplyBuilder builder; + builder.setRawCommandReply(BSON("ok" << 1)); + builder.setMetadata(BSONObj{}); + + if (_uassertInHandler) + uassert(40469, "Synthetic uassert failure", false); + + return DbResponse{builder.done()}; + } + + void setUassertInHandler() { + _uassertInHandler = true; + } + + bool ranHandler() { + bool ret = _ranHandler; + _ranHandler = false; + return ret; + } + +private: + bool _uassertInHandler = false; + bool _ranHandler = false; +}; + +using namespace transport; +class MockTL : public TransportLayerMock { +public: + ~MockTL() = default; + + Ticket sourceMessage(const SessionHandle& session, + Message* message, + Date_t expiration = Ticket::kNoExpirationDate) override { + ASSERT_EQ(_ssm->state(), ServiceStateMachine::State::Source); + _lastTicketSource = true; + + _ranSource = true; + log() << "In sourceMessage"; + + if (_nextShouldFail & Source) { + return TransportLayer::TicketSessionClosedStatus; + } + + if (_nextMessage) { + *message = *_nextMessage; + } + + return TransportLayerMock::sourceMessage(session, message, expiration); + } + + Ticket sinkMessage(const SessionHandle& session, + const Message& message, + Date_t expiration = Ticket::kNoExpirationDate) override { + ASSERT_EQ(_ssm->state(), ServiceStateMachine::State::Process); + _lastTicketSource = false; + + log() << "In sinkMessage"; + _ranSink = true; + + if (_nextShouldFail & Sink) { + return TransportLayer::TicketSessionClosedStatus; + } + + _lastSunk = message; + + return TransportLayerMock::sinkMessage(session, message, expiration); + } + + Status wait(Ticket&& ticket) override { + if (!ticket.valid()) { + return ticket.status(); + } + ASSERT_EQ(_ssm->state(), + _lastTicketSource ? ServiceStateMachine::State::SourceWait + : ServiceStateMachine::State::SinkWait); + std::stringstream ss; + ss << _ssm->state(); + log() << "In wait. ssm state: " << ss.str(); + return TransportLayerMock::wait(std::move(ticket)); + } + + void asyncWait(Ticket&& ticket, TicketCallback callback) override { + MONGO_UNREACHABLE; + } + + void setNextMessage(Message&& message) { + _nextMessage = std::move(message); + } + + void setSSM(ServiceStateMachine* ssm) { + _ssm = ssm; + } + + enum FailureMode { Nothing = 0, Source = 0x1, Sink = 0x10 }; + + void setNextFailure(FailureMode mode = Source) { + _nextShouldFail = mode; + } + + Message&& getLastSunk() { + return std::move(_lastSunk); + } + + bool ranSink() const { + return _ranSink; + } + + bool ranSource() const { + return _ranSource; + } + +private: + bool _lastTicketSource = true; + bool _ranSink = false; + bool _ranSource = false; + boost::optional<Message> _nextMessage; + FailureMode _nextShouldFail = Nothing; + Message _lastSunk; + ServiceStateMachine* _ssm; +}; + +Message buildRequest(BSONObj input) { + rpc::CommandRequestBuilder builder; + builder.setDatabase("admin"); + builder.setCommandName("ping"); + builder.setCommandArgs(input); + builder.setMetadata(BSONObj{}); + + return builder.done(); +} + +} // namespace + +class ServiceStateMachineFixture : public unittest::Test { +protected: + void setUp() override { + + auto scOwned = stdx::make_unique<ServiceContextNoop>(); + auto sc = scOwned.get(); + setGlobalServiceContext(std::move(scOwned)); + + sc->setTickSource(stdx::make_unique<TickSourceMock>()); + sc->setFastClockSource(stdx::make_unique<ClockSourceMock>()); + + auto sep = stdx::make_unique<MockSEP>(); + _sep = sep.get(); + sc->setServiceEntryPoint(std::move(sep)); + + auto tl = stdx::make_unique<MockTL>(); + _tl = tl.get(); + sc->addAndStartTransportLayer(std::move(tl)); + + _ssm = stdx::make_unique<ServiceStateMachine>( + getGlobalServiceContext(), _tl->createSession(), true); + _tl->setSSM(_ssm.get()); + } + + void tearDown() override { + getGlobalServiceContext()->getTransportLayer()->shutdown(); + } + + ServiceStateMachine::State runPingTest(); + void checkPingOk(); + + MockTL* _tl; + MockSEP* _sep; + SessionHandle _session; + std::unique_ptr<ServiceStateMachine> _ssm; + bool _ranHandler; +}; + +ServiceStateMachine::State ServiceStateMachineFixture::runPingTest() { + _tl->setNextMessage(buildRequest(BSON("ping" << 1))); + + ASSERT_FALSE(haveClient()); + ASSERT_EQ(_ssm->state(), ServiceStateMachine::State::Source); + log() << "run next"; + _ssm->runNext(); + auto ret = _ssm->state(); + ASSERT_FALSE(haveClient()); + + return ret; +} + +void ServiceStateMachineFixture::checkPingOk() { + auto msg = _tl->getLastSunk(); + rpc::CommandReply reply(&msg); + + ASSERT_BSONOBJ_EQ(reply.getCommandReply(), BSON("ok" << 1)); +} + +TEST_F(ServiceStateMachineFixture, TestOkaySimpleCommand) { + ASSERT_EQ(ServiceStateMachine::State::Source, runPingTest()); + checkPingOk(); +} + +TEST_F(ServiceStateMachineFixture, TestThrowHandling) { + _sep->setUassertInHandler(); + + ASSERT_EQ(ServiceStateMachine::State::Ended, runPingTest()); + ASSERT_THROWS(checkPingOk(), MsgAssertionException); + ASSERT_TRUE(_tl->ranSource()); + ASSERT_FALSE(_tl->ranSink()); +} + +TEST_F(ServiceStateMachineFixture, TestSourceError) { + _tl->setNextFailure(MockTL::Source); + + ASSERT_EQ(ServiceStateMachine::State::Ended, runPingTest()); + ASSERT_THROWS(checkPingOk(), MsgAssertionException); + ASSERT_TRUE(_tl->ranSource()); + ASSERT_FALSE(_tl->ranSink()); +} + +TEST_F(ServiceStateMachineFixture, TestSinkError) { + _tl->setNextFailure(MockTL::Sink); + + ASSERT_EQ(ServiceStateMachine::State::Ended, runPingTest()); + ASSERT_TRUE(_tl->ranSource()); + ASSERT_TRUE(_tl->ranSink()); +} + +} // namespace mongo diff --git a/src/mongo/transport/ticket.cpp b/src/mongo/transport/ticket.cpp index f7ee14778c4..d1003957b80 100644 --- a/src/mongo/transport/ticket.cpp +++ b/src/mongo/transport/ticket.cpp @@ -54,10 +54,22 @@ Ticket::Ticket(Ticket&&) = default; Ticket& Ticket::operator=(Ticket&&) = default; Status Ticket::wait()&& { + // If the ticket is invalid then _tl is a nullptr and we should return early. + if (!valid()) + return status(); + + invariant(_tl); return _tl->wait(std::move(*this)); } void Ticket::asyncWait(TicketCallback cb)&& { + // If the ticket is invalid then _tl is a nullptr and we should return early. + if (!valid()) { + cb(status()); + return; + } + + invariant(_tl); return _tl->asyncWait(std::move(*this), std::move(cb)); } diff --git a/src/mongo/transport/ticket.h b/src/mongo/transport/ticket.h index 63b064d104b..a1ace6f3a08 100644 --- a/src/mongo/transport/ticket.h +++ b/src/mongo/transport/ticket.h @@ -103,6 +103,9 @@ public: * Asynchronously wait for this ticket to be filled. * * This is this-rvalue qualified because it consumes the ticket + * + * If the ticket has expired or is not valid when asyncWait is called, cb will be called + * immediately and inline with the error status. */ void asyncWait(TicketCallback cb) &&; @@ -135,7 +138,7 @@ public: } private: - TransportLayer* _tl; + TransportLayer* _tl = nullptr; Status _status = Status::OK(); std::unique_ptr<TicketImpl> _ticket; }; diff --git a/src/mongo/transport/transport_layer_mock.cpp b/src/mongo/transport/transport_layer_mock.cpp index 6dfd6083757..88784056fbf 100644 --- a/src/mongo/transport/transport_layer_mock.cpp +++ b/src/mongo/transport/transport_layer_mock.cpp @@ -86,7 +86,7 @@ Status TransportLayerMock::wait(Ticket&& ticket) { } void TransportLayerMock::asyncWait(Ticket&& ticket, TicketCallback callback) { - callback(Status::OK()); + callback(wait(std::move(ticket))); } TransportLayer::Stats TransportLayerMock::sessionStats() { |