diff options
Diffstat (limited to 'src/mongo/util')
-rw-r--r-- | src/mongo/util/itoa.h | 4 | ||||
-rw-r--r-- | src/mongo/util/net/sockaddr.cpp | 131 | ||||
-rw-r--r-- | src/mongo/util/net/sockaddr.h | 24 | ||||
-rw-r--r-- | src/mongo/util/net/socket_utils.cpp | 13 |
4 files changed, 96 insertions, 76 deletions
diff --git a/src/mongo/util/itoa.h b/src/mongo/util/itoa.h index 878bd62d3b6..76edcb3dd47 100644 --- a/src/mongo/util/itoa.h +++ b/src/mongo/util/itoa.h @@ -49,6 +49,10 @@ public: ItoA(const ItoA&) = delete; ItoA& operator=(const ItoA&) = delete; + std::string toString() const { + return _str.toString(); + } + operator StringData() const { return _str; } diff --git a/src/mongo/util/net/sockaddr.cpp b/src/mongo/util/net/sockaddr.cpp index 7a96087b896..1a2b058cbf3 100644 --- a/src/mongo/util/net/sockaddr.cpp +++ b/src/mongo/util/net/sockaddr.cpp @@ -53,6 +53,7 @@ #endif #endif +#include "mongo/base/status.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/util/builder.h" #include "mongo/logv2/log.h" @@ -69,24 +70,36 @@ struct AddrInfoDeleter { }; using AddrInfoPtr = std::unique_ptr<addrinfo, AddrInfoDeleter>; -struct AddrErr { - AddrInfoPtr addr; - int err; -}; +AddrInfoPtr resolveAddrInfo(StringData hostOrIp, int port, sa_family_t familyHint) { + struct AddrError { + AddrInfoPtr addr; + int err; + }; + + // Convert to std::string to ensure null-termination. + const auto hostString = std::string(hostOrIp); + const auto portString = ItoA(port).toString(); + + auto tryResolve = [&](bool allowDns) noexcept { + AddrError result; -AddrErr resolveAddrInfo(const std::string& hostOrIp, int port, sa_family_t familyHint) { - const std::string portStr{StringData{ItoA(port)}}; - auto tryResolve = [&](bool allowDns) noexcept->AddrErr { addrinfo hints; memset(&hints, 0, sizeof(addrinfo)); hints.ai_socktype = SOCK_STREAM; if (!allowDns) hints.ai_flags |= AI_NUMERICHOST; hints.ai_family = familyHint; + addrinfo* addrs = nullptr; - int ret = getaddrinfo(hostOrIp.c_str(), portStr.c_str(), &hints, &addrs); - AddrInfoPtr rvPtr(addrs); - return {std::move(rvPtr), ret}; + result.err = getaddrinfo(hostString.c_str(), portString.c_str(), &hints, &addrs); + result.addr = AddrInfoPtr(addrs); + return result; + }; + + auto validateResolution = [](AddrError addrErr) -> AddrInfoPtr { + uassert(ErrorCodes::HostUnreachable, getAddrInfoStrError(addrErr.err), addrErr.err == 0); + + return std::move(addrErr.addr); }; switch (auto r = tryResolve(false); r.err) { @@ -96,9 +109,9 @@ AddrErr resolveAddrInfo(const std::string& hostOrIp, int port, sa_family_t famil case EAI_NODATA: // Old IPv6-capable hosts can return EAI_NODATA. #endif #endif - return tryResolve(true); // Not an IP address. Retry with DNS. + return validateResolution(tryResolve(true)); // Not an IP address. Retry with DNS. default: - return r; + return validateResolution(std::move(r)); } } @@ -129,90 +142,88 @@ SockAddr::SockAddr(int sourcePort) { _isValid = true; } -void SockAddr::initUnixDomainSocket(const std::string& path, int port) { +void SockAddr::initUnixDomainSocket(StringData path, int port) { #ifdef _WIN32 uassert(13080, "no unix socket support on windows", false); #endif uassert( 13079, "path to unix socket too long", path.size() < sizeof(as<sockaddr_un>().sun_path)); as<sockaddr_un>().sun_family = AF_UNIX; - strcpy(as<sockaddr_un>().sun_path, path.c_str()); + path.copyTo(as<sockaddr_un>().sun_path, /* includeEndingNull =*/true); addressSize = sizeof(sockaddr_un); _isValid = true; } -SockAddr::SockAddr(StringData target, int port, sa_family_t familyHint) - : _hostOrIp(target.toString()) { - if (_hostOrIp == "localhost") { - _hostOrIp = "127.0.0.1"; +SockAddr SockAddr::create(StringData target, int port, sa_family_t familyHint) { + if (target == "localhost") { + target = "127.0.0.1"_sd; } - if (str::contains(_hostOrIp, '/') || familyHint == AF_UNIX) { - initUnixDomainSocket(_hostOrIp, port); - return; + if (str::contains(target, '/') || familyHint == AF_UNIX) { + SockAddr ret; + ret.initUnixDomainSocket(target, port); + return ret; } - auto addrErr = resolveAddrInfo(_hostOrIp, port, familyHint); + try { + const auto ownedAddrs = resolveAddrInfo(target, port, familyHint); - if (addrErr.err) { + // This throws away all but the first address. + // Use SockAddr::createAll() to get all addresses. + const auto* addrs = ownedAddrs.get(); + fassert(16501, static_cast<size_t>(addrs->ai_addrlen) <= sizeof(struct sockaddr_storage)); + return SockAddr(addrs->ai_addr, addrs->ai_addrlen, target); + } catch (const DBException&) { // we were unsuccessful - if (_hostOrIp != "0.0.0.0") { // don't log if this as it is a - // CRT construction and log() may not work yet. - LOGV2(23175, - "getaddrinfo(\"{host}\") failed: {error}", - "Command getaddrinfo failed", - "host"_attr = _hostOrIp, - "error"_attr = getAddrInfoStrError(addrErr.err)); - _isValid = false; - return; + + if (target == "0.0.0.0") { + // don't log if this as it is a CRT construction and log() may not work yet. + return SockAddr(port); } - *this = SockAddr(port); - return; - } - // This throws away all but the first address. - // Use SockAddr::createAll() to get all addresses. - const auto* addrs = addrErr.addr.get(); - fassert(16501, static_cast<size_t>(addrs->ai_addrlen) <= sizeof(sa)); - memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen); - addressSize = addrs->ai_addrlen; - _isValid = true; + throw; + } } std::vector<SockAddr> SockAddr::createAll(StringData target, int port, sa_family_t familyHint) { - std::string hostOrIp = target.toString(); - if (str::contains(hostOrIp, '/')) { + if (str::contains(target, '/')) { std::vector<SockAddr> ret = {SockAddr()}; - ret[0].initUnixDomainSocket(hostOrIp, port); + ret[0].initUnixDomainSocket(target, port); // Currently, this is always valid since initUnixDomainSocket() // will uassert() on failure. Be defensive against future changes. return ret[0].isValid() ? ret : std::vector<SockAddr>(); } - auto addrErr = resolveAddrInfo(hostOrIp, port, familyHint); - if (addrErr.err) { + try { + const auto ownedAddrs = resolveAddrInfo(target, port, familyHint); + + std::set<SockAddr> ret; + for (const auto* addrs = ownedAddrs.get(); addrs; addrs = addrs->ai_next) { + fassert(40594, + static_cast<size_t>(addrs->ai_addrlen) <= sizeof(struct sockaddr_storage)); + ret.emplace(addrs->ai_addr, addrs->ai_addrlen, target); + } + return std::vector<SockAddr>(ret.begin(), ret.end()); + } catch (const DBException& ex) { LOGV2(23176, "getaddrinfo(\"{host}\") failed: {error}", "getaddrinfo invocation failed", - "host"_attr = hostOrIp, - "error"_attr = getAddrInfoStrError(addrErr.err)); + "host"_attr = target, + "error"_attr = ex.toStatus()); return {}; } - - std::set<SockAddr> ret; - struct sockaddr_storage storage; - memset(&storage, 0, sizeof(storage)); - for (const auto* addrs = addrErr.addr.get(); addrs; addrs = addrs->ai_next) { - fassert(40594, static_cast<size_t>(addrs->ai_addrlen) <= sizeof(struct sockaddr_storage)); - ret.emplace(addrs->ai_addr, addrs->ai_addrlen); - } - return std::vector<SockAddr>(ret.begin(), ret.end()); } -SockAddr::SockAddr(const sockaddr* other, socklen_t size) - : addressSize(size), _hostOrIp(), sa(), _isValid(true) { +SockAddr::SockAddr(const sockaddr* other, socklen_t size) : addressSize(size), _hostOrIp(), sa() { memcpy(&sa, other, size); _hostOrIp = toString(true); + _isValid = true; +} + +SockAddr::SockAddr(const sockaddr* other, socklen_t size, StringData hostOrIp) + : addressSize(size), _hostOrIp(hostOrIp.toString()), sa() { + memcpy(&sa, other, size); + _isValid = true; } bool SockAddr::isIP() const { diff --git a/src/mongo/util/net/sockaddr.h b/src/mongo/util/net/sockaddr.h index 1ee56d93514..66b985274c2 100644 --- a/src/mongo/util/net/sockaddr.h +++ b/src/mongo/util/net/sockaddr.h @@ -73,23 +73,23 @@ struct SockAddr { SockAddr(); explicit SockAddr(int sourcePort); /* listener side */ + explicit SockAddr(const sockaddr* other, socklen_t size); + explicit SockAddr(const sockaddr* other, socklen_t size, StringData hostOrIp); + /** * Initialize a SockAddr for a given IP or Hostname. * - * If target fails to resolve/parse, SockAddr.isValid() may return false, - * or the resulting SockAddr may be equivalent to SockAddr(port). + * If target fails to resolve/parse, this function may throw or the resulting SockAddr may be + * equivalent to SockAddr(port). * - * If target is a unix domain socket, a uassert() exception will be thrown - * on windows or if addr exceeds maximum path length. + * If target is a unix domain socket, a uassert() exception will be thrown on windows or if addr + * exceeds maximum path length. * - * If target resolves to more than one address, only the first address - * will be used. Others will be discarded. - * SockAddr::createAll() is recommended for capturing all addresses. + * If target resolves to more than one address, only the first address will be used. Others will + * be discarded. SockAddr::createAll() is recommended for capturing all addresses. */ - explicit SockAddr(StringData target, int port, sa_family_t familyHint); - - explicit SockAddr(const sockaddr* other, socklen_t size); + static SockAddr create(StringData target, int port, sa_family_t familyHint); /** * Resolve an ip or hostname to a vector of SockAddr objects. @@ -153,11 +153,11 @@ struct SockAddr { void serializeToBSON(StringData fieldName, BSONObjBuilder* builder) const; private: - void initUnixDomainSocket(const std::string& path, int port); + void initUnixDomainSocket(StringData path, int port); std::string _hostOrIp; struct sockaddr_storage sa; - bool _isValid; + bool _isValid = false; }; } // namespace mongo diff --git a/src/mongo/util/net/socket_utils.cpp b/src/mongo/util/net/socket_utils.cpp index 5e1909969a1..685928a7667 100644 --- a/src/mongo/util/net/socket_utils.cpp +++ b/src/mongo/util/net/socket_utils.cpp @@ -199,11 +199,16 @@ std::string makeUnixSockPath(int port) { // If an ip address is passed in, just return that. If a hostname is passed // in, look up its ip and return that. Returns "" on failure. std::string hostbyname(const char* hostname) { - SockAddr sockAddr(hostname, 0, IPv6Enabled() ? AF_UNSPEC : AF_INET); - if (!sockAddr.isValid() || sockAddr.getAddr() == "0.0.0.0") + try { + auto addr = SockAddr::create(hostname, 0, IPv6Enabled() ? AF_UNSPEC : AF_INET).getAddr(); + if (addr == "0.0.0.0") { + return ""; + } + + return addr; + } catch (const DBException&) { return ""; - else - return sockAddr.getAddr(); + } } // --- my -- |