summaryrefslogtreecommitdiff
path: root/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp')
-rw-r--r--qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp159
1 files changed, 97 insertions, 62 deletions
diff --git a/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp b/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp
index f7483a220c..30234bb686 100644
--- a/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp
+++ b/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp
@@ -25,11 +25,13 @@
#include "qpid/Exception.h"
#include "qpid/sys/posix/check.h"
#include "qpid/sys/posix/PrivatePosix.h"
+#include "qpid/log/Statement.h"
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/errno.h>
+#include <poll.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
@@ -50,36 +52,6 @@ namespace sys {
namespace ssl {
namespace {
-std::string getName(int fd, bool local, bool includeService = false)
-{
- ::sockaddr_storage name; // big enough for any socket address
- ::socklen_t namelen = sizeof(name);
-
- int result = -1;
- if (local) {
- result = ::getsockname(fd, (::sockaddr*)&name, &namelen);
- } else {
- result = ::getpeername(fd, (::sockaddr*)&name, &namelen);
- }
-
- QPID_POSIX_CHECK(result);
-
- char servName[NI_MAXSERV];
- char dispName[NI_MAXHOST];
- if (includeService) {
- if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName),
- servName, sizeof(servName),
- NI_NUMERICHOST | NI_NUMERICSERV) != 0)
- throw QPID_POSIX_ERROR(rc);
- return std::string(dispName) + ":" + std::string(servName);
-
- } else {
- if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), 0, 0, NI_NUMERICHOST) != 0)
- throw QPID_POSIX_ERROR(rc);
- return dispName;
- }
-}
-
std::string getService(int fd, bool local)
{
::sockaddr_storage name; // big enough for any socket address
@@ -132,7 +104,7 @@ std::string getDomainFromSubject(std::string subject)
}
-SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0)
+SslSocket::SslSocket() : socket(0), prototype(0)
{
impl->fd = ::socket (PF_INET, SOCK_STREAM, 0);
if (impl->fd < 0) throw QPID_POSIX_ERROR(errno);
@@ -144,7 +116,7 @@ SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0
* returned from accept. Because we use posix accept rather than
* PR_Accept, we have to reset the handshake.
*/
-SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : IOHandle(ioph), socket(0), prototype(0)
+SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0)
{
socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd));
NSS_CHECK(SSL_ResetHandshake(socket, true));
@@ -238,6 +210,7 @@ int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, b
SslSocket* SslSocket::accept() const
{
+ QPID_LOG(trace, "Accepting SSL connection.");
int afd = ::accept(impl->fd, 0, 0);
if ( afd >= 0) {
return new SslSocket(new IOHandlePrivate(afd), prototype);
@@ -248,36 +221,109 @@ SslSocket* SslSocket::accept() const
}
}
-int SslSocket::read(void *buf, size_t count) const
-{
- return PR_Read(socket, buf, count);
-}
+#define SSL_STREAM_MAX_WAIT_ms 20
+#define SSL_STREAM_MAX_RETRIES 2
-int SslSocket::write(const void *buf, size_t count) const
-{
- return PR_Write(socket, buf, count);
-}
+static bool isSslStream(int afd) {
+ int retries = SSL_STREAM_MAX_RETRIES;
+ unsigned char buf[5] = {};
-std::string SslSocket::getSockname() const
-{
- return getName(impl->fd, true);
+ do {
+ struct pollfd fd = {afd, POLLIN, 0};
+
+ /*
+ * Note that this is blocking the accept thread, so connections that
+ * send no data can limit the rate at which we can accept new
+ * connections.
+ */
+ if (::poll(&fd, 1, SSL_STREAM_MAX_WAIT_ms) > 0) {
+ errno = 0;
+ int result = recv(afd, buf, sizeof(buf), MSG_PEEK | MSG_DONTWAIT);
+ if (result == sizeof(buf)) {
+ break;
+ }
+ if (errno && errno != EAGAIN) {
+ int err = errno;
+ ::close(afd);
+ throw QPID_POSIX_ERROR(err);
+ }
+ }
+ } while (retries-- > 0);
+
+ if (retries < 0) {
+ return false;
+ }
+
+ /*
+ * SSLv2 Client Hello format
+ * http://www.mozilla.org/projects/security/pki/nss/ssl/draft02.html
+ *
+ * Bytes 0-1: RECORD-LENGTH
+ * Byte 2: MSG-CLIENT-HELLO (1)
+ * Byte 3: CLIENT-VERSION-MSB
+ * Byte 4: CLIENT-VERSION-LSB
+ *
+ * Allowed versions:
+ * 2.0 - SSLv2
+ * 3.0 - SSLv3
+ * 3.1 - TLS 1.0
+ * 3.2 - TLS 1.1
+ * 3.3 - TLS 1.2
+ *
+ * The version sent in the Client-Hello is the latest version supported by
+ * the client. NSS may send version 3.x in an SSLv2 header for
+ * maximum compatibility.
+ */
+ bool isSSL2Handshake = buf[2] == 1 && // MSG-CLIENT-HELLO
+ ((buf[3] == 3 && buf[4] <= 3) || // SSL 3.0 & TLS 1.0-1.2 (v3.1-3.3)
+ (buf[3] == 2 && buf[4] == 0)); // SSL 2
+
+ /*
+ * SSLv3/TLS Client Hello format
+ * RFC 2246
+ *
+ * Byte 0: ContentType (handshake - 22)
+ * Bytes 1-2: ProtocolVersion {major, minor}
+ *
+ * Allowed versions:
+ * 3.0 - SSLv3
+ * 3.1 - TLS 1.0
+ * 3.2 - TLS 1.1
+ * 3.3 - TLS 1.2
+ */
+ bool isSSL3Handshake = buf[0] == 22 && // handshake
+ (buf[1] == 3 && buf[2] <= 3); // SSL 3.0 & TLS 1.0-1.2 (v3.1-3.3)
+
+ return isSSL2Handshake || isSSL3Handshake;
}
-std::string SslSocket::getPeername() const
+Socket* SslMuxSocket::accept() const
{
- return getName(impl->fd, false);
+ int afd = ::accept(impl->fd, 0, 0);
+ if (afd >= 0) {
+ QPID_LOG(trace, "Accepting connection with optional SSL wrapper.");
+ if (isSslStream(afd)) {
+ QPID_LOG(trace, "Accepted SSL connection.");
+ return new SslSocket(new IOHandlePrivate(afd), prototype);
+ } else {
+ QPID_LOG(trace, "Accepted Plaintext connection.");
+ return new Socket(new IOHandlePrivate(afd));
+ }
+ } else if (errno == EAGAIN) {
+ return 0;
+ } else {
+ throw QPID_POSIX_ERROR(errno);
+ }
}
-std::string SslSocket::getPeerAddress() const
+int SslSocket::read(void *buf, size_t count) const
{
- if (!connectname.empty())
- return connectname;
- return getName(impl->fd, false, true);
+ return PR_Read(socket, buf, count);
}
-std::string SslSocket::getLocalAddress() const
+int SslSocket::write(const void *buf, size_t count) const
{
- return getName(impl->fd, true, true);
+ return PR_Write(socket, buf, count);
}
uint16_t SslSocket::getLocalPort() const
@@ -290,17 +336,6 @@ uint16_t SslSocket::getRemotePort() const
return atoi(getService(impl->fd, true).c_str());
}
-int SslSocket::getError() const
-{
- int result;
- socklen_t rSize = sizeof (result);
-
- if (::getsockopt(impl->fd, SOL_SOCKET, SO_ERROR, &result, &rSize) < 0)
- throw QPID_POSIX_ERROR(errno);
-
- return result;
-}
-
void SslSocket::setTcpNoDelay(bool nodelay) const
{
if (nodelay) {