summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/db/request_execution_context.h11
-rw-r--r--src/mongo/db/service_entry_point_common.cpp5
-rw-r--r--src/mongo/s/commands/cluster_command_test_fixture.cpp4
-rw-r--r--src/mongo/s/commands/strategy.cpp8
-rw-r--r--src/mongo/s/commands/strategy.h3
-rw-r--r--src/mongo/s/service_entry_point_mongos.cpp301
6 files changed, 215 insertions, 117 deletions
diff --git a/src/mongo/db/request_execution_context.h b/src/mongo/db/request_execution_context.h
index b34c7e4f3b7..44f06f36369 100644
--- a/src/mongo/db/request_execution_context.h
+++ b/src/mongo/db/request_execution_context.h
@@ -57,22 +57,21 @@ public:
RequestExecutionContext(const RequestExecutionContext&) = delete;
RequestExecutionContext(RequestExecutionContext&&) = delete;
- explicit RequestExecutionContext(OperationContext* opCtx) : _opCtx(opCtx) {}
+ RequestExecutionContext(OperationContext* opCtx, Message message)
+ : _opCtx(opCtx),
+ _message(std::move(message)),
+ _dbmsg(std::make_unique<DbMessage>(_message.get())) {}
auto getOpCtx() const {
invariant(_isOnClientThread());
return _opCtx;
}
- void setMessage(Message message) {
- invariant(_isOnClientThread() && !_message);
- _message = std::move(message);
- _dbmsg = std::make_unique<DbMessage>(_message.get());
- }
const Message& getMessage() const {
invariant(_isOnClientThread() && _message);
return _message.get();
}
+
DbMessage& getDbMessage() const {
invariant(_isOnClientThread() && _dbmsg);
return *_dbmsg.get();
diff --git a/src/mongo/db/service_entry_point_common.cpp b/src/mongo/db/service_entry_point_common.cpp
index af22da58f4e..c6e26cfc796 100644
--- a/src/mongo/db/service_entry_point_common.cpp
+++ b/src/mongo/db/service_entry_point_common.cpp
@@ -141,10 +141,7 @@ struct HandleRequest {
ExecutionContext(OperationContext* opCtx,
Message msg,
std::unique_ptr<const ServiceEntryPointCommon::Hooks> hooks)
- : RequestExecutionContext(opCtx), behaviors(std::move(hooks)) {
- // It also initializes dbMessage, which is accessible via getDbMessage()
- setMessage(std::move(msg));
- }
+ : RequestExecutionContext(opCtx, std::move(msg)), behaviors(std::move(hooks)) {}
~ExecutionContext() = default;
Client& client() const {
diff --git a/src/mongo/s/commands/cluster_command_test_fixture.cpp b/src/mongo/s/commands/cluster_command_test_fixture.cpp
index 31bee683dbf..c182dca1514 100644
--- a/src/mongo/s/commands/cluster_command_test_fixture.cpp
+++ b/src/mongo/s/commands/cluster_command_test_fixture.cpp
@@ -125,7 +125,9 @@ DbResponse ClusterCommandTestFixture::runCommand(BSONObj cmd) {
auto clusterGLE = ClusterLastErrorInfo::get(client.get());
clusterGLE->newRequest();
- return Strategy::clientCommand(opCtx.get(), opMsgRequest.serialize());
+ AlternativeClientRegion acr(client);
+ auto rec = std::make_shared<RequestExecutionContext>(opCtx.get(), opMsgRequest.serialize());
+ return Strategy::clientCommand(std::move(rec)).get();
}
void ClusterCommandTestFixture::runCommandSuccessful(BSONObj cmd, bool isTargeted) {
diff --git a/src/mongo/s/commands/strategy.cpp b/src/mongo/s/commands/strategy.cpp
index f3f27506da7..eb386f3763b 100644
--- a/src/mongo/s/commands/strategy.cpp
+++ b/src/mongo/s/commands/strategy.cpp
@@ -1028,7 +1028,9 @@ DbResponse Strategy::queryOp(OperationContext* opCtx, const NamespaceString& nss
cursorId)};
}
-DbResponse Strategy::clientCommand(OperationContext* opCtx, const Message& m) {
+Future<DbResponse> Strategy::clientCommand(std::shared_ptr<RequestExecutionContext> rec) try {
+ auto opCtx = rec->getOpCtx();
+ const Message& m = rec->getMessage();
auto reply = rpc::makeReplyBuilder(rpc::protocolForMessage(m));
BSONObjBuilder errorBuilder;
@@ -1098,7 +1100,7 @@ DbResponse Strategy::clientCommand(OperationContext* opCtx, const Message& m) {
}
if (OpMsg::isFlagSet(m, OpMsg::kMoreToCome)) {
- return {}; // Don't reply.
+ return DbResponse{}; // Don't reply.
}
DbResponse dbResponse;
@@ -1112,6 +1114,8 @@ DbResponse Strategy::clientCommand(OperationContext* opCtx, const Message& m) {
dbResponse.response = reply->done();
return dbResponse;
+} catch (const DBException& e) {
+ return e.toStatus();
}
DbResponse Strategy::getMore(OperationContext* opCtx, const NamespaceString& nss, DbMessage* dbm) {
diff --git a/src/mongo/s/commands/strategy.h b/src/mongo/s/commands/strategy.h
index 3807a63a7f6..85fafda1acb 100644
--- a/src/mongo/s/commands/strategy.h
+++ b/src/mongo/s/commands/strategy.h
@@ -33,6 +33,7 @@
#include "mongo/client/connection_string.h"
#include "mongo/db/query/explain_options.h"
+#include "mongo/db/request_execution_context.h"
#include "mongo/s/client/shard.h"
namespace mongo {
@@ -82,7 +83,7 @@ public:
* Catches StaleConfigException errors and retries the command automatically after refreshing
* the metadata for the failing namespace.
*/
- static DbResponse clientCommand(OperationContext* opCtx, const Message& message);
+ static Future<DbResponse> clientCommand(std::shared_ptr<RequestExecutionContext> rec);
/**
* Helper to run an explain of a find operation on the shards. Fills 'out' with the result of
diff --git a/src/mongo/s/service_entry_point_mongos.cpp b/src/mongo/s/service_entry_point_mongos.cpp
index a10cec0ee7b..870d58968aa 100644
--- a/src/mongo/s/service_entry_point_mongos.cpp
+++ b/src/mongo/s/service_entry_point_mongos.cpp
@@ -29,6 +29,8 @@
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork
+#include <memory>
+
#include "mongo/platform/basic.h"
#include "mongo/s/service_entry_point_mongos.h"
@@ -40,12 +42,12 @@
#include "mongo/db/dbmessage.h"
#include "mongo/db/lasterror.h"
#include "mongo/db/operation_context.h"
+#include "mongo/db/request_execution_context.h"
#include "mongo/db/service_context.h"
#include "mongo/logv2/log.h"
#include "mongo/rpc/message.h"
#include "mongo/s/cluster_last_error_info.h"
#include "mongo/s/commands/strategy.h"
-#include "mongo/util/scopeguard.h"
namespace mongo {
@@ -60,16 +62,50 @@ BSONObj buildErrReply(const DBException& ex) {
} // namespace
+// Allows for decomposing `handleRequest` into parts and simplifies composing the future-chain.
+struct HandleRequest : public std::enable_shared_from_this<HandleRequest> {
+ struct OpRunnerBase;
-Future<DbResponse> ServiceEntryPointMongos::handleRequest(OperationContext* opCtx,
- const Message& message) noexcept try {
- const int32_t msgId = message.header().getId();
- const NetworkOp op = message.operation();
+ HandleRequest(OperationContext* opCtx, const Message& message)
+ : rec(std::make_shared<RequestExecutionContext>(opCtx, message)),
+ op(message.operation()),
+ msgId(message.header().getId()),
+ nsString(getNamespaceString(rec->getDbMessage())) {}
+
+ // Prepares the environment for handling the request (e.g., setting up `ClusterLastErrorInfo`).
+ void setupEnvironment();
+
+ // Returns a future that does the heavy lifting of running client commands.
+ Future<DbResponse> handleRequest();
+
+ // Runs on successful execution of the future returned by `handleRequest`.
+ void onSuccess(const DbResponse&);
+
+ // Returns a future-chain to handle the request and prepare the response.
+ Future<DbResponse> run();
+
+ static NamespaceString getNamespaceString(const DbMessage& dbmsg) {
+ if (!dbmsg.messageShouldHaveNs())
+ return {};
+ return NamespaceString(dbmsg.getns());
+ }
+
+ const std::shared_ptr<RequestExecutionContext> rec;
+ const NetworkOp op;
+ const int32_t msgId;
+ const NamespaceString nsString;
+
+ boost::optional<long long> slowMsOverride;
+};
+
+void HandleRequest::setupEnvironment() {
+ using namespace fmt::literals;
+ auto opCtx = rec->getOpCtx();
// This exception will not be returned to the caller, but will be logged and will close the
// connection
uassert(ErrorCodes::IllegalOperation,
- str::stream() << "Message type " << op << " is not supported.",
+ "Message type {} is not supported."_format(op),
isSupportedRequestNetworkOp(op) &&
op != dbCompressed); // Decompression should be handled above us.
@@ -84,115 +120,174 @@ Future<DbResponse> ServiceEntryPointMongos::handleRequest(OperationContext* opCt
AuthorizationSession::get(opCtx->getClient())->startRequest(opCtx);
CurOp::get(opCtx)->ensureStarted();
+}
- DbMessage dbm(message);
+// The base for various operation runners that handle the request, and often generate a DbResponse.
+struct HandleRequest::OpRunnerBase {
+ explicit OpRunnerBase(std::shared_ptr<HandleRequest> hr) : hr(std::move(hr)) {}
+ virtual ~OpRunnerBase() = default;
+ virtual Future<DbResponse> run() = 0;
+ const std::shared_ptr<HandleRequest> hr;
+};
+
+struct CommandOpRunner final : public HandleRequest::OpRunnerBase {
+ using HandleRequest::OpRunnerBase::OpRunnerBase;
+ Future<DbResponse> run() override {
+ return Strategy::clientCommand(hr->rec).tap([hr = hr](const DbResponse&) {
+ // Hello should take kMaxAwaitTimeMs at most, log if it takes twice that.
+ if (auto command = CurOp::get(hr->rec->getOpCtx())->getCommand();
+ command && (command->getName() == "hello")) {
+ hr->slowMsOverride =
+ 2 * durationCount<Milliseconds>(SingleServerIsMasterMonitor::kMaxAwaitTime);
+ }
+ });
+ }
+};
+
+// The base for operations that may throw exceptions, but should not cause the connection to close.
+struct OpRunner : public HandleRequest::OpRunnerBase {
+ using HandleRequest::OpRunnerBase::OpRunnerBase;
+ virtual DbResponse runOperation() = 0;
+ Future<DbResponse> run() override;
+};
+
+Future<DbResponse> OpRunner::run() try {
+ using namespace fmt::literals;
+ const NamespaceString& nss = hr->nsString;
+ const DbMessage& dbm = hr->rec->getDbMessage();
+
+ if (dbm.messageShouldHaveNs()) {
+ uassert(ErrorCodes::InvalidNamespace, "Invalid ns [{}]"_format(nss.ns()), nss.isValid());
+
+ uassert(ErrorCodes::IllegalOperation,
+ "Can't use 'local' database through mongos",
+ nss.db() != NamespaceString::kLocalDb);
+ }
- // This is before the try block since it handles all exceptions that should not cause the
- // connection to close.
- if (op == dbMsg || (op == dbQuery && NamespaceString(dbm.getns()).isCommand())) {
- auto dbResponse = Strategy::clientCommand(opCtx, message);
+ LOGV2_DEBUG(22867,
+ 3,
+ "Request::process begin ns: {namespace} msg id: {msgId} op: {operation}",
+ "Starting operation",
+ "namespace"_attr = nss,
+ "msgId"_attr = hr->msgId,
+ "operation"_attr = networkOpToString(hr->op));
- // Hello should take kMaxAwaitTimeMs at most, log if it takes twice that.
- boost::optional<long long> slowMsOverride;
- if (auto command = CurOp::get(opCtx)->getCommand();
- command && (command->getName() == "hello")) {
- slowMsOverride =
- 2 * durationCount<Milliseconds>(SingleServerIsMasterMonitor::kMaxAwaitTime);
- }
+ auto dbResponse = runOperation();
- // Mark the op as complete, populate the response length, and log it if appropriate.
- CurOp::get(opCtx)->completeAndLogOperation(
- opCtx, logv2::LogComponent::kCommand, dbResponse.response.size(), slowMsOverride);
+ LOGV2_DEBUG(22868,
+ 3,
+ "Request::process end ns: {namespace} msg id: {msgId} op: {operation}",
+ "Done processing operation",
+ "namespace"_attr = nss,
+ "msgId"_attr = hr->msgId,
+ "operation"_attr = networkOpToString(hr->op));
- return Future<DbResponse>::makeReady(std::move(dbResponse));
- }
+ return Future<DbResponse>::makeReady(std::move(dbResponse));
+} catch (const DBException& ex) {
+ LOGV2_DEBUG(22869,
+ 1,
+ "Exception thrown while processing {operation} op for {namespace}: {error}",
+ "Got an error while processing operation",
+ "operation"_attr = networkOpToString(hr->op),
+ "namespace"_attr = hr->nsString.ns(),
+ "error"_attr = ex);
- NamespaceString nss;
DbResponse dbResponse;
- try {
- if (dbm.messageShouldHaveNs()) {
- nss = NamespaceString(StringData(dbm.getns()));
-
- uassert(ErrorCodes::InvalidNamespace,
- str::stream() << "Invalid ns [" << nss.ns() << "]",
- nss.isValid());
-
- uassert(ErrorCodes::IllegalOperation,
- "Can't use 'local' database through mongos",
- nss.db() != NamespaceString::kLocalDb);
- }
-
-
- LOGV2_DEBUG(22867,
- 3,
- "Request::process begin ns: {namespace} msg id: {msgId} op: {operation}",
- "Starting operation",
- "namespace"_attr = nss,
- "msgId"_attr = msgId,
- "operation"_attr = networkOpToString(op));
-
- switch (op) {
- case dbQuery:
- // Commands are handled above through Strategy::clientCommand().
- invariant(!nss.isCommand());
- opCtx->markKillOnClientDisconnect();
- dbResponse = Strategy::queryOp(opCtx, nss, &dbm);
- break;
-
- case dbGetMore:
- dbResponse = Strategy::getMore(opCtx, nss, &dbm);
- break;
-
- case dbKillCursors:
- Strategy::killCursors(opCtx, &dbm); // No Response.
- break;
-
- case dbInsert:
- case dbUpdate:
- case dbDelete:
- Strategy::writeOp(opCtx, &dbm); // No Response.
- break;
-
- default:
- MONGO_UNREACHABLE;
- }
-
- LOGV2_DEBUG(22868,
- 3,
- "Request::process end ns: {namespace} msg id: {msgId} op: {operation}",
- "Done processing operation",
- "namespace"_attr = nss,
- "msgId"_attr = msgId,
- "operation"_attr = networkOpToString(op));
-
- } catch (const DBException& ex) {
- LOGV2_DEBUG(22869,
- 1,
- "Exception thrown while processing {operation} op for {namespace}: {error}",
- "Got an error while processing operation",
- "operation"_attr = networkOpToString(op),
- "namespace"_attr = nss.ns(),
- "error"_attr = ex);
-
- if (op == dbQuery || op == dbGetMore) {
- dbResponse = replyToQuery(buildErrReply(ex), ResultFlag_ErrSet);
- } else {
- // No Response.
- }
-
- // We *always* populate the last error for now
- LastError::get(opCtx->getClient()).setLastError(ex.code(), ex.what());
- CurOp::get(opCtx)->debug().errInfo = ex.toStatus();
+ if (hr->op == dbQuery || hr->op == dbGetMore) {
+ dbResponse = replyToQuery(buildErrReply(ex), ResultFlag_ErrSet);
+ } else {
+ // No Response.
+ }
+
+ // We *always* populate the last error for now
+ auto opCtx = hr->rec->getOpCtx();
+ LastError::get(opCtx->getClient()).setLastError(ex.code(), ex.what());
+
+ CurOp::get(opCtx)->debug().errInfo = ex.toStatus();
+
+ return Future<DbResponse>::makeReady(std::move(dbResponse));
+}
+
+struct QueryOpRunner final : public OpRunner {
+ using OpRunner::OpRunner;
+ DbResponse runOperation() override {
+ // Commands are handled through CommandOpRunner and Strategy::clientCommand().
+ invariant(!hr->nsString.isCommand());
+ hr->rec->getOpCtx()->markKillOnClientDisconnect();
+ return Strategy::queryOp(hr->rec->getOpCtx(), hr->nsString, &hr->rec->getDbMessage());
+ }
+};
+
+struct GetMoreOpRunner final : public OpRunner {
+ using OpRunner::OpRunner;
+ DbResponse runOperation() override {
+ return Strategy::getMore(hr->rec->getOpCtx(), hr->nsString, &hr->rec->getDbMessage());
+ }
+};
+
+struct KillCursorsOpRunner final : public OpRunner {
+ using OpRunner::OpRunner;
+ DbResponse runOperation() override {
+ Strategy::killCursors(hr->rec->getOpCtx(), &hr->rec->getDbMessage()); // No Response.
+ return {};
}
+};
+struct WriteOpRunner final : public OpRunner {
+ using OpRunner::OpRunner;
+ DbResponse runOperation() override {
+ Strategy::writeOp(hr->rec->getOpCtx(), &hr->rec->getDbMessage()); // No Response.
+ return {};
+ }
+};
+
+Future<DbResponse> HandleRequest::handleRequest() {
+ switch (op) {
+ case dbQuery:
+ if (!nsString.isCommand())
+ return std::make_unique<QueryOpRunner>(shared_from_this())->run();
+ // FALLTHROUGH: it's a query containing a command
+ case dbMsg:
+ return std::make_unique<CommandOpRunner>(shared_from_this())->run();
+ case dbGetMore:
+ return std::make_unique<GetMoreOpRunner>(shared_from_this())->run();
+ case dbKillCursors:
+ return std::make_unique<KillCursorsOpRunner>(shared_from_this())->run();
+ case dbInsert:
+ case dbUpdate:
+ case dbDelete:
+ return std::make_unique<WriteOpRunner>(shared_from_this())->run();
+ default:
+ MONGO_UNREACHABLE;
+ }
+}
+
+void HandleRequest::onSuccess(const DbResponse& dbResponse) {
+ auto opCtx = rec->getOpCtx();
// Mark the op as complete, populate the response length, and log it if appropriate.
CurOp::get(opCtx)->completeAndLogOperation(
- opCtx, logv2::LogComponent::kCommand, dbResponse.response.size());
+ opCtx, logv2::LogComponent::kCommand, dbResponse.response.size(), slowMsOverride);
+}
- return Future<DbResponse>::makeReady(std::move(dbResponse));
-} catch (const DBException& e) {
- LOGV2(4879803, "Failed to handle request", "error"_attr = redact(e));
- return e.toStatus();
+Future<DbResponse> HandleRequest::run() {
+ auto fp = makePromiseFuture<void>();
+ auto future = std::move(fp.future)
+ .then([this, anchor = shared_from_this()] { setupEnvironment(); })
+ .then([this, anchor = shared_from_this()] { return handleRequest(); })
+ .tap([this, anchor = shared_from_this()](const DbResponse& dbResponse) {
+ onSuccess(dbResponse);
+ })
+ .tapError([](Status status) {
+ LOGV2(4879803, "Failed to handle request", "error"_attr = redact(status));
+ });
+ fp.promise.emplaceValue();
+ return future;
+}
+
+Future<DbResponse> ServiceEntryPointMongos::handleRequest(OperationContext* opCtx,
+ const Message& message) noexcept {
+ auto hr = std::make_shared<HandleRequest>(opCtx, message);
+ return hr->run();
}
} // namespace mongo