From c83e50d7275adf2a5e946ba2c4b0861fcd9dc69b Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Sat, 4 May 2019 09:03:54 -0400 Subject: SERVER-28679 Set OP_MSG checksum --- src/mongo/base/error_codes.err | 5 +- src/mongo/client/async_client.cpp | 1 + src/mongo/client/dbclient_connection.cpp | 8 ++ src/mongo/client/dbclient_cursor_test.cpp | 6 +- src/mongo/db/traffic_reader.cpp | 3 +- src/mongo/dbtests/SConscript | 1 + src/mongo/dbtests/shared_buffer.cpp | 133 +++++++++++++++++++++ .../embedded/mongo_embedded/mongo_embedded.cpp | 1 + src/mongo/rpc/SConscript | 2 + src/mongo/rpc/message.h | 8 ++ src/mongo/rpc/op_msg.cpp | 45 ++++++- src/mongo/rpc/op_msg.h | 31 +++++ src/mongo/rpc/op_msg_integration_test.cpp | 69 ++++++++++- src/mongo/rpc/op_msg_test.cpp | 109 ++++++++++++----- src/mongo/rpc/reply_builder_test.cpp | 12 +- src/mongo/tools/bridge.cpp | 4 + src/mongo/transport/service_state_machine.cpp | 17 ++- src/mongo/transport/transport_layer_asio_test.cpp | 1 + src/mongo/util/shared_buffer.h | 19 +++ 19 files changed, 433 insertions(+), 42 deletions(-) create mode 100644 src/mongo/dbtests/shared_buffer.cpp (limited to 'src') diff --git a/src/mongo/base/error_codes.err b/src/mongo/base/error_codes.err index 14be694a0bc..b34ac7c5c61 100644 --- a/src/mongo/base/error_codes.err +++ b/src/mongo/base/error_codes.err @@ -287,6 +287,7 @@ error_code("IndexBuildAlreadyInProgress", 285) error_code("ChangeStreamHistoryLost", 286) # The code below is for internal use only and must never be returned in a network response error_code("TransactionCoordinatorDeadlineTaskCanceled", 287) +error_code("ChecksumMismatch", 288) # Error codes 4000-8999 are reserved. @@ -348,9 +349,9 @@ error_class("ShutdownError", ["ShutdownInProgress", "InterruptedAtShutdown"]) # indicates that it cannot be executed as normal and must abort its intended work. error_class("CancelationError", ["ShutdownInProgress", "InterruptedAtShutdown", "CallbackCanceled"]) -#TODO SERVER-28679 add checksum failure. error_class("ConnectionFatalMessageParseError", ["IllegalOpMsgFlag", - "TooManyDocumentSequences"]) + "TooManyDocumentSequences", + "ChecksumMismatch"]) error_class("ExceededTimeLimitError", ["ExceededTimeLimit", "MaxTimeMSExpired", "NetworkInterfaceExceededTimeLimit"]) diff --git a/src/mongo/client/async_client.cpp b/src/mongo/client/async_client.cpp index e99af8a258b..7410b1a89d9 100644 --- a/src/mongo/client/async_client.cpp +++ b/src/mongo/client/async_client.cpp @@ -202,6 +202,7 @@ Future AsyncDBClient::_call(Message request, const BatonHandle& baton) auto msgId = nextMessageId(); request.header().setId(msgId); request.header().setResponseToMsgId(0); + OpMsg::appendChecksum(&request); return _session->asyncSinkMessage(request, baton) .then([this, baton] { return _session->asyncSourceMessage(baton); }) diff --git a/src/mongo/client/dbclient_connection.cpp b/src/mongo/client/dbclient_connection.cpp index 2341e2a8a69..b01edcc993c 100644 --- a/src/mongo/client/dbclient_connection.cpp +++ b/src/mongo/client/dbclient_connection.cpp @@ -85,6 +85,8 @@ using std::endl; using std::map; using std::string; +MONGO_FAIL_POINT_DEFINE(dbClientConnectionDisableChecksum); + namespace { /** @@ -576,6 +578,9 @@ void DBClientConnection::say(Message& toSend, bool isRetry, string* actualServer toSend.header().setId(nextMessageId()); toSend.header().setResponseToMsgId(0); + if (!MONGO_FAIL_POINT(dbClientConnectionDisableChecksum)) { + OpMsg::appendChecksum(&toSend); + } uassertStatusOK( _session->sinkMessage(uassertStatusOK(_compressorManager.compressMessage(toSend)))); killSessionOnError.dismiss(); @@ -619,6 +624,9 @@ bool DBClientConnection::call(Message& toSend, toSend.header().setId(nextMessageId()); toSend.header().setResponseToMsgId(0); + if (!MONGO_FAIL_POINT(dbClientConnectionDisableChecksum)) { + OpMsg::appendChecksum(&toSend); + } auto swm = _compressorManager.compressMessage(toSend); uassertStatusOK(swm.getStatus()); diff --git a/src/mongo/client/dbclient_cursor_test.cpp b/src/mongo/client/dbclient_cursor_test.cpp index 6429df5c604..292ce2c8bb5 100644 --- a/src/mongo/client/dbclient_cursor_test.cpp +++ b/src/mongo/client/dbclient_cursor_test.cpp @@ -59,12 +59,14 @@ public: const auto reqId = nextMessageId(); toSend.header().setId(reqId); toSend.header().setResponseToMsgId(0); + OpMsg::appendChecksum(&toSend); _lastSent = toSend; // Mock response. response = _mockCallResponse; response.header().setId(nextMessageId()); response.header().setResponseToMsgId(reqId); + OpMsg::appendChecksum(&response); return true; } @@ -164,7 +166,7 @@ TEST_F(DBClientCursorTest, DBClientCursorHandlesOpMsgExhaustCorrectly) { auto m = conn.getLastSentMessage(); ASSERT(!m.empty()); auto msg = OpMsg::parse(m); - ASSERT_EQ(OpMsg::flags(m), 0U); + ASSERT_EQ(OpMsg::flags(m), OpMsg::kChecksumPresent); ASSERT_EQ(msg.body.getStringField("find"), nss.coll()); ASSERT_EQ(msg.body["batchSize"].number(), 0); @@ -433,7 +435,7 @@ TEST_F(DBClientCursorTest, DBClientCursorPassesReadOnceFlag) { auto m = conn.getLastSentMessage(); ASSERT(!m.empty()); auto msg = OpMsg::parse(m); - ASSERT_EQ(OpMsg::flags(m), 0U); + ASSERT_EQ(OpMsg::flags(m), OpMsg::kChecksumPresent); ASSERT_EQ(msg.body.getStringField("find"), nss.coll()); ASSERT_EQ(msg.body["batchSize"].number(), 0); ASSERT_TRUE(msg.body.getBoolField("readOnce")) << msg.body; diff --git a/src/mongo/db/traffic_reader.cpp b/src/mongo/db/traffic_reader.cpp index 7ce7ddcb112..18fa2baf7dd 100644 --- a/src/mongo/db/traffic_reader.cpp +++ b/src/mongo/db/traffic_reader.cpp @@ -192,7 +192,8 @@ void addOpType(TrafficReaderPacket& packet, BSONObjBuilder* builder) { if (packet.message.getNetworkOp() == dbMsg) { Message message; message.setData(dbMsg, packet.message.data(), packet.message.dataLen()); - + // Some header fields like requestId are missing, so the checksum won't match. + OpMsg::removeChecksum(&message); auto opMsg = rpc::opMsgRequestFromAnyProtocol(message); builder->append("opType", opMsg.getCommandName()); } else { diff --git a/src/mongo/dbtests/SConscript b/src/mongo/dbtests/SConscript index 8f9f1286a5a..c53dcb90f96 100644 --- a/src/mongo/dbtests/SConscript +++ b/src/mongo/dbtests/SConscript @@ -85,6 +85,7 @@ if not has_option('noshell') and usemozjs: 'mock_dbclient_conn_test.cpp', 'mock_replica_set_test.cpp', 'multikey_paths_test.cpp', + 'shared_buffer.cpp', 'pdfiletests.cpp', 'plan_executor_invalidation_test.cpp', 'plan_ranking.cpp', diff --git a/src/mongo/dbtests/shared_buffer.cpp b/src/mongo/dbtests/shared_buffer.cpp new file mode 100644 index 00000000000..1cb632d7b8c --- /dev/null +++ b/src/mongo/dbtests/shared_buffer.cpp @@ -0,0 +1,133 @@ +/** + * Copyright (C) 2019-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/base/string_data.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/shared_buffer.h" + +namespace mongo { +namespace { + +using SharedBufferTest = unittest::Test; + +TEST_F(SharedBufferTest, ReallocOrCopyNull) { + SharedBuffer buf; + ASSERT_EQ(buf.capacity(), 0u); + ASSERT(!buf); + ASSERT(!buf.isShared()); + buf.reallocOrCopy(10); + ASSERT(buf); + ASSERT(!buf.isShared()); + ASSERT_EQ(buf.capacity(), 10u); +} + +TEST_F(SharedBufferTest, ReallocOrCopyNullShared) { + // null SharedBuffers are never considered "shared", even when copied. + SharedBuffer buf; + const SharedBuffer sharer = buf; + ASSERT_EQ(buf.capacity(), 0u); + ASSERT(!buf); + ASSERT(!buf.isShared()); + buf.reallocOrCopy(10); + ASSERT(buf); + ASSERT(!buf.isShared()); + ASSERT_EQ(buf.capacity(), 10u); + ASSERT_EQ(sharer.capacity(), 0u); +} + +SharedBuffer makeBuffer() { + SharedBuffer buf = SharedBuffer::allocate(4); + memcpy(buf.get(), "foo", 4); + return buf; +} + +TEST_F(SharedBufferTest, ReallocOrCopyGrow) { + SharedBuffer buf = makeBuffer(); + ASSERT_EQ(buf.capacity(), 4u); + ASSERT(buf); + ASSERT(!buf.isShared()); + buf.reallocOrCopy(10); + ASSERT(buf); + ASSERT(!buf.isShared()); + ASSERT_EQ(buf.capacity(), 10u); + ASSERT_EQ("foo"_sd, buf.get()); +} + +TEST_F(SharedBufferTest, ReallocOrCopyGrowShared) { + SharedBuffer buf = makeBuffer(); + const SharedBuffer sharer = buf; + ASSERT_EQ(buf.capacity(), 4u); + ASSERT(buf); + ASSERT(buf.isShared()); + buf.reallocOrCopy(10); + ASSERT(buf); + ASSERT(!buf.isShared()); + ASSERT_EQ(buf.capacity(), 10u); + ASSERT_EQ(sharer.capacity(), 4u); + ASSERT_EQ("foo"_sd, buf.get()); + ASSERT_EQ("foo"_sd, sharer.get()); + ASSERT_NE(buf.get(), sharer.get()); +} + +TEST_F(SharedBufferTest, ReallocOrCopyShrink) { + SharedBuffer buf = makeBuffer(); + ASSERT_EQ(buf.capacity(), 4u); + ASSERT(buf); + ASSERT(!buf.isShared()); + // The buffer is already at least 1 byte. + buf.reallocOrCopy(1); + ASSERT(buf); + // We copy it anyway. + ASSERT(!buf.isShared()); + ASSERT_EQ(buf.capacity(), 1u); + ASSERT_EQ('f', buf.get()[0]); +} + +TEST_F(SharedBufferTest, ReallocOrCopyShrinkShared) { + SharedBuffer buf = makeBuffer(); + const SharedBuffer sharer = buf; + ASSERT_EQ(buf.capacity(), 4u); + ASSERT(buf); + ASSERT(buf.isShared()); + // The buffer is already at least 1 byte. + buf.reallocOrCopy(1); + ASSERT(buf); + // We copy it anyway. + ASSERT(!buf.isShared()); + ASSERT_EQ(buf.capacity(), 1u); + ASSERT_EQ(sharer.capacity(), 4u); + ASSERT_EQ('f', buf.get()[0]); + ASSERT_EQ("foo"_sd, sharer.get()); + ASSERT_NE(buf.get(), sharer.get()); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp b/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp index b8e67033b14..e75edd1c0a3 100644 --- a/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp +++ b/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp @@ -410,6 +410,7 @@ void client_wire_protocol_rpc(mongo_embedded_v1_client* const client, client->response = sep->handleRequest(opCtx.get(), msg); + // Note that we skip OP_MSG's optional checksum for embedded. MsgData::View outMessage(client->response.response.buf()); outMessage.setId(nextMessageId()); outMessage.setResponseToMsgId(msg.header().getId()); diff --git a/src/mongo/rpc/SConscript b/src/mongo/rpc/SConscript index 9ce0c49fe40..af84d991ab1 100644 --- a/src/mongo/rpc/SConscript +++ b/src/mongo/rpc/SConscript @@ -36,6 +36,7 @@ env.Library( '$BUILD_DIR/mongo/bson/util/bson_extract', '$BUILD_DIR/mongo/db/bson/dotted_path_support', '$BUILD_DIR/mongo/db/server_options_core', + '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum', ], ) @@ -141,6 +142,7 @@ env.CppUnitTest( ], LIBDEPS=[ '$BUILD_DIR/mongo/client/clientdriver_minimal', + '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum', 'rpc', ], ) diff --git a/src/mongo/rpc/message.h b/src/mongo/rpc/message.h index 3689f2b2512..2acb34a632a 100644 --- a/src/mongo/rpc/message.h +++ b/src/mongo/rpc/message.h @@ -424,6 +424,14 @@ public: return size() - sizeof(MSGHEADER::Value); } + size_t capacity() const { + return _buf.capacity(); + } + + void realloc(size_t size) { + _buf.reallocOrCopy(size); + } + void reset() { _buf = {}; } diff --git a/src/mongo/rpc/op_msg.cpp b/src/mongo/rpc/op_msg.cpp index dd15704d4bc..d6dcad993c7 100644 --- a/src/mongo/rpc/op_msg.cpp +++ b/src/mongo/rpc/op_msg.cpp @@ -42,6 +42,7 @@ #include "mongo/util/bufreader.h" #include "mongo/util/hex.h" #include "mongo/util/log.h" +#include "third_party/wiredtiger/wiredtiger.h" namespace mongo { namespace { @@ -58,6 +59,18 @@ enum class Section : uint8_t { kDocSequence = 1, }; +constexpr int kCrc32Size = 4; + +// All fields including size, requestId, and responseTo must already be set. The size must already +// include the final 4-byte checksum. +uint32_t calculateChecksum(const Message& message) { + if (message.operation() != dbMsg) { + return 0; + } + + invariant(OpMsg::isFlagSet(message, OpMsg::kChecksumPresent)); + return wiredtiger_crc32c_func()(message.singleData().view2ptr(), message.size() - kCrc32Size); +} } // namespace uint32_t OpMsg::flags(const Message& message) { @@ -76,6 +89,31 @@ void OpMsg::replaceFlags(Message* message, uint32_t flags) { DataView(message->singleData().data()).write>(flags); } +uint32_t OpMsg::getChecksum(const Message& message) { + invariant(message.operation() == dbMsg); + invariant(isFlagSet(message, kChecksumPresent)); + return BufReader(message.singleData().view2ptr() + message.size() - kCrc32Size, kCrc32Size) + .read>(); +} + +void OpMsg::appendChecksum(Message* message) { + if (message->operation() != dbMsg) { + return; + } + + invariant(!isFlagSet(*message, kChecksumPresent)); + setFlag(message, kChecksumPresent); + const size_t newSize = message->size() + kCrc32Size; + if (message->capacity() < newSize) { + message->realloc(newSize); + } + + // Everything before the checksum, including the final size, is covered by the checksum. + message->header().setLen(newSize); + DataView(message->singleData().view2ptr() + newSize - kCrc32Size) + .write>(calculateChecksum(*message)); +} + OpMsg OpMsg::parse(const Message& message) try { // It is the caller's responsibility to call the correct parser for a given message type. invariant(!message.empty()); @@ -87,7 +125,6 @@ OpMsg OpMsg::parse(const Message& message) try { << std::bitset<32>(flags).to_string(), !containsUnknownRequiredFlags(flags)); - constexpr int kCrc32Size = 4; const bool haveChecksum = flags & kChecksumPresent; const int checksumSize = haveChecksum ? kCrc32Size : 0; @@ -152,6 +189,12 @@ OpMsg OpMsg::parse(const Message& message) try { !inBody); } + if (haveChecksum) { + uassert(ErrorCodes::ChecksumMismatch, + "OP_MSG checksum does not match contents", + OpMsg::getChecksum(message) == calculateChecksum(message)); + } + return msg; } catch (const DBException& ex) { LOG(1) << "invalid message: " << ex.code() << " " << redact(ex) << " -- " diff --git a/src/mongo/rpc/op_msg.h b/src/mongo/rpc/op_msg.h index 5983ed63226..3b13b586bff 100644 --- a/src/mongo/rpc/op_msg.h +++ b/src/mongo/rpc/op_msg.h @@ -73,6 +73,37 @@ struct OpMsg { replaceFlags(message, flags(*message) | flag); } + /** + * Removes a flag from the list of set flags in message. + * Only legal on an otherwise valid OP_MSG message. + */ + static void clearFlag(Message* message, uint32_t flag) { + replaceFlags(message, flags(*message) & ~flag); + } + + /** + * Retrieves the checksum stored at the end of the message. + */ + static uint32_t getChecksum(const Message& message); + + /** + * Add a checksum at the end of the message. Call this after setting size, requestId, and + * responseTo. The checksumPresent flag must *not* already be set. + */ + static void appendChecksum(Message* message); + + /** + * If the checksum is present, unsets the checksumPresent flag and shrinks message by 4 bytes. + */ + static void removeChecksum(Message* message) { + if (!isFlagSet(*message, kChecksumPresent)) { + return; + } + + clearFlag(message, kChecksumPresent); + message->header().setLen(message->size() - 4); + } + /** * Parses and returns an OpMsg containing unowned BSON. */ diff --git a/src/mongo/rpc/op_msg_integration_test.cpp b/src/mongo/rpc/op_msg_integration_test.cpp index 098b9af1323..561264b5db7 100644 --- a/src/mongo/rpc/op_msg_integration_test.cpp +++ b/src/mongo/rpc/op_msg_integration_test.cpp @@ -196,6 +196,9 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { // rather than say() so that we get an error back when the connection is closed. Normally // using call() if kMoreToCome set results in blocking forever. OpMsg::setFlag(&request, OpMsg::kMoreToCome); + // conn.call() calculated the request checksum, but setFlag() makes it invalid. Clear the + // checksum so the next conn.call() recalculates it. + OpMsg::removeChecksum(&request); ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr)); uassertStatusOK(conn.connect(host, "integration_test")); // Reconnect. @@ -224,6 +227,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { // Round-trip command claims to succeed due to w:0. + OpMsg::removeChecksum(&request); OpMsg::replaceFlags(&request, 0); ASSERT(conn.call(request, reply, /*assertOK*/ true, nullptr)); ASSERT_OK(getStatusFromCommandResult( @@ -231,6 +235,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { // Fire-and-forget should still close connection. OpMsg::setFlag(&request, OpMsg::kMoreToCome); + OpMsg::removeChecksum(&request); ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr)); break; @@ -270,7 +275,19 @@ TEST(OpMsg, DocumentSequenceReturnsWork) { << "admin")); } -TEST(OpMsg, ServerHandlesExhaustCorrectly) { +constexpr auto kDisableChecksum = "dbClientConnectionDisableChecksum"; + +void disableClientChecksum() { + auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum); + failPoint->setMode(FailPoint::alwaysOn); +} + +void enableClientChecksum() { + auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum); + failPoint->setMode(FailPoint::off); +} + +void exhaustTest(bool enableChecksum) { std::string errMsg; auto conn = std::unique_ptr( unittest::getFixtureConnectionString().connect("integration_test", errMsg)); @@ -281,6 +298,12 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { return; } + if (!enableChecksum) { + disableClientChecksum(); + } + + ON_BLOCK_EXIT([&] { enableClientChecksum(); }); + NamespaceString nss("test", "coll"); conn->dropCollection(nss.toString()); @@ -301,6 +324,8 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { const long long cursorId = res["cursor"]["id"].numberLong(); ASSERT(res["cursor"]["firstBatch"].Array().empty()); ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + // Reply has checksum if and only if the request did. + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); // Construct getMore request with exhaust flag. Set batch size so we will need multiple batches // to exhaust the cursor. @@ -314,6 +339,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { ASSERT(conn->call(request, reply)); auto lastRequestId = reply.header().getId(); ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); res = OpMsg::parse(reply).body; ASSERT_OK(getStatusFromCommandResult(res)); ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId); @@ -326,6 +352,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { conn->recv(reply, lastRequestId); lastRequestId = reply.header().getId(); ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); res = OpMsg::parse(reply).body; ASSERT_OK(getStatusFromCommandResult(res)); ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId); @@ -337,6 +364,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { // Receive terminal batch. ASSERT(conn->recv(reply, lastRequestId)); ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome)); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); res = OpMsg::parse(reply).body; ASSERT_OK(getStatusFromCommandResult(res)); ASSERT_EQ(res["cursor"]["id"].numberLong(), 0); @@ -345,6 +373,14 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) { ASSERT_BSONOBJ_EQ(nextBatch[0].embeddedObject(), BSON("_id" << 4)); } +TEST(OpMsg, ServerHandlesExhaustCorrectly) { + exhaustTest(false); +} + +TEST(OpMsg, ServerHandlesExhaustCorrectlyWithChecksum) { + exhaustTest(true); +} + TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) { // This test simply tries to verify that using the exhaust option with DBClientCursor works // correctly. The externally visible behavior should technically be the same as a non-exhaust @@ -400,4 +436,35 @@ TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) { ASSERT(!cursor->more()); ASSERT(cursor->isDead()); } + +void checksumTest(bool enableChecksum) { + // The server replies with a checksum if and only if the request has a checksum. + std::string errMsg; + auto conn = std::unique_ptr( + unittest::getFixtureConnectionString().connect("integration_test", errMsg)); + uassert(ErrorCodes::SocketException, errMsg, conn); + + if (!enableChecksum) { + disableClientChecksum(); + } + + ON_BLOCK_EXIT([&] { enableClientChecksum(); }); + + auto opMsgRequest = OpMsgRequest::fromDBAndBody("admin", BSON("ping" << 1)); + auto request = opMsgRequest.serialize(); + + Message reply; + ASSERT(conn->call(request, reply)); + + auto opMsgReply = OpMsg::parse(reply); + ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum); +} + +TEST(OpMsg, ServerRepliesWithoutChecksumToRequestWithoutChecksum) { + checksumTest(true); +} + +TEST(OpMsg, ServerRepliesWithChecksumToRequestWithChecksum) { + checksumTest(true); +} } // namespace mongo diff --git a/src/mongo/rpc/op_msg_test.cpp b/src/mongo/rpc/op_msg_test.cpp index 314f6b99284..bf280768638 100644 --- a/src/mongo/rpc/op_msg_test.cpp +++ b/src/mongo/rpc/op_msg_test.cpp @@ -41,6 +41,7 @@ #include "mongo/unittest/unittest.h" #include "mongo/util/hex.h" #include "mongo/util/log.h" +#include "third_party/wiredtiger/wiredtiger.h" namespace mongo { namespace { @@ -94,14 +95,19 @@ public: explicit Sized(T&&... args) { buffer.skip(sizeof(int32_t)); append(args...); - DataView(buffer.buf()).write>(buffer.len()); + updateSize(); } // Adds extra to the stored size. Use this to produce illegal messages. Sized&& addToSize(int32_t extra) && { - DataView(buffer.buf()).write>(buffer.len() + extra); + updateSize(extra); return std::move(*this); } + +protected: + void updateSize(int32_t extra = 0) { + DataView(buffer.buf()).write>(buffer.len() + extra); + } }; // A Bytes that puts the standard message header at the front. @@ -127,7 +133,25 @@ public: } OpMsgBytes&& addToSize(int32_t extra) && { - DataView(buffer.buf()).write>(buffer.len() + extra); + updateSize(extra); + return std::move(*this); + } + + OpMsgBytes&& appendChecksum() && { + // Reserve space at the end for the checksum. + append(0); + updateSize(); + // Checksum all bits except the checksum itself. + uint32_t checksum = wiredtiger_crc32c_func()(buffer.buf(), buffer.len() - 4); + // Write the checksum bits at the end. + auto checksumBits = DataView(buffer.buf() + buffer.len() - sizeof(checksum)); + checksumBits.write>(checksum); + return std::move(*this); + } + + OpMsgBytes&& appendChecksum(uint32_t checksum) && { + append(checksum); + updateSize(); return std::move(*this); } }; @@ -158,9 +182,6 @@ const char kDocSequenceSection = 1; const uint32_t kNoFlags = 0; const uint32_t kHaveChecksum = 1; -// CRC filler value -const uint32_t kFakeCRC = 0; // TODO will need to compute real crc when SERVER-28679 is done. - TEST_F(OpMsgParser, SucceedsWithJustBody) { auto msg = OpMsgBytes{ kNoFlags, // @@ -172,13 +193,12 @@ TEST_F(OpMsgParser, SucceedsWithJustBody) { ASSERT_EQ(msg.sequences.size(), 0u); } -TEST_F(OpMsgParser, IgnoresCrcIfPresent) { // Until SERVER-28679 is done. - auto msg = OpMsgBytes{ - kHaveChecksum, // - kBodySection, - fromjson("{ping: 1}"), - kFakeCRC, // If not ignored, this would be read as a second body. - }.parse(); +TEST_F(OpMsgParser, SucceedsWithChecksum) { + auto msg = OpMsgBytes{kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}")} + .appendChecksum() + .parse(); ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); ASSERT_EQ(msg.sequences.size(), 0u); @@ -422,12 +442,13 @@ TEST_F(OpMsgParser, FailsIfBodyTooBig) { } TEST_F(OpMsgParser, FailsIfBodyTooBigIntoChecksum) { - auto msg = OpMsgBytes{ - kHaveChecksum, // - kBodySection, - fromjson("{ping: 1}"), - kFakeCRC, - }.addToSize(-1); // Shrink message so body extends past end. + auto msg = + OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), + }.appendChecksum() + .addToSize(-1); // Shrink message so body extends past end. ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON); } @@ -449,19 +470,19 @@ TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBig) { } TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBigIntoChecksum) { - auto msg = OpMsgBytes{ - kHaveChecksum, // - kBodySection, - fromjson("{ping: 1}"), - - kDocSequenceSection, - Sized{ - "docs", // - fromjson("{a: 1}"), - }, + auto msg = + OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), - kFakeCRC, - }.addToSize(-1); // Shrink message so body extends past end. + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + }.appendChecksum() + .addToSize(-1); // Shrink message so body extends past end. ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow); } @@ -594,6 +615,18 @@ TEST_F(OpMsgParser, SucceedsWithUnknownOptionalFlags) { } } +TEST_F(OpMsgParser, FailsWithChecksumMismatch) { + auto msg = OpMsgBytes{kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}")} + .appendChecksum(123); + + ASSERT_THROWS_WITH_CHECK(msg.parse(), AssertionException, [](const DBException& ex) { + ASSERT_EQ(ex.toStatus().code(), ErrorCodes::ChecksumMismatch); + ASSERT(ErrorCodes::isConnectionFatalMessageParseError(ex.toStatus().code())); + }); +} + void testSerializer(const Message& fromSerializer, OpMsgBytes&& expected) { const auto expectedMsg = expected.done(); ASSERT_EQ(fromSerializer.operation(), dbMsg); @@ -860,5 +893,19 @@ TEST(OpMsgRequest, FromDbAndBodyDoesNotCopy) { ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1, $db: 'db'}")); ASSERT_EQ(static_cast(msg.body.objdata()), bodyPtr); } + +TEST(OpMsgTest, ChecksumResizesMessage) { + auto msg = OpMsgBytes{kNoFlags, // + kBodySection, + fromjson("{ping: 1}")} + .done(); + + // Test that appendChecksum() resizes the buffer if necessary. + const auto capacity = msg.sharedBuffer().capacity(); + OpMsg::appendChecksum(&msg); + ASSERT_EQ(msg.sharedBuffer().capacity(), capacity + 4); + // The checksum is correct. + OpMsg::parse(msg); +} } // namespace } // namespace mongo diff --git a/src/mongo/rpc/reply_builder_test.cpp b/src/mongo/rpc/reply_builder_test.cpp index 5bb4ff13320..fcdb036c650 100644 --- a/src/mongo/rpc/reply_builder_test.cpp +++ b/src/mongo/rpc/reply_builder_test.cpp @@ -151,6 +151,9 @@ TEST(OpMsgReplyBuilder, CommandError) { replyBuilder.setCommandReply(status, extraObj); replyBuilder.getBodyBuilder().appendElements(metadata); auto msg = replyBuilder.done(); + msg.header().setId(124); + msg.header().setResponseToMsgId(123); + OpMsg::appendChecksum(&msg); rpc::OpMsgReply parsed(&msg); @@ -172,7 +175,9 @@ void testRoundTrip(rpc::ReplyBuilderInterface& replyBuilder, bool unifiedBodyAnd replyBuilder.getBodyBuilder().appendElements(metadata); auto msg = replyBuilder.done(); - + msg.header().setId(124); + msg.header().setResponseToMsgId(123); + OpMsg::appendChecksum(&msg); T parsed(&msg); if (unifiedBodyAndMetadata) { @@ -197,7 +202,10 @@ void testErrors(rpc::ReplyBuilderInterface& replyBuilder) { replyBuilder.setCommandReply(status); replyBuilder.getBodyBuilder().appendElements(buildMetadata()); - const auto msg = replyBuilder.done(); + auto msg = replyBuilder.done(); + msg.header().setId(124); + msg.header().setResponseToMsgId(123); + OpMsg::appendChecksum(&msg); T parsed(&msg); const Status result = getStatusFromCommandResult(parsed.getCommandReply()); diff --git a/src/mongo/tools/bridge.cpp b/src/mongo/tools/bridge.cpp index 4cce5d997b8..b3908037c4b 100644 --- a/src/mongo/tools/bridge.cpp +++ b/src/mongo/tools/bridge.cpp @@ -383,6 +383,10 @@ DbResponse ServiceEntryPointBridge::handleRequest(OperationContext* opCtx, const } else { dest.setExhaust(false); } + + // The original checksum won't be valid once the network layer replaces requestId. Remove it + // because the network layer re-checksums the response. + OpMsg::removeChecksum(&response); return {std::move(response), exhaustNS}; } else { return {Message()}; diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp index ebcad102106..8e751d9327c 100644 --- a/src/mongo/transport/service_state_machine.cpp +++ b/src/mongo/transport/service_state_machine.cpp @@ -140,14 +140,23 @@ Message makeExhaustMessage(Message requestMsg, DbResponse* dbresponse) { return Message(); } - // Indicate that the response is part of an exhaust stream. + const bool checksumPresent = OpMsg::isFlagSet(requestMsg, OpMsg::kChecksumPresent); + OpMsg::removeChecksum(&dbresponse->response); + // Indicate that the response is part of an exhaust stream. Re-checksum if needed. OpMsg::setFlag(&dbresponse->response, OpMsg::kMoreToCome); + if (checksumPresent) { + OpMsg::appendChecksum(&dbresponse->response); + } // Return an augmented form of the initial request, which is to be used as the next request to // be processed by the database. The id of the response is used as the request id of this - // 'synthetic' request. + // 'synthetic' request. Re-checksum if needed. + OpMsg::removeChecksum(&requestMsg); requestMsg.header().setId(dbresponse->response.header().getId()); requestMsg.header().setResponseToMsgId(dbresponse->response.header().getResponseToMsgId()); + if (checksumPresent) { + OpMsg::appendChecksum(&requestMsg); + } return requestMsg; } @@ -454,10 +463,14 @@ void ServiceStateMachine::_processMessage(ThreadGuard guard) { Message& toSink = dbresponse.response; if (!toSink.empty()) { invariant(!OpMsg::isFlagSet(_inMessage, OpMsg::kMoreToCome)); + invariant(!OpMsg::isFlagSet(toSink, OpMsg::kChecksumPresent)); // Update the header for the response message. toSink.header().setId(nextMessageId()); toSink.header().setResponseToMsgId(_inMessage.header().getId()); + if (OpMsg::isFlagSet(_inMessage, OpMsg::kChecksumPresent)) { + OpMsg::appendChecksum(&toSink); + } // If the incoming message has the exhaust flag set and is a 'getMore' command, then we // bypass the normal RPC behavior. We will sink the response to the network, but we also diff --git a/src/mongo/transport/transport_layer_asio_test.cpp b/src/mongo/transport/transport_layer_asio_test.cpp index 7bbaab7cd59..45552516ee0 100644 --- a/src/mongo/transport/transport_layer_asio_test.cpp +++ b/src/mongo/transport/transport_layer_asio_test.cpp @@ -259,6 +259,7 @@ public: Message msg = builder.finish(); msg.header().setResponseToMsgId(0); msg.header().setId(0); + OpMsg::appendChecksum(&msg); std::error_code ec; asio::write(_sock, asio::buffer(msg.buf(), msg.size()), ec); diff --git a/src/mongo/util/shared_buffer.h b/src/mongo/util/shared_buffer.h index e18db9b7e3c..b0bc09c769b 100644 --- a/src/mongo/util/shared_buffer.h +++ b/src/mongo/util/shared_buffer.h @@ -29,6 +29,8 @@ #pragma once +#include + #include #include "mongo/platform/atomic_word.h" @@ -73,6 +75,23 @@ public: _holder = std::move(tmp._holder); } + /** + * Resizes the buffer, copying the current contents. If shared, an exclusive copy is made. + */ + void reallocOrCopy(size_t size) { + if (isShared()) { + auto tmp = SharedBuffer::allocate(size); + memcpy(tmp._holder->data(), + _holder->data(), + std::min(size, static_cast(_holder->_capacity))); + swap(tmp); + } else if (_holder) { + realloc(size); + } else { + *this = SharedBuffer::allocate(size); + } + } + char* get() const { return _holder ? _holder->data() : NULL; } -- cgit v1.2.1