diff options
Diffstat (limited to 'src/mongo/util/net')
26 files changed, 4709 insertions, 4692 deletions
diff --git a/src/mongo/util/net/hostandport.cpp b/src/mongo/util/net/hostandport.cpp index 25455bd7b6f..ea1d0f93467 100644 --- a/src/mongo/util/net/hostandport.cpp +++ b/src/mongo/util/net/hostandport.cpp @@ -41,141 +41,134 @@ namespace mongo { - StatusWith<HostAndPort> HostAndPort::parse(StringData text) { - HostAndPort result; - Status status = result.initialize(text); - if (!status.isOK()) { - return StatusWith<HostAndPort>(status); - } - return StatusWith<HostAndPort>(result); +StatusWith<HostAndPort> HostAndPort::parse(StringData text) { + HostAndPort result; + Status status = result.initialize(text); + if (!status.isOK()) { + return StatusWith<HostAndPort>(status); } + return StatusWith<HostAndPort>(result); +} - HostAndPort::HostAndPort() : _port(-1) {} +HostAndPort::HostAndPort() : _port(-1) {} - HostAndPort::HostAndPort(StringData text) { - uassertStatusOK(initialize(text)); - } +HostAndPort::HostAndPort(StringData text) { + uassertStatusOK(initialize(text)); +} - HostAndPort::HostAndPort(const std::string& h, int p) : _host(h), _port(p) {} +HostAndPort::HostAndPort(const std::string& h, int p) : _host(h), _port(p) {} - bool HostAndPort::operator<(const HostAndPort& r) const { - const int cmp = host().compare(r.host()); - if (cmp) - return cmp < 0; - return port() < r.port(); - } +bool HostAndPort::operator<(const HostAndPort& r) const { + const int cmp = host().compare(r.host()); + if (cmp) + return cmp < 0; + return port() < r.port(); +} - bool HostAndPort::operator==(const HostAndPort& r) const { - return host() == r.host() && port() == r.port(); - } +bool HostAndPort::operator==(const HostAndPort& r) const { + return host() == r.host() && port() == r.port(); +} - int HostAndPort::port() const { - if (hasPort()) - return _port; - return ServerGlobalParams::DefaultDBPort; - } +int HostAndPort::port() const { + if (hasPort()) + return _port; + return ServerGlobalParams::DefaultDBPort; +} - bool HostAndPort::isLocalHost() const { - return ( _host == "localhost" - || str::startsWith(_host.c_str(), "127.") - || _host == "::1" - || _host == "anonymous unix socket" - || _host.c_str()[0] == '/' // unix socket - ); - } +bool HostAndPort::isLocalHost() const { + return (_host == "localhost" || str::startsWith(_host.c_str(), "127.") || _host == "::1" || + _host == "anonymous unix socket" || _host.c_str()[0] == '/' // unix socket + ); +} - std::string HostAndPort::toString() const { - StringBuilder ss; - append( ss ); - return ss.str(); - } +std::string HostAndPort::toString() const { + StringBuilder ss; + append(ss); + return ss.str(); +} - void HostAndPort::append(StringBuilder& ss) const { - // wrap ipv6 addresses in []s for roundtrip-ability - if (host().find(':') != std::string::npos) { - ss << '[' << host() << ']'; - } - else { - ss << host(); - } - ss << ':' << port(); +void HostAndPort::append(StringBuilder& ss) const { + // wrap ipv6 addresses in []s for roundtrip-ability + if (host().find(':') != std::string::npos) { + ss << '[' << host() << ']'; + } else { + ss << host(); } + ss << ':' << port(); +} - bool HostAndPort::empty() const { - return _host.empty() && _port < 0; - } +bool HostAndPort::empty() const { + return _host.empty() && _port < 0; +} - Status HostAndPort::initialize(StringData s) { - size_t colonPos = s.rfind(':'); - StringData hostPart = s.substr(0, colonPos); - - // handle ipv6 hostPart (which we require to be wrapped in []s) - const size_t openBracketPos = s.find('['); - const size_t closeBracketPos = s.find(']'); - if (openBracketPos != std::string::npos) { - if (openBracketPos != 0) { - return Status(ErrorCodes::FailedToParse, - str::stream() << "'[' present, but not first character in " - << s.toString()); - } - if (closeBracketPos == std::string::npos) { - return Status(ErrorCodes::FailedToParse, - str::stream() << "ipv6 address is missing closing ']' in hostname in " - << s.toString()); - } - - hostPart = s.substr(openBracketPos+1, closeBracketPos-openBracketPos-1); - // prevent accidental assignment of port to the value of the final portion of hostPart - if (colonPos < closeBracketPos) { - colonPos = std::string::npos; - } - else if (colonPos != closeBracketPos+1) { - return Status(ErrorCodes::FailedToParse, - str::stream() << "Extraneous characters between ']' and pre-port ':'" - << " in " << s.toString()); - } - } - else if (closeBracketPos != std::string::npos) { +Status HostAndPort::initialize(StringData s) { + size_t colonPos = s.rfind(':'); + StringData hostPart = s.substr(0, colonPos); + + // handle ipv6 hostPart (which we require to be wrapped in []s) + const size_t openBracketPos = s.find('['); + const size_t closeBracketPos = s.find(']'); + if (openBracketPos != std::string::npos) { + if (openBracketPos != 0) { return Status(ErrorCodes::FailedToParse, - str::stream() << "']' present without '[' in " << s.toString()); + str::stream() << "'[' present, but not first character in " + << s.toString()); } - else if (s.find(':') != colonPos) { + if (closeBracketPos == std::string::npos) { return Status(ErrorCodes::FailedToParse, - str::stream() << "More than one ':' detected. If this is an ipv6 address," - << " it needs to be surrounded by '[' and ']'; " + str::stream() << "ipv6 address is missing closing ']' in hostname in " << s.toString()); } - if (hostPart.empty()) { - return Status(ErrorCodes::FailedToParse, str::stream() << - "Empty host component parsing HostAndPort from \"" << - escape(s.toString()) << "\""); + hostPart = s.substr(openBracketPos + 1, closeBracketPos - openBracketPos - 1); + // prevent accidental assignment of port to the value of the final portion of hostPart + if (colonPos < closeBracketPos) { + colonPos = std::string::npos; + } else if (colonPos != closeBracketPos + 1) { + return Status(ErrorCodes::FailedToParse, + str::stream() << "Extraneous characters between ']' and pre-port ':'" + << " in " << s.toString()); } + } else if (closeBracketPos != std::string::npos) { + return Status(ErrorCodes::FailedToParse, + str::stream() << "']' present without '[' in " << s.toString()); + } else if (s.find(':') != colonPos) { + return Status(ErrorCodes::FailedToParse, + str::stream() << "More than one ':' detected. If this is an ipv6 address," + << " it needs to be surrounded by '[' and ']'; " + << s.toString()); + } + + if (hostPart.empty()) { + return Status(ErrorCodes::FailedToParse, + str::stream() << "Empty host component parsing HostAndPort from \"" + << escape(s.toString()) << "\""); + } - int port; - if (colonPos != std::string::npos) { - const StringData portPart = s.substr(colonPos + 1); - Status status = parseNumberFromStringWithBase(portPart, 10, &port); - if (!status.isOK()) { - return status; - } - if (port <= 0) { - return Status(ErrorCodes::FailedToParse, str::stream() << "Port number " << port << - " out of range parsing HostAndPort from \"" << escape(s.toString()) << - "\""); - } + int port; + if (colonPos != std::string::npos) { + const StringData portPart = s.substr(colonPos + 1); + Status status = parseNumberFromStringWithBase(portPart, 10, &port); + if (!status.isOK()) { + return status; } - else { - port = -1; + if (port <= 0) { + return Status(ErrorCodes::FailedToParse, + str::stream() << "Port number " << port + << " out of range parsing HostAndPort from \"" + << escape(s.toString()) << "\""); } - _host = hostPart.toString(); - _port = port; - return Status::OK(); + } else { + port = -1; } + _host = hostPart.toString(); + _port = port; + return Status::OK(); +} - std::ostream& operator<<(std::ostream& os, const HostAndPort& hp) { - return os << hp.toString(); - } +std::ostream& operator<<(std::ostream& os, const HostAndPort& hp) { + return os << hp.toString(); +} } // namespace mongo diff --git a/src/mongo/util/net/hostandport.h b/src/mongo/util/net/hostandport.h index 7ea3e84ddbf..6f7bd46b43c 100644 --- a/src/mongo/util/net/hostandport.h +++ b/src/mongo/util/net/hostandport.h @@ -34,93 +34,95 @@ #include "mongo/platform/hash_namespace.h" namespace mongo { - class Status; - class StringData; - template <typename T> class StatusWith; +class Status; +class StringData; +template <typename T> +class StatusWith; + +/** + * Name of a process on the network. + * + * Composed of some name component, followed optionally by a colon and a numeric port. The name + * might be an IPv4 or IPv6 address or a relative or fully qualified host name, or an absolute + * path to a unix socket. + */ +struct HostAndPort { + /** + * Parses "text" to produce a HostAndPort. Returns either that or an error + * status describing the parse failure. + */ + static StatusWith<HostAndPort> parse(StringData text); /** - * Name of a process on the network. + * Construct an empty/invalid HostAndPort. + */ + HostAndPort(); + + /** + * Constructs a HostAndPort by parsing "text" of the form hostname[:portnumber] + * Throws an AssertionException if bad config std::string or bad port #. + */ + explicit HostAndPort(StringData text); + + /** + * Constructs a HostAndPort with the hostname "h" and port "p". * - * Composed of some name component, followed optionally by a colon and a numeric port. The name - * might be an IPv4 or IPv6 address or a relative or fully qualified host name, or an absolute - * path to a unix socket. + * If "p" is -1, port() returns ServerGlobalParams::DefaultDBPort. + */ + HostAndPort(const std::string& h, int p); + + /** + * (Re-)initializes this HostAndPort by parsing "s". Returns + * Status::OK on success. The state of this HostAndPort is unspecified + * after initialize() returns a non-OK status, though it is safe to + * assign to it or re-initialize it. */ - struct HostAndPort { - - /** - * Parses "text" to produce a HostAndPort. Returns either that or an error - * status describing the parse failure. - */ - static StatusWith<HostAndPort> parse(StringData text); - - /** - * Construct an empty/invalid HostAndPort. - */ - HostAndPort(); - - /** - * Constructs a HostAndPort by parsing "text" of the form hostname[:portnumber] - * Throws an AssertionException if bad config std::string or bad port #. - */ - explicit HostAndPort(StringData text); - - /** - * Constructs a HostAndPort with the hostname "h" and port "p". - * - * If "p" is -1, port() returns ServerGlobalParams::DefaultDBPort. - */ - HostAndPort(const std::string& h, int p); - - /** - * (Re-)initializes this HostAndPort by parsing "s". Returns - * Status::OK on success. The state of this HostAndPort is unspecified - * after initialize() returns a non-OK status, though it is safe to - * assign to it or re-initialize it. - */ - Status initialize(StringData s); - - bool operator<(const HostAndPort& r) const; - bool operator==(const HostAndPort& r) const; - bool operator!=(const HostAndPort& r) const { return !(*this == r); } - - /** - * Returns true if the hostname looks localhost-y. - * - * TODO: Make a more rigorous implementation, perhaps elsewhere in - * the networking library. - */ - bool isLocalHost() const; - - /** - * Returns a string representation of "host:port". - */ - std::string toString() const; - - /** - * Like toString(), above, but writes to "ss", instead. - */ - void append( StringBuilder& ss ) const; - - /** - * Returns true if this object represents no valid HostAndPort. - */ - bool empty() const; - - const std::string& host() const { - return _host; - } - int port() const; - - bool hasPort() const { - return _port >= 0; - } - - private: - std::string _host; - int _port; // -1 indicates unspecified - }; - - std::ostream& operator<<(std::ostream& os, const HostAndPort& hp); + Status initialize(StringData s); + + bool operator<(const HostAndPort& r) const; + bool operator==(const HostAndPort& r) const; + bool operator!=(const HostAndPort& r) const { + return !(*this == r); + } + + /** + * Returns true if the hostname looks localhost-y. + * + * TODO: Make a more rigorous implementation, perhaps elsewhere in + * the networking library. + */ + bool isLocalHost() const; + + /** + * Returns a string representation of "host:port". + */ + std::string toString() const; + + /** + * Like toString(), above, but writes to "ss", instead. + */ + void append(StringBuilder& ss) const; + + /** + * Returns true if this object represents no valid HostAndPort. + */ + bool empty() const; + + const std::string& host() const { + return _host; + } + int port() const; + + bool hasPort() const { + return _port >= 0; + } + +private: + std::string _host; + int _port; // -1 indicates unspecified +}; + +std::ostream& operator<<(std::ostream& os, const HostAndPort& hp); } // namespace mongo diff --git a/src/mongo/util/net/hostandport_test.cpp b/src/mongo/util/net/hostandport_test.cpp index 38f2192541e..c325a7d111a 100644 --- a/src/mongo/util/net/hostandport_test.cpp +++ b/src/mongo/util/net/hostandport_test.cpp @@ -33,83 +33,81 @@ namespace mongo { namespace { - TEST(HostAndPort, BasicLessThanComparison) { - // Not less than self. - ASSERT_FALSE(HostAndPort("a", 1) < HostAndPort("a", 1)); +TEST(HostAndPort, BasicLessThanComparison) { + // Not less than self. + ASSERT_FALSE(HostAndPort("a", 1) < HostAndPort("a", 1)); - // Lex order by name. - ASSERT_LESS_THAN(HostAndPort("a", 1), HostAndPort("b", 1)); - ASSERT_FALSE(HostAndPort("b", 1) < HostAndPort("a", 1)); + // Lex order by name. + ASSERT_LESS_THAN(HostAndPort("a", 1), HostAndPort("b", 1)); + ASSERT_FALSE(HostAndPort("b", 1) < HostAndPort("a", 1)); - // Then, order by port number. - ASSERT_LESS_THAN(HostAndPort("a", 1), HostAndPort("a", 2)); - ASSERT_FALSE(HostAndPort("a", 2) < HostAndPort("a", 1)); - } + // Then, order by port number. + ASSERT_LESS_THAN(HostAndPort("a", 1), HostAndPort("a", 2)); + ASSERT_FALSE(HostAndPort("a", 2) < HostAndPort("a", 1)); +} - TEST(HostAndPort, BasicEquality) { - // Comparison on host field - ASSERT_EQUALS(HostAndPort("a", 1), HostAndPort("a", 1)); - ASSERT_FALSE(HostAndPort("b", 1) == HostAndPort("a", 1)); - ASSERT_FALSE(HostAndPort("a", 1) != HostAndPort("a", 1)); - ASSERT_NOT_EQUALS(HostAndPort("b", 1), HostAndPort("a", 1)); +TEST(HostAndPort, BasicEquality) { + // Comparison on host field + ASSERT_EQUALS(HostAndPort("a", 1), HostAndPort("a", 1)); + ASSERT_FALSE(HostAndPort("b", 1) == HostAndPort("a", 1)); + ASSERT_FALSE(HostAndPort("a", 1) != HostAndPort("a", 1)); + ASSERT_NOT_EQUALS(HostAndPort("b", 1), HostAndPort("a", 1)); - // Comparison on port field - ASSERT_FALSE(HostAndPort("a", 1) == HostAndPort("a", 2)); - ASSERT_NOT_EQUALS(HostAndPort("a", 1), HostAndPort("a", 2)); - } + // Comparison on port field + ASSERT_FALSE(HostAndPort("a", 1) == HostAndPort("a", 2)); + ASSERT_NOT_EQUALS(HostAndPort("a", 1), HostAndPort("a", 2)); +} - TEST(HostAndPort, ImplicitPortSelection) { - ASSERT_EQUALS(HostAndPort("a", -1), - HostAndPort("a", int(ServerGlobalParams::DefaultDBPort))); - ASSERT_EQUALS(int(ServerGlobalParams::DefaultDBPort), HostAndPort("a", -1).port()); - ASSERT_FALSE(HostAndPort("a", -1).empty()); - } +TEST(HostAndPort, ImplicitPortSelection) { + ASSERT_EQUALS(HostAndPort("a", -1), HostAndPort("a", int(ServerGlobalParams::DefaultDBPort))); + ASSERT_EQUALS(int(ServerGlobalParams::DefaultDBPort), HostAndPort("a", -1).port()); + ASSERT_FALSE(HostAndPort("a", -1).empty()); +} - TEST(HostAndPort, ConstructorParsing) { - ASSERT_THROWS(HostAndPort(""), AssertionException); - ASSERT_THROWS(HostAndPort("a:"), AssertionException); - ASSERT_THROWS(HostAndPort("a:0xa"), AssertionException); - ASSERT_THROWS(HostAndPort(":123"), AssertionException); - ASSERT_THROWS(HostAndPort("[124d:"), AssertionException); - ASSERT_THROWS(HostAndPort("[124d:]asdf:34"), AssertionException); - ASSERT_THROWS(HostAndPort("frim[124d:]:34"), AssertionException); - ASSERT_THROWS(HostAndPort("[124d:]12:34"), AssertionException); - ASSERT_THROWS(HostAndPort("124d:12:34"), AssertionException); +TEST(HostAndPort, ConstructorParsing) { + ASSERT_THROWS(HostAndPort(""), AssertionException); + ASSERT_THROWS(HostAndPort("a:"), AssertionException); + ASSERT_THROWS(HostAndPort("a:0xa"), AssertionException); + ASSERT_THROWS(HostAndPort(":123"), AssertionException); + ASSERT_THROWS(HostAndPort("[124d:"), AssertionException); + ASSERT_THROWS(HostAndPort("[124d:]asdf:34"), AssertionException); + ASSERT_THROWS(HostAndPort("frim[124d:]:34"), AssertionException); + ASSERT_THROWS(HostAndPort("[124d:]12:34"), AssertionException); + ASSERT_THROWS(HostAndPort("124d:12:34"), AssertionException); - ASSERT_EQUALS(HostAndPort("abc"), HostAndPort("abc", -1)); - ASSERT_EQUALS(HostAndPort("abc.def:3421"), HostAndPort("abc.def", 3421)); - ASSERT_EQUALS(HostAndPort("[124d:]:34"), HostAndPort("124d:", 34)); - ASSERT_EQUALS(HostAndPort("[124d:efg]:34"), HostAndPort("124d:efg", 34)); - ASSERT_EQUALS(HostAndPort("[124d:]"), HostAndPort("124d:", -1)); - } + ASSERT_EQUALS(HostAndPort("abc"), HostAndPort("abc", -1)); + ASSERT_EQUALS(HostAndPort("abc.def:3421"), HostAndPort("abc.def", 3421)); + ASSERT_EQUALS(HostAndPort("[124d:]:34"), HostAndPort("124d:", 34)); + ASSERT_EQUALS(HostAndPort("[124d:efg]:34"), HostAndPort("124d:efg", 34)); + ASSERT_EQUALS(HostAndPort("[124d:]"), HostAndPort("124d:", -1)); +} - TEST(HostAndPort, StaticParseFunction) { - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("a:").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("a:0").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("a:0xa").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse(":123").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[124d:").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[124d:]asdf:34").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("124d:asdf:34").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("1234:").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[[124d]]").getStatus()); - ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[[124d]:34]").getStatus()); +TEST(HostAndPort, StaticParseFunction) { + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("a:").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("a:0").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("a:0xa").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse(":123").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[124d:").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[124d:]asdf:34").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("124d:asdf:34").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("1234:").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[[124d]]").getStatus()); + ASSERT_EQUALS(ErrorCodes::FailedToParse, HostAndPort::parse("[[124d]:34]").getStatus()); - ASSERT_EQUALS(unittest::assertGet(HostAndPort::parse("abc")), HostAndPort("abc", -1)); - ASSERT_EQUALS(unittest::assertGet(HostAndPort::parse("abc.def:3421")), - HostAndPort("abc.def", 3421)); - ASSERT_EQUALS(unittest::assertGet(HostAndPort::parse("[243:1bc]:21")), - HostAndPort("243:1bc", 21)); - } + ASSERT_EQUALS(unittest::assertGet(HostAndPort::parse("abc")), HostAndPort("abc", -1)); + ASSERT_EQUALS(unittest::assertGet(HostAndPort::parse("abc.def:3421")), + HostAndPort("abc.def", 3421)); + ASSERT_EQUALS(unittest::assertGet(HostAndPort::parse("[243:1bc]:21")), + HostAndPort("243:1bc", 21)); +} - TEST(HostAndPort, RoundTripAbility) { - ASSERT_EQUALS(HostAndPort("abc"), HostAndPort(HostAndPort("abc").toString())); - ASSERT_EQUALS(HostAndPort("abc.def:3421"), - HostAndPort(HostAndPort("abc.def:3421").toString())); - ASSERT_EQUALS(HostAndPort("[124d:]:34"), HostAndPort(HostAndPort("[124d:]:34").toString())); - ASSERT_EQUALS(HostAndPort("[124d:]"), HostAndPort(HostAndPort("[124d:]").toString())); - } +TEST(HostAndPort, RoundTripAbility) { + ASSERT_EQUALS(HostAndPort("abc"), HostAndPort(HostAndPort("abc").toString())); + ASSERT_EQUALS(HostAndPort("abc.def:3421"), HostAndPort(HostAndPort("abc.def:3421").toString())); + ASSERT_EQUALS(HostAndPort("[124d:]:34"), HostAndPort(HostAndPort("[124d:]:34").toString())); + ASSERT_EQUALS(HostAndPort("[124d:]"), HostAndPort(HostAndPort("[124d:]").toString())); +} } // namespace } // namespace mongo diff --git a/src/mongo/util/net/httpclient.cpp b/src/mongo/util/net/httpclient.cpp index 4f7e9260555..a3aa63651da 100644 --- a/src/mongo/util/net/httpclient.cpp +++ b/src/mongo/util/net/httpclient.cpp @@ -41,158 +41,156 @@ namespace mongo { - using std::string; - using std::stringstream; +using std::string; +using std::stringstream; - //#define HD(x) cout << x << endl; +//#define HD(x) cout << x << endl; #define HD(x) - int HttpClient::get( const std::string& url , Result * result ) { - return _go( "GET" , url , 0 , result ); +int HttpClient::get(const std::string& url, Result* result) { + return _go("GET", url, 0, result); +} + +int HttpClient::post(const std::string& url, const std::string& data, Result* result) { + return _go("POST", url, data.c_str(), result); +} + +int HttpClient::_go(const char* command, string url, const char* body, Result* result) { + bool ssl = false; + if (url.find("https://") == 0) { + ssl = true; + url = url.substr(8); + } else { + uassert(10271, "invalid url", url.find("http://") == 0); + url = url.substr(7); } - int HttpClient::post( const std::string& url , const std::string& data , Result * result ) { - return _go( "POST" , url , data.c_str() , result ); + string host, path; + if (url.find("/") == string::npos) { + host = url; + path = "/"; + } else { + host = url.substr(0, url.find("/")); + path = url.substr(url.find("/")); } - int HttpClient::_go( const char * command , string url , const char * body , Result * result ) { - bool ssl = false; - if ( url.find( "https://" ) == 0 ) { - ssl = true; - url = url.substr( 8 ); - } - else { - uassert( 10271 , "invalid url" , url.find( "http://" ) == 0 ); - url = url.substr( 7 ); - } - string host , path; - if ( url.find( "/" ) == string::npos ) { - host = url; - path = "/"; - } - else { - host = url.substr( 0 , url.find( "/" ) ); - path = url.substr( url.find( "/" ) ); - } + HD("host [" << host << "]"); + HD("path [" << path << "]"); + string server = host; + int port = ssl ? 443 : 80; - HD( "host [" << host << "]" ); - HD( "path [" << path << "]" ); - - string server = host; - int port = ssl ? 443 : 80; + string::size_type idx = host.find(":"); + if (idx != string::npos) { + server = host.substr(0, idx); + string t = host.substr(idx + 1); + port = atoi(t.c_str()); + } - string::size_type idx = host.find( ":" ); - if ( idx != string::npos ) { - server = host.substr( 0 , idx ); - string t = host.substr( idx + 1 ); - port = atoi( t.c_str() ); + HD("server [" << server << "]"); + HD("port [" << port << "]"); + + string req; + { + stringstream ss; + ss << command << " " << path << " HTTP/1.1\r\n"; + ss << "Host: " << host << "\r\n"; + ss << "Connection: Close\r\n"; + ss << "User-Agent: mongodb http client\r\n"; + if (body) { + ss << "Content-Length: " << strlen(body) << "\r\n"; } - - HD( "server [" << server << "]" ); - HD( "port [" << port << "]" ); - - string req; - { - stringstream ss; - ss << command << " " << path << " HTTP/1.1\r\n"; - ss << "Host: " << host << "\r\n"; - ss << "Connection: Close\r\n"; - ss << "User-Agent: mongodb http client\r\n"; - if ( body ) { - ss << "Content-Length: " << strlen( body ) << "\r\n"; - } - ss << "\r\n"; - if ( body ) { - ss << body; - } - - req = ss.str(); + ss << "\r\n"; + if (body) { + ss << body; } - SockAddr addr( server.c_str() , port ); - uassert( 15000 , "server socket addr is invalid" , addr.isValid() ); - HD( "addr: " << addr.toString() ); + req = ss.str(); + } - Socket sock; - if ( ! sock.connect( addr ) ) - return -1; - - if ( ssl ) { + SockAddr addr(server.c_str(), port); + uassert(15000, "server socket addr is invalid", addr.isValid()); + HD("addr: " << addr.toString()); + + Socket sock; + if (!sock.connect(addr)) + return -1; + + if (ssl) { #ifdef MONGO_CONFIG_SSL - // pointer to global singleton instance - SSLManagerInterface* mgr = getSSLManager(); + // pointer to global singleton instance + SSLManagerInterface* mgr = getSSLManager(); - sock.secure(mgr, ""); + sock.secure(mgr, ""); #else - uasserted( 15862 , "no ssl support" ); + uasserted(15862, "no ssl support"); #endif - } - - { - const char * out = req.c_str(); - int toSend = req.size(); - sock.send( out , toSend, "_go" ); - } - - char buf[4097]; - int got = sock.unsafe_recv( buf , 4096 ); - buf[got] = 0; - - int rc; - char version[32]; - verify( sscanf( buf , "%s %d" , version , &rc ) == 2 ); - HD( "rc: " << rc ); - - StringBuilder sb; - if ( result ) - sb << buf; - - // SERVER-8864, unsafe_recv will throw when recv returns 0 indicating closed socket. - try { - while ( ( got = sock.unsafe_recv( buf , 4096 ) ) > 0) { - buf[got] = 0; - if ( result ) - sb << buf; - } - } catch (const SocketException&) {} + } + { + const char* out = req.c_str(); + int toSend = req.size(); + sock.send(out, toSend, "_go"); + } - if ( result ) { - result->_init( rc , sb.str() ); + char buf[4097]; + int got = sock.unsafe_recv(buf, 4096); + buf[got] = 0; + + int rc; + char version[32]; + verify(sscanf(buf, "%s %d", version, &rc) == 2); + HD("rc: " << rc); + + StringBuilder sb; + if (result) + sb << buf; + + // SERVER-8864, unsafe_recv will throw when recv returns 0 indicating closed socket. + try { + while ((got = sock.unsafe_recv(buf, 4096)) > 0) { + buf[got] = 0; + if (result) + sb << buf; } + } catch (const SocketException&) { + } - return rc; + + if (result) { + result->_init(rc, sb.str()); } - void HttpClient::Result::_init( int code , string entire ) { - _code = code; - _entireResponse = entire; + return rc; +} - while ( true ) { - size_t i = entire.find( '\n' ); - if ( i == string::npos ) { - // invalid - break; - } +void HttpClient::Result::_init(int code, string entire) { + _code = code; + _entireResponse = entire; - string h = entire.substr( 0 , i ); - entire = entire.substr( i + 1 ); + while (true) { + size_t i = entire.find('\n'); + if (i == string::npos) { + // invalid + break; + } - if ( h.size() && h[h.size()-1] == '\r' ) - h = h.substr( 0 , h.size() - 1 ); + string h = entire.substr(0, i); + entire = entire.substr(i + 1); - if ( h.size() == 0 ) - break; + if (h.size() && h[h.size() - 1] == '\r') + h = h.substr(0, h.size() - 1); - i = h.find( ':' ); - if ( i != string::npos ) - _headers[h.substr(0,i)] = str::ltrim(h.substr(i+1)); - } + if (h.size() == 0) + break; - _body = entire; + i = h.find(':'); + if (i != string::npos) + _headers[h.substr(0, i)] = str::ltrim(h.substr(i + 1)); } + _body = entire; +} } diff --git a/src/mongo/util/net/httpclient.h b/src/mongo/util/net/httpclient.h index 18c66aee25c..8c7c1af8ce6 100644 --- a/src/mongo/util/net/httpclient.h +++ b/src/mongo/util/net/httpclient.h @@ -37,52 +37,51 @@ namespace mongo { - class HttpClient { - MONGO_DISALLOW_COPYING(HttpClient); - public: - - typedef std::map<std::string,std::string> Headers; +class HttpClient { + MONGO_DISALLOW_COPYING(HttpClient); - class Result { - public: - Result() {} +public: + typedef std::map<std::string, std::string> Headers; - const std::string& getEntireResponse() const { - return _entireResponse; - } + class Result { + public: + Result() {} - Headers getHeaders() const { - return _headers; - } + const std::string& getEntireResponse() const { + return _entireResponse; + } - const std::string& getBody() const { - return _body; - } + Headers getHeaders() const { + return _headers; + } - private: + const std::string& getBody() const { + return _body; + } - void _init( int code , std::string entire ); + private: + void _init(int code, std::string entire); - int _code; - std::string _entireResponse; + int _code; + std::string _entireResponse; - Headers _headers; - std::string _body; + Headers _headers; + std::string _body; - friend class HttpClient; - }; + friend class HttpClient; + }; - /** - * @return response code - */ - int get( const std::string& url , Result * result = 0 ); + /** + * @return response code + */ + int get(const std::string& url, Result* result = 0); - /** - * @return response code - */ - int post( const std::string& url , const std::string& body , Result * result = 0 ); + /** + * @return response code + */ + int post(const std::string& url, const std::string& body, Result* result = 0); - private: - int _go( const char * command , std::string url , const char * body , Result * result ); - }; +private: + int _go(const char* command, std::string url, const char* body, Result* result); +}; } diff --git a/src/mongo/util/net/listen.cpp b/src/mongo/util/net/listen.cpp index 249fe6f878e..be77c6fc485 100644 --- a/src/mongo/util/net/listen.cpp +++ b/src/mongo/util/net/listen.cpp @@ -46,11 +46,11 @@ #ifndef _WIN32 -# ifndef __sun -# include <ifaddrs.h> -# endif -# include <sys/resource.h> -# include <sys/stat.h> +#ifndef __sun +#include <ifaddrs.h> +#endif +#include <sys/resource.h> +#include <sys/stat.h> #include <sys/types.h> #include <sys/socket.h> @@ -61,7 +61,7 @@ #include <errno.h> #include <netdb.h> #ifdef __OpenBSD__ -# include <sys/uio.h> +#include <sys/uio.h> #endif #else @@ -74,594 +74,593 @@ namespace mongo { - using std::shared_ptr; - using std::endl; - using std::string; - using std::vector; +using std::shared_ptr; +using std::endl; +using std::string; +using std::vector; - // ----- Listener ------- +// ----- Listener ------- - const Listener* Listener::_timeTracker; +const Listener* Listener::_timeTracker; - vector<SockAddr> ipToAddrs(const char* ips, int port, bool useUnixSockets) { - vector<SockAddr> out; - if (*ips == '\0') { - out.push_back(SockAddr("0.0.0.0", port)); // IPv4 all +vector<SockAddr> ipToAddrs(const char* ips, int port, bool useUnixSockets) { + vector<SockAddr> out; + if (*ips == '\0') { + out.push_back(SockAddr("0.0.0.0", port)); // IPv4 all - if (IPv6Enabled()) - out.push_back(SockAddr("::", port)); // IPv6 all + if (IPv6Enabled()) + out.push_back(SockAddr("::", port)); // IPv6 all #ifndef _WIN32 - if (useUnixSockets) - out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); // Unix socket + if (useUnixSockets) + out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); // Unix socket #endif - return out; - } + return out; + } - while(*ips) { - string ip; - const char * comma = strchr(ips, ','); - if (comma) { - ip = string(ips, comma - ips); - ips = comma + 1; - } - else { - ip = string(ips); - ips = ""; - } + while (*ips) { + string ip; + const char* comma = strchr(ips, ','); + if (comma) { + ip = string(ips, comma - ips); + ips = comma + 1; + } else { + ip = string(ips); + ips = ""; + } - SockAddr sa(ip.c_str(), port); - out.push_back(sa); + SockAddr sa(ip.c_str(), port); + out.push_back(sa); #ifndef _WIN32 - if (sa.isValid() && useUnixSockets && - (sa.getAddr() == "127.0.0.1" || sa.getAddr() == "0.0.0.0")) // only IPv4 - out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); + if (sa.isValid() && useUnixSockets && + (sa.getAddr() == "127.0.0.1" || sa.getAddr() == "0.0.0.0")) // only IPv4 + out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); #endif - } - return out; - } - - Listener::Listener(const string& name, const string &ip, int port, bool logConnect ) - : _port(port), _name(name), _ip(ip), _setupSocketsSuccessful(false), - _logConnect(logConnect), _elapsedTime(0) { + return out; +} + +Listener::Listener(const string& name, const string& ip, int port, bool logConnect) + : _port(port), + _name(name), + _ip(ip), + _setupSocketsSuccessful(false), + _logConnect(logConnect), + _elapsedTime(0) { #ifdef MONGO_CONFIG_SSL - _ssl = getSSLManager(); + _ssl = getSSLManager(); #endif - } - - Listener::~Listener() { - if ( _timeTracker == this ) - _timeTracker = 0; - } +} + +Listener::~Listener() { + if (_timeTracker == this) + _timeTracker = 0; +} - void Listener::setupSockets() { - checkTicketNumbers(); +void Listener::setupSockets() { + checkTicketNumbers(); #if !defined(_WIN32) - _mine = ipToAddrs(_ip.c_str(), _port, (!serverGlobalParams.noUnixSocket && - useUnixSockets())); + _mine = ipToAddrs(_ip.c_str(), _port, (!serverGlobalParams.noUnixSocket && useUnixSockets())); #else - _mine = ipToAddrs(_ip.c_str(), _port, false); + _mine = ipToAddrs(_ip.c_str(), _port, false); #endif - for (std::vector<SockAddr>::const_iterator it=_mine.begin(), end=_mine.end(); - it != end; - ++it) { - - const SockAddr& me = *it; + for (std::vector<SockAddr>::const_iterator it = _mine.begin(), end = _mine.end(); it != end; + ++it) { + const SockAddr& me = *it; - if (!me.isValid()) { - error() << "listen(): socket is invalid." << endl; - return; - } + if (!me.isValid()) { + error() << "listen(): socket is invalid." << endl; + return; + } - SOCKET sock = ::socket(me.getType(), SOCK_STREAM, 0); - ScopeGuard socketGuard = MakeGuard(&closesocket, sock); - massert( 15863 , str::stream() << "listen(): invalid socket? " << errnoWithDescription() , sock >= 0 ); + SOCKET sock = ::socket(me.getType(), SOCK_STREAM, 0); + ScopeGuard socketGuard = MakeGuard(&closesocket, sock); + massert(15863, + str::stream() << "listen(): invalid socket? " << errnoWithDescription(), + sock >= 0); - if (me.getType() == AF_UNIX) { + if (me.getType() == AF_UNIX) { #if !defined(_WIN32) - if (unlink(me.getAddr().c_str()) == -1) { - if (errno != ENOENT) { - error() << "Failed to unlink socket file " << me << " " - << errnoWithDescription(errno); - fassertFailedNoTrace(28578); - } + if (unlink(me.getAddr().c_str()) == -1) { + if (errno != ENOENT) { + error() << "Failed to unlink socket file " << me << " " + << errnoWithDescription(errno); + fassertFailedNoTrace(28578); } -#endif - } - else if (me.getType() == AF_INET6) { - // IPv6 can also accept IPv4 connections as mapped addresses (::ffff:127.0.0.1) - // That causes a conflict if we don't do set it to IPV6_ONLY - const int one = 1; - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*) &one, sizeof(one)); } +#endif + } else if (me.getType() == AF_INET6) { + // IPv6 can also accept IPv4 connections as mapped addresses (::ffff:127.0.0.1) + // That causes a conflict if we don't do set it to IPV6_ONLY + const int one = 1; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&one, sizeof(one)); + } #if !defined(_WIN32) - { - const int one = 1; - if ( setsockopt( sock , SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0 ) - log() << "Failed to set socket opt, SO_REUSEADDR" << endl; - } + { + const int one = 1; + if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) + log() << "Failed to set socket opt, SO_REUSEADDR" << endl; + } #endif - if ( ::bind(sock, me.raw(), me.addressSize) != 0 ) { - int x = errno; - error() << "listen(): bind() failed " << errnoWithDescription(x) << " for socket: " << me.toString() << endl; - if ( x == EADDRINUSE ) - error() << " addr already in use" << endl; - return; - } + if (::bind(sock, me.raw(), me.addressSize) != 0) { + int x = errno; + error() << "listen(): bind() failed " << errnoWithDescription(x) + << " for socket: " << me.toString() << endl; + if (x == EADDRINUSE) + error() << " addr already in use" << endl; + return; + } #if !defined(_WIN32) - if (me.getType() == AF_UNIX) { - if (chmod(me.getAddr().c_str(), serverGlobalParams.unixSocketPermissions) == -1) { - error() << "Failed to chmod socket file " << me << " " - << errnoWithDescription(errno); - fassertFailedNoTrace(28582); - } - ListeningSockets::get()->addPath( me.getAddr() ); + if (me.getType() == AF_UNIX) { + if (chmod(me.getAddr().c_str(), serverGlobalParams.unixSocketPermissions) == -1) { + error() << "Failed to chmod socket file " << me << " " + << errnoWithDescription(errno); + fassertFailedNoTrace(28582); } + ListeningSockets::get()->addPath(me.getAddr()); + } #endif - _socks.push_back(sock); - socketGuard.Dismiss(); - } - - _setupSocketsSuccessful = true; + _socks.push_back(sock); + socketGuard.Dismiss(); } - - + + _setupSocketsSuccessful = true; +} + + #if !defined(_WIN32) - void Listener::initAndListen() { - if (!_setupSocketsSuccessful) { +void Listener::initAndListen() { + if (!_setupSocketsSuccessful) { + return; + } + + SOCKET maxfd = 0; // needed for select() + for (unsigned i = 0; i < _socks.size(); i++) { + if (::listen(_socks[i], 128) != 0) { + error() << "listen(): listen() failed " << errnoWithDescription() << endl; return; } - SOCKET maxfd = 0; // needed for select() - for (unsigned i = 0; i < _socks.size(); i++) { - if (::listen(_socks[i], 128) != 0) { - error() << "listen(): listen() failed " << errnoWithDescription() << endl; - return; - } - - ListeningSockets::get()->add(_socks[i]); + ListeningSockets::get()->add(_socks[i]); - if (_socks[i] > maxfd) { - maxfd = _socks[i]; - } + if (_socks[i] > maxfd) { + maxfd = _socks[i]; } + } - if ( maxfd >= FD_SETSIZE ) { - error() << "socket " << maxfd << " is higher than " << FD_SETSIZE-1 << - "; not supported" << warnings; - return; - } + if (maxfd >= FD_SETSIZE) { + error() << "socket " << maxfd << " is higher than " << FD_SETSIZE - 1 << "; not supported" + << warnings; + return; + } #ifdef MONGO_CONFIG_SSL - _logListen(_port, _ssl); + _logListen(_port, _ssl); #else - _logListen(_port, false); + _logListen(_port, false); #endif - { - // Wake up any threads blocked in waitUntilListening() - stdx::lock_guard<stdx::mutex> lock(_readyMutex); - _ready = true; - _readyCondition.notify_all(); - } + { + // Wake up any threads blocked in waitUntilListening() + stdx::lock_guard<stdx::mutex> lock(_readyMutex); + _ready = true; + _readyCondition.notify_all(); + } - struct timeval maxSelectTime; - while ( ! inShutdown() ) { - fd_set fds[1]; - FD_ZERO(fds); - - for (vector<SOCKET>::iterator it=_socks.begin(), end=_socks.end(); it != end; ++it) { - FD_SET(*it, fds); - } + struct timeval maxSelectTime; + while (!inShutdown()) { + fd_set fds[1]; + FD_ZERO(fds); + + for (vector<SOCKET>::iterator it = _socks.begin(), end = _socks.end(); it != end; ++it) { + FD_SET(*it, fds); + } - maxSelectTime.tv_sec = 0; - maxSelectTime.tv_usec = 10000; - const int ret = select(maxfd+1, fds, NULL, NULL, &maxSelectTime); + maxSelectTime.tv_sec = 0; + maxSelectTime.tv_usec = 10000; + const int ret = select(maxfd + 1, fds, NULL, NULL, &maxSelectTime); - if (ret == 0) { + if (ret == 0) { #if defined(__linux__) - _elapsedTime += ( 10000 - maxSelectTime.tv_usec ) / 1000; + _elapsedTime += (10000 - maxSelectTime.tv_usec) / 1000; #else - _elapsedTime += 10; + _elapsedTime += 10; #endif - continue; - } + continue; + } - if (ret < 0) { - int x = errno; + if (ret < 0) { + int x = errno; #ifdef EINTR - if ( x == EINTR ) { - log() << "select() signal caught, continuing" << endl; - continue; - } -#endif - if ( ! inShutdown() ) - log() << "select() failure: ret=" << ret << " " << errnoWithDescription(x) << endl; - return; + if (x == EINTR) { + log() << "select() signal caught, continuing" << endl; + continue; } +#endif + if (!inShutdown()) + log() << "select() failure: ret=" << ret << " " << errnoWithDescription(x) << endl; + return; + } #if defined(__linux__) - _elapsedTime += std::max(ret, (int)(( 10000 - maxSelectTime.tv_usec ) / 1000)); + _elapsedTime += std::max(ret, (int)((10000 - maxSelectTime.tv_usec) / 1000)); #else - _elapsedTime += ret; // assume 1ms to grab connection. very rough + _elapsedTime += ret; // assume 1ms to grab connection. very rough #endif - for (vector<SOCKET>::iterator it=_socks.begin(), end=_socks.end(); it != end; ++it) { - if (! (FD_ISSET(*it, fds))) + for (vector<SOCKET>::iterator it = _socks.begin(), end = _socks.end(); it != end; ++it) { + if (!(FD_ISSET(*it, fds))) + continue; + SockAddr from; + int s = accept(*it, from.raw(), &from.addressSize); + if (s < 0) { + int x = errno; // so no global issues + if (x == EBADF) { + log() << "Port " << _port << " is no longer valid" << endl; + return; + } else if (x == ECONNABORTED) { + log() << "Connection on port " << _port << " aborted" << endl; continue; - SockAddr from; - int s = accept(*it, from.raw(), &from.addressSize); - if ( s < 0 ) { - int x = errno; // so no global issues - if (x == EBADF) { - log() << "Port " << _port << " is no longer valid" << endl; - return; - } - else if (x == ECONNABORTED) { - log() << "Connection on port " << _port << " aborted" << endl; - continue; - } - if ( x == 0 && inShutdown() ) { - return; // socket closed - } - if( !inShutdown() ) { - log() << "Listener: accept() returns " << s << " " << errnoWithDescription(x) << endl; - if (x == EMFILE || x == ENFILE) { - // Connection still in listen queue but we can't accept it yet - error() << "Out of file descriptors. Waiting one second before trying to accept more connections." << warnings; - sleepsecs(1); - } + } + if (x == 0 && inShutdown()) { + return; // socket closed + } + if (!inShutdown()) { + log() << "Listener: accept() returns " << s << " " << errnoWithDescription(x) + << endl; + if (x == EMFILE || x == ENFILE) { + // Connection still in listen queue but we can't accept it yet + error() << "Out of file descriptors. Waiting one second before trying to " + "accept more connections." << warnings; + sleepsecs(1); } - continue; } - if (from.getType() != AF_UNIX) - disableNagle(s); + continue; + } + if (from.getType() != AF_UNIX) + disableNagle(s); #ifdef SO_NOSIGPIPE - // ignore SIGPIPE signals on osx, to avoid process exit - const int one = 1; - setsockopt( s , SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int)); + // ignore SIGPIPE signals on osx, to avoid process exit + const int one = 1; + setsockopt(s, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int)); #endif - long long myConnectionNumber = globalConnectionNumber.addAndFetch(1); + long long myConnectionNumber = globalConnectionNumber.addAndFetch(1); + + if (_logConnect && !serverGlobalParams.quiet) { + int conns = globalTicketHolder.used() + 1; + const char* word = (conns == 1 ? " connection" : " connections"); + log() << "connection accepted from " << from.toString() << " #" + << myConnectionNumber << " (" << conns << word << " now open)" << endl; + } - if (_logConnect && !serverGlobalParams.quiet) { - int conns = globalTicketHolder.used()+1; - const char* word = (conns == 1 ? " connection" : " connections"); - log() << "connection accepted from " << from.toString() << " #" << myConnectionNumber << " (" << conns << word << " now open)" << endl; - } - - std::shared_ptr<Socket> pnewSock( new Socket(s, from) ); + std::shared_ptr<Socket> pnewSock(new Socket(s, from)); #ifdef MONGO_CONFIG_SSL - if (_ssl) { - pnewSock->secureAccepted(_ssl); - } -#endif - accepted( pnewSock , myConnectionNumber ); + if (_ssl) { + pnewSock->secureAccepted(_ssl); } +#endif + accepted(pnewSock, myConnectionNumber); } } +} -#else - // Windows - - // Given a SOCKET, turns off nonblocking mode - static void disableNonblockingMode(SOCKET socket) { - unsigned long resultBuffer = 0; - unsigned long resultBufferBytesWritten = 0; - unsigned long newNonblockingEnabled = 0; - const int status = WSAIoctl(socket, - FIONBIO, - &newNonblockingEnabled, - sizeof(unsigned long), - &resultBuffer, - sizeof(resultBuffer), - &resultBufferBytesWritten, - NULL, - NULL); - if (status == SOCKET_ERROR) { - const int mongo_errno = WSAGetLastError(); - error() << "Windows WSAIoctl returned " << errnoWithDescription(mongo_errno) << endl; - fassertFailed(16726); - } +#else +// Windows + +// Given a SOCKET, turns off nonblocking mode +static void disableNonblockingMode(SOCKET socket) { + unsigned long resultBuffer = 0; + unsigned long resultBufferBytesWritten = 0; + unsigned long newNonblockingEnabled = 0; + const int status = WSAIoctl(socket, + FIONBIO, + &newNonblockingEnabled, + sizeof(unsigned long), + &resultBuffer, + sizeof(resultBuffer), + &resultBufferBytesWritten, + NULL, + NULL); + if (status == SOCKET_ERROR) { + const int mongo_errno = WSAGetLastError(); + error() << "Windows WSAIoctl returned " << errnoWithDescription(mongo_errno) << endl; + fassertFailed(16726); } +} - // RAII wrapper class to ensure we do not leak WSAEVENTs. - class EventHolder { - WSAEVENT _socketEventHandle; - public: - EventHolder() { - _socketEventHandle = WSACreateEvent(); - if (_socketEventHandle == WSA_INVALID_EVENT) { - const int mongo_errno = WSAGetLastError(); - error() << "Windows WSACreateEvent returned " << errnoWithDescription(mongo_errno) +// RAII wrapper class to ensure we do not leak WSAEVENTs. +class EventHolder { + WSAEVENT _socketEventHandle; + +public: + EventHolder() { + _socketEventHandle = WSACreateEvent(); + if (_socketEventHandle == WSA_INVALID_EVENT) { + const int mongo_errno = WSAGetLastError(); + error() << "Windows WSACreateEvent returned " << errnoWithDescription(mongo_errno) << endl; - fassertFailed(16728); - } + fassertFailed(16728); } - ~EventHolder() { - BOOL bstatus = WSACloseEvent(_socketEventHandle); - if (bstatus == FALSE) { - const int mongo_errno = WSAGetLastError(); - error() << "Windows WSACloseEvent returned " << errnoWithDescription(mongo_errno) + } + ~EventHolder() { + BOOL bstatus = WSACloseEvent(_socketEventHandle); + if (bstatus == FALSE) { + const int mongo_errno = WSAGetLastError(); + error() << "Windows WSACloseEvent returned " << errnoWithDescription(mongo_errno) << endl; - fassertFailed(16725); - } - } - WSAEVENT get() { - return _socketEventHandle; - } - }; - - void Listener::initAndListen() { - if (!_setupSocketsSuccessful) { - return; + fassertFailed(16725); } + } + WSAEVENT get() { + return _socketEventHandle; + } +}; - for (unsigned i = 0; i < _socks.size(); i++) { - if (::listen(_socks[i], 128) != 0) { - error() << "listen(): listen() failed " << errnoWithDescription() << endl; - return; - } +void Listener::initAndListen() { + if (!_setupSocketsSuccessful) { + return; + } - ListeningSockets::get()->add(_socks[i]); + for (unsigned i = 0; i < _socks.size(); i++) { + if (::listen(_socks[i], 128) != 0) { + error() << "listen(): listen() failed " << errnoWithDescription() << endl; + return; } + ListeningSockets::get()->add(_socks[i]); + } + #ifdef MONGO_CONFIG_SSL - _logListen(_port, _ssl); + _logListen(_port, _ssl); #else - _logListen(_port, false); + _logListen(_port, false); #endif - { - // Wake up any threads blocked in waitUntilListening() - stdx::lock_guard<stdx::mutex> lock(_readyMutex); - _ready = true; - _readyCondition.notify_all(); - } + { + // Wake up any threads blocked in waitUntilListening() + stdx::lock_guard<stdx::mutex> lock(_readyMutex); + _ready = true; + _readyCondition.notify_all(); + } + + OwnedPointerVector<EventHolder> eventHolders; + std::unique_ptr<WSAEVENT[]> events(new WSAEVENT[_socks.size()]); - OwnedPointerVector<EventHolder> eventHolders; - std::unique_ptr<WSAEVENT[]> events(new WSAEVENT[_socks.size()]); - - - // Populate events array with an event for each socket we are watching + + // Populate events array with an event for each socket we are watching + for (size_t count = 0; count < _socks.size(); ++count) { + EventHolder* ev(new EventHolder); + eventHolders.mutableVector().push_back(ev); + events[count] = ev->get(); + } + + while (!inShutdown()) { + // Turn on listening for accept-ready sockets for (size_t count = 0; count < _socks.size(); ++count) { - EventHolder* ev(new EventHolder); - eventHolders.mutableVector().push_back(ev); - events[count] = ev->get(); - } - - while ( ! inShutdown() ) { - // Turn on listening for accept-ready sockets - for (size_t count = 0; count < _socks.size(); ++count) { - int status = WSAEventSelect(_socks[count], events[count], FD_ACCEPT | FD_CLOSE); - if (status == SOCKET_ERROR) { - const int mongo_errno = WSAGetLastError(); - - // During shutdown, we may fail to listen on the socket if it has already - // been closed - if (inShutdown()) { - return; - } + int status = WSAEventSelect(_socks[count], events[count], FD_ACCEPT | FD_CLOSE); + if (status == SOCKET_ERROR) { + const int mongo_errno = WSAGetLastError(); - error() << "Windows WSAEventSelect returned " - << errnoWithDescription(mongo_errno) << endl; - fassertFailed(16727); + // During shutdown, we may fail to listen on the socket if it has already + // been closed + if (inShutdown()) { + return; } + + error() << "Windows WSAEventSelect returned " << errnoWithDescription(mongo_errno) + << endl; + fassertFailed(16727); } - - // Wait till one of them goes active, or we time out - DWORD result = WSAWaitForMultipleEvents(_socks.size(), - events.get(), - FALSE, // don't wait for all the events - 10, // timeout, in ms - FALSE); // do not allow I/O interruptions - if (result == WSA_WAIT_FAILED) { - const int mongo_errno = WSAGetLastError(); - error() << "Windows WSAWaitForMultipleEvents returned " - << errnoWithDescription(mongo_errno) << endl; - fassertFailed(16723); - } - - if (result == WSA_WAIT_TIMEOUT) { - _elapsedTime += 10; - continue; - } - _elapsedTime += 1; // assume 1ms to grab connection. very rough - - // Determine which socket is ready - DWORD eventIndex = result - WSA_WAIT_EVENT_0; - WSANETWORKEVENTS networkEvents; - // Extract event details, and clear event for next pass - int status = WSAEnumNetworkEvents(_socks[eventIndex], - events[eventIndex], - &networkEvents); - if (status == SOCKET_ERROR) { - const int mongo_errno = WSAGetLastError(); - error() << "Windows WSAEnumNetworkEvents returned " + } + + // Wait till one of them goes active, or we time out + DWORD result = WSAWaitForMultipleEvents(_socks.size(), + events.get(), + FALSE, // don't wait for all the events + 10, // timeout, in ms + FALSE); // do not allow I/O interruptions + if (result == WSA_WAIT_FAILED) { + const int mongo_errno = WSAGetLastError(); + error() << "Windows WSAWaitForMultipleEvents returned " << errnoWithDescription(mongo_errno) << endl; + fassertFailed(16723); + } + + if (result == WSA_WAIT_TIMEOUT) { + _elapsedTime += 10; + continue; + } + _elapsedTime += 1; // assume 1ms to grab connection. very rough + + // Determine which socket is ready + DWORD eventIndex = result - WSA_WAIT_EVENT_0; + WSANETWORKEVENTS networkEvents; + // Extract event details, and clear event for next pass + int status = WSAEnumNetworkEvents(_socks[eventIndex], events[eventIndex], &networkEvents); + if (status == SOCKET_ERROR) { + const int mongo_errno = WSAGetLastError(); + error() << "Windows WSAEnumNetworkEvents returned " << errnoWithDescription(mongo_errno) + << endl; + continue; + } + + if (networkEvents.lNetworkEvents & FD_CLOSE) { + log() << "listen socket closed" << endl; + break; + } + + if (!(networkEvents.lNetworkEvents & FD_ACCEPT)) { + error() << "Unexpected network event: " << networkEvents.lNetworkEvents << endl; + continue; + } + + int iec = networkEvents.iErrorCode[FD_ACCEPT_BIT]; + if (iec != 0) { + error() << "Windows socket accept did not work:" << errnoWithDescription(iec) << endl; + continue; + } + + status = WSAEventSelect(_socks[eventIndex], NULL, 0); + if (status == SOCKET_ERROR) { + const int mongo_errno = WSAGetLastError(); + error() << "Windows WSAEventSelect returned " << errnoWithDescription(mongo_errno) + << endl; + continue; + } + + disableNonblockingMode(_socks[eventIndex]); + + SockAddr from; + int s = accept(_socks[eventIndex], from.raw(), &from.addressSize); + if (s < 0) { + int x = errno; // so no global issues + if (x == EBADF) { + log() << "Port " << _port << " is no longer valid" << endl; continue; - } - - if (networkEvents.lNetworkEvents & FD_CLOSE) { - log() << "listen socket closed" << endl; - break; - } - - if (!(networkEvents.lNetworkEvents & FD_ACCEPT)) { - error() << "Unexpected network event: " << networkEvents.lNetworkEvents << endl; - continue; - } - - int iec = networkEvents.iErrorCode[FD_ACCEPT_BIT]; - if (iec != 0) { - error() << "Windows socket accept did not work:" << errnoWithDescription(iec) - << endl; + } else if (x == ECONNABORTED) { + log() << "Listener on port " << _port << " aborted" << endl; continue; } - - status = WSAEventSelect(_socks[eventIndex], NULL, 0); - if (status == SOCKET_ERROR) { - const int mongo_errno = WSAGetLastError(); - error() << "Windows WSAEventSelect returned " - << errnoWithDescription(mongo_errno) << endl; - continue; + if (x == 0 && inShutdown()) { + return; // socket closed } - - disableNonblockingMode(_socks[eventIndex]); - - SockAddr from; - int s = accept(_socks[eventIndex], from.raw(), &from.addressSize); - if ( s < 0 ) { - int x = errno; // so no global issues - if (x == EBADF) { - log() << "Port " << _port << " is no longer valid" << endl; - continue; - } - else if (x == ECONNABORTED) { - log() << "Listener on port " << _port << " aborted" << endl; - continue; - } - if ( x == 0 && inShutdown() ) { - return; // socket closed - } - if( !inShutdown() ) { - log() << "Listener: accept() returns " << s << " " - << errnoWithDescription(x) << endl; - if (x == EMFILE || x == ENFILE) { - // Connection still in listen queue but we can't accept it yet - error() << "Out of file descriptors. Waiting one second before" - " trying to accept more connections." << warnings; - sleepsecs(1); - } + if (!inShutdown()) { + log() << "Listener: accept() returns " << s << " " << errnoWithDescription(x) + << endl; + if (x == EMFILE || x == ENFILE) { + // Connection still in listen queue but we can't accept it yet + error() << "Out of file descriptors. Waiting one second before" + " trying to accept more connections." << warnings; + sleepsecs(1); } - continue; } - if (from.getType() != AF_UNIX) - disableNagle(s); + continue; + } + if (from.getType() != AF_UNIX) + disableNagle(s); - long long myConnectionNumber = globalConnectionNumber.addAndFetch(1); + long long myConnectionNumber = globalConnectionNumber.addAndFetch(1); - if (_logConnect && !serverGlobalParams.quiet) { - int conns = globalTicketHolder.used()+1; - const char* word = (conns == 1 ? " connection" : " connections"); - log() << "connection accepted from " << from.toString() << " #" << myConnectionNumber << " (" << conns << word << " now open)" << endl; - } - - std::shared_ptr<Socket> pnewSock( new Socket(s, from) ); + if (_logConnect && !serverGlobalParams.quiet) { + int conns = globalTicketHolder.used() + 1; + const char* word = (conns == 1 ? " connection" : " connections"); + log() << "connection accepted from " << from.toString() << " #" << myConnectionNumber + << " (" << conns << word << " now open)" << endl; + } + + std::shared_ptr<Socket> pnewSock(new Socket(s, from)); #ifdef MONGO_CONFIG_SSL - if (_ssl) { - pnewSock->secureAccepted(_ssl); - } -#endif - accepted( pnewSock , myConnectionNumber ); + if (_ssl) { + pnewSock->secureAccepted(_ssl); } +#endif + accepted(pnewSock, myConnectionNumber); } +} #endif - void Listener::_logListen( int port , bool ssl ) { - log() << _name << ( _name.size() ? " " : "" ) << "waiting for connections on port " << port << ( ssl ? " ssl" : "" ) << endl; - } +void Listener::_logListen(int port, bool ssl) { + log() << _name << (_name.size() ? " " : "") << "waiting for connections on port " << port + << (ssl ? " ssl" : "") << endl; +} - void Listener::waitUntilListening() const { - stdx::unique_lock<stdx::mutex> lock(_readyMutex); - while (!_ready) { - _readyCondition.wait(lock); - } +void Listener::waitUntilListening() const { + stdx::unique_lock<stdx::mutex> lock(_readyMutex); + while (!_ready) { + _readyCondition.wait(lock); } +} - void Listener::accepted(std::shared_ptr<Socket> psocket, long long connectionId ) { - MessagingPort* port = new MessagingPort(psocket); - port->setConnectionId( connectionId ); - acceptedMP( port ); - } - - void Listener::acceptedMP(MessagingPort *mp) { - verify(!"You must overwrite one of the accepted methods"); - } +void Listener::accepted(std::shared_ptr<Socket> psocket, long long connectionId) { + MessagingPort* port = new MessagingPort(psocket); + port->setConnectionId(connectionId); + acceptedMP(port); +} - // ----- ListeningSockets ------- +void Listener::acceptedMP(MessagingPort* mp) { + verify(!"You must overwrite one of the accepted methods"); +} - ListeningSockets* ListeningSockets::_instance = new ListeningSockets(); +// ----- ListeningSockets ------- - ListeningSockets* ListeningSockets::get() { - return _instance; - } +ListeningSockets* ListeningSockets::_instance = new ListeningSockets(); - // ------ connection ticket and control ------ +ListeningSockets* ListeningSockets::get() { + return _instance; +} + +// ------ connection ticket and control ------ - int getMaxConnections() { +int getMaxConnections() { #ifdef _WIN32 - return DEFAULT_MAX_CONN; + return DEFAULT_MAX_CONN; #else - struct rlimit limit; - verify( getrlimit(RLIMIT_NOFILE,&limit) == 0 ); + struct rlimit limit; + verify(getrlimit(RLIMIT_NOFILE, &limit) == 0); - int max = (int)(limit.rlim_cur * .8); + int max = (int)(limit.rlim_cur * .8); - LOG(1) << "fd limit" - << " hard:" << limit.rlim_max - << " soft:" << limit.rlim_cur - << " max conn: " << max - << endl; + LOG(1) << "fd limit" + << " hard:" << limit.rlim_max << " soft:" << limit.rlim_cur << " max conn: " << max + << endl; - return max; + return max; #endif - } +} - void Listener::checkTicketNumbers() { - int want = getMaxConnections(); - int current = globalTicketHolder.outof(); - if ( current != DEFAULT_MAX_CONN ) { - if ( current < want ) { - // they want fewer than they can handle - // which is fine - LOG(1) << " only allowing " << current << " connections" << endl; - return; - } - if ( current > want ) { - log() << " --maxConns too high, can only handle " << want << endl; - } +void Listener::checkTicketNumbers() { + int want = getMaxConnections(); + int current = globalTicketHolder.outof(); + if (current != DEFAULT_MAX_CONN) { + if (current < want) { + // they want fewer than they can handle + // which is fine + LOG(1) << " only allowing " << current << " connections" << endl; + return; + } + if (current > want) { + log() << " --maxConns too high, can only handle " << want << endl; } - globalTicketHolder.resize( want ); } + globalTicketHolder.resize(want); +} - TicketHolder Listener::globalTicketHolder(DEFAULT_MAX_CONN); - AtomicInt64 Listener::globalConnectionNumber; - - void ListeningSockets::closeAll() { - std::set<int>* sockets; - std::set<std::string>* paths; +TicketHolder Listener::globalTicketHolder(DEFAULT_MAX_CONN); +AtomicInt64 Listener::globalConnectionNumber; - { - stdx::lock_guard<stdx::mutex> lk( _mutex ); - sockets = _sockets; - _sockets = new std::set<int>(); - paths = _socketPaths; - _socketPaths = new std::set<std::string>(); - } +void ListeningSockets::closeAll() { + std::set<int>* sockets; + std::set<std::string>* paths; - for ( std::set<int>::iterator i=sockets->begin(); i!=sockets->end(); i++ ) { - int sock = *i; - log() << "closing listening socket: " << sock << std::endl; - closesocket( sock ); - } - delete sockets; + { + stdx::lock_guard<stdx::mutex> lk(_mutex); + sockets = _sockets; + _sockets = new std::set<int>(); + paths = _socketPaths; + _socketPaths = new std::set<std::string>(); + } - for ( std::set<std::string>::iterator i=paths->begin(); i!=paths->end(); i++ ) { - std::string path = *i; - log() << "removing socket file: " << path << std::endl; - ::remove( path.c_str() ); - } - delete paths; + for (std::set<int>::iterator i = sockets->begin(); i != sockets->end(); i++) { + int sock = *i; + log() << "closing listening socket: " << sock << std::endl; + closesocket(sock); } + delete sockets; + for (std::set<std::string>::iterator i = paths->begin(); i != paths->end(); i++) { + std::string path = *i; + log() << "removing socket file: " << path << std::endl; + ::remove(path.c_str()); + } + delete paths; +} } diff --git a/src/mongo/util/net/listen.h b/src/mongo/util/net/listen.h index 546ef223bf4..390c23bc443 100644 --- a/src/mongo/util/net/listen.h +++ b/src/mongo/util/net/listen.h @@ -42,128 +42,129 @@ namespace mongo { - const int DEFAULT_MAX_CONN = 1000000; - - class MessagingPort; - - class Listener { - MONGO_DISALLOW_COPYING(Listener); - public: - - Listener(const std::string& name, const std::string &ip, int port, bool logConnect=true ); - - virtual ~Listener(); - - void initAndListen(); // never returns unless error (start a thread) - - /* spawn a thread, etc., then return */ - virtual void accepted(std::shared_ptr<Socket> psocket, long long connectionId ); - virtual void acceptedMP(MessagingPort *mp); - - const int _port; - - /** - * @return a rough estimate of elapsed time since the server started - todo: - 1) consider adding some sort of relaxedLoad semantic to the reading here of - _elapsedTime - 2) curTimeMillis() implementations have gotten faster. consider eliminating - this code? would have to measure it first. if eliminated be careful if - syscall used isn't skewable. Note also if #2 is done, listen() doesn't - then have to keep waking up and maybe that helps on a developer's laptop - battery usage... - */ - long long getMyElapsedTimeMillis() const { return _elapsedTime; } - - /** - * Allocate sockets for the listener and set _setupSocketsSuccessful to true - * iff the process was successful. - */ - void setupSockets(); - - void setAsTimeTracker() { - _timeTracker = this; - } - - // TODO(spencer): Remove this and get the global Listener via the - // globalEnvironmentExperiment - static const Listener* getTimeTracker() { - return _timeTracker; - } - - static long long getElapsedTimeMillis() { - if ( _timeTracker ) - return _timeTracker->getMyElapsedTimeMillis(); - - // should this assert or throw? seems like callers may not expect to get zero back, certainly not forever. - return 0; - } - - /** - * Blocks until initAndListen has been called on this instance and gotten far enough that - * it is ready to receive incoming network requests. - */ - void waitUntilListening() const; - - private: - std::vector<SockAddr> _mine; - std::vector<SOCKET> _socks; - std::string _name; - std::string _ip; - bool _setupSocketsSuccessful; - bool _logConnect; - long long _elapsedTime; - mutable stdx::mutex _readyMutex; // Protects _ready - mutable stdx::condition_variable _readyCondition; // Used to wait for changes to _ready - // Boolean that indicates whether this Listener is ready to accept incoming network requests - bool _ready; +const int DEFAULT_MAX_CONN = 1000000; + +class MessagingPort; + +class Listener { + MONGO_DISALLOW_COPYING(Listener); + +public: + Listener(const std::string& name, const std::string& ip, int port, bool logConnect = true); + + virtual ~Listener(); + + void initAndListen(); // never returns unless error (start a thread) + + /* spawn a thread, etc., then return */ + virtual void accepted(std::shared_ptr<Socket> psocket, long long connectionId); + virtual void acceptedMP(MessagingPort* mp); + + const int _port; + + /** + * @return a rough estimate of elapsed time since the server started + todo: + 1) consider adding some sort of relaxedLoad semantic to the reading here of + _elapsedTime + 2) curTimeMillis() implementations have gotten faster. consider eliminating + this code? would have to measure it first. if eliminated be careful if + syscall used isn't skewable. Note also if #2 is done, listen() doesn't + then have to keep waking up and maybe that helps on a developer's laptop + battery usage... + */ + long long getMyElapsedTimeMillis() const { + return _elapsedTime; + } + + /** + * Allocate sockets for the listener and set _setupSocketsSuccessful to true + * iff the process was successful. + */ + void setupSockets(); + + void setAsTimeTracker() { + _timeTracker = this; + } + + // TODO(spencer): Remove this and get the global Listener via the + // globalEnvironmentExperiment + static const Listener* getTimeTracker() { + return _timeTracker; + } + + static long long getElapsedTimeMillis() { + if (_timeTracker) + return _timeTracker->getMyElapsedTimeMillis(); + + // should this assert or throw? seems like callers may not expect to get zero back, certainly not forever. + return 0; + } + + /** + * Blocks until initAndListen has been called on this instance and gotten far enough that + * it is ready to receive incoming network requests. + */ + void waitUntilListening() const; + +private: + std::vector<SockAddr> _mine; + std::vector<SOCKET> _socks; + std::string _name; + std::string _ip; + bool _setupSocketsSuccessful; + bool _logConnect; + long long _elapsedTime; + mutable stdx::mutex _readyMutex; // Protects _ready + mutable stdx::condition_variable _readyCondition; // Used to wait for changes to _ready + // Boolean that indicates whether this Listener is ready to accept incoming network requests + bool _ready; #ifdef MONGO_CONFIG_SSL - SSLManagerInterface* _ssl; + SSLManagerInterface* _ssl; #endif - void _logListen( int port , bool ssl ); - - static const Listener* _timeTracker; - - virtual bool useUnixSockets() const { return false; } - - public: - /** the "next" connection number. every connection to this process has a unique number */ - static AtomicInt64 globalConnectionNumber; - - /** keeps track of how many allowed connections there are and how many are being used*/ - static TicketHolder globalTicketHolder; - - /** makes sure user input is sane */ - static void checkTicketNumbers(); - }; - - class ListeningSockets { - public: - ListeningSockets() - : _sockets( new std::set<int>() ) - , _socketPaths( new std::set<std::string>() ) - { } - void add( int sock ) { - stdx::lock_guard<stdx::mutex> lk( _mutex ); - _sockets->insert( sock ); - } - void addPath( const std::string& path ) { - stdx::lock_guard<stdx::mutex> lk( _mutex ); - _socketPaths->insert( path ); - } - void remove( int sock ) { - stdx::lock_guard<stdx::mutex> lk( _mutex ); - _sockets->erase( sock ); - } - void closeAll(); - static ListeningSockets* get(); - private: - stdx::mutex _mutex; - std::set<int>* _sockets; - std::set<std::string>* _socketPaths; // for unix domain sockets - static ListeningSockets* _instance; - }; - + void _logListen(int port, bool ssl); + + static const Listener* _timeTracker; + + virtual bool useUnixSockets() const { + return false; + } + +public: + /** the "next" connection number. every connection to this process has a unique number */ + static AtomicInt64 globalConnectionNumber; + + /** keeps track of how many allowed connections there are and how many are being used*/ + static TicketHolder globalTicketHolder; + + /** makes sure user input is sane */ + static void checkTicketNumbers(); +}; + +class ListeningSockets { +public: + ListeningSockets() : _sockets(new std::set<int>()), _socketPaths(new std::set<std::string>()) {} + void add(int sock) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _sockets->insert(sock); + } + void addPath(const std::string& path) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _socketPaths->insert(path); + } + void remove(int sock) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _sockets->erase(sock); + } + void closeAll(); + static ListeningSockets* get(); + +private: + stdx::mutex _mutex; + std::set<int>* _sockets; + std::set<std::string>* _socketPaths; // for unix domain sockets + static ListeningSockets* _instance; +}; } diff --git a/src/mongo/util/net/message.cpp b/src/mongo/util/net/message.cpp index 4d68aa01177..f2777cb7972 100644 --- a/src/mongo/util/net/message.cpp +++ b/src/mongo/util/net/message.cpp @@ -40,34 +40,33 @@ namespace mongo { - void Message::send( MessagingPort &p, const char *context ) { - if ( empty() ) { - return; - } - if ( _buf != 0 ) { - p.send( _buf, MsgData::ConstView(_buf).getLen(), context ); - } - else { - p.send( _data, context ); - } +void Message::send(MessagingPort& p, const char* context) { + if (empty()) { + return; } + if (_buf != 0) { + p.send(_buf, MsgData::ConstView(_buf).getLen(), context); + } else { + p.send(_data, context); + } +} - AtomicWord<MSGID> NextMsgId; - - /*struct MsgStart { - MsgStart() { - NextMsgId = (((unsigned) time(0)) << 16) ^ curTimeMillis(); - verify(MsgDataHeaderSize == 16); - } - } msgstart;*/ +AtomicWord<MSGID> NextMsgId; - MSGID nextMessageId() { - return NextMsgId.fetchAndAdd(1); +/*struct MsgStart { + MsgStart() { + NextMsgId = (((unsigned) time(0)) << 16) ^ curTimeMillis(); + verify(MsgDataHeaderSize == 16); } +} msgstart;*/ - bool doesOpGetAResponse( int op ) { - return op == dbQuery || op == dbGetMore; - } +MSGID nextMessageId() { + return NextMsgId.fetchAndAdd(1); +} + +bool doesOpGetAResponse(int op) { + return op == dbQuery || op == dbGetMore; +} -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/message.h b/src/mongo/util/net/message.h index b3573026f84..d83492e519c 100644 --- a/src/mongo/util/net/message.h +++ b/src/mongo/util/net/message.h @@ -44,55 +44,65 @@ namespace mongo { - /** - * Maximum accepted message size on the wire protocol. - */ - const size_t MaxMessageSizeBytes = 48 * 1000 * 1000; - - class Message; - class MessagingPort; - class PiggyBackData; - - typedef uint32_t MSGID; - - enum Operations { - opReply = 1, /* reply. responseTo is set. */ - dbMsg = 1000, /* generic msg command followed by a std::string */ - dbUpdate = 2001, /* update object */ - dbInsert = 2002, - //dbGetByOID = 2003, - dbQuery = 2004, - dbGetMore = 2005, - dbDelete = 2006, - dbKillCursors = 2007, - dbCommand = 2008, - dbCommandReply = 2009, - }; - - bool doesOpGetAResponse( int op ); - - inline const char * opToString( int op ) { - switch ( op ) { - case 0: return "none"; - case opReply: return "reply"; - case dbMsg: return "msg"; - case dbUpdate: return "update"; - case dbInsert: return "insert"; - case dbQuery: return "query"; - case dbGetMore: return "getmore"; - case dbDelete: return "remove"; - case dbKillCursors: return "killcursors"; - case dbCommand: return "command"; - case dbCommandReply: return "commandReply"; +/** + * Maximum accepted message size on the wire protocol. + */ +const size_t MaxMessageSizeBytes = 48 * 1000 * 1000; + +class Message; +class MessagingPort; +class PiggyBackData; + +typedef uint32_t MSGID; + +enum Operations { + opReply = 1, /* reply. responseTo is set. */ + dbMsg = 1000, /* generic msg command followed by a std::string */ + dbUpdate = 2001, /* update object */ + dbInsert = 2002, + // dbGetByOID = 2003, + dbQuery = 2004, + dbGetMore = 2005, + dbDelete = 2006, + dbKillCursors = 2007, + dbCommand = 2008, + dbCommandReply = 2009, +}; + +bool doesOpGetAResponse(int op); + +inline const char* opToString(int op) { + switch (op) { + case 0: + return "none"; + case opReply: + return "reply"; + case dbMsg: + return "msg"; + case dbUpdate: + return "update"; + case dbInsert: + return "insert"; + case dbQuery: + return "query"; + case dbGetMore: + return "getmore"; + case dbDelete: + return "remove"; + case dbKillCursors: + return "killcursors"; + case dbCommand: + return "command"; + case dbCommandReply: + return "commandReply"; default: - massert( 16141, str::stream() << "cannot translate opcode " << op, !op ); + massert(16141, str::stream() << "cannot translate opcode " << op, !op); return ""; - } } +} - inline bool opIsWrite( int op ) { - switch ( op ) { - +inline bool opIsWrite(int op) { + switch (op) { case 0: case opReply: case dbMsg: @@ -110,383 +120,387 @@ namespace mongo { PRINT(op); verify(0); return ""; - } - } +} - namespace MSGHEADER { +namespace MSGHEADER { #pragma pack(1) - /* see http://dochub.mongodb.org/core/mongowireprotocol - */ - struct Layout { - int32_t messageLength; // total message size, including this - int32_t requestID; // identifier for this message - int32_t responseTo; // requestID from the original request - // (used in responses from db) - int32_t opCode; - }; +/* see http://dochub.mongodb.org/core/mongowireprotocol +*/ +struct Layout { + int32_t messageLength; // total message size, including this + int32_t requestID; // identifier for this message + int32_t responseTo; // requestID from the original request + // (used in responses from db) + int32_t opCode; +}; #pragma pack() - class ConstView { - public: - typedef ConstDataView view_type; +class ConstView { +public: + typedef ConstDataView view_type; - ConstView(const char* data) : _data(data) { } + ConstView(const char* data) : _data(data) {} - const char* view2ptr() const { - return data().view(); - } + const char* view2ptr() const { + return data().view(); + } - int32_t getMessageLength() const { - return data().read<LittleEndian<int32_t>>(offsetof(Layout, messageLength)); - } + int32_t getMessageLength() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, messageLength)); + } - int32_t getRequestID() const { - return data().read<LittleEndian<int32_t>>(offsetof(Layout, requestID)); - } + int32_t getRequestID() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, requestID)); + } - int32_t getResponseTo() const { - return data().read<LittleEndian<int32_t>>(offsetof(Layout, responseTo)); - } + int32_t getResponseTo() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, responseTo)); + } - int32_t getOpCode() const { - return data().read<LittleEndian<int32_t>>(offsetof(Layout, opCode)); - } + int32_t getOpCode() const { + return data().read<LittleEndian<int32_t>>(offsetof(Layout, opCode)); + } - protected: - const view_type& data() const { - return _data; - } +protected: + const view_type& data() const { + return _data; + } - private: - view_type _data; - }; +private: + view_type _data; +}; - class View : public ConstView { - public: - typedef DataView view_type; +class View : public ConstView { +public: + typedef DataView view_type; - View(char* data) : ConstView(data) {} + View(char* data) : ConstView(data) {} - using ConstView::view2ptr; - char* view2ptr() { - return data().view(); - } + using ConstView::view2ptr; + char* view2ptr() { + return data().view(); + } - void setMessageLength(int32_t value) { - data().write(tagLittleEndian(value), offsetof(Layout, messageLength)); - } + void setMessageLength(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, messageLength)); + } - void setRequestID(int32_t value) { - data().write(tagLittleEndian(value), offsetof(Layout, requestID)); - } + void setRequestID(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, requestID)); + } - void setResponseTo(int32_t value) { - data().write(tagLittleEndian(value), offsetof(Layout, responseTo)); - } + void setResponseTo(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, responseTo)); + } - void setOpCode(int32_t value) { - data().write(tagLittleEndian(value), offsetof(Layout, opCode)); - } + void setOpCode(int32_t value) { + data().write(tagLittleEndian(value), offsetof(Layout, opCode)); + } - private: - view_type data() const { - return const_cast<char *>(ConstView::view2ptr()); - } - }; +private: + view_type data() const { + return const_cast<char*>(ConstView::view2ptr()); + } +}; - class Value : public EncodedValueStorage<Layout, ConstView, View> { - public: - Value() { - BOOST_STATIC_ASSERT(sizeof(Value) == sizeof(Layout)); - } +class Value : public EncodedValueStorage<Layout, ConstView, View> { +public: + Value() { + BOOST_STATIC_ASSERT(sizeof(Value) == sizeof(Layout)); + } - Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {} - }; + Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {} +}; - } // namespace MSGHEADER +} // namespace MSGHEADER - namespace MsgData { +namespace MsgData { #pragma pack(1) - struct Layout { - MSGHEADER::Layout header; - char data[4]; - }; +struct Layout { + MSGHEADER::Layout header; + char data[4]; +}; #pragma pack() - class ConstView { - public: - ConstView(const char* storage) : _storage(storage) { } +class ConstView { +public: + ConstView(const char* storage) : _storage(storage) {} - const char* view2ptr() const { - return storage().view(); - } - - int32_t getLen() const { - return header().getMessageLength(); - } + const char* view2ptr() const { + return storage().view(); + } - MSGID getId() const { - return header().getRequestID(); - } + int32_t getLen() const { + return header().getMessageLength(); + } - MSGID getResponseTo() const { - return header().getResponseTo(); - } + MSGID getId() const { + return header().getRequestID(); + } - int32_t getOperation() const { - return header().getOpCode(); - } + MSGID getResponseTo() const { + return header().getResponseTo(); + } - const char* data() const { - return storage().view(offsetof(Layout, data)); - } + int32_t getOperation() const { + return header().getOpCode(); + } - bool valid() const { - if ( getLen() <= 0 || getLen() > ( 4 * BSONObjMaxInternalSize ) ) - return false; - if ( getOperation() < 0 || getOperation() > 30000 ) - return false; - return true; - } + const char* data() const { + return storage().view(offsetof(Layout, data)); + } - int64_t getCursor() const { - verify( getResponseTo() > 0 ); - verify( getOperation() == opReply ); - return ConstDataView(data() + sizeof(int32_t)).read<LittleEndian<int64_t>>(); - } + bool valid() const { + if (getLen() <= 0 || getLen() > (4 * BSONObjMaxInternalSize)) + return false; + if (getOperation() < 0 || getOperation() > 30000) + return false; + return true; + } - int dataLen() const; // len without header + int64_t getCursor() const { + verify(getResponseTo() > 0); + verify(getOperation() == opReply); + return ConstDataView(data() + sizeof(int32_t)).read<LittleEndian<int64_t>>(); + } - protected: - const ConstDataView& storage() const { - return _storage; - } + int dataLen() const; // len without header - MSGHEADER::ConstView header() const { - return storage().view(offsetof(Layout, header)); - } +protected: + const ConstDataView& storage() const { + return _storage; + } - private: - ConstDataView _storage; - }; + MSGHEADER::ConstView header() const { + return storage().view(offsetof(Layout, header)); + } - class View : public ConstView { - public: - View(char* storage) : ConstView(storage) {} +private: + ConstDataView _storage; +}; - using ConstView::view2ptr; - char* view2ptr() { - return storage().view(); - } +class View : public ConstView { +public: + View(char* storage) : ConstView(storage) {} - void setLen(int value) { - return header().setMessageLength(value); - } + using ConstView::view2ptr; + char* view2ptr() { + return storage().view(); + } - void setId(MSGID value) { - return header().setRequestID(value); - } + void setLen(int value) { + return header().setMessageLength(value); + } - void setResponseTo(MSGID value) { - return header().setResponseTo(value); - } + void setId(MSGID value) { + return header().setRequestID(value); + } - void setOperation(int value) { - return header().setOpCode(value); - } + void setResponseTo(MSGID value) { + return header().setResponseTo(value); + } - using ConstView::data; - char* data() { - return storage().view(offsetof(Layout, data)); - } + void setOperation(int value) { + return header().setOpCode(value); + } - private: - DataView storage() const { - return const_cast<char *>(ConstView::view2ptr()); - } + using ConstView::data; + char* data() { + return storage().view(offsetof(Layout, data)); + } - MSGHEADER::View header() const { - return storage().view(offsetof(Layout, header)); - } - }; +private: + DataView storage() const { + return const_cast<char*>(ConstView::view2ptr()); + } - class Value : public EncodedValueStorage<Layout, ConstView, View> { - public: - Value() { - BOOST_STATIC_ASSERT(sizeof(Value) == sizeof(Layout)); - } + MSGHEADER::View header() const { + return storage().view(offsetof(Layout, header)); + } +}; - Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {} - }; +class Value : public EncodedValueStorage<Layout, ConstView, View> { +public: + Value() { + BOOST_STATIC_ASSERT(sizeof(Value) == sizeof(Layout)); + } - const int MsgDataHeaderSize = sizeof(Value) - 4; - inline int ConstView::dataLen() const { - return getLen() - MsgDataHeaderSize; - } - } // namespace MsgData - - class Message { - public: - // we assume here that a vector with initial size 0 does no allocation (0 is the default, but wanted to make it explicit). - Message() : _buf( 0 ), _data( 0 ), _freeIt( false ) {} - Message( void * data , bool freeIt ) : - _buf( 0 ), _data( 0 ), _freeIt( false ) { - _setData( reinterpret_cast< char* >( data ), freeIt ); - }; - Message(Message& r) : _buf( 0 ), _data( 0 ), _freeIt( false ) { - *this = r; - } - ~Message() { - reset(); - } + Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {} +}; + +const int MsgDataHeaderSize = sizeof(Value) - 4; +inline int ConstView::dataLen() const { + return getLen() - MsgDataHeaderSize; +} +} // namespace MsgData + +class Message { +public: + // we assume here that a vector with initial size 0 does no allocation (0 is the default, but wanted to make it explicit). + Message() : _buf(0), _data(0), _freeIt(false) {} + Message(void* data, bool freeIt) : _buf(0), _data(0), _freeIt(false) { + _setData(reinterpret_cast<char*>(data), freeIt); + }; + Message(Message& r) : _buf(0), _data(0), _freeIt(false) { + *this = r; + } + ~Message() { + reset(); + } - SockAddr _from; + SockAddr _from; - MsgData::View header() const { - verify( !empty() ); - return _buf ? _buf : _data[ 0 ].first; - } + MsgData::View header() const { + verify(!empty()); + return _buf ? _buf : _data[0].first; + } - int operation() const { return header().getOperation(); } + int operation() const { + return header().getOperation(); + } - MsgData::View singleData() const { - massert( 13273, "single data buffer expected", _buf ); - return header(); - } + MsgData::View singleData() const { + massert(13273, "single data buffer expected", _buf); + return header(); + } - bool empty() const { return !_buf && _data.empty(); } + bool empty() const { + return !_buf && _data.empty(); + } - int size() const { - int res = 0; - if ( _buf ) { - res = MsgData::ConstView(_buf).getLen(); - } - else { - for (MsgVec::const_iterator it = _data.begin(); it != _data.end(); ++it) { - res += it->second; - } + int size() const { + int res = 0; + if (_buf) { + res = MsgData::ConstView(_buf).getLen(); + } else { + for (MsgVec::const_iterator it = _data.begin(); it != _data.end(); ++it) { + res += it->second; } - return res; } + return res; + } - int dataSize() const { return size() - sizeof(MSGHEADER::Value); } - - // concat multiple buffers - noop if <2 buffers already, otherwise can be expensive copy - // can get rid of this if we make response handling smarter - void concat() { - if ( _buf || empty() ) { - return; - } + int dataSize() const { + return size() - sizeof(MSGHEADER::Value); + } - verify( _freeIt ); - int totalSize = 0; - for (std::vector< std::pair< char *, int > >::const_iterator i = _data.begin(); - i != _data.end(); ++i) { - totalSize += i->second; - } - char *buf = (char*)mongoMalloc( totalSize ); - char *p = buf; - for (std::vector< std::pair< char *, int > >::const_iterator i = _data.begin(); - i != _data.end(); ++i) { - memcpy( p, i->first, i->second ); - p += i->second; - } - reset(); - _setData( buf, true ); + // concat multiple buffers - noop if <2 buffers already, otherwise can be expensive copy + // can get rid of this if we make response handling smarter + void concat() { + if (_buf || empty()) { + return; } - // vector swap() so this is fast - Message& operator=(Message& r) { - verify( empty() ); - verify( r._freeIt ); - _buf = r._buf; - r._buf = 0; - if ( r._data.size() > 0 ) { - _data.swap( r._data ); - } - r._freeIt = false; - _freeIt = true; - return *this; + verify(_freeIt); + int totalSize = 0; + for (std::vector<std::pair<char*, int>>::const_iterator i = _data.begin(); i != _data.end(); + ++i) { + totalSize += i->second; + } + char* buf = (char*)mongoMalloc(totalSize); + char* p = buf; + for (std::vector<std::pair<char*, int>>::const_iterator i = _data.begin(); i != _data.end(); + ++i) { + memcpy(p, i->first, i->second); + p += i->second; } + reset(); + _setData(buf, true); + } - void reset() { - if ( _freeIt ) { - if ( _buf ) { - free( _buf ); - } - for (std::vector< std::pair< char *, int > >::const_iterator i = _data.begin(); - i != _data.end(); ++i) { - free(i->first); - } - } - _buf = 0; - _data.clear(); - _freeIt = false; + // vector swap() so this is fast + Message& operator=(Message& r) { + verify(empty()); + verify(r._freeIt); + _buf = r._buf; + r._buf = 0; + if (r._data.size() > 0) { + _data.swap(r._data); } + r._freeIt = false; + _freeIt = true; + return *this; + } - // use to add a buffer - // assumes message will free everything - void appendData(char *d, int size) { - if ( size <= 0 ) { - return; - } - if ( empty() ) { - MsgData::View md = d; - md.setLen(size); // can be updated later if more buffers added - _setData( md.view2ptr(), true ); - return; + void reset() { + if (_freeIt) { + if (_buf) { + free(_buf); } - verify( _freeIt ); - if ( _buf ) { - _data.push_back(std::make_pair(_buf, MsgData::ConstView(_buf).getLen())); - _buf = 0; + for (std::vector<std::pair<char*, int>>::const_iterator i = _data.begin(); + i != _data.end(); + ++i) { + free(i->first); } - _data.push_back(std::make_pair(d, size)); - header().setLen(header().getLen() + size); } + _buf = 0; + _data.clear(); + _freeIt = false; + } - // use to set first buffer if empty - void setData(char* d, bool freeIt) { - verify( empty() ); - _setData( d, freeIt ); + // use to add a buffer + // assumes message will free everything + void appendData(char* d, int size) { + if (size <= 0) { + return; } - void setData(int operation, const char *msgtxt) { - setData(operation, msgtxt, strlen(msgtxt)+1); + if (empty()) { + MsgData::View md = d; + md.setLen(size); // can be updated later if more buffers added + _setData(md.view2ptr(), true); + return; } - void setData(int operation, const char *msgdata, size_t len) { - verify( empty() ); - size_t dataLen = len + sizeof(MsgData::Value) - 4; - MsgData::View d = reinterpret_cast<char *>(mongoMalloc(dataLen)); - memcpy(d.data(), msgdata, len); - d.setLen(dataLen); - d.setOperation(operation); - _setData( d.view2ptr(), true ); + verify(_freeIt); + if (_buf) { + _data.push_back(std::make_pair(_buf, MsgData::ConstView(_buf).getLen())); + _buf = 0; } + _data.push_back(std::make_pair(d, size)); + header().setLen(header().getLen() + size); + } - bool doIFreeIt() { - return _freeIt; - } + // use to set first buffer if empty + void setData(char* d, bool freeIt) { + verify(empty()); + _setData(d, freeIt); + } + void setData(int operation, const char* msgtxt) { + setData(operation, msgtxt, strlen(msgtxt) + 1); + } + void setData(int operation, const char* msgdata, size_t len) { + verify(empty()); + size_t dataLen = len + sizeof(MsgData::Value) - 4; + MsgData::View d = reinterpret_cast<char*>(mongoMalloc(dataLen)); + memcpy(d.data(), msgdata, len); + d.setLen(dataLen); + d.setOperation(operation); + _setData(d.view2ptr(), true); + } - void send( MessagingPort &p, const char *context ); - - std::string toString() const; + bool doIFreeIt() { + return _freeIt; + } - private: - void _setData( char* d, bool freeIt ) { - _freeIt = freeIt; - _buf = d; - } - // if just one buffer, keep it in _buf, otherwise keep a sequence of buffers in _data - char* _buf; - // byte buffer(s) - the first must contain at least a full MsgData unless using _buf for storage instead - typedef std::vector< std::pair< char*, int > > MsgVec; - MsgVec _data; - bool _freeIt; - }; + void send(MessagingPort& p, const char* context); + + std::string toString() const; + +private: + void _setData(char* d, bool freeIt) { + _freeIt = freeIt; + _buf = d; + } + // if just one buffer, keep it in _buf, otherwise keep a sequence of buffers in _data + char* _buf; + // byte buffer(s) - the first must contain at least a full MsgData unless using _buf for storage instead + typedef std::vector<std::pair<char*, int>> MsgVec; + MsgVec _data; + bool _freeIt; +}; - MSGID nextMessageId(); +MSGID nextMessageId(); -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/message_port.cpp b/src/mongo/util/net/message_port.cpp index b4e062456c9..8fae1b34a82 100644 --- a/src/mongo/util/net/message_port.cpp +++ b/src/mongo/util/net/message_port.cpp @@ -48,308 +48,304 @@ #include "mongo/util/time_support.h" #ifndef _WIN32 -# ifndef __sun -# include <ifaddrs.h> -# endif -# include <sys/resource.h> -# include <sys/stat.h> +#ifndef __sun +#include <ifaddrs.h> +#endif +#include <sys/resource.h> +#include <sys/stat.h> #endif namespace mongo { - using std::shared_ptr; - using std::string; +using std::shared_ptr; +using std::string; // if you want trace output: #define mmm(x) - void AbstractMessagingPort::setConnectionId( long long connectionId ) { - verify( _connectionId == 0 ); - _connectionId = connectionId; - } - - /* messagingport -------------------------------------------------------------- */ - - class PiggyBackData { - public: - PiggyBackData( MessagingPort * port ) { - _port = port; - _buf = new char[1300]; - _cur = _buf; - } - - ~PiggyBackData() { - DESTRUCTOR_GUARD ( - flush(); - delete[]( _cur ); - ); - } - - void append( Message& m ) { - verify( m.header().getLen() <= 1300 ); +void AbstractMessagingPort::setConnectionId(long long connectionId) { + verify(_connectionId == 0); + _connectionId = connectionId; +} - if ( len() + m.header().getLen() > 1300 ) - flush(); +/* messagingport -------------------------------------------------------------- */ - memcpy( _cur , m.singleData().view2ptr() , m.header().getLen() ); - _cur += m.header().getLen(); - } - - void flush() { - if ( _buf == _cur ) - return; +class PiggyBackData { +public: + PiggyBackData(MessagingPort* port) { + _port = port; + _buf = new char[1300]; + _cur = _buf; + } - _port->send( _buf , len(), "flush" ); - _cur = _buf; - } + ~PiggyBackData() { + DESTRUCTOR_GUARD(flush(); delete[](_cur);); + } - int len() const { return _cur - _buf; } - - private: - MessagingPort* _port; - char * _buf; - char * _cur; - }; - - class Ports { - std::set<MessagingPort*> ports; - stdx::mutex m; - public: - Ports() : ports() {} - void closeAll(unsigned skip_mask) { - stdx::lock_guard<stdx::mutex> bl(m); - for ( std::set<MessagingPort*>::iterator i = ports.begin(); i != ports.end(); i++ ) { - if( (*i)->tag & skip_mask ) - continue; - (*i)->shutdown(); - } - } - void insert(MessagingPort* p) { - stdx::lock_guard<stdx::mutex> bl(m); - ports.insert(p); - } - void erase(MessagingPort* p) { - stdx::lock_guard<stdx::mutex> bl(m); - ports.erase(p); - } - }; + void append(Message& m) { + verify(m.header().getLen() <= 1300); - // we "new" this so it is still be around when other automatic global vars - // are being destructed during termination. - Ports& ports = *(new Ports()); + if (len() + m.header().getLen() > 1300) + flush(); - void MessagingPort::closeAllSockets(unsigned mask) { - ports.closeAll(mask); + memcpy(_cur, m.singleData().view2ptr(), m.header().getLen()); + _cur += m.header().getLen(); } - MessagingPort::MessagingPort(int fd, const SockAddr& remote) - : psock( new Socket( fd , remote ) ) , piggyBackData(0) { - ports.insert(this); - } + void flush() { + if (_buf == _cur) + return; - MessagingPort::MessagingPort( double timeout, logger::LogSeverity ll ) - : psock( new Socket( timeout, ll ) ) { - ports.insert(this); - piggyBackData = 0; + _port->send(_buf, len(), "flush"); + _cur = _buf; } - MessagingPort::MessagingPort( std::shared_ptr<Socket> sock ) - : psock( sock ), piggyBackData( 0 ) { - ports.insert(this); + int len() const { + return _cur - _buf; } - void MessagingPort::setSocketTimeout(double timeout) { - psock->setTimeout(timeout); +private: + MessagingPort* _port; + char* _buf; + char* _cur; +}; + +class Ports { + std::set<MessagingPort*> ports; + stdx::mutex m; + +public: + Ports() : ports() {} + void closeAll(unsigned skip_mask) { + stdx::lock_guard<stdx::mutex> bl(m); + for (std::set<MessagingPort*>::iterator i = ports.begin(); i != ports.end(); i++) { + if ((*i)->tag & skip_mask) + continue; + (*i)->shutdown(); + } } - - void MessagingPort::shutdown() { - psock->close(); + void insert(MessagingPort* p) { + stdx::lock_guard<stdx::mutex> bl(m); + ports.insert(p); } - - MessagingPort::~MessagingPort() { - if ( piggyBackData ) - delete( piggyBackData ); - shutdown(); - ports.erase(this); + void erase(MessagingPort* p) { + stdx::lock_guard<stdx::mutex> bl(m); + ports.erase(p); } - - bool MessagingPort::recv(Message& m) { - try { -again: - //mmm( log() << "* recv() sock:" << this->sock << endl; ) - MSGHEADER::Value header; - int headerLen = sizeof(MSGHEADER::Value); - psock->recv( (char *)&header, headerLen ); - int len = header.constView().getMessageLength(); - - if ( len == 542393671 ) { - // an http GET - string msg = "It looks like you are trying to access MongoDB over HTTP on the native driver port.\n"; - LOG( psock->getLogLevel() ) << msg; - std::stringstream ss; - ss << "HTTP/1.0 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nContent-Length: " << msg.size() << "\r\n\r\n" << msg; - string s = ss.str(); - send( s.c_str(), s.size(), "http" ); - return false; +}; + +// we "new" this so it is still be around when other automatic global vars +// are being destructed during termination. +Ports& ports = *(new Ports()); + +void MessagingPort::closeAllSockets(unsigned mask) { + ports.closeAll(mask); +} + +MessagingPort::MessagingPort(int fd, const SockAddr& remote) + : psock(new Socket(fd, remote)), piggyBackData(0) { + ports.insert(this); +} + +MessagingPort::MessagingPort(double timeout, logger::LogSeverity ll) + : psock(new Socket(timeout, ll)) { + ports.insert(this); + piggyBackData = 0; +} + +MessagingPort::MessagingPort(std::shared_ptr<Socket> sock) : psock(sock), piggyBackData(0) { + ports.insert(this); +} + +void MessagingPort::setSocketTimeout(double timeout) { + psock->setTimeout(timeout); +} + +void MessagingPort::shutdown() { + psock->close(); +} + +MessagingPort::~MessagingPort() { + if (piggyBackData) + delete (piggyBackData); + shutdown(); + ports.erase(this); +} + +bool MessagingPort::recv(Message& m) { + try { + again: + // mmm( log() << "* recv() sock:" << this->sock << endl; ) + MSGHEADER::Value header; + int headerLen = sizeof(MSGHEADER::Value); + psock->recv((char*)&header, headerLen); + int len = header.constView().getMessageLength(); + + if (len == 542393671) { + // an http GET + string msg = + "It looks like you are trying to access MongoDB over HTTP on the native driver " + "port.\n"; + LOG(psock->getLogLevel()) << msg; + std::stringstream ss; + ss << "HTTP/1.0 200 OK\r\nConnection: close\r\nContent-Type: " + "text/plain\r\nContent-Length: " << msg.size() << "\r\n\r\n" << msg; + string s = ss.str(); + send(s.c_str(), s.size(), "http"); + return false; + } else if (len == -1) { + // Endian check from the client, after connecting, to see what mode server is running in. + unsigned foo = 0x10203040; + send((char*)&foo, 4, "endian"); + psock->setHandshakeReceived(); + goto again; + } + // If responseTo is not 0 or -1 for first packet assume SSL + else if (psock->isAwaitingHandshake()) { +#ifndef MONGO_CONFIG_SSL + if (header.constView().getResponseTo() != 0 && + header.constView().getResponseTo() != -1) { + uasserted(17133, + "SSL handshake requested, SSL feature not available in this build"); } - else if ( len == -1 ) { - // Endian check from the client, after connecting, to see what mode server is running in. - unsigned foo = 0x10203040; - send( (char *) &foo, 4, "endian" ); +#else + if (header.constView().getResponseTo() != 0 && + header.constView().getResponseTo() != -1) { + uassert(17132, + "SSL handshake received but server is started without SSL support", + sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled); + setX509SubjectName( + psock->doSSLHandshake(reinterpret_cast<const char*>(&header), sizeof(header))); psock->setHandshakeReceived(); goto again; } - // If responseTo is not 0 or -1 for first packet assume SSL - else if (psock->isAwaitingHandshake()) { -#ifndef MONGO_CONFIG_SSL - if (header.constView().getResponseTo() != 0 - && header.constView().getResponseTo() != -1) { - uasserted(17133, - "SSL handshake requested, SSL feature not available in this build"); - } -#else - if (header.constView().getResponseTo() != 0 - && header.constView().getResponseTo() != -1) { - uassert(17132, - "SSL handshake received but server is started without SSL support", - sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled); - setX509SubjectName(psock->doSSLHandshake( - reinterpret_cast<const char*>(&header), sizeof(header))); - psock->setHandshakeReceived(); - goto again; - } - uassert(17189, "The server is configured to only allow SSL connections", - sslGlobalParams.sslMode.load() != SSLParams::SSLMode_requireSSL); -#endif // MONGO_CONFIG_SSL - } - if ( static_cast<size_t>(len) < sizeof(MSGHEADER::Value) || - static_cast<size_t>(len) > MaxMessageSizeBytes ) { - LOG(0) << "recv(): message len " << len << " is invalid. " - << "Min " << sizeof(MSGHEADER::Value) << " Max: " << MaxMessageSizeBytes; - return false; - } - - psock->setHandshakeReceived(); - int z = (len+1023)&0xfffffc00; - verify(z>=len); - MsgData::View md = reinterpret_cast<char *>(mongoMalloc(z)); - ScopeGuard guard = MakeGuard(free, md.view2ptr()); - verify(md.view2ptr()); - - memcpy(md.view2ptr(), &header, headerLen); - int left = len - headerLen; - - psock->recv( md.data(), left ); - - guard.Dismiss(); - m.setData(md.view2ptr(), true); - return true; - + uassert(17189, + "The server is configured to only allow SSL connections", + sslGlobalParams.sslMode.load() != SSLParams::SSLMode_requireSSL); +#endif // MONGO_CONFIG_SSL } - catch ( const SocketException & e ) { - logger::LogSeverity severity = psock->getLogLevel(); - if (!e.shouldPrint()) - severity = severity.lessSevere(); - LOG(severity) << "SocketException: remote: " << remote() << " error: " << e; - m.reset(); + if (static_cast<size_t>(len) < sizeof(MSGHEADER::Value) || + static_cast<size_t>(len) > MaxMessageSizeBytes) { + LOG(0) << "recv(): message len " << len << " is invalid. " + << "Min " << sizeof(MSGHEADER::Value) << " Max: " << MaxMessageSizeBytes; return false; } - } - void MessagingPort::reply(Message& received, Message& response) { - say(/*received.from, */response, received.header().getId()); - } + psock->setHandshakeReceived(); + int z = (len + 1023) & 0xfffffc00; + verify(z >= len); + MsgData::View md = reinterpret_cast<char*>(mongoMalloc(z)); + ScopeGuard guard = MakeGuard(free, md.view2ptr()); + verify(md.view2ptr()); - void MessagingPort::reply(Message& received, Message& response, MSGID responseTo) { - say(/*received.from, */response, responseTo); - } + memcpy(md.view2ptr(), &header, headerLen); + int left = len - headerLen; - bool MessagingPort::call(Message& toSend, Message& response) { - mmm( log() << "*call()" << endl; ) - say(toSend); - return recv( toSend , response ); - } + psock->recv(md.data(), left); - bool MessagingPort::recv( const Message& toSend , Message& response ) { - while ( 1 ) { - bool ok = recv(response); - if ( !ok ) { - mmm( log() << "recv not ok" << endl; ) - return false; - } - //log() << "got response: " << response.data->responseTo << endl; - if ( response.header().getResponseTo() == toSend.header().getId() ) - break; - error() << "MessagingPort::call() wrong id got:" - << std::hex << (unsigned)response.header().getResponseTo() - << " expect:" << (unsigned)toSend.header().getId() << '\n' - << std::dec - << " toSend op: " << (unsigned)toSend.operation() << '\n' - << " response msgid:" << (unsigned)response.header().getId() << '\n' - << " response len: " << (unsigned)response.header().getLen() << '\n' - << " response op: " << response.operation() << '\n' - << " remote: " << psock->remoteString(); - verify(false); - response.reset(); - } - mmm( log() << "*call() end" << endl; ) + guard.Dismiss(); + m.setData(md.view2ptr(), true); return true; + + } catch (const SocketException& e) { + logger::LogSeverity severity = psock->getLogLevel(); + if (!e.shouldPrint()) + severity = severity.lessSevere(); + LOG(severity) << "SocketException: remote: " << remote() << " error: " << e; + m.reset(); + return false; + } +} + +void MessagingPort::reply(Message& received, Message& response) { + say(/*received.from, */ response, received.header().getId()); +} + +void MessagingPort::reply(Message& received, Message& response, MSGID responseTo) { + say(/*received.from, */ response, responseTo); +} + +bool MessagingPort::call(Message& toSend, Message& response) { + mmm(log() << "*call()" << endl;) say(toSend); + return recv(toSend, response); +} + +bool MessagingPort::recv(const Message& toSend, Message& response) { + while (1) { + bool ok = recv(response); + if (!ok) { + mmm(log() << "recv not ok" << endl;) return false; + } + // log() << "got response: " << response.data->responseTo << endl; + if (response.header().getResponseTo() == toSend.header().getId()) + break; + error() << "MessagingPort::call() wrong id got:" << std::hex + << (unsigned)response.header().getResponseTo() + << " expect:" << (unsigned)toSend.header().getId() << '\n' << std::dec + << " toSend op: " << (unsigned)toSend.operation() << '\n' + << " response msgid:" << (unsigned)response.header().getId() << '\n' + << " response len: " << (unsigned)response.header().getLen() << '\n' + << " response op: " << response.operation() << '\n' + << " remote: " << psock->remoteString(); + verify(false); + response.reset(); } + mmm(log() << "*call() end" << endl;) return true; +} - void MessagingPort::say(Message& toSend, int responseTo) { - verify( !toSend.empty() ); - mmm( log() << "* say() thr:" << GetCurrentThreadId() << endl; ) +void MessagingPort::say(Message& toSend, int responseTo) { + verify(!toSend.empty()); + mmm(log() << "* say() thr:" << GetCurrentThreadId() << endl;) toSend.header().setId(nextMessageId()); - toSend.header().setResponseTo(responseTo); + toSend.header().setResponseTo(responseTo); - if ( piggyBackData && piggyBackData->len() ) { - mmm( log() << "* have piggy back" << endl; ) - if ( ( piggyBackData->len() + toSend.header().getLen() ) > 1300 ) { - // won't fit in a packet - so just send it off - piggyBackData->flush(); - } - else { - piggyBackData->append( toSend ); - piggyBackData->flush(); - return; - } + if (piggyBackData && piggyBackData->len()) { + mmm(log() << "* have piggy back" + << endl;) if ((piggyBackData->len() + toSend.header().getLen()) > 1300) { + // won't fit in a packet - so just send it off + piggyBackData->flush(); + } + else { + piggyBackData->append(toSend); + piggyBackData->flush(); + return; } - - toSend.send( *this, "say" ); } - void MessagingPort::piggyBack( Message& toSend , int responseTo ) { + toSend.send(*this, "say"); +} - if ( toSend.header().getLen() > 1300 ) { - // not worth saving because its almost an entire packet - say( toSend ); - return; - } +void MessagingPort::piggyBack(Message& toSend, int responseTo) { + if (toSend.header().getLen() > 1300) { + // not worth saving because its almost an entire packet + say(toSend); + return; + } - // we're going to be storing this, so need to set it up - toSend.header().setId(nextMessageId()); - toSend.header().setResponseTo(responseTo); + // we're going to be storing this, so need to set it up + toSend.header().setId(nextMessageId()); + toSend.header().setResponseTo(responseTo); - if ( ! piggyBackData ) - piggyBackData = new PiggyBackData( this ); + if (!piggyBackData) + piggyBackData = new PiggyBackData(this); - piggyBackData->append( toSend ); - } + piggyBackData->append(toSend); +} - HostAndPort MessagingPort::remote() const { - if ( ! _remoteParsed.hasPort() ) { - SockAddr sa = psock->remoteAddr(); - _remoteParsed = HostAndPort( sa.getAddr(), sa.getPort()); - } - return _remoteParsed; +HostAndPort MessagingPort::remote() const { + if (!_remoteParsed.hasPort()) { + SockAddr sa = psock->remoteAddr(); + _remoteParsed = HostAndPort(sa.getAddr(), sa.getPort()); } + return _remoteParsed; +} - SockAddr MessagingPort::remoteAddr() const { - return psock->remoteAddr(); - } +SockAddr MessagingPort::remoteAddr() const { + return psock->remoteAddr(); +} - SockAddr MessagingPort::localAddr() const { - return psock->localAddr(); - } +SockAddr MessagingPort::localAddr() const { + return psock->localAddr(); +} -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/message_port.h b/src/mongo/util/net/message_port.h index 7843621bf5e..6af98b160a9 100644 --- a/src/mongo/util/net/message_port.h +++ b/src/mongo/util/net/message_port.h @@ -37,135 +37,142 @@ namespace mongo { - class MessagingPort; - class PiggyBackData; - - class AbstractMessagingPort { - MONGO_DISALLOW_COPYING(AbstractMessagingPort); - public: - AbstractMessagingPort() : tag(0), _connectionId(0) {} - virtual ~AbstractMessagingPort() { } - virtual void reply(Message& received, Message& response, MSGID responseTo) = 0; // like the reply below, but doesn't rely on received.data still being available - virtual void reply(Message& received, Message& response) = 0; - - virtual HostAndPort remote() const = 0; - virtual unsigned remotePort() const = 0; - virtual SockAddr remoteAddr() const = 0; - virtual SockAddr localAddr() const = 0; - - void setX509SubjectName(const std::string& x509SubjectName) { - _x509SubjectName = x509SubjectName; - } - - std::string getX509SubjectName() { - return _x509SubjectName; - } - - long long connectionId() const { return _connectionId; } - void setConnectionId( long long connectionId ); - - public: - // TODO make this private with some helpers - - /* ports can be tagged with various classes. see closeAllSockets(tag). defaults to 0. */ - unsigned tag; - - private: - long long _connectionId; - std::string _x509SubjectName; - }; - - class MessagingPort : public AbstractMessagingPort { - public: - MessagingPort(int fd, const SockAddr& remote); - - // in some cases the timeout will actually be 2x this value - eg we do a partial send, - // then the timeout fires, then we try to send again, then the timeout fires again with - // no data sent, then we detect that the other side is down - MessagingPort(double so_timeout = 0, - logger::LogSeverity logLevel = logger::LogSeverity::Log() ); - - MessagingPort(std::shared_ptr<Socket> socket); - - virtual ~MessagingPort(); - - void setSocketTimeout(double timeout); - - void shutdown(); - - /* it's assumed if you reuse a message object, that it doesn't cross MessagingPort's. - also, the Message data will go out of scope on the subsequent recv call. - */ - bool recv(Message& m); - void reply(Message& received, Message& response, MSGID responseTo); - void reply(Message& received, Message& response); - bool call(Message& toSend, Message& response); - - void say(Message& toSend, int responseTo = 0); - - /** - * this is used for doing 'async' queries - * instead of doing call( to , from ) - * you would do - * say( to ) - * recv( from ) - * Note: if you fail to call recv and someone else uses this port, - * horrible things will happen - */ - bool recv( const Message& sent , Message& response ); - - void piggyBack( Message& toSend , int responseTo = 0 ); - - unsigned remotePort() const { return psock->remotePort(); } - virtual HostAndPort remote() const; - virtual SockAddr remoteAddr() const; - virtual SockAddr localAddr() const; - - std::shared_ptr<Socket> psock; - - void send( const char * data , int len, const char *context ) { - psock->send( data, len, context ); - } - void send(const std::vector< std::pair< char *, int > > &data, const char *context) { - psock->send( data, context ); - } - bool connect(SockAddr& farEnd) { - return psock->connect( farEnd ); - } +class MessagingPort; +class PiggyBackData; + +class AbstractMessagingPort { + MONGO_DISALLOW_COPYING(AbstractMessagingPort); + +public: + AbstractMessagingPort() : tag(0), _connectionId(0) {} + virtual ~AbstractMessagingPort() {} + virtual void reply( + Message& received, + Message& response, + MSGID + responseTo) = 0; // like the reply below, but doesn't rely on received.data still being available + virtual void reply(Message& received, Message& response) = 0; + + virtual HostAndPort remote() const = 0; + virtual unsigned remotePort() const = 0; + virtual SockAddr remoteAddr() const = 0; + virtual SockAddr localAddr() const = 0; + + void setX509SubjectName(const std::string& x509SubjectName) { + _x509SubjectName = x509SubjectName; + } + + std::string getX509SubjectName() { + return _x509SubjectName; + } + + long long connectionId() const { + return _connectionId; + } + void setConnectionId(long long connectionId); + +public: + // TODO make this private with some helpers + + /* ports can be tagged with various classes. see closeAllSockets(tag). defaults to 0. */ + unsigned tag; + +private: + long long _connectionId; + std::string _x509SubjectName; +}; + +class MessagingPort : public AbstractMessagingPort { +public: + MessagingPort(int fd, const SockAddr& remote); + + // in some cases the timeout will actually be 2x this value - eg we do a partial send, + // then the timeout fires, then we try to send again, then the timeout fires again with + // no data sent, then we detect that the other side is down + MessagingPort(double so_timeout = 0, logger::LogSeverity logLevel = logger::LogSeverity::Log()); + + MessagingPort(std::shared_ptr<Socket> socket); + + virtual ~MessagingPort(); + + void setSocketTimeout(double timeout); + + void shutdown(); + + /* it's assumed if you reuse a message object, that it doesn't cross MessagingPort's. + also, the Message data will go out of scope on the subsequent recv call. + */ + bool recv(Message& m); + void reply(Message& received, Message& response, MSGID responseTo); + void reply(Message& received, Message& response); + bool call(Message& toSend, Message& response); + + void say(Message& toSend, int responseTo = 0); + + /** + * this is used for doing 'async' queries + * instead of doing call( to , from ) + * you would do + * say( to ) + * recv( from ) + * Note: if you fail to call recv and someone else uses this port, + * horrible things will happen + */ + bool recv(const Message& sent, Message& response); + + void piggyBack(Message& toSend, int responseTo = 0); + + unsigned remotePort() const { + return psock->remotePort(); + } + virtual HostAndPort remote() const; + virtual SockAddr remoteAddr() const; + virtual SockAddr localAddr() const; + + std::shared_ptr<Socket> psock; + + void send(const char* data, int len, const char* context) { + psock->send(data, len, context); + } + void send(const std::vector<std::pair<char*, int>>& data, const char* context) { + psock->send(data, context); + } + bool connect(SockAddr& farEnd) { + return psock->connect(farEnd); + } #ifdef MONGO_CONFIG_SSL - /** - * Initiates the TLS/SSL handshake on this MessagingPort. - * When this function returns, further communication on this - * MessagingPort will be encrypted. - * ssl - Pointer to the global SSLManager. - * remoteHost - The hostname of the remote server. - */ - bool secure( SSLManagerInterface* ssl, const std::string& remoteHost ) { - return psock->secure( ssl, remoteHost ); - } + /** + * Initiates the TLS/SSL handshake on this MessagingPort. + * When this function returns, further communication on this + * MessagingPort will be encrypted. + * ssl - Pointer to the global SSLManager. + * remoteHost - The hostname of the remote server. + */ + bool secure(SSLManagerInterface* ssl, const std::string& remoteHost) { + return psock->secure(ssl, remoteHost); + } #endif - bool isStillConnected() { - return psock->isStillConnected(); - } + bool isStillConnected() { + return psock->isStillConnected(); + } - uint64_t getSockCreationMicroSec() const { - return psock->getSockCreationMicroSec(); - } + uint64_t getSockCreationMicroSec() const { + return psock->getSockCreationMicroSec(); + } - private: - - PiggyBackData * piggyBackData; +private: + PiggyBackData* piggyBackData; - // this is the parsed version of remote - // mutable because its initialized only on call to remote() - mutable HostAndPort _remoteParsed; + // this is the parsed version of remote + // mutable because its initialized only on call to remote() + mutable HostAndPort _remoteParsed; - public: - static void closeAllSockets(unsigned tagMask = 0xffffffff); +public: + static void closeAllSockets(unsigned tagMask = 0xffffffff); - friend class PiggyBackData; - }; + friend class PiggyBackData; +}; -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/message_server.h b/src/mongo/util/net/message_server.h index 7bb9759eeff..8807b6c0cb7 100644 --- a/src/mongo/util/net/message_server.h +++ b/src/mongo/util/net/message_server.h @@ -38,37 +38,37 @@ namespace mongo { - class MessageHandler { - public: - virtual ~MessageHandler() {} - - /** - * called once when a socket is connected - */ - virtual void connected( AbstractMessagingPort* p ) = 0; +class MessageHandler { +public: + virtual ~MessageHandler() {} - /** - * called every time a message comes in - * handler is responsible for responding to client - */ - virtual void process(Message& m, AbstractMessagingPort* p) = 0; - }; + /** + * called once when a socket is connected + */ + virtual void connected(AbstractMessagingPort* p) = 0; - class MessageServer { - public: - struct Options { - int port; // port to bind to - std::string ipList; // addresses to bind to + /** + * called every time a message comes in + * handler is responsible for responding to client + */ + virtual void process(Message& m, AbstractMessagingPort* p) = 0; +}; - Options() : port(0), ipList("") {} - }; +class MessageServer { +public: + struct Options { + int port; // port to bind to + std::string ipList; // addresses to bind to - virtual ~MessageServer() {} - virtual void run() = 0; - virtual void setAsTimeTracker() = 0; - virtual void setupSockets() = 0; + Options() : port(0), ipList("") {} }; - // TODO use a factory here to decide between port and asio variations - MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ); + virtual ~MessageServer() {} + virtual void run() = 0; + virtual void setAsTimeTracker() = 0; + virtual void setupSockets() = 0; +}; + +// TODO use a factory here to decide between port and asio variations +MessageServer* createServer(const MessageServer::Options& opts, MessageHandler* handler); } diff --git a/src/mongo/util/net/message_server_port.cpp b/src/mongo/util/net/message_server_port.cpp index 7c176c66764..0a56c9203c1 100644 --- a/src/mongo/util/net/message_server_port.cpp +++ b/src/mongo/util/net/message_server_port.cpp @@ -55,7 +55,7 @@ #include "mongo/util/scopeguard.h" #ifdef __linux__ // TODO: consider making this ifndef _WIN32 -# include <sys/resource.h> +#include <sys/resource.h> #endif #if !defined(__has_feature) @@ -64,206 +64,205 @@ namespace mongo { - using std::unique_ptr; - using std::endl; +using std::unique_ptr; +using std::endl; namespace { - class MessagingPortWithHandler : public MessagingPort { - MONGO_DISALLOW_COPYING(MessagingPortWithHandler); +class MessagingPortWithHandler : public MessagingPort { + MONGO_DISALLOW_COPYING(MessagingPortWithHandler); - public: - MessagingPortWithHandler(const std::shared_ptr<Socket>& socket, - MessageHandler* handler, - long long connectionId) - : MessagingPort(socket), _handler(handler) { - setConnectionId(connectionId); - } +public: + MessagingPortWithHandler(const std::shared_ptr<Socket>& socket, + MessageHandler* handler, + long long connectionId) + : MessagingPort(socket), _handler(handler) { + setConnectionId(connectionId); + } - MessageHandler* getHandler() const { return _handler; } + MessageHandler* getHandler() const { + return _handler; + } - private: - // Not owned. - MessageHandler* const _handler; - }; +private: + // Not owned. + MessageHandler* const _handler; +}; } // namespace - class PortMessageServer : public MessageServer , public Listener { - public: - /** - * Creates a new message server. - * - * @param opts - * @param handler the handler to use. Caller is responsible for managing this object - * and should make sure that it lives longer than this server. - */ - PortMessageServer( const MessageServer::Options& opts, MessageHandler * handler ) : - Listener( "" , opts.ipList, opts.port ), _handler(handler) { +class PortMessageServer : public MessageServer, public Listener { +public: + /** + * Creates a new message server. + * + * @param opts + * @param handler the handler to use. Caller is responsible for managing this object + * and should make sure that it lives longer than this server. + */ + PortMessageServer(const MessageServer::Options& opts, MessageHandler* handler) + : Listener("", opts.ipList, opts.port), _handler(handler) {} + + virtual void accepted(std::shared_ptr<Socket> psocket, long long connectionId) { + ScopeGuard sleepAfterClosingPort = MakeGuard(sleepmillis, 2); + std::unique_ptr<MessagingPortWithHandler> portWithHandler( + new MessagingPortWithHandler(psocket, _handler, connectionId)); + + if (!Listener::globalTicketHolder.tryAcquire()) { + log() << "connection refused because too many open connections: " + << Listener::globalTicketHolder.used() << endl; + return; } - virtual void accepted(std::shared_ptr<Socket> psocket, long long connectionId ) { - ScopeGuard sleepAfterClosingPort = MakeGuard(sleepmillis, 2); - std::unique_ptr<MessagingPortWithHandler> portWithHandler( - new MessagingPortWithHandler(psocket, _handler, connectionId)); - - if ( ! Listener::globalTicketHolder.tryAcquire() ) { - log() << "connection refused because too many open connections: " << Listener::globalTicketHolder.used() << endl; - return; - } - - try { + try { #ifndef __linux__ // TODO: consider making this ifdef _WIN32 - { - stdx::thread thr(stdx::bind(&handleIncomingMsg, portWithHandler.get())); - } + { stdx::thread thr(stdx::bind(&handleIncomingMsg, portWithHandler.get())); } #else - pthread_attr_t attrs; - pthread_attr_init(&attrs); - pthread_attr_setdetachstate(&attrs, PTHREAD_CREATE_DETACHED); + pthread_attr_t attrs; + pthread_attr_init(&attrs); + pthread_attr_setdetachstate(&attrs, PTHREAD_CREATE_DETACHED); - static const size_t STACK_SIZE = 1024*1024; // if we change this we need to update the warning + static const size_t STACK_SIZE = + 1024 * 1024; // if we change this we need to update the warning - struct rlimit limits; - verify(getrlimit(RLIMIT_STACK, &limits) == 0); - if (limits.rlim_cur > STACK_SIZE) { - size_t stackSizeToSet = STACK_SIZE; + struct rlimit limits; + verify(getrlimit(RLIMIT_STACK, &limits) == 0); + if (limits.rlim_cur > STACK_SIZE) { + size_t stackSizeToSet = STACK_SIZE; #if !__has_feature(address_sanitizer) - if (kDebugBuild) - stackSizeToSet /= 2; + if (kDebugBuild) + stackSizeToSet /= 2; #endif - pthread_attr_setstacksize(&attrs, stackSizeToSet); - } else if (limits.rlim_cur < 1024*1024) { - warning() << "Stack size set to " << (limits.rlim_cur/1024) << "KB. We suggest 1MB" << endl; - } + pthread_attr_setstacksize(&attrs, stackSizeToSet); + } else if (limits.rlim_cur < 1024 * 1024) { + warning() << "Stack size set to " << (limits.rlim_cur / 1024) + << "KB. We suggest 1MB" << endl; + } - pthread_t thread; - int failed = - pthread_create(&thread, &attrs, &handleIncomingMsg, portWithHandler.get()); + pthread_t thread; + int failed = pthread_create(&thread, &attrs, &handleIncomingMsg, portWithHandler.get()); - pthread_attr_destroy(&attrs); + pthread_attr_destroy(&attrs); - if (failed) { - log() << "pthread_create failed: " << errnoWithDescription(failed) << endl; - throw boost::thread_resource_error(); // for consistency with boost::thread - } + if (failed) { + log() << "pthread_create failed: " << errnoWithDescription(failed) << endl; + throw boost::thread_resource_error(); // for consistency with boost::thread + } #endif // __linux__ - portWithHandler.release(); - sleepAfterClosingPort.Dismiss(); - } - catch ( boost::thread_resource_error& ) { - Listener::globalTicketHolder.release(); - log() << "can't create new thread, closing connection" << endl; - } - catch ( ... ) { - Listener::globalTicketHolder.release(); - log() << "unknown error accepting new socket" << endl; - } + portWithHandler.release(); + sleepAfterClosingPort.Dismiss(); + } catch (boost::thread_resource_error&) { + Listener::globalTicketHolder.release(); + log() << "can't create new thread, closing connection" << endl; + } catch (...) { + Listener::globalTicketHolder.release(); + log() << "unknown error accepting new socket" << endl; } + } - virtual void setAsTimeTracker() { - Listener::setAsTimeTracker(); - } + virtual void setAsTimeTracker() { + Listener::setAsTimeTracker(); + } - virtual void setupSockets() { - Listener::setupSockets(); - } + virtual void setupSockets() { + Listener::setupSockets(); + } - void run() { - initAndListen(); - } + void run() { + initAndListen(); + } - virtual bool useUnixSockets() const { return true; } - - private: - MessageHandler* _handler; - - /** - * Handles incoming messages from a given socket. - * - * Terminating conditions: - * 1. Assertions while handling the request. - * 2. Socket is closed. - * 3. Server is shutting down (based on inShutdown) - * - * @param arg this method is in charge of cleaning up the arg object. - * - * @return NULL - */ - static void* handleIncomingMsg(void* arg) { - TicketHolderReleaser connTicketReleaser( &Listener::globalTicketHolder ); - - invariant(arg); - unique_ptr<MessagingPortWithHandler> portWithHandler( - static_cast<MessagingPortWithHandler*>(arg)); - MessageHandler* const handler = portWithHandler->getHandler(); - - setThreadName(std::string(str::stream() << "conn" << portWithHandler->connectionId())); - portWithHandler->psock->setLogLevel(logger::LogSeverity::Debug(1)); - - Message m; - int64_t counter = 0; - try { - handler->connected(portWithHandler.get()); - - while ( ! inShutdown() ) { - m.reset(); - portWithHandler->psock->clearCounters(); - - if (!portWithHandler->recv(m)) { - if (!serverGlobalParams.quiet) { - int conns = Listener::globalTicketHolder.used()-1; - const char* word = (conns == 1 ? " connection" : " connections"); - log() << "end connection " << portWithHandler->psock->remoteString() - << " (" << conns << word << " now open)" << endl; - } - portWithHandler->shutdown(); - break; + virtual bool useUnixSockets() const { + return true; + } + +private: + MessageHandler* _handler; + + /** + * Handles incoming messages from a given socket. + * + * Terminating conditions: + * 1. Assertions while handling the request. + * 2. Socket is closed. + * 3. Server is shutting down (based on inShutdown) + * + * @param arg this method is in charge of cleaning up the arg object. + * + * @return NULL + */ + static void* handleIncomingMsg(void* arg) { + TicketHolderReleaser connTicketReleaser(&Listener::globalTicketHolder); + + invariant(arg); + unique_ptr<MessagingPortWithHandler> portWithHandler( + static_cast<MessagingPortWithHandler*>(arg)); + MessageHandler* const handler = portWithHandler->getHandler(); + + setThreadName(std::string(str::stream() << "conn" << portWithHandler->connectionId())); + portWithHandler->psock->setLogLevel(logger::LogSeverity::Debug(1)); + + Message m; + int64_t counter = 0; + try { + handler->connected(portWithHandler.get()); + + while (!inShutdown()) { + m.reset(); + portWithHandler->psock->clearCounters(); + + if (!portWithHandler->recv(m)) { + if (!serverGlobalParams.quiet) { + int conns = Listener::globalTicketHolder.used() - 1; + const char* word = (conns == 1 ? " connection" : " connections"); + log() << "end connection " << portWithHandler->psock->remoteString() << " (" + << conns << word << " now open)" << endl; } + portWithHandler->shutdown(); + break; + } - handler->process(m, portWithHandler.get()); - networkCounter.hit(portWithHandler->psock->getBytesIn(), - portWithHandler->psock->getBytesOut()); + handler->process(m, portWithHandler.get()); + networkCounter.hit(portWithHandler->psock->getBytesIn(), + portWithHandler->psock->getBytesOut()); - // Occasionally we want to see if we're using too much memory. - if ((counter++ & 0xf) == 0) { - markThreadIdle(); - } + // Occasionally we want to see if we're using too much memory. + if ((counter++ & 0xf) == 0) { + markThreadIdle(); } } - catch ( AssertionException& e ) { - log() << "AssertionException handling request, closing client connection: " << e << endl; - portWithHandler->shutdown(); - } - catch ( SocketException& e ) { - log() << "SocketException handling request, closing client connection: " << e << endl; - portWithHandler->shutdown(); - } - catch ( const DBException& e ) { // must be right above std::exception to avoid catching subclasses - log() << "DBException handling request, closing client connection: " << e << endl; - portWithHandler->shutdown(); - } - catch ( std::exception &e ) { - error() << "Uncaught std::exception: " << e.what() << ", terminating" << endl; - dbexit( EXIT_UNCAUGHT ); - } + } catch (AssertionException& e) { + log() << "AssertionException handling request, closing client connection: " << e + << endl; + portWithHandler->shutdown(); + } catch (SocketException& e) { + log() << "SocketException handling request, closing client connection: " << e << endl; + portWithHandler->shutdown(); + } catch (const DBException& + e) { // must be right above std::exception to avoid catching subclasses + log() << "DBException handling request, closing client connection: " << e << endl; + portWithHandler->shutdown(); + } catch (std::exception& e) { + error() << "Uncaught std::exception: " << e.what() << ", terminating" << endl; + dbexit(EXIT_UNCAUGHT); + } - // Normal disconnect path. +// Normal disconnect path. #ifdef MONGO_CONFIG_SSL - SSLManagerInterface* manager = getSSLManager(); - if (manager) - manager->cleanupThreadLocals(); + SSLManagerInterface* manager = getSSLManager(); + if (manager) + manager->cleanupThreadLocals(); #endif - return NULL; - } - }; + return NULL; + } +}; - MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ) { - return new PortMessageServer( opts , handler ); - } +MessageServer* createServer(const MessageServer::Options& opts, MessageHandler* handler) { + return new PortMessageServer(opts, handler); +} } // namespace mongo diff --git a/src/mongo/util/net/miniwebserver.cpp b/src/mongo/util/net/miniwebserver.cpp index 888ca70e4de..fc86f95b24f 100644 --- a/src/mongo/util/net/miniwebserver.cpp +++ b/src/mongo/util/net/miniwebserver.cpp @@ -41,205 +41,196 @@ namespace mongo { - using std::shared_ptr; - using std::endl; - using std::stringstream; - using std::vector; - - MiniWebServer::MiniWebServer(const string& name, const string &ip, int port) - : Listener(name, ip, port, false) - {} - - string MiniWebServer::parseURL( const char * buf ) { - const char * urlStart = strchr( buf , ' ' ); - if ( ! urlStart ) - return "/"; - - urlStart++; - - const char * end = strchr( urlStart , ' ' ); - if ( ! end ) { - end = strchr( urlStart , '\r' ); - if ( ! end ) { - end = strchr( urlStart , '\n' ); - } +using std::shared_ptr; +using std::endl; +using std::stringstream; +using std::vector; + +MiniWebServer::MiniWebServer(const string& name, const string& ip, int port) + : Listener(name, ip, port, false) {} + +string MiniWebServer::parseURL(const char* buf) { + const char* urlStart = strchr(buf, ' '); + if (!urlStart) + return "/"; + + urlStart++; + + const char* end = strchr(urlStart, ' '); + if (!end) { + end = strchr(urlStart, '\r'); + if (!end) { + end = strchr(urlStart, '\n'); } - - if ( ! end ) - return "/"; - - int diff = (int)(end-urlStart); - if ( diff < 0 || diff > 255 ) - return "/"; - - return string( urlStart , (int)(end-urlStart) ); } - void MiniWebServer::parseParams( BSONObj & params , string query ) { - if ( query.size() == 0 ) - return; + if (!end) + return "/"; - BSONObjBuilder b; - while ( query.size() ) { + int diff = (int)(end - urlStart); + if (diff < 0 || diff > 255) + return "/"; - string::size_type amp = query.find( "&" ); + return string(urlStart, (int)(end - urlStart)); +} - string cur; - if ( amp == string::npos ) { - cur = query; - query = ""; - } - else { - cur = query.substr( 0 , amp ); - query = query.substr( amp + 1 ); - } +void MiniWebServer::parseParams(BSONObj& params, string query) { + if (query.size() == 0) + return; - string::size_type eq = cur.find( "=" ); - if ( eq == string::npos ) - continue; + BSONObjBuilder b; + while (query.size()) { + string::size_type amp = query.find("&"); - b.append( urlDecode(cur.substr(0,eq)) , urlDecode(cur.substr(eq+1) ) ); + string cur; + if (amp == string::npos) { + cur = query; + query = ""; + } else { + cur = query.substr(0, amp); + query = query.substr(amp + 1); } - params = b.obj(); - } + string::size_type eq = cur.find("="); + if (eq == string::npos) + continue; - string MiniWebServer::parseMethod( const char * headers ) { - const char * end = strchr( headers , ' ' ); - if ( ! end ) - return "GET"; - return string( headers , (int)(end-headers) ); + b.append(urlDecode(cur.substr(0, eq)), urlDecode(cur.substr(eq + 1))); } - const char *MiniWebServer::body( const char *buf ) { - const char *ret = strstr( buf, "\r\n\r\n" ); - return ret ? ret + 4 : ret; - } + params = b.obj(); +} - bool MiniWebServer::fullReceive( const char *buf ) { - const char *bod = body( buf ); - if ( !bod ) - return false; - const char *lenString = "Content-Length:"; - const char *lengthLoc = strstr( buf, lenString ); - if ( !lengthLoc ) - return true; - lengthLoc += strlen( lenString ); - long len = strtol( lengthLoc, 0, 10 ); - if ( long( strlen( bod ) ) == len ) - return true; - return false; - } +string MiniWebServer::parseMethod(const char* headers) { + const char* end = strchr(headers, ' '); + if (!end) + return "GET"; + return string(headers, (int)(end - headers)); +} - void MiniWebServer::accepted(std::shared_ptr<Socket> psock, long long connectionId ) { - char buf[4096]; - int len = 0; - try { +const char* MiniWebServer::body(const char* buf) { + const char* ret = strstr(buf, "\r\n\r\n"); + return ret ? ret + 4 : ret; +} + +bool MiniWebServer::fullReceive(const char* buf) { + const char* bod = body(buf); + if (!bod) + return false; + const char* lenString = "Content-Length:"; + const char* lengthLoc = strstr(buf, lenString); + if (!lengthLoc) + return true; + lengthLoc += strlen(lenString); + long len = strtol(lengthLoc, 0, 10); + if (long(strlen(bod)) == len) + return true; + return false; +} + +void MiniWebServer::accepted(std::shared_ptr<Socket> psock, long long connectionId) { + char buf[4096]; + int len = 0; + try { #ifdef MONGO_CONFIG_SSL - psock->doSSLHandshake(); + psock->doSSLHandshake(); #endif - psock->setTimeout(8); - while ( 1 ) { - int left = sizeof(buf) - 1 - len; - if( left == 0 ) - break; - int x; - try { - x = psock->unsafe_recv( buf + len , left ); - } catch (const SocketException&) { - psock->close(); - return; - } - len += x; - buf[ len ] = 0; - if ( fullReceive( buf ) ) { - break; - } + psock->setTimeout(8); + while (1) { + int left = sizeof(buf) - 1 - len; + if (left == 0) + break; + int x; + try { + x = psock->unsafe_recv(buf + len, left); + } catch (const SocketException&) { + psock->close(); + return; } - } - catch (const SocketException& e) { - LOG(1) << "couldn't recv data via http client: " << e << endl; - return; - } - buf[len] = 0; - - string responseMsg; - int responseCode = 599; - vector<string> headers; - - try { - doRequest(buf, parseURL( buf ), responseMsg, responseCode, headers, psock->remoteAddr() ); - } - catch ( std::exception& e ) { - responseCode = 500; - responseMsg = "error loading page: "; - responseMsg += e.what(); - } - catch ( ... ) { - responseCode = 500; - responseMsg = "unknown error loading page"; - } - - stringstream ss; - ss << "HTTP/1.0 " << responseCode; - if ( responseCode == 200 ) ss << " OK"; - ss << "\r\n"; - if ( headers.empty() ) { - ss << "Content-Type: text/html\r\n"; - } - else { - for ( vector<string>::iterator i = headers.begin(); i != headers.end(); i++ ) { - verify( strncmp("Content-Length", i->c_str(), 14) ); - ss << *i << "\r\n"; + len += x; + buf[len] = 0; + if (fullReceive(buf)) { + break; } } - ss << "Connection: close\r\n"; - ss << "Content-Length: " << responseMsg.size() << "\r\n"; - ss << "\r\n"; - ss << responseMsg; - string response = ss.str(); - - try { - psock->send( response.c_str(), response.size() , "http response" ); - psock->close(); - } - catch ( SocketException& e ) { - LOG(1) << "couldn't send data to http client: " << e << endl; - } + } catch (const SocketException& e) { + LOG(1) << "couldn't recv data via http client: " << e << endl; + return; + } + buf[len] = 0; + + string responseMsg; + int responseCode = 599; + vector<string> headers; + + try { + doRequest(buf, parseURL(buf), responseMsg, responseCode, headers, psock->remoteAddr()); + } catch (std::exception& e) { + responseCode = 500; + responseMsg = "error loading page: "; + responseMsg += e.what(); + } catch (...) { + responseCode = 500; + responseMsg = "unknown error loading page"; } - string MiniWebServer::getHeader( const char * req , const std::string& wanted ) { - const char * headers = strchr( req , '\n' ); - if ( ! headers ) - return ""; - pcrecpp::StringPiece input( headers + 1 ); - - string name; - string val; - pcrecpp::RE re("([\\w\\-]+): (.*?)\r?\n"); - while ( re.Consume( &input, &name, &val) ) { - if ( name == wanted ) - return val; + stringstream ss; + ss << "HTTP/1.0 " << responseCode; + if (responseCode == 200) + ss << " OK"; + ss << "\r\n"; + if (headers.empty()) { + ss << "Content-Type: text/html\r\n"; + } else { + for (vector<string>::iterator i = headers.begin(); i != headers.end(); i++) { + verify(strncmp("Content-Length", i->c_str(), 14)); + ss << *i << "\r\n"; } - return ""; } + ss << "Connection: close\r\n"; + ss << "Content-Length: " << responseMsg.size() << "\r\n"; + ss << "\r\n"; + ss << responseMsg; + string response = ss.str(); + + try { + psock->send(response.c_str(), response.size(), "http response"); + psock->close(); + } catch (SocketException& e) { + LOG(1) << "couldn't send data to http client: " << e << endl; + } +} - string MiniWebServer::urlDecode(const char* s) { - stringstream out; - while(*s) { - if (*s == '+') { - out << ' '; - } - else if (*s == '%') { - out << fromHex(s+1); - s+=2; - } - else { - out << *s; - } - s++; +string MiniWebServer::getHeader(const char* req, const std::string& wanted) { + const char* headers = strchr(req, '\n'); + if (!headers) + return ""; + pcrecpp::StringPiece input(headers + 1); + + string name; + string val; + pcrecpp::RE re("([\\w\\-]+): (.*?)\r?\n"); + while (re.Consume(&input, &name, &val)) { + if (name == wanted) + return val; + } + return ""; +} + +string MiniWebServer::urlDecode(const char* s) { + stringstream out; + while (*s) { + if (*s == '+') { + out << ' '; + } else if (*s == '%') { + out << fromHex(s + 1); + s += 2; + } else { + out << *s; } - return out.str(); + s++; } + return out.str(); +} -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/miniwebserver.h b/src/mongo/util/net/miniwebserver.h index d0eb06f445a..fa026e6711d 100644 --- a/src/mongo/util/net/miniwebserver.h +++ b/src/mongo/util/net/miniwebserver.h @@ -39,36 +39,38 @@ namespace mongo { - class MiniWebServer : public Listener { - public: - MiniWebServer(const std::string& name, const std::string &ip, int _port); - virtual ~MiniWebServer() {} +class MiniWebServer : public Listener { +public: + MiniWebServer(const std::string& name, const std::string& ip, int _port); + virtual ~MiniWebServer() {} - virtual void doRequest( - const char *rq, // the full request - std::string url, - // set these and return them: - std::string& responseMsg, - int& responseCode, - std::vector<std::string>& headers, // if completely empty, content-type: text/html will be added - const SockAddr &from - ) = 0; + virtual void doRequest( + const char* rq, // the full request + std::string url, + // set these and return them: + std::string& responseMsg, + int& responseCode, + std::vector<std::string>& + headers, // if completely empty, content-type: text/html will be added + const SockAddr& from) = 0; - // --- static helpers ---- + // --- static helpers ---- - static void parseParams( BSONObj & params , std::string query ); + static void parseParams(BSONObj& params, std::string query); - static std::string parseURL( const char * buf ); - static std::string parseMethod( const char * headers ); - static std::string getHeader( const char * headers , const std::string& name ); - static const char *body( const char *buf ); + static std::string parseURL(const char* buf); + static std::string parseMethod(const char* headers); + static std::string getHeader(const char* headers, const std::string& name); + static const char* body(const char* buf); - static std::string urlDecode(const char* s); - static std::string urlDecode(const std::string& s) {return urlDecode(s.c_str());} + static std::string urlDecode(const char* s); + static std::string urlDecode(const std::string& s) { + return urlDecode(s.c_str()); + } - private: - void accepted(std::shared_ptr<Socket> psocket, long long connectionId ); - static bool fullReceive( const char *buf ); - }; +private: + void accepted(std::shared_ptr<Socket> psocket, long long connectionId); + static bool fullReceive(const char* buf); +}; -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/sock.cpp b/src/mongo/util/net/sock.cpp index 42a5dcdba04..d9deed4036e 100644 --- a/src/mongo/util/net/sock.cpp +++ b/src/mongo/util/net/sock.cpp @@ -34,17 +34,17 @@ #include "mongo/util/net/sock.h" #if !defined(_WIN32) -# include <sys/socket.h> -# include <sys/types.h> -# include <sys/un.h> -# include <netinet/in.h> -# include <netinet/tcp.h> -# include <arpa/inet.h> -# include <errno.h> -# include <netdb.h> -# if defined(__OpenBSD__) -# include <sys/uio.h> -# endif +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <arpa/inet.h> +#include <errno.h> +#include <netdb.h> +#if defined(__OpenBSD__) +#include <sys/uio.h> +#endif #endif #include "mongo/config.h" @@ -63,930 +63,928 @@ namespace mongo { - using std::endl; - using std::pair; - using std::string; - using std::stringstream; - using std::vector; - - MONGO_FP_DECLARE(throwSockExcep); - - static bool ipv6 = false; - void enableIPv6(bool state) { ipv6 = state; } - bool IPv6Enabled() { return ipv6; } - - void setSockTimeouts(int sock, double secs) { - bool report = shouldLog(logger::LogSeverity::Debug(4)); - DEV report = true; +using std::endl; +using std::pair; +using std::string; +using std::stringstream; +using std::vector; + +MONGO_FP_DECLARE(throwSockExcep); + +static bool ipv6 = false; +void enableIPv6(bool state) { + ipv6 = state; +} +bool IPv6Enabled() { + return ipv6; +} + +void setSockTimeouts(int sock, double secs) { + bool report = shouldLog(logger::LogSeverity::Debug(4)); + DEV report = true; #if defined(_WIN32) - DWORD timeout = secs * 1000; // Windows timeout is a DWORD, in milliseconds. - int status = - setsockopt( sock, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast<char*>(&timeout), sizeof(DWORD) ); - if (report && (status == SOCKET_ERROR)) - log() << "unable to set SO_RCVTIMEO: " - << errnoWithDescription(WSAGetLastError()) << endl; - status = setsockopt( sock, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast<char*>(&timeout), sizeof(DWORD) ); - DEV if (report && (status == SOCKET_ERROR)) - log() << "unable to set SO_SNDTIMEO: " - << errnoWithDescription(WSAGetLastError()) << endl; + DWORD timeout = secs * 1000; // Windows timeout is a DWORD, in milliseconds. + int status = + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char*>(&timeout), sizeof(DWORD)); + if (report && (status == SOCKET_ERROR)) + log() << "unable to set SO_RCVTIMEO: " << errnoWithDescription(WSAGetLastError()) << endl; + status = + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<char*>(&timeout), sizeof(DWORD)); + DEV if (report && (status == SOCKET_ERROR)) log() + << "unable to set SO_SNDTIMEO: " << errnoWithDescription(WSAGetLastError()) << endl; #else - struct timeval tv; - tv.tv_sec = (int)secs; - tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000)); - bool ok = setsockopt( sock, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0; - if( report && !ok ) log() << "unable to set SO_RCVTIMEO" << endl; - ok = setsockopt( sock, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0; - DEV if( report && !ok ) log() << "unable to set SO_SNDTIMEO" << endl; + struct timeval tv; + tv.tv_sec = (int)secs; + tv.tv_usec = (int)((long long)(secs * 1000 * 1000) % (1000 * 1000)); + bool ok = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(tv)) == 0; + if (report && !ok) + log() << "unable to set SO_RCVTIMEO" << endl; + ok = setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char*)&tv, sizeof(tv)) == 0; + DEV if (report && !ok) log() << "unable to set SO_SNDTIMEO" << endl; #endif - } +} #if defined(_WIN32) - void disableNagle(int sock) { - int x = 1; - if ( setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *) &x, sizeof(x)) ) - error() << "disableNagle failed: " << errnoWithDescription() << endl; - if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) ) - error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl; - } +void disableNagle(int sock) { + int x = 1; + if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char*)&x, sizeof(x))) + error() << "disableNagle failed: " << errnoWithDescription() << endl; + if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char*)&x, sizeof(x))) + error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl; +} #else - - void disableNagle(int sock) { - int x = 1; + +void disableNagle(int sock) { + int x = 1; #ifdef SOL_TCP - int level = SOL_TCP; + int level = SOL_TCP; #else - int level = SOL_SOCKET; + int level = SOL_SOCKET; #endif - if ( setsockopt(sock, level, TCP_NODELAY, (char *) &x, sizeof(x)) ) - error() << "disableNagle failed: " << errnoWithDescription() << endl; + if (setsockopt(sock, level, TCP_NODELAY, (char*)&x, sizeof(x))) + error() << "disableNagle failed: " << errnoWithDescription() << endl; #ifdef SO_KEEPALIVE - if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) ) - error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl; - -# ifdef __linux__ - socklen_t len = sizeof(x); - if ( getsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, &len) ) - error() << "can't get TCP_KEEPIDLE: " << errnoWithDescription() << endl; - - if (x > 300) { - x = 300; - if ( setsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, sizeof(x)) ) { - error() << "can't set TCP_KEEPIDLE: " << errnoWithDescription() << endl; - } - } + if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char*)&x, sizeof(x))) + error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl; - len = sizeof(x); // just in case it changed - if ( getsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, &len) ) - error() << "can't get TCP_KEEPINTVL: " << errnoWithDescription() << endl; +#ifdef __linux__ + socklen_t len = sizeof(x); + if (getsockopt(sock, level, TCP_KEEPIDLE, (char*)&x, &len)) + error() << "can't get TCP_KEEPIDLE: " << errnoWithDescription() << endl; - if (x > 300) { - x = 300; - if ( setsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, sizeof(x)) ) { - error() << "can't set TCP_KEEPINTVL: " << errnoWithDescription() << endl; - } + if (x > 300) { + x = 300; + if (setsockopt(sock, level, TCP_KEEPIDLE, (char*)&x, sizeof(x))) { + error() << "can't set TCP_KEEPIDLE: " << errnoWithDescription() << endl; } -# endif -#endif + } + + len = sizeof(x); // just in case it changed + if (getsockopt(sock, level, TCP_KEEPINTVL, (char*)&x, &len)) + error() << "can't get TCP_KEEPINTVL: " << errnoWithDescription() << endl; + if (x > 300) { + x = 300; + if (setsockopt(sock, level, TCP_KEEPINTVL, (char*)&x, sizeof(x))) { + error() << "can't set TCP_KEEPINTVL: " << errnoWithDescription() << endl; + } } +#endif +#endif +} #endif - string getAddrInfoStrError(int code) { +string getAddrInfoStrError(int code) { #if !defined(_WIN32) - return gai_strerror(code); + return gai_strerror(code); #else - /* gai_strerrorA is not threadsafe on windows. don't use it. */ - return errnoWithDescription(code); + /* gai_strerrorA is not threadsafe on windows. don't use it. */ + return errnoWithDescription(code); #endif - } - - // --- SockAddr - SockAddr::SockAddr() { - addressSize = sizeof(sa); - memset(&sa, 0, sizeof(sa)); - sa.ss_family = AF_UNSPEC; - _isValid = true; - } - - SockAddr::SockAddr(int sourcePort) { - memset(as<sockaddr_in>().sin_zero, 0, sizeof(as<sockaddr_in>().sin_zero)); - as<sockaddr_in>().sin_family = AF_INET; - as<sockaddr_in>().sin_port = htons(sourcePort); - as<sockaddr_in>().sin_addr.s_addr = htonl(INADDR_ANY); - addressSize = sizeof(sockaddr_in); - _isValid = true; - } - - SockAddr::SockAddr(const char * _iporhost , int port) { - string target = _iporhost; - if( target == "localhost" ) { - target = "127.0.0.1"; - } - - if( mongoutils::str::contains(target, '/') ) { +} + +// --- SockAddr +SockAddr::SockAddr() { + addressSize = sizeof(sa); + memset(&sa, 0, sizeof(sa)); + sa.ss_family = AF_UNSPEC; + _isValid = true; +} + +SockAddr::SockAddr(int sourcePort) { + memset(as<sockaddr_in>().sin_zero, 0, sizeof(as<sockaddr_in>().sin_zero)); + as<sockaddr_in>().sin_family = AF_INET; + as<sockaddr_in>().sin_port = htons(sourcePort); + as<sockaddr_in>().sin_addr.s_addr = htonl(INADDR_ANY); + addressSize = sizeof(sockaddr_in); + _isValid = true; +} + +SockAddr::SockAddr(const char* _iporhost, int port) { + string target = _iporhost; + if (target == "localhost") { + target = "127.0.0.1"; + } + + if (mongoutils::str::contains(target, '/')) { #ifdef _WIN32 - uassert(13080, "no unix socket support on windows", false); + uassert(13080, "no unix socket support on windows", false); #endif - uassert(13079, "path to unix socket too long", - target.size() < sizeof(as<sockaddr_un>().sun_path)); - as<sockaddr_un>().sun_family = AF_UNIX; - strcpy(as<sockaddr_un>().sun_path, target.c_str()); - addressSize = sizeof(sockaddr_un); - _isValid = true; - return; - } + uassert(13079, + "path to unix socket too long", + target.size() < sizeof(as<sockaddr_un>().sun_path)); + as<sockaddr_un>().sun_family = AF_UNIX; + strcpy(as<sockaddr_un>().sun_path, target.c_str()); + addressSize = sizeof(sockaddr_un); + _isValid = true; + return; + } - addrinfo* addrs = NULL; - addrinfo hints; - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_socktype = SOCK_STREAM; - //hints.ai_flags = AI_ADDRCONFIG; // This is often recommended but don't do it. - // SERVER-1579 - hints.ai_flags |= AI_NUMERICHOST; // first pass tries w/o DNS lookup - hints.ai_family = (IPv6Enabled() ? AF_UNSPEC : AF_INET); + addrinfo* addrs = NULL; + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_socktype = SOCK_STREAM; + // hints.ai_flags = AI_ADDRCONFIG; // This is often recommended but don't do it. + // SERVER-1579 + hints.ai_flags |= AI_NUMERICHOST; // first pass tries w/o DNS lookup + hints.ai_family = (IPv6Enabled() ? AF_UNSPEC : AF_INET); - StringBuilder ss; - ss << port; - int ret = getaddrinfo(target.c_str(), ss.str().c_str(), &hints, &addrs); + StringBuilder ss; + ss << port; + int ret = getaddrinfo(target.c_str(), ss.str().c_str(), &hints, &addrs); - // old C compilers on IPv6-capable hosts return EAI_NODATA error +// old C compilers on IPv6-capable hosts return EAI_NODATA error #ifdef EAI_NODATA - int nodata = (ret == EAI_NODATA); + int nodata = (ret == EAI_NODATA); #else - int nodata = false; + int nodata = false; #endif - if ( (ret == EAI_NONAME || nodata) ) { - // iporhost isn't an IP address, allow DNS lookup - hints.ai_flags &= ~AI_NUMERICHOST; - ret = getaddrinfo(target.c_str(), ss.str().c_str(), &hints, &addrs); - } - - if (ret) { - // we were unsuccessful - if( target != "0.0.0.0" ) { // don't log if this as it is a - // CRT construction and log() may not work yet. - log() << "getaddrinfo(\"" << target << "\") failed: " << - getAddrInfoStrError(ret) << endl; - _isValid = false; - return; - } - *this = SockAddr(port); + if ((ret == EAI_NONAME || nodata)) { + // iporhost isn't an IP address, allow DNS lookup + hints.ai_flags &= ~AI_NUMERICHOST; + ret = getaddrinfo(target.c_str(), ss.str().c_str(), &hints, &addrs); + } + + if (ret) { + // we were unsuccessful + if (target != "0.0.0.0") { // don't log if this as it is a + // CRT construction and log() may not work yet. + log() << "getaddrinfo(\"" << target << "\") failed: " << getAddrInfoStrError(ret) + << endl; + _isValid = false; return; } - - //TODO: handle other addresses in linked list; - fassert(16501, addrs->ai_addrlen <= sizeof(sa)); - memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen); - addressSize = addrs->ai_addrlen; - freeaddrinfo(addrs); - _isValid = true; + *this = SockAddr(port); + return; } - bool SockAddr::isLocalHost() const { - switch (getType()) { - case AF_INET: return getAddr() == "127.0.0.1"; - case AF_INET6: return getAddr() == "::1"; - case AF_UNIX: return true; - default: return false; - } - fassert(16502, false); - return false; + // TODO: handle other addresses in linked list; + fassert(16501, addrs->ai_addrlen <= sizeof(sa)); + memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen); + addressSize = addrs->ai_addrlen; + freeaddrinfo(addrs); + _isValid = true; +} + +bool SockAddr::isLocalHost() const { + switch (getType()) { + case AF_INET: + return getAddr() == "127.0.0.1"; + case AF_INET6: + return getAddr() == "::1"; + case AF_UNIX: + return true; + default: + return false; } + fassert(16502, false); + return false; +} - string SockAddr::toString(bool includePort) const { - string out = getAddr(); - if (includePort && getType() != AF_UNIX && getType() != AF_UNSPEC) - out += mongoutils::str::stream() << ':' << getPort(); - return out; - } - - sa_family_t SockAddr::getType() const { - return sa.ss_family; - } - - unsigned SockAddr::getPort() const { - switch (getType()) { - case AF_INET: return ntohs(as<sockaddr_in>().sin_port); - case AF_INET6: return ntohs(as<sockaddr_in6>().sin6_port); - case AF_UNIX: return 0; - case AF_UNSPEC: return 0; - default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return 0; - } +string SockAddr::toString(bool includePort) const { + string out = getAddr(); + if (includePort && getType() != AF_UNIX && getType() != AF_UNSPEC) + out += mongoutils::str::stream() << ':' << getPort(); + return out; +} + +sa_family_t SockAddr::getType() const { + return sa.ss_family; +} + +unsigned SockAddr::getPort() const { + switch (getType()) { + case AF_INET: + return ntohs(as<sockaddr_in>().sin_port); + case AF_INET6: + return ntohs(as<sockaddr_in6>().sin6_port); + case AF_UNIX: + return 0; + case AF_UNSPEC: + return 0; + default: + massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); + return 0; } - - std::string SockAddr::getAddr() const { - switch (getType()) { +} + +std::string SockAddr::getAddr() const { + switch (getType()) { case AF_INET: case AF_INET6: { - const int buflen=128; + const int buflen = 128; char buffer[buflen]; int ret = getnameinfo(raw(), addressSize, buffer, buflen, NULL, 0, NI_NUMERICHOST); - massert(13082, mongoutils::str::stream() << "getnameinfo error " - << getAddrInfoStrError(ret), ret == 0); + massert(13082, + mongoutils::str::stream() << "getnameinfo error " << getAddrInfoStrError(ret), + ret == 0); return buffer; } - - case AF_UNIX: - return (as<sockaddr_un>().sun_path[0] != '\0' ? as<sockaddr_un>().sun_path : - "anonymous unix socket"); - case AF_UNSPEC: + + case AF_UNIX: + return (as<sockaddr_un>().sun_path[0] != '\0' ? as<sockaddr_un>().sun_path + : "anonymous unix socket"); + case AF_UNSPEC: return "(NONE)"; - default: - massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return ""; - } + default: + massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); + return ""; } +} - bool SockAddr::operator==(const SockAddr& r) const { - if (getType() != r.getType()) - return false; - - if (getPort() != r.getPort()) - return false; - - switch (getType()) { - case AF_INET: +bool SockAddr::operator==(const SockAddr& r) const { + if (getType() != r.getType()) + return false; + + if (getPort() != r.getPort()) + return false; + + switch (getType()) { + case AF_INET: return as<sockaddr_in>().sin_addr.s_addr == r.as<sockaddr_in>().sin_addr.s_addr; - case AF_INET6: - return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, - r.as<sockaddr_in6>().sin6_addr.s6_addr, + case AF_INET6: + return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, + r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) == 0; - case AF_UNIX: + case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) == 0; - case AF_UNSPEC: - return true; // assume all unspecified addresses are the same - default: + case AF_UNSPEC: + return true; // assume all unspecified addresses are the same + default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); - } - return false; } - - bool SockAddr::operator!=(const SockAddr& r) const { - return !(*this == r); - } - - bool SockAddr::operator<(const SockAddr& r) const { - if (getType() < r.getType()) - return true; - else if (getType() > r.getType()) - return false; - - if (getPort() < r.getPort()) - return true; - else if (getPort() > r.getPort()) - return false; - - switch (getType()) { - case AF_INET: + return false; +} + +bool SockAddr::operator!=(const SockAddr& r) const { + return !(*this == r); +} + +bool SockAddr::operator<(const SockAddr& r) const { + if (getType() < r.getType()) + return true; + else if (getType() > r.getType()) + return false; + + if (getPort() < r.getPort()) + return true; + else if (getPort() > r.getPort()) + return false; + + switch (getType()) { + case AF_INET: return as<sockaddr_in>().sin_addr.s_addr < r.as<sockaddr_in>().sin_addr.s_addr; - case AF_INET6: - return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, - r.as<sockaddr_in6>().sin6_addr.s6_addr, + case AF_INET6: + return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, + r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) < 0; - case AF_UNIX: + case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) < 0; - case AF_UNSPEC: + case AF_UNSPEC: return false; - default: + default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); - } - return false; } + return false; +} - string makeUnixSockPath(int port) { - return mongoutils::str::stream() << serverGlobalParams.socket << "/mongodb-" << port - << ".sock"; - } +string makeUnixSockPath(int port) { + return mongoutils::str::stream() << serverGlobalParams.socket << "/mongodb-" << port << ".sock"; +} - // If an ip address is passed in, just return that. If a hostname is passed - // in, look up its ip and return that. Returns "" on failure. - string hostbyname(const char *hostname) { - SockAddr sockAddr(hostname, 0); - if (!sockAddr.isValid() || sockAddr.getAddr() == "0.0.0.0") - return ""; - else - return sockAddr.getAddr(); - } - - // --- my -- +// If an ip address is passed in, just return that. If a hostname is passed +// in, look up its ip and return that. Returns "" on failure. +string hostbyname(const char* hostname) { + SockAddr sockAddr(hostname, 0); + if (!sockAddr.isValid() || sockAddr.getAddr() == "0.0.0.0") + return ""; + else + return sockAddr.getAddr(); +} - DiagStr& _hostNameCached = *(new DiagStr); // this is also written to from commands/cloud.cpp +// --- my -- - string getHostName() { - char buf[256]; - int ec = gethostname(buf, 127); - if ( ec || *buf == 0 ) { - log() << "can't get this server's hostname " << errnoWithDescription() << endl; - return ""; - } - return buf; - } +DiagStr& _hostNameCached = *(new DiagStr); // this is also written to from commands/cloud.cpp - /** we store our host name once */ - string getHostNameCached() { - string temp = _hostNameCached.get(); - if (_hostNameCached.empty()) { - temp = getHostName(); - _hostNameCached = temp; - } - return temp; +string getHostName() { + char buf[256]; + int ec = gethostname(buf, 127); + if (ec || *buf == 0) { + log() << "can't get this server's hostname " << errnoWithDescription() << endl; + return ""; } + return buf; +} - string prettyHostName() { - StringBuilder s; - s << getHostNameCached(); - if (serverGlobalParams.port != ServerGlobalParams::DefaultDBPort) - s << ':' << mongo::serverGlobalParams.port; - return s.str(); +/** we store our host name once */ +string getHostNameCached() { + string temp = _hostNameCached.get(); + if (_hostNameCached.empty()) { + temp = getHostName(); + _hostNameCached = temp; } + return temp; +} + +string prettyHostName() { + StringBuilder s; + s << getHostNameCached(); + if (serverGlobalParams.port != ServerGlobalParams::DefaultDBPort) + s << ':' << mongo::serverGlobalParams.port; + return s.str(); +} - // --------- SocketException ---------- +// --------- SocketException ---------- #ifdef MSG_NOSIGNAL - const int portSendFlags = MSG_NOSIGNAL; - const int portRecvFlags = MSG_NOSIGNAL; +const int portSendFlags = MSG_NOSIGNAL; +const int portRecvFlags = MSG_NOSIGNAL; #else - const int portSendFlags = 0; - const int portRecvFlags = 0; +const int portSendFlags = 0; +const int portRecvFlags = 0; #endif - string SocketException::toString() const { - stringstream ss; - ss << _ei.code << " socket exception [" << _getStringType(_type) << "] "; - - if ( _server.size() ) - ss << "server [" << _server << "] "; - - if ( _extra.size() ) - ss << _extra; - - return ss.str(); - } +string SocketException::toString() const { + stringstream ss; + ss << _ei.code << " socket exception [" << _getStringType(_type) << "] "; + + if (_server.size()) + ss << "server [" << _server << "] "; - // ------------ Socket ----------------- + if (_extra.size()) + ss << _extra; - static int socketGetLastError() { + return ss.str(); +} + +// ------------ Socket ----------------- + +static int socketGetLastError() { #ifdef _WIN32 - return WSAGetLastError(); + return WSAGetLastError(); #else - return errno; + return errno; #endif +} + +static SockAddr getLocalAddrForBoundSocketFd(int fd) { + SockAddr result; + int rc = getsockname(fd, result.raw(), &result.addressSize); + if (rc != 0) { + warning() << "Could not resolve local address for socket with fd " << fd << ": " + << getAddrInfoStrError(socketGetLastError()); + result = SockAddr(); + } + return result; +} + +Socket::Socket(int fd, const SockAddr& remote) + : _fd(fd), + _remote(remote), + _timeout(0), + _lastValidityCheckAtSecs(time(0)), + _logLevel(logger::LogSeverity::Log()) { + _init(); + if (fd >= 0) { + _local = getLocalAddrForBoundSocketFd(_fd); } +} - static SockAddr getLocalAddrForBoundSocketFd(int fd) { - SockAddr result; - int rc = getsockname(fd, result.raw(), &result.addressSize); - if (rc != 0) { - warning() << "Could not resolve local address for socket with fd " << fd << ": " << - getAddrInfoStrError(socketGetLastError()); - result = SockAddr(); - } - return result; - } - - Socket::Socket(int fd , const SockAddr& remote) : - _fd(fd), _remote(remote), _timeout(0), _lastValidityCheckAtSecs(time(0)), - _logLevel(logger::LogSeverity::Log()) { - _init(); - if (fd >= 0) { - _local = getLocalAddrForBoundSocketFd(_fd); - } - } +Socket::Socket(double timeout, logger::LogSeverity ll) : _logLevel(ll) { + _fd = -1; + _timeout = timeout; + _lastValidityCheckAtSecs = time(0); + _init(); +} - Socket::Socket( double timeout, logger::LogSeverity ll ) : _logLevel(ll) { - _fd = -1; - _timeout = timeout; - _lastValidityCheckAtSecs = time(0); - _init(); - } +Socket::~Socket() { + close(); +} - Socket::~Socket() { - close(); - } - - void Socket::_init() { - _bytesOut = 0; - _bytesIn = 0; - _awaitingHandshake = true; +void Socket::_init() { + _bytesOut = 0; + _bytesIn = 0; + _awaitingHandshake = true; #ifdef MONGO_CONFIG_SSL - _sslManager = 0; + _sslManager = 0; #endif - } +} - void Socket::close() { - if ( _fd >= 0 ) { - // Stop any blocking reads/writes, and prevent new reads/writes +void Socket::close() { + if (_fd >= 0) { +// Stop any blocking reads/writes, and prevent new reads/writes #if defined(_WIN32) - shutdown( _fd, SD_BOTH ); + shutdown(_fd, SD_BOTH); #else - shutdown( _fd, SHUT_RDWR ); + shutdown(_fd, SHUT_RDWR); #endif - closesocket( _fd ); - _fd = -1; - } + closesocket(_fd); + _fd = -1; } +} #ifdef MONGO_CONFIG_SSL - bool Socket::secure(SSLManagerInterface* mgr, const std::string& remoteHost) { - fassert(16503, mgr); - if ( _fd < 0 ) { - return false; - } - _sslManager = mgr; - _sslConnection.reset(_sslManager->connect(this)); - mgr->parseAndValidatePeerCertificate(_sslConnection.get(), remoteHost); - return true; - } - - void Socket::secureAccepted( SSLManagerInterface* ssl ) { - _sslManager = ssl; +bool Socket::secure(SSLManagerInterface* mgr, const std::string& remoteHost) { + fassert(16503, mgr); + if (_fd < 0) { + return false; } - - std::string Socket::doSSLHandshake(const char* firstBytes, int len) { - if (!_sslManager) return ""; - fassert(16506, _fd); - if (_sslConnection.get()) { - throw SocketException(SocketException::RECV_ERROR, - "Attempt to call SSL_accept on already secure Socket from " + + _sslManager = mgr; + _sslConnection.reset(_sslManager->connect(this)); + mgr->parseAndValidatePeerCertificate(_sslConnection.get(), remoteHost); + return true; +} + +void Socket::secureAccepted(SSLManagerInterface* ssl) { + _sslManager = ssl; +} + +std::string Socket::doSSLHandshake(const char* firstBytes, int len) { + if (!_sslManager) + return ""; + fassert(16506, _fd); + if (_sslConnection.get()) { + throw SocketException(SocketException::RECV_ERROR, + "Attempt to call SSL_accept on already secure Socket from " + remoteString()); - } - _sslConnection.reset(_sslManager->accept(this, firstBytes, len)); - return _sslManager->parseAndValidatePeerCertificate(_sslConnection.get(), ""); } + _sslConnection.reset(_sslManager->accept(this, firstBytes, len)); + return _sslManager->parseAndValidatePeerCertificate(_sslConnection.get(), ""); +} #endif - class ConnectBG : public BackgroundJob { - public: - ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) { } +class ConnectBG : public BackgroundJob { +public: + ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) {} - void run() { + void run() { #if defined(_WIN32) - if ((_res = _connect()) == SOCKET_ERROR) { - _errnoWithDescription = errnoWithDescription(); - } + if ((_res = _connect()) == SOCKET_ERROR) { + _errnoWithDescription = errnoWithDescription(); + } #else - while ((_res = _connect()) == -1) { - const int error = errno; - if (error != EINTR) { - _errnoWithDescription = errnoWithDescription(error); - break; - } + while ((_res = _connect()) == -1) { + const int error = errno; + if (error != EINTR) { + _errnoWithDescription = errnoWithDescription(error); + break; } -#endif } +#endif + } - std::string name() const { return "ConnectBG"; } - std::string getErrnoWithDescription() const { return _errnoWithDescription; } - int inError() const { return _res; } + std::string name() const { + return "ConnectBG"; + } + std::string getErrnoWithDescription() const { + return _errnoWithDescription; + } + int inError() const { + return _res; + } - private: - int _connect() const { - return ::connect(_sock, _remote.raw(), _remote.addressSize); - } +private: + int _connect() const { + return ::connect(_sock, _remote.raw(), _remote.addressSize); + } - int _sock; - int _res; - SockAddr _remote; - std::string _errnoWithDescription; - }; + int _sock; + int _res; + SockAddr _remote; + std::string _errnoWithDescription; +}; - bool Socket::connect(SockAddr& remote) { - _remote = remote; +bool Socket::connect(SockAddr& remote) { + _remote = remote; - _fd = socket(remote.getType(), SOCK_STREAM, 0); - if ( _fd == INVALID_SOCKET ) { - LOG(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl; - return false; - } + _fd = socket(remote.getType(), SOCK_STREAM, 0); + if (_fd == INVALID_SOCKET) { + LOG(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl; + return false; + } - if ( _timeout > 0 ) { - setTimeout( _timeout ); - } + if (_timeout > 0) { + setTimeout(_timeout); + } - static const unsigned int connectTimeoutMillis = 5000; - ConnectBG bg(_fd, remote); - bg.go(); - if ( bg.wait(connectTimeoutMillis) ) { - if ( bg.inError() ) { - warning() << "Failed to connect to " - << _remote.getAddr() << ":" << _remote.getPort() - << ", reason: " << bg.getErrnoWithDescription() << endl; - close(); - return false; - } - } - else { - // time out the connect + static const unsigned int connectTimeoutMillis = 5000; + ConnectBG bg(_fd, remote); + bg.go(); + if (bg.wait(connectTimeoutMillis)) { + if (bg.inError()) { + warning() << "Failed to connect to " << _remote.getAddr() << ":" << _remote.getPort() + << ", reason: " << bg.getErrnoWithDescription() << endl; close(); - bg.wait(); // so bg stays in scope until bg thread terminates - warning() << "Failed to connect to " - << _remote.getAddr() << ":" << _remote.getPort() - << " after " << connectTimeoutMillis << " milliseconds, giving up." << endl; return false; } + } else { + // time out the connect + close(); + bg.wait(); // so bg stays in scope until bg thread terminates + warning() << "Failed to connect to " << _remote.getAddr() << ":" << _remote.getPort() + << " after " << connectTimeoutMillis << " milliseconds, giving up." << endl; + return false; + } - if (remote.getType() != AF_UNIX) - disableNagle(_fd); + if (remote.getType() != AF_UNIX) + disableNagle(_fd); #ifdef SO_NOSIGPIPE - // ignore SIGPIPE signals on osx, to avoid process exit - const int one = 1; - setsockopt( _fd , SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int)); + // ignore SIGPIPE signals on osx, to avoid process exit + const int one = 1; + setsockopt(_fd, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int)); #endif - _local = getLocalAddrForBoundSocketFd(_fd); + _local = getLocalAddrForBoundSocketFd(_fd); - _fdCreationMicroSec = curTimeMicros64(); + _fdCreationMicroSec = curTimeMicros64(); - _awaitingHandshake = false; + _awaitingHandshake = false; - return true; - } + return true; +} - // throws if SSL_write or send fails - int Socket::_send( const char * data , int len, const char * context ) { +// throws if SSL_write or send fails +int Socket::_send(const char* data, int len, const char* context) { #ifdef MONGO_CONFIG_SSL - if ( _sslConnection.get() ) { - return _sslManager->SSL_write( _sslConnection.get() , data , len ); - } -#endif - int ret = ::send( _fd , data , len , portSendFlags ); - if (ret < 0) { - handleSendError(ret, context); - } - return ret; + if (_sslConnection.get()) { + return _sslManager->SSL_write(_sslConnection.get(), data, len); } - - // sends all data or throws an exception - void Socket::send( const char * data , int len, const char *context ) { - while( len > 0 ) { - int ret = -1; - if (MONGO_FAIL_POINT(throwSockExcep)) { +#endif + int ret = ::send(_fd, data, len, portSendFlags); + if (ret < 0) { + handleSendError(ret, context); + } + return ret; +} + +// sends all data or throws an exception +void Socket::send(const char* data, int len, const char* context) { + while (len > 0) { + int ret = -1; + if (MONGO_FAIL_POINT(throwSockExcep)) { #if defined(_WIN32) - WSASetLastError(WSAENETUNREACH); + WSASetLastError(WSAENETUNREACH); #else - errno = ENETUNREACH; + errno = ENETUNREACH; #endif - handleSendError(ret, context); - } - else { - ret = _send(data, len, context); - } - - _bytesOut += ret; + handleSendError(ret, context); + } else { + ret = _send(data, len, context); + } - fassert(16507, ret <= len); - len -= ret; - data += ret; + _bytesOut += ret; - } + fassert(16507, ret <= len); + len -= ret; + data += ret; } +} - void Socket::_send( const vector< pair< char *, int > > &data, const char *context ) { - for (vector< pair<char *, int> >::const_iterator i = data.begin(); - i != data.end(); - ++i) { - char * data = i->first; - int len = i->second; - send( data, len, context ); - } +void Socket::_send(const vector<pair<char*, int>>& data, const char* context) { + for (vector<pair<char*, int>>::const_iterator i = data.begin(); i != data.end(); ++i) { + char* data = i->first; + int len = i->second; + send(data, len, context); } +} - /** sends all data or throws an exception - * @param context descriptive for logging - */ - void Socket::send( const vector< pair< char *, int > > &data, const char *context ) { - +/** sends all data or throws an exception + * @param context descriptive for logging + */ +void Socket::send(const vector<pair<char*, int>>& data, const char* context) { #ifdef MONGO_CONFIG_SSL - if ( _sslConnection.get() ) { - _send( data , context ); - return; - } + if (_sslConnection.get()) { + _send(data, context); + return; + } #endif #if defined(_WIN32) - // TODO use scatter/gather api - _send( data , context ); + // TODO use scatter/gather api + _send(data, context); #else - vector<struct iovec> d( data.size() ); - int i = 0; - for (vector< pair<char *, int> >::const_iterator j = data.begin(); - j != data.end(); - ++j) { - if ( j->second > 0 ) { - d[ i ].iov_base = j->first; - d[ i ].iov_len = j->second; - ++i; - _bytesOut += j->second; - } - } - struct msghdr meta; - memset( &meta, 0, sizeof( meta ) ); - meta.msg_iov = &d[ 0 ]; - meta.msg_iovlen = d.size(); - - while( meta.msg_iovlen > 0 ) { - int ret = -1; - if (MONGO_FAIL_POINT(throwSockExcep)) { + vector<struct iovec> d(data.size()); + int i = 0; + for (vector<pair<char*, int>>::const_iterator j = data.begin(); j != data.end(); ++j) { + if (j->second > 0) { + d[i].iov_base = j->first; + d[i].iov_len = j->second; + ++i; + _bytesOut += j->second; + } + } + struct msghdr meta; + memset(&meta, 0, sizeof(meta)); + meta.msg_iov = &d[0]; + meta.msg_iovlen = d.size(); + + while (meta.msg_iovlen > 0) { + int ret = -1; + if (MONGO_FAIL_POINT(throwSockExcep)) { #if defined(_WIN32) - WSASetLastError(WSAENETUNREACH); + WSASetLastError(WSAENETUNREACH); #else - errno = ENETUNREACH; + errno = ENETUNREACH; #endif + } else { + ret = ::sendmsg(_fd, &meta, portSendFlags); + } + + if (ret == -1) { + if (errno != EAGAIN || _timeout == 0) { + LOG(_logLevel) << "Socket " << context << " send() " << errnoWithDescription() + << ' ' << remoteString() << endl; + throw SocketException(SocketException::SEND_ERROR, remoteString()); + } else { + LOG(_logLevel) << "Socket " << context << " send() remote timeout " + << remoteString() << endl; + throw SocketException(SocketException::SEND_TIMEOUT, remoteString()); } - else { - ret = ::sendmsg(_fd, &meta, portSendFlags); - } - - if (ret == -1) { - if ( errno != EAGAIN || _timeout == 0 ) { - LOG(_logLevel) << "Socket " << context << - " send() " << errnoWithDescription() << ' ' << remoteString() << endl; - throw SocketException( SocketException::SEND_ERROR , remoteString() ); - } - else { - LOG(_logLevel) << "Socket " << context << - " send() remote timeout " << remoteString() << endl; - throw SocketException( SocketException::SEND_TIMEOUT , remoteString() ); - } - } - else { - struct iovec *& i = meta.msg_iov; - while( ret > 0 ) { - if ( i->iov_len > unsigned( ret ) ) { - i->iov_len -= ret; - i->iov_base = (char*)(i->iov_base) + ret; - ret = 0; - } - else { - ret -= i->iov_len; - ++i; - --(meta.msg_iovlen); - } + } else { + struct iovec*& i = meta.msg_iov; + while (ret > 0) { + if (i->iov_len > unsigned(ret)) { + i->iov_len -= ret; + i->iov_base = (char*)(i->iov_base) + ret; + ret = 0; + } else { + ret -= i->iov_len; + ++i; + --(meta.msg_iovlen); } } } -#endif } +#endif +} - void Socket::recv( char * buf , int len ) { - while( len > 0 ) { - int ret = -1; - if (MONGO_FAIL_POINT(throwSockExcep)) { +void Socket::recv(char* buf, int len) { + while (len > 0) { + int ret = -1; + if (MONGO_FAIL_POINT(throwSockExcep)) { #if defined(_WIN32) - WSASetLastError(WSAENETUNREACH); + WSASetLastError(WSAENETUNREACH); #else - errno = ENETUNREACH; + errno = ENETUNREACH; #endif - if (ret <= 0) { - handleRecvError(ret, len); - continue; - } - } - else { - ret = unsafe_recv(buf, len); + if (ret <= 0) { + handleRecvError(ret, len); + continue; } - - fassert(16508, ret <= len); - len -= ret; - buf += ret; + } else { + ret = unsafe_recv(buf, len); } - } - int Socket::unsafe_recv( char *buf, int max ) { - int x = _recv( buf , max ); - _bytesIn += x; - return x; + fassert(16508, ret <= len); + len -= ret; + buf += ret; } +} + +int Socket::unsafe_recv(char* buf, int max) { + int x = _recv(buf, max); + _bytesIn += x; + return x; +} - // throws if SSL_read fails or recv returns an error - int Socket::_recv( char *buf, int max ) { +// throws if SSL_read fails or recv returns an error +int Socket::_recv(char* buf, int max) { #ifdef MONGO_CONFIG_SSL - if ( _sslConnection.get() ){ - return _sslManager->SSL_read( _sslConnection.get() , buf , max ); - } + if (_sslConnection.get()) { + return _sslManager->SSL_read(_sslConnection.get(), buf, max); + } #endif - int ret = ::recv( _fd , buf , max , portRecvFlags ); - if (ret <= 0) { - handleRecvError(ret, max); // If no throw return and call _recv again - return 0; - } - return ret; + int ret = ::recv(_fd, buf, max, portRecvFlags); + if (ret <= 0) { + handleRecvError(ret, max); // If no throw return and call _recv again + return 0; } + return ret; +} - void Socket::handleSendError(int ret, const char* context) { - +void Socket::handleSendError(int ret, const char* context) { #if defined(_WIN32) - const int mongo_errno = WSAGetLastError(); - if ( mongo_errno == WSAETIMEDOUT && _timeout != 0 ) { + const int mongo_errno = WSAGetLastError(); + if (mongo_errno == WSAETIMEDOUT && _timeout != 0) { #else - const int mongo_errno = errno; - if ( ( mongo_errno == EAGAIN || mongo_errno == EWOULDBLOCK ) && _timeout != 0 ) { + const int mongo_errno = errno; + if ((mongo_errno == EAGAIN || mongo_errno == EWOULDBLOCK) && _timeout != 0) { #endif - LOG(_logLevel) << "Socket " << context << - " send() timed out " << remoteString() << endl; - throw SocketException(SocketException::SEND_TIMEOUT , remoteString()); - } - else { - LOG(_logLevel) << "Socket " << context << " send() " - << errnoWithDescription(mongo_errno) << ' ' << remoteString() << endl; - throw SocketException(SocketException::SEND_ERROR , remoteString()); - } + LOG(_logLevel) << "Socket " << context << " send() timed out " << remoteString() << endl; + throw SocketException(SocketException::SEND_TIMEOUT, remoteString()); + } else { + LOG(_logLevel) << "Socket " << context << " send() " << errnoWithDescription(mongo_errno) + << ' ' << remoteString() << endl; + throw SocketException(SocketException::SEND_ERROR, remoteString()); } +} - void Socket::handleRecvError(int ret, int len) { - if (ret == 0) { - LOG(3) << "Socket recv() conn closed? " << remoteString() << endl; - throw SocketException(SocketException::CLOSED , remoteString()); - } - - // ret < 0 +void Socket::handleRecvError(int ret, int len) { + if (ret == 0) { + LOG(3) << "Socket recv() conn closed? " << remoteString() << endl; + throw SocketException(SocketException::CLOSED, remoteString()); + } + +// ret < 0 #if defined(_WIN32) - int e = WSAGetLastError(); + int e = WSAGetLastError(); #else - int e = errno; -# if defined(EINTR) - if (e == EINTR) { - LOG(_logLevel) << "EINTR returned from recv(), retrying"; - return; - } -# endif + int e = errno; +#if defined(EINTR) + if (e == EINTR) { + LOG(_logLevel) << "EINTR returned from recv(), retrying"; + return; + } +#endif #endif #if defined(_WIN32) - // Windows - if ((e == EAGAIN || e == WSAETIMEDOUT) && _timeout > 0) { + // Windows + if ((e == EAGAIN || e == WSAETIMEDOUT) && _timeout > 0) { #else - if (e == EAGAIN && _timeout > 0) { + if (e == EAGAIN && _timeout > 0) { #endif - // this is a timeout - LOG(_logLevel) << "Socket recv() timeout " << remoteString() <<endl; - throw SocketException(SocketException::RECV_TIMEOUT, remoteString()); - } - - LOG(_logLevel) << "Socket recv() " << - errnoWithDescription(e) << " " << remoteString() <<endl; - throw SocketException(SocketException::RECV_ERROR , remoteString()); - } - - void Socket::setTimeout( double secs ) { - setSockTimeouts( _fd, secs ); - } - - // TODO: allow modification? - // - // <positive value> : secs to wait between stillConnected checks - // 0 : always check - // -1 : never check - const int Socket::errorPollIntervalSecs( 5 ); - - // Patch to allow better tolerance of flaky network connections that get broken - // while we aren't looking. - // TODO: Remove when better async changes come. - // - // isStillConnected() polls the socket at max every Socket::errorPollIntervalSecs to determine - // if any disconnection-type events have happened on the socket. - bool Socket::isStillConnected() { - if (_fd == -1) { - // According to the man page, poll will respond with POLLVNAL for invalid or - // unopened descriptors, but it doesn't seem to be properly implemented in - // some platforms - it can return 0 events and 0 for revent. Hence this workaround. - return false; - } - - if ( errorPollIntervalSecs < 0 ) return true; - if ( ! isPollSupported() ) return true; // nothing we can do + // this is a timeout + LOG(_logLevel) << "Socket recv() timeout " << remoteString() << endl; + throw SocketException(SocketException::RECV_TIMEOUT, remoteString()); + } + + LOG(_logLevel) << "Socket recv() " << errnoWithDescription(e) << " " << remoteString() << endl; + throw SocketException(SocketException::RECV_ERROR, remoteString()); +} + +void Socket::setTimeout(double secs) { + setSockTimeouts(_fd, secs); +} + +// TODO: allow modification? +// +// <positive value> : secs to wait between stillConnected checks +// 0 : always check +// -1 : never check +const int Socket::errorPollIntervalSecs(5); + +// Patch to allow better tolerance of flaky network connections that get broken +// while we aren't looking. +// TODO: Remove when better async changes come. +// +// isStillConnected() polls the socket at max every Socket::errorPollIntervalSecs to determine +// if any disconnection-type events have happened on the socket. +bool Socket::isStillConnected() { + if (_fd == -1) { + // According to the man page, poll will respond with POLLVNAL for invalid or + // unopened descriptors, but it doesn't seem to be properly implemented in + // some platforms - it can return 0 events and 0 for revent. Hence this workaround. + return false; + } - time_t now = time( 0 ); - time_t idleTimeSecs = now - _lastValidityCheckAtSecs; + if (errorPollIntervalSecs < 0) + return true; + if (!isPollSupported()) + return true; // nothing we can do - // Only check once every 5 secs - if ( idleTimeSecs < errorPollIntervalSecs ) return true; - // Reset our timer, we're checking the connection - _lastValidityCheckAtSecs = now; + time_t now = time(0); + time_t idleTimeSecs = now - _lastValidityCheckAtSecs; - // It's been long enough, poll to see if our socket is still connected + // Only check once every 5 secs + if (idleTimeSecs < errorPollIntervalSecs) + return true; + // Reset our timer, we're checking the connection + _lastValidityCheckAtSecs = now; - pollfd pollInfo; - pollInfo.fd = _fd; - // We only care about reading the EOF message on clean close (and errors) - pollInfo.events = POLLIN; + // It's been long enough, poll to see if our socket is still connected - // Poll( info[], size, timeout ) - timeout == 0 => nonblocking - int nEvents = socketPoll( &pollInfo, 1, 0 ); + pollfd pollInfo; + pollInfo.fd = _fd; + // We only care about reading the EOF message on clean close (and errors) + pollInfo.events = POLLIN; - LOG( 2 ) << "polling for status of connection to " << remoteString() - << ", " << ( nEvents == 0 ? "no events" : - nEvents == -1 ? "error detected" : - "event detected" ) << endl; + // Poll( info[], size, timeout ) - timeout == 0 => nonblocking + int nEvents = socketPoll(&pollInfo, 1, 0); - if ( nEvents == 0 ) { - // No events incoming, return still connected AFAWK - return true; - } - else if ( nEvents < 0 ) { - // Poll itself failed, this is weird, warn and log errno - warning() << "Socket poll() failed during connectivity check" - << " (idle " << idleTimeSecs << " secs," - << " remote host " << remoteString() << ")" - << causedBy(errnoWithDescription()) << endl; + LOG(2) << "polling for status of connection to " << remoteString() << ", " + << (nEvents == 0 ? "no events" : nEvents == -1 ? "error detected" : "event detected") + << endl; - // Return true since it's not clear that we're disconnected. - return true; - } - - dassert( nEvents == 1 ); - dassert( pollInfo.revents > 0 ); + if (nEvents == 0) { + // No events incoming, return still connected AFAWK + return true; + } else if (nEvents < 0) { + // Poll itself failed, this is weird, warn and log errno + warning() << "Socket poll() failed during connectivity check" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" << causedBy(errnoWithDescription()) + << endl; + + // Return true since it's not clear that we're disconnected. + return true; + } - // Return false at this point, some event happened on the socket, but log what the - // actual event was. + dassert(nEvents == 1); + dassert(pollInfo.revents > 0); - if ( pollInfo.revents & POLLIN ) { + // Return false at this point, some event happened on the socket, but log what the + // actual event was. - // There shouldn't really be any data to recv here, so make sure this - // is a clean hangup. + if (pollInfo.revents & POLLIN) { + // There shouldn't really be any data to recv here, so make sure this + // is a clean hangup. - const int testBufLength = 1024; - char testBuf[testBufLength]; + const int testBufLength = 1024; + char testBuf[testBufLength]; - int recvd = ::recv( _fd, testBuf, testBufLength, portRecvFlags ); + int recvd = ::recv(_fd, testBuf, testBufLength, portRecvFlags); - if ( recvd < 0 ) { - // An error occurred during recv, warn and log errno - warning() << "Socket recv() failed during connectivity check" - << " (idle " << idleTimeSecs << " secs," - << " remote host " << remoteString() << ")" - << causedBy(errnoWithDescription()) << endl; - } - else if ( recvd > 0 ) { - // We got nonzero data from this socket, very weird? - // Log and warn at runtime, log and abort at devtime - // TODO: Dump the data to the log somehow? - error() << "Socket found pending " << recvd - << " bytes of data during connectivity check" - << " (idle " << idleTimeSecs << " secs," - << " remote host " << remoteString() << ")" << endl; - DEV { - std::string hex = hexdump(testBuf, recvd); - error() << "Hex dump of stale log data: " << hex << endl; - } - dassert( false ); - } - else { - // recvd == 0, socket closed remotely, just return false - LOG( 0 ) << "Socket closed remotely, no longer connected" - << " (idle " << idleTimeSecs << " secs," - << " remote host " << remoteString() << ")" << endl; - } - } - else if ( pollInfo.revents & POLLHUP ) { - // A hangup has occurred on this socket - LOG( 0 ) << "Socket hangup detected, no longer connected" << " (idle " - << idleTimeSecs << " secs," << " remote host " << remoteString() << ")" - << endl; - } - else if ( pollInfo.revents & POLLERR ) { - // An error has occurred on this socket - LOG( 0 ) << "Socket error detected, no longer connected" << " (idle " - << idleTimeSecs << " secs," << " remote host " << remoteString() << ")" - << endl; - } - else if ( pollInfo.revents & POLLNVAL ) { - // Socket descriptor itself is weird - // Log and warn at runtime, log and abort at devtime - error() << "Socket descriptor detected as invalid" - << " (idle " << idleTimeSecs << " secs," - << " remote host " << remoteString() << ")" << endl; - dassert( false ); - } - else { - // Don't know what poll is saying here + if (recvd < 0) { + // An error occurred during recv, warn and log errno + warning() << "Socket recv() failed during connectivity check" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" + << causedBy(errnoWithDescription()) << endl; + } else if (recvd > 0) { + // We got nonzero data from this socket, very weird? // Log and warn at runtime, log and abort at devtime - error() << "Socket had unknown event (" << static_cast<int>(pollInfo.revents) << ")" + // TODO: Dump the data to the log somehow? + error() << "Socket found pending " << recvd + << " bytes of data during connectivity check" << " (idle " << idleTimeSecs << " secs," << " remote host " << remoteString() << ")" << endl; - dassert( false ); - } - - return false; - } + DEV { + std::string hex = hexdump(testBuf, recvd); + error() << "Hex dump of stale log data: " << hex << endl; + } + dassert(false); + } else { + // recvd == 0, socket closed remotely, just return false + LOG(0) << "Socket closed remotely, no longer connected" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" << endl; + } + } else if (pollInfo.revents & POLLHUP) { + // A hangup has occurred on this socket + LOG(0) << "Socket hangup detected, no longer connected" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" << endl; + } else if (pollInfo.revents & POLLERR) { + // An error has occurred on this socket + LOG(0) << "Socket error detected, no longer connected" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" << endl; + } else if (pollInfo.revents & POLLNVAL) { + // Socket descriptor itself is weird + // Log and warn at runtime, log and abort at devtime + error() << "Socket descriptor detected as invalid" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" << endl; + dassert(false); + } else { + // Don't know what poll is saying here + // Log and warn at runtime, log and abort at devtime + error() << "Socket had unknown event (" << static_cast<int>(pollInfo.revents) << ")" + << " (idle " << idleTimeSecs << " secs," + << " remote host " << remoteString() << ")" << endl; + dassert(false); + } + + return false; +} #if defined(_WIN32) - struct WinsockInit { - WinsockInit() { - WSADATA d; - if ( WSAStartup(MAKEWORD(2,2), &d) != 0 ) { - log() << "ERROR: wsastartup failed " << errnoWithDescription() << endl; - quickExit(EXIT_NTSERVICE_ERROR); - } +struct WinsockInit { + WinsockInit() { + WSADATA d; + if (WSAStartup(MAKEWORD(2, 2), &d) != 0) { + log() << "ERROR: wsastartup failed " << errnoWithDescription() << endl; + quickExit(EXIT_NTSERVICE_ERROR); } - } winsock_init; + } +} winsock_init; #endif -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/sock.h b/src/mongo/util/net/sock.h index a2aec13c388..03d751c70fb 100644 --- a/src/mongo/util/net/sock.h +++ b/src/mongo/util/net/sock.h @@ -39,10 +39,10 @@ #include <errno.h> #ifdef __OpenBSD__ -# include <sys/uio.h> +#include <sys/uio.h> #endif -#endif // not _WIN32 +#endif // not _WIN32 #include <string> #include <utility> @@ -58,266 +58,326 @@ namespace mongo { #ifdef MONGO_CONFIG_SSL - class SSLManagerInterface; - class SSLConnection; +class SSLManagerInterface; +class SSLConnection; #endif - extern const int portSendFlags; - extern const int portRecvFlags; +extern const int portSendFlags; +extern const int portRecvFlags; - const int SOCK_FAMILY_UNKNOWN_ERROR=13078; +const int SOCK_FAMILY_UNKNOWN_ERROR = 13078; - void disableNagle(int sock); +void disableNagle(int sock); #if defined(_WIN32) - typedef short sa_family_t; - typedef int socklen_t; +typedef short sa_family_t; +typedef int socklen_t; - // This won't actually be used on windows - struct sockaddr_un { - short sun_family; - char sun_path[108]; // length from unix header - }; +// This won't actually be used on windows +struct sockaddr_un { + short sun_family; + char sun_path[108]; // length from unix header +}; -#else // _WIN32 +#else // _WIN32 - inline void closesocket(int s) { close(s); } - const int INVALID_SOCKET = -1; - typedef int SOCKET; +inline void closesocket(int s) { + close(s); +} +const int INVALID_SOCKET = -1; +typedef int SOCKET; -#endif // _WIN32 +#endif // _WIN32 - std::string makeUnixSockPath(int port); +std::string makeUnixSockPath(int port); - // If an ip address is passed in, just return that. If a hostname is passed - // in, look up its ip and return that. Returns "" on failure. - std::string hostbyname(const char *hostname); +// If an ip address is passed in, just return that. If a hostname is passed +// in, look up its ip and return that. Returns "" on failure. +std::string hostbyname(const char* hostname); - void enableIPv6(bool state=true); - bool IPv6Enabled(); - void setSockTimeouts(int sock, double secs); +void enableIPv6(bool state = true); +bool IPv6Enabled(); +void setSockTimeouts(int sock, double secs); + +/** + * wrapped around os representation of network address + */ +struct SockAddr { + SockAddr(); + explicit SockAddr(int sourcePort); /* listener side */ + SockAddr( + const char* ip, + int port); /* EndPoint (remote) side, or if you want to specify which interface locally */ + + template <typename T> + T& as() { + return *(T*)(&sa); + } + template <typename T> + const T& as() const { + return *(const T*)(&sa); + } + + std::string toString(bool includePort = true) const; + + bool isValid() const { + return _isValid; + } /** - * wrapped around os representation of network address + * @return one of AF_INET, AF_INET6, or AF_UNIX */ - struct SockAddr { - SockAddr(); - explicit SockAddr(int sourcePort); /* listener side */ - SockAddr(const char *ip, int port); /* EndPoint (remote) side, or if you want to specify which interface locally */ + sa_family_t getType() const; - template <typename T> T& as() { return *(T*)(&sa); } - template <typename T> const T& as() const { return *(const T*)(&sa); } - - std::string toString(bool includePort=true) const; + unsigned getPort() const; - bool isValid() const { return _isValid; } + std::string getAddr() const; - /** - * @return one of AF_INET, AF_INET6, or AF_UNIX - */ - sa_family_t getType() const; + bool isLocalHost() const; - unsigned getPort() const; + bool operator==(const SockAddr& r) const; - std::string getAddr() const; + bool operator!=(const SockAddr& r) const; - bool isLocalHost() const; + bool operator<(const SockAddr& r) const; - bool operator==(const SockAddr& r) const; + const sockaddr* raw() const { + return (sockaddr*)&sa; + } + sockaddr* raw() { + return (sockaddr*)&sa; + } - bool operator!=(const SockAddr& r) const; + socklen_t addressSize; - bool operator<(const SockAddr& r) const; +private: + struct sockaddr_storage sa; + bool _isValid; +}; - const sockaddr* raw() const {return (sockaddr*)&sa;} - sockaddr* raw() {return (sockaddr*)&sa;} +/** this is not cache and does a syscall */ +std::string getHostName(); - socklen_t addressSize; - private: - struct sockaddr_storage sa; - bool _isValid; - }; +/** this is cached, so if changes during the process lifetime + * will be stale */ +std::string getHostNameCached(); - /** this is not cache and does a syscall */ - std::string getHostName(); - - /** this is cached, so if changes during the process lifetime - * will be stale */ - std::string getHostNameCached(); +std::string prettyHostName(); - std::string prettyHostName(); - - /** - * thrown by Socket and SockAddr - */ - class SocketException : public DBException { - public: - const enum Type { CLOSED , RECV_ERROR , SEND_ERROR, RECV_TIMEOUT, SEND_TIMEOUT, FAILED_STATE, CONNECT_ERROR } _type; - - SocketException( Type t , const std::string& server , int code = 9001 , const std::string& extra="" ) - : DBException( std::string("socket exception [") + _getStringType( t ) + "] for " + server, code ), - _type(t), - _server(server), - _extra(extra) - {} - - virtual ~SocketException() throw() {} - - bool shouldPrint() const { return _type != CLOSED; } - virtual std::string toString() const; - virtual const std::string* server() const { return &_server; } - private: - - // TODO: Allow exceptions better control over their messages - static std::string _getStringType( Type t ){ - switch (t) { - case CLOSED: return "CLOSED"; - case RECV_ERROR: return "RECV_ERROR"; - case SEND_ERROR: return "SEND_ERROR"; - case RECV_TIMEOUT: return "RECV_TIMEOUT"; - case SEND_TIMEOUT: return "SEND_TIMEOUT"; - case FAILED_STATE: return "FAILED_STATE"; - case CONNECT_ERROR: return "CONNECT_ERROR"; - default: return "UNKNOWN"; // should never happen - } +/** + * thrown by Socket and SockAddr + */ +class SocketException : public DBException { +public: + const enum Type { + CLOSED, + RECV_ERROR, + SEND_ERROR, + RECV_TIMEOUT, + SEND_TIMEOUT, + FAILED_STATE, + CONNECT_ERROR + } _type; + + SocketException(Type t, + const std::string& server, + int code = 9001, + const std::string& extra = "") + : DBException(std::string("socket exception [") + _getStringType(t) + "] for " + server, + code), + _type(t), + _server(server), + _extra(extra) {} + + virtual ~SocketException() throw() {} + + bool shouldPrint() const { + return _type != CLOSED; + } + virtual std::string toString() const; + virtual const std::string* server() const { + return &_server; + } + +private: + // TODO: Allow exceptions better control over their messages + static std::string _getStringType(Type t) { + switch (t) { + case CLOSED: + return "CLOSED"; + case RECV_ERROR: + return "RECV_ERROR"; + case SEND_ERROR: + return "SEND_ERROR"; + case RECV_TIMEOUT: + return "RECV_TIMEOUT"; + case SEND_TIMEOUT: + return "SEND_TIMEOUT"; + case FAILED_STATE: + return "FAILED_STATE"; + case CONNECT_ERROR: + return "CONNECT_ERROR"; + default: + return "UNKNOWN"; // should never happen } + } - std::string _server; - std::string _extra; - }; - - - /** - * thin wrapped around file descriptor and system calls - * todo: ssl - */ - class Socket { - MONGO_DISALLOW_COPYING(Socket); - public: - - static const int errorPollIntervalSecs; - - Socket(int sock, const SockAddr& farEnd); - - /** In some cases the timeout will actually be 2x this value - eg we do a partial send, - then the timeout fires, then we try to send again, then the timeout fires again with - no data sent, then we detect that the other side is down. - - Generally you don't want a timeout, you should be very prepared for errors if you set one. - */ - Socket(double so_timeout = 0, logger::LogSeverity logLevel = logger::LogSeverity::Log() ); + std::string _server; + std::string _extra; +}; - ~Socket(); - /** The correct way to initialize and connect to a socket is as follows: (1) construct the - * SockAddr, (2) check whether the SockAddr isValid(), (3) if the SockAddr is valid, a - * Socket may then try to connect to that SockAddr. It is critical to check the return - * value of connect as a false return indicates that there was an error, and the Socket - * failed to connect to the given SockAddr. This failure may be due to ConnectBG returning - * an error, or due to a timeout on connection, or due to the system socket deciding the - * socket is invalid. - */ - bool connect(SockAddr& farEnd); - - void close(); - void send( const char * data , int len, const char *context ); - void send( const std::vector< std::pair< char *, int > > &data, const char *context ); - - // recv len or throw SocketException - void recv( char * data , int len ); - int unsafe_recv( char *buf, int max ); - - logger::LogSeverity getLogLevel() const { return _logLevel; } - void setLogLevel( logger::LogSeverity ll ) { _logLevel = ll; } +/** + * thin wrapped around file descriptor and system calls + * todo: ssl + */ +class Socket { + MONGO_DISALLOW_COPYING(Socket); - SockAddr remoteAddr() const { return _remote; } - std::string remoteString() const { return _remote.toString(); } - unsigned remotePort() const { return _remote.getPort(); } +public: + static const int errorPollIntervalSecs; - SockAddr localAddr() const { return _local; } + Socket(int sock, const SockAddr& farEnd); - void clearCounters() { _bytesIn = 0; _bytesOut = 0; } - long long getBytesIn() const { return _bytesIn; } - long long getBytesOut() const { return _bytesOut; } - int rawFD() const { return _fd; } + /** In some cases the timeout will actually be 2x this value - eg we do a partial send, + then the timeout fires, then we try to send again, then the timeout fires again with + no data sent, then we detect that the other side is down. - void setTimeout( double secs ); - bool isStillConnected(); + Generally you don't want a timeout, you should be very prepared for errors if you set one. + */ + Socket(double so_timeout = 0, logger::LogSeverity logLevel = logger::LogSeverity::Log()); - void setHandshakeReceived() { - _awaitingHandshake = false; - } + ~Socket(); - bool isAwaitingHandshake() { - return _awaitingHandshake; - } + /** The correct way to initialize and connect to a socket is as follows: (1) construct the + * SockAddr, (2) check whether the SockAddr isValid(), (3) if the SockAddr is valid, a + * Socket may then try to connect to that SockAddr. It is critical to check the return + * value of connect as a false return indicates that there was an error, and the Socket + * failed to connect to the given SockAddr. This failure may be due to ConnectBG returning + * an error, or due to a timeout on connection, or due to the system socket deciding the + * socket is invalid. + */ + bool connect(SockAddr& farEnd); + + void close(); + void send(const char* data, int len, const char* context); + void send(const std::vector<std::pair<char*, int>>& data, const char* context); + + // recv len or throw SocketException + void recv(char* data, int len); + int unsafe_recv(char* buf, int max); + + logger::LogSeverity getLogLevel() const { + return _logLevel; + } + void setLogLevel(logger::LogSeverity ll) { + _logLevel = ll; + } + + SockAddr remoteAddr() const { + return _remote; + } + std::string remoteString() const { + return _remote.toString(); + } + unsigned remotePort() const { + return _remote.getPort(); + } + + SockAddr localAddr() const { + return _local; + } + + void clearCounters() { + _bytesIn = 0; + _bytesOut = 0; + } + long long getBytesIn() const { + return _bytesIn; + } + long long getBytesOut() const { + return _bytesOut; + } + int rawFD() const { + return _fd; + } + + void setTimeout(double secs); + bool isStillConnected(); + + void setHandshakeReceived() { + _awaitingHandshake = false; + } + + bool isAwaitingHandshake() { + return _awaitingHandshake; + } #ifdef MONGO_CONFIG_SSL - /** secures inline - * ssl - Pointer to the global SSLManager. - * remoteHost - The hostname of the remote server. - */ - bool secure( SSLManagerInterface* ssl, const std::string& remoteHost); + /** secures inline + * ssl - Pointer to the global SSLManager. + * remoteHost - The hostname of the remote server. + */ + bool secure(SSLManagerInterface* ssl, const std::string& remoteHost); - void secureAccepted( SSLManagerInterface* ssl ); + void secureAccepted(SSLManagerInterface* ssl); #endif - - /** - * This function calls SSL_accept() if SSL-encrypted sockets - * are desired. SSL_accept() waits until the remote host calls - * SSL_connect(). The return value is the subject name of any - * client certificate provided during the handshake. - * - * @firstBytes is the first bytes received on the socket used - * to detect the connection SSL, @len is the number of bytes - * - * This function may throw SocketException. - */ - std::string doSSLHandshake(const char* firstBytes = NULL, int len = 0); - - /** - * @return the time when the socket was opened. - */ - uint64_t getSockCreationMicroSec() const { - return _fdCreationMicroSec; - } - void handleRecvError(int ret, int len); - MONGO_COMPILER_NORETURN void handleSendError(int ret, const char* context); + /** + * This function calls SSL_accept() if SSL-encrypted sockets + * are desired. SSL_accept() waits until the remote host calls + * SSL_connect(). The return value is the subject name of any + * client certificate provided during the handshake. + * + * @firstBytes is the first bytes received on the socket used + * to detect the connection SSL, @len is the number of bytes + * + * This function may throw SocketException. + */ + std::string doSSLHandshake(const char* firstBytes = NULL, int len = 0); + + /** + * @return the time when the socket was opened. + */ + uint64_t getSockCreationMicroSec() const { + return _fdCreationMicroSec; + } + + void handleRecvError(int ret, int len); + MONGO_COMPILER_NORETURN void handleSendError(int ret, const char* context); - private: - void _init(); +private: + void _init(); - /** sends dumbly, just each buffer at a time */ - void _send( const std::vector< std::pair< char *, int > > &data, const char *context ); + /** sends dumbly, just each buffer at a time */ + void _send(const std::vector<std::pair<char*, int>>& data, const char* context); - /** raw send, same semantics as ::send with an additional context parameter */ - int _send( const char * data , int len , const char * context ); + /** raw send, same semantics as ::send with an additional context parameter */ + int _send(const char* data, int len, const char* context); - /** raw recv, same semantics as ::recv */ - int _recv( char * buf , int max ); + /** raw recv, same semantics as ::recv */ + int _recv(char* buf, int max); - int _fd; - uint64_t _fdCreationMicroSec; - SockAddr _local; - SockAddr _remote; - double _timeout; + int _fd; + uint64_t _fdCreationMicroSec; + SockAddr _local; + SockAddr _remote; + double _timeout; - long long _bytesIn; - long long _bytesOut; - time_t _lastValidityCheckAtSecs; + long long _bytesIn; + long long _bytesOut; + time_t _lastValidityCheckAtSecs; #ifdef MONGO_CONFIG_SSL - std::unique_ptr<SSLConnection> _sslConnection; - SSLManagerInterface* _sslManager; + std::unique_ptr<SSLConnection> _sslConnection; + SSLManagerInterface* _sslManager; #endif - logger::LogSeverity _logLevel; // passed to log() when logging errors - - /** true until the first packet has been received or an outgoing connect has been made */ - bool _awaitingHandshake; + logger::LogSeverity _logLevel; // passed to log() when logging errors - }; + /** true until the first packet has been received or an outgoing connect has been made */ + bool _awaitingHandshake; +}; -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/sock_test.cpp b/src/mongo/util/net/sock_test.cpp index 0a823e15f23..26f0c2c821a 100644 --- a/src/mongo/util/net/sock_test.cpp +++ b/src/mongo/util/net/sock_test.cpp @@ -44,305 +44,301 @@ namespace { - using namespace mongo; - using std::shared_ptr; +using namespace mongo; +using std::shared_ptr; - typedef std::shared_ptr<Socket> SocketPtr; - typedef std::pair<SocketPtr, SocketPtr> SocketPair; +typedef std::shared_ptr<Socket> SocketPtr; +typedef std::pair<SocketPtr, SocketPtr> SocketPair; - // On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair' - // call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc. - // For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET - // pair on a random port. - SocketPair socketPair(const int type, const int protocol = 0); +// On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair' +// call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc. +// For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET +// pair on a random port. +SocketPair socketPair(const int type, const int protocol = 0); #if defined(_WIN32) - namespace detail { - void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) { - *acceptSock = INVALID_SOCKET; - const SOCKET result = ::accept(listenSock, NULL, 0); - if (result != INVALID_SOCKET) { - *acceptSock = result; - } - notify.notifyOne(); +namespace detail { +void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) { + *acceptSock = INVALID_SOCKET; + const SOCKET result = ::accept(listenSock, NULL, 0); + if (result != INVALID_SOCKET) { + *acceptSock = result; + } + notify.notifyOne(); +} + +void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) { + *connectSock = INVALID_SOCKET; + SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol); + if (newSock != INVALID_SOCKET) { + int result = ::connect(newSock, where.ai_addr, where.ai_addrlen); + if (result == 0) { + *connectSock = newSock; } + } + notify.notifyOne(); +} +} // namespace detail - void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) { - *connectSock = INVALID_SOCKET; - SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol); - if (newSock != INVALID_SOCKET) { - int result = ::connect(newSock, where.ai_addr, where.ai_addrlen); - if (result == 0) { - *connectSock = newSock; - } - } - notify.notifyOne(); - } - } // namespace detail +SocketPair socketPair(const int type, const int protocol) { + const int domain = PF_INET; - SocketPair socketPair(const int type, const int protocol) { + // Create a listen socket and a connect socket. + const SOCKET listenSock = ::socket(domain, type, protocol); + if (listenSock == INVALID_SOCKET) + return SocketPair(); - const int domain = PF_INET; + // Bind the listen socket on port zero, it will pick one for us, and start it listening + // for connections. + struct addrinfo hints, *res; + ::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_INET; + hints.ai_socktype = type; + hints.ai_flags = AI_PASSIVE; - // Create a listen socket and a connect socket. - const SOCKET listenSock = ::socket(domain, type, protocol); - if (listenSock == INVALID_SOCKET) - return SocketPair(); + int result = ::getaddrinfo(NULL, "0", &hints, &res); + if (result != 0) { + closesocket(listenSock); + return SocketPair(); + } - // Bind the listen socket on port zero, it will pick one for us, and start it listening - // for connections. - struct addrinfo hints, *res; - ::memset(&hints, 0, sizeof(hints)); - hints.ai_family = PF_INET; - hints.ai_socktype = type; - hints.ai_flags = AI_PASSIVE; + result = ::bind(listenSock, res->ai_addr, res->ai_addrlen); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - int result = ::getaddrinfo(NULL, "0", &hints, &res); - if (result != 0) { - closesocket(listenSock); - return SocketPair(); - } + // Read out the port to which we bound. + sockaddr_in bindAddr; + ::socklen_t len = sizeof(bindAddr); + ::memset(&bindAddr, 0, sizeof(bindAddr)); + result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - result = ::bind(listenSock, res->ai_addr, res->ai_addrlen); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + result = ::listen(listenSock, 1); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - // Read out the port to which we bound. - sockaddr_in bindAddr; - ::socklen_t len = sizeof(bindAddr); - ::memset(&bindAddr, 0, sizeof(bindAddr)); - result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + struct addrinfo connectHints, *connectRes; + ::memset(&connectHints, 0, sizeof(connectHints)); + connectHints.ai_family = PF_INET; + connectHints.ai_socktype = type; + std::stringstream portStream; + portStream << ntohs(bindAddr.sin_port); + result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - result = ::listen(listenSock, 1); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some + // threads to do the connect and acccept. - struct addrinfo connectHints, *connectRes; - ::memset(&connectHints, 0, sizeof(connectHints)); - connectHints.ai_family = PF_INET; - connectHints.ai_socktype = type; - std::stringstream portStream; - portStream << ntohs(bindAddr.sin_port); - result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + Notification accepted; + SOCKET acceptSock = INVALID_SOCKET; + stdx::thread acceptor( + stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted))); - // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some - // threads to do the connect and acccept. - - Notification accepted; - SOCKET acceptSock = INVALID_SOCKET; - stdx::thread acceptor( - stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted))); - - Notification connected; - SOCKET connectSock = INVALID_SOCKET; - stdx::thread connector( - stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected))); - - connected.waitToBeNotified(); - if (connectSock == INVALID_SOCKET) { - closesocket(listenSock); - ::freeaddrinfo(res); - ::freeaddrinfo(connectRes); - closesocket(acceptSock); - closesocket(connectSock); - return SocketPair(); - } + Notification connected; + SOCKET connectSock = INVALID_SOCKET; + stdx::thread connector( + stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected))); - accepted.waitToBeNotified(); - if (acceptSock == INVALID_SOCKET) { - closesocket(listenSock); - ::freeaddrinfo(res); - ::freeaddrinfo(connectRes); - closesocket(acceptSock); - closesocket(connectSock); - return SocketPair(); - } + connected.waitToBeNotified(); + if (connectSock == INVALID_SOCKET) { + closesocket(listenSock); + ::freeaddrinfo(res); + ::freeaddrinfo(connectRes); + closesocket(acceptSock); + closesocket(connectSock); + return SocketPair(); + } + accepted.waitToBeNotified(); + if (acceptSock == INVALID_SOCKET) { closesocket(listenSock); ::freeaddrinfo(res); ::freeaddrinfo(connectRes); + closesocket(acceptSock); + closesocket(connectSock); + return SocketPair(); + } - SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr())); - SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr())); + closesocket(listenSock); + ::freeaddrinfo(res); + ::freeaddrinfo(connectRes); - return SocketPair(first, second); - } + SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr())); + SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr())); + + return SocketPair(first, second); +} #else - // We can just use ::socketpair and wrap up the result in a Socket. - SocketPair socketPair(const int type, const int protocol) { - // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX - // is the name that BSD used. We use the BSD name because it is more - // widely supported (e.g. Solaris 10). - const int domain = PF_UNIX; - - int socks[2]; - const int result = ::socketpair(domain, type, protocol, socks); - if (result == 0) { - return SocketPair( - SocketPtr(new Socket(socks[0], SockAddr())), - SocketPtr(new Socket(socks[1], SockAddr()))); - } - return SocketPair(); +// We can just use ::socketpair and wrap up the result in a Socket. +SocketPair socketPair(const int type, const int protocol) { + // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX + // is the name that BSD used. We use the BSD name because it is more + // widely supported (e.g. Solaris 10). + const int domain = PF_UNIX; + + int socks[2]; + const int result = ::socketpair(domain, type, protocol, socks); + if (result == 0) { + return SocketPair(SocketPtr(new Socket(socks[0], SockAddr())), + SocketPtr(new Socket(socks[1], SockAddr()))); } + return SocketPair(); +} #endif - // This should match the name of the fail point declared in sock.cpp. - const char kSocketFailPointName[] = "throwSockExcep"; - - class SocketFailPointTest : public unittest::Test { - public: +// This should match the name of the fail point declared in sock.cpp. +const char kSocketFailPointName[] = "throwSockExcep"; + +class SocketFailPointTest : public unittest::Test { +public: + SocketFailPointTest() + : _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName)), + _sockets(socketPair(SOCK_STREAM)) { + ASSERT_TRUE(_failPoint != NULL); + ASSERT_TRUE(_sockets.first); + ASSERT_TRUE(_sockets.second); + } - SocketFailPointTest() - : _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName)) - , _sockets(socketPair(SOCK_STREAM)) { - ASSERT_TRUE(_failPoint != NULL); - ASSERT_TRUE(_sockets.first); - ASSERT_TRUE(_sockets.second); - } + ~SocketFailPointTest() {} - ~SocketFailPointTest() { - } + bool trySend() { + char byte = 'x'; + _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend"); + return true; + } - bool trySend() { - char byte = 'x'; - _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend"); - return true; - } + bool trySendVector() { + std::vector<std::pair<char*, int>> data; + char byte = 'x'; + data.push_back(std::make_pair(&byte, sizeof(byte))); + _sockets.first->send(data, "SocketFailPointTest::trySendVector"); + return true; + } - bool trySendVector() { - std::vector<std::pair<char*, int> > data; - char byte = 'x'; - data.push_back(std::make_pair(&byte, sizeof(byte))); - _sockets.first->send(data, "SocketFailPointTest::trySendVector"); - return true; - } + bool tryRecv() { + char byte; + _sockets.second->recv(&byte, sizeof(byte)); + return true; + } - bool tryRecv() { - char byte; - _sockets.second->recv(&byte, sizeof(byte)); - return true; - } + // You must queue at least one byte on the send socket before calling this function. + size_t countRecvable(size_t max) { + std::vector<char> buf(max); + // This isn't great, because we don't have a guarantee that multiple sends will be + // captured in one recv. However, sock doesn't let us pass flags into recv, so we + // can't make this non blocking, and therefore can't risk another call. + return _sockets.second->unsafe_recv(&buf[0], max); + } - // You must queue at least one byte on the send socket before calling this function. - size_t countRecvable(size_t max) { - std::vector<char> buf(max); - // This isn't great, because we don't have a guarantee that multiple sends will be - // captured in one recv. However, sock doesn't let us pass flags into recv, so we - // can't make this non blocking, and therefore can't risk another call. - return _sockets.second->unsafe_recv(&buf[0], max); - } + FailPoint* const _failPoint; + const SocketPair _sockets; +}; - FailPoint* const _failPoint; - const SocketPair _sockets; - }; +class ScopedFailPointEnabler { +public: + ScopedFailPointEnabler(FailPoint& fp) : _fp(fp) { + _fp.setMode(FailPoint::alwaysOn); + } - class ScopedFailPointEnabler { - public: - ScopedFailPointEnabler(FailPoint& fp) - : _fp(fp) { - _fp.setMode(FailPoint::alwaysOn); - } + ~ScopedFailPointEnabler() { + _fp.setMode(FailPoint::off); + } - ~ScopedFailPointEnabler() { - _fp.setMode(FailPoint::off); - } - private: - FailPoint& _fp; - }; +private: + FailPoint& _fp; +}; - TEST_F(SocketFailPointTest, TestSend) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - const ScopedFailPointEnabler enabled(*_failPoint); - ASSERT_THROWS(trySend(), SocketException); - } - // Channel should be working again - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); +TEST_F(SocketFailPointTest, TestSend) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { + const ScopedFailPointEnabler enabled(*_failPoint); + ASSERT_THROWS(trySend(), SocketException); } - - TEST_F(SocketFailPointTest, TestSendVector) { - ASSERT_TRUE(trySendVector()); - ASSERT_TRUE(tryRecv()); - { - const ScopedFailPointEnabler enabled(*_failPoint); - ASSERT_THROWS(trySendVector(), SocketException); - } - ASSERT_TRUE(trySendVector()); - ASSERT_TRUE(tryRecv()); + // Channel should be working again + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); +} + +TEST_F(SocketFailPointTest, TestSendVector) { + ASSERT_TRUE(trySendVector()); + ASSERT_TRUE(tryRecv()); + { + const ScopedFailPointEnabler enabled(*_failPoint); + ASSERT_THROWS(trySendVector(), SocketException); } - - TEST_F(SocketFailPointTest, TestRecv) { - ASSERT_TRUE(trySend()); // data for recv - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); // data for recv - const ScopedFailPointEnabler enabled(*_failPoint); - ASSERT_THROWS(tryRecv(), SocketException); - } - ASSERT_TRUE(trySend()); // data for recv - ASSERT_TRUE(tryRecv()); + ASSERT_TRUE(trySendVector()); + ASSERT_TRUE(tryRecv()); +} + +TEST_F(SocketFailPointTest, TestRecv) { + ASSERT_TRUE(trySend()); // data for recv + ASSERT_TRUE(tryRecv()); + { + ASSERT_TRUE(trySend()); // data for recv + const ScopedFailPointEnabler enabled(*_failPoint); + ASSERT_THROWS(tryRecv(), SocketException); } - - TEST_F(SocketFailPointTest, TestFailedSendsDontSend) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); // queue 1 byte - const ScopedFailPointEnabler enabled(*_failPoint); - // Fail to queue another byte - ASSERT_THROWS(trySend(), SocketException); - } - // Failed byte should not have been transmitted. - ASSERT_EQUALS(size_t(1), countRecvable(2)); + ASSERT_TRUE(trySend()); // data for recv + ASSERT_TRUE(tryRecv()); +} + +TEST_F(SocketFailPointTest, TestFailedSendsDontSend) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { + ASSERT_TRUE(trySend()); // queue 1 byte + const ScopedFailPointEnabler enabled(*_failPoint); + // Fail to queue another byte + ASSERT_THROWS(trySend(), SocketException); } - - // Ensure that calling send doesn't actually enqueue data to the socket - TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); // queue 1 byte - const ScopedFailPointEnabler enabled(*_failPoint); - // Fail to queue another byte - ASSERT_THROWS(trySendVector(), SocketException); - } - // Failed byte should not have been transmitted. - ASSERT_EQUALS(size_t(1), countRecvable(2)); + // Failed byte should not have been transmitted. + ASSERT_EQUALS(size_t(1), countRecvable(2)); +} + +// Ensure that calling send doesn't actually enqueue data to the socket +TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { + ASSERT_TRUE(trySend()); // queue 1 byte + const ScopedFailPointEnabler enabled(*_failPoint); + // Fail to queue another byte + ASSERT_THROWS(trySendVector(), SocketException); } - - TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); - const ScopedFailPointEnabler enabled(*_failPoint); - // Fail to recv that byte - ASSERT_THROWS(tryRecv(), SocketException); - } - // Failed byte should still be queued to recv. - ASSERT_EQUALS(size_t(1), countRecvable(1)); - // Channel should be working again + // Failed byte should not have been transmitted. + ASSERT_EQUALS(size_t(1), countRecvable(2)); +} + +TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); + const ScopedFailPointEnabler enabled(*_failPoint); + // Fail to recv that byte + ASSERT_THROWS(tryRecv(), SocketException); } + // Failed byte should still be queued to recv. + ASSERT_EQUALS(size_t(1), countRecvable(1)); + // Channel should be working again + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); +} -} // namespace +} // namespace diff --git a/src/mongo/util/net/socket_poll.cpp b/src/mongo/util/net/socket_poll.cpp index 44ef84ff28a..f52da54355e 100644 --- a/src/mongo/util/net/socket_poll.cpp +++ b/src/mongo/util/net/socket_poll.cpp @@ -36,39 +36,37 @@ namespace mongo { #ifdef _WIN32 - typedef int (WSAAPI *WSAPollFunction)(pollfd* fdarray, ULONG nfds, INT timeout); +typedef int(WSAAPI* WSAPollFunction)(pollfd* fdarray, ULONG nfds, INT timeout); - static WSAPollFunction wsaPollFunction = NULL; +static WSAPollFunction wsaPollFunction = NULL; - MONGO_INITIALIZER(DynamicLinkWin32Poll)(InitializerContext* context) { - HINSTANCE wsaPollLib = LoadLibraryW( L"Ws2_32.dll" ); - if (wsaPollLib) { - wsaPollFunction = - reinterpret_cast<WSAPollFunction>(GetProcAddress(wsaPollLib, "WSAPoll")); - } - - return Status::OK(); +MONGO_INITIALIZER(DynamicLinkWin32Poll)(InitializerContext* context) { + HINSTANCE wsaPollLib = LoadLibraryW(L"Ws2_32.dll"); + if (wsaPollLib) { + wsaPollFunction = reinterpret_cast<WSAPollFunction>(GetProcAddress(wsaPollLib, "WSAPoll")); } - bool isPollSupported() { return wsaPollFunction != NULL; } + return Status::OK(); +} - int socketPoll( pollfd* fdarray, unsigned long nfds, int timeout ) { - fassert(17185, isPollSupported()); - return wsaPollFunction(fdarray, nfds, timeout); - } +bool isPollSupported() { + return wsaPollFunction != NULL; +} + +int socketPoll(pollfd* fdarray, unsigned long nfds, int timeout) { + fassert(17185, isPollSupported()); + return wsaPollFunction(fdarray, nfds, timeout); +} #else - bool isPollSupported() { return true; } +bool isPollSupported() { + return true; +} - int socketPoll( pollfd* fdarray, unsigned long nfds, int timeout ) { - return ::poll(fdarray, nfds, timeout); - } +int socketPoll(pollfd* fdarray, unsigned long nfds, int timeout) { + return ::poll(fdarray, nfds, timeout); +} #endif - } - - - - diff --git a/src/mongo/util/net/socket_poll.h b/src/mongo/util/net/socket_poll.h index b3e8ed2fcd3..379dc8226be 100644 --- a/src/mongo/util/net/socket_poll.h +++ b/src/mongo/util/net/socket_poll.h @@ -29,10 +29,10 @@ #pragma once #ifndef _WIN32 -# include <sys/poll.h> -#endif // ndef _WIN32 +#include <sys/poll.h> +#endif // ndef _WIN32 namespace mongo { - bool isPollSupported(); - int socketPoll(pollfd* fdarray, unsigned long nfds, int timeout); +bool isPollSupported(); +int socketPoll(pollfd* fdarray, unsigned long nfds, int timeout); } diff --git a/src/mongo/util/net/ssl_expiration.cpp b/src/mongo/util/net/ssl_expiration.cpp index d05462ce442..110568e272f 100644 --- a/src/mongo/util/net/ssl_expiration.cpp +++ b/src/mongo/util/net/ssl_expiration.cpp @@ -36,42 +36,39 @@ namespace mongo { - static const auto oneDay = stdx::chrono::hours(24); +static const auto oneDay = stdx::chrono::hours(24); - CertificateExpirationMonitor::CertificateExpirationMonitor(Date_t date) - : _certExpiration(date) - , _lastCheckTime(Date_t::now()) { - } +CertificateExpirationMonitor::CertificateExpirationMonitor(Date_t date) + : _certExpiration(date), _lastCheckTime(Date_t::now()) {} - std::string CertificateExpirationMonitor::taskName() const { - return "CertificateExpirationMonitor"; - } +std::string CertificateExpirationMonitor::taskName() const { + return "CertificateExpirationMonitor"; +} - void CertificateExpirationMonitor::taskDoWork() { - const Milliseconds timeSinceLastCheck = Date_t::now() - _lastCheckTime; +void CertificateExpirationMonitor::taskDoWork() { + const Milliseconds timeSinceLastCheck = Date_t::now() - _lastCheckTime; - if (timeSinceLastCheck < oneDay) - return; + if (timeSinceLastCheck < oneDay) + return; - const Date_t now = Date_t::now(); - _lastCheckTime = now; + const Date_t now = Date_t::now(); + _lastCheckTime = now; - if (_certExpiration <= now) { - // The certificate has expired. - warning() << "Server certificate is now invalid. It expired on " - << dateToISOStringUTC(_certExpiration); - return; - } + if (_certExpiration <= now) { + // The certificate has expired. + warning() << "Server certificate is now invalid. It expired on " + << dateToISOStringUTC(_certExpiration); + return; + } - const auto remainingValidDuration = _certExpiration - now; + const auto remainingValidDuration = _certExpiration - now; - if (remainingValidDuration <= 30 * oneDay) { - // The certificate will expire in the next 30 days. - warning() << "Server certificate will expire on " - << dateToISOStringUTC(_certExpiration) << " in " - << durationCount<stdx::chrono::hours>(remainingValidDuration) / 24 - << " days."; - } + if (remainingValidDuration <= 30 * oneDay) { + // The certificate will expire in the next 30 days. + warning() << "Server certificate will expire on " << dateToISOStringUTC(_certExpiration) + << " in " << durationCount<stdx::chrono::hours>(remainingValidDuration) / 24 + << " days."; } +} } // namespace mongo diff --git a/src/mongo/util/net/ssl_expiration.h b/src/mongo/util/net/ssl_expiration.h index ac6f30cd039..fc56c3968c7 100644 --- a/src/mongo/util/net/ssl_expiration.h +++ b/src/mongo/util/net/ssl_expiration.h @@ -32,27 +32,27 @@ namespace mongo { - class CertificateExpirationMonitor : public PeriodicTask { - public: - explicit CertificateExpirationMonitor(Date_t date); - - /** - * Gets the PeriodicTask's name. - * @return CertificateExpirationMonitor's name. - */ - virtual std::string taskName() const; - - /** - * Wakes up every minute as it is a PeriodicTask. - * Checks once a day if the server certificate has expired - * or will expire in the next 30 days and sends a warning - * to the log accordingly. - */ - virtual void taskDoWork(); - - private: - const Date_t _certExpiration; - Date_t _lastCheckTime; - }; +class CertificateExpirationMonitor : public PeriodicTask { +public: + explicit CertificateExpirationMonitor(Date_t date); + + /** + * Gets the PeriodicTask's name. + * @return CertificateExpirationMonitor's name. + */ + virtual std::string taskName() const; + + /** + * Wakes up every minute as it is a PeriodicTask. + * Checks once a day if the server certificate has expired + * or will expire in the next 30 days and sends a warning + * to the log accordingly. + */ + virtual void taskDoWork(); + +private: + const Date_t _certExpiration; + Date_t _lastCheckTime; +}; } // namespace mongo diff --git a/src/mongo/util/net/ssl_manager.cpp b/src/mongo/util/net/ssl_manager.cpp index be5165b97b8..8ab0ae55516 100644 --- a/src/mongo/util/net/ssl_manager.cpp +++ b/src/mongo/util/net/ssl_manager.cpp @@ -63,7 +63,7 @@ using std::endl; namespace mongo { - SSLParams sslGlobalParams; +SSLParams sslGlobalParams; #ifdef MONGO_CONFIG_SSL // Old copies of OpenSSL will not have constants to disable protocols they don't support. @@ -76,931 +76,902 @@ namespace mongo { #define SSL_OP_NO_TLSv1_2 0 #endif - namespace { +namespace { - /** - * Multithreaded Support for SSL. - * - * In order to allow OpenSSL to work in a multithreaded environment, you - * must provide some callbacks for it to use for locking. The following code - * sets up a vector of mutexes and uses thread-local storage to assign an id - * to each thread. - * The so-called SSLThreadInfo class encapsulates most of the logic required for - * OpenSSL multithreaded support. - */ +/** + * Multithreaded Support for SSL. + * + * In order to allow OpenSSL to work in a multithreaded environment, you + * must provide some callbacks for it to use for locking. The following code + * sets up a vector of mutexes and uses thread-local storage to assign an id + * to each thread. + * The so-called SSLThreadInfo class encapsulates most of the logic required for + * OpenSSL multithreaded support. + */ - unsigned long _ssl_id_callback(); - void _ssl_locking_callback(int mode, int type, const char *file, int line); +unsigned long _ssl_id_callback(); +void _ssl_locking_callback(int mode, int type, const char* file, int line); - class SSLThreadInfo { - public: +class SSLThreadInfo { +public: + SSLThreadInfo() { + _id = _next.fetchAndAdd(1); + } - SSLThreadInfo() { - _id = _next.fetchAndAdd(1); - } + ~SSLThreadInfo() {} - ~SSLThreadInfo() { - } + unsigned long id() const { + return _id; + } - unsigned long id() const { return _id; } + void lock_callback(int mode, int type, const char* file, int line) { + if (mode & CRYPTO_LOCK) { + _mutex[type]->lock(); + } else { + _mutex[type]->unlock(); + } + } - void lock_callback( int mode, int type, const char *file, int line ) { - if ( mode & CRYPTO_LOCK ) { - _mutex[type]->lock(); - } - else { - _mutex[type]->unlock(); - } - } + static void init() { + while ((int)_mutex.size() < CRYPTO_num_locks()) { + _mutex.emplace_back(stdx::make_unique<stdx::recursive_mutex>()); + } + } - static void init() { - while ( (int)_mutex.size() < CRYPTO_num_locks() ) { - _mutex.emplace_back( stdx::make_unique<stdx::recursive_mutex>() ); - } - } + static SSLThreadInfo* get() { + SSLThreadInfo* me = _thread.get(); + if (!me) { + me = new SSLThreadInfo(); + _thread.reset(me); + } + return me; + } - static SSLThreadInfo* get() { - SSLThreadInfo* me = _thread.get(); - if ( ! me ) { - me = new SSLThreadInfo(); - _thread.reset( me ); - } - return me; - } +private: + unsigned _id; - private: - unsigned _id; + static AtomicUInt32 _next; + // Note: see SERVER-8734 for why we are using a recursive mutex here. + // Once the deadlock fix in OpenSSL is incorporated into most distros of + // Linux, this can be changed back to a nonrecursive mutex. + static std::vector<std::unique_ptr<stdx::recursive_mutex>> _mutex; + static boost::thread_specific_ptr<SSLThreadInfo> _thread; +}; - static AtomicUInt32 _next; - // Note: see SERVER-8734 for why we are using a recursive mutex here. - // Once the deadlock fix in OpenSSL is incorporated into most distros of - // Linux, this can be changed back to a nonrecursive mutex. - static std::vector<std::unique_ptr<stdx::recursive_mutex>> _mutex; - static boost::thread_specific_ptr<SSLThreadInfo> _thread; - }; +unsigned long _ssl_id_callback() { + return SSLThreadInfo::get()->id(); +} - unsigned long _ssl_id_callback() { - return SSLThreadInfo::get()->id(); - } +void _ssl_locking_callback(int mode, int type, const char* file, int line) { + SSLThreadInfo::get()->lock_callback(mode, type, file, line); +} - void _ssl_locking_callback(int mode, int type, const char *file, int line) { - SSLThreadInfo::get()->lock_callback( mode , type , file , line ); - } +AtomicUInt32 SSLThreadInfo::_next; +std::vector<std::unique_ptr<stdx::recursive_mutex>> SSLThreadInfo::_mutex; +boost::thread_specific_ptr<SSLThreadInfo> SSLThreadInfo::_thread; - AtomicUInt32 SSLThreadInfo::_next; - std::vector<std::unique_ptr<stdx::recursive_mutex>> SSLThreadInfo::_mutex; - boost::thread_specific_ptr<SSLThreadInfo> SSLThreadInfo::_thread; +//////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////// +SimpleMutex sslManagerMtx; +SSLManagerInterface* theSSLManager = NULL; +static const int BUFFER_SIZE = 8 * 1024; +static const int DATE_LEN = 128; - SimpleMutex sslManagerMtx; - SSLManagerInterface* theSSLManager = NULL; - static const int BUFFER_SIZE = 8*1024; - static const int DATE_LEN = 128; +class SSLManager : public SSLManagerInterface { +public: + explicit SSLManager(const SSLParams& params, bool isServer); - class SSLManager : public SSLManagerInterface { - public: - explicit SSLManager(const SSLParams& params, bool isServer); + virtual ~SSLManager(); - virtual ~SSLManager(); + virtual SSLConnection* connect(Socket* socket); - virtual SSLConnection* connect(Socket* socket); + virtual SSLConnection* accept(Socket* socket, const char* initialBytes, int len); - virtual SSLConnection* accept(Socket* socket, const char* initialBytes, int len); + virtual std::string parseAndValidatePeerCertificate(const SSLConnection* conn, + const std::string& remoteHost); - virtual std::string parseAndValidatePeerCertificate(const SSLConnection* conn, - const std::string& remoteHost); + virtual void cleanupThreadLocals(); - virtual void cleanupThreadLocals(); + virtual const SSLConfiguration& getSSLConfiguration() const { + return _sslConfiguration; + } - virtual const SSLConfiguration& getSSLConfiguration() const { - return _sslConfiguration; - } + virtual int SSL_read(SSLConnection* conn, void* buf, int num); - virtual int SSL_read(SSLConnection* conn, void* buf, int num); - - virtual int SSL_write(SSLConnection* conn, const void* buf, int num); - - virtual unsigned long ERR_get_error(); - - virtual char* ERR_error_string(unsigned long e, char* buf); - - virtual int SSL_get_error(const SSLConnection* conn, int ret); - - virtual int SSL_shutdown(SSLConnection* conn); - - virtual void SSL_free(SSLConnection* conn); - - private: - SSL_CTX* _serverContext; // SSL context for incoming connections - SSL_CTX* _clientContext; // SSL context for outgoing connections - std::string _password; - bool _weakValidation; - bool _allowInvalidCertificates; - bool _allowInvalidHostnames; - SSLConfiguration _sslConfiguration; - - /** - * creates an SSL object to be used for this file descriptor. - * caller must SSL_free it. - */ - SSL* _secure(SSL_CTX* context, int fd); - - /** - * Given an error code from an SSL-type IO function, logs an - * appropriate message and throws a SocketException - */ - MONGO_COMPILER_NORETURN void _handleSSLError(int code, int ret); - - /* - * Init the SSL context using parameters provided in params. - */ - bool _initSSLContext(SSL_CTX** context, const SSLParams& params); - - /* - * Converts time from OpenSSL return value to unsigned long long - * representing the milliseconds since the epoch. - */ - unsigned long long _convertASN1ToMillis(ASN1_TIME* t); - - /* - * Parse and store x509 subject name from the PEM keyfile. - * For server instances check that PEM certificate is not expired - * and extract server certificate notAfter date. - * @param keyFile referencing the PEM file to be read. - * @param subjectName as a pointer to the subject name variable being set. - * @param serverNotAfter a Date_t object pointer that is valued if the - * date is to be checked (as for a server certificate) and null otherwise. - * @return bool showing if the function was successful. - */ - bool _parseAndValidateCertificate(const std::string& keyFile, - std::string* subjectName, - Date_t* serverNotAfter); + virtual int SSL_write(SSLConnection* conn, const void* buf, int num); - /** @return true if was successful, otherwise false */ - bool _setupPEM(SSL_CTX* context, - const std::string& keyFile, - const std::string& password); - - /* - * Set up an SSL context for certificate validation by loading a CA - */ - bool _setupCA(SSL_CTX* context, const std::string& caFile); - - /* - * Import a certificate revocation list into an SSL context - * for use with validating certificates - */ - bool _setupCRL(SSL_CTX* context, const std::string& crlFile); - - /* - * sub function for checking the result of an SSL operation - */ - bool _doneWithSSLOp(SSLConnection* conn, int status); - - /* - * Send and receive network data - */ - void _flushNetworkBIO(SSLConnection* conn); - - /* - * match a remote host name to an x.509 host name - */ - bool _hostNameMatch(const char* nameToMatch, const char* certHostName); - - /** - * Callbacks for SSL functions - */ - static int password_cb( char *buf,int num, int rwflag,void *userdata ); - static int verify_cb(int ok, X509_STORE_CTX *ctx); - - }; - - void setupFIPS() { - // Turn on FIPS mode if requested, OPENSSL_FIPS must be defined by the OpenSSL headers -#if defined(MONGO_CONFIG_HAVE_FIPS_MODE_SET) - int status = FIPS_mode_set(1); - if (!status) { - severe() << "can't activate FIPS mode: " << - SSLManagerInterface::getSSLErrorMessage(ERR_get_error()) << endl; - fassertFailedNoTrace(16703); - } - log() << "FIPS 140-2 mode activated" << endl; -#else - severe() << "this version of mongodb was not compiled with FIPS support"; - fassertFailedNoTrace(17089); -#endif - } - } // namespace + virtual unsigned long ERR_get_error(); - // Global variable indicating if this is a server or a client instance - bool isSSLServer = false; + virtual char* ERR_error_string(unsigned long e, char* buf); + virtual int SSL_get_error(const SSLConnection* conn, int ret); - MONGO_INITIALIZER(SetupOpenSSL) (InitializerContext*) { - SSL_library_init(); - SSL_load_error_strings(); - ERR_load_crypto_strings(); + virtual int SSL_shutdown(SSLConnection* conn); - if (sslGlobalParams.sslFIPSMode) { - setupFIPS(); - } + virtual void SSL_free(SSLConnection* conn); - // Add all digests and ciphers to OpenSSL's internal table - // so that encryption/decryption is backwards compatible - OpenSSL_add_all_algorithms(); +private: + SSL_CTX* _serverContext; // SSL context for incoming connections + SSL_CTX* _clientContext; // SSL context for outgoing connections + std::string _password; + bool _weakValidation; + bool _allowInvalidCertificates; + bool _allowInvalidHostnames; + SSLConfiguration _sslConfiguration; - // Setup OpenSSL multithreading callbacks - CRYPTO_set_id_callback(_ssl_id_callback); - CRYPTO_set_locking_callback(_ssl_locking_callback); + /** + * creates an SSL object to be used for this file descriptor. + * caller must SSL_free it. + */ + SSL* _secure(SSL_CTX* context, int fd); - SSLThreadInfo::init(); - SSLThreadInfo::get(); + /** + * Given an error code from an SSL-type IO function, logs an + * appropriate message and throws a SocketException + */ + MONGO_COMPILER_NORETURN void _handleSSLError(int code, int ret); - return Status::OK(); - } + /* + * Init the SSL context using parameters provided in params. + */ + bool _initSSLContext(SSL_CTX** context, const SSLParams& params); - MONGO_INITIALIZER_WITH_PREREQUISITES(SSLManager, - ("SetupOpenSSL")) - (InitializerContext*) { - stdx::lock_guard<SimpleMutex> lck(sslManagerMtx); - if (sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled) { - theSSLManager = new SSLManager(sslGlobalParams, isSSLServer); - } - return Status::OK(); - } + /* + * Converts time from OpenSSL return value to unsigned long long + * representing the milliseconds since the epoch. + */ + unsigned long long _convertASN1ToMillis(ASN1_TIME* t); - std::unique_ptr<SSLManagerInterface> SSLManagerInterface::create(const SSLParams& params, - bool isServer) { - return stdx::make_unique<SSLManager>(params, isServer); - } + /* + * Parse and store x509 subject name from the PEM keyfile. + * For server instances check that PEM certificate is not expired + * and extract server certificate notAfter date. + * @param keyFile referencing the PEM file to be read. + * @param subjectName as a pointer to the subject name variable being set. + * @param serverNotAfter a Date_t object pointer that is valued if the + * date is to be checked (as for a server certificate) and null otherwise. + * @return bool showing if the function was successful. + */ + bool _parseAndValidateCertificate(const std::string& keyFile, + std::string* subjectName, + Date_t* serverNotAfter); - SSLManagerInterface* getSSLManager() { - stdx::lock_guard<SimpleMutex> lck(sslManagerMtx); - if (theSSLManager) - return theSSLManager; - return NULL; - } + /** @return true if was successful, otherwise false */ + bool _setupPEM(SSL_CTX* context, const std::string& keyFile, const std::string& password); - std::string getCertificateSubjectName(X509* cert) { - std::string result; + /* + * Set up an SSL context for certificate validation by loading a CA + */ + bool _setupCA(SSL_CTX* context, const std::string& caFile); - BIO* out = BIO_new(BIO_s_mem()); - uassert(16884, "unable to allocate BIO memory", NULL != out); - ON_BLOCK_EXIT(BIO_free, out); + /* + * Import a certificate revocation list into an SSL context + * for use with validating certificates + */ + bool _setupCRL(SSL_CTX* context, const std::string& crlFile); - if (X509_NAME_print_ex(out, - X509_get_subject_name(cert), - 0, - XN_FLAG_RFC2253) >= 0) { - if (BIO_number_written(out) > 0) { - result.resize(BIO_number_written(out)); - BIO_read(out, &result[0], result.size()); - } - } - else { - log() << "failed to convert subject name to RFC2253 format" << endl; - } + /* + * sub function for checking the result of an SSL operation + */ + bool _doneWithSSLOp(SSLConnection* conn, int status); + + /* + * Send and receive network data + */ + void _flushNetworkBIO(SSLConnection* conn); - return result; + /* + * match a remote host name to an x.509 host name + */ + bool _hostNameMatch(const char* nameToMatch, const char* certHostName); + + /** + * Callbacks for SSL functions + */ + static int password_cb(char* buf, int num, int rwflag, void* userdata); + static int verify_cb(int ok, X509_STORE_CTX* ctx); +}; + +void setupFIPS() { +// Turn on FIPS mode if requested, OPENSSL_FIPS must be defined by the OpenSSL headers +#if defined(MONGO_CONFIG_HAVE_FIPS_MODE_SET) + int status = FIPS_mode_set(1); + if (!status) { + severe() << "can't activate FIPS mode: " + << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()) << endl; + fassertFailedNoTrace(16703); } + log() << "FIPS 140-2 mode activated" << endl; +#else + severe() << "this version of mongodb was not compiled with FIPS support"; + fassertFailedNoTrace(17089); +#endif +} +} // namespace - SSLConnection::SSLConnection(SSL_CTX* context, - Socket* sock, - const char* initialBytes, - int len) : socket(sock) { - // This just ensures that SSL multithreading support is set up for this thread, - // if it's not already. - SSLThreadInfo::get(); - - ssl = SSL_new(context); - - std::string sslErr = NULL != getSSLManager() ? - getSSLManager()->getSSLErrorMessage(ERR_get_error()) : ""; - massert(15861, "Error creating new SSL object " + sslErr, ssl); - - BIO_new_bio_pair(&internalBIO, BUFFER_SIZE, &networkBIO, BUFFER_SIZE); - SSL_set_bio(ssl, internalBIO, internalBIO); - - if (len > 0) { - int toBIO = BIO_write(networkBIO, initialBytes, len); - if (toBIO != len) { - LOG(3) << "Failed to write initial network data to the SSL BIO layer"; - throw SocketException(SocketException::RECV_ERROR , socket->remoteString()); - } - } - } +// Global variable indicating if this is a server or a client instance +bool isSSLServer = false; - SSLConnection::~SSLConnection() { - if (ssl) { // The internalBIO is automatically freed as part of SSL_free - SSL_free(ssl); - } - if (networkBIO) { - BIO_free(networkBIO); - } - } - BSONObj SSLConfiguration::getServerStatusBSON() const { - BSONObjBuilder security; - security.append("SSLServerSubjectName", - serverSubjectName); - security.appendBool("SSLServerHasCertificateAuthority", - hasCA); - security.appendDate("SSLServerCertificateExpirationDate", - serverCertificateExpirationDate); - return security.obj(); +MONGO_INITIALIZER(SetupOpenSSL)(InitializerContext*) { + SSL_library_init(); + SSL_load_error_strings(); + ERR_load_crypto_strings(); + + if (sslGlobalParams.sslFIPSMode) { + setupFIPS(); } - SSLManagerInterface::~SSLManagerInterface() {} + // Add all digests and ciphers to OpenSSL's internal table + // so that encryption/decryption is backwards compatible + OpenSSL_add_all_algorithms(); - SSLManager::SSLManager(const SSLParams& params, bool isServer) : - _serverContext(NULL), - _clientContext(NULL), - _weakValidation(params.sslWeakCertificateValidation), - _allowInvalidCertificates(params.sslAllowInvalidCertificates), - _allowInvalidHostnames(params.sslAllowInvalidHostnames) { + // Setup OpenSSL multithreading callbacks + CRYPTO_set_id_callback(_ssl_id_callback); + CRYPTO_set_locking_callback(_ssl_locking_callback); - if (!_initSSLContext(&_clientContext, params)) { - uasserted(16768, "ssl initialization problem"); - } + SSLThreadInfo::init(); + SSLThreadInfo::get(); - // pick the certificate for use in outgoing connections, - std::string clientPEM; - if (!isServer || params.sslClusterFile.empty()) { - // We are either a client, or a server without a cluster key, - // so use the PEM key file, if specified - clientPEM = params.sslPEMKeyFile; - } - else { - // We are a server with a cluster key, so use the cluster key file - clientPEM = params.sslClusterFile; - } + return Status::OK(); +} - if (!clientPEM.empty()) { - if (!_parseAndValidateCertificate(clientPEM, - &_sslConfiguration.clientSubjectName, NULL)) { - uasserted(16941, "ssl initialization problem"); - } - } - // SSL server specific initialization - if (isServer) { - if (!_initSSLContext(&_serverContext, params)) { - uasserted(16562, "ssl initialization problem"); - } +MONGO_INITIALIZER_WITH_PREREQUISITES(SSLManager, ("SetupOpenSSL")) +(InitializerContext*) { + stdx::lock_guard<SimpleMutex> lck(sslManagerMtx); + if (sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled) { + theSSLManager = new SSLManager(sslGlobalParams, isSSLServer); + } + return Status::OK(); +} - if (!_parseAndValidateCertificate(params.sslPEMKeyFile, - &_sslConfiguration.serverSubjectName, - &_sslConfiguration.serverCertificateExpirationDate)) { - uasserted(16942, "ssl initialization problem"); - } +std::unique_ptr<SSLManagerInterface> SSLManagerInterface::create(const SSLParams& params, + bool isServer) { + return stdx::make_unique<SSLManager>(params, isServer); +} + +SSLManagerInterface* getSSLManager() { + stdx::lock_guard<SimpleMutex> lck(sslManagerMtx); + if (theSSLManager) + return theSSLManager; + return NULL; +} + +std::string getCertificateSubjectName(X509* cert) { + std::string result; - static CertificateExpirationMonitor task = - CertificateExpirationMonitor(_sslConfiguration.serverCertificateExpirationDate); + BIO* out = BIO_new(BIO_s_mem()); + uassert(16884, "unable to allocate BIO memory", NULL != out); + ON_BLOCK_EXIT(BIO_free, out); + + if (X509_NAME_print_ex(out, X509_get_subject_name(cert), 0, XN_FLAG_RFC2253) >= 0) { + if (BIO_number_written(out) > 0) { + result.resize(BIO_number_written(out)); + BIO_read(out, &result[0], result.size()); } + } else { + log() << "failed to convert subject name to RFC2253 format" << endl; } - SSLManager::~SSLManager() { - if (NULL != _serverContext) { - SSL_CTX_free(_serverContext); - } - if (NULL != _clientContext) { - SSL_CTX_free(_clientContext); + return result; +} + +SSLConnection::SSLConnection(SSL_CTX* context, Socket* sock, const char* initialBytes, int len) + : socket(sock) { + // This just ensures that SSL multithreading support is set up for this thread, + // if it's not already. + SSLThreadInfo::get(); + + ssl = SSL_new(context); + + std::string sslErr = + NULL != getSSLManager() ? getSSLManager()->getSSLErrorMessage(ERR_get_error()) : ""; + massert(15861, "Error creating new SSL object " + sslErr, ssl); + + BIO_new_bio_pair(&internalBIO, BUFFER_SIZE, &networkBIO, BUFFER_SIZE); + SSL_set_bio(ssl, internalBIO, internalBIO); + + if (len > 0) { + int toBIO = BIO_write(networkBIO, initialBytes, len); + if (toBIO != len) { + LOG(3) << "Failed to write initial network data to the SSL BIO layer"; + throw SocketException(SocketException::RECV_ERROR, socket->remoteString()); } } +} - int SSLManager::password_cb(char *buf,int num, int rwflag,void *userdata) { - // Unless OpenSSL misbehaves, num should always be positive - fassert(17314, num > 0); - SSLManager* sm = static_cast<SSLManager*>(userdata); - const size_t copied = sm->_password.copy(buf, num - 1); - buf[copied] = '\0'; - return copied; +SSLConnection::~SSLConnection() { + if (ssl) { // The internalBIO is automatically freed as part of SSL_free + SSL_free(ssl); } - - int SSLManager::verify_cb(int ok, X509_STORE_CTX *ctx) { - return 1; // always succeed; we will catch the error in our get_verify_result() call + if (networkBIO) { + BIO_free(networkBIO); } +} - int SSLManager::SSL_read(SSLConnection* conn, void* buf, int num) { - int status; - do { - status = ::SSL_read(conn->ssl, buf, num); - } while(!_doneWithSSLOp(conn, status)); - - if (status <= 0) - _handleSSLError(SSL_get_error(conn, status), status); - return status; - } +BSONObj SSLConfiguration::getServerStatusBSON() const { + BSONObjBuilder security; + security.append("SSLServerSubjectName", serverSubjectName); + security.appendBool("SSLServerHasCertificateAuthority", hasCA); + security.appendDate("SSLServerCertificateExpirationDate", serverCertificateExpirationDate); + return security.obj(); +} - int SSLManager::SSL_write(SSLConnection* conn, const void* buf, int num) { - int status; - do { - status = ::SSL_write(conn->ssl, buf, num); - } while(!_doneWithSSLOp(conn, status)); - - if (status <= 0) - _handleSSLError(SSL_get_error(conn, status), status); - return status; - } +SSLManagerInterface::~SSLManagerInterface() {} - unsigned long SSLManager::ERR_get_error() { - return ::ERR_get_error(); +SSLManager::SSLManager(const SSLParams& params, bool isServer) + : _serverContext(NULL), + _clientContext(NULL), + _weakValidation(params.sslWeakCertificateValidation), + _allowInvalidCertificates(params.sslAllowInvalidCertificates), + _allowInvalidHostnames(params.sslAllowInvalidHostnames) { + if (!_initSSLContext(&_clientContext, params)) { + uasserted(16768, "ssl initialization problem"); } - char* SSLManager::ERR_error_string(unsigned long e, char* buf) { - return ::ERR_error_string(e, buf); + // pick the certificate for use in outgoing connections, + std::string clientPEM; + if (!isServer || params.sslClusterFile.empty()) { + // We are either a client, or a server without a cluster key, + // so use the PEM key file, if specified + clientPEM = params.sslPEMKeyFile; + } else { + // We are a server with a cluster key, so use the cluster key file + clientPEM = params.sslClusterFile; } - int SSLManager::SSL_get_error(const SSLConnection* conn, int ret) { - return ::SSL_get_error(conn->ssl, ret); + if (!clientPEM.empty()) { + if (!_parseAndValidateCertificate(clientPEM, &_sslConfiguration.clientSubjectName, NULL)) { + uasserted(16941, "ssl initialization problem"); + } } + // SSL server specific initialization + if (isServer) { + if (!_initSSLContext(&_serverContext, params)) { + uasserted(16562, "ssl initialization problem"); + } - int SSLManager::SSL_shutdown(SSLConnection* conn) { - int status; - do { - status = ::SSL_shutdown(conn->ssl); - } while(!_doneWithSSLOp(conn, status)); - - if (status < 0) - _handleSSLError(SSL_get_error(conn, status), status); - return status; + if (!_parseAndValidateCertificate(params.sslPEMKeyFile, + &_sslConfiguration.serverSubjectName, + &_sslConfiguration.serverCertificateExpirationDate)) { + uasserted(16942, "ssl initialization problem"); + } + + static CertificateExpirationMonitor task = + CertificateExpirationMonitor(_sslConfiguration.serverCertificateExpirationDate); } +} - void SSLManager::SSL_free(SSLConnection* conn) { - return ::SSL_free(conn->ssl); +SSLManager::~SSLManager() { + if (NULL != _serverContext) { + SSL_CTX_free(_serverContext); + } + if (NULL != _clientContext) { + SSL_CTX_free(_clientContext); } +} - bool SSLManager::_initSSLContext(SSL_CTX** context, const SSLParams& params) { - *context = SSL_CTX_new(SSLv23_method()); - massert(15864, - mongoutils::str::stream() << "can't create SSL Context: " << - getSSLErrorMessage(ERR_get_error()), - context); - - // SSL_OP_ALL - Activate all bug workaround options, to support buggy client SSL's. - // SSL_OP_NO_SSLv2 - Disable SSL v2 support - // SSL_OP_NO_SSLv3 - Disable SSL v3 support - long supportedProtocols = SSL_OP_ALL|SSL_OP_NO_SSLv2|SSL_OP_NO_SSLv3; - - // Set the supported TLS protocols. Allow --sslDisabledProtocols to disable selected ciphers. - if (!params.sslDisabledProtocols.empty()) { - for (const SSLParams::Protocols& protocol : params.sslDisabledProtocols) { - if (protocol == SSLParams::Protocols::TLS1_0) { - supportedProtocols |= SSL_OP_NO_TLSv1; - } else if (protocol == SSLParams::Protocols::TLS1_1) { - supportedProtocols |= SSL_OP_NO_TLSv1_1; - } else if (protocol == SSLParams::Protocols::TLS1_2) { - supportedProtocols |= SSL_OP_NO_TLSv1_2; - } - } - } - SSL_CTX_set_options(*context, supportedProtocols); +int SSLManager::password_cb(char* buf, int num, int rwflag, void* userdata) { + // Unless OpenSSL misbehaves, num should always be positive + fassert(17314, num > 0); + SSLManager* sm = static_cast<SSLManager*>(userdata); + const size_t copied = sm->_password.copy(buf, num - 1); + buf[copied] = '\0'; + return copied; +} - // HIGH - Enable strong ciphers - // !EXPORT - Disable export ciphers (40/56 bit) - // !aNULL - Disable anonymous auth ciphers - // @STRENGTH - Sort ciphers based on strength - std::string cipherConfig = "HIGH:!EXPORT:!aNULL@STRENGTH"; +int SSLManager::verify_cb(int ok, X509_STORE_CTX* ctx) { + return 1; // always succeed; we will catch the error in our get_verify_result() call +} - // Allow the cipher configuration string to be overriden by --sslCipherConfig - if (!params.sslCipherConfig.empty()) { - cipherConfig = params.sslCipherConfig; - } +int SSLManager::SSL_read(SSLConnection* conn, void* buf, int num) { + int status; + do { + status = ::SSL_read(conn->ssl, buf, num); + } while (!_doneWithSSLOp(conn, status)); - massert(28615, mongoutils::str::stream() << "can't set supported cipher suites: " << - getSSLErrorMessage(ERR_get_error()), - SSL_CTX_set_cipher_list(*context, cipherConfig.c_str())); - - // If renegotiation is needed, don't return from recv() or send() until it's successful. - // Note: this is for blocking sockets only. - SSL_CTX_set_mode(*context, SSL_MODE_AUTO_RETRY); - - massert(28607, - mongoutils::str::stream() << "can't store ssl session id context: " << - getSSLErrorMessage(ERR_get_error()), - SSL_CTX_set_session_id_context( - *context, - static_cast<unsigned char*>(static_cast<void*>(context)), - sizeof(*context))); - - // Use the clusterfile for internal outgoing SSL connections if specified - if (context == &_clientContext && !params.sslClusterFile.empty()) { - EVP_set_pw_prompt("Enter cluster certificate passphrase"); - if (!_setupPEM(*context, params.sslClusterFile, params.sslClusterPassword)) { - return false; - } + if (status <= 0) + _handleSSLError(SSL_get_error(conn, status), status); + return status; +} + +int SSLManager::SSL_write(SSLConnection* conn, const void* buf, int num) { + int status; + do { + status = ::SSL_write(conn->ssl, buf, num); + } while (!_doneWithSSLOp(conn, status)); + + if (status <= 0) + _handleSSLError(SSL_get_error(conn, status), status); + return status; +} + +unsigned long SSLManager::ERR_get_error() { + return ::ERR_get_error(); +} + +char* SSLManager::ERR_error_string(unsigned long e, char* buf) { + return ::ERR_error_string(e, buf); +} + +int SSLManager::SSL_get_error(const SSLConnection* conn, int ret) { + return ::SSL_get_error(conn->ssl, ret); +} + +int SSLManager::SSL_shutdown(SSLConnection* conn) { + int status; + do { + status = ::SSL_shutdown(conn->ssl); + } while (!_doneWithSSLOp(conn, status)); + + if (status < 0) + _handleSSLError(SSL_get_error(conn, status), status); + return status; +} + +void SSLManager::SSL_free(SSLConnection* conn) { + return ::SSL_free(conn->ssl); +} + +bool SSLManager::_initSSLContext(SSL_CTX** context, const SSLParams& params) { + *context = SSL_CTX_new(SSLv23_method()); + massert(15864, + mongoutils::str::stream() + << "can't create SSL Context: " << getSSLErrorMessage(ERR_get_error()), + context); + + // SSL_OP_ALL - Activate all bug workaround options, to support buggy client SSL's. + // SSL_OP_NO_SSLv2 - Disable SSL v2 support + // SSL_OP_NO_SSLv3 - Disable SSL v3 support + long supportedProtocols = SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3; + + // Set the supported TLS protocols. Allow --sslDisabledProtocols to disable selected ciphers. + if (!params.sslDisabledProtocols.empty()) { + for (const SSLParams::Protocols& protocol : params.sslDisabledProtocols) { + if (protocol == SSLParams::Protocols::TLS1_0) { + supportedProtocols |= SSL_OP_NO_TLSv1; + } else if (protocol == SSLParams::Protocols::TLS1_1) { + supportedProtocols |= SSL_OP_NO_TLSv1_1; + } else if (protocol == SSLParams::Protocols::TLS1_2) { + supportedProtocols |= SSL_OP_NO_TLSv1_2; + } + } + } + SSL_CTX_set_options(*context, supportedProtocols); + + // HIGH - Enable strong ciphers + // !EXPORT - Disable export ciphers (40/56 bit) + // !aNULL - Disable anonymous auth ciphers + // @STRENGTH - Sort ciphers based on strength + std::string cipherConfig = "HIGH:!EXPORT:!aNULL@STRENGTH"; + + // Allow the cipher configuration string to be overriden by --sslCipherConfig + if (!params.sslCipherConfig.empty()) { + cipherConfig = params.sslCipherConfig; + } + + massert(28615, + mongoutils::str::stream() + << "can't set supported cipher suites: " << getSSLErrorMessage(ERR_get_error()), + SSL_CTX_set_cipher_list(*context, cipherConfig.c_str())); + + // If renegotiation is needed, don't return from recv() or send() until it's successful. + // Note: this is for blocking sockets only. + SSL_CTX_set_mode(*context, SSL_MODE_AUTO_RETRY); + + massert(28607, + mongoutils::str::stream() + << "can't store ssl session id context: " << getSSLErrorMessage(ERR_get_error()), + SSL_CTX_set_session_id_context(*context, + static_cast<unsigned char*>(static_cast<void*>(context)), + sizeof(*context))); + + // Use the clusterfile for internal outgoing SSL connections if specified + if (context == &_clientContext && !params.sslClusterFile.empty()) { + EVP_set_pw_prompt("Enter cluster certificate passphrase"); + if (!_setupPEM(*context, params.sslClusterFile, params.sslClusterPassword)) { + return false; } - // Use the pemfile for everything else - else if (!params.sslPEMKeyFile.empty()) { - EVP_set_pw_prompt("Enter PEM passphrase"); - if (!_setupPEM(*context, params.sslPEMKeyFile, params.sslPEMKeyPassword)) { - return false; - } + } + // Use the pemfile for everything else + else if (!params.sslPEMKeyFile.empty()) { + EVP_set_pw_prompt("Enter PEM passphrase"); + if (!_setupPEM(*context, params.sslPEMKeyFile, params.sslPEMKeyPassword)) { + return false; } + } - if (!params.sslCAFile.empty()) { - // Set up certificate validation with a certificate authority - if (!_setupCA(*context, params.sslCAFile)) { - return false; - } + if (!params.sslCAFile.empty()) { + // Set up certificate validation with a certificate authority + if (!_setupCA(*context, params.sslCAFile)) { + return false; } + } - if (!params.sslCRLFile.empty()) { - if (!_setupCRL(*context, params.sslCRLFile)) { - return false; - } + if (!params.sslCRLFile.empty()) { + if (!_setupCRL(*context, params.sslCRLFile)) { + return false; } + } + + return true; +} + +unsigned long long SSLManager::_convertASN1ToMillis(ASN1_TIME* asn1time) { + BIO* outBIO = BIO_new(BIO_s_mem()); + int timeError = ASN1_TIME_print(outBIO, asn1time); + ON_BLOCK_EXIT(BIO_free, outBIO); - return true; + if (timeError <= 0) { + error() << "ASN1_TIME_print failed or wrote no data."; + return 0; } - unsigned long long SSLManager::_convertASN1ToMillis(ASN1_TIME* asn1time) { - BIO *outBIO = BIO_new(BIO_s_mem()); - int timeError = ASN1_TIME_print(outBIO, asn1time); - ON_BLOCK_EXIT(BIO_free, outBIO); + char dateChar[DATE_LEN]; + timeError = BIO_gets(outBIO, dateChar, DATE_LEN); + if (timeError <= 0) { + error() << "BIO_gets call failed to transfer contents to buf"; + return 0; + } - if (timeError <= 0) { - error() << "ASN1_TIME_print failed or wrote no data."; - return 0; - } + // Ensure that day format is two digits for parsing. + // Jun 8 17:00:03 2014 becomes Jun 08 17:00:03 2014. + if (dateChar[4] == ' ') { + dateChar[4] = '0'; + } - char dateChar[DATE_LEN]; - timeError = BIO_gets(outBIO, dateChar, DATE_LEN); - if (timeError <= 0) { - error() << "BIO_gets call failed to transfer contents to buf"; - return 0; - } + std::istringstream inStringStream((std::string(dateChar, 20))); + boost::posix_time::time_input_facet* inputFacet = + new boost::posix_time::time_input_facet("%b %d %H:%M:%S %Y"); - //Ensure that day format is two digits for parsing. - //Jun 8 17:00:03 2014 becomes Jun 08 17:00:03 2014. - if (dateChar[4] == ' ') { - dateChar[4] = '0'; - } + inStringStream.imbue(std::locale(std::cout.getloc(), inputFacet)); + boost::posix_time::ptime posixTime; + inStringStream >> posixTime; - std::istringstream inStringStream((std::string(dateChar,20))); - boost::posix_time::time_input_facet *inputFacet = - new boost::posix_time::time_input_facet("%b %d %H:%M:%S %Y"); + const boost::gregorian::date epoch = boost::gregorian::date(1970, boost::gregorian::Jan, 1); - inStringStream.imbue(std::locale(std::cout.getloc(), inputFacet)); - boost::posix_time::ptime posixTime; - inStringStream >> posixTime; + return (posixTime - boost::posix_time::ptime(epoch)).total_milliseconds(); +} - const boost::gregorian::date epoch = - boost::gregorian::date(1970, boost::gregorian::Jan, 1); +bool SSLManager::_parseAndValidateCertificate(const std::string& keyFile, + std::string* subjectName, + Date_t* serverCertificateExpirationDate) { + BIO* inBIO = BIO_new(BIO_s_file_internal()); + if (inBIO == NULL) { + error() << "failed to allocate BIO object: " << getSSLErrorMessage(ERR_get_error()); + return false; + } + + ON_BLOCK_EXIT(BIO_free, inBIO); + if (BIO_read_filename(inBIO, keyFile.c_str()) <= 0) { + error() << "cannot read key file when setting subject name: " << keyFile << ' ' + << getSSLErrorMessage(ERR_get_error()); + return false; + } - return (posixTime - boost::posix_time::ptime(epoch)).total_milliseconds(); + X509* x509 = PEM_read_bio_X509(inBIO, NULL, &SSLManager::password_cb, this); + if (x509 == NULL) { + error() << "cannot retrieve certificate from keyfile: " << keyFile << ' ' + << getSSLErrorMessage(ERR_get_error()); + return false; } + ON_BLOCK_EXIT(X509_free, x509); - bool SSLManager::_parseAndValidateCertificate(const std::string& keyFile, - std::string* subjectName, - Date_t* serverCertificateExpirationDate) { - BIO *inBIO = BIO_new(BIO_s_file_internal()); - if (inBIO == NULL) { - error() << "failed to allocate BIO object: " - << getSSLErrorMessage(ERR_get_error()); + *subjectName = getCertificateSubjectName(x509); + if (serverCertificateExpirationDate != NULL) { + unsigned long long notBeforeMillis = _convertASN1ToMillis(X509_get_notBefore(x509)); + if (notBeforeMillis == 0) { + error() << "date conversion failed"; return false; } - ON_BLOCK_EXIT(BIO_free, inBIO); - if (BIO_read_filename(inBIO, keyFile.c_str()) <= 0) { - error() << "cannot read key file when setting subject name: " - << keyFile << ' ' << getSSLErrorMessage(ERR_get_error()); + unsigned long long notAfterMillis = _convertASN1ToMillis(X509_get_notAfter(x509)); + if (notAfterMillis == 0) { + error() << "date conversion failed"; return false; } - X509* x509 = PEM_read_bio_X509(inBIO, NULL, &SSLManager::password_cb, this); - if (x509 == NULL) { - error() << "cannot retrieve certificate from keyfile: " - << keyFile << ' ' << getSSLErrorMessage(ERR_get_error()); - return false; + if ((notBeforeMillis > curTimeMillis64()) || (curTimeMillis64() > notAfterMillis)) { + severe() << "The provided SSL certificate is expired or not yet valid."; + fassertFailedNoTrace(28652); } - ON_BLOCK_EXIT(X509_free, x509); - *subjectName = getCertificateSubjectName(x509); - if (serverCertificateExpirationDate != NULL) { + *serverCertificateExpirationDate = Date_t::fromMillisSinceEpoch(notAfterMillis); + } - unsigned long long notBeforeMillis = _convertASN1ToMillis(X509_get_notBefore(x509)); - if (notBeforeMillis == 0) { - error() << "date conversion failed"; - return false; - } + return true; +} - unsigned long long notAfterMillis = _convertASN1ToMillis(X509_get_notAfter(x509)); - if (notAfterMillis == 0) { - error() << "date conversion failed"; - return false; - } +bool SSLManager::_setupPEM(SSL_CTX* context, + const std::string& keyFile, + const std::string& password) { + _password = password; - if ((notBeforeMillis > curTimeMillis64()) || (curTimeMillis64() > notAfterMillis)) { - severe() << "The provided SSL certificate is expired or not yet valid."; - fassertFailedNoTrace(28652); - } + if (SSL_CTX_use_certificate_chain_file(context, keyFile.c_str()) != 1) { + error() << "cannot read certificate file: " << keyFile << ' ' + << getSSLErrorMessage(ERR_get_error()) << endl; + return false; + } - *serverCertificateExpirationDate = Date_t::fromMillisSinceEpoch(notAfterMillis); - } + // If password is empty, use default OpenSSL callback, which uses the terminal + // to securely request the password interactively from the user. + if (!password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata(context, this); + SSL_CTX_set_default_passwd_cb(context, &SSLManager::password_cb); + } - return true; + if (SSL_CTX_use_PrivateKey_file(context, keyFile.c_str(), SSL_FILETYPE_PEM) != 1) { + error() << "cannot read PEM key file: " << keyFile << ' ' + << getSSLErrorMessage(ERR_get_error()) << endl; + return false; } - bool SSLManager::_setupPEM(SSL_CTX* context, - const std::string& keyFile, - const std::string& password) { - _password = password; + // Verify that the certificate and the key go together. + if (SSL_CTX_check_private_key(context) != 1) { + error() << "SSL certificate validation: " << getSSLErrorMessage(ERR_get_error()) << endl; + return false; + } - if ( SSL_CTX_use_certificate_chain_file( context , keyFile.c_str() ) != 1 ) { - error() << "cannot read certificate file: " << keyFile << ' ' << - getSSLErrorMessage(ERR_get_error()) << endl; - return false; - } + return true; +} - // If password is empty, use default OpenSSL callback, which uses the terminal - // to securely request the password interactively from the user. - if (!password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( context , this ); - SSL_CTX_set_default_passwd_cb( context, &SSLManager::password_cb ); - } - - if ( SSL_CTX_use_PrivateKey_file( context , keyFile.c_str() , SSL_FILETYPE_PEM ) != 1 ) { - error() << "cannot read PEM key file: " << keyFile << ' ' << - getSSLErrorMessage(ERR_get_error()) << endl; - return false; - } - - // Verify that the certificate and the key go together. - if (SSL_CTX_check_private_key(context) != 1) { - error() << "SSL certificate validation: " << getSSLErrorMessage(ERR_get_error()) - << endl; - return false; +bool SSLManager::_setupCA(SSL_CTX* context, const std::string& caFile) { + // Set the list of CAs sent to clients + STACK_OF(X509_NAME)* certNames = SSL_load_client_CA_file(caFile.c_str()); + if (certNames == NULL) { + error() << "cannot read certificate authority file: " << caFile << " " + << getSSLErrorMessage(ERR_get_error()) << endl; + return false; + } + SSL_CTX_set_client_CA_list(context, certNames); + + // Load trusted CA + if (SSL_CTX_load_verify_locations(context, caFile.c_str(), NULL) != 1) { + error() << "cannot read certificate authority file: " << caFile << " " + << getSSLErrorMessage(ERR_get_error()) << endl; + return false; + } + // Set SSL to require peer (client) certificate verification + // if a certificate is presented + SSL_CTX_set_verify(context, SSL_VERIFY_PEER, &SSLManager::verify_cb); + _sslConfiguration.hasCA = true; + return true; +} + +bool SSLManager::_setupCRL(SSL_CTX* context, const std::string& crlFile) { + X509_STORE* store = SSL_CTX_get_cert_store(context); + fassert(16583, store); + + X509_STORE_set_flags(store, X509_V_FLAG_CRL_CHECK); + X509_LOOKUP* lookup = X509_STORE_add_lookup(store, X509_LOOKUP_file()); + fassert(16584, lookup); + + int status = X509_load_crl_file(lookup, crlFile.c_str(), X509_FILETYPE_PEM); + if (status == 0) { + error() << "cannot read CRL file: " << crlFile << ' ' << getSSLErrorMessage(ERR_get_error()) + << endl; + return false; + } + log() << "ssl imported " << status << " revoked certificate" << ((status == 1) ? "" : "s") + << " from the revocation list." << endl; + return true; +} + +/* +* The interface layer between network and BIO-pair. The BIO-pair buffers +* the data to/from the TLS layer. +*/ +void SSLManager::_flushNetworkBIO(SSLConnection* conn) { + char buffer[BUFFER_SIZE]; + int wantWrite; + + /* + * Write the complete contents of the buffer. Leaving the buffer + * unflushed could cause a deadlock. + */ + while ((wantWrite = BIO_ctrl_pending(conn->networkBIO)) > 0) { + if (wantWrite > BUFFER_SIZE) { + wantWrite = BUFFER_SIZE; } - - return true; + int fromBIO = BIO_read(conn->networkBIO, buffer, wantWrite); + + int writePos = 0; + do { + int numWrite = fromBIO - writePos; + numWrite = send(conn->socket->rawFD(), buffer + writePos, numWrite, portSendFlags); + if (numWrite < 0) { + conn->socket->handleSendError(numWrite, ""); + } + writePos += numWrite; + } while (writePos < fromBIO); } - bool SSLManager::_setupCA(SSL_CTX* context, const std::string& caFile) { - // Set the list of CAs sent to clients - STACK_OF (X509_NAME) * certNames = SSL_load_client_CA_file(caFile.c_str()); - if (certNames == NULL) { - error() << "cannot read certificate authority file: " << caFile << " " << - getSSLErrorMessage(ERR_get_error()) << endl; - return false; + int wantRead; + while ((wantRead = BIO_ctrl_get_read_request(conn->networkBIO)) > 0) { + if (wantRead > BUFFER_SIZE) { + wantRead = BUFFER_SIZE; } - SSL_CTX_set_client_CA_list(context, certNames); - // Load trusted CA - if (SSL_CTX_load_verify_locations(context, caFile.c_str(), NULL) != 1) { - error() << "cannot read certificate authority file: " << caFile << " " << - getSSLErrorMessage(ERR_get_error()) << endl; - return false; + int numRead = recv(conn->socket->rawFD(), buffer, wantRead, portRecvFlags); + if (numRead <= 0) { + conn->socket->handleRecvError(numRead, wantRead); + continue; + } + + int toBIO = BIO_write(conn->networkBIO, buffer, numRead); + if (toBIO != numRead) { + LOG(3) << "Failed to write network data to the SSL BIO layer"; + throw SocketException(SocketException::RECV_ERROR, conn->socket->remoteString()); } - // Set SSL to require peer (client) certificate verification - // if a certificate is presented - SSL_CTX_set_verify(context, SSL_VERIFY_PEER, &SSLManager::verify_cb); - _sslConfiguration.hasCA = true; - return true; } +} - bool SSLManager::_setupCRL(SSL_CTX* context, const std::string& crlFile) { - X509_STORE *store = SSL_CTX_get_cert_store(context); - fassert(16583, store); - - X509_STORE_set_flags(store, X509_V_FLAG_CRL_CHECK); - X509_LOOKUP *lookup = X509_STORE_add_lookup(store, X509_LOOKUP_file()); - fassert(16584, lookup); - - int status = X509_load_crl_file(lookup, crlFile.c_str(), X509_FILETYPE_PEM); - if (status == 0) { - error() << "cannot read CRL file: " << crlFile << ' ' << - getSSLErrorMessage(ERR_get_error()) << endl; +bool SSLManager::_doneWithSSLOp(SSLConnection* conn, int status) { + int sslErr = SSL_get_error(conn, status); + switch (sslErr) { + case SSL_ERROR_NONE: + _flushNetworkBIO(conn); // success, flush network BIO before leaving + return true; + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + _flushNetworkBIO(conn); // not ready, flush network BIO and try again return false; - } - log() << "ssl imported " << status << " revoked certificate" << - ((status == 1) ? "" : "s") << " from the revocation list." << - endl; - return true; + default: + return true; } +} - /* - * The interface layer between network and BIO-pair. The BIO-pair buffers - * the data to/from the TLS layer. - */ - void SSLManager::_flushNetworkBIO(SSLConnection* conn){ - char buffer[BUFFER_SIZE]; - int wantWrite; - - /* - * Write the complete contents of the buffer. Leaving the buffer - * unflushed could cause a deadlock. - */ - while ((wantWrite = BIO_ctrl_pending(conn->networkBIO)) > 0) { - if (wantWrite > BUFFER_SIZE) { - wantWrite = BUFFER_SIZE; - } - int fromBIO = BIO_read(conn->networkBIO, buffer, wantWrite); - - int writePos = 0; - do { - int numWrite = fromBIO - writePos; - numWrite = send(conn->socket->rawFD(), buffer + writePos, numWrite, portSendFlags); - if (numWrite < 0) { - conn->socket->handleSendError(numWrite, ""); - } - writePos += numWrite; - } while (writePos < fromBIO); - } +SSLConnection* SSLManager::connect(Socket* socket) { + std::unique_ptr<SSLConnection> sslConn = + stdx::make_unique<SSLConnection>(_clientContext, socket, (const char*)NULL, 0); - int wantRead; - while ((wantRead = BIO_ctrl_get_read_request(conn->networkBIO)) > 0) - { - if (wantRead > BUFFER_SIZE) { - wantRead = BUFFER_SIZE; - } + int ret; + do { + ret = ::SSL_connect(sslConn->ssl); + } while (!_doneWithSSLOp(sslConn.get(), ret)); - int numRead = recv(conn->socket->rawFD(), buffer, wantRead, portRecvFlags); - if (numRead <= 0) { - conn->socket->handleRecvError(numRead, wantRead); - continue; - } + if (ret != 1) + _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); - int toBIO = BIO_write(conn->networkBIO, buffer, numRead); - if (toBIO != numRead) { - LOG(3) << "Failed to write network data to the SSL BIO layer"; - throw SocketException(SocketException::RECV_ERROR , conn->socket->remoteString()); - } - } - } + return sslConn.release(); +} - bool SSLManager::_doneWithSSLOp(SSLConnection* conn, int status) { - int sslErr = SSL_get_error(conn, status); - switch (sslErr) { - case SSL_ERROR_NONE: - _flushNetworkBIO(conn); // success, flush network BIO before leaving - return true; - case SSL_ERROR_WANT_WRITE: - case SSL_ERROR_WANT_READ: - _flushNetworkBIO(conn); // not ready, flush network BIO and try again - return false; - default: - return true; - } - } +SSLConnection* SSLManager::accept(Socket* socket, const char* initialBytes, int len) { + std::unique_ptr<SSLConnection> sslConn = + stdx::make_unique<SSLConnection>(_serverContext, socket, initialBytes, len); - SSLConnection* SSLManager::connect(Socket* socket) { - std::unique_ptr<SSLConnection> sslConn = stdx::make_unique<SSLConnection>(_clientContext, socket, (const char*)NULL, 0); - - int ret; - do { - ret = ::SSL_connect(sslConn->ssl); - } while(!_doneWithSSLOp(sslConn.get(), ret)); - - if (ret != 1) - _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); - - return sslConn.release(); + int ret; + do { + ret = ::SSL_accept(sslConn->ssl); + } while (!_doneWithSSLOp(sslConn.get(), ret)); + + if (ret != 1) + _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); + + return sslConn.release(); +} + +// TODO SERVER-11601 Use NFC Unicode canonicalization +bool SSLManager::_hostNameMatch(const char* nameToMatch, const char* certHostName) { + if (strlen(certHostName) < 2) { + return false; } - SSLConnection* SSLManager::accept(Socket* socket, const char* initialBytes, int len) { - std::unique_ptr<SSLConnection> sslConn = stdx::make_unique<SSLConnection>(_serverContext, socket, initialBytes, len); - - int ret; - do { - ret = ::SSL_accept(sslConn->ssl); - } while(!_doneWithSSLOp(sslConn.get(), ret)); - - if (ret != 1) - _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); - - return sslConn.release(); + // match wildcard DNS names + if (certHostName[0] == '*' && certHostName[1] == '.') { + // allow name.example.com if the cert is *.example.com, '*' does not match '.' + const char* subName = strchr(nameToMatch, '.'); + return subName && !strcasecmp(certHostName + 1, subName); + } else { + return !strcasecmp(nameToMatch, certHostName); } +} - // TODO SERVER-11601 Use NFC Unicode canonicalization - bool SSLManager::_hostNameMatch(const char* nameToMatch, - const char* certHostName) { - if (strlen(certHostName) < 2) { - return false; - } - - // match wildcard DNS names - if (certHostName[0] == '*' && certHostName[1] == '.') { - // allow name.example.com if the cert is *.example.com, '*' does not match '.' - const char* subName = strchr(nameToMatch, '.'); - return subName && !strcasecmp(certHostName+1, subName); +std::string SSLManager::parseAndValidatePeerCertificate(const SSLConnection* conn, + const std::string& remoteHost) { + // only set if a CA cert has been provided + if (!_sslConfiguration.hasCA) + return ""; + + X509* peerCert = SSL_get_peer_certificate(conn->ssl); + + if (NULL == peerCert) { // no certificate presented by peer + if (_weakValidation) { + warning() << "no SSL certificate provided by peer" << endl; + } else { + error() << "no SSL certificate provided by peer; connection rejected" << endl; + throw SocketException(SocketException::CONNECT_ERROR, ""); } - else { - return !strcasecmp(nameToMatch, certHostName); + return ""; + } + ON_BLOCK_EXIT(X509_free, peerCert); + + long result = SSL_get_verify_result(conn->ssl); + + if (result != X509_V_OK) { + if (_allowInvalidCertificates) { + warning() << "SSL peer certificate validation failed:" + << X509_verify_cert_error_string(result); + } else { + error() << "SSL peer certificate validation failed:" + << X509_verify_cert_error_string(result); + throw SocketException(SocketException::CONNECT_ERROR, ""); } } - std::string SSLManager::parseAndValidatePeerCertificate(const SSLConnection* conn, - const std::string& remoteHost) { - // only set if a CA cert has been provided - if (!_sslConfiguration.hasCA) return ""; + // TODO: check optional cipher restriction, using cert. + std::string peerSubjectName = getCertificateSubjectName(peerCert); - X509* peerCert = SSL_get_peer_certificate(conn->ssl); + // If this is an SSL client context (on a MongoDB server or client) + // perform hostname validation of the remote server + if (remoteHost.empty()) { + return peerSubjectName; + } - if (NULL == peerCert) { // no certificate presented by peer - if (_weakValidation) { - warning() << "no SSL certificate provided by peer" << endl; - } - else { - error() << "no SSL certificate provided by peer; connection rejected" << endl; - throw SocketException(SocketException::CONNECT_ERROR, ""); - } - return ""; - } - ON_BLOCK_EXIT(X509_free, peerCert); + // Try to match using the Subject Alternate Name, if it exists. + // RFC-2818 requires the Subject Alternate Name to be used if present. + // Otherwise, the most specific Common Name field in the subject field + // must be used. - long result = SSL_get_verify_result(conn->ssl); + bool sanMatch = false; + bool cnMatch = false; - if (result != X509_V_OK) { - if (_allowInvalidCertificates) { - warning() << "SSL peer certificate validation failed:" << - X509_verify_cert_error_string(result); - } - else { - error() << "SSL peer certificate validation failed:" << - X509_verify_cert_error_string(result); - throw SocketException(SocketException::CONNECT_ERROR, ""); - } - } - - // TODO: check optional cipher restriction, using cert. - std::string peerSubjectName = getCertificateSubjectName(peerCert); - - // If this is an SSL client context (on a MongoDB server or client) - // perform hostname validation of the remote server - if (remoteHost.empty()) { - return peerSubjectName; - } + STACK_OF(GENERAL_NAME)* sanNames = static_cast<STACK_OF(GENERAL_NAME)*>( + X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL)); - // Try to match using the Subject Alternate Name, if it exists. - // RFC-2818 requires the Subject Alternate Name to be used if present. - // Otherwise, the most specific Common Name field in the subject field - // must be used. - - bool sanMatch = false; - bool cnMatch = false; - - STACK_OF(GENERAL_NAME)* sanNames = static_cast<STACK_OF(GENERAL_NAME)*> - (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL)); - - if (sanNames != NULL) { - int sanNamesList = sk_GENERAL_NAME_num(sanNames); - for (int i = 0; i < sanNamesList; i++) { - const GENERAL_NAME* currentName = sk_GENERAL_NAME_value(sanNames, i); - if (currentName && currentName->type == GEN_DNS) { - char *dnsName = - reinterpret_cast<char *>(ASN1_STRING_data(currentName->d.dNSName)); - if (_hostNameMatch(remoteHost.c_str(), dnsName)) { - sanMatch = true; - break; - } + if (sanNames != NULL) { + int sanNamesList = sk_GENERAL_NAME_num(sanNames); + for (int i = 0; i < sanNamesList; i++) { + const GENERAL_NAME* currentName = sk_GENERAL_NAME_value(sanNames, i); + if (currentName && currentName->type == GEN_DNS) { + char* dnsName = reinterpret_cast<char*>(ASN1_STRING_data(currentName->d.dNSName)); + if (_hostNameMatch(remoteHost.c_str(), dnsName)) { + sanMatch = true; + break; } } - sk_GENERAL_NAME_pop_free(sanNames, GENERAL_NAME_free); - } - else { - // If Subject Alternate Name (SAN) didn't exist, check Common Name (CN). - int cnBegin = peerSubjectName.find("CN=") + 3; - int cnEnd = peerSubjectName.find(",", cnBegin); - std::string commonName = peerSubjectName.substr(cnBegin, cnEnd-cnBegin); - - if (_hostNameMatch(remoteHost.c_str(), commonName.c_str())) { - cnMatch = true; - } } + sk_GENERAL_NAME_pop_free(sanNames, GENERAL_NAME_free); + } else { + // If Subject Alternate Name (SAN) didn't exist, check Common Name (CN). + int cnBegin = peerSubjectName.find("CN=") + 3; + int cnEnd = peerSubjectName.find(",", cnBegin); + std::string commonName = peerSubjectName.substr(cnBegin, cnEnd - cnBegin); - if (!sanMatch && !cnMatch) { - if (_allowInvalidCertificates || _allowInvalidHostnames) { - warning() << "The server certificate does not match the host name " << - remoteHost; - } - else { - error() << "The server certificate does not match the host name " << - remoteHost; - throw SocketException(SocketException::CONNECT_ERROR, ""); - } + if (_hostNameMatch(remoteHost.c_str(), commonName.c_str())) { + cnMatch = true; } - - return peerSubjectName; } - void SSLManager::cleanupThreadLocals() { - ERR_remove_state(0); + if (!sanMatch && !cnMatch) { + if (_allowInvalidCertificates || _allowInvalidHostnames) { + warning() << "The server certificate does not match the host name " << remoteHost; + } else { + error() << "The server certificate does not match the host name " << remoteHost; + throw SocketException(SocketException::CONNECT_ERROR, ""); + } } - std::string SSLManagerInterface::getSSLErrorMessage(int code) { - // 120 from the SSL documentation for ERR_error_string - static const size_t msglen = 120; + return peerSubjectName; +} - char msg[msglen]; - ERR_error_string_n(code, msg, msglen); - return msg; - } +void SSLManager::cleanupThreadLocals() { + ERR_remove_state(0); +} + +std::string SSLManagerInterface::getSSLErrorMessage(int code) { + // 120 from the SSL documentation for ERR_error_string + static const size_t msglen = 120; + + char msg[msglen]; + ERR_error_string_n(code, msg, msglen); + return msg; +} + +void SSLManager::_handleSSLError(int code, int ret) { + int err = ERR_get_error(); - void SSLManager::_handleSSLError(int code, int ret) { - int err = ERR_get_error(); - - switch (code) { + switch (code) { case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: // should not happen because we turned on AUTO_RETRY @@ -1010,7 +981,7 @@ namespace mongo { error() << "SSL: " << code << ", possibly timed out during connect"; break; - case SSL_ERROR_ZERO_RETURN: + case SSL_ERROR_ZERO_RETURN: // TODO: Check if we can avoid throwing an exception for this condition LOG(3) << "SSL network connection closed"; break; @@ -1019,25 +990,22 @@ namespace mongo { // check the return value of the actual SSL operation if (err != 0) { error() << "SSL: " << getSSLErrorMessage(err); - } - else if (ret == 0) { + } else if (ret == 0) { error() << "Unexpected EOF encountered during SSL communication"; - } - else { + } else { error() << "The SSL BIO reported an I/O error " << errnoWithDescription(); } break; - case SSL_ERROR_SSL: - { + case SSL_ERROR_SSL: { error() << "SSL: " << getSSLErrorMessage(err); break; } - + default: error() << "unrecognized SSL error"; break; - } - throw SocketException(SocketException::CONNECT_ERROR, ""); } -#endif // #ifdef MONGO_CONFIG_SSL + throw SocketException(SocketException::CONNECT_ERROR, ""); +} +#endif // #ifdef MONGO_CONFIG_SSL } diff --git a/src/mongo/util/net/ssl_manager.h b/src/mongo/util/net/ssl_manager.h index b9af6c424b4..1c6295ed517 100644 --- a/src/mongo/util/net/ssl_manager.h +++ b/src/mongo/util/net/ssl_manager.h @@ -42,117 +42,115 @@ #include <openssl/err.h> #include <openssl/ssl.h> -#endif // #ifdef MONGO_CONFIG_SSL +#endif // #ifdef MONGO_CONFIG_SSL namespace mongo { - /* - * @return the SSL version std::string prefixed with prefix and suffixed with suffix - */ - const std::string getSSLVersion(const std::string &prefix, const std::string &suffix); +/* + * @return the SSL version std::string prefixed with prefix and suffixed with suffix + */ +const std::string getSSLVersion(const std::string& prefix, const std::string& suffix); } #ifdef MONGO_CONFIG_SSL namespace mongo { - struct SSLParams; - - class SSLConnection { - public: - SSL* ssl; - BIO* networkBIO; - BIO* internalBIO; - Socket* socket; - - SSLConnection(SSL_CTX* ctx, Socket* sock, const char* initialBytes, int len); - - ~SSLConnection(); - }; - - struct SSLConfiguration { - SSLConfiguration() : - serverSubjectName(""), clientSubjectName(""), - hasCA(false) {} - SSLConfiguration(const std::string& serverSubjectName, - const std::string& clientSubjectName, - const Date_t& serverCertificateExpirationDate, - bool hasCA) : - serverSubjectName(serverSubjectName), - clientSubjectName(clientSubjectName), - serverCertificateExpirationDate(serverCertificateExpirationDate), - hasCA(hasCA) {} - - BSONObj getServerStatusBSON() const; - std::string serverSubjectName; - std::string clientSubjectName; - Date_t serverCertificateExpirationDate; - bool hasCA; - }; - - class SSLManagerInterface { - public: - static std::unique_ptr<SSLManagerInterface> create(const SSLParams& params, bool isServer); - - virtual ~SSLManagerInterface(); - - /** - * Initiates a TLS connection. - * Throws SocketException on failure. - * @return a pointer to an SSLConnection. Resources are freed in SSLConnection's destructor - */ - virtual SSLConnection* connect(Socket* socket) = 0; - - /** - * Waits for the other side to initiate a TLS connection. - * Throws SocketException on failure. - * @return a pointer to an SSLConnection. Resources are freed in SSLConnection's destructor - */ - virtual SSLConnection* accept(Socket* socket, const char* initialBytes, int len) = 0; - - /** - * Fetches a peer certificate and validates it if it exists - * Throws SocketException on failure - * @return a std::string containing the certificate's subject name. - */ - virtual std::string parseAndValidatePeerCertificate(const SSLConnection* conn, - const std::string& remoteHost) = 0; - - /** - * Cleans up SSL thread local memory; use at thread exit - * to avoid memory leaks - */ - virtual void cleanupThreadLocals() = 0; - - /** - * Gets the SSLConfiguration containing all information about the current SSL setup - * @return the SSLConfiguration - */ - virtual const SSLConfiguration& getSSLConfiguration() const = 0; - - /** - * Fetches the error text for an error code, in a thread-safe manner. - */ - static std::string getSSLErrorMessage(int code); - - /** - * ssl.h wrappers - */ - virtual int SSL_read(SSLConnection* conn, void* buf, int num) = 0; - - virtual int SSL_write(SSLConnection* conn, const void* buf, int num) = 0; - - virtual unsigned long ERR_get_error() = 0; - - virtual char* ERR_error_string(unsigned long e, char* buf) = 0; - - virtual int SSL_get_error(const SSLConnection* conn, int ret) = 0; - - virtual int SSL_shutdown(SSLConnection* conn) = 0; - - virtual void SSL_free(SSLConnection* conn) = 0; - }; - - // Access SSL functions through this instance. - SSLManagerInterface* getSSLManager(); - - extern bool isSSLServer; +struct SSLParams; + +class SSLConnection { +public: + SSL* ssl; + BIO* networkBIO; + BIO* internalBIO; + Socket* socket; + + SSLConnection(SSL_CTX* ctx, Socket* sock, const char* initialBytes, int len); + + ~SSLConnection(); +}; + +struct SSLConfiguration { + SSLConfiguration() : serverSubjectName(""), clientSubjectName(""), hasCA(false) {} + SSLConfiguration(const std::string& serverSubjectName, + const std::string& clientSubjectName, + const Date_t& serverCertificateExpirationDate, + bool hasCA) + : serverSubjectName(serverSubjectName), + clientSubjectName(clientSubjectName), + serverCertificateExpirationDate(serverCertificateExpirationDate), + hasCA(hasCA) {} + + BSONObj getServerStatusBSON() const; + std::string serverSubjectName; + std::string clientSubjectName; + Date_t serverCertificateExpirationDate; + bool hasCA; +}; + +class SSLManagerInterface { +public: + static std::unique_ptr<SSLManagerInterface> create(const SSLParams& params, bool isServer); + + virtual ~SSLManagerInterface(); + + /** + * Initiates a TLS connection. + * Throws SocketException on failure. + * @return a pointer to an SSLConnection. Resources are freed in SSLConnection's destructor + */ + virtual SSLConnection* connect(Socket* socket) = 0; + + /** + * Waits for the other side to initiate a TLS connection. + * Throws SocketException on failure. + * @return a pointer to an SSLConnection. Resources are freed in SSLConnection's destructor + */ + virtual SSLConnection* accept(Socket* socket, const char* initialBytes, int len) = 0; + + /** + * Fetches a peer certificate and validates it if it exists + * Throws SocketException on failure + * @return a std::string containing the certificate's subject name. + */ + virtual std::string parseAndValidatePeerCertificate(const SSLConnection* conn, + const std::string& remoteHost) = 0; + + /** + * Cleans up SSL thread local memory; use at thread exit + * to avoid memory leaks + */ + virtual void cleanupThreadLocals() = 0; + + /** + * Gets the SSLConfiguration containing all information about the current SSL setup + * @return the SSLConfiguration + */ + virtual const SSLConfiguration& getSSLConfiguration() const = 0; + + /** + * Fetches the error text for an error code, in a thread-safe manner. + */ + static std::string getSSLErrorMessage(int code); + + /** + * ssl.h wrappers + */ + virtual int SSL_read(SSLConnection* conn, void* buf, int num) = 0; + + virtual int SSL_write(SSLConnection* conn, const void* buf, int num) = 0; + + virtual unsigned long ERR_get_error() = 0; + + virtual char* ERR_error_string(unsigned long e, char* buf) = 0; + + virtual int SSL_get_error(const SSLConnection* conn, int ret) = 0; + + virtual int SSL_shutdown(SSLConnection* conn) = 0; + + virtual void SSL_free(SSLConnection* conn) = 0; +}; + +// Access SSL functions through this instance. +SSLManagerInterface* getSSLManager(); + +extern bool isSSLServer; } -#endif // #ifdef MONGO_CONFIG_SSL +#endif // #ifdef MONGO_CONFIG_SSL diff --git a/src/mongo/util/net/ssl_options.cpp b/src/mongo/util/net/ssl_options.cpp index a8cf9646bef..881718179a4 100644 --- a/src/mongo/util/net/ssl_options.cpp +++ b/src/mongo/util/net/ssl_options.cpp @@ -41,348 +41,356 @@ namespace mongo { - using std::string; - - Status addSSLServerOptions(moe::OptionSection* options) { - options->addOptionChaining("net.ssl.sslOnNormalPorts", "sslOnNormalPorts", moe::Switch, - "use ssl on configured ports") - .setSources(moe::SourceAllLegacy) - .incompatibleWith("net.ssl.mode"); - - options->addOptionChaining("net.ssl.mode", "sslMode", moe::String, - "set the SSL operation mode (disabled|allowSSL|preferSSL|requireSSL)"); - - options->addOptionChaining("net.ssl.PEMKeyFile", "sslPEMKeyFile", moe::String, - "PEM file for ssl"); - - options->addOptionChaining("net.ssl.PEMKeyPassword", "sslPEMKeyPassword", moe::String, - "PEM file password") - .setImplicit(moe::Value(std::string(""))); - - options->addOptionChaining("net.ssl.clusterFile", "sslClusterFile", moe::String, - "Key file for internal SSL authentication"); - - options->addOptionChaining("net.ssl.clusterPassword", "sslClusterPassword", moe::String, - "Internal authentication key file password") - .setImplicit(moe::Value(std::string(""))); - - options->addOptionChaining("net.ssl.CAFile", "sslCAFile", moe::String, - "Certificate Authority file for SSL"); - - options->addOptionChaining("net.ssl.CRLFile", "sslCRLFile", moe::String, - "Certificate Revocation List file for SSL"); - - options->addOptionChaining("net.ssl.sslCipherConfig", "sslCipherConfig", moe::String, - "OpenSSL cipher configuration string") - .hidden(); - - options->addOptionChaining("net.ssl.disabledProtocols", "sslDisabledProtocols", moe::String, - "Comma separated list of disabled protocols") - .hidden(); - - options->addOptionChaining("net.ssl.weakCertificateValidation", - "sslWeakCertificateValidation", moe::Switch, "allow client to connect without " - "presenting a certificate"); - - // Alias for --sslWeakCertificateValidation. - options->addOptionChaining("net.ssl.allowConnectionsWithoutCertificates", - "sslAllowConnectionsWithoutCertificates", moe::Switch, - "allow client to connect without presenting a certificate"); - - options->addOptionChaining("net.ssl.allowInvalidHostnames", "sslAllowInvalidHostnames", - moe::Switch, "Allow server certificates to provide non-matching hostnames"); - - options->addOptionChaining("net.ssl.allowInvalidCertificates", "sslAllowInvalidCertificates", - moe::Switch, "allow connections to servers with invalid certificates"); - - options->addOptionChaining("net.ssl.FIPSMode", "sslFIPSMode", moe::Switch, - "activate FIPS 140-2 mode at startup"); - - return Status::OK(); - } - - Status addSSLClientOptions(moe::OptionSection* options) { - options->addOptionChaining("ssl", "ssl", moe::Switch, "use SSL for all connections"); - - options->addOptionChaining("ssl.CAFile", "sslCAFile", moe::String, - "Certificate Authority file for SSL") - .requires("ssl"); - - options->addOptionChaining("ssl.PEMKeyFile", "sslPEMKeyFile", moe::String, - "PEM certificate/key file for SSL") - .requires("ssl"); - - options->addOptionChaining("ssl.PEMKeyPassword", "sslPEMKeyPassword", moe::String, - "password for key in PEM file for SSL") - .requires("ssl"); - - options->addOptionChaining("ssl.CRLFile", "sslCRLFile", moe::String, - "Certificate Revocation List file for SSL") - .requires("ssl") - .requires("ssl.CAFile"); - - options->addOptionChaining("net.ssl.disabledProtocols", "sslDisabledProtocols", moe::String, - "Comma separated list of disabled protocols") - .requires("ssl") - .hidden(); - - options->addOptionChaining("net.ssl.allowInvalidHostnames", "sslAllowInvalidHostnames", - moe::Switch, "allow connections to servers with non-matching hostnames") - .requires("ssl"); - - options->addOptionChaining("ssl.allowInvalidCertificates", "sslAllowInvalidCertificates", - moe::Switch, "allow connections to servers with invalid certificates") - .requires("ssl"); - - options->addOptionChaining("ssl.FIPSMode", "sslFIPSMode", moe::Switch, - "activate FIPS 140-2 mode at startup") - .requires("ssl"); - - return Status::OK(); - } - - Status validateSSLServerOptions(const moe::Environment& params) { +using std::string; + +Status addSSLServerOptions(moe::OptionSection* options) { + options->addOptionChaining("net.ssl.sslOnNormalPorts", + "sslOnNormalPorts", + moe::Switch, + "use ssl on configured ports") + .setSources(moe::SourceAllLegacy) + .incompatibleWith("net.ssl.mode"); + + options->addOptionChaining( + "net.ssl.mode", + "sslMode", + moe::String, + "set the SSL operation mode (disabled|allowSSL|preferSSL|requireSSL)"); + + options->addOptionChaining( + "net.ssl.PEMKeyFile", "sslPEMKeyFile", moe::String, "PEM file for ssl"); + + options->addOptionChaining( + "net.ssl.PEMKeyPassword", "sslPEMKeyPassword", moe::String, "PEM file password") + .setImplicit(moe::Value(std::string(""))); + + options->addOptionChaining("net.ssl.clusterFile", + "sslClusterFile", + moe::String, + "Key file for internal SSL authentication"); + + options->addOptionChaining("net.ssl.clusterPassword", + "sslClusterPassword", + moe::String, + "Internal authentication key file password") + .setImplicit(moe::Value(std::string(""))); + + options->addOptionChaining( + "net.ssl.CAFile", "sslCAFile", moe::String, "Certificate Authority file for SSL"); + + options->addOptionChaining( + "net.ssl.CRLFile", "sslCRLFile", moe::String, "Certificate Revocation List file for SSL"); + + options->addOptionChaining("net.ssl.sslCipherConfig", + "sslCipherConfig", + moe::String, + "OpenSSL cipher configuration string").hidden(); + + options->addOptionChaining("net.ssl.disabledProtocols", + "sslDisabledProtocols", + moe::String, + "Comma separated list of disabled protocols").hidden(); + + options->addOptionChaining("net.ssl.weakCertificateValidation", + "sslWeakCertificateValidation", + moe::Switch, + "allow client to connect without " + "presenting a certificate"); + + // Alias for --sslWeakCertificateValidation. + options->addOptionChaining("net.ssl.allowConnectionsWithoutCertificates", + "sslAllowConnectionsWithoutCertificates", + moe::Switch, + "allow client to connect without presenting a certificate"); + + options->addOptionChaining("net.ssl.allowInvalidHostnames", + "sslAllowInvalidHostnames", + moe::Switch, + "Allow server certificates to provide non-matching hostnames"); + + options->addOptionChaining("net.ssl.allowInvalidCertificates", + "sslAllowInvalidCertificates", + moe::Switch, + "allow connections to servers with invalid certificates"); + + options->addOptionChaining( + "net.ssl.FIPSMode", "sslFIPSMode", moe::Switch, "activate FIPS 140-2 mode at startup"); + + return Status::OK(); +} + +Status addSSLClientOptions(moe::OptionSection* options) { + options->addOptionChaining("ssl", "ssl", moe::Switch, "use SSL for all connections"); + + options->addOptionChaining( + "ssl.CAFile", "sslCAFile", moe::String, "Certificate Authority file for SSL") + .requires("ssl"); + + options->addOptionChaining( + "ssl.PEMKeyFile", "sslPEMKeyFile", moe::String, "PEM certificate/key file for SSL") + .requires("ssl"); + + options->addOptionChaining("ssl.PEMKeyPassword", + "sslPEMKeyPassword", + moe::String, + "password for key in PEM file for SSL").requires("ssl"); + + options->addOptionChaining("ssl.CRLFile", + "sslCRLFile", + moe::String, + "Certificate Revocation List file for SSL") + .requires("ssl") + .requires("ssl.CAFile"); + + options->addOptionChaining("net.ssl.disabledProtocols", + "sslDisabledProtocols", + moe::String, + "Comma separated list of disabled protocols") + .requires("ssl") + .hidden(); + + options->addOptionChaining("net.ssl.allowInvalidHostnames", + "sslAllowInvalidHostnames", + moe::Switch, + "allow connections to servers with non-matching hostnames") + .requires("ssl"); + + options->addOptionChaining("ssl.allowInvalidCertificates", + "sslAllowInvalidCertificates", + moe::Switch, + "allow connections to servers with invalid certificates") + .requires("ssl"); + + options->addOptionChaining( + "ssl.FIPSMode", "sslFIPSMode", moe::Switch, "activate FIPS 140-2 mode at startup") + .requires("ssl"); + + return Status::OK(); +} + +Status validateSSLServerOptions(const moe::Environment& params) { #ifdef _WIN32 - if (params.count("install") || params.count("reinstall")) { - if (params.count("net.ssl.PEMKeyFile") && - !boost::filesystem::path(params["net.ssl.PEMKeyFile"].as<string>()).is_absolute()) { - return Status(ErrorCodes::BadValue, - "PEMKeyFile requires an absolute file path with Windows services"); - } - - if (params.count("net.ssl.clusterFile") && - !boost::filesystem::path( - params["net.ssl.clusterFile"].as<string>()).is_absolute()) { - return Status(ErrorCodes::BadValue, - "clusterFile requires an absolute file path with Windows services"); - } - - if (params.count("net.ssl.CAFile") && - !boost::filesystem::path(params["net.ssl.CAFile"].as<string>()).is_absolute()) { - return Status(ErrorCodes::BadValue, - "CAFile requires an absolute file path with Windows services"); - } - - if (params.count("net.ssl.CRLFile") && - !boost::filesystem::path(params["net.ssl.CRLFile"].as<string>()).is_absolute()) { - return Status(ErrorCodes::BadValue, - "CRLFile requires an absolute file path with Windows services"); - } - + if (params.count("install") || params.count("reinstall")) { + if (params.count("net.ssl.PEMKeyFile") && + !boost::filesystem::path(params["net.ssl.PEMKeyFile"].as<string>()).is_absolute()) { + return Status(ErrorCodes::BadValue, + "PEMKeyFile requires an absolute file path with Windows services"); } -#endif - return Status::OK(); - } - - Status canonicalizeSSLServerOptions(moe::Environment* params) { + if (params.count("net.ssl.clusterFile") && + !boost::filesystem::path(params["net.ssl.clusterFile"].as<string>()).is_absolute()) { + return Status(ErrorCodes::BadValue, + "clusterFile requires an absolute file path with Windows services"); + } - if (params->count("net.ssl.sslOnNormalPorts") && - (*params)["net.ssl.sslOnNormalPorts"].as<bool>() == true) { - Status ret = params->set("net.ssl.mode", moe::Value(std::string("requireSSL"))); - if (!ret.isOK()) { - return ret; - } - ret = params->remove("net.ssl.sslOnNormalPorts"); - if (!ret.isOK()) { - return ret; - } + if (params.count("net.ssl.CAFile") && + !boost::filesystem::path(params["net.ssl.CAFile"].as<string>()).is_absolute()) { + return Status(ErrorCodes::BadValue, + "CAFile requires an absolute file path with Windows services"); } - return Status::OK(); + if (params.count("net.ssl.CRLFile") && + !boost::filesystem::path(params["net.ssl.CRLFile"].as<string>()).is_absolute()) { + return Status(ErrorCodes::BadValue, + "CRLFile requires an absolute file path with Windows services"); + } } +#endif - Status storeSSLServerOptions(const moe::Environment& params) { + return Status::OK(); +} - if (params.count("net.ssl.mode")) { - std::string sslModeParam = params["net.ssl.mode"].as<string>(); - if (sslModeParam == "disabled") { - sslGlobalParams.sslMode.store(SSLParams::SSLMode_disabled); - } - else if (sslModeParam == "allowSSL") { - sslGlobalParams.sslMode.store(SSLParams::SSLMode_allowSSL); - } - else if (sslModeParam == "preferSSL") { - sslGlobalParams.sslMode.store(SSLParams::SSLMode_preferSSL); - } - else if (sslModeParam == "requireSSL") { - sslGlobalParams.sslMode.store(SSLParams::SSLMode_requireSSL); - } - else { - return Status(ErrorCodes::BadValue, - "unsupported value for sslMode " + sslModeParam ); - } +Status canonicalizeSSLServerOptions(moe::Environment* params) { + if (params->count("net.ssl.sslOnNormalPorts") && + (*params)["net.ssl.sslOnNormalPorts"].as<bool>() == true) { + Status ret = params->set("net.ssl.mode", moe::Value(std::string("requireSSL"))); + if (!ret.isOK()) { + return ret; } - - if (params.count("net.ssl.PEMKeyFile")) { - sslGlobalParams.sslPEMKeyFile = boost::filesystem::absolute( - params["net.ssl.PEMKeyFile"].as<string>()).generic_string(); + ret = params->remove("net.ssl.sslOnNormalPorts"); + if (!ret.isOK()) { + return ret; } + } - if (params.count("net.ssl.PEMKeyPassword")) { - sslGlobalParams.sslPEMKeyPassword = params["net.ssl.PEMKeyPassword"].as<string>(); + return Status::OK(); +} + +Status storeSSLServerOptions(const moe::Environment& params) { + if (params.count("net.ssl.mode")) { + std::string sslModeParam = params["net.ssl.mode"].as<string>(); + if (sslModeParam == "disabled") { + sslGlobalParams.sslMode.store(SSLParams::SSLMode_disabled); + } else if (sslModeParam == "allowSSL") { + sslGlobalParams.sslMode.store(SSLParams::SSLMode_allowSSL); + } else if (sslModeParam == "preferSSL") { + sslGlobalParams.sslMode.store(SSLParams::SSLMode_preferSSL); + } else if (sslModeParam == "requireSSL") { + sslGlobalParams.sslMode.store(SSLParams::SSLMode_requireSSL); + } else { + return Status(ErrorCodes::BadValue, "unsupported value for sslMode " + sslModeParam); } + } - if (params.count("net.ssl.clusterFile")) { - sslGlobalParams.sslClusterFile = - boost::filesystem::absolute( - params["net.ssl.clusterFile"].as<string>()).generic_string(); - } + if (params.count("net.ssl.PEMKeyFile")) { + sslGlobalParams.sslPEMKeyFile = + boost::filesystem::absolute(params["net.ssl.PEMKeyFile"].as<string>()).generic_string(); + } - if (params.count("net.ssl.clusterPassword")) { - sslGlobalParams.sslClusterPassword = params["net.ssl.clusterPassword"].as<string>(); - } + if (params.count("net.ssl.PEMKeyPassword")) { + sslGlobalParams.sslPEMKeyPassword = params["net.ssl.PEMKeyPassword"].as<string>(); + } - if (params.count("net.ssl.CAFile")) { - sslGlobalParams.sslCAFile = boost::filesystem::absolute( - params["net.ssl.CAFile"].as<std::string>()).generic_string(); - } + if (params.count("net.ssl.clusterFile")) { + sslGlobalParams.sslClusterFile = + boost::filesystem::absolute(params["net.ssl.clusterFile"].as<string>()) + .generic_string(); + } - if (params.count("net.ssl.CRLFile")) { - sslGlobalParams.sslCRLFile = boost::filesystem::absolute( - params["net.ssl.CRLFile"].as<std::string>()).generic_string(); - } + if (params.count("net.ssl.clusterPassword")) { + sslGlobalParams.sslClusterPassword = params["net.ssl.clusterPassword"].as<string>(); + } - if (params.count("net.ssl.sslCipherConfig")) { - sslGlobalParams.sslCipherConfig = params["net.ssl.sslCipherConfig"].as<string>(); - } + if (params.count("net.ssl.CAFile")) { + sslGlobalParams.sslCAFile = + boost::filesystem::absolute(params["net.ssl.CAFile"].as<std::string>()) + .generic_string(); + } - if (params.count("net.ssl.disabledProtocols")) { - std::vector<std::string> tokens = StringSplitter::split( - params["net.ssl.disabledProtocols"].as<string>(), ","); - - const std::map<std::string, SSLParams::Protocols> validConfigs { - {"noTLS1_0", SSLParams::Protocols::TLS1_0}, - {"noTLS1_1", SSLParams::Protocols::TLS1_1}, - {"noTLS1_2", SSLParams::Protocols::TLS1_2} - }; - for (const std::string& token : tokens) { - auto mappedToken = validConfigs.find(token); - if (mappedToken != validConfigs.end()) { - sslGlobalParams.sslDisabledProtocols.push_back(mappedToken->second); - } else { - return Status(ErrorCodes::BadValue, - "Unrecognized disabledProtocols '" + token +"'"); - } - } - } + if (params.count("net.ssl.CRLFile")) { + sslGlobalParams.sslCRLFile = + boost::filesystem::absolute(params["net.ssl.CRLFile"].as<std::string>()) + .generic_string(); + } - if (params.count("net.ssl.weakCertificateValidation")) { - sslGlobalParams.sslWeakCertificateValidation = - params["net.ssl.weakCertificateValidation"].as<bool>(); - } - else if (params.count("net.ssl.allowConnectionsWithoutCertificates")) { - sslGlobalParams.sslWeakCertificateValidation = - params["net.ssl.allowConnectionsWithoutCertificates"].as<bool>(); - } - if (params.count("net.ssl.allowInvalidHostnames")) { - sslGlobalParams.sslAllowInvalidHostnames = - params["net.ssl.allowInvalidHostnames"].as<bool>(); - } - if (params.count("net.ssl.allowInvalidCertificates")) { - sslGlobalParams.sslAllowInvalidCertificates = - params["net.ssl.allowInvalidCertificates"].as<bool>(); - } - if (params.count("net.ssl.FIPSMode")) { - sslGlobalParams.sslFIPSMode = params["net.ssl.FIPSMode"].as<bool>(); - } + if (params.count("net.ssl.sslCipherConfig")) { + sslGlobalParams.sslCipherConfig = params["net.ssl.sslCipherConfig"].as<string>(); + } - int clusterAuthMode = serverGlobalParams.clusterAuthMode.load(); - if (sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled) { - if (sslGlobalParams.sslPEMKeyFile.size() == 0) { - return Status(ErrorCodes::BadValue, - "need sslPEMKeyFile when SSL is enabled"); - } - if (sslGlobalParams.sslWeakCertificateValidation && - sslGlobalParams.sslCAFile.empty()) { + if (params.count("net.ssl.disabledProtocols")) { + std::vector<std::string> tokens = + StringSplitter::split(params["net.ssl.disabledProtocols"].as<string>(), ","); + + const std::map<std::string, SSLParams::Protocols> validConfigs{ + {"noTLS1_0", SSLParams::Protocols::TLS1_0}, + {"noTLS1_1", SSLParams::Protocols::TLS1_1}, + {"noTLS1_2", SSLParams::Protocols::TLS1_2}}; + for (const std::string& token : tokens) { + auto mappedToken = validConfigs.find(token); + if (mappedToken != validConfigs.end()) { + sslGlobalParams.sslDisabledProtocols.push_back(mappedToken->second); + } else { return Status(ErrorCodes::BadValue, - "need sslCAFile with sslWeakCertificateValidation"); - } - if (!sslGlobalParams.sslCRLFile.empty() && - sslGlobalParams.sslCAFile.empty()) { - return Status(ErrorCodes::BadValue, "need sslCAFile with sslCRLFile"); - } - std::string sslCANotFoundError("No SSL certificate validation can be performed since" - " no CA file has been provided; please specify an" - " sslCAFile parameter"); - - if (sslGlobalParams.sslCAFile.empty()) { - if (clusterAuthMode == ServerGlobalParams::ClusterAuthMode_x509) { - return Status(ErrorCodes::BadValue, sslCANotFoundError); - } - warning() << sslCANotFoundError; + "Unrecognized disabledProtocols '" + token + "'"); } } - else if (sslGlobalParams.sslPEMKeyFile.size() || - sslGlobalParams.sslPEMKeyPassword.size() || - sslGlobalParams.sslClusterFile.size() || - sslGlobalParams.sslClusterPassword.size() || - sslGlobalParams.sslCAFile.size() || - sslGlobalParams.sslCRLFile.size() || - sslGlobalParams.sslCipherConfig.size() || - sslGlobalParams.sslDisabledProtocols.size() || - sslGlobalParams.sslWeakCertificateValidation || - sslGlobalParams.sslFIPSMode) { - return Status(ErrorCodes::BadValue, - "need to enable SSL via the sslMode flag when " - "using SSL configuration parameters"); - } - if (clusterAuthMode == ServerGlobalParams::ClusterAuthMode_sendKeyFile || - clusterAuthMode == ServerGlobalParams::ClusterAuthMode_sendX509 || - clusterAuthMode == ServerGlobalParams::ClusterAuthMode_x509) { - if (sslGlobalParams.sslMode.load() == SSLParams::SSLMode_disabled) { - return Status(ErrorCodes::BadValue, "need to enable SSL via the sslMode flag"); - } - } - if (sslGlobalParams.sslMode.load() == SSLParams::SSLMode_allowSSL) { - if (clusterAuthMode == ServerGlobalParams::ClusterAuthMode_sendX509 || - clusterAuthMode == ServerGlobalParams::ClusterAuthMode_x509) { - return Status(ErrorCodes::BadValue, - "cannot have x.509 cluster authentication in allowSSL mode"); - } - } - return Status::OK(); } - Status storeSSLClientOptions(const moe::Environment& params) { - if (params.count("ssl") && params["ssl"].as<bool>() == true) { - sslGlobalParams.sslMode.store(SSLParams::SSLMode_requireSSL); - } - if (params.count("ssl.PEMKeyFile")) { - sslGlobalParams.sslPEMKeyFile = params["ssl.PEMKeyFile"].as<std::string>(); - } - if (params.count("ssl.PEMKeyPassword")) { - sslGlobalParams.sslPEMKeyPassword = params["ssl.PEMKeyPassword"].as<std::string>(); - } - if (params.count("ssl.CAFile")) { - sslGlobalParams.sslCAFile = params["ssl.CAFile"].as<std::string>(); + if (params.count("net.ssl.weakCertificateValidation")) { + sslGlobalParams.sslWeakCertificateValidation = + params["net.ssl.weakCertificateValidation"].as<bool>(); + } else if (params.count("net.ssl.allowConnectionsWithoutCertificates")) { + sslGlobalParams.sslWeakCertificateValidation = + params["net.ssl.allowConnectionsWithoutCertificates"].as<bool>(); + } + if (params.count("net.ssl.allowInvalidHostnames")) { + sslGlobalParams.sslAllowInvalidHostnames = + params["net.ssl.allowInvalidHostnames"].as<bool>(); + } + if (params.count("net.ssl.allowInvalidCertificates")) { + sslGlobalParams.sslAllowInvalidCertificates = + params["net.ssl.allowInvalidCertificates"].as<bool>(); + } + if (params.count("net.ssl.FIPSMode")) { + sslGlobalParams.sslFIPSMode = params["net.ssl.FIPSMode"].as<bool>(); + } + + int clusterAuthMode = serverGlobalParams.clusterAuthMode.load(); + if (sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled) { + if (sslGlobalParams.sslPEMKeyFile.size() == 0) { + return Status(ErrorCodes::BadValue, "need sslPEMKeyFile when SSL is enabled"); } - if (params.count("ssl.CRLFile")) { - sslGlobalParams.sslCRLFile = params["ssl.CRLFile"].as<std::string>(); + if (sslGlobalParams.sslWeakCertificateValidation && sslGlobalParams.sslCAFile.empty()) { + return Status(ErrorCodes::BadValue, "need sslCAFile with sslWeakCertificateValidation"); } - if (params.count("net.ssl.allowInvalidHostnames")) { - sslGlobalParams.sslAllowInvalidHostnames = - params["net.ssl.allowInvalidHostnames"].as<bool>(); + if (!sslGlobalParams.sslCRLFile.empty() && sslGlobalParams.sslCAFile.empty()) { + return Status(ErrorCodes::BadValue, "need sslCAFile with sslCRLFile"); } - if (params.count("ssl.allowInvalidCertificates")) { - sslGlobalParams.sslAllowInvalidCertificates = true; + std::string sslCANotFoundError( + "No SSL certificate validation can be performed since" + " no CA file has been provided; please specify an" + " sslCAFile parameter"); + + if (sslGlobalParams.sslCAFile.empty()) { + if (clusterAuthMode == ServerGlobalParams::ClusterAuthMode_x509) { + return Status(ErrorCodes::BadValue, sslCANotFoundError); + } + warning() << sslCANotFoundError; } - if (params.count("ssl.FIPSMode")) { - sslGlobalParams.sslFIPSMode = true; + } else if (sslGlobalParams.sslPEMKeyFile.size() || sslGlobalParams.sslPEMKeyPassword.size() || + sslGlobalParams.sslClusterFile.size() || sslGlobalParams.sslClusterPassword.size() || + sslGlobalParams.sslCAFile.size() || sslGlobalParams.sslCRLFile.size() || + sslGlobalParams.sslCipherConfig.size() || + sslGlobalParams.sslDisabledProtocols.size() || + sslGlobalParams.sslWeakCertificateValidation || sslGlobalParams.sslFIPSMode) { + return Status(ErrorCodes::BadValue, + "need to enable SSL via the sslMode flag when " + "using SSL configuration parameters"); + } + if (clusterAuthMode == ServerGlobalParams::ClusterAuthMode_sendKeyFile || + clusterAuthMode == ServerGlobalParams::ClusterAuthMode_sendX509 || + clusterAuthMode == ServerGlobalParams::ClusterAuthMode_x509) { + if (sslGlobalParams.sslMode.load() == SSLParams::SSLMode_disabled) { + return Status(ErrorCodes::BadValue, "need to enable SSL via the sslMode flag"); } - return Status::OK(); } - - Status validateSSLMongoShellOptions(const moe::Environment& params) { - // Users must specify either a CAFile or allowInvalidCertificates if ssl=true. - if (params.count("ssl") && - params["ssl"].as<bool>() == true && - !params.count("ssl.CAFile") && - !params.count("ssl.allowInvalidCertificates")) { + if (sslGlobalParams.sslMode.load() == SSLParams::SSLMode_allowSSL) { + if (clusterAuthMode == ServerGlobalParams::ClusterAuthMode_sendX509 || + clusterAuthMode == ServerGlobalParams::ClusterAuthMode_x509) { return Status(ErrorCodes::BadValue, - "need to either provide sslCAFile or specify sslAllowInvalidCertificates"); + "cannot have x.509 cluster authentication in allowSSL mode"); } - return Status::OK(); } + return Status::OK(); +} + +Status storeSSLClientOptions(const moe::Environment& params) { + if (params.count("ssl") && params["ssl"].as<bool>() == true) { + sslGlobalParams.sslMode.store(SSLParams::SSLMode_requireSSL); + } + if (params.count("ssl.PEMKeyFile")) { + sslGlobalParams.sslPEMKeyFile = params["ssl.PEMKeyFile"].as<std::string>(); + } + if (params.count("ssl.PEMKeyPassword")) { + sslGlobalParams.sslPEMKeyPassword = params["ssl.PEMKeyPassword"].as<std::string>(); + } + if (params.count("ssl.CAFile")) { + sslGlobalParams.sslCAFile = params["ssl.CAFile"].as<std::string>(); + } + if (params.count("ssl.CRLFile")) { + sslGlobalParams.sslCRLFile = params["ssl.CRLFile"].as<std::string>(); + } + if (params.count("net.ssl.allowInvalidHostnames")) { + sslGlobalParams.sslAllowInvalidHostnames = + params["net.ssl.allowInvalidHostnames"].as<bool>(); + } + if (params.count("ssl.allowInvalidCertificates")) { + sslGlobalParams.sslAllowInvalidCertificates = true; + } + if (params.count("ssl.FIPSMode")) { + sslGlobalParams.sslFIPSMode = true; + } + return Status::OK(); +} + +Status validateSSLMongoShellOptions(const moe::Environment& params) { + // Users must specify either a CAFile or allowInvalidCertificates if ssl=true. + if (params.count("ssl") && params["ssl"].as<bool>() == true && !params.count("ssl.CAFile") && + !params.count("ssl.allowInvalidCertificates")) { + return Status(ErrorCodes::BadValue, + "need to either provide sslCAFile or specify sslAllowInvalidCertificates"); + } + return Status::OK(); +} -} // namespace mongo +} // namespace mongo diff --git a/src/mongo/util/net/ssl_options.h b/src/mongo/util/net/ssl_options.h index d348e004b4e..c5e39908bef 100644 --- a/src/mongo/util/net/ssl_options.h +++ b/src/mongo/util/net/ssl_options.h @@ -35,82 +35,78 @@ namespace mongo { - namespace optionenvironment { - class OptionSection; - class Environment; - } // namespace optionenvironment - - namespace moe = mongo::optionenvironment; - - struct SSLParams { - enum class Protocols { - TLS1_0, - TLS1_1, - TLS1_2 - }; - AtomicInt32 sslMode; // --sslMode - the SSL operation mode, see enum SSLModes - bool sslOnNormalPorts; // --sslOnNormalPorts (deprecated) - std::string sslPEMKeyFile; // --sslPEMKeyFile - std::string sslPEMKeyPassword; // --sslPEMKeyPassword - std::string sslClusterFile; // --sslInternalKeyFile - std::string sslClusterPassword; // --sslInternalKeyPassword - std::string sslCAFile; // --sslCAFile - std::string sslCRLFile; // --sslCRLFile - std::string sslCipherConfig; // --sslCipherConfig - std::vector<Protocols> sslDisabledProtocols; // --sslDisabledProtocols - bool sslWeakCertificateValidation; // --sslWeakCertificateValidation - bool sslFIPSMode; // --sslFIPSMode - bool sslAllowInvalidCertificates; // --sslAllowInvalidCertificates - bool sslAllowInvalidHostnames; // --sslAllowInvalidHostnames - - SSLParams() { - sslMode.store(SSLMode_disabled); - } - - enum SSLModes { - /** - * Make unencrypted outgoing connections and do not accept incoming SSL-connections - */ - SSLMode_disabled, - - /** - * Make unencrypted outgoing connections and accept both unencrypted and SSL-connections - */ - SSLMode_allowSSL, - - /** - * Make outgoing SSL-connections and accept both unecrypted and SSL-connections - */ - SSLMode_preferSSL, - - /** - * Make outgoing SSL-connections and only accept incoming SSL-connections - */ - SSLMode_requireSSL - }; +namespace optionenvironment { +class OptionSection; +class Environment; +} // namespace optionenvironment + +namespace moe = mongo::optionenvironment; + +struct SSLParams { + enum class Protocols { TLS1_0, TLS1_1, TLS1_2 }; + AtomicInt32 sslMode; // --sslMode - the SSL operation mode, see enum SSLModes + bool sslOnNormalPorts; // --sslOnNormalPorts (deprecated) + std::string sslPEMKeyFile; // --sslPEMKeyFile + std::string sslPEMKeyPassword; // --sslPEMKeyPassword + std::string sslClusterFile; // --sslInternalKeyFile + std::string sslClusterPassword; // --sslInternalKeyPassword + std::string sslCAFile; // --sslCAFile + std::string sslCRLFile; // --sslCRLFile + std::string sslCipherConfig; // --sslCipherConfig + std::vector<Protocols> sslDisabledProtocols; // --sslDisabledProtocols + bool sslWeakCertificateValidation; // --sslWeakCertificateValidation + bool sslFIPSMode; // --sslFIPSMode + bool sslAllowInvalidCertificates; // --sslAllowInvalidCertificates + bool sslAllowInvalidHostnames; // --sslAllowInvalidHostnames + + SSLParams() { + sslMode.store(SSLMode_disabled); + } + + enum SSLModes { + /** + * Make unencrypted outgoing connections and do not accept incoming SSL-connections + */ + SSLMode_disabled, + + /** + * Make unencrypted outgoing connections and accept both unencrypted and SSL-connections + */ + SSLMode_allowSSL, + + /** + * Make outgoing SSL-connections and accept both unecrypted and SSL-connections + */ + SSLMode_preferSSL, + + /** + * Make outgoing SSL-connections and only accept incoming SSL-connections + */ + SSLMode_requireSSL }; +}; - extern SSLParams sslGlobalParams; +extern SSLParams sslGlobalParams; - Status addSSLServerOptions(moe::OptionSection* options); +Status addSSLServerOptions(moe::OptionSection* options); - Status addSSLClientOptions(moe::OptionSection* options); +Status addSSLClientOptions(moe::OptionSection* options); - Status storeSSLServerOptions(const moe::Environment& params); +Status storeSSLServerOptions(const moe::Environment& params); - /** - * Canonicalize SSL options for the given environment that have different representations with - * the same logical meaning - */ - Status canonicalizeSSLServerOptions(moe::Environment* params); +/** + * Canonicalize SSL options for the given environment that have different representations with + * the same logical meaning + */ +Status canonicalizeSSLServerOptions(moe::Environment* params); - Status validateSSLServerOptions(const moe::Environment& params); +Status validateSSLServerOptions(const moe::Environment& params); - Status storeSSLClientOptions(const moe::Environment& params); +Status storeSSLClientOptions(const moe::Environment& params); - /** - * Used by the Mongo shell to validate that the SSL options passed are acceptable and - * do not conflict with one another. - */ - Status validateSSLMongoShellOptions(const moe::Environment& params); +/** + * Used by the Mongo shell to validate that the SSL options passed are acceptable and + * do not conflict with one another. + */ +Status validateSSLMongoShellOptions(const moe::Environment& params); } |