diff options
Diffstat (limited to 'qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp')
-rw-r--r-- | qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp | 47 |
1 files changed, 30 insertions, 17 deletions
diff --git a/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp b/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp index 8ebc5937d2..01e2658877 100644 --- a/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp +++ b/qpid/cpp/src/qpid/sys/ssl/SslSocket.cpp @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -52,9 +52,9 @@ namespace ssl { namespace { std::string getName(int fd, bool local, bool includeService = false) { - ::sockaddr_storage name; // big enough for any socket address + ::sockaddr_storage name; // big enough for any socket address ::socklen_t namelen = sizeof(name); - + int result = -1; if (local) { result = ::getsockname(fd, (::sockaddr*)&name, &namelen); @@ -67,8 +67,8 @@ std::string getName(int fd, bool local, bool includeService = false) char servName[NI_MAXSERV]; char dispName[NI_MAXHOST]; if (includeService) { - if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), - servName, sizeof(servName), + if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), + servName, sizeof(servName), NI_NUMERICHOST | NI_NUMERICSERV) != 0) throw QPID_POSIX_ERROR(rc); return std::string(dispName) + ":" + std::string(servName); @@ -82,9 +82,9 @@ std::string getName(int fd, bool local, bool includeService = false) std::string getService(int fd, bool local) { - ::sockaddr_storage name; // big enough for any socket address + ::sockaddr_storage name; // big enough for any socket address ::socklen_t namelen = sizeof(name); - + int result = -1; if (local) { result = ::getsockname(fd, (::sockaddr*)&name, &namelen); @@ -95,8 +95,8 @@ std::string getService(int fd, bool local) QPID_POSIX_CHECK(result); char servName[NI_MAXSERV]; - if (int rc=::getnameinfo((::sockaddr*)&name, namelen, 0, 0, - servName, sizeof(servName), + if (int rc=::getnameinfo((::sockaddr*)&name, namelen, 0, 0, + servName, sizeof(servName), NI_NUMERICHOST | NI_NUMERICSERV) != 0) throw QPID_POSIX_ERROR(rc); return servName; @@ -132,8 +132,8 @@ std::string getDomainFromSubject(std::string subject) } -SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0) -{ +SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0) +{ impl->fd = ::socket (PF_INET, SOCK_STREAM, 0); if (impl->fd < 0) throw QPID_POSIX_ERROR(errno); socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd)); @@ -145,12 +145,12 @@ SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0 * PR_Accept, we have to reset the handshake. */ SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : IOHandle(ioph), socket(0), prototype(0) -{ +{ socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); NSS_CHECK(SSL_ResetHandshake(socket, true)); } -void SslSocket::setNonblocking() const +void SslSocket::setNonblocking() const { PRSocketOptionData option; option.option = PR_SockOpt_Nonblocking; @@ -164,7 +164,15 @@ void SslSocket::connect(const std::string& host, uint16_t port) const namestream << host << ":" << port; connectname = namestream.str(); - void* arg = SslOptions::global.certName.empty() ? 0 : const_cast<char*>(SslOptions::global.certName.c_str()); + void* arg; + // Use the connection's cert-name if it has one; else use global cert-name + if (certname != "") { + arg = const_cast<char*>(certname.c_str()); + } else if (SslOptions::global.certName.empty()) { + arg = 0; + } else { + arg = const_cast<char*>(SslOptions::global.certName.c_str()); + } NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg)); NSS_CHECK(SSL_SetURL(socket, host.data())); @@ -220,7 +228,7 @@ int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, b throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno))); if (::listen(socket, backlog) < 0) throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno))); - + socklen_t namelen = sizeof(name); if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0) throw QPID_POSIX_ERROR(errno); @@ -235,7 +243,7 @@ SslSocket* SslSocket::accept() const return new SslSocket(new IOHandlePrivate(afd), prototype); } else if (errno == EAGAIN) { return 0; - } else { + } else { throw QPID_POSIX_ERROR(errno); } } @@ -303,6 +311,11 @@ void SslSocket::setTcpNoDelay(bool nodelay) const } } +void SslSocket::setCertName(const std::string& name) +{ + certname = name; +} + /** get the bit length of the current cipher's key */ int SslSocket::getKeyLen() const |