diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/transport/transport_layer_asio.cpp | 265 | ||||
-rw-r--r-- | src/mongo/util/net/sockaddr.cpp | 2 | ||||
-rw-r--r-- | src/mongo/util/net/sockaddr.h | 2 | ||||
-rw-r--r-- | src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp | 13 | ||||
-rw-r--r-- | src/third_party/asio-master/patches/0004-Fix-UNIX-endpoint-length.patch | 25 |
5 files changed, 207 insertions, 100 deletions
diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index 87750d6b506..e5f0be0be52 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -293,19 +293,106 @@ TransportLayerASIO::TransportLayerASIO(const TransportLayerASIO::Options& opts, TransportLayerASIO::~TransportLayerASIO() = default; +class WrappedEndpoint { +public: + using Endpoint = asio::generic::stream_protocol::endpoint; + + explicit WrappedEndpoint(const asio::ip::basic_resolver_entry<asio::ip::tcp>& source) + : _str(str::stream() << source.endpoint().address().to_string() << ":" + << source.service_name()), + _endpoint(source.endpoint()) {} + +#ifndef _WIN32 + explicit WrappedEndpoint(const asio::local::stream_protocol::endpoint& source) + : _str(source.path()), _endpoint(source) {} +#endif + + WrappedEndpoint() = default; + + Endpoint* operator->() noexcept { + return &_endpoint; + } + + Endpoint& operator*() noexcept { + return _endpoint; + } + + const std::string& toString() const { + return _str; + } + + sa_family_t family() const { + return _endpoint.data()->sa_family; + } + +private: + std::string _str; + Endpoint _endpoint; +}; + using Resolver = asio::ip::tcp::resolver; class WrappedResolver { public: using Flags = Resolver::flags; - using Results = Resolver::results_type; + using EndpointVector = std::vector<WrappedEndpoint>; explicit WrappedResolver(asio::io_context& ioCtx) : _resolver(ioCtx) {} - Future<Results> resolve(const HostAndPort& peer, Flags flags, bool enableIPv6) { - Results results; + StatusWith<EndpointVector> resolve(const HostAndPort& peer, bool enableIPv6) { + if (auto unixEp = _checkForUnixSocket(peer)) { + return *unixEp; + } + // We always want to resolve the "service" (port number) as a numeric. + // + // We intentionally don't set the Resolver::address_configured flag because it might prevent + // us from connecting to localhost on hosts with only a loopback interface + // (see SERVER-1579). + const auto flags = Resolver::numeric_service; + + // We resolve in two steps, the first step tries to resolve the hostname as an IP address - + // that way if there's a DNS timeout, we can still connect to IP addresses quickly. + // (See SERVER-1709) + // + // Then, if the numeric (IP address) lookup failed, we fall back to DNS or return the error + // from the resolver. + return _resolve(peer, flags | Resolver::numeric_host, enableIPv6) + .onError([=](Status) { return _resolve(peer, flags, enableIPv6); }) + .getNoThrow(); + } + + Future<EndpointVector> asyncResolve(const HostAndPort& peer, bool enableIPv6) { + if (auto unixEp = _checkForUnixSocket(peer)) { + return *unixEp; + } + + // We follow the same numeric -> hostname fallback procedure as the synchronous resolver + // function for setting resolver flags (see above). + const auto flags = Resolver::numeric_service; + return _asyncResolve(peer, flags | Resolver::numeric_host, enableIPv6).onError([=](Status) { + return _asyncResolve(peer, flags, enableIPv6); + }); + } + + void cancel() { + _resolver.cancel(); + } + +private: + boost::optional<EndpointVector> _checkForUnixSocket(const HostAndPort& peer) { +#ifndef _WIN32 + if (mongoutils::str::contains(peer.host(), '/')) { + asio::local::stream_protocol::endpoint ep(peer.host()); + return EndpointVector{WrappedEndpoint(ep)}; + } +#endif + return boost::none; + } + + Future<EndpointVector> _resolve(const HostAndPort& peer, Flags flags, bool enableIPv6) { std::error_code ec; auto port = std::to_string(peer.port()); + Results results; if (enableIPv6) { results = _resolver.resolve(peer.host(), port, flags, ec); } else { @@ -319,7 +406,7 @@ public: } } - Future<Results> asyncResolve(const HostAndPort& peer, Flags flags, bool enableIPv6) { + Future<EndpointVector> _asyncResolve(const HostAndPort& peer, Flags flags, bool enableIPv6) { auto port = std::to_string(peer.port()); Future<Results> ret; if (enableIPv6) { @@ -330,16 +417,12 @@ public: } return std::move(ret) - .onError([this, peer](Status status) { return _makeFuture(status, peer); }) + .onError([this, peer](Status status) { return _checkResults(status, peer); }) .then([this, peer](Results results) { return _makeFuture(results, peer); }); } - void cancel() { - _resolver.cancel(); - } - -private: - Future<Results> _makeFuture(StatusWith<Results> results, const HostAndPort& peer) { + using Results = Resolver::results_type; + StatusWith<Results> _checkResults(StatusWith<Results> results, const HostAndPort& peer) { if (!results.isOK()) { return Status{ErrorCodes::HostNotFound, str::stream() << "Could not find address for " << peer << ": " @@ -348,63 +431,50 @@ private: return Status{ErrorCodes::HostNotFound, str::stream() << "Could not find address for " << peer}; } else { - return std::move(results.getValue()); + return results; + } + } + + Future<EndpointVector> _makeFuture(StatusWith<Results> results, const HostAndPort& peer) { + results = _checkResults(std::move(results), peer); + if (!results.isOK()) { + return results.getStatus(); + } else { + auto& epl = results.getValue(); + return EndpointVector(epl.begin(), epl.end()); } } Resolver _resolver; }; +Status makeConnectError(Status status, const HostAndPort& peer, const WrappedEndpoint& endpoint) { + std::string errmsg; + if (peer.toString() != endpoint.toString()) { + errmsg = str::stream() << "Error connecting to " << peer << " (" << endpoint.toString() + << ")"; + } else { + errmsg = str::stream() << "Error connecting to " << peer; + } + + return status.withContext(errmsg); +} + StatusWith<SessionHandle> TransportLayerASIO::connect(HostAndPort peer, ConnectSSLMode sslMode, Milliseconds timeout) { std::error_code ec; GenericSocket sock(*_egressReactor); -#ifndef _WIN32 - if (mongoutils::str::contains(peer.host(), '/')) { - invariant(!peer.hasPort()); - auto res = - _doSyncConnect(asio::local::stream_protocol::endpoint(peer.host()), peer, timeout); - if (!res.isOK()) { - return res.getStatus(); - } else { - return static_cast<SessionHandle>(std::move(res.getValue())); - } - } -#endif - WrappedResolver resolver(*_egressReactor); - // We always want to resolve the "service" (port number) as a numeric. - // - // We intentionally don't set the Resolver::address_configured flag because it might prevent us - // from connecting to localhost on hosts with only a loopback interface (see SERVER-1579). - const auto resolverFlags = Resolver::numeric_service; - - // We resolve in two steps, the first step tries to resolve the hostname as an IP address - - // that way if there's a DNS timeout, we can still connect to IP addresses quickly. - // (See SERVER-1709) - // - // Then, if the numeric (IP address) lookup failed, we fall back to DNS or return the error - // from the resolver. - auto swResolverIt = - resolver.resolve(peer, resolverFlags | Resolver::numeric_host, _listenerOptions.enableIPv6) - .getNoThrow(); - if (!swResolverIt.isOK()) { - if (swResolverIt == ErrorCodes::HostNotFound) { - swResolverIt = - resolver.resolve(peer, resolverFlags, _listenerOptions.enableIPv6).getNoThrow(); - if (!swResolverIt.isOK()) { - return swResolverIt.getStatus(); - } - } else { - return swResolverIt.getStatus(); - } + auto swEndpoints = resolver.resolve(peer, _listenerOptions.enableIPv6); + if (!swEndpoints.isOK()) { + return swEndpoints.getStatus(); } - auto& resolverIt = swResolverIt.getValue(); - auto sws = _doSyncConnect(resolverIt->endpoint(), peer, timeout); + auto endpoints = std::move(swEndpoints.getValue()); + auto sws = _doSyncConnect(endpoints.front(), peer, timeout); if (!sws.isOK()) { return sws.getStatus(); } @@ -412,6 +482,12 @@ StatusWith<SessionHandle> TransportLayerASIO::connect(HostAndPort peer, auto session = std::move(sws.getValue()); session->ensureSync(); +#ifndef _WIN32 + if (endpoints.front().family() == AF_UNIX) { + return static_cast<SessionHandle>(std::move(session)); + } +#endif + #ifndef MONGO_CONFIG_SSL if (sslMode == kEnableSSL) { return {ErrorCodes::InvalidSSLConfiguration, "SSL requested but not supported"}; @@ -436,14 +512,14 @@ StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncCon Endpoint endpoint, const HostAndPort& peer, const Milliseconds& timeout) { GenericSocket sock(*_egressReactor); std::error_code ec; - sock.open(endpoint.protocol()); + sock.open(endpoint->protocol()); sock.non_blocking(true); auto now = Date_t::now(); auto expiration = now + timeout; do { auto curTimeout = expiration - now; - sock.connect(endpoint, curTimeout.toSystemDuration(), ec); + sock.connect(*endpoint, curTimeout.toSystemDuration(), ec); if (ec) { now = Date_t::now(); } @@ -451,10 +527,18 @@ StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncCon // the error/timeout below. } while (ec == asio::error::interrupted && now < expiration); - if (ec) { - return errorCodeToStatus(ec); - } else if (now >= expiration) { - return {ErrorCodes::NetworkTimeout, str::stream() << "Timed out connecting to " << peer}; + auto status = [&] { + if (ec) { + return errorCodeToStatus(ec); + } else if (now >= expiration) { + return Status(ErrorCodes::NetworkTimeout, "Timed out"); + } else { + return Status::OK(); + } + }(); + + if (!status.isOK()) { + return makeConnectError(status, peer, endpoint); } sock.non_blocking(false); @@ -478,6 +562,7 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect(HostAndPort peer, GenericSocket socket; WrappedResolver resolver; + WrappedEndpoint resolvedEndpoint; const HostAndPort peer; TransportLayerASIO::ASIOSessionHandle session; }; @@ -489,22 +574,12 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect(HostAndPort peer, return Status{ErrorCodes::HostNotFound, "Hostname or IP address to connect to is empty"}; } - // We always want to resolve the "service" (port number) as a numeric. - // - // We intentionally don't set the Resolver::address_configured flag because it might prevent us - // from connecting to localhost on hosts with only a loopback interface (see SERVER-1579). - const auto resolverFlags = Resolver::numeric_service; - return connector->resolver - .asyncResolve( - connector->peer, resolverFlags | Resolver::numeric_host, _listenerOptions.enableIPv6) - .onError([this, connector, resolverFlags](Status status) { - return connector->resolver.asyncResolve( - connector->peer, resolverFlags, _listenerOptions.enableIPv6); - }) - .then([connector](WrappedResolver::Results results) { - connector->socket.open(results->endpoint().protocol()); + return connector->resolver.asyncResolve(connector->peer, _listenerOptions.enableIPv6) + .then([connector](WrappedResolver::EndpointVector results) { + connector->resolvedEndpoint = results.front(); + connector->socket.open(connector->resolvedEndpoint->protocol()); connector->socket.non_blocking(true); - return connector->socket.async_connect(results->endpoint(), UseFuture{}); + return connector->socket.async_connect(*connector->resolvedEndpoint, UseFuture{}); }) .then([this, connector, sslMode]() { connector->session = @@ -527,7 +602,7 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect(HostAndPort peer, return connector->finish(); }) .onError([connector](Status status) -> Future<SessionHandle> { - return status.withContext(str::stream() << "Error connecting to " << connector->peer); + return makeConnectError(status, connector->peer, connector->resolvedEndpoint); }); } @@ -555,6 +630,7 @@ Status TransportLayerASIO::setup() { } _listenerPort = _listenerOptions.port; + WrappedResolver resolver(*_acceptorReactor); for (auto& ip : listenAddrs) { std::error_code ec; @@ -563,34 +639,33 @@ Status TransportLayerASIO::setup() { continue; } - const auto addrs = SockAddr::createAll( - ip, _listenerOptions.port, _listenerOptions.enableIPv6 ? AF_UNSPEC : AF_INET); - if (addrs.empty()) { - warning() << "Found no addresses for " << ip; + auto swAddrs = + resolver.resolve(HostAndPort(ip, _listenerPort), _listenerOptions.enableIPv6); + if (!swAddrs.isOK()) { + warning() << "Found no addresses for " << swAddrs.getStatus(); continue; } + auto& addrs = swAddrs.getValue(); - for (const auto& addr : addrs) { - asio::generic::stream_protocol::endpoint endpoint(addr.raw(), addr.addressSize); - + for (auto& addr : addrs) { #ifndef _WIN32 - if (addr.getType() == AF_UNIX) { - if (::unlink(ip.c_str()) == -1 && errno != ENOENT) { - error() << "Failed to unlink socket file " << ip << " " + if (addr.family() == AF_UNIX) { + if (::unlink(addr.toString().c_str()) == -1 && errno != ENOENT) { + error() << "Failed to unlink socket file " << addr.toString().c_str() << " " << errnoWithDescription(errno); fassertFailedNoTrace(40486); } } #endif - if (addr.getType() == AF_INET6 && !_listenerOptions.enableIPv6) { + if (addr.family() == AF_INET6 && !_listenerOptions.enableIPv6) { error() << "Specified ipv6 bind address, but ipv6 is disabled"; fassertFailedNoTrace(40488); } GenericAcceptor acceptor(*_acceptorReactor); - acceptor.open(endpoint.protocol()); + acceptor.open(addr->protocol()); acceptor.set_option(GenericAcceptor::reuse_address(true)); - if (addr.getType() == AF_INET6) { + if (addr.family() == AF_INET6) { acceptor.set_option(asio::ip::v6_only(true)); } @@ -599,22 +674,23 @@ Status TransportLayerASIO::setup() { return errorCodeToStatus(ec); } - acceptor.bind(endpoint, ec); + acceptor.bind(*addr, ec); if (ec) { return errorCodeToStatus(ec); } #ifndef _WIN32 - if (addr.getType() == AF_UNIX) { - if (::chmod(ip.c_str(), serverGlobalParams.unixSocketPermissions) == -1) { - error() << "Failed to chmod socket file " << ip << " " + if (addr.family() == AF_UNIX) { + if (::chmod(addr.toString().c_str(), serverGlobalParams.unixSocketPermissions) == + -1) { + error() << "Failed to chmod socket file " << addr.toString().c_str() << " " << errnoWithDescription(errno); fassertFailedNoTrace(40487); } } #endif if (_listenerOptions.port == 0 && - (addr.getType() == AF_INET || addr.getType() == AF_INET6)) { + (addr.family() == AF_INET || addr.family() == AF_INET6)) { if (_listenerPort != _listenerOptions.port) { return Status(ErrorCodes::BadValue, "Port 0 (ephemeral port) is not allowed when" @@ -627,7 +703,10 @@ Status TransportLayerASIO::setup() { } _listenerPort = endpointToHostAndPort(endpoint).port(); } - _acceptors.emplace_back(std::move(addr), std::move(acceptor)); + + sockaddr_storage sa; + memcpy(&sa, addr->data(), addr->size()); + _acceptors.emplace_back(SockAddr(sa, addr->size()), std::move(acceptor)); } } diff --git a/src/mongo/util/net/sockaddr.cpp b/src/mongo/util/net/sockaddr.cpp index 137b1d1639b..0cf024b6f44 100644 --- a/src/mongo/util/net/sockaddr.cpp +++ b/src/mongo/util/net/sockaddr.cpp @@ -193,7 +193,7 @@ std::vector<SockAddr> SockAddr::createAll(StringData target, int port, sa_family return std::vector<SockAddr>(ret.begin(), ret.end()); } -SockAddr::SockAddr(struct sockaddr_storage& other, socklen_t size) +SockAddr::SockAddr(const sockaddr_storage& other, socklen_t size) : addressSize(size), _hostOrIp(), sa(other), _isValid(true) { _hostOrIp = toString(true); } diff --git a/src/mongo/util/net/sockaddr.h b/src/mongo/util/net/sockaddr.h index 371e7f8c17c..840aa520fa1 100644 --- a/src/mongo/util/net/sockaddr.h +++ b/src/mongo/util/net/sockaddr.h @@ -87,7 +87,7 @@ struct SockAddr { */ explicit SockAddr(StringData target, int port, sa_family_t familyHint); - explicit SockAddr(struct sockaddr_storage& other, socklen_t size); + explicit SockAddr(const sockaddr_storage& other, socklen_t size); /** * Resolve an ip or hostname to a vector of SockAddr objects. diff --git a/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp b/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp index af02feada17..e7e2f2e9611 100644 --- a/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp +++ b/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp @@ -110,12 +110,15 @@ void endpoint::init(const char* path_name, std::size_t path_length) data_.local.sun_family = AF_UNIX; if (path_length > 0) memcpy(data_.local.sun_path, path_name, path_length); - path_length_ = path_length; - // NUL-terminate normal path names. Names that start with a NUL are in the - // UNIX domain protocol's "abstract namespace" and are not NUL-terminated. - if (path_length > 0 && data_.local.sun_path[0] == 0) - data_.local.sun_path[path_length] = 0; + // For anonymous (zero-length path) or abstract namespace sockets, the path_length_ is just + // the length of the buffer passed in. + path_length_ = path_length; + // Otherwise it's a normal UNIX path, and the size must include the null terminator. + if (path_length > 0 && data_.local.sun_path[0] != 0) + { + path_length_ += 1; + } } } // namespace detail diff --git a/src/third_party/asio-master/patches/0004-Fix-UNIX-endpoint-length.patch b/src/third_party/asio-master/patches/0004-Fix-UNIX-endpoint-length.patch new file mode 100644 index 00000000000..6c217f7180f --- /dev/null +++ b/src/third_party/asio-master/patches/0004-Fix-UNIX-endpoint-length.patch @@ -0,0 +1,25 @@ +diff --git a/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp b/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp +index af02feada1..e7e2f2e961 100644 +--- a/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp ++++ b/src/third_party/asio-master/asio/include/asio/local/detail/impl/endpoint.ipp +@@ -110,12 +110,15 @@ void endpoint::init(const char* path_name, std::size_t path_length) + data_.local.sun_family = AF_UNIX; + if (path_length > 0) + memcpy(data_.local.sun_path, path_name, path_length); +- path_length_ = path_length; + +- // NUL-terminate normal path names. Names that start with a NUL are in the +- // UNIX domain protocol's "abstract namespace" and are not NUL-terminated. +- if (path_length > 0 && data_.local.sun_path[0] == 0) +- data_.local.sun_path[path_length] = 0; ++ // For anonymous (zero-length path) or abstract namespace sockets, the path_length_ is just ++ // the length of the buffer passed in. ++ path_length_ = path_length; ++ // Otherwise it's a normal UNIX path, and the size must include the null terminator. ++ if (path_length > 0 && data_.local.sun_path[0] != 0) ++ { ++ path_length_ += 1; ++ } + } + + } // namespace detail |