summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/mongo/transport/SConscript2
-rw-r--r--src/mongo/transport/message_compressor_manager.cpp30
-rw-r--r--src/mongo/transport/message_compressor_manager_test.cpp89
-rw-r--r--src/mongo/transport/message_compressor_snappy.cpp136
4 files changed, 241 insertions, 16 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript
index 857ca9d9dfb..a4ca32867d6 100644
--- a/src/mongo/transport/SConscript
+++ b/src/mongo/transport/SConscript
@@ -176,7 +176,7 @@ env.CppUnitTest(
zlibEnv = env.Clone()
-zlibEnv.InjectThirdPartyIncludePaths(libraries=['zlib'])
+zlibEnv.InjectThirdPartyIncludePaths(libraries=['zlib', 'snappy'])
zlibEnv.Library(
target='message_compressor',
source=[
diff --git a/src/mongo/transport/message_compressor_manager.cpp b/src/mongo/transport/message_compressor_manager.cpp
index e889bbefc9d..5e911ba0a8d 100644
--- a/src/mongo/transport/message_compressor_manager.cpp
+++ b/src/mongo/transport/message_compressor_manager.cpp
@@ -52,23 +52,31 @@ struct CompressionHeader {
uint8_t compressorId;
void serialize(DataRangeCursor* cursor) {
- cursor->writeAndAdvance<LittleEndian<int32_t>>(originalOpCode).transitional_ignore();
- cursor->writeAndAdvance<LittleEndian<int32_t>>(uncompressedSize).transitional_ignore();
- cursor->writeAndAdvance<LittleEndian<uint8_t>>(compressorId).transitional_ignore();
+ uassertStatusOK(cursor->writeAndAdvance<LittleEndian<int32_t>>(originalOpCode));
+ uassertStatusOK(cursor->writeAndAdvance<LittleEndian<int32_t>>(uncompressedSize));
+ uassertStatusOK(cursor->writeAndAdvance<LittleEndian<uint8_t>>(compressorId));
}
CompressionHeader(int32_t _opcode, int32_t _size, uint8_t _id)
: originalOpCode{_opcode}, uncompressedSize{_size}, compressorId{_id} {}
CompressionHeader(ConstDataRangeCursor* cursor) {
- originalOpCode = cursor->readAndAdvance<LittleEndian<std::int32_t>>().getValue();
- uncompressedSize = cursor->readAndAdvance<LittleEndian<std::int32_t>>().getValue();
- compressorId = cursor->readAndAdvance<LittleEndian<uint8_t>>().getValue();
+ originalOpCode = _readWithChecking<LittleEndian<std::int32_t>>(cursor);
+ uncompressedSize = _readWithChecking<LittleEndian<std::int32_t>>(cursor);
+ compressorId = _readWithChecking<LittleEndian<uint8_t>>(cursor);
}
static size_t size() {
return sizeof(originalOpCode) + sizeof(uncompressedSize) + sizeof(compressorId);
}
+
+private:
+ template <typename T>
+ T _readWithChecking(ConstDataRangeCursor* cursor) {
+ auto sw = cursor->readAndAdvance<T>();
+ uassertStatusOK(sw.getStatus());
+ return sw.getValue();
+ }
};
const transport::Session::Decoration<MessageCompressorManager> getForSession =
@@ -136,6 +144,9 @@ StatusWith<Message> MessageCompressorManager::decompressMessage(const Message& m
MessageCompressorId* compressorId) {
auto inputHeader = msg.header();
ConstDataRangeCursor input(inputHeader.data(), inputHeader.data() + inputHeader.dataLen());
+ if (input.length() < CompressionHeader::size()) {
+ return {ErrorCodes::BadValue, "Invalid compressed message header"};
+ }
CompressionHeader compressionHeader(&input);
auto compressor = _registry->getCompressor(compressionHeader.compressorId);
@@ -150,7 +161,12 @@ StatusWith<Message> MessageCompressorManager::decompressMessage(const Message& m
LOG(3) << "Decompressing message with " << compressor->getName();
- auto bufferSize = compressionHeader.uncompressedSize + MsgData::MsgDataHeaderSize;
+ size_t bufferSize = compressionHeader.uncompressedSize + MsgData::MsgDataHeaderSize;
+ if (bufferSize > MaxMessageSizeBytes) {
+ return {ErrorCodes::BadValue,
+ "Decompressed message would be larger than maximum message size"};
+ }
+
auto outputMessageBuffer = SharedBuffer::allocate(bufferSize);
MsgData::View outMessage(outputMessageBuffer.get());
outMessage.setId(inputHeader.getId());
diff --git a/src/mongo/transport/message_compressor_manager_test.cpp b/src/mongo/transport/message_compressor_manager_test.cpp
index 7d6b64bc58c..79eba2e7263 100644
--- a/src/mongo/transport/message_compressor_manager_test.cpp
+++ b/src/mongo/transport/message_compressor_manager_test.cpp
@@ -26,6 +26,8 @@
* it in the license file.
*/
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork
+
#include "mongo/platform/basic.h"
#include "mongo/bson/bsonobjbuilder.h"
@@ -36,6 +38,7 @@
#include "mongo/transport/message_compressor_snappy.h"
#include "mongo/transport/message_compressor_zlib.h"
#include "mongo/unittest/unittest.h"
+#include "mongo/util/log.h"
#include "mongo/util/net/message.h"
#include <string>
@@ -128,6 +131,40 @@ void checkFidelity(const Message& msg, std::unique_ptr<MessageCompressorBase> co
ASSERT_EQ(memcmp(decompressedMsgView.data(), originalView.data(), originalView.dataLen()), 0);
}
+void checkOverflow(std::unique_ptr<MessageCompressorBase> compressor) {
+ // This is our test data that we're going to try to compress/decompress into a buffer that's
+ // way too small.
+ const auto data = std::string{
+ "We embrace reality. We apply high-quality thinking and rigor."
+ "We have courage in our convictions but work hard to ensure biases "
+ "or personal beliefs do not get in the way of finding the best solution."};
+ ConstDataRange input(data.data(), data.size());
+
+ // This is our tiny buffer that should cause an error.
+ std::array<char, 16> smallBuffer;
+ DataRange smallOutput(smallBuffer.data(), smallBuffer.size());
+
+ // This is a normal sized buffer that we can store a compressed version of our test data safely
+ std::vector<char> normalBuffer;
+ normalBuffer.resize(compressor->getMaxCompressedSize(data.size()));
+ DataRange normalRange(normalBuffer.data(), normalBuffer.size());
+ ASSERT_OK(compressor->compressData(input, normalRange));
+
+ // Check that compressing the test data into a small buffer fails
+ ASSERT_NOT_OK(compressor->compressData(input, smallOutput));
+
+ // Check that decompressing compressed test data into a small buffer fails
+ ASSERT_NOT_OK(compressor->decompressData(normalRange, smallOutput));
+
+ // Check that decompressing a valid buffer that's missing data doesn't overflow the
+ // source buffer.
+ std::vector<char> scratch;
+ scratch.resize(data.size());
+ ConstDataRange tooSmallRange(normalBuffer.data(), normalBuffer.size() / 2);
+ ASSERT_NOT_OK(
+ compressor->decompressData(tooSmallRange, DataRange(scratch.data(), scratch.size())));
+}
+
Message buildMessage() {
const auto data = std::string{"Hello, world!"};
const auto bufferSize = MsgData::MsgDataHeaderSize + data.size();
@@ -195,6 +232,14 @@ TEST(ZlibMessageCompressor, Fidelity) {
checkFidelity(testMessage, stdx::make_unique<ZlibMessageCompressor>());
}
+TEST(SnappyMessageCompressor, Overflow) {
+ checkOverflow(stdx::make_unique<SnappyMessageCompressor>());
+}
+
+TEST(ZlibMessageCompressor, Overflow) {
+ checkOverflow(stdx::make_unique<ZlibMessageCompressor>());
+}
+
TEST(MessageCompressorManager, SERVER_28008) {
// Create a client and server that will negotiate the same compressors,
@@ -248,5 +293,47 @@ TEST(MessageCompressorManager, SERVER_28008) {
ASSERT_EQ(compressorId, zlibId);
}
-} // namespace mongo
+TEST(MessageCompressorManager, MessageSizeTooLarge) {
+ auto registry = buildRegistry();
+ MessageCompressorManager compManager(&registry);
+
+ auto badMessageBuffer = SharedBuffer::allocate(128);
+ MsgData::View badMessage(badMessageBuffer.get());
+ badMessage.setId(1);
+ badMessage.setResponseToMsgId(0);
+ badMessage.setOperation(dbCompressed);
+ badMessage.setLen(128);
+
+ DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen());
+ uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(dbQuery));
+ uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(MaxMessageSizeBytes + 1));
+ uassertStatusOK(
+ cursor.writeAndAdvance<LittleEndian<uint8_t>>(registry.getCompressor("noop")->getId()));
+
+ auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus();
+ ASSERT_NOT_OK(status);
+}
+
+TEST(MessageCompressorManager, RuntMessage) {
+ auto registry = buildRegistry();
+ MessageCompressorManager compManager(&registry);
+
+ auto badMessageBuffer = SharedBuffer::allocate(128);
+ MsgData::View badMessage(badMessageBuffer.get());
+ badMessage.setId(1);
+ badMessage.setResponseToMsgId(0);
+ badMessage.setOperation(dbCompressed);
+ badMessage.setLen(MsgData::MsgDataHeaderSize + 8);
+
+ // This is a totally bogus compression header of just the orginal opcode + 0 byte uncompressed
+ // size
+ DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen());
+ uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(dbQuery));
+ uassertStatusOK(cursor.writeAndAdvance<LittleEndian<int32_t>>(0));
+
+ auto status = compManager.decompressMessage(Message(badMessageBuffer), nullptr).getStatus();
+ ASSERT_NOT_OK(status);
+}
+
} // namespace
+} // namespace mongo
diff --git a/src/mongo/transport/message_compressor_snappy.cpp b/src/mongo/transport/message_compressor_snappy.cpp
index db1e0c9dfca..c722513f452 100644
--- a/src/mongo/transport/message_compressor_snappy.cpp
+++ b/src/mongo/transport/message_compressor_snappy.cpp
@@ -30,26 +30,136 @@
#include "mongo/platform/basic.h"
+#include "mongo/base/data_range_cursor.h"
#include "mongo/base/init.h"
#include "mongo/stdx/memory.h"
#include "mongo/transport/message_compressor_registry.h"
#include "mongo/transport/message_compressor_snappy.h"
-#include "third_party/snappy-1.1.3/snappy.h"
+#include <snappy-sinksource.h>
+#include <snappy.h>
namespace mongo {
+namespace {
+class SnappySourceSinkException : public DBException {
+public:
+ SnappySourceSinkException(Status status) : DBException(status) {}
+};
+
+// This is a bounds-checking version of snappy::UncheckedByteArraySink.
+//
+// If the amount of scratch buffer space requested by snappy is larger than the sink
+// buffer, than it will allocate a new temporary buffer so that snappy can finish.
+// If the amount of scratch buffer space requested by snappy is less than or equal to
+// the size of the sink buffer, than it will just return the sink buffer.
+//
+// If the scratch buffer is the sink buffer, than Append will just advance the buffer
+// cursor and do bounds checking without any copying.
+//
+// Appending data past the end of the sink buffer will throw a SnappySourcesinkexception.
+class DataRangeSink final : public snappy::Sink {
+public:
+ DataRangeSink(DataRange buffer) : _cursor(buffer) {}
+
+ char* GetAppendBuffer(size_t length, char* scratch) final {
+ if (length > _cursor.length()) {
+ _scratch.resize(length);
+ return _scratch.data();
+ }
+
+ return const_cast<char*>(_cursor.data());
+ }
+
+ void AppendAndTakeOwnership(char* data,
+ size_t n,
+ void (*deleter)(void*, const char*, size_t),
+ void* deleterArg) final {
+ Append(data, n);
+ if (data != _cursor.data()) {
+ (*deleter)(deleterArg, data, n);
+ }
+ }
+
+ void Append(const char* bytes, size_t n) final {
+ Status status = Status::OK();
+ if (bytes == _cursor.data()) {
+ status = _cursor.advance(n);
+ } else {
+ ConstDataRange toWrite(bytes, n);
+ status = _cursor.writeAndAdvance(toWrite);
+ }
+ if (!status.isOK()) {
+ throw SnappySourceSinkException(std::move(status));
+ }
+ }
+
+ char* GetAppendBufferVariable(size_t minSize,
+ size_t desiredSizeHint,
+ char* scratch,
+ size_t scratchSize,
+ size_t* allocatedSize) {
+ if (desiredSizeHint > _cursor.length() || minSize > _cursor.length()) {
+ _scratch.resize(desiredSizeHint);
+ *allocatedSize = _scratch.size();
+ return _scratch.data();
+ }
+
+ *allocatedSize = _cursor.length();
+ return const_cast<char*>(_cursor.data());
+ }
+
+private:
+ DataRangeCursor _cursor;
+ std::vector<char> _scratch;
+};
+
+class ConstDataRangeSource final : public snappy::Source {
+public:
+ ConstDataRangeSource(ConstDataRange buffer) : _cursor(buffer) {}
+
+ size_t Available() const final {
+ return _cursor.length();
+ }
+
+ const char* Peek(size_t* len) final {
+ *len = _cursor.length();
+ return _cursor.data();
+ }
+
+ void Skip(size_t n) final {
+ auto status = _cursor.advance(n);
+ if (!status.isOK()) {
+ throw SnappySourceSinkException(std::move(status));
+ }
+ }
+
+private:
+ ConstDataRangeCursor _cursor;
+};
+
+} // namespace
SnappyMessageCompressor::SnappyMessageCompressor()
: MessageCompressorBase(MessageCompressor::kSnappy) {}
std::size_t SnappyMessageCompressor::getMaxCompressedSize(size_t inputSize) {
- return snappy::MaxCompressedLength(inputSize);
+ // Testing has shown that snappy typically requests two additional bytes of buffer space when
+ // compressing beyond what snappy::MaxCompressedLength returns. So by padding this by 2 more
+ // bytes, we can avoid additional allocations/copies during compression.
+ return snappy::MaxCompressedLength(inputSize) + 2;
}
StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange input,
DataRange output) {
size_t outLength;
- snappy::RawCompress(input.data(), input.length(), const_cast<char*>(output.data()), &outLength);
+ ConstDataRangeSource source(input);
+ DataRangeSink sink(output);
+
+ try {
+ outLength = snappy::Compress(&source, &sink);
+ } catch (const SnappySourceSinkException& e) {
+ return e.toStatus();
+ }
counterHitCompress(input.length(), outLength);
return {outLength};
@@ -57,11 +167,23 @@ StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange inp
StatusWith<std::size_t> SnappyMessageCompressor::decompressData(ConstDataRange input,
DataRange output) {
- bool ret =
- snappy::RawUncompress(input.data(), input.length(), const_cast<char*>(output.data()));
+ try {
+ uint32_t expectedLength = 0;
+ ConstDataRangeSource lengthCheckSource(input);
+ if (!snappy::GetUncompressedLength(&lengthCheckSource, &expectedLength) ||
+ expectedLength > output.length()) {
+ return {ErrorCodes::BadValue, "Compressed message was invalid or corrupted"};
+ }
+
+ ConstDataRangeSource source(input);
+ DataRangeSink sink(output);
- if (!ret) {
- return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"};
+ bool ret = snappy::Uncompress(&source, &sink);
+ if (!ret) {
+ return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"};
+ }
+ } catch (const SnappySourceSinkException& e) {
+ return e.toStatus();
}
counterHitDecompress(input.length(), output.length());