summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJason Carey <jcarey@argv.me>2016-12-01 14:54:59 -0500
committerJason Carey <jcarey@argv.me>2016-12-08 11:18:17 -0500
commit80775cd49db61ec792313da630158a2904faa75f (patch)
tree43566a1716adaa33bdc7f9291b3619bacd1285f4 /src
parentcea6463452fb138b4536aed6660be082343dd9de (diff)
downloadmongo-80775cd49db61ec792313da630158a2904faa75f.tar.gz
SERVER-27240 Replace ConnectBG with poll
It's unsafe to close a socket from another thread. Also, after returning EINTR, the connect call converts to an async call. And on non-linux systems that requires a fallback to poll/select to handle errors. Because of that, let's just do the connect without the background thread at all, starting off with poll. (cherry picked from commit a1baabeee5694aa8c4ffa1827233684d6c7fcc49)
Diffstat (limited to 'src')
-rw-r--r--src/mongo/util/net/sock.cpp166
1 files changed, 110 insertions, 56 deletions
diff --git a/src/mongo/util/net/sock.cpp b/src/mongo/util/net/sock.cpp
index 43426b6ff03..e2bbc28da83 100644
--- a/src/mongo/util/net/sock.cpp
+++ b/src/mongo/util/net/sock.cpp
@@ -33,7 +33,10 @@
#include "mongo/util/net/sock.h"
+#include <algorithm>
+
#if !defined(_WIN32)
+#include <fcntl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
@@ -70,6 +73,38 @@ using std::vector;
MONGO_FP_DECLARE(throwSockExcep);
+namespace {
+
+// Provides a cross-platform function for setting a file descriptor/socket to non-blocking mode.
+bool setBlock(int fd, bool block) {
+#ifdef _WIN32
+ u_long ioMode = block ? 0 : 1;
+ return (NO_ERROR == ::ioctlsocket(fd, FIONBIO, &ioMode));
+#else
+ int flags = fcntl(fd, F_GETFL, fd);
+ if (block) {
+ return (-1 != fcntl(fd, F_SETFL, (flags & ~O_NONBLOCK)));
+ } else {
+ return (-1 != fcntl(fd, F_SETFL, (flags | O_NONBLOCK)));
+ }
+#endif
+}
+
+void networkWarnWithDescription(const Socket& socket, StringData call, int errorCode = -1) {
+#ifdef _WIN32
+ if (errorCode == -1) {
+ errorCode = WSAGetLastError();
+ }
+#endif
+ std::string ewd(errnoWithDescription(errorCode));
+ warning() << "Failed to connect to " << socket.remoteAddr().getAddr() << ":"
+ << socket.remoteAddr().getPort() << ", in(" << call << "), reason: " << ewd;
+}
+
+const double kMaxConnectTimeoutMS = 5000;
+
+} // namespace
+
static bool ipv6 = false;
void enableIPv6(bool state) {
ipv6 = state;
@@ -539,53 +574,91 @@ std::string Socket::doSSLHandshake(const char* firstBytes, int len) {
}
#endif
-class ConnectBG : public BackgroundJob {
-public:
- ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) {}
+bool Socket::connect(SockAddr& remote) {
+ _remote = remote;
- void run() {
-#if defined(_WIN32)
- if ((_res = _connect()) == SOCKET_ERROR) {
- _errnoWithDescription = errnoWithDescription();
+ _fd = ::socket(remote.getType(), SOCK_STREAM, 0);
+ if (_fd == INVALID_SOCKET) {
+ networkWarnWithDescription(*this, "socket");
+ return false;
+ }
+
+ if (!setBlock(_fd, false)) {
+ networkWarnWithDescription(*this, "set socket to non-blocking mode");
+ return false;
+ }
+
+ const uint64_t connectTimeoutMillis =
+ _timeout > 0 ? std::min(kMaxConnectTimeoutMS, _timeout) : kMaxConnectTimeoutMS;
+ const uint64_t expiration = curTimeMillis64() + connectTimeoutMillis;
+
+ bool connectSucceeded = ::connect(_fd, _remote.raw(), _remote.addressSize) == 0;
+
+ if (!connectSucceeded) {
+#ifdef _WIN32
+ if (WSAGetLastError() != WSAEWOULDBLOCK) {
+ networkWarnWithDescription(*this, "connect");
+ return false;
}
#else
- while ((_res = _connect()) == -1) {
- const int error = errno;
- if (error != EINTR) {
- _errnoWithDescription = errnoWithDescription(error);
- break;
- }
+ if (errno != EINTR && errno != EINPROGRESS) {
+ networkWarnWithDescription(*this, "connect");
+ return false;
}
#endif
- }
- std::string name() const {
- return "ConnectBG";
- }
- std::string getErrnoWithDescription() const {
- return _errnoWithDescription;
- }
- int inError() const {
- return _res;
- }
+ pollfd pfd;
+ pfd.fd = _fd;
+ pfd.events = POLLOUT;
-private:
- int _connect() const {
- return ::connect(_sock, _remote.raw(), _remote.addressSize);
- }
+ while (true) {
+ const uint64_t timeout = std::max(0ull, expiration - curTimeMillis64());
- int _sock;
- int _res;
- SockAddr _remote;
- std::string _errnoWithDescription;
-};
+ int pollReturn = socketPoll(&pfd, 1, timeout);
+#ifdef _WIN32
+ if (pollReturn == SOCKET_ERROR) {
+ networkWarnWithDescription(*this, "poll");
+ return false;
+ }
+#else
+ if (pollReturn == -1) {
+ if (errno != EINTR) {
+ networkWarnWithDescription(*this, "poll");
+ return false;
+ }
-bool Socket::connect(SockAddr& remote) {
- _remote = remote;
+ // EINTR in poll, try again
+ continue;
+ }
+#endif
+ // No activity for the full duration of the timeout.
+ if (pollReturn == 0) {
+ warning() << "Failed to connect to " << _remote.getAddr() << ":"
+ << _remote.getPort() << " after " << connectTimeoutMillis
+ << " milliseconds, giving up.";
+ return false;
+ }
- _fd = socket(remote.getType(), SOCK_STREAM, 0);
- if (_fd == INVALID_SOCKET) {
- LOG(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl;
+ // We had a result, see if there's an error on the socket.
+ int optVal;
+ socklen_t optLen = sizeof(optVal);
+ if (::getsockopt(
+ _fd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&optVal), &optLen) == -1) {
+ networkWarnWithDescription(*this, "getsockopt");
+ return false;
+ }
+ if (optVal != 0) {
+ networkWarnWithDescription(*this, "checking socket for error after poll", optVal);
+ return false;
+ }
+
+ // We had activity and we don't have errors on the socket, we're connected.
+ break;
+ }
+ }
+
+ if (!setBlock(_fd, true)) {
+ networkWarnWithDescription(*this, "could not set socket to blocking mode");
return false;
}
@@ -593,25 +666,6 @@ bool Socket::connect(SockAddr& remote) {
setTimeout(_timeout);
}
- static const unsigned int connectTimeoutMillis = 5000;
- ConnectBG bg(_fd, remote);
- bg.go();
- if (bg.wait(connectTimeoutMillis)) {
- if (bg.inError()) {
- warning() << "Failed to connect to " << _remote.getAddr() << ":" << _remote.getPort()
- << ", reason: " << bg.getErrnoWithDescription() << endl;
- close();
- return false;
- }
- } else {
- // time out the connect
- close();
- bg.wait(); // so bg stays in scope until bg thread terminates
- warning() << "Failed to connect to " << _remote.getAddr() << ":" << _remote.getPort()
- << " after " << connectTimeoutMillis << " milliseconds, giving up." << endl;
- return false;
- }
-
if (remote.getType() != AF_UNIX)
disableNagle(_fd);