summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBilly Donahue <billy.donahue@mongodb.com>2021-07-27 01:50:52 -0400
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-08-06 04:15:30 +0000
commita57e4d409a81be929d4830199797c675322ae164 (patch)
treea7d41987d1264fab0fdc28914e8148f2912a5b1f
parent84a1a0599614a8c077da4b3899aba5a647202746 (diff)
downloadmongo-a57e4d409a81be929d4830199797c675322ae164.tar.gz
SERVER-58204 physical layout: session_asio.cpp and asio_utils.cpp
- bring GenericSocket into ASIOSession class - refactor/minimize AsyncHandlerHelper traits - no logging in headers - fix sign extension bugs and conform to Winsock select requirements
-rw-r--r--src/mongo/transport/SConscript2
-rw-r--r--src/mongo/transport/asio_utils.cpp321
-rw-r--r--src/mongo/transport/asio_utils.h437
-rw-r--r--src/mongo/transport/session_asio.cpp720
-rw-r--r--src/mongo/transport/session_asio.h677
-rw-r--r--src/mongo/transport/transport_layer_asio.cpp9
6 files changed, 1171 insertions, 995 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript
index b0dc247c673..d7898ee475e 100644
--- a/src/mongo/transport/SConscript
+++ b/src/mongo/transport/SConscript
@@ -53,6 +53,8 @@ tlEnv.Library(
target='transport_layer',
source=[
'transport_layer_asio.cpp',
+ 'asio_utils.cpp',
+ 'session_asio.cpp',
'transport_options.idl',
],
LIBDEPS=[
diff --git a/src/mongo/transport/asio_utils.cpp b/src/mongo/transport/asio_utils.cpp
new file mode 100644
index 00000000000..c954dc829eb
--- /dev/null
+++ b/src/mongo/transport/asio_utils.cpp
@@ -0,0 +1,321 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork
+#include "mongo/transport/asio_utils.h"
+
+#include "mongo/config.h"
+#include "mongo/logv2/log.h"
+
+namespace mongo::transport {
+
+Status errorCodeToStatus(const std::error_code& ec) {
+ if (!ec)
+ return Status::OK();
+
+ if (ec == asio::error::operation_aborted) {
+ return {ErrorCodes::CallbackCanceled, "Callback was canceled"};
+ }
+
+#ifdef _WIN32
+ if (ec == asio::error::timed_out) {
+#else
+ if (ec == asio::error::try_again || ec == asio::error::would_block) {
+#endif
+ return {ErrorCodes::NetworkTimeout, "Socket operation timed out"};
+ } else if (ec == asio::error::eof) {
+ return {ErrorCodes::HostUnreachable, "Connection closed by peer"};
+ } else if (ec == asio::error::connection_reset) {
+ return {ErrorCodes::HostUnreachable, "Connection reset by peer"};
+ } else if (ec == asio::error::network_reset) {
+ return {ErrorCodes::HostUnreachable, "Connection reset by network"};
+ }
+
+ // If the ec.category() is a mongoErrorCategory() then this error was propogated from
+ // mongodb code and we should just pass the error cdoe along as-is.
+ ErrorCodes::Error errorCode = (ec.category() == mongoErrorCategory())
+ ? ErrorCodes::Error(ec.value())
+ // Otherwise it's an error code from the network and we should pass it along as a
+ // SocketException
+ : ErrorCodes::SocketException;
+ // Either way, include the error message.
+ return {errorCode, ec.message()};
+}
+
+template <typename T>
+auto toUnsignedEquivalent(T x) {
+ return static_cast<std::make_unsigned_t<T>>(x);
+}
+
+template <typename Dur>
+timeval toTimeval(Dur dur) {
+ auto sec = duration_cast<Seconds>(dur);
+ timeval tv{};
+ tv.tv_sec = sec.count();
+ tv.tv_usec = duration_cast<Microseconds>(dur - sec).count();
+ return tv;
+}
+
+StatusWith<unsigned> pollASIOSocket(asio::generic::stream_protocol::socket& socket,
+ unsigned mask,
+ Milliseconds timeout) {
+#ifdef _WIN32
+ // On Windows, use `select` to approximate `poll`.
+ // Windows `select` has a couple special rules:
+ // - any empty fd_set args *must* be passed as nullptr.
+ // - the fd_set args can't *all* be nullptr.
+ struct FlagFdSet {
+ unsigned pollFlag;
+ fd_set fds;
+ };
+ std::array sets{
+ FlagFdSet{toUnsignedEquivalent(POLLIN)},
+ FlagFdSet{toUnsignedEquivalent(POLLOUT)},
+ FlagFdSet{toUnsignedEquivalent(POLLERR)},
+ };
+ auto fd = socket.native_handle();
+ mask |= POLLERR; // Always interested in errors.
+ for (auto& [pollFlag, fds] : sets) {
+ FD_ZERO(&fds);
+ if (mask & pollFlag)
+ FD_SET(fd, &fds);
+ }
+
+ auto timeoutTv = toTimeval(timeout);
+ auto fdsPtr = [&](size_t i) {
+ fd_set* ptr = &sets[i].fds;
+ return FD_ISSET(fd, ptr) ? ptr : nullptr;
+ };
+ int result = ::select(fd + 1, fdsPtr(0), fdsPtr(1), fdsPtr(2), &timeoutTv);
+ if (result == SOCKET_ERROR) {
+ auto errDesc = errnoWithDescription(WSAGetLastError());
+ return {ErrorCodes::InternalError, errDesc};
+ } else if (result == 0) {
+ return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"};
+ }
+
+ unsigned revents = 0;
+ for (auto& [pollFlag, fds] : sets)
+ if (FD_ISSET(fd, &fds))
+ revents |= pollFlag;
+ return revents;
+#else
+ pollfd pollItem = {};
+ pollItem.fd = socket.native_handle();
+ pollItem.events = mask;
+
+ int result;
+ boost::optional<Date_t> expiration;
+ if (timeout.count() > 0) {
+ expiration = Date_t::now() + timeout;
+ }
+ do {
+ Milliseconds curTimeout;
+ if (expiration) {
+ curTimeout = *expiration - Date_t::now();
+ if (curTimeout.count() <= 0) {
+ result = 0;
+ break;
+ }
+ } else {
+ curTimeout = timeout;
+ }
+ result = ::poll(&pollItem, 1, curTimeout.count());
+ } while (result == -1 && errno == EINTR);
+
+ if (result == -1) {
+ int errCode = errno;
+ return {ErrorCodes::InternalError, errnoWithDescription(errCode)};
+ } else if (result == 0) {
+ return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"};
+ }
+ return toUnsignedEquivalent(pollItem.revents);
+#endif
+}
+
+#ifdef MONGO_CONFIG_SSL
+boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(const asio::const_buffer& buffer) {
+ // This method's caller should have read in at least one MSGHEADER::Value's worth of data.
+ // The fragment we are about to examine must be strictly smaller.
+ static const size_t sizeOfTLSFragmentToRead = 11;
+ invariant(buffer.size() >= sizeOfTLSFragmentToRead);
+
+ static_assert(sizeOfTLSFragmentToRead < sizeof(MSGHEADER::Value),
+ "checkTLSRequest's caller read a MSGHEADER::Value, which must be larger than "
+ "message containing the TLS version");
+
+ /**
+ * The fragment we are to examine is a record, containing a handshake, containing a
+ * ClientHello. We wish to examine the advertised protocol version in the ClientHello.
+ * The following roughly describes the contents of these structures. Note that we do not
+ * need, or wish to, examine the entire ClientHello, we're looking exclusively for the
+ * client_version.
+ *
+ * Below is a rough description of the payload we will be examining. We shall perform some
+ * basic checks to ensure the payload matches these expectations. If it does not, we should
+ * bail out, and not emit protocol version alerts.
+ *
+ * enum {alert(21), handshake(22)} ContentType;
+ * TLSPlaintext {
+ * ContentType type = handshake(22),
+ * ProtocolVersion version; // Irrelevant. Clients send the real version in ClientHello.
+ * uint16 length;
+ * fragment, see Handshake stuct for contents
+ * ...
+ * }
+ *
+ * enum {client_hello(1)} HandshakeType;
+ * Handshake {
+ * HandshakeType msg_type = client_hello(1);
+ * uint24_t length;
+ * ClientHello body;
+ * }
+ *
+ * ClientHello {
+ * ProtocolVersion client_version; // <- This is the value we want to extract.
+ * }
+ */
+
+ static const std::uint8_t ContentType_handshake = 22;
+ static const std::uint8_t HandshakeType_client_hello = 1;
+
+ using ProtocolVersion = std::array<std::uint8_t, 2>;
+ static const ProtocolVersion tls10VersionBytes{3, 1};
+ static const ProtocolVersion tls11VersionBytes{3, 2};
+
+ auto request = reinterpret_cast<const char*>(buffer.data());
+ auto cdr = ConstDataRangeCursor(request, request + buffer.size());
+
+ // Parse the record header.
+ // Extract the ContentType from the header, and ensure it is a handshake.
+ StatusWith<std::uint8_t> record_ContentType = cdr.readAndAdvanceNoThrow<std::uint8_t>();
+ if (!record_ContentType.isOK() || record_ContentType.getValue() != ContentType_handshake) {
+ return boost::none;
+ }
+ // Skip the record's ProtocolVersion. Clients tend to send TLS 1.0 in
+ // the record, but then their real protocol version in the enclosed ClientHello.
+ StatusWith<ProtocolVersion> record_protocol_version =
+ cdr.readAndAdvanceNoThrow<ProtocolVersion>();
+ if (!record_protocol_version.isOK()) {
+ return boost::none;
+ }
+ // Parse the record length. It should be be larger than the remaining expected payload.
+ auto record_length = cdr.readAndAdvanceNoThrow<BigEndian<std::uint16_t>>();
+ if (!record_length.isOK() || record_length.getValue() < cdr.length()) {
+ return boost::none;
+ }
+
+ // Parse the handshake header.
+ // Extract the HandshakeType, and ensure it is a ClientHello.
+ StatusWith<std::uint8_t> handshake_type = cdr.readAndAdvanceNoThrow<std::uint8_t>();
+ if (!handshake_type.isOK() || handshake_type.getValue() != HandshakeType_client_hello) {
+ return boost::none;
+ }
+ // Extract the handshake length, and ensure it is larger than the remaining expected
+ // payload. This requires a little work because the packet represents it with a uint24_t.
+ StatusWith<std::array<std::uint8_t, 3>> handshake_length_bytes =
+ cdr.readAndAdvanceNoThrow<std::array<std::uint8_t, 3>>();
+ if (!handshake_length_bytes.isOK()) {
+ return boost::none;
+ }
+ std::uint32_t handshake_length = 0;
+ for (std::uint8_t handshake_byte : handshake_length_bytes.getValue()) {
+ handshake_length <<= 8;
+ handshake_length |= handshake_byte;
+ }
+ if (handshake_length < cdr.length()) {
+ return boost::none;
+ }
+ StatusWith<ProtocolVersion> client_version = cdr.readAndAdvanceNoThrow<ProtocolVersion>();
+ if (!client_version.isOK()) {
+ return boost::none;
+ }
+
+ // Invariant: We read exactly as much data as expected.
+ invariant((cdr.data() - request) == sizeOfTLSFragmentToRead);
+
+ auto isProtocolDisabled = [](SSLParams::Protocols protocol) {
+ const auto& params = getSSLGlobalParams();
+ return std::find(params.sslDisabledProtocols.begin(),
+ params.sslDisabledProtocols.end(),
+ protocol) != params.sslDisabledProtocols.end();
+ };
+
+ auto makeTLSProtocolVersionAlert =
+ [](const std::array<std::uint8_t, 2>& versionBytes) -> std::array<std::uint8_t, 7> {
+ /**
+ * The structure for this alert packet is as follows:
+ * TLSPlaintext {
+ * ContentType type = alert(21);
+ * ProtocolVersion = versionBytes;
+ * uint16_t length = 2
+ * fragment = AlertDescription {
+ * AlertLevel level = fatal(2);
+ * AlertDescription = protocol_version(70);
+ * }
+ *
+ */
+ return std::array<std::uint8_t, 7>{
+ 0x15, versionBytes[0], versionBytes[1], 0x00, 0x02, 0x02, 0x46};
+ };
+
+ ProtocolVersion version = client_version.getValue();
+ if (version == tls10VersionBytes && isProtocolDisabled(SSLParams::Protocols::TLS1_0)) {
+ return makeTLSProtocolVersionAlert(version);
+ } else if (client_version == tls11VersionBytes &&
+ isProtocolDisabled(SSLParams::Protocols::TLS1_1)) {
+ return makeTLSProtocolVersionAlert(version);
+ }
+ // TLS1.2 cannot be distinguished from TLS1.3, just by looking at the ProtocolVersion bytes.
+ // TLS 1.3 compatible clients advertise a "supported_versions" extension, which we would
+ // have to extract here.
+ // Hopefully by the time this matters, OpenSSL will properly emit protocol_version alerts.
+
+ return boost::none;
+}
+#endif
+
+void failedSetSocketOption(const std::system_error& ex,
+ StringData note,
+ BSONObj optionDescription) {
+ LOGV2_INFO(5693100,
+ "Asio socket.set_option failed with std::system_error",
+ "note"_attr = note,
+ "option"_attr = optionDescription,
+ "error"_attr = [&ex] {
+ return BSONObjBuilder{}
+ .append("what", ex.what())
+ .append("message", ex.code().message())
+ .append("category", ex.code().category().name())
+ .append("value", ex.code().value())
+ .obj();
+ }());
+}
+
+} // namespace mongo::transport
diff --git a/src/mongo/transport/asio_utils.h b/src/mongo/transport/asio_utils.h
index e903843d43c..2475e90154b 100644
--- a/src/mongo/transport/asio_utils.h
+++ b/src/mongo/transport/asio_utils.h
@@ -39,15 +39,16 @@
#include "mongo/base/string_data.h"
#include "mongo/base/system_error.h"
#include "mongo/config.h"
+#include "mongo/stdx/type_traits.h"
#include "mongo/util/errno_util.h"
#include "mongo/util/future.h"
#include "mongo/util/hex.h"
#include "mongo/util/net/hostandport.h"
#include "mongo/util/net/sockaddr.h"
#include "mongo/util/net/ssl_manager.h"
+#include "mongo/util/net/ssl_options.h"
-namespace mongo {
-namespace transport {
+namespace mongo::transport {
inline SockAddr endpointToSockAddr(const asio::generic::stream_protocol::endpoint& endPoint) {
SockAddr wrappedAddr(endPoint.data(), endPoint.size());
@@ -59,38 +60,7 @@ inline HostAndPort endpointToHostAndPort(const asio::generic::stream_protocol::e
return HostAndPort(endpointToSockAddr(endPoint).toString(true));
}
-inline Status errorCodeToStatus(const std::error_code& ec) {
- if (!ec)
- return Status::OK();
-
- if (ec == asio::error::operation_aborted) {
- return {ErrorCodes::CallbackCanceled, "Callback was canceled"};
- }
-
-#ifdef _WIN32
- if (ec == asio::error::timed_out) {
-#else
- if (ec == asio::error::try_again || ec == asio::error::would_block) {
-#endif
- return {ErrorCodes::NetworkTimeout, "Socket operation timed out"};
- } else if (ec == asio::error::eof) {
- return {ErrorCodes::HostUnreachable, "Connection closed by peer"};
- } else if (ec == asio::error::connection_reset) {
- return {ErrorCodes::HostUnreachable, "Connection reset by peer"};
- } else if (ec == asio::error::network_reset) {
- return {ErrorCodes::HostUnreachable, "Connection reset by network"};
- }
-
- // If the ec.category() is a mongoErrorCategory() then this error was propogated from
- // mongodb code and we should just pass the error cdoe along as-is.
- ErrorCodes::Error errorCode = (ec.category() == mongoErrorCategory())
- ? ErrorCodes::Error(ec.value())
- // Otherwise it's an error code from the network and we should pass it along as a
- // SocketException
- : ErrorCodes::SocketException;
- // Either way, include the error message.
- return {errorCode, ec.message()};
-}
+Status errorCodeToStatus(const std::error_code& ec);
/*
* The ASIO implementation of poll (i.e. socket.wait()) cannot poll for a mask of events, and
@@ -103,227 +73,25 @@ inline Status errorCodeToStatus(const std::error_code& ec) {
* check whether it matches the expected events mask.
* - On error: it returns a Status(ErrorCodes::InternalError)
*/
-template <typename Socket, typename EventsMask>
-StatusWith<EventsMask> pollASIOSocket(Socket& socket, EventsMask mask, Milliseconds timeout) {
-#ifdef _WIN32
- fd_set readfds;
- fd_set writefds;
- fd_set errfds;
-
- FD_ZERO(&readfds);
- FD_ZERO(&writefds);
- FD_ZERO(&errfds);
-
- auto fd = socket.native_handle();
- if (mask & POLLIN) {
- FD_SET(fd, &readfds);
- }
- if (mask & POLLOUT) {
- FD_SET(fd, &writefds);
- }
- FD_SET(fd, &errfds);
-
- timeval timeoutTv{};
- auto timeoutUs = duration_cast<Microseconds>(timeout);
- if (timeoutUs >= Seconds{1}) {
- auto timeoutSec = duration_cast<Seconds>(timeoutUs);
- timeoutTv.tv_sec = timeoutSec.count();
- timeoutUs -= timeoutSec;
- }
- timeoutTv.tv_usec = timeoutUs.count();
- int result = ::select(1, &readfds, &writefds, &errfds, &timeoutTv);
- if (result == SOCKET_ERROR) {
- auto errDesc = errnoWithDescription(WSAGetLastError());
- return {ErrorCodes::InternalError, errDesc};
- }
- int revents = (FD_ISSET(fd, &readfds) ? POLLIN : 0) | (FD_ISSET(fd, &writefds) ? POLLOUT : 0) |
- (FD_ISSET(fd, &errfds) ? POLLERR : 0);
-#else
- pollfd pollItem = {};
- pollItem.fd = socket.native_handle();
- pollItem.events = mask;
-
- int result;
- boost::optional<Date_t> expiration;
- if (timeout.count() > 0) {
- expiration = Date_t::now() + timeout;
- }
- do {
- Milliseconds curTimeout;
- if (expiration) {
- curTimeout = *expiration - Date_t::now();
- if (curTimeout.count() <= 0) {
- result = 0;
- break;
- }
- } else {
- curTimeout = timeout;
- }
- result = ::poll(&pollItem, 1, curTimeout.count());
- } while (result == -1 && errno == EINTR);
-
- if (result == -1) {
- int errCode = errno;
- return {ErrorCodes::InternalError, errnoWithDescription(errCode)};
- }
- int revents = pollItem.revents;
-#endif
-
- if (result == 0) {
- return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"};
- } else {
- return revents;
- }
-}
+StatusWith<unsigned> pollASIOSocket(asio::generic::stream_protocol::socket& socket,
+ unsigned mask,
+ Milliseconds timeout);
#ifdef MONGO_CONFIG_SSL
/**
* Peeks at a fragment of a client issued TLS handshake packet. Returns a TLS alert
* packet if the client has selected a protocol which has been disabled by the server.
*/
-template <typename Buffer>
-boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(const Buffer& buffers) {
- // This method's caller should have read in at least one MSGHEADER::Value's worth of data.
- // The fragment we are about to examine must be strictly smaller.
- static const size_t sizeOfTLSFragmentToRead = 11;
- invariant(asio::buffer_size(buffers) >= sizeOfTLSFragmentToRead);
-
- static_assert(sizeOfTLSFragmentToRead < sizeof(MSGHEADER::Value),
- "checkTLSRequest's caller read a MSGHEADER::Value, which must be larger than "
- "message containing the TLS version");
-
- /**
- * The fragment we are to examine is a record, containing a handshake, containing a
- * ClientHello. We wish to examine the advertised protocol version in the ClientHello.
- * The following roughly describes the contents of these structures. Note that we do not
- * need, or wish to, examine the entire ClientHello, we're looking exclusively for the
- * client_version.
- *
- * Below is a rough description of the payload we will be examining. We shall perform some
- * basic checks to ensure the payload matches these expectations. If it does not, we should
- * bail out, and not emit protocol version alerts.
- *
- * enum {alert(21), handshake(22)} ContentType;
- * TLSPlaintext {
- * ContentType type = handshake(22),
- * ProtocolVersion version; // Irrelevant. Clients send the real version in ClientHello.
- * uint16 length;
- * fragment, see Handshake stuct for contents
- * ...
- * }
- *
- * enum {client_hello(1)} HandshakeType;
- * Handshake {
- * HandshakeType msg_type = client_hello(1);
- * uint24_t length;
- * ClientHello body;
- * }
- *
- * ClientHello {
- * ProtocolVersion client_version; // <- This is the value we want to extract.
- * }
- */
-
- static const std::uint8_t ContentType_handshake = 22;
- static const std::uint8_t HandshakeType_client_hello = 1;
-
- using ProtocolVersion = std::array<std::uint8_t, 2>;
- static const ProtocolVersion tls10VersionBytes{3, 1};
- static const ProtocolVersion tls11VersionBytes{3, 2};
-
- auto request = asio::buffer_cast<const char*>(buffers);
- auto cdr = ConstDataRangeCursor(request, request + asio::buffer_size(buffers));
-
- // Parse the record header.
- // Extract the ContentType from the header, and ensure it is a handshake.
- StatusWith<std::uint8_t> record_ContentType = cdr.readAndAdvanceNoThrow<std::uint8_t>();
- if (!record_ContentType.isOK() || record_ContentType.getValue() != ContentType_handshake) {
- return boost::none;
- }
- // Skip the record's ProtocolVersion. Clients tend to send TLS 1.0 in
- // the record, but then their real protocol version in the enclosed ClientHello.
- StatusWith<ProtocolVersion> record_protocol_version =
- cdr.readAndAdvanceNoThrow<ProtocolVersion>();
- if (!record_protocol_version.isOK()) {
- return boost::none;
- }
- // Parse the record length. It should be be larger than the remaining expected payload.
- auto record_length = cdr.readAndAdvanceNoThrow<BigEndian<std::uint16_t>>();
- if (!record_length.isOK() || record_length.getValue() < cdr.length()) {
- return boost::none;
- }
-
- // Parse the handshake header.
- // Extract the HandshakeType, and ensure it is a ClientHello.
- StatusWith<std::uint8_t> handshake_type = cdr.readAndAdvanceNoThrow<std::uint8_t>();
- if (!handshake_type.isOK() || handshake_type.getValue() != HandshakeType_client_hello) {
- return boost::none;
- }
- // Extract the handshake length, and ensure it is larger than the remaining expected
- // payload. This requires a little work because the packet represents it with a uint24_t.
- StatusWith<std::array<std::uint8_t, 3>> handshake_length_bytes =
- cdr.readAndAdvanceNoThrow<std::array<std::uint8_t, 3>>();
- if (!handshake_length_bytes.isOK()) {
- return boost::none;
- }
- std::uint32_t handshake_length = 0;
- for (std::uint8_t handshake_byte : handshake_length_bytes.getValue()) {
- handshake_length <<= 8;
- handshake_length |= handshake_byte;
- }
- if (handshake_length < cdr.length()) {
- return boost::none;
- }
- StatusWith<ProtocolVersion> client_version = cdr.readAndAdvanceNoThrow<ProtocolVersion>();
- if (!client_version.isOK()) {
- return boost::none;
- }
-
- // Invariant: We read exactly as much data as expected.
- invariant((cdr.data() - request) == sizeOfTLSFragmentToRead);
-
- auto isProtocolDisabled = [](SSLParams::Protocols protocol) {
- const auto& params = getSSLGlobalParams();
- return std::find(params.sslDisabledProtocols.begin(),
- params.sslDisabledProtocols.end(),
- protocol) != params.sslDisabledProtocols.end();
- };
-
- auto makeTLSProtocolVersionAlert =
- [](const std::array<std::uint8_t, 2>& versionBytes) -> std::array<std::uint8_t, 7> {
- /**
- * The structure for this alert packet is as follows:
- * TLSPlaintext {
- * ContentType type = alert(21);
- * ProtocolVersion = versionBytes;
- * uint16_t length = 2
- * fragment = AlertDescription {
- * AlertLevel level = fatal(2);
- * AlertDescription = protocol_version(70);
- * }
- *
- */
- return std::array<std::uint8_t, 7>{
- 0x15, versionBytes[0], versionBytes[1], 0x00, 0x02, 0x02, 0x46};
- };
-
- ProtocolVersion version = client_version.getValue();
- if (version == tls10VersionBytes && isProtocolDisabled(SSLParams::Protocols::TLS1_0)) {
- return makeTLSProtocolVersionAlert(version);
- } else if (client_version == tls11VersionBytes &&
- isProtocolDisabled(SSLParams::Protocols::TLS1_1)) {
- return makeTLSProtocolVersionAlert(version);
- }
- // TLS1.2 cannot be distinguished from TLS1.3, just by looking at the ProtocolVersion bytes.
- // TLS 1.3 compatible clients advertise a "supported_versions" extension, which we would
- // have to extract here.
- // Hopefully by the time this matters, OpenSSL will properly emit protocol_version alerts.
-
- return boost::none;
-}
+boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(const asio::const_buffer& buffer);
#endif
/**
+ * setSocketOption failed. Log the error.
+ * This is in the .cpp file just to keep LOGV2 out of this header.
+ */
+void failedSetSocketOption(const std::system_error& ex, StringData note, BSONObj optionDescription);
+
+/**
* Calls Asio `socket.set_option(opt)` with better failure diagnostics.
* To be used instead of Asio `socket.set_option``, because errors are hard to diagnose.
* Emits a log message about what option was attempted and what went wrong with
@@ -342,26 +110,14 @@ void setSocketOption(Socket& socket, const Option& opt, StringData note) {
try {
socket.set_option(opt);
} catch (const std::system_error& ex) {
- LOGV2_INFO(5693100,
- "Asio socket.set_option failed with std::system_error",
- "note"_attr = note,
- "option"_attr =
- [&opt, p = socket.local_endpoint().protocol()] {
- return BSONObjBuilder{}
- .append("level", opt.level(p))
- .append("name", opt.name(p))
- .append("data", hexdump(opt.data(p), opt.size(p)))
- .obj();
- }(),
- "error"_attr =
- [&ex] {
- return BSONObjBuilder{}
- .append("what", ex.what())
- .append("message", ex.code().message())
- .append("category", ex.code().category().name())
- .append("value", ex.code().value())
- .obj();
- }());
+ BSONObj optionDescription = [&opt, p = socket.local_endpoint().protocol()] {
+ return BSONObjBuilder{}
+ .append("level", opt.level(p))
+ .append("name", opt.name(p))
+ .append("data", hexdump(opt.data(p), opt.size(p)))
+ .obj();
+ }();
+ failedSetSocketOption(ex, note, optionDescription);
throw;
}
}
@@ -390,106 +146,89 @@ void setSocketOption(Socket& socket, const Option& opt, std::error_code& ec, Str
* Example:
* Future<size_t> future = my_socket.async_read_some(my_buffer, UseFuture{});
*/
-struct UseFuture {};
-
-namespace use_future_details {
-
-template <typename... Args>
-struct AsyncHandlerHelper {
- using Result = std::tuple<Args...>;
- static void complete(Promise<Result>* promise, Args... args) {
- promise->emplaceValue(args...);
- }
+struct UseFuture {
+ template <typename... Args>
+ class Adapter;
};
-template <>
-struct AsyncHandlerHelper<> {
- using Result = void;
- static void complete(Promise<Result>* promise) {
- promise->emplaceValue();
- }
-};
-
-template <typename Arg>
-struct AsyncHandlerHelper<Arg> {
- using Result = Arg;
- static void complete(Promise<Result>* promise, Arg arg) {
- promise->emplaceValue(arg);
- }
-};
+template <typename... ArgsFromAsio>
+class UseFuture::Adapter {
+private:
+ template <typename Dum, typename... Ts>
+ struct ArgPack : stdx::type_identity<std::tuple<Ts...>> {};
+ template <typename Dum>
+ struct ArgPack<Dum> : stdx::type_identity<void> {};
+ template <typename Dum, typename T>
+ struct ArgPack<Dum, T> : stdx::type_identity<T> {};
-template <typename... Args>
-struct AsyncHandlerHelper<std::error_code, Args...> {
- using Helper = AsyncHandlerHelper<Args...>;
- using Result = typename Helper::Result;
-
- template <typename... Args2>
- static void complete(Promise<Result>* promise, std::error_code ec, Args2&&... args) {
- if (ec) {
- promise->setError(errorCodeToStatus(ec));
- } else {
- Helper::complete(promise, std::forward<Args2>(args)...);
+ /**
+ * If an Asio callback takes a leading error_code, it's stripped from
+ * the Future's value_type. Any errors reported by Asio will instead
+ * be delivered by setting the Future's error Status.
+ */
+ template <typename Dum, typename... Ts>
+ struct StripError : ArgPack<Dum, Ts...> {};
+ template <typename Dum, typename... Ts>
+ struct StripError<Dum, std::error_code, Ts...> : ArgPack<Dum, Ts...> {};
+
+ using Result = typename StripError<void, ArgsFromAsio...>::type;
+
+ struct Handler {
+ private:
+ template <typename... As>
+ void _onSuccess(As&&... args) {
+ promise.emplaceValue(std::forward<As>(args)...);
}
- }
-};
-
-template <>
-struct AsyncHandlerHelper<std::error_code> {
- using Result = void;
- static void complete(Promise<Result>* promise, std::error_code ec) {
- if (ec) {
- promise->setError(errorCodeToStatus(ec));
- } else {
- promise->emplaceValue();
+ template <typename... As>
+ void _onInvoke(As&&... args) {
+ _onSuccess(std::forward<As>(args)...);
+ }
+ template <typename... As>
+ void _onInvoke(std::error_code ec, As&&... args) {
+ if (ec) {
+ promise.setError(errorCodeToStatus(ec));
+ return;
+ }
+ _onSuccess(std::forward<As>(args)...);
}
- }
-};
-
-template <typename... Args>
-struct AsyncHandler {
- using Helper = AsyncHandlerHelper<Args...>;
- using Result = typename Helper::Result;
- explicit AsyncHandler(UseFuture) {}
+ public:
+ explicit Handler(const UseFuture&) {}
- template <typename... Args2>
- void operator()(Args2&&... args) {
- Helper::complete(&promise, std::forward<Args2>(args)...);
- }
+ template <typename... As>
+ void operator()(As&&... args) {
+ static_assert((std::is_same_v<std::decay_t<As>, std::decay_t<ArgsFromAsio>> && ...),
+ "Unexpected argument list from Asio async result callback.");
+ _onInvoke(std::forward<As>(args)...);
+ }
- Promise<Result> promise;
-};
+ Promise<Result> promise;
+ };
-template <typename... Args>
-struct AsyncResult {
- using completion_handler_type = AsyncHandler<Args...>;
- using RealResult = typename AsyncHandler<Args...>::Result;
- using return_type = Future<RealResult>;
+public:
+ using return_type = Future<Result>;
+ using completion_handler_type = Handler;
- explicit AsyncResult(completion_handler_type& handler) {
- auto pf = makePromiseFuture<RealResult>();
- fut = std::move(pf.future);
- handler.promise = std::move(pf.promise);
+ explicit Adapter(Handler& handler) {
+ auto&& [p, f] = makePromiseFuture<Result>();
+ _fut = std::move(f);
+ handler.promise = std::move(p);
}
- auto get() {
- return std::move(fut);
+ return_type get() {
+ return std::move(_fut);
}
- Future<RealResult> fut;
+private:
+ Future<Result> _fut;
};
-} // namespace use_future_details
-} // namespace transport
-} // namespace mongo
+} // namespace mongo::transport
namespace asio {
-template <typename Comp, typename Sig>
-class async_result;
-
-template <typename Result, typename... Args>
-class async_result<::mongo::transport::UseFuture, Result(Args...)>
- : public ::mongo::transport::use_future_details::AsyncResult<Args...> {
- using ::mongo::transport::use_future_details::AsyncResult<Args...>::AsyncResult;
+template <typename... Args>
+class async_result<mongo::transport::UseFuture, void(Args...)>
+ : public mongo::transport::UseFuture::Adapter<Args...> {
+ using mongo::transport::UseFuture::Adapter<Args...>::Adapter;
};
} // namespace asio
diff --git a/src/mongo/transport/session_asio.cpp b/src/mongo/transport/session_asio.cpp
new file mode 100644
index 00000000000..b4f27331874
--- /dev/null
+++ b/src/mongo/transport/session_asio.cpp
@@ -0,0 +1,720 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork
+
+#include "mongo/transport/session_asio.h"
+
+#include "mongo/config.h"
+#include "mongo/logv2/log.h"
+
+namespace mongo::transport {
+
+MONGO_FAIL_POINT_DEFINE(transportLayerASIOshortOpportunisticReadWrite);
+
+namespace {
+
+template <int Name>
+class ASIOSocketTimeoutOption {
+public:
+#ifdef _WIN32
+ using TimeoutType = DWORD;
+
+ ASIOSocketTimeoutOption(Milliseconds timeoutVal) : _timeout(timeoutVal.count()) {}
+
+#else
+ using TimeoutType = timeval;
+
+ ASIOSocketTimeoutOption(Milliseconds timeoutVal) {
+ _timeout.tv_sec = duration_cast<Seconds>(timeoutVal).count();
+ const auto minusSeconds = timeoutVal - Seconds{_timeout.tv_sec};
+ _timeout.tv_usec = duration_cast<Microseconds>(minusSeconds).count();
+ }
+#endif
+
+ template <typename Protocol>
+ int name(const Protocol&) const {
+ return Name;
+ }
+
+ template <typename Protocol>
+ const TimeoutType* data(const Protocol&) const {
+ return &_timeout;
+ }
+
+ template <typename Protocol>
+ std::size_t size(const Protocol&) const {
+ return sizeof(_timeout);
+ }
+
+ template <typename Protocol>
+ int level(const Protocol&) const {
+ return SOL_SOCKET;
+ }
+
+private:
+ TimeoutType _timeout;
+};
+
+} // namespace
+
+
+TransportLayerASIO::ASIOSession::ASIOSession(
+ TransportLayerASIO* tl,
+ GenericSocket socket,
+ bool isIngressSession,
+ Endpoint endpoint,
+ std::shared_ptr<const SSLConnectionContext> transientSSLContext) try
+ : _socket(std::move(socket)),
+ _tl(tl),
+ _isIngressSession(isIngressSession) {
+ auto family = endpointToSockAddr(_socket.local_endpoint()).getType();
+ if (family == AF_INET || family == AF_INET6) {
+ setSocketOption(_socket, asio::ip::tcp::no_delay(true), "session no delay");
+ setSocketOption(_socket, asio::socket_base::keep_alive(true), "session keep alive");
+ setSocketKeepAliveParams(_socket.native_handle());
+ }
+
+ _localAddr = endpointToSockAddr(_socket.local_endpoint());
+
+ if (endpoint == Endpoint()) {
+ // Inbound connection, query socket for remote.
+ _remoteAddr = endpointToSockAddr(_socket.remote_endpoint());
+ } else {
+ // Outbound connection, get remote from resolved endpoint.
+ // Necessary for TCP_FASTOPEN where the remote isn't connected yet.
+ _remoteAddr = endpointToSockAddr(endpoint);
+ }
+
+ _local = HostAndPort(_localAddr.toString(true));
+ _remote = HostAndPort(_remoteAddr.toString(true));
+#ifdef MONGO_CONFIG_SSL
+ _sslContext = transientSSLContext ? transientSSLContext : *tl->_sslContext;
+ if (transientSSLContext) {
+ logv2::DynamicAttributes attrs;
+ if (transientSSLContext->targetClusterURI) {
+ attrs.add("targetClusterURI", *transientSSLContext->targetClusterURI);
+ }
+ attrs.add("isIngress", isIngressSession);
+ attrs.add("connectionId", id());
+ attrs.add("remote", remote());
+ LOGV2(5271001, "Initializing the ASIOSession with transient SSL context", attrs);
+ }
+#endif
+} catch (const DBException&) {
+ throw;
+} catch (const asio::system_error& error) {
+ uasserted(ErrorCodes::SocketException, error.what());
+} catch (...) {
+ uasserted(50797, str::stream() << "Unknown exception while configuring socket.");
+}
+
+void TransportLayerASIO::ASIOSession::end() {
+ if (getSocket().is_open()) {
+ std::error_code ec;
+ getSocket().shutdown(GenericSocket::shutdown_both, ec);
+ if ((ec) && (ec != asio::error::not_connected)) {
+ LOGV2_ERROR(23841,
+ "Error shutting down socket: {error}",
+ "Error shutting down socket",
+ "error"_attr = ec.message());
+ }
+ }
+}
+
+StatusWith<Message> TransportLayerASIO::ASIOSession::sourceMessage() noexcept try {
+ ensureSync();
+ return sourceMessageImpl().getNoThrow();
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
+
+Future<Message> TransportLayerASIO::ASIOSession::asyncSourceMessage(
+ const BatonHandle& baton) noexcept try {
+ ensureAsync();
+ return sourceMessageImpl(baton);
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
+
+Status TransportLayerASIO::ASIOSession::waitForData() noexcept try {
+ ensureSync();
+ asio::error_code ec;
+ getSocket().wait(asio::ip::tcp::socket::wait_read, ec);
+ return errorCodeToStatus(ec);
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
+
+Future<void> TransportLayerASIO::ASIOSession::asyncWaitForData() noexcept try {
+ ensureAsync();
+ return getSocket().async_wait(asio::ip::tcp::socket::wait_read, UseFuture{});
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
+
+Status TransportLayerASIO::ASIOSession::sinkMessage(Message message) noexcept try {
+ ensureSync();
+
+ return write(asio::buffer(message.buf(), message.size()))
+ .then([this, &message] {
+ if (_isIngressSession) {
+ networkCounter.hitPhysicalOut(message.size());
+ }
+ })
+ .getNoThrow();
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
+
+Future<void> TransportLayerASIO::ASIOSession::asyncSinkMessage(
+ Message message, const BatonHandle& baton) noexcept try {
+ ensureAsync();
+ return write(asio::buffer(message.buf(), message.size()), baton)
+ .then([this, message /*keep the buffer alive*/]() {
+ if (_isIngressSession) {
+ networkCounter.hitPhysicalOut(message.size());
+ }
+ });
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
+
+void TransportLayerASIO::ASIOSession::cancelAsyncOperations(const BatonHandle& baton) {
+ LOGV2_DEBUG(4615608,
+ 3,
+ "Cancelling outstanding I/O operations on connection to {remote}",
+ "Cancelling outstanding I/O operations on connection to remote",
+ "remote"_attr = _remote);
+ if (baton && baton->networking() && baton->networking()->cancelSession(*this)) {
+ // If we have a baton, it was for networking, and it owned our session, then we're done.
+ return;
+ }
+
+ getSocket().cancel();
+}
+
+void TransportLayerASIO::ASIOSession::setTimeout(boost::optional<Milliseconds> timeout) {
+ invariant(!timeout || timeout->count() > 0);
+ _configuredTimeout = timeout;
+}
+
+bool TransportLayerASIO::ASIOSession::isConnected() {
+ // socket.is_open() only returns whether the socket is a valid file descriptor and
+ // if we haven't marked this socket as closed already.
+ if (!getSocket().is_open())
+ return false;
+
+ auto swPollEvents = pollASIOSocket(getSocket(), POLLIN, Milliseconds{0});
+ if (!swPollEvents.isOK()) {
+ if (swPollEvents != ErrorCodes::NetworkTimeout) {
+ LOGV2_WARNING(4615609,
+ "Failed to poll socket for connectivity check: {error}",
+ "Failed to poll socket for connectivity check",
+ "error"_attr = swPollEvents.getStatus());
+ return false;
+ }
+ return true;
+ }
+
+ auto revents = swPollEvents.getValue();
+ if (revents & POLLIN) {
+ char testByte;
+ int size = ::recv(getSocket().native_handle(), &testByte, sizeof(testByte), MSG_PEEK);
+ if (size == sizeof(testByte)) {
+ return true;
+ } else if (size == -1) {
+ LOGV2_WARNING(4615610,
+ "Failed to check socket connectivity: {error}",
+ "Failed to check socket connectivity",
+ "error"_attr = errnoWithDescription(errno));
+ }
+ // If size == 0 then we got disconnected and we should return false.
+ }
+
+ return false;
+}
+
+#ifdef MONGO_CONFIG_SSL
+const SSLConfiguration* TransportLayerASIO::ASIOSession::getSSLConfiguration() const {
+ if (_sslContext->manager) {
+ return &_sslContext->manager->getSSLConfiguration();
+ }
+ return nullptr;
+}
+
+const std::shared_ptr<SSLManagerInterface> TransportLayerASIO::ASIOSession::getSSLManager() const {
+ return _sslContext->manager;
+}
+
+// The unique_lock here is held by TransportLayerASIO to synchronize with the asyncConnect
+// timeout callback. It will be unlocked before the SSL actually handshake begins.
+Future<void> TransportLayerASIO::ASIOSession::handshakeSSLForEgressWithLock(
+ stdx::unique_lock<Latch> lk, const HostAndPort& target, const ReactorHandle& reactor) {
+ if (!_sslContext->egress) {
+ return Future<void>::makeReady(
+ Status(ErrorCodes::SSLHandshakeFailed, "SSL requested but SSL support is disabled"));
+ }
+
+ _sslSocket.emplace(std::move(_socket), *_sslContext->egress, removeFQDNRoot(target.host()));
+ lk.unlock();
+
+ auto doHandshake = [&] {
+ if (_blockingMode == Sync) {
+ std::error_code ec;
+ _sslSocket->handshake(asio::ssl::stream_base::client, ec);
+ return futurize(ec);
+ } else {
+ return _sslSocket->async_handshake(asio::ssl::stream_base::client, UseFuture{});
+ }
+ };
+ return doHandshake().then([this, target, reactor] {
+ _ranHandshake = true;
+
+ return getSSLManager()
+ ->parseAndValidatePeerCertificate(
+ _sslSocket->native_handle(), _sslSocket->get_sni(), target.host(), target, reactor)
+ .then([this](SSLPeerInfo info) { SSLPeerInfo::forSession(shared_from_this()) = info; });
+ });
+}
+
+// For synchronous connections where we don't have an async timer, just take a dummy lock and
+// pass it to the WithLock version of handshakeSSLForEgress
+Future<void> TransportLayerASIO::ASIOSession::handshakeSSLForEgress(const HostAndPort& target) {
+ auto mutex = MONGO_MAKE_LATCH();
+ return handshakeSSLForEgressWithLock(stdx::unique_lock<Latch>(mutex), target, nullptr);
+}
+#endif
+
+void TransportLayerASIO::ASIOSession::ensureSync() {
+ asio::error_code ec;
+ if (_blockingMode != Sync) {
+ getSocket().non_blocking(false, ec);
+ fassert(40490, errorCodeToStatus(ec));
+ _blockingMode = Sync;
+ }
+
+ if (_socketTimeout != _configuredTimeout) {
+ // Change boost::none (which means no timeout) into a zero value for the socket option,
+ // which also means no timeout.
+ auto timeout = _configuredTimeout.value_or(Milliseconds{0});
+ setSocketOption(
+ getSocket(), ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), ec, "session send timeout");
+ if (auto status = errorCodeToStatus(ec); !status.isOK()) {
+ tasserted(5342000, status.reason());
+ }
+
+ setSocketOption(getSocket(),
+ ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout),
+ ec,
+ "session receive timeout");
+ if (auto status = errorCodeToStatus(ec); !status.isOK()) {
+ tasserted(5342001, status.reason());
+ }
+
+ _socketTimeout = _configuredTimeout;
+ }
+}
+
+void TransportLayerASIO::ASIOSession::ensureAsync() {
+ if (_blockingMode == Async)
+ return;
+
+ // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't
+ // expecting a socket timeout when they do an async operation.
+ invariant(!_configuredTimeout);
+
+ asio::error_code ec;
+ getSocket().non_blocking(true, ec);
+ fassert(50706, errorCodeToStatus(ec));
+ _blockingMode = Async;
+}
+
+auto TransportLayerASIO::ASIOSession::getSocket() -> GenericSocket& {
+#ifdef MONGO_CONFIG_SSL
+ if (_sslSocket) {
+ return static_cast<GenericSocket&>(_sslSocket->lowest_layer());
+ }
+#endif
+ return _socket;
+}
+
+Future<Message> TransportLayerASIO::ASIOSession::sourceMessageImpl(const BatonHandle& baton) {
+ static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value);
+
+ auto headerBuffer = SharedBuffer::allocate(kHeaderSize);
+ auto ptr = headerBuffer.get();
+ return read(asio::buffer(ptr, kHeaderSize), baton)
+ .then([headerBuffer = std::move(headerBuffer), this, baton]() mutable {
+ if (checkForHTTPRequest(asio::buffer(headerBuffer.get(), kHeaderSize))) {
+ return sendHTTPResponse(baton);
+ }
+
+ const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength());
+ if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) {
+ StringBuilder sb;
+ sb << "recv(): message msgLen " << msgLen << " is invalid. "
+ << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes;
+ const auto str = sb.str();
+ LOGV2(4615638,
+ "recv(): message msgLen {msgLen} is invalid. Min: {min} Max: {max}",
+ "recv(): message mstLen is invalid.",
+ "msgLen"_attr = msgLen,
+ "min"_attr = kHeaderSize,
+ "max"_attr = MaxMessageSizeBytes);
+
+ return Future<Message>::makeReady(Status(ErrorCodes::ProtocolError, str));
+ }
+
+ if (msgLen == kHeaderSize) {
+ // This probably isn't a real case since all (current) messages have bodies.
+ if (_isIngressSession) {
+ networkCounter.hitPhysicalIn(msgLen);
+ }
+ return Future<Message>::makeReady(Message(std::move(headerBuffer)));
+ }
+
+ auto buffer = SharedBuffer::allocate(msgLen);
+ memcpy(buffer.get(), headerBuffer.get(), kHeaderSize);
+
+ MsgData::View msgView(buffer.get());
+ return read(asio::buffer(msgView.data(), msgView.dataLen()), baton)
+ .then([this, buffer = std::move(buffer), msgLen]() mutable {
+ if (_isIngressSession) {
+ networkCounter.hitPhysicalIn(msgLen);
+ }
+ return Message(std::move(buffer));
+ });
+ });
+}
+
+template <typename MutableBufferSequence>
+Future<void> TransportLayerASIO::ASIOSession::read(const MutableBufferSequence& buffers,
+ const BatonHandle& baton) {
+ // TODO SERVER-47229 Guard active ops for cancellation here.
+#ifdef MONGO_CONFIG_SSL
+ if (_sslSocket) {
+ return opportunisticRead(*_sslSocket, buffers, baton);
+ } else if (!_ranHandshake) {
+ invariant(asio::buffer_size(buffers) >= sizeof(MSGHEADER::Value));
+
+ return opportunisticRead(_socket, buffers, baton)
+ .then([this, buffers]() mutable {
+ _ranHandshake = true;
+ return maybeHandshakeSSLForIngress(buffers);
+ })
+ .then([this, buffers, baton](bool needsRead) mutable {
+ if (needsRead) {
+ return read(buffers, baton);
+ } else {
+ return Future<void>::makeReady();
+ }
+ });
+ }
+#endif
+ return opportunisticRead(_socket, buffers, baton);
+}
+
+template <typename ConstBufferSequence>
+Future<void> TransportLayerASIO::ASIOSession::write(const ConstBufferSequence& buffers,
+ const BatonHandle& baton) {
+ // TODO SERVER-47229 Guard active ops for cancellation here.
+#ifdef MONGO_CONFIG_SSL
+ _ranHandshake = true;
+ if (_sslSocket) {
+#ifdef __linux__
+ // We do some trickery in asio (see moreToSend), which appears to work well on linux,
+ // but fails on other platforms.
+ return opportunisticWrite(*_sslSocket, buffers, baton);
+#else
+ if (_blockingMode == Async) {
+ // Opportunistic writes are broken for async egress SSL (switching between blocking
+ // and non-blocking mode corrupts the TLS exchange).
+ return asio::async_write(*_sslSocket, buffers, UseFuture{}).ignoreValue();
+ } else {
+ return opportunisticWrite(*_sslSocket, buffers, baton);
+ }
+#endif
+ }
+#endif // MONGO_CONFIG_SSL
+ return opportunisticWrite(_socket, buffers, baton);
+}
+
+template <typename Stream, typename MutableBufferSequence>
+Future<void> TransportLayerASIO::ASIOSession::opportunisticRead(
+ Stream& stream, const MutableBufferSequence& buffers, const BatonHandle& baton) {
+ std::error_code ec;
+ size_t size;
+
+ if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) &&
+ _blockingMode == Async) {
+ asio::mutable_buffer localBuffer = buffers;
+
+ if (buffers.size()) {
+ localBuffer = asio::mutable_buffer(buffers.data(), 1);
+ }
+
+ do {
+ size = asio::read(stream, localBuffer, ec);
+ } while (ec == asio::error::interrupted); // retry syscall EINTR
+
+ if (!ec && buffers.size() > 1) {
+ ec = asio::error::would_block;
+ }
+ } else {
+ do {
+ size = asio::read(stream, buffers, ec);
+ } while (ec == asio::error::interrupted); // retry syscall EINTR
+ }
+
+ if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) &&
+ (_blockingMode == Async)) {
+ // asio::read is a loop internally, so some of buffers may have been read into already.
+ // So we need to adjust the buffers passed into async_read to be offset by size, if
+ // size is > 0.
+ MutableBufferSequence asyncBuffers(buffers);
+ if (size > 0) {
+ asyncBuffers += size;
+ }
+
+ if (auto networkingBaton = baton ? baton->networking() : nullptr;
+ networkingBaton && networkingBaton->canWait()) {
+ return networkingBaton->addSession(*this, NetworkingBaton::Type::In)
+ .onError([](Status error) {
+ if (ErrorCodes::isShutdownError(error)) {
+ // If the baton has detached, it will cancel its polling. We catch that
+ // error here and return Status::OK so that we invoke
+ // opportunisticRead() again and switch to asio::async_read() below.
+ return Status::OK();
+ }
+
+ return error;
+ })
+ .then([&stream, asyncBuffers, baton, this] {
+ return opportunisticRead(stream, asyncBuffers, baton);
+ });
+ }
+
+ return asio::async_read(stream, asyncBuffers, UseFuture{}).ignoreValue();
+ } else {
+ return futurize(ec);
+ }
+}
+
+#ifdef MONGO_CONFIG_SSL
+boost::optional<std::string> TransportLayerASIO::ASIOSession::getSniName() const {
+ return SSLPeerInfo::forSession(shared_from_this()).sniName;
+}
+#endif
+
+template <typename Stream, typename ConstBufferSequence>
+Future<void> TransportLayerASIO::ASIOSession::opportunisticWrite(Stream& stream,
+ const ConstBufferSequence& buffers,
+ const BatonHandle& baton) {
+ std::error_code ec;
+ std::size_t size;
+
+ if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) &&
+ _blockingMode == Async) {
+ asio::const_buffer localBuffer = buffers;
+
+ if (buffers.size()) {
+ localBuffer = asio::const_buffer(buffers.data(), 1);
+ }
+
+ do {
+ size = asio::write(stream, localBuffer, ec);
+ } while (ec == asio::error::interrupted); // retry syscall EINTR
+ if (!ec && buffers.size() > 1) {
+ ec = asio::error::would_block;
+ }
+ } else {
+ do {
+ size = asio::write(stream, buffers, ec);
+ } while (ec == asio::error::interrupted); // retry syscall EINTR
+ }
+
+ if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) &&
+ (_blockingMode == Async)) {
+
+ // asio::write is a loop internally, so some of buffers may have been read into already.
+ // So we need to adjust the buffers passed into async_write to be offset by size, if
+ // size is > 0.
+ ConstBufferSequence asyncBuffers(buffers);
+ if (size > 0) {
+ asyncBuffers += size;
+ }
+
+ if (auto more = moreToSend(stream, asyncBuffers, baton)) {
+ return std::move(*more);
+ }
+
+ if (auto networkingBaton = baton ? baton->networking() : nullptr;
+ networkingBaton && networkingBaton->canWait()) {
+ return networkingBaton->addSession(*this, NetworkingBaton::Type::Out)
+ .onError([](Status error) {
+ if (ErrorCodes::isCancellationError(error)) {
+ // If the baton has detached, it will cancel its polling. We catch that
+ // error here and return Status::OK so that we invoke
+ // opportunisticWrite() again and switch to asio::async_write() below.
+ return Status::OK();
+ }
+
+ return error;
+ })
+ .then([&stream, asyncBuffers, baton, this] {
+ return opportunisticWrite(stream, asyncBuffers, baton);
+ });
+ }
+
+ return asio::async_write(stream, asyncBuffers, UseFuture{}).ignoreValue();
+ } else {
+ return futurize(ec);
+ }
+}
+
+#ifdef MONGO_CONFIG_SSL
+template <typename MutableBufferSequence>
+Future<bool> TransportLayerASIO::ASIOSession::maybeHandshakeSSLForIngress(
+ const MutableBufferSequence& buffer) {
+ invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value));
+ MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer));
+ auto responseTo = headerView.getResponseToMsgId();
+
+ if (checkForHTTPRequest(buffer)) {
+ return Future<bool>::makeReady(false);
+ }
+ // This logic was taken from the old mongo/util/net/sock.cpp.
+ //
+ // It lets us run both TLS and unencrypted mongo over the same port.
+ //
+ // The first message received from the client should have the responseTo field of the wire
+ // protocol message needs to be 0 or -1. Otherwise the connection is either sending
+ // garbage or a TLS Hello packet which will be caught by the TLS handshake.
+ if (responseTo != 0 && responseTo != -1) {
+ if (!_sslContext->ingress) {
+ return Future<bool>::makeReady(
+ Status(ErrorCodes::SSLHandshakeFailed,
+ "SSL handshake received but server is started without SSL support"));
+ }
+
+ auto tlsAlert = checkTLSRequest(buffer);
+ if (tlsAlert) {
+ return opportunisticWrite(getSocket(), asio::buffer(tlsAlert->data(), tlsAlert->size()))
+ .then([] {
+ return Future<bool>::makeReady(
+ Status(ErrorCodes::SSLHandshakeFailed,
+ "SSL handshake failed, as client requested disabled protocol"));
+ });
+ }
+
+ _sslSocket.emplace(std::move(_socket), *_sslContext->ingress, "");
+ auto doHandshake = [&] {
+ if (_blockingMode == Sync) {
+ std::error_code ec;
+ _sslSocket->handshake(asio::ssl::stream_base::server, buffer, ec);
+ return futurize(ec, asio::buffer_size(buffer));
+ } else {
+ return _sslSocket->async_handshake(
+ asio::ssl::stream_base::server, buffer, UseFuture{});
+ }
+ };
+ return doHandshake().then([this](size_t size) {
+ if (_sslSocket->get_sni()) {
+ auto sniName = _sslSocket->get_sni().get();
+ LOGV2_DEBUG(
+ 4908000, 2, "Client connected with SNI extension", "sniName"_attr = sniName);
+ } else {
+ LOGV2_DEBUG(4908001, 2, "Client connected without SNI extension");
+ }
+ if (SSLPeerInfo::forSession(shared_from_this()).subjectName.empty()) {
+ return getSSLManager()
+ ->parseAndValidatePeerCertificate(
+ _sslSocket->native_handle(), _sslSocket->get_sni(), "", _remote, nullptr)
+ .then([this](SSLPeerInfo info) -> bool {
+ SSLPeerInfo::forSession(shared_from_this()) = info;
+ return true;
+ });
+ }
+
+ return Future<bool>::makeReady(true);
+ });
+ } else if (_tl->_sslMode() == SSLParams::SSLMode_requireSSL) {
+ uasserted(ErrorCodes::SSLHandshakeFailed,
+ "The server is configured to only allow SSL connections");
+ } else {
+ if (!sslGlobalParams.disableNonSSLConnectionLogging &&
+ _tl->_sslMode() == SSLParams::SSLMode_preferSSL) {
+ LOGV2(23838,
+ "SSL mode is set to 'preferred' and connection {connectionId} to {remote} is "
+ "not using SSL.",
+ "SSL mode is set to 'preferred' and connection to remote is not using SSL.",
+ "connectionId"_attr = id(),
+ "remote"_attr = remote());
+ }
+ return Future<bool>::makeReady(false);
+ }
+}
+#endif // MONGO_CONFIG_SSL
+
+template <typename Buffer>
+bool TransportLayerASIO::ASIOSession::checkForHTTPRequest(const Buffer& buffers) {
+ invariant(asio::buffer_size(buffers) >= 4);
+ const StringData bufferAsStr(asio::buffer_cast<const char*>(buffers), 4);
+ return (bufferAsStr == "GET "_sd);
+}
+
+Future<Message> TransportLayerASIO::ASIOSession::sendHTTPResponse(const BatonHandle& baton) {
+ constexpr auto userMsg =
+ "It looks like you are trying to access MongoDB over HTTP"
+ " on the native driver port.\r\n"_sd;
+
+ static const std::string httpResp = str::stream() << "HTTP/1.0 200 OK\r\n"
+ "Connection: close\r\n"
+ "Content-Type: text/plain\r\n"
+ "Content-Length: "
+ << userMsg.size() << "\r\n\r\n"
+ << userMsg;
+
+ return write(asio::buffer(httpResp.data(), httpResp.size()), baton)
+ .onError([](const Status& status) {
+ return Status(ErrorCodes::ProtocolError,
+ str::stream()
+ << "Client sent an HTTP request over a native MongoDB connection, "
+ "but there was an error sending a response: "
+ << status.toString());
+ })
+ .then([] {
+ return StatusWith<Message>(
+ ErrorCodes::ProtocolError,
+ "Client sent an HTTP request over a native MongoDB connection");
+ });
+}
+
+} // namespace mongo::transport
diff --git a/src/mongo/transport/session_asio.h b/src/mongo/transport/session_asio.h
index a91e710f6f9..b77d95c5c7f 100644
--- a/src/mongo/transport/session_asio.h
+++ b/src/mongo/transport/session_asio.h
@@ -51,10 +51,9 @@
#include "mongo/util/net/ssl.hpp"
#endif
-namespace mongo {
-namespace transport {
+namespace mongo::transport {
-MONGO_FAIL_POINT_DEFINE(transportLayerASIOshortOpportunisticReadWrite);
+extern FailPoint transportLayerASIOshortOpportunisticReadWrite;
template <typename SuccessValue>
auto futurize(const std::error_code& ec, SuccessValue&& successValue) {
@@ -65,7 +64,7 @@ auto futurize(const std::error_code& ec, SuccessValue&& successValue) {
return Result::makeReady(successValue);
}
-Future<void> futurize(const std::error_code& ec) {
+inline Future<void> futurize(const std::error_code& ec) {
using Result = Future<void>;
if (MONGO_unlikely(ec)) {
return Result::makeReady(errorCodeToStatus(ec));
@@ -73,13 +72,10 @@ Future<void> futurize(const std::error_code& ec) {
return Result::makeReady();
}
-using GenericSocket = asio::generic::stream_protocol::socket;
-
class TransportLayerASIO::ASIOSession final : public Session {
- ASIOSession(const ASIOSession&) = delete;
- ASIOSession& operator=(const ASIOSession&) = delete;
-
public:
+ using GenericSocket = asio::generic::stream_protocol::socket;
+
using Endpoint = asio::generic::stream_protocol::endpoint;
// If the socket is disconnected while any of these options are being set, this constructor
@@ -88,50 +84,10 @@ public:
GenericSocket socket,
bool isIngressSession,
Endpoint endpoint = Endpoint(),
- std::shared_ptr<const SSLConnectionContext> transientSSLContext = nullptr) try
- : _socket(std::move(socket)),
- _tl(tl),
- _isIngressSession(isIngressSession) {
- auto family = endpointToSockAddr(_socket.local_endpoint()).getType();
- if (family == AF_INET || family == AF_INET6) {
- setSocketOption(_socket, asio::ip::tcp::no_delay(true), "session no delay");
- setSocketOption(_socket, asio::socket_base::keep_alive(true), "session keep alive");
- setSocketKeepAliveParams(_socket.native_handle());
- }
+ std::shared_ptr<const SSLConnectionContext> transientSSLContext = nullptr);
- _localAddr = endpointToSockAddr(_socket.local_endpoint());
-
- if (endpoint == Endpoint()) {
- // Inbound connection, query socket for remote.
- _remoteAddr = endpointToSockAddr(_socket.remote_endpoint());
- } else {
- // Outbound connection, get remote from resolved endpoint.
- // Necessary for TCP_FASTOPEN where the remote isn't connected yet.
- _remoteAddr = endpointToSockAddr(endpoint);
- }
-
- _local = HostAndPort(_localAddr.toString(true));
- _remote = HostAndPort(_remoteAddr.toString(true));
-#ifdef MONGO_CONFIG_SSL
- _sslContext = transientSSLContext ? transientSSLContext : *tl->_sslContext;
- if (transientSSLContext) {
- logv2::DynamicAttributes attrs;
- if (transientSSLContext->targetClusterURI) {
- attrs.add("targetClusterURI", *transientSSLContext->targetClusterURI);
- }
- attrs.add("isIngress", isIngressSession);
- attrs.add("connectionId", id());
- attrs.add("remote", remote());
- LOGV2(5271001, "Initializing the ASIOSession with transient SSL context", attrs);
- }
-#endif
- } catch (const DBException&) {
- throw;
- } catch (const asio::system_error& error) {
- uasserted(ErrorCodes::SocketException, error.what());
- } catch (...) {
- uasserted(50797, str::stream() << "Unknown exception while configuring socket.");
- }
+ ASIOSession(const ASIOSession&) = delete;
+ ASIOSession& operator=(const ASIOSession&) = delete;
~ASIOSession() {
end();
@@ -157,142 +113,31 @@ public:
return _localAddr;
}
- void end() override {
- if (getSocket().is_open()) {
- std::error_code ec;
- getSocket().shutdown(GenericSocket::shutdown_both, ec);
- if ((ec) && (ec != asio::error::not_connected)) {
- LOGV2_ERROR(23841,
- "Error shutting down socket: {error}",
- "Error shutting down socket",
- "error"_attr = ec.message());
- }
- }
- }
+ void end() override;
- StatusWith<Message> sourceMessage() noexcept override try {
- ensureSync();
- return sourceMessageImpl().getNoThrow();
- } catch (const DBException& ex) {
- return ex.toStatus();
- }
+ StatusWith<Message> sourceMessage() noexcept override;
- Future<Message> asyncSourceMessage(const BatonHandle& baton = nullptr) noexcept override try {
- ensureAsync();
- return sourceMessageImpl(baton);
- } catch (const DBException& ex) {
- return ex.toStatus();
- }
+ Future<Message> asyncSourceMessage(const BatonHandle& baton = nullptr) noexcept override;
- Status waitForData() noexcept override try {
- ensureSync();
- asio::error_code ec;
- getSocket().wait(asio::ip::tcp::socket::wait_read, ec);
- return errorCodeToStatus(ec);
- } catch (const DBException& ex) {
- return ex.toStatus();
- }
+ Status waitForData() noexcept override;
- Future<void> asyncWaitForData() noexcept override try {
- ensureAsync();
- return getSocket().async_wait(asio::ip::tcp::socket::wait_read, UseFuture{});
- } catch (const DBException& ex) {
- return ex.toStatus();
- }
+ Future<void> asyncWaitForData() noexcept override;
- Status sinkMessage(Message message) noexcept override try {
- ensureSync();
-
- return write(asio::buffer(message.buf(), message.size()))
- .then([this, &message] {
- if (_isIngressSession) {
- networkCounter.hitPhysicalOut(message.size());
- }
- })
- .getNoThrow();
- } catch (const DBException& ex) {
- return ex.toStatus();
- }
+ Status sinkMessage(Message message) noexcept override;
Future<void> asyncSinkMessage(Message message,
- const BatonHandle& baton = nullptr) noexcept override try {
- ensureAsync();
- return write(asio::buffer(message.buf(), message.size()), baton)
- .then([this, message /*keep the buffer alive*/]() {
- if (_isIngressSession) {
- networkCounter.hitPhysicalOut(message.size());
- }
- });
- } catch (const DBException& ex) {
- return ex.toStatus();
- }
+ const BatonHandle& baton = nullptr) noexcept override;
- void cancelAsyncOperations(const BatonHandle& baton = nullptr) override {
- LOGV2_DEBUG(4615608,
- 3,
- "Cancelling outstanding I/O operations on connection to {remote}",
- "Cancelling outstanding I/O operations on connection to remote",
- "remote"_attr = _remote);
- if (baton && baton->networking() && baton->networking()->cancelSession(*this)) {
- // If we have a baton, it was for networking, and it owned our session, then we're done.
- return;
- }
+ void cancelAsyncOperations(const BatonHandle& baton = nullptr) override;
- getSocket().cancel();
- }
-
- void setTimeout(boost::optional<Milliseconds> timeout) override {
- invariant(!timeout || timeout->count() > 0);
- _configuredTimeout = timeout;
- }
-
- bool isConnected() override {
- // socket.is_open() only returns whether the socket is a valid file descriptor and
- // if we haven't marked this socket as closed already.
- if (!getSocket().is_open())
- return false;
-
- auto swPollEvents = pollASIOSocket(getSocket(), POLLIN, Milliseconds{0});
- if (!swPollEvents.isOK()) {
- if (swPollEvents != ErrorCodes::NetworkTimeout) {
- LOGV2_WARNING(4615609,
- "Failed to poll socket for connectivity check: {error}",
- "Failed to poll socket for connectivity check",
- "error"_attr = swPollEvents.getStatus());
- return false;
- }
- return true;
- }
-
- auto revents = swPollEvents.getValue();
- if (revents & POLLIN) {
- char testByte;
- int size = ::recv(getSocket().native_handle(), &testByte, sizeof(testByte), MSG_PEEK);
- if (size == sizeof(testByte)) {
- return true;
- } else if (size == -1) {
- LOGV2_WARNING(4615610,
- "Failed to check socket connectivity: {error}",
- "Failed to check socket connectivity",
- "error"_attr = errnoWithDescription(errno));
- }
- // If size == 0 then we got disconnected and we should return false.
- }
+ void setTimeout(boost::optional<Milliseconds> timeout) override;
- return false;
- }
+ bool isConnected() override;
#ifdef MONGO_CONFIG_SSL
- const SSLConfiguration* getSSLConfiguration() const override {
- if (_sslContext->manager) {
- return &_sslContext->manager->getSSLConfiguration();
- }
- return nullptr;
- }
+ const SSLConfiguration* getSSLConfiguration() const override;
- const std::shared_ptr<SSLManagerInterface> getSSLManager() const override {
- return _sslContext->manager;
- }
+ const std::shared_ptr<SSLManagerInterface> getSSLManager() const override;
#endif
protected:
@@ -304,305 +149,32 @@ protected:
// timeout callback. It will be unlocked before the SSL actually handshake begins.
Future<void> handshakeSSLForEgressWithLock(stdx::unique_lock<Latch> lk,
const HostAndPort& target,
- const ReactorHandle& reactor) {
- if (!_sslContext->egress) {
- return Future<void>::makeReady(Status(ErrorCodes::SSLHandshakeFailed,
- "SSL requested but SSL support is disabled"));
- }
-
- _sslSocket.emplace(std::move(_socket), *_sslContext->egress, removeFQDNRoot(target.host()));
- lk.unlock();
-
- auto doHandshake = [&] {
- if (_blockingMode == Sync) {
- std::error_code ec;
- _sslSocket->handshake(asio::ssl::stream_base::client, ec);
- return futurize(ec);
- } else {
- return _sslSocket->async_handshake(asio::ssl::stream_base::client, UseFuture{});
- }
- };
- return doHandshake().then([this, target, reactor] {
- _ranHandshake = true;
-
- return getSSLManager()
- ->parseAndValidatePeerCertificate(_sslSocket->native_handle(),
- _sslSocket->get_sni(),
- target.host(),
- target,
- reactor)
- .then([this](SSLPeerInfo info) {
- SSLPeerInfo::forSession(shared_from_this()) = info;
- });
- });
- }
+ const ReactorHandle& reactor);
// For synchronous connections where we don't have an async timer, just take a dummy lock and
// pass it to the WithLock version of handshakeSSLForEgress
- Future<void> handshakeSSLForEgress(const HostAndPort& target) {
- auto mutex = MONGO_MAKE_LATCH();
- return handshakeSSLForEgressWithLock(stdx::unique_lock<Latch>(mutex), target, nullptr);
- }
+ Future<void> handshakeSSLForEgress(const HostAndPort& target);
#endif
- void ensureSync() {
- asio::error_code ec;
- if (_blockingMode != Sync) {
- getSocket().non_blocking(false, ec);
- fassert(40490, errorCodeToStatus(ec));
- _blockingMode = Sync;
- }
+ void ensureSync();
- if (_socketTimeout != _configuredTimeout) {
- // Change boost::none (which means no timeout) into a zero value for the socket option,
- // which also means no timeout.
- auto timeout = _configuredTimeout.value_or(Milliseconds{0});
- setSocketOption(getSocket(),
- ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout),
- ec,
- "session send timeout");
- if (auto status = errorCodeToStatus(ec); !status.isOK()) {
- tasserted(5342000, status.reason());
- }
-
- setSocketOption(getSocket(),
- ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout),
- ec,
- "session receive timeout");
- if (auto status = errorCodeToStatus(ec); !status.isOK()) {
- tasserted(5342001, status.reason());
- }
-
- _socketTimeout = _configuredTimeout;
- }
- }
-
- void ensureAsync() {
- if (_blockingMode == Async)
- return;
-
- // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't
- // expecting a socket timeout when they do an async operation.
- invariant(!_configuredTimeout);
-
- asio::error_code ec;
- getSocket().non_blocking(true, ec);
- fassert(50706, errorCodeToStatus(ec));
- _blockingMode = Async;
- }
+ void ensureAsync();
private:
- template <int Name>
- class ASIOSocketTimeoutOption {
- public:
-#ifdef _WIN32
- using TimeoutType = DWORD;
+ GenericSocket& getSocket();
- ASIOSocketTimeoutOption(Milliseconds timeoutVal) : _timeout(timeoutVal.count()) {}
-
-#else
- using TimeoutType = timeval;
-
- ASIOSocketTimeoutOption(Milliseconds timeoutVal) {
- _timeout.tv_sec = duration_cast<Seconds>(timeoutVal).count();
- const auto minusSeconds = timeoutVal - Seconds{_timeout.tv_sec};
- _timeout.tv_usec = duration_cast<Microseconds>(minusSeconds).count();
- }
-#endif
-
- template <typename Protocol>
- int name(const Protocol&) const {
- return Name;
- }
-
- template <typename Protocol>
- const TimeoutType* data(const Protocol&) const {
- return &_timeout;
- }
-
- template <typename Protocol>
- std::size_t size(const Protocol&) const {
- return sizeof(_timeout);
- }
-
- template <typename Protocol>
- int level(const Protocol&) const {
- return SOL_SOCKET;
- }
-
- private:
- TimeoutType _timeout;
- };
-
- GenericSocket& getSocket() {
-#ifdef MONGO_CONFIG_SSL
- if (_sslSocket) {
- return static_cast<GenericSocket&>(_sslSocket->lowest_layer());
- }
-#endif
- return _socket;
- }
-
- Future<Message> sourceMessageImpl(const BatonHandle& baton = nullptr) {
- static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value);
-
- auto headerBuffer = SharedBuffer::allocate(kHeaderSize);
- auto ptr = headerBuffer.get();
- return read(asio::buffer(ptr, kHeaderSize), baton)
- .then([headerBuffer = std::move(headerBuffer), this, baton]() mutable {
- if (checkForHTTPRequest(asio::buffer(headerBuffer.get(), kHeaderSize))) {
- return sendHTTPResponse(baton);
- }
-
- const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength());
- if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) {
- StringBuilder sb;
- sb << "recv(): message msgLen " << msgLen << " is invalid. "
- << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes;
- const auto str = sb.str();
- LOGV2(4615638,
- "recv(): message msgLen {msgLen} is invalid. Min: {min} Max: {max}",
- "recv(): message mstLen is invalid.",
- "msgLen"_attr = msgLen,
- "min"_attr = kHeaderSize,
- "max"_attr = MaxMessageSizeBytes);
-
- return Future<Message>::makeReady(Status(ErrorCodes::ProtocolError, str));
- }
-
- if (msgLen == kHeaderSize) {
- // This probably isn't a real case since all (current) messages have bodies.
- if (_isIngressSession) {
- networkCounter.hitPhysicalIn(msgLen);
- }
- return Future<Message>::makeReady(Message(std::move(headerBuffer)));
- }
-
- auto buffer = SharedBuffer::allocate(msgLen);
- memcpy(buffer.get(), headerBuffer.get(), kHeaderSize);
-
- MsgData::View msgView(buffer.get());
- return read(asio::buffer(msgView.data(), msgView.dataLen()), baton)
- .then([this, buffer = std::move(buffer), msgLen]() mutable {
- if (_isIngressSession) {
- networkCounter.hitPhysicalIn(msgLen);
- }
- return Message(std::move(buffer));
- });
- });
- }
+ Future<Message> sourceMessageImpl(const BatonHandle& baton = nullptr);
template <typename MutableBufferSequence>
- Future<void> read(const MutableBufferSequence& buffers, const BatonHandle& baton = nullptr) {
- // TODO SERVER-47229 Guard active ops for cancellation here.
-#ifdef MONGO_CONFIG_SSL
- if (_sslSocket) {
- return opportunisticRead(*_sslSocket, buffers, baton);
- } else if (!_ranHandshake) {
- invariant(asio::buffer_size(buffers) >= sizeof(MSGHEADER::Value));
-
- return opportunisticRead(_socket, buffers, baton)
- .then([this, buffers]() mutable {
- _ranHandshake = true;
- return maybeHandshakeSSLForIngress(buffers);
- })
- .then([this, buffers, baton](bool needsRead) mutable {
- if (needsRead) {
- return read(buffers, baton);
- } else {
- return Future<void>::makeReady();
- }
- });
- }
-#endif
- return opportunisticRead(_socket, buffers, baton);
- }
+ Future<void> read(const MutableBufferSequence& buffers, const BatonHandle& baton = nullptr);
template <typename ConstBufferSequence>
- Future<void> write(const ConstBufferSequence& buffers, const BatonHandle& baton = nullptr) {
- // TODO SERVER-47229 Guard active ops for cancellation here.
-#ifdef MONGO_CONFIG_SSL
- _ranHandshake = true;
- if (_sslSocket) {
-#ifdef __linux__
- // We do some trickery in asio (see moreToSend), which appears to work well on linux,
- // but fails on other platforms.
- return opportunisticWrite(*_sslSocket, buffers, baton);
-#else
- if (_blockingMode == Async) {
- // Opportunistic writes are broken for async egress SSL (switching between blocking
- // and non-blocking mode corrupts the TLS exchange).
- return asio::async_write(*_sslSocket, buffers, UseFuture{}).ignoreValue();
- } else {
- return opportunisticWrite(*_sslSocket, buffers, baton);
- }
-#endif
- }
-#endif
- return opportunisticWrite(_socket, buffers, baton);
- }
+ Future<void> write(const ConstBufferSequence& buffers, const BatonHandle& baton = nullptr);
template <typename Stream, typename MutableBufferSequence>
Future<void> opportunisticRead(Stream& stream,
const MutableBufferSequence& buffers,
- const BatonHandle& baton = nullptr) {
- std::error_code ec;
- size_t size;
-
- if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) &&
- _blockingMode == Async) {
- asio::mutable_buffer localBuffer = buffers;
-
- if (buffers.size()) {
- localBuffer = asio::mutable_buffer(buffers.data(), 1);
- }
-
- do {
- size = asio::read(stream, localBuffer, ec);
- } while (ec == asio::error::interrupted); // retry syscall EINTR
-
- if (!ec && buffers.size() > 1) {
- ec = asio::error::would_block;
- }
- } else {
- do {
- size = asio::read(stream, buffers, ec);
- } while (ec == asio::error::interrupted); // retry syscall EINTR
- }
-
- if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) &&
- (_blockingMode == Async)) {
- // asio::read is a loop internally, so some of buffers may have been read into already.
- // So we need to adjust the buffers passed into async_read to be offset by size, if
- // size is > 0.
- MutableBufferSequence asyncBuffers(buffers);
- if (size > 0) {
- asyncBuffers += size;
- }
-
- if (auto networkingBaton = baton ? baton->networking() : nullptr;
- networkingBaton && networkingBaton->canWait()) {
- return networkingBaton->addSession(*this, NetworkingBaton::Type::In)
- .onError([](Status error) {
- if (ErrorCodes::isShutdownError(error)) {
- // If the baton has detached, it will cancel its polling. We catch that
- // error here and return Status::OK so that we invoke
- // opportunisticRead() again and switch to asio::async_read() below.
- return Status::OK();
- }
-
- return error;
- })
- .then([&stream, asyncBuffers, baton, this] {
- return opportunisticRead(stream, asyncBuffers, baton);
- });
- }
-
- return asio::async_read(stream, asyncBuffers, UseFuture{}).ignoreValue();
- } else {
- return futurize(ec);
- }
- }
+ const BatonHandle& baton = nullptr);
/**
* moreToSend checks the ssl socket after an opportunisticWrite. If there are still bytes to
@@ -635,204 +207,26 @@ private:
return boost::none;
}
- boost::optional<std::string> getSniName() const override {
- return SSLPeerInfo::forSession(shared_from_this()).sniName;
- }
+ boost::optional<std::string> getSniName() const override;
#endif
template <typename Stream, typename ConstBufferSequence>
Future<void> opportunisticWrite(Stream& stream,
const ConstBufferSequence& buffers,
- const BatonHandle& baton = nullptr) {
- std::error_code ec;
- std::size_t size;
-
- if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) &&
- _blockingMode == Async) {
- asio::const_buffer localBuffer = buffers;
-
- if (buffers.size()) {
- localBuffer = asio::const_buffer(buffers.data(), 1);
- }
-
- do {
- size = asio::write(stream, localBuffer, ec);
- } while (ec == asio::error::interrupted); // retry syscall EINTR
- if (!ec && buffers.size() > 1) {
- ec = asio::error::would_block;
- }
- } else {
- do {
- size = asio::write(stream, buffers, ec);
- } while (ec == asio::error::interrupted); // retry syscall EINTR
- }
-
- if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) &&
- (_blockingMode == Async)) {
-
- // asio::write is a loop internally, so some of buffers may have been read into already.
- // So we need to adjust the buffers passed into async_write to be offset by size, if
- // size is > 0.
- ConstBufferSequence asyncBuffers(buffers);
- if (size > 0) {
- asyncBuffers += size;
- }
-
- if (auto more = moreToSend(stream, asyncBuffers, baton)) {
- return std::move(*more);
- }
-
- if (auto networkingBaton = baton ? baton->networking() : nullptr;
- networkingBaton && networkingBaton->canWait()) {
- return networkingBaton->addSession(*this, NetworkingBaton::Type::Out)
- .onError([](Status error) {
- if (ErrorCodes::isCancellationError(error)) {
- // If the baton has detached, it will cancel its polling. We catch that
- // error here and return Status::OK so that we invoke
- // opportunisticWrite() again and switch to asio::async_write() below.
- return Status::OK();
- }
-
- return error;
- })
- .then([&stream, asyncBuffers, baton, this] {
- return opportunisticWrite(stream, asyncBuffers, baton);
- });
- }
-
- return asio::async_write(stream, asyncBuffers, UseFuture{}).ignoreValue();
- } else {
- return futurize(ec);
- }
- }
+ const BatonHandle& baton = nullptr);
#ifdef MONGO_CONFIG_SSL
template <typename MutableBufferSequence>
- Future<bool> maybeHandshakeSSLForIngress(const MutableBufferSequence& buffer) {
- invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value));
- MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer));
- auto responseTo = headerView.getResponseToMsgId();
-
- if (checkForHTTPRequest(buffer)) {
- return Future<bool>::makeReady(false);
- }
- // This logic was taken from the old mongo/util/net/sock.cpp.
- //
- // It lets us run both TLS and unencrypted mongo over the same port.
- //
- // The first message received from the client should have the responseTo field of the wire
- // protocol message needs to be 0 or -1. Otherwise the connection is either sending
- // garbage or a TLS Hello packet which will be caught by the TLS handshake.
- if (responseTo != 0 && responseTo != -1) {
- if (!_sslContext->ingress) {
- return Future<bool>::makeReady(
- Status(ErrorCodes::SSLHandshakeFailed,
- "SSL handshake received but server is started without SSL support"));
- }
-
- auto tlsAlert = checkTLSRequest(buffer);
- if (tlsAlert) {
- return opportunisticWrite(getSocket(),
- asio::buffer(tlsAlert->data(), tlsAlert->size()))
- .then([] {
- return Future<bool>::makeReady(
- Status(ErrorCodes::SSLHandshakeFailed,
- "SSL handshake failed, as client requested disabled protocol"));
- });
- }
-
- _sslSocket.emplace(std::move(_socket), *_sslContext->ingress, "");
- auto doHandshake = [&] {
- if (_blockingMode == Sync) {
- std::error_code ec;
- _sslSocket->handshake(asio::ssl::stream_base::server, buffer, ec);
- return futurize(ec, asio::buffer_size(buffer));
- } else {
- return _sslSocket->async_handshake(
- asio::ssl::stream_base::server, buffer, UseFuture{});
- }
- };
- return doHandshake().then([this](size_t size) {
- if (_sslSocket->get_sni()) {
- auto sniName = _sslSocket->get_sni().get();
- LOGV2_DEBUG(4908000,
- 2,
- "Client connected with SNI extension",
- "sniName"_attr = sniName);
- } else {
- LOGV2_DEBUG(4908001, 2, "Client connected without SNI extension");
- }
- if (SSLPeerInfo::forSession(shared_from_this()).subjectName.empty()) {
- return getSSLManager()
- ->parseAndValidatePeerCertificate(_sslSocket->native_handle(),
- _sslSocket->get_sni(),
- "",
- _remote,
- nullptr)
- .then([this](SSLPeerInfo info) -> bool {
- SSLPeerInfo::forSession(shared_from_this()) = info;
- return true;
- });
- }
-
- return Future<bool>::makeReady(true);
- });
- } else if (_tl->_sslMode() == SSLParams::SSLMode_requireSSL) {
- uasserted(ErrorCodes::SSLHandshakeFailed,
- "The server is configured to only allow SSL connections");
- } else {
- if (!sslGlobalParams.disableNonSSLConnectionLogging &&
- _tl->_sslMode() == SSLParams::SSLMode_preferSSL) {
- LOGV2(23838,
- "SSL mode is set to 'preferred' and connection {connectionId} to {remote} is "
- "not using SSL.",
- "SSL mode is set to 'preferred' and connection to remote is not using SSL.",
- "connectionId"_attr = id(),
- "remote"_attr = remote());
- }
- return Future<bool>::makeReady(false);
- }
- }
+ Future<bool> maybeHandshakeSSLForIngress(const MutableBufferSequence& buffer);
#endif
template <typename Buffer>
- bool checkForHTTPRequest(const Buffer& buffers) {
- invariant(asio::buffer_size(buffers) >= 4);
- const StringData bufferAsStr(asio::buffer_cast<const char*>(buffers), 4);
- return (bufferAsStr == "GET "_sd);
- }
+ bool checkForHTTPRequest(const Buffer& buffers);
// Called from read() to send an HTTP response back to a client that's trying to use HTTP
// over a native MongoDB port. This returns a Future<Message> to match its only caller, but it
// always contains an error, so it could really return Future<Anything>
- Future<Message> sendHTTPResponse(const BatonHandle& baton = nullptr) {
- constexpr auto userMsg =
- "It looks like you are trying to access MongoDB over HTTP"
- " on the native driver port.\r\n"_sd;
-
- static const std::string httpResp = str::stream() << "HTTP/1.0 200 OK\r\n"
- "Connection: close\r\n"
- "Content-Type: text/plain\r\n"
- "Content-Length: "
- << userMsg.size() << "\r\n\r\n"
- << userMsg;
-
- return write(asio::buffer(httpResp.data(), httpResp.size()), baton)
- .onError(
- [](const Status& status) {
- return Status(
- ErrorCodes::ProtocolError,
- str::stream()
- << "Client sent an HTTP request over a native MongoDB connection, "
- "but there was an error sending a response: "
- << status.toString());
- })
- .then([] {
- return StatusWith<Message>(
- ErrorCodes::ProtocolError,
- "Client sent an HTTP request over a native MongoDB connection");
- });
- }
+ Future<Message> sendHTTPResponse(const BatonHandle& baton = nullptr);
enum BlockingMode {
Unknown,
@@ -862,5 +256,4 @@ private:
bool _isIngressSession;
};
-} // namespace transport
-} // namespace mongo
+} // namespace mongo::transport
diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp
index 8323948621e..ae03407e8af 100644
--- a/src/mongo/transport/transport_layer_asio.cpp
+++ b/src/mongo/transport/transport_layer_asio.cpp
@@ -468,7 +468,7 @@ StatusWith<SessionHandle> TransportLayerASIO::connect(
}
std::error_code ec;
- GenericSocket sock(*_egressReactor);
+ ASIOSession::GenericSocket sock(*_egressReactor);
WrappedResolver resolver(*_egressReactor);
Date_t timeBefore = Date_t::now();
@@ -529,7 +529,7 @@ StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncCon
const HostAndPort& peer,
const Milliseconds& timeout,
boost::optional<TransientSSLParams> transientSSLParams) {
- GenericSocket sock(*_egressReactor);
+ ASIOSession::GenericSocket sock(*_egressReactor);
std::error_code ec;
const auto protocol = endpoint->protocol();
@@ -622,7 +622,7 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect(
Promise<SessionHandle> promise;
Mutex mutex = MONGO_MAKE_LATCH(HierarchicalAcquisitionLevel(0), "AsyncConnectState::mutex");
- GenericSocket socket;
+ ASIOSession::GenericSocket socket;
ASIOReactorTimer timeoutTimer;
WrappedResolver resolver;
WrappedEndpoint resolvedEndpoint;
@@ -1185,7 +1185,8 @@ ReactorHandle TransportLayerASIO::getReactor(WhichReactor which) {
}
void TransportLayerASIO::_acceptConnection(GenericAcceptor& acceptor) {
- auto acceptCb = [this, &acceptor](const std::error_code& ec, GenericSocket peerSocket) mutable {
+ auto acceptCb = [this, &acceptor](const std::error_code& ec,
+ ASIOSession::GenericSocket peerSocket) mutable {
if (auto lk = stdx::lock_guard(_mutex); _isShutdown) {
return;
}