diff options
author | A. Jesse Jiryu Davis <jesse@mongodb.com> | 2019-05-04 09:03:54 -0400 |
---|---|---|
committer | A. Jesse Jiryu Davis <jesse@mongodb.com> | 2019-05-14 19:23:10 -0400 |
commit | c83e50d7275adf2a5e946ba2c4b0861fcd9dc69b (patch) | |
tree | 2f672def64169a68c2017a460896aae6ce67c2e5 /src/mongo/rpc | |
parent | 089dd83af48cf198916e2dca50742378d4c3d361 (diff) | |
download | mongo-c83e50d7275adf2a5e946ba2c4b0861fcd9dc69b.tar.gz |
SERVER-28679 Set OP_MSG checksum
Diffstat (limited to 'src/mongo/rpc')
-rw-r--r-- | src/mongo/rpc/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/rpc/message.h | 8 | ||||
-rw-r--r-- | src/mongo/rpc/op_msg.cpp | 45 | ||||
-rw-r--r-- | src/mongo/rpc/op_msg.h | 31 | ||||
-rw-r--r-- | src/mongo/rpc/op_msg_integration_test.cpp | 69 | ||||
-rw-r--r-- | src/mongo/rpc/op_msg_test.cpp | 109 | ||||
-rw-r--r-- | src/mongo/rpc/reply_builder_test.cpp | 12 |
7 files changed, 241 insertions, 35 deletions
diff --git a/src/mongo/rpc/SConscript b/src/mongo/rpc/SConscript index 9ce0c49fe40..af84d991ab1 100644 --- a/src/mongo/rpc/SConscript +++ b/src/mongo/rpc/SConscript @@ -36,6 +36,7 @@ env.Library( '$BUILD_DIR/mongo/bson/util/bson_extract', '$BUILD_DIR/mongo/db/bson/dotted_path_support', '$BUILD_DIR/mongo/db/server_options_core', + '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum', ], ) @@ -141,6 +142,7 @@ env.CppUnitTest( ], LIBDEPS=[ '$BUILD_DIR/mongo/client/clientdriver_minimal', + '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum', 'rpc', ], ) diff --git a/src/mongo/rpc/message.h b/src/mongo/rpc/message.h index 3689f2b2512..2acb34a632a 100644 --- a/src/mongo/rpc/message.h +++ b/src/mongo/rpc/message.h @@ -424,6 +424,14 @@ public: return size() - sizeof(MSGHEADER::Value); } + size_t capacity() const { + return _buf.capacity(); + } + + void realloc(size_t size) { + _buf.reallocOrCopy(size); + } + void reset() { _buf = {}; } diff --git a/src/mongo/rpc/op_msg.cpp b/src/mongo/rpc/op_msg.cpp index dd15704d4bc..d6dcad993c7 100644 --- a/src/mongo/rpc/op_msg.cpp +++ b/src/mongo/rpc/op_msg.cpp @@ -42,6 +42,7 @@ #include "mongo/util/bufreader.h" #include "mongo/util/hex.h" #include "mongo/util/log.h" +#include "third_party/wiredtiger/wiredtiger.h" namespace mongo { namespace { @@ -58,6 +59,18 @@ enum class Section : uint8_t { kDocSequence = 1, }; +constexpr int kCrc32Size = 4; + +// All fields including size, requestId, and responseTo must already be set. The size must already +// include the final 4-byte checksum. +uint32_t calculateChecksum(const Message& message) { + if (message.operation() != dbMsg) { + return 0; + } + + invariant(OpMsg::isFlagSet(message, OpMsg::kChecksumPresent)); + return wiredtiger_crc32c_func()(message.singleData().view2ptr(), message.size() - kCrc32Size); +} } // namespace uint32_t OpMsg::flags(const Message& message) { @@ -76,6 +89,31 @@ void OpMsg::replaceFlags(Message* message, uint32_t flags) { DataView(message->singleData().data()).write<LittleEndian<uint32_t>>(flags); } +uint32_t OpMsg::getChecksum(const Message& message) { + invariant(message.operation() == dbMsg); + invariant(isFlagSet(message, kChecksumPresent)); + return BufReader(message.singleData().view2ptr() + message.size() - kCrc32Size, kCrc32Size) + .read<LittleEndian<uint32_t>>(); +} + +void OpMsg::appendChecksum(Message* message) { + if (message->operation() != dbMsg) { + return; + } + + invariant(!isFlagSet(*message, kChecksumPresent)); + setFlag(message, kChecksumPresent); + const size_t newSize = message->size() + kCrc32Size; + if (message->capacity() < newSize) { + message->realloc(newSize); + } + + // Everything before the checksum, including the final size, is covered by the checksum. + message->header().setLen(newSize); + DataView(message->singleData().view2ptr() + newSize - kCrc32Size) + .write<LittleEndian<uint32_t>>(calculateChecksum(*message)); +} + OpMsg OpMsg::parse(const Message& message) try { // It is the caller's responsibility to call the correct parser for a given message type. invariant(!message.empty()); @@ -87,7 +125,6 @@ OpMsg OpMsg::parse(const Message& message) try { << std::bitset<32>(flags).to_string(), !containsUnknownRequiredFlags(flags)); - constexpr int kCrc32Size = 4; const bool haveChecksum = flags & kChecksumPresent; const int checksumSize = haveChecksum ? kCrc32Size : 0; @@ -152,6 +189,12 @@ OpMsg OpMsg::parse(const Message& message) try { !inBody); } + if (haveChecksum) { + uassert(ErrorCodes::ChecksumMismatch, + "OP_MSG checksum does not match contents", + OpMsg::getChecksum(message) == calculateChecksum(message)); + } + return msg; } catch (const DBException& ex) { LOG(1) << "invalid message: " << ex.code() << " " << redact(ex) << " -- " diff --git a/src/mongo/rpc/op_msg.h b/src/mongo/rpc/op_msg.h index 5983ed63226..3b13b586bff 100644 --- a/src/mongo/rpc/op_msg.h +++ b/src/mongo/rpc/op_msg.h @@ -74,6 +74,37 @@ struct OpMsg { } /** + * Removes a flag from the list of set flags in message. + * Only legal on an otherwise valid OP_MSG message. + */ + static void clearFlag(Message* message, uint32_t flag) { + replaceFlags(message, flags(*message) & ~flag); + } + + /** + * Retrieves the checksum stored at the end of the message. + */ + static uint32_t getChecksum(const Message& message); + + /** + * Add a checksum at the end of the message. Call this after setting size, requestId, and + * responseTo. The checksumPresent flag must *not* already be set. + */ + static void appendChecksum(Message* message); + + /** + * If the checksum is present, unsets the checksumPresent flag and shrinks message by 4 bytes. + */ + static void removeChecksum(Message* message) { + if (!isFlagSet(*message, kChecksumPresent)) { + return; + } + + clearFlag(message, kChecksumPresent); + message->header().setLen(message->size() - 4); + } + + /** * Parses and returns an OpMsg containing unowned BSON. */ static OpMsg parse(const Message& message); diff --git a/src/mongo/rpc/op_msg_integration_test.cpp b/src/mongo/rpc/op_msg_integration_test.cpp index 098b9af1323..561264b5db7 100644 --- a/src/mongo/rpc/op_msg_integration_test.cpp +++ b/src/mongo/rpc/op_msg_integration_test.cpp @@ -196,6 +196,9 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { // rather than say() so that we get an error back when the connection is closed. Normally // using call() if kMoreToCome set results in blocking forever. OpMsg::setFlag(&request, OpMsg::kMoreToCome); + // conn.call() calculated the request checksum, but setFlag() makes it invalid. Clear the + // checksum so the next conn.call() recalculates it. + OpMsg::removeChecksum(&request); ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr)); uassertStatusOK(conn.connect(host, "integration_test")); // Reconnect. @@ -224,6 +227,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { // Round-trip command claims to succeed due to w:0. + OpMsg::removeChecksum(&request); OpMsg::replaceFlags(&request, 0); ASSERT(conn.call(request, reply, /*assertOK*/ true, nullptr)); ASSERT_OK(getStatusFromCommandResult( @@ -231,6 +235,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { // Fire-and-forget should still close connection. OpMsg::setFlag(&request, OpMsg::kMoreToCome); + OpMsg::removeChecksum(&request); ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr)); break; @@ -270,7 +275,19 @@ TEST(OpMsg, DocumentSequenceReturnsWork) { << "admin")); } -TEST(OpMsg, ServerHandlesExhaustCorrectly) { +constexpr auto kDisableChecksum = "dbClientConnectionDisableChecksum"; + +void disableClientChecksum() { + auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum); + failPoint->setMode(FailPoint::alwaysOn); +} + +void enableClientChecksum() { + auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum); + failPoint->setMode(FailPoint::off); +} + +void exhaustTest(bool enableChecksum) { std::string errMsg; auto conn = std::unique_ptr<DBClientBase>( unittest::getFixtureConnectionString().connect("integration_test", errMsg)); @@ -281,6 +298,12 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { return; } + if (!enableChecksum) { + disableClientChecksum(); + } + + ON_BLOCK_EXIT([&] { enableClientChecksum(); }); + NamespaceString nss("test", "coll"); conn->dropCollection(nss.toString()); @@ -301,6 +324,8 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { const long long cursorId = res["cursor"]["id"].numberLong(); ASSERT(res["cursor"]["firstBatch"].Array().empty()); ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + // Reply has checksum if and only if the request did. + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); // Construct getMore request with exhaust flag. Set batch size so we will need multiple batches // to exhaust the cursor. @@ -314,6 +339,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { ASSERT(conn->call(request, reply)); auto lastRequestId = reply.header().getId(); ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); res = OpMsg::parse(reply).body; ASSERT_OK(getStatusFromCommandResult(res)); ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId); @@ -326,6 +352,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { conn->recv(reply, lastRequestId); lastRequestId = reply.header().getId(); ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); res = OpMsg::parse(reply).body; ASSERT_OK(getStatusFromCommandResult(res)); ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId); @@ -337,6 +364,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { // Receive terminal batch. ASSERT(conn->recv(reply, lastRequestId)); ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); res = OpMsg::parse(reply).body; ASSERT_OK(getStatusFromCommandResult(res)); ASSERT_EQ(res["cursor"]["id"].numberLong(), 0); @@ -345,6 +373,14 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { ASSERT_BSONOBJ_EQ(nextBatch[0].embeddedObject(), BSON("_id" << 4)); } +TEST(OpMsg, ServerHandlesExhaustCorrectly) { + exhaustTest(false); +} + +TEST(OpMsg, ServerHandlesExhaustCorrectlyWithChecksum) { + exhaustTest(true); +} + TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) { // This test simply tries to verify that using the exhaust option with DBClientCursor works // correctly. The externally visible behavior should technically be the same as a non-exhaust @@ -400,4 +436,35 @@ TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) { ASSERT(!cursor->more()); ASSERT(cursor->isDead()); } + +void checksumTest(bool enableChecksum) { + // The server replies with a checksum if and only if the request has a checksum. + std::string errMsg; + auto conn = std::unique_ptr<DBClientBase>( + unittest::getFixtureConnectionString().connect("integration_test", errMsg)); + uassert(ErrorCodes::SocketException, errMsg, conn); + + if (!enableChecksum) { + disableClientChecksum(); + } + + ON_BLOCK_EXIT([&] { enableClientChecksum(); }); + + auto opMsgRequest = OpMsgRequest::fromDBAndBody("admin", BSON("ping" << 1)); + auto request = opMsgRequest.serialize(); + + Message reply; + ASSERT(conn->call(request, reply)); + + auto opMsgReply = OpMsg::parse(reply); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); +} + +TEST(OpMsg, ServerRepliesWithoutChecksumToRequestWithoutChecksum) { + checksumTest(true); +} + +TEST(OpMsg, ServerRepliesWithChecksumToRequestWithChecksum) { + checksumTest(true); +} } // namespace mongo diff --git a/src/mongo/rpc/op_msg_test.cpp b/src/mongo/rpc/op_msg_test.cpp index 314f6b99284..bf280768638 100644 --- a/src/mongo/rpc/op_msg_test.cpp +++ b/src/mongo/rpc/op_msg_test.cpp @@ -41,6 +41,7 @@ #include "mongo/unittest/unittest.h" #include "mongo/util/hex.h" #include "mongo/util/log.h" +#include "third_party/wiredtiger/wiredtiger.h" namespace mongo { namespace { @@ -94,14 +95,19 @@ public: explicit Sized(T&&... args) { buffer.skip(sizeof(int32_t)); append(args...); - DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len()); + updateSize(); } // Adds extra to the stored size. Use this to produce illegal messages. Sized&& addToSize(int32_t extra) && { - DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra); + updateSize(extra); return std::move(*this); } + +protected: + void updateSize(int32_t extra = 0) { + DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra); + } }; // A Bytes that puts the standard message header at the front. @@ -127,7 +133,25 @@ public: } OpMsgBytes&& addToSize(int32_t extra) && { - DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra); + updateSize(extra); + return std::move(*this); + } + + OpMsgBytes&& appendChecksum() && { + // Reserve space at the end for the checksum. + append<uint32_t>(0); + updateSize(); + // Checksum all bits except the checksum itself. + uint32_t checksum = wiredtiger_crc32c_func()(buffer.buf(), buffer.len() - 4); + // Write the checksum bits at the end. + auto checksumBits = DataView(buffer.buf() + buffer.len() - sizeof(checksum)); + checksumBits.write<LittleEndian<uint32_t>>(checksum); + return std::move(*this); + } + + OpMsgBytes&& appendChecksum(uint32_t checksum) && { + append(checksum); + updateSize(); return std::move(*this); } }; @@ -158,9 +182,6 @@ const char kDocSequenceSection = 1; const uint32_t kNoFlags = 0; const uint32_t kHaveChecksum = 1; -// CRC filler value -const uint32_t kFakeCRC = 0; // TODO will need to compute real crc when SERVER-28679 is done. - TEST_F(OpMsgParser, SucceedsWithJustBody) { auto msg = OpMsgBytes{ kNoFlags, // @@ -172,13 +193,12 @@ TEST_F(OpMsgParser, SucceedsWithJustBody) { ASSERT_EQ(msg.sequences.size(), 0u); } -TEST_F(OpMsgParser, IgnoresCrcIfPresent) { // Until SERVER-28679 is done. - auto msg = OpMsgBytes{ - kHaveChecksum, // - kBodySection, - fromjson("{ping: 1}"), - kFakeCRC, // If not ignored, this would be read as a second body. - }.parse(); +TEST_F(OpMsgParser, SucceedsWithChecksum) { + auto msg = OpMsgBytes{kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}")} + .appendChecksum() + .parse(); ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); ASSERT_EQ(msg.sequences.size(), 0u); @@ -422,12 +442,13 @@ TEST_F(OpMsgParser, FailsIfBodyTooBig) { } TEST_F(OpMsgParser, FailsIfBodyTooBigIntoChecksum) { - auto msg = OpMsgBytes{ - kHaveChecksum, // - kBodySection, - fromjson("{ping: 1}"), - kFakeCRC, - }.addToSize(-1); // Shrink message so body extends past end. + auto msg = + OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), + }.appendChecksum() + .addToSize(-1); // Shrink message so body extends past end. ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON); } @@ -449,19 +470,19 @@ TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBig) { } TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBigIntoChecksum) { - auto msg = OpMsgBytes{ - kHaveChecksum, // - kBodySection, - fromjson("{ping: 1}"), - - kDocSequenceSection, - Sized{ - "docs", // - fromjson("{a: 1}"), - }, + auto msg = + OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), - kFakeCRC, - }.addToSize(-1); // Shrink message so body extends past end. + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + }.appendChecksum() + .addToSize(-1); // Shrink message so body extends past end. ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow); } @@ -594,6 +615,18 @@ TEST_F(OpMsgParser, SucceedsWithUnknownOptionalFlags) { } } +TEST_F(OpMsgParser, FailsWithChecksumMismatch) { + auto msg = OpMsgBytes{kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}")} + .appendChecksum(123); + + ASSERT_THROWS_WITH_CHECK(msg.parse(), AssertionException, [](const DBException& ex) { + ASSERT_EQ(ex.toStatus().code(), ErrorCodes::ChecksumMismatch); + ASSERT(ErrorCodes::isConnectionFatalMessageParseError(ex.toStatus().code())); + }); +} + void testSerializer(const Message& fromSerializer, OpMsgBytes&& expected) { const auto expectedMsg = expected.done(); ASSERT_EQ(fromSerializer.operation(), dbMsg); @@ -860,5 +893,19 @@ TEST(OpMsgRequest, FromDbAndBodyDoesNotCopy) { ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1, $db: 'db'}")); ASSERT_EQ(static_cast<const void*>(msg.body.objdata()), bodyPtr); } + +TEST(OpMsgTest, ChecksumResizesMessage) { + auto msg = OpMsgBytes{kNoFlags, // + kBodySection, + fromjson("{ping: 1}")} + .done(); + + // Test that appendChecksum() resizes the buffer if necessary. + const auto capacity = msg.sharedBuffer().capacity(); + OpMsg::appendChecksum(&msg); + ASSERT_EQ(msg.sharedBuffer().capacity(), capacity + 4); + // The checksum is correct. + OpMsg::parse(msg); +} } // namespace } // namespace mongo diff --git a/src/mongo/rpc/reply_builder_test.cpp b/src/mongo/rpc/reply_builder_test.cpp index 5bb4ff13320..fcdb036c650 100644 --- a/src/mongo/rpc/reply_builder_test.cpp +++ b/src/mongo/rpc/reply_builder_test.cpp @@ -151,6 +151,9 @@ TEST(OpMsgReplyBuilder, CommandError) { replyBuilder.setCommandReply(status, extraObj); replyBuilder.getBodyBuilder().appendElements(metadata); auto msg = replyBuilder.done(); + msg.header().setId(124); + msg.header().setResponseToMsgId(123); + OpMsg::appendChecksum(&msg); rpc::OpMsgReply parsed(&msg); @@ -172,7 +175,9 @@ void testRoundTrip(rpc::ReplyBuilderInterface& replyBuilder, bool unifiedBodyAnd replyBuilder.getBodyBuilder().appendElements(metadata); auto msg = replyBuilder.done(); - + msg.header().setId(124); + msg.header().setResponseToMsgId(123); + OpMsg::appendChecksum(&msg); T parsed(&msg); if (unifiedBodyAndMetadata) { @@ -197,7 +202,10 @@ void testErrors(rpc::ReplyBuilderInterface& replyBuilder) { replyBuilder.setCommandReply(status); replyBuilder.getBodyBuilder().appendElements(buildMetadata()); - const auto msg = replyBuilder.done(); + auto msg = replyBuilder.done(); + msg.header().setId(124); + msg.header().setResponseToMsgId(123); + OpMsg::appendChecksum(&msg); T parsed(&msg); const Status result = getStatusFromCommandResult(parsed.getCommandReply()); |