// socketft.cpp - written and placed in the public domain by Wei Dai #include "pch.h" #include "socketft.h" #ifdef SOCKETS_AVAILABLE #include "wait.h" #ifdef USE_BERKELEY_STYLE_SOCKETS #include #include #include #include #include #include #endif NAMESPACE_BEGIN(CryptoPP) #ifdef USE_WINDOWS_STYLE_SOCKETS const int SOCKET_EINVAL = WSAEINVAL; const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK; typedef int socklen_t; #else const int SOCKET_EINVAL = EINVAL; const int SOCKET_EWOULDBLOCK = EWOULDBLOCK; #endif Socket::Err::Err(socket_t s, const std::string& operation, int error) : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error) , m_s(s) { } Socket::~Socket() { if (m_own) { try { CloseSocket(); } catch (...) { } } } void Socket::AttachSocket(socket_t s, bool own) { if (m_own) CloseSocket(); m_s = s; m_own = own; SocketChanged(); } socket_t Socket::DetachSocket() { socket_t s = m_s; m_s = INVALID_SOCKET; SocketChanged(); return s; } void Socket::Create(int nType) { assert(m_s == INVALID_SOCKET); m_s = socket(AF_INET, nType, 0); CheckAndHandleError("socket", m_s); m_own = true; SocketChanged(); } void Socket::CloseSocket() { if (m_s != INVALID_SOCKET) { #ifdef USE_WINDOWS_STYLE_SOCKETS CancelIo((HANDLE) m_s); CheckAndHandleError_int("closesocket", closesocket(m_s)); #else CheckAndHandleError_int("close", close(m_s)); #endif m_s = INVALID_SOCKET; SocketChanged(); } } void Socket::Bind(unsigned int port, const char *addr) { sockaddr_in sa; memset(&sa, 0, sizeof(sa)); sa.sin_family = AF_INET; if (addr == NULL) sa.sin_addr.s_addr = htonl(INADDR_ANY); else { unsigned long result = inet_addr(addr); if (result == -1) // Solaris doesn't have INADDR_NONE { SetLastError(SOCKET_EINVAL); CheckAndHandleError_int("inet_addr", SOCKET_ERROR); } sa.sin_addr.s_addr = result; } sa.sin_port = htons((u_short)port); Bind((sockaddr *)&sa, sizeof(sa)); } void Socket::Bind(const sockaddr *psa, socklen_t saLen) { assert(m_s != INVALID_SOCKET); // cygwin workaround: needs const_cast CheckAndHandleError_int("bind", bind(m_s, const_cast(psa), saLen)); } void Socket::Listen(int backlog) { assert(m_s != INVALID_SOCKET); CheckAndHandleError_int("listen", listen(m_s, backlog)); } bool Socket::Connect(const char *addr, unsigned int port) { assert(addr != NULL); sockaddr_in sa; memset(&sa, 0, sizeof(sa)); sa.sin_family = AF_INET; sa.sin_addr.s_addr = inet_addr(addr); if (sa.sin_addr.s_addr == -1) // Solaris doesn't have INADDR_NONE { hostent *lphost = gethostbyname(addr); if (lphost == NULL) { SetLastError(SOCKET_EINVAL); CheckAndHandleError_int("gethostbyname", SOCKET_ERROR); } sa.sin_addr.s_addr = ((in_addr *)lphost->h_addr)->s_addr; } sa.sin_port = htons((u_short)port); return Connect((const sockaddr *)&sa, sizeof(sa)); } bool Socket::Connect(const sockaddr* psa, socklen_t saLen) { assert(m_s != INVALID_SOCKET); int result = connect(m_s, const_cast(psa), saLen); if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK) return false; CheckAndHandleError_int("connect", result); return true; } bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen) { assert(m_s != INVALID_SOCKET); socket_t s = accept(m_s, psa, psaLen); if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK) return false; CheckAndHandleError("accept", s); target.AttachSocket(s, true); return true; } void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen) { assert(m_s != INVALID_SOCKET); CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen)); } void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen) { assert(m_s != INVALID_SOCKET); CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen)); } unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags) { assert(m_s != INVALID_SOCKET); int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags); CheckAndHandleError_int("send", result); return result; } unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags) { assert(m_s != INVALID_SOCKET); int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags); CheckAndHandleError_int("recv", result); return result; } void Socket::ShutDown(int how) { assert(m_s != INVALID_SOCKET); int result = shutdown(m_s, how); CheckAndHandleError_int("shutdown", result); } void Socket::IOCtl(long cmd, unsigned long *argp) { assert(m_s != INVALID_SOCKET); #ifdef USE_WINDOWS_STYLE_SOCKETS CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp)); #else CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp)); #endif } bool Socket::SendReady(const timeval *timeout) { fd_set fds; FD_ZERO(&fds); FD_SET(m_s, &fds); int ready; if (timeout == NULL) ready = select((int)m_s+1, NULL, &fds, NULL, NULL); else { timeval timeoutCopy = *timeout; // select() modified timeout on Linux ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy); } CheckAndHandleError_int("select", ready); return ready > 0; } bool Socket::ReceiveReady(const timeval *timeout) { fd_set fds; FD_ZERO(&fds); FD_SET(m_s, &fds); int ready; if (timeout == NULL) ready = select((int)m_s+1, &fds, NULL, NULL, NULL); else { timeval timeoutCopy = *timeout; // select() modified timeout on Linux ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy); } CheckAndHandleError_int("select", ready); return ready > 0; } unsigned int Socket::PortNameToNumber(const char *name, const char *protocol) { int port = atoi(name); if (IntToString(port) == name) return port; servent *se = getservbyname(name, protocol); if (!se) throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL); return ntohs(se->s_port); } void Socket::StartSockets() { #ifdef USE_WINDOWS_STYLE_SOCKETS WSADATA wsd; int result = WSAStartup(0x0202, &wsd); if (result != 0) throw Err(INVALID_SOCKET, "WSAStartup", result); #endif } void Socket::ShutdownSockets() { #ifdef USE_WINDOWS_STYLE_SOCKETS int result = WSACleanup(); if (result != 0) throw Err(INVALID_SOCKET, "WSACleanup", result); #endif } int Socket::GetLastError() { #ifdef USE_WINDOWS_STYLE_SOCKETS return WSAGetLastError(); #else return errno; #endif } void Socket::SetLastError(int errorCode) { #ifdef USE_WINDOWS_STYLE_SOCKETS WSASetLastError(errorCode); #else errno = errorCode; #endif } void Socket::HandleError(const char *operation) const { int err = GetLastError(); throw Err(m_s, operation, err); } #ifdef USE_WINDOWS_STYLE_SOCKETS SocketReceiver::SocketReceiver(Socket &s) : m_s(s), m_resultPending(false), m_eofReceived(false) { m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true); m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid()); memset(&m_overlapped, 0, sizeof(m_overlapped)); m_overlapped.hEvent = m_event; } SocketReceiver::~SocketReceiver() { #ifdef USE_WINDOWS_STYLE_SOCKETS CancelIo((HANDLE) m_s.GetSocket()); #endif } bool SocketReceiver::Receive(byte* buf, size_t bufLen) { assert(!m_resultPending && !m_eofReceived); DWORD flags = 0; // don't queue too much at once, or we might use up non-paged memory WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf}; if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0) { if (m_lastResult == 0) m_eofReceived = true; } else { switch (WSAGetLastError()) { default: m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR); case WSAEDISCON: m_lastResult = 0; m_eofReceived = true; break; case WSA_IO_PENDING: m_resultPending = true; } } return !m_resultPending; } void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack) { if (m_resultPending) container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack)); else if (!m_eofReceived) container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack)); } unsigned int SocketReceiver::GetReceiveResult() { if (m_resultPending) { DWORD flags = 0; if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags)) { if (m_lastResult == 0) m_eofReceived = true; } else { switch (WSAGetLastError()) { default: m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE); case WSAEDISCON: m_lastResult = 0; m_eofReceived = true; } } m_resultPending = false; } return m_lastResult; } // ************************************************************* SocketSender::SocketSender(Socket &s) : m_s(s), m_resultPending(false), m_lastResult(0) { m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true); m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid()); memset(&m_overlapped, 0, sizeof(m_overlapped)); m_overlapped.hEvent = m_event; } SocketSender::~SocketSender() { #ifdef USE_WINDOWS_STYLE_SOCKETS CancelIo((HANDLE) m_s.GetSocket()); #endif } void SocketSender::Send(const byte* buf, size_t bufLen) { assert(!m_resultPending); DWORD written = 0; // don't queue too much at once, or we might use up non-paged memory WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf}; if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0) { m_resultPending = false; m_lastResult = written; } else { if (WSAGetLastError() != WSA_IO_PENDING) m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR); m_resultPending = true; } } void SocketSender::SendEof() { assert(!m_resultPending); m_s.ShutDown(SD_SEND); m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event)); m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE)); m_resultPending = true; } bool SocketSender::EofSent() { if (m_resultPending) { WSANETWORKEVENTS events; m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events)); if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE) throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL); if (events.iErrorCode[FD_CLOSE_BIT] != 0) throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]); m_resultPending = false; } return m_lastResult != 0; } void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack) { if (m_resultPending) container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack)); else container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack)); } unsigned int SocketSender::GetSendResult() { if (m_resultPending) { DWORD flags = 0; BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags); m_s.CheckAndHandleError("WSAGetOverlappedResult", result); m_resultPending = false; } return m_lastResult; } #endif #ifdef USE_BERKELEY_STYLE_SOCKETS SocketReceiver::SocketReceiver(Socket &s) : m_s(s), m_lastResult(0), m_eofReceived(false) { } void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack) { if (!m_eofReceived) container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack)); } bool SocketReceiver::Receive(byte* buf, size_t bufLen) { m_lastResult = m_s.Receive(buf, bufLen); if (bufLen > 0 && m_lastResult == 0) m_eofReceived = true; return true; } unsigned int SocketReceiver::GetReceiveResult() { return m_lastResult; } SocketSender::SocketSender(Socket &s) : m_s(s), m_lastResult(0) { } void SocketSender::Send(const byte* buf, size_t bufLen) { m_lastResult = m_s.Send(buf, bufLen); } void SocketSender::SendEof() { m_s.ShutDown(SD_SEND); } unsigned int SocketSender::GetSendResult() { return m_lastResult; } void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack) { container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack)); } #endif NAMESPACE_END #endif // #ifdef SOCKETS_AVAILABLE