summaryrefslogtreecommitdiff
path: root/src/mongo/rpc
diff options
context:
space:
mode:
authorA. Jesse Jiryu Davis <jesse@mongodb.com>2019-05-04 09:03:54 -0400
committerA. Jesse Jiryu Davis <jesse@mongodb.com>2019-05-14 19:23:10 -0400
commitc83e50d7275adf2a5e946ba2c4b0861fcd9dc69b (patch)
tree2f672def64169a68c2017a460896aae6ce67c2e5 /src/mongo/rpc
parent089dd83af48cf198916e2dca50742378d4c3d361 (diff)
downloadmongo-c83e50d7275adf2a5e946ba2c4b0861fcd9dc69b.tar.gz
SERVER-28679 Set OP_MSG checksum
Diffstat (limited to 'src/mongo/rpc')
-rw-r--r--src/mongo/rpc/SConscript2
-rw-r--r--src/mongo/rpc/message.h8
-rw-r--r--src/mongo/rpc/op_msg.cpp45
-rw-r--r--src/mongo/rpc/op_msg.h31
-rw-r--r--src/mongo/rpc/op_msg_integration_test.cpp69
-rw-r--r--src/mongo/rpc/op_msg_test.cpp109
-rw-r--r--src/mongo/rpc/reply_builder_test.cpp12
7 files changed, 241 insertions, 35 deletions
diff --git a/src/mongo/rpc/SConscript b/src/mongo/rpc/SConscript
index 9ce0c49fe40..af84d991ab1 100644
--- a/src/mongo/rpc/SConscript
+++ b/src/mongo/rpc/SConscript
@@ -36,6 +36,7 @@ env.Library(
'$BUILD_DIR/mongo/bson/util/bson_extract',
'$BUILD_DIR/mongo/db/bson/dotted_path_support',
'$BUILD_DIR/mongo/db/server_options_core',
+ '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum',
],
)
@@ -141,6 +142,7 @@ env.CppUnitTest(
],
LIBDEPS=[
'$BUILD_DIR/mongo/client/clientdriver_minimal',
+ '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum',
'rpc',
],
)
diff --git a/src/mongo/rpc/message.h b/src/mongo/rpc/message.h
index 3689f2b2512..2acb34a632a 100644
--- a/src/mongo/rpc/message.h
+++ b/src/mongo/rpc/message.h
@@ -424,6 +424,14 @@ public:
return size() - sizeof(MSGHEADER::Value);
}
+ size_t capacity() const {
+ return _buf.capacity();
+ }
+
+ void realloc(size_t size) {
+ _buf.reallocOrCopy(size);
+ }
+
void reset() {
_buf = {};
}
diff --git a/src/mongo/rpc/op_msg.cpp b/src/mongo/rpc/op_msg.cpp
index dd15704d4bc..d6dcad993c7 100644
--- a/src/mongo/rpc/op_msg.cpp
+++ b/src/mongo/rpc/op_msg.cpp
@@ -42,6 +42,7 @@
#include "mongo/util/bufreader.h"
#include "mongo/util/hex.h"
#include "mongo/util/log.h"
+#include "third_party/wiredtiger/wiredtiger.h"
namespace mongo {
namespace {
@@ -58,6 +59,18 @@ enum class Section : uint8_t {
kDocSequence = 1,
};
+constexpr int kCrc32Size = 4;
+
+// All fields including size, requestId, and responseTo must already be set. The size must already
+// include the final 4-byte checksum.
+uint32_t calculateChecksum(const Message& message) {
+ if (message.operation() != dbMsg) {
+ return 0;
+ }
+
+ invariant(OpMsg::isFlagSet(message, OpMsg::kChecksumPresent));
+ return wiredtiger_crc32c_func()(message.singleData().view2ptr(), message.size() - kCrc32Size);
+}
} // namespace
uint32_t OpMsg::flags(const Message& message) {
@@ -76,6 +89,31 @@ void OpMsg::replaceFlags(Message* message, uint32_t flags) {
DataView(message->singleData().data()).write<LittleEndian<uint32_t>>(flags);
}
+uint32_t OpMsg::getChecksum(const Message& message) {
+ invariant(message.operation() == dbMsg);
+ invariant(isFlagSet(message, kChecksumPresent));
+ return BufReader(message.singleData().view2ptr() + message.size() - kCrc32Size, kCrc32Size)
+ .read<LittleEndian<uint32_t>>();
+}
+
+void OpMsg::appendChecksum(Message* message) {
+ if (message->operation() != dbMsg) {
+ return;
+ }
+
+ invariant(!isFlagSet(*message, kChecksumPresent));
+ setFlag(message, kChecksumPresent);
+ const size_t newSize = message->size() + kCrc32Size;
+ if (message->capacity() < newSize) {
+ message->realloc(newSize);
+ }
+
+ // Everything before the checksum, including the final size, is covered by the checksum.
+ message->header().setLen(newSize);
+ DataView(message->singleData().view2ptr() + newSize - kCrc32Size)
+ .write<LittleEndian<uint32_t>>(calculateChecksum(*message));
+}
+
OpMsg OpMsg::parse(const Message& message) try {
// It is the caller's responsibility to call the correct parser for a given message type.
invariant(!message.empty());
@@ -87,7 +125,6 @@ OpMsg OpMsg::parse(const Message& message) try {
<< std::bitset<32>(flags).to_string(),
!containsUnknownRequiredFlags(flags));
- constexpr int kCrc32Size = 4;
const bool haveChecksum = flags & kChecksumPresent;
const int checksumSize = haveChecksum ? kCrc32Size : 0;
@@ -152,6 +189,12 @@ OpMsg OpMsg::parse(const Message& message) try {
!inBody);
}
+ if (haveChecksum) {
+ uassert(ErrorCodes::ChecksumMismatch,
+ "OP_MSG checksum does not match contents",
+ OpMsg::getChecksum(message) == calculateChecksum(message));
+ }
+
return msg;
} catch (const DBException& ex) {
LOG(1) << "invalid message: " << ex.code() << " " << redact(ex) << " -- "
diff --git a/src/mongo/rpc/op_msg.h b/src/mongo/rpc/op_msg.h
index 5983ed63226..3b13b586bff 100644
--- a/src/mongo/rpc/op_msg.h
+++ b/src/mongo/rpc/op_msg.h
@@ -74,6 +74,37 @@ struct OpMsg {
}
/**
+ * Removes a flag from the list of set flags in message.
+ * Only legal on an otherwise valid OP_MSG message.
+ */
+ static void clearFlag(Message* message, uint32_t flag) {
+ replaceFlags(message, flags(*message) & ~flag);
+ }
+
+ /**
+ * Retrieves the checksum stored at the end of the message.
+ */
+ static uint32_t getChecksum(const Message& message);
+
+ /**
+ * Add a checksum at the end of the message. Call this after setting size, requestId, and
+ * responseTo. The checksumPresent flag must *not* already be set.
+ */
+ static void appendChecksum(Message* message);
+
+ /**
+ * If the checksum is present, unsets the checksumPresent flag and shrinks message by 4 bytes.
+ */
+ static void removeChecksum(Message* message) {
+ if (!isFlagSet(*message, kChecksumPresent)) {
+ return;
+ }
+
+ clearFlag(message, kChecksumPresent);
+ message->header().setLen(message->size() - 4);
+ }
+
+ /**
* Parses and returns an OpMsg containing unowned BSON.
*/
static OpMsg parse(const Message& message);
diff --git a/src/mongo/rpc/op_msg_integration_test.cpp b/src/mongo/rpc/op_msg_integration_test.cpp
index 098b9af1323..561264b5db7 100644
--- a/src/mongo/rpc/op_msg_integration_test.cpp
+++ b/src/mongo/rpc/op_msg_integration_test.cpp
@@ -196,6 +196,9 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) {
// rather than say() so that we get an error back when the connection is closed. Normally
// using call() if kMoreToCome set results in blocking forever.
OpMsg::setFlag(&request, OpMsg::kMoreToCome);
+ // conn.call() calculated the request checksum, but setFlag() makes it invalid. Clear the
+ // checksum so the next conn.call() recalculates it.
+ OpMsg::removeChecksum(&request);
ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr));
uassertStatusOK(conn.connect(host, "integration_test")); // Reconnect.
@@ -224,6 +227,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) {
// Round-trip command claims to succeed due to w:0.
+ OpMsg::removeChecksum(&request);
OpMsg::replaceFlags(&request, 0);
ASSERT(conn.call(request, reply, /*assertOK*/ true, nullptr));
ASSERT_OK(getStatusFromCommandResult(
@@ -231,6 +235,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) {
// Fire-and-forget should still close connection.
OpMsg::setFlag(&request, OpMsg::kMoreToCome);
+ OpMsg::removeChecksum(&request);
ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr));
break;
@@ -270,7 +275,19 @@ TEST(OpMsg, DocumentSequenceReturnsWork) {
<< "admin"));
}
-TEST(OpMsg, ServerHandlesExhaustCorrectly) {
+constexpr auto kDisableChecksum = "dbClientConnectionDisableChecksum";
+
+void disableClientChecksum() {
+ auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum);
+ failPoint->setMode(FailPoint::alwaysOn);
+}
+
+void enableClientChecksum() {
+ auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum);
+ failPoint->setMode(FailPoint::off);
+}
+
+void exhaustTest(bool enableChecksum) {
std::string errMsg;
auto conn = std::unique_ptr<DBClientBase>(
unittest::getFixtureConnectionString().connect("integration_test", errMsg));
@@ -281,6 +298,12 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
return;
}
+ if (!enableChecksum) {
+ disableClientChecksum();
+ }
+
+ ON_BLOCK_EXIT([&] { enableClientChecksum(); });
+
NamespaceString nss("test", "coll");
conn->dropCollection(nss.toString());
@@ -301,6 +324,8 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
const long long cursorId = res["cursor"]["id"].numberLong();
ASSERT(res["cursor"]["firstBatch"].Array().empty());
ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ // Reply has checksum if and only if the request did.
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
// Construct getMore request with exhaust flag. Set batch size so we will need multiple batches
// to exhaust the cursor.
@@ -314,6 +339,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
ASSERT(conn->call(request, reply));
auto lastRequestId = reply.header().getId();
ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
res = OpMsg::parse(reply).body;
ASSERT_OK(getStatusFromCommandResult(res));
ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId);
@@ -326,6 +352,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
conn->recv(reply, lastRequestId);
lastRequestId = reply.header().getId();
ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
res = OpMsg::parse(reply).body;
ASSERT_OK(getStatusFromCommandResult(res));
ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId);
@@ -337,6 +364,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
// Receive terminal batch.
ASSERT(conn->recv(reply, lastRequestId));
ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
res = OpMsg::parse(reply).body;
ASSERT_OK(getStatusFromCommandResult(res));
ASSERT_EQ(res["cursor"]["id"].numberLong(), 0);
@@ -345,6 +373,14 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
ASSERT_BSONOBJ_EQ(nextBatch[0].embeddedObject(), BSON("_id" << 4));
}
+TEST(OpMsg, ServerHandlesExhaustCorrectly) {
+ exhaustTest(false);
+}
+
+TEST(OpMsg, ServerHandlesExhaustCorrectlyWithChecksum) {
+ exhaustTest(true);
+}
+
TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) {
// This test simply tries to verify that using the exhaust option with DBClientCursor works
// correctly. The externally visible behavior should technically be the same as a non-exhaust
@@ -400,4 +436,35 @@ TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) {
ASSERT(!cursor->more());
ASSERT(cursor->isDead());
}
+
+void checksumTest(bool enableChecksum) {
+ // The server replies with a checksum if and only if the request has a checksum.
+ std::string errMsg;
+ auto conn = std::unique_ptr<DBClientBase>(
+ unittest::getFixtureConnectionString().connect("integration_test", errMsg));
+ uassert(ErrorCodes::SocketException, errMsg, conn);
+
+ if (!enableChecksum) {
+ disableClientChecksum();
+ }
+
+ ON_BLOCK_EXIT([&] { enableClientChecksum(); });
+
+ auto opMsgRequest = OpMsgRequest::fromDBAndBody("admin", BSON("ping" << 1));
+ auto request = opMsgRequest.serialize();
+
+ Message reply;
+ ASSERT(conn->call(request, reply));
+
+ auto opMsgReply = OpMsg::parse(reply);
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
+}
+
+TEST(OpMsg, ServerRepliesWithoutChecksumToRequestWithoutChecksum) {
+ checksumTest(true);
+}
+
+TEST(OpMsg, ServerRepliesWithChecksumToRequestWithChecksum) {
+ checksumTest(true);
+}
} // namespace mongo
diff --git a/src/mongo/rpc/op_msg_test.cpp b/src/mongo/rpc/op_msg_test.cpp
index 314f6b99284..bf280768638 100644
--- a/src/mongo/rpc/op_msg_test.cpp
+++ b/src/mongo/rpc/op_msg_test.cpp
@@ -41,6 +41,7 @@
#include "mongo/unittest/unittest.h"
#include "mongo/util/hex.h"
#include "mongo/util/log.h"
+#include "third_party/wiredtiger/wiredtiger.h"
namespace mongo {
namespace {
@@ -94,14 +95,19 @@ public:
explicit Sized(T&&... args) {
buffer.skip(sizeof(int32_t));
append(args...);
- DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len());
+ updateSize();
}
// Adds extra to the stored size. Use this to produce illegal messages.
Sized&& addToSize(int32_t extra) && {
- DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra);
+ updateSize(extra);
return std::move(*this);
}
+
+protected:
+ void updateSize(int32_t extra = 0) {
+ DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra);
+ }
};
// A Bytes that puts the standard message header at the front.
@@ -127,7 +133,25 @@ public:
}
OpMsgBytes&& addToSize(int32_t extra) && {
- DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra);
+ updateSize(extra);
+ return std::move(*this);
+ }
+
+ OpMsgBytes&& appendChecksum() && {
+ // Reserve space at the end for the checksum.
+ append<uint32_t>(0);
+ updateSize();
+ // Checksum all bits except the checksum itself.
+ uint32_t checksum = wiredtiger_crc32c_func()(buffer.buf(), buffer.len() - 4);
+ // Write the checksum bits at the end.
+ auto checksumBits = DataView(buffer.buf() + buffer.len() - sizeof(checksum));
+ checksumBits.write<LittleEndian<uint32_t>>(checksum);
+ return std::move(*this);
+ }
+
+ OpMsgBytes&& appendChecksum(uint32_t checksum) && {
+ append(checksum);
+ updateSize();
return std::move(*this);
}
};
@@ -158,9 +182,6 @@ const char kDocSequenceSection = 1;
const uint32_t kNoFlags = 0;
const uint32_t kHaveChecksum = 1;
-// CRC filler value
-const uint32_t kFakeCRC = 0; // TODO will need to compute real crc when SERVER-28679 is done.
-
TEST_F(OpMsgParser, SucceedsWithJustBody) {
auto msg = OpMsgBytes{
kNoFlags, //
@@ -172,13 +193,12 @@ TEST_F(OpMsgParser, SucceedsWithJustBody) {
ASSERT_EQ(msg.sequences.size(), 0u);
}
-TEST_F(OpMsgParser, IgnoresCrcIfPresent) { // Until SERVER-28679 is done.
- auto msg = OpMsgBytes{
- kHaveChecksum, //
- kBodySection,
- fromjson("{ping: 1}"),
- kFakeCRC, // If not ignored, this would be read as a second body.
- }.parse();
+TEST_F(OpMsgParser, SucceedsWithChecksum) {
+ auto msg = OpMsgBytes{kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}")}
+ .appendChecksum()
+ .parse();
ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}"));
ASSERT_EQ(msg.sequences.size(), 0u);
@@ -422,12 +442,13 @@ TEST_F(OpMsgParser, FailsIfBodyTooBig) {
}
TEST_F(OpMsgParser, FailsIfBodyTooBigIntoChecksum) {
- auto msg = OpMsgBytes{
- kHaveChecksum, //
- kBodySection,
- fromjson("{ping: 1}"),
- kFakeCRC,
- }.addToSize(-1); // Shrink message so body extends past end.
+ auto msg =
+ OpMsgBytes{
+ kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}"),
+ }.appendChecksum()
+ .addToSize(-1); // Shrink message so body extends past end.
ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON);
}
@@ -449,19 +470,19 @@ TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBig) {
}
TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBigIntoChecksum) {
- auto msg = OpMsgBytes{
- kHaveChecksum, //
- kBodySection,
- fromjson("{ping: 1}"),
-
- kDocSequenceSection,
- Sized{
- "docs", //
- fromjson("{a: 1}"),
- },
+ auto msg =
+ OpMsgBytes{
+ kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}"),
- kFakeCRC,
- }.addToSize(-1); // Shrink message so body extends past end.
+ kDocSequenceSection,
+ Sized{
+ "docs", //
+ fromjson("{a: 1}"),
+ },
+ }.appendChecksum()
+ .addToSize(-1); // Shrink message so body extends past end.
ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow);
}
@@ -594,6 +615,18 @@ TEST_F(OpMsgParser, SucceedsWithUnknownOptionalFlags) {
}
}
+TEST_F(OpMsgParser, FailsWithChecksumMismatch) {
+ auto msg = OpMsgBytes{kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}")}
+ .appendChecksum(123);
+
+ ASSERT_THROWS_WITH_CHECK(msg.parse(), AssertionException, [](const DBException& ex) {
+ ASSERT_EQ(ex.toStatus().code(), ErrorCodes::ChecksumMismatch);
+ ASSERT(ErrorCodes::isConnectionFatalMessageParseError(ex.toStatus().code()));
+ });
+}
+
void testSerializer(const Message& fromSerializer, OpMsgBytes&& expected) {
const auto expectedMsg = expected.done();
ASSERT_EQ(fromSerializer.operation(), dbMsg);
@@ -860,5 +893,19 @@ TEST(OpMsgRequest, FromDbAndBodyDoesNotCopy) {
ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1, $db: 'db'}"));
ASSERT_EQ(static_cast<const void*>(msg.body.objdata()), bodyPtr);
}
+
+TEST(OpMsgTest, ChecksumResizesMessage) {
+ auto msg = OpMsgBytes{kNoFlags, //
+ kBodySection,
+ fromjson("{ping: 1}")}
+ .done();
+
+ // Test that appendChecksum() resizes the buffer if necessary.
+ const auto capacity = msg.sharedBuffer().capacity();
+ OpMsg::appendChecksum(&msg);
+ ASSERT_EQ(msg.sharedBuffer().capacity(), capacity + 4);
+ // The checksum is correct.
+ OpMsg::parse(msg);
+}
} // namespace
} // namespace mongo
diff --git a/src/mongo/rpc/reply_builder_test.cpp b/src/mongo/rpc/reply_builder_test.cpp
index 5bb4ff13320..fcdb036c650 100644
--- a/src/mongo/rpc/reply_builder_test.cpp
+++ b/src/mongo/rpc/reply_builder_test.cpp
@@ -151,6 +151,9 @@ TEST(OpMsgReplyBuilder, CommandError) {
replyBuilder.setCommandReply(status, extraObj);
replyBuilder.getBodyBuilder().appendElements(metadata);
auto msg = replyBuilder.done();
+ msg.header().setId(124);
+ msg.header().setResponseToMsgId(123);
+ OpMsg::appendChecksum(&msg);
rpc::OpMsgReply parsed(&msg);
@@ -172,7 +175,9 @@ void testRoundTrip(rpc::ReplyBuilderInterface& replyBuilder, bool unifiedBodyAnd
replyBuilder.getBodyBuilder().appendElements(metadata);
auto msg = replyBuilder.done();
-
+ msg.header().setId(124);
+ msg.header().setResponseToMsgId(123);
+ OpMsg::appendChecksum(&msg);
T parsed(&msg);
if (unifiedBodyAndMetadata) {
@@ -197,7 +202,10 @@ void testErrors(rpc::ReplyBuilderInterface& replyBuilder) {
replyBuilder.setCommandReply(status);
replyBuilder.getBodyBuilder().appendElements(buildMetadata());
- const auto msg = replyBuilder.done();
+ auto msg = replyBuilder.done();
+ msg.header().setId(124);
+ msg.header().setResponseToMsgId(123);
+ OpMsg::appendChecksum(&msg);
T parsed(&msg);
const Status result = getStatusFromCommandResult(parsed.getCommandReply());