diff options
31 files changed, 1428 insertions, 163 deletions
diff --git a/SConstruct b/SConstruct index 8228a103036..9489f53bd31 100644 --- a/SConstruct +++ b/SConstruct @@ -1413,7 +1413,7 @@ else: env.AppendUnique( CPPDEFINES=[ 'NDEBUG' ] ) if env.TargetOSIs('linux'): - env.Append( LIBS=['m'] ) + env.Append( LIBS=['m',"resolv"] ) elif env.TargetOSIs('solaris'): env.Append( LIBS=["socket","resolv","lgrp"] ) @@ -1422,6 +1422,9 @@ elif env.TargetOSIs('freebsd'): env.Append( LIBS=[ "kvm" ] ) env.Append( CCFLAGS=[ "-fno-omit-frame-pointer" ] ) +elif env.TargetOSIs('darwin'): + env.Append( LIBS=["resolv"] ) + elif env.TargetOSIs('openbsd'): env.Append( LIBS=[ "kvm" ] ) @@ -1582,6 +1585,7 @@ elif env.TargetOSIs('windows'): 'advapi32.lib', 'bcrypt.lib', 'crypt32.lib', + 'dnsapi.lib', 'kernel32.lib', 'shell32.lib', 'pdh.lib', diff --git a/jstests/serial_run/srv-uri.js b/jstests/serial_run/srv-uri.js new file mode 100644 index 00000000000..b550c540d6c --- /dev/null +++ b/jstests/serial_run/srv-uri.js @@ -0,0 +1,8 @@ +(function() { + "use strict"; + const md = MongoRunner.runMongod({port: "27017", dbpath: MongoRunner.dataPath}); + assert.neq(null, md, "unable to start mongod"); + const exitCode = runMongoProgram('mongo', 'mongodb+srv://test1.test.build.10gen.cc.', '--eval', ';'); + assert.eq(exitCode, 0, "Failed to connect with a `mongodb+srv://` style URI."); + MongoRunner.stopMongod(md); +}()); diff --git a/src/mongo/base/error_codes.err b/src/mongo/base/error_codes.err index 7cade582bd4..01133e3c7cc 100644 --- a/src/mongo/base/error_codes.err +++ b/src/mongo/base/error_codes.err @@ -227,6 +227,8 @@ error_code("AtomicityFailure", 226) error_code("CannotImplicitlyCreateCollection", 227); error_code("SessionTransferIncomplete", 228) error_code("MustDowngrade", 229) +error_code("DNSHostNotFound", 230) +error_code("DNSProtocolError", 231) # Error codes 4000-8999 are reserved. diff --git a/src/mongo/client/SConscript b/src/mongo/client/SConscript index 16dd55dafa0..1956ae58109 100644 --- a/src/mongo/client/SConscript +++ b/src/mongo/client/SConscript @@ -24,6 +24,9 @@ env.Library( ], LIBDEPS=[ '$BUILD_DIR/mongo/util/net/network', + ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/util/dns_query', ] ) diff --git a/src/mongo/client/connection_string.cpp b/src/mongo/client/connection_string.cpp index 16a6e78440d..5e8a5f4e0bb 100644 --- a/src/mongo/client/connection_string.cpp +++ b/src/mongo/client/connection_string.cpp @@ -121,9 +121,6 @@ void ConnectionString::_finishInit() { uassert(ErrorCodes::FailedToParse, "Cannot specify a replica set name for a ConnectionString of type MASTER", _setName.empty()); - uassert(ErrorCodes::FailedToParse, - "ConnectionStrings of type MASTER must contain exactly one server", - _servers.size() == 1); break; case SET: uassert(ErrorCodes::FailedToParse, diff --git a/src/mongo/client/connection_string.h b/src/mongo/client/connection_string.h index 2e888b56dad..0467a77dce3 100644 --- a/src/mongo/client/connection_string.h +++ b/src/mongo/client/connection_string.h @@ -28,6 +28,7 @@ #pragma once +#include <memory> #include <sstream> #include <string> #include <vector> @@ -119,10 +120,10 @@ public: bool operator==(const ConnectionString& other) const; bool operator!=(const ConnectionString& other) const; - DBClientBase* connect(StringData applicationName, - std::string& errmsg, - double socketTimeout = 0, - const MongoURI* uri = nullptr) const; + std::unique_ptr<DBClientBase> connect(StringData applicationName, + std::string& errmsg, + double socketTimeout = 0, + const MongoURI* uri = nullptr) const; static StatusWith<ConnectionString> parse(const std::string& url); @@ -139,9 +140,9 @@ public: virtual ~ConnectionHook() {} // Returns an alternative connection object for a string - virtual DBClientBase* connect(const ConnectionString& c, - std::string& errmsg, - double socketTimeout) = 0; + virtual std::unique_ptr<DBClientBase> connect(const ConnectionString& c, + std::string& errmsg, + double socketTimeout) = 0; }; static void setConnectionHook(ConnectionHook* hook) { diff --git a/src/mongo/client/connection_string_connect.cpp b/src/mongo/client/connection_string_connect.cpp index e4c99dd37e1..d824fa6be78 100644 --- a/src/mongo/client/connection_string_connect.cpp +++ b/src/mongo/client/connection_string_connect.cpp @@ -46,10 +46,10 @@ namespace mongo { stdx::mutex ConnectionString::_connectHookMutex; ConnectionString::ConnectionHook* ConnectionString::_connectHook = NULL; -DBClientBase* ConnectionString::connect(StringData applicationName, - std::string& errmsg, - double socketTimeout, - const MongoURI* uri) const { +std::unique_ptr<DBClientBase> ConnectionString::connect(StringData applicationName, + std::string& errmsg, + double socketTimeout, + const MongoURI* uri) const { MongoURI newURI{}; if (uri) { newURI = *uri; @@ -57,15 +57,18 @@ DBClientBase* ConnectionString::connect(StringData applicationName, switch (_type) { case MASTER: { - auto c = stdx::make_unique<DBClientConnection>(true, 0, std::move(newURI)); - - c->setSoTimeout(socketTimeout); - LOG(1) << "creating new connection to:" << _servers[0]; - if (!c->connect(_servers[0], applicationName, errmsg)) { - return 0; + for (const auto& server : _servers) { + auto c = stdx::make_unique<DBClientConnection>(true, 0, newURI); + + c->setSoTimeout(socketTimeout); + LOG(1) << "creating new connection to:" << server; + if (!c->connect(server, applicationName, errmsg)) { + continue; + } + LOG(1) << "connected connection!"; + return std::move(c); } - LOG(1) << "connected connection!"; - return c.release(); + return nullptr; } case SET: { @@ -74,9 +77,9 @@ DBClientBase* ConnectionString::connect(StringData applicationName, if (!set->connect()) { errmsg = "connect failed to replica set "; errmsg += toString(); - return 0; + return nullptr; } - return set.release(); + return std::move(set); } case CUSTOM: { @@ -91,7 +94,7 @@ DBClientBase* ConnectionString::connect(StringData applicationName, _connectHook); // Double-checked lock, since this will never be active during normal operation - DBClientBase* replacementConn = _connectHook->connect(*this, errmsg, socketTimeout); + auto replacementConn = _connectHook->connect(*this, errmsg, socketTimeout); log() << "replacing connection to " << this->toString() << " with " << (replacementConn ? replacementConn->getServerAddress() : "(empty)"); diff --git a/src/mongo/client/connpool.cpp b/src/mongo/client/connpool.cpp index 30325079fab..6c5a3d9cdfa 100644 --- a/src/mongo/client/connpool.cpp +++ b/src/mongo/client/connpool.cpp @@ -256,7 +256,7 @@ DBClientBase* DBConnectionPool::get(const ConnectionString& url, double socketTi // If no connections for this host are available in the PoolForHost (that is, all the // connections have been checked out, or none have been created yet), create a new connection. string errmsg; - c = url.connect(StringData(), errmsg, socketTimeout); + c = url.connect(StringData(), errmsg, socketTimeout).release(); uassert(13328, _name + ": connect failed " + url.toString() + " : " + errmsg, c); return _finishCreate(url.toString(), socketTimeout, c); @@ -277,7 +277,7 @@ DBClientBase* DBConnectionPool::get(const string& host, double socketTimeout) { const ConnectionString cs(uassertStatusOK(ConnectionString::parse(host))); string errmsg; - c = cs.connect(StringData(), errmsg, socketTimeout); + c = cs.connect(StringData(), errmsg, socketTimeout).release(); if (!c) throw SocketException(SocketException::CONNECT_ERROR, host, diff --git a/src/mongo/client/mongo_uri.cpp b/src/mongo/client/mongo_uri.cpp index e2356881290..e7c500d719b 100644 --- a/src/mongo/client/mongo_uri.cpp +++ b/src/mongo/client/mongo_uri.cpp @@ -34,24 +34,26 @@ #include <utility> +#include <boost/algorithm/string/case_conv.hpp> +#include <boost/algorithm/string/classification.hpp> +#include <boost/algorithm/string/find_iterator.hpp> +#include <boost/algorithm/string/predicate.hpp> + #include "mongo/base/status_with.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/client/dbclientinterface.h" #include "mongo/client/sasl_client_authenticate.h" #include "mongo/db/namespace_string.h" +#include "mongo/util/dns_query.h" #include "mongo/util/hex.h" #include "mongo/util/mongoutils/str.h" -#include <boost/algorithm/string/case_conv.hpp> -#include <boost/algorithm/string/classification.hpp> -#include <boost/algorithm/string/find_iterator.hpp> -#include <boost/algorithm/string/predicate.hpp> - namespace { constexpr std::array<char, 16> hexits{ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'}; const mongo::StringData kURIPrefix{"mongodb://"}; -} +const mongo::StringData kURISRVPrefix{"mongodb+srv://"}; +} // namespace /** * RFC 3986 Section 2.1 - Percent Encoding @@ -94,7 +96,8 @@ namespace mongo { namespace { /** - * Helper Method for MongoURI::parse() to split a string into exactly 2 pieces by a char delimeter + * Helper Method for MongoURI::parse() to split a string into exactly 2 pieces by a char + * delimeter. */ std::pair<StringData, StringData> partitionForward(StringData str, const char c) { const auto delim = str.find(c); @@ -105,8 +108,8 @@ std::pair<StringData, StringData> partitionForward(StringData str, const char c) } /** - * Helper method for MongoURI::parse() to split a string into exactly 2 pieces by a char delimiter - * searching backward from the end of the string. + * Helper method for MongoURI::parse() to split a string into exactly 2 pieces by a char + * delimiter searching backward from the end of the string. */ std::pair<StringData, StringData> partitionBackward(StringData str, const char c) { const auto delim = str.rfind(c); @@ -121,17 +124,17 @@ std::pair<StringData, StringData> partitionBackward(StringData str, const char c * * foo=bar&baz=qux&... */ -StatusWith<MongoURI::OptionsMap> parseOptions(StringData options, StringData url) { +MongoURI::OptionsMap parseOptions(StringData options, StringData url) { MongoURI::OptionsMap ret; if (options.empty()) { return ret; } if (options.find('?') != std::string::npos) { - return Status(ErrorCodes::FailedToParse, - str::stream() - << "URI Cannot Contain multiple questions marks for mongodb:// URL: " - << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() + << "URI Cannot Contain multiple questions marks for mongodb:// URL: " + << url); } const auto optionsStr = options.toString(); @@ -141,16 +144,16 @@ StatusWith<MongoURI::OptionsMap> parseOptions(StringData options, StringData url ++i) { const auto opt = boost::copy_range<std::string>(*i); if (opt.empty()) { - return Status(ErrorCodes::FailedToParse, - str::stream() - << "Missing a key/value pair in the options for mongodb:// URL: " - << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() + << "Missing a key/value pair in the options for mongodb:// URL: " + << url); } const auto kvPair = partitionForward(opt, '='); const auto keyRaw = kvPair.first; if (keyRaw.empty()) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "Missing a key for key/value pair in the options for mongodb:// URL: " @@ -158,7 +161,7 @@ StatusWith<MongoURI::OptionsMap> parseOptions(StringData options, StringData url } const auto key = uriDecode(keyRaw); if (!key.isOK()) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "Key '" << keyRaw << "' in options cannot properly be URL decoded for mongodb:// URL: " @@ -166,14 +169,14 @@ StatusWith<MongoURI::OptionsMap> parseOptions(StringData options, StringData url } const auto valRaw = kvPair.second; if (valRaw.empty()) { - return Status(ErrorCodes::FailedToParse, - str::stream() << "Missing value for key '" << keyRaw - << "' in the options for mongodb:// URL: " - << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() << "Missing value for key '" << keyRaw + << "' in the options for mongodb:// URL: " + << url); } const auto val = uriDecode(valRaw); if (!val.isOK()) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "Value '" << valRaw << "' for key '" << keyRaw << "' in options cannot properly be URL decoded for mongodb:// URL: " @@ -186,20 +189,38 @@ StatusWith<MongoURI::OptionsMap> parseOptions(StringData options, StringData url return ret; } +MongoURI::OptionsMap addTXTOptions(MongoURI::OptionsMap options, + const std::string& host, + const StringData url, + const bool isSeedlist) { + // If there is no seedlist mode, then don't add any TXT options. + if (!isSeedlist) + return options; + + // Get all TXT records and parse them as options, adding them to the options set. + const auto txtRecords = dns::getTXTRecords(host); + + for (const auto& record : txtRecords) { + auto txtOptions = parseOptions(record, url); + // Note that, `std::map` and `std::unordered_map` insert does not replace existing + // values -- this gives the desired behavior that user-specified values override TXT + // record specified values. + options.insert(begin(txtOptions), end(txtOptions)); + } + + return options; +} } // namespace -StatusWith<MongoURI> MongoURI::parse(const std::string& url) { +MongoURI MongoURI::parseImpl(const std::string& url) { const StringData urlSD(url); // 1. Validate and remove the scheme prefix mongodb:// - if (!urlSD.startsWith(kURIPrefix)) { - const auto cs_status = ConnectionString::parse(url); - if (!cs_status.isOK()) { - return cs_status.getStatus(); - } - return MongoURI(cs_status.getValue()); + const bool isSeedlist = urlSD.startsWith(kURISRVPrefix); + if (!(urlSD.startsWith(kURIPrefix) || isSeedlist)) { + return MongoURI(uassertStatusOK(ConnectionString::parse(url))); } - const auto uriWithoutPrefix = urlSD.substr(kURIPrefix.size()); + const auto uriWithoutPrefix = urlSD.substr(urlSD.find("://") + 3); // 2. Split the string by the first, unescaped / (if any), yielding: // split[0]: User information and host identifers @@ -211,14 +232,15 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { // 2.b Make sure that there are no question marks in the left side of the / // as any options after the ? must still have the / delimeter if (databaseAndOptions.empty() && userAndHostInfo.find('?') != std::string::npos) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "URI must contain slash delimeter between hosts and options for mongodb:// URL: " << url); } - // 3. Split the user information and host identifiers string by the last, unescaped @, yielding: + // 3. Split the user information and host identifiers string by the last, unescaped @, + // yielding: // split[0]: User information // split[1]: Host identifiers; const auto userAndHost = partitionBackward(userAndHostInfo, '@'); @@ -237,18 +259,21 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { }; if (containsColonOrAt(usernameSD)) { - return Status(ErrorCodes::FailedToParse, - str::stream() << "Username must be URL Encoded for mongodb:// URL: " << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() << "Username must be URL Encoded for mongodb:// URL: " + << url); } + if (containsColonOrAt(passwordSD)) { - return Status(ErrorCodes::FailedToParse, - str::stream() << "Password must be URL Encoded for mongodb:// URL: " << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() << "Password must be URL Encoded for mongodb:// URL: " + << url); } // Get the username and make sure it did not fail to decode const auto usernameWithStatus = uriDecode(usernameSD); if (!usernameWithStatus.isOK()) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "Username cannot properly be URL decoded for mongodb:// URL: " << url); } @@ -257,7 +282,7 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { // Get the password and make sure it did not fail to decode const auto passwordWithStatus = uriDecode(passwordSD); if (!passwordWithStatus.isOK()) - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "Password cannot properly be URL decoded for mongodb:// URL: " << url); const auto password = passwordWithStatus.getValue(); @@ -271,7 +296,7 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { ++i) { const auto hostWithStatus = uriDecode(boost::copy_range<std::string>(*i)); if (!hostWithStatus.isOK()) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "Host cannot properly be URL decoded for mongodb:// URL: " << url); } @@ -282,23 +307,35 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { } if ((host.find('/') != std::string::npos) && !StringData(host).endsWith(".sock")) { - return Status( + throw DBException( ErrorCodes::FailedToParse, str::stream() << "'" << host << "' in '" << url << "' appears to be a unix socket, but does not end in '.sock'"); } - const auto statusHostAndPort = HostAndPort::parse(host); - if (!statusHostAndPort.isOK()) { - return statusHostAndPort.getStatus(); - } - servers.push_back(statusHostAndPort.getValue()); + servers.push_back(uassertStatusOK(HostAndPort::parse(host))); } if (servers.empty()) { - return Status(ErrorCodes::FailedToParse, "No server(s) specified"); + throw DBException(ErrorCodes::FailedToParse, "No server(s) specified"); } - // 6. Split the auth database and connection options string by the first, unescaped ?, yielding: + const std::string canonicalHost = servers.front().host(); + // If we're in seedlist mode, lookup the SRV record for `_mongodb._tcp` on the specified + // domain name. Take that list of servers as the new list of servers. + if (isSeedlist) { + if (servers.size() > 1) { + throw DBException(ErrorCodes::FailedToParse, + "Only a single server may be specified with a mongo+srv:// url."); + } + auto srvEntries = dns::lookupSRVRecords("_mongodb._tcp." + canonicalHost); + servers.clear(); + std::transform(begin(srvEntries), end(srvEntries), back_inserter(servers), [](auto& srv) { + return HostAndPort(std::move(srv.host), srv.port); + }); + } + + // 6. Split the auth database and connection options string by the first, unescaped ?, + // yielding: // split[0] = auth database // split[1] = connection options const auto dbAndOpts = partitionForward(databaseAndOptions, '?'); @@ -306,10 +343,10 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { const auto connectionOptions = dbAndOpts.second; const auto databaseWithStatus = uriDecode(databaseSD); if (!databaseWithStatus.isOK()) { - return Status(ErrorCodes::FailedToParse, - str::stream() - << "Database name cannot properly be URL decoded for mongodb:// URL: " - << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() << "Database name cannot properly be URL " + "decoded for mongodb:// URL: " + << url); } const auto database = databaseWithStatus.getValue(); @@ -320,35 +357,32 @@ StatusWith<MongoURI> MongoURI::parse(const std::string& url) { if (!database.empty() && !NamespaceString::validDBName(database, NamespaceString::DollarInDbNameBehavior::Disallow)) { - return Status(ErrorCodes::FailedToParse, - str::stream() - << "Database name cannot have reserved characters for mongodb:// URL: " - << url); + throw DBException(ErrorCodes::FailedToParse, + str::stream() << "Database name cannot have reserved " + "characters for mongodb:// URL: " + << url); } // 8. Validate, split, and URL decode the connection options - const auto optsWith = parseOptions(connectionOptions, url); - if (!optsWith.isOK()) { - return optsWith.getStatus(); - } - const auto options = optsWith.getValue(); + auto options = + addTXTOptions(parseOptions(connectionOptions, url), canonicalHost, url, isSeedlist); // If a replica set option was specified, store it in the 'setName' field. const auto optIter = options.find("replicaSet"); std::string setName; - if (optIter != options.end()) { + if (optIter != end(options)) { setName = optIter->second; invariant(!setName.empty()); } - if ((servers.size() > 1) && setName.empty()) { - return Status(ErrorCodes::FailedToParse, - "Cannot list multiple servers in URL without 'replicaSet' option"); - } - ConnectionString cs( setName.empty() ? ConnectionString::MASTER : ConnectionString::SET, servers, setName); return MongoURI(std::move(cs), username, password, database, std::move(options)); } +StatusWith<MongoURI> MongoURI::parse(const std::string& url) try { + return parseImpl(url); +} catch (const std::exception&) { + return exceptionToStatus(); +} } // namespace mongo diff --git a/src/mongo/client/mongo_uri.h b/src/mongo/client/mongo_uri.h index f7a7c4aeb81..11d948a526c 100644 --- a/src/mongo/client/mongo_uri.h +++ b/src/mongo/client/mongo_uri.h @@ -1,7 +1,7 @@ /** * Copyright (C) 2015 MongoDB Inc. * - * This program is free software: you can redistribute it and/or modify + * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, version 3, * as published by the Free Software Foundation. * @@ -43,7 +43,6 @@ #include "mongo/util/net/hostandport.h" namespace mongo { - /** * Encode a string for embedding in a URI. * Replaces reserved bytes with %xx sequences. @@ -51,6 +50,7 @@ namespace mongo { * Optionally allows passthrough characters to remain unescaped. */ void uriEncode(std::ostream& ss, StringData str, StringData passthrough = ""_sd); + inline std::string uriEncode(StringData str, StringData passthrough = ""_sd) { std::ostringstream ss; uriEncode(ss, str, passthrough); @@ -66,9 +66,13 @@ StatusWith<std::string> uriDecode(StringData str); /** * MongoURI handles parsing of URIs for mongodb, and falls back to old-style * ConnectionString parsing. It's used primarily by the shell. - * It parses URIs with the following format: + * It parses URIs with the following formats: * * mongodb://[usr:pwd@]host1[:port1]...[,hostN[:portN]]][/[db][?options]] + * mongodb+srv://[usr:pwd@]host[/[db][?options]] + * + * `mongodb+srv://` URIs will perform DNS SRV and TXT lookups and expand per the DNS Seedlist + * specification. * * While this format is generally RFC 3986 compliant, some exceptions do exist: * 1. The 'host' field, as defined by section 3.2.2 is expanded in the following ways: @@ -96,6 +100,11 @@ StatusWith<std::string> uriDecode(StringData str); */ class MongoURI { public: + // Note that, because this map is used for DNS TXT record injection on options, there is a + // requirement on its behavior for `insert`: insert must not replace or update existing values + // -- this gives the desired behavior that user-specified values override TXT record specified + // values. `std::map` and `std::unordered_map` satisfy this requirement. Make sure that + // whichever map type is used provides that guarantee. using OptionsMap = std::map<std::string, std::string>; static StatusWith<MongoURI> parse(const std::string& url); @@ -140,20 +149,21 @@ public: // server (say a member of a replica-set), you can pass in its HostAndPort information to // get a new URI with the same info, except type() will be MASTER and getServers() will // be the single host you pass in. - MongoURI cloneURIForServer(const HostAndPort& hostAndPort) const { - return MongoURI(ConnectionString(hostAndPort), _user, _password, _database, _options); + MongoURI cloneURIForServer(HostAndPort hostAndPort) const { + return MongoURI( + ConnectionString(std::move(hostAndPort)), _user, _password, _database, _options); } ConnectionString::ConnectionType type() const { return _connectString.type(); } - explicit MongoURI(const ConnectionString connectString) - : _connectString(std::move(connectString)){}; + explicit MongoURI(const ConnectionString& connectString) : _connectString(connectString){}; MongoURI() = default; friend std::ostream& operator<<(std::ostream&, const MongoURI&); + friend StringBuilder& operator<<(StringBuilder&, const MongoURI&); private: @@ -166,10 +176,12 @@ private: _user(user), _password(password), _database(database), - _options(std::move(options)){}; + _options(std::move(options)) {} BSONObj _makeAuthObjFromOptions(int maxWireVersion) const; + static MongoURI parseImpl(const std::string& url); + ConnectionString _connectString; std::string _user; std::string _password; @@ -178,13 +190,10 @@ private: }; inline std::ostream& operator<<(std::ostream& ss, const MongoURI& uri) { - ss << uri._connectString; - return ss; + return ss << uri._connectString; } inline StringBuilder& operator<<(StringBuilder& sb, const MongoURI& uri) { - sb << uri._connectString; - return sb; + return sb << uri._connectString; } - } // namespace mongo diff --git a/src/mongo/client/mongo_uri_test.cpp b/src/mongo/client/mongo_uri_test.cpp index 00edcb110f4..5153452f9ec 100644 --- a/src/mongo/client/mongo_uri_test.cpp +++ b/src/mongo/client/mongo_uri_test.cpp @@ -1,7 +1,7 @@ /** * Copyright (C) 2009-2015 MongoDB Inc. * - * This program is free software: you can redistribute it and/or modify + * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, version 3, * as published by the Free Software Foundation. * @@ -506,8 +506,8 @@ TEST(MongoURI, CloneURIForServer) { /** * These tests come from the Mongo Uri Specifications for the drivers found at: * https://github.com/mongodb/specifications/tree/master/source/connection-string/tests - * They have been slighly altered as the Drivers specification is slighly different - * from the server specification. + * They have been altered as the Drivers specification is somewhat different from the shell + * implementation. */ TEST(MongoURI, specTests) { const std::string files[] = { @@ -599,4 +599,156 @@ TEST(MongoURI, specTests) { } } +TEST(MongoURI, srvRecordTest) { + using namespace mongo; + const struct { + std::string uri; + std::string user; + std::string password; + std::string database; + std::vector<HostAndPort> hosts; + std::map<std::string, std::string> options; + } tests[] = { + // Test some non-SRV URIs to make sure that they do not perform expansions + {"mongodb://test1.test.build.10gen.cc:12345/", + "", + "", + "", + {{"test1.test.build.10gen.cc", 12345}}, + {}}, + {"mongodb://test6.test.build.10gen.cc:12345/", + "", + "", + "", + {{"test6.test.build.10gen.cc", 12345}}, + {}}, + + // Test a sample URI against each provided testing DNS entry + {"mongodb+srv://test1.test.build.10gen.cc/", + "", + "", + "", + {{"localhost.build.10gen.cc.", 27017}, {"localhost.build.10gen.cc.", 27018}}, + {}}, + + {"mongodb+srv://user:password@test2.test.build.10gen.cc/" + "database?someOption=someValue&someOtherOption=someOtherValue", + "user", + "password", + "database", + {{"localhost.build.10gen.cc.", 27018}, {"localhost.build.10gen.cc.", 27019}}, + {{"someOption", "someValue"}, {"someOtherOption", "someOtherValue"}}}, + + + {"mongodb+srv://user:password@test3.test.build.10gen.cc/" + "database?someOption=someValue&someOtherOption=someOtherValue", + "user", + "password", + "database", + {{"localhost.build.10gen.cc.", 27017}}, + {{"someOption", "someValue"}, {"someOtherOption", "someOtherValue"}}}, + + + {"mongodb+srv://user:password@test5.test.build.10gen.cc/" + "database?someOption=someValue&someOtherOption=someOtherValue", + "user", + "password", + "database", + {{"localhost.build.10gen.cc.", 27017}}, + {{"someOption", "someValue"}, + {"someOtherOption", "someOtherValue"}, + {"connectTimeoutMS", "300000"}, + {"socketTimeoutMS", "300000"}}}, + + {"mongodb+srv://user:password@test5.test.build.10gen.cc/" + "database?someOption=someValue&socketTimeoutMS=100&someOtherOption=someOtherValue", + "user", + "password", + "database", + {{"localhost.build.10gen.cc.", 27017}}, + {{"someOption", "someValue"}, + {"someOtherOption", "someOtherValue"}, + {"connectTimeoutMS", "300000"}, + {"socketTimeoutMS", "100"}}}, + + {"mongodb+srv://test6.test.build.10gen.cc/", + "", + "", + "", + {{"localhost.build.10gen.cc.", 27017}}, + {{"connectTimeoutMS", "200000"}, {"socketTimeoutMS", "200000"}}}, + + {"mongodb+srv://test6.test.build.10gen.cc/database", + "", + "", + "database", + {{"localhost.build.10gen.cc.", 27017}}, + {{"connectTimeoutMS", "200000"}, {"socketTimeoutMS", "200000"}}}, + + {"mongodb+srv://test6.test.build.10gen.cc/?connectTimeoutMS=300000", + "", + "", + "", + {{"localhost.build.10gen.cc.", 27017}}, + {{"connectTimeoutMS", "300000"}, {"socketTimeoutMS", "200000"}}}, + + {"mongodb+srv://test6.test.build.10gen.cc/?irrelevantOption=irrelevantValue", + "", + "", + "", + {{"localhost.build.10gen.cc.", 27017}}, + {{"connectTimeoutMS", "200000"}, + {"socketTimeoutMS", "200000"}, + {"irrelevantOption", "irrelevantValue"}}}, + + + {"mongodb+srv://test6.test.build.10gen.cc/" + "?irrelevantOption=irrelevantValue&connectTimeoutMS=300000", + "", + "", + "", + {{"localhost.build.10gen.cc.", 27017}}, + {{"connectTimeoutMS", "300000"}, + {"socketTimeoutMS", "200000"}, + {"irrelevantOption", "irrelevantValue"}}}, + }; + + for (const auto& test : tests) { + auto rs = MongoURI::parse(test.uri); + ASSERT_OK(rs.getStatus()); + auto rv = rs.getValue(); + ASSERT_EQ(rv.getUser(), test.user); + ASSERT_EQ(rv.getPassword(), test.password); + ASSERT_EQ(rv.getDatabase(), test.database); + std::vector<std::pair<std::string, std::string>> options(begin(rv.getOptions()), + end(rv.getOptions())); + std::sort(begin(options), end(options)); + std::vector<std::pair<std::string, std::string>> expectedOptions(begin(test.options), + end(test.options)); + std::sort(begin(expectedOptions), end(expectedOptions)); + + for (std::size_t i = 0; i < std::min(options.size(), expectedOptions.size()); ++i) { + if (options[i] != expectedOptions[i]) { + mongo::unittest::log() << "Option: \"" << options[i].first << "=" + << options[i].second << "\" doesn't equal: \"" + << expectedOptions[i].first << "=" + << expectedOptions[i].second << "\"" << std::endl; + std::cerr << "Failing URI: \"" << test.uri << "\"" << std::endl; + ASSERT(false); + } + } + ASSERT_EQ(options.size(), expectedOptions.size()); + + std::vector<HostAndPort> hosts(begin(rv.getServers()), end(rv.getServers())); + std::sort(begin(hosts), end(hosts)); + auto expectedHosts = test.hosts; + std::sort(begin(expectedHosts), end(expectedHosts)); + + for (std::size_t i = 0; i < std::min(hosts.size(), expectedHosts.size()); ++i) { + ASSERT_EQ(hosts[i], expectedHosts[i]); + } + ASSERT_TRUE(hosts.size() == expectedHosts.size()); + } +} + } // namespace diff --git a/src/mongo/client/mongo_uri_tests/mongo-uri-host-identifiers.json b/src/mongo/client/mongo_uri_tests/mongo-uri-host-identifiers.json index 8a1fda580a8..5af3797ccdb 100644 --- a/src/mongo/client/mongo_uri_tests/mongo-uri-host-identifiers.json +++ b/src/mongo/client/mongo_uri_tests/mongo-uri-host-identifiers.json @@ -108,7 +108,7 @@ { "description": "Multiple hosts (mixed formats)", "uri": "mongodb://127.0.0.1,[::1]:27018,example.com:27019", - "valid": false, + "valid": true, "warning": false, "hosts": [ { @@ -160,7 +160,7 @@ { "description": "UTF-8 hosts", "uri": "mongodb://bücher.example.com,umläut.example.com/", - "valid": false, + "valid": true, "warning": false, "hosts": [ { @@ -200,4 +200,4 @@ } } ] -}
\ No newline at end of file +} diff --git a/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-absolute.json b/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-absolute.json index d167c315bbc..f4f78883298 100644 --- a/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-absolute.json +++ b/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-absolute.json @@ -62,7 +62,7 @@ ], "options": null, "uri": "mongodb://%2Ftmp%2Fmongodb-27017.sock,%2Ftmp%2Fmongodb-27018.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -82,7 +82,7 @@ ], "options": null, "uri": "mongodb://127.0.0.1:27017,%2Ftmp%2Fmongodb-27017.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -102,7 +102,7 @@ ], "options": null, "uri": "mongodb://mongodb-27017.sock,%2Ftmp%2Fmongodb-27018.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -260,7 +260,7 @@ ], "options": null, "uri": "mongodb://%2Ftmp%2Fmongodb-27017.sock,%2Ftmp%2Fmongodb-27018.sock/admin", - "valid": false, + "valid": true, "warning": false }, { @@ -312,7 +312,7 @@ "w": 1 }, "uri": "mongodb://bob:bar@%2Ftmp%2Fmongodb-27017.sock,%2Ftmp%2Fmongodb-27018.sock/admin?w=1", - "valid": false, + "valid": true, "warning": false }, { diff --git a/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-relative.json b/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-relative.json index 8da45c4e1b9..a2d737ba72a 100644 --- a/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-relative.json +++ b/src/mongo/client/mongo_uri_tests/mongo-uri-unix-sockets-relative.json @@ -62,7 +62,7 @@ ], "options": null, "uri": "mongodb://rel%2Fmongodb-27017.sock,rel%2Fmongodb-27018.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -104,7 +104,7 @@ ], "options": null, "uri": "mongodb://rel%2Fmongodb-27017.sock,%2Ftmp%2Fmongodb-27018.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -146,7 +146,7 @@ ], "options": null, "uri": "mongodb://127.0.0.1:27017,rel%2Fmongodb-27017.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -188,7 +188,7 @@ ], "options": null, "uri": "mongodb://mongodb-27017.sock,rel%2Fmongodb-27018.sock", - "valid": false, + "valid": true, "warning": false }, { @@ -302,7 +302,7 @@ ], "options": null, "uri": "mongodb://rel%2Fmongodb-27017.sock,rel%2Fmongodb-27018.sock/admin", - "valid": false, + "valid": true, "warning": false }, { @@ -352,7 +352,7 @@ ], "options": null, "uri": "mongodb://rel%2Fmongodb-27017.sock,rel%2Fmongodb-27018.sock/admin", - "valid": false, + "valid": true, "warning": false }, { @@ -404,7 +404,7 @@ "w": 1 }, "uri": "mongodb://bob:bar@rel%2Fmongodb-27017.sock,rel%2Fmongodb-27018.sock/admin?w=1", - "valid": false, + "valid": true, "warning": false }, { diff --git a/src/mongo/client/mongo_uri_tests/mongo-uri-valid-auth.json b/src/mongo/client/mongo_uri_tests/mongo-uri-valid-auth.json index a44236e7b62..68228c6ddec 100644 --- a/src/mongo/client/mongo_uri_tests/mongo-uri-valid-auth.json +++ b/src/mongo/client/mongo_uri_tests/mongo-uri-valid-auth.json @@ -135,7 +135,7 @@ ], "options": null, "uri": "mongodb://alice:secret@127.0.0.1,example.com:27018", - "valid": false, + "valid": true, "warning": false }, { @@ -159,7 +159,7 @@ ], "options": null, "uri": "mongodb://alice:secret@example.com,[::1]:27019/admin", - "valid": false, + "valid": true, "warning": false }, { diff --git a/src/mongo/db/cloner.h b/src/mongo/db/cloner.h index 6c9abb64b12..6c8cc255b12 100644 --- a/src/mongo/db/cloner.h +++ b/src/mongo/db/cloner.h @@ -50,8 +50,8 @@ class Cloner { public: Cloner(); - void setConnection(DBClientBase* c) { - _conn.reset(c); + void setConnection(std::unique_ptr<DBClientBase> c) { + _conn = std::move(c); } /** diff --git a/src/mongo/db/commands/clone_collection.cpp b/src/mongo/db/commands/clone_collection.cpp index 0ec44950793..e5d1d7196f6 100644 --- a/src/mongo/db/commands/clone_collection.cpp +++ b/src/mongo/db/commands/clone_collection.cpp @@ -144,12 +144,11 @@ public: << (copyIndexes ? "" : ", not copying indexes"); Cloner cloner; - unique_ptr<DBClientConnection> myconn; - myconn.reset(new DBClientConnection()); + auto myconn = stdx::make_unique<DBClientConnection>(); if (!myconn->connect(HostAndPort(fromhost), StringData(), errmsg)) return false; - cloner.setConnection(myconn.release()); + cloner.setConnection(std::move(myconn)); return cloner.copyCollection(opCtx, collection, query, errmsg, copyIndexes); } diff --git a/src/mongo/db/commands/copydb.cpp b/src/mongo/db/commands/copydb.cpp index 0695a60382e..26fbf053bf9 100644 --- a/src/mongo/db/commands/copydb.cpp +++ b/src/mongo/db/commands/copydb.cpp @@ -187,7 +187,7 @@ public: return false; } } - cloner.setConnection(authConn.release()); + cloner.setConnection(std::move(authConn)); } else if (cmdObj.hasField(saslCommandConversationIdFieldName) && cmdObj.hasField(saslCommandPayloadFieldName)) { uassert(25487, "must call copydbsaslstart first", authConn.get()); @@ -208,16 +208,16 @@ public: } result.append("done", true); - cloner.setConnection(authConn.release()); + cloner.setConnection(std::move(authConn)); } else if (!fromSelf) { // If fromSelf leave the cloner's conn empty, it will use a DBDirectClient instead. const ConnectionString cs(uassertStatusOK(ConnectionString::parse(fromhost))); - DBClientBase* conn = cs.connect(StringData(), errmsg); + auto conn = cs.connect(StringData(), errmsg); if (!conn) { return false; } - cloner.setConnection(conn); + cloner.setConnection(std::move(conn)); } // Either we didn't need the authConn (if we even had one), or we already moved it diff --git a/src/mongo/db/commands/copydb_start_commands.cpp b/src/mongo/db/commands/copydb_start_commands.cpp index d537d621977..32b63d3bfbd 100644 --- a/src/mongo/db/commands/copydb_start_commands.cpp +++ b/src/mongo/db/commands/copydb_start_commands.cpp @@ -112,7 +112,7 @@ public: const ConnectionString cs(uassertStatusOK(ConnectionString::parse(fromhost))); auto& authConn = CopyDbAuthConnection::forClient(opCtx->getClient()); - authConn.reset(cs.connect(StringData(), errmsg)); + authConn = cs.connect(StringData(), errmsg); if (!authConn) { return false; } @@ -208,7 +208,7 @@ public: } auto& authConn = CopyDbAuthConnection::forClient(opCtx->getClient()); - authConn.reset(cs.connect(StringData(), errmsg)); + authConn = cs.connect(StringData(), errmsg); if (!authConn.get()) { return false; } diff --git a/src/mongo/db/repl/rollback_source_impl.cpp b/src/mongo/db/repl/rollback_source_impl.cpp index f0f5667c25d..5de7dde2337 100644 --- a/src/mongo/db/repl/rollback_source_impl.cpp +++ b/src/mongo/db/repl/rollback_source_impl.cpp @@ -82,14 +82,14 @@ std::pair<BSONObj, NamespaceString> RollbackSourceImpl::findOneByUUID(const std: void RollbackSourceImpl::copyCollectionFromRemote(OperationContext* opCtx, const NamespaceString& nss) const { std::string errmsg; - std::unique_ptr<DBClientConnection> tmpConn(new DBClientConnection()); + auto tmpConn = stdx::make_unique<DBClientConnection>(); uassert(15908, errmsg, tmpConn->connect(_source, StringData(), errmsg) && replAuthenticate(tmpConn.get())); // cloner owns _conn in unique_ptr Cloner cloner; - cloner.setConnection(tmpConn.release()); + cloner.setConnection(std::move(tmpConn)); uassert(15909, str::stream() << "replSet rollback error resyncing collection " << nss.ns() << ' ' << errmsg, diff --git a/src/mongo/dbtests/mock/mock_conn_registry.cpp b/src/mongo/dbtests/mock/mock_conn_registry.cpp index 29fd5e0be9b..8d01e8b38d2 100644 --- a/src/mongo/dbtests/mock/mock_conn_registry.cpp +++ b/src/mongo/dbtests/mock/mock_conn_registry.cpp @@ -77,31 +77,27 @@ void MockConnRegistry::clear() { _registry.clear(); } -MockDBClientConnection* MockConnRegistry::connect(const std::string& connStr) { +std::unique_ptr<MockDBClientConnection> MockConnRegistry::connect(const std::string& connStr) { stdx::lock_guard<stdx::mutex> sl(_registryMutex); fassert(16534, _registry.count(connStr) == 1); - return new MockDBClientConnection(_registry[connStr], true); + return stdx::make_unique<MockDBClientConnection>(_registry[connStr], true); } MockConnRegistry::MockConnHook::MockConnHook(MockConnRegistry* registry) : _registry(registry) {} MockConnRegistry::MockConnHook::~MockConnHook() {} -mongo::DBClientBase* MockConnRegistry::MockConnHook::connect(const ConnectionString& connString, - std::string& errmsg, - double socketTimeout) { +std::unique_ptr<mongo::DBClientBase> MockConnRegistry::MockConnHook::connect( + const ConnectionString& connString, std::string& errmsg, double socketTimeout) { const string hostName(connString.toString()); - MockDBClientConnection* conn = _registry->connect(hostName); + auto conn = _registry->connect(hostName); if (!conn->connect(hostName.c_str(), StringData(), errmsg)) { - // Assumption: connect never throws, so no leak. - delete conn; - // mimic ConnectionString::connect for MASTER type connection to return NULL // if the destination is unreachable. - return NULL; + return nullptr; } - return conn; -} + return std::move(conn); } +} // namespace mongo diff --git a/src/mongo/dbtests/mock/mock_conn_registry.h b/src/mongo/dbtests/mock/mock_conn_registry.h index 33764885a69..d72a702aea5 100644 --- a/src/mongo/dbtests/mock/mock_conn_registry.h +++ b/src/mongo/dbtests/mock/mock_conn_registry.h @@ -79,7 +79,7 @@ public: /** * @return a new mocked connection to a server with the given hostName. */ - MockDBClientConnection* connect(const std::string& hostName); + std::unique_ptr<MockDBClientConnection> connect(const std::string& hostName); /** * @return the hook that can be used with ConnectionString. @@ -100,9 +100,9 @@ private: MockConnHook(MockConnRegistry* registry); ~MockConnHook(); - mongo::DBClientBase* connect(const mongo::ConnectionString& connString, - std::string& errmsg, - double socketTimeout); + std::unique_ptr<mongo::DBClientBase> connect(const mongo::ConnectionString& connString, + std::string& errmsg, + double socketTimeout); private: MockConnRegistry* _registry; diff --git a/src/mongo/shell/dbshell.cpp b/src/mongo/shell/dbshell.cpp index c3c30ed5b09..03ca476e744 100644 --- a/src/mongo/shell/dbshell.cpp +++ b/src/mongo/shell/dbshell.cpp @@ -223,11 +223,13 @@ string getURIFromArgs(const std::string& arg, const std::string& host, const std return kDefaultMongoURL.toString(); } - if (str::startsWith(arg, "mongodb://") && host.empty() && port.empty()) { + if ((str::startsWith(arg, "mongodb://") || str::startsWith(arg, "mongodb+srv://")) && + host.empty() && port.empty()) { // mongo mongodb://blah return arg; } - if (str::startsWith(host, "mongodb://") && arg.empty() && port.empty()) { + if ((str::startsWith(host, "mongodb://") || str::startsWith(arg, "mongodb+srv://")) && + arg.empty() && port.empty()) { // mongo --host mongodb://blah return host; } diff --git a/src/mongo/shell/mongo.js b/src/mongo/shell/mongo.js index 76d0d53dfb0..9ad8829e1ff 100644 --- a/src/mongo/shell/mongo.js +++ b/src/mongo/shell/mongo.js @@ -231,7 +231,7 @@ connect = function(url, user, pass) { throw Error("Empty connection string"); } - if (!url.startsWith("mongodb://")) { + if (!url.startsWith("mongodb://") && !url.startsWith("mongodb+srv://")) { const colon = url.lastIndexOf(":"); const slash = url.lastIndexOf("/"); if (url.split("/").length > 1) { diff --git a/src/mongo/shell/shell_options.cpp b/src/mongo/shell/shell_options.cpp index 7b549fc57c5..32dcf89bef0 100644 --- a/src/mongo/shell/shell_options.cpp +++ b/src/mongo/shell/shell_options.cpp @@ -418,7 +418,8 @@ Status storeMongoShellOptions(const moe::Environment& params, return Status(ErrorCodes::BadValue, sb.str()); } - if (shellGlobalParams.url.find("mongodb://") == 0) { + if ((shellGlobalParams.url.find("mongodb://") == 0) && + (shellGlobalParams.url.find("mongodb+srv://") == 0)) { auto cs_status = MongoURI::parse(shellGlobalParams.url); if (!cs_status.isOK()) { return cs_status.getStatus(); diff --git a/src/mongo/util/SConscript b/src/mongo/util/SConscript index 6e9416443ff..bdcc18e55fc 100644 --- a/src/mongo/util/SConscript +++ b/src/mongo/util/SConscript @@ -494,6 +494,25 @@ env.Library( ) env.Library( + target='dns_query', + source=[ + 'dns_query.cpp', + ], + LIBDEPS_PRIVATE=[ + "$BUILD_DIR/mongo/base", + ], +) + +env.CppUnitTest( + target='dns_query_test', + source=['dns_query_test.cpp'], + LIBDEPS=[ + 'dns_query', + "$BUILD_DIR/mongo/base", + ] +) + +env.Library( target="secure_zero_memory", source=[ 'secure_zero_memory.cpp', diff --git a/src/mongo/util/dns_query.cpp b/src/mongo/util/dns_query.cpp new file mode 100644 index 00000000000..9ce9e8743c6 --- /dev/null +++ b/src/mongo/util/dns_query.cpp @@ -0,0 +1,119 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/util/dns_query.h" + +#include <array> +#include <cassert> +#include <cstdint> +#include <exception> +#include <iostream> +#include <memory> +#include <sstream> +#include <stdexcept> +#include <string> +#include <vector> + +#include <boost/noncopyable.hpp> + +// It is safe to include the implementation "headers" in an anonymous namespace, as the code is +// meant to live in a single TU -- this one. Include one of these headers last. +#define MONGO_UTIL_DNS_QUERY_PLATFORM_INCLUDE_WHITELIST +#ifdef WIN32 +#include "mongo/util/dns_query_windows-impl.h" +#else +#include "mongo/util/dns_query_posix-impl.h" +#endif +#undef MONGO_UTIL_DNS_QUERY_PLATFORM_INCLUDE_WHITELIST + +using std::begin; +using std::end; +using namespace std::literals::string_literals; + +namespace mongo { + +/** + * Returns a string with the IP address or domain name listed... + */ +std::vector<std::string> dns::lookupARecords(const std::string& service) { + DNSQueryState dnsQuery; + auto response = dnsQuery.lookup(service, DNSQueryClass::kInternet, DNSQueryType::kAddress); + + if (response.size() == 0) { + throw DBException(ErrorCodes::DNSProtocolError, + "Looking up " + service + " A record yielded no results."); + } + + std::vector<std::string> rv; + std::transform(begin(response), end(response), back_inserter(rv), [](const auto& entry) { + return entry.addressEntry(); + }); + + return rv; +} + +std::vector<dns::SRVHostEntry> dns::lookupSRVRecords(const std::string& service) { + DNSQueryState dnsQuery; + + auto response = dnsQuery.lookup(service, DNSQueryClass::kInternet, DNSQueryType::kSRV); + + std::vector<SRVHostEntry> rv; + + std::transform(begin(response), end(response), back_inserter(rv), [](const auto& entry) { + return entry.srvHostEntry(); + }); + return rv; +} + +std::vector<std::string> dns::lookupTXTRecords(const std::string& service) { + DNSQueryState dnsQuery; + + auto response = dnsQuery.lookup(service, DNSQueryClass::kInternet, DNSQueryType::kTXT); + + std::vector<std::string> rv; + + for (auto& entry : response) { + auto txtEntry = entry.txtEntry(); + rv.insert(end(rv), + std::make_move_iterator(begin(txtEntry)), + std::make_move_iterator(end(txtEntry))); + } + return rv; +} + +std::vector<std::string> dns::getTXTRecords(const std::string& service) try { + return lookupTXTRecords(service); +} catch (const DBException& ex) { + if (ex.code() == ErrorCodes::DNSHostNotFound) { + return {}; + } + throw; +} +} // namespace mongo diff --git a/src/mongo/util/dns_query.h b/src/mongo/util/dns_query.h new file mode 100644 index 00000000000..e8b203c8f71 --- /dev/null +++ b/src/mongo/util/dns_query.h @@ -0,0 +1,113 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <array> +#include <cstdint> +#include <exception> +#include <iomanip> +#include <iostream> +#include <stdexcept> +#include <string> +#include <vector> + +#include <boost/noncopyable.hpp> + +#include "mongo/util/assert_util.h" + +namespace mongo { +namespace dns { +/** + * An `SRVHostEntry` object represents the information received from a DNS lookup of an SRV record. + * NOTE: An SRV DNS record has several fields, such as priority and weight. This structure lacks + * those fields at this time. They should be safe to add in the future. The code using this + * structure does not need access to those fields at this time. + */ +struct SRVHostEntry { + std::string host; + std::uint16_t port; + + SRVHostEntry(std::string initialHost, const std::uint16_t initialPort) + : host(std::move(initialHost)), port(initialPort) {} + + inline auto makeRelopsLens() const { + return std::tie(host, port); + } + + inline friend std::ostream& operator<<(std::ostream& os, const SRVHostEntry& entry) { + return os << entry.host << ':' << entry.port; + } + + inline friend bool operator==(const SRVHostEntry& lhs, const SRVHostEntry& rhs) { + return lhs.makeRelopsLens() == rhs.makeRelopsLens(); + } + + inline friend bool operator!=(const SRVHostEntry& lhs, const SRVHostEntry& rhs) { + return !(lhs == rhs); + } + + inline friend bool operator<(const SRVHostEntry& lhs, const SRVHostEntry& rhs) { + return lhs.makeRelopsLens() < rhs.makeRelopsLens(); + } +}; + +/** + * Returns a vector containing SRVHost entries for the specified `service`. + * THROWS: `DBException` with `ErrorCodes::DNSHostNotFound` as the status value if the entry is not + * found and `ErrorCodes::DNSProtocolError` as the status value if the DNS lookup fails, for any + * other reason + */ +std::vector<SRVHostEntry> lookupSRVRecords(const std::string& service); + +/** + * Returns a group of strings containing text from DNS TXT entries for a specified service. + * THROWS: `DBException` with `ErrorCodes::DNSHostNotFound` as the status value if the entry is not + * found and `ErrorCodes::DNSProtocolError` as the status value if the DNS lookup fails, for any + * other reason + */ +std::vector<std::string> lookupTXTRecords(const std::string& service); + +/** + * Returns a group of strings containing text from DNS TXT entries for a specified service. + * If the lookup fails because the record doesn't exist, an empty vector is returned. + * THROWS: `DBException` with `ErrorCodes::DNSProtocolError` as th status value if the DNS lookup + * fails for any other reason. + */ +std::vector<std::string> getTXTRecords(const std::string& service); + +/** + * Returns a group of strings containing Address entries for a specified service. + * THROWS: `DBException` with `ErrorCodes::DNSHostNotFound` as the status value if the entry is not + * found and `ErrorCodes::DNSProtocolError` as the status value if the DNS lookup fails, for any + * other reason + * NOTE: This function mostly exists to provide an easy test of the OS DNS APIs in our test driver. + */ +std::vector<std::string> lookupARecords(const std::string& service); +} // namespace dns +} // namespace mongo diff --git a/src/mongo/util/dns_query_posix-impl.h b/src/mongo/util/dns_query_posix-impl.h new file mode 100644 index 00000000000..259135e2735 --- /dev/null +++ b/src/mongo/util/dns_query_posix-impl.h @@ -0,0 +1,352 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#ifndef MONGO_UTIL_DNS_QUERY_PLATFORM_INCLUDE_WHITELIST +#error Do not include the DNS Query platform implementation headers. Please use "mongo/util/dns_query.h" instead. +#endif + +// DNS Headers for POSIX/libresolv have to be included in a specific order +// clang-format off +#include <sys/types.h> +#include <netinet/in.h> +#include <arpa/nameser.h> +#include <resolv.h> +// clang-format on + +#include <stdio.h> + +#include <iostream> +#include <cassert> +#include <sstream> +#include <string> +#include <cstdint> +#include <vector> +#include <array> +#include <stdexcept> +#include <memory> +#include <exception> + +#include <boost/noncopyable.hpp> + +namespace mongo { +namespace dns { +// The anonymous namespace is safe, in this header, as it is not really a header. It is only used +// in the `dns_query.cpp` TU. +namespace { + +using std::begin; +using std::end; +using namespace std::literals::string_literals; + +const std::size_t kMaxExpectedDNSResponseSize = 65536; +const std::size_t kMaxSRVHostNameSize = 8192; + +enum class DNSQueryClass { + kInternet = ns_c_in, +}; + +enum class DNSQueryType { + kSRV = ns_t_srv, + kTXT = ns_t_txt, + kAddress = ns_t_a, +}; + +/** + * A `ResourceRecord` represents a single DNS entry as parsed by the resolver API. + * It can be viewed as one of various record types, using the member functions. + * It roughly corresponds to the DNS RR data structure + */ +class ResourceRecord { +public: + explicit ResourceRecord() = default; + + explicit ResourceRecord(std::string initialService, ns_msg& ns_answer, const int initialPos) + : _service(std::move(initialService)), + _answerStart(ns_msg_base(ns_answer)), + _answerEnd(ns_msg_end(ns_answer)), + _pos(initialPos) { + if (ns_parserr(&ns_answer, ns_s_an, initialPos, &this->_resource_record)) + this->_badRecord(); + } + + /** + * View this record as a DNS TXT record. + */ + std::vector<std::string> txtEntry() const { + const auto data = this->_rawData(); + if (data.empty()) { + throw DBException(ErrorCodes::DNSProtocolError, + "DNS TXT Record is not correctly sized"); + } + const std::size_t amount = data.front(); + const auto first = begin(data) + 1; + std::vector<std::string> rv; + if (data.size() - 1 < amount) { + throw DBException(ErrorCodes::DNSProtocolError, + "DNS TXT Record is not correctly sized"); + } + rv.emplace_back(first, first + amount); + return rv; + } + + /** + * View this record as a DNS A record. + */ + std::string addressEntry() const { + std::string rv; + + auto data = _rawData(); + if (data.size() != 4) { + throw DBException(ErrorCodes::DNSProtocolError, "DNS A Record is not correctly sized"); + } + for (const std::uint8_t& ch : data) { + std::ostringstream oss; + oss << int(ch); + rv += oss.str() + "."; + } + rv.pop_back(); + return rv; + } + + /** + * View this record as a DNS SRV record. + */ + SRVHostEntry srvHostEntry() const { + const std::size_t kPortOffsetInPacket = 4; + + const std::uint8_t* const data = ns_rr_rdata(this->_resource_record); + if (data < this->_answerStart || + data + kPortOffsetInPacket + sizeof(std::uint16_t) > this->_answerEnd) { + std::ostringstream oss; + oss << "Invalid record " << this->_pos << " of SRV answer for \"" << this->_service + << "\": Incorrect result size"; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + } + const std::uint16_t port = [data] { + std::uint16_t tmp; + memcpy(&tmp, data + kPortOffsetInPacket, sizeof(tmp)); + return ntohs(tmp); + }(); + + // The '@' is an impermissible character in a host name, so we populate the string we'll + // return with it, such that a failure in string manipulation or corrupted dns packets will + // cause an illegal hostname. + std::string name(kMaxSRVHostNameSize, '@'); + + const auto size = dn_expand(this->_answerStart, + this->_answerEnd, + data + kPortOffsetInPacket + sizeof(port), + &name[0], + name.size()); + + if (size < 1) + this->_badRecord(); + + // Trim the expanded name + name.resize(name.find('\0')); + name += '.'; + + // return by copy is equivalent to a `shrink_to_fit` and `move`. + return {name, port}; + } + +private: + void _badRecord() const { + std::ostringstream oss; + oss << "Invalid record " << this->_pos << " of DNS answer for \"" << this->_service + << "\": \"" << strerror(errno) << "\""; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + }; + + std::vector<std::uint8_t> _rawData() const { + const std::uint8_t* const data = ns_rr_rdata(this->_resource_record); + const std::size_t length = ns_rr_rdlen(this->_resource_record); + + return {data, data + length}; + } + + std::string _service; + ns_rr _resource_record; + const std::uint8_t* _answerStart; + const std::uint8_t* _answerEnd; + int _pos; +}; + +/** + * The `DNSResponse` class represents a response to a DNS query. + * It has STL-compatible iterators to view individual DNS Resource Records within a response. + */ +class DNSResponse { +public: + explicit DNSResponse(std::string initialService, std::vector<std::uint8_t> initialData) + : _service(std::move(initialService)), _data(std::move(initialData)) { + if (ns_initparse(this->_data.data(), this->_data.size(), &this->_ns_answer)) { + std::ostringstream oss; + oss << "Invalid SRV answer for \"" << this->_service << "\""; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + } + + this->_nRecords = ns_msg_count(this->_ns_answer, ns_s_an); + + if (!this->_nRecords) { + std::ostringstream oss; + oss << "No SRV records for \"" << this->_service << "\""; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + } + } + + class iterator { + public: + auto makeRelopsLens() const { + return std::tie(this->_response, this->_pos); + } + + inline friend bool operator==(const iterator& lhs, const iterator& rhs) { + return lhs.makeRelopsLens() == rhs.makeRelopsLens(); + } + + inline friend bool operator<(const iterator& lhs, const iterator& rhs) { + return lhs.makeRelopsLens() < rhs.makeRelopsLens(); + } + + inline friend bool operator!=(const iterator& lhs, const iterator& rhs) { + return !(lhs == rhs); + } + + const ResourceRecord& operator*() { + this->_populate(); + return this->_record; + } + + const ResourceRecord* operator->() { + this->_populate(); + return &this->_record; + } + + iterator& operator++() { + this->_advance(); + return *this; + } + + iterator operator++(int) { + iterator tmp = *this; + this->_advance(); + return tmp; + } + + private: + friend DNSResponse; + + explicit iterator(DNSResponse* const r) + : _response(r), _record(this->_response->_service, this->_response->_ns_answer, 0) {} + + explicit iterator(DNSResponse* const initialResponse, const int initialPos) + : _response(initialResponse), _pos(initialPos) {} + + void _populate() { + if (this->_ready) { + return; + } + this->_record = + ResourceRecord(this->_response->_service, this->_response->_ns_answer, this->_pos); + this->_ready = true; + } + + void _advance() { + ++this->_pos; + this->_ready = false; + } + + DNSResponse* _response; + int _pos = 0; + ResourceRecord _record; + bool _ready = false; + }; + + auto begin() { + return iterator(this); + } + + auto end() { + return iterator(this, this->_nRecords); + } + + std::size_t size() const { + return this->_nRecords; + } + +private: + std::string _service; + std::vector<std::uint8_t> _data; + ns_msg _ns_answer; + std::size_t _nRecords; +}; + +/** + * The `DNSQueryState` object represents the state of a DNS query interface, on Unix-like systems. + */ +class DNSQueryState : boost::noncopyable { +public: + std::vector<std::uint8_t> raw_lookup(const std::string& service, + const DNSQueryClass class_, + const DNSQueryType type) { + std::vector<std::uint8_t> result(kMaxExpectedDNSResponseSize); + const int size = res_nsearch( + &_state, service.c_str(), int(class_), int(type), &result[0], result.size()); + + if (size < 0) { + std::ostringstream oss; + oss << "Failed to look up service \"" << service << "\": " << strerror(errno); + throw DBException(ErrorCodes::DNSHostNotFound, oss.str()); + } + result.resize(size); + + return result; + } + + DNSResponse lookup(const std::string& service, + const DNSQueryClass class_, + const DNSQueryType type) { + return DNSResponse(service, raw_lookup(service, class_, type)); + } + +public: + ~DNSQueryState() { + res_nclose(&_state); + } + + DNSQueryState() : _state() { + res_ninit(&_state); + } + +private: + struct __res_state _state; +}; +} // namespace +} // namespace dns +} // namespace mongo diff --git a/src/mongo/util/dns_query_test.cpp b/src/mongo/util/dns_query_test.cpp new file mode 100644 index 00000000000..8f26ab8ec62 --- /dev/null +++ b/src/mongo/util/dns_query_test.cpp @@ -0,0 +1,196 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ +#include "mongo/util/dns_query.h" + +#include "mongo/unittest/unittest.h" + +using namespace std::literals::string_literals; + +namespace { +std::string getFirstARecord(const std::string& service) { + auto res = mongo::dns::lookupARecords(service); + if (res.empty()) + return ""; + return res.front(); +} + +TEST(MongoDnsQuery, basic) { + // We only require 50% of the records to pass, because it is possible that some large scale + // outages could cause some of these records to fail. + const double kPassingPercentage = 0.50; + std::size_t resolution_count = 0; + const struct { + std::string dns; + std::string ip; + } tests[] = + // The reason for a vast number of tests over basic DNS query calls is to provide a + // redundancy in testing. We'd like to make sure that this test always passes. Lazy + // maintanance will cause some references to be commented out. Our belief is that all 13 + // root servers and both of Google's public servers will all be unresolvable (when + // connections are available) only when a major problem occurs. This test only fails if + // more than half of the resolved names fail. + { + // These can be kept up to date by checking the root-servers.org listings. Note that + // root name servers are located in the `root-servers.net.` domain, NOT in the + // `root-servers.org.` domain. The `.org` domain is for webpages with statistics on + // these servers. The `.net` domain is the domain with the canonical addresses for + // these servers. + {"a.root-servers.net.", "198.41.0.4"}, + {"b.root-servers.net.", "192.228.79.201"}, + {"c.root-servers.net.", "192.33.4.12"}, + {"d.root-servers.net.", "199.7.91.13"}, + {"e.root-servers.net.", "192.203.230.10"}, + {"f.root-servers.net.", "192.5.5.241"}, + {"g.root-servers.net.", "192.112.36.4"}, + {"h.root-servers.net.", "198.97.190.53"}, + {"i.root-servers.net.", "192.36.148.17"}, + {"j.root-servers.net.", "192.58.128.30"}, + {"k.root-servers.net.", "193.0.14.129"}, + {"l.root-servers.net.", "199.7.83.42"}, + {"m.root-servers.net.", "202.12.27.33"}, + + // These can be kept up to date by checking with Google's public dns service. + {"google-public-dns-a.google.com.", "8.8.8.8"}, + {"google-public-dns-b.google.com.", "8.8.4.4"}, + }; + for (const auto& test : tests) { + try { + const auto witness = getFirstARecord(test.dns); + std::cout << "Resolved " << test.dns << " to: " << witness << std::endl; + + const bool resolution = (witness == test.ip); + if (!resolution) + std::cerr << "Warning: Did not correctly resolve " << test.dns << std::endl; + resolution_count += resolution; + } + // Failure to resolve is okay, but not great -- print a warning + catch (const mongo::DBException& ex) { + std::cerr << "Warning: Did not resolve " << test.dns << " at all: " << ex.what() + << std::endl; + } + } + + // As long as enough tests pass, we're okay -- this means that a single DNS name server drift + // won't cause a BF -- when enough fail, then we can rebuild the list in one pass. + const std::size_t kPassingRate = sizeof(tests) / sizeof(tests[0]) * kPassingPercentage; + ASSERT_GTE(resolution_count, kPassingRate); +} + +TEST(MongoDnsQuery, srvRecords) { + const auto kMongodbSRVPrefix = "_mongodb._tcp."s; + const struct { + std::string query; + std::vector<mongo::dns::SRVHostEntry> result; + } tests[] = { + {"test1.test.build.10gen.cc.", + { + {"localhost.build.10gen.cc.", 27017}, {"localhost.build.10gen.cc.", 27018}, + }}, + {"test2.test.build.10gen.cc.", + { + {"localhost.build.10gen.cc.", 27018}, {"localhost.build.10gen.cc.", 27019}, + }}, + {"test3.test.build.10gen.cc.", + { + {"localhost.build.10gen.cc.", 27017}, + }}, + + // Test case 4 does not exist in the expected DNS records. + {"test4.test.build.10gen.cc.", {}}, + + {"test5.test.build.10gen.cc.", + { + {"localhost.build.10gen.cc.", 27017}, + }}, + {"test6.test.build.10gen.cc.", + { + {"localhost.build.10gen.cc.", 27017}, + }}, + }; + for (const auto& test : tests) { + const auto& expected = test.result; + if (expected.empty()) { + ASSERT_THROWS_CODE(mongo::dns::lookupSRVRecords(kMongodbSRVPrefix + test.query), + mongo::DBException, + mongo::ErrorCodes::DNSHostNotFound); + continue; + } + + auto witness = mongo::dns::lookupSRVRecords(kMongodbSRVPrefix + test.query); + std::sort(begin(witness), end(witness)); + + for (const auto& entry : witness) { + std::cout << "Entry: " << entry << std::endl; + } + + for (std::size_t i = 0; i < witness.size() && i < expected.size(); ++i) { + std::cout << "Expected: " << expected.at(i) << std::endl; + std::cout << "Witness: " << witness.at(i) << std::endl; + ASSERT_EQ(witness.at(i), expected.at(i)); + } + + ASSERT_TRUE(std::equal(begin(witness), end(witness), begin(expected), end(expected))); + ASSERT_TRUE(witness.size() == expected.size()); + } +} + +TEST(MongoDnsQuery, txtRecords) { + const struct { + std::string query; + std::vector<std::string> result; + } tests[] = { + // Test case 4 does not exist in the expected DNS records. + {"test4.test.build.10gen.cc.", {}}, + + {"test5.test.build.10gen.cc", + { + "connectTimeoutMS=300000&socketTimeoutMS=300000", + }}, + {"test6.test.build.10gen.cc", + { + "connectTimeoutMS=200000", "socketTimeoutMS=200000", + }}, + }; + + for (const auto& test : tests) { + try { + auto witness = mongo::dns::getTXTRecords(test.query); + std::sort(begin(witness), end(witness)); + + const auto& expected = test.result; + + ASSERT_TRUE(std::equal(begin(witness), end(witness), begin(expected), end(expected))); + ASSERT_EQ(witness.size(), expected.size()); + } catch (const mongo::DBException& ex) { + if (ex.code() != mongo::ErrorCodes::DNSHostNotFound) + throw; + ASSERT_TRUE(test.result.empty()); + } + } +} +} // namespace diff --git a/src/mongo/util/dns_query_windows-impl.h b/src/mongo/util/dns_query_windows-impl.h new file mode 100644 index 00000000000..04fe778c5eb --- /dev/null +++ b/src/mongo/util/dns_query_windows-impl.h @@ -0,0 +1,255 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#ifndef MONGO_UTIL_DNS_QUERY_PLATFORM_INCLUDE_WHITELIST +#error Do not include the DNS Query platform implementation headers. Please use "mongo/util/dns_query.h" instead. +#endif + +#include <windns.h> + +#include <stdio.h> + +#include <array> +#include <cassert> +#include <cstdint> +#include <exception> +#include <iostream> +#include <memory> +#include <sstream> +#include <stdexcept> +#include <string> +#include <vector> + +#include <boost/noncopyable.hpp> + +#include "mongo/util/errno_util.h" + +using std::begin; +using std::end; +using namespace std::literals::string_literals; + +namespace mongo { +namespace dns { +// The anonymous namespace is safe, in this header, as it is not really a header. It is only used +// in the `dns_query.cpp` TU. +namespace { +enum class DNSQueryClass { kInternet }; + +enum class DNSQueryType { kSRV = DNS_TYPE_SRV, kTXT = DNS_TYPE_TEXT, kAddress = DNS_TYPE_A }; + +/** + * A `ResourceRecord` represents a single DNS entry as parsed by the resolver API. + * It can be viewed as one of various record types, using the member functions. + * It roughly corresponds to the DNS RR data structure + */ +class ResourceRecord { +public: + explicit ResourceRecord(std::shared_ptr<DNS_RECORDA> initialRecord) + : _record(std::move(initialRecord)) {} + explicit ResourceRecord() = default; + + /** + * View this record as a DNS TXT record. + */ + std::vector<std::string> txtEntry() const { + if (this->_record->wType != DNS_TYPE_TEXT) { + std::ostringstream oss; + oss << "Incorrect record format for \"" << this->_service + << "\": expected TXT record, found something else"; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + } + + std::vector<std::string> rv; + + const auto start = this->_record->Data.TXT.pStringArray; + const auto count = this->_record->Data.TXT.dwStringCount; + std::copy(start, start + count, back_inserter(rv)); + return rv; + } + + /** + * View this record as a DNS A record. + */ + std::string addressEntry() const { + if (this->_record->wType != DNS_TYPE_A) { + std::ostringstream oss; + oss << "Incorrect record format for \"" << this->_service + << "\": expected A record, found something else"; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + } + + std::string rv; + const auto& data = this->_record->Data.A.IpAddress; + + for (int i = 0; i < 4; ++i) { + std::ostringstream oss; + oss << int(data >> (i * CHAR_BIT) & 0xFF); + rv += oss.str() + "."; + } + rv.pop_back(); + + return rv; + } + + /** + * View this record as a DNS SRV record. + */ + SRVHostEntry srvHostEntry() const { + if (this->_record->wType != DNS_TYPE_SRV) { + std::ostringstream oss; + oss << "Incorrect record format for \"" << this->_service + << "\": expected SRV record, found something else"; + throw DBException(ErrorCodes::DNSProtocolError, oss.str()); + } + + const auto& data = this->_record->Data.SRV; + return {data.pNameTarget + "."s, data.wPort}; + } + +private: + std::string _service; + std::shared_ptr<DNS_RECORDA> _record; +}; + +void freeDnsRecord(PDNS_RECORDA record) { + DnsRecordListFree(record, DnsFreeRecordList); +} + +/** + * The `DNSResponse` class represents a response to a DNS query. + * It has STL-compatible iterators to view individual DNS Resource Records within a response. + */ +class DNSResponse { +public: + explicit DNSResponse(PDNS_RECORDA initialResults) : _results(initialResults, freeDnsRecord) {} + + class iterator : public std::iterator<std::forward_iterator_tag, ResourceRecord> { + public: + const ResourceRecord& operator*() { + this->_populate(); + return this->_storage; + } + + const ResourceRecord* operator->() { + this->_populate(); + return &this->_storage; + } + + iterator& operator++() { + this->_advance(); + return *this; + } + + iterator operator++(int) { + iterator tmp = *this; + this->_advance(); + return tmp; + } + + auto makeRelopsLens() const { + return this->_record.get(); + } + + inline friend bool operator==(const iterator& lhs, const iterator& rhs) { + return lhs.makeRelopsLens() == rhs.makeRelopsLens(); + } + + inline friend bool operator<(const iterator& lhs, const iterator& rhs) { + return lhs.makeRelopsLens() < rhs.makeRelopsLens(); + } + + inline friend bool operator!=(const iterator& lhs, const iterator& rhs) { + return !(lhs == rhs); + } + + private: + friend DNSResponse; + + explicit iterator(std::shared_ptr<DNS_RECORDA> initialRecord) + : _record(std::move(initialRecord)) {} + + void _advance() { + this->_record = {this->_record, this->_record->pNext}; + this->_ready = false; + } + + void _populate() { + if (this->_ready) { + return; + } + this->_storage = ResourceRecord{this->_record}; + this->_ready = true; + } + + std::shared_ptr<DNS_RECORDA> _record; + ResourceRecord _storage; + bool _ready = false; + }; + + iterator begin() const { + return iterator{this->_results}; + } + + iterator end() const { + return iterator{nullptr}; + } + + std::size_t size() const { + return std::distance(this->begin(), this->end()); + } + +private: + std::shared_ptr<std::remove_pointer<PDNS_RECORDA>::type> _results; +}; + +/** + * The `DNSQueryState` object represents the state of a DNS query interface, on Windows systems. + */ +class DNSQueryState { +public: + DNSResponse lookup(const std::string& service, + const DNSQueryClass class_, + const DNSQueryType type) { + PDNS_RECORDA queryResults; + auto ec = DnsQuery_UTF8(service.c_str(), + WORD(type), + DNS_QUERY_BYPASS_CACHE, + nullptr, + reinterpret_cast<PDNS_RECORD*>(&queryResults), + nullptr); + + if (ec) { + throw DBException(ErrorCodes::DNSHostNotFound, + "Failed to look up service \""s + "\":"s + errnoWithDescription(ec)); + } + return DNSResponse{queryResults}; + } +}; +} // namespace +} // namespace dns +} // namespace mongo |