summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/transport/SConscript27
-rw-r--r--src/mongo/transport/asio_utils.h8
-rw-r--r--src/mongo/transport/mock_session.h2
-rw-r--r--src/mongo/transport/session.h10
-rw-r--r--src/mongo/transport/session_asio.h217
-rw-r--r--src/mongo/transport/transport_layer_asio.cpp1
-rw-r--r--src/mongo/transport/transport_layer_asio_test.cpp203
-rw-r--r--src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp57
-rw-r--r--src/third_party/asio-master/patches/0003-MONGO-HACK-allow-blocking-sockets-to-timeout.patch78
9 files changed, 468 insertions, 135 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript
index 92120a8ce98..8527328b08c 100644
--- a/src/mongo/transport/SConscript
+++ b/src/mongo/transport/SConscript
@@ -67,6 +67,21 @@ tlEnv.Library(
],
)
+tlEnv.CppUnitTest(
+ target='transport_layer_asio_test',
+ source=[
+ 'transport_layer_asio_test.cpp',
+ ],
+ LIBDEPS=[
+ 'transport_layer',
+ '$BUILD_DIR/mongo/base',
+ '$BUILD_DIR/mongo/db/service_context_noop_init',
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/third_party/shim_asio',
+ ],
+)
+
tlEnv.Library(
target='service_executor',
source=[
@@ -183,15 +198,3 @@ env.CppUnitTest(
]
)
-env.CppUnitTest(
- target='transport_layer_asio_test',
- source=[
- 'transport_layer_asio_test.cpp',
- ],
- LIBDEPS=[
- 'transport_layer',
- '$BUILD_DIR/mongo/base',
- '$BUILD_DIR/mongo/db/service_context_noop_init',
- ],
-)
-
diff --git a/src/mongo/transport/asio_utils.h b/src/mongo/transport/asio_utils.h
index 141ec4f78ea..89e7821f5b8 100644
--- a/src/mongo/transport/asio_utils.h
+++ b/src/mongo/transport/asio_utils.h
@@ -54,6 +54,14 @@ inline Status errorCodeToStatus(const std::error_code& ec) {
if (!ec)
return Status::OK();
+#ifdef _WIN32
+ if (ec == asio::error::timed_out) {
+#else
+ if (ec == asio::error::try_again || ec == asio::error::would_block) {
+#endif
+ return {ErrorCodes::NetworkTimeout, "Socket operation timed out"};
+ }
+
// If the ec.category() is a mongoErrorCategory() then this error was propogated from
// mongodb code and we should just pass the error cdoe along as-is.
ErrorCodes::Error errorCode = (ec.category() == mongoErrorCategory())
diff --git a/src/mongo/transport/mock_session.h b/src/mongo/transport/mock_session.h
index 2f091854193..411e57bd0ec 100644
--- a/src/mongo/transport/mock_session.h
+++ b/src/mongo/transport/mock_session.h
@@ -103,6 +103,8 @@ public:
cb(sinkMessage(message));
}
+ void setTimeout(boost::optional<Milliseconds>) override {}
+
explicit MockSession(TransportLayer* tl)
: _tl(checked_cast<TransportLayerMock*>(tl)), _remote(), _local() {}
explicit MockSession(HostAndPort remote, HostAndPort local, TransportLayer* tl)
diff --git a/src/mongo/transport/session.h b/src/mongo/transport/session.h
index d25d32b9027..db80aa0814d 100644
--- a/src/mongo/transport/session.h
+++ b/src/mongo/transport/session.h
@@ -112,6 +112,16 @@ public:
virtual Status sinkMessage(Message message) = 0;
virtual void asyncSinkMessage(Message message, std::function<void(Status)> cb) = 0;
+ /**
+ * This should only be used to detect when the remote host has disappeared without
+ * notice. It does NOT work correctly for ensuring that operations complete or fail
+ * by some deadline.
+ *
+ * This timeout will only effect calls sourceMessage()/sinkMessage(). Async operations do not
+ * currently support timeouts.
+ */
+ virtual void setTimeout(boost::optional<Milliseconds> timeout) = 0;
+
virtual const HostAndPort& remote() const = 0;
virtual const HostAndPort& local() const = 0;
diff --git a/src/mongo/transport/session_asio.h b/src/mongo/transport/session_asio.h
index 3d925a7cd62..f18dfac6fe1 100644
--- a/src/mongo/transport/session_asio.h
+++ b/src/mongo/transport/session_asio.h
@@ -104,7 +104,7 @@ public:
ensureSync();
auto out = StatusWith<Message>(ErrorCodes::InternalError, "uninitialized...");
bool called = false;
- sourceMessageImpl(true, [&](StatusWith<Message> in) {
+ sourceMessageImpl([&](StatusWith<Message> in) {
out = std::move(in);
called = true;
});
@@ -114,7 +114,7 @@ public:
void asyncSourceMessage(std::function<void(StatusWith<Message>)> cb) override {
ensureAsync();
- sourceMessageImpl(false, std::move(cb));
+ sourceMessageImpl(std::move(cb));
}
Status sinkMessage(Message message) override {
@@ -124,8 +124,7 @@ public:
size_t size;
bool called = false;
- write(true,
- asio::buffer(message.buf(), message.size()),
+ write(asio::buffer(message.buf(), message.size()),
[&](const std::error_code& ec_, size_t size_) {
ec = ec_;
size = size_;
@@ -144,7 +143,7 @@ public:
void asyncSinkMessage(Message message, std::function<void(Status)> cb) override {
ensureAsync();
- write(false, asio::buffer(message.buf(), message.size()), [
+ write(asio::buffer(message.buf(), message.size()), [
message, // keep the buffer alive.
cb = std::move(cb),
this
@@ -157,10 +156,56 @@ public:
networkCounter.hitPhysicalOut(message.size());
cb(Status::OK());
});
- };
+ }
+ void setTimeout(boost::optional<Milliseconds> timeout) override {
+ invariant(!timeout || timeout->count() > 0);
+ _configuredTimeout = timeout;
+ }
private:
+ template <int Name>
+ class ASIOSocketTimeoutOption {
+ public:
+#ifdef _WIN32
+ using TimeoutType = DWORD;
+
+ ASIOSocketTimeoutOption(Milliseconds timeoutVal) : _timeout(timeoutVal.count()) {}
+
+#else
+ using TimeoutType = timeval;
+
+ ASIOSocketTimeoutOption(Milliseconds timeoutVal) {
+ _timeout.tv_sec = duration_cast<Seconds>(timeoutVal).count();
+ const auto minusSeconds = timeoutVal - Seconds{_timeout.tv_sec};
+ _timeout.tv_usec = duration_cast<Microseconds>(minusSeconds).count();
+ }
+#endif
+
+ template <typename Protocol>
+ int name(const Protocol&) const {
+ return Name;
+ }
+
+ template <typename Protocol>
+ const TimeoutType* data(const Protocol&) const {
+ return &_timeout;
+ }
+
+ template <typename Protocol>
+ std::size_t size(const Protocol&) const {
+ return sizeof(_timeout);
+ }
+
+ template <typename Protocol>
+ int level(const Protocol&) const {
+ return SOL_SOCKET;
+ }
+
+ private:
+ TimeoutType _timeout;
+ };
+
GenericSocket& getSocket() {
#ifdef MONGO_CONFIG_SSL
if (_sslSocket) {
@@ -179,68 +224,65 @@ private:
}
template <typename Callback>
- void sourceMessageImpl(bool sync, Callback&& cb) {
+ void sourceMessageImpl(Callback&& cb) {
static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value);
auto headerBuffer = SharedBuffer::allocate(kHeaderSize);
auto ptr = headerBuffer.get();
- read(
- sync,
- asio::buffer(ptr, kHeaderSize),
- [ sync, cb = std::forward<Callback>(cb), headerBuffer = std::move(headerBuffer), this ](
- const std::error_code& ec, size_t size) mutable {
-
- if (ec)
- return cb(errorCodeToStatus(ec));
- invariant(size == kHeaderSize);
-
- const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength());
- if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) {
- StringBuilder sb;
- sb << "recv(): message msgLen " << msgLen << " is invalid. "
- << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes;
- const auto str = sb.str();
- LOG(0) << str;
-
- return cb(Status(ErrorCodes::ProtocolError, str));
- }
-
- if (msgLen == size) {
- // This probably isn't a real case since all (current) messages have bodies.
- networkCounter.hitPhysicalIn(msgLen);
- return cb(Message(std::move(headerBuffer)));
- }
-
- auto buffer = SharedBuffer::allocate(msgLen);
- memcpy(buffer.get(), headerBuffer.get(), kHeaderSize);
-
- MsgData::View msgView(buffer.get());
- read(sync,
- asio::buffer(msgView.data(), msgView.dataLen()),
- [ cb = std::move(cb), buffer = std::move(buffer), msgLen, this ](
- const std::error_code& ec, size_t size) mutable {
- if (ec)
- return cb(errorCodeToStatus(ec));
- networkCounter.hitPhysicalIn(msgLen);
- return cb(Message(std::move(buffer)));
- });
- });
+ read(asio::buffer(ptr, kHeaderSize),
+ [ cb = std::forward<Callback>(cb), headerBuffer = std::move(headerBuffer), this ](
+ const std::error_code& ec, size_t size) mutable {
+ if (ec) {
+ return cb(errorCodeToStatus(ec));
+ }
+
+ invariant(size == kHeaderSize);
+
+ const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength());
+ if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) {
+ StringBuilder sb;
+ sb << "recv(): message msgLen " << msgLen << " is invalid. "
+ << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes;
+ const auto str = sb.str();
+ LOG(0) << str;
+
+ return cb(Status(ErrorCodes::ProtocolError, str));
+ }
+
+ if (msgLen == size) {
+ // This probably isn't a real case since all (current) messages have bodies.
+ networkCounter.hitPhysicalIn(msgLen);
+ return cb(Message(std::move(headerBuffer)));
+ }
+
+ auto buffer = SharedBuffer::allocate(msgLen);
+ memcpy(buffer.get(), headerBuffer.get(), kHeaderSize);
+
+ MsgData::View msgView(buffer.get());
+ read(asio::buffer(msgView.data(), msgView.dataLen()),
+ [ cb = std::move(cb), buffer = std::move(buffer), msgLen, this ](
+ const std::error_code& ec, size_t size) mutable {
+ if (ec) {
+ return cb(errorCodeToStatus(ec));
+ }
+
+ networkCounter.hitPhysicalIn(msgLen);
+ return cb(Message(std::move(buffer)));
+ });
+ });
}
-
template <typename MutableBufferSequence, typename CompleteHandler>
- void read(bool sync, const MutableBufferSequence& buffers, CompleteHandler&& handler) {
+ void read(const MutableBufferSequence& buffers, CompleteHandler&& handler) {
#ifdef MONGO_CONFIG_SSL
if (_sslSocket) {
- return opportunisticRead(
- sync, *_sslSocket, buffers, std::forward<CompleteHandler>(handler));
+ return opportunisticRead(*_sslSocket, buffers, std::forward<CompleteHandler>(handler));
} else if (!_ranHandshake) {
invariant(asio::buffer_size(buffers) >= sizeof(MSGHEADER::Value));
- auto postHandshakeCb = [this, sync, buffers, handler](Status status,
- bool needsRead) mutable {
+ auto postHandshakeCb = [this, buffers, handler](Status status, bool needsRead) mutable {
if (status.isOK()) {
if (needsRead) {
- read(sync, buffers, handler);
+ read(buffers, handler);
} else {
std::error_code ec;
handler(ec, asio::buffer_size(buffers));
@@ -250,47 +292,63 @@ private:
}
};
- auto handshakeRecvCb =
- [ this, postHandshakeCb = std::move(postHandshakeCb), sync, buffers ](
- const std::error_code& ec, size_t size) mutable {
+ auto handshakeRecvCb = [ this, postHandshakeCb = std::move(postHandshakeCb), buffers ](
+ const std::error_code& ec, size_t size) mutable {
_ranHandshake = true;
if (ec) {
postHandshakeCb(errorCodeToStatus(ec), size);
return;
}
- maybeHandshakeSSL(sync, buffers, std::move(postHandshakeCb));
+ maybeHandshakeSSL(buffers, std::move(postHandshakeCb));
};
- return opportunisticRead(sync, _socket, buffers, std::move(handshakeRecvCb));
+ return opportunisticRead(_socket, buffers, std::move(handshakeRecvCb));
}
#endif
- return opportunisticRead(sync, _socket, buffers, std::forward<CompleteHandler>(handler));
+ return opportunisticRead(_socket, buffers, std::forward<CompleteHandler>(handler));
}
template <typename ConstBufferSequence, typename CompleteHandler>
- void write(bool sync, const ConstBufferSequence& buffers, CompleteHandler&& handler) {
+ void write(const ConstBufferSequence& buffers, CompleteHandler&& handler) {
#ifdef MONGO_CONFIG_SSL
if (_sslSocket) {
- return opportunisticWrite(
- sync, *_sslSocket, buffers, std::forward<CompleteHandler>(handler));
+ return opportunisticWrite(*_sslSocket, buffers, std::forward<CompleteHandler>(handler));
}
#endif
- return opportunisticWrite(sync, _socket, buffers, std::forward<CompleteHandler>(handler));
+ return opportunisticWrite(_socket, buffers, std::forward<CompleteHandler>(handler));
}
void ensureSync() {
- if (_blockingMode == Sync)
- return;
asio::error_code ec;
- getSocket().non_blocking(false, ec);
- fassertStatusOK(40490, errorCodeToStatus(ec));
- _blockingMode = Sync;
+ if (_blockingMode != Sync) {
+ getSocket().non_blocking(false, ec);
+ fassertStatusOK(40490, errorCodeToStatus(ec));
+ _blockingMode = Sync;
+ }
+
+ if (_socketTimeout != _configuredTimeout) {
+ // Change boost::none (which means no timeout) into a zero value for the socket option,
+ // which also means no timeout.
+ auto timeout = _configuredTimeout.value_or(Milliseconds{0});
+ getSocket().set_option(ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), ec);
+ uassertStatusOK(errorCodeToStatus(ec));
+
+ getSocket().set_option(ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout), ec);
+ uassertStatusOK(errorCodeToStatus(ec));
+
+ _socketTimeout = _configuredTimeout;
+ }
}
void ensureAsync() {
if (_blockingMode == Async)
return;
+
+ // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't
+ // expecting a socket timeout when they do an async operation.
+ invariant(!_configuredTimeout);
+
asio::error_code ec;
getSocket().non_blocking(true, ec);
fassertStatusOK(50706, errorCodeToStatus(ec));
@@ -298,13 +356,13 @@ private:
}
template <typename Stream, typename MutableBufferSequence, typename CompleteHandler>
- void opportunisticRead(bool sync,
- Stream& stream,
+ void opportunisticRead(Stream& stream,
const MutableBufferSequence& buffers,
CompleteHandler&& handler) {
std::error_code ec;
auto size = asio::read(stream, buffers, ec);
- if ((ec == asio::error::would_block || ec == asio::error::try_again) && !sync) {
+ if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) &&
+ (_blockingMode == Async)) {
// asio::read is a loop internally, so some of buffers may have been read into already.
// So we need to adjust the buffers passed into async_read to be offset by size, if
// size is > 0.
@@ -325,13 +383,13 @@ private:
}
template <typename Stream, typename ConstBufferSequence, typename CompleteHandler>
- void opportunisticWrite(bool sync,
- Stream& stream,
+ void opportunisticWrite(Stream& stream,
const ConstBufferSequence& buffers,
CompleteHandler&& handler) {
std::error_code ec;
auto size = asio::write(stream, buffers, ec);
- if ((ec == asio::error::would_block || ec == asio::error::try_again) && !sync) {
+ if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) &&
+ (_blockingMode == Async)) {
// asio::write is a loop internally, so some of buffers may have been read into already.
// So we need to adjust the buffers passed into async_write to be offset by size, if
// size is > 0.
@@ -353,7 +411,7 @@ private:
#ifdef MONGO_CONFIG_SSL
template <typename MutableBufferSequence, typename HandshakeCb>
- void maybeHandshakeSSL(bool sync, const MutableBufferSequence& buffer, HandshakeCb onComplete) {
+ void maybeHandshakeSSL(const MutableBufferSequence& buffer, HandshakeCb onComplete) {
invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value));
MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer));
auto responseTo = headerView.getResponseToMsgId();
@@ -408,7 +466,7 @@ private:
onComplete(ec ? errorCodeToStatus(ec) : Status::OK(), true);
};
- if (sync) {
+ if (_blockingMode == Sync) {
std::error_code ec;
_sslSocket->handshake(asio::ssl::stream_base::server, buffer, ec);
handshakeCompleteCb(ec, asio::buffer_size(buffer));
@@ -441,6 +499,9 @@ private:
HostAndPort _remote;
HostAndPort _local;
+ boost::optional<Milliseconds> _configuredTimeout;
+ boost::optional<Milliseconds> _socketTimeout;
+
GenericSocket _socket;
#ifdef MONGO_CONFIG_SSL
boost::optional<asio::ssl::stream<decltype(_socket)>> _sslSocket;
diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp
index df0319a9b14..8e437ef949e 100644
--- a/src/mongo/transport/transport_layer_asio.cpp
+++ b/src/mongo/transport/transport_layer_asio.cpp
@@ -40,7 +40,6 @@
#include "mongo/util/net/ssl.hpp"
#endif
-#include "mongo/base/checked_cast.h"
#include "mongo/base/system_error.h"
#include "mongo/db/server_options.h"
#include "mongo/db/service_context.h"
diff --git a/src/mongo/transport/transport_layer_asio_test.cpp b/src/mongo/transport/transport_layer_asio_test.cpp
index e82e0c1353a..594dfaaf3d9 100644
--- a/src/mongo/transport/transport_layer_asio_test.cpp
+++ b/src/mongo/transport/transport_layer_asio_test.cpp
@@ -36,8 +36,11 @@
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/log.h"
+#include "mongo/util/net/op_msg.h"
#include "mongo/util/net/sock.h"
+#include "asio.hpp"
+
namespace mongo {
namespace {
@@ -133,6 +136,9 @@ TEST(TransportLayerASIO, PortZeroConnect) {
ServerGlobalParams params;
params.noUnixSocket = true;
transport::TransportLayerASIO::Options opts(&params);
+
+ // TODO SERVER-30212 should clean this up and assign a port from the supplied port range
+ // provided by resmoke.
opts.port = 0;
return opts;
}();
@@ -153,5 +159,202 @@ TEST(TransportLayerASIO, PortZeroConnect) {
tla.shutdown();
}
+class TimeoutSEP : public ServiceEntryPoint {
+public:
+ void endAllSessions(transport::Session::TagMask tags) override {
+ MONGO_UNREACHABLE;
+ }
+
+ bool shutdown(Milliseconds timeout) override {
+ return true;
+ }
+
+ Stats sessionStats() const override {
+ return {};
+ }
+
+ size_t numOpenSessions() const override {
+ return 0;
+ }
+
+ DbResponse handleRequest(OperationContext* opCtx, const Message& request) override {
+ MONGO_UNREACHABLE;
+ }
+
+ bool waitForTimeout(boost::optional<Milliseconds> timeout = boost::none) {
+ stdx::unique_lock<stdx::mutex> lk(_mutex);
+ bool ret = true;
+ if (timeout) {
+ ret = _cond.wait_for(lk, timeout->toSystemDuration(), [this] { return _finished; });
+ } else {
+ _cond.wait(lk, [this] { return _finished; });
+ }
+
+ _finished = false;
+ return ret;
+ }
+
+protected:
+ void notifyComplete() {
+ stdx::unique_lock<stdx::mutex> lk(_mutex);
+ _finished = true;
+ _cond.notify_one();
+ }
+
+private:
+ stdx::mutex _mutex;
+ stdx::condition_variable _cond;
+ bool _finished = false;
+};
+
+class TimeoutSyncSEP : public TimeoutSEP {
+public:
+ enum Mode { kShouldTimeout, kNoTimeout };
+ TimeoutSyncSEP(Mode mode) : _mode(mode) {}
+
+ void startSession(transport::SessionHandle session) override {
+ log() << "Accepted connection from " << session->remote();
+ stdx::thread([ this, session = std::move(session) ] {
+ log() << "waiting for message";
+ session->setTimeout(Milliseconds{500});
+ auto status = session->sourceMessage().getStatus();
+ if (_mode == kShouldTimeout) {
+ ASSERT_EQ(status, ErrorCodes::NetworkTimeout);
+ log() << "message timed out";
+ } else {
+ ASSERT_OK(status);
+ log() << "message received okay";
+ }
+
+ notifyComplete();
+ }).detach();
+ }
+
+private:
+ Mode _mode;
+};
+
+class TimeoutConnector {
+public:
+ TimeoutConnector(int port, bool sendRequest)
+ : _ctx(), _sock(_ctx), _endpoint(asio::ip::address_v4::loopback(), port) {
+ std::error_code ec;
+ _sock.connect(_endpoint, ec);
+ ASSERT_EQ(ec, std::error_code());
+
+ if (sendRequest) {
+ sendMessage();
+ }
+ }
+
+ void sendMessage() {
+ OpMsgBuilder builder;
+ builder.setBody(BSON("ping" << 1));
+ Message msg = builder.finish();
+ msg.header().setResponseToMsgId(0);
+ msg.header().setId(0);
+
+ std::error_code ec;
+ asio::write(_sock, asio::buffer(msg.buf(), msg.size()), ec);
+ ASSERT_FALSE(ec);
+ }
+
+private:
+ asio::io_context _ctx;
+ asio::ip::tcp::socket _sock;
+ asio::ip::tcp::endpoint _endpoint;
+};
+
+std::unique_ptr<transport::TransportLayerASIO> makeAndStartTL(ServiceEntryPoint* sep) {
+ auto options = [] {
+ ServerGlobalParams params;
+ params.noUnixSocket = true;
+ transport::TransportLayerASIO::Options opts(&params);
+ opts.port = 0;
+ return opts;
+ }();
+
+ auto tla = std::make_unique<transport::TransportLayerASIO>(options, sep);
+ ASSERT_OK(tla->setup());
+ ASSERT_OK(tla->start());
+
+ return tla;
+}
+
+/* check that timeouts actually time out */
+TEST(TransportLayerASIO, SourceSyncTimeoutTimesOut) {
+ TimeoutSyncSEP sep(TimeoutSyncSEP::kShouldTimeout);
+ auto tla = makeAndStartTL(&sep);
+
+ TimeoutConnector connector(tla->listenerPort(), false);
+
+ sep.waitForTimeout();
+ tla->shutdown();
+}
+
+/* check that timeouts don't time out unless there's an actual timeout */
+TEST(TransportLayerASIO, SourceSyncTimeoutSucceeds) {
+ TimeoutSyncSEP sep(TimeoutSyncSEP::kNoTimeout);
+ auto tla = makeAndStartTL(&sep);
+
+ TimeoutConnector connector(tla->listenerPort(), true);
+
+ sep.waitForTimeout();
+ tla->shutdown();
+}
+
+/* check that switching from timeouts to no timeouts correctly resets the timeout to unlimited */
+class TimeoutSwitchModesSEP : public TimeoutSEP {
+public:
+ void startSession(transport::SessionHandle session) override {
+ log() << "Accepted connection from " << session->remote();
+ stdx::thread worker([ this, session = std::move(session) ] {
+ log() << "waiting for message";
+ auto sourceMessage = [&] { return session->sourceMessage().getStatus(); };
+
+ // the first message we source should time out.
+ session->setTimeout(Milliseconds{500});
+ ASSERT_EQ(sourceMessage(), ErrorCodes::NetworkTimeout);
+ notifyComplete();
+
+ log() << "timed out successfully";
+
+ // get the session back in a known state with the timeout still in place
+ ASSERT_OK(sourceMessage());
+ notifyComplete();
+
+ log() << "waiting for message without a timeout";
+
+ // this should block and timeout the waitForComplete mutex, and the session should wait
+ // for a while to make sure this isn't timing out and then send a message to unblock
+ // the this call to recv
+ session->setTimeout(boost::none);
+ ASSERT_OK(sourceMessage());
+
+ notifyComplete();
+ log() << "ending test";
+ });
+ worker.detach();
+ }
+};
+
+TEST(TransportLayerASIO, SwitchTimeoutModes) {
+ TimeoutSwitchModesSEP sep;
+ auto tla = makeAndStartTL(&sep);
+
+ TimeoutConnector connector(tla->listenerPort(), false);
+
+ ASSERT_TRUE(sep.waitForTimeout());
+
+ connector.sendMessage();
+ ASSERT_TRUE(sep.waitForTimeout());
+
+ ASSERT_FALSE(sep.waitForTimeout(Milliseconds{1000}));
+ connector.sendMessage();
+ ASSERT_TRUE(sep.waitForTimeout());
+
+ tla->shutdown();
+}
+
} // namespace
} // namespace mongo
diff --git a/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp b/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp
index b3b1a0cf811..2f89889fac8 100644
--- a/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp
+++ b/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp
@@ -803,33 +803,19 @@ size_t sync_recv(socket_type s, state_type state, buf* bufs,
return 0;
}
- // Read some data.
- for (;;)
- {
- // Try to complete the operation without blocking.
- signed_size_type bytes = socket_ops::recv(s, bufs, count, flags, ec);
-
- // Check if operation succeeded.
- if (bytes > 0)
- return bytes;
+ signed_size_type bytes = socket_ops::recv(s, bufs, count, flags, ec);
- // Check for EOF.
- if ((state & stream_oriented) && bytes == 0)
- {
- ec = asio::error::eof;
- return 0;
- }
+ // Check if operation succeeded.
+ if (bytes > 0)
+ return bytes;
- // Operation failed.
- if ((state & user_set_non_blocking)
- || (ec != asio::error::would_block
- && ec != asio::error::try_again))
- return 0;
-
- // Wait for socket to become ready.
- if (socket_ops::poll_read(s, 0, -1, ec) < 0)
- return 0;
+ // Check for EOF.
+ if ((state & stream_oriented) && bytes == 0)
+ {
+ ec = asio::error::eof;
}
+
+ return 0;
}
#if defined(ASIO_HAS_IOCP)
@@ -1203,26 +1189,9 @@ size_t sync_send(socket_type s, state_type state, const buf* bufs,
return 0;
}
- // Read some data.
- for (;;)
- {
- // Try to complete the operation without blocking.
- signed_size_type bytes = socket_ops::send(s, bufs, count, flags, ec);
-
- // Check if operation succeeded.
- if (bytes >= 0)
- return bytes;
-
- // Operation failed.
- if ((state & user_set_non_blocking)
- || (ec != asio::error::would_block
- && ec != asio::error::try_again))
- return 0;
-
- // Wait for socket to become ready.
- if (socket_ops::poll_write(s, 0, -1, ec) < 0)
- return 0;
- }
+ // Write some data
+ signed_size_type bytes = socket_ops::send(s, bufs, count, flags, ec);
+ return bytes;
}
#if defined(ASIO_HAS_IOCP)
diff --git a/src/third_party/asio-master/patches/0003-MONGO-HACK-allow-blocking-sockets-to-timeout.patch b/src/third_party/asio-master/patches/0003-MONGO-HACK-allow-blocking-sockets-to-timeout.patch
new file mode 100644
index 00000000000..ba9186dd010
--- /dev/null
+++ b/src/third_party/asio-master/patches/0003-MONGO-HACK-allow-blocking-sockets-to-timeout.patch
@@ -0,0 +1,78 @@
+diff --git a/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp b/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp
+index b3b1a0cf81..2f89889fac 100644
+--- a/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp
++++ b/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp
+@@ -803,33 +803,19 @@ size_t sync_recv(socket_type s, state_type state, buf* bufs,
+ return 0;
+ }
+
+- // Read some data.
+- for (;;)
+- {
+- // Try to complete the operation without blocking.
+- signed_size_type bytes = socket_ops::recv(s, bufs, count, flags, ec);
+-
+- // Check if operation succeeded.
+- if (bytes > 0)
+- return bytes;
++ signed_size_type bytes = socket_ops::recv(s, bufs, count, flags, ec);
+
+- // Check for EOF.
+- if ((state & stream_oriented) && bytes == 0)
+- {
+- ec = asio::error::eof;
+- return 0;
+- }
++ // Check if operation succeeded.
++ if (bytes > 0)
++ return bytes;
+
+- // Operation failed.
+- if ((state & user_set_non_blocking)
+- || (ec != asio::error::would_block
+- && ec != asio::error::try_again))
+- return 0;
+-
+- // Wait for socket to become ready.
+- if (socket_ops::poll_read(s, 0, -1, ec) < 0)
+- return 0;
++ // Check for EOF.
++ if ((state & stream_oriented) && bytes == 0)
++ {
++ ec = asio::error::eof;
+ }
++
++ return 0;
+ }
+
+ #if defined(ASIO_HAS_IOCP)
+@@ -1203,26 +1189,9 @@ size_t sync_send(socket_type s, state_type state, const buf* bufs,
+ return 0;
+ }
+
+- // Read some data.
+- for (;;)
+- {
+- // Try to complete the operation without blocking.
+- signed_size_type bytes = socket_ops::send(s, bufs, count, flags, ec);
+-
+- // Check if operation succeeded.
+- if (bytes >= 0)
+- return bytes;
+-
+- // Operation failed.
+- if ((state & user_set_non_blocking)
+- || (ec != asio::error::would_block
+- && ec != asio::error::try_again))
+- return 0;
+-
+- // Wait for socket to become ready.
+- if (socket_ops::poll_write(s, 0, -1, ec) < 0)
+- return 0;
+- }
++ // Write some data
++ signed_size_type bytes = socket_ops::send(s, bufs, count, flags, ec);
++ return bytes;
+ }
+
+ #if defined(ASIO_HAS_IOCP)