summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorA. Jesse Jiryu Davis <jesse@mongodb.com>2019-05-04 09:03:54 -0400
committerA. Jesse Jiryu Davis <jesse@mongodb.com>2019-05-14 19:23:10 -0400
commitc83e50d7275adf2a5e946ba2c4b0861fcd9dc69b (patch)
tree2f672def64169a68c2017a460896aae6ce67c2e5 /src
parent089dd83af48cf198916e2dca50742378d4c3d361 (diff)
downloadmongo-c83e50d7275adf2a5e946ba2c4b0861fcd9dc69b.tar.gz
SERVER-28679 Set OP_MSG checksum
Diffstat (limited to 'src')
-rw-r--r--src/mongo/base/error_codes.err5
-rw-r--r--src/mongo/client/async_client.cpp1
-rw-r--r--src/mongo/client/dbclient_connection.cpp8
-rw-r--r--src/mongo/client/dbclient_cursor_test.cpp6
-rw-r--r--src/mongo/db/traffic_reader.cpp3
-rw-r--r--src/mongo/dbtests/SConscript1
-rw-r--r--src/mongo/dbtests/shared_buffer.cpp133
-rw-r--r--src/mongo/embedded/mongo_embedded/mongo_embedded.cpp1
-rw-r--r--src/mongo/rpc/SConscript2
-rw-r--r--src/mongo/rpc/message.h8
-rw-r--r--src/mongo/rpc/op_msg.cpp45
-rw-r--r--src/mongo/rpc/op_msg.h31
-rw-r--r--src/mongo/rpc/op_msg_integration_test.cpp69
-rw-r--r--src/mongo/rpc/op_msg_test.cpp109
-rw-r--r--src/mongo/rpc/reply_builder_test.cpp12
-rw-r--r--src/mongo/tools/bridge.cpp4
-rw-r--r--src/mongo/transport/service_state_machine.cpp17
-rw-r--r--src/mongo/transport/transport_layer_asio_test.cpp1
-rw-r--r--src/mongo/util/shared_buffer.h19
19 files changed, 433 insertions, 42 deletions
diff --git a/src/mongo/base/error_codes.err b/src/mongo/base/error_codes.err
index 14be694a0bc..b34ac7c5c61 100644
--- a/src/mongo/base/error_codes.err
+++ b/src/mongo/base/error_codes.err
@@ -287,6 +287,7 @@ error_code("IndexBuildAlreadyInProgress", 285)
error_code("ChangeStreamHistoryLost", 286)
# The code below is for internal use only and must never be returned in a network response
error_code("TransactionCoordinatorDeadlineTaskCanceled", 287)
+error_code("ChecksumMismatch", 288)
# Error codes 4000-8999 are reserved.
@@ -348,9 +349,9 @@ error_class("ShutdownError", ["ShutdownInProgress", "InterruptedAtShutdown"])
# indicates that it cannot be executed as normal and must abort its intended work.
error_class("CancelationError", ["ShutdownInProgress", "InterruptedAtShutdown", "CallbackCanceled"])
-#TODO SERVER-28679 add checksum failure.
error_class("ConnectionFatalMessageParseError", ["IllegalOpMsgFlag",
- "TooManyDocumentSequences"])
+ "TooManyDocumentSequences",
+ "ChecksumMismatch"])
error_class("ExceededTimeLimitError", ["ExceededTimeLimit", "MaxTimeMSExpired", "NetworkInterfaceExceededTimeLimit"])
diff --git a/src/mongo/client/async_client.cpp b/src/mongo/client/async_client.cpp
index e99af8a258b..7410b1a89d9 100644
--- a/src/mongo/client/async_client.cpp
+++ b/src/mongo/client/async_client.cpp
@@ -202,6 +202,7 @@ Future<Message> AsyncDBClient::_call(Message request, const BatonHandle& baton)
auto msgId = nextMessageId();
request.header().setId(msgId);
request.header().setResponseToMsgId(0);
+ OpMsg::appendChecksum(&request);
return _session->asyncSinkMessage(request, baton)
.then([this, baton] { return _session->asyncSourceMessage(baton); })
diff --git a/src/mongo/client/dbclient_connection.cpp b/src/mongo/client/dbclient_connection.cpp
index 2341e2a8a69..b01edcc993c 100644
--- a/src/mongo/client/dbclient_connection.cpp
+++ b/src/mongo/client/dbclient_connection.cpp
@@ -85,6 +85,8 @@ using std::endl;
using std::map;
using std::string;
+MONGO_FAIL_POINT_DEFINE(dbClientConnectionDisableChecksum);
+
namespace {
/**
@@ -576,6 +578,9 @@ void DBClientConnection::say(Message& toSend, bool isRetry, string* actualServer
toSend.header().setId(nextMessageId());
toSend.header().setResponseToMsgId(0);
+ if (!MONGO_FAIL_POINT(dbClientConnectionDisableChecksum)) {
+ OpMsg::appendChecksum(&toSend);
+ }
uassertStatusOK(
_session->sinkMessage(uassertStatusOK(_compressorManager.compressMessage(toSend))));
killSessionOnError.dismiss();
@@ -619,6 +624,9 @@ bool DBClientConnection::call(Message& toSend,
toSend.header().setId(nextMessageId());
toSend.header().setResponseToMsgId(0);
+ if (!MONGO_FAIL_POINT(dbClientConnectionDisableChecksum)) {
+ OpMsg::appendChecksum(&toSend);
+ }
auto swm = _compressorManager.compressMessage(toSend);
uassertStatusOK(swm.getStatus());
diff --git a/src/mongo/client/dbclient_cursor_test.cpp b/src/mongo/client/dbclient_cursor_test.cpp
index 6429df5c604..292ce2c8bb5 100644
--- a/src/mongo/client/dbclient_cursor_test.cpp
+++ b/src/mongo/client/dbclient_cursor_test.cpp
@@ -59,12 +59,14 @@ public:
const auto reqId = nextMessageId();
toSend.header().setId(reqId);
toSend.header().setResponseToMsgId(0);
+ OpMsg::appendChecksum(&toSend);
_lastSent = toSend;
// Mock response.
response = _mockCallResponse;
response.header().setId(nextMessageId());
response.header().setResponseToMsgId(reqId);
+ OpMsg::appendChecksum(&response);
return true;
}
@@ -164,7 +166,7 @@ TEST_F(DBClientCursorTest, DBClientCursorHandlesOpMsgExhaustCorrectly) {
auto m = conn.getLastSentMessage();
ASSERT(!m.empty());
auto msg = OpMsg::parse(m);
- ASSERT_EQ(OpMsg::flags(m), 0U);
+ ASSERT_EQ(OpMsg::flags(m), OpMsg::kChecksumPresent);
ASSERT_EQ(msg.body.getStringField("find"), nss.coll());
ASSERT_EQ(msg.body["batchSize"].number(), 0);
@@ -433,7 +435,7 @@ TEST_F(DBClientCursorTest, DBClientCursorPassesReadOnceFlag) {
auto m = conn.getLastSentMessage();
ASSERT(!m.empty());
auto msg = OpMsg::parse(m);
- ASSERT_EQ(OpMsg::flags(m), 0U);
+ ASSERT_EQ(OpMsg::flags(m), OpMsg::kChecksumPresent);
ASSERT_EQ(msg.body.getStringField("find"), nss.coll());
ASSERT_EQ(msg.body["batchSize"].number(), 0);
ASSERT_TRUE(msg.body.getBoolField("readOnce")) << msg.body;
diff --git a/src/mongo/db/traffic_reader.cpp b/src/mongo/db/traffic_reader.cpp
index 7ce7ddcb112..18fa2baf7dd 100644
--- a/src/mongo/db/traffic_reader.cpp
+++ b/src/mongo/db/traffic_reader.cpp
@@ -192,7 +192,8 @@ void addOpType(TrafficReaderPacket& packet, BSONObjBuilder* builder) {
if (packet.message.getNetworkOp() == dbMsg) {
Message message;
message.setData(dbMsg, packet.message.data(), packet.message.dataLen());
-
+ // Some header fields like requestId are missing, so the checksum won't match.
+ OpMsg::removeChecksum(&message);
auto opMsg = rpc::opMsgRequestFromAnyProtocol(message);
builder->append("opType", opMsg.getCommandName());
} else {
diff --git a/src/mongo/dbtests/SConscript b/src/mongo/dbtests/SConscript
index 8f9f1286a5a..c53dcb90f96 100644
--- a/src/mongo/dbtests/SConscript
+++ b/src/mongo/dbtests/SConscript
@@ -85,6 +85,7 @@ if not has_option('noshell') and usemozjs:
'mock_dbclient_conn_test.cpp',
'mock_replica_set_test.cpp',
'multikey_paths_test.cpp',
+ 'shared_buffer.cpp',
'pdfiletests.cpp',
'plan_executor_invalidation_test.cpp',
'plan_ranking.cpp',
diff --git a/src/mongo/dbtests/shared_buffer.cpp b/src/mongo/dbtests/shared_buffer.cpp
new file mode 100644
index 00000000000..1cb632d7b8c
--- /dev/null
+++ b/src/mongo/dbtests/shared_buffer.cpp
@@ -0,0 +1,133 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/base/string_data.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/shared_buffer.h"
+
+namespace mongo {
+namespace {
+
+using SharedBufferTest = unittest::Test;
+
+TEST_F(SharedBufferTest, ReallocOrCopyNull) {
+ SharedBuffer buf;
+ ASSERT_EQ(buf.capacity(), 0u);
+ ASSERT(!buf);
+ ASSERT(!buf.isShared());
+ buf.reallocOrCopy(10);
+ ASSERT(buf);
+ ASSERT(!buf.isShared());
+ ASSERT_EQ(buf.capacity(), 10u);
+}
+
+TEST_F(SharedBufferTest, ReallocOrCopyNullShared) {
+ // null SharedBuffers are never considered "shared", even when copied.
+ SharedBuffer buf;
+ const SharedBuffer sharer = buf;
+ ASSERT_EQ(buf.capacity(), 0u);
+ ASSERT(!buf);
+ ASSERT(!buf.isShared());
+ buf.reallocOrCopy(10);
+ ASSERT(buf);
+ ASSERT(!buf.isShared());
+ ASSERT_EQ(buf.capacity(), 10u);
+ ASSERT_EQ(sharer.capacity(), 0u);
+}
+
+SharedBuffer makeBuffer() {
+ SharedBuffer buf = SharedBuffer::allocate(4);
+ memcpy(buf.get(), "foo", 4);
+ return buf;
+}
+
+TEST_F(SharedBufferTest, ReallocOrCopyGrow) {
+ SharedBuffer buf = makeBuffer();
+ ASSERT_EQ(buf.capacity(), 4u);
+ ASSERT(buf);
+ ASSERT(!buf.isShared());
+ buf.reallocOrCopy(10);
+ ASSERT(buf);
+ ASSERT(!buf.isShared());
+ ASSERT_EQ(buf.capacity(), 10u);
+ ASSERT_EQ("foo"_sd, buf.get());
+}
+
+TEST_F(SharedBufferTest, ReallocOrCopyGrowShared) {
+ SharedBuffer buf = makeBuffer();
+ const SharedBuffer sharer = buf;
+ ASSERT_EQ(buf.capacity(), 4u);
+ ASSERT(buf);
+ ASSERT(buf.isShared());
+ buf.reallocOrCopy(10);
+ ASSERT(buf);
+ ASSERT(!buf.isShared());
+ ASSERT_EQ(buf.capacity(), 10u);
+ ASSERT_EQ(sharer.capacity(), 4u);
+ ASSERT_EQ("foo"_sd, buf.get());
+ ASSERT_EQ("foo"_sd, sharer.get());
+ ASSERT_NE(buf.get(), sharer.get());
+}
+
+TEST_F(SharedBufferTest, ReallocOrCopyShrink) {
+ SharedBuffer buf = makeBuffer();
+ ASSERT_EQ(buf.capacity(), 4u);
+ ASSERT(buf);
+ ASSERT(!buf.isShared());
+ // The buffer is already at least 1 byte.
+ buf.reallocOrCopy(1);
+ ASSERT(buf);
+ // We copy it anyway.
+ ASSERT(!buf.isShared());
+ ASSERT_EQ(buf.capacity(), 1u);
+ ASSERT_EQ('f', buf.get()[0]);
+}
+
+TEST_F(SharedBufferTest, ReallocOrCopyShrinkShared) {
+ SharedBuffer buf = makeBuffer();
+ const SharedBuffer sharer = buf;
+ ASSERT_EQ(buf.capacity(), 4u);
+ ASSERT(buf);
+ ASSERT(buf.isShared());
+ // The buffer is already at least 1 byte.
+ buf.reallocOrCopy(1);
+ ASSERT(buf);
+ // We copy it anyway.
+ ASSERT(!buf.isShared());
+ ASSERT_EQ(buf.capacity(), 1u);
+ ASSERT_EQ(sharer.capacity(), 4u);
+ ASSERT_EQ('f', buf.get()[0]);
+ ASSERT_EQ("foo"_sd, sharer.get());
+ ASSERT_NE(buf.get(), sharer.get());
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp b/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp
index b8e67033b14..e75edd1c0a3 100644
--- a/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp
+++ b/src/mongo/embedded/mongo_embedded/mongo_embedded.cpp
@@ -410,6 +410,7 @@ void client_wire_protocol_rpc(mongo_embedded_v1_client* const client,
client->response = sep->handleRequest(opCtx.get(), msg);
+ // Note that we skip OP_MSG's optional checksum for embedded.
MsgData::View outMessage(client->response.response.buf());
outMessage.setId(nextMessageId());
outMessage.setResponseToMsgId(msg.header().getId());
diff --git a/src/mongo/rpc/SConscript b/src/mongo/rpc/SConscript
index 9ce0c49fe40..af84d991ab1 100644
--- a/src/mongo/rpc/SConscript
+++ b/src/mongo/rpc/SConscript
@@ -36,6 +36,7 @@ env.Library(
'$BUILD_DIR/mongo/bson/util/bson_extract',
'$BUILD_DIR/mongo/db/bson/dotted_path_support',
'$BUILD_DIR/mongo/db/server_options_core',
+ '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum',
],
)
@@ -141,6 +142,7 @@ env.CppUnitTest(
],
LIBDEPS=[
'$BUILD_DIR/mongo/client/clientdriver_minimal',
+ '$BUILD_DIR/third_party/wiredtiger/wiredtiger_checksum',
'rpc',
],
)
diff --git a/src/mongo/rpc/message.h b/src/mongo/rpc/message.h
index 3689f2b2512..2acb34a632a 100644
--- a/src/mongo/rpc/message.h
+++ b/src/mongo/rpc/message.h
@@ -424,6 +424,14 @@ public:
return size() - sizeof(MSGHEADER::Value);
}
+ size_t capacity() const {
+ return _buf.capacity();
+ }
+
+ void realloc(size_t size) {
+ _buf.reallocOrCopy(size);
+ }
+
void reset() {
_buf = {};
}
diff --git a/src/mongo/rpc/op_msg.cpp b/src/mongo/rpc/op_msg.cpp
index dd15704d4bc..d6dcad993c7 100644
--- a/src/mongo/rpc/op_msg.cpp
+++ b/src/mongo/rpc/op_msg.cpp
@@ -42,6 +42,7 @@
#include "mongo/util/bufreader.h"
#include "mongo/util/hex.h"
#include "mongo/util/log.h"
+#include "third_party/wiredtiger/wiredtiger.h"
namespace mongo {
namespace {
@@ -58,6 +59,18 @@ enum class Section : uint8_t {
kDocSequence = 1,
};
+constexpr int kCrc32Size = 4;
+
+// All fields including size, requestId, and responseTo must already be set. The size must already
+// include the final 4-byte checksum.
+uint32_t calculateChecksum(const Message& message) {
+ if (message.operation() != dbMsg) {
+ return 0;
+ }
+
+ invariant(OpMsg::isFlagSet(message, OpMsg::kChecksumPresent));
+ return wiredtiger_crc32c_func()(message.singleData().view2ptr(), message.size() - kCrc32Size);
+}
} // namespace
uint32_t OpMsg::flags(const Message& message) {
@@ -76,6 +89,31 @@ void OpMsg::replaceFlags(Message* message, uint32_t flags) {
DataView(message->singleData().data()).write<LittleEndian<uint32_t>>(flags);
}
+uint32_t OpMsg::getChecksum(const Message& message) {
+ invariant(message.operation() == dbMsg);
+ invariant(isFlagSet(message, kChecksumPresent));
+ return BufReader(message.singleData().view2ptr() + message.size() - kCrc32Size, kCrc32Size)
+ .read<LittleEndian<uint32_t>>();
+}
+
+void OpMsg::appendChecksum(Message* message) {
+ if (message->operation() != dbMsg) {
+ return;
+ }
+
+ invariant(!isFlagSet(*message, kChecksumPresent));
+ setFlag(message, kChecksumPresent);
+ const size_t newSize = message->size() + kCrc32Size;
+ if (message->capacity() < newSize) {
+ message->realloc(newSize);
+ }
+
+ // Everything before the checksum, including the final size, is covered by the checksum.
+ message->header().setLen(newSize);
+ DataView(message->singleData().view2ptr() + newSize - kCrc32Size)
+ .write<LittleEndian<uint32_t>>(calculateChecksum(*message));
+}
+
OpMsg OpMsg::parse(const Message& message) try {
// It is the caller's responsibility to call the correct parser for a given message type.
invariant(!message.empty());
@@ -87,7 +125,6 @@ OpMsg OpMsg::parse(const Message& message) try {
<< std::bitset<32>(flags).to_string(),
!containsUnknownRequiredFlags(flags));
- constexpr int kCrc32Size = 4;
const bool haveChecksum = flags & kChecksumPresent;
const int checksumSize = haveChecksum ? kCrc32Size : 0;
@@ -152,6 +189,12 @@ OpMsg OpMsg::parse(const Message& message) try {
!inBody);
}
+ if (haveChecksum) {
+ uassert(ErrorCodes::ChecksumMismatch,
+ "OP_MSG checksum does not match contents",
+ OpMsg::getChecksum(message) == calculateChecksum(message));
+ }
+
return msg;
} catch (const DBException& ex) {
LOG(1) << "invalid message: " << ex.code() << " " << redact(ex) << " -- "
diff --git a/src/mongo/rpc/op_msg.h b/src/mongo/rpc/op_msg.h
index 5983ed63226..3b13b586bff 100644
--- a/src/mongo/rpc/op_msg.h
+++ b/src/mongo/rpc/op_msg.h
@@ -74,6 +74,37 @@ struct OpMsg {
}
/**
+ * Removes a flag from the list of set flags in message.
+ * Only legal on an otherwise valid OP_MSG message.
+ */
+ static void clearFlag(Message* message, uint32_t flag) {
+ replaceFlags(message, flags(*message) & ~flag);
+ }
+
+ /**
+ * Retrieves the checksum stored at the end of the message.
+ */
+ static uint32_t getChecksum(const Message& message);
+
+ /**
+ * Add a checksum at the end of the message. Call this after setting size, requestId, and
+ * responseTo. The checksumPresent flag must *not* already be set.
+ */
+ static void appendChecksum(Message* message);
+
+ /**
+ * If the checksum is present, unsets the checksumPresent flag and shrinks message by 4 bytes.
+ */
+ static void removeChecksum(Message* message) {
+ if (!isFlagSet(*message, kChecksumPresent)) {
+ return;
+ }
+
+ clearFlag(message, kChecksumPresent);
+ message->header().setLen(message->size() - 4);
+ }
+
+ /**
* Parses and returns an OpMsg containing unowned BSON.
*/
static OpMsg parse(const Message& message);
diff --git a/src/mongo/rpc/op_msg_integration_test.cpp b/src/mongo/rpc/op_msg_integration_test.cpp
index 098b9af1323..561264b5db7 100644
--- a/src/mongo/rpc/op_msg_integration_test.cpp
+++ b/src/mongo/rpc/op_msg_integration_test.cpp
@@ -196,6 +196,9 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) {
// rather than say() so that we get an error back when the connection is closed. Normally
// using call() if kMoreToCome set results in blocking forever.
OpMsg::setFlag(&request, OpMsg::kMoreToCome);
+ // conn.call() calculated the request checksum, but setFlag() makes it invalid. Clear the
+ // checksum so the next conn.call() recalculates it.
+ OpMsg::removeChecksum(&request);
ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr));
uassertStatusOK(conn.connect(host, "integration_test")); // Reconnect.
@@ -224,6 +227,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) {
// Round-trip command claims to succeed due to w:0.
+ OpMsg::removeChecksum(&request);
OpMsg::replaceFlags(&request, 0);
ASSERT(conn.call(request, reply, /*assertOK*/ true, nullptr));
ASSERT_OK(getStatusFromCommandResult(
@@ -231,6 +235,7 @@ TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) {
// Fire-and-forget should still close connection.
OpMsg::setFlag(&request, OpMsg::kMoreToCome);
+ OpMsg::removeChecksum(&request);
ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr));
break;
@@ -270,7 +275,19 @@ TEST(OpMsg, DocumentSequenceReturnsWork) {
<< "admin"));
}
-TEST(OpMsg, ServerHandlesExhaustCorrectly) {
+constexpr auto kDisableChecksum = "dbClientConnectionDisableChecksum";
+
+void disableClientChecksum() {
+ auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum);
+ failPoint->setMode(FailPoint::alwaysOn);
+}
+
+void enableClientChecksum() {
+ auto failPoint = getGlobalFailPointRegistry()->getFailPoint(kDisableChecksum);
+ failPoint->setMode(FailPoint::off);
+}
+
+void exhaustTest(bool enableChecksum) {
std::string errMsg;
auto conn = std::unique_ptr<DBClientBase>(
unittest::getFixtureConnectionString().connect("integration_test", errMsg));
@@ -281,6 +298,12 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
return;
}
+ if (!enableChecksum) {
+ disableClientChecksum();
+ }
+
+ ON_BLOCK_EXIT([&] { enableClientChecksum(); });
+
NamespaceString nss("test", "coll");
conn->dropCollection(nss.toString());
@@ -301,6 +324,8 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
const long long cursorId = res["cursor"]["id"].numberLong();
ASSERT(res["cursor"]["firstBatch"].Array().empty());
ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ // Reply has checksum if and only if the request did.
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
// Construct getMore request with exhaust flag. Set batch size so we will need multiple batches
// to exhaust the cursor.
@@ -314,6 +339,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
ASSERT(conn->call(request, reply));
auto lastRequestId = reply.header().getId();
ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
res = OpMsg::parse(reply).body;
ASSERT_OK(getStatusFromCommandResult(res));
ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId);
@@ -326,6 +352,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
conn->recv(reply, lastRequestId);
lastRequestId = reply.header().getId();
ASSERT(OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
res = OpMsg::parse(reply).body;
ASSERT_OK(getStatusFromCommandResult(res));
ASSERT_EQ(res["cursor"]["id"].numberLong(), cursorId);
@@ -337,6 +364,7 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
// Receive terminal batch.
ASSERT(conn->recv(reply, lastRequestId));
ASSERT(!OpMsg::isFlagSet(reply, OpMsg::kMoreToCome));
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
res = OpMsg::parse(reply).body;
ASSERT_OK(getStatusFromCommandResult(res));
ASSERT_EQ(res["cursor"]["id"].numberLong(), 0);
@@ -345,6 +373,14 @@ TEST(OpMsg, ServerHandlesExhaustCorrectly) {
ASSERT_BSONOBJ_EQ(nextBatch[0].embeddedObject(), BSON("_id" << 4));
}
+TEST(OpMsg, ServerHandlesExhaustCorrectly) {
+ exhaustTest(false);
+}
+
+TEST(OpMsg, ServerHandlesExhaustCorrectlyWithChecksum) {
+ exhaustTest(true);
+}
+
TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) {
// This test simply tries to verify that using the exhaust option with DBClientCursor works
// correctly. The externally visible behavior should technically be the same as a non-exhaust
@@ -400,4 +436,35 @@ TEST(OpMsg, ExhaustWithDBClientCursorBehavesCorrectly) {
ASSERT(!cursor->more());
ASSERT(cursor->isDead());
}
+
+void checksumTest(bool enableChecksum) {
+ // The server replies with a checksum if and only if the request has a checksum.
+ std::string errMsg;
+ auto conn = std::unique_ptr<DBClientBase>(
+ unittest::getFixtureConnectionString().connect("integration_test", errMsg));
+ uassert(ErrorCodes::SocketException, errMsg, conn);
+
+ if (!enableChecksum) {
+ disableClientChecksum();
+ }
+
+ ON_BLOCK_EXIT([&] { enableClientChecksum(); });
+
+ auto opMsgRequest = OpMsgRequest::fromDBAndBody("admin", BSON("ping" << 1));
+ auto request = opMsgRequest.serialize();
+
+ Message reply;
+ ASSERT(conn->call(request, reply));
+
+ auto opMsgReply = OpMsg::parse(reply);
+ ASSERT_EQ(OpMsg::isFlagSet(reply, OpMsg::kChecksumPresent), enableChecksum);
+}
+
+TEST(OpMsg, ServerRepliesWithoutChecksumToRequestWithoutChecksum) {
+ checksumTest(true);
+}
+
+TEST(OpMsg, ServerRepliesWithChecksumToRequestWithChecksum) {
+ checksumTest(true);
+}
} // namespace mongo
diff --git a/src/mongo/rpc/op_msg_test.cpp b/src/mongo/rpc/op_msg_test.cpp
index 314f6b99284..bf280768638 100644
--- a/src/mongo/rpc/op_msg_test.cpp
+++ b/src/mongo/rpc/op_msg_test.cpp
@@ -41,6 +41,7 @@
#include "mongo/unittest/unittest.h"
#include "mongo/util/hex.h"
#include "mongo/util/log.h"
+#include "third_party/wiredtiger/wiredtiger.h"
namespace mongo {
namespace {
@@ -94,14 +95,19 @@ public:
explicit Sized(T&&... args) {
buffer.skip(sizeof(int32_t));
append(args...);
- DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len());
+ updateSize();
}
// Adds extra to the stored size. Use this to produce illegal messages.
Sized&& addToSize(int32_t extra) && {
- DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra);
+ updateSize(extra);
return std::move(*this);
}
+
+protected:
+ void updateSize(int32_t extra = 0) {
+ DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra);
+ }
};
// A Bytes that puts the standard message header at the front.
@@ -127,7 +133,25 @@ public:
}
OpMsgBytes&& addToSize(int32_t extra) && {
- DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra);
+ updateSize(extra);
+ return std::move(*this);
+ }
+
+ OpMsgBytes&& appendChecksum() && {
+ // Reserve space at the end for the checksum.
+ append<uint32_t>(0);
+ updateSize();
+ // Checksum all bits except the checksum itself.
+ uint32_t checksum = wiredtiger_crc32c_func()(buffer.buf(), buffer.len() - 4);
+ // Write the checksum bits at the end.
+ auto checksumBits = DataView(buffer.buf() + buffer.len() - sizeof(checksum));
+ checksumBits.write<LittleEndian<uint32_t>>(checksum);
+ return std::move(*this);
+ }
+
+ OpMsgBytes&& appendChecksum(uint32_t checksum) && {
+ append(checksum);
+ updateSize();
return std::move(*this);
}
};
@@ -158,9 +182,6 @@ const char kDocSequenceSection = 1;
const uint32_t kNoFlags = 0;
const uint32_t kHaveChecksum = 1;
-// CRC filler value
-const uint32_t kFakeCRC = 0; // TODO will need to compute real crc when SERVER-28679 is done.
-
TEST_F(OpMsgParser, SucceedsWithJustBody) {
auto msg = OpMsgBytes{
kNoFlags, //
@@ -172,13 +193,12 @@ TEST_F(OpMsgParser, SucceedsWithJustBody) {
ASSERT_EQ(msg.sequences.size(), 0u);
}
-TEST_F(OpMsgParser, IgnoresCrcIfPresent) { // Until SERVER-28679 is done.
- auto msg = OpMsgBytes{
- kHaveChecksum, //
- kBodySection,
- fromjson("{ping: 1}"),
- kFakeCRC, // If not ignored, this would be read as a second body.
- }.parse();
+TEST_F(OpMsgParser, SucceedsWithChecksum) {
+ auto msg = OpMsgBytes{kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}")}
+ .appendChecksum()
+ .parse();
ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}"));
ASSERT_EQ(msg.sequences.size(), 0u);
@@ -422,12 +442,13 @@ TEST_F(OpMsgParser, FailsIfBodyTooBig) {
}
TEST_F(OpMsgParser, FailsIfBodyTooBigIntoChecksum) {
- auto msg = OpMsgBytes{
- kHaveChecksum, //
- kBodySection,
- fromjson("{ping: 1}"),
- kFakeCRC,
- }.addToSize(-1); // Shrink message so body extends past end.
+ auto msg =
+ OpMsgBytes{
+ kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}"),
+ }.appendChecksum()
+ .addToSize(-1); // Shrink message so body extends past end.
ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON);
}
@@ -449,19 +470,19 @@ TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBig) {
}
TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBigIntoChecksum) {
- auto msg = OpMsgBytes{
- kHaveChecksum, //
- kBodySection,
- fromjson("{ping: 1}"),
-
- kDocSequenceSection,
- Sized{
- "docs", //
- fromjson("{a: 1}"),
- },
+ auto msg =
+ OpMsgBytes{
+ kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}"),
- kFakeCRC,
- }.addToSize(-1); // Shrink message so body extends past end.
+ kDocSequenceSection,
+ Sized{
+ "docs", //
+ fromjson("{a: 1}"),
+ },
+ }.appendChecksum()
+ .addToSize(-1); // Shrink message so body extends past end.
ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow);
}
@@ -594,6 +615,18 @@ TEST_F(OpMsgParser, SucceedsWithUnknownOptionalFlags) {
}
}
+TEST_F(OpMsgParser, FailsWithChecksumMismatch) {
+ auto msg = OpMsgBytes{kHaveChecksum, //
+ kBodySection,
+ fromjson("{ping: 1}")}
+ .appendChecksum(123);
+
+ ASSERT_THROWS_WITH_CHECK(msg.parse(), AssertionException, [](const DBException& ex) {
+ ASSERT_EQ(ex.toStatus().code(), ErrorCodes::ChecksumMismatch);
+ ASSERT(ErrorCodes::isConnectionFatalMessageParseError(ex.toStatus().code()));
+ });
+}
+
void testSerializer(const Message& fromSerializer, OpMsgBytes&& expected) {
const auto expectedMsg = expected.done();
ASSERT_EQ(fromSerializer.operation(), dbMsg);
@@ -860,5 +893,19 @@ TEST(OpMsgRequest, FromDbAndBodyDoesNotCopy) {
ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1, $db: 'db'}"));
ASSERT_EQ(static_cast<const void*>(msg.body.objdata()), bodyPtr);
}
+
+TEST(OpMsgTest, ChecksumResizesMessage) {
+ auto msg = OpMsgBytes{kNoFlags, //
+ kBodySection,
+ fromjson("{ping: 1}")}
+ .done();
+
+ // Test that appendChecksum() resizes the buffer if necessary.
+ const auto capacity = msg.sharedBuffer().capacity();
+ OpMsg::appendChecksum(&msg);
+ ASSERT_EQ(msg.sharedBuffer().capacity(), capacity + 4);
+ // The checksum is correct.
+ OpMsg::parse(msg);
+}
} // namespace
} // namespace mongo
diff --git a/src/mongo/rpc/reply_builder_test.cpp b/src/mongo/rpc/reply_builder_test.cpp
index 5bb4ff13320..fcdb036c650 100644
--- a/src/mongo/rpc/reply_builder_test.cpp
+++ b/src/mongo/rpc/reply_builder_test.cpp
@@ -151,6 +151,9 @@ TEST(OpMsgReplyBuilder, CommandError) {
replyBuilder.setCommandReply(status, extraObj);
replyBuilder.getBodyBuilder().appendElements(metadata);
auto msg = replyBuilder.done();
+ msg.header().setId(124);
+ msg.header().setResponseToMsgId(123);
+ OpMsg::appendChecksum(&msg);
rpc::OpMsgReply parsed(&msg);
@@ -172,7 +175,9 @@ void testRoundTrip(rpc::ReplyBuilderInterface& replyBuilder, bool unifiedBodyAnd
replyBuilder.getBodyBuilder().appendElements(metadata);
auto msg = replyBuilder.done();
-
+ msg.header().setId(124);
+ msg.header().setResponseToMsgId(123);
+ OpMsg::appendChecksum(&msg);
T parsed(&msg);
if (unifiedBodyAndMetadata) {
@@ -197,7 +202,10 @@ void testErrors(rpc::ReplyBuilderInterface& replyBuilder) {
replyBuilder.setCommandReply(status);
replyBuilder.getBodyBuilder().appendElements(buildMetadata());
- const auto msg = replyBuilder.done();
+ auto msg = replyBuilder.done();
+ msg.header().setId(124);
+ msg.header().setResponseToMsgId(123);
+ OpMsg::appendChecksum(&msg);
T parsed(&msg);
const Status result = getStatusFromCommandResult(parsed.getCommandReply());
diff --git a/src/mongo/tools/bridge.cpp b/src/mongo/tools/bridge.cpp
index 4cce5d997b8..b3908037c4b 100644
--- a/src/mongo/tools/bridge.cpp
+++ b/src/mongo/tools/bridge.cpp
@@ -383,6 +383,10 @@ DbResponse ServiceEntryPointBridge::handleRequest(OperationContext* opCtx, const
} else {
dest.setExhaust(false);
}
+
+ // The original checksum won't be valid once the network layer replaces requestId. Remove it
+ // because the network layer re-checksums the response.
+ OpMsg::removeChecksum(&response);
return {std::move(response), exhaustNS};
} else {
return {Message()};
diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp
index ebcad102106..8e751d9327c 100644
--- a/src/mongo/transport/service_state_machine.cpp
+++ b/src/mongo/transport/service_state_machine.cpp
@@ -140,14 +140,23 @@ Message makeExhaustMessage(Message requestMsg, DbResponse* dbresponse) {
return Message();
}
- // Indicate that the response is part of an exhaust stream.
+ const bool checksumPresent = OpMsg::isFlagSet(requestMsg, OpMsg::kChecksumPresent);
+ OpMsg::removeChecksum(&dbresponse->response);
+ // Indicate that the response is part of an exhaust stream. Re-checksum if needed.
OpMsg::setFlag(&dbresponse->response, OpMsg::kMoreToCome);
+ if (checksumPresent) {
+ OpMsg::appendChecksum(&dbresponse->response);
+ }
// Return an augmented form of the initial request, which is to be used as the next request to
// be processed by the database. The id of the response is used as the request id of this
- // 'synthetic' request.
+ // 'synthetic' request. Re-checksum if needed.
+ OpMsg::removeChecksum(&requestMsg);
requestMsg.header().setId(dbresponse->response.header().getId());
requestMsg.header().setResponseToMsgId(dbresponse->response.header().getResponseToMsgId());
+ if (checksumPresent) {
+ OpMsg::appendChecksum(&requestMsg);
+ }
return requestMsg;
}
@@ -454,10 +463,14 @@ void ServiceStateMachine::_processMessage(ThreadGuard guard) {
Message& toSink = dbresponse.response;
if (!toSink.empty()) {
invariant(!OpMsg::isFlagSet(_inMessage, OpMsg::kMoreToCome));
+ invariant(!OpMsg::isFlagSet(toSink, OpMsg::kChecksumPresent));
// Update the header for the response message.
toSink.header().setId(nextMessageId());
toSink.header().setResponseToMsgId(_inMessage.header().getId());
+ if (OpMsg::isFlagSet(_inMessage, OpMsg::kChecksumPresent)) {
+ OpMsg::appendChecksum(&toSink);
+ }
// If the incoming message has the exhaust flag set and is a 'getMore' command, then we
// bypass the normal RPC behavior. We will sink the response to the network, but we also
diff --git a/src/mongo/transport/transport_layer_asio_test.cpp b/src/mongo/transport/transport_layer_asio_test.cpp
index 7bbaab7cd59..45552516ee0 100644
--- a/src/mongo/transport/transport_layer_asio_test.cpp
+++ b/src/mongo/transport/transport_layer_asio_test.cpp
@@ -259,6 +259,7 @@ public:
Message msg = builder.finish();
msg.header().setResponseToMsgId(0);
msg.header().setId(0);
+ OpMsg::appendChecksum(&msg);
std::error_code ec;
asio::write(_sock, asio::buffer(msg.buf(), msg.size()), ec);
diff --git a/src/mongo/util/shared_buffer.h b/src/mongo/util/shared_buffer.h
index e18db9b7e3c..b0bc09c769b 100644
--- a/src/mongo/util/shared_buffer.h
+++ b/src/mongo/util/shared_buffer.h
@@ -29,6 +29,8 @@
#pragma once
+#include <algorithm>
+
#include <boost/intrusive_ptr.hpp>
#include "mongo/platform/atomic_word.h"
@@ -73,6 +75,23 @@ public:
_holder = std::move(tmp._holder);
}
+ /**
+ * Resizes the buffer, copying the current contents. If shared, an exclusive copy is made.
+ */
+ void reallocOrCopy(size_t size) {
+ if (isShared()) {
+ auto tmp = SharedBuffer::allocate(size);
+ memcpy(tmp._holder->data(),
+ _holder->data(),
+ std::min(size, static_cast<size_t>(_holder->_capacity)));
+ swap(tmp);
+ } else if (_holder) {
+ realloc(size);
+ } else {
+ *this = SharedBuffer::allocate(size);
+ }
+ }
+
char* get() const {
return _holder ? _holder->data() : NULL;
}