diff options
Diffstat (limited to 'src/mongo/util/net/sock_test.cpp')
-rw-r--r-- | src/mongo/util/net/sock_test.cpp | 508 |
1 files changed, 252 insertions, 256 deletions
diff --git a/src/mongo/util/net/sock_test.cpp b/src/mongo/util/net/sock_test.cpp index 0a823e15f23..26f0c2c821a 100644 --- a/src/mongo/util/net/sock_test.cpp +++ b/src/mongo/util/net/sock_test.cpp @@ -44,305 +44,301 @@ namespace { - using namespace mongo; - using std::shared_ptr; +using namespace mongo; +using std::shared_ptr; - typedef std::shared_ptr<Socket> SocketPtr; - typedef std::pair<SocketPtr, SocketPtr> SocketPair; +typedef std::shared_ptr<Socket> SocketPtr; +typedef std::pair<SocketPtr, SocketPtr> SocketPair; - // On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair' - // call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc. - // For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET - // pair on a random port. - SocketPair socketPair(const int type, const int protocol = 0); +// On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair' +// call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc. +// For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET +// pair on a random port. +SocketPair socketPair(const int type, const int protocol = 0); #if defined(_WIN32) - namespace detail { - void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) { - *acceptSock = INVALID_SOCKET; - const SOCKET result = ::accept(listenSock, NULL, 0); - if (result != INVALID_SOCKET) { - *acceptSock = result; - } - notify.notifyOne(); +namespace detail { +void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) { + *acceptSock = INVALID_SOCKET; + const SOCKET result = ::accept(listenSock, NULL, 0); + if (result != INVALID_SOCKET) { + *acceptSock = result; + } + notify.notifyOne(); +} + +void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) { + *connectSock = INVALID_SOCKET; + SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol); + if (newSock != INVALID_SOCKET) { + int result = ::connect(newSock, where.ai_addr, where.ai_addrlen); + if (result == 0) { + *connectSock = newSock; } + } + notify.notifyOne(); +} +} // namespace detail - void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) { - *connectSock = INVALID_SOCKET; - SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol); - if (newSock != INVALID_SOCKET) { - int result = ::connect(newSock, where.ai_addr, where.ai_addrlen); - if (result == 0) { - *connectSock = newSock; - } - } - notify.notifyOne(); - } - } // namespace detail +SocketPair socketPair(const int type, const int protocol) { + const int domain = PF_INET; - SocketPair socketPair(const int type, const int protocol) { + // Create a listen socket and a connect socket. + const SOCKET listenSock = ::socket(domain, type, protocol); + if (listenSock == INVALID_SOCKET) + return SocketPair(); - const int domain = PF_INET; + // Bind the listen socket on port zero, it will pick one for us, and start it listening + // for connections. + struct addrinfo hints, *res; + ::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_INET; + hints.ai_socktype = type; + hints.ai_flags = AI_PASSIVE; - // Create a listen socket and a connect socket. - const SOCKET listenSock = ::socket(domain, type, protocol); - if (listenSock == INVALID_SOCKET) - return SocketPair(); + int result = ::getaddrinfo(NULL, "0", &hints, &res); + if (result != 0) { + closesocket(listenSock); + return SocketPair(); + } - // Bind the listen socket on port zero, it will pick one for us, and start it listening - // for connections. - struct addrinfo hints, *res; - ::memset(&hints, 0, sizeof(hints)); - hints.ai_family = PF_INET; - hints.ai_socktype = type; - hints.ai_flags = AI_PASSIVE; + result = ::bind(listenSock, res->ai_addr, res->ai_addrlen); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - int result = ::getaddrinfo(NULL, "0", &hints, &res); - if (result != 0) { - closesocket(listenSock); - return SocketPair(); - } + // Read out the port to which we bound. + sockaddr_in bindAddr; + ::socklen_t len = sizeof(bindAddr); + ::memset(&bindAddr, 0, sizeof(bindAddr)); + result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - result = ::bind(listenSock, res->ai_addr, res->ai_addrlen); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + result = ::listen(listenSock, 1); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - // Read out the port to which we bound. - sockaddr_in bindAddr; - ::socklen_t len = sizeof(bindAddr); - ::memset(&bindAddr, 0, sizeof(bindAddr)); - result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + struct addrinfo connectHints, *connectRes; + ::memset(&connectHints, 0, sizeof(connectHints)); + connectHints.ai_family = PF_INET; + connectHints.ai_socktype = type; + std::stringstream portStream; + portStream << ntohs(bindAddr.sin_port); + result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes); + if (result != 0) { + closesocket(listenSock); + ::freeaddrinfo(res); + return SocketPair(); + } - result = ::listen(listenSock, 1); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some + // threads to do the connect and acccept. - struct addrinfo connectHints, *connectRes; - ::memset(&connectHints, 0, sizeof(connectHints)); - connectHints.ai_family = PF_INET; - connectHints.ai_socktype = type; - std::stringstream portStream; - portStream << ntohs(bindAddr.sin_port); - result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes); - if (result != 0) { - closesocket(listenSock); - ::freeaddrinfo(res); - return SocketPair(); - } + Notification accepted; + SOCKET acceptSock = INVALID_SOCKET; + stdx::thread acceptor( + stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted))); - // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some - // threads to do the connect and acccept. - - Notification accepted; - SOCKET acceptSock = INVALID_SOCKET; - stdx::thread acceptor( - stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted))); - - Notification connected; - SOCKET connectSock = INVALID_SOCKET; - stdx::thread connector( - stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected))); - - connected.waitToBeNotified(); - if (connectSock == INVALID_SOCKET) { - closesocket(listenSock); - ::freeaddrinfo(res); - ::freeaddrinfo(connectRes); - closesocket(acceptSock); - closesocket(connectSock); - return SocketPair(); - } + Notification connected; + SOCKET connectSock = INVALID_SOCKET; + stdx::thread connector( + stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected))); - accepted.waitToBeNotified(); - if (acceptSock == INVALID_SOCKET) { - closesocket(listenSock); - ::freeaddrinfo(res); - ::freeaddrinfo(connectRes); - closesocket(acceptSock); - closesocket(connectSock); - return SocketPair(); - } + connected.waitToBeNotified(); + if (connectSock == INVALID_SOCKET) { + closesocket(listenSock); + ::freeaddrinfo(res); + ::freeaddrinfo(connectRes); + closesocket(acceptSock); + closesocket(connectSock); + return SocketPair(); + } + accepted.waitToBeNotified(); + if (acceptSock == INVALID_SOCKET) { closesocket(listenSock); ::freeaddrinfo(res); ::freeaddrinfo(connectRes); + closesocket(acceptSock); + closesocket(connectSock); + return SocketPair(); + } - SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr())); - SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr())); + closesocket(listenSock); + ::freeaddrinfo(res); + ::freeaddrinfo(connectRes); - return SocketPair(first, second); - } + SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr())); + SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr())); + + return SocketPair(first, second); +} #else - // We can just use ::socketpair and wrap up the result in a Socket. - SocketPair socketPair(const int type, const int protocol) { - // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX - // is the name that BSD used. We use the BSD name because it is more - // widely supported (e.g. Solaris 10). - const int domain = PF_UNIX; - - int socks[2]; - const int result = ::socketpair(domain, type, protocol, socks); - if (result == 0) { - return SocketPair( - SocketPtr(new Socket(socks[0], SockAddr())), - SocketPtr(new Socket(socks[1], SockAddr()))); - } - return SocketPair(); +// We can just use ::socketpair and wrap up the result in a Socket. +SocketPair socketPair(const int type, const int protocol) { + // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX + // is the name that BSD used. We use the BSD name because it is more + // widely supported (e.g. Solaris 10). + const int domain = PF_UNIX; + + int socks[2]; + const int result = ::socketpair(domain, type, protocol, socks); + if (result == 0) { + return SocketPair(SocketPtr(new Socket(socks[0], SockAddr())), + SocketPtr(new Socket(socks[1], SockAddr()))); } + return SocketPair(); +} #endif - // This should match the name of the fail point declared in sock.cpp. - const char kSocketFailPointName[] = "throwSockExcep"; - - class SocketFailPointTest : public unittest::Test { - public: +// This should match the name of the fail point declared in sock.cpp. +const char kSocketFailPointName[] = "throwSockExcep"; + +class SocketFailPointTest : public unittest::Test { +public: + SocketFailPointTest() + : _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName)), + _sockets(socketPair(SOCK_STREAM)) { + ASSERT_TRUE(_failPoint != NULL); + ASSERT_TRUE(_sockets.first); + ASSERT_TRUE(_sockets.second); + } - SocketFailPointTest() - : _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName)) - , _sockets(socketPair(SOCK_STREAM)) { - ASSERT_TRUE(_failPoint != NULL); - ASSERT_TRUE(_sockets.first); - ASSERT_TRUE(_sockets.second); - } + ~SocketFailPointTest() {} - ~SocketFailPointTest() { - } + bool trySend() { + char byte = 'x'; + _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend"); + return true; + } - bool trySend() { - char byte = 'x'; - _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend"); - return true; - } + bool trySendVector() { + std::vector<std::pair<char*, int>> data; + char byte = 'x'; + data.push_back(std::make_pair(&byte, sizeof(byte))); + _sockets.first->send(data, "SocketFailPointTest::trySendVector"); + return true; + } - bool trySendVector() { - std::vector<std::pair<char*, int> > data; - char byte = 'x'; - data.push_back(std::make_pair(&byte, sizeof(byte))); - _sockets.first->send(data, "SocketFailPointTest::trySendVector"); - return true; - } + bool tryRecv() { + char byte; + _sockets.second->recv(&byte, sizeof(byte)); + return true; + } - bool tryRecv() { - char byte; - _sockets.second->recv(&byte, sizeof(byte)); - return true; - } + // You must queue at least one byte on the send socket before calling this function. + size_t countRecvable(size_t max) { + std::vector<char> buf(max); + // This isn't great, because we don't have a guarantee that multiple sends will be + // captured in one recv. However, sock doesn't let us pass flags into recv, so we + // can't make this non blocking, and therefore can't risk another call. + return _sockets.second->unsafe_recv(&buf[0], max); + } - // You must queue at least one byte on the send socket before calling this function. - size_t countRecvable(size_t max) { - std::vector<char> buf(max); - // This isn't great, because we don't have a guarantee that multiple sends will be - // captured in one recv. However, sock doesn't let us pass flags into recv, so we - // can't make this non blocking, and therefore can't risk another call. - return _sockets.second->unsafe_recv(&buf[0], max); - } + FailPoint* const _failPoint; + const SocketPair _sockets; +}; - FailPoint* const _failPoint; - const SocketPair _sockets; - }; +class ScopedFailPointEnabler { +public: + ScopedFailPointEnabler(FailPoint& fp) : _fp(fp) { + _fp.setMode(FailPoint::alwaysOn); + } - class ScopedFailPointEnabler { - public: - ScopedFailPointEnabler(FailPoint& fp) - : _fp(fp) { - _fp.setMode(FailPoint::alwaysOn); - } + ~ScopedFailPointEnabler() { + _fp.setMode(FailPoint::off); + } - ~ScopedFailPointEnabler() { - _fp.setMode(FailPoint::off); - } - private: - FailPoint& _fp; - }; +private: + FailPoint& _fp; +}; - TEST_F(SocketFailPointTest, TestSend) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - const ScopedFailPointEnabler enabled(*_failPoint); - ASSERT_THROWS(trySend(), SocketException); - } - // Channel should be working again - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); +TEST_F(SocketFailPointTest, TestSend) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { + const ScopedFailPointEnabler enabled(*_failPoint); + ASSERT_THROWS(trySend(), SocketException); } - - TEST_F(SocketFailPointTest, TestSendVector) { - ASSERT_TRUE(trySendVector()); - ASSERT_TRUE(tryRecv()); - { - const ScopedFailPointEnabler enabled(*_failPoint); - ASSERT_THROWS(trySendVector(), SocketException); - } - ASSERT_TRUE(trySendVector()); - ASSERT_TRUE(tryRecv()); + // Channel should be working again + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); +} + +TEST_F(SocketFailPointTest, TestSendVector) { + ASSERT_TRUE(trySendVector()); + ASSERT_TRUE(tryRecv()); + { + const ScopedFailPointEnabler enabled(*_failPoint); + ASSERT_THROWS(trySendVector(), SocketException); } - - TEST_F(SocketFailPointTest, TestRecv) { - ASSERT_TRUE(trySend()); // data for recv - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); // data for recv - const ScopedFailPointEnabler enabled(*_failPoint); - ASSERT_THROWS(tryRecv(), SocketException); - } - ASSERT_TRUE(trySend()); // data for recv - ASSERT_TRUE(tryRecv()); + ASSERT_TRUE(trySendVector()); + ASSERT_TRUE(tryRecv()); +} + +TEST_F(SocketFailPointTest, TestRecv) { + ASSERT_TRUE(trySend()); // data for recv + ASSERT_TRUE(tryRecv()); + { + ASSERT_TRUE(trySend()); // data for recv + const ScopedFailPointEnabler enabled(*_failPoint); + ASSERT_THROWS(tryRecv(), SocketException); } - - TEST_F(SocketFailPointTest, TestFailedSendsDontSend) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); // queue 1 byte - const ScopedFailPointEnabler enabled(*_failPoint); - // Fail to queue another byte - ASSERT_THROWS(trySend(), SocketException); - } - // Failed byte should not have been transmitted. - ASSERT_EQUALS(size_t(1), countRecvable(2)); + ASSERT_TRUE(trySend()); // data for recv + ASSERT_TRUE(tryRecv()); +} + +TEST_F(SocketFailPointTest, TestFailedSendsDontSend) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { + ASSERT_TRUE(trySend()); // queue 1 byte + const ScopedFailPointEnabler enabled(*_failPoint); + // Fail to queue another byte + ASSERT_THROWS(trySend(), SocketException); } - - // Ensure that calling send doesn't actually enqueue data to the socket - TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); // queue 1 byte - const ScopedFailPointEnabler enabled(*_failPoint); - // Fail to queue another byte - ASSERT_THROWS(trySendVector(), SocketException); - } - // Failed byte should not have been transmitted. - ASSERT_EQUALS(size_t(1), countRecvable(2)); + // Failed byte should not have been transmitted. + ASSERT_EQUALS(size_t(1), countRecvable(2)); +} + +// Ensure that calling send doesn't actually enqueue data to the socket +TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { + ASSERT_TRUE(trySend()); // queue 1 byte + const ScopedFailPointEnabler enabled(*_failPoint); + // Fail to queue another byte + ASSERT_THROWS(trySendVector(), SocketException); } - - TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) { - ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); - { - ASSERT_TRUE(trySend()); - const ScopedFailPointEnabler enabled(*_failPoint); - // Fail to recv that byte - ASSERT_THROWS(tryRecv(), SocketException); - } - // Failed byte should still be queued to recv. - ASSERT_EQUALS(size_t(1), countRecvable(1)); - // Channel should be working again + // Failed byte should not have been transmitted. + ASSERT_EQUALS(size_t(1), countRecvable(2)); +} + +TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) { + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); + { ASSERT_TRUE(trySend()); - ASSERT_TRUE(tryRecv()); + const ScopedFailPointEnabler enabled(*_failPoint); + // Fail to recv that byte + ASSERT_THROWS(tryRecv(), SocketException); } + // Failed byte should still be queued to recv. + ASSERT_EQUALS(size_t(1), countRecvable(1)); + // Channel should be working again + ASSERT_TRUE(trySend()); + ASSERT_TRUE(tryRecv()); +} -} // namespace +} // namespace |