diff options
-rw-r--r-- | src/mongo/db/request_execution_context.h | 11 | ||||
-rw-r--r-- | src/mongo/db/service_entry_point_common.cpp | 5 | ||||
-rw-r--r-- | src/mongo/s/commands/cluster_command_test_fixture.cpp | 4 | ||||
-rw-r--r-- | src/mongo/s/commands/strategy.cpp | 8 | ||||
-rw-r--r-- | src/mongo/s/commands/strategy.h | 3 | ||||
-rw-r--r-- | src/mongo/s/service_entry_point_mongos.cpp | 301 |
6 files changed, 215 insertions, 117 deletions
diff --git a/src/mongo/db/request_execution_context.h b/src/mongo/db/request_execution_context.h index b34c7e4f3b7..44f06f36369 100644 --- a/src/mongo/db/request_execution_context.h +++ b/src/mongo/db/request_execution_context.h @@ -57,22 +57,21 @@ public: RequestExecutionContext(const RequestExecutionContext&) = delete; RequestExecutionContext(RequestExecutionContext&&) = delete; - explicit RequestExecutionContext(OperationContext* opCtx) : _opCtx(opCtx) {} + RequestExecutionContext(OperationContext* opCtx, Message message) + : _opCtx(opCtx), + _message(std::move(message)), + _dbmsg(std::make_unique<DbMessage>(_message.get())) {} auto getOpCtx() const { invariant(_isOnClientThread()); return _opCtx; } - void setMessage(Message message) { - invariant(_isOnClientThread() && !_message); - _message = std::move(message); - _dbmsg = std::make_unique<DbMessage>(_message.get()); - } const Message& getMessage() const { invariant(_isOnClientThread() && _message); return _message.get(); } + DbMessage& getDbMessage() const { invariant(_isOnClientThread() && _dbmsg); return *_dbmsg.get(); diff --git a/src/mongo/db/service_entry_point_common.cpp b/src/mongo/db/service_entry_point_common.cpp index af22da58f4e..c6e26cfc796 100644 --- a/src/mongo/db/service_entry_point_common.cpp +++ b/src/mongo/db/service_entry_point_common.cpp @@ -141,10 +141,7 @@ struct HandleRequest { ExecutionContext(OperationContext* opCtx, Message msg, std::unique_ptr<const ServiceEntryPointCommon::Hooks> hooks) - : RequestExecutionContext(opCtx), behaviors(std::move(hooks)) { - // It also initializes dbMessage, which is accessible via getDbMessage() - setMessage(std::move(msg)); - } + : RequestExecutionContext(opCtx, std::move(msg)), behaviors(std::move(hooks)) {} ~ExecutionContext() = default; Client& client() const { diff --git a/src/mongo/s/commands/cluster_command_test_fixture.cpp b/src/mongo/s/commands/cluster_command_test_fixture.cpp index 31bee683dbf..c182dca1514 100644 --- a/src/mongo/s/commands/cluster_command_test_fixture.cpp +++ b/src/mongo/s/commands/cluster_command_test_fixture.cpp @@ -125,7 +125,9 @@ DbResponse ClusterCommandTestFixture::runCommand(BSONObj cmd) { auto clusterGLE = ClusterLastErrorInfo::get(client.get()); clusterGLE->newRequest(); - return Strategy::clientCommand(opCtx.get(), opMsgRequest.serialize()); + AlternativeClientRegion acr(client); + auto rec = std::make_shared<RequestExecutionContext>(opCtx.get(), opMsgRequest.serialize()); + return Strategy::clientCommand(std::move(rec)).get(); } void ClusterCommandTestFixture::runCommandSuccessful(BSONObj cmd, bool isTargeted) { diff --git a/src/mongo/s/commands/strategy.cpp b/src/mongo/s/commands/strategy.cpp index f3f27506da7..eb386f3763b 100644 --- a/src/mongo/s/commands/strategy.cpp +++ b/src/mongo/s/commands/strategy.cpp @@ -1028,7 +1028,9 @@ DbResponse Strategy::queryOp(OperationContext* opCtx, const NamespaceString& nss cursorId)}; } -DbResponse Strategy::clientCommand(OperationContext* opCtx, const Message& m) { +Future<DbResponse> Strategy::clientCommand(std::shared_ptr<RequestExecutionContext> rec) try { + auto opCtx = rec->getOpCtx(); + const Message& m = rec->getMessage(); auto reply = rpc::makeReplyBuilder(rpc::protocolForMessage(m)); BSONObjBuilder errorBuilder; @@ -1098,7 +1100,7 @@ DbResponse Strategy::clientCommand(OperationContext* opCtx, const Message& m) { } if (OpMsg::isFlagSet(m, OpMsg::kMoreToCome)) { - return {}; // Don't reply. + return DbResponse{}; // Don't reply. } DbResponse dbResponse; @@ -1112,6 +1114,8 @@ DbResponse Strategy::clientCommand(OperationContext* opCtx, const Message& m) { dbResponse.response = reply->done(); return dbResponse; +} catch (const DBException& e) { + return e.toStatus(); } DbResponse Strategy::getMore(OperationContext* opCtx, const NamespaceString& nss, DbMessage* dbm) { diff --git a/src/mongo/s/commands/strategy.h b/src/mongo/s/commands/strategy.h index 3807a63a7f6..85fafda1acb 100644 --- a/src/mongo/s/commands/strategy.h +++ b/src/mongo/s/commands/strategy.h @@ -33,6 +33,7 @@ #include "mongo/client/connection_string.h" #include "mongo/db/query/explain_options.h" +#include "mongo/db/request_execution_context.h" #include "mongo/s/client/shard.h" namespace mongo { @@ -82,7 +83,7 @@ public: * Catches StaleConfigException errors and retries the command automatically after refreshing * the metadata for the failing namespace. */ - static DbResponse clientCommand(OperationContext* opCtx, const Message& message); + static Future<DbResponse> clientCommand(std::shared_ptr<RequestExecutionContext> rec); /** * Helper to run an explain of a find operation on the shards. Fills 'out' with the result of diff --git a/src/mongo/s/service_entry_point_mongos.cpp b/src/mongo/s/service_entry_point_mongos.cpp index a10cec0ee7b..870d58968aa 100644 --- a/src/mongo/s/service_entry_point_mongos.cpp +++ b/src/mongo/s/service_entry_point_mongos.cpp @@ -29,6 +29,8 @@ #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork +#include <memory> + #include "mongo/platform/basic.h" #include "mongo/s/service_entry_point_mongos.h" @@ -40,12 +42,12 @@ #include "mongo/db/dbmessage.h" #include "mongo/db/lasterror.h" #include "mongo/db/operation_context.h" +#include "mongo/db/request_execution_context.h" #include "mongo/db/service_context.h" #include "mongo/logv2/log.h" #include "mongo/rpc/message.h" #include "mongo/s/cluster_last_error_info.h" #include "mongo/s/commands/strategy.h" -#include "mongo/util/scopeguard.h" namespace mongo { @@ -60,16 +62,50 @@ BSONObj buildErrReply(const DBException& ex) { } // namespace +// Allows for decomposing `handleRequest` into parts and simplifies composing the future-chain. +struct HandleRequest : public std::enable_shared_from_this<HandleRequest> { + struct OpRunnerBase; -Future<DbResponse> ServiceEntryPointMongos::handleRequest(OperationContext* opCtx, - const Message& message) noexcept try { - const int32_t msgId = message.header().getId(); - const NetworkOp op = message.operation(); + HandleRequest(OperationContext* opCtx, const Message& message) + : rec(std::make_shared<RequestExecutionContext>(opCtx, message)), + op(message.operation()), + msgId(message.header().getId()), + nsString(getNamespaceString(rec->getDbMessage())) {} + + // Prepares the environment for handling the request (e.g., setting up `ClusterLastErrorInfo`). + void setupEnvironment(); + + // Returns a future that does the heavy lifting of running client commands. + Future<DbResponse> handleRequest(); + + // Runs on successful execution of the future returned by `handleRequest`. + void onSuccess(const DbResponse&); + + // Returns a future-chain to handle the request and prepare the response. + Future<DbResponse> run(); + + static NamespaceString getNamespaceString(const DbMessage& dbmsg) { + if (!dbmsg.messageShouldHaveNs()) + return {}; + return NamespaceString(dbmsg.getns()); + } + + const std::shared_ptr<RequestExecutionContext> rec; + const NetworkOp op; + const int32_t msgId; + const NamespaceString nsString; + + boost::optional<long long> slowMsOverride; +}; + +void HandleRequest::setupEnvironment() { + using namespace fmt::literals; + auto opCtx = rec->getOpCtx(); // This exception will not be returned to the caller, but will be logged and will close the // connection uassert(ErrorCodes::IllegalOperation, - str::stream() << "Message type " << op << " is not supported.", + "Message type {} is not supported."_format(op), isSupportedRequestNetworkOp(op) && op != dbCompressed); // Decompression should be handled above us. @@ -84,115 +120,174 @@ Future<DbResponse> ServiceEntryPointMongos::handleRequest(OperationContext* opCt AuthorizationSession::get(opCtx->getClient())->startRequest(opCtx); CurOp::get(opCtx)->ensureStarted(); +} - DbMessage dbm(message); +// The base for various operation runners that handle the request, and often generate a DbResponse. +struct HandleRequest::OpRunnerBase { + explicit OpRunnerBase(std::shared_ptr<HandleRequest> hr) : hr(std::move(hr)) {} + virtual ~OpRunnerBase() = default; + virtual Future<DbResponse> run() = 0; + const std::shared_ptr<HandleRequest> hr; +}; + +struct CommandOpRunner final : public HandleRequest::OpRunnerBase { + using HandleRequest::OpRunnerBase::OpRunnerBase; + Future<DbResponse> run() override { + return Strategy::clientCommand(hr->rec).tap([hr = hr](const DbResponse&) { + // Hello should take kMaxAwaitTimeMs at most, log if it takes twice that. + if (auto command = CurOp::get(hr->rec->getOpCtx())->getCommand(); + command && (command->getName() == "hello")) { + hr->slowMsOverride = + 2 * durationCount<Milliseconds>(SingleServerIsMasterMonitor::kMaxAwaitTime); + } + }); + } +}; + +// The base for operations that may throw exceptions, but should not cause the connection to close. +struct OpRunner : public HandleRequest::OpRunnerBase { + using HandleRequest::OpRunnerBase::OpRunnerBase; + virtual DbResponse runOperation() = 0; + Future<DbResponse> run() override; +}; + +Future<DbResponse> OpRunner::run() try { + using namespace fmt::literals; + const NamespaceString& nss = hr->nsString; + const DbMessage& dbm = hr->rec->getDbMessage(); + + if (dbm.messageShouldHaveNs()) { + uassert(ErrorCodes::InvalidNamespace, "Invalid ns [{}]"_format(nss.ns()), nss.isValid()); + + uassert(ErrorCodes::IllegalOperation, + "Can't use 'local' database through mongos", + nss.db() != NamespaceString::kLocalDb); + } - // This is before the try block since it handles all exceptions that should not cause the - // connection to close. - if (op == dbMsg || (op == dbQuery && NamespaceString(dbm.getns()).isCommand())) { - auto dbResponse = Strategy::clientCommand(opCtx, message); + LOGV2_DEBUG(22867, + 3, + "Request::process begin ns: {namespace} msg id: {msgId} op: {operation}", + "Starting operation", + "namespace"_attr = nss, + "msgId"_attr = hr->msgId, + "operation"_attr = networkOpToString(hr->op)); - // Hello should take kMaxAwaitTimeMs at most, log if it takes twice that. - boost::optional<long long> slowMsOverride; - if (auto command = CurOp::get(opCtx)->getCommand(); - command && (command->getName() == "hello")) { - slowMsOverride = - 2 * durationCount<Milliseconds>(SingleServerIsMasterMonitor::kMaxAwaitTime); - } + auto dbResponse = runOperation(); - // Mark the op as complete, populate the response length, and log it if appropriate. - CurOp::get(opCtx)->completeAndLogOperation( - opCtx, logv2::LogComponent::kCommand, dbResponse.response.size(), slowMsOverride); + LOGV2_DEBUG(22868, + 3, + "Request::process end ns: {namespace} msg id: {msgId} op: {operation}", + "Done processing operation", + "namespace"_attr = nss, + "msgId"_attr = hr->msgId, + "operation"_attr = networkOpToString(hr->op)); - return Future<DbResponse>::makeReady(std::move(dbResponse)); - } + return Future<DbResponse>::makeReady(std::move(dbResponse)); +} catch (const DBException& ex) { + LOGV2_DEBUG(22869, + 1, + "Exception thrown while processing {operation} op for {namespace}: {error}", + "Got an error while processing operation", + "operation"_attr = networkOpToString(hr->op), + "namespace"_attr = hr->nsString.ns(), + "error"_attr = ex); - NamespaceString nss; DbResponse dbResponse; - try { - if (dbm.messageShouldHaveNs()) { - nss = NamespaceString(StringData(dbm.getns())); - - uassert(ErrorCodes::InvalidNamespace, - str::stream() << "Invalid ns [" << nss.ns() << "]", - nss.isValid()); - - uassert(ErrorCodes::IllegalOperation, - "Can't use 'local' database through mongos", - nss.db() != NamespaceString::kLocalDb); - } - - - LOGV2_DEBUG(22867, - 3, - "Request::process begin ns: {namespace} msg id: {msgId} op: {operation}", - "Starting operation", - "namespace"_attr = nss, - "msgId"_attr = msgId, - "operation"_attr = networkOpToString(op)); - - switch (op) { - case dbQuery: - // Commands are handled above through Strategy::clientCommand(). - invariant(!nss.isCommand()); - opCtx->markKillOnClientDisconnect(); - dbResponse = Strategy::queryOp(opCtx, nss, &dbm); - break; - - case dbGetMore: - dbResponse = Strategy::getMore(opCtx, nss, &dbm); - break; - - case dbKillCursors: - Strategy::killCursors(opCtx, &dbm); // No Response. - break; - - case dbInsert: - case dbUpdate: - case dbDelete: - Strategy::writeOp(opCtx, &dbm); // No Response. - break; - - default: - MONGO_UNREACHABLE; - } - - LOGV2_DEBUG(22868, - 3, - "Request::process end ns: {namespace} msg id: {msgId} op: {operation}", - "Done processing operation", - "namespace"_attr = nss, - "msgId"_attr = msgId, - "operation"_attr = networkOpToString(op)); - - } catch (const DBException& ex) { - LOGV2_DEBUG(22869, - 1, - "Exception thrown while processing {operation} op for {namespace}: {error}", - "Got an error while processing operation", - "operation"_attr = networkOpToString(op), - "namespace"_attr = nss.ns(), - "error"_attr = ex); - - if (op == dbQuery || op == dbGetMore) { - dbResponse = replyToQuery(buildErrReply(ex), ResultFlag_ErrSet); - } else { - // No Response. - } - - // We *always* populate the last error for now - LastError::get(opCtx->getClient()).setLastError(ex.code(), ex.what()); - CurOp::get(opCtx)->debug().errInfo = ex.toStatus(); + if (hr->op == dbQuery || hr->op == dbGetMore) { + dbResponse = replyToQuery(buildErrReply(ex), ResultFlag_ErrSet); + } else { + // No Response. + } + + // We *always* populate the last error for now + auto opCtx = hr->rec->getOpCtx(); + LastError::get(opCtx->getClient()).setLastError(ex.code(), ex.what()); + + CurOp::get(opCtx)->debug().errInfo = ex.toStatus(); + + return Future<DbResponse>::makeReady(std::move(dbResponse)); +} + +struct QueryOpRunner final : public OpRunner { + using OpRunner::OpRunner; + DbResponse runOperation() override { + // Commands are handled through CommandOpRunner and Strategy::clientCommand(). + invariant(!hr->nsString.isCommand()); + hr->rec->getOpCtx()->markKillOnClientDisconnect(); + return Strategy::queryOp(hr->rec->getOpCtx(), hr->nsString, &hr->rec->getDbMessage()); + } +}; + +struct GetMoreOpRunner final : public OpRunner { + using OpRunner::OpRunner; + DbResponse runOperation() override { + return Strategy::getMore(hr->rec->getOpCtx(), hr->nsString, &hr->rec->getDbMessage()); + } +}; + +struct KillCursorsOpRunner final : public OpRunner { + using OpRunner::OpRunner; + DbResponse runOperation() override { + Strategy::killCursors(hr->rec->getOpCtx(), &hr->rec->getDbMessage()); // No Response. + return {}; } +}; +struct WriteOpRunner final : public OpRunner { + using OpRunner::OpRunner; + DbResponse runOperation() override { + Strategy::writeOp(hr->rec->getOpCtx(), &hr->rec->getDbMessage()); // No Response. + return {}; + } +}; + +Future<DbResponse> HandleRequest::handleRequest() { + switch (op) { + case dbQuery: + if (!nsString.isCommand()) + return std::make_unique<QueryOpRunner>(shared_from_this())->run(); + // FALLTHROUGH: it's a query containing a command + case dbMsg: + return std::make_unique<CommandOpRunner>(shared_from_this())->run(); + case dbGetMore: + return std::make_unique<GetMoreOpRunner>(shared_from_this())->run(); + case dbKillCursors: + return std::make_unique<KillCursorsOpRunner>(shared_from_this())->run(); + case dbInsert: + case dbUpdate: + case dbDelete: + return std::make_unique<WriteOpRunner>(shared_from_this())->run(); + default: + MONGO_UNREACHABLE; + } +} + +void HandleRequest::onSuccess(const DbResponse& dbResponse) { + auto opCtx = rec->getOpCtx(); // Mark the op as complete, populate the response length, and log it if appropriate. CurOp::get(opCtx)->completeAndLogOperation( - opCtx, logv2::LogComponent::kCommand, dbResponse.response.size()); + opCtx, logv2::LogComponent::kCommand, dbResponse.response.size(), slowMsOverride); +} - return Future<DbResponse>::makeReady(std::move(dbResponse)); -} catch (const DBException& e) { - LOGV2(4879803, "Failed to handle request", "error"_attr = redact(e)); - return e.toStatus(); +Future<DbResponse> HandleRequest::run() { + auto fp = makePromiseFuture<void>(); + auto future = std::move(fp.future) + .then([this, anchor = shared_from_this()] { setupEnvironment(); }) + .then([this, anchor = shared_from_this()] { return handleRequest(); }) + .tap([this, anchor = shared_from_this()](const DbResponse& dbResponse) { + onSuccess(dbResponse); + }) + .tapError([](Status status) { + LOGV2(4879803, "Failed to handle request", "error"_attr = redact(status)); + }); + fp.promise.emplaceValue(); + return future; +} + +Future<DbResponse> ServiceEntryPointMongos::handleRequest(OperationContext* opCtx, + const Message& message) noexcept { + auto hr = std::make_shared<HandleRequest>(opCtx, message); + return hr->run(); } } // namespace mongo |