diff options
author | Tyler Seip <Tyler.Seip@mongodb.com> | 2021-11-05 17:19:41 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-03-16 19:39:43 +0000 |
commit | a4f74872c86ce146a7079e397aa54d5f3711a4f4 (patch) | |
tree | b4dd94306c25b470097e9b066edaacb68921699b | |
parent | fcb579b971ae344eae8aa1e1e00db7e627d9e013 (diff) | |
download | mongo-a4f74872c86ce146a7079e397aa54d5f3711a4f4.tar.gz |
SERVER-60678: Add peeking to SessionASIO
-rw-r--r-- | src/mongo/transport/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/transport/asio_utils.h | 26 | ||||
-rw-r--r-- | src/mongo/transport/asio_utils_test.cpp | 149 | ||||
-rw-r--r-- | src/mongo/transport/session_asio.cpp | 17 |
4 files changed, 186 insertions, 7 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 4bca2d682de..51ede8d2e85 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -181,6 +181,7 @@ env.Library( tlEnv.CppUnitTest( target='transport_test', source=[ + 'asio_utils_test.cpp', 'message_compressor_manager_test.cpp', 'message_compressor_registry_test.cpp', 'transport_layer_asio_test.cpp', diff --git a/src/mongo/transport/asio_utils.h b/src/mongo/transport/asio_utils.h index fc63d800492..e786d9f7eb9 100644 --- a/src/mongo/transport/asio_utils.h +++ b/src/mongo/transport/asio_utils.h @@ -63,7 +63,7 @@ inline HostAndPort endpointToHostAndPort(const asio::generic::stream_protocol::e Status errorCodeToStatus(const std::error_code& ec); -/* +/** * The ASIO implementation of poll (i.e. socket.wait()) cannot poll for a mask of events, and * doesn't support timeouts. * @@ -78,6 +78,30 @@ StatusWith<unsigned> pollASIOSocket(asio::generic::stream_protocol::socket& sock unsigned mask, Milliseconds timeout); +/** + * Attempts to fill up the passed in buffer sequence with bytes from the underlying stream + * without blocking. Returns the number of bytes we were actually able to fill in. Throws + * on failure to read socket for reasons other than blocking. + */ +template <typename Stream, typename MutableBufferSequence> +size_t peekASIOStream(Stream& stream, const MutableBufferSequence& buffers) { + std::error_code ec; + size_t bytesRead; + do { + bytesRead = stream.receive(buffers, stream.message_peek, ec); + } while (ec == asio::error::interrupted); + + // On a completely empty socket, receive returns 0 bytes read and sets + // the error code to either would_block or try_again. Since this isn't + // actually an error condition for our purposes, we ignore these two + // errors. + if (ec != asio::error::would_block && ec != asio::error::try_again) { + uassertStatusOK(errorCodeToStatus(ec)); + } + + return bytesRead; +} + #ifdef MONGO_CONFIG_SSL /** * Peeks at a fragment of a client issued TLS handshake packet. Returns a TLS alert diff --git a/src/mongo/transport/asio_utils_test.cpp b/src/mongo/transport/asio_utils_test.cpp new file mode 100644 index 00000000000..fd96a9bf44b --- /dev/null +++ b/src/mongo/transport/asio_utils_test.cpp @@ -0,0 +1,149 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/transport/asio_utils.h" + +#include <asio.hpp> + +#include "mongo/unittest/assert_that.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/time_support.h" + +namespace mongo::transport { +namespace { + +using namespace unittest::match; + +template <typename Stream> +void writeToSocketAndPollForResponse(Stream& writeSocket, Stream& readSocket, StringData data) { + // Write our payload to our socket. + asio::write(writeSocket, asio::const_buffer(data.rawData(), data.size())); + + // Poll the other end of the connection for data before returning. Wait up to a second for data + // to appear. + const auto start = Date_t::now(); + while (!readSocket.available()) { + if (Date_t::now() - start >= Seconds(1)) { + FAIL("Data was not successfully transmitted across socket pair in time"); + } + } +} + +template <typename Stream> +void peekAllSubstrings(Stream& writeSocket, Stream& readSocket, StringData data) { + writeToSocketAndPollForResponse(writeSocket, readSocket, data); + + // Peek from the socket for all substrings up to and including the full payload size. + // We should never block here. + for (size_t bufferSize = 0; bufferSize <= data.size(); ++bufferSize) { + auto inBuffer = std::make_unique<char[]>(bufferSize); + const auto bytesRead = + peekASIOStream(readSocket, asio::mutable_buffer(inBuffer.get(), bufferSize)); + ASSERT_THAT(StringData(inBuffer.get(), bytesRead), Eq(data.substr(0, bufferSize))); + } +} + +template <typename Stream> +void peekPastBuffer(Stream& writeSocket, Stream& readSocket, StringData data) { + writeToSocketAndPollForResponse(writeSocket, readSocket, data); + + // Peek from the socket more than is available. We should just get what is available. + const auto bufferSize = data.size() + 1; + for (size_t attemptCount = 0; attemptCount < 3; ++attemptCount) { + auto inBuffer = std::make_unique<char[]>(bufferSize); + const auto bytesRead = + peekASIOStream(readSocket, asio::mutable_buffer(inBuffer.get(), bufferSize)); + ASSERT_THAT(StringData(inBuffer.get(), bytesRead), Eq(data)); + } +} + +#ifdef ASIO_HAS_LOCAL_SOCKETS +auto prepareUnixSocketPair(asio::io_context& io_context) { + asio::local::stream_protocol::socket writeSocket(io_context); + asio::local::stream_protocol::socket readSocket(io_context); + asio::local::connect_pair(writeSocket, readSocket); + readSocket.non_blocking(true); + + return std::pair(std::move(writeSocket), std::move(readSocket)); +} + +TEST(ASIOUtils, PeekAvailableBytes) { + asio::io_context io_context; + auto [writeSocket, readSocket] = prepareUnixSocketPair(io_context); + + peekAllSubstrings(writeSocket, readSocket, "example"_sd); +} + +TEST(ASIOUtils, PeekPastAvailableBytes) { + asio::io_context io_context; + auto [writeSocket, readSocket] = prepareUnixSocketPair(io_context); + + peekPastBuffer(writeSocket, readSocket, "example"_sd); +} +#endif // ASIO_HAS_LOCAL_SOCKETS + +auto prepareTCPSocketPair(asio::io_context& io_context) { + // Make a local loopback connection on an arbitrary ephemeral port. + asio::ip::tcp::endpoint ep(asio::ip::make_address("127.0.0.1"), 0); + asio::ip::tcp::acceptor acceptor(io_context, ep.protocol()); + { + std::error_code ec; + acceptor.bind(ep, ec); + uassertStatusOK(errorCodeToStatus(ec)); + } + acceptor.listen(); + + asio::ip::tcp::socket readSocket(io_context, ep.protocol()); + readSocket.connect(acceptor.local_endpoint()); + asio::ip::tcp::socket writeSocket(io_context); + acceptor.accept(writeSocket); + writeSocket.non_blocking(false); + // Set no_delay so that our output doesn't get buffered in a kernel buffer. + writeSocket.set_option(asio::ip::tcp::no_delay(true)); + readSocket.non_blocking(true); + + return std::pair(std::move(writeSocket), std::move(readSocket)); +} + +TEST(ASIOUtils, PeekAvailableBytesTCP) { + asio::io_context io_context; + auto [writeSocket, readSocket] = prepareTCPSocketPair(io_context); + + peekAllSubstrings(writeSocket, readSocket, "example"_sd); +} + +TEST(ASIOUtils, PeekPastAvailableBytesTCP) { + asio::io_context io_context; + auto [writeSocket, readSocket] = prepareTCPSocketPair(io_context); + + peekPastBuffer(writeSocket, readSocket, "example"_sd); +} +} // namespace +} // namespace mongo::transport diff --git a/src/mongo/transport/session_asio.cpp b/src/mongo/transport/session_asio.cpp index 80f9e87aac9..34e2507a2b8 100644 --- a/src/mongo/transport/session_asio.cpp +++ b/src/mongo/transport/session_asio.cpp @@ -33,6 +33,8 @@ #include "mongo/config.h" #include "mongo/logv2/log.h" +#include "mongo/transport/asio_utils.h" +#include "mongo/util/assert_util.h" namespace mongo::transport { @@ -246,17 +248,20 @@ bool TransportLayerASIO::ASIOSession::isConnected() { auto revents = swPollEvents.getValue(); if (revents & POLLIN) { - char testByte; - int size = ::recv(getSocket().native_handle(), &testByte, sizeof(testByte), MSG_PEEK); - if (size == sizeof(testByte)) { + try { + char testByte; + const auto bytesRead = + peekASIOStream(getSocket(), asio::buffer(&testByte, sizeof(testByte))); + uassert(ErrorCodes::SocketException, + "Couldn't peek from underlying socket", + bytesRead == sizeof(testByte)); return true; - } else if (size == -1) { + } catch (const DBException& e) { LOGV2_WARNING(4615610, "Failed to check socket connectivity: {error}", "Failed to check socket connectivity", - "error"_attr = errnoWithDescription(errno)); + "error"_attr = e); } - // If size == 0 then we got disconnected and we should return false. } return false; |