diff options
Diffstat (limited to 'src/mongo/rpc')
26 files changed, 2159 insertions, 26 deletions
diff --git a/src/mongo/rpc/SConscript b/src/mongo/rpc/SConscript index c3b364020a7..637c62ffaa7 100644 --- a/src/mongo/rpc/SConscript +++ b/src/mongo/rpc/SConscript @@ -24,14 +24,19 @@ env.Library( 'protocol', ], source=[ + 'message.cpp', + 'op_msg.cpp', 'protocol.cpp', ], LIBDEPS=[ '$BUILD_DIR/mongo/base', - '$BUILD_DIR/mongo/bson/util/bson_extract', - '$BUILD_DIR/mongo/db/service_context', '$BUILD_DIR/mongo/db/wire_version', ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/bson/util/bson_extract', + '$BUILD_DIR/mongo/db/bson/dotted_path_support', + '$BUILD_DIR/mongo/db/server_options_core', + ], ) env.Library( @@ -180,6 +185,7 @@ env.CppUnitTest( 'get_status_from_command_result_test.cpp', 'legacy_request_test.cpp', 'object_check_test.cpp', + 'op_msg_test.cpp', 'protocol_test.cpp', 'reply_builder_test.cpp', ], @@ -238,3 +244,15 @@ env.CppUnitTest( 'client_metadata', ] ) + +env.CppIntegrationTest( + target='op_msg_integration_test', + source=[ + 'op_msg_integration_test.cpp', + ], + LIBDEPS=[ + 'protocol', + '$BUILD_DIR/mongo/client/clientdriver', + '$BUILD_DIR/mongo/util/version_impl', + ], +) diff --git a/src/mongo/rpc/command_reply.cpp b/src/mongo/rpc/command_reply.cpp index 884a63fae61..79de0cc5506 100644 --- a/src/mongo/rpc/command_reply.cpp +++ b/src/mongo/rpc/command_reply.cpp @@ -36,8 +36,8 @@ #include "mongo/base/data_range_cursor.h" #include "mongo/base/data_type_validated.h" #include "mongo/bson/simple_bsonobj_comparator.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/object_check.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/command_reply_builder.cpp b/src/mongo/rpc/command_reply_builder.cpp index 001522163ce..534ad779f86 100644 --- a/src/mongo/rpc/command_reply_builder.cpp +++ b/src/mongo/rpc/command_reply_builder.cpp @@ -32,10 +32,10 @@ #include <utility> +#include "mongo/rpc/message.h" #include "mongo/stdx/memory.h" #include "mongo/util/assert_util.h" #include "mongo/util/mongoutils/str.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/command_reply_builder.h b/src/mongo/rpc/command_reply_builder.h index b990ca3d8fa..cce27f90b01 100644 --- a/src/mongo/rpc/command_reply_builder.h +++ b/src/mongo/rpc/command_reply_builder.h @@ -32,9 +32,9 @@ #include "mongo/base/status.h" #include "mongo/db/jsobj.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/protocol.h" #include "mongo/rpc/reply_builder_interface.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/command_reply_test.cpp b/src/mongo/rpc/command_reply_test.cpp index 77bf5ba59a1..e207c497b2b 100644 --- a/src/mongo/rpc/command_reply_test.cpp +++ b/src/mongo/rpc/command_reply_test.cpp @@ -37,9 +37,9 @@ #include "mongo/base/data_view.h" #include "mongo/db/jsobj.h" #include "mongo/rpc/command_reply.h" +#include "mongo/rpc/message.h" #include "mongo/stdx/memory.h" #include "mongo/unittest/unittest.h" -#include "mongo/util/net/message.h" namespace { diff --git a/src/mongo/rpc/command_request.cpp b/src/mongo/rpc/command_request.cpp index 98e85f02f5f..bce2323bfe5 100644 --- a/src/mongo/rpc/command_request.cpp +++ b/src/mongo/rpc/command_request.cpp @@ -40,10 +40,10 @@ #include "mongo/bson/simple_bsonobj_comparator.h" #include "mongo/client/read_preference.h" #include "mongo/db/jsobj.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/object_check.h" #include "mongo/util/assert_util.h" #include "mongo/util/mongoutils/str.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/command_request.h b/src/mongo/rpc/command_request.h index 07cefccdb0e..d2163185422 100644 --- a/src/mongo/rpc/command_request.h +++ b/src/mongo/rpc/command_request.h @@ -29,8 +29,8 @@ #pragma once #include "mongo/db/jsobj.h" -#include "mongo/util/net/message.h" -#include "mongo/util/net/op_msg.h" +#include "mongo/rpc/message.h" +#include "mongo/rpc/op_msg.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/command_request_builder.h b/src/mongo/rpc/command_request_builder.h index 0e3ae41fbca..e890dd978b1 100644 --- a/src/mongo/rpc/command_request_builder.h +++ b/src/mongo/rpc/command_request_builder.h @@ -28,8 +28,8 @@ #pragma once -#include "mongo/util/net/message.h" -#include "mongo/util/net/op_msg.h" +#include "mongo/rpc/message.h" +#include "mongo/rpc/op_msg.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/command_request_test.cpp b/src/mongo/rpc/command_request_test.cpp index aad6dd1d3ff..5902766c2be 100644 --- a/src/mongo/rpc/command_request_test.cpp +++ b/src/mongo/rpc/command_request_test.cpp @@ -35,9 +35,9 @@ #include "mongo/db/jsobj.h" #include "mongo/rpc/command_request.h" #include "mongo/rpc/command_request_builder.h" +#include "mongo/rpc/message.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" -#include "mongo/util/net/message.h" namespace { diff --git a/src/mongo/rpc/factory.cpp b/src/mongo/rpc/factory.cpp index 6ef513c0284..94bd71f6c39 100644 --- a/src/mongo/rpc/factory.cpp +++ b/src/mongo/rpc/factory.cpp @@ -38,12 +38,12 @@ #include "mongo/rpc/legacy_reply_builder.h" #include "mongo/rpc/legacy_request.h" #include "mongo/rpc/legacy_request_builder.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/op_msg_rpc_impls.h" #include "mongo/rpc/protocol.h" #include "mongo/stdx/memory.h" #include "mongo/util/assert_util.h" #include "mongo/util/mongoutils/str.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/factory.h b/src/mongo/rpc/factory.h index fb3f7c0b853..724809ed771 100644 --- a/src/mongo/rpc/factory.h +++ b/src/mongo/rpc/factory.h @@ -28,8 +28,8 @@ #pragma once +#include "mongo/rpc/op_msg.h" #include "mongo/rpc/protocol.h" -#include "mongo/util/net/op_msg.h" #include <memory> diff --git a/src/mongo/rpc/legacy_reply_builder.h b/src/mongo/rpc/legacy_reply_builder.h index fba41aabdfe..22f40092df6 100644 --- a/src/mongo/rpc/legacy_reply_builder.h +++ b/src/mongo/rpc/legacy_reply_builder.h @@ -32,9 +32,9 @@ #include "mongo/base/status.h" #include "mongo/bson/util/builder.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/protocol.h" #include "mongo/rpc/reply_builder_interface.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/legacy_request.h b/src/mongo/rpc/legacy_request.h index 81c0bb9d94c..a5435ae09be 100644 --- a/src/mongo/rpc/legacy_request.h +++ b/src/mongo/rpc/legacy_request.h @@ -28,8 +28,8 @@ #pragma once -#include "mongo/util/net/message.h" -#include "mongo/util/net/op_msg.h" +#include "mongo/rpc/message.h" +#include "mongo/rpc/op_msg.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/legacy_request_builder.cpp b/src/mongo/rpc/legacy_request_builder.cpp index 80ace2b994e..d7978ef65a7 100644 --- a/src/mongo/rpc/legacy_request_builder.cpp +++ b/src/mongo/rpc/legacy_request_builder.cpp @@ -36,10 +36,10 @@ #include "mongo/client/dbclientinterface.h" #include "mongo/client/read_preference.h" #include "mongo/db/namespace_string.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/metadata.h" #include "mongo/stdx/memory.h" #include "mongo/util/assert_util.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/legacy_request_builder.h b/src/mongo/rpc/legacy_request_builder.h index 31c8c9b0f77..37a4057d6c2 100644 --- a/src/mongo/rpc/legacy_request_builder.h +++ b/src/mongo/rpc/legacy_request_builder.h @@ -28,8 +28,8 @@ #pragma once -#include "mongo/util/net/message.h" -#include "mongo/util/net/op_msg.h" +#include "mongo/rpc/message.h" +#include "mongo/rpc/op_msg.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/message.cpp b/src/mongo/rpc/message.cpp new file mode 100644 index 00000000000..f5a7822bd6b --- /dev/null +++ b/src/mongo/rpc/message.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/rpc/message.h" + +#include "mongo/platform/atomic_word.h" + +namespace mongo { + +namespace { +AtomicWord<int32_t> NextMsgId; +} // namespace + +int32_t nextMessageId() { + return NextMsgId.fetchAndAdd(1); +} + +} // namespace mongo diff --git a/src/mongo/rpc/message.h b/src/mongo/rpc/message.h new file mode 100644 index 00000000000..a2fe2ec8c9d --- /dev/null +++ b/src/mongo/rpc/message.h @@ -0,0 +1,478 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <cstdint> + +#include "mongo/base/data_type_endian.h" +#include "mongo/base/data_view.h" +#include "mongo/base/encoded_value_storage.h" +#include "mongo/base/static_assert.h" +#include "mongo/util/mongoutils/str.h" + +namespace mongo { + +/** + * Maximum accepted message size on the wire protocol. + */ +const size_t MaxMessageSizeBytes = 48 * 1000 * 1000; + +enum NetworkOp : int32_t { + opInvalid = 0, + opReply = 1, /* reply. responseTo is set. */ + dbUpdate = 2001, /* update object */ + dbInsert = 2002, + // dbGetByOID = 2003, + dbQuery = 2004, + dbGetMore = 2005, + dbDelete = 2006, + dbKillCursors = 2007, + // dbCommand_DEPRECATED = 2008, // + // dbCommandReply_DEPRECATED = 2009, // + dbCommand = 2010, + dbCommandReply = 2011, + dbCompressed = 2012, + dbMsg = 2013, +}; + +inline bool isSupportedRequestNetworkOp(NetworkOp op) { + switch (op) { + case dbUpdate: + case dbInsert: + case dbQuery: + case dbGetMore: + case dbDelete: + case dbKillCursors: + case dbCommand: + case dbCompressed: + case dbMsg: + return true; + case dbCommandReply: + case opReply: + default: + return false; + } +} + +enum class LogicalOp { + opInvalid, + opUpdate, + opInsert, + opQuery, + opGetMore, + opDelete, + opKillCursors, + opCommand, + opCompressed, +}; + +inline LogicalOp networkOpToLogicalOp(NetworkOp networkOp) { + switch (networkOp) { + case dbUpdate: + return LogicalOp::opUpdate; + case dbInsert: + return LogicalOp::opInsert; + case dbQuery: + return LogicalOp::opQuery; + case dbGetMore: + return LogicalOp::opGetMore; + case dbDelete: + return LogicalOp::opDelete; + case dbKillCursors: + return LogicalOp::opKillCursors; + case dbMsg: + case dbCommand: + return LogicalOp::opCommand; + case dbCompressed: + return LogicalOp::opCompressed; + default: + int op = int(networkOp); + massert(34348, str::stream() << "cannot translate opcode " << op, !op); + return LogicalOp::opInvalid; + } +} + +inline const char* networkOpToString(NetworkOp networkOp) { + switch (networkOp) { + case opInvalid: + return "none"; + case opReply: + return "reply"; + case dbUpdate: + return "update"; + case dbInsert: + return "insert"; + case dbQuery: + return "query"; + case dbGetMore: + return "getmore"; + case dbDelete: + return "remove"; + case dbKillCursors: + return "killcursors"; + case dbCommand: + return "command"; + case dbCommandReply: + return "commandReply"; + case dbCompressed: + return "compressed"; + case dbMsg: + return "msg"; + default: + int op = static_cast<int>(networkOp); + massert(16141, str::stream() << "cannot translate opcode " << op, !op); + return ""; + } +} + +inline const char* logicalOpToString(LogicalOp logicalOp) { + switch (logicalOp) { + case LogicalOp::opInvalid: + return "none"; + case LogicalOp::opUpdate: + return "update"; + case LogicalOp::opInsert: + return "insert"; + case LogicalOp::opQuery: + return "query"; + case LogicalOp::opGetMore: + return "getmore"; + case LogicalOp::opDelete: + return "remove"; + case LogicalOp::opKillCursors: + return "killcursors"; + case LogicalOp::opCommand: + return "command"; + case LogicalOp::opCompressed: + return "compressed"; + default: + MONGO_UNREACHABLE; + } +} + +namespace MSGHEADER { + +#pragma pack(1) +/** + * See http://dochub.mongodb.org/core/mongowireprotocol + */ +struct Layout { + int32_t messageLength; // total message size, including this + int32_t requestID; // identifier for this message + int32_t responseTo; // requestID from the original request + // (used in responses from db) + int32_t opCode; +}; +#pragma pack() + +class ConstView { +public: + typedef ConstDataView view_type; + + ConstView(const char* data) : _data(data) {} + + const char* view2ptr() const { + return data().view(); + } + + int32_t getMessageLength() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, messageLength)); + } + + int32_t getRequestMsgId() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, requestID)); + } + + int32_t getResponseToMsgId() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, responseTo)); + } + + int32_t getOpCode() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, opCode)); + } + +protected: + const view_type& data() const { + return _data; + } + +private: + view_type _data; +}; + +class View : public ConstView { +public: + typedef DataView view_type; + + View(char* data) : ConstView(data) {} + + using ConstView::view2ptr; + char* view2ptr() { + return data().view(); + } + + void setMessageLength(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, messageLength)); + } + + void setRequestMsgId(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, requestID)); + } + + void setResponseToMsgId(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, responseTo)); + } + + void setOpCode(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, opCode)); + } + +private: + view_type data() const { + return const_cast<char*>(ConstView::view2ptr()); + } +}; + +class Value : public EncodedValueStorage<Layout, ConstView, View> { +public: + Value() { + MONGO_STATIC_ASSERT(sizeof(Value) == sizeof(Layout)); + } + + Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {} +}; + +} // namespace MSGHEADER + +namespace MsgData { + +#pragma pack(1) +struct Layout { + MSGHEADER::Layout header; + char data[4]; +}; +#pragma pack() + +class ConstView { +public: + ConstView(const char* storage) : _storage(storage) {} + + const char* view2ptr() const { + return storage().view(); + } + + int32_t getLen() const { + return header().getMessageLength(); + } + + int32_t getId() const { + return header().getRequestMsgId(); + } + + int32_t getResponseToMsgId() const { + return header().getResponseToMsgId(); + } + + NetworkOp getNetworkOp() const { + return NetworkOp(header().getOpCode()); + } + + const char* data() const { + return storage().view(offsetof(Layout, data)); + } + + bool valid() const { + if (getLen() <= 0 || getLen() > (4 * BSONObjMaxInternalSize)) + return false; + if (getNetworkOp() < 0 || getNetworkOp() > 30000) + return false; + return true; + } + + int64_t getCursor() const { + verify(getResponseToMsgId() > 0); + verify(getNetworkOp() == opReply); + return ConstDataView(data() + sizeof(int32_t)).read<LittleEndian<int64_t>>(); + } + + int dataLen() const; // len without header + +protected: + const ConstDataView& storage() const { + return _storage; + } + + MSGHEADER::ConstView header() const { + return storage().view(offsetof(Layout, header)); + } + +private: + ConstDataView _storage; +}; + +class View : public ConstView { +public: + View(char* storage) : ConstView(storage) {} + + using ConstView::view2ptr; + char* view2ptr() { + return storage().view(); + } + + void setLen(int value) { + return header().setMessageLength(value); + } + + void setId(int32_t value) { + return header().setRequestMsgId(value); + } + + void setResponseToMsgId(int32_t value) { + return header().setResponseToMsgId(value); + } + + void setOperation(int value) { + return header().setOpCode(value); + } + + using ConstView::data; + char* data() { + return storage().view(offsetof(Layout, data)); + } + +private: + DataView storage() const { + return const_cast<char*>(ConstView::view2ptr()); + } + + MSGHEADER::View header() const { + return storage().view(offsetof(Layout, header)); + } +}; + +class Value : public EncodedValueStorage<Layout, ConstView, View> { +public: + Value() { + MONGO_STATIC_ASSERT(sizeof(Value) == sizeof(Layout)); + } + + Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {} +}; + +const int MsgDataHeaderSize = sizeof(Value) - 4; + +inline int ConstView::dataLen() const { + return getLen() - MsgDataHeaderSize; +} + +} // namespace MsgData + +class Message { +public: + Message() = default; + explicit Message(SharedBuffer data) : _buf(std::move(data)) {} + + MsgData::View header() const { + verify(!empty()); + return _buf.get(); + } + + NetworkOp operation() const { + return header().getNetworkOp(); + } + + MsgData::View singleData() const { + massert(13273, "single data buffer expected", _buf); + return header(); + } + + bool empty() const { + return !_buf; + } + + int size() const { + if (_buf) { + return MsgData::ConstView(_buf.get()).getLen(); + } + return 0; + } + + int dataSize() const { + return size() - sizeof(MSGHEADER::Value); + } + + void reset() { + _buf = {}; + } + + // use to set first buffer if empty + void setData(SharedBuffer buf) { + verify(empty()); + _buf = std::move(buf); + } + void setData(int operation, const char* msgtxt) { + setData(operation, msgtxt, strlen(msgtxt) + 1); + } + void setData(int operation, const char* msgdata, size_t len) { + verify(empty()); + size_t dataLen = len + sizeof(MsgData::Value) - 4; + _buf = SharedBuffer::allocate(dataLen); + MsgData::View d = _buf.get(); + if (len) + memcpy(d.data(), msgdata, len); + d.setLen(dataLen); + d.setOperation(operation); + } + + char* buf() { + return _buf.get(); + } + + const char* buf() const { + return _buf.get(); + } + + SharedBuffer sharedBuffer() { + return _buf; + } + + ConstSharedBuffer sharedBuffer() const { + return _buf; + } + +private: + SharedBuffer _buf; +}; + +/** + * Returns an always incrementing value to be used to assign to the next received network message. + */ +int32_t nextMessageId(); + +} // namespace mongo diff --git a/src/mongo/rpc/metadata.h b/src/mongo/rpc/metadata.h index 461927f32a7..b2b3ff13a3e 100644 --- a/src/mongo/rpc/metadata.h +++ b/src/mongo/rpc/metadata.h @@ -31,8 +31,8 @@ #include <tuple> #include "mongo/base/status_with.h" +#include "mongo/rpc/op_msg.h" #include "mongo/stdx/functional.h" -#include "mongo/util/net/op_msg.h" namespace mongo { class BSONObj; diff --git a/src/mongo/rpc/op_msg.cpp b/src/mongo/rpc/op_msg.cpp new file mode 100644 index 00000000000..f886027e192 --- /dev/null +++ b/src/mongo/rpc/op_msg.cpp @@ -0,0 +1,250 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kNetwork + +#include "mongo/platform/basic.h" + +#include "mongo/rpc/op_msg.h" + +#include <bitset> +#include <set> + +#include "mongo/base/data_type_endian.h" +#include "mongo/db/bson/dotted_path_support.h" +#include "mongo/rpc/object_check.h" +#include "mongo/util/bufreader.h" +#include "mongo/util/hex.h" +#include "mongo/util/log.h" + +namespace mongo { +namespace { + +auto kAllSupportedFlags = OpMsg::kChecksumPresent | OpMsg::kMoreToCome; + +bool containsUnknownRequiredFlags(uint32_t flags) { + const uint32_t kRequiredFlagMask = 0xffff; // Low 2 bytes are required, high 2 are optional. + return (flags & ~kAllSupportedFlags & kRequiredFlagMask) != 0; +} + +enum class Section : uint8_t { + kBody = 0, + kDocSequence = 1, +}; + +} // namespace + +uint32_t OpMsg::flags(const Message& message) { + if (message.operation() != dbMsg) + return 0; // Other command protocols are the same as no flags set. + + return BufReader(message.singleData().data(), message.dataSize()) + .read<LittleEndian<uint32_t>>(); +} + +void OpMsg::replaceFlags(Message* message, uint32_t flags) { + invariant(!message->empty()); + invariant(message->operation() == dbMsg); + invariant(message->dataSize() >= static_cast<int>(sizeof(uint32_t))); + + DataView(message->singleData().data()).write<LittleEndian<uint32_t>>(flags); +} + +OpMsg OpMsg::parse(const Message& message) try { + // It is the caller's responsibility to call the correct parser for a given message type. + invariant(!message.empty()); + invariant(message.operation() == dbMsg); + + const uint32_t flags = OpMsg::flags(message); + uassert(ErrorCodes::IllegalOpMsgFlag, + str::stream() << "Message contains illegal flags value: Ob" + << std::bitset<32>(flags).to_string(), + !containsUnknownRequiredFlags(flags)); + + constexpr int kCrc32Size = 4; + const bool haveChecksum = flags & kChecksumPresent; + const int checksumSize = haveChecksum ? kCrc32Size : 0; + + // The sections begin after the flags and before the checksum (if present). + BufReader sectionsBuf(message.singleData().data() + sizeof(flags), + message.dataSize() - sizeof(flags) - checksumSize); + + // TODO some validation may make more sense in the IDL parser. I've tagged them with comments. + bool haveBody = false; + OpMsg msg; + while (!sectionsBuf.atEof()) { + const auto sectionKind = sectionsBuf.read<Section>(); + switch (sectionKind) { + case Section::kBody: { + uassert(40430, "Multiple body sections in message", !haveBody); + haveBody = true; + msg.body = sectionsBuf.read<Validated<BSONObj>>(); + break; + } + + case Section::kDocSequence: { + // We use an O(N^2) algorithm here and an O(N*M) algorithm below. These are fastest + // for the current small values of N, but would be problematic if it is large. + // If we need more document sequences, raise the limit and use a better algorithm. + uassert(ErrorCodes::TooManyDocumentSequences, + "Too many document sequences in OP_MSG", + msg.sequences.size() < 2); // Limit is <=2 since we are about to add one. + + // The first 4 bytes are the total size, including themselves. + const auto remainingSize = + sectionsBuf.read<LittleEndian<int32_t>>() - sizeof(int32_t); + BufReader seqBuf(sectionsBuf.skip(remainingSize), remainingSize); + const auto name = seqBuf.readCStr(); + uassert(40431, + str::stream() << "Duplicate document sequence: " << name, + !msg.getSequence(name)); // TODO IDL + + msg.sequences.push_back({name.toString()}); + while (!seqBuf.atEof()) { + msg.sequences.back().objs.push_back(seqBuf.read<Validated<BSONObj>>()); + } + break; + } + + default: + // Using uint32_t so we append as a decimal number rather than as a char. + uasserted(40432, str::stream() << "Unknown section kind " << uint32_t(sectionKind)); + } + } + + uassert(40587, "OP_MSG messages must have a body", haveBody); + + // Detect duplicates between doc sequences and body. TODO IDL + // Technically this is O(N*M) but N is at most 2. + for (const auto& docSeq : msg.sequences) { + const char* name = docSeq.name.c_str(); // Pointer is redirected by next call. + auto inBody = + !dotted_path_support::extractElementAtPathOrArrayAlongPath(msg.body, name).eoo(); + uassert(40433, + str::stream() << "Duplicate field between body and document sequence " + << docSeq.name, + !inBody); + } + + return msg; +} catch (const DBException& ex) { + LOG(1) << "invalid message: " << ex.code() << " " << redact(ex) << " -- " + << redact(hexdump(message.singleData().view2ptr(), message.size())); + throw; +} + +Message OpMsg::serialize() const { + OpMsgBuilder builder; + for (auto&& seq : sequences) { + auto docSeq = builder.beginDocSequence(seq.name); + for (auto&& obj : seq.objs) { + docSeq.append(obj); + } + } + builder.beginBody().appendElements(body); + return builder.finish(); +} + +void OpMsg::shareOwnershipWith(const ConstSharedBuffer& buffer) { + if (!body.isOwned()) { + body.shareOwnershipWith(buffer); + } + for (auto&& seq : sequences) { + for (auto&& obj : seq.objs) { + if (!obj.isOwned()) { + obj.shareOwnershipWith(buffer); + } + } + } +} + +auto OpMsgBuilder::beginDocSequence(StringData name) -> DocSequenceBuilder { + invariant(_state == kEmpty || _state == kDocSequence); + invariant(!_openBuilder); + _openBuilder = true; + _state = kDocSequence; + _buf.appendStruct(Section::kDocSequence); + int sizeOffset = _buf.len(); + _buf.skip(sizeof(int32_t)); // section size. + _buf.appendStr(name, true); + return DocSequenceBuilder(this, &_buf, sizeOffset); +} + +void OpMsgBuilder::finishDocumentStream(DocSequenceBuilder* docSequenceBuilder) { + invariant(_state == kDocSequence); + invariant(_openBuilder); + _openBuilder = false; + const int32_t size = _buf.len() - docSequenceBuilder->_sizeOffset; + invariant(size > 0); + DataView(_buf.buf()).write<LittleEndian<int32_t>>(size, docSequenceBuilder->_sizeOffset); +} + +BSONObjBuilder OpMsgBuilder::beginBody() { + invariant(_state == kEmpty || _state == kDocSequence); + _state = kBody; + _buf.appendStruct(Section::kBody); + invariant(_bodyStart == 0); + _bodyStart = _buf.len(); // Cannot be 0. + return BSONObjBuilder(_buf); +} + +BSONObjBuilder OpMsgBuilder::resumeBody() { + invariant(_state == kBody); + invariant(_bodyStart != 0); + return BSONObjBuilder(BSONObjBuilder::ResumeBuildingTag(), _buf, _bodyStart); +} + +AtomicBool OpMsgBuilder::disableDupeFieldCheck_forTest{false}; + +Message OpMsgBuilder::finish() { + if (kDebugBuild && !disableDupeFieldCheck_forTest.load()) { + std::set<StringData> seenFields; + for (auto elem : resumeBody().asTempObj()) { + if (!(seenFields.insert(elem.fieldNameStringData()).second)) { + severe() << "OP_MSG with duplicate field '" << elem.fieldNameStringData() + << "' : " << redact(resumeBody().asTempObj()); + fassert(40474, false); + } + } + } + + invariant(_state == kBody); + invariant(_bodyStart); + invariant(!_openBuilder); + _state = kDone; + + const auto size = _buf.len(); + MSGHEADER::View header(_buf.buf()); + header.setMessageLength(size); + // header.setRequestMsgId(...); // These are currently filled in by the networking layer. + // header.setResponseToMsgId(...); + header.setOpCode(dbMsg); + return Message(_buf.release()); +} + +} // namespace mongo diff --git a/src/mongo/rpc/op_msg.h b/src/mongo/rpc/op_msg.h new file mode 100644 index 00000000000..ed7f1993d86 --- /dev/null +++ b/src/mongo/rpc/op_msg.h @@ -0,0 +1,313 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <algorithm> +#include <string> +#include <vector> + +#include "mongo/base/string_data.h" +#include "mongo/db/jsobj.h" +#include "mongo/rpc/message.h" + +namespace mongo { + +struct OpMsg { + struct DocumentSequence { + std::string name; + std::vector<BSONObj> objs; + }; + + static constexpr uint32_t kChecksumPresent = 1 << 0; + static constexpr uint32_t kMoreToCome = 1 << 1; + + /** + * Returns the unvalidated flags for the given message if it is an OP_MSG message. + * Returns 0 for other message kinds since they are the equivalent of no flags set. + * Throws if the message is too small to hold flags. + */ + static uint32_t flags(const Message& message); + static bool isFlagSet(const Message& message, uint32_t flag) { + return flags(message) & flag; + } + + /** + * Replaces the flags in message with the supplied flags. + * Only legal on an otherwise valid OP_MSG message. + */ + static void replaceFlags(Message* message, uint32_t flags); + + /** + * Adds flag to the list of set flags in message. + * Only legal on an otherwise valid OP_MSG message. + */ + static void setFlag(Message* message, uint32_t flag) { + replaceFlags(message, flags(*message) | flag); + } + + /** + * Parses and returns an OpMsg containing unowned BSON. + */ + static OpMsg parse(const Message& message); + + /** + * Parses and returns an OpMsg containing owned BSON. + */ + static OpMsg parseOwned(const Message& message) { + auto msg = parse(message); + msg.shareOwnershipWith(message.sharedBuffer()); + return msg; + } + + Message serialize() const; + + /** + * Makes all BSONObjs in this object share ownership with buffer. + */ + void shareOwnershipWith(const ConstSharedBuffer& buffer); + + /** + * Returns a pointer to the sequence with the given name or nullptr if there are none. + */ + const DocumentSequence* getSequence(StringData name) const { + // Getting N sequences is technically O(N**2) but because there currently is at most 2 + // sequences, this does either 1 or 2 comparisons. Consider making sequences a StringMap if + // there will be many sequences. This problem may also just go away with the IDL project. + auto it = std::find_if( + sequences.begin(), sequences.end(), [&](const auto& seq) { return seq.name == name; }); + return it == sequences.end() ? nullptr : &*it; + } + + BSONObj body; + std::vector<DocumentSequence> sequences; +}; + +/** + * An OpMsg that represents a request. This is a separate type from OpMsg only to provide better + * type-safety along with a place to hang request-specific methods. + */ +struct OpMsgRequest : public OpMsg { + // TODO in C++17 remove constructors so we can use aggregate initialization. + OpMsgRequest() = default; + explicit OpMsgRequest(OpMsg&& generic) : OpMsg(std::move(generic)) {} + + static OpMsgRequest parse(const Message& message) { + return OpMsgRequest(OpMsg::parse(message)); + } + + static OpMsgRequest fromDBAndBody(StringData db, + BSONObj body, + const BSONObj& extraFields = {}) { + OpMsgRequest request; + request.body = ([&] { + BSONObjBuilder bodyBuilder(std::move(body)); + bodyBuilder.appendElements(extraFields); + bodyBuilder.append("$db", db); + return bodyBuilder.obj(); + }()); + return request; + } + + StringData getDatabase() const { + if (auto elem = body["$db"]) + return elem.checkAndGetStringData(); + uasserted(40571, "OP_MSG requests require a $db argument"); + } + + StringData getCommandName() const { + return body.firstElementFieldName(); + } + + // DO NOT ADD MEMBERS! Since this type is essentially a strong typedef (see the class comment), + // it should not hold more data than an OpMsg. It should be freely interconvertible with OpMsg + // without issues like slicing. +}; + +/** + * Builds an OP_MSG message in-place in a Message buffer. + * + * While the OP_MSG format imposes no ordering of sections, in order to efficiently support our + * usage patterns, this class requires that all document sequences (if any) are built before the + * body. This allows repeatedly appending fields to the body until right before it is ready to be + * sent. + */ +class OpMsgBuilder { + MONGO_DISALLOW_COPYING(OpMsgBuilder); + +public: + OpMsgBuilder() { + skipHeaderAndFlags(); + } + + /** + * See the documentation for DocSequenceBuilder below. + */ + class DocSequenceBuilder; + DocSequenceBuilder beginDocSequence(StringData name); + + /** + * Returns an empty builder for the body. + * It is an error to call this if a body has already been begun. You must destroy or call + * done() on the returned builder before calling any methods on this object. + */ + BSONObjBuilder beginBody(); + void setBody(const BSONObj& body) { + beginBody().appendElements(body); + } + + /** + * Returns a builder that can be used to append new fields to the body. + * It is an error to call this if beginBody() hasn't been called yet. It is an error to append + * elements with field names that already exist in the body. You must destroy or call done() on + * the returned builder before calling any methods on this object. + * + * TODO decide if it is worth keeping the begin/resume distinction in the public API. + */ + BSONObjBuilder resumeBody(); + void appendElementsToBody(const BSONObj& body) { + resumeBody().appendElements(body); + } + + /** + * Finish building and return a Message ready to give to the networking layer for transmission. + * It is illegal to call any methods on this object after calling this. + */ + Message finish(); + + /** + * Reset this object to its initial empty state. All previously appended data is lost. + */ + void reset() { + invariant(!_openBuilder); + + _buf.reset(); + skipHeaderAndFlags(); + _bodyStart = 0; + _state = kEmpty; + _openBuilder = false; + } + + /** + * Set to true in tests that need to be able to generate duplicate top-level fields to see how + * the server handles them. Is false by default, although the check only happens in debug + * builds. + */ + static AtomicBool disableDupeFieldCheck_forTest; + +private: + friend class DocSequenceBuilder; + + enum State { + kEmpty, + kDocSequence, + kBody, + kDone, + }; + + void finishDocumentStream(DocSequenceBuilder* docSequenceBuilder); + + void skipHeaderAndFlags() { + _buf.skip(sizeof(MSGHEADER::Layout)); // This is filled in by finish(). + _buf.appendNum(uint32_t(0)); // flags (currently always 0). + } + + // When adding members, remember to update reset(). + BufBuilder _buf; + int _bodyStart = 0; + State _state = kEmpty; + bool _openBuilder = false; +}; + +/** + * Builds a document sequence in an OpMsgBuilder. + * + * Example: + * + * auto docSeq = msgBuilder.beginDocSequence("some.sequence"); + * + * docSeq.append(BSON("a" << 1)); // Copy an obj into the sequence + * + * auto bob = docSeq.appendBuilder(); // Build an obj in-place + * bob.append("a", 2); + * bob.doneFast(); + * + * docSeq.done(); // Or just let it go out of scope. + */ +class OpMsgBuilder::DocSequenceBuilder { + MONGO_DISALLOW_COPYING(DocSequenceBuilder); + +public: + DocSequenceBuilder(DocSequenceBuilder&& other) + : _buf(other._buf), _msgBuilder(other._msgBuilder), _sizeOffset(other._sizeOffset) { + other._buf = nullptr; + } + + ~DocSequenceBuilder() { + if (_buf) + done(); + } + + /** + * Indicates that the caller is done with this stream prior to destruction. + * Following this call, it is illegal to call any methods on this object. + */ + void done() { + invariant(_buf); + _msgBuilder->finishDocumentStream(this); + _buf = nullptr; + } + + /** + * Appends a single document to this sequence. + */ + void append(const BSONObj& obj) { + _buf->appendBuf(obj.objdata(), obj.objsize()); + } + + /** + * Returns a BSONObjBuilder that appends a single document to this sequence in place. + * It is illegal to call any methods on this DocSequenceBuilder until the returned builder + * is destroyed or done()/doneFast() is called on it. + */ + BSONObjBuilder appendBuilder() { + return BSONObjBuilder(*_buf); + } + +private: + friend OpMsgBuilder; + + DocSequenceBuilder(OpMsgBuilder* msgBuilder, BufBuilder* buf, int sizeOffset) + : _buf(buf), _msgBuilder(msgBuilder), _sizeOffset(sizeOffset) {} + + BufBuilder* _buf; + OpMsgBuilder* const _msgBuilder; + const int _sizeOffset; +}; + +} // namespace mongo diff --git a/src/mongo/rpc/op_msg_integration_test.cpp b/src/mongo/rpc/op_msg_integration_test.cpp new file mode 100644 index 00000000000..6dd9fec7f2f --- /dev/null +++ b/src/mongo/rpc/op_msg_integration_test.cpp @@ -0,0 +1,168 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/client/dbclientinterface.h" +#include "mongo/rpc/get_status_from_command_result.h" +#include "mongo/rpc/op_msg.h" +#include "mongo/unittest/integration_test.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/scopeguard.h" + +namespace mongo { + +TEST(OpMsg, UnknownRequiredFlagClosesConnection) { + std::string errMsg; + auto conn = std::unique_ptr<DBClientBase>( + unittest::getFixtureConnectionString().connect("integration_test", errMsg)); + uassert(ErrorCodes::SocketException, errMsg, conn); + + auto request = OpMsgRequest::fromDBAndBody("admin", BSON("ping" << 1)).serialize(); + OpMsg::setFlag(&request, 1u << 15); // This should be the last required flag to be assigned. + + Message reply; + ASSERT(!conn->call(request, reply, /*assertOK*/ false)); +} + +TEST(OpMsg, UnknownOptionalFlagIsIgnored) { + std::string errMsg; + auto conn = std::unique_ptr<DBClientBase>( + unittest::getFixtureConnectionString().connect("integration_test", errMsg)); + uassert(ErrorCodes::SocketException, errMsg, conn); + + auto request = OpMsgRequest::fromDBAndBody("admin", BSON("ping" << 1)).serialize(); + OpMsg::setFlag(&request, 1u << 31); // This should be the last optional flag to be assigned. + + Message reply; + ASSERT(conn->call(request, reply)); + uassertStatusOK(getStatusFromCommandResult( + conn->parseCommandReplyMessage(conn->getServerAddress(), reply)->getCommandReply())); +} + +TEST(OpMsg, FireAndForgetInsertWorks) { + std::string errMsg; + auto conn = std::unique_ptr<DBClientBase>( + unittest::getFixtureConnectionString().connect("integration_test", errMsg)); + uassert(ErrorCodes::SocketException, errMsg, conn); + + conn->dropCollection("test.collection"); + + conn->runFireAndForgetCommand(OpMsgRequest::fromDBAndBody("test", fromjson(R"({ + insert: "collection", + writeConcern: {w: 0}, + documents: [ + {a: 1} + ] + })"))); + + ASSERT_EQ(conn->count("test.collection"), 1u); +} + +TEST(OpMsg, CloseConnectionOnFireAndForgetNotMasterError) { + const auto connStr = unittest::getFixtureConnectionString(); + + // This test only works against a replica set. + if (connStr.type() != ConnectionString::SET) { + return; + } + + bool foundSecondary = false; + for (auto host : connStr.getServers()) { + DBClientConnection conn; + uassertStatusOK(conn.connect(host, "integration_test")); + bool isMaster; + ASSERT(conn.isMaster(isMaster)); + if (isMaster) + continue; + foundSecondary = true; + + auto request = OpMsgRequest::fromDBAndBody("test", fromjson(R"({ + insert: "collection", + writeConcern: {w: 0}, + documents: [ + {a: 1} + ] + })")).serialize(); + + // Round-trip command fails with NotMaster error. Note that this failure is in command + // dispatch which ignores w:0. + Message reply; + ASSERT(conn.call(request, reply, /*assertOK*/ true, nullptr)); + ASSERT_EQ( + getStatusFromCommandResult( + conn.parseCommandReplyMessage(conn.getServerAddress(), reply)->getCommandReply()), + ErrorCodes::NotMaster); + + // Fire-and-forget closes connection when it sees that error. Note that this is using call() + // rather than say() so that we get an error back when the connection is closed. Normally + // using call() if kMoreToCome set results in blocking forever. + OpMsg::setFlag(&request, OpMsg::kMoreToCome); + ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr)); + + uassertStatusOK(conn.connect(host, "integration_test")); // Reconnect. + + // Disable eager checking of master to simulate a stepdown occurring after the check. This + // should respect w:0. + BSONObj output; + ASSERT(conn.runCommand("admin", + fromjson(R"({ + configureFailPoint: 'skipCheckingForNotMasterInCommandDispatch', + mode: 'alwaysOn' + })"), + output)) + << output; + ON_BLOCK_EXIT([&] { + uassertStatusOK(conn.connect(host, "integration_test-cleanup")); + ASSERT(conn.runCommand("admin", + fromjson(R"({ + configureFailPoint: + 'skipCheckingForNotMasterInCommandDispatch', + mode: 'off' + })"), + output)) + << output; + }); + + + // Round-trip command claims to succeed due to w:0. + OpMsg::replaceFlags(&request, 0); + ASSERT(conn.call(request, reply, /*assertOK*/ true, nullptr)); + ASSERT_OK(getStatusFromCommandResult( + conn.parseCommandReplyMessage(conn.getServerAddress(), reply)->getCommandReply())); + + // Fire-and-forget should still close connection. + OpMsg::setFlag(&request, OpMsg::kMoreToCome); + ASSERT(!conn.call(request, reply, /*assertOK*/ false, nullptr)); + + break; + } + ASSERT(foundSecondary); +} + +} // namespace mongo diff --git a/src/mongo/rpc/op_msg_rpc_impls.h b/src/mongo/rpc/op_msg_rpc_impls.h index 67149810d04..4725456630f 100644 --- a/src/mongo/rpc/op_msg_rpc_impls.h +++ b/src/mongo/rpc/op_msg_rpc_impls.h @@ -28,9 +28,9 @@ #pragma once +#include "mongo/rpc/op_msg.h" #include "mongo/rpc/reply_builder_interface.h" #include "mongo/rpc/reply_interface.h" -#include "mongo/util/net/op_msg.h" namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/op_msg_test.cpp b/src/mongo/rpc/op_msg_test.cpp new file mode 100644 index 00000000000..29df95c9136 --- /dev/null +++ b/src/mongo/rpc/op_msg_test.cpp @@ -0,0 +1,863 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kDefault + +#include "mongo/platform/basic.h" + +#include <type_traits> + +#include "mongo/base/static_assert.h" +#include "mongo/bson/json.h" +#include "mongo/bson/util/builder.h" +#include "mongo/db/jsobj.h" +#include "mongo/rpc/op_msg.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/hex.h" +#include "mongo/util/log.h" + +namespace mongo { +namespace { + +// Makes a SharedBuffer out of arguments passed to constructor. +class Bytes { +public: + template <typename... T> + explicit Bytes(T&&... args) { + append(args...); + } + +protected: + void append() {} // no-op base case + + template <typename T, typename... Rest> + std::enable_if_t<std::is_integral<T>::value> append(T arg, Rest&&... rest) { + // Make sure BufBuilder has a real overload of this exact type and it isn't implicitly + // converted. + (void)static_cast<void (BufBuilder::*)(T)>(&BufBuilder::appendNum); + + buffer.appendNum(arg); // automatically little endian. + append(rest...); + } + + template <typename... Rest> + void append(const BSONObj& arg, Rest&&... rest) { + arg.appendSelfToBufBuilder(buffer); + append(rest...); + } + + template <typename... Rest> + void append(const Bytes& arg, Rest&&... rest) { + buffer.appendBuf(arg.buffer.buf(), arg.buffer.len()); + append(rest...); + } + + template <typename... Rest> + void append(StringData arg, Rest&&... rest) { + buffer.appendStr(arg, /* null terminate*/ true); + append(rest...); + } + + BufBuilder buffer; +}; + +// A Bytes that puts the size of the buffer at the front as a little-endian int32 +class Sized : public Bytes { +public: + template <typename... T> + explicit Sized(T&&... args) { + buffer.skip(sizeof(int32_t)); + append(args...); + DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len()); + } + + // Adds extra to the stored size. Use this to produce illegal messages. + Sized&& addToSize(int32_t extra) && { + DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra); + return std::move(*this); + } +}; + +// A Bytes that puts the standard message header at the front. +class OpMsgBytes : public Sized { +public: + template <typename... T> + explicit OpMsgBytes(T&&... args) + : Sized{int32_t{1}, // requestId + int32_t{2}, // replyId + int32_t{dbMsg}, // opCode + args...} {} + + Message done() { + const auto orig = Message(buffer.release()); + // Copy the message to an exact-sized allocation so ASAN can detect out-of-bounds accesses. + auto copy = SharedBuffer::allocate(orig.size()); + memcpy(copy.get(), orig.buf(), orig.size()); + return Message(std::move(copy)); + } + + OpMsg parse() { + return OpMsg::parseOwned(done()); + } + + OpMsgBytes&& addToSize(int32_t extra) && { + DataView(buffer.buf()).write<LittleEndian<int32_t>>(buffer.len() + extra); + return std::move(*this); + } +}; + +// Fixture class to raise log verbosity so that invalid messages are printed by the parser. +class OpMsgParser : public unittest::Test { +public: + void setUp() override { + _original = + logger::globalLogDomain()->getMinimumLogSeverity(logger::LogComponent::kNetwork); + logger::globalLogDomain()->setMinimumLoggedSeverity(logger::LogComponent::kNetwork, + logger::LogSeverity::Debug(1)); + } + void tearDown() override { + logger::globalLogDomain()->setMinimumLoggedSeverity(logger::LogComponent::kNetwork, + _original); + } + +private: + logger::LogSeverity _original = logger::LogSeverity::Debug(0); +}; + +// Section bytes +const char kBodySection = 0; +const char kDocSequenceSection = 1; + +// Flags +const uint32_t kNoFlags = 0; +const uint32_t kHaveChecksum = 1; + +// CRC filler value +const uint32_t kFakeCRC = 0; // TODO will need to compute real crc when SERVER-28679 is done. + +TEST_F(OpMsgParser, SucceedsWithJustBody) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 0u); +} + +TEST_F(OpMsgParser, IgnoresCrcIfPresent) { // Until SERVER-28679 is done. + auto msg = OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), + kFakeCRC, // If not ignored, this would be read as a second body. + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 0u); +} + +TEST_F(OpMsgParser, SucceedsWithBodyThenSequence) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + fromjson("{a: 2}"), + }, + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 1u); + ASSERT_EQ(msg.sequences[0].name, "docs"); + ASSERT_EQ(msg.sequences[0].objs.size(), 2u); + ASSERT_BSONOBJ_EQ(msg.sequences[0].objs[0], fromjson("{a: 1}")); + ASSERT_BSONOBJ_EQ(msg.sequences[0].objs[1], fromjson("{a: 2}")); +} + +TEST_F(OpMsgParser, SucceedsWithSequenceThenBody) { + auto msg = OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + + kBodySection, + fromjson("{ping: 1}"), + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 1u); + ASSERT_EQ(msg.sequences[0].name, "docs"); + ASSERT_EQ(msg.sequences[0].objs.size(), 1u); + ASSERT_BSONOBJ_EQ(msg.sequences[0].objs[0], fromjson("{a: 1}")); +} + +TEST_F(OpMsgParser, SucceedsWithSequenceThenBodyThenSequence) { + auto msg = OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "empty", // + }, + + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 2u); + ASSERT_EQ(msg.sequences[0].name, "empty"); + ASSERT_EQ(msg.sequences[0].objs.size(), 0u); + ASSERT_EQ(msg.sequences[1].name, "docs"); + ASSERT_EQ(msg.sequences[1].objs.size(), 1u); + ASSERT_BSONOBJ_EQ(msg.sequences[1].objs[0], fromjson("{a: 1}")); +} + +TEST_F(OpMsgParser, SucceedsWithSequenceThenSequenceThenBody) { + auto msg = OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "empty", // + }, + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + + kBodySection, + fromjson("{ping: 1}"), + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 2u); + ASSERT_EQ(msg.sequences[0].name, "empty"); + ASSERT_EQ(msg.sequences[0].objs.size(), 0u); + ASSERT_EQ(msg.sequences[1].name, "docs"); + ASSERT_EQ(msg.sequences[1].objs.size(), 1u); + ASSERT_BSONOBJ_EQ(msg.sequences[1].objs[0], fromjson("{a: 1}")); +} + +TEST_F(OpMsgParser, SucceedsWithBodyThenSequenceThenSequence) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + + kDocSequenceSection, + Sized{ + "empty", // + }, + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1}")); + ASSERT_EQ(msg.sequences.size(), 2u); + ASSERT_EQ(msg.sequences[0].name, "docs"); + ASSERT_EQ(msg.sequences[0].objs.size(), 1u); + ASSERT_BSONOBJ_EQ(msg.sequences[0].objs[0], fromjson("{a: 1}")); + ASSERT_EQ(msg.sequences[1].name, "empty"); + ASSERT_EQ(msg.sequences[1].objs.size(), 0u); +} + +TEST_F(OpMsgParser, FailsIfNoBody) { + auto msg = OpMsgBytes{ + kNoFlags, // + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40587); +} + +TEST_F(OpMsgParser, FailsIfNoBodyEvenWithSequence) { + auto msg = OpMsgBytes{ + kNoFlags, // + + kDocSequenceSection, + Sized{"docs"}, + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40587); +} + +TEST_F(OpMsgParser, FailsIfTwoBodies) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kBodySection, + fromjson("{pong: 1}"), + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40430); +} + +TEST_F(OpMsgParser, FailsIfDuplicateSequences) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{"docs"}, + + kDocSequenceSection, + Sized{"docs"}, + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40431); +} + +TEST_F(OpMsgParser, FailsIfDuplicateSequenceWithBody) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1, 'docs': []}"), + + kDocSequenceSection, + Sized{"docs"}, + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40433); +} + +TEST_F(OpMsgParser, FailsIfDuplicateSequenceWithBodyNested) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1, a: {b:[]}}"), + + kDocSequenceSection, + Sized{"a.b"}, + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40433); +} + +TEST_F(OpMsgParser, SucceedsIfSequenceAndBodyHaveCommonPrefix) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{cursor: {ns: 'foo.bar', id: 1}}"), + + kDocSequenceSection, + Sized{ + "cursor.firstBatch", // + fromjson("{_id: 1}"), + }, + }.parse(); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{cursor: {ns: 'foo.bar', id: 1}}")); + ASSERT_EQ(msg.sequences.size(), 1u); + ASSERT_EQ(msg.sequences[0].name, "cursor.firstBatch"); + ASSERT_EQ(msg.sequences[0].objs.size(), 1u); + ASSERT_BSONOBJ_EQ(msg.sequences[0].objs[0], fromjson("{_id: 1}")); +} + +TEST_F(OpMsgParser, FailsIfUnknownSectionKind) { + auto msg = OpMsgBytes{ + kNoFlags, // + '\x99', // This is where a section kind would go + Sized{}, + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, 40432); +} + +TEST_F(OpMsgParser, FailsIfBodyTooBig) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + }.addToSize(-1); // Shrink message so body extends past end. + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON); +} + +TEST_F(OpMsgParser, FailsIfBodyTooBigIntoChecksum) { + auto msg = OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), + kFakeCRC, + }.addToSize(-1); // Shrink message so body extends past end. + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON); +} + +TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBig) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + }.addToSize(-1); // Shrink message so body extends past end. + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow); +} + +TEST_F(OpMsgParser, FailsIfDocumentSequenceTooBigIntoChecksum) { + auto msg = OpMsgBytes{ + kHaveChecksum, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }, + + kFakeCRC, + }.addToSize(-1); // Shrink message so body extends past end. + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow); +} + +TEST_F(OpMsgParser, FailsIfDocumentInSequenceTooBig) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + }.addToSize(-1), // Shrink sequence so document extends past end. + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::InvalidBSON); +} + +TEST_F(OpMsgParser, FailsIfNameOfDocumentSequenceTooBig) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{ + "foo", + }.addToSize(-1), // Shrink sequence so document extends past end. + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow); +} + +TEST_F(OpMsgParser, FailsIfNameOfDocumentSequenceHasNoNulTerminator) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{'f', 'o', 'o'}, + // No '\0' at end of document. ASAN should complain if we keep looking for one. + }; + + ASSERT_THROWS_CODE(msg.parse(), AssertionException, ErrorCodes::Overflow); +} + +TEST_F(OpMsgParser, FailsIfTooManyDocumentSequences) { + auto msg = OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + + kDocSequenceSection, + Sized{"foo"}, + + kDocSequenceSection, + Sized{"bar"}, + + kDocSequenceSection, + Sized{"baz"}, + }; + + ASSERT_THROWS_WITH_CHECK( + msg.parse(), ExceptionFor<ErrorCodes::TooManyDocumentSequences>, [](const DBException& ex) { + ASSERT(ex.isA<ErrorCategory::ConnectionFatalMessageParseError>()); + }); +} + +TEST_F(OpMsgParser, FailsIfNoRoomForFlags) { + // Flags are 4 bytes. Try 0-3 bytes. + ASSERT_THROWS_CODE(OpMsgBytes{}.parse(), AssertionException, ErrorCodes::Overflow); + ASSERT_THROWS_CODE(OpMsgBytes{'\0'}.parse(), AssertionException, ErrorCodes::Overflow); + ASSERT_THROWS_CODE((OpMsgBytes{'\0', '\0'}.parse()), AssertionException, ErrorCodes::Overflow); + ASSERT_THROWS_CODE( + (OpMsgBytes{'\0', '\0', '\0'}.parse()), AssertionException, ErrorCodes::Overflow); + + ASSERT_THROWS_CODE(OpMsg::flags(OpMsgBytes{}.done()), AssertionException, ErrorCodes::Overflow); + ASSERT_THROWS_CODE( + OpMsg::flags(OpMsgBytes{'\0'}.done()), AssertionException, ErrorCodes::Overflow); + ASSERT_THROWS_CODE( + OpMsg::flags(OpMsgBytes{'\0', '\0'}.done()), AssertionException, ErrorCodes::Overflow); + ASSERT_THROWS_CODE(OpMsg::flags(OpMsgBytes{'\0', '\0', '\0'}.done()), + AssertionException, + ErrorCodes::Overflow); +} + +TEST_F(OpMsgParser, FlagExtractionWorks) { + ASSERT_EQ(OpMsg::flags(OpMsgBytes{0u}.done()), 0u); // All clear. + ASSERT_EQ(OpMsg::flags(OpMsgBytes{~0u}.done()), ~0u); // All set. + + for (auto i = uint32_t(0); i < 32u; i++) { + const auto flags = uint32_t(1) << i; + ASSERT_EQ(OpMsg::flags(OpMsgBytes{flags}.done()), flags) << flags; + ASSERT(OpMsg::isFlagSet(OpMsgBytes{flags}.done(), flags)) << flags; + ASSERT(!OpMsg::isFlagSet(OpMsgBytes{~flags}.done(), flags)) << flags; + ASSERT(!OpMsg::isFlagSet(OpMsgBytes{0u}.done(), flags)) << flags; + ASSERT(OpMsg::isFlagSet(OpMsgBytes{~0u}.done(), flags)) << flags; + } +} + +TEST_F(OpMsgParser, FailsWithUnknownRequiredFlags) { + // Bits 0 and 1 are known, and bits >= 16 are optional. + for (auto i = uint32_t(2); i < 16u; i++) { + auto flags = uint32_t(1) << i; + auto msg = OpMsgBytes{ + flags, // + kBodySection, + fromjson("{ping: 1}"), + }; + + ASSERT_THROWS_WITH_CHECK(msg.parse(), AssertionException, [](const DBException& ex) { + ASSERT_EQ(ex.toStatus().code(), ErrorCodes::IllegalOpMsgFlag); + ASSERT(ErrorCodes::isConnectionFatalMessageParseError(ex.toStatus().code())); + }); + } +} + +TEST_F(OpMsgParser, SucceedsWithUnknownOptionalFlags) { + // bits >= 16 are optional. + for (auto i = uint32_t(16); i < 32u; i++) { + auto flags = uint32_t(1) << i; + OpMsgBytes{ + flags, // + kBodySection, + fromjson("{ping: 1}"), + }.parse(); + } +} + +void testSerializer(const Message& fromSerializer, OpMsgBytes&& expected) { + const auto expectedMsg = expected.done(); + ASSERT_EQ(fromSerializer.operation(), dbMsg); + // Ignoring request and reply ids since they aren't handled by OP_MSG code. + + auto gotSD = StringData(fromSerializer.singleData().data(), fromSerializer.dataSize()); + auto expectedSD = StringData(expectedMsg.singleData().data(), expectedMsg.dataSize()); + if (gotSD == expectedSD) + return; + + size_t commonLength = + std::mismatch(gotSD.begin(), gotSD.end(), expectedSD.begin(), expectedSD.end()).first - + gotSD.begin(); + + log() << "Mismatch after " << commonLength << " bytes."; + log() << "Common prefix: " << hexdump(gotSD.rawData(), commonLength); + log() << "Got suffix : " + << hexdump(gotSD.rawData() + commonLength, gotSD.size() - commonLength); + log() << "Expected suffix: " + << hexdump(expectedSD.rawData() + commonLength, expectedSD.size() - commonLength); + FAIL("Serialization didn't match expected data. See above for details."); +} + +TEST(OpMsgSerializer, JustBody) { + OpMsg msg; + msg.body = fromjson("{ping: 1}"); + + testSerializer(msg.serialize(), + OpMsgBytes{ + kNoFlags, // + kBodySection, + fromjson("{ping: 1}"), + }); +} + +TEST(OpMsgSerializer, BodyAndSequence) { + OpMsg msg; + msg.body = fromjson("{ping: 1}"); + msg.sequences = {{"docs", {fromjson("{a:1}"), fromjson("{a:2}")}}}; + + testSerializer(msg.serialize(), + OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + fromjson("{a: 2}"), + }, + + kBodySection, + fromjson("{ping: 1}"), + }); +} + +TEST(OpMsgSerializer, BodyAndEmptySequence) { + OpMsg msg; + msg.body = fromjson("{ping: 1}"); + msg.sequences = {{"docs", {}}}; + + testSerializer(msg.serialize(), + OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "docs", // + }, + + kBodySection, + fromjson("{ping: 1}"), + }); +} + +TEST(OpMsgSerializer, BodyAndTwoSequences) { + OpMsg msg; + msg.body = fromjson("{ping: 1}"); + msg.sequences = { + {"a", {fromjson("{a: 1}")}}, // + {"b", {fromjson("{b: 1}")}}, + }; + + testSerializer(msg.serialize(), + OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "a", // + fromjson("{a: 1}"), + }, + + kDocSequenceSection, + Sized{ + "b", // + fromjson("{b: 1}"), + }, + + kBodySection, + fromjson("{ping: 1}"), + }); +} + +TEST(OpMsgSerializer, BodyAndSequenceInPlace) { + OpMsgBuilder builder; + + auto emptySeq = builder.beginDocSequence("empty"); + emptySeq.done(); + + { + auto seq = builder.beginDocSequence("docs"); + seq.append(fromjson("{a: 1}")); + seq.appendBuilder().append("a", 2); + } + + builder.beginBody().append("ping", 1); + builder.resumeBody().append("$db", "foo"); + + testSerializer(builder.finish(), + OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "empty", + }, + + kDocSequenceSection, + Sized{ + "docs", // + fromjson("{a: 1}"), + fromjson("{a: 2}"), + }, + + kBodySection, + fromjson("{ping: 1, $db: 'foo'}"), + }); +} + +TEST(OpMsgSerializer, BodyAndInPlaceSequenceInPlaceWithReset) { + OpMsgBuilder builder; + + auto emptySeq = builder.beginDocSequence("empty"); + emptySeq.done(); + + { + auto seq = builder.beginDocSequence("docs"); + seq.append(fromjson("{a: 1}")); + seq.appendBuilder().append("a", 2); + } + + builder.beginBody().append("ping", 1); + builder.resumeBody().append("$db", "foo"); + + builder.reset(); + + // Everything above shouldn't matter. + + { + auto seq = builder.beginDocSequence("docs2"); + seq.append(fromjson("{b: 1}")); + seq.appendBuilder().append("b", 2); + } + + builder.beginBody().append("pong", 1); + + testSerializer(builder.finish(), + OpMsgBytes{ + kNoFlags, // + kDocSequenceSection, + Sized{ + "docs2", // + fromjson("{b: 1}"), + fromjson("{b: 2}"), + }, + + kBodySection, + fromjson("{pong: 1}"), + }); +} + +TEST(OpMsgSerializer, ReplaceFlagsWorks) { + { + auto msg = OpMsgBytes{~0u}.done(); + OpMsg::replaceFlags(&msg, 0u); + ASSERT_EQ(OpMsg::flags(msg), 0u); + } + { + auto msg = OpMsgBytes{0u}.done(); + OpMsg::replaceFlags(&msg, ~0u); + ASSERT_EQ(OpMsg::flags(msg), ~0u); + } + + for (auto i = uint32_t(0); i < 32u; i++) { + auto flags = uint32_t(1) << i; + { + auto msg = OpMsgBytes{0u}.done(); + OpMsg::replaceFlags(&msg, flags); + ASSERT_EQ(OpMsg::flags(msg), flags) << flags; + } + { + auto msg = OpMsgBytes{~0u}.done(); + OpMsg::replaceFlags(&msg, flags); + ASSERT_EQ(OpMsg::flags(msg), flags) << flags; + } + { + auto msg = OpMsgBytes{~flags}.done(); + OpMsg::replaceFlags(&msg, flags); + ASSERT_EQ(OpMsg::flags(msg), flags) << flags; + } + } +} + +TEST(OpMsgSerializer, SetFlagWorks) { + for (auto i = uint32_t(0); i < 32u; i++) { + auto flags = uint32_t(1) << i; + { + auto msg = OpMsgBytes{0u}.done(); + OpMsg::setFlag(&msg, flags); + ASSERT_EQ(OpMsg::flags(msg), flags) << flags; + } + { + auto msg = OpMsgBytes{~0u}.done(); + OpMsg::setFlag(&msg, flags); + ASSERT_EQ(OpMsg::flags(msg), ~0u) << flags; + } + { + auto msg = OpMsgBytes{~flags}.done(); + OpMsg::setFlag(&msg, flags); + ASSERT_EQ(OpMsg::flags(msg), ~0u) << flags; + } + } +} + +TEST(OpMsgRequest, GetDatabaseWorks) { + OpMsgRequest msg; + msg.body = fromjson("{$db: 'foo'}"); + ASSERT_EQ(msg.getDatabase(), "foo"); + + msg.body = fromjson("{before: 1, $db: 'foo'}"); + ASSERT_EQ(msg.getDatabase(), "foo"); + + msg.body = fromjson("{before: 1, $db: 'foo', after: 1}"); + ASSERT_EQ(msg.getDatabase(), "foo"); +} + +TEST(OpMsgRequest, GetDatabaseThrowsWrongType) { + OpMsgRequest msg; + msg.body = fromjson("{$db: 1}"); + ASSERT_THROWS(msg.getDatabase(), DBException); +} + +TEST(OpMsgRequest, GetDatabaseThrowsMissing) { + OpMsgRequest msg; + msg.body = fromjson("{}"); + ASSERT_THROWS(msg.getDatabase(), AssertionException); + + msg.body = fromjson("{$notdb: 'foo'}"); + ASSERT_THROWS(msg.getDatabase(), AssertionException); +} + +TEST(OpMsgRequest, FromDbAndBodyDoesNotCopy) { + auto body = fromjson("{ping: 1}"); + const void* const bodyPtr = body.objdata(); + auto msg = OpMsgRequest::fromDBAndBody("db", std::move(body)); + + ASSERT_BSONOBJ_EQ(msg.body, fromjson("{ping: 1, $db: 'db'}")); + ASSERT_EQ(static_cast<const void*>(msg.body.objdata()), bodyPtr); +} +} // namespace +} // namespace mongo diff --git a/src/mongo/rpc/protocol.h b/src/mongo/rpc/protocol.h index c7500761e98..10e770271d3 100644 --- a/src/mongo/rpc/protocol.h +++ b/src/mongo/rpc/protocol.h @@ -34,7 +34,7 @@ #include "mongo/base/status_with.h" #include "mongo/db/wire_version.h" -#include "mongo/util/net/message.h" +#include "mongo/rpc/message.h" namespace mongo { class BSONObj; diff --git a/src/mongo/rpc/reply_builder_interface.cpp b/src/mongo/rpc/reply_builder_interface.cpp index 9c5b2998e2f..461e3edc6a3 100644 --- a/src/mongo/rpc/reply_builder_interface.cpp +++ b/src/mongo/rpc/reply_builder_interface.cpp @@ -34,8 +34,6 @@ #include "mongo/base/status_with.h" #include "mongo/db/jsobj.h" -#include "mongo/util/net/message.h" - namespace mongo { namespace rpc { diff --git a/src/mongo/rpc/unique_message.h b/src/mongo/rpc/unique_message.h index dd23729445a..b98d806dc0a 100644 --- a/src/mongo/rpc/unique_message.h +++ b/src/mongo/rpc/unique_message.h @@ -32,8 +32,8 @@ #include <utility> #include "mongo/base/disallow_copying.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/reply_interface.h" -#include "mongo/util/net/message.h" namespace mongo { namespace rpc { |