/**
* Copyright (C) 2013 10gen Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the GNU Affero General Public License in all respects
* for all of the code used other than as permitted herein. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you do not
* wish to do so, delete this exception statement from your version. If you
* delete this exception statement from all source files in the program,
* then also delete it in the license file.
*/
#include "mongo/platform/basic.h"
#include "mongo/util/net/sock.h"
#include
#ifndef _WIN32
#include
#include
#include
#endif
#include "mongo/db/server_options.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/concurrency/synchronization.h"
#include "mongo/util/fail_point_service.h"
namespace {
using namespace mongo;
typedef boost::shared_ptr SocketPtr;
typedef std::pair SocketPair;
// On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair'
// call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc.
// For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET
// pair on a random port.
SocketPair socketPair(const int type, const int protocol = 0);
#if defined(_WIN32)
namespace detail {
void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) {
*acceptSock = INVALID_SOCKET;
const SOCKET result = ::accept(listenSock, NULL, 0);
if (result != INVALID_SOCKET) {
*acceptSock = result;
}
notify.notifyOne();
}
void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) {
*connectSock = INVALID_SOCKET;
SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol);
if (newSock != INVALID_SOCKET) {
int result = ::connect(newSock, where.ai_addr, where.ai_addrlen);
if (result == 0) {
*connectSock = newSock;
}
}
notify.notifyOne();
}
} // namespace detail
SocketPair socketPair(const int type, const int protocol) {
const int domain = PF_INET;
// Create a listen socket and a connect socket.
const SOCKET listenSock = ::socket(domain, type, protocol);
if (listenSock == INVALID_SOCKET)
return SocketPair();
// Bind the listen socket on port zero, it will pick one for us, and start it listening
// for connections.
struct addrinfo hints, *res;
::memset(&hints, 0, sizeof(hints));
hints.ai_family = PF_INET;
hints.ai_socktype = type;
hints.ai_flags = AI_PASSIVE;
int result = ::getaddrinfo(NULL, "0", &hints, &res);
if (result != 0) {
closesocket(listenSock);
return SocketPair();
}
result = ::bind(listenSock, res->ai_addr, res->ai_addrlen);
if (result != 0) {
closesocket(listenSock);
::freeaddrinfo(res);
return SocketPair();
}
// Read out the port to which we bound.
sockaddr_in bindAddr;
::socklen_t len = sizeof(bindAddr);
::memset(&bindAddr, 0, sizeof(bindAddr));
result = ::getsockname(listenSock, reinterpret_cast(&bindAddr), &len);
if (result != 0) {
closesocket(listenSock);
::freeaddrinfo(res);
return SocketPair();
}
result = ::listen(listenSock, 1);
if (result != 0) {
closesocket(listenSock);
::freeaddrinfo(res);
return SocketPair();
}
struct addrinfo connectHints, *connectRes;
::memset(&connectHints, 0, sizeof(connectHints));
connectHints.ai_family = PF_INET;
connectHints.ai_socktype = type;
std::stringstream portStream;
portStream << ntohs(bindAddr.sin_port);
result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes);
if (result != 0) {
closesocket(listenSock);
::freeaddrinfo(res);
return SocketPair();
}
// I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some
// threads to do the connect and acccept.
Notification accepted;
SOCKET acceptSock = INVALID_SOCKET;
boost::thread acceptor(
boost::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted)));
Notification connected;
SOCKET connectSock = INVALID_SOCKET;
boost::thread connector(
boost::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected)));
connected.waitToBeNotified();
if (connectSock == INVALID_SOCKET) {
closesocket(listenSock);
::freeaddrinfo(res);
::freeaddrinfo(connectRes);
closesocket(acceptSock);
closesocket(connectSock);
return SocketPair();
}
accepted.waitToBeNotified();
if (acceptSock == INVALID_SOCKET) {
closesocket(listenSock);
::freeaddrinfo(res);
::freeaddrinfo(connectRes);
closesocket(acceptSock);
closesocket(connectSock);
return SocketPair();
}
closesocket(listenSock);
::freeaddrinfo(res);
::freeaddrinfo(connectRes);
SocketPtr first(new Socket(static_cast(acceptSock), SockAddr()));
SocketPtr second(new Socket(static_cast(connectSock), SockAddr()));
return SocketPair(first, second);
}
#else
// We can just use ::socketpair and wrap up the result in a Socket.
SocketPair socketPair(const int type, const int protocol) {
// PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX
// is the name that BSD used. We use the BSD name because it is more
// widely supported (e.g. Solaris 10).
const int domain = PF_UNIX;
int socks[2];
const int result = ::socketpair(domain, type, protocol, socks);
if (result == 0) {
return SocketPair(
SocketPtr(new Socket(socks[0], SockAddr())),
SocketPtr(new Socket(socks[1], SockAddr())));
}
return SocketPair();
}
#endif
// This should match the name of the fail point declared in sock.cpp.
const char kSocketFailPointName[] = "throwSockExcep";
class SocketFailPointTest : public unittest::Test {
public:
SocketFailPointTest()
: _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName))
, _sockets(socketPair(SOCK_STREAM)) {
ASSERT_TRUE(_failPoint != NULL);
ASSERT_TRUE(_sockets.first);
ASSERT_TRUE(_sockets.second);
}
~SocketFailPointTest() {
}
bool trySend() {
char byte = 'x';
_sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend");
return true;
}
bool trySendVector() {
std::vector > data;
char byte = 'x';
data.push_back(std::make_pair(&byte, sizeof(byte)));
_sockets.first->send(data, "SocketFailPointTest::trySendVector");
return true;
}
bool tryRecv() {
char byte;
_sockets.second->recv(&byte, sizeof(byte));
return true;
}
// You must queue at least one byte on the send socket before calling this function.
size_t countRecvable(size_t max) {
std::vector buf(max);
// This isn't great, because we don't have a guarantee that multiple sends will be
// captured in one recv. However, sock doesn't let us pass flags into recv, so we
// can't make this non blocking, and therefore can't risk another call.
return _sockets.second->unsafe_recv(&buf[0], max);
}
FailPoint* const _failPoint;
const SocketPair _sockets;
};
class ScopedFailPointEnabler {
public:
ScopedFailPointEnabler(FailPoint& fp)
: _fp(fp) {
_fp.setMode(FailPoint::alwaysOn);
}
~ScopedFailPointEnabler() {
_fp.setMode(FailPoint::off);
}
private:
FailPoint& _fp;
};
TEST_F(SocketFailPointTest, TestSend) {
ASSERT_TRUE(trySend());
ASSERT_TRUE(tryRecv());
{
const ScopedFailPointEnabler enabled(*_failPoint);
ASSERT_THROWS(trySend(), SocketException);
}
// Channel should be working again
ASSERT_TRUE(trySend());
ASSERT_TRUE(tryRecv());
}
TEST_F(SocketFailPointTest, TestSendVector) {
ASSERT_TRUE(trySendVector());
ASSERT_TRUE(tryRecv());
{
const ScopedFailPointEnabler enabled(*_failPoint);
ASSERT_THROWS(trySendVector(), SocketException);
}
ASSERT_TRUE(trySendVector());
ASSERT_TRUE(tryRecv());
}
TEST_F(SocketFailPointTest, TestRecv) {
ASSERT_TRUE(trySend()); // data for recv
ASSERT_TRUE(tryRecv());
{
ASSERT_TRUE(trySend()); // data for recv
const ScopedFailPointEnabler enabled(*_failPoint);
ASSERT_THROWS(tryRecv(), SocketException);
}
ASSERT_TRUE(trySend()); // data for recv
ASSERT_TRUE(tryRecv());
}
TEST_F(SocketFailPointTest, TestFailedSendsDontSend) {
ASSERT_TRUE(trySend());
ASSERT_TRUE(tryRecv());
{
ASSERT_TRUE(trySend()); // queue 1 byte
const ScopedFailPointEnabler enabled(*_failPoint);
// Fail to queue another byte
ASSERT_THROWS(trySend(), SocketException);
}
// Failed byte should not have been transmitted.
ASSERT_EQUALS(size_t(1), countRecvable(2));
}
// Ensure that calling send doesn't actually enqueue data to the socket
TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) {
ASSERT_TRUE(trySend());
ASSERT_TRUE(tryRecv());
{
ASSERT_TRUE(trySend()); // queue 1 byte
const ScopedFailPointEnabler enabled(*_failPoint);
// Fail to queue another byte
ASSERT_THROWS(trySendVector(), SocketException);
}
// Failed byte should not have been transmitted.
ASSERT_EQUALS(size_t(1), countRecvable(2));
}
TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) {
ASSERT_TRUE(trySend());
ASSERT_TRUE(tryRecv());
{
ASSERT_TRUE(trySend());
const ScopedFailPointEnabler enabled(*_failPoint);
// Fail to recv that byte
ASSERT_THROWS(tryRecv(), SocketException);
}
// Failed byte should still be queued to recv.
ASSERT_EQUALS(size_t(1), countRecvable(1));
// Channel should be working again
ASSERT_TRUE(trySend());
ASSERT_TRUE(tryRecv());
}
} // namespace