diff options
author | Billy Donahue <billy.donahue@mongodb.com> | 2021-07-27 01:50:52 -0400 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-08-06 04:15:30 +0000 |
commit | a57e4d409a81be929d4830199797c675322ae164 (patch) | |
tree | a7d41987d1264fab0fdc28914e8148f2912a5b1f | |
parent | 84a1a0599614a8c077da4b3899aba5a647202746 (diff) | |
download | mongo-a57e4d409a81be929d4830199797c675322ae164.tar.gz |
SERVER-58204 physical layout: session_asio.cpp and asio_utils.cpp
- bring GenericSocket into ASIOSession class
- refactor/minimize AsyncHandlerHelper traits
- no logging in headers
- fix sign extension bugs and conform to Winsock select requirements
-rw-r--r-- | src/mongo/transport/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/transport/asio_utils.cpp | 321 | ||||
-rw-r--r-- | src/mongo/transport/asio_utils.h | 437 | ||||
-rw-r--r-- | src/mongo/transport/session_asio.cpp | 720 | ||||
-rw-r--r-- | src/mongo/transport/session_asio.h | 677 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_asio.cpp | 9 |
6 files changed, 1171 insertions, 995 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index b0dc247c673..d7898ee475e 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -53,6 +53,8 @@ tlEnv.Library( target='transport_layer', source=[ 'transport_layer_asio.cpp', + 'asio_utils.cpp', + 'session_asio.cpp', 'transport_options.idl', ], LIBDEPS=[ diff --git a/src/mongo/transport/asio_utils.cpp b/src/mongo/transport/asio_utils.cpp new file mode 100644 index 00000000000..c954dc829eb --- /dev/null +++ b/src/mongo/transport/asio_utils.cpp @@ -0,0 +1,321 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork +#include "mongo/transport/asio_utils.h" + +#include "mongo/config.h" +#include "mongo/logv2/log.h" + +namespace mongo::transport { + +Status errorCodeToStatus(const std::error_code& ec) { + if (!ec) + return Status::OK(); + + if (ec == asio::error::operation_aborted) { + return {ErrorCodes::CallbackCanceled, "Callback was canceled"}; + } + +#ifdef _WIN32 + if (ec == asio::error::timed_out) { +#else + if (ec == asio::error::try_again || ec == asio::error::would_block) { +#endif + return {ErrorCodes::NetworkTimeout, "Socket operation timed out"}; + } else if (ec == asio::error::eof) { + return {ErrorCodes::HostUnreachable, "Connection closed by peer"}; + } else if (ec == asio::error::connection_reset) { + return {ErrorCodes::HostUnreachable, "Connection reset by peer"}; + } else if (ec == asio::error::network_reset) { + return {ErrorCodes::HostUnreachable, "Connection reset by network"}; + } + + // If the ec.category() is a mongoErrorCategory() then this error was propogated from + // mongodb code and we should just pass the error cdoe along as-is. + ErrorCodes::Error errorCode = (ec.category() == mongoErrorCategory()) + ? ErrorCodes::Error(ec.value()) + // Otherwise it's an error code from the network and we should pass it along as a + // SocketException + : ErrorCodes::SocketException; + // Either way, include the error message. + return {errorCode, ec.message()}; +} + +template <typename T> +auto toUnsignedEquivalent(T x) { + return static_cast<std::make_unsigned_t<T>>(x); +} + +template <typename Dur> +timeval toTimeval(Dur dur) { + auto sec = duration_cast<Seconds>(dur); + timeval tv{}; + tv.tv_sec = sec.count(); + tv.tv_usec = duration_cast<Microseconds>(dur - sec).count(); + return tv; +} + +StatusWith<unsigned> pollASIOSocket(asio::generic::stream_protocol::socket& socket, + unsigned mask, + Milliseconds timeout) { +#ifdef _WIN32 + // On Windows, use `select` to approximate `poll`. + // Windows `select` has a couple special rules: + // - any empty fd_set args *must* be passed as nullptr. + // - the fd_set args can't *all* be nullptr. + struct FlagFdSet { + unsigned pollFlag; + fd_set fds; + }; + std::array sets{ + FlagFdSet{toUnsignedEquivalent(POLLIN)}, + FlagFdSet{toUnsignedEquivalent(POLLOUT)}, + FlagFdSet{toUnsignedEquivalent(POLLERR)}, + }; + auto fd = socket.native_handle(); + mask |= POLLERR; // Always interested in errors. + for (auto& [pollFlag, fds] : sets) { + FD_ZERO(&fds); + if (mask & pollFlag) + FD_SET(fd, &fds); + } + + auto timeoutTv = toTimeval(timeout); + auto fdsPtr = [&](size_t i) { + fd_set* ptr = &sets[i].fds; + return FD_ISSET(fd, ptr) ? ptr : nullptr; + }; + int result = ::select(fd + 1, fdsPtr(0), fdsPtr(1), fdsPtr(2), &timeoutTv); + if (result == SOCKET_ERROR) { + auto errDesc = errnoWithDescription(WSAGetLastError()); + return {ErrorCodes::InternalError, errDesc}; + } else if (result == 0) { + return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"}; + } + + unsigned revents = 0; + for (auto& [pollFlag, fds] : sets) + if (FD_ISSET(fd, &fds)) + revents |= pollFlag; + return revents; +#else + pollfd pollItem = {}; + pollItem.fd = socket.native_handle(); + pollItem.events = mask; + + int result; + boost::optional<Date_t> expiration; + if (timeout.count() > 0) { + expiration = Date_t::now() + timeout; + } + do { + Milliseconds curTimeout; + if (expiration) { + curTimeout = *expiration - Date_t::now(); + if (curTimeout.count() <= 0) { + result = 0; + break; + } + } else { + curTimeout = timeout; + } + result = ::poll(&pollItem, 1, curTimeout.count()); + } while (result == -1 && errno == EINTR); + + if (result == -1) { + int errCode = errno; + return {ErrorCodes::InternalError, errnoWithDescription(errCode)}; + } else if (result == 0) { + return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"}; + } + return toUnsignedEquivalent(pollItem.revents); +#endif +} + +#ifdef MONGO_CONFIG_SSL +boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(const asio::const_buffer& buffer) { + // This method's caller should have read in at least one MSGHEADER::Value's worth of data. + // The fragment we are about to examine must be strictly smaller. + static const size_t sizeOfTLSFragmentToRead = 11; + invariant(buffer.size() >= sizeOfTLSFragmentToRead); + + static_assert(sizeOfTLSFragmentToRead < sizeof(MSGHEADER::Value), + "checkTLSRequest's caller read a MSGHEADER::Value, which must be larger than " + "message containing the TLS version"); + + /** + * The fragment we are to examine is a record, containing a handshake, containing a + * ClientHello. We wish to examine the advertised protocol version in the ClientHello. + * The following roughly describes the contents of these structures. Note that we do not + * need, or wish to, examine the entire ClientHello, we're looking exclusively for the + * client_version. + * + * Below is a rough description of the payload we will be examining. We shall perform some + * basic checks to ensure the payload matches these expectations. If it does not, we should + * bail out, and not emit protocol version alerts. + * + * enum {alert(21), handshake(22)} ContentType; + * TLSPlaintext { + * ContentType type = handshake(22), + * ProtocolVersion version; // Irrelevant. Clients send the real version in ClientHello. + * uint16 length; + * fragment, see Handshake stuct for contents + * ... + * } + * + * enum {client_hello(1)} HandshakeType; + * Handshake { + * HandshakeType msg_type = client_hello(1); + * uint24_t length; + * ClientHello body; + * } + * + * ClientHello { + * ProtocolVersion client_version; // <- This is the value we want to extract. + * } + */ + + static const std::uint8_t ContentType_handshake = 22; + static const std::uint8_t HandshakeType_client_hello = 1; + + using ProtocolVersion = std::array<std::uint8_t, 2>; + static const ProtocolVersion tls10VersionBytes{3, 1}; + static const ProtocolVersion tls11VersionBytes{3, 2}; + + auto request = reinterpret_cast<const char*>(buffer.data()); + auto cdr = ConstDataRangeCursor(request, request + buffer.size()); + + // Parse the record header. + // Extract the ContentType from the header, and ensure it is a handshake. + StatusWith<std::uint8_t> record_ContentType = cdr.readAndAdvanceNoThrow<std::uint8_t>(); + if (!record_ContentType.isOK() || record_ContentType.getValue() != ContentType_handshake) { + return boost::none; + } + // Skip the record's ProtocolVersion. Clients tend to send TLS 1.0 in + // the record, but then their real protocol version in the enclosed ClientHello. + StatusWith<ProtocolVersion> record_protocol_version = + cdr.readAndAdvanceNoThrow<ProtocolVersion>(); + if (!record_protocol_version.isOK()) { + return boost::none; + } + // Parse the record length. It should be be larger than the remaining expected payload. + auto record_length = cdr.readAndAdvanceNoThrow<BigEndian<std::uint16_t>>(); + if (!record_length.isOK() || record_length.getValue() < cdr.length()) { + return boost::none; + } + + // Parse the handshake header. + // Extract the HandshakeType, and ensure it is a ClientHello. + StatusWith<std::uint8_t> handshake_type = cdr.readAndAdvanceNoThrow<std::uint8_t>(); + if (!handshake_type.isOK() || handshake_type.getValue() != HandshakeType_client_hello) { + return boost::none; + } + // Extract the handshake length, and ensure it is larger than the remaining expected + // payload. This requires a little work because the packet represents it with a uint24_t. + StatusWith<std::array<std::uint8_t, 3>> handshake_length_bytes = + cdr.readAndAdvanceNoThrow<std::array<std::uint8_t, 3>>(); + if (!handshake_length_bytes.isOK()) { + return boost::none; + } + std::uint32_t handshake_length = 0; + for (std::uint8_t handshake_byte : handshake_length_bytes.getValue()) { + handshake_length <<= 8; + handshake_length |= handshake_byte; + } + if (handshake_length < cdr.length()) { + return boost::none; + } + StatusWith<ProtocolVersion> client_version = cdr.readAndAdvanceNoThrow<ProtocolVersion>(); + if (!client_version.isOK()) { + return boost::none; + } + + // Invariant: We read exactly as much data as expected. + invariant((cdr.data() - request) == sizeOfTLSFragmentToRead); + + auto isProtocolDisabled = [](SSLParams::Protocols protocol) { + const auto& params = getSSLGlobalParams(); + return std::find(params.sslDisabledProtocols.begin(), + params.sslDisabledProtocols.end(), + protocol) != params.sslDisabledProtocols.end(); + }; + + auto makeTLSProtocolVersionAlert = + [](const std::array<std::uint8_t, 2>& versionBytes) -> std::array<std::uint8_t, 7> { + /** + * The structure for this alert packet is as follows: + * TLSPlaintext { + * ContentType type = alert(21); + * ProtocolVersion = versionBytes; + * uint16_t length = 2 + * fragment = AlertDescription { + * AlertLevel level = fatal(2); + * AlertDescription = protocol_version(70); + * } + * + */ + return std::array<std::uint8_t, 7>{ + 0x15, versionBytes[0], versionBytes[1], 0x00, 0x02, 0x02, 0x46}; + }; + + ProtocolVersion version = client_version.getValue(); + if (version == tls10VersionBytes && isProtocolDisabled(SSLParams::Protocols::TLS1_0)) { + return makeTLSProtocolVersionAlert(version); + } else if (client_version == tls11VersionBytes && + isProtocolDisabled(SSLParams::Protocols::TLS1_1)) { + return makeTLSProtocolVersionAlert(version); + } + // TLS1.2 cannot be distinguished from TLS1.3, just by looking at the ProtocolVersion bytes. + // TLS 1.3 compatible clients advertise a "supported_versions" extension, which we would + // have to extract here. + // Hopefully by the time this matters, OpenSSL will properly emit protocol_version alerts. + + return boost::none; +} +#endif + +void failedSetSocketOption(const std::system_error& ex, + StringData note, + BSONObj optionDescription) { + LOGV2_INFO(5693100, + "Asio socket.set_option failed with std::system_error", + "note"_attr = note, + "option"_attr = optionDescription, + "error"_attr = [&ex] { + return BSONObjBuilder{} + .append("what", ex.what()) + .append("message", ex.code().message()) + .append("category", ex.code().category().name()) + .append("value", ex.code().value()) + .obj(); + }()); +} + +} // namespace mongo::transport diff --git a/src/mongo/transport/asio_utils.h b/src/mongo/transport/asio_utils.h index e903843d43c..2475e90154b 100644 --- a/src/mongo/transport/asio_utils.h +++ b/src/mongo/transport/asio_utils.h @@ -39,15 +39,16 @@ #include "mongo/base/string_data.h" #include "mongo/base/system_error.h" #include "mongo/config.h" +#include "mongo/stdx/type_traits.h" #include "mongo/util/errno_util.h" #include "mongo/util/future.h" #include "mongo/util/hex.h" #include "mongo/util/net/hostandport.h" #include "mongo/util/net/sockaddr.h" #include "mongo/util/net/ssl_manager.h" +#include "mongo/util/net/ssl_options.h" -namespace mongo { -namespace transport { +namespace mongo::transport { inline SockAddr endpointToSockAddr(const asio::generic::stream_protocol::endpoint& endPoint) { SockAddr wrappedAddr(endPoint.data(), endPoint.size()); @@ -59,38 +60,7 @@ inline HostAndPort endpointToHostAndPort(const asio::generic::stream_protocol::e return HostAndPort(endpointToSockAddr(endPoint).toString(true)); } -inline Status errorCodeToStatus(const std::error_code& ec) { - if (!ec) - return Status::OK(); - - if (ec == asio::error::operation_aborted) { - return {ErrorCodes::CallbackCanceled, "Callback was canceled"}; - } - -#ifdef _WIN32 - if (ec == asio::error::timed_out) { -#else - if (ec == asio::error::try_again || ec == asio::error::would_block) { -#endif - return {ErrorCodes::NetworkTimeout, "Socket operation timed out"}; - } else if (ec == asio::error::eof) { - return {ErrorCodes::HostUnreachable, "Connection closed by peer"}; - } else if (ec == asio::error::connection_reset) { - return {ErrorCodes::HostUnreachable, "Connection reset by peer"}; - } else if (ec == asio::error::network_reset) { - return {ErrorCodes::HostUnreachable, "Connection reset by network"}; - } - - // If the ec.category() is a mongoErrorCategory() then this error was propogated from - // mongodb code and we should just pass the error cdoe along as-is. - ErrorCodes::Error errorCode = (ec.category() == mongoErrorCategory()) - ? ErrorCodes::Error(ec.value()) - // Otherwise it's an error code from the network and we should pass it along as a - // SocketException - : ErrorCodes::SocketException; - // Either way, include the error message. - return {errorCode, ec.message()}; -} +Status errorCodeToStatus(const std::error_code& ec); /* * The ASIO implementation of poll (i.e. socket.wait()) cannot poll for a mask of events, and @@ -103,227 +73,25 @@ inline Status errorCodeToStatus(const std::error_code& ec) { * check whether it matches the expected events mask. * - On error: it returns a Status(ErrorCodes::InternalError) */ -template <typename Socket, typename EventsMask> -StatusWith<EventsMask> pollASIOSocket(Socket& socket, EventsMask mask, Milliseconds timeout) { -#ifdef _WIN32 - fd_set readfds; - fd_set writefds; - fd_set errfds; - - FD_ZERO(&readfds); - FD_ZERO(&writefds); - FD_ZERO(&errfds); - - auto fd = socket.native_handle(); - if (mask & POLLIN) { - FD_SET(fd, &readfds); - } - if (mask & POLLOUT) { - FD_SET(fd, &writefds); - } - FD_SET(fd, &errfds); - - timeval timeoutTv{}; - auto timeoutUs = duration_cast<Microseconds>(timeout); - if (timeoutUs >= Seconds{1}) { - auto timeoutSec = duration_cast<Seconds>(timeoutUs); - timeoutTv.tv_sec = timeoutSec.count(); - timeoutUs -= timeoutSec; - } - timeoutTv.tv_usec = timeoutUs.count(); - int result = ::select(1, &readfds, &writefds, &errfds, &timeoutTv); - if (result == SOCKET_ERROR) { - auto errDesc = errnoWithDescription(WSAGetLastError()); - return {ErrorCodes::InternalError, errDesc}; - } - int revents = (FD_ISSET(fd, &readfds) ? POLLIN : 0) | (FD_ISSET(fd, &writefds) ? POLLOUT : 0) | - (FD_ISSET(fd, &errfds) ? POLLERR : 0); -#else - pollfd pollItem = {}; - pollItem.fd = socket.native_handle(); - pollItem.events = mask; - - int result; - boost::optional<Date_t> expiration; - if (timeout.count() > 0) { - expiration = Date_t::now() + timeout; - } - do { - Milliseconds curTimeout; - if (expiration) { - curTimeout = *expiration - Date_t::now(); - if (curTimeout.count() <= 0) { - result = 0; - break; - } - } else { - curTimeout = timeout; - } - result = ::poll(&pollItem, 1, curTimeout.count()); - } while (result == -1 && errno == EINTR); - - if (result == -1) { - int errCode = errno; - return {ErrorCodes::InternalError, errnoWithDescription(errCode)}; - } - int revents = pollItem.revents; -#endif - - if (result == 0) { - return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"}; - } else { - return revents; - } -} +StatusWith<unsigned> pollASIOSocket(asio::generic::stream_protocol::socket& socket, + unsigned mask, + Milliseconds timeout); #ifdef MONGO_CONFIG_SSL /** * Peeks at a fragment of a client issued TLS handshake packet. Returns a TLS alert * packet if the client has selected a protocol which has been disabled by the server. */ -template <typename Buffer> -boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(const Buffer& buffers) { - // This method's caller should have read in at least one MSGHEADER::Value's worth of data. - // The fragment we are about to examine must be strictly smaller. - static const size_t sizeOfTLSFragmentToRead = 11; - invariant(asio::buffer_size(buffers) >= sizeOfTLSFragmentToRead); - - static_assert(sizeOfTLSFragmentToRead < sizeof(MSGHEADER::Value), - "checkTLSRequest's caller read a MSGHEADER::Value, which must be larger than " - "message containing the TLS version"); - - /** - * The fragment we are to examine is a record, containing a handshake, containing a - * ClientHello. We wish to examine the advertised protocol version in the ClientHello. - * The following roughly describes the contents of these structures. Note that we do not - * need, or wish to, examine the entire ClientHello, we're looking exclusively for the - * client_version. - * - * Below is a rough description of the payload we will be examining. We shall perform some - * basic checks to ensure the payload matches these expectations. If it does not, we should - * bail out, and not emit protocol version alerts. - * - * enum {alert(21), handshake(22)} ContentType; - * TLSPlaintext { - * ContentType type = handshake(22), - * ProtocolVersion version; // Irrelevant. Clients send the real version in ClientHello. - * uint16 length; - * fragment, see Handshake stuct for contents - * ... - * } - * - * enum {client_hello(1)} HandshakeType; - * Handshake { - * HandshakeType msg_type = client_hello(1); - * uint24_t length; - * ClientHello body; - * } - * - * ClientHello { - * ProtocolVersion client_version; // <- This is the value we want to extract. - * } - */ - - static const std::uint8_t ContentType_handshake = 22; - static const std::uint8_t HandshakeType_client_hello = 1; - - using ProtocolVersion = std::array<std::uint8_t, 2>; - static const ProtocolVersion tls10VersionBytes{3, 1}; - static const ProtocolVersion tls11VersionBytes{3, 2}; - - auto request = asio::buffer_cast<const char*>(buffers); - auto cdr = ConstDataRangeCursor(request, request + asio::buffer_size(buffers)); - - // Parse the record header. - // Extract the ContentType from the header, and ensure it is a handshake. - StatusWith<std::uint8_t> record_ContentType = cdr.readAndAdvanceNoThrow<std::uint8_t>(); - if (!record_ContentType.isOK() || record_ContentType.getValue() != ContentType_handshake) { - return boost::none; - } - // Skip the record's ProtocolVersion. Clients tend to send TLS 1.0 in - // the record, but then their real protocol version in the enclosed ClientHello. - StatusWith<ProtocolVersion> record_protocol_version = - cdr.readAndAdvanceNoThrow<ProtocolVersion>(); - if (!record_protocol_version.isOK()) { - return boost::none; - } - // Parse the record length. It should be be larger than the remaining expected payload. - auto record_length = cdr.readAndAdvanceNoThrow<BigEndian<std::uint16_t>>(); - if (!record_length.isOK() || record_length.getValue() < cdr.length()) { - return boost::none; - } - - // Parse the handshake header. - // Extract the HandshakeType, and ensure it is a ClientHello. - StatusWith<std::uint8_t> handshake_type = cdr.readAndAdvanceNoThrow<std::uint8_t>(); - if (!handshake_type.isOK() || handshake_type.getValue() != HandshakeType_client_hello) { - return boost::none; - } - // Extract the handshake length, and ensure it is larger than the remaining expected - // payload. This requires a little work because the packet represents it with a uint24_t. - StatusWith<std::array<std::uint8_t, 3>> handshake_length_bytes = - cdr.readAndAdvanceNoThrow<std::array<std::uint8_t, 3>>(); - if (!handshake_length_bytes.isOK()) { - return boost::none; - } - std::uint32_t handshake_length = 0; - for (std::uint8_t handshake_byte : handshake_length_bytes.getValue()) { - handshake_length <<= 8; - handshake_length |= handshake_byte; - } - if (handshake_length < cdr.length()) { - return boost::none; - } - StatusWith<ProtocolVersion> client_version = cdr.readAndAdvanceNoThrow<ProtocolVersion>(); - if (!client_version.isOK()) { - return boost::none; - } - - // Invariant: We read exactly as much data as expected. - invariant((cdr.data() - request) == sizeOfTLSFragmentToRead); - - auto isProtocolDisabled = [](SSLParams::Protocols protocol) { - const auto& params = getSSLGlobalParams(); - return std::find(params.sslDisabledProtocols.begin(), - params.sslDisabledProtocols.end(), - protocol) != params.sslDisabledProtocols.end(); - }; - - auto makeTLSProtocolVersionAlert = - [](const std::array<std::uint8_t, 2>& versionBytes) -> std::array<std::uint8_t, 7> { - /** - * The structure for this alert packet is as follows: - * TLSPlaintext { - * ContentType type = alert(21); - * ProtocolVersion = versionBytes; - * uint16_t length = 2 - * fragment = AlertDescription { - * AlertLevel level = fatal(2); - * AlertDescription = protocol_version(70); - * } - * - */ - return std::array<std::uint8_t, 7>{ - 0x15, versionBytes[0], versionBytes[1], 0x00, 0x02, 0x02, 0x46}; - }; - - ProtocolVersion version = client_version.getValue(); - if (version == tls10VersionBytes && isProtocolDisabled(SSLParams::Protocols::TLS1_0)) { - return makeTLSProtocolVersionAlert(version); - } else if (client_version == tls11VersionBytes && - isProtocolDisabled(SSLParams::Protocols::TLS1_1)) { - return makeTLSProtocolVersionAlert(version); - } - // TLS1.2 cannot be distinguished from TLS1.3, just by looking at the ProtocolVersion bytes. - // TLS 1.3 compatible clients advertise a "supported_versions" extension, which we would - // have to extract here. - // Hopefully by the time this matters, OpenSSL will properly emit protocol_version alerts. - - return boost::none; -} +boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(const asio::const_buffer& buffer); #endif /** + * setSocketOption failed. Log the error. + * This is in the .cpp file just to keep LOGV2 out of this header. + */ +void failedSetSocketOption(const std::system_error& ex, StringData note, BSONObj optionDescription); + +/** * Calls Asio `socket.set_option(opt)` with better failure diagnostics. * To be used instead of Asio `socket.set_option``, because errors are hard to diagnose. * Emits a log message about what option was attempted and what went wrong with @@ -342,26 +110,14 @@ void setSocketOption(Socket& socket, const Option& opt, StringData note) { try { socket.set_option(opt); } catch (const std::system_error& ex) { - LOGV2_INFO(5693100, - "Asio socket.set_option failed with std::system_error", - "note"_attr = note, - "option"_attr = - [&opt, p = socket.local_endpoint().protocol()] { - return BSONObjBuilder{} - .append("level", opt.level(p)) - .append("name", opt.name(p)) - .append("data", hexdump(opt.data(p), opt.size(p))) - .obj(); - }(), - "error"_attr = - [&ex] { - return BSONObjBuilder{} - .append("what", ex.what()) - .append("message", ex.code().message()) - .append("category", ex.code().category().name()) - .append("value", ex.code().value()) - .obj(); - }()); + BSONObj optionDescription = [&opt, p = socket.local_endpoint().protocol()] { + return BSONObjBuilder{} + .append("level", opt.level(p)) + .append("name", opt.name(p)) + .append("data", hexdump(opt.data(p), opt.size(p))) + .obj(); + }(); + failedSetSocketOption(ex, note, optionDescription); throw; } } @@ -390,106 +146,89 @@ void setSocketOption(Socket& socket, const Option& opt, std::error_code& ec, Str * Example: * Future<size_t> future = my_socket.async_read_some(my_buffer, UseFuture{}); */ -struct UseFuture {}; - -namespace use_future_details { - -template <typename... Args> -struct AsyncHandlerHelper { - using Result = std::tuple<Args...>; - static void complete(Promise<Result>* promise, Args... args) { - promise->emplaceValue(args...); - } +struct UseFuture { + template <typename... Args> + class Adapter; }; -template <> -struct AsyncHandlerHelper<> { - using Result = void; - static void complete(Promise<Result>* promise) { - promise->emplaceValue(); - } -}; - -template <typename Arg> -struct AsyncHandlerHelper<Arg> { - using Result = Arg; - static void complete(Promise<Result>* promise, Arg arg) { - promise->emplaceValue(arg); - } -}; +template <typename... ArgsFromAsio> +class UseFuture::Adapter { +private: + template <typename Dum, typename... Ts> + struct ArgPack : stdx::type_identity<std::tuple<Ts...>> {}; + template <typename Dum> + struct ArgPack<Dum> : stdx::type_identity<void> {}; + template <typename Dum, typename T> + struct ArgPack<Dum, T> : stdx::type_identity<T> {}; -template <typename... Args> -struct AsyncHandlerHelper<std::error_code, Args...> { - using Helper = AsyncHandlerHelper<Args...>; - using Result = typename Helper::Result; - - template <typename... Args2> - static void complete(Promise<Result>* promise, std::error_code ec, Args2&&... args) { - if (ec) { - promise->setError(errorCodeToStatus(ec)); - } else { - Helper::complete(promise, std::forward<Args2>(args)...); + /** + * If an Asio callback takes a leading error_code, it's stripped from + * the Future's value_type. Any errors reported by Asio will instead + * be delivered by setting the Future's error Status. + */ + template <typename Dum, typename... Ts> + struct StripError : ArgPack<Dum, Ts...> {}; + template <typename Dum, typename... Ts> + struct StripError<Dum, std::error_code, Ts...> : ArgPack<Dum, Ts...> {}; + + using Result = typename StripError<void, ArgsFromAsio...>::type; + + struct Handler { + private: + template <typename... As> + void _onSuccess(As&&... args) { + promise.emplaceValue(std::forward<As>(args)...); } - } -}; - -template <> -struct AsyncHandlerHelper<std::error_code> { - using Result = void; - static void complete(Promise<Result>* promise, std::error_code ec) { - if (ec) { - promise->setError(errorCodeToStatus(ec)); - } else { - promise->emplaceValue(); + template <typename... As> + void _onInvoke(As&&... args) { + _onSuccess(std::forward<As>(args)...); + } + template <typename... As> + void _onInvoke(std::error_code ec, As&&... args) { + if (ec) { + promise.setError(errorCodeToStatus(ec)); + return; + } + _onSuccess(std::forward<As>(args)...); } - } -}; - -template <typename... Args> -struct AsyncHandler { - using Helper = AsyncHandlerHelper<Args...>; - using Result = typename Helper::Result; - explicit AsyncHandler(UseFuture) {} + public: + explicit Handler(const UseFuture&) {} - template <typename... Args2> - void operator()(Args2&&... args) { - Helper::complete(&promise, std::forward<Args2>(args)...); - } + template <typename... As> + void operator()(As&&... args) { + static_assert((std::is_same_v<std::decay_t<As>, std::decay_t<ArgsFromAsio>> && ...), + "Unexpected argument list from Asio async result callback."); + _onInvoke(std::forward<As>(args)...); + } - Promise<Result> promise; -}; + Promise<Result> promise; + }; -template <typename... Args> -struct AsyncResult { - using completion_handler_type = AsyncHandler<Args...>; - using RealResult = typename AsyncHandler<Args...>::Result; - using return_type = Future<RealResult>; +public: + using return_type = Future<Result>; + using completion_handler_type = Handler; - explicit AsyncResult(completion_handler_type& handler) { - auto pf = makePromiseFuture<RealResult>(); - fut = std::move(pf.future); - handler.promise = std::move(pf.promise); + explicit Adapter(Handler& handler) { + auto&& [p, f] = makePromiseFuture<Result>(); + _fut = std::move(f); + handler.promise = std::move(p); } - auto get() { - return std::move(fut); + return_type get() { + return std::move(_fut); } - Future<RealResult> fut; +private: + Future<Result> _fut; }; -} // namespace use_future_details -} // namespace transport -} // namespace mongo +} // namespace mongo::transport namespace asio { -template <typename Comp, typename Sig> -class async_result; - -template <typename Result, typename... Args> -class async_result<::mongo::transport::UseFuture, Result(Args...)> - : public ::mongo::transport::use_future_details::AsyncResult<Args...> { - using ::mongo::transport::use_future_details::AsyncResult<Args...>::AsyncResult; +template <typename... Args> +class async_result<mongo::transport::UseFuture, void(Args...)> + : public mongo::transport::UseFuture::Adapter<Args...> { + using mongo::transport::UseFuture::Adapter<Args...>::Adapter; }; } // namespace asio diff --git a/src/mongo/transport/session_asio.cpp b/src/mongo/transport/session_asio.cpp new file mode 100644 index 00000000000..b4f27331874 --- /dev/null +++ b/src/mongo/transport/session_asio.cpp @@ -0,0 +1,720 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork + +#include "mongo/transport/session_asio.h" + +#include "mongo/config.h" +#include "mongo/logv2/log.h" + +namespace mongo::transport { + +MONGO_FAIL_POINT_DEFINE(transportLayerASIOshortOpportunisticReadWrite); + +namespace { + +template <int Name> +class ASIOSocketTimeoutOption { +public: +#ifdef _WIN32 + using TimeoutType = DWORD; + + ASIOSocketTimeoutOption(Milliseconds timeoutVal) : _timeout(timeoutVal.count()) {} + +#else + using TimeoutType = timeval; + + ASIOSocketTimeoutOption(Milliseconds timeoutVal) { + _timeout.tv_sec = duration_cast<Seconds>(timeoutVal).count(); + const auto minusSeconds = timeoutVal - Seconds{_timeout.tv_sec}; + _timeout.tv_usec = duration_cast<Microseconds>(minusSeconds).count(); + } +#endif + + template <typename Protocol> + int name(const Protocol&) const { + return Name; + } + + template <typename Protocol> + const TimeoutType* data(const Protocol&) const { + return &_timeout; + } + + template <typename Protocol> + std::size_t size(const Protocol&) const { + return sizeof(_timeout); + } + + template <typename Protocol> + int level(const Protocol&) const { + return SOL_SOCKET; + } + +private: + TimeoutType _timeout; +}; + +} // namespace + + +TransportLayerASIO::ASIOSession::ASIOSession( + TransportLayerASIO* tl, + GenericSocket socket, + bool isIngressSession, + Endpoint endpoint, + std::shared_ptr<const SSLConnectionContext> transientSSLContext) try + : _socket(std::move(socket)), + _tl(tl), + _isIngressSession(isIngressSession) { + auto family = endpointToSockAddr(_socket.local_endpoint()).getType(); + if (family == AF_INET || family == AF_INET6) { + setSocketOption(_socket, asio::ip::tcp::no_delay(true), "session no delay"); + setSocketOption(_socket, asio::socket_base::keep_alive(true), "session keep alive"); + setSocketKeepAliveParams(_socket.native_handle()); + } + + _localAddr = endpointToSockAddr(_socket.local_endpoint()); + + if (endpoint == Endpoint()) { + // Inbound connection, query socket for remote. + _remoteAddr = endpointToSockAddr(_socket.remote_endpoint()); + } else { + // Outbound connection, get remote from resolved endpoint. + // Necessary for TCP_FASTOPEN where the remote isn't connected yet. + _remoteAddr = endpointToSockAddr(endpoint); + } + + _local = HostAndPort(_localAddr.toString(true)); + _remote = HostAndPort(_remoteAddr.toString(true)); +#ifdef MONGO_CONFIG_SSL + _sslContext = transientSSLContext ? transientSSLContext : *tl->_sslContext; + if (transientSSLContext) { + logv2::DynamicAttributes attrs; + if (transientSSLContext->targetClusterURI) { + attrs.add("targetClusterURI", *transientSSLContext->targetClusterURI); + } + attrs.add("isIngress", isIngressSession); + attrs.add("connectionId", id()); + attrs.add("remote", remote()); + LOGV2(5271001, "Initializing the ASIOSession with transient SSL context", attrs); + } +#endif +} catch (const DBException&) { + throw; +} catch (const asio::system_error& error) { + uasserted(ErrorCodes::SocketException, error.what()); +} catch (...) { + uasserted(50797, str::stream() << "Unknown exception while configuring socket."); +} + +void TransportLayerASIO::ASIOSession::end() { + if (getSocket().is_open()) { + std::error_code ec; + getSocket().shutdown(GenericSocket::shutdown_both, ec); + if ((ec) && (ec != asio::error::not_connected)) { + LOGV2_ERROR(23841, + "Error shutting down socket: {error}", + "Error shutting down socket", + "error"_attr = ec.message()); + } + } +} + +StatusWith<Message> TransportLayerASIO::ASIOSession::sourceMessage() noexcept try { + ensureSync(); + return sourceMessageImpl().getNoThrow(); +} catch (const DBException& ex) { + return ex.toStatus(); +} + +Future<Message> TransportLayerASIO::ASIOSession::asyncSourceMessage( + const BatonHandle& baton) noexcept try { + ensureAsync(); + return sourceMessageImpl(baton); +} catch (const DBException& ex) { + return ex.toStatus(); +} + +Status TransportLayerASIO::ASIOSession::waitForData() noexcept try { + ensureSync(); + asio::error_code ec; + getSocket().wait(asio::ip::tcp::socket::wait_read, ec); + return errorCodeToStatus(ec); +} catch (const DBException& ex) { + return ex.toStatus(); +} + +Future<void> TransportLayerASIO::ASIOSession::asyncWaitForData() noexcept try { + ensureAsync(); + return getSocket().async_wait(asio::ip::tcp::socket::wait_read, UseFuture{}); +} catch (const DBException& ex) { + return ex.toStatus(); +} + +Status TransportLayerASIO::ASIOSession::sinkMessage(Message message) noexcept try { + ensureSync(); + + return write(asio::buffer(message.buf(), message.size())) + .then([this, &message] { + if (_isIngressSession) { + networkCounter.hitPhysicalOut(message.size()); + } + }) + .getNoThrow(); +} catch (const DBException& ex) { + return ex.toStatus(); +} + +Future<void> TransportLayerASIO::ASIOSession::asyncSinkMessage( + Message message, const BatonHandle& baton) noexcept try { + ensureAsync(); + return write(asio::buffer(message.buf(), message.size()), baton) + .then([this, message /*keep the buffer alive*/]() { + if (_isIngressSession) { + networkCounter.hitPhysicalOut(message.size()); + } + }); +} catch (const DBException& ex) { + return ex.toStatus(); +} + +void TransportLayerASIO::ASIOSession::cancelAsyncOperations(const BatonHandle& baton) { + LOGV2_DEBUG(4615608, + 3, + "Cancelling outstanding I/O operations on connection to {remote}", + "Cancelling outstanding I/O operations on connection to remote", + "remote"_attr = _remote); + if (baton && baton->networking() && baton->networking()->cancelSession(*this)) { + // If we have a baton, it was for networking, and it owned our session, then we're done. + return; + } + + getSocket().cancel(); +} + +void TransportLayerASIO::ASIOSession::setTimeout(boost::optional<Milliseconds> timeout) { + invariant(!timeout || timeout->count() > 0); + _configuredTimeout = timeout; +} + +bool TransportLayerASIO::ASIOSession::isConnected() { + // socket.is_open() only returns whether the socket is a valid file descriptor and + // if we haven't marked this socket as closed already. + if (!getSocket().is_open()) + return false; + + auto swPollEvents = pollASIOSocket(getSocket(), POLLIN, Milliseconds{0}); + if (!swPollEvents.isOK()) { + if (swPollEvents != ErrorCodes::NetworkTimeout) { + LOGV2_WARNING(4615609, + "Failed to poll socket for connectivity check: {error}", + "Failed to poll socket for connectivity check", + "error"_attr = swPollEvents.getStatus()); + return false; + } + return true; + } + + auto revents = swPollEvents.getValue(); + if (revents & POLLIN) { + char testByte; + int size = ::recv(getSocket().native_handle(), &testByte, sizeof(testByte), MSG_PEEK); + if (size == sizeof(testByte)) { + return true; + } else if (size == -1) { + LOGV2_WARNING(4615610, + "Failed to check socket connectivity: {error}", + "Failed to check socket connectivity", + "error"_attr = errnoWithDescription(errno)); + } + // If size == 0 then we got disconnected and we should return false. + } + + return false; +} + +#ifdef MONGO_CONFIG_SSL +const SSLConfiguration* TransportLayerASIO::ASIOSession::getSSLConfiguration() const { + if (_sslContext->manager) { + return &_sslContext->manager->getSSLConfiguration(); + } + return nullptr; +} + +const std::shared_ptr<SSLManagerInterface> TransportLayerASIO::ASIOSession::getSSLManager() const { + return _sslContext->manager; +} + +// The unique_lock here is held by TransportLayerASIO to synchronize with the asyncConnect +// timeout callback. It will be unlocked before the SSL actually handshake begins. +Future<void> TransportLayerASIO::ASIOSession::handshakeSSLForEgressWithLock( + stdx::unique_lock<Latch> lk, const HostAndPort& target, const ReactorHandle& reactor) { + if (!_sslContext->egress) { + return Future<void>::makeReady( + Status(ErrorCodes::SSLHandshakeFailed, "SSL requested but SSL support is disabled")); + } + + _sslSocket.emplace(std::move(_socket), *_sslContext->egress, removeFQDNRoot(target.host())); + lk.unlock(); + + auto doHandshake = [&] { + if (_blockingMode == Sync) { + std::error_code ec; + _sslSocket->handshake(asio::ssl::stream_base::client, ec); + return futurize(ec); + } else { + return _sslSocket->async_handshake(asio::ssl::stream_base::client, UseFuture{}); + } + }; + return doHandshake().then([this, target, reactor] { + _ranHandshake = true; + + return getSSLManager() + ->parseAndValidatePeerCertificate( + _sslSocket->native_handle(), _sslSocket->get_sni(), target.host(), target, reactor) + .then([this](SSLPeerInfo info) { SSLPeerInfo::forSession(shared_from_this()) = info; }); + }); +} + +// For synchronous connections where we don't have an async timer, just take a dummy lock and +// pass it to the WithLock version of handshakeSSLForEgress +Future<void> TransportLayerASIO::ASIOSession::handshakeSSLForEgress(const HostAndPort& target) { + auto mutex = MONGO_MAKE_LATCH(); + return handshakeSSLForEgressWithLock(stdx::unique_lock<Latch>(mutex), target, nullptr); +} +#endif + +void TransportLayerASIO::ASIOSession::ensureSync() { + asio::error_code ec; + if (_blockingMode != Sync) { + getSocket().non_blocking(false, ec); + fassert(40490, errorCodeToStatus(ec)); + _blockingMode = Sync; + } + + if (_socketTimeout != _configuredTimeout) { + // Change boost::none (which means no timeout) into a zero value for the socket option, + // which also means no timeout. + auto timeout = _configuredTimeout.value_or(Milliseconds{0}); + setSocketOption( + getSocket(), ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), ec, "session send timeout"); + if (auto status = errorCodeToStatus(ec); !status.isOK()) { + tasserted(5342000, status.reason()); + } + + setSocketOption(getSocket(), + ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout), + ec, + "session receive timeout"); + if (auto status = errorCodeToStatus(ec); !status.isOK()) { + tasserted(5342001, status.reason()); + } + + _socketTimeout = _configuredTimeout; + } +} + +void TransportLayerASIO::ASIOSession::ensureAsync() { + if (_blockingMode == Async) + return; + + // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't + // expecting a socket timeout when they do an async operation. + invariant(!_configuredTimeout); + + asio::error_code ec; + getSocket().non_blocking(true, ec); + fassert(50706, errorCodeToStatus(ec)); + _blockingMode = Async; +} + +auto TransportLayerASIO::ASIOSession::getSocket() -> GenericSocket& { +#ifdef MONGO_CONFIG_SSL + if (_sslSocket) { + return static_cast<GenericSocket&>(_sslSocket->lowest_layer()); + } +#endif + return _socket; +} + +Future<Message> TransportLayerASIO::ASIOSession::sourceMessageImpl(const BatonHandle& baton) { + static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value); + + auto headerBuffer = SharedBuffer::allocate(kHeaderSize); + auto ptr = headerBuffer.get(); + return read(asio::buffer(ptr, kHeaderSize), baton) + .then([headerBuffer = std::move(headerBuffer), this, baton]() mutable { + if (checkForHTTPRequest(asio::buffer(headerBuffer.get(), kHeaderSize))) { + return sendHTTPResponse(baton); + } + + const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength()); + if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) { + StringBuilder sb; + sb << "recv(): message msgLen " << msgLen << " is invalid. " + << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes; + const auto str = sb.str(); + LOGV2(4615638, + "recv(): message msgLen {msgLen} is invalid. Min: {min} Max: {max}", + "recv(): message mstLen is invalid.", + "msgLen"_attr = msgLen, + "min"_attr = kHeaderSize, + "max"_attr = MaxMessageSizeBytes); + + return Future<Message>::makeReady(Status(ErrorCodes::ProtocolError, str)); + } + + if (msgLen == kHeaderSize) { + // This probably isn't a real case since all (current) messages have bodies. + if (_isIngressSession) { + networkCounter.hitPhysicalIn(msgLen); + } + return Future<Message>::makeReady(Message(std::move(headerBuffer))); + } + + auto buffer = SharedBuffer::allocate(msgLen); + memcpy(buffer.get(), headerBuffer.get(), kHeaderSize); + + MsgData::View msgView(buffer.get()); + return read(asio::buffer(msgView.data(), msgView.dataLen()), baton) + .then([this, buffer = std::move(buffer), msgLen]() mutable { + if (_isIngressSession) { + networkCounter.hitPhysicalIn(msgLen); + } + return Message(std::move(buffer)); + }); + }); +} + +template <typename MutableBufferSequence> +Future<void> TransportLayerASIO::ASIOSession::read(const MutableBufferSequence& buffers, + const BatonHandle& baton) { + // TODO SERVER-47229 Guard active ops for cancellation here. +#ifdef MONGO_CONFIG_SSL + if (_sslSocket) { + return opportunisticRead(*_sslSocket, buffers, baton); + } else if (!_ranHandshake) { + invariant(asio::buffer_size(buffers) >= sizeof(MSGHEADER::Value)); + + return opportunisticRead(_socket, buffers, baton) + .then([this, buffers]() mutable { + _ranHandshake = true; + return maybeHandshakeSSLForIngress(buffers); + }) + .then([this, buffers, baton](bool needsRead) mutable { + if (needsRead) { + return read(buffers, baton); + } else { + return Future<void>::makeReady(); + } + }); + } +#endif + return opportunisticRead(_socket, buffers, baton); +} + +template <typename ConstBufferSequence> +Future<void> TransportLayerASIO::ASIOSession::write(const ConstBufferSequence& buffers, + const BatonHandle& baton) { + // TODO SERVER-47229 Guard active ops for cancellation here. +#ifdef MONGO_CONFIG_SSL + _ranHandshake = true; + if (_sslSocket) { +#ifdef __linux__ + // We do some trickery in asio (see moreToSend), which appears to work well on linux, + // but fails on other platforms. + return opportunisticWrite(*_sslSocket, buffers, baton); +#else + if (_blockingMode == Async) { + // Opportunistic writes are broken for async egress SSL (switching between blocking + // and non-blocking mode corrupts the TLS exchange). + return asio::async_write(*_sslSocket, buffers, UseFuture{}).ignoreValue(); + } else { + return opportunisticWrite(*_sslSocket, buffers, baton); + } +#endif + } +#endif // MONGO_CONFIG_SSL + return opportunisticWrite(_socket, buffers, baton); +} + +template <typename Stream, typename MutableBufferSequence> +Future<void> TransportLayerASIO::ASIOSession::opportunisticRead( + Stream& stream, const MutableBufferSequence& buffers, const BatonHandle& baton) { + std::error_code ec; + size_t size; + + if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) && + _blockingMode == Async) { + asio::mutable_buffer localBuffer = buffers; + + if (buffers.size()) { + localBuffer = asio::mutable_buffer(buffers.data(), 1); + } + + do { + size = asio::read(stream, localBuffer, ec); + } while (ec == asio::error::interrupted); // retry syscall EINTR + + if (!ec && buffers.size() > 1) { + ec = asio::error::would_block; + } + } else { + do { + size = asio::read(stream, buffers, ec); + } while (ec == asio::error::interrupted); // retry syscall EINTR + } + + if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) && + (_blockingMode == Async)) { + // asio::read is a loop internally, so some of buffers may have been read into already. + // So we need to adjust the buffers passed into async_read to be offset by size, if + // size is > 0. + MutableBufferSequence asyncBuffers(buffers); + if (size > 0) { + asyncBuffers += size; + } + + if (auto networkingBaton = baton ? baton->networking() : nullptr; + networkingBaton && networkingBaton->canWait()) { + return networkingBaton->addSession(*this, NetworkingBaton::Type::In) + .onError([](Status error) { + if (ErrorCodes::isShutdownError(error)) { + // If the baton has detached, it will cancel its polling. We catch that + // error here and return Status::OK so that we invoke + // opportunisticRead() again and switch to asio::async_read() below. + return Status::OK(); + } + + return error; + }) + .then([&stream, asyncBuffers, baton, this] { + return opportunisticRead(stream, asyncBuffers, baton); + }); + } + + return asio::async_read(stream, asyncBuffers, UseFuture{}).ignoreValue(); + } else { + return futurize(ec); + } +} + +#ifdef MONGO_CONFIG_SSL +boost::optional<std::string> TransportLayerASIO::ASIOSession::getSniName() const { + return SSLPeerInfo::forSession(shared_from_this()).sniName; +} +#endif + +template <typename Stream, typename ConstBufferSequence> +Future<void> TransportLayerASIO::ASIOSession::opportunisticWrite(Stream& stream, + const ConstBufferSequence& buffers, + const BatonHandle& baton) { + std::error_code ec; + std::size_t size; + + if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) && + _blockingMode == Async) { + asio::const_buffer localBuffer = buffers; + + if (buffers.size()) { + localBuffer = asio::const_buffer(buffers.data(), 1); + } + + do { + size = asio::write(stream, localBuffer, ec); + } while (ec == asio::error::interrupted); // retry syscall EINTR + if (!ec && buffers.size() > 1) { + ec = asio::error::would_block; + } + } else { + do { + size = asio::write(stream, buffers, ec); + } while (ec == asio::error::interrupted); // retry syscall EINTR + } + + if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) && + (_blockingMode == Async)) { + + // asio::write is a loop internally, so some of buffers may have been read into already. + // So we need to adjust the buffers passed into async_write to be offset by size, if + // size is > 0. + ConstBufferSequence asyncBuffers(buffers); + if (size > 0) { + asyncBuffers += size; + } + + if (auto more = moreToSend(stream, asyncBuffers, baton)) { + return std::move(*more); + } + + if (auto networkingBaton = baton ? baton->networking() : nullptr; + networkingBaton && networkingBaton->canWait()) { + return networkingBaton->addSession(*this, NetworkingBaton::Type::Out) + .onError([](Status error) { + if (ErrorCodes::isCancellationError(error)) { + // If the baton has detached, it will cancel its polling. We catch that + // error here and return Status::OK so that we invoke + // opportunisticWrite() again and switch to asio::async_write() below. + return Status::OK(); + } + + return error; + }) + .then([&stream, asyncBuffers, baton, this] { + return opportunisticWrite(stream, asyncBuffers, baton); + }); + } + + return asio::async_write(stream, asyncBuffers, UseFuture{}).ignoreValue(); + } else { + return futurize(ec); + } +} + +#ifdef MONGO_CONFIG_SSL +template <typename MutableBufferSequence> +Future<bool> TransportLayerASIO::ASIOSession::maybeHandshakeSSLForIngress( + const MutableBufferSequence& buffer) { + invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value)); + MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer)); + auto responseTo = headerView.getResponseToMsgId(); + + if (checkForHTTPRequest(buffer)) { + return Future<bool>::makeReady(false); + } + // This logic was taken from the old mongo/util/net/sock.cpp. + // + // It lets us run both TLS and unencrypted mongo over the same port. + // + // The first message received from the client should have the responseTo field of the wire + // protocol message needs to be 0 or -1. Otherwise the connection is either sending + // garbage or a TLS Hello packet which will be caught by the TLS handshake. + if (responseTo != 0 && responseTo != -1) { + if (!_sslContext->ingress) { + return Future<bool>::makeReady( + Status(ErrorCodes::SSLHandshakeFailed, + "SSL handshake received but server is started without SSL support")); + } + + auto tlsAlert = checkTLSRequest(buffer); + if (tlsAlert) { + return opportunisticWrite(getSocket(), asio::buffer(tlsAlert->data(), tlsAlert->size())) + .then([] { + return Future<bool>::makeReady( + Status(ErrorCodes::SSLHandshakeFailed, + "SSL handshake failed, as client requested disabled protocol")); + }); + } + + _sslSocket.emplace(std::move(_socket), *_sslContext->ingress, ""); + auto doHandshake = [&] { + if (_blockingMode == Sync) { + std::error_code ec; + _sslSocket->handshake(asio::ssl::stream_base::server, buffer, ec); + return futurize(ec, asio::buffer_size(buffer)); + } else { + return _sslSocket->async_handshake( + asio::ssl::stream_base::server, buffer, UseFuture{}); + } + }; + return doHandshake().then([this](size_t size) { + if (_sslSocket->get_sni()) { + auto sniName = _sslSocket->get_sni().get(); + LOGV2_DEBUG( + 4908000, 2, "Client connected with SNI extension", "sniName"_attr = sniName); + } else { + LOGV2_DEBUG(4908001, 2, "Client connected without SNI extension"); + } + if (SSLPeerInfo::forSession(shared_from_this()).subjectName.empty()) { + return getSSLManager() + ->parseAndValidatePeerCertificate( + _sslSocket->native_handle(), _sslSocket->get_sni(), "", _remote, nullptr) + .then([this](SSLPeerInfo info) -> bool { + SSLPeerInfo::forSession(shared_from_this()) = info; + return true; + }); + } + + return Future<bool>::makeReady(true); + }); + } else if (_tl->_sslMode() == SSLParams::SSLMode_requireSSL) { + uasserted(ErrorCodes::SSLHandshakeFailed, + "The server is configured to only allow SSL connections"); + } else { + if (!sslGlobalParams.disableNonSSLConnectionLogging && + _tl->_sslMode() == SSLParams::SSLMode_preferSSL) { + LOGV2(23838, + "SSL mode is set to 'preferred' and connection {connectionId} to {remote} is " + "not using SSL.", + "SSL mode is set to 'preferred' and connection to remote is not using SSL.", + "connectionId"_attr = id(), + "remote"_attr = remote()); + } + return Future<bool>::makeReady(false); + } +} +#endif // MONGO_CONFIG_SSL + +template <typename Buffer> +bool TransportLayerASIO::ASIOSession::checkForHTTPRequest(const Buffer& buffers) { + invariant(asio::buffer_size(buffers) >= 4); + const StringData bufferAsStr(asio::buffer_cast<const char*>(buffers), 4); + return (bufferAsStr == "GET "_sd); +} + +Future<Message> TransportLayerASIO::ASIOSession::sendHTTPResponse(const BatonHandle& baton) { + constexpr auto userMsg = + "It looks like you are trying to access MongoDB over HTTP" + " on the native driver port.\r\n"_sd; + + static const std::string httpResp = str::stream() << "HTTP/1.0 200 OK\r\n" + "Connection: close\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: " + << userMsg.size() << "\r\n\r\n" + << userMsg; + + return write(asio::buffer(httpResp.data(), httpResp.size()), baton) + .onError([](const Status& status) { + return Status(ErrorCodes::ProtocolError, + str::stream() + << "Client sent an HTTP request over a native MongoDB connection, " + "but there was an error sending a response: " + << status.toString()); + }) + .then([] { + return StatusWith<Message>( + ErrorCodes::ProtocolError, + "Client sent an HTTP request over a native MongoDB connection"); + }); +} + +} // namespace mongo::transport diff --git a/src/mongo/transport/session_asio.h b/src/mongo/transport/session_asio.h index a91e710f6f9..b77d95c5c7f 100644 --- a/src/mongo/transport/session_asio.h +++ b/src/mongo/transport/session_asio.h @@ -51,10 +51,9 @@ #include "mongo/util/net/ssl.hpp" #endif -namespace mongo { -namespace transport { +namespace mongo::transport { -MONGO_FAIL_POINT_DEFINE(transportLayerASIOshortOpportunisticReadWrite); +extern FailPoint transportLayerASIOshortOpportunisticReadWrite; template <typename SuccessValue> auto futurize(const std::error_code& ec, SuccessValue&& successValue) { @@ -65,7 +64,7 @@ auto futurize(const std::error_code& ec, SuccessValue&& successValue) { return Result::makeReady(successValue); } -Future<void> futurize(const std::error_code& ec) { +inline Future<void> futurize(const std::error_code& ec) { using Result = Future<void>; if (MONGO_unlikely(ec)) { return Result::makeReady(errorCodeToStatus(ec)); @@ -73,13 +72,10 @@ Future<void> futurize(const std::error_code& ec) { return Result::makeReady(); } -using GenericSocket = asio::generic::stream_protocol::socket; - class TransportLayerASIO::ASIOSession final : public Session { - ASIOSession(const ASIOSession&) = delete; - ASIOSession& operator=(const ASIOSession&) = delete; - public: + using GenericSocket = asio::generic::stream_protocol::socket; + using Endpoint = asio::generic::stream_protocol::endpoint; // If the socket is disconnected while any of these options are being set, this constructor @@ -88,50 +84,10 @@ public: GenericSocket socket, bool isIngressSession, Endpoint endpoint = Endpoint(), - std::shared_ptr<const SSLConnectionContext> transientSSLContext = nullptr) try - : _socket(std::move(socket)), - _tl(tl), - _isIngressSession(isIngressSession) { - auto family = endpointToSockAddr(_socket.local_endpoint()).getType(); - if (family == AF_INET || family == AF_INET6) { - setSocketOption(_socket, asio::ip::tcp::no_delay(true), "session no delay"); - setSocketOption(_socket, asio::socket_base::keep_alive(true), "session keep alive"); - setSocketKeepAliveParams(_socket.native_handle()); - } + std::shared_ptr<const SSLConnectionContext> transientSSLContext = nullptr); - _localAddr = endpointToSockAddr(_socket.local_endpoint()); - - if (endpoint == Endpoint()) { - // Inbound connection, query socket for remote. - _remoteAddr = endpointToSockAddr(_socket.remote_endpoint()); - } else { - // Outbound connection, get remote from resolved endpoint. - // Necessary for TCP_FASTOPEN where the remote isn't connected yet. - _remoteAddr = endpointToSockAddr(endpoint); - } - - _local = HostAndPort(_localAddr.toString(true)); - _remote = HostAndPort(_remoteAddr.toString(true)); -#ifdef MONGO_CONFIG_SSL - _sslContext = transientSSLContext ? transientSSLContext : *tl->_sslContext; - if (transientSSLContext) { - logv2::DynamicAttributes attrs; - if (transientSSLContext->targetClusterURI) { - attrs.add("targetClusterURI", *transientSSLContext->targetClusterURI); - } - attrs.add("isIngress", isIngressSession); - attrs.add("connectionId", id()); - attrs.add("remote", remote()); - LOGV2(5271001, "Initializing the ASIOSession with transient SSL context", attrs); - } -#endif - } catch (const DBException&) { - throw; - } catch (const asio::system_error& error) { - uasserted(ErrorCodes::SocketException, error.what()); - } catch (...) { - uasserted(50797, str::stream() << "Unknown exception while configuring socket."); - } + ASIOSession(const ASIOSession&) = delete; + ASIOSession& operator=(const ASIOSession&) = delete; ~ASIOSession() { end(); @@ -157,142 +113,31 @@ public: return _localAddr; } - void end() override { - if (getSocket().is_open()) { - std::error_code ec; - getSocket().shutdown(GenericSocket::shutdown_both, ec); - if ((ec) && (ec != asio::error::not_connected)) { - LOGV2_ERROR(23841, - "Error shutting down socket: {error}", - "Error shutting down socket", - "error"_attr = ec.message()); - } - } - } + void end() override; - StatusWith<Message> sourceMessage() noexcept override try { - ensureSync(); - return sourceMessageImpl().getNoThrow(); - } catch (const DBException& ex) { - return ex.toStatus(); - } + StatusWith<Message> sourceMessage() noexcept override; - Future<Message> asyncSourceMessage(const BatonHandle& baton = nullptr) noexcept override try { - ensureAsync(); - return sourceMessageImpl(baton); - } catch (const DBException& ex) { - return ex.toStatus(); - } + Future<Message> asyncSourceMessage(const BatonHandle& baton = nullptr) noexcept override; - Status waitForData() noexcept override try { - ensureSync(); - asio::error_code ec; - getSocket().wait(asio::ip::tcp::socket::wait_read, ec); - return errorCodeToStatus(ec); - } catch (const DBException& ex) { - return ex.toStatus(); - } + Status waitForData() noexcept override; - Future<void> asyncWaitForData() noexcept override try { - ensureAsync(); - return getSocket().async_wait(asio::ip::tcp::socket::wait_read, UseFuture{}); - } catch (const DBException& ex) { - return ex.toStatus(); - } + Future<void> asyncWaitForData() noexcept override; - Status sinkMessage(Message message) noexcept override try { - ensureSync(); - - return write(asio::buffer(message.buf(), message.size())) - .then([this, &message] { - if (_isIngressSession) { - networkCounter.hitPhysicalOut(message.size()); - } - }) - .getNoThrow(); - } catch (const DBException& ex) { - return ex.toStatus(); - } + Status sinkMessage(Message message) noexcept override; Future<void> asyncSinkMessage(Message message, - const BatonHandle& baton = nullptr) noexcept override try { - ensureAsync(); - return write(asio::buffer(message.buf(), message.size()), baton) - .then([this, message /*keep the buffer alive*/]() { - if (_isIngressSession) { - networkCounter.hitPhysicalOut(message.size()); - } - }); - } catch (const DBException& ex) { - return ex.toStatus(); - } + const BatonHandle& baton = nullptr) noexcept override; - void cancelAsyncOperations(const BatonHandle& baton = nullptr) override { - LOGV2_DEBUG(4615608, - 3, - "Cancelling outstanding I/O operations on connection to {remote}", - "Cancelling outstanding I/O operations on connection to remote", - "remote"_attr = _remote); - if (baton && baton->networking() && baton->networking()->cancelSession(*this)) { - // If we have a baton, it was for networking, and it owned our session, then we're done. - return; - } + void cancelAsyncOperations(const BatonHandle& baton = nullptr) override; - getSocket().cancel(); - } - - void setTimeout(boost::optional<Milliseconds> timeout) override { - invariant(!timeout || timeout->count() > 0); - _configuredTimeout = timeout; - } - - bool isConnected() override { - // socket.is_open() only returns whether the socket is a valid file descriptor and - // if we haven't marked this socket as closed already. - if (!getSocket().is_open()) - return false; - - auto swPollEvents = pollASIOSocket(getSocket(), POLLIN, Milliseconds{0}); - if (!swPollEvents.isOK()) { - if (swPollEvents != ErrorCodes::NetworkTimeout) { - LOGV2_WARNING(4615609, - "Failed to poll socket for connectivity check: {error}", - "Failed to poll socket for connectivity check", - "error"_attr = swPollEvents.getStatus()); - return false; - } - return true; - } - - auto revents = swPollEvents.getValue(); - if (revents & POLLIN) { - char testByte; - int size = ::recv(getSocket().native_handle(), &testByte, sizeof(testByte), MSG_PEEK); - if (size == sizeof(testByte)) { - return true; - } else if (size == -1) { - LOGV2_WARNING(4615610, - "Failed to check socket connectivity: {error}", - "Failed to check socket connectivity", - "error"_attr = errnoWithDescription(errno)); - } - // If size == 0 then we got disconnected and we should return false. - } + void setTimeout(boost::optional<Milliseconds> timeout) override; - return false; - } + bool isConnected() override; #ifdef MONGO_CONFIG_SSL - const SSLConfiguration* getSSLConfiguration() const override { - if (_sslContext->manager) { - return &_sslContext->manager->getSSLConfiguration(); - } - return nullptr; - } + const SSLConfiguration* getSSLConfiguration() const override; - const std::shared_ptr<SSLManagerInterface> getSSLManager() const override { - return _sslContext->manager; - } + const std::shared_ptr<SSLManagerInterface> getSSLManager() const override; #endif protected: @@ -304,305 +149,32 @@ protected: // timeout callback. It will be unlocked before the SSL actually handshake begins. Future<void> handshakeSSLForEgressWithLock(stdx::unique_lock<Latch> lk, const HostAndPort& target, - const ReactorHandle& reactor) { - if (!_sslContext->egress) { - return Future<void>::makeReady(Status(ErrorCodes::SSLHandshakeFailed, - "SSL requested but SSL support is disabled")); - } - - _sslSocket.emplace(std::move(_socket), *_sslContext->egress, removeFQDNRoot(target.host())); - lk.unlock(); - - auto doHandshake = [&] { - if (_blockingMode == Sync) { - std::error_code ec; - _sslSocket->handshake(asio::ssl::stream_base::client, ec); - return futurize(ec); - } else { - return _sslSocket->async_handshake(asio::ssl::stream_base::client, UseFuture{}); - } - }; - return doHandshake().then([this, target, reactor] { - _ranHandshake = true; - - return getSSLManager() - ->parseAndValidatePeerCertificate(_sslSocket->native_handle(), - _sslSocket->get_sni(), - target.host(), - target, - reactor) - .then([this](SSLPeerInfo info) { - SSLPeerInfo::forSession(shared_from_this()) = info; - }); - }); - } + const ReactorHandle& reactor); // For synchronous connections where we don't have an async timer, just take a dummy lock and // pass it to the WithLock version of handshakeSSLForEgress - Future<void> handshakeSSLForEgress(const HostAndPort& target) { - auto mutex = MONGO_MAKE_LATCH(); - return handshakeSSLForEgressWithLock(stdx::unique_lock<Latch>(mutex), target, nullptr); - } + Future<void> handshakeSSLForEgress(const HostAndPort& target); #endif - void ensureSync() { - asio::error_code ec; - if (_blockingMode != Sync) { - getSocket().non_blocking(false, ec); - fassert(40490, errorCodeToStatus(ec)); - _blockingMode = Sync; - } + void ensureSync(); - if (_socketTimeout != _configuredTimeout) { - // Change boost::none (which means no timeout) into a zero value for the socket option, - // which also means no timeout. - auto timeout = _configuredTimeout.value_or(Milliseconds{0}); - setSocketOption(getSocket(), - ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), - ec, - "session send timeout"); - if (auto status = errorCodeToStatus(ec); !status.isOK()) { - tasserted(5342000, status.reason()); - } - - setSocketOption(getSocket(), - ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout), - ec, - "session receive timeout"); - if (auto status = errorCodeToStatus(ec); !status.isOK()) { - tasserted(5342001, status.reason()); - } - - _socketTimeout = _configuredTimeout; - } - } - - void ensureAsync() { - if (_blockingMode == Async) - return; - - // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't - // expecting a socket timeout when they do an async operation. - invariant(!_configuredTimeout); - - asio::error_code ec; - getSocket().non_blocking(true, ec); - fassert(50706, errorCodeToStatus(ec)); - _blockingMode = Async; - } + void ensureAsync(); private: - template <int Name> - class ASIOSocketTimeoutOption { - public: -#ifdef _WIN32 - using TimeoutType = DWORD; + GenericSocket& getSocket(); - ASIOSocketTimeoutOption(Milliseconds timeoutVal) : _timeout(timeoutVal.count()) {} - -#else - using TimeoutType = timeval; - - ASIOSocketTimeoutOption(Milliseconds timeoutVal) { - _timeout.tv_sec = duration_cast<Seconds>(timeoutVal).count(); - const auto minusSeconds = timeoutVal - Seconds{_timeout.tv_sec}; - _timeout.tv_usec = duration_cast<Microseconds>(minusSeconds).count(); - } -#endif - - template <typename Protocol> - int name(const Protocol&) const { - return Name; - } - - template <typename Protocol> - const TimeoutType* data(const Protocol&) const { - return &_timeout; - } - - template <typename Protocol> - std::size_t size(const Protocol&) const { - return sizeof(_timeout); - } - - template <typename Protocol> - int level(const Protocol&) const { - return SOL_SOCKET; - } - - private: - TimeoutType _timeout; - }; - - GenericSocket& getSocket() { -#ifdef MONGO_CONFIG_SSL - if (_sslSocket) { - return static_cast<GenericSocket&>(_sslSocket->lowest_layer()); - } -#endif - return _socket; - } - - Future<Message> sourceMessageImpl(const BatonHandle& baton = nullptr) { - static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value); - - auto headerBuffer = SharedBuffer::allocate(kHeaderSize); - auto ptr = headerBuffer.get(); - return read(asio::buffer(ptr, kHeaderSize), baton) - .then([headerBuffer = std::move(headerBuffer), this, baton]() mutable { - if (checkForHTTPRequest(asio::buffer(headerBuffer.get(), kHeaderSize))) { - return sendHTTPResponse(baton); - } - - const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength()); - if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) { - StringBuilder sb; - sb << "recv(): message msgLen " << msgLen << " is invalid. " - << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes; - const auto str = sb.str(); - LOGV2(4615638, - "recv(): message msgLen {msgLen} is invalid. Min: {min} Max: {max}", - "recv(): message mstLen is invalid.", - "msgLen"_attr = msgLen, - "min"_attr = kHeaderSize, - "max"_attr = MaxMessageSizeBytes); - - return Future<Message>::makeReady(Status(ErrorCodes::ProtocolError, str)); - } - - if (msgLen == kHeaderSize) { - // This probably isn't a real case since all (current) messages have bodies. - if (_isIngressSession) { - networkCounter.hitPhysicalIn(msgLen); - } - return Future<Message>::makeReady(Message(std::move(headerBuffer))); - } - - auto buffer = SharedBuffer::allocate(msgLen); - memcpy(buffer.get(), headerBuffer.get(), kHeaderSize); - - MsgData::View msgView(buffer.get()); - return read(asio::buffer(msgView.data(), msgView.dataLen()), baton) - .then([this, buffer = std::move(buffer), msgLen]() mutable { - if (_isIngressSession) { - networkCounter.hitPhysicalIn(msgLen); - } - return Message(std::move(buffer)); - }); - }); - } + Future<Message> sourceMessageImpl(const BatonHandle& baton = nullptr); template <typename MutableBufferSequence> - Future<void> read(const MutableBufferSequence& buffers, const BatonHandle& baton = nullptr) { - // TODO SERVER-47229 Guard active ops for cancellation here. -#ifdef MONGO_CONFIG_SSL - if (_sslSocket) { - return opportunisticRead(*_sslSocket, buffers, baton); - } else if (!_ranHandshake) { - invariant(asio::buffer_size(buffers) >= sizeof(MSGHEADER::Value)); - - return opportunisticRead(_socket, buffers, baton) - .then([this, buffers]() mutable { - _ranHandshake = true; - return maybeHandshakeSSLForIngress(buffers); - }) - .then([this, buffers, baton](bool needsRead) mutable { - if (needsRead) { - return read(buffers, baton); - } else { - return Future<void>::makeReady(); - } - }); - } -#endif - return opportunisticRead(_socket, buffers, baton); - } + Future<void> read(const MutableBufferSequence& buffers, const BatonHandle& baton = nullptr); template <typename ConstBufferSequence> - Future<void> write(const ConstBufferSequence& buffers, const BatonHandle& baton = nullptr) { - // TODO SERVER-47229 Guard active ops for cancellation here. -#ifdef MONGO_CONFIG_SSL - _ranHandshake = true; - if (_sslSocket) { -#ifdef __linux__ - // We do some trickery in asio (see moreToSend), which appears to work well on linux, - // but fails on other platforms. - return opportunisticWrite(*_sslSocket, buffers, baton); -#else - if (_blockingMode == Async) { - // Opportunistic writes are broken for async egress SSL (switching between blocking - // and non-blocking mode corrupts the TLS exchange). - return asio::async_write(*_sslSocket, buffers, UseFuture{}).ignoreValue(); - } else { - return opportunisticWrite(*_sslSocket, buffers, baton); - } -#endif - } -#endif - return opportunisticWrite(_socket, buffers, baton); - } + Future<void> write(const ConstBufferSequence& buffers, const BatonHandle& baton = nullptr); template <typename Stream, typename MutableBufferSequence> Future<void> opportunisticRead(Stream& stream, const MutableBufferSequence& buffers, - const BatonHandle& baton = nullptr) { - std::error_code ec; - size_t size; - - if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) && - _blockingMode == Async) { - asio::mutable_buffer localBuffer = buffers; - - if (buffers.size()) { - localBuffer = asio::mutable_buffer(buffers.data(), 1); - } - - do { - size = asio::read(stream, localBuffer, ec); - } while (ec == asio::error::interrupted); // retry syscall EINTR - - if (!ec && buffers.size() > 1) { - ec = asio::error::would_block; - } - } else { - do { - size = asio::read(stream, buffers, ec); - } while (ec == asio::error::interrupted); // retry syscall EINTR - } - - if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) && - (_blockingMode == Async)) { - // asio::read is a loop internally, so some of buffers may have been read into already. - // So we need to adjust the buffers passed into async_read to be offset by size, if - // size is > 0. - MutableBufferSequence asyncBuffers(buffers); - if (size > 0) { - asyncBuffers += size; - } - - if (auto networkingBaton = baton ? baton->networking() : nullptr; - networkingBaton && networkingBaton->canWait()) { - return networkingBaton->addSession(*this, NetworkingBaton::Type::In) - .onError([](Status error) { - if (ErrorCodes::isShutdownError(error)) { - // If the baton has detached, it will cancel its polling. We catch that - // error here and return Status::OK so that we invoke - // opportunisticRead() again and switch to asio::async_read() below. - return Status::OK(); - } - - return error; - }) - .then([&stream, asyncBuffers, baton, this] { - return opportunisticRead(stream, asyncBuffers, baton); - }); - } - - return asio::async_read(stream, asyncBuffers, UseFuture{}).ignoreValue(); - } else { - return futurize(ec); - } - } + const BatonHandle& baton = nullptr); /** * moreToSend checks the ssl socket after an opportunisticWrite. If there are still bytes to @@ -635,204 +207,26 @@ private: return boost::none; } - boost::optional<std::string> getSniName() const override { - return SSLPeerInfo::forSession(shared_from_this()).sniName; - } + boost::optional<std::string> getSniName() const override; #endif template <typename Stream, typename ConstBufferSequence> Future<void> opportunisticWrite(Stream& stream, const ConstBufferSequence& buffers, - const BatonHandle& baton = nullptr) { - std::error_code ec; - std::size_t size; - - if (MONGO_unlikely(transportLayerASIOshortOpportunisticReadWrite.shouldFail()) && - _blockingMode == Async) { - asio::const_buffer localBuffer = buffers; - - if (buffers.size()) { - localBuffer = asio::const_buffer(buffers.data(), 1); - } - - do { - size = asio::write(stream, localBuffer, ec); - } while (ec == asio::error::interrupted); // retry syscall EINTR - if (!ec && buffers.size() > 1) { - ec = asio::error::would_block; - } - } else { - do { - size = asio::write(stream, buffers, ec); - } while (ec == asio::error::interrupted); // retry syscall EINTR - } - - if (((ec == asio::error::would_block) || (ec == asio::error::try_again)) && - (_blockingMode == Async)) { - - // asio::write is a loop internally, so some of buffers may have been read into already. - // So we need to adjust the buffers passed into async_write to be offset by size, if - // size is > 0. - ConstBufferSequence asyncBuffers(buffers); - if (size > 0) { - asyncBuffers += size; - } - - if (auto more = moreToSend(stream, asyncBuffers, baton)) { - return std::move(*more); - } - - if (auto networkingBaton = baton ? baton->networking() : nullptr; - networkingBaton && networkingBaton->canWait()) { - return networkingBaton->addSession(*this, NetworkingBaton::Type::Out) - .onError([](Status error) { - if (ErrorCodes::isCancellationError(error)) { - // If the baton has detached, it will cancel its polling. We catch that - // error here and return Status::OK so that we invoke - // opportunisticWrite() again and switch to asio::async_write() below. - return Status::OK(); - } - - return error; - }) - .then([&stream, asyncBuffers, baton, this] { - return opportunisticWrite(stream, asyncBuffers, baton); - }); - } - - return asio::async_write(stream, asyncBuffers, UseFuture{}).ignoreValue(); - } else { - return futurize(ec); - } - } + const BatonHandle& baton = nullptr); #ifdef MONGO_CONFIG_SSL template <typename MutableBufferSequence> - Future<bool> maybeHandshakeSSLForIngress(const MutableBufferSequence& buffer) { - invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value)); - MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer)); - auto responseTo = headerView.getResponseToMsgId(); - - if (checkForHTTPRequest(buffer)) { - return Future<bool>::makeReady(false); - } - // This logic was taken from the old mongo/util/net/sock.cpp. - // - // It lets us run both TLS and unencrypted mongo over the same port. - // - // The first message received from the client should have the responseTo field of the wire - // protocol message needs to be 0 or -1. Otherwise the connection is either sending - // garbage or a TLS Hello packet which will be caught by the TLS handshake. - if (responseTo != 0 && responseTo != -1) { - if (!_sslContext->ingress) { - return Future<bool>::makeReady( - Status(ErrorCodes::SSLHandshakeFailed, - "SSL handshake received but server is started without SSL support")); - } - - auto tlsAlert = checkTLSRequest(buffer); - if (tlsAlert) { - return opportunisticWrite(getSocket(), - asio::buffer(tlsAlert->data(), tlsAlert->size())) - .then([] { - return Future<bool>::makeReady( - Status(ErrorCodes::SSLHandshakeFailed, - "SSL handshake failed, as client requested disabled protocol")); - }); - } - - _sslSocket.emplace(std::move(_socket), *_sslContext->ingress, ""); - auto doHandshake = [&] { - if (_blockingMode == Sync) { - std::error_code ec; - _sslSocket->handshake(asio::ssl::stream_base::server, buffer, ec); - return futurize(ec, asio::buffer_size(buffer)); - } else { - return _sslSocket->async_handshake( - asio::ssl::stream_base::server, buffer, UseFuture{}); - } - }; - return doHandshake().then([this](size_t size) { - if (_sslSocket->get_sni()) { - auto sniName = _sslSocket->get_sni().get(); - LOGV2_DEBUG(4908000, - 2, - "Client connected with SNI extension", - "sniName"_attr = sniName); - } else { - LOGV2_DEBUG(4908001, 2, "Client connected without SNI extension"); - } - if (SSLPeerInfo::forSession(shared_from_this()).subjectName.empty()) { - return getSSLManager() - ->parseAndValidatePeerCertificate(_sslSocket->native_handle(), - _sslSocket->get_sni(), - "", - _remote, - nullptr) - .then([this](SSLPeerInfo info) -> bool { - SSLPeerInfo::forSession(shared_from_this()) = info; - return true; - }); - } - - return Future<bool>::makeReady(true); - }); - } else if (_tl->_sslMode() == SSLParams::SSLMode_requireSSL) { - uasserted(ErrorCodes::SSLHandshakeFailed, - "The server is configured to only allow SSL connections"); - } else { - if (!sslGlobalParams.disableNonSSLConnectionLogging && - _tl->_sslMode() == SSLParams::SSLMode_preferSSL) { - LOGV2(23838, - "SSL mode is set to 'preferred' and connection {connectionId} to {remote} is " - "not using SSL.", - "SSL mode is set to 'preferred' and connection to remote is not using SSL.", - "connectionId"_attr = id(), - "remote"_attr = remote()); - } - return Future<bool>::makeReady(false); - } - } + Future<bool> maybeHandshakeSSLForIngress(const MutableBufferSequence& buffer); #endif template <typename Buffer> - bool checkForHTTPRequest(const Buffer& buffers) { - invariant(asio::buffer_size(buffers) >= 4); - const StringData bufferAsStr(asio::buffer_cast<const char*>(buffers), 4); - return (bufferAsStr == "GET "_sd); - } + bool checkForHTTPRequest(const Buffer& buffers); // Called from read() to send an HTTP response back to a client that's trying to use HTTP // over a native MongoDB port. This returns a Future<Message> to match its only caller, but it // always contains an error, so it could really return Future<Anything> - Future<Message> sendHTTPResponse(const BatonHandle& baton = nullptr) { - constexpr auto userMsg = - "It looks like you are trying to access MongoDB over HTTP" - " on the native driver port.\r\n"_sd; - - static const std::string httpResp = str::stream() << "HTTP/1.0 200 OK\r\n" - "Connection: close\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: " - << userMsg.size() << "\r\n\r\n" - << userMsg; - - return write(asio::buffer(httpResp.data(), httpResp.size()), baton) - .onError( - [](const Status& status) { - return Status( - ErrorCodes::ProtocolError, - str::stream() - << "Client sent an HTTP request over a native MongoDB connection, " - "but there was an error sending a response: " - << status.toString()); - }) - .then([] { - return StatusWith<Message>( - ErrorCodes::ProtocolError, - "Client sent an HTTP request over a native MongoDB connection"); - }); - } + Future<Message> sendHTTPResponse(const BatonHandle& baton = nullptr); enum BlockingMode { Unknown, @@ -862,5 +256,4 @@ private: bool _isIngressSession; }; -} // namespace transport -} // namespace mongo +} // namespace mongo::transport diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index 8323948621e..ae03407e8af 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -468,7 +468,7 @@ StatusWith<SessionHandle> TransportLayerASIO::connect( } std::error_code ec; - GenericSocket sock(*_egressReactor); + ASIOSession::GenericSocket sock(*_egressReactor); WrappedResolver resolver(*_egressReactor); Date_t timeBefore = Date_t::now(); @@ -529,7 +529,7 @@ StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncCon const HostAndPort& peer, const Milliseconds& timeout, boost::optional<TransientSSLParams> transientSSLParams) { - GenericSocket sock(*_egressReactor); + ASIOSession::GenericSocket sock(*_egressReactor); std::error_code ec; const auto protocol = endpoint->protocol(); @@ -622,7 +622,7 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect( Promise<SessionHandle> promise; Mutex mutex = MONGO_MAKE_LATCH(HierarchicalAcquisitionLevel(0), "AsyncConnectState::mutex"); - GenericSocket socket; + ASIOSession::GenericSocket socket; ASIOReactorTimer timeoutTimer; WrappedResolver resolver; WrappedEndpoint resolvedEndpoint; @@ -1185,7 +1185,8 @@ ReactorHandle TransportLayerASIO::getReactor(WhichReactor which) { } void TransportLayerASIO::_acceptConnection(GenericAcceptor& acceptor) { - auto acceptCb = [this, &acceptor](const std::error_code& ec, GenericSocket peerSocket) mutable { + auto acceptCb = [this, &acceptor](const std::error_code& ec, + ASIOSession::GenericSocket peerSocket) mutable { if (auto lk = stdx::lock_guard(_mutex); _isShutdown) { return; } |