/** * Copyright (C) 2018-present MongoDB, Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the Server Side Public License, version 1, * as published by MongoDB, Inc. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * Server Side Public License for more details. * * You should have received a copy of the Server Side Public License * along with this program. If not, see * . * * As a special exception, the copyright holders give permission to link the * code of portions of this program with the OpenSSL library under certain * conditions as described in each individual source file and distribute * linked combinations including the program with the OpenSSL library. You * must comply with the Server Side Public License in all respects for * all of the code used other than as permitted herein. If you modify file(s) * with this exception, you may extend this exception to your version of the * file(s), but you are not obligated to do so. If you do not wish to do so, * delete this exception statement from your version. If you delete this * exception statement from all source files in the program, then also delete * it in the license file. */ #include "mongo/platform/basic.h" #include #include #include #include "mongo/bson/bsonobjbuilder.h" #include "mongo/rpc/message.h" #include "mongo/transport/message_compressor_manager.h" #include "mongo/transport/message_compressor_noop.h" #include "mongo/transport/message_compressor_registry.h" #include "mongo/transport/message_compressor_snappy.h" #include "mongo/transport/message_compressor_zlib.h" #include "mongo/transport/message_compressor_zstd.h" #include "mongo/unittest/unittest.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest namespace mongo { namespace { const auto assertOk = [](auto&& sw) { ASSERT(sw.isOK()); return sw.getValue(); }; MessageCompressorRegistry buildRegistry() { MessageCompressorRegistry ret; auto compressor = std::make_unique(); std::vector compressorList = {compressor->getName()}; ret.setSupportedCompressors(std::move(compressorList)); ret.registerImplementation(std::move(compressor)); ret.finalizeSupportedCompressors().transitional_ignore(); return ret; } void checkNegotiationResult(const BSONObj& result, const std::vector& algos) { auto compressorsList = result.getField("compression"); if (algos.empty()) { ASSERT_TRUE(compressorsList.eoo()); return; } ASSERT_TRUE(!compressorsList.eoo()); ASSERT_TRUE(compressorsList.isABSONObj()); auto compressorsListObj = compressorsList.Obj(); std::vector resultAlgos; for (const auto& e : compressorsListObj) { resultAlgos.push_back(e.checkAndGetStringData().toString()); } ASSERT_EQ(algos.size(), resultAlgos.size()); for (size_t i = 0; i < algos.size(); i++) { ASSERT_EQ(algos[i], resultAlgos[i]); } } void checkServerNegotiation(const boost::optional>& input, const std::vector& expected) { auto registry = buildRegistry(); MessageCompressorManager manager(®istry); BSONObjBuilder serverOutput; manager.serverNegotiate(input, &serverOutput); checkNegotiationResult(serverOutput.done(), expected); } void checkFidelity(const Message& msg, std::unique_ptr compressor) { MessageCompressorRegistry registry; const auto originalView = msg.singleData(); const auto compressorName = compressor->getName(); std::vector compressorList = {compressorName}; registry.setSupportedCompressors(std::move(compressorList)); registry.registerImplementation(std::move(compressor)); registry.finalizeSupportedCompressors().transitional_ignore(); MessageCompressorManager mgr(®istry); BSONObjBuilder negotiatorOut; std::vector negotiator({compressorName}); mgr.serverNegotiate(negotiator, &negotiatorOut); checkNegotiationResult(negotiatorOut.done(), {compressorName}); auto swm = mgr.compressMessage(msg); ASSERT_OK(swm.getStatus()); auto compressedMsg = std::move(swm.getValue()); const auto compressedMsgView = compressedMsg.singleData(); ASSERT_EQ(compressedMsgView.getId(), originalView.getId()); ASSERT_EQ(compressedMsgView.getResponseToMsgId(), originalView.getResponseToMsgId()); ASSERT_EQ(compressedMsgView.getNetworkOp(), dbCompressed); swm = mgr.decompressMessage(compressedMsg); ASSERT_OK(swm.getStatus()); auto decompressedMsg = std::move(swm.getValue()); const auto decompressedMsgView = decompressedMsg.singleData(); ASSERT_EQ(decompressedMsgView.getId(), originalView.getId()); ASSERT_EQ(decompressedMsgView.getResponseToMsgId(), originalView.getResponseToMsgId()); ASSERT_EQ(decompressedMsgView.getNetworkOp(), originalView.getNetworkOp()); ASSERT_EQ(decompressedMsgView.getLen(), originalView.getLen()); ASSERT_EQ(memcmp(decompressedMsgView.data(), originalView.data(), originalView.dataLen()), 0); } void checkOverflow(std::unique_ptr compressor) { // This is our test data that we're going to try to compress/decompress into a buffer that's // way too small. const std::string data = "We embrace reality. We apply high-quality thinking and rigor." "We have courage in our convictions but work hard to ensure biases " "or personal beliefs do not get in the way of finding the best solution."; ConstDataRange input(data.data(), data.size()); // This is our tiny buffer that should cause an error. std::array smallBuffer; DataRange smallOutput(smallBuffer.data(), smallBuffer.size()); // This is a normal sized buffer that we can store a compressed version of our test data safely std::vector normalBuffer; normalBuffer.resize(compressor->getMaxCompressedSize(data.size())); auto sws = compressor->compressData(input, DataRange(normalBuffer.data(), normalBuffer.size())); ASSERT_OK(sws); DataRange normalRange = DataRange(normalBuffer.data(), sws.getValue()); // Check that compressing the test data into a small buffer fails ASSERT_NOT_OK(compressor->compressData(input, smallOutput)); // Check that decompressing compressed test data into a small buffer fails ASSERT_NOT_OK(compressor->decompressData(normalRange, smallOutput)); // Check that decompressing a valid buffer that's missing data doesn't overflow the // source buffer. std::vector scratch; scratch.resize(data.size()); ConstDataRange tooSmallRange(normalBuffer.data(), normalBuffer.size() / 2); ASSERT_NOT_OK( compressor->decompressData(tooSmallRange, DataRange(scratch.data(), scratch.size()))); } Message buildMessage() { const auto data = std::string{"Hello, world!"}; const auto bufferSize = MsgData::MsgDataHeaderSize + data.size(); auto buf = SharedBuffer::allocate(bufferSize); MsgData::View testView(buf.get()); testView.setId(123456); testView.setResponseToMsgId(654321); testView.setOperation(dbQuery); testView.setLen(bufferSize); memcpy(testView.data(), data.data(), data.size()); return Message{buf}; } TEST(MessageCompressorManager, NoCompressionRequested) { auto input = BSON("isMaster" << 1); checkServerNegotiation(boost::none, {}); } TEST(MessageCompressorManager, NormalCompressionRequested) { std::vector input{"noop"_sd}; checkServerNegotiation(input, {"noop"}); } TEST(MessageCompressorManager, BadCompressionRequested) { std::vector input{"fakecompressor"_sd}; checkServerNegotiation(input, {}); } TEST(MessageCompressorManager, BadAndGoodCompressionRequested) { std::vector input{"fakecompressor"_sd, "noop"_sd}; checkServerNegotiation(input, {"noop"}); } // Transitional: Parse BSON "isMaster"-like docs for compressor lists. boost::optional> parseBSON(BSONObj input) { auto elem = input["compression"]; if (!elem) { return boost::none; } uassert(ErrorCodes::BadValue, str::stream() << "'compression' is not an array: " << elem, elem.type() == Array); std::vector ret; for (const auto& e : elem.Obj()) { uassert(ErrorCodes::BadValue, str::stream() << "'compression' element is not a string: " << e, e.type() == String); ret.push_back(e.valueStringData()); } return ret; } TEST(MessageCompressorManager, FullNormalCompression) { auto registry = buildRegistry(); MessageCompressorManager clientManager(®istry); MessageCompressorManager serverManager(®istry); BSONObjBuilder clientOutput; clientManager.clientBegin(&clientOutput); auto clientObj = clientOutput.done(); checkNegotiationResult(clientObj, {"noop"}); BSONObjBuilder serverOutput; serverManager.serverNegotiate(parseBSON(clientObj), &serverOutput); auto serverObj = serverOutput.done(); checkNegotiationResult(serverObj, {"noop"}); clientManager.clientFinish(serverObj); } TEST(NoopMessageCompressor, Fidelity) { auto testMessage = buildMessage(); checkFidelity(testMessage, std::make_unique()); } TEST(SnappyMessageCompressor, Fidelity) { auto testMessage = buildMessage(); checkFidelity(testMessage, std::make_unique()); } TEST(ZlibMessageCompressor, Fidelity) { auto testMessage = buildMessage(); checkFidelity(testMessage, std::make_unique()); } TEST(ZstdMessageCompressor, Fidelity) { auto testMessage = buildMessage(); checkFidelity(testMessage, std::make_unique()); } TEST(SnappyMessageCompressor, Overflow) { checkOverflow(std::make_unique()); } TEST(ZlibMessageCompressor, Overflow) { checkOverflow(std::make_unique()); } TEST(ZstdMessageCompressor, Overflow) { checkOverflow(std::make_unique()); } TEST(MessageCompressorManager, SERVER_28008) { // Create a client and server that will negotiate the same compressors, // but with a different ordering for the preferred compressor. std::unique_ptr zstdCompressor = std::make_unique(); const auto zstdId = zstdCompressor->getId(); std::unique_ptr zlibCompressor = std::make_unique(); const auto zlibId = zlibCompressor->getId(); std::unique_ptr snappyCompressor = std::make_unique(); const auto snappyId = snappyCompressor->getId(); MessageCompressorRegistry registry; registry.setSupportedCompressors( {snappyCompressor->getName(), zlibCompressor->getName(), zstdCompressor->getName()}); registry.registerImplementation(std::move(zlibCompressor)); registry.registerImplementation(std::move(zstdCompressor)); registry.registerImplementation(std::move(snappyCompressor)); ASSERT_OK(registry.finalizeSupportedCompressors()); MessageCompressorManager clientManager(®istry); MessageCompressorManager serverManager(®istry); // Do negotiation BSONObjBuilder clientOutput; clientManager.clientBegin(&clientOutput); auto clientObj = clientOutput.done(); BSONObjBuilder serverOutput; serverManager.serverNegotiate(parseBSON(clientObj), &serverOutput); auto serverObj = serverOutput.done(); clientManager.clientFinish(serverObj); // The preferred compressor is snappy. Check that we round trip as snappy by default. auto toSend = buildMessage(); toSend = assertOk(clientManager.compressMessage(toSend, nullptr)); MessageCompressorId compressorId; auto recvd = assertOk(serverManager.decompressMessage(toSend, &compressorId)); ASSERT_EQ(compressorId, snappyId); toSend = assertOk(serverManager.compressMessage(recvd, nullptr)); recvd = assertOk(clientManager.decompressMessage(toSend, &compressorId)); ASSERT_EQ(compressorId, snappyId); // Now, force the client to send as zLib. We should round trip as // zlib if we feed the out compresor id parameter from // decompressMessage back in to compressMessage. toSend = buildMessage(); toSend = assertOk(clientManager.compressMessage(toSend, &zlibId)); recvd = assertOk(serverManager.decompressMessage(toSend, &compressorId)); ASSERT_EQ(compressorId, zlibId); toSend = assertOk(serverManager.compressMessage(recvd, &compressorId)); recvd = assertOk(clientManager.decompressMessage(toSend, &compressorId)); ASSERT_EQ(compressorId, zlibId); // Then, force the client to send as zstd. We should round trip as // zstd if we feed the out compresor id parameter from // decompressMessage back in to compressMessage. toSend = buildMessage(); toSend = assertOk(clientManager.compressMessage(toSend, &zstdId)); recvd = assertOk(serverManager.decompressMessage(toSend, &compressorId)); ASSERT_EQ(compressorId, zstdId); toSend = assertOk(serverManager.compressMessage(recvd, &compressorId)); recvd = assertOk(clientManager.decompressMessage(toSend, &compressorId)); ASSERT_EQ(compressorId, zstdId); } TEST(MessageCompressorManager, MessageSizeTooLarge) { auto registry = buildRegistry(); MessageCompressorManager compManager(®istry); auto badMessageBuffer = SharedBuffer::allocate(128); MsgData::View badMessage(badMessageBuffer.get()); badMessage.setId(1); badMessage.setResponseToMsgId(0); badMessage.setOperation(dbCompressed); badMessage.setLen(128); DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); cursor.writeAndAdvance>(dbQuery); cursor.writeAndAdvance>(MaxMessageSizeBytes + 1); cursor.writeAndAdvance>(registry.getCompressor("noop")->getId()); auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus(); ASSERT_NOT_OK(status); } TEST(MessageCompressorManager, MessageSizeMax32Bit) { auto registry = buildRegistry(); MessageCompressorManager compManager(®istry); auto badMessageBuffer = SharedBuffer::allocate(128); MsgData::View badMessage(badMessageBuffer.get()); badMessage.setId(1); badMessage.setResponseToMsgId(0); badMessage.setOperation(dbCompressed); badMessage.setLen(128); DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); cursor.writeAndAdvance>(dbQuery); cursor.writeAndAdvance>(std::numeric_limits::max()); cursor.writeAndAdvance>(registry.getCompressor("noop")->getId()); auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus(); ASSERT_NOT_OK(status); } TEST(MessageCompressorManager, MessageSizeTooSmall) { auto registry = buildRegistry(); MessageCompressorManager compManager(®istry); auto badMessageBuffer = SharedBuffer::allocate(128); MsgData::View badMessage(badMessageBuffer.get()); badMessage.setId(1); badMessage.setResponseToMsgId(0); badMessage.setOperation(dbCompressed); badMessage.setLen(128); DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); cursor.writeAndAdvance>(dbQuery); cursor.writeAndAdvance>(-1); cursor.writeAndAdvance>(registry.getCompressor("noop")->getId()); auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus(); ASSERT_NOT_OK(status); } TEST(MessageCompressorManager, RuntMessage) { auto registry = buildRegistry(); MessageCompressorManager compManager(®istry); auto badMessageBuffer = SharedBuffer::allocate(128); MsgData::View badMessage(badMessageBuffer.get()); badMessage.setId(1); badMessage.setResponseToMsgId(0); badMessage.setOperation(dbCompressed); badMessage.setLen(MsgData::MsgDataHeaderSize + 8); // This is a totally bogus compression header of just the orginal opcode + 0 byte uncompressed // size DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); cursor.writeAndAdvance>(dbQuery); cursor.writeAndAdvance>(0); auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus(); ASSERT_NOT_OK(status); } } // namespace } // namespace mongo