diff options
author | Jonathan Reams <jbreams@mongodb.com> | 2018-02-20 14:33:42 -0500 |
---|---|---|
committer | Jonathan Reams <jbreams@mongodb.com> | 2018-03-02 11:07:01 -0500 |
commit | b2d8bd06318e1fddf4f1579084bbda4fb556c176 (patch) | |
tree | f591c41a0100dc85b51177396e80b946822aa712 | |
parent | 975d539ae068bd27ebb478b6f3673b89d2ad6beb (diff) | |
download | mongo-b2d8bd06318e1fddf4f1579084bbda4fb556c176.tar.gz |
SERVER-33300 Integrate TransportLayer with DBClient
33 files changed, 890 insertions, 251 deletions
diff --git a/src/mongo/SConscript b/src/mongo/SConscript index b1030f3a380..55ff0660453 100644 --- a/src/mongo/SConscript +++ b/src/mongo/SConscript @@ -466,6 +466,7 @@ if not has_option('noshell') and usemozjs: 'scripting/scripting', 'shell/mongojs', 'transport/message_compressor', + 'transport/transport_layer_manager', 'util/net/network', 'util/options_parser/options_parser_init', 'util/processinfo', @@ -508,6 +509,7 @@ if not has_option('noshell') and usemozjs: "$BUILD_DIR/third_party/shim_pcrecpp", "shell_core", "db/server_options_core", + "db/service_context_noop_init", "client/clientdriver", "$BUILD_DIR/mongo/util/password", ], diff --git a/src/mongo/client/SConscript b/src/mongo/client/SConscript index 3a48533419e..92785bdae90 100644 --- a/src/mongo/client/SConscript +++ b/src/mongo/client/SConscript @@ -47,6 +47,7 @@ env.CppUnitTest( ], LIBDEPS=[ 'clientdriver', + '$BUILD_DIR/mongo/transport/transport_layer_egress_init', ] ) @@ -282,6 +283,7 @@ env.CppUnitTest( '$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init', '$BUILD_DIR/mongo/db/service_context_noop_init', '$BUILD_DIR/mongo/transport/transport_layer', + '$BUILD_DIR/mongo/transport/transport_layer_egress_init', '$BUILD_DIR/mongo/util/net/network', '$BUILD_DIR/mongo/util/version_impl', ], diff --git a/src/mongo/client/connection_pool.cpp b/src/mongo/client/connection_pool.cpp index 965ce7d790c..dbccb711d8d 100644 --- a/src/mongo/client/connection_pool.cpp +++ b/src/mongo/client/connection_pool.cpp @@ -105,7 +105,7 @@ void ConnectionPool::closeAllInUseConnections() { stdx::lock_guard<stdx::mutex> lk(_mutex); for (ConnectionList::iterator iter = _inUseConnections.begin(); iter != _inUseConnections.end(); ++iter) { - iter->conn->port().shutdown(); + iter->conn->shutdown(); } } @@ -189,7 +189,7 @@ ConnectionPool::ConnectionList::iterator ConnectionPool::acquireConnection( conn->setSoTimeout(durationCount<Milliseconds>(timeout) / 1000.0); uassertStatusOK(conn->connect(target, StringData())); - conn->port().setTag(conn->port().getTag() | _messagingPortTags); + conn->setTags(_messagingPortTags); if (isInternalAuthSet()) { conn->auth(getInternalUserAuthParams()); diff --git a/src/mongo/client/dbclient.cpp b/src/mongo/client/dbclient.cpp index 2d05e518b7b..d7f6c164552 100644 --- a/src/mongo/client/dbclient.cpp +++ b/src/mongo/client/dbclient.cpp @@ -70,12 +70,12 @@ #include "mongo/util/debug_util.h" #include "mongo/util/fail_point_service.h" #include "mongo/util/log.h" -#include "mongo/util/net/message_port.h" +#include "mongo/util/net/sock.h" #include "mongo/util/net/socket_exception.h" #include "mongo/util/net/ssl_manager.h" #include "mongo/util/net/ssl_options.h" #include "mongo/util/password_digest.h" -#include "mongo/util/represent_as.h" +#include "mongo/util/scopeguard.h" #include "mongo/util/time_support.h" #include "mongo/util/version.h" @@ -851,7 +851,7 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData auto swIsMasterReply = initWireVersion(this, _applicationName); if (!swIsMasterReply.isOK()) { - _failed = true; + _markFailed(kSetFlag); return swIsMasterReply.status; } @@ -922,8 +922,7 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData auto validationStatus = _hook(swIsMasterReply); if (!validationStatus.isOK()) { // Disconnect and mark failed. - _failed = true; - _port.reset(); + _markFailed(kReleaseSession); return validationStatus; } } @@ -931,27 +930,9 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData return Status::OK(); } -namespace { -const auto kMaxMillisCount = Milliseconds::max().count(); -} // namespace - Status DBClientConnection::connectSocketOnly(const HostAndPort& serverAddress) { _serverAddress = serverAddress; - _failed = true; - - // We need to construct a SockAddr so we can resolve the address. - SockAddr osAddr{serverAddress.host().c_str(), - serverAddress.port(), - static_cast<sa_family_t>(IPv6Enabled() ? AF_UNSPEC : AF_INET)}; - - if (!osAddr.isValid()) { - return Status(ErrorCodes::InvalidOptions, - str::stream() << "couldn't initialize connection to host " - << serverAddress.host() - << ", address is invalid"); - } - - _port.reset(new MessagingPort(_so_timeout, _logLevel)); + _markFailed(kReleaseSession); if (serverAddress.host().empty()) { return Status(ErrorCodes::InvalidOptions, @@ -959,47 +940,45 @@ Status DBClientConnection::connectSocketOnly(const HostAndPort& serverAddress) { << ", host is empty"); } - if (osAddr.getAddr() == "0.0.0.0") { + if (serverAddress.host() == "0.0.0.0") { return Status(ErrorCodes::InvalidOptions, str::stream() << "couldn't connect to server " << _serverAddress.toString() << ", address resolved to 0.0.0.0"); } - _resolvedAddress = osAddr.getAddr(); - - if (!_port->connect(osAddr)) { - return Status(ErrorCodes::HostUnreachable, - str::stream() << "couldn't connect to server " << _serverAddress.toString() - << ", connection attempt failed"); - } - + transport::ConnectSSLMode sslMode = transport::kGlobalSSLMode; #ifdef MONGO_CONFIG_SSL // Prefer to get SSL mode directly from our URI, but if it is not set, fall back to // checking global SSL params. DBClientConnections create through the shell will have a // meaningful URI set, but DBClientConnections created from within the server may not. - int sslMode; auto options = _uri.getOptions(); auto iter = options.find("ssl"); if (iter != options.end()) { if (iter->second == "true") { - sslMode = SSLParams::SSLMode_requireSSL; + sslMode = transport::kEnableSSL; } else { - sslMode = SSLParams::SSLMode_disabled; + sslMode = transport::kDisableSSL; } - } else { - sslMode = sslGlobalParams.sslMode.load(); } - if (sslMode == SSLParams::SSLMode_preferSSL || sslMode == SSLParams::SSLMode_requireSSL) { - uassert(40312, "SSL is not enabled; cannot create an SSL connection", sslManager()); - if (!_port->secure(sslManager(), serverAddress.host())) { - return Status(ErrorCodes::SSLHandshakeFailed, "Failed to initialize SSL on connection"); - } - } #endif + auto tl = getGlobalServiceContext()->getTransportLayer(); + auto sws = tl->connect(serverAddress, sslMode, _socketTimeout.value_or(Milliseconds{5000})); + if (!sws.isOK()) { + return Status(ErrorCodes::HostUnreachable, + str::stream() << "couldn't connect to server " << _serverAddress.toString() + << ", connection attempt failed: " + << sws.getStatus().toString()); + } + + _session = std::move(sws.getValue()); + _sessionCreationMicros = curTimeMicros64(); + _lastConnectivityCheck = Date_t::now(); + _session->setTimeout(_socketTimeout); + _session->setTags(_tagMask); _failed = false; - LOG(1) << "connected to server " << toString() << endl; + LOG(1) << "connected to server " << toString(); return Status::OK(); } @@ -1040,13 +1019,62 @@ rpc::UniqueReply DBClientConnection::parseCommandReplyMessage(const std::string& return DBClientBase::parseCommandReplyMessage(host, std::move(replyMsg)); } catch (const DBException& ex) { if (ErrorCodes::isConnectionFatalMessageParseError(ex.code())) { - _port->shutdown(); - _failed = true; + _markFailed(kEndSession); } throw; } } +void DBClientConnection::_markFailed(FailAction action) { + _failed = true; + if (_session) { + if (action == kEndSession) { + _session->end(); + } else if (action == kReleaseSession) { + _session.reset(); + } + } +} + +bool DBClientConnection::isStillConnected() { + // This method tries to figure out whether the connection is still open, but with several + // caveats. + + // If we don't have a _session then we may have hit an error, or we may just not have + // connected yet - the _failed flag should indicate which. + // + // Otherwise, return false if we know we've had an error (_failed is true) + if (!_session) { + return !_failed; + } else if (_failed) { + return false; + } + + // Checking whether the socket actually has an error by calling _session->isConnected() + // is actually pretty expensive, so we cache the result for 5 seconds + auto now = getGlobalServiceContext()->getFastClockSource()->now(); + if (now - _lastConnectivityCheck < Seconds{5}) { + return true; + } + + _lastConnectivityCheck = now; + + // This will poll() the underlying socket and do a 1 byte recv to see if the connection + // has been closed. + return _session->isConnected(); +} + +void DBClientConnection::setTags(transport::Session::TagMask tags) { + _tagMask = tags; + if (!_session) + return; + _session->setTags(tags); +} + +void DBClientConnection::shutdown() { + _markFailed(kEndSession); +} + void DBClientConnection::_checkConnection() { if (!_failed) return; @@ -1062,7 +1090,7 @@ void DBClientConnection::_checkConnection() { _failed = false; auto connectStatus = connect(_serverAddress, _applicationName); if (!connectStatus.isOK()) { - _failed = true; + _markFailed(kSetFlag); LOG(_logLevel) << "reconnect " << toString() << " failed " << errmsg << endl; if (connectStatus == ErrorCodes::IncompatibleCatalogManager) { uassertStatusOK(connectStatus); // Will always throw @@ -1087,24 +1115,29 @@ void DBClientConnection::_checkConnection() { } void DBClientConnection::setSoTimeout(double timeout) { - _so_timeout = timeout; - if (_port) { - // `timeout` is in seconds. - auto ms = representAs<int64_t>(std::floor(timeout * 1000)).value_or(kMaxMillisCount); - _port->setTimeout(ms > kMaxMillisCount ? Milliseconds::max() : Milliseconds(ms)); + Milliseconds::rep timeoutMs = std::floor(timeout * 1000); + if (timeout <= 0) { + _socketTimeout = boost::none; + } else if (timeoutMs >= Milliseconds::max().count()) { + _socketTimeout = Milliseconds::max(); + } else { + _socketTimeout = Milliseconds{timeoutMs}; + } + + if (_session) { + _session->setTimeout(_socketTimeout); } } uint64_t DBClientConnection::getSockCreationMicroSec() const { - if (_port) { - return _port->getSockCreationMicroSec(); + if (_session) { + return _sessionCreationMicros; } else { return INVALID_SOCK_CREATION_TIME; } } -const uint64_t DBClientBase::INVALID_SOCK_CREATION_TIME = - static_cast<uint64_t>(0xFFFFFFFFFFFFFFFFULL); +const uint64_t DBClientBase::INVALID_SOCK_CREATION_TIME = std::numeric_limits<uint64_t>::max(); unique_ptr<DBClientCursor> DBClientBase::query(const string& ns, Query query, @@ -1206,8 +1239,7 @@ unsigned long long DBClientConnection::query(stdx::function<void(DBClientCursorB /* connection CANNOT be used anymore as more data may be on the way from the server. we have to reconnect. */ - _failed = true; - _port->shutdown(); + _markFailed(kEndSession); throw; } @@ -1378,7 +1410,6 @@ DBClientConnection::DBClientConnection(bool _autoReconnect, : _failed(false), autoReconnect(_autoReconnect), autoReconnectBackoff(1000, 2000), - _so_timeout(so_timeout), _hook(hook), _uri(std::move(uri)) { _numConnections.fetchAndAdd(1); @@ -1386,74 +1417,73 @@ DBClientConnection::DBClientConnection(bool _autoReconnect, void DBClientConnection::say(Message& toSend, bool isRetry, string* actualServer) { checkConnection(); - try { - toSend.header().setId(nextMessageId()); - toSend.header().setResponseToMsgId(0); - auto swm = _compressorManager.compressMessage(toSend); - uassertStatusOK(swm.getStatus()); - port().say(swm.getValue()); - } catch (const DBException&) { - _failed = true; - _port->shutdown(); - throw; - } + auto killSessionOnError = MakeGuard([this] { _markFailed(kEndSession); }); + + toSend.header().setId(nextMessageId()); + toSend.header().setResponseToMsgId(0); + uassertStatusOK( + _session->sinkMessage(uassertStatusOK(_compressorManager.compressMessage(toSend)))); + killSessionOnError.Dismiss(); } bool DBClientConnection::recv(Message& m, int lastRequestId) { - if (!port().recv(m)) { - _failed = true; + auto killSessionOnError = MakeGuard([this] { _markFailed(kEndSession); }); + auto swm = _session->sourceMessage(); + if (!swm.isOK()) { return false; } - try { - uassert(40570, - "Response ID did not match the sent message ID.", - m.header().getResponseToMsgId() == lastRequestId); - - if (m.operation() == dbCompressed) { - m = uassertStatusOK(_compressorManager.decompressMessage(m)); - } + m = std::move(swm.getValue()); + uassert(40570, + "Response ID did not match the sent message ID.", + m.header().getResponseToMsgId() == lastRequestId); - return true; - } catch (const DBException&) { - _failed = true; - _port->shutdown(); - throw; + if (m.operation() == dbCompressed) { + m = uassertStatusOK(_compressorManager.decompressMessage(m)); } + + killSessionOnError.Dismiss(); + return true; } bool DBClientConnection::call(Message& toSend, Message& response, bool assertOk, string* actualServer) { - /* todo: this is very ugly messagingport::call returns an error code AND can throw - an exception. we should make it return void and just throw an exception anytime - it fails - */ checkConnection(); - try { - toSend.header().setId(nextMessageId()); - toSend.header().setResponseToMsgId(0); - auto swm = _compressorManager.compressMessage(toSend); - uassertStatusOK(swm.getStatus()); - - if (!port().call(swm.getValue(), response)) { - _failed = true; - if (assertOk) - uasserted(10278, - str::stream() << "dbclient error communicating with server: " - << getServerAddress()); - return false; - } + auto killSessionOnError = MakeGuard([this] { _markFailed(kEndSession); }); + auto maybeThrow = [&](const auto& errStatus) { + if (assertOk) + uasserted(10278, + str::stream() << "dbclient error communicating with server " + << getServerAddress() + << ": " + << redact(errStatus)); + return false; + }; - if (response.operation() == dbCompressed) { - response = uassertStatusOK(_compressorManager.decompressMessage(response)); - } - } catch (const DBException&) { - _failed = true; - _port->shutdown(); - throw; + toSend.header().setId(nextMessageId()); + toSend.header().setResponseToMsgId(0); + auto swm = _compressorManager.compressMessage(toSend); + uassertStatusOK(swm.getStatus()); + + auto sinkStatus = _session->sinkMessage(swm.getValue()); + if (!sinkStatus.isOK()) { + return maybeThrow(sinkStatus); + } + + swm = _session->sourceMessage(); + if (swm.isOK()) { + response = std::move(swm.getValue()); + } else { + return maybeThrow(swm.getStatus()); } + + if (response.operation() == dbCompressed) { + response = uassertStatusOK(_compressorManager.decompressMessage(response)); + } + + killSessionOnError.Dismiss(); return true; } @@ -1504,7 +1534,7 @@ void DBClientConnection::handleNotMasterResponse(const BSONObj& replyBody, << _parentReplSetName}); } - _failed = true; + _markFailed(kSetFlag); } AtomicInt32 DBClientConnection::_numConnections; diff --git a/src/mongo/client/dbclientinterface.h b/src/mongo/client/dbclientinterface.h index 86848f9ed93..457839a9fba 100644 --- a/src/mongo/client/dbclientinterface.h +++ b/src/mongo/client/dbclientinterface.h @@ -45,8 +45,9 @@ #include "mongo/rpc/unique_message.h" #include "mongo/stdx/functional.h" #include "mongo/transport/message_compressor_manager.h" +#include "mongo/transport/session.h" +#include "mongo/transport/transport_layer.h" #include "mongo/util/mongoutils/str.h" -#include "mongo/util/net/abstract_message_port.h" #include "mongo/util/net/message.h" #include "mongo/util/net/op_msg.h" @@ -969,9 +970,11 @@ public: return _failed; } - bool isStillConnected() { - return _port ? _port->isStillConnected() : true; - } + bool isStillConnected(); + + void setTags(transport::Session::TagMask tag); + + void shutdown(); void setWireVersions(int minWireVersion, int maxWireVersion) { _minWireVersion = minWireVersion; @@ -986,11 +989,6 @@ public: return _maxWireVersion; } - AbstractMessagingPort& port() { - verify(_port); - return *_port; - } - std::string toString() const { std::stringstream ss; ss << _serverAddress; @@ -1020,7 +1018,7 @@ public: } void setSoTimeout(double timeout); double getSoTimeout() const { - return _so_timeout; + return _socketTimeout.value_or(Milliseconds{0}).count() / 1000.0; } virtual bool lazySupported() const { @@ -1037,7 +1035,7 @@ public: */ void setParentReplSetName(const std::string& replSetName); - uint64_t getSockCreationMicroSec() const; + uint64_t getSockCreationMicroSec() const override; MessageCompressorManager& getCompressorManager() { return _compressorManager; @@ -1065,9 +1063,13 @@ protected: virtual void _auth(const BSONObj& params); - std::unique_ptr<AbstractMessagingPort> _port; + transport::SessionHandle _session; + boost::optional<Milliseconds> _socketTimeout; + transport::Session::TagMask _tagMask = transport::Session::kEmptyTagMask; + uint64_t _sessionCreationMicros = INVALID_SOCK_CREATION_TIME; + Date_t _lastConnectivityCheck; - bool _failed; + bool _failed = false; const bool autoReconnect; Backoff autoReconnectBackoff; @@ -1078,7 +1080,6 @@ protected: void _checkConnection(); std::map<std::string, BSONObj> authCache; - double _so_timeout; static AtomicInt32 _numConnections; @@ -1089,6 +1090,8 @@ private: * returned. */ void handleNotMasterResponse(const BSONObj& replyBody, StringData errorMsgFieldName); + enum FailAction { kSetFlag, kEndSession, kReleaseSession }; + void _markFailed(FailAction action); // Contains the string for the replica set name of the host this is connected to. // Should be empty if this connection is not pointing to a replica set member. diff --git a/src/mongo/db/client.h b/src/mongo/db/client.h index 86629690af2..40e77ac04ab 100644 --- a/src/mongo/db/client.h +++ b/src/mongo/db/client.h @@ -52,7 +52,6 @@ namespace mongo { -class AbstractMessagingPort; class Collection; class OperationContext; diff --git a/src/mongo/db/repl/oplogreader.cpp b/src/mongo/db/repl/oplogreader.cpp index 659e061c554..9f5f01021bf 100644 --- a/src/mongo/db/repl/oplogreader.cpp +++ b/src/mongo/db/repl/oplogreader.cpp @@ -82,8 +82,7 @@ bool OplogReader::connect(const HostAndPort& host) { error() << errmsg << endl; return false; } - _conn->port().setTag(_conn->port().getTag() | - executor::NetworkInterface::kMessagingPortKeepOpen); + _conn->setTags(executor::NetworkInterface::kMessagingPortKeepOpen); _host = host; } return true; diff --git a/src/mongo/dbtests/repltests.cpp b/src/mongo/dbtests/repltests.cpp index 6306c4ce0db..4e4937b2e71 100644 --- a/src/mongo/dbtests/repltests.cpp +++ b/src/mongo/dbtests/repltests.cpp @@ -52,6 +52,7 @@ #include "mongo/db/repl/replication_coordinator_mock.h" #include "mongo/db/repl/sync_tail.h" #include "mongo/dbtests/dbtests.h" +#include "mongo/transport/transport_layer_asio.h" #include "mongo/util/log.h" using namespace mongo::repl; @@ -109,6 +110,15 @@ public: if (mongo::storageGlobalParams.engine == "mobile") { return; } + + transport::TransportLayerASIO::Options opts; + opts.mode = transport::TransportLayerASIO::Options::kEgress; + auto sc = getGlobalServiceContext(); + + sc->setTransportLayer(std::make_unique<transport::TransportLayerASIO>(opts, nullptr)); + ASSERT_OK(sc->getTransportLayer()->setup()); + ASSERT_OK(sc->getTransportLayer()->start()); + ReplSettings replSettings; replSettings.setOplogSizeBytes(10 * 1024 * 1024); replSettings.setMaster(true); @@ -154,6 +164,7 @@ public: ->setFollowerMode(repl::MemberState::RS_PRIMARY) .ignore(); + getGlobalServiceContext()->getTransportLayer()->shutdown(); } catch (...) { FAIL("Exception while cleaning up test"); diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript index 722599088bb..83c784ff0be 100644 --- a/src/mongo/s/SConscript +++ b/src/mongo/s/SConscript @@ -409,6 +409,7 @@ env.CppUnitTest( LIBDEPS=[ '$BUILD_DIR/mongo/db/service_context_noop_init', '$BUILD_DIR/mongo/dbtests/mocklib', + '$BUILD_DIR/mongo/transport/transport_layer_egress_init', '$BUILD_DIR/mongo/util/net/network', 'client/sharding_connection_hook', 'sharding_legacy_api', diff --git a/src/mongo/shell/dbshell.cpp b/src/mongo/shell/dbshell.cpp index eef4859804a..c6e4724efc6 100644 --- a/src/mongo/shell/dbshell.cpp +++ b/src/mongo/shell/dbshell.cpp @@ -57,6 +57,7 @@ #include "mongo/shell/shell_utils.h" #include "mongo/shell/shell_utils_launcher.h" #include "mongo/stdx/utility.h" +#include "mongo/transport/transport_layer_asio.h" #include "mongo/util/exit.h" #include "mongo/util/file.h" #include "mongo/util/log.h" @@ -741,6 +742,18 @@ int _main(int argc, char* argv[], char** envp) { mongo::runGlobalInitializersOrDie(argc, argv, envp); + // TODO This should use a TransportLayerManager or TransportLayerFactory + auto serviceContext = getGlobalServiceContext(); + transport::TransportLayerASIO::Options opts; + opts.enableIPv6 = shellGlobalParams.enableIPv6; + opts.mode = transport::TransportLayerASIO::Options::kEgress; + + serviceContext->setTransportLayer( + std::make_unique<transport::TransportLayerASIO>(opts, nullptr)); + auto tlPtr = serviceContext->getTransportLayer(); + uassertStatusOK(tlPtr->setup()); + uassertStatusOK(tlPtr->start()); + // hide password from ps output for (int i = 0; i < (argc - 1); ++i) { if (!strcmp(argv[i], "-p") || !strcmp(argv[i], "--password")) { diff --git a/src/mongo/shell/shell_options.cpp b/src/mongo/shell/shell_options.cpp index 537b61c87cf..8e8ca5f4b3e 100644 --- a/src/mongo/shell/shell_options.cpp +++ b/src/mongo/shell/shell_options.cpp @@ -281,6 +281,7 @@ Status storeMongoShellOptions(const moe::Environment& params, #endif if (params.count("ipv6")) { mongo::enableIPv6(); + shellGlobalParams.enableIPv6 = true; } if (params.count("verbose")) { logger::globalLogDomain()->setMinimumLoggedSeverity(logger::LogSeverity::Debug(1)); diff --git a/src/mongo/shell/shell_options.h b/src/mongo/shell/shell_options.h index 4092c6a74df..1fe68e571a4 100644 --- a/src/mongo/shell/shell_options.h +++ b/src/mongo/shell/shell_options.h @@ -63,6 +63,7 @@ struct ShellGlobalParams { bool norc; bool nojit = true; bool javascriptProtection = true; + bool enableIPv6 = false; std::string script; diff --git a/src/mongo/tools/bridge.cpp b/src/mongo/tools/bridge.cpp index 5f9eac1e4e0..375c88bbc1b 100644 --- a/src/mongo/tools/bridge.cpp +++ b/src/mongo/tools/bridge.cpp @@ -49,6 +49,7 @@ #include "mongo/stdx/thread.h" #include "mongo/tools/bridge_commands.h" #include "mongo/tools/mongobridge_options.h" +#include "mongo/transport/transport_layer_asio.h" #include "mongo/util/assert_util.h" #include "mongo/util/exit.h" #include "mongo/util/log.h" @@ -91,33 +92,33 @@ public: : _mp(mp), _settingsMutex(settingsMutex), _settings(settings), _prng(seed) {} void operator()() { - DBClientConnection dest; - - { + transport::SessionHandle dest = []() -> transport::SessionHandle { HostAndPort destAddr{mongoBridgeGlobalParams.destUri}; const Seconds kConnectTimeout(30); - Timer connectTimer; - while (true) { - // DBClientConnection::connectSocketOnly() is used instead of - // DBClientConnection::connect() to avoid sending an isMaster command when the - // connection is established. We'd otherwise trigger a socket timeout when - // forwarding an _isSelf command because dest's replication subsystem hasn't been - // initialized yet and so it cannot respond to the isMaster command. - auto status = dest.connectSocketOnly(destAddr); - if (status.isOK()) { - break; - } - Seconds elapsed{connectTimer.seconds()}; - if (elapsed >= kConnectTimeout) { - warning() << "Unable to establish connection to " - << mongoBridgeGlobalParams.destUri << " after " << elapsed - << " seconds: " << status; - log() << "end connection " << _mp->remote().toString(); - _mp->shutdown(); - return; + auto now = getGlobalServiceContext()->getFastClockSource()->now(); + const auto connectExpiration = now + kConnectTimeout; + while (now < connectExpiration) { + auto tl = getGlobalServiceContext()->getTransportLayer(); + auto sws = + tl->connect(destAddr, transport::kGlobalSSLMode, connectExpiration - now); + auto status = sws.getStatus(); + if (!status.isOK()) { + warning() << "Unable to establish connection to " << destAddr << ": " << status; + now = getGlobalServiceContext()->getFastClockSource()->now(); + } else { + return std::move(sws.getValue()); } + sleepmillis(500); } + + return nullptr; + }(); + + if (!dest) { + log() << "end connection " << _mp->remote(); + _mp->shutdown(); + return; } bool receivingFirstMessage = true; @@ -231,8 +232,11 @@ public: request.operation() == dbCommand || request.operation() == dbMsg)) { // TODO dbMsg moreToCome // Forward the message to 'dest' and receive its reply in 'response'. - response.reset(); - dest.port().call(request, response); + uassertStatusOK(dest->sinkMessage(request)); + response = uassertStatusOK(dest->sourceMessage()); + uassert(50727, + "Response ID did not match the sent message ID.", + response.header().getResponseToMsgId() == request.header().getId()); // If there's nothing to respond back to '_mp' with, then close the connection. if (response.empty()) { @@ -280,15 +284,14 @@ public: MsgData::View header = response.header(); QueryResult::View qr = header.view2ptr(); if (qr.getCursorId()) { - response.reset(); - dest.port().recv(response); + response = uassertStatusOK(dest->sourceMessage()); _mp->say(response); } else { exhaust = false; } } } else { - dest.port().say(request); + uassertStatusOK(dest->sinkMessage(request)); } } catch (const DBException& ex) { error() << "Caught DBException in Forwarder: " << ex << ", end connection " @@ -409,6 +412,25 @@ int bridgeMain(int argc, char** argv, char** envp) { runGlobalInitializersOrDie(argc, argv, envp); startSignalProcessingThread(LogFileStatus::kNoLogFileToRotate); + auto serviceContext = getGlobalServiceContext(); + transport::TransportLayerASIO::Options opts; + opts.mode = mongo::transport::TransportLayerASIO::Options::kEgress; + + serviceContext->setTransportLayer( + std::make_unique<mongo::transport::TransportLayerASIO>(opts, nullptr)); + auto tl = serviceContext->getTransportLayer(); + if (!tl->setup().isOK()) { + log() << "Error setting up transport layer"; + return EXIT_NET_ERROR; + } + + if (!tl->start().isOK()) { + log() << "Error starting transport layer"; + return EXIT_NET_ERROR; + } + + serviceContext->notifyStartupComplete(); + listener = stdx::make_unique<BridgeListener>(); listener->setupSockets(); listener->initAndListen(); diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 8527328b08c..f5cce983eac 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -67,6 +67,19 @@ tlEnv.Library( ], ) +# This library will initialize an egress transport layer in a mongo initializer +# for C++ tests that require networking. +env.Library( + target='transport_layer_egress_init', + source=[ + 'transport_layer_egress_init.cpp', + ], + LIBDEPS_PRIVATE=[ + 'transport_layer', + '$BUILD_DIR/mongo/db/service_context_noop_init', + ] +) + tlEnv.CppUnitTest( target='transport_layer_asio_test', source=[ diff --git a/src/mongo/transport/asio_utils.h b/src/mongo/transport/asio_utils.h index 89e7821f5b8..e5647a19264 100644 --- a/src/mongo/transport/asio_utils.h +++ b/src/mongo/transport/asio_utils.h @@ -30,9 +30,14 @@ #include "mongo/base/status.h" #include "mongo/base/system_error.h" +#include "mongo/util/errno_util.h" #include "mongo/util/net/hostandport.h" #include "mongo/util/net/sockaddr.h" +#ifndef _WIN32 +#include <sys/poll.h> +#endif // ndef _WIN32 + #include <asio.hpp> namespace mongo { @@ -60,6 +65,9 @@ inline Status errorCodeToStatus(const std::error_code& ec) { if (ec == asio::error::try_again || ec == asio::error::would_block) { #endif return {ErrorCodes::NetworkTimeout, "Socket operation timed out"}; + } else if (ec == asio::error::eof || ec == asio::error::connection_reset || + ec == asio::error::network_reset) { + return {ErrorCodes::HostUnreachable, "Connection was closed"}; } // If the ec.category() is a mongoErrorCategory() then this error was propogated from @@ -73,5 +81,90 @@ inline Status errorCodeToStatus(const std::error_code& ec) { return {errorCode, ec.message()}; } +/* + * The ASIO implementation of poll (i.e. socket.wait()) cannot poll for a mask of events, and + * doesn't support timeouts. + * + * This wraps up ::select/::poll for Windows/POSIX for a single socket and handles EINTR on POSIX + * + * - On timeout: it returns Status(ErrorCodes::NetworkTimeout) + * - On poll returning with an event: it returns the EventsMask for the socket, the caller must + * check whether it matches the expected events mask. + * - On error: it returns a Status(ErrorCodes::InternalError) + */ +template <typename Socket, typename EventsMask> +StatusWith<EventsMask> pollASIOSocket(Socket& socket, EventsMask mask, Milliseconds timeout) { +#ifdef _WIN32 + fd_set readfds; + fd_set writefds; + fd_set errfds; + + FD_ZERO(&readfds); + FD_ZERO(&writefds); + FD_ZERO(&errfds); + + auto fd = socket.native_handle(); + if (mask & POLLIN) { + FD_SET(fd, &readfds); + } + if (mask & POLLOUT) { + FD_SET(fd, &writefds); + } + FD_SET(fd, &errfds); + + timeval timeoutTv{}; + auto timeoutUs = duration_cast<Microseconds>(timeout); + if (timeoutUs >= Seconds{1}) { + auto timeoutSec = duration_cast<Seconds>(timeoutUs); + timeoutTv.tv_sec = timeoutSec.count(); + timeoutUs -= timeoutSec; + } + timeoutTv.tv_usec = timeoutUs.count(); + int result = ::select(1, &readfds, &writefds, &errfds, &timeoutTv); + if (result == SOCKET_ERROR) { + auto errDesc = errnoWithDescription(WSAGetLastError()); + return {ErrorCodes::InternalError, errDesc}; + } + int revents = (FD_ISSET(fd, &readfds) ? POLLIN : 0) | (FD_ISSET(fd, &writefds) ? POLLOUT : 0) | + (FD_ISSET(fd, &errfds) ? POLLERR : 0); +#else + pollfd pollItem; + pollItem.fd = socket.native_handle(); + pollItem.events = mask; + + int result; + boost::optional<Date_t> expiration; + if (timeout.count() > 0) { + expiration = Date_t::now() + timeout; + } + do { + Milliseconds curTimeout; + if (expiration) { + curTimeout = *expiration - Date_t::now(); + if (curTimeout.count() <= 0) { + result = 0; + break; + } + } else { + curTimeout = timeout; + } + result = ::poll(&pollItem, 1, curTimeout.count()); + } while (result == -1 && errno == EINTR); + + if (result == -1) { + int errCode = errno; + return {ErrorCodes::InternalError, errnoWithDescription(errCode)}; + } + int revents = pollItem.revents; +#endif + + if (result == 0) { + return {ErrorCodes::NetworkTimeout, "Timed out waiting for poll"}; + } else { + return revents; + } +} + + } // namespace transport } // namespace mongo diff --git a/src/mongo/transport/mock_session.h b/src/mongo/transport/mock_session.h index 411e57bd0ec..c9d519f6288 100644 --- a/src/mongo/transport/mock_session.h +++ b/src/mongo/transport/mock_session.h @@ -105,6 +105,10 @@ public: void setTimeout(boost::optional<Milliseconds>) override {} + bool isConnected() override { + return true; + } + explicit MockSession(TransportLayer* tl) : _tl(checked_cast<TransportLayerMock*>(tl)), _remote(), _local() {} explicit MockSession(HostAndPort remote, HostAndPort local, TransportLayer* tl) diff --git a/src/mongo/transport/session.h b/src/mongo/transport/session.h index db80aa0814d..c9438269ec7 100644 --- a/src/mongo/transport/session.h +++ b/src/mongo/transport/session.h @@ -122,6 +122,16 @@ public: */ virtual void setTimeout(boost::optional<Milliseconds> timeout) = 0; + /** + * This will return whether calling sourceMessage()/sinkMessage() will fail with an EOF error. + * + * Implementations may actually perform some I/O or call syscalls to determine this, rather + * than just checking a flag. + * + * This must not be called while the session is currently sourcing or sinking a message. + */ + virtual bool isConnected() = 0; + virtual const HostAndPort& remote() const = 0; virtual const HostAndPort& local() const = 0; diff --git a/src/mongo/transport/session_asio.h b/src/mongo/transport/session_asio.h index f18dfac6fe1..a1100eed535 100644 --- a/src/mongo/transport/session_asio.h +++ b/src/mongo/transport/session_asio.h @@ -94,7 +94,7 @@ public: std::error_code ec; getSocket().cancel(); getSocket().shutdown(GenericSocket::shutdown_both, ec); - if (ec) { + if ((ec) && (ec != asio::error::not_connected)) { error() << "Error shutting down socket: " << ec.message(); } } @@ -163,6 +163,120 @@ public: _configuredTimeout = timeout; } + bool isConnected() override { + // socket.is_open() only returns whether the socket is a valid file descriptor and + // if we haven't marked this socket as closed already. + if (!getSocket().is_open()) + return false; + + auto swPollEvents = pollASIOSocket(getSocket(), POLLIN, Milliseconds{0}); + if (!swPollEvents.isOK()) { + if (swPollEvents != ErrorCodes::NetworkTimeout) { + warning() << "Failed to poll socket for connectivity check: " + << swPollEvents.getStatus(); + return false; + } + return true; + } + + auto revents = swPollEvents.getValue(); + if (revents & POLLIN) { + char testByte; + int size = ::recv(getSocket().native_handle(), &testByte, sizeof(testByte), MSG_PEEK); + if (size == sizeof(testByte)) { + return true; + } else if (size == -1) { + auto errDesc = errnoWithDescription(errno); + warning() << "Failed to check socket connectivity: " << errDesc; + } + // If size == 0 then we got disconnected and we should return false. + } + + return false; + } + +protected: + friend class TransportLayerASIO; + +#ifdef MONGO_CONFIG_SSL + template <typename HandshakeCb> + void handshakeSSLForEgress(HostAndPort target, HandshakeCb onComplete) { + if (!_tl->_egressSSLContext) { + return onComplete( + {ErrorCodes::SSLHandshakeFailed, "SSL requested but SSL support is disabled"}); + } + + _sslSocket.emplace(std::move(_socket), *_tl->_egressSSLContext); + auto handshakeCompleteCb = + [ this, target = std::move(target), onComplete = std::move(onComplete) ]( + const std::error_code& ec) { + _ranHandshake = true; + if (ec) { + onComplete(errorCodeToStatus(ec)); + return; + } + + auto sslManager = getSSLManager(); + auto swPeerInfo = sslManager->parseAndValidatePeerCertificate( + _sslSocket->native_handle(), target.host()); + if (!swPeerInfo.isOK()) { + onComplete(swPeerInfo.getStatus()); + return; + } + + if (swPeerInfo.getValue()) { + SSLPeerInfo::forSession(shared_from_this()) = std::move(*swPeerInfo.getValue()); + } + + onComplete(Status::OK()); + }; + if (_blockingMode == Sync) { + std::error_code ec; + _sslSocket->handshake(asio::ssl::stream_base::client, ec); + handshakeCompleteCb(ec); + } else { + return _sslSocket->async_handshake(asio::ssl::stream_base::client, + std::move(handshakeCompleteCb)); + } + } +#endif + + void ensureSync() { + asio::error_code ec; + if (_blockingMode != Sync) { + getSocket().non_blocking(false, ec); + fassertStatusOK(40490, errorCodeToStatus(ec)); + _blockingMode = Sync; + } + + if (_socketTimeout != _configuredTimeout) { + // Change boost::none (which means no timeout) into a zero value for the socket option, + // which also means no timeout. + auto timeout = _configuredTimeout.value_or(Milliseconds{0}); + getSocket().set_option(ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), ec); + uassertStatusOK(errorCodeToStatus(ec)); + + getSocket().set_option(ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout), ec); + uassertStatusOK(errorCodeToStatus(ec)); + + _socketTimeout = _configuredTimeout; + } + } + + void ensureAsync() { + if (_blockingMode == Async) + return; + + // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't + // expecting a socket timeout when they do an async operation. + invariant(!_configuredTimeout); + + asio::error_code ec; + getSocket().non_blocking(true, ec); + fassertStatusOK(50706, errorCodeToStatus(ec)); + _blockingMode = Async; + } + private: template <int Name> class ASIOSocketTimeoutOption { @@ -215,14 +329,6 @@ private: return _socket; } - bool isOpen() const { -#ifdef MONGO_CONFIG_SSL - return _sslSocket ? _sslSocket->lowest_layer().is_open() : _socket.is_open(); -#else - return _socket.is_open(); -#endif - } - template <typename Callback> void sourceMessageImpl(Callback&& cb) { static constexpr auto kHeaderSize = sizeof(MSGHEADER::Value); @@ -300,7 +406,7 @@ private: return; } - maybeHandshakeSSL(buffers, std::move(postHandshakeCb)); + maybeHandshakeSSLForIngress(buffers, std::move(postHandshakeCb)); }; return opportunisticRead(_socket, buffers, std::move(handshakeRecvCb)); @@ -312,6 +418,7 @@ private: template <typename ConstBufferSequence, typename CompleteHandler> void write(const ConstBufferSequence& buffers, CompleteHandler&& handler) { #ifdef MONGO_CONFIG_SSL + _ranHandshake = true; if (_sslSocket) { return opportunisticWrite(*_sslSocket, buffers, std::forward<CompleteHandler>(handler)); } @@ -319,42 +426,6 @@ private: return opportunisticWrite(_socket, buffers, std::forward<CompleteHandler>(handler)); } - void ensureSync() { - asio::error_code ec; - if (_blockingMode != Sync) { - getSocket().non_blocking(false, ec); - fassertStatusOK(40490, errorCodeToStatus(ec)); - _blockingMode = Sync; - } - - if (_socketTimeout != _configuredTimeout) { - // Change boost::none (which means no timeout) into a zero value for the socket option, - // which also means no timeout. - auto timeout = _configuredTimeout.value_or(Milliseconds{0}); - getSocket().set_option(ASIOSocketTimeoutOption<SO_SNDTIMEO>(timeout), ec); - uassertStatusOK(errorCodeToStatus(ec)); - - getSocket().set_option(ASIOSocketTimeoutOption<SO_RCVTIMEO>(timeout), ec); - uassertStatusOK(errorCodeToStatus(ec)); - - _socketTimeout = _configuredTimeout; - } - } - - void ensureAsync() { - if (_blockingMode == Async) - return; - - // Socket timeouts currently only effect synchronous calls, so make sure the caller isn't - // expecting a socket timeout when they do an async operation. - invariant(!_configuredTimeout); - - asio::error_code ec; - getSocket().non_blocking(true, ec); - fassertStatusOK(50706, errorCodeToStatus(ec)); - _blockingMode = Async; - } - template <typename Stream, typename MutableBufferSequence, typename CompleteHandler> void opportunisticRead(Stream& stream, const MutableBufferSequence& buffers, @@ -411,7 +482,7 @@ private: #ifdef MONGO_CONFIG_SSL template <typename MutableBufferSequence, typename HandshakeCb> - void maybeHandshakeSSL(const MutableBufferSequence& buffer, HandshakeCb onComplete) { + void maybeHandshakeSSLForIngress(const MutableBufferSequence& buffer, HandshakeCb onComplete) { invariant(asio::buffer_size(buffer) >= sizeof(MSGHEADER::Value)); MSGHEADER::ConstView headerView(asio::buffer_cast<char*>(buffer)); auto responseTo = headerView.getResponseToMsgId(); @@ -424,15 +495,14 @@ private: // protocol message needs to be 0 or -1. Otherwise the connection is either sending // garbage or a TLS Hello packet which will be caught by the TLS handshake. if (responseTo != 0 && responseTo != -1) { - if (!_tl->_sslContext) { + if (!_tl->_ingressSSLContext) { return onComplete( {ErrorCodes::SSLHandshakeFailed, "SSL handshake received but server is started without SSL support"}, false); } - _sslSocket.emplace(std::move(_socket), *_tl->_sslContext); - + _sslSocket.emplace(std::move(_socket), *_tl->_ingressSSLContext); auto handshakeCompleteCb = [ this, onComplete = std::move(onComplete) ]( const std::error_code& ec, size_t size) mutable { auto& sslPeerInfo = SSLPeerInfo::forSession(shared_from_this()); diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 053aa5f6c15..29b4a7bb72c 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -37,6 +37,8 @@ namespace mongo { namespace transport { +enum ConnectSSLMode { kGlobalSSLMode, kEnableSSL, kDisableSSL }; + /** * The TransportLayer moves Messages between transport::Endpoints and the database. * This class owns an Acceptor that generates new endpoints from which it can @@ -63,6 +65,15 @@ public: virtual ~TransportLayer() = default; + virtual StatusWith<SessionHandle> connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) = 0; + + virtual void asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) = 0; + /** * Start the TransportLayer. After this point, the TransportLayer will begin accepting active * sessions from new transport::Endpoints. diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index 8e437ef949e..1f5d45e053b 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -36,9 +36,6 @@ #include <boost/algorithm/string.hpp> #include "mongo/config.h" -#ifdef MONGO_CONFIG_SSL -#include "mongo/util/net/ssl.hpp" -#endif #include "mongo/base/system_error.h" #include "mongo/db/server_options.h" @@ -53,6 +50,10 @@ #include "mongo/util/net/ssl_manager.h" #include "mongo/util/net/ssl_options.h" +#ifdef MONGO_CONFIG_SSL +#include "mongo/util/net/ssl.hpp" +#endif + // session_asio.h has some header dependencies that require it to be the last header. #include "mongo/transport/session_asio.h" @@ -74,7 +75,8 @@ TransportLayerASIO::TransportLayerASIO(const TransportLayerASIO::Options& opts, : _workerIOContext(std::make_shared<asio::io_context>()), _acceptorIOContext(stdx::make_unique<asio::io_context>()), #ifdef MONGO_CONFIG_SSL - _sslContext(nullptr), + _ingressSSLContext(nullptr), + _egressSSLContext(nullptr), #endif _sep(sep), _listenerOptions(opts) { @@ -82,24 +84,164 @@ TransportLayerASIO::TransportLayerASIO(const TransportLayerASIO::Options& opts, TransportLayerASIO::~TransportLayerASIO() = default; +StatusWith<SessionHandle> TransportLayerASIO::connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) { + std::error_code ec; + GenericSocket sock(*_workerIOContext); +#ifndef _WIN32 + if (mongoutils::str::contains(peer.host(), '/')) { + invariant(!peer.hasPort()); + auto res = + _doSyncConnect(asio::local::stream_protocol::endpoint(peer.host()), peer, timeout); + if (!res.isOK()) { + return res.getStatus(); + } else { + return static_cast<SessionHandle>(std::move(res.getValue())); + } + } +#endif + + using Resolver = asio::ip::tcp::resolver; + Resolver resolver(*_workerIOContext); + std::string portNumberStr = std::to_string(peer.port()); + auto doResolve = [&](auto resolverFlags) -> StatusWith<Resolver::iterator> { + // If IPv6 is disabled, then we should specify that we only want IPv4 addresses, otherwise + // we should do a normal AF_UNSPEC resolution to get both IPv4/IPv6 + Resolver::iterator resolverIt; + if (_listenerOptions.enableIPv6) { + resolverIt = resolver.resolve(peer.host(), portNumberStr, resolverFlags, ec); + } else { + resolverIt = resolver.resolve( + asio::ip::tcp::v4(), peer.host(), portNumberStr, resolverFlags, ec); + } + + if (ec) { + return {ErrorCodes::HostNotFound, + str::stream() << "Could not find address for " << peer.host() << ": " + << ec.message()}; + } else if (resolverIt == Resolver::iterator()) { + return {ErrorCodes::HostNotFound, + str::stream() << "Could not find address for " << peer.host()}; + } + + return resolverIt; + }; + + // We always want to resolve the "service" (port number) as a numeric. + // + // We intentionally don't set the Resolver::address_configured flag because it might prevent us + // from connecting to localhost on hosts with only a loopback interface (see SERVER-1579). + const auto resolverFlags = Resolver::numeric_service; + + // We resolve in two steps, the first step tries to resolve the hostname as an IP address - + // that way if there's a DNS timeout, we can still connect to IP addresses quickly. + // (See SERVER-1709) + // + // Then, if the numeric (IP address) lookup failed, we fall back to DNS or return the error + // from the resolver. + auto swResolverIt = doResolve(resolverFlags | Resolver::numeric_host); + if (!swResolverIt.isOK()) { + if (swResolverIt == ErrorCodes::HostNotFound) { + swResolverIt = doResolve(resolverFlags); + if (!swResolverIt.isOK()) { + return swResolverIt.getStatus(); + } + } else { + return swResolverIt.getStatus(); + } + } + + auto& resolverIt = swResolverIt.getValue(); + auto sws = _doSyncConnect(resolverIt->endpoint(), peer, timeout); + if (!sws.isOK()) { + return sws.getStatus(); + } + + auto session = std::move(sws.getValue()); + session->ensureSync(); + +#ifndef MONGO_CONFIG_SSL + if (sslMode == kEnableSSL) { + return {ErrorCodes::InvalidSSLConfiguration, "SSL requested but not supported"}; + } +#else + auto globalSSLMode = _sslMode(); + if (sslMode == kEnableSSL || + (sslMode == kGlobalSSLMode && ((globalSSLMode == SSLParams::SSLMode_preferSSL) || + (globalSSLMode == SSLParams::SSLMode_requireSSL)))) { + Status sslStatus = Status::OK(); + auto onComplete = [&sslStatus](Status status) { sslStatus = status; }; + session->handshakeSSLForEgress(peer, std::move(onComplete)); + if (!sslStatus.isOK()) { + return sslStatus; + } + } +#endif + + return static_cast<SessionHandle>(std::move(session)); +} + +template <typename Endpoint> +StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncConnect( + Endpoint endpoint, const HostAndPort& peer, const Milliseconds& timeout) { + GenericSocket sock(*_workerIOContext); + std::error_code ec; + sock.open(endpoint.protocol()); + sock.non_blocking(true); + + auto now = Date_t::now(); + auto expiration = now + timeout; + do { + auto curTimeout = expiration - now; + sock.connect(endpoint, curTimeout.toSystemDuration(), ec); + if (ec) { + now = Date_t::now(); + } + // We loop below if ec == interrupted to deal with EINTR failures, otherwise we handle + // the error/timeout below. + } while (ec == asio::error::interrupted && now < expiration); + + if (ec) { + return errorCodeToStatus(ec); + } else if (now >= expiration) { + return {ErrorCodes::NetworkTimeout, str::stream() << "Timed out connecting to " << peer}; + } + + sock.non_blocking(false); + return std::make_shared<ASIOSession>(this, std::move(sock)); +} + +void TransportLayerASIO::asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) { + MONGO_UNREACHABLE; +} + Status TransportLayerASIO::setup() { std::vector<std::string> listenAddrs; - if (_listenerOptions.ipList.empty()) { + if (_listenerOptions.ipList.empty() && _listenerOptions.isIngress()) { listenAddrs = {"127.0.0.1"}; if (_listenerOptions.enableIPv6) { listenAddrs.emplace_back("::1"); } - } else { + } else if (!_listenerOptions.ipList.empty()) { boost::split( listenAddrs, _listenerOptions.ipList, boost::is_any_of(","), boost::token_compress_on); } #ifndef _WIN32 - if (_listenerOptions.useUnixSockets) { + if (_listenerOptions.useUnixSockets && _listenerOptions.isIngress()) { listenAddrs.emplace_back(makeUnixSockPath(_listenerOptions.port)); } #endif + if (!(_listenerOptions.isIngress()) && !listenAddrs.empty()) { + return {ErrorCodes::BadValue, + "Cannot bind to listening sockets with ingress networking is disabled"}; + } + _listenerPort = _listenerOptions.port; for (auto& ip : listenAddrs) { @@ -177,20 +319,32 @@ Status TransportLayerASIO::setup() { } } - if (_acceptors.empty()) { + if (_acceptors.empty() && _listenerOptions.isIngress()) { return Status(ErrorCodes::SocketException, "No available addresses/ports to bind to"); } #ifdef MONGO_CONFIG_SSL const auto& sslParams = getSSLGlobalParams(); + auto sslManager = getSSLManager(); - if (_sslMode() != SSLParams::SSLMode_disabled) { - _sslContext = stdx::make_unique<asio::ssl::context>(asio::ssl::context::sslv23); + if (_sslMode() != SSLParams::SSLMode_disabled && _listenerOptions.isIngress()) { + _ingressSSLContext = stdx::make_unique<asio::ssl::context>(asio::ssl::context::sslv23); Status status = - getSSLManager()->initSSLContext(_sslContext->native_handle(), - sslParams, - SSLManagerInterface::ConnectionDirection::kIncoming); + sslManager->initSSLContext(_ingressSSLContext->native_handle(), + sslParams, + SSLManagerInterface::ConnectionDirection::kIncoming); + if (!status.isOK()) { + return status; + } + } + + if (_listenerOptions.isEgress() && sslManager) { + _egressSSLContext = stdx::make_unique<asio::ssl::context>(asio::ssl::context::sslv23); + Status status = + sslManager->initSSLContext(_egressSSLContext->native_handle(), + sslParams, + SSLManagerInterface::ConnectionDirection::kOutgoing); if (!status.isOK()) { return status; } @@ -204,31 +358,35 @@ Status TransportLayerASIO::start() { stdx::lock_guard<stdx::mutex> lk(_mutex); _running.store(true); - _listenerThread = stdx::thread([this] { - setThreadName("listener"); - while (_running.load()) { - asio::io_context::work work(*_acceptorIOContext); - try { - _acceptorIOContext->run(); - } catch (...) { - severe() << "Uncaught exception in the listener: " << exceptionToStatus(); - fassertFailed(40491); - } + if (_listenerOptions.isIngress()) { + for (auto& acceptor : _acceptors) { + acceptor.second.listen(serverGlobalParams.listenBacklog); + _acceptConnection(acceptor.second); } - }); - for (auto& acceptor : _acceptors) { - acceptor.second.listen(serverGlobalParams.listenBacklog); - _acceptConnection(acceptor.second); - } + _listenerThread = stdx::thread([this] { + setThreadName("listener"); + while (_running.load()) { + asio::io_context::work work(*_acceptorIOContext); + try { + _acceptorIOContext->run(); + } catch (...) { + severe() << "Uncaught exception in the listener: " << exceptionToStatus(); + fassertFailed(40491); + } + } + }); - const char* ssl = ""; + const char* ssl = ""; #ifdef MONGO_CONFIG_SSL - if (_sslMode() != SSLParams::SSLMode_disabled) { - ssl = " ssl"; - } + if (_sslMode() != SSLParams::SSLMode_disabled) { + ssl = " ssl"; + } #endif - log() << "waiting for connections on port " << _listenerPort << ssl; + log() << "waiting for connections on port " << _listenerPort << ssl; + } else { + invariant(_acceptors.empty()); + } return Status::OK(); } diff --git a/src/mongo/transport/transport_layer_asio.h b/src/mongo/transport/transport_layer_asio.h index 62c0e02edae..da2bdab547e 100644 --- a/src/mongo/transport/transport_layer_asio.h +++ b/src/mongo/transport/transport_layer_asio.h @@ -77,6 +77,19 @@ public: explicit Options(const ServerGlobalParams* params); Options() = default; + constexpr static auto kIngress = 0x1; + constexpr static auto kEgress = 0x10; + + int mode = kIngress | kEgress; + + bool isIngress() const { + return mode & kIngress; + } + + bool isEgress() const { + return mode & kEgress; + } + int port = ServerGlobalParams::DefaultDBPort; // port to bind to std::string ipList; // addresses to bind to #ifndef _WIN32 @@ -92,6 +105,14 @@ public: virtual ~TransportLayerASIO(); + StatusWith<SessionHandle> connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) final; + void asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) final; + Status setup() final; Status start() final; @@ -111,6 +132,12 @@ private: using GenericAcceptor = asio::basic_socket_acceptor<asio::generic::stream_protocol>; void _acceptConnection(GenericAcceptor& acceptor); + + template <typename Endpoint> + StatusWith<ASIOSessionHandle> _doSyncConnect(Endpoint endpoint, + const HostAndPort& peer, + const Milliseconds& timeout); + #ifdef MONGO_CONFIG_SSL SSLParams::SSLModes _sslMode() const; #endif @@ -144,7 +171,8 @@ private: std::unique_ptr<asio::io_context> _acceptorIOContext; #ifdef MONGO_CONFIG_SSL - std::unique_ptr<asio::ssl::context> _sslContext; + std::unique_ptr<asio::ssl::context> _ingressSSLContext; + std::unique_ptr<asio::ssl::context> _egressSSLContext; #endif std::vector<std::pair<SockAddr, GenericAcceptor>> _acceptors; diff --git a/src/mongo/transport/transport_layer_egress_init.cpp b/src/mongo/transport/transport_layer_egress_init.cpp new file mode 100644 index 00000000000..20a4fe5a47a --- /dev/null +++ b/src/mongo/transport/transport_layer_egress_init.cpp @@ -0,0 +1,58 @@ +/** + * Copyright 2018 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects + * for all of the code used other than as permitted herein. If you modify + * file(s) with this exception, you may extend this exception to your + * version of the file(s), but you are not obligated to do so. If you do not + * wish to do so, delete this exception statement from your version. If you + * delete this exception statement from all source files in the program, + * then also delete it in the license file. + */ + +#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kDefault + +#include "mongo/platform/basic.h" + +#include "mongo/base/init.h" +#include "mongo/db/service_context.h" +#include "mongo/transport/transport_layer_asio.h" + +namespace mongo { +namespace { +// Linking with this file will configure an egress-only TransportLayer on a ServiceContextNoop. +// Use this for unit/integration tests that require only egress networking. +MONGO_INITIALIZER_WITH_PREREQUISITES(ConfigureEgressTransportLayer, ("SetGlobalEnvironment")) +(InitializerContext* context) { + auto sc = getGlobalServiceContext(); + invariant(!sc->getTransportLayer()); + + transport::TransportLayerASIO::Options opts; + opts.mode = transport::TransportLayerASIO::Options::kEgress; + sc->setTransportLayer(std::make_unique<transport::TransportLayerASIO>(opts, nullptr)); + auto status = sc->getTransportLayer()->setup(); + if (!status.isOK()) { + return status; + } + + return sc->getTransportLayer()->start(); +} + +} // namespace +} // namespace diff --git a/src/mongo/transport/transport_layer_manager.cpp b/src/mongo/transport/transport_layer_manager.cpp index 28bcf351e51..f53250e91bd 100644 --- a/src/mongo/transport/transport_layer_manager.cpp +++ b/src/mongo/transport/transport_layer_manager.cpp @@ -59,6 +59,19 @@ void TransportLayerManager::_foreach(Callable&& cb) const { } } +StatusWith<SessionHandle> TransportLayerManager::connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) { + return _tls.front()->connect(peer, sslMode, timeout); +} + +void TransportLayerManager::asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) { + MONGO_UNREACHABLE; +} + // TODO Right now this and setup() leave TLs started if there's an error. In practice the server // exits with an error and this isn't an issue, but we should make this more robust. Status TransportLayerManager::start() { diff --git a/src/mongo/transport/transport_layer_manager.h b/src/mongo/transport/transport_layer_manager.h index 7c80935f00a..70f54b10652 100644 --- a/src/mongo/transport/transport_layer_manager.h +++ b/src/mongo/transport/transport_layer_manager.h @@ -57,6 +57,14 @@ public: : _tls(std::move(tls)) {} TransportLayerManager(); + StatusWith<SessionHandle> connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) override; + void asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) override; + Status start() override; void shutdown() override; Status setup() override; diff --git a/src/mongo/transport/transport_layer_mock.cpp b/src/mongo/transport/transport_layer_mock.cpp index ec1a28e6777..0fc6ce7964c 100644 --- a/src/mongo/transport/transport_layer_mock.cpp +++ b/src/mongo/transport/transport_layer_mock.cpp @@ -62,6 +62,19 @@ bool TransportLayerMock::owns(Session::Id id) { return _sessions.count(id) > 0; } +StatusWith<SessionHandle> TransportLayerMock::connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) { + MONGO_UNREACHABLE; +} + +void TransportLayerMock::asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) { + MONGO_UNREACHABLE; +} + Status TransportLayerMock::setup() { return Status::OK(); } diff --git a/src/mongo/transport/transport_layer_mock.h b/src/mongo/transport/transport_layer_mock.h index 7a861df129a..06a9cd3f37c 100644 --- a/src/mongo/transport/transport_layer_mock.h +++ b/src/mongo/transport/transport_layer_mock.h @@ -53,6 +53,14 @@ public: SessionHandle get(Session::Id id); bool owns(Session::Id id); + StatusWith<SessionHandle> connect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout) override; + void asyncConnect(HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + std::function<void(StatusWith<SessionHandle>)> callback) override; + Status setup() override; Status start() override; void shutdown() override; diff --git a/src/mongo/unittest/SConscript b/src/mongo/unittest/SConscript index 60315dcc5bd..e00502d94f7 100644 --- a/src/mongo/unittest/SConscript +++ b/src/mongo/unittest/SConscript @@ -32,6 +32,9 @@ env.Library(target="integration_test_main", '$BUILD_DIR/mongo/client/connection_string', '$BUILD_DIR/mongo/util/options_parser/options_parser_init', ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/transport/transport_layer_egress_init', + ], ) bmEnv = env.Clone() diff --git a/src/mongo/unittest/integration_test_main.cpp b/src/mongo/unittest/integration_test_main.cpp index 69743f6c4e2..aaaf12f8c23 100644 --- a/src/mongo/unittest/integration_test_main.cpp +++ b/src/mongo/unittest/integration_test_main.cpp @@ -36,6 +36,8 @@ #include "mongo/base/initializer.h" #include "mongo/client/connection_string.h" +#include "mongo/db/service_context.h" +#include "mongo/transport/transport_layer_asio.h" #include "mongo/unittest/unittest.h" #include "mongo/util/log.h" #include "mongo/util/options_parser/environment.h" diff --git a/src/third_party/asio-master/asio/include/asio/basic_socket.hpp b/src/third_party/asio-master/asio/include/asio/basic_socket.hpp index 43430161270..d224bac0a83 100644 --- a/src/third_party/asio-master/asio/include/asio/basic_socket.hpp +++ b/src/third_party/asio-master/asio/include/asio/basic_socket.hpp @@ -760,7 +760,7 @@ public: peer_endpoint.protocol(), ec); asio::detail::throw_error(ec, "connect"); } - this->get_service().connect(this->get_implementation(), peer_endpoint, ec); + this->get_service().connect(this->get_implementation(), peer_endpoint, -1, ec); asio::detail::throw_error(ec, "connect"); } @@ -805,7 +805,63 @@ public: } } - this->get_service().connect(this->get_implementation(), peer_endpoint, ec); + this->get_service().connect(this->get_implementation(), peer_endpoint, -1, ec); + ASIO_SYNC_OP_VOID_RETURN(ec); + } + + /// Connect the socket to the specified endpoint with a timeout. + /** + * This function is used to connect a socket to the specified remote endpoint. + * The function call will block until the connection is successfully made or + * an error occurs. + * + * The socket is automatically opened if it is not already open. If the + * connect fails, and the socket was automatically opened, the socket is + * not returned to the closed state. + * + * Passing a timeout of less than zero will return an invalid_argument error. + * + * @param peer_endpoint The remote endpoint to which the socket will be + * connected. + * + * @param timeout The time to wait for the connection before failing + * + * @param ec Set to indicate what error occurred, if any. + * + * @par Example + * @code + * asio::ip::tcp::socket socket(io_context); + * asio::ip::tcp::endpoint endpoint( + * asio::ip::address::from_string("1.2.3.4"), 12345); + * asio::error_code ec; + * socket.connect(endpoint, std::chrono::seconds{30}, ec); + * if (ec) + * { + * // An error occurred. + * } + * @endcode + */ + template <typename Duration> + ASIO_SYNC_OP_VOID connect(const endpoint_type& peer_endpoint, + Duration timeout, asio::error_code& ec) + { + if (!is_open()) + { + this->get_service().open(this->get_implementation(), + peer_endpoint.protocol(), ec); + if (ec) + { + ASIO_SYNC_OP_VOID_RETURN(ec); + } + } + + auto timeout_ms = std::chrono::duration_cast<std::chrono::milliseconds>(timeout); + if (timeout_ms.count() < 0) + { + ec = asio::error::invalid_argument; + ASIO_SYNC_OP_VOID_RETURN(ec); + } + this->get_service().connect(this->get_implementation(), peer_endpoint, timeout_ms.count(), ec); ASIO_SYNC_OP_VOID_RETURN(ec); } diff --git a/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp b/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp index 2f89889fac8..58ad04ea66f 100644 --- a/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp +++ b/src/third_party/asio-master/asio/include/asio/detail/impl/socket_ops.ipp @@ -491,7 +491,7 @@ int connect(socket_type s, const socket_addr_type* addr, } void sync_connect(socket_type s, const socket_addr_type* addr, - std::size_t addrlen, asio::error_code& ec) + std::size_t addrlen, int timeout_ms, asio::error_code& ec) { // Perform the connect operation. socket_ops::connect(s, addr, addrlen, ec); @@ -503,8 +503,15 @@ void sync_connect(socket_type s, const socket_addr_type* addr, } // Wait for socket to become ready. - if (socket_ops::poll_connect(s, -1, ec) < 0) + int res = socket_ops::poll_connect(s, timeout_ms, ec); + if (res < 0) + return; + + if (res == 0) + { + ec = asio::error::timed_out; return; + } // Get the error code from the connect operation. int connect_error = 0; diff --git a/src/third_party/asio-master/asio/include/asio/detail/reactive_socket_service.hpp b/src/third_party/asio-master/asio/include/asio/detail/reactive_socket_service.hpp index b7b264806a9..ef9a9366a85 100644 --- a/src/third_party/asio-master/asio/include/asio/detail/reactive_socket_service.hpp +++ b/src/third_party/asio-master/asio/include/asio/detail/reactive_socket_service.hpp @@ -486,10 +486,10 @@ public: // Connect the socket to the specified endpoint. asio::error_code connect(implementation_type& impl, - const endpoint_type& peer_endpoint, asio::error_code& ec) + const endpoint_type& peer_endpoint, int timeout_ms, asio::error_code& ec) { socket_ops::sync_connect(impl.socket_, - peer_endpoint.data(), peer_endpoint.size(), ec); + peer_endpoint.data(), peer_endpoint.size(), timeout_ms, ec); return ec; } diff --git a/src/third_party/asio-master/asio/include/asio/detail/socket_ops.hpp b/src/third_party/asio-master/asio/include/asio/detail/socket_ops.hpp index b1fe32af429..2f2a1c38552 100644 --- a/src/third_party/asio-master/asio/include/asio/detail/socket_ops.hpp +++ b/src/third_party/asio-master/asio/include/asio/detail/socket_ops.hpp @@ -104,7 +104,7 @@ ASIO_DECL int connect(socket_type s, const socket_addr_type* addr, std::size_t addrlen, asio::error_code& ec); ASIO_DECL void sync_connect(socket_type s, const socket_addr_type* addr, - std::size_t addrlen, asio::error_code& ec); + std::size_t addrlen, int timeout_ms, asio::error_code& ec); #if defined(ASIO_HAS_IOCP) diff --git a/src/third_party/asio-master/asio/include/asio/detail/win_iocp_socket_service.hpp b/src/third_party/asio-master/asio/include/asio/detail/win_iocp_socket_service.hpp index ab099f6eab1..21d3f24fa77 100644 --- a/src/third_party/asio-master/asio/include/asio/detail/win_iocp_socket_service.hpp +++ b/src/third_party/asio-master/asio/include/asio/detail/win_iocp_socket_service.hpp @@ -562,10 +562,10 @@ public: // Connect the socket to the specified endpoint. asio::error_code connect(implementation_type& impl, - const endpoint_type& peer_endpoint, asio::error_code& ec) + const endpoint_type& peer_endpoint, int timeout_ms, asio::error_code& ec) { socket_ops::sync_connect(impl.socket_, - peer_endpoint.data(), peer_endpoint.size(), ec); + peer_endpoint.data(), peer_endpoint.size(), timeout_ms, ec); return ec; } |