diff options
author | Jason Carey <jcarey@argv.me> | 2019-05-13 18:24:36 -0400 |
---|---|---|
committer | Jason Carey <jcarey@argv.me> | 2019-06-06 09:00:28 -0400 |
commit | efa1ae064b9206f82136a8d14cbb86d47e8754b1 (patch) | |
tree | c4666fa197f837b5a0feaa8f980299a8eed7146a /src/mongo/executor | |
parent | b1ff28c63836aa13112cf3499574160a5950c6ec (diff) | |
download | mongo-efa1ae064b9206f82136a8d14cbb86d47e8754b1.tar.gz |
SERVER-41133 Add TE::scheduleRemoteCommandOnAny
Add support for a mode for the task executor where rather than
targetting a single host, we target any of a set of hosts. This should
behave identically to scheduleRemoteCommand, except that we concurrently
get() connections from the connection pool for each host, preferring the
first which is available
Diffstat (limited to 'src/mongo/executor')
18 files changed, 533 insertions, 235 deletions
diff --git a/src/mongo/executor/network_connection_hook.h b/src/mongo/executor/network_connection_hook.h index 368e79111ae..32186b985e6 100644 --- a/src/mongo/executor/network_connection_hook.h +++ b/src/mongo/executor/network_connection_hook.h @@ -32,6 +32,7 @@ #include <boost/optional.hpp> #include "mongo/bson/bsonobj.h" +#include "mongo/executor/remote_command_request.h" namespace mongo { @@ -43,7 +44,6 @@ struct HostAndPort; namespace executor { struct RemoteCommandResponse; -struct RemoteCommandRequest; /** * An hooking interface for augmenting an implementation of NetworkInterface with domain-specific diff --git a/src/mongo/executor/network_interface.h b/src/mongo/executor/network_interface.h index 033dfb00d30..8bc864e3df6 100644 --- a/src/mongo/executor/network_interface.h +++ b/src/mongo/executor/network_interface.h @@ -54,7 +54,8 @@ class NetworkInterface { public: using Response = RemoteCommandResponse; - using RemoteCommandCompletionFn = unique_function<void(const TaskExecutor::ResponseStatus&)>; + using RemoteCommandCompletionFn = + unique_function<void(const TaskExecutor::ResponseOnAnyStatus&)>; virtual ~NetworkInterface(); @@ -144,19 +145,20 @@ public: * function will not run. */ virtual Status startCommand(const TaskExecutor::CallbackHandle& cbHandle, - RemoteCommandRequest& request, + RemoteCommandRequestOnAny& request, RemoteCommandCompletionFn&& onFinish, const BatonHandle& baton = nullptr) = 0; - Future<TaskExecutor::ResponseStatus> startCommand(const TaskExecutor::CallbackHandle& cbHandle, - RemoteCommandRequest& request, - const BatonHandle& baton = nullptr) { - auto pf = makePromiseFuture<TaskExecutor::ResponseStatus>(); + Future<TaskExecutor::ResponseOnAnyStatus> startCommand( + const TaskExecutor::CallbackHandle& cbHandle, + RemoteCommandRequestOnAny& request, + const BatonHandle& baton = nullptr) { + auto pf = makePromiseFuture<TaskExecutor::ResponseOnAnyStatus>(); auto status = startCommand( cbHandle, request, - [p = std::move(pf.promise)](const TaskExecutor::ResponseStatus& rs) mutable { + [p = std::move(pf.promise)](const TaskExecutor::ResponseOnAnyStatus& rs) mutable { p.emplaceValue(rs); }, baton); diff --git a/src/mongo/executor/network_interface_integration_fixture.cpp b/src/mongo/executor/network_interface_integration_fixture.cpp index f9ea9080d33..2729326442a 100644 --- a/src/mongo/executor/network_interface_integration_fixture.cpp +++ b/src/mongo/executor/network_interface_integration_fixture.cpp @@ -86,12 +86,21 @@ PseudoRandom* NetworkInterfaceIntegrationFixture::getRandomNumberGenerator() { void NetworkInterfaceIntegrationFixture::startCommand(const TaskExecutor::CallbackHandle& cbHandle, RemoteCommandRequest& request, StartCommandCB onFinish) { - net().startCommand(cbHandle, request, onFinish).transitional_ignore(); + RemoteCommandRequestOnAny rcroa{request}; + + auto cb = [onFinish = std::move(onFinish)](const TaskExecutor::ResponseOnAnyStatus& rs) { + onFinish(rs); + }; + invariant(net().startCommand(cbHandle, rcroa, std::move(cb))); } Future<RemoteCommandResponse> NetworkInterfaceIntegrationFixture::runCommand( const TaskExecutor::CallbackHandle& cbHandle, RemoteCommandRequest request) { - return net().startCommand(cbHandle, request); + RemoteCommandRequestOnAny rcroa{request}; + + return net().startCommand(cbHandle, rcroa).then([](TaskExecutor::ResponseOnAnyStatus roa) { + return RemoteCommandResponse(roa); + }); } RemoteCommandResponse NetworkInterfaceIntegrationFixture::runCommandSync( diff --git a/src/mongo/executor/network_interface_mock.cpp b/src/mongo/executor/network_interface_mock.cpp index c702647214f..7f156744803 100644 --- a/src/mongo/executor/network_interface_mock.cpp +++ b/src/mongo/executor/network_interface_mock.cpp @@ -105,7 +105,7 @@ std::string NetworkInterfaceMock::getHostName() { } Status NetworkInterfaceMock::startCommand(const CallbackHandle& cbHandle, - RemoteCommandRequest& request, + RemoteCommandRequestOnAny& request, RemoteCommandCompletionFn&& onFinish, const BatonHandle& baton) { if (inShutdown()) { @@ -117,11 +117,14 @@ Status NetworkInterfaceMock::startCommand(const CallbackHandle& cbHandle, const Date_t now = _now_inlock(); auto op = NetworkOperation(cbHandle, request, now, std::move(onFinish)); + // network interface mock only works with single target requests + invariant(request.target.size() == 1); + // If we don't have a hook, or we have already 'connected' to this host, enqueue the op. - if (!_hook || _connections.count(request.target)) { + if (!_hook || _connections.count(request.target[0])) { _enqueueOperation_inlock(std::move(op)); } else { - _connectThenEnqueueOperation_inlock(request.target, std::move(op)); + _connectThenEnqueueOperation_inlock(request.target[0], std::move(op)); } return Status::OK(); @@ -531,7 +534,8 @@ void NetworkInterfaceMock::_connectThenEnqueueOperation_inlock(const HostAndPort auto cbh = op.getCallbackHandle(); // The completion handler for the postconnect command schedules the original command. - auto postconnectCompletionHandler = [ this, op = std::move(op) ](ResponseStatus rs) mutable { + auto postconnectCompletionHandler = + [ this, op = std::move(op) ](TaskExecutor::ResponseOnAnyStatus rs) mutable { stdx::lock_guard<stdx::mutex> lk(_mutex); if (!rs.isOK()) { op.setResponse(_now_inlock(), rs); @@ -663,17 +667,21 @@ NetworkInterfaceMock::NetworkOperation::NetworkOperation() _response(kUnsetResponse), _onFinish() {} -NetworkInterfaceMock::NetworkOperation::NetworkOperation(const CallbackHandle& cbHandle, - const RemoteCommandRequest& theRequest, - Date_t theRequestDate, - RemoteCommandCompletionFn onFinish) +NetworkInterfaceMock::NetworkOperation::NetworkOperation( + const CallbackHandle& cbHandle, + const RemoteCommandRequestOnAny& theRequest, + Date_t theRequestDate, + RemoteCommandCompletionFn onFinish) : _requestDate(theRequestDate), _nextConsiderationDate(theRequestDate), _responseDate(), _cbHandle(cbHandle), - _request(theRequest), + _requestOnAny(theRequest), + _request(theRequest, 0), _response(kUnsetResponse), - _onFinish(std::move(onFinish)) {} + _onFinish(std::move(onFinish)) { + invariant(theRequest.target.size() == 1); +} std::string NetworkInterfaceMock::NetworkOperation::getDiagnosticString() const { return str::stream() << "NetworkOperation -- request:'" << _request.toString() @@ -699,7 +707,7 @@ void NetworkInterfaceMock::NetworkOperation::setResponse(Date_t responseDate, void NetworkInterfaceMock::NetworkOperation::finishResponse() { invariant(_onFinish); - _onFinish(_response); + _onFinish({_request.target, _response}); _onFinish = RemoteCommandCompletionFn(); } diff --git a/src/mongo/executor/network_interface_mock.h b/src/mongo/executor/network_interface_mock.h index 498f88abae1..05742d67cc0 100644 --- a/src/mongo/executor/network_interface_mock.h +++ b/src/mongo/executor/network_interface_mock.h @@ -110,7 +110,7 @@ public: Date_t now() override; std::string getHostName() override; Status startCommand(const TaskExecutor::CallbackHandle& cbHandle, - RemoteCommandRequest& request, + RemoteCommandRequestOnAny& request, RemoteCommandCompletionFn&& onFinish, const BatonHandle& baton = nullptr) override; @@ -434,7 +434,7 @@ class NetworkInterfaceMock::NetworkOperation { public: NetworkOperation(); NetworkOperation(const TaskExecutor::CallbackHandle& cbHandle, - const RemoteCommandRequest& theRequest, + const RemoteCommandRequestOnAny& theRequest, Date_t theRequestDate, RemoteCommandCompletionFn onFinish); @@ -464,6 +464,13 @@ public: /** * Gets the request that initiated this operation. */ + const RemoteCommandRequestOnAny& getRequestOnAny() const { + return _requestOnAny; + } + + /** + * Gets the request that initiated this operation. + */ const RemoteCommandRequest& getRequest() const { return _request; } @@ -506,6 +513,7 @@ private: Date_t _nextConsiderationDate; Date_t _responseDate; TaskExecutor::CallbackHandle _cbHandle; + RemoteCommandRequestOnAny _requestOnAny; RemoteCommandRequest _request; TaskExecutor::ResponseStatus _response; RemoteCommandCompletionFn _onFinish; diff --git a/src/mongo/executor/network_interface_mock_test.cpp b/src/mongo/executor/network_interface_mock_test.cpp index f3dd78cd0d4..d31244d368c 100644 --- a/src/mongo/executor/network_interface_mock_test.cpp +++ b/src/mongo/executor/network_interface_mock_test.cpp @@ -90,6 +90,9 @@ public: net().shutdown(); } + RemoteCommandRequestOnAny kUnimportantRequest{ + {testHost()}, "testDB", BSON("test" << 1), rpc::makeEmptyMetadata(), nullptr}; + private: NetworkInterfaceMock _net; ThreadPoolMock _executor; @@ -162,18 +165,20 @@ TEST_F(NetworkInterfaceMockTest, ConnectionHook) { bool commandFinished = false; bool gotCorrectCommandReply = false; - RemoteCommandRequest actualCommandExpected{ - testHost(), "testDB", BSON("test" << 1), rpc::makeEmptyMetadata(), nullptr}; - RemoteCommandResponse actualResponseExpected{BSON("1212121212" - << "12121212121212"), - Milliseconds(0)}; - - ASSERT_OK(net().startCommand(cb, actualCommandExpected, [&](RemoteCommandResponse resp) { - commandFinished = true; - if (resp.isOK()) { - gotCorrectCommandReply = (actualResponseExpected.toString() == resp.toString()); - } - })); + RemoteCommandRequestOnAny actualCommandExpected{ + {testHost()}, "testDB", BSON("test" << 1), rpc::makeEmptyMetadata(), nullptr}; + RemoteCommandOnAnyResponse actualResponseExpected{testHost(), + BSON("1212121212" + << "12121212121212"), + Milliseconds(0)}; + + ASSERT_OK( + net().startCommand(cb, actualCommandExpected, [&](const RemoteCommandOnAnyResponse& resp) { + commandFinished = true; + if (resp.isOK()) { + gotCorrectCommandReply = (actualResponseExpected == resp); + } + })); // At this point validate and makeRequest should have been called. ASSERT(validateCalled); @@ -240,8 +245,8 @@ TEST_F(NetworkInterfaceMockTest, ConnectionHookFailedValidation) { bool commandFinished = false; bool statusPropagated = false; - RemoteCommandRequest request; - ASSERT_OK(net().startCommand(cb, request, [&](RemoteCommandResponse resp) { + RemoteCommandRequestOnAny request{kUnimportantRequest}; + ASSERT_OK(net().startCommand(cb, request, [&](const RemoteCommandOnAnyResponse& resp) { commandFinished = true; statusPropagated = resp.status.code() == ErrorCodes::ConflictingOperationInProgress; @@ -280,9 +285,9 @@ TEST_F(NetworkInterfaceMockTest, ConnectionHookNoRequest) { bool commandFinished = false; - RemoteCommandRequest request; + RemoteCommandRequestOnAny request{kUnimportantRequest}; ASSERT_OK(net().startCommand( - cb, request, [&](RemoteCommandResponse resp) { commandFinished = true; })); + cb, request, [&](const RemoteCommandOnAnyResponse& resp) { commandFinished = true; })); { net().enterNetwork(); @@ -317,8 +322,8 @@ TEST_F(NetworkInterfaceMockTest, ConnectionHookMakeRequestFails) { bool commandFinished = false; bool errorPropagated = false; - RemoteCommandRequest request; - ASSERT_OK(net().startCommand(cb, request, [&](RemoteCommandResponse resp) { + RemoteCommandRequestOnAny request{kUnimportantRequest}; + ASSERT_OK(net().startCommand(cb, request, [&](const RemoteCommandOnAnyResponse& resp) { commandFinished = true; errorPropagated = resp.status.code() == ErrorCodes::InvalidSyncSource; })); @@ -354,8 +359,8 @@ TEST_F(NetworkInterfaceMockTest, ConnectionHookHandleReplyFails) { bool commandFinished = false; bool errorPropagated = false; - RemoteCommandRequest request; - ASSERT_OK(net().startCommand(cb, request, [&](RemoteCommandResponse resp) { + RemoteCommandRequestOnAny request{kUnimportantRequest}; + ASSERT_OK(net().startCommand(cb, request, [&](const RemoteCommandOnAnyResponse& resp) { commandFinished = true; errorPropagated = resp.status.code() == ErrorCodes::CappedPositionLost; })); @@ -426,8 +431,8 @@ TEST_F(NetworkInterfaceMockTest, StartCommandReturnsNotOKIfShutdownHasStarted) { tearDown(); TaskExecutor::CallbackHandle cb{}; - RemoteCommandRequest request; - ASSERT_NOT_OK(net().startCommand(cb, request, [](RemoteCommandResponse resp) {})); + RemoteCommandRequestOnAny request{kUnimportantRequest}; + ASSERT_NOT_OK(net().startCommand(cb, request, [](const RemoteCommandOnAnyResponse& resp) {})); } TEST_F(NetworkInterfaceMockTest, SetAlarmReturnsNotOKIfShutdownHasStarted) { @@ -441,11 +446,13 @@ TEST_F(NetworkInterfaceMockTest, CommandTimeout) { startNetwork(); TaskExecutor::CallbackHandle cb; - RemoteCommandRequest request; + RemoteCommandRequestOnAny request{kUnimportantRequest}; request.timeout = Milliseconds(2000); ErrorCodes::Error statusPropagated = ErrorCodes::OK; - auto finishFn = [&](RemoteCommandResponse resp) { statusPropagated = resp.status.code(); }; + auto finishFn = [&](const RemoteCommandOnAnyResponse& resp) { + statusPropagated = resp.status.code(); + }; // // Command times out. diff --git a/src/mongo/executor/network_interface_tl.cpp b/src/mongo/executor/network_interface_tl.cpp index 80b91244fea..c7528397951 100644 --- a/src/mongo/executor/network_interface_tl.cpp +++ b/src/mongo/executor/network_interface_tl.cpp @@ -40,6 +40,7 @@ #include "mongo/util/concurrency/idle_thread_block.h" #include "mongo/util/log.h" #include "mongo/util/net/socket_utils.h" +#include "mongo/util/strong_weak_finish_line.h" namespace mongo { namespace executor { @@ -173,19 +174,19 @@ Date_t NetworkInterfaceTL::now() { } NetworkInterfaceTL::CommandState::CommandState(NetworkInterfaceTL* interface_, - RemoteCommandRequest request_, + RemoteCommandRequestOnAny request_, const TaskExecutor::CallbackHandle& cbHandle_, - Promise<RemoteCommandResponse> promise_) + Promise<RemoteCommandOnAnyResponse> promise_) : interface(interface_), - request(std::move(request_)), + requestOnAny(std::move(request_)), cbHandle(cbHandle_), promise(std::move(promise_)) {} auto NetworkInterfaceTL::CommandState::make(NetworkInterfaceTL* interface, - RemoteCommandRequest request, + RemoteCommandRequestOnAny request, const TaskExecutor::CallbackHandle& cbHandle, - Promise<RemoteCommandResponse> promise) { + Promise<RemoteCommandOnAnyResponse> promise) { auto state = std::make_shared<CommandState>(interface, std::move(request), cbHandle, std::move(promise)); @@ -207,7 +208,7 @@ NetworkInterfaceTL::CommandState::~CommandState() { } Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHandle, - RemoteCommandRequest& request, + RemoteCommandRequestOnAny& request, RemoteCommandCompletionFn&& onFinish, const BatonHandle& baton) { if (inShutdown()) { @@ -227,36 +228,38 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa request.metadata = newMetadata.obj(); } - auto pf = makePromiseFuture<RemoteCommandResponse>(); + auto cmdPF = makePromiseFuture<RemoteCommandOnAnyResponse>(); - auto state = CommandState::make(this, request, cbHandle, std::move(pf.promise)); - state->start = now(); - if (state->request.timeout != state->request.kNoTimeout) { - state->deadline = state->start + state->request.timeout; + auto cmdState = CommandState::make(this, request, cbHandle, std::move(cmdPF.promise)); + cmdState->start = now(); + if (cmdState->requestOnAny.timeout != cmdState->requestOnAny.kNoTimeout) { + cmdState->deadline = cmdState->start + cmdState->requestOnAny.timeout; } auto executor = baton ? ExecutorPtr(baton) : ExecutorPtr(_reactor); - std::move(pf.future) + std::move(cmdPF.future) .thenRunOn(executor) - .onError([requestId = state->request.id](auto error)->StatusWith<RemoteCommandResponse> { - LOG(2) << "Failed to get connection from pool for request " << requestId << ": " - << redact(error); - - // The TransportLayer has, for historical reasons returned SocketException for - // network errors, but sharding assumes HostUnreachable on network errors. - if (error == ErrorCodes::SocketException) { - error = Status(ErrorCodes::HostUnreachable, error.reason()); - } - return error; - }) - .getAsync([ this, state, onFinish = std::move(onFinish) ]( - StatusWith<RemoteCommandResponse> response) { - auto duration = now() - state->start; + .onError([requestId = cmdState->requestOnAny.id](auto error) + ->StatusWith<RemoteCommandOnAnyResponse> { + LOG(2) << "Failed to get connection from pool for request " << requestId + << ": " << redact(error); + + // The TransportLayer has, for historical reasons returned SocketException + // for network errors, but sharding assumes HostUnreachable on network + // errors. + if (error == ErrorCodes::SocketException) { + error = Status(ErrorCodes::HostUnreachable, error.reason()); + } + return error; + }) + .getAsync([ this, cmdState, onFinish = std::move(onFinish) ]( + StatusWith<RemoteCommandOnAnyResponse> response) { + auto duration = now() - cmdState->start; if (!response.isOK()) { - onFinish(RemoteCommandResponse(response.getStatus(), duration)); + onFinish(RemoteCommandOnAnyResponse(boost::none, response.getStatus(), duration)); } else { const auto& rs = response.getValue(); - LOG(2) << "Request " << state->request.id << " finished with response: " + LOG(2) << "Request " << cmdState->requestOnAny.id << " finished with response: " << redact(rs.isOK() ? rs.data.toString() : rs.status.toString()); onFinish(rs); } @@ -267,23 +270,70 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa return Status::OK(); } - _pool->get(request.target, request.sslMode, request.timeout) - .thenRunOn(executor) - .then([this, state, baton](auto conn) { + auto[connPromise, connFuture] = makePromiseFuture<ConnectionPool::ConnectionHandle>(); + + std::move(connFuture).thenRunOn(executor).getAsync([this, cmdState, baton](auto swConn) { + auto status = swConn.getStatus(); + + if (status.isOK()) { if (MONGO_FAIL_POINT(networkInterfaceDiscardCommandsAfterAcquireConn)) { log() << "Discarding command due to failpoint after acquireConn"; - return; + } else { + try { + _onAcquireConn(cmdState, std::move(swConn.getValue()), baton); + } catch (const DBException& ex) { + status = ex.toStatus(); + } } + } - _onAcquireConn(state, std::move(conn), baton); - }) - .getAsync([this, state](auto status) { - // If we couldn't get a connection or _onAcquireConn threw, then we should clean up - // here. - if (!status.isOK() && !state->done.swap(true)) { - state->promise.setError(status); + if (!status.isOK() && !cmdState->done.swap(true)) { + cmdState->promise.setError(status); + } + }); + + struct ConnState { + explicit ConnState(size_t n, Promise<ConnectionPool::ConnectionHandle> p) + : finishLine(n), promise(std::move(p)) {} + + StrongWeakFinishLine finishLine; + Promise<ConnectionPool::ConnectionHandle> promise; + }; + + auto connState = std::make_shared<ConnState>(request.target.size(), std::move(connPromise)); + + for (size_t idx = 0; idx < request.target.size() && !connState->finishLine.isReady(); ++idx) { + auto getConnectionCallback = [ connState, cmdState, idx ]( + StatusWith<ConnectionPool::ConnectionHandle> swConn) noexcept { + if (swConn.isOK()) { + if (connState->finishLine.arriveStrongly()) { + cmdState->request.emplace(cmdState->requestOnAny, idx); + connState->promise.emplaceValue(std::move(swConn.getValue())); + } else { + swConn.getValue()->indicateSuccess(); + } + } else { + LOG(2) << "Failed to get connection from pool for request " + << cmdState->requestOnAny.id << ": " << swConn.getStatus(); + + if (connState->finishLine.arriveWeakly()) { + connState->promise.setError(swConn.getStatus()); + } } - }); + }; + + if (auto semi = _pool->get(request.target[idx], request.sslMode, request.timeout); + semi.isReady()) { + // If we have a connection in hand, stay on thread and immediately handle it + getConnectionCallback(std::move(semi).getNoThrow()); + } else { + // Otherwise route all connection management over the networking reactor, to ensure we + // promptly return connections to the pool that may not be needed + std::move(semi) + .thenRunOn(ExecutorPtr(_reactor)) + .getAsync(std::move(getConnectionCallback)); + } + } return Status::OK(); } @@ -312,7 +362,7 @@ void NetworkInterfaceTL::_onAcquireConn(std::shared_ptr<CommandState> state, "connection from the pool, took " << connDuration << ", timeout was set to " - << state->request.timeout); + << state->requestOnAny.timeout); } state->timer = _reactor->makeTimer(); @@ -333,9 +383,9 @@ void NetworkInterfaceTL::_onAcquireConn(std::shared_ptr<CommandState> state, } const std::string message = str::stream() - << "Request " << state->request.id << " timed out" + << "Request " << state->requestOnAny.id << " timed out" << ", deadline was " << state->deadline.toString() << ", op was " - << redact(state->request.toString()); + << redact(state->requestOnAny.toString()); LOG(2) << message; state->promise.setError( @@ -345,21 +395,22 @@ void NetworkInterfaceTL::_onAcquireConn(std::shared_ptr<CommandState> state, }); } - client->runCommandRequest(state->request, baton) + client->runCommandRequest(*state->request, baton) .then([this, state](RemoteCommandResponse response) { if (state->done.load()) { uasserted(ErrorCodes::CallbackCanceled, "Callback was canceled"); } + const auto& target = state->conn->getHostAndPort(); + if (_metadataHook && response.status.isOK()) { - auto target = state->conn->getHostAndPort().toString(); response.status = - _metadataHook->readReplyMetadata(nullptr, std::move(target), response.data); + _metadataHook->readReplyMetadata(nullptr, target.toString(), response.data); } - return RemoteCommandResponse(std::move(response)); + return RemoteCommandOnAnyResponse(target, std::move(response)); }) - .getAsync([this, state, baton](StatusWith<RemoteCommandResponse> swr) { + .getAsync([this, state, baton](StatusWith<RemoteCommandOnAnyResponse> swr) { if (!swr.isOK()) { state->conn->indicateFailure(swr.getStatus()); } else if (!swr.getValue().isOK()) { @@ -414,10 +465,11 @@ void NetworkInterfaceTL::cancelCommand(const TaskExecutor::CallbackHandle& cbHan _counters.canceled++; } - LOG(2) << "Canceling operation; original request was: " << redact(state->request.toString()); + LOG(2) << "Canceling operation; original request was: " + << redact(state->requestOnAny.toString()); state->promise.setError({ErrorCodes::CallbackCanceled, str::stream() << "Command canceled; original request was: " - << redact(state->request.toString())}); + << redact(state->requestOnAny.toString())}); if (state->conn) { auto client = checked_cast<connection_pool_tl::TLConnection*>(state->conn.get()); client->client()->cancel(baton); diff --git a/src/mongo/executor/network_interface_tl.h b/src/mongo/executor/network_interface_tl.h index 7e7b24b8c71..8ce93e1da08 100644 --- a/src/mongo/executor/network_interface_tl.h +++ b/src/mongo/executor/network_interface_tl.h @@ -65,7 +65,7 @@ public: void signalWorkAvailable() override; Date_t now() override; Status startCommand(const TaskExecutor::CallbackHandle& cbHandle, - RemoteCommandRequest& request, + RemoteCommandRequestOnAny& request, RemoteCommandCompletionFn&& onFinish, const BatonHandle& baton) override; @@ -86,21 +86,22 @@ public: private: struct CommandState { CommandState(NetworkInterfaceTL* interface_, - RemoteCommandRequest request_, + RemoteCommandRequestOnAny request_, const TaskExecutor::CallbackHandle& cbHandle_, - Promise<RemoteCommandResponse> promise_); + Promise<RemoteCommandOnAnyResponse> promise_); ~CommandState(); // Create a new CommandState in a shared_ptr // Prefer this over raw construction static auto make(NetworkInterfaceTL* interface, - RemoteCommandRequest request, + RemoteCommandRequestOnAny request, const TaskExecutor::CallbackHandle& cbHandle, - Promise<RemoteCommandResponse> promise); + Promise<RemoteCommandOnAnyResponse> promise); NetworkInterfaceTL* interface; - RemoteCommandRequest request; + RemoteCommandRequestOnAny requestOnAny; + boost::optional<RemoteCommandRequest> request; TaskExecutor::CallbackHandle cbHandle; Date_t deadline = RemoteCommandRequest::kNoExpirationDate; Date_t start; @@ -109,7 +110,7 @@ private: std::unique_ptr<transport::ReactorTimer> timer; AtomicWord<bool> done; - Promise<RemoteCommandResponse> promise; + Promise<RemoteCommandOnAnyResponse> promise; }; struct AlarmState { diff --git a/src/mongo/executor/remote_command_request.cpp b/src/mongo/executor/remote_command_request.cpp index 65cbb3686e2..e46b483a9cf 100644 --- a/src/mongo/executor/remote_command_request.cpp +++ b/src/mongo/executor/remote_command_request.cpp @@ -31,12 +31,15 @@ #include "mongo/executor/remote_command_request.h" -#include <ostream> +#include <fmt/format.h> #include "mongo/bson/simple_bsonobj_comparator.h" #include "mongo/platform/atomic_word.h" +#include "mongo/util/if_constexpr.h" #include "mongo/util/str.h" +using namespace fmt::literals; + namespace mongo { namespace executor { namespace { @@ -47,43 +50,69 @@ AtomicWord<unsigned long long> requestIdCounter(0); } // namespace -constexpr Milliseconds RemoteCommandRequest::kNoTimeout; -constexpr Date_t RemoteCommandRequest::kNoExpirationDate; +constexpr Milliseconds RemoteCommandRequestBase::kNoTimeout; -RemoteCommandRequest::RemoteCommandRequest() : id(requestIdCounter.addAndFetch(1)) {} +constexpr Date_t RemoteCommandRequestBase::kNoExpirationDate; -RemoteCommandRequest::RemoteCommandRequest(RequestId requestId, - const HostAndPort& theTarget, - const std::string& theDbName, - const BSONObj& theCmdObj, - const BSONObj& metadataObj, - OperationContext* opCtx, - Milliseconds timeoutMillis) +RemoteCommandRequestBase::RemoteCommandRequestBase(RequestId requestId, + const std::string& theDbName, + const BSONObj& theCmdObj, + const BSONObj& metadataObj, + OperationContext* opCtx, + Milliseconds timeoutMillis) : id(requestId), - target(theTarget), dbname(theDbName), metadata(metadataObj), cmdObj(theCmdObj), opCtx(opCtx), timeout(timeoutMillis) {} -RemoteCommandRequest::RemoteCommandRequest(const HostAndPort& theTarget, - const std::string& theDbName, - const BSONObj& theCmdObj, - const BSONObj& metadataObj, - OperationContext* opCtx, - Milliseconds timeoutMillis) - : RemoteCommandRequest(requestIdCounter.addAndFetch(1), - theTarget, - theDbName, - theCmdObj, - metadataObj, - opCtx, - timeoutMillis) {} - -std::string RemoteCommandRequest::toString() const { +RemoteCommandRequestBase::RemoteCommandRequestBase() : id(requestIdCounter.addAndFetch(1)) {} + +template <typename T> +RemoteCommandRequestImpl<T>::RemoteCommandRequestImpl() = default; + +template <typename T> +RemoteCommandRequestImpl<T>::RemoteCommandRequestImpl(RequestId requestId, + const T& theTarget, + const std::string& theDbName, + const BSONObj& theCmdObj, + const BSONObj& metadataObj, + OperationContext* opCtx, + Milliseconds timeoutMillis) + : RemoteCommandRequestBase(requestId, theDbName, theCmdObj, metadataObj, opCtx, timeoutMillis), + target(theTarget) { + IF_CONSTEXPR(std::is_same_v<T, std::vector<HostAndPort>>) { + invariant(!theTarget.empty()); + } +} + +template <typename T> +RemoteCommandRequestImpl<T>::RemoteCommandRequestImpl(const T& theTarget, + const std::string& theDbName, + const BSONObj& theCmdObj, + const BSONObj& metadataObj, + OperationContext* opCtx, + Milliseconds timeoutMillis) + : RemoteCommandRequestImpl(requestIdCounter.addAndFetch(1), + theTarget, + theDbName, + theCmdObj, + metadataObj, + opCtx, + timeoutMillis) {} + +template <typename T> +std::string RemoteCommandRequestImpl<T>::toString() const { str::stream out; - out << "RemoteCommand " << id << " -- target:" << target.toString() << " db:" << dbname; + out << "RemoteCommand " << id << " -- target:"; + IF_CONSTEXPR(std::is_same_v<HostAndPort, T>) { + out << target.toString(); + } + else { + out << "[{}]"_format(fmt::join(target, ", ")); + } + out << " db:" << dbname; if (expirationDate != kNoExpirationDate) { out << " expDate:" << expirationDate.toString(); @@ -93,7 +122,8 @@ std::string RemoteCommandRequest::toString() const { return out; } -bool RemoteCommandRequest::operator==(const RemoteCommandRequest& rhs) const { +template <typename T> +bool RemoteCommandRequestImpl<T>::operator==(const RemoteCommandRequestImpl& rhs) const { if (this == &rhs) { return true; } @@ -103,13 +133,13 @@ bool RemoteCommandRequest::operator==(const RemoteCommandRequest& rhs) const { timeout == rhs.timeout; } -bool RemoteCommandRequest::operator!=(const RemoteCommandRequest& rhs) const { +template <typename T> +bool RemoteCommandRequestImpl<T>::operator!=(const RemoteCommandRequestImpl& rhs) const { return !(*this == rhs); } -std::ostream& operator<<(std::ostream& os, const RemoteCommandRequest& request) { - return os << request.toString(); -} +template struct RemoteCommandRequestImpl<HostAndPort>; +template struct RemoteCommandRequestImpl<std::vector<HostAndPort>>; } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/remote_command_request.h b/src/mongo/executor/remote_command_request.h index 81469c10878..b776a14eafa 100644 --- a/src/mongo/executor/remote_command_request.h +++ b/src/mongo/executor/remote_command_request.h @@ -35,16 +35,14 @@ #include "mongo/db/jsobj.h" #include "mongo/rpc/metadata.h" #include "mongo/transport/transport_layer.h" +#include "mongo/util/concepts.h" #include "mongo/util/net/hostandport.h" #include "mongo/util/time_support.h" namespace mongo { namespace executor { -/** - * Type of object describing a command to execute against a remote MongoDB node. - */ -struct RemoteCommandRequest { +struct RemoteCommandRequestBase { // Indicates that there is no timeout for the request to complete static constexpr Milliseconds kNoTimeout{-1}; @@ -54,40 +52,17 @@ struct RemoteCommandRequest { // Type to represent the internal id of this request typedef uint64_t RequestId; - RemoteCommandRequest(); - - RemoteCommandRequest(RequestId requestId, - const HostAndPort& theTarget, - const std::string& theDbName, - const BSONObj& theCmdObj, - const BSONObj& metadataObj, - OperationContext* opCtx, - Milliseconds timeoutMillis); - - RemoteCommandRequest(const HostAndPort& theTarget, - const std::string& theDbName, - const BSONObj& theCmdObj, - const BSONObj& metadataObj, - OperationContext* opCtx, - Milliseconds timeoutMillis = kNoTimeout); - - RemoteCommandRequest(const HostAndPort& theTarget, - const std::string& theDbName, - const BSONObj& theCmdObj, - OperationContext* opCtx, - Milliseconds timeoutMillis = kNoTimeout) - : RemoteCommandRequest( - theTarget, theDbName, theCmdObj, rpc::makeEmptyMetadata(), opCtx, timeoutMillis) {} - - std::string toString() const; - - bool operator==(const RemoteCommandRequest& rhs) const; - bool operator!=(const RemoteCommandRequest& rhs) const; + RemoteCommandRequestBase(); + RemoteCommandRequestBase(RequestId requestId, + const std::string& theDbName, + const BSONObj& theCmdObj, + const BSONObj& metadataObj, + OperationContext* opCtx, + Milliseconds timeoutMillis); // Internal id of this request. Not interpreted and used for tracing purposes only. RequestId id; - HostAndPort target; std::string dbname; BSONObj metadata{rpc::makeEmptyMetadata()}; BSONObj cmdObj; @@ -107,9 +82,70 @@ struct RemoteCommandRequest { Date_t expirationDate = kNoExpirationDate; transport::ConnectSSLMode sslMode = transport::kGlobalSSLMode; + +protected: + ~RemoteCommandRequestBase() = default; +}; + +/** + * Type of object describing a command to execute against a remote MongoDB node. + */ +template <typename Target> +struct RemoteCommandRequestImpl : RemoteCommandRequestBase { + RemoteCommandRequestImpl(); + + // Allow implicit conversion from RemoteCommandRequest to RemoteCommandRequestOnAny + REQUIRES_FOR_NON_TEMPLATE(std::is_same_v<Target, std::vector<HostAndPort>>) + RemoteCommandRequestImpl(const RemoteCommandRequestImpl<HostAndPort>& other) + : RemoteCommandRequestBase(other), target({other.target}) {} + + // Allow conversion from RemoteCommandRequestOnAny to RemoteCommandRequest with the index of a + // particular host + REQUIRES_FOR_NON_TEMPLATE(std::is_same_v<Target, HostAndPort>) + RemoteCommandRequestImpl(const RemoteCommandRequestImpl<std::vector<HostAndPort>>& other, + size_t idx) + : RemoteCommandRequestBase(other), target(other.target[idx]) {} + + RemoteCommandRequestImpl(RequestId requestId, + const Target& theTarget, + const std::string& theDbName, + const BSONObj& theCmdObj, + const BSONObj& metadataObj, + OperationContext* opCtx, + Milliseconds timeoutMillis); + + RemoteCommandRequestImpl(const Target& theTarget, + const std::string& theDbName, + const BSONObj& theCmdObj, + const BSONObj& metadataObj, + OperationContext* opCtx, + Milliseconds timeoutMillis = kNoTimeout); + + RemoteCommandRequestImpl(const Target& theTarget, + const std::string& theDbName, + const BSONObj& theCmdObj, + OperationContext* opCtx, + Milliseconds timeoutMillis = kNoTimeout) + : RemoteCommandRequestImpl( + theTarget, theDbName, theCmdObj, rpc::makeEmptyMetadata(), opCtx, timeoutMillis) {} + + std::string toString() const; + + bool operator==(const RemoteCommandRequestImpl& rhs) const; + bool operator!=(const RemoteCommandRequestImpl& rhs) const; + + friend std::ostream& operator<<(std::ostream& os, const RemoteCommandRequestImpl& response) { + return (os << response.toString()); + } + + Target target; }; -std::ostream& operator<<(std::ostream& os, const RemoteCommandRequest& response); +extern template struct RemoteCommandRequestImpl<HostAndPort>; +extern template struct RemoteCommandRequestImpl<std::vector<HostAndPort>>; + +using RemoteCommandRequest = RemoteCommandRequestImpl<HostAndPort>; +using RemoteCommandRequestOnAny = RemoteCommandRequestImpl<std::vector<HostAndPort>>; } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/remote_command_response.cpp b/src/mongo/executor/remote_command_response.cpp index cfe23390805..8baafdb3d67 100644 --- a/src/mongo/executor/remote_command_response.cpp +++ b/src/mongo/executor/remote_command_response.cpp @@ -38,48 +38,37 @@ namespace mongo { namespace executor { -RemoteCommandResponse::RemoteCommandResponse(ErrorCodes::Error code, std::string reason) +RemoteCommandResponseBase::RemoteCommandResponseBase(ErrorCodes::Error code, std::string reason) : status(code, reason){}; -RemoteCommandResponse::RemoteCommandResponse(ErrorCodes::Error code, - std::string reason, - Milliseconds millis) +RemoteCommandResponseBase::RemoteCommandResponseBase(ErrorCodes::Error code, + std::string reason, + Milliseconds millis) : elapsedMillis(millis), status(code, reason) {} -RemoteCommandResponse::RemoteCommandResponse(Status s) : status(std::move(s)) { +RemoteCommandResponseBase::RemoteCommandResponseBase(Status s) : status(std::move(s)) { invariant(!isOK()); }; -RemoteCommandResponse::RemoteCommandResponse(Status s, Milliseconds millis) +RemoteCommandResponseBase::RemoteCommandResponseBase(Status s, Milliseconds millis) : elapsedMillis(millis), status(std::move(s)) { invariant(!isOK()); }; -RemoteCommandResponse::RemoteCommandResponse(BSONObj dataObj, Milliseconds millis) +RemoteCommandResponseBase::RemoteCommandResponseBase(BSONObj dataObj, Milliseconds millis) : data(std::move(dataObj)), elapsedMillis(millis) { // The buffer backing the default empty BSONObj has static duration so it is effectively // owned. invariant(data.isOwned() || data.objdata() == BSONObj().objdata()); }; -RemoteCommandResponse::RemoteCommandResponse(Message messageArg, - BSONObj dataObj, - Milliseconds millis) - : message(std::make_shared<const Message>(std::move(messageArg))), - data(std::move(dataObj)), - elapsedMillis(millis) { - if (!data.isOwned()) { - data.shareOwnershipWith(message->sharedBuffer()); - } -} - // TODO(amidvidy): we currently discard output docs when we use this constructor. We should // have RCR hold those too, but we need more machinery before that is possible. -RemoteCommandResponse::RemoteCommandResponse(const rpc::ReplyInterface& rpcReply, - Milliseconds millis) - : RemoteCommandResponse(rpcReply.getCommandReply(), std::move(millis)) {} +RemoteCommandResponseBase::RemoteCommandResponseBase(const rpc::ReplyInterface& rpcReply, + Milliseconds millis) + : RemoteCommandResponseBase(rpcReply.getCommandReply(), std::move(millis)) {} -bool RemoteCommandResponse::isOK() const { +bool RemoteCommandResponseBase::isOK() const { return status.isOK(); } @@ -104,5 +93,64 @@ std::ostream& operator<<(std::ostream& os, const RemoteCommandResponse& response return os << response.toString(); } +RemoteCommandResponse::RemoteCommandResponse(const RemoteCommandOnAnyResponse& other) + : RemoteCommandResponseBase(other) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, + ErrorCodes::Error code, + std::string reason) + : RemoteCommandResponseBase(code, std::move(reason)), target(std::move(hp)) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, + ErrorCodes::Error code, + std::string reason, + Milliseconds millis) + : RemoteCommandResponseBase(code, std::move(reason), millis), target(std::move(hp)) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, Status s) + : RemoteCommandResponseBase(std::move(s)), target(std::move(hp)) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, + Status s, + Milliseconds millis) + : RemoteCommandResponseBase(std::move(s), millis), target(std::move(hp)) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(HostAndPort hp, + BSONObj dataObj, + Milliseconds millis) + : RemoteCommandResponseBase(std::move(dataObj), millis), target(std::move(hp)) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(HostAndPort hp, + const rpc::ReplyInterface& rpcReply, + Milliseconds millis) + : RemoteCommandResponseBase(rpcReply, millis), target(std::move(hp)) {} + +RemoteCommandOnAnyResponse::RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, + const RemoteCommandResponse& other) + : RemoteCommandResponseBase(other), target(std::move(hp)) {} + +bool RemoteCommandOnAnyResponse::operator==(const RemoteCommandOnAnyResponse& rhs) const { + if (this == &rhs) { + return true; + } + SimpleBSONObjComparator bsonComparator; + return bsonComparator.evaluate(data == rhs.data) && elapsedMillis == rhs.elapsedMillis && + target == rhs.target; +} + +bool RemoteCommandOnAnyResponse::operator!=(const RemoteCommandOnAnyResponse& rhs) const { + return !(*this == rhs); +} + +std::string RemoteCommandOnAnyResponse::toString() const { + return str::stream() << "RemoteOnAnyResponse -- " + << " cmd:" << data.toString() << " target: " + << (!target ? StringData("[none]") : StringData(target->toString())); +} + +std::ostream& operator<<(std::ostream& os, const RemoteCommandOnAnyResponse& response) { + return os << response.toString(); +} + } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/remote_command_response.h b/src/mongo/executor/remote_command_response.h index 18e6da0fcd7..2cd90b37974 100644 --- a/src/mongo/executor/remote_command_response.h +++ b/src/mongo/executor/remote_command_response.h @@ -37,6 +37,7 @@ #include "mongo/base/status.h" #include "mongo/db/jsobj.h" #include "mongo/rpc/message.h" +#include "mongo/util/net/hostandport.h" #include "mongo/util/time_support.h" namespace mongo { @@ -51,37 +52,85 @@ namespace executor { /** * Type of object describing the response of previously sent RemoteCommandRequest. */ -struct RemoteCommandResponse { - RemoteCommandResponse() = default; +struct RemoteCommandResponseBase { + RemoteCommandResponseBase() = default; - RemoteCommandResponse(ErrorCodes::Error code, std::string reason); + RemoteCommandResponseBase(ErrorCodes::Error code, std::string reason); - RemoteCommandResponse(ErrorCodes::Error code, std::string reason, Milliseconds millis); + RemoteCommandResponseBase(ErrorCodes::Error code, std::string reason, Milliseconds millis); - RemoteCommandResponse(Status s); + RemoteCommandResponseBase(Status s); - RemoteCommandResponse(Status s, Milliseconds millis); + RemoteCommandResponseBase(Status s, Milliseconds millis); - RemoteCommandResponse(BSONObj dataObj, Milliseconds millis); + RemoteCommandResponseBase(BSONObj dataObj, Milliseconds millis); - RemoteCommandResponse(Message messageArg, BSONObj dataObj, Milliseconds millis); - - RemoteCommandResponse(const rpc::ReplyInterface& rpcReply, Milliseconds millis); + RemoteCommandResponseBase(const rpc::ReplyInterface& rpcReply, Milliseconds millis); bool isOK() const; + BSONObj data; // Always owned. May point into message. + boost::optional<Milliseconds> elapsedMillis; + Status status = Status::OK(); + +protected: + ~RemoteCommandResponseBase() = default; +}; + +struct RemoteCommandOnAnyResponse; + +struct RemoteCommandResponse : RemoteCommandResponseBase { + using RemoteCommandResponseBase::RemoteCommandResponseBase; + + RemoteCommandResponse(const RemoteCommandOnAnyResponse& other); + std::string toString() const; bool operator==(const RemoteCommandResponse& rhs) const; bool operator!=(const RemoteCommandResponse& rhs) const; - std::shared_ptr<const Message> message; // May be null. - BSONObj data; // Always owned. May point into message. - boost::optional<Milliseconds> elapsedMillis; - Status status = Status::OK(); + friend std::ostream& operator<<(std::ostream& os, const RemoteCommandResponse& request); }; -std::ostream& operator<<(std::ostream& os, const RemoteCommandResponse& request); +/** + * This type is a RemoteCommandResponse + the target that the origin request was actually run on. + * + * For the moment, it is only returned by scheduleRemoteCommandOnAny, and should be thought of as a + * different return type for that rpc api, rather than a higher-information RemoteCommandResponse. + */ +struct RemoteCommandOnAnyResponse : RemoteCommandResponseBase { + RemoteCommandOnAnyResponse() = default; + + RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, + ErrorCodes::Error code, + std::string reason); + + RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, + ErrorCodes::Error code, + std::string reason, + Milliseconds millis); + + RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, Status s); + + RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, Status s, Milliseconds millis); + + RemoteCommandOnAnyResponse(HostAndPort hp, BSONObj dataObj, Milliseconds millis); + + RemoteCommandOnAnyResponse(HostAndPort hp, + const rpc::ReplyInterface& rpcReply, + Milliseconds millis); + + RemoteCommandOnAnyResponse(boost::optional<HostAndPort> hp, const RemoteCommandResponse& other); + + std::string toString() const; + + bool operator==(const RemoteCommandOnAnyResponse& rhs) const; + bool operator!=(const RemoteCommandOnAnyResponse& rhs) const; + + boost::optional<HostAndPort> target; + + friend std::ostream& operator<<(std::ostream& os, const RemoteCommandOnAnyResponse& request); +}; } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/scoped_task_executor.cpp b/src/mongo/executor/scoped_task_executor.cpp index ba7c1fc7f3b..e891d3c69a8 100644 --- a/src/mongo/executor/scoped_task_executor.cpp +++ b/src/mongo/executor/scoped_task_executor.cpp @@ -131,12 +131,13 @@ public: std::move(work)); } - StatusWith<CallbackHandle> scheduleRemoteCommand(const RemoteCommandRequest& request, - const RemoteCommandCallbackFn& cb, - const BatonHandle& baton = nullptr) override { + StatusWith<CallbackHandle> scheduleRemoteCommandOnAny( + const RemoteCommandRequestOnAny& request, + const RemoteCommandOnAnyCallbackFn& cb, + const BatonHandle& baton = nullptr) override { return _wrapCallback( [&](auto&& x) { - return _executor->scheduleRemoteCommand(request, std::move(x), baton); + return _executor->scheduleRemoteCommandOnAny(request, std::move(x), baton); }, cb); } @@ -250,9 +251,9 @@ private: args.status = kShutdownStatus; } else { - static_assert(std::is_same_v<ArgsT, RemoteCommandCallbackArgs>, + static_assert(std::is_same_v<ArgsT, RemoteCommandOnAnyCallbackArgs>, "_wrapCallback only supports CallbackArgs and " - "RemoteCommandCallbackArgs"); + "RemoteCommandOnAnyCallbackArgs"); args.response.status = kShutdownStatus; } diff --git a/src/mongo/executor/task_executor.cpp b/src/mongo/executor/task_executor.cpp index dea064860c7..83d69b525fd 100644 --- a/src/mongo/executor/task_executor.cpp +++ b/src/mongo/executor/task_executor.cpp @@ -79,6 +79,20 @@ TaskExecutor::RemoteCommandCallbackArgs::RemoteCommandCallbackArgs( const ResponseStatus& theResponse) : executor(theExecutor), myHandle(theHandle), request(theRequest), response(theResponse) {} +TaskExecutor::RemoteCommandCallbackArgs::RemoteCommandCallbackArgs( + const RemoteCommandOnAnyCallbackArgs& other, size_t idx) + : executor(other.executor), + myHandle(other.myHandle), + request(other.request, idx), + response(other.response) {} + +TaskExecutor::RemoteCommandOnAnyCallbackArgs::RemoteCommandOnAnyCallbackArgs( + TaskExecutor* theExecutor, + const CallbackHandle& theHandle, + const RemoteCommandRequestOnAny& theRequest, + const ResponseOnAnyStatus& theResponse) + : executor(theExecutor), myHandle(theHandle), request(theRequest), response(theResponse) {} + TaskExecutor::CallbackState* TaskExecutor::getCallbackFromHandle(const CallbackHandle& cbHandle) { return cbHandle.getCallback(); } @@ -96,5 +110,15 @@ void TaskExecutor::setCallbackForHandle(CallbackHandle* cbHandle, cbHandle->setCallback(std::move(callback)); } + +StatusWith<TaskExecutor::CallbackHandle> TaskExecutor::scheduleRemoteCommand( + const RemoteCommandRequest& request, + const RemoteCommandCallbackFn& cb, + const BatonHandle& baton) { + return scheduleRemoteCommandOnAny(request, [cb](const RemoteCommandOnAnyCallbackArgs& args) { + cb({args, 0}); + }); +} + } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/task_executor.h b/src/mongo/executor/task_executor.h index e06048e7269..cb1b2f00cef 100644 --- a/src/mongo/executor/task_executor.h +++ b/src/mongo/executor/task_executor.h @@ -79,12 +79,14 @@ class TaskExecutor : public OutOfLineExecutor { public: struct CallbackArgs; struct RemoteCommandCallbackArgs; + struct RemoteCommandOnAnyCallbackArgs; class CallbackState; class CallbackHandle; class EventState; class EventHandle; using ResponseStatus = RemoteCommandResponse; + using ResponseOnAnyStatus = RemoteCommandOnAnyResponse; /** * Type of a regular callback function. @@ -106,6 +108,9 @@ public: */ using RemoteCommandCallbackFn = stdx::function<void(const RemoteCommandCallbackArgs&)>; + using RemoteCommandOnAnyCallbackFn = + stdx::function<void(const RemoteCommandOnAnyCallbackArgs&)>; + /** * Destroys the task executor. Implicitly performs the equivalent of shutdown() and join() * before returning, if necessary. @@ -254,9 +259,13 @@ public: * Contract: Implementations should guarantee that callback should be called *after* doing any * processing related to the callback. */ - virtual StatusWith<CallbackHandle> scheduleRemoteCommand( - const RemoteCommandRequest& request, - const RemoteCommandCallbackFn& cb, + virtual StatusWith<CallbackHandle> scheduleRemoteCommand(const RemoteCommandRequest& request, + const RemoteCommandCallbackFn& cb, + const BatonHandle& baton = nullptr); + + virtual StatusWith<CallbackHandle> scheduleRemoteCommandOnAny( + const RemoteCommandRequestOnAny& request, + const RemoteCommandOnAnyCallbackFn& cb, const BatonHandle& baton = nullptr) = 0; /** @@ -457,11 +466,25 @@ struct TaskExecutor::RemoteCommandCallbackArgs { const RemoteCommandRequest& theRequest, const ResponseStatus& theResponse); + RemoteCommandCallbackArgs(const RemoteCommandOnAnyCallbackArgs& other, size_t idx); + TaskExecutor* executor; CallbackHandle myHandle; RemoteCommandRequest request; ResponseStatus response; }; +struct TaskExecutor::RemoteCommandOnAnyCallbackArgs { + RemoteCommandOnAnyCallbackArgs(TaskExecutor* theExecutor, + const CallbackHandle& theHandle, + const RemoteCommandRequestOnAny& theRequest, + const ResponseOnAnyStatus& theResponse); + + TaskExecutor* executor; + CallbackHandle myHandle; + RemoteCommandRequestOnAny request; + ResponseOnAnyStatus response; +}; + } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/task_executor_test_fixture.h b/src/mongo/executor/task_executor_test_fixture.h index f0d599cddf2..f8815b45360 100644 --- a/src/mongo/executor/task_executor_test_fixture.h +++ b/src/mongo/executor/task_executor_test_fixture.h @@ -32,12 +32,12 @@ #include <memory> #include "mongo/base/string_data.h" +#include "mongo/executor/remote_command_request.h" #include "mongo/unittest/unittest.h" namespace mongo { namespace executor { -struct RemoteCommandRequest; class TaskExecutor; class NetworkInterface; class NetworkInterfaceMock; diff --git a/src/mongo/executor/thread_pool_task_executor.cpp b/src/mongo/executor/thread_pool_task_executor.cpp index 73fca1533d9..7ef0669aea0 100644 --- a/src/mongo/executor/thread_pool_task_executor.cpp +++ b/src/mongo/executor/thread_pool_task_executor.cpp @@ -391,10 +391,10 @@ using ResponseStatus = TaskExecutor::ResponseStatus; // which expects a RemoteCommandResponse as part of RemoteCommandCallbackArgs, // can be run despite a RemoteCommandResponse never having been created. void remoteCommandFinished(const TaskExecutor::CallbackArgs& cbData, - const TaskExecutor::RemoteCommandCallbackFn& cb, - const RemoteCommandRequest& request, - const ResponseStatus& rs) { - cb(TaskExecutor::RemoteCommandCallbackArgs(cbData.executor, cbData.myHandle, request, rs)); + const TaskExecutor::RemoteCommandOnAnyCallbackFn& cb, + const RemoteCommandRequestOnAny& request, + const TaskExecutor::ResponseOnAnyStatus& rs) { + cb({cbData.executor, cbData.myHandle, request, rs}); } // If the request failed to receive a connection from the pool, @@ -402,11 +402,10 @@ void remoteCommandFinished(const TaskExecutor::CallbackArgs& cbData, // which expects a RemoteCommandResponse as part of RemoteCommandCallbackArgs, // can be run despite a RemoteCommandResponse never having been created. void remoteCommandFailedEarly(const TaskExecutor::CallbackArgs& cbData, - const TaskExecutor::RemoteCommandCallbackFn& cb, - const RemoteCommandRequest& request) { + const TaskExecutor::RemoteCommandOnAnyCallbackFn& cb, + const RemoteCommandRequestOnAny& request) { invariant(!cbData.status.isOK()); - cb(TaskExecutor::RemoteCommandCallbackArgs( - cbData.executor, cbData.myHandle, request, {cbData.status})); + cb({cbData.executor, cbData.myHandle, request, {boost::none, cbData.status}}); } // The command names that the initial sync test fixture pauses on during the collection cloning @@ -416,9 +415,9 @@ const auto initialSyncPauseCmds = } // namespace -StatusWith<TaskExecutor::CallbackHandle> ThreadPoolTaskExecutor::scheduleRemoteCommand( - const RemoteCommandRequest& request, - const RemoteCommandCallbackFn& cb, +StatusWith<TaskExecutor::CallbackHandle> ThreadPoolTaskExecutor::scheduleRemoteCommandOnAny( + const RemoteCommandRequestOnAny& request, + const RemoteCommandOnAnyCallbackFn& cb, const BatonHandle& baton) { if (MONGO_FAIL_POINT(initialSyncFuzzerSynchronizationPoint1)) { @@ -441,7 +440,7 @@ StatusWith<TaskExecutor::CallbackHandle> ThreadPoolTaskExecutor::scheduleRemoteC } } - RemoteCommandRequest scheduledRequest = request; + RemoteCommandRequestOnAny scheduledRequest = request; if (request.timeout == RemoteCommandRequest::kNoTimeout) { scheduledRequest.expirationDate = RemoteCommandRequest::kNoExpirationDate; } else { @@ -467,7 +466,7 @@ StatusWith<TaskExecutor::CallbackHandle> ThreadPoolTaskExecutor::scheduleRemoteC auto commandStatus = _net->startCommand( swCbHandle.getValue(), scheduledRequest, - [this, scheduledRequest, cbState, cb](const ResponseStatus& response) { + [this, scheduledRequest, cbState, cb](const ResponseOnAnyStatus& response) { using std::swap; CallbackFn newCb = [cb, scheduledRequest, response](const CallbackArgs& cbData) { remoteCommandFinished(cbData, cb, scheduledRequest, response); diff --git a/src/mongo/executor/thread_pool_task_executor.h b/src/mongo/executor/thread_pool_task_executor.h index 89010b451e1..8285785d748 100644 --- a/src/mongo/executor/thread_pool_task_executor.h +++ b/src/mongo/executor/thread_pool_task_executor.h @@ -84,9 +84,10 @@ public: void waitForEvent(const EventHandle& event) override; StatusWith<CallbackHandle> scheduleWork(CallbackFn&& work) override; StatusWith<CallbackHandle> scheduleWorkAt(Date_t when, CallbackFn&& work) override; - StatusWith<CallbackHandle> scheduleRemoteCommand(const RemoteCommandRequest& request, - const RemoteCommandCallbackFn& cb, - const BatonHandle& baton = nullptr) override; + StatusWith<CallbackHandle> scheduleRemoteCommandOnAny( + const RemoteCommandRequestOnAny& request, + const RemoteCommandOnAnyCallbackFn& cb, + const BatonHandle& baton = nullptr) override; void cancel(const CallbackHandle& cbHandle) override; void wait(const CallbackHandle& cbHandle, Interruptible* interruptible = Interruptible::notInterruptible()) override; |