diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/transport/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/transport/message_compressor_manager.cpp | 30 | ||||
-rw-r--r-- | src/mongo/transport/message_compressor_manager_test.cpp | 89 | ||||
-rw-r--r-- | src/mongo/transport/message_compressor_snappy.cpp | 136 |
4 files changed, 241 insertions, 16 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 857ca9d9dfb..a4ca32867d6 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -176,7 +176,7 @@ env.CppUnitTest( zlibEnv = env.Clone() -zlibEnv.InjectThirdPartyIncludePaths(libraries=['zlib']) +zlibEnv.InjectThirdPartyIncludePaths(libraries=['zlib', 'snappy']) zlibEnv.Library( target='message_compressor', source=[ diff --git a/src/mongo/transport/message_compressor_manager.cpp b/src/mongo/transport/message_compressor_manager.cpp index e889bbefc9d..5e911ba0a8d 100644 --- a/src/mongo/transport/message_compressor_manager.cpp +++ b/src/mongo/transport/message_compressor_manager.cpp @@ -52,23 +52,31 @@ struct CompressionHeader { uint8_t compressorId; void serialize(DataRangeCursor* cursor) { - cursor->writeAndAdvance<LittleEndian<int32_t>>(originalOpCode).transitional_ignore(); - cursor->writeAndAdvance<LittleEndian<int32_t>>(uncompressedSize).transitional_ignore(); - cursor->writeAndAdvance<LittleEndian<uint8_t>>(compressorId).transitional_ignore(); + uassertStatusOK(cursor->writeAndAdvance<LittleEndian<int32_t>>(originalOpCode)); + uassertStatusOK(cursor->writeAndAdvance<LittleEndian<int32_t>>(uncompressedSize)); + uassertStatusOK(cursor->writeAndAdvance<LittleEndian<uint8_t>>(compressorId)); } CompressionHeader(int32_t _opcode, int32_t _size, uint8_t _id) : originalOpCode{_opcode}, uncompressedSize{_size}, compressorId{_id} {} CompressionHeader(ConstDataRangeCursor* cursor) { - originalOpCode = cursor->readAndAdvance<LittleEndian<std::int32_t>>().getValue(); - uncompressedSize = cursor->readAndAdvance<LittleEndian<std::int32_t>>().getValue(); - compressorId = cursor->readAndAdvance<LittleEndian<uint8_t>>().getValue(); + originalOpCode = _readWithChecking<LittleEndian<std::int32_t>>(cursor); + uncompressedSize = _readWithChecking<LittleEndian<std::int32_t>>(cursor); + compressorId = _readWithChecking<LittleEndian<uint8_t>>(cursor); } static size_t size() { return sizeof(originalOpCode) + sizeof(uncompressedSize) + sizeof(compressorId); } + +private: + template <typename T> + T _readWithChecking(ConstDataRangeCursor* cursor) { + auto sw = cursor->readAndAdvance<T>(); + uassertStatusOK(sw.getStatus()); + return sw.getValue(); + } }; const transport::Session::Decoration<MessageCompressorManager> getForSession = @@ -136,6 +144,9 @@ StatusWith<Message> MessageCompressorManager::decompressMessage(const Message& m MessageCompressorId* compressorId) { auto inputHeader = msg.header(); ConstDataRangeCursor input(inputHeader.data(), inputHeader.data() + inputHeader.dataLen()); + if (input.length() < CompressionHeader::size()) { + return {ErrorCodes::BadValue, "Invalid compressed message header"}; + } CompressionHeader compressionHeader(&input); auto compressor = _registry->getCompressor(compressionHeader.compressorId); @@ -150,7 +161,12 @@ StatusWith<Message> MessageCompressorManager::decompressMessage(const Message& m LOG(3) << "Decompressing message with " << compressor->getName(); - auto bufferSize = compressionHeader.uncompressedSize + MsgData::MsgDataHeaderSize; + size_t bufferSize = compressionHeader.uncompressedSize + MsgData::MsgDataHeaderSize; + if (bufferSize > MaxMessageSizeBytes) { + return {ErrorCodes::BadValue, + "Decompressed message would be larger than maximum message size"}; + } + auto outputMessageBuffer = SharedBuffer::allocate(bufferSize); MsgData::View outMessage(outputMessageBuffer.get()); outMessage.setId(inputHeader.getId()); diff --git a/src/mongo/transport/message_compressor_manager_test.cpp b/src/mongo/transport/message_compressor_manager_test.cpp index 7d6b64bc58c..79eba2e7263 100644 --- a/src/mongo/transport/message_compressor_manager_test.cpp +++ b/src/mongo/transport/message_compressor_manager_test.cpp @@ -26,6 +26,8 @@ * it in the license file. */ +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork + #include "mongo/platform/basic.h" #include "mongo/bson/bsonobjbuilder.h" @@ -36,6 +38,7 @@ #include "mongo/transport/message_compressor_snappy.h" #include "mongo/transport/message_compressor_zlib.h" #include "mongo/unittest/unittest.h" +#include "mongo/util/log.h" #include "mongo/util/net/message.h" #include <string> @@ -128,6 +131,40 @@ void checkFidelity(const Message& msg, std::unique_ptr<MessageCompressorBase> co ASSERT_EQ(memcmp(decompressedMsgView.data(), originalView.data(), originalView.dataLen()), 0); } +void checkOverflow(std::unique_ptr<MessageCompressorBase> compressor) { + // This is our test data that we're going to try to compress/decompress into a buffer that's + // way too small. + const auto data = std::string{ + "We embrace reality. We apply high-quality thinking and rigor." + "We have courage in our convictions but work hard to ensure biases " + "or personal beliefs do not get in the way of finding the best solution."}; + ConstDataRange input(data.data(), data.size()); + + // This is our tiny buffer that should cause an error. + std::array<char, 16> smallBuffer; + DataRange smallOutput(smallBuffer.data(), smallBuffer.size()); + + // This is a normal sized buffer that we can store a compressed version of our test data safely + std::vector<char> normalBuffer; + normalBuffer.resize(compressor->getMaxCompressedSize(data.size())); + DataRange normalRange(normalBuffer.data(), normalBuffer.size()); + ASSERT_OK(compressor->compressData(input, normalRange)); + + // Check that compressing the test data into a small buffer fails + ASSERT_NOT_OK(compressor->compressData(input, smallOutput)); + + // Check that decompressing compressed test data into a small buffer fails + ASSERT_NOT_OK(compressor->decompressData(normalRange, smallOutput)); + + // Check that decompressing a valid buffer that's missing data doesn't overflow the + // source buffer. + std::vector<char> scratch; + scratch.resize(data.size()); + ConstDataRange tooSmallRange(normalBuffer.data(), normalBuffer.size() / 2); + ASSERT_NOT_OK( + compressor->decompressData(tooSmallRange, DataRange(scratch.data(), scratch.size()))); +} + Message buildMessage() { const auto data = std::string{"Hello, world!"}; const auto bufferSize = MsgData::MsgDataHeaderSize + data.size(); @@ -195,6 +232,14 @@ TEST(ZlibMessageCompressor, Fidelity) { checkFidelity(testMessage, stdx::make_unique<ZlibMessageCompressor>()); } +TEST(SnappyMessageCompressor, Overflow) { + checkOverflow(stdx::make_unique<SnappyMessageCompressor>()); +} + +TEST(ZlibMessageCompressor, Overflow) { + checkOverflow(stdx::make_unique<ZlibMessageCompressor>()); +} + TEST(MessageCompressorManager, SERVER_28008) { // Create a client and server that will negotiate the same compressors, @@ -248,5 +293,47 @@ TEST(MessageCompressorManager, SERVER_28008) { ASSERT_EQ(compressorId, zlibId); } -} // namespace mongo +TEST(MessageCompressorManager, MessageSizeTooLarge) { + auto registry = buildRegistry(); + MessageCompressorManager compManager(®istry); + + auto badMessageBuffer = SharedBuffer::allocate(128); + MsgData::View badMessage(badMessageBuffer.get()); + badMessage.setId(1); + badMessage.setResponseToMsgId(0); + badMessage.setOperation(dbCompressed); + badMessage.setLen(128); + + DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); + uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(dbQuery)); + uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(MaxMessageSizeBytes + 1)); + uassertStatusOK( + cursor.writeAndAdvance<LittleEndian<uint8_t>>(registry.getCompressor("noop")->getId())); + + auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus(); + ASSERT_NOT_OK(status); +} + +TEST(MessageCompressorManager, RuntMessage) { + auto registry = buildRegistry(); + MessageCompressorManager compManager(®istry); + + auto badMessageBuffer = SharedBuffer::allocate(128); + MsgData::View badMessage(badMessageBuffer.get()); + badMessage.setId(1); + badMessage.setResponseToMsgId(0); + badMessage.setOperation(dbCompressed); + badMessage.setLen(MsgData::MsgDataHeaderSize + 8); + + // This is a totally bogus compression header of just the orginal opcode + 0 byte uncompressed + // size + DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); + uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(dbQuery)); + uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(0)); + + auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus(); + ASSERT_NOT_OK(status); +} + } // namespace +} // namespace mongo diff --git a/src/mongo/transport/message_compressor_snappy.cpp b/src/mongo/transport/message_compressor_snappy.cpp index db1e0c9dfca..c722513f452 100644 --- a/src/mongo/transport/message_compressor_snappy.cpp +++ b/src/mongo/transport/message_compressor_snappy.cpp @@ -30,26 +30,136 @@ #include "mongo/platform/basic.h" +#include "mongo/base/data_range_cursor.h" #include "mongo/base/init.h" #include "mongo/stdx/memory.h" #include "mongo/transport/message_compressor_registry.h" #include "mongo/transport/message_compressor_snappy.h" -#include "third_party/snappy-1.1.3/snappy.h" +#include <snappy-sinksource.h> +#include <snappy.h> namespace mongo { +namespace { +class SnappySourceSinkException : public DBException { +public: + SnappySourceSinkException(Status status) : DBException(status) {} +}; + +// This is a bounds-checking version of snappy::UncheckedByteArraySink. +// +// If the amount of scratch buffer space requested by snappy is larger than the sink +// buffer, than it will allocate a new temporary buffer so that snappy can finish. +// If the amount of scratch buffer space requested by snappy is less than or equal to +// the size of the sink buffer, than it will just return the sink buffer. +// +// If the scratch buffer is the sink buffer, than Append will just advance the buffer +// cursor and do bounds checking without any copying. +// +// Appending data past the end of the sink buffer will throw a SnappySourcesinkexception. +class DataRangeSink final : public snappy::Sink { +public: + DataRangeSink(DataRange buffer) : _cursor(buffer) {} + + char* GetAppendBuffer(size_t length, char* scratch) final { + if (length > _cursor.length()) { + _scratch.resize(length); + return _scratch.data(); + } + + return const_cast<char*>(_cursor.data()); + } + + void AppendAndTakeOwnership(char* data, + size_t n, + void (*deleter)(void*, const char*, size_t), + void* deleterArg) final { + Append(data, n); + if (data != _cursor.data()) { + (*deleter)(deleterArg, data, n); + } + } + + void Append(const char* bytes, size_t n) final { + Status status = Status::OK(); + if (bytes == _cursor.data()) { + status = _cursor.advance(n); + } else { + ConstDataRange toWrite(bytes, n); + status = _cursor.writeAndAdvance(toWrite); + } + if (!status.isOK()) { + throw SnappySourceSinkException(std::move(status)); + } + } + + char* GetAppendBufferVariable(size_t minSize, + size_t desiredSizeHint, + char* scratch, + size_t scratchSize, + size_t* allocatedSize) { + if (desiredSizeHint > _cursor.length() || minSize > _cursor.length()) { + _scratch.resize(desiredSizeHint); + *allocatedSize = _scratch.size(); + return _scratch.data(); + } + + *allocatedSize = _cursor.length(); + return const_cast<char*>(_cursor.data()); + } + +private: + DataRangeCursor _cursor; + std::vector<char> _scratch; +}; + +class ConstDataRangeSource final : public snappy::Source { +public: + ConstDataRangeSource(ConstDataRange buffer) : _cursor(buffer) {} + + size_t Available() const final { + return _cursor.length(); + } + + const char* Peek(size_t* len) final { + *len = _cursor.length(); + return _cursor.data(); + } + + void Skip(size_t n) final { + auto status = _cursor.advance(n); + if (!status.isOK()) { + throw SnappySourceSinkException(std::move(status)); + } + } + +private: + ConstDataRangeCursor _cursor; +}; + +} // namespace SnappyMessageCompressor::SnappyMessageCompressor() : MessageCompressorBase(MessageCompressor::kSnappy) {} std::size_t SnappyMessageCompressor::getMaxCompressedSize(size_t inputSize) { - return snappy::MaxCompressedLength(inputSize); + // Testing has shown that snappy typically requests two additional bytes of buffer space when + // compressing beyond what snappy::MaxCompressedLength returns. So by padding this by 2 more + // bytes, we can avoid additional allocations/copies during compression. + return snappy::MaxCompressedLength(inputSize) + 2; } StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange input, DataRange output) { size_t outLength; - snappy::RawCompress(input.data(), input.length(), const_cast<char*>(output.data()), &outLength); + ConstDataRangeSource source(input); + DataRangeSink sink(output); + + try { + outLength = snappy::Compress(&source, &sink); + } catch (const SnappySourceSinkException& e) { + return e.toStatus(); + } counterHitCompress(input.length(), outLength); return {outLength}; @@ -57,11 +167,23 @@ StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange inp StatusWith<std::size_t> SnappyMessageCompressor::decompressData(ConstDataRange input, DataRange output) { - bool ret = - snappy::RawUncompress(input.data(), input.length(), const_cast<char*>(output.data())); + try { + uint32_t expectedLength = 0; + ConstDataRangeSource lengthCheckSource(input); + if (!snappy::GetUncompressedLength(&lengthCheckSource, &expectedLength) || + expectedLength > output.length()) { + return {ErrorCodes::BadValue, "Compressed message was invalid or corrupted"}; + } + + ConstDataRangeSource source(input); + DataRangeSink sink(output); - if (!ret) { - return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"}; + bool ret = snappy::Uncompress(&source, &sink); + if (!ret) { + return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"}; + } + } catch (const SnappySourceSinkException& e) { + return e.toStatus(); } counterHitDecompress(input.length(), output.length()); |