From d94b263ca72b874440757259a56bd5676b5ee4a3 Mon Sep 17 00:00:00 2001 From: Mathias Stearn Date: Mon, 17 Jul 2017 11:13:13 -0400 Subject: SERVER-30118 always check replyTo field in DBClient --- src/mongo/client/dbclient.cpp | 6 +++++- src/mongo/client/dbclient_rs.cpp | 4 ++-- src/mongo/client/dbclient_rs.h | 2 +- src/mongo/client/dbclientcursor.cpp | 29 ++++++++++++++++++++--------- src/mongo/client/dbclientcursor.h | 2 ++ src/mongo/client/dbclientinterface.h | 4 ++-- 6 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/mongo/client/dbclient.cpp b/src/mongo/client/dbclient.cpp index ff2dfd86cb8..8fedea65d05 100644 --- a/src/mongo/client/dbclient.cpp +++ b/src/mongo/client/dbclient.cpp @@ -1243,12 +1243,16 @@ void DBClientConnection::say(Message& toSend, bool isRetry, string* actualServer } } -bool DBClientConnection::recv(Message& m) { +bool DBClientConnection::recv(Message& m, int lastRequestId) { if (!port().recv(m)) { _failed = true; return false; } + uassert(40570, + "Response ID did not match the sent message ID.", + m.header().getResponseToMsgId() == lastRequestId); + if (m.operation() == dbCompressed) { auto swm = _compressorManager.decompressMessage(m); uassertStatusOK(swm.getStatus()); diff --git a/src/mongo/client/dbclient_rs.cpp b/src/mongo/client/dbclient_rs.cpp index 4b589fbdb22..1af22471a22 100644 --- a/src/mongo/client/dbclient_rs.cpp +++ b/src/mongo/client/dbclient_rs.cpp @@ -819,12 +819,12 @@ void DBClientReplicaSet::say(Message& toSend, bool isRetry, string* actualServer return; } -bool DBClientReplicaSet::recv(Message& m) { +bool DBClientReplicaSet::recv(Message& m, int lastRequestId) { verify(_lazyState._lastClient); // TODO: It would be nice if we could easily wrap a conn error as a result error try { - return _lazyState._lastClient->recv(m); + return _lazyState._lastClient->recv(m, lastRequestId); } catch (DBException& e) { log() << "could not receive data from " << _lazyState._lastClient->toString() << causedBy(redact(e)); diff --git a/src/mongo/client/dbclient_rs.h b/src/mongo/client/dbclient_rs.h index 784b03bef3a..6c6361f11ba 100644 --- a/src/mongo/client/dbclient_rs.h +++ b/src/mongo/client/dbclient_rs.h @@ -132,7 +132,7 @@ public: // ---- callback pieces ------- virtual void say(Message& toSend, bool isRetry = false, std::string* actualServer = 0); - virtual bool recv(Message& toRecv); + virtual bool recv(Message& toRecv, int lastRequestId); virtual void checkResponse(const std::vector& batch, bool networkError, bool* retry = NULL, diff --git a/src/mongo/client/dbclientcursor.cpp b/src/mongo/client/dbclientcursor.cpp index 832c384cbdf..e9abd6c770a 100644 --- a/src/mongo/client/dbclientcursor.cpp +++ b/src/mongo/client/dbclientcursor.cpp @@ -114,6 +114,7 @@ void DBClientCursor::_assembleInit(Message& toSend) { } bool DBClientCursor::init() { + invariant(!_connectionHasPendingReplies); Message toSend; _assembleInit(toSend); verify(_client); @@ -139,11 +140,15 @@ void DBClientCursor::initLazy(bool isRetry) { Message toSend; _assembleInit(toSend); _client->say(toSend, isRetry, &_originalHost); + _lastRequestId = toSend.header().getId(); + _connectionHasPendingReplies = true; } bool DBClientCursor::initLazyFinish(bool& retry) { + invariant(_connectionHasPendingReplies); Message reply; - bool recvd = _client->recv(reply); + bool recvd = _client->recv(reply, _lastRequestId); + _connectionHasPendingReplies = false; // If we get a bad response, return false if (!recvd || reply.empty()) { @@ -163,6 +168,7 @@ bool DBClientCursor::initLazyFinish(bool& retry) { } void DBClientCursor::requestMore() { + invariant(!_connectionHasPendingReplies); verify(cursorId && batch.pos == batch.objs.size()); if (haveLimit) { @@ -193,7 +199,7 @@ void DBClientCursor::exhaustReceiveMore() { verify(!haveLimit); Message response; verify(_client); - if (!_client->recv(response)) { + if (!_client->recv(response, _lastRequestId)) { uasserted(16465, "recv failed while exhausting cursor"); } dataReceived(response); @@ -259,6 +265,13 @@ void DBClientCursor::dataReceived(const Message& reply, bool& retry, string& hos cursorId = qr.getCursorId(); } + if (opts & QueryOption_Exhaust) { + // With exhaust mode, each reply after the first claims to be a reply to the previous one + // rather than the initial request. + _connectionHasPendingReplies = (cursorId != 0); + _lastRequestId = reply.header().getId(); + } + batch.pos = 0; batch.objs.clear(); batch.objs.reserve(qr.getNReturned()); @@ -460,20 +473,18 @@ DBClientCursor::~DBClientCursor() { } void DBClientCursor::kill() { - DESTRUCTOR_GUARD( - + DESTRUCTOR_GUARD({ if (cursorId && _ownCursor && !globalInShutdownDeprecated()) { - if (_client) { + if (_client && !_connectionHasPendingReplies) { _client->killCursor(cursorId); } else { - verify(_scopedHost.size()); - ScopedDbConnection conn(_scopedHost); + verify(_scopedHost.size() || (_client && _connectionHasPendingReplies)); + ScopedDbConnection conn(_client ? _client->getServerAddress() : _scopedHost); conn->killCursor(cursorId); conn.done(); } } - - ); + }); // Mark this cursor as dead since we can't do any getMores. cursorId = 0; diff --git a/src/mongo/client/dbclientcursor.h b/src/mongo/client/dbclientcursor.h index 7cd394f2e13..91d4e9ad95c 100644 --- a/src/mongo/client/dbclientcursor.h +++ b/src/mongo/client/dbclientcursor.h @@ -248,6 +248,8 @@ private: std::string _lazyHost; bool wasError; BSONVersion _enabledBSONVersion; + bool _connectionHasPendingReplies = false; + int _lastRequestId = 0; void dataReceived(const Message& reply) { bool retry; diff --git a/src/mongo/client/dbclientinterface.h b/src/mongo/client/dbclientinterface.h index d6bd3389901..60de0894123 100644 --- a/src/mongo/client/dbclientinterface.h +++ b/src/mongo/client/dbclientinterface.h @@ -227,7 +227,7 @@ public: std::string* actualServer = nullptr) = 0; /* used by QueryOption_Exhaust. To use that your subclass must implement this. */ - virtual bool recv(Message& m) { + virtual bool recv(Message& m, int lastRequestId) { verify(false); return false; } @@ -948,7 +948,7 @@ public: } virtual void say(Message& toSend, bool isRetry = false, std::string* actualServer = 0); - virtual bool recv(Message& m); + virtual bool recv(Message& m, int lastRequestId); virtual void checkResponse(const std::vector& batch, bool networkError, bool* retry = NULL, -- cgit v1.2.1