summaryrefslogtreecommitdiff
path: root/src/mongo/transport/message_compressor_snappy.cpp
diff options
context:
space:
mode:
authorJonathan Reams <jbreams@mongodb.com>2017-09-26 11:44:28 -0400
committerJonathan Reams <jbreams@mongodb.com>2017-09-29 14:37:47 -0400
commit59ead734faa8aa51f0c53bf2bd39d0a0247ddf99 (patch)
tree23e48df5b6276e8edd601d4c652492af60483423 /src/mongo/transport/message_compressor_snappy.cpp
parent79da41f50a567ce1b0df5a4ab7a7eb5109414762 (diff)
downloadmongo-59ead734faa8aa51f0c53bf2bd39d0a0247ddf99.tar.gz
SERVER-31273 Use Source/Sink version of snappy functions
Diffstat (limited to 'src/mongo/transport/message_compressor_snappy.cpp')
-rw-r--r--src/mongo/transport/message_compressor_snappy.cpp136
1 files changed, 129 insertions, 7 deletions
diff --git a/src/mongo/transport/message_compressor_snappy.cpp b/src/mongo/transport/message_compressor_snappy.cpp
index db1e0c9dfca..c722513f452 100644
--- a/src/mongo/transport/message_compressor_snappy.cpp
+++ b/src/mongo/transport/message_compressor_snappy.cpp
@@ -30,26 +30,136 @@
#include "mongo/platform/basic.h"
+#include "mongo/base/data_range_cursor.h"
#include "mongo/base/init.h"
#include "mongo/stdx/memory.h"
#include "mongo/transport/message_compressor_registry.h"
#include "mongo/transport/message_compressor_snappy.h"
-#include "third_party/snappy-1.1.3/snappy.h"
+#include <snappy-sinksource.h>
+#include <snappy.h>
namespace mongo {
+namespace {
+class SnappySourceSinkException : public DBException {
+public:
+ SnappySourceSinkException(Status status) : DBException(status) {}
+};
+
+// This is a bounds-checking version of snappy::UncheckedByteArraySink.
+//
+// If the amount of scratch buffer space requested by snappy is larger than the sink
+// buffer, than it will allocate a new temporary buffer so that snappy can finish.
+// If the amount of scratch buffer space requested by snappy is less than or equal to
+// the size of the sink buffer, than it will just return the sink buffer.
+//
+// If the scratch buffer is the sink buffer, than Append will just advance the buffer
+// cursor and do bounds checking without any copying.
+//
+// Appending data past the end of the sink buffer will throw a SnappySourcesinkexception.
+class DataRangeSink final : public snappy::Sink {
+public:
+ DataRangeSink(DataRange buffer) : _cursor(buffer) {}
+
+ char* GetAppendBuffer(size_t length, char* scratch) final {
+ if (length > _cursor.length()) {
+ _scratch.resize(length);
+ return _scratch.data();
+ }
+
+ return const_cast<char*>(_cursor.data());
+ }
+
+ void AppendAndTakeOwnership(char* data,
+ size_t n,
+ void (*deleter)(void*, const char*, size_t),
+ void* deleterArg) final {
+ Append(data, n);
+ if (data != _cursor.data()) {
+ (*deleter)(deleterArg, data, n);
+ }
+ }
+
+ void Append(const char* bytes, size_t n) final {
+ Status status = Status::OK();
+ if (bytes == _cursor.data()) {
+ status = _cursor.advance(n);
+ } else {
+ ConstDataRange toWrite(bytes, n);
+ status = _cursor.writeAndAdvance(toWrite);
+ }
+ if (!status.isOK()) {
+ throw SnappySourceSinkException(std::move(status));
+ }
+ }
+
+ char* GetAppendBufferVariable(size_t minSize,
+ size_t desiredSizeHint,
+ char* scratch,
+ size_t scratchSize,
+ size_t* allocatedSize) {
+ if (desiredSizeHint > _cursor.length() || minSize > _cursor.length()) {
+ _scratch.resize(desiredSizeHint);
+ *allocatedSize = _scratch.size();
+ return _scratch.data();
+ }
+
+ *allocatedSize = _cursor.length();
+ return const_cast<char*>(_cursor.data());
+ }
+
+private:
+ DataRangeCursor _cursor;
+ std::vector<char> _scratch;
+};
+
+class ConstDataRangeSource final : public snappy::Source {
+public:
+ ConstDataRangeSource(ConstDataRange buffer) : _cursor(buffer) {}
+
+ size_t Available() const final {
+ return _cursor.length();
+ }
+
+ const char* Peek(size_t* len) final {
+ *len = _cursor.length();
+ return _cursor.data();
+ }
+
+ void Skip(size_t n) final {
+ auto status = _cursor.advance(n);
+ if (!status.isOK()) {
+ throw SnappySourceSinkException(std::move(status));
+ }
+ }
+
+private:
+ ConstDataRangeCursor _cursor;
+};
+
+} // namespace
SnappyMessageCompressor::SnappyMessageCompressor()
: MessageCompressorBase(MessageCompressor::kSnappy) {}
std::size_t SnappyMessageCompressor::getMaxCompressedSize(size_t inputSize) {
- return snappy::MaxCompressedLength(inputSize);
+ // Testing has shown that snappy typically requests two additional bytes of buffer space when
+ // compressing beyond what snappy::MaxCompressedLength returns. So by padding this by 2 more
+ // bytes, we can avoid additional allocations/copies during compression.
+ return snappy::MaxCompressedLength(inputSize) + 2;
}
StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange input,
DataRange output) {
size_t outLength;
- snappy::RawCompress(input.data(), input.length(), const_cast<char*>(output.data()), &outLength);
+ ConstDataRangeSource source(input);
+ DataRangeSink sink(output);
+
+ try {
+ outLength = snappy::Compress(&source, &sink);
+ } catch (const SnappySourceSinkException& e) {
+ return e.toStatus();
+ }
counterHitCompress(input.length(), outLength);
return {outLength};
@@ -57,11 +167,23 @@ StatusWith<std::size_t> SnappyMessageCompressor::compressData(ConstDataRange inp
StatusWith<std::size_t> SnappyMessageCompressor::decompressData(ConstDataRange input,
DataRange output) {
- bool ret =
- snappy::RawUncompress(input.data(), input.length(), const_cast<char*>(output.data()));
+ try {
+ uint32_t expectedLength = 0;
+ ConstDataRangeSource lengthCheckSource(input);
+ if (!snappy::GetUncompressedLength(&lengthCheckSource, &expectedLength) ||
+ expectedLength > output.length()) {
+ return {ErrorCodes::BadValue, "Compressed message was invalid or corrupted"};
+ }
+
+ ConstDataRangeSource source(input);
+ DataRangeSink sink(output);
- if (!ret) {
- return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"};
+ bool ret = snappy::Uncompress(&source, &sink);
+ if (!ret) {
+ return Status{ErrorCodes::BadValue, "Compressed message was invalid or corrupted"};
+ }
+ } catch (const SnappySourceSinkException& e) {
+ return e.toStatus();
}
counterHitDecompress(input.length(), output.length());