diff options
Diffstat (limited to 'src/mongo')
-rw-r--r-- | src/mongo/transport/SConscript | 27 | ||||
-rw-r--r-- | src/mongo/transport/asio_utils.h | 8 | ||||
-rw-r--r-- | src/mongo/transport/mock_session.h | 2 | ||||
-rw-r--r-- | src/mongo/transport/session.h | 10 | ||||
-rw-r--r-- | src/mongo/transport/session_asio.h | 217 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_asio.cpp | 1 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_asio_test.cpp | 203 |
7 files changed, 377 insertions, 91 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 92120a8ce98..8527328b08c 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -67,6 +67,21 @@ tlEnv.Library( ], ) +tlEnv.CppUnitTest( + target='transport_layer_asio_test', + source=[ + 'transport_layer_asio_test.cpp', + ], + LIBDEPS=[ + 'transport_layer', + '$BUILD_DIR/mongo/base', + '$BUILD_DIR/mongo/db/service_context_noop_init', + ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/third_party/shim_asio', + ], +) + tlEnv.Library( target='service_executor', source=[ @@ -183,15 +198,3 @@ env.CppUnitTest( ] ) -env.CppUnitTest( - target='transport_layer_asio_test', - source=[ - 'transport_layer_asio_test.cpp', - ], - LIBDEPS=[ - 'transport_layer', - '$BUILD_DIR/mongo/base', - '$BUILD_DIR/mongo/db/service_context_noop_init', - ], -) - diff --git a/src/mongo/transport/asio_utils.h b/src/mongo/transport/asio_utils.h index 141ec4f78ea..89e7821f5b8 100644 --- a/src/mongo/transport/asio_utils.h +++ b/src/mongo/transport/asio_utils.h @@ -54,6 +54,14 @@ inline Status errorCodeToStatus(const std::error_code& ec) { if (!ec) return Status::OK(); +#ifdef _WIN32 + if (ec == asio::error::timed_out) { +#else + if (ec == asio::error::try_again || ec == asio::error::would_block) { +#endif + return {ErrorCodes::NetworkTimeout, "Socket operation timed out"}; + } + // If the ec.category() is a mongoErrorCategory() then this error was propogated from // mongodb code and we should just pass the error cdoe along as-is. ErrorCodes::Error errorCode = (ec.category() == mongoErrorCategory()) diff --git a/src/mongo/transport/mock_session.h b/src/mongo/transport/mock_session.h index 2f091854193..411e57bd0ec 100644 --- a/src/mongo/transport/mock_session.h +++ b/src/mongo/transport/mock_session.h @@ -103,6 +103,8 @@ public: cb(sinkMessage(message)); } + void setTimeout(boost::optional<Milliseconds>) override {} + explicit MockSession(TransportLayer* tl) : _tl(checked_cast<TransportLayerMock*>(tl)), _remote(), _local() {} explicit MockSession(HostAndPort remote, HostAndPort local, TransportLayer* tl) diff --git a/src/mongo/transport/session.h b/src/mongo/transport/session.h index d25d32b9027..db80aa0814d 100644 --- a/src/mongo/transport/session.h +++ b/src/mongo/transport/session.h @@ -112,6 +112,16 @@ public: virtual Status sinkMessage(Message message) = 0; virtual void asyncSinkMessage(Message message, std::function<void(Status)> cb) = 0; + /** + * This should only be used to detect when the remote host has disappeared without + * notice. It does NOT work correctly for ensuring that operations complete or fail + * by some deadline. + * + * This timeout will only effect calls sourceMessage()/sinkMessage(). Async operations do not + * currently support timeouts. + */ + virtual void setTimeout(boost::optional<Milliseconds> timeout) = 0; + virtual const HostAndPort& remote() const = 0; virtual const HostAndPort& local() const = 0; diff --git a/src/mongo/transport/session_asio.h b/src/mongo/transport/session_asio.h index 3d925a7cd62..f18dfac6fe1 100644 --- a/src/mongo/transport/session_asio.h +++ b/src/mongo/transport/session_asio.h @@ -104,7 +104,7 @@ public: ensureSync(); auto out = StatusWith<Message>(ErrorCodes::InternalError, "uninitialized..."); bool called = false; - sourceMessageImpl(true, [&](StatusWith<Message> in) { + sourceMessageImpl([&](StatusWith<Message> in) { out = std::move(in); called = true; }); @@ -114,7 +114,7 @@ public: void asyncSourceMessage(std::function<void(StatusWith<Message>)> cb) override { ensureAsync(); - sourceMessageImpl(false, std::move(cb)); + sourceMessageImpl(std::move(cb)); } Status sinkMessage(Message message) override { @@ -124,8 +124,7 @@ public: size_t size; bool called = false; - write(true, - asio::buffer(message.buf(), message.size()), + write(asio::buffer(message.buf(), message.size()), [&](const std::error_code& ec_, size_t size_) { ec = ec_; size = size_; @@ -144,7 +143,7 @@ public: void asyncSinkMessage(Message message, std::function<void(Status)> cb) override { ensureAsync(); - write(false, asio::buffer(message.buf(), message.size()), [ + write(asio::buffer(message.buf(), message.size()), [ message, // keep the buffer alive. cb = std::move(cb), this @@ -157,10 +156,56 @@ public: networkCounter.hitPhysicalOut(message.size()); cb(Status::OK()); }); - }; + } + void setTimeout(boost::optional<Milliseconds> timeout) override { + invariant(!timeout || timeout->count() > 0); + _configuredTimeout = timeout; + } private: + template <int Name> + class ASIOSocketTimeoutOption { + public: +#ifdef _WIN32 + using TimeoutType = DWORD; + + ASIOSocketTimeoutOption(Milliseconds timeoutVal) : _timeout(timeoutVal.count()) {} + +#else + using TimeoutType = timeval; + + ASIOSocketTimeoutOption(Milliseconds timeoutVal) { + _timeout.tv_sec = duration_cast<Seconds>(timeoutVal).count(); + const auto minusSeconds = timeoutVal - Seconds{_timeout.tv_sec}; + _timeout.tv_usec = duration_cast<Microseconds>(minusSeconds).count(); + } +#endif + + template <typename Protocol> + int name(const Protocol&) const { + return Name; + } + + template <typename Protocol> + const TimeoutType* data(const Protocol&) const { + return &_timeout; + } + + template <typename Protocol> + std::size_t size(const Protocol&) const { + return sizeof(_timeout); + } + + template <typename Protocol> + int level(const Protocol&) const { + return SOL_SOCKET; + } + + private: + TimeoutType _timeout; + }; + GenericSocket& getSocket() { #ifdef MONGO_CONFIG_SSL if (_sslSocket) { @@ -179,68 +224,65 @@ private: } template <typename Callback> - void sourceMessageImpl(bool sync, Callback&& cb) { + void sourceMessageImpl(Callback&& cb) { static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value); auto headerBuffer = SharedBuffer::allocate(kHeaderSize); auto ptr = headerBuffer.get(); - read( - sync, - asio::buffer(ptr, kHeaderSize), - [ sync, cb = std::forward<Callback>(cb), headerBuffer = std::move(headerBuffer), this ]( - const std::error_code& ec, size_t size) mutable { - - if (ec) - return cb(errorCodeToStatus(ec)); - invariant(size == kHeaderSize); - - const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength()); - if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) { - StringBuilder sb; - sb << "recv(): message msgLen " << msgLen << " is invalid. " - << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes; - const auto str = sb.str(); - LOG(0) << str; - - return cb(Status(ErrorCodes::ProtocolError, str)); - } - - if (msgLen == size) { - // This probably isn't a real case since all (current) messages have bodies. - networkCounter.hitPhysicalIn(msgLen); - return cb(Message(std::move(headerBuffer))); - } - - auto buffer = SharedBuffer::allocate(msgLen); - memcpy(buffer.get(), headerBuffer.get(), kHeaderSize); - - MsgData::View msgView(buffer.get()); - read(sync, - asio::buffer(msgView.data(), msgView.dataLen()), - [ cb = std::move(cb), buffer = std::move(buffer), msgLen, this ]( - const std::error_code& ec, size_t size) mutable { - if (ec) - return cb(errorCodeToStatus(ec)); - networkCounter.hitPhysicalIn(msgLen); - return cb(Message(std::move(buffer))); - }); - }); + read(asio::buffer(ptr, kHeaderSize), + [ cb = std::forward<Callback>(cb), headerBuffer = std::move(headerBuffer), this ]( + const std::error_code& ec, size_t size) mutable { + if (ec) { + return cb(errorCodeToStatus(ec)); + } + + invariant(size == kHeaderSize); + + const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength()); + if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) { + StringBuilder sb; + sb << "recv(): message msgLen " << msgLen << " is invalid. " + << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes; + const auto str = sb.str(); + LOG(0) << str; + + return cb(Status(ErrorCodes::ProtocolError, str)); + } + + if (msgLen == size) { + // This probably isn't a real case since all (current) messages have bodies. + networkCounter.hitPhysicalIn(msgLen); + return cb(Message(std::move(headerBuffer))); + } + + auto buffer = SharedBuffer::allocate(msgLen); + memcpy(buffer.get(), headerBuffer.get(), kHeaderSize); + + MsgData::View msgView(buffer.get()); + read(asio::buffer(msgView.data(), msgView.dataLen()), + [ cb = std::move(cb), buffer = std::move(buffer), msgLen, this ]( + const std::error_code& ec, size_t size) mutable { + if (ec) { + return cb(errorCodeToStatus(ec)); + } + + networkCounter.hitPhysicalIn(msgLen); + return cb(Message(std::move(buffer))); + }); + }); } - template <typename MutableBufferSequence, typename CompleteHandler> - void read(bool sync, const MutableBufferSequence& buffers, CompleteHandler&& handler) { + void read(const MutableBufferSequence& buffers, CompleteHandler&& handler) { #ifdef MONGO_CONFIG_SSL if (_sslSocket) { - return opportunisticRead( - sync, *_sslSocket, buffers, std::forward<CompleteHandler>(handler)); + return opportunisticRead(*_sslSocket, buffers, std::forward<CompleteHandler>(handler)); } else if (!_ranHandshake) { invariant(asio::buffer_size(buffers) >= sizeof(MSGHEADER::Value)); - auto postHandshakeCb = [this, sync, buffers, handler](Status status, - bool needsRead) mutable { + auto postHandshakeCb = [this, buffers, handler](Status status, bool needsRead) mutable { if (status.isOK()) { if (needsRead) { - read(sync, buffers, handler); + read(buffers, handler); } else { std::error_code ec; handler(ec, asio::buffer_size(buffers)); @@ -250,47 +292,63 @@ private: } }; - auto handshakeRecvCb = - [ this, postHandshakeCb = std::move(postHandshakeCb), sync, buffers ]( - const std::error_code& ec, size_t size) mutable { + auto handshakeRecvCb = [ this, postHandshakeCb = std::move(postHandshakeCb), buffers ]( + const std::error_code& ec, size_t size) mutable { _ranHandshake = true; if (ec) { postHandshakeCb(errorCodeToStatus(ec), size); return; } - maybeHandshakeSSL(sync, buffers, std::move(postHandshakeCb)); + maybeHandshakeSSL(buffers, std::move(postHandshakeCb)); }; - return opportunisticRead(sync, _socket, buffers, std::move(handshakeRecvCb)); + return opportunisticRead(_socket, buffers, std::move(handshakeRecvCb)); } #endif - return opportunisticRead(sync, _socket, buffers, std::forward<CompleteHandler>(handler)); + return opportunisticRead(_socket, buffers, std::forward<CompleteHandler>(handler)); } template <typename ConstBufferSequence, typename CompleteHandler> - void write(bool sync, const ConstBufferSequence& buffers, CompleteHandler&& handler) { + void write(const ConstBufferSequence& buffers, CompleteHandler&& handler) { #ifdef MONGO_CONFIG_SSL if (_sslSocket) { - return opportunisticWrite( - sync, *_sslSocket, buffers, std::forward<CompleteHandler>(handler)); + return opportunisticWrite(*_sslSocket, buffers, std::forward<CompleteHandler>(handler)); } #endif - return opportunisticWrite(sync, _socket, buffers, std::forward<CompleteHandler>(handler)); + return opportunisticWrite(_socket, buffers, std::forward<CompleteHandler>(handler)); } void ensureSync() { - if (_blockingMode == Sync) - return; asio::error_code ec; - getSocket().non_blocking(false, ec); - fassertStatusOK(40490, errorCodeToStatus(ec)); - _blockingMode = Sync; + if (_blockingMode != Sync) { + getSocket().non_blocking(false, ec); + fassertStatusOK(40490, errorCodeToStatus(ec)); + _blockingMode = Sync; + } + + if (_socketTimeout != _configuredTimeout) { + // Change boost::none (which means no timeout) into a zero value for the socket option, + // which also means no timeout. + auto timeout = _configuredTimeout.value_or(Milliseconds{0}); + getSocket().set_option(ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), ec); + uassertStatusOK(errorCodeToStatus(ec)); + + getSocket().set_option(ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout), ec); + uassertStatusOK(errorCodeToStatus(ec)); + + _socketTimeout = _configuredTimeout; + } } void ensureAsync() { if (_blockingMode == Async) return; + + // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't + // expecting a socket timeout when they do an async operation. + invariant(!_configuredTimeout); + asio::error_code ec; getSocket().non_blocking(true, ec); fassertStatusOK(50706, errorCodeToStatus(ec)); @@ -298,13 +356,13 @@ private: } template <typename Stream, typename MutableBufferSequence, typename CompleteHandler> - void opportunisticRead(bool sync, - Stream& stream, + void opportunisticRead(Stream& stream, const MutableBufferSequence& buffers, CompleteHandler&& handler) { std::error_code ec; auto size = asio::read(stream, buffers, ec); - if ((ec == asio::error::would_block || ec == asio::error::try_again) && !sync) { + if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) && + (_blockingMode == Async)) { // asio::read is a loop internally, so some of buffers may have been read into already. // So we need to adjust the buffers passed into async_read to be offset by size, if // size is > 0. @@ -325,13 +383,13 @@ private: } template <typename Stream, typename ConstBufferSequence, typename CompleteHandler> - void opportunisticWrite(bool sync, - Stream& stream, + void opportunisticWrite(Stream& stream, const ConstBufferSequence& buffers, CompleteHandler&& handler) { std::error_code ec; auto size = asio::write(stream, buffers, ec); - if ((ec == asio::error::would_block || ec == asio::error::try_again) && !sync) { + if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) && + (_blockingMode == Async)) { // asio::write is a loop internally, so some of buffers may have been read into already. // So we need to adjust the buffers passed into async_write to be offset by size, if // size is > 0. @@ -353,7 +411,7 @@ private: #ifdef MONGO_CONFIG_SSL template <typename MutableBufferSequence, typename HandshakeCb> - void maybeHandshakeSSL(bool sync, const MutableBufferSequence& buffer, HandshakeCb onComplete) { + void maybeHandshakeSSL(const MutableBufferSequence& buffer, HandshakeCb onComplete) { invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value)); MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer)); auto responseTo = headerView.getResponseToMsgId(); @@ -408,7 +466,7 @@ private: onComplete(ec ? errorCodeToStatus(ec) : Status::OK(), true); }; - if (sync) { + if (_blockingMode == Sync) { std::error_code ec; _sslSocket->handshake(asio::ssl::stream_base::server, buffer, ec); handshakeCompleteCb(ec, asio::buffer_size(buffer)); @@ -441,6 +499,9 @@ private: HostAndPort _remote; HostAndPort _local; + boost::optional<Milliseconds> _configuredTimeout; + boost::optional<Milliseconds> _socketTimeout; + GenericSocket _socket; #ifdef MONGO_CONFIG_SSL boost::optional<asio::ssl::stream<decltype(_socket)>> _sslSocket; diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index df0319a9b14..8e437ef949e 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -40,7 +40,6 @@ #include "mongo/util/net/ssl.hpp" #endif -#include "mongo/base/checked_cast.h" #include "mongo/base/system_error.h" #include "mongo/db/server_options.h" #include "mongo/db/service_context.h" diff --git a/src/mongo/transport/transport_layer_asio_test.cpp b/src/mongo/transport/transport_layer_asio_test.cpp index e82e0c1353a..594dfaaf3d9 100644 --- a/src/mongo/transport/transport_layer_asio_test.cpp +++ b/src/mongo/transport/transport_layer_asio_test.cpp @@ -36,8 +36,11 @@ #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" #include "mongo/util/log.h" +#include "mongo/util/net/op_msg.h" #include "mongo/util/net/sock.h" +#include "asio.hpp" + namespace mongo { namespace { @@ -133,6 +136,9 @@ TEST(TransportLayerASIO, PortZeroConnect) { ServerGlobalParams params; params.noUnixSocket = true; transport::TransportLayerASIO::Options opts(¶ms); + + // TODO SERVER-30212 should clean this up and assign a port from the supplied port range + // provided by resmoke. opts.port = 0; return opts; }(); @@ -153,5 +159,202 @@ TEST(TransportLayerASIO, PortZeroConnect) { tla.shutdown(); } +class TimeoutSEP : public ServiceEntryPoint { +public: + void endAllSessions(transport::Session::TagMask tags) override { + MONGO_UNREACHABLE; + } + + bool shutdown(Milliseconds timeout) override { + return true; + } + + Stats sessionStats() const override { + return {}; + } + + size_t numOpenSessions() const override { + return 0; + } + + DbResponse handleRequest(OperationContext* opCtx, const Message& request) override { + MONGO_UNREACHABLE; + } + + bool waitForTimeout(boost::optional<Milliseconds> timeout = boost::none) { + stdx::unique_lock<stdx::mutex> lk(_mutex); + bool ret = true; + if (timeout) { + ret = _cond.wait_for(lk, timeout->toSystemDuration(), [this] { return _finished; }); + } else { + _cond.wait(lk, [this] { return _finished; }); + } + + _finished = false; + return ret; + } + +protected: + void notifyComplete() { + stdx::unique_lock<stdx::mutex> lk(_mutex); + _finished = true; + _cond.notify_one(); + } + +private: + stdx::mutex _mutex; + stdx::condition_variable _cond; + bool _finished = false; +}; + +class TimeoutSyncSEP : public TimeoutSEP { +public: + enum Mode { kShouldTimeout, kNoTimeout }; + TimeoutSyncSEP(Mode mode) : _mode(mode) {} + + void startSession(transport::SessionHandle session) override { + log() << "Accepted connection from " << session->remote(); + stdx::thread([ this, session = std::move(session) ] { + log() << "waiting for message"; + session->setTimeout(Milliseconds{500}); + auto status = session->sourceMessage().getStatus(); + if (_mode == kShouldTimeout) { + ASSERT_EQ(status, ErrorCodes::NetworkTimeout); + log() << "message timed out"; + } else { + ASSERT_OK(status); + log() << "message received okay"; + } + + notifyComplete(); + }).detach(); + } + +private: + Mode _mode; +}; + +class TimeoutConnector { +public: + TimeoutConnector(int port, bool sendRequest) + : _ctx(), _sock(_ctx), _endpoint(asio::ip::address_v4::loopback(), port) { + std::error_code ec; + _sock.connect(_endpoint, ec); + ASSERT_EQ(ec, std::error_code()); + + if (sendRequest) { + sendMessage(); + } + } + + void sendMessage() { + OpMsgBuilder builder; + builder.setBody(BSON("ping" << 1)); + Message msg = builder.finish(); + msg.header().setResponseToMsgId(0); + msg.header().setId(0); + + std::error_code ec; + asio::write(_sock, asio::buffer(msg.buf(), msg.size()), ec); + ASSERT_FALSE(ec); + } + +private: + asio::io_context _ctx; + asio::ip::tcp::socket _sock; + asio::ip::tcp::endpoint _endpoint; +}; + +std::unique_ptr<transport::TransportLayerASIO> makeAndStartTL(ServiceEntryPoint* sep) { + auto options = [] { + ServerGlobalParams params; + params.noUnixSocket = true; + transport::TransportLayerASIO::Options opts(¶ms); + opts.port = 0; + return opts; + }(); + + auto tla = std::make_unique<transport::TransportLayerASIO>(options, sep); + ASSERT_OK(tla->setup()); + ASSERT_OK(tla->start()); + + return tla; +} + +/* check that timeouts actually time out */ +TEST(TransportLayerASIO, SourceSyncTimeoutTimesOut) { + TimeoutSyncSEP sep(TimeoutSyncSEP::kShouldTimeout); + auto tla = makeAndStartTL(&sep); + + TimeoutConnector connector(tla->listenerPort(), false); + + sep.waitForTimeout(); + tla->shutdown(); +} + +/* check that timeouts don't time out unless there's an actual timeout */ +TEST(TransportLayerASIO, SourceSyncTimeoutSucceeds) { + TimeoutSyncSEP sep(TimeoutSyncSEP::kNoTimeout); + auto tla = makeAndStartTL(&sep); + + TimeoutConnector connector(tla->listenerPort(), true); + + sep.waitForTimeout(); + tla->shutdown(); +} + +/* check that switching from timeouts to no timeouts correctly resets the timeout to unlimited */ +class TimeoutSwitchModesSEP : public TimeoutSEP { +public: + void startSession(transport::SessionHandle session) override { + log() << "Accepted connection from " << session->remote(); + stdx::thread worker([ this, session = std::move(session) ] { + log() << "waiting for message"; + auto sourceMessage = [&] { return session->sourceMessage().getStatus(); }; + + // the first message we source should time out. + session->setTimeout(Milliseconds{500}); + ASSERT_EQ(sourceMessage(), ErrorCodes::NetworkTimeout); + notifyComplete(); + + log() << "timed out successfully"; + + // get the session back in a known state with the timeout still in place + ASSERT_OK(sourceMessage()); + notifyComplete(); + + log() << "waiting for message without a timeout"; + + // this should block and timeout the waitForComplete mutex, and the session should wait + // for a while to make sure this isn't timing out and then send a message to unblock + // the this call to recv + session->setTimeout(boost::none); + ASSERT_OK(sourceMessage()); + + notifyComplete(); + log() << "ending test"; + }); + worker.detach(); + } +}; + +TEST(TransportLayerASIO, SwitchTimeoutModes) { + TimeoutSwitchModesSEP sep; + auto tla = makeAndStartTL(&sep); + + TimeoutConnector connector(tla->listenerPort(), false); + + ASSERT_TRUE(sep.waitForTimeout()); + + connector.sendMessage(); + ASSERT_TRUE(sep.waitForTimeout()); + + ASSERT_FALSE(sep.waitForTimeout(Milliseconds{1000})); + connector.sendMessage(); + ASSERT_TRUE(sep.waitForTimeout()); + + tla->shutdown(); +} + } // namespace } // namespace mongo |