diff options
Diffstat (limited to 'src/mongo/transport')
21 files changed, 312 insertions, 79 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 4b41ec5d085..83fb5729e01 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -23,6 +23,7 @@ env.Library( '$BUILD_DIR/mongo/unittest/unittest', '$BUILD_DIR/mongo/util/foundation', '$BUILD_DIR/mongo/util/net/network', + '$BUILD_DIR/mongo/transport/message_compressor', ], ) @@ -99,16 +100,19 @@ env.CppUnitTest( LIBDEPS=[ 'transport_layer_mock', ], +) env.Library( target='message_compressor', source=[ - 'message_compressor_registry.cpp', 'message_compressor_manager.cpp', + 'message_compressor_registry.cpp', + 'message_compressor_snappy.cpp', ], LIBDEPS=[ '$BUILD_DIR/mongo/base', '$BUILD_DIR/mongo/util/options_parser/options_parser', + '$BUILD_DIR/third_party/shim_snappy', ] ) diff --git a/src/mongo/transport/message_compressor_base.h b/src/mongo/transport/message_compressor_base.h index c0a2c14b72d..a2c05a6069e 100644 --- a/src/mongo/transport/message_compressor_base.h +++ b/src/mongo/transport/message_compressor_base.h @@ -33,8 +33,17 @@ #include "mongo/base/string_data.h" #include "mongo/platform/atomic_word.h" +#include <type_traits> + namespace mongo { -using MessageCompressorId = uint8_t; +enum class MessageCompressor : uint8_t { + kNoop = 0, + kSnappy = 1, + kExtended = 255, +}; + +StringData getMessageCompressorName(MessageCompressor id); +using MessageCompressorId = std::underlying_type<MessageCompressor>::type; class MessageCompressorBase { MONGO_DISALLOW_COPYING(MessageCompressorBase); @@ -109,8 +118,9 @@ protected: /* * This is called by sub-classes to intialize their ID/name fields. */ - MessageCompressorBase(MessageCompressorId id, StringData name) - : _id{id}, _name{name.toString()} {} + MessageCompressorBase(MessageCompressor id) + : _id{static_cast<MessageCompressorId>(id)}, + _name{getMessageCompressorName(id).toString()} {} /* * Called by sub-classes to bump their bytesIn/bytesOut counters for compression diff --git a/src/mongo/transport/message_compressor_manager.cpp b/src/mongo/transport/message_compressor_manager.cpp index 42f3e9c8c0d..78fa19ad8a3 100644 --- a/src/mongo/transport/message_compressor_manager.cpp +++ b/src/mongo/transport/message_compressor_manager.cpp @@ -59,10 +59,10 @@ struct CompressionHeader { CompressionHeader(int32_t _opcode, int32_t _size, uint8_t _id) : originalOpCode{_opcode}, uncompressedSize{_size}, compressorId{_id} {} - CompressionHeader(ConstDataRangeCursor cursor) { - originalOpCode = cursor.readAndAdvance<LittleEndian<std::int32_t>>().getValue(); - uncompressedSize = cursor.readAndAdvance<LittleEndian<std::int32_t>>().getValue(); - compressorId = cursor.readAndAdvance<LittleEndian<uint8_t>>().getValue(); + CompressionHeader(ConstDataRangeCursor* cursor) { + originalOpCode = cursor->readAndAdvance<LittleEndian<std::int32_t>>().getValue(); + uncompressedSize = cursor->readAndAdvance<LittleEndian<std::int32_t>>().getValue(); + compressorId = cursor->readAndAdvance<LittleEndian<uint8_t>>().getValue(); } static size_t size() { @@ -93,6 +93,8 @@ StatusWith<Message> MessageCompressorManager::compressMessage(const Message& msg inputHeader.getNetworkOp(), inputHeader.dataLen(), compressor->getId()); if (bufferSize > MaxMessageSizeBytes) { + LOG(3) << "Compressed message would be larger than " << MaxMessageSizeBytes + << ", returning original uncompressed message"; return {msg}; } @@ -122,7 +124,7 @@ StatusWith<Message> MessageCompressorManager::compressMessage(const Message& msg StatusWith<Message> MessageCompressorManager::decompressMessage(const Message& msg) { auto inputHeader = msg.header(); ConstDataRangeCursor input(inputHeader.data(), inputHeader.data() + inputHeader.dataLen()); - CompressionHeader compressionHeader(input); + CompressionHeader compressionHeader(&input); auto compressor = _registry->getCompressor(compressionHeader.compressorId); if (!compressor) { @@ -145,6 +147,10 @@ StatusWith<Message> MessageCompressorManager::decompressMessage(const Message& m if (!sws.isOK()) return sws.getStatus(); + if (sws.getValue() != static_cast<std::size_t>(compressionHeader.uncompressedSize)) { + return {ErrorCodes::BadValue, "Decompressing message returned less data than expected"}; + } + outMessage.setLen(sws.getValue() + MsgData::MsgDataHeaderSize); return {Message(outputMessageBuffer)}; @@ -224,8 +230,8 @@ void MessageCompressorManager::serverNegotiate(const BSONObj& input, BSONObjBuil if ((cur = _registry->getCompressor(curName))) { LOG(3) << cur->getName() << " is supported"; _negotiated.push_back(cur); - } else { // Otherwise the compressor is not supported and we skip over it. - LOG(3) << cur->getName() << " is not supported"; + } else { // Otherwise the compressor is not supported and we skip over it. + LOG(3) << curName << " is not supported"; } } diff --git a/src/mongo/transport/message_compressor_manager.h b/src/mongo/transport/message_compressor_manager.h index a3d35027ffa..7af084996db 100644 --- a/src/mongo/transport/message_compressor_manager.h +++ b/src/mongo/transport/message_compressor_manager.h @@ -48,13 +48,16 @@ public: /* * Default constructor. Uses the global MessageCompressorRegistry. */ - explicit MessageCompressorManager(); + MessageCompressorManager(); /* * Constructs a manager from a specific MessageCompressorRegistry - used by the unit tests * to test various registry configurations. */ - MessageCompressorManager(MessageCompressorRegistry* factory); + explicit MessageCompressorManager(MessageCompressorRegistry* factory); + + MessageCompressorManager(MessageCompressorManager&&) = default; + MessageCompressorManager& operator=(MessageCompressorManager&&) = default; /* * Called by a client constructing an isMaster request. This function will append the result diff --git a/src/mongo/transport/message_compressor_manager_test.cpp b/src/mongo/transport/message_compressor_manager_test.cpp index c431f102e23..383bb1e3260 100644 --- a/src/mongo/transport/message_compressor_manager_test.cpp +++ b/src/mongo/transport/message_compressor_manager_test.cpp @@ -30,10 +30,11 @@ #include "mongo/bson/bsonobjbuilder.h" #include "mongo/stdx/memory.h" -#include "mongo/transport/message_compressor_registry.h" #include "mongo/transport/message_compressor_manager.h" #include "mongo/transport/message_compressor_noop.h" +#include "mongo/transport/message_compressor_registry.h" #include "mongo/unittest/unittest.h" +#include "mongo/util/net/message.h" #include <string> #include <vector> @@ -49,7 +50,7 @@ MessageCompressorRegistry buildRegistry() { ret.registerImplementation(std::move(compressor)); ret.finalizeSupportedCompressors(); - return std::move(ret); + return ret; } void checkNegotiationResult(const BSONObj& result, const std::vector<std::string>& algos) { @@ -80,7 +81,57 @@ void checkServerNegotiation(const BSONObj& input, const std::vector<std::string> manager.serverNegotiate(input, &serverOutput); checkNegotiationResult(serverOutput.done(), expected); } -} // namespace + +void checkFidelity(const Message& msg, std::unique_ptr<MessageCompressorBase> compressor) { + MessageCompressorRegistry registry; + const auto originalView = msg.singleData(); + const auto compressorName = compressor->getName(); + + std::vector<std::string> compressorList = {compressorName}; + registry.setSupportedCompressors(std::move(compressorList)); + registry.registerImplementation(std::move(compressor)); + registry.finalizeSupportedCompressors(); + + MessageCompressorManager mgr(®istry); + auto negotiator = BSON("isMaster" << 1 << "compression" << BSON_ARRAY(compressorName)); + BSONObjBuilder negotiatorOut; + mgr.serverNegotiate(negotiator, &negotiatorOut); + checkNegotiationResult(negotiatorOut.done(), {compressorName}); + + auto swm = mgr.compressMessage(msg); + ASSERT_OK(swm.getStatus()); + auto compressedMsg = std::move(swm.getValue()); + const auto compressedMsgView = compressedMsg.singleData(); + + ASSERT_EQ(compressedMsgView.getId(), originalView.getId()); + ASSERT_EQ(compressedMsgView.getResponseToMsgId(), originalView.getResponseToMsgId()); + ASSERT_EQ(compressedMsgView.getNetworkOp(), dbCompressed); + + swm = mgr.decompressMessage(compressedMsg); + ASSERT_OK(swm.getStatus()); + auto decompressedMsg = std::move(swm.getValue()); + + const auto decompressedMsgView = decompressedMsg.singleData(); + ASSERT_EQ(decompressedMsgView.getId(), originalView.getId()); + ASSERT_EQ(decompressedMsgView.getResponseToMsgId(), originalView.getResponseToMsgId()); + ASSERT_EQ(decompressedMsgView.getNetworkOp(), originalView.getNetworkOp()); + ASSERT_EQ(decompressedMsgView.getLen(), originalView.getLen()); + + ASSERT_EQ(memcmp(decompressedMsgView.data(), originalView.data(), originalView.dataLen()), 0); +} + +Message buildMessage() { + const auto data = std::string{"Hello, world!"}; + const auto bufferSize = MsgData::MsgDataHeaderSize + data.size(); + auto buf = SharedBuffer::allocate(bufferSize); + MsgData::View testView(buf.get()); + testView.setId(123456); + testView.setResponseToMsgId(654321); + testView.setOperation(dbQuery); + testView.setLen(bufferSize); + memcpy(testView.data(), data.data(), data.size()); + return Message{buf}; +} TEST(MessageCompressorManager, NoCompressionRequested) { auto input = BSON("isMaster" << 1); @@ -120,4 +171,16 @@ TEST(MessageCompressorManager, FullNormalCompression) { clientManager.clientFinish(serverObj); } + +TEST(NoopMessageCompressor, Fidelity) { + auto testMessage = buildMessage(); + checkFidelity(testMessage, stdx::make_unique<NoopMessageCompressor>()); +} + +TEST(SnappyMessageCompressor, Fidelity) { + auto testMessage = buildMessage(); + checkFidelity(testMessage, stdx::make_unique<NoopMessageCompressor>()); +} + } // namespace mongo +} // namespace diff --git a/src/mongo/transport/message_compressor_noop.h b/src/mongo/transport/message_compressor_noop.h index 54a2c795f9d..b0602482b78 100644 --- a/src/mongo/transport/message_compressor_noop.h +++ b/src/mongo/transport/message_compressor_noop.h @@ -26,28 +26,26 @@ * it in the license file. */ -#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork - #include "mongo/transport/message_compressor_base.h" namespace mongo { class NoopMessageCompressor final : public MessageCompressorBase { public: - NoopMessageCompressor() : MessageCompressorBase(0, "noop") {} + NoopMessageCompressor() : MessageCompressorBase(MessageCompressor::kNoop) {} std::size_t getMaxCompressedSize(size_t inputSize) override { return inputSize; } StatusWith<std::size_t> compressData(ConstDataRange input, DataRange output) override { - memcpy(const_cast<char*>(output.data()), input.data(), input.length()); + output.write(input); counterHitCompress(input.length(), input.length()); return {input.length()}; } StatusWith<std::size_t> decompressData(ConstDataRange input, DataRange output) override { - memcpy(const_cast<char*>(output.data()), input.data(), input.length()); + output.write(input); counterHitDecompress(input.length(), input.length()); return {input.length()}; } diff --git a/src/mongo/transport/message_compressor_registry.cpp b/src/mongo/transport/message_compressor_registry.cpp index e64d73b063e..a9cf4cda0b9 100644 --- a/src/mongo/transport/message_compressor_registry.cpp +++ b/src/mongo/transport/message_compressor_registry.cpp @@ -33,6 +33,7 @@ #include "mongo/base/init.h" #include "mongo/stdx/memory.h" #include "mongo/transport/message_compressor_noop.h" +#include "mongo/transport/message_compressor_snappy.h" #include "mongo/util/options_parser/option_section.h" #include <boost/algorithm/string/classification.hpp> @@ -40,6 +41,18 @@ namespace mongo { +StringData getMessageCompressorName(MessageCompressor id) { + switch (id) { + case MessageCompressor::kNoop: + return "noop"_sd; + case MessageCompressor::kSnappy: + return "snappy"_sd; + default: + fassert(40269, "Invalid message compressor ID"); + } + MONGO_UNREACHABLE; +} + MessageCompressorRegistry& MessageCompressorRegistry::get() { static MessageCompressorRegistry globalRegistry; return globalRegistry; @@ -48,7 +61,7 @@ MessageCompressorRegistry& MessageCompressorRegistry::get() { void MessageCompressorRegistry::registerImplementation( std::unique_ptr<MessageCompressorBase> impl) { // It's an error to register a compressor that's already been registered - fassert(40254, + fassert(40270, _compressorsByName.find(impl->getName()) == _compressorsByName.end() && _compressorsByIds[impl->getId()] == nullptr); @@ -61,13 +74,15 @@ void MessageCompressorRegistry::registerImplementation( _compressorsByIds[impl->getId()] = std::move(impl); } -void MessageCompressorRegistry::finalizeSupportedCompressors() { - // Remove compressor names from the compressorNames list if they were never registered. - // This prevents _compressorNames from having totally bogus names specified by users. - std::remove_if( - _compressorNames.begin(), _compressorNames.end(), [this](const std::string& name) { - return _compressorsByName.find(name) == _compressorsByName.end(); - }); +Status MessageCompressorRegistry::finalizeSupportedCompressors() { + for (auto it = _compressorNames.begin(); it != _compressorNames.end(); ++it) { + if (_compressorsByName.find(*it) == _compressorsByName.end()) { + std::stringstream ss; + ss << "Invalid network message compressor specified in configuration: " << *it; + return {ErrorCodes::BadValue, ss.str()}; + } + } + return Status::OK(); } const std::vector<std::string>& MessageCompressorRegistry::getCompressorNames() const { @@ -119,7 +134,7 @@ Status storeMessageCompressionOptions(const moe::Environment& params) { // This instantiates and registers the "noop" compressor. It must happen after option storage // because that's when the configuration of the compressors gets set. MONGO_INITIALIZER_GENERAL(NoopMessageCompressorInit, - ("EndStartupOptionHandling"), + ("EndStartupOptionStorage"), ("AllCompressorsRegistered")) (InitializerContext* context) { auto& compressorRegistry = MessageCompressorRegistry::get(); @@ -131,7 +146,6 @@ MONGO_INITIALIZER_GENERAL(NoopMessageCompressorInit, // any compressor. It must be run after all the compressors have registered themselves with // the global registry. MONGO_INITIALIZER(AllCompressorsRegistered)(InitializerContext* context) { - MessageCompressorRegistry::get().finalizeSupportedCompressors(); - return Status::OK(); + return MessageCompressorRegistry::get().finalizeSupportedCompressors(); } } // namespace mongo diff --git a/src/mongo/transport/message_compressor_registry.h b/src/mongo/transport/message_compressor_registry.h index 721185cfb1c..9d8549ed0e3 100644 --- a/src/mongo/transport/message_compressor_registry.h +++ b/src/mongo/transport/message_compressor_registry.h @@ -107,7 +107,7 @@ public: * calls to registerImplementation. It will remove any compressor names that aren't keys in * the _compressors map. */ - void finalizeSupportedCompressors(); + Status finalizeSupportedCompressors(); private: StringMap<MessageCompressorBase*> _compressorsByName; diff --git a/src/mongo/transport/message_compressor_registry_test.cpp b/src/mongo/transport/message_compressor_registry_test.cpp index b3a766f2f55..a14f067e606 100644 --- a/src/mongo/transport/message_compressor_registry_test.cpp +++ b/src/mongo/transport/message_compressor_registry_test.cpp @@ -73,15 +73,17 @@ TEST(MessageCompressorRegistry, NothingRegistered) { TEST(MessageCompressorRegistry, SetSupported) { MessageCompressorRegistry registry; auto compressor = stdx::make_unique<NoopMessageCompressor>(); - auto compressorPtr = compressor.get(); + auto compressorId = compressor->getId(); + auto compressorName = compressor->getName(); std::vector<std::string> compressorList = {"foobar"}; registry.setSupportedCompressors(std::move(compressorList)); registry.registerImplementation(std::move(compressor)); - registry.finalizeSupportedCompressors(); + auto ret = registry.finalizeSupportedCompressors(); + ASSERT_NOT_OK(ret); - ASSERT_NULL(registry.getCompressor(compressorPtr->getName())); - ASSERT_NULL(registry.getCompressor(compressorPtr->getId())); + ASSERT_NULL(registry.getCompressor(compressorId)); + ASSERT_NULL(registry.getCompressor(compressorName)); } } // namespace } // namespace mongo diff --git a/src/mongo/transport/message_compressor_snappy.cpp b/src/mongo/transport/message_compressor_snappy.cpp new file mode 100644 index 00000000000..db1e0c9dfca --- /dev/null +++ b/src/mongo/transport/message_compressor_snappy.cpp @@ -0,0 +1,80 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork + +#include "mongo/platform/basic.h" + +#include "mongo/base/init.h" +#include "mongo/stdx/memory.h" +#include "mongo/transport/message_compressor_registry.h" +#include "mongo/transport/message_compressor_snappy.h" + +#include "third_party/snappy-1.1.3/snappy.h" + +namespace mongo { + +SnappyMessageCompressor::SnappyMessageCompressor() + : MessageCompressorBase(MessageCompressor::kSnappy) {} + +std::size_t SnappyMessageCompressor::getMaxCompressedSize(size_t inputSize) { + return snappy::MaxCompressedLength(inputSize); +} + +StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange input, + DataRange output) { + size_t outLength; + snappy::RawCompress(input.data(), input.length(), const_cast<char*>(output.data()), &outLength); + + counterHitCompress(input.length(), outLength); + return {outLength}; +} + +StatusWith<std::size_t> SnappyMessageCompressor::decompressData(ConstDataRange input, + DataRange output) { + bool ret = + snappy::RawUncompress(input.data(), input.length(), const_cast<char*>(output.data())); + + if (!ret) { + return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"}; + } + + counterHitDecompress(input.length(), output.length()); + return output.length(); +} + + +MONGO_INITIALIZER_GENERAL(SnappyMessageCompressorInit, + ("EndStartupOptionHandling"), + ("AllCompressorsRegistered")) +(InitializerContext* context) { + auto& compressorRegistry = MessageCompressorRegistry::get(); + compressorRegistry.registerImplementation(stdx::make_unique<SnappyMessageCompressor>()); + return Status::OK(); +} +} // namespace mongo diff --git a/src/mongo/transport/message_compressor_snappy.h b/src/mongo/transport/message_compressor_snappy.h new file mode 100644 index 00000000000..3521370df09 --- /dev/null +++ b/src/mongo/transport/message_compressor_snappy.h @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/transport/message_compressor_base.h" + +namespace mongo { +class SnappyMessageCompressor final : public MessageCompressorBase { +public: + SnappyMessageCompressor(); + + std::size_t getMaxCompressedSize(size_t inputSize) override; + + StatusWith<std::size_t> compressData(ConstDataRange input, DataRange output) override; + + StatusWith<std::size_t> decompressData(ConstDataRange input, DataRange output) override; +}; + + +} // namespace mongo diff --git a/src/mongo/transport/service_entry_point_test_suite.cpp b/src/mongo/transport/service_entry_point_test_suite.cpp index efd43fd9a9b..841851d5ccd 100644 --- a/src/mongo/transport/service_entry_point_test_suite.cpp +++ b/src/mongo/transport/service_entry_point_test_suite.cpp @@ -127,13 +127,13 @@ ServiceEntryPointTestSuite::MockTLHarness::MockTLHarness() _asyncWait(kDefaultAsyncWait), _end(kDefaultEnd) {} -Ticket ServiceEntryPointTestSuite::MockTLHarness::sourceMessage(const Session& session, +Ticket ServiceEntryPointTestSuite::MockTLHarness::sourceMessage(Session& session, Message* message, Date_t expiration) { return _sourceMessage(session, message, expiration); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::sinkMessage(const Session& session, +Ticket ServiceEntryPointTestSuite::MockTLHarness::sinkMessage(Session& session, const Message& message, Date_t expiration) { return _sinkMessage(session, message, expiration); @@ -191,19 +191,17 @@ Status ServiceEntryPointTestSuite::MockTLHarness::_waitOnceThenError(transport:: return _defaultWait(std::move(ticket)); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSource(const Session& s, - Message* m, - Date_t d) { +Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSource(Session& s, Message* m, Date_t d) { return Ticket(this, stdx::make_unique<ServiceEntryPointTestSuite::MockTicket>(s, m, d)); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSink(const Session& s, +Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSink(Session& s, const Message&, Date_t d) { return Ticket(this, stdx::make_unique<ServiceEntryPointTestSuite::MockTicket>(s, d)); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::_sinkThenErrorOnWait(const Session& s, +Ticket ServiceEntryPointTestSuite::MockTLHarness::_sinkThenErrorOnWait(Session& s, const Message& m, Date_t d) { _wait = stdx::bind(&ServiceEntryPointTestSuite::MockTLHarness::_waitOnceThenError, this, _1); @@ -264,7 +262,7 @@ void ServiceEntryPointTestSuite::halfLifeCycleTest() { // Step 1: SEP gets a ticket to source a Message // Step 2: SEP calls wait() on the ticket and receives a Message // Step 3: SEP gets a ticket to sink a Message - _tl->_sinkMessage = [this](const Session& session, const Message& m, Date_t expiration) { + _tl->_sinkMessage = [this](Session& session, const Message& m, Date_t expiration) { // Step 4: SEP calls wait() on the ticket and receives an error _tl->_wait = diff --git a/src/mongo/transport/service_entry_point_test_suite.h b/src/mongo/transport/service_entry_point_test_suite.h index 623f0533267..2249c86b9bd 100644 --- a/src/mongo/transport/service_entry_point_test_suite.h +++ b/src/mongo/transport/service_entry_point_test_suite.h @@ -120,11 +120,11 @@ private: MockTLHarness(); transport::Ticket sourceMessage( - const transport::Session& session, + transport::Session& session, Message* message, Date_t expiration = transport::Ticket::kNoExpirationDate) override; transport::Ticket sinkMessage( - const transport::Session& session, + transport::Session& session, const Message& message, Date_t expiration = transport::Ticket::kNoExpirationDate) override; Status wait(transport::Ticket&& ticket) override; @@ -141,10 +141,8 @@ private: ServiceEntryPointTestSuite::MockTicket* getMockTicket(const transport::Ticket& ticket); // Mocked method hooks - stdx::function<transport::Ticket(const transport::Session&, Message*, Date_t)> - _sourceMessage; - stdx::function<transport::Ticket(const transport::Session&, const Message&, Date_t)> - _sinkMessage; + stdx::function<transport::Ticket(transport::Session&, Message*, Date_t)> _sourceMessage; + stdx::function<transport::Ticket(transport::Session&, const Message&, Date_t)> _sinkMessage; stdx::function<Status(transport::Ticket)> _wait; stdx::function<void(transport::Ticket, TicketCallback)> _asyncWait; stdx::function<void(const transport::Session&)> _end; @@ -154,11 +152,9 @@ private: stdx::function<void(void)> _shutdown = [] {}; // Pre-set hook methods - transport::Ticket _defaultSource(const transport::Session& s, Message* m, Date_t d); - transport::Ticket _defaultSink(const transport::Session& s, const Message&, Date_t d); - transport::Ticket _sinkThenErrorOnWait(const transport::Session& s, - const Message& m, - Date_t d); + transport::Ticket _defaultSource(transport::Session& s, Message* m, Date_t d); + transport::Ticket _defaultSink(transport::Session& s, const Message&, Date_t d); + transport::Ticket _sinkThenErrorOnWait(transport::Session& s, const Message& m, Date_t d); Status _defaultWait(transport::Ticket ticket); Status _waitError(transport::Ticket ticket); diff --git a/src/mongo/transport/session.h b/src/mongo/transport/session.h index 8551f607037..c7ec5cd28f0 100644 --- a/src/mongo/transport/session.h +++ b/src/mongo/transport/session.h @@ -29,6 +29,7 @@ #pragma once #include "mongo/base/disallow_copying.h" +#include "mongo/transport/message_compressor_manager.h" #include "mongo/transport/session_id.h" #include "mongo/transport/ticket.h" #include "mongo/util/net/hostandport.h" @@ -152,6 +153,10 @@ public: return _ended; } + MessageCompressorManager& getCompressorManager() { + return _messageCompressorManager; + } + private: bool _ended = false; @@ -163,6 +168,8 @@ private: TagMask _tags; TransportLayer* _tl; + + MessageCompressorManager _messageCompressorManager; }; } // namespace transport diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 44995fc375b..3634ec68de5 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -98,7 +98,7 @@ public: * TransportLayer is unable to source a Message, this will be a failed status, * and the passed-in Message buffer may be left in an invalid state. */ - virtual Ticket sourceMessage(const Session& session, + virtual Ticket sourceMessage(Session& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) = 0; @@ -117,7 +117,7 @@ public: * This method does NOT take ownership of the sunk Message, which must be cleaned * up by the caller. */ - virtual Ticket sinkMessage(const Session& session, + virtual Ticket sinkMessage(Session& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) = 0; diff --git a/src/mongo/transport/transport_layer_legacy.cpp b/src/mongo/transport/transport_layer_legacy.cpp index 8d1f95b6a23..3c004eb793f 100644 --- a/src/mongo/transport/transport_layer_legacy.cpp +++ b/src/mongo/transport/transport_layer_legacy.cpp @@ -98,13 +98,19 @@ Status TransportLayerLegacy::start() { TransportLayerLegacy::~TransportLayerLegacy() = default; -Ticket TransportLayerLegacy::sourceMessage(const Session& session, - Message* message, - Date_t expiration) { - auto sourceCb = [message](AbstractMessagingPort* amp) -> Status { +Ticket TransportLayerLegacy::sourceMessage(Session& session, Message* message, Date_t expiration) { + auto& compressorMgr = session.getCompressorManager(); + auto sourceCb = [message, &compressorMgr](AbstractMessagingPort* amp) -> Status { if (!amp->recv(*message)) { return {ErrorCodes::HostUnreachable, "Recv failed"}; } + + if (message->operation() == dbCompressed) { + auto swm = compressorMgr.decompressMessage(*message); + if (!swm.isOK()) + return swm.getStatus(); + *message = swm.getValue(); + } return Status::OK(); }; @@ -137,12 +143,18 @@ TransportLayer::Stats TransportLayerLegacy::sessionStats() { return stats; } -Ticket TransportLayerLegacy::sinkMessage(const Session& session, +Ticket TransportLayerLegacy::sinkMessage(Session& session, const Message& message, Date_t expiration) { - auto sinkCb = [&message](AbstractMessagingPort* amp) -> Status { + auto& compressorMgr = session.getCompressorManager(); + auto sinkCb = [&message, &compressorMgr](AbstractMessagingPort* amp) -> Status { try { - amp->say(message); + auto swm = compressorMgr.compressMessage(message); + if (!swm.isOK()) + return swm.getStatus(); + const auto& compressedMessage = swm.getValue(); + amp->say(compressedMessage); + return Status::OK(); } catch (const SocketException& e) { return {ErrorCodes::HostUnreachable, e.what()}; diff --git a/src/mongo/transport/transport_layer_legacy.h b/src/mongo/transport/transport_layer_legacy.h index 1ef193b4754..7472212ffc9 100644 --- a/src/mongo/transport/transport_layer_legacy.h +++ b/src/mongo/transport/transport_layer_legacy.h @@ -67,11 +67,11 @@ public: Status setup(); Status start() override; - Ticket sourceMessage(const Session& session, + Ticket sourceMessage(Session& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) override; - Ticket sinkMessage(const Session& session, + Ticket sinkMessage(Session& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) override; diff --git a/src/mongo/transport/transport_layer_manager.cpp b/src/mongo/transport/transport_layer_manager.cpp index 7ba1797a21f..e513155e5cd 100644 --- a/src/mongo/transport/transport_layer_manager.cpp +++ b/src/mongo/transport/transport_layer_manager.cpp @@ -43,13 +43,11 @@ namespace transport { TransportLayerManager::TransportLayerManager() = default; -Ticket TransportLayerManager::sourceMessage(const Session& session, - Message* message, - Date_t expiration) { +Ticket TransportLayerManager::sourceMessage(Session& session, Message* message, Date_t expiration) { return session.getTransportLayer()->sourceMessage(session, message, expiration); } -Ticket TransportLayerManager::sinkMessage(const Session& session, +Ticket TransportLayerManager::sinkMessage(Session& session, const Message& message, Date_t expiration) { return session.getTransportLayer()->sinkMessage(session, message, expiration); diff --git a/src/mongo/transport/transport_layer_manager.h b/src/mongo/transport/transport_layer_manager.h index aeed86edcfe..20d27d6571c 100644 --- a/src/mongo/transport/transport_layer_manager.h +++ b/src/mongo/transport/transport_layer_manager.h @@ -54,10 +54,10 @@ class TransportLayerManager final : public TransportLayer { public: TransportLayerManager(); - Ticket sourceMessage(const Session& session, + Ticket sourceMessage(Session& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) override; - Ticket sinkMessage(const Session& session, + Ticket sinkMessage(Session& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) override; diff --git a/src/mongo/transport/transport_layer_mock.cpp b/src/mongo/transport/transport_layer_mock.cpp index 3f71b5d16e0..e7fa76d2e9b 100644 --- a/src/mongo/transport/transport_layer_mock.cpp +++ b/src/mongo/transport/transport_layer_mock.cpp @@ -64,9 +64,7 @@ boost::optional<Message*> TransportLayerMock::TicketMock::msg() const { TransportLayerMock::TransportLayerMock() : _shutdown(false) {} -Ticket TransportLayerMock::sourceMessage(const Session& session, - Message* message, - Date_t expiration) { +Ticket TransportLayerMock::sourceMessage(Session& session, Message* message, Date_t expiration) { if (inShutdown()) { return Ticket(TransportLayer::ShutdownStatus); } else if (!owns(session.id())) { @@ -79,7 +77,7 @@ Ticket TransportLayerMock::sourceMessage(const Session& session, stdx::make_unique<TransportLayerMock::TicketMock>(&session, message, expiration)); } -Ticket TransportLayerMock::sinkMessage(const Session& session, +Ticket TransportLayerMock::sinkMessage(Session& session, const Message& message, Date_t expiration) { if (inShutdown()) { diff --git a/src/mongo/transport/transport_layer_mock.h b/src/mongo/transport/transport_layer_mock.h index f519713e9bc..38ab3eed0f1 100644 --- a/src/mongo/transport/transport_layer_mock.h +++ b/src/mongo/transport/transport_layer_mock.h @@ -76,10 +76,10 @@ public: TransportLayerMock(); ~TransportLayerMock(); - Ticket sourceMessage(const Session& session, + Ticket sourceMessage(Session& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) override; - Ticket sinkMessage(const Session& session, + Ticket sinkMessage(Session& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) override; |