diff options
author | Adam Midvidy <amidvidy@gmail.com> | 2015-08-14 11:26:02 -0400 |
---|---|---|
committer | Adam Midvidy <amidvidy@gmail.com> | 2015-08-18 11:45:13 -0400 |
commit | e6ddd3da54a42088e5fa524df34f06def842eda9 (patch) | |
tree | 62c9173b288c07188e6952635ab013d1d168f9e3 /src | |
parent | 4c61da5028cafcea65dd35d7f0e02941495cda98 (diff) | |
download | mongo-e6ddd3da54a42088e5fa524df34f06def842eda9.tar.gz |
SERVER-19420 implement connection hook API in NetworkInterfaceASIO
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/executor/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/executor/async_mock_stream_factory.cpp | 105 | ||||
-rw-r--r-- | src/mongo/executor/async_mock_stream_factory.h | 57 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_asio.cpp | 6 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_asio.h | 7 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_asio_auth.cpp | 49 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_asio_command.cpp | 44 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_asio_test.cpp | 423 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_mock_test.cpp | 41 | ||||
-rw-r--r-- | src/mongo/executor/remote_command_request.h | 12 | ||||
-rw-r--r-- | src/mongo/executor/remote_command_response.cpp | 8 | ||||
-rw-r--r-- | src/mongo/executor/remote_command_response.h | 7 | ||||
-rw-r--r-- | src/mongo/executor/test_network_connection_hook.h | 87 |
13 files changed, 658 insertions, 190 deletions
diff --git a/src/mongo/executor/SConscript b/src/mongo/executor/SConscript index 22e895e1f5d..3d3d6b05f7a 100644 --- a/src/mongo/executor/SConscript +++ b/src/mongo/executor/SConscript @@ -72,7 +72,6 @@ env.Library(target='network_test_env', env.Library( target='network_interface_asio', source=[ - 'async_mock_stream_factory.cpp', 'async_secure_stream.cpp', 'async_secure_stream_factory.cpp', 'async_stream.cpp', @@ -96,6 +95,7 @@ env.Library( env.CppUnitTest( target='network_interface_asio_test', source=[ + 'async_mock_stream_factory.cpp', 'network_interface_asio_test.cpp', ], LIBDEPS=[ diff --git a/src/mongo/executor/async_mock_stream_factory.cpp b/src/mongo/executor/async_mock_stream_factory.cpp index 88229496110..ef007870176 100644 --- a/src/mongo/executor/async_mock_stream_factory.cpp +++ b/src/mongo/executor/async_mock_stream_factory.cpp @@ -32,12 +32,18 @@ #include "mongo/executor/async_mock_stream_factory.h" +#include <exception> #include <iterator> #include <system_error> +#include "mongo/rpc/command_reply_builder.h" +#include "mongo/rpc/factory.h" +#include "mongo/rpc/legacy_reply_builder.h" +#include "mongo/rpc/request_interface.h" #include "mongo/stdx/memory.h" #include "mongo/util/assert_util.h" #include "mongo/util/log.h" +#include "mongo/util/net/message.h" namespace mongo { namespace executor { @@ -86,9 +92,9 @@ void AsyncMockStreamFactory::MockStream::connect(asio::ip::tcp::resolver::iterat ConnectHandler&& connectHandler) { { stdx::unique_lock<stdx::mutex> lk(_mutex); - log() << "connect() for: " << _target; - _block_inlock(&lk); + // Block before returning from connect. + _block_inlock(kBlockedBeforeConnect, &lk); } _io_service->post([connectHandler, endpoints] { connectHandler(std::error_code()); }); } @@ -96,15 +102,13 @@ void AsyncMockStreamFactory::MockStream::connect(asio::ip::tcp::resolver::iterat void AsyncMockStreamFactory::MockStream::write(asio::const_buffer buf, StreamHandler&& writeHandler) { stdx::unique_lock<stdx::mutex> lk(_mutex); - log() << "write() for: " << _target; - auto begin = asio::buffer_cast<const uint8_t*>(buf); auto size = asio::buffer_size(buf); _writeQueue.push({begin, begin + size}); // Block after data is written. - _block_inlock(&lk); + _block_inlock(kBlockedAfterWrite, &lk); lk.unlock(); _io_service->post([writeHandler, size] { writeHandler(std::error_code(), size); }); @@ -113,10 +117,8 @@ void AsyncMockStreamFactory::MockStream::write(asio::const_buffer buf, void AsyncMockStreamFactory::MockStream::read(asio::mutable_buffer buf, StreamHandler&& readHandler) { stdx::unique_lock<stdx::mutex> lk(_mutex); - log() << "read() for: " << _target; - // Block before data is read. - _block_inlock(&lk); + _block_inlock(kBlockedBeforeRead, &lk); auto nextRead = std::move(_readQueue.front()); _readQueue.pop(); @@ -138,26 +140,26 @@ void AsyncMockStreamFactory::MockStream::read(asio::mutable_buffer buf, void AsyncMockStreamFactory::MockStream::pushRead(std::vector<uint8_t> toRead) { stdx::unique_lock<stdx::mutex> lk(_mutex); - invariant(_blocked); + invariant(_state != kRunning); _readQueue.emplace(std::move(toRead)); } std::vector<uint8_t> AsyncMockStreamFactory::MockStream::popWrite() { stdx::unique_lock<stdx::mutex> lk(_mutex); - invariant(_blocked); + invariant(_state != kRunning); auto nextWrite = std::move(_writeQueue.front()); _writeQueue.pop(); return nextWrite; } -void AsyncMockStreamFactory::MockStream::_block_inlock(stdx::unique_lock<stdx::mutex>* lk) { - log() << "blocking in stream for: " << _target; - invariant(!_blocked); - _blocked = true; +void AsyncMockStreamFactory::MockStream::_block_inlock(StreamState state, + stdx::unique_lock<stdx::mutex>* lk) { + invariant(_state == kRunning); + _state = state; lk->unlock(); _cv.notify_one(); lk->lock(); - _cv.wait(*lk, [this]() { return !_blocked; }); + _cv.wait(*lk, [this]() { return _state == kRunning; }); } void AsyncMockStreamFactory::MockStream::unblock() { @@ -166,18 +168,79 @@ void AsyncMockStreamFactory::MockStream::unblock() { } void AsyncMockStreamFactory::MockStream::_unblock_inlock(stdx::unique_lock<stdx::mutex>* lk) { - log() << "unblocking stream for: " << _target; - invariant(_blocked); - _blocked = false; + invariant(_state != kRunning); + _state = kRunning; lk->unlock(); _cv.notify_one(); lk->lock(); } -void AsyncMockStreamFactory::MockStream::waitUntilBlocked() { +auto AsyncMockStreamFactory::MockStream::waitUntilBlocked() -> StreamState { stdx::unique_lock<stdx::mutex> lk(_mutex); - log() << "waiting until stream for " << _target << " has blocked"; - _cv.wait(lk, [this]() { return _blocked; }); + _cv.wait(lk, [this]() { return _state != kRunning; }); + return _state; +} + +HostAndPort AsyncMockStreamFactory::MockStream::target() { + return _target; +} + +void AsyncMockStreamFactory::MockStream::simulateServer( + rpc::Protocol proto, + const stdx::function<RemoteCommandResponse(RemoteCommandRequest)> replyFunc) { + std::exception_ptr ex; + uint32_t messageId = 0; + + RemoteCommandResponse resp; + { + WriteEvent write{this}; + + std::vector<uint8_t> messageData = popWrite(); + Message msg(messageData.data(), false); + + auto parsedRequest = rpc::makeRequest(&msg); + ASSERT(parsedRequest->getProtocol() == proto); + + RemoteCommandRequest rcr(target(), *parsedRequest); + + messageId = msg.header().getId(); + + // So we can allow ASSERTs in replyFunc, we capture any exceptions, but rethrow + // them later to prevent deadlock + try { + resp = replyFunc(std::move(rcr)); + } catch (...) { + ex = std::current_exception(); + } + } + + auto replyBuilder = rpc::makeReplyBuilder(proto); + replyBuilder->setMetadata(resp.metadata); + replyBuilder->setCommandReply(resp.data); + + auto replyMsg = replyBuilder->done(); + replyMsg->header().setResponseTo(messageId); + + { + // The first read will be for the header. + ReadEvent read{this}; + auto hdrBytes = reinterpret_cast<const uint8_t*>(replyMsg->header().view2ptr()); + pushRead({hdrBytes, hdrBytes + sizeof(MSGHEADER::Value)}); + } + + { + // The second read will be for the message data. + ReadEvent read{this}; + auto dataBytes = reinterpret_cast<const uint8_t*>(replyMsg->buf()); + auto pastHeader = dataBytes; + std::advance(pastHeader, sizeof(MSGHEADER::Value)); + pushRead({pastHeader, dataBytes + static_cast<std::size_t>(replyMsg->size())}); + } + + if (ex) { + // Rethrow ASSERTS after the NIA completes it's Write-Read-Read sequence. + std::rethrow_exception(ex); + } } } // namespace executor diff --git a/src/mongo/executor/async_mock_stream_factory.h b/src/mongo/executor/async_mock_stream_factory.h index 0948b7b13cd..856f12e7187 100644 --- a/src/mongo/executor/async_mock_stream_factory.h +++ b/src/mongo/executor/async_mock_stream_factory.h @@ -36,8 +36,13 @@ #include "mongo/executor/async_stream_factory_interface.h" #include "mongo/executor/async_stream_interface.h" -#include "mongo/stdx/mutex.h" +#include "mongo/executor/remote_command_request.h" +#include "mongo/executor/remote_command_response.h" +#include "mongo/rpc/protocol.h" #include "mongo/stdx/condition_variable.h" +#include "mongo/stdx/functional.h" +#include "mongo/stdx/mutex.h" +#include "mongo/unittest/unittest.h" #include "mongo/util/net/hostandport.h" namespace mongo { @@ -57,6 +62,15 @@ public: MockStream(asio::io_service* io_service, AsyncMockStreamFactory* factory, const HostAndPort& target); + + // Use unscoped enum so we can specialize on it + enum StreamState { + kRunning, + kBlockedBeforeConnect, + kBlockedBeforeRead, + kBlockedAfterWrite, + }; + ~MockStream(); void connect(asio::ip::tcp::resolver::iterator endpoints, @@ -64,27 +78,32 @@ public: void write(asio::const_buffer buf, StreamHandler&& writeHandler) override; void read(asio::mutable_buffer buf, StreamHandler&& readHandler) override; - void waitUntilBlocked(); + HostAndPort target(); + + StreamState waitUntilBlocked(); std::vector<uint8_t> popWrite(); void pushRead(std::vector<uint8_t> toRead); void unblock(); + void simulateServer( + rpc::Protocol proto, + const stdx::function<RemoteCommandResponse(RemoteCommandRequest)> replyFunc); + private: void _unblock_inlock(stdx::unique_lock<stdx::mutex>* lk); - void _block_inlock(stdx::unique_lock<stdx::mutex>* lk); + void _block_inlock(StreamState state, stdx::unique_lock<stdx::mutex>* lk); asio::io_service* _io_service; AsyncMockStreamFactory* _factory; HostAndPort _target; - stdx::mutex _mutex; stdx::condition_variable _cv; - bool _blocked{false}; + StreamState _state{kRunning}; std::queue<std::vector<uint8_t>> _readQueue; std::queue<std::vector<uint8_t>> _writeQueue; @@ -102,5 +121,33 @@ private: std::unordered_map<HostAndPort, MockStream*> _streams; }; +template <int EventType> +class StreamEvent { +public: + StreamEvent(AsyncMockStreamFactory::MockStream* stream) : _stream(stream) { + ASSERT(stream->waitUntilBlocked() == EventType); + } + + void skip() { + _stream->unblock(); + skipped = true; + } + + ~StreamEvent() { + if (!skipped) { + skip(); + } + } + +private: + bool skipped = false; + AsyncMockStreamFactory::MockStream* _stream = nullptr; +}; + +using ReadEvent = StreamEvent<AsyncMockStreamFactory::MockStream::StreamState::kBlockedBeforeRead>; +using WriteEvent = StreamEvent<AsyncMockStreamFactory::MockStream::StreamState::kBlockedAfterWrite>; +using ConnectEvent = + StreamEvent<AsyncMockStreamFactory::MockStream::StreamState::kBlockedBeforeConnect>; + } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/network_interface_asio.cpp b/src/mongo/executor/network_interface_asio.cpp index b122d76ab87..4b40543a53b 100644 --- a/src/mongo/executor/network_interface_asio.cpp +++ b/src/mongo/executor/network_interface_asio.cpp @@ -48,7 +48,13 @@ namespace executor { NetworkInterfaceASIO::NetworkInterfaceASIO( std::unique_ptr<AsyncStreamFactoryInterface> streamFactory) + : NetworkInterfaceASIO(std::move(streamFactory), nullptr) {} + +NetworkInterfaceASIO::NetworkInterfaceASIO( + std::unique_ptr<AsyncStreamFactoryInterface> streamFactory, + std::unique_ptr<NetworkConnectionHook> networkConnectionHook) : _io_service(), + _hook(std::move(networkConnectionHook)), _resolver(_io_service), _state(State::kReady), _streamFactory(std::move(streamFactory)), diff --git a/src/mongo/executor/network_interface_asio.h b/src/mongo/executor/network_interface_asio.h index 724a16d7c9a..7cd1a2f2993 100644 --- a/src/mongo/executor/network_interface_asio.h +++ b/src/mongo/executor/network_interface_asio.h @@ -38,6 +38,7 @@ #include "mongo/base/status.h" #include "mongo/base/system_error.h" +#include "mongo/executor/network_connection_hook.h" #include "mongo/executor/network_interface.h" #include "mongo/executor/remote_command_request.h" #include "mongo/executor/remote_command_response.h" @@ -54,7 +55,6 @@ namespace executor { class AsyncStreamFactoryInterface; class AsyncStreamInterface; -class NetworkConnectionHook; /** * Implementation of the replication system's network interface using Christopher @@ -62,6 +62,8 @@ class NetworkConnectionHook; */ class NetworkInterfaceASIO final : public NetworkInterface { public: + NetworkInterfaceASIO(std::unique_ptr<AsyncStreamFactoryInterface> streamFactory, + std::unique_ptr<NetworkConnectionHook> networkConnectionHook); NetworkInterfaceASIO(std::unique_ptr<AsyncStreamFactoryInterface> streamFactory); std::string getDiagnosticString() override; std::string getHostName() override; @@ -239,6 +241,7 @@ private: void _setupSocket(AsyncOp* op, asio::ip::tcp::resolver::iterator endpoints); void _runIsMaster(AsyncOp* op); + void _runConnectionHook(AsyncOp* op); void _authenticate(AsyncOp* op); // Communication state machine @@ -254,6 +257,8 @@ private: asio::io_service _io_service; stdx::thread _serviceRunner; + std::unique_ptr<NetworkConnectionHook> _hook; + asio::ip::tcp::resolver _resolver; std::atomic<State> _state; diff --git a/src/mongo/executor/network_interface_asio_auth.cpp b/src/mongo/executor/network_interface_asio_auth.cpp index a6f1fdcfa38..f239389408a 100644 --- a/src/mongo/executor/network_interface_asio_auth.cpp +++ b/src/mongo/executor/network_interface_asio_auth.cpp @@ -62,33 +62,40 @@ void NetworkInterfaceASIO::_runIsMaster(AsyncOp* op) { // Callback to parse protocol information out of received ismaster response auto parseIsMaster = [this, op]() { - try { - auto commandReply = rpc::makeReply(&(op->command().toRecv())); - BSONObj isMasterReply = commandReply->getCommandReply(); - auto protocolSet = rpc::parseProtocolSetFromIsMasterReply(isMasterReply); - if (!protocolSet.isOK()) - return _completeOperation(op, protocolSet.getStatus()); - - op->connection().setServerProtocols(protocolSet.getValue()); + auto swCommandReply = op->command().response(rpc::Protocol::kOpQuery, now()); + if (!swCommandReply.isOK()) { + return _completeOperation(op, swCommandReply.getStatus()); + } - // Set the operation protocol - auto negotiatedProtocol = rpc::negotiate(op->connection().serverProtocols(), - op->connection().clientProtocols()); + auto commandReply = std::move(swCommandReply.getValue()); - if (!negotiatedProtocol.isOK()) { - return _completeOperation(op, negotiatedProtocol.getStatus()); + if (_hook) { + // Run the validation hook. + auto validHost = _hook->validateHost(op->request().target, commandReply); + if (!validHost.isOK()) { + return _completeOperation(op, validHost); } + } + + auto protocolSet = rpc::parseProtocolSetFromIsMasterReply(commandReply.data); + if (!protocolSet.isOK()) + return _completeOperation(op, protocolSet.getStatus()); - op->setOperationProtocol(negotiatedProtocol.getValue()); + op->connection().setServerProtocols(protocolSet.getValue()); - // Advance the state machine - return _authenticate(op); + // Set the operation protocol + auto negotiatedProtocol = + rpc::negotiate(op->connection().serverProtocols(), op->connection().clientProtocols()); - } catch (...) { - // makeReply will throw if the reply was invalid. - return _completeOperation(op, exceptionToStatus()); + if (!negotiatedProtocol.isOK()) { + return _completeOperation(op, negotiatedProtocol.getStatus()); } + + op->setOperationProtocol(negotiatedProtocol.getValue()); + + return _authenticate(op); + }; _asyncRunCommand(&cmd, @@ -105,7 +112,7 @@ void NetworkInterfaceASIO::_authenticate(AsyncOp* op) { // This check is sufficient to see if auth is enabled on the system, // and avoids creating dependencies on deeper, less accessible auth code. if (!isInternalAuthSet()) { - return asio::post(_io_service, [this, op]() { _beginCommunication(op); }); + return _runConnectionHook(op); } // We will only have a valid clientName if SSL is enabled. @@ -136,7 +143,7 @@ void NetworkInterfaceASIO::_authenticate(AsyncOp* op) { auto authHook = [this, op](auth::AuthResponse response) { if (!response.isOK()) return _completeOperation(op, response); - return _beginCommunication(op); + return _runConnectionHook(op); }; auto params = getInternalUserAuthParamsWithFallback(); diff --git a/src/mongo/executor/network_interface_asio_command.cpp b/src/mongo/executor/network_interface_asio_command.cpp index 35237245ea2..fa1a560521b 100644 --- a/src/mongo/executor/network_interface_asio_command.cpp +++ b/src/mongo/executor/network_interface_asio_command.cpp @@ -41,6 +41,7 @@ #include "mongo/rpc/factory.h" #include "mongo/rpc/protocol.h" #include "mongo/rpc/reply_interface.h" +#include "mongo/rpc/request_builder_interface.h" #include "mongo/stdx/memory.h" #include "mongo/util/assert_util.h" #include "mongo/util/log.h" @@ -271,5 +272,48 @@ void NetworkInterfaceASIO::_asyncRunCommand(AsyncCommand* cmd, NetworkOpHandler asyncSendMessage(cmd->conn().stream(), &cmd->toSend(), std::move(sendMessageCallback)); } +void NetworkInterfaceASIO::_runConnectionHook(AsyncOp* op) { + if (!_hook) { + return _beginCommunication(op); + } + + auto swOptionalRequest = _hook->makeRequest(op->request().target); + + if (!swOptionalRequest.isOK()) { + return _completeOperation(op, swOptionalRequest.getStatus()); + } + + auto optionalRequest = std::move(swOptionalRequest.getValue()); + + if (optionalRequest == boost::none) { + return _beginCommunication(op); + } + + auto& cmd = op->beginCommand(*optionalRequest, op->operationProtocol(), now()); + + auto finishHook = [this, op]() { + auto response = op->command().response(op->operationProtocol(), now()); + + if (!response.isOK()) { + return _completeOperation(op, response.getStatus()); + } + + auto handleStatus = + _hook->handleReply(op->request().target, std::move(response.getValue())); + + if (!handleStatus.isOK()) { + return _completeOperation(op, handleStatus); + } + + return _beginCommunication(op); + }; + + return _asyncRunCommand(&cmd, + [this, op, finishHook](std::error_code ec, std::size_t bytes) { + _validateAndRun(op, ec, finishHook); + }); +} + + } // namespace executor } // namespace mongo diff --git a/src/mongo/executor/network_interface_asio_test.cpp b/src/mongo/executor/network_interface_asio_test.cpp index a1aeb324eab..da3cde07874 100644 --- a/src/mongo/executor/network_interface_asio_test.cpp +++ b/src/mongo/executor/network_interface_asio_test.cpp @@ -30,15 +30,15 @@ #include "mongo/platform/basic.h" +#include <boost/optional.hpp> + +#include "mongo/base/status_with.h" #include "mongo/db/jsobj.h" #include "mongo/db/wire_version.h" #include "mongo/executor/async_mock_stream_factory.h" #include "mongo/executor/network_interface_asio.h" -#include "mongo/rpc/command_reply_builder.h" -#include "mongo/rpc/factory.h" +#include "mongo/executor/test_network_connection_hook.h" #include "mongo/rpc/legacy_reply_builder.h" -#include "mongo/rpc/protocol.h" -#include "mongo/rpc/request_interface.h" #include "mongo/stdx/future.h" #include "mongo/stdx/memory.h" #include "mongo/unittest/unittest.h" @@ -49,6 +49,8 @@ namespace mongo { namespace executor { namespace { +HostAndPort testHost{"localhost", 20000}; + class NetworkInterfaceASIOTest : public mongo::unittest::Test { public: void setUp() override { @@ -73,7 +75,11 @@ public: return *_streamFactory; } -private: + void simulateServerReply(AsyncMockStreamFactory::MockStream* stream, + rpc::Protocol proto, + const stdx::function<RemoteCommandResponse(RemoteCommandRequest)>) {} + +protected: AsyncMockStreamFactory* _streamFactory; std::unique_ptr<NetworkInterfaceASIO> _net; }; @@ -104,124 +110,339 @@ TEST_F(NetworkInterfaceASIOTest, StartCommand) { auto stream = streamFactory().blockUntilStreamExists(testHost); // Allow stream to connect. - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); - } - - log() << "connected"; - - uint32_t isMasterMsgId = 0; - - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); - - log() << "NIA blocked after writing isMaster request"; - - // Check that an isMaster has been run on the stream - std::vector<uint8_t> messageData = stream->popWrite(); - Message msg(messageData.data(), false); - - auto request = rpc::makeRequest(&msg); - - ASSERT_EQ(request->getCommandName(), "isMaster"); - ASSERT_EQ(request->getDatabase(), "admin"); - - isMasterMsgId = msg.header().getId(); - - // Check that we used OP_QUERY. - ASSERT(request->getProtocol() == rpc::Protocol::kOpQuery); - } - - rpc::LegacyReplyBuilder replyBuilder; - replyBuilder.setMetadata(BSONObj()); - replyBuilder.setCommandReply(BSON("minWireVersion" << mongo::minWireVersion << "maxWireVersion" - << mongo::maxWireVersion)); - auto replyMsg = replyBuilder.done(); - replyMsg->header().setResponseTo(isMasterMsgId); - - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); - - log() << "NIA blocked before reading isMaster reply header"; - - // write out the full message now, even though another read() call will read the rest. - auto hdrBytes = reinterpret_cast<const uint8_t*>(replyMsg->header().view2ptr()); + ConnectEvent{stream}.skip(); - stream->pushRead({hdrBytes, hdrBytes + sizeof(MSGHEADER::Value)}); + // simulate isMaster reply. + stream->simulateServer( + rpc::Protocol::kOpQuery, + [](RemoteCommandRequest request) -> RemoteCommandResponse { + ASSERT_EQ(std::string{request.cmdObj.firstElementFieldName()}, "isMaster"); + ASSERT_EQ(request.dbname, "admin"); - auto dataBytes = reinterpret_cast<const uint8_t*>(replyMsg->buf()); - auto pastHeader = dataBytes; - std::advance(pastHeader, sizeof(MSGHEADER::Value)); // skip the header this time - - stream->pushRead({pastHeader, dataBytes + static_cast<std::size_t>(replyMsg->size())}); - } - - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); - log() << "NIA blocked before reading isMaster reply data"; - } + RemoteCommandResponse response; + response.data = BSON("minWireVersion" << mongo::minWireVersion << "maxWireVersion" + << mongo::maxWireVersion); + return response; + }); + auto expectedMetadata = BSON("meep" + << "beep"); auto expectedCommandReply = BSON("boop" << "bop" << "ok" << 1.0); - auto expectedMetadata = BSON("meep" - << "beep"); - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); + // simulate user command + stream->simulateServer(rpc::Protocol::kOpCommandV1, + [&](RemoteCommandRequest request) -> RemoteCommandResponse { + ASSERT_EQ(std::string{request.cmdObj.firstElementFieldName()}, + "foo"); + ASSERT_EQ(request.dbname, "testDB"); - log() << "blocked after write(), reading user command request"; + RemoteCommandResponse response; + response.data = expectedCommandReply; + response.metadata = expectedMetadata; + return response; + }); - std::vector<uint8_t> messageData{stream->popWrite()}; + auto res = fut.get(); - Message msg(messageData.data(), false); - auto request = rpc::makeRequest(&msg); + ASSERT(callbackCalled); + ASSERT_EQ(res.data, expectedCommandReply); + ASSERT_EQ(res.metadata, expectedMetadata); +} - // the command we requested should be running. - ASSERT_EQ(request->getCommandName(), "foo"); - ASSERT_EQ(request->getDatabase(), "testDB"); +class NetworkInterfaceASIOConnectionHookTest : public NetworkInterfaceASIOTest { +public: + void setUp() override {} - // we should be using op command given our previous isMaster reply. - ASSERT(request->getProtocol() == rpc::Protocol::kOpCommandV1); + void start(std::unique_ptr<NetworkConnectionHook> hook) { + auto factory = stdx::make_unique<AsyncMockStreamFactory>(); + // keep unowned pointer, but pass ownership to NIA + _streamFactory = factory.get(); + _net = stdx::make_unique<NetworkInterfaceASIO>(std::move(factory), std::move(hook)); + _net->startup(); + } +}; - rpc::CommandReplyBuilder replyBuilder; - replyBuilder.setMetadata(expectedMetadata).setCommandReply(expectedCommandReply); - auto replyMsg = replyBuilder.done(); +template <typename T> +void assertThrowsStatus(stdx::future<T>&& fut, const Status& s) { + ASSERT([&] { + try { + std::forward<stdx::future<T>>(fut).get(); + return false; + } catch (const DBException& ex) { + return ex.toStatus() == s; + } + }()); +} - replyMsg->header().setResponseTo(msg.header().getId()); +TEST_F(NetworkInterfaceASIOConnectionHookTest, ValidateHostInvalid) { + bool validateCalled = false; + bool hostCorrect = false; + bool isMasterReplyCorrect = false; + bool makeRequestCalled = false; + bool handleReplyCalled = false; + + auto validationFailedStatus = Status(ErrorCodes::AlreadyInitialized, "blahhhhh"); + + start(makeTestHook( + [&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply) { + validateCalled = true; + hostCorrect = (remoteHost == testHost); + isMasterReplyCorrect = (isMasterReply.data["TESTKEY"].str() == "TESTVALUE"); + return validationFailedStatus; + }, + [&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> { + makeRequestCalled = true; + return {boost::none}; + }, + [&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) { + handleReplyCalled = true; + return Status::OK(); + })); + + stdx::promise<RemoteCommandResponse> done; + auto doneFuture = done.get_future(); + + net().startCommand({}, + {testHost, + "blah", + BSON("foo" + << "bar")}, + [&](StatusWith<RemoteCommandResponse> result) { + try { + done.set_value(uassertStatusOK(result)); + } catch (...) { + done.set_exception(std::current_exception()); + } + }); - // write out the full message now, even though another read() call will read the rest. - auto hdrBytes = reinterpret_cast<const uint8_t*>(replyMsg->header().view2ptr()); + auto stream = streamFactory().blockUntilStreamExists(testHost); - stream->pushRead({hdrBytes, hdrBytes + sizeof(MSGHEADER::Value)}); + ConnectEvent{stream}.skip(); + + // simulate isMaster reply. + stream->simulateServer(rpc::Protocol::kOpQuery, + [](RemoteCommandRequest request) -> RemoteCommandResponse { + RemoteCommandResponse response; + response.data = BSON("minWireVersion" + << mongo::minWireVersion << "maxWireVersion" + << mongo::maxWireVersion << "TESTKEY" + << "TESTVALUE"); + return response; + }); + + // we should stop here. + assertThrowsStatus(std::move(doneFuture), validationFailedStatus); + ASSERT(validateCalled); + ASSERT(hostCorrect); + ASSERT(isMasterReplyCorrect); + + ASSERT(!makeRequestCalled); + ASSERT(!handleReplyCalled); +} - auto dataBytes = reinterpret_cast<const uint8_t*>(replyMsg->buf()); - auto pastHeader = dataBytes; - std::advance(pastHeader, sizeof(MSGHEADER::Value)); // skip the header this time +TEST_F(NetworkInterfaceASIOConnectionHookTest, MakeRequestReturnsError) { + bool makeRequestCalled = false; + bool handleReplyCalled = false; + + Status makeRequestError{ErrorCodes::DBPathInUse, "bloooh"}; + + start(makeTestHook( + [&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply) + -> Status { return Status::OK(); }, + [&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> { + makeRequestCalled = true; + return makeRequestError; + }, + [&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) { + handleReplyCalled = true; + return Status::OK(); + })); + + stdx::promise<RemoteCommandResponse> done; + auto doneFuture = done.get_future(); + + net().startCommand({}, + {testHost, + "blah", + BSON("foo" + << "bar")}, + [&](StatusWith<RemoteCommandResponse> result) { + try { + done.set_value(uassertStatusOK(result)); + } catch (...) { + done.set_exception(std::current_exception()); + } + }); - stream->pushRead({pastHeader, dataBytes + static_cast<std::size_t>(replyMsg->size())}); - } + auto stream = streamFactory().blockUntilStreamExists(testHost); + ConnectEvent{stream}.skip(); + + // simulate isMaster reply. + stream->simulateServer(rpc::Protocol::kOpQuery, + [](RemoteCommandRequest request) -> RemoteCommandResponse { + RemoteCommandResponse response; + response.data = BSON("minWireVersion" << mongo::minWireVersion + << "maxWireVersion" + << mongo::maxWireVersion); + return response; + }); + + // We should stop here. + assertThrowsStatus(std::move(doneFuture), makeRequestError); + + ASSERT(makeRequestCalled); + ASSERT(!handleReplyCalled); +} - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); - } +TEST_F(NetworkInterfaceASIOConnectionHookTest, MakeRequestReturnsNone) { + bool makeRequestCalled = false; + bool handleReplyCalled = false; + + start(makeTestHook( + [&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply) + -> Status { return Status::OK(); }, + [&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> { + makeRequestCalled = true; + return {boost::none}; + }, + [&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) { + handleReplyCalled = true; + return Status::OK(); + })); + + stdx::promise<RemoteCommandResponse> done; + auto doneFuture = done.get_future(); + + auto commandRequest = BSON("foo" + << "bar"); + + net().startCommand({}, + {testHost, "blah", commandRequest}, + [&](StatusWith<RemoteCommandResponse> result) { + try { + done.set_value(uassertStatusOK(result)); + } catch (...) { + done.set_exception(std::current_exception()); + } + }); - { - stream->waitUntilBlocked(); - auto guard = MakeGuard([&] { stream->unblock(); }); - } - auto res = fut.get(); + auto stream = streamFactory().blockUntilStreamExists(testHost); + ConnectEvent{stream}.skip(); + + // simulate isMaster reply. + stream->simulateServer(rpc::Protocol::kOpQuery, + [](RemoteCommandRequest request) -> RemoteCommandResponse { + RemoteCommandResponse response; + response.data = BSON("minWireVersion" << mongo::minWireVersion + << "maxWireVersion" + << mongo::maxWireVersion); + return response; + }); + + auto commandReply = BSON("foo" + << "boo" + << "ok" << 1.0); + + auto metadata = BSON("aaa" + << "bbb"); + + // Simulate user command. + stream->simulateServer(rpc::Protocol::kOpCommandV1, + [&](RemoteCommandRequest request) -> RemoteCommandResponse { + ASSERT_EQ(commandRequest, request.cmdObj); + + RemoteCommandResponse response; + response.data = commandReply; + response.metadata = metadata; + return response; + }); + + // We should get back the reply now. + auto reply = doneFuture.get(); + ASSERT_EQ(reply.data, commandReply); + ASSERT_EQ(reply.metadata, metadata); +} - ASSERT(callbackCalled); - ASSERT_EQ(res.data, expectedCommandReply); - ASSERT_EQ(res.metadata, expectedMetadata); +TEST_F(NetworkInterfaceASIOConnectionHookTest, HandleReplyReturnsError) { + bool makeRequestCalled = false; + + bool handleReplyCalled = false; + bool handleReplyArgumentCorrect = false; + + BSONObj hookCommandRequest = BSON("1ddd" + << "fff"); + BSONObj hookRequestMetadata = BSON("wdwd" << 1212); + + BSONObj hookCommandReply = BSON("blah" + << "blah" + << "ok" << 1.0); + BSONObj hookReplyMetadata = BSON("1111" << 2222); + + Status handleReplyError{ErrorCodes::AuthSchemaIncompatible, "daowdjkpowkdjpow"}; + + start(makeTestHook( + [&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply) + -> Status { return Status::OK(); }, + [&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> { + makeRequestCalled = true; + return {boost::make_optional<RemoteCommandRequest>( + {testHost, "foo", hookCommandRequest, hookRequestMetadata})}; + + }, + [&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) { + handleReplyCalled = true; + handleReplyArgumentCorrect = + (response.data == hookCommandReply) && (response.metadata == hookReplyMetadata); + return handleReplyError; + })); + + stdx::promise<RemoteCommandResponse> done; + auto doneFuture = done.get_future(); + auto commandRequest = BSON("foo" + << "bar"); + net().startCommand({}, + {testHost, "blah", commandRequest}, + [&](StatusWith<RemoteCommandResponse> result) { + try { + done.set_value(uassertStatusOK(result)); + } catch (...) { + done.set_exception(std::current_exception()); + } + }); + + auto stream = streamFactory().blockUntilStreamExists(testHost); + ConnectEvent{stream}.skip(); + + // simulate isMaster reply. + stream->simulateServer(rpc::Protocol::kOpQuery, + [](RemoteCommandRequest request) -> RemoteCommandResponse { + RemoteCommandResponse response; + response.data = BSON("minWireVersion" << mongo::minWireVersion + << "maxWireVersion" + << mongo::maxWireVersion); + return response; + }); + + // Simulate hook reply + stream->simulateServer(rpc::Protocol::kOpCommandV1, + [&](RemoteCommandRequest request) -> RemoteCommandResponse { + ASSERT_EQ(request.cmdObj, hookCommandRequest); + ASSERT_EQ(request.metadata, hookRequestMetadata); + + RemoteCommandResponse response; + response.data = hookCommandReply; + response.metadata = hookReplyMetadata; + return response; + }); + + assertThrowsStatus(std::move(doneFuture), handleReplyError); + + ASSERT(makeRequestCalled); + ASSERT(handleReplyCalled); + ASSERT(handleReplyArgumentCorrect); } TEST_F(NetworkInterfaceASIOTest, setAlarm) { diff --git a/src/mongo/executor/network_interface_mock_test.cpp b/src/mongo/executor/network_interface_mock_test.cpp index 011e0f0924e..283ec71fe59 100644 --- a/src/mongo/executor/network_interface_mock_test.cpp +++ b/src/mongo/executor/network_interface_mock_test.cpp @@ -33,9 +33,10 @@ #include <utility> #include "mongo/base/status.h" -#include "mongo/executor/network_interface.h" #include "mongo/executor/network_connection_hook.h" +#include "mongo/executor/network_interface.h" #include "mongo/executor/network_interface_mock.h" +#include "mongo/executor/test_network_connection_hook.h" #include "mongo/executor/thread_pool_mock.h" #include "mongo/stdx/memory.h" #include "mongo/unittest/unittest.h" @@ -44,44 +45,6 @@ namespace mongo { namespace executor { namespace { -template <typename ValidateFunc, typename RequestFunc, typename ReplyFunc> -class TestConnectionHook final : public NetworkConnectionHook { -public: - TestConnectionHook(ValidateFunc&& validateFunc, - RequestFunc&& requestFunc, - ReplyFunc&& replyFunc) - : _validateFunc(std::forward<ValidateFunc>(validateFunc)), - _requestFunc(std::forward<RequestFunc>(requestFunc)), - _replyFunc(std::forward<ReplyFunc>(replyFunc)) {} - - Status validateHost(const HostAndPort& remoteHost, - const RemoteCommandResponse& isMasterReply) override { - return _validateFunc(remoteHost, isMasterReply); - } - - StatusWith<boost::optional<RemoteCommandRequest>> makeRequest(const HostAndPort& remoteHost) { - return _requestFunc(remoteHost); - } - - Status handleReply(const HostAndPort& remoteHost, RemoteCommandResponse&& response) { - return _replyFunc(remoteHost, std::move(response)); - } - -private: - ValidateFunc _validateFunc; - RequestFunc _requestFunc; - ReplyFunc _replyFunc; -}; - -template <typename Val, typename Req, typename Rep> -static std::unique_ptr<TestConnectionHook<Val, Req, Rep>> makeTestHook(Val&& validateFunc, - Req&& requestFunc, - Rep&& replyFunc) { - return stdx::make_unique<TestConnectionHook<Val, Req, Rep>>(std::forward<Val>(validateFunc), - std::forward<Req>(requestFunc), - std::forward<Rep>(replyFunc)); -} - class NetworkInterfaceMockTest : public mongo::unittest::Test { public: NetworkInterfaceMockTest() : _net{}, _executor(&_net, 1) {} diff --git a/src/mongo/executor/remote_command_request.h b/src/mongo/executor/remote_command_request.h index 325455ec1da..91a3689f61c 100644 --- a/src/mongo/executor/remote_command_request.h +++ b/src/mongo/executor/remote_command_request.h @@ -31,8 +31,9 @@ #include <string> #include "mongo/db/jsobj.h" -#include "mongo/util/net/hostandport.h" #include "mongo/rpc/metadata.h" +#include "mongo/rpc/request_interface.h" +#include "mongo/util/net/hostandport.h" #include "mongo/util/time_support.h" namespace mongo { @@ -72,6 +73,15 @@ struct RemoteCommandRequest { : RemoteCommandRequest( theTarget, theDbName, theCmdObj, rpc::makeEmptyMetadata(), timeoutMillis) {} + RemoteCommandRequest(const HostAndPort& theTarget, + const rpc::RequestInterface& request, + const Milliseconds timeoutMillis = kNoTimeout) + : RemoteCommandRequest(theTarget, + request.getDatabase().toString(), + request.getCommandArgs(), + request.getMetadata(), + timeoutMillis) {} + std::string toString() const; HostAndPort target; diff --git a/src/mongo/executor/remote_command_response.cpp b/src/mongo/executor/remote_command_response.cpp index a27d6856f24..a0b1baeeb79 100644 --- a/src/mongo/executor/remote_command_response.cpp +++ b/src/mongo/executor/remote_command_response.cpp @@ -30,11 +30,19 @@ #include "mongo/executor/remote_command_response.h" +#include "mongo/rpc/reply_interface.h" #include "mongo/util/mongoutils/str.h" namespace mongo { namespace executor { +// TODO(amidvidy): we currently discard output docs when we use this constructor. We should +// have RCR hold those too, but we need more machinery before that is possible. +RemoteCommandResponse::RemoteCommandResponse(const rpc::ReplyInterface& rpcReply, + Milliseconds millis) + : RemoteCommandResponse(rpcReply.getCommandReply(), rpcReply.getMetadata(), std::move(millis)) { +} + std::string RemoteCommandResponse::toString() const { return str::stream() << "RemoteResponse -- " << " cmd:" << data.toString(); diff --git a/src/mongo/executor/remote_command_response.h b/src/mongo/executor/remote_command_response.h index 0f3e1c9ce0e..65d3f72c66b 100644 --- a/src/mongo/executor/remote_command_response.h +++ b/src/mongo/executor/remote_command_response.h @@ -34,6 +34,11 @@ #include "mongo/util/time_support.h" namespace mongo { + +namespace rpc { +class ReplyInterface; +} // namespace rpc + namespace executor { @@ -46,6 +51,8 @@ struct RemoteCommandResponse { RemoteCommandResponse(BSONObj dataObj, BSONObj metadataObj, Milliseconds millis) : data(std::move(dataObj)), metadata(std::move(metadataObj)), elapsedMillis(millis) {} + RemoteCommandResponse(const rpc::ReplyInterface& rpcReply, Milliseconds millis); + std::string toString() const; BSONObj data; diff --git a/src/mongo/executor/test_network_connection_hook.h b/src/mongo/executor/test_network_connection_hook.h new file mode 100644 index 00000000000..0294b23cd38 --- /dev/null +++ b/src/mongo/executor/test_network_connection_hook.h @@ -0,0 +1,87 @@ +/** + * Copyright (C) 2015 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <boost/optional.hpp> +#include <memory> + +#include "mongo/base/status_with.h" +#include "mongo/executor/network_connection_hook.h" +#include "mongo/stdx/memory.h" + +namespace mongo { +namespace executor { + + +/** + * A utility for creating one-off NetworkConnectionHook instances from inline lambdas. This is + * only to be used in testing code, not in production. + */ +template <typename ValidateFunc, typename RequestFunc, typename ReplyFunc> +class TestConnectionHook final : public NetworkConnectionHook { +public: + TestConnectionHook(ValidateFunc&& validateFunc, + RequestFunc&& requestFunc, + ReplyFunc&& replyFunc) + : _validateFunc(std::forward<ValidateFunc>(validateFunc)), + _requestFunc(std::forward<RequestFunc>(requestFunc)), + _replyFunc(std::forward<ReplyFunc>(replyFunc)) {} + + Status validateHost(const HostAndPort& remoteHost, + const RemoteCommandResponse& isMasterReply) override { + return _validateFunc(remoteHost, isMasterReply); + } + + StatusWith<boost::optional<RemoteCommandRequest>> makeRequest(const HostAndPort& remoteHost) { + return _requestFunc(remoteHost); + } + + Status handleReply(const HostAndPort& remoteHost, RemoteCommandResponse&& response) { + return _replyFunc(remoteHost, std::move(response)); + } + +private: + ValidateFunc _validateFunc; + RequestFunc _requestFunc; + ReplyFunc _replyFunc; +}; + +/** + * Factory function for TestConnectionHook instances. Needed for template type deduction, so that + * one can instantiate a TestConnectionHook instance without uttering the unutterable (types). + */ +template <typename Val, typename Req, typename Rep> +std::unique_ptr<TestConnectionHook<Val, Req, Rep>> makeTestHook(Val&& validateFunc, + Req&& requestFunc, + Rep&& replyFunc) { + return stdx::make_unique<TestConnectionHook<Val, Req, Rep>>(std::forward<Val>(validateFunc), + std::forward<Req>(requestFunc), + std::forward<Rep>(replyFunc)); +} + +} // namespace executor +} // namespace mongo |