/**
* Copyright (C) 2016 MongoDB Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the GNU Affero General Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork
#include "mongo/platform/basic.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/stdx/memory.h"
#include "mongo/transport/message_compressor_manager.h"
#include "mongo/transport/message_compressor_noop.h"
#include "mongo/transport/message_compressor_registry.h"
#include "mongo/transport/message_compressor_snappy.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/log.h"
#include "mongo/util/net/message.h"
#include
#include
namespace mongo {
namespace {
MessageCompressorRegistry buildRegistry() {
MessageCompressorRegistry ret;
auto compressor = stdx::make_unique();
std::vector compressorList = {compressor->getName()};
ret.setSupportedCompressors(std::move(compressorList));
ret.registerImplementation(std::move(compressor));
ret.finalizeSupportedCompressors();
return ret;
}
void checkNegotiationResult(const BSONObj& result, const std::vector& algos) {
auto compressorsList = result.getField("compression");
if (algos.empty()) {
ASSERT_TRUE(compressorsList.eoo());
return;
}
ASSERT_TRUE(!compressorsList.eoo());
ASSERT_TRUE(compressorsList.isABSONObj());
auto compressorsListObj = compressorsList.Obj();
std::vector resultAlgos;
for (const auto& e : compressorsListObj) {
resultAlgos.push_back(e.checkAndGetStringData().toString());
}
ASSERT_EQ(algos.size(), resultAlgos.size());
for (size_t i = 0; i < algos.size(); i++) {
ASSERT_EQ(algos[i], resultAlgos[i]);
}
}
void checkServerNegotiation(const BSONObj& input, const std::vector& expected) {
auto registry = buildRegistry();
MessageCompressorManager manager(®istry);
BSONObjBuilder serverOutput;
manager.serverNegotiate(input, &serverOutput);
checkNegotiationResult(serverOutput.done(), expected);
}
void checkFidelity(const Message& msg, std::unique_ptr compressor) {
MessageCompressorRegistry registry;
const auto originalView = msg.singleData();
const auto compressorName = compressor->getName();
std::vector compressorList = {compressorName};
registry.setSupportedCompressors(std::move(compressorList));
registry.registerImplementation(std::move(compressor));
registry.finalizeSupportedCompressors();
MessageCompressorManager mgr(®istry);
auto negotiator = BSON("isMaster" << 1 << "compression" << BSON_ARRAY(compressorName));
BSONObjBuilder negotiatorOut;
mgr.serverNegotiate(negotiator, &negotiatorOut);
checkNegotiationResult(negotiatorOut.done(), {compressorName});
auto swm = mgr.compressMessage(msg);
ASSERT_OK(swm.getStatus());
auto compressedMsg = std::move(swm.getValue());
const auto compressedMsgView = compressedMsg.singleData();
ASSERT_EQ(compressedMsgView.getId(), originalView.getId());
ASSERT_EQ(compressedMsgView.getResponseToMsgId(), originalView.getResponseToMsgId());
ASSERT_EQ(compressedMsgView.getNetworkOp(), dbCompressed);
swm = mgr.decompressMessage(compressedMsg);
ASSERT_OK(swm.getStatus());
auto decompressedMsg = std::move(swm.getValue());
const auto decompressedMsgView = decompressedMsg.singleData();
ASSERT_EQ(decompressedMsgView.getId(), originalView.getId());
ASSERT_EQ(decompressedMsgView.getResponseToMsgId(), originalView.getResponseToMsgId());
ASSERT_EQ(decompressedMsgView.getNetworkOp(), originalView.getNetworkOp());
ASSERT_EQ(decompressedMsgView.getLen(), originalView.getLen());
ASSERT_EQ(memcmp(decompressedMsgView.data(), originalView.data(), originalView.dataLen()), 0);
}
void checkOverflow(std::unique_ptr compressor) {
// This is our test data that we're going to try to compress/decompress into a buffer that's
// way too small.
const std::string data =
"We embrace reality. We apply high-quality thinking and rigor."
"We have courage in our convictions but work hard to ensure biases "
"or personal beliefs do not get in the way of finding the best solution.";
ConstDataRange input(data.data(), data.size());
// This is our tiny buffer that should cause an error.
std::array smallBuffer;
DataRange smallOutput(smallBuffer.data(), smallBuffer.size());
// This is a normal sized buffer that we can store a compressed version of our test data safely
std::vector normalBuffer;
normalBuffer.resize(compressor->getMaxCompressedSize(data.size()));
auto sws = compressor->compressData(input, DataRange(normalBuffer.data(), normalBuffer.size()));
ASSERT_OK(sws);
DataRange normalRange = DataRange(normalBuffer.data(), sws.getValue());
// Check that compressing the test data into a small buffer fails
ASSERT_NOT_OK(compressor->compressData(input, smallOutput));
// Check that decompressing compressed test data into a small buffer fails
ASSERT_NOT_OK(compressor->decompressData(normalRange, smallOutput));
// Check that decompressing a valid buffer that's missing data doesn't overflow the
// source buffer.
std::vector scratch;
scratch.resize(data.size());
ConstDataRange tooSmallRange(normalBuffer.data(), normalBuffer.size() / 2);
ASSERT_NOT_OK(
compressor->decompressData(tooSmallRange, DataRange(scratch.data(), scratch.size())));
}
Message buildMessage() {
const auto data = std::string{"Hello, world!"};
const auto bufferSize = MsgData::MsgDataHeaderSize + data.size();
auto buf = SharedBuffer::allocate(bufferSize);
MsgData::View testView(buf.get());
testView.setId(123456);
testView.setResponseToMsgId(654321);
testView.setOperation(dbQuery);
testView.setLen(bufferSize);
memcpy(testView.data(), data.data(), data.size());
return Message{buf};
}
TEST(MessageCompressorManager, NoCompressionRequested) {
auto input = BSON("isMaster" << 1);
checkServerNegotiation(input, {});
}
TEST(MessageCompressorManager, NormalCompressionRequested) {
auto input = BSON("isMaster" << 1 << "compression" << BSON_ARRAY("noop"));
checkServerNegotiation(input, {"noop"});
}
TEST(MessageCompressorManager, BadCompressionRequested) {
auto input = BSON("isMaster" << 1 << "compression" << BSON_ARRAY("fakecompressor"));
checkServerNegotiation(input, {});
}
TEST(MessageCompressorManager, BadAndGoodCompressionRequested) {
auto input = BSON("isMaster" << 1 << "compression" << BSON_ARRAY("fakecompressor"
<< "noop"));
checkServerNegotiation(input, {"noop"});
}
TEST(MessageCompressorManager, FullNormalCompression) {
auto registry = buildRegistry();
MessageCompressorManager clientManager(®istry);
MessageCompressorManager serverManager(®istry);
BSONObjBuilder clientOutput;
clientManager.clientBegin(&clientOutput);
auto clientObj = clientOutput.done();
checkNegotiationResult(clientObj, {"noop"});
BSONObjBuilder serverOutput;
serverManager.serverNegotiate(clientObj, &serverOutput);
auto serverObj = serverOutput.done();
checkNegotiationResult(serverObj, {"noop"});
clientManager.clientFinish(serverObj);
}
TEST(NoopMessageCompressor, Fidelity) {
auto testMessage = buildMessage();
checkFidelity(testMessage, stdx::make_unique());
}
TEST(SnappyMessageCompressor, Fidelity) {
auto testMessage = buildMessage();
checkFidelity(testMessage, stdx::make_unique());
}
TEST(SnappyMessageCompressor, Overflow) {
checkOverflow(stdx::make_unique());
}
TEST(MessageCompressorManager, MessageSizeTooLarge) {
auto registry = buildRegistry();
MessageCompressorManager compManager(®istry);
auto badMessageBuffer = SharedBuffer::allocate(128);
MsgData::View badMessage(badMessageBuffer.get());
badMessage.setId(1);
badMessage.setResponseToMsgId(0);
badMessage.setOperation(dbCompressed);
badMessage.setLen(128);
DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen());
uassertStatusOK(cursor.writeAndAdvance>(dbQuery));
uassertStatusOK(cursor.writeAndAdvance>(MaxMessageSizeBytes + 1));
uassertStatusOK(
cursor.writeAndAdvance>(registry.getCompressor("noop")->getId()));
auto status = compManager.decompressMessage(Message(badMessageBuffer)).getStatus();
ASSERT_NOT_OK(status);
}
TEST(MessageCompressorManager, RuntMessage) {
auto registry = buildRegistry();
MessageCompressorManager compManager(®istry);
auto badMessageBuffer = SharedBuffer::allocate(128);
MsgData::View badMessage(badMessageBuffer.get());
badMessage.setId(1);
badMessage.setResponseToMsgId(0);
badMessage.setOperation(dbCompressed);
badMessage.setLen(MsgData::MsgDataHeaderSize + 8);
// This is a totally bogus compression header of just the orginal opcode + 0 byte uncompressed
// size
DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen());
uassertStatusOK(cursor.writeAndAdvance>(dbQuery));
uassertStatusOK(cursor.writeAndAdvance>(0));
auto status = compManager.decompressMessage(Message(badMessageBuffer)).getStatus();
ASSERT_NOT_OK(status);
}
} // namespace
} // namespace mongo