summaryrefslogtreecommitdiff
path: root/src/mongo/util/net/sock_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/util/net/sock_test.cpp')
-rw-r--r--src/mongo/util/net/sock_test.cpp508
1 files changed, 252 insertions, 256 deletions
diff --git a/src/mongo/util/net/sock_test.cpp b/src/mongo/util/net/sock_test.cpp
index 0a823e15f23..26f0c2c821a 100644
--- a/src/mongo/util/net/sock_test.cpp
+++ b/src/mongo/util/net/sock_test.cpp
@@ -44,305 +44,301 @@
namespace {
- using namespace mongo;
- using std::shared_ptr;
+using namespace mongo;
+using std::shared_ptr;
- typedef std::shared_ptr<Socket> SocketPtr;
- typedef std::pair<SocketPtr, SocketPtr> SocketPair;
+typedef std::shared_ptr<Socket> SocketPtr;
+typedef std::pair<SocketPtr, SocketPtr> SocketPair;
- // On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair'
- // call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc.
- // For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET
- // pair on a random port.
- SocketPair socketPair(const int type, const int protocol = 0);
+// On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair'
+// call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc.
+// For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET
+// pair on a random port.
+SocketPair socketPair(const int type, const int protocol = 0);
#if defined(_WIN32)
- namespace detail {
- void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) {
- *acceptSock = INVALID_SOCKET;
- const SOCKET result = ::accept(listenSock, NULL, 0);
- if (result != INVALID_SOCKET) {
- *acceptSock = result;
- }
- notify.notifyOne();
+namespace detail {
+void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) {
+ *acceptSock = INVALID_SOCKET;
+ const SOCKET result = ::accept(listenSock, NULL, 0);
+ if (result != INVALID_SOCKET) {
+ *acceptSock = result;
+ }
+ notify.notifyOne();
+}
+
+void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) {
+ *connectSock = INVALID_SOCKET;
+ SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol);
+ if (newSock != INVALID_SOCKET) {
+ int result = ::connect(newSock, where.ai_addr, where.ai_addrlen);
+ if (result == 0) {
+ *connectSock = newSock;
}
+ }
+ notify.notifyOne();
+}
+} // namespace detail
- void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) {
- *connectSock = INVALID_SOCKET;
- SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol);
- if (newSock != INVALID_SOCKET) {
- int result = ::connect(newSock, where.ai_addr, where.ai_addrlen);
- if (result == 0) {
- *connectSock = newSock;
- }
- }
- notify.notifyOne();
- }
- } // namespace detail
+SocketPair socketPair(const int type, const int protocol) {
+ const int domain = PF_INET;
- SocketPair socketPair(const int type, const int protocol) {
+ // Create a listen socket and a connect socket.
+ const SOCKET listenSock = ::socket(domain, type, protocol);
+ if (listenSock == INVALID_SOCKET)
+ return SocketPair();
- const int domain = PF_INET;
+ // Bind the listen socket on port zero, it will pick one for us, and start it listening
+ // for connections.
+ struct addrinfo hints, *res;
+ ::memset(&hints, 0, sizeof(hints));
+ hints.ai_family = PF_INET;
+ hints.ai_socktype = type;
+ hints.ai_flags = AI_PASSIVE;
- // Create a listen socket and a connect socket.
- const SOCKET listenSock = ::socket(domain, type, protocol);
- if (listenSock == INVALID_SOCKET)
- return SocketPair();
+ int result = ::getaddrinfo(NULL, "0", &hints, &res);
+ if (result != 0) {
+ closesocket(listenSock);
+ return SocketPair();
+ }
- // Bind the listen socket on port zero, it will pick one for us, and start it listening
- // for connections.
- struct addrinfo hints, *res;
- ::memset(&hints, 0, sizeof(hints));
- hints.ai_family = PF_INET;
- hints.ai_socktype = type;
- hints.ai_flags = AI_PASSIVE;
+ result = ::bind(listenSock, res->ai_addr, res->ai_addrlen);
+ if (result != 0) {
+ closesocket(listenSock);
+ ::freeaddrinfo(res);
+ return SocketPair();
+ }
- int result = ::getaddrinfo(NULL, "0", &hints, &res);
- if (result != 0) {
- closesocket(listenSock);
- return SocketPair();
- }
+ // Read out the port to which we bound.
+ sockaddr_in bindAddr;
+ ::socklen_t len = sizeof(bindAddr);
+ ::memset(&bindAddr, 0, sizeof(bindAddr));
+ result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len);
+ if (result != 0) {
+ closesocket(listenSock);
+ ::freeaddrinfo(res);
+ return SocketPair();
+ }
- result = ::bind(listenSock, res->ai_addr, res->ai_addrlen);
- if (result != 0) {
- closesocket(listenSock);
- ::freeaddrinfo(res);
- return SocketPair();
- }
+ result = ::listen(listenSock, 1);
+ if (result != 0) {
+ closesocket(listenSock);
+ ::freeaddrinfo(res);
+ return SocketPair();
+ }
- // Read out the port to which we bound.
- sockaddr_in bindAddr;
- ::socklen_t len = sizeof(bindAddr);
- ::memset(&bindAddr, 0, sizeof(bindAddr));
- result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len);
- if (result != 0) {
- closesocket(listenSock);
- ::freeaddrinfo(res);
- return SocketPair();
- }
+ struct addrinfo connectHints, *connectRes;
+ ::memset(&connectHints, 0, sizeof(connectHints));
+ connectHints.ai_family = PF_INET;
+ connectHints.ai_socktype = type;
+ std::stringstream portStream;
+ portStream << ntohs(bindAddr.sin_port);
+ result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes);
+ if (result != 0) {
+ closesocket(listenSock);
+ ::freeaddrinfo(res);
+ return SocketPair();
+ }
- result = ::listen(listenSock, 1);
- if (result != 0) {
- closesocket(listenSock);
- ::freeaddrinfo(res);
- return SocketPair();
- }
+ // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some
+ // threads to do the connect and acccept.
- struct addrinfo connectHints, *connectRes;
- ::memset(&connectHints, 0, sizeof(connectHints));
- connectHints.ai_family = PF_INET;
- connectHints.ai_socktype = type;
- std::stringstream portStream;
- portStream << ntohs(bindAddr.sin_port);
- result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes);
- if (result != 0) {
- closesocket(listenSock);
- ::freeaddrinfo(res);
- return SocketPair();
- }
+ Notification accepted;
+ SOCKET acceptSock = INVALID_SOCKET;
+ stdx::thread acceptor(
+ stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted)));
- // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some
- // threads to do the connect and acccept.
-
- Notification accepted;
- SOCKET acceptSock = INVALID_SOCKET;
- stdx::thread acceptor(
- stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted)));
-
- Notification connected;
- SOCKET connectSock = INVALID_SOCKET;
- stdx::thread connector(
- stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected)));
-
- connected.waitToBeNotified();
- if (connectSock == INVALID_SOCKET) {
- closesocket(listenSock);
- ::freeaddrinfo(res);
- ::freeaddrinfo(connectRes);
- closesocket(acceptSock);
- closesocket(connectSock);
- return SocketPair();
- }
+ Notification connected;
+ SOCKET connectSock = INVALID_SOCKET;
+ stdx::thread connector(
+ stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected)));
- accepted.waitToBeNotified();
- if (acceptSock == INVALID_SOCKET) {
- closesocket(listenSock);
- ::freeaddrinfo(res);
- ::freeaddrinfo(connectRes);
- closesocket(acceptSock);
- closesocket(connectSock);
- return SocketPair();
- }
+ connected.waitToBeNotified();
+ if (connectSock == INVALID_SOCKET) {
+ closesocket(listenSock);
+ ::freeaddrinfo(res);
+ ::freeaddrinfo(connectRes);
+ closesocket(acceptSock);
+ closesocket(connectSock);
+ return SocketPair();
+ }
+ accepted.waitToBeNotified();
+ if (acceptSock == INVALID_SOCKET) {
closesocket(listenSock);
::freeaddrinfo(res);
::freeaddrinfo(connectRes);
+ closesocket(acceptSock);
+ closesocket(connectSock);
+ return SocketPair();
+ }
- SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr()));
- SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr()));
+ closesocket(listenSock);
+ ::freeaddrinfo(res);
+ ::freeaddrinfo(connectRes);
- return SocketPair(first, second);
- }
+ SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr()));
+ SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr()));
+
+ return SocketPair(first, second);
+}
#else
- // We can just use ::socketpair and wrap up the result in a Socket.
- SocketPair socketPair(const int type, const int protocol) {
- // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX
- // is the name that BSD used. We use the BSD name because it is more
- // widely supported (e.g. Solaris 10).
- const int domain = PF_UNIX;
-
- int socks[2];
- const int result = ::socketpair(domain, type, protocol, socks);
- if (result == 0) {
- return SocketPair(
- SocketPtr(new Socket(socks[0], SockAddr())),
- SocketPtr(new Socket(socks[1], SockAddr())));
- }
- return SocketPair();
+// We can just use ::socketpair and wrap up the result in a Socket.
+SocketPair socketPair(const int type, const int protocol) {
+ // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX
+ // is the name that BSD used. We use the BSD name because it is more
+ // widely supported (e.g. Solaris 10).
+ const int domain = PF_UNIX;
+
+ int socks[2];
+ const int result = ::socketpair(domain, type, protocol, socks);
+ if (result == 0) {
+ return SocketPair(SocketPtr(new Socket(socks[0], SockAddr())),
+ SocketPtr(new Socket(socks[1], SockAddr())));
}
+ return SocketPair();
+}
#endif
- // This should match the name of the fail point declared in sock.cpp.
- const char kSocketFailPointName[] = "throwSockExcep";
-
- class SocketFailPointTest : public unittest::Test {
- public:
+// This should match the name of the fail point declared in sock.cpp.
+const char kSocketFailPointName[] = "throwSockExcep";
+
+class SocketFailPointTest : public unittest::Test {
+public:
+ SocketFailPointTest()
+ : _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName)),
+ _sockets(socketPair(SOCK_STREAM)) {
+ ASSERT_TRUE(_failPoint != NULL);
+ ASSERT_TRUE(_sockets.first);
+ ASSERT_TRUE(_sockets.second);
+ }
- SocketFailPointTest()
- : _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName))
- , _sockets(socketPair(SOCK_STREAM)) {
- ASSERT_TRUE(_failPoint != NULL);
- ASSERT_TRUE(_sockets.first);
- ASSERT_TRUE(_sockets.second);
- }
+ ~SocketFailPointTest() {}
- ~SocketFailPointTest() {
- }
+ bool trySend() {
+ char byte = 'x';
+ _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend");
+ return true;
+ }
- bool trySend() {
- char byte = 'x';
- _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend");
- return true;
- }
+ bool trySendVector() {
+ std::vector<std::pair<char*, int>> data;
+ char byte = 'x';
+ data.push_back(std::make_pair(&byte, sizeof(byte)));
+ _sockets.first->send(data, "SocketFailPointTest::trySendVector");
+ return true;
+ }
- bool trySendVector() {
- std::vector<std::pair<char*, int> > data;
- char byte = 'x';
- data.push_back(std::make_pair(&byte, sizeof(byte)));
- _sockets.first->send(data, "SocketFailPointTest::trySendVector");
- return true;
- }
+ bool tryRecv() {
+ char byte;
+ _sockets.second->recv(&byte, sizeof(byte));
+ return true;
+ }
- bool tryRecv() {
- char byte;
- _sockets.second->recv(&byte, sizeof(byte));
- return true;
- }
+ // You must queue at least one byte on the send socket before calling this function.
+ size_t countRecvable(size_t max) {
+ std::vector<char> buf(max);
+ // This isn't great, because we don't have a guarantee that multiple sends will be
+ // captured in one recv. However, sock doesn't let us pass flags into recv, so we
+ // can't make this non blocking, and therefore can't risk another call.
+ return _sockets.second->unsafe_recv(&buf[0], max);
+ }
- // You must queue at least one byte on the send socket before calling this function.
- size_t countRecvable(size_t max) {
- std::vector<char> buf(max);
- // This isn't great, because we don't have a guarantee that multiple sends will be
- // captured in one recv. However, sock doesn't let us pass flags into recv, so we
- // can't make this non blocking, and therefore can't risk another call.
- return _sockets.second->unsafe_recv(&buf[0], max);
- }
+ FailPoint* const _failPoint;
+ const SocketPair _sockets;
+};
- FailPoint* const _failPoint;
- const SocketPair _sockets;
- };
+class ScopedFailPointEnabler {
+public:
+ ScopedFailPointEnabler(FailPoint& fp) : _fp(fp) {
+ _fp.setMode(FailPoint::alwaysOn);
+ }
- class ScopedFailPointEnabler {
- public:
- ScopedFailPointEnabler(FailPoint& fp)
- : _fp(fp) {
- _fp.setMode(FailPoint::alwaysOn);
- }
+ ~ScopedFailPointEnabler() {
+ _fp.setMode(FailPoint::off);
+ }
- ~ScopedFailPointEnabler() {
- _fp.setMode(FailPoint::off);
- }
- private:
- FailPoint& _fp;
- };
+private:
+ FailPoint& _fp;
+};
- TEST_F(SocketFailPointTest, TestSend) {
- ASSERT_TRUE(trySend());
- ASSERT_TRUE(tryRecv());
- {
- const ScopedFailPointEnabler enabled(*_failPoint);
- ASSERT_THROWS(trySend(), SocketException);
- }
- // Channel should be working again
- ASSERT_TRUE(trySend());
- ASSERT_TRUE(tryRecv());
+TEST_F(SocketFailPointTest, TestSend) {
+ ASSERT_TRUE(trySend());
+ ASSERT_TRUE(tryRecv());
+ {
+ const ScopedFailPointEnabler enabled(*_failPoint);
+ ASSERT_THROWS(trySend(), SocketException);
}
-
- TEST_F(SocketFailPointTest, TestSendVector) {
- ASSERT_TRUE(trySendVector());
- ASSERT_TRUE(tryRecv());
- {
- const ScopedFailPointEnabler enabled(*_failPoint);
- ASSERT_THROWS(trySendVector(), SocketException);
- }
- ASSERT_TRUE(trySendVector());
- ASSERT_TRUE(tryRecv());
+ // Channel should be working again
+ ASSERT_TRUE(trySend());
+ ASSERT_TRUE(tryRecv());
+}
+
+TEST_F(SocketFailPointTest, TestSendVector) {
+ ASSERT_TRUE(trySendVector());
+ ASSERT_TRUE(tryRecv());
+ {
+ const ScopedFailPointEnabler enabled(*_failPoint);
+ ASSERT_THROWS(trySendVector(), SocketException);
}
-
- TEST_F(SocketFailPointTest, TestRecv) {
- ASSERT_TRUE(trySend()); // data for recv
- ASSERT_TRUE(tryRecv());
- {
- ASSERT_TRUE(trySend()); // data for recv
- const ScopedFailPointEnabler enabled(*_failPoint);
- ASSERT_THROWS(tryRecv(), SocketException);
- }
- ASSERT_TRUE(trySend()); // data for recv
- ASSERT_TRUE(tryRecv());
+ ASSERT_TRUE(trySendVector());
+ ASSERT_TRUE(tryRecv());
+}
+
+TEST_F(SocketFailPointTest, TestRecv) {
+ ASSERT_TRUE(trySend()); // data for recv
+ ASSERT_TRUE(tryRecv());
+ {
+ ASSERT_TRUE(trySend()); // data for recv
+ const ScopedFailPointEnabler enabled(*_failPoint);
+ ASSERT_THROWS(tryRecv(), SocketException);
}
-
- TEST_F(SocketFailPointTest, TestFailedSendsDontSend) {
- ASSERT_TRUE(trySend());
- ASSERT_TRUE(tryRecv());
- {
- ASSERT_TRUE(trySend()); // queue 1 byte
- const ScopedFailPointEnabler enabled(*_failPoint);
- // Fail to queue another byte
- ASSERT_THROWS(trySend(), SocketException);
- }
- // Failed byte should not have been transmitted.
- ASSERT_EQUALS(size_t(1), countRecvable(2));
+ ASSERT_TRUE(trySend()); // data for recv
+ ASSERT_TRUE(tryRecv());
+}
+
+TEST_F(SocketFailPointTest, TestFailedSendsDontSend) {
+ ASSERT_TRUE(trySend());
+ ASSERT_TRUE(tryRecv());
+ {
+ ASSERT_TRUE(trySend()); // queue 1 byte
+ const ScopedFailPointEnabler enabled(*_failPoint);
+ // Fail to queue another byte
+ ASSERT_THROWS(trySend(), SocketException);
}
-
- // Ensure that calling send doesn't actually enqueue data to the socket
- TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) {
- ASSERT_TRUE(trySend());
- ASSERT_TRUE(tryRecv());
- {
- ASSERT_TRUE(trySend()); // queue 1 byte
- const ScopedFailPointEnabler enabled(*_failPoint);
- // Fail to queue another byte
- ASSERT_THROWS(trySendVector(), SocketException);
- }
- // Failed byte should not have been transmitted.
- ASSERT_EQUALS(size_t(1), countRecvable(2));
+ // Failed byte should not have been transmitted.
+ ASSERT_EQUALS(size_t(1), countRecvable(2));
+}
+
+// Ensure that calling send doesn't actually enqueue data to the socket
+TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) {
+ ASSERT_TRUE(trySend());
+ ASSERT_TRUE(tryRecv());
+ {
+ ASSERT_TRUE(trySend()); // queue 1 byte
+ const ScopedFailPointEnabler enabled(*_failPoint);
+ // Fail to queue another byte
+ ASSERT_THROWS(trySendVector(), SocketException);
}
-
- TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) {
- ASSERT_TRUE(trySend());
- ASSERT_TRUE(tryRecv());
- {
- ASSERT_TRUE(trySend());
- const ScopedFailPointEnabler enabled(*_failPoint);
- // Fail to recv that byte
- ASSERT_THROWS(tryRecv(), SocketException);
- }
- // Failed byte should still be queued to recv.
- ASSERT_EQUALS(size_t(1), countRecvable(1));
- // Channel should be working again
+ // Failed byte should not have been transmitted.
+ ASSERT_EQUALS(size_t(1), countRecvable(2));
+}
+
+TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) {
+ ASSERT_TRUE(trySend());
+ ASSERT_TRUE(tryRecv());
+ {
ASSERT_TRUE(trySend());
- ASSERT_TRUE(tryRecv());
+ const ScopedFailPointEnabler enabled(*_failPoint);
+ // Fail to recv that byte
+ ASSERT_THROWS(tryRecv(), SocketException);
}
+ // Failed byte should still be queued to recv.
+ ASSERT_EQUALS(size_t(1), countRecvable(1));
+ // Channel should be working again
+ ASSERT_TRUE(trySend());
+ ASSERT_TRUE(tryRecv());
+}
-} // namespace
+} // namespace