summaryrefslogtreecommitdiff
path: root/cpp/src/qpid/sys/ssl/SslSocket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'cpp/src/qpid/sys/ssl/SslSocket.cpp')
-rw-r--r--cpp/src/qpid/sys/ssl/SslSocket.cpp182
1 files changed, 71 insertions, 111 deletions
diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp
index 30234bb686..a328e49c13 100644
--- a/cpp/src/qpid/sys/ssl/SslSocket.cpp
+++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp
@@ -20,6 +20,7 @@
*/
#include "qpid/sys/ssl/SslSocket.h"
+#include "qpid/sys/SocketAddress.h"
#include "qpid/sys/ssl/check.h"
#include "qpid/sys/ssl/util.h"
#include "qpid/Exception.h"
@@ -52,28 +53,6 @@ namespace sys {
namespace ssl {
namespace {
-std::string getService(int fd, bool local)
-{
- ::sockaddr_storage name; // big enough for any socket address
- ::socklen_t namelen = sizeof(name);
-
- int result = -1;
- if (local) {
- result = ::getsockname(fd, (::sockaddr*)&name, &namelen);
- } else {
- result = ::getpeername(fd, (::sockaddr*)&name, &namelen);
- }
-
- QPID_POSIX_CHECK(result);
-
- char servName[NI_MAXSERV];
- if (int rc=::getnameinfo((::sockaddr*)&name, namelen, 0, 0,
- servName, sizeof(servName),
- NI_NUMERICHOST | NI_NUMERICSERV) != 0)
- throw QPID_POSIX_ERROR(rc);
- return servName;
-}
-
const std::string DOMAIN_SEPARATOR("@");
const std::string DC_SEPARATOR(".");
const std::string DC("DC");
@@ -101,14 +80,18 @@ std::string getDomainFromSubject(std::string subject)
}
return domain;
}
-
}
-SslSocket::SslSocket() : socket(0), prototype(0)
+SslSocket::SslSocket(const std::string& certName, bool clientAuth) :
+ nssSocket(0), certname(certName), prototype(0)
{
- impl->fd = ::socket (PF_INET, SOCK_STREAM, 0);
- if (impl->fd < 0) throw QPID_POSIX_ERROR(errno);
- socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd));
+ //configure prototype socket:
+ prototype = SSL_ImportFD(0, PR_NewTCPSocket());
+
+ if (clientAuth) {
+ NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE));
+ NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE));
+ }
}
/**
@@ -116,25 +99,44 @@ SslSocket::SslSocket() : socket(0), prototype(0)
* returned from accept. Because we use posix accept rather than
* PR_Accept, we have to reset the handshake.
*/
-SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0)
+SslSocket::SslSocket(int fd, PRFileDesc* model) : BSDSocket(fd), nssSocket(0), prototype(0)
{
- socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd));
- NSS_CHECK(SSL_ResetHandshake(socket, true));
+ nssSocket = SSL_ImportFD(model, PR_ImportTCPSocket(fd));
+ NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_TRUE));
}
void SslSocket::setNonblocking() const
{
+ if (!nssSocket) {
+ BSDSocket::setNonblocking();
+ return;
+ }
PRSocketOptionData option;
option.option = PR_SockOpt_Nonblocking;
option.value.non_blocking = true;
- PR_SetSocketOption(socket, &option);
+ PR_SetSocketOption(nssSocket, &option);
}
-void SslSocket::connect(const std::string& host, const std::string& port) const
+void SslSocket::setTcpNoDelay() const
{
- std::stringstream namestream;
- namestream << host << ":" << port;
- connectname = namestream.str();
+ if (!nssSocket) {
+ BSDSocket::setTcpNoDelay();
+ return;
+ }
+ PRSocketOptionData option;
+ option.option = PR_SockOpt_NoDelay;
+ option.value.no_delay = true;
+ PR_SetSocketOption(nssSocket, &option);
+}
+
+void SslSocket::connect(const SocketAddress& addr) const
+{
+ BSDSocket::connect(addr);
+}
+
+void SslSocket::finishConnect(const SocketAddress& addr) const
+{
+ nssSocket = SSL_ImportFD(0, PR_ImportTCPSocket(fd));
void* arg;
// Use the connection's cert-name if it has one; else use global cert-name
@@ -145,75 +147,48 @@ void SslSocket::connect(const std::string& host, const std::string& port) const
} else {
arg = const_cast<char*>(SslOptions::global.certName.c_str());
}
- NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg));
- NSS_CHECK(SSL_SetURL(socket, host.data()));
-
- char hostBuffer[PR_NETDB_BUF_SIZE];
- PRHostEnt hostEntry;
- PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry));
- PRNetAddr address;
- int value = PR_EnumerateHostEnt(0, &hostEntry, boost::lexical_cast<PRUint16>(port), &address);
- if (value < 0) {
- throw Exception(QPID_MSG("Error getting address for host: " << ErrorString()));
- } else if (value == 0) {
- throw Exception(QPID_MSG("Could not resolve address for host."));
- }
- PR_CHECK(PR_Connect(socket, &address, PR_INTERVAL_NO_TIMEOUT));
- NSS_CHECK(SSL_ForceHandshake(socket));
+ NSS_CHECK(SSL_GetClientAuthDataHook(nssSocket, NSS_GetClientAuthData, arg));
+
+ url = addr.getHost();
+ NSS_CHECK(SSL_SetURL(nssSocket, url.data()));
+
+ NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_FALSE));
+ NSS_CHECK(SSL_ForceHandshake(nssSocket));
}
void SslSocket::close() const
{
- if (impl->fd > 0) {
- PR_Close(socket);
- impl->fd = -1;
+ if (!nssSocket) {
+ BSDSocket::close();
+ return;
+ }
+ if (fd > 0) {
+ PR_Close(nssSocket);
+ fd = -1;
}
}
-int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, bool clientAuth) const
+int SslSocket::listen(const SocketAddress& sa, int backlog) const
{
- //configure prototype socket:
- prototype = SSL_ImportFD(0, PR_NewTCPSocket());
- if (clientAuth) {
- NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE));
- NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE));
- }
-
//get certificate and key (is this the correct way?)
- CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(certName.c_str()), 0);
- if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << certName << "'"));
+ std::string cName( (certname == "") ? "localhost.localdomain" : certname);
+ CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(cName.c_str()), 0);
+ if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << cName << "'"));
SECKEYPrivateKey *key = PK11_FindKeyByAnyCert(cert, 0);
if (!key) throw Exception(QPID_MSG("Failed to retrieve private key from certificate"));
NSS_CHECK(SSL_ConfigSecureServer(prototype, cert, key, NSS_FindCertKEAType(cert)));
SECKEY_DestroyPrivateKey(key);
CERT_DestroyCertificate(cert);
- //bind and listen
- const int& socket = impl->fd;
- int yes=1;
- QPID_POSIX_CHECK(setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes)));
- struct sockaddr_in name;
- name.sin_family = AF_INET;
- name.sin_port = htons(port);
- name.sin_addr.s_addr = 0;
- if (::bind(socket, (struct sockaddr*)&name, sizeof(name)) < 0)
- throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno)));
- if (::listen(socket, backlog) < 0)
- throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno)));
-
- socklen_t namelen = sizeof(name);
- if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0)
- throw QPID_POSIX_ERROR(errno);
-
- return ntohs(name.sin_port);
+ return BSDSocket::listen(sa, backlog);
}
-SslSocket* SslSocket::accept() const
+Socket* SslSocket::accept() const
{
QPID_LOG(trace, "Accepting SSL connection.");
- int afd = ::accept(impl->fd, 0, 0);
+ int afd = ::accept(fd, 0, 0);
if ( afd >= 0) {
- return new SslSocket(new IOHandlePrivate(afd), prototype);
+ return new SslSocket(afd, prototype);
} else if (errno == EAGAIN) {
return 0;
} else {
@@ -297,17 +272,22 @@ static bool isSslStream(int afd) {
return isSSL2Handshake || isSSL3Handshake;
}
+SslMuxSocket::SslMuxSocket(const std::string& certName, bool clientAuth) :
+ SslSocket(certName, clientAuth)
+{
+}
+
Socket* SslMuxSocket::accept() const
{
- int afd = ::accept(impl->fd, 0, 0);
+ int afd = ::accept(fd, 0, 0);
if (afd >= 0) {
QPID_LOG(trace, "Accepting connection with optional SSL wrapper.");
if (isSslStream(afd)) {
QPID_LOG(trace, "Accepted SSL connection.");
- return new SslSocket(new IOHandlePrivate(afd), prototype);
+ return new SslSocket(afd, prototype);
} else {
QPID_LOG(trace, "Accepted Plaintext connection.");
- return new Socket(new IOHandlePrivate(afd));
+ return new BSDSocket(afd);
}
} else if (errno == EAGAIN) {
return 0;
@@ -318,32 +298,12 @@ Socket* SslMuxSocket::accept() const
int SslSocket::read(void *buf, size_t count) const
{
- return PR_Read(socket, buf, count);
+ return PR_Read(nssSocket, buf, count);
}
int SslSocket::write(const void *buf, size_t count) const
{
- return PR_Write(socket, buf, count);
-}
-
-uint16_t SslSocket::getLocalPort() const
-{
- return std::atoi(getService(impl->fd, true).c_str());
-}
-
-uint16_t SslSocket::getRemotePort() const
-{
- return atoi(getService(impl->fd, true).c_str());
-}
-
-void SslSocket::setTcpNoDelay(bool nodelay) const
-{
- if (nodelay) {
- PRSocketOptionData option;
- option.option = PR_SockOpt_NoDelay;
- option.value.no_delay = true;
- PR_SetSocketOption(socket, &option);
- }
+ return PR_Write(nssSocket, buf, count);
}
void SslSocket::setCertName(const std::string& name)
@@ -359,7 +319,7 @@ int SslSocket::getKeyLen() const
int keySize = 0;
SECStatus rc;
- rc = SSL_SecurityStatus( socket,
+ rc = SSL_SecurityStatus( nssSocket,
&enabled,
NULL,
NULL,
@@ -374,7 +334,7 @@ int SslSocket::getKeyLen() const
std::string SslSocket::getClientAuthId() const
{
std::string authId;
- CERTCertificate* cert = SSL_PeerCertificate(socket);
+ CERTCertificate* cert = SSL_PeerCertificate(nssSocket);
if (cert) {
authId = CERT_GetCommonName(&(cert->subject));
/*