diff options
Diffstat (limited to 'src/mongo/client/dbclient.cpp')
-rw-r--r-- | src/mongo/client/dbclient.cpp | 254 |
1 files changed, 142 insertions, 112 deletions
diff --git a/src/mongo/client/dbclient.cpp b/src/mongo/client/dbclient.cpp index 2d05e518b7b..d7f6c164552 100644 --- a/src/mongo/client/dbclient.cpp +++ b/src/mongo/client/dbclient.cpp @@ -70,12 +70,12 @@ #include "mongo/util/debug_util.h" #include "mongo/util/fail_point_service.h" #include "mongo/util/log.h" -#include "mongo/util/net/message_port.h" +#include "mongo/util/net/sock.h" #include "mongo/util/net/socket_exception.h" #include "mongo/util/net/ssl_manager.h" #include "mongo/util/net/ssl_options.h" #include "mongo/util/password_digest.h" -#include "mongo/util/represent_as.h" +#include "mongo/util/scopeguard.h" #include "mongo/util/time_support.h" #include "mongo/util/version.h" @@ -851,7 +851,7 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData auto swIsMasterReply = initWireVersion(this, _applicationName); if (!swIsMasterReply.isOK()) { - _failed = true; + _markFailed(kSetFlag); return swIsMasterReply.status; } @@ -922,8 +922,7 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData auto validationStatus = _hook(swIsMasterReply); if (!validationStatus.isOK()) { // Disconnect and mark failed. - _failed = true; - _port.reset(); + _markFailed(kReleaseSession); return validationStatus; } } @@ -931,27 +930,9 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData return Status::OK(); } -namespace { -const auto kMaxMillisCount = Milliseconds::max().count(); -} // namespace - Status DBClientConnection::connectSocketOnly(const HostAndPort& serverAddress) { _serverAddress = serverAddress; - _failed = true; - - // We need to construct a SockAddr so we can resolve the address. - SockAddr osAddr{serverAddress.host().c_str(), - serverAddress.port(), - static_cast<sa_family_t>(IPv6Enabled() ? AF_UNSPEC : AF_INET)}; - - if (!osAddr.isValid()) { - return Status(ErrorCodes::InvalidOptions, - str::stream() << "couldn't initialize connection to host " - << serverAddress.host() - << ", address is invalid"); - } - - _port.reset(new MessagingPort(_so_timeout, _logLevel)); + _markFailed(kReleaseSession); if (serverAddress.host().empty()) { return Status(ErrorCodes::InvalidOptions, @@ -959,47 +940,45 @@ Status DBClientConnection::connectSocketOnly(const HostAndPort& serverAddress) { << ", host is empty"); } - if (osAddr.getAddr() == "0.0.0.0") { + if (serverAddress.host() == "0.0.0.0") { return Status(ErrorCodes::InvalidOptions, str::stream() << "couldn't connect to server " << _serverAddress.toString() << ", address resolved to 0.0.0.0"); } - _resolvedAddress = osAddr.getAddr(); - - if (!_port->connect(osAddr)) { - return Status(ErrorCodes::HostUnreachable, - str::stream() << "couldn't connect to server " << _serverAddress.toString() - << ", connection attempt failed"); - } - + transport::ConnectSSLMode sslMode = transport::kGlobalSSLMode; #ifdef MONGO_CONFIG_SSL // Prefer to get SSL mode directly from our URI, but if it is not set, fall back to // checking global SSL params. DBClientConnections create through the shell will have a // meaningful URI set, but DBClientConnections created from within the server may not. - int sslMode; auto options = _uri.getOptions(); auto iter = options.find("ssl"); if (iter != options.end()) { if (iter->second == "true") { - sslMode = SSLParams::SSLMode_requireSSL; + sslMode = transport::kEnableSSL; } else { - sslMode = SSLParams::SSLMode_disabled; + sslMode = transport::kDisableSSL; } - } else { - sslMode = sslGlobalParams.sslMode.load(); } - if (sslMode == SSLParams::SSLMode_preferSSL || sslMode == SSLParams::SSLMode_requireSSL) { - uassert(40312, "SSL is not enabled; cannot create an SSL connection", sslManager()); - if (!_port->secure(sslManager(), serverAddress.host())) { - return Status(ErrorCodes::SSLHandshakeFailed, "Failed to initialize SSL on connection"); - } - } #endif + auto tl = getGlobalServiceContext()->getTransportLayer(); + auto sws = tl->connect(serverAddress, sslMode, _socketTimeout.value_or(Milliseconds{5000})); + if (!sws.isOK()) { + return Status(ErrorCodes::HostUnreachable, + str::stream() << "couldn't connect to server " << _serverAddress.toString() + << ", connection attempt failed: " + << sws.getStatus().toString()); + } + + _session = std::move(sws.getValue()); + _sessionCreationMicros = curTimeMicros64(); + _lastConnectivityCheck = Date_t::now(); + _session->setTimeout(_socketTimeout); + _session->setTags(_tagMask); _failed = false; - LOG(1) << "connected to server " << toString() << endl; + LOG(1) << "connected to server " << toString(); return Status::OK(); } @@ -1040,13 +1019,62 @@ rpc::UniqueReply DBClientConnection::parseCommandReplyMessage(const std::string& return DBClientBase::parseCommandReplyMessage(host, std::move(replyMsg)); } catch (const DBException& ex) { if (ErrorCodes::isConnectionFatalMessageParseError(ex.code())) { - _port->shutdown(); - _failed = true; + _markFailed(kEndSession); } throw; } } +void DBClientConnection::_markFailed(FailAction action) { + _failed = true; + if (_session) { + if (action == kEndSession) { + _session->end(); + } else if (action == kReleaseSession) { + _session.reset(); + } + } +} + +bool DBClientConnection::isStillConnected() { + // This method tries to figure out whether the connection is still open, but with several + // caveats. + + // If we don't have a _session then we may have hit an error, or we may just not have + // connected yet - the _failed flag should indicate which. + // + // Otherwise, return false if we know we've had an error (_failed is true) + if (!_session) { + return !_failed; + } else if (_failed) { + return false; + } + + // Checking whether the socket actually has an error by calling _session->isConnected() + // is actually pretty expensive, so we cache the result for 5 seconds + auto now = getGlobalServiceContext()->getFastClockSource()->now(); + if (now - _lastConnectivityCheck < Seconds{5}) { + return true; + } + + _lastConnectivityCheck = now; + + // This will poll() the underlying socket and do a 1 byte recv to see if the connection + // has been closed. + return _session->isConnected(); +} + +void DBClientConnection::setTags(transport::Session::TagMask tags) { + _tagMask = tags; + if (!_session) + return; + _session->setTags(tags); +} + +void DBClientConnection::shutdown() { + _markFailed(kEndSession); +} + void DBClientConnection::_checkConnection() { if (!_failed) return; @@ -1062,7 +1090,7 @@ void DBClientConnection::_checkConnection() { _failed = false; auto connectStatus = connect(_serverAddress, _applicationName); if (!connectStatus.isOK()) { - _failed = true; + _markFailed(kSetFlag); LOG(_logLevel) << "reconnect " << toString() << " failed " << errmsg << endl; if (connectStatus == ErrorCodes::IncompatibleCatalogManager) { uassertStatusOK(connectStatus); // Will always throw @@ -1087,24 +1115,29 @@ void DBClientConnection::_checkConnection() { } void DBClientConnection::setSoTimeout(double timeout) { - _so_timeout = timeout; - if (_port) { - // `timeout` is in seconds. - auto ms = representAs<int64_t>(std::floor(timeout * 1000)).value_or(kMaxMillisCount); - _port->setTimeout(ms > kMaxMillisCount ? Milliseconds::max() : Milliseconds(ms)); + Milliseconds::rep timeoutMs = std::floor(timeout * 1000); + if (timeout <= 0) { + _socketTimeout = boost::none; + } else if (timeoutMs >= Milliseconds::max().count()) { + _socketTimeout = Milliseconds::max(); + } else { + _socketTimeout = Milliseconds{timeoutMs}; + } + + if (_session) { + _session->setTimeout(_socketTimeout); } } uint64_t DBClientConnection::getSockCreationMicroSec() const { - if (_port) { - return _port->getSockCreationMicroSec(); + if (_session) { + return _sessionCreationMicros; } else { return INVALID_SOCK_CREATION_TIME; } } -const uint64_t DBClientBase::INVALID_SOCK_CREATION_TIME = - static_cast<uint64_t>(0xFFFFFFFFFFFFFFFFULL); +const uint64_t DBClientBase::INVALID_SOCK_CREATION_TIME = std::numeric_limits<uint64_t>::max(); unique_ptr<DBClientCursor> DBClientBase::query(const string& ns, Query query, @@ -1206,8 +1239,7 @@ unsigned long long DBClientConnection::query(stdx::function<void(DBClientCursorB /* connection CANNOT be used anymore as more data may be on the way from the server. we have to reconnect. */ - _failed = true; - _port->shutdown(); + _markFailed(kEndSession); throw; } @@ -1378,7 +1410,6 @@ DBClientConnection::DBClientConnection(bool _autoReconnect, : _failed(false), autoReconnect(_autoReconnect), autoReconnectBackoff(1000, 2000), - _so_timeout(so_timeout), _hook(hook), _uri(std::move(uri)) { _numConnections.fetchAndAdd(1); @@ -1386,74 +1417,73 @@ DBClientConnection::DBClientConnection(bool _autoReconnect, void DBClientConnection::say(Message& toSend, bool isRetry, string* actualServer) { checkConnection(); - try { - toSend.header().setId(nextMessageId()); - toSend.header().setResponseToMsgId(0); - auto swm = _compressorManager.compressMessage(toSend); - uassertStatusOK(swm.getStatus()); - port().say(swm.getValue()); - } catch (const DBException&) { - _failed = true; - _port->shutdown(); - throw; - } + auto killSessionOnError = MakeGuard([this] { _markFailed(kEndSession); }); + + toSend.header().setId(nextMessageId()); + toSend.header().setResponseToMsgId(0); + uassertStatusOK( + _session->sinkMessage(uassertStatusOK(_compressorManager.compressMessage(toSend)))); + killSessionOnError.Dismiss(); } bool DBClientConnection::recv(Message& m, int lastRequestId) { - if (!port().recv(m)) { - _failed = true; + auto killSessionOnError = MakeGuard([this] { _markFailed(kEndSession); }); + auto swm = _session->sourceMessage(); + if (!swm.isOK()) { return false; } - try { - uassert(40570, - "Response ID did not match the sent message ID.", - m.header().getResponseToMsgId() == lastRequestId); - - if (m.operation() == dbCompressed) { - m = uassertStatusOK(_compressorManager.decompressMessage(m)); - } + m = std::move(swm.getValue()); + uassert(40570, + "Response ID did not match the sent message ID.", + m.header().getResponseToMsgId() == lastRequestId); - return true; - } catch (const DBException&) { - _failed = true; - _port->shutdown(); - throw; + if (m.operation() == dbCompressed) { + m = uassertStatusOK(_compressorManager.decompressMessage(m)); } + + killSessionOnError.Dismiss(); + return true; } bool DBClientConnection::call(Message& toSend, Message& response, bool assertOk, string* actualServer) { - /* todo: this is very ugly messagingport::call returns an error code AND can throw - an exception. we should make it return void and just throw an exception anytime - it fails - */ checkConnection(); - try { - toSend.header().setId(nextMessageId()); - toSend.header().setResponseToMsgId(0); - auto swm = _compressorManager.compressMessage(toSend); - uassertStatusOK(swm.getStatus()); - - if (!port().call(swm.getValue(), response)) { - _failed = true; - if (assertOk) - uasserted(10278, - str::stream() << "dbclient error communicating with server: " - << getServerAddress()); - return false; - } + auto killSessionOnError = MakeGuard([this] { _markFailed(kEndSession); }); + auto maybeThrow = [&](const auto& errStatus) { + if (assertOk) + uasserted(10278, + str::stream() << "dbclient error communicating with server " + << getServerAddress() + << ": " + << redact(errStatus)); + return false; + }; - if (response.operation() == dbCompressed) { - response = uassertStatusOK(_compressorManager.decompressMessage(response)); - } - } catch (const DBException&) { - _failed = true; - _port->shutdown(); - throw; + toSend.header().setId(nextMessageId()); + toSend.header().setResponseToMsgId(0); + auto swm = _compressorManager.compressMessage(toSend); + uassertStatusOK(swm.getStatus()); + + auto sinkStatus = _session->sinkMessage(swm.getValue()); + if (!sinkStatus.isOK()) { + return maybeThrow(sinkStatus); + } + + swm = _session->sourceMessage(); + if (swm.isOK()) { + response = std::move(swm.getValue()); + } else { + return maybeThrow(swm.getStatus()); } + + if (response.operation() == dbCompressed) { + response = uassertStatusOK(_compressorManager.decompressMessage(response)); + } + + killSessionOnError.Dismiss(); return true; } @@ -1504,7 +1534,7 @@ void DBClientConnection::handleNotMasterResponse(const BSONObj& replyBody, << _parentReplSetName}); } - _failed = true; + _markFailed(kSetFlag); } AtomicInt32 DBClientConnection::_numConnections; |