diff options
author | Andrew Stitcher <astitcher@apache.org> | 2012-10-24 05:51:19 +0000 |
---|---|---|
committer | Andrew Stitcher <astitcher@apache.org> | 2012-10-24 05:51:19 +0000 |
commit | c61f7c0c79717b0eb842d0c4c88deeda9f7672e6 (patch) | |
tree | baabe1be2b4389b3412154598fe0ea543ddd63f5 /cpp/src | |
parent | f47c28c9ee2fb8f2967a221a28912060edcba749 (diff) | |
download | qpid-python-c61f7c0c79717b0eb842d0c4c88deeda9f7672e6.tar.gz |
QPID-4272: Large amounts of code are duplicated between the SSL and TCP transports
Refactor to unify the various SSL and TCP interfaces:
- Make ssl listen have the same signature as regular sockets
- Give ssl connect same interface as tcp
- Unify the SSL and TCP IO interfaces some more
git-svn-id: https://svn.apache.org/repos/asf/qpid/trunk/qpid@1401558 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'cpp/src')
-rw-r--r-- | cpp/src/qpid/client/SslConnector.cpp | 22 | ||||
-rw-r--r-- | cpp/src/qpid/sys/Socket.h | 17 | ||||
-rw-r--r-- | cpp/src/qpid/sys/SocketAddress.h | 3 | ||||
-rw-r--r-- | cpp/src/qpid/sys/SslPlugin.cpp | 35 | ||||
-rw-r--r-- | cpp/src/qpid/sys/posix/Socket.cpp | 12 | ||||
-rw-r--r-- | cpp/src/qpid/sys/posix/SocketAddress.cpp | 5 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslHandler.cpp | 16 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslHandler.h | 18 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslIo.cpp | 23 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslIo.h | 54 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.cpp | 123 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.h | 16 | ||||
-rw-r--r-- | cpp/src/qpid/sys/windows/AsynchIO.cpp | 3 | ||||
-rw-r--r-- | cpp/src/qpid/sys/windows/Socket.cpp | 12 |
14 files changed, 153 insertions, 206 deletions
diff --git a/cpp/src/qpid/client/SslConnector.cpp b/cpp/src/qpid/client/SslConnector.cpp index 0b07d14f35..c49deaa279 100644 --- a/cpp/src/qpid/client/SslConnector.cpp +++ b/cpp/src/qpid/client/SslConnector.cpp @@ -79,11 +79,11 @@ class SslConnector : public Connector ~SslConnector(); - void readbuff(qpid::sys::ssl::SslIO&, qpid::sys::ssl::SslIOBufferBase*); - void writebuff(qpid::sys::ssl::SslIO&); + void readbuff(AsynchIO&, AsynchIOBufferBase*); + void writebuff(AsynchIO&); void writeDataBlock(const framing::AMQDataBlock& data); - void eof(qpid::sys::ssl::SslIO&); - void disconnected(qpid::sys::ssl::SslIO&); + void eof(AsynchIO&); + void disconnected(AsynchIO&); void connect(const std::string& host, const std::string& port); void close(); @@ -96,7 +96,7 @@ class SslConnector : public Connector framing::OutputHandler* getOutputHandler(); const std::string& getIdentifier() const; const SecuritySettings* getSecuritySettings(); - void socketClosed(qpid::sys::ssl::SslIO&, const qpid::sys::ssl::SslSocket&); + void socketClosed(AsynchIO&, const Socket&); size_t decode(const char* buffer, size_t size); size_t encode(char* buffer, size_t size); @@ -168,7 +168,7 @@ void SslConnector::connect(const std::string& host, const std::string& port){ Mutex::ScopedLock l(lock); assert(closed); try { - socket.connect(host, port); + socket.connect(SocketAddress(host, port)); } catch (const std::exception& e) { socket.close(); throw TransportFailure(e.what()); @@ -199,7 +199,7 @@ void SslConnector::close() { } } -void SslConnector::socketClosed(SslIO&, const SslSocket&) { +void SslConnector::socketClosed(AsynchIO&, const Socket&) { if (aio) aio->queueForDeletion(); if (shutdownHandler) @@ -255,7 +255,7 @@ void SslConnector::send(AMQFrame& frame) { } } -void SslConnector::writebuff(SslIO& /*aio*/) +void SslConnector::writebuff(AsynchIO& /*aio*/) { // It's possible to be disconnected and be writable if (closed) @@ -304,7 +304,7 @@ size_t SslConnector::encode(char* buffer, size_t size) return bytesWritten; } -void SslConnector::readbuff(SslIO& aio, SslIO::BufferBase* buff) +void SslConnector::readbuff(AsynchIO& aio, AsynchIOBufferBase* buff) { int32_t decoded = decode(buff->bytes+buff->dataStart, buff->dataCount); // TODO: unreading needs to go away, and when we can cope @@ -351,11 +351,11 @@ void SslConnector::writeDataBlock(const AMQDataBlock& data) { aio->queueWrite(buff); } -void SslConnector::eof(SslIO&) { +void SslConnector::eof(AsynchIO&) { close(); } -void SslConnector::disconnected(SslIO&) { +void SslConnector::disconnected(AsynchIO&) { close(); socketClosed(*aio, socket); } diff --git a/cpp/src/qpid/sys/Socket.h b/cpp/src/qpid/sys/Socket.h index defec4879c..aa8a8a31d9 100644 --- a/cpp/src/qpid/sys/Socket.h +++ b/cpp/src/qpid/sys/Socket.h @@ -33,6 +33,10 @@ namespace sys { class Duration; class SocketAddress; +namespace ssl { +class SslMuxSocket; +} + class QPID_COMMON_CLASS_EXTERN Socket : public IOHandle { public: @@ -47,7 +51,6 @@ public: QPID_COMMON_EXTERN void setTcpNoDelay() const; - QPID_COMMON_EXTERN void connect(const std::string& host, const std::string& port) const; QPID_COMMON_EXTERN void connect(const SocketAddress&) const; QPID_COMMON_EXTERN void close() const; @@ -57,7 +60,6 @@ public: *@param backlog maximum number of pending connections. *@return The bound port. */ - QPID_COMMON_EXTERN int listen(const std::string& host = "", const std::string& port = "0", int backlog = 10) const; QPID_COMMON_EXTERN int listen(const SocketAddress&, int backlog = 10) const; /** @@ -91,19 +93,18 @@ public: QPID_COMMON_EXTERN int read(void *buf, size_t count) const; QPID_COMMON_EXTERN int write(const void *buf, size_t count) const; -private: +protected: /** Create socket */ void createSocket(const SocketAddress&) const; -public: - /** Construct socket with existing handle */ - Socket(IOHandlePrivate*); - -protected: mutable std::string localname; mutable std::string peername; mutable bool nonblocking; mutable bool nodelay; + + /** Construct socket with existing handle */ + Socket(IOHandlePrivate*); + friend class qpid::sys::ssl::SslMuxSocket; }; }} diff --git a/cpp/src/qpid/sys/SocketAddress.h b/cpp/src/qpid/sys/SocketAddress.h index dcca109d94..a4da5cca79 100644 --- a/cpp/src/qpid/sys/SocketAddress.h +++ b/cpp/src/qpid/sys/SocketAddress.h @@ -44,11 +44,12 @@ public: QPID_COMMON_EXTERN bool nextAddress(); QPID_COMMON_EXTERN std::string asString(bool numeric=true) const; + QPID_COMMON_EXTERN std::string getHost() const; QPID_COMMON_EXTERN void setAddrInfoPort(uint16_t port); QPID_COMMON_EXTERN static std::string asString(::sockaddr const * const addr, size_t addrlen); QPID_COMMON_EXTERN static uint16_t getPort(::sockaddr const * const addr); - + private: std::string host; diff --git a/cpp/src/qpid/sys/SslPlugin.cpp b/cpp/src/qpid/sys/SslPlugin.cpp index c14cb5f016..1cebadeab3 100644 --- a/cpp/src/qpid/sys/SslPlugin.cpp +++ b/cpp/src/qpid/sys/SslPlugin.cpp @@ -29,6 +29,7 @@ #include "qpid/sys/AsynchIO.h" #include "qpid/sys/ssl/SslIo.h" #include "qpid/sys/ssl/SslSocket.h" +#include "qpid/sys/SocketAddress.h" #include "qpid/broker/Broker.h" #include "qpid/log/Statement.h" @@ -68,8 +69,6 @@ template <class T> class SslProtocolFactoryTmpl : public ProtocolFactory { private: - typedef SslAcceptorTmpl<T> SslAcceptor; - Timer& brokerTimer; uint32_t maxNegotiateTime; const bool tcpNoDelay; @@ -79,7 +78,10 @@ class SslProtocolFactoryTmpl : public ProtocolFactory { bool nodict; public: - SslProtocolFactoryTmpl(const SslServerOptions&, int backlog, bool nodelay, Timer& timer, uint32_t maxTime); + SslProtocolFactoryTmpl(const std::string& host, const std::string& port, + const SslServerOptions&, + int backlog, bool nodelay, + Timer& timer, uint32_t maxTime); void accept(Poller::shared_ptr, ConnectionCodec::Factory*); void connect(Poller::shared_ptr, const std::string& host, const std::string& port, ConnectionCodec::Factory*, @@ -139,14 +141,16 @@ static struct SslPlugin : public Plugin { const broker::Broker::Options& opts = broker->getOptions(); ProtocolFactory::shared_ptr protocol(options.multiplex ? - static_cast<ProtocolFactory*>(new SslMuxProtocolFactory(options, - opts.connectionBacklog, - opts.tcpNoDelay, - broker->getTimer(), opts.maxNegotiateTime)) : - static_cast<ProtocolFactory*>(new SslProtocolFactory(options, - opts.connectionBacklog, - opts.tcpNoDelay, - broker->getTimer(), opts.maxNegotiateTime))); + static_cast<ProtocolFactory*>(new SslMuxProtocolFactory("", boost::lexical_cast<std::string>(options.port), + options, + opts.connectionBacklog, + opts.tcpNoDelay, + broker->getTimer(), opts.maxNegotiateTime)) : + static_cast<ProtocolFactory*>(new SslProtocolFactory("", boost::lexical_cast<std::string>(options.port), + options, + opts.connectionBacklog, + opts.tcpNoDelay, + broker->getTimer(), opts.maxNegotiateTime))); QPID_LOG(notice, "Listening for " << (options.multiplex ? "SSL or TCP" : "SSL") << " connections on TCP port " << @@ -161,10 +165,15 @@ static struct SslPlugin : public Plugin { } sslPlugin; template <class T> -SslProtocolFactoryTmpl<T>::SslProtocolFactoryTmpl(const SslServerOptions& options, int backlog, bool nodelay, Timer& timer, uint32_t maxTime) : +SslProtocolFactoryTmpl<T>::SslProtocolFactoryTmpl(const std::string& host, const std::string& port, + const SslServerOptions& options, + int backlog, bool nodelay, + Timer& timer, uint32_t maxTime) : brokerTimer(timer), maxNegotiateTime(maxTime), - tcpNoDelay(nodelay), listeningPort(listener.listen(options.port, backlog, options.certName, options.clientAuth)), + tcpNoDelay(nodelay), + listener(options.certName, options.clientAuth), + listeningPort(listener.listen(SocketAddress(host, port), backlog)), nodict(options.nodict) {} diff --git a/cpp/src/qpid/sys/posix/Socket.cpp b/cpp/src/qpid/sys/posix/Socket.cpp index 77ae1af60c..0c01374369 100644 --- a/cpp/src/qpid/sys/posix/Socket.cpp +++ b/cpp/src/qpid/sys/posix/Socket.cpp @@ -135,12 +135,6 @@ void Socket::setTcpNoDelay() const } } -void Socket::connect(const std::string& host, const std::string& port) const -{ - SocketAddress sa(host, port); - connect(sa); -} - void Socket::connect(const SocketAddress& addr) const { // The display name for an outbound connection needs to be the name that was specified @@ -188,12 +182,6 @@ Socket::close() const socket = -1; } -int Socket::listen(const std::string& host, const std::string& port, int backlog) const -{ - SocketAddress sa(host, port); - return listen(sa, backlog); -} - int Socket::listen(const SocketAddress& sa, int backlog) const { createSocket(sa); diff --git a/cpp/src/qpid/sys/posix/SocketAddress.cpp b/cpp/src/qpid/sys/posix/SocketAddress.cpp index 344bd28669..cd23442226 100644 --- a/cpp/src/qpid/sys/posix/SocketAddress.cpp +++ b/cpp/src/qpid/sys/posix/SocketAddress.cpp @@ -102,6 +102,11 @@ std::string SocketAddress::asString(bool numeric) const return asString(ai.ai_addr, ai.ai_addrlen); } +std::string SocketAddress::getHost() const +{ + return host; +} + bool SocketAddress::nextAddress() { bool r = currentAddrInfo->ai_next != 0; if (r) diff --git a/cpp/src/qpid/sys/ssl/SslHandler.cpp b/cpp/src/qpid/sys/ssl/SslHandler.cpp index 8668c7d8d0..6e079a8094 100644 --- a/cpp/src/qpid/sys/ssl/SslHandler.cpp +++ b/cpp/src/qpid/sys/ssl/SslHandler.cpp @@ -83,7 +83,7 @@ void SslHandler::init(SslIO* a, Timer& timer, uint32_t maxTime) { void SslHandler::write(const framing::ProtocolInitiation& data) { QPID_LOG(debug, "SENT [" << identifier << "]: INIT(" << data << ")"); - SslIO::BufferBase* buff = aio->getQueuedBuffer(); + AsynchIOBufferBase* buff = aio->getQueuedBuffer(); assert(buff); framing::Buffer out(buff->bytes, buff->byteCount); data.encode(out); @@ -106,7 +106,7 @@ void SslHandler::giveReadCredit(int32_t) { } // Input side -void SslHandler::readbuff(SslIO& , SslIO::BufferBase* buff) { +void SslHandler::readbuff(AsynchIO& , AsynchIOBufferBase* buff) { if (readError) { return; } @@ -160,13 +160,13 @@ void SslHandler::readbuff(SslIO& , SslIO::BufferBase* buff) { } } -void SslHandler::eof(SslIO&) { +void SslHandler::eof(AsynchIO&) { QPID_LOG(debug, "DISCONNECTED [" << identifier << "]"); if (codec) codec->closed(); aio->queueWriteClose(); } -void SslHandler::closedSocket(SslIO&, const SslSocket& s) { +void SslHandler::closedSocket(AsynchIO&, const Socket& s) { // If we closed with data still to send log a warning if (!aio->writeQueueEmpty()) { QPID_LOG(warning, "CLOSING [" << identifier << "] unsent data (probably due to client disconnect)"); @@ -176,16 +176,16 @@ void SslHandler::closedSocket(SslIO&, const SslSocket& s) { delete this; } -void SslHandler::disconnect(SslIO& a) { +void SslHandler::disconnect(AsynchIO& a) { // treat the same as eof eof(a); } // Notifications -void SslHandler::nobuffs(SslIO&) { +void SslHandler::nobuffs(AsynchIO&) { } -void SslHandler::idle(SslIO&){ +void SslHandler::idle(AsynchIO&){ if (isClient && codec == 0) { codec = factory->create(*this, identifier, getSecuritySettings(aio)); write(framing::ProtocolInitiation(codec->getVersion())); @@ -199,7 +199,7 @@ void SslHandler::idle(SslIO&){ if (!codec->canEncode()) { return; } - SslIO::BufferBase* buff = aio->getQueuedBuffer(); + AsynchIOBufferBase* buff = aio->getQueuedBuffer(); if (buff) { size_t encoded=codec->encode(buff->bytes, buff->byteCount); buff->dataCount = encoded; diff --git a/cpp/src/qpid/sys/ssl/SslHandler.h b/cpp/src/qpid/sys/ssl/SslHandler.h index 14814b0281..d25304b37e 100644 --- a/cpp/src/qpid/sys/ssl/SslHandler.h +++ b/cpp/src/qpid/sys/ssl/SslHandler.h @@ -24,6 +24,7 @@ #include "qpid/sys/ConnectionCodec.h" #include "qpid/sys/OutputControl.h" +#include "qpid/sys/SecuritySettings.h" #include <boost/intrusive_ptr.hpp> @@ -35,14 +36,15 @@ namespace framing { namespace sys { +class AsynchIO; +struct AsynchIOBufferBase; +class Socket; class Timer; class TimerTask; namespace ssl { class SslIO; -struct SslIOBufferBase; -class SslSocket; class SslHandler : public OutputControl { std::string identifier; @@ -70,14 +72,14 @@ class SslHandler : public OutputControl { void giveReadCredit(int32_t); // Input side - void readbuff(SslIO& aio, SslIOBufferBase* buff); - void eof(SslIO& aio); - void disconnect(SslIO& aio); + void readbuff(qpid::sys::AsynchIO&, qpid::sys::AsynchIOBufferBase* buff); + void eof(qpid::sys::AsynchIO&); + void disconnect(qpid::sys::AsynchIO& a); // Notifications - void nobuffs(SslIO& aio); - void idle(SslIO& aio); - void closedSocket(SslIO& aio, const SslSocket& s); + void nobuffs(qpid::sys::AsynchIO&); + void idle(qpid::sys::AsynchIO&); + void closedSocket(qpid::sys::AsynchIO&, const qpid::sys::Socket& s); }; }}} // namespace qpid::sys::ssl diff --git a/cpp/src/qpid/sys/ssl/SslIo.cpp b/cpp/src/qpid/sys/ssl/SslIo.cpp index bbfb703170..92e51a2234 100644 --- a/cpp/src/qpid/sys/ssl/SslIo.cpp +++ b/cpp/src/qpid/sys/ssl/SslIo.cpp @@ -68,32 +68,28 @@ __thread int64_t threadMaxIoTimeNs = 2 * 1000000; // start at 2ms * Asynch Acceptor */ -template <class T> -SslAcceptorTmpl<T>::SslAcceptorTmpl(const T& s, Callback callback) : +SslAcceptor::SslAcceptor(const Socket& s, Callback callback) : acceptedCallback(callback), - handle(s, boost::bind(&SslAcceptorTmpl<T>::readable, this, _1), 0, 0), + handle(s, boost::bind(&SslAcceptor::readable, this, _1), 0, 0), socket(s) { s.setNonblocking(); ignoreSigpipe(); } -template <class T> -SslAcceptorTmpl<T>::~SslAcceptorTmpl() +SslAcceptor::~SslAcceptor() { handle.stopWatch(); } -template <class T> -void SslAcceptorTmpl<T>::start(Poller::shared_ptr poller) { +void SslAcceptor::start(Poller::shared_ptr poller) { handle.startWatch(poller); } /* * We keep on accepting as long as there is something to accept */ -template <class T> -void SslAcceptorTmpl<T>::readable(DispatchHandle& h) { +void SslAcceptor::readable(DispatchHandle& h) { Socket* s; do { errno = 0; @@ -114,10 +110,6 @@ void SslAcceptorTmpl<T>::readable(DispatchHandle& h) { h.rewatch(); } -// Explicitly instantiate the templates we need -template class SslAcceptorTmpl<SslSocket>; -template class SslAcceptorTmpl<SslMuxSocket>; - /* * Asynch Connector */ @@ -134,13 +126,14 @@ SslConnector::SslConnector(const SslSocket& s, boost::bind(&SslConnector::connComplete, this, _1)), connCallback(connCb), failCallback(failCb), - socket(s) + socket(s), + sa(hostname, port) { //TODO: would be better for connect to be performed on a //non-blocking socket, but that doesn't work at present so connect //blocks until complete try { - socket.connect(hostname, port); + socket.connect(sa); socket.setNonblocking(); startWatch(poller); } catch(std::exception& e) { diff --git a/cpp/src/qpid/sys/ssl/SslIo.h b/cpp/src/qpid/sys/ssl/SslIo.h index f3112bfa65..a72cd7c76c 100644 --- a/cpp/src/qpid/sys/ssl/SslIo.h +++ b/cpp/src/qpid/sys/ssl/SslIo.h @@ -21,8 +21,10 @@ * */ +#include <qpid/sys/AsynchIO.h> #include "qpid/sys/DispatchHandle.h" #include "qpid/sys/SecuritySettings.h" +#include "qpid/sys/SocketAddress.h" #include <boost/function.hpp> #include <boost/shared_array.hpp> @@ -41,19 +43,18 @@ class SslSocket; * Asynchronous ssl acceptor: accepts connections then does a callback * with the accepted fd */ -template <class T> -class SslAcceptorTmpl { +class SslAcceptor { public: typedef boost::function1<void, const Socket&> Callback; private: Callback acceptedCallback; qpid::sys::DispatchHandle handle; - const T& socket; + const Socket& socket; public: - SslAcceptorTmpl(const T& s, Callback callback); - ~SslAcceptorTmpl(); + SslAcceptor(const Socket& s, Callback callback); + ~SslAcceptor(); void start(qpid::sys::Poller::shared_ptr poller); private: @@ -73,6 +74,7 @@ private: ConnectedCallback connCallback; FailedCallback failCallback; const SslSocket& socket; + SocketAddress sa; public: SslConnector(const SslSocket& socket, @@ -87,23 +89,6 @@ private: void failure(int, std::string); }; -struct SslIOBufferBase { - char* bytes; - int32_t byteCount; - int32_t dataStart; - int32_t dataCount; - - SslIOBufferBase(char* const b, const int32_t s) : - bytes(b), - byteCount(s), - dataStart(0), - dataCount(0) - {} - - virtual ~SslIOBufferBase() - {} -}; - /* * Asychronous reader/writer: * Reader accepts buffers to read into; reads into the provided buffers @@ -116,18 +101,8 @@ struct SslIOBufferBase { * The class is implemented in terms of DispatchHandle to allow it to be deleted by deleting * the contained DispatchHandle */ -class SslIO : private qpid::sys::DispatchHandle { +class SslIO : public AsynchIO, private qpid::sys::DispatchHandle { public: - typedef SslIOBufferBase BufferBase; - - typedef boost::function2<void, SslIO&, BufferBase*> ReadCallback; - typedef boost::function1<void, SslIO&> EofCallback; - typedef boost::function1<void, SslIO&> DisconnectCallback; - typedef boost::function2<void, SslIO&, const SslSocket&> ClosedCallback; - typedef boost::function1<void, SslIO&> BuffersEmptyCallback; - typedef boost::function1<void, SslIO&> IdleCallback; - typedef boost::function1<void, SslIO&> RequestCallback; - SslIO(const SslSocket& s, ReadCallback rCb, EofCallback eofCb, DisconnectCallback disCb, ClosedCallback cCb = 0, BuffersEmptyCallback eCb = 0, IdleCallback iCb = 0); @@ -153,17 +128,6 @@ private: volatile bool writePending; public: - /* - * Size of IO buffers - this is the maximum possible frame size + 1 - */ - const static uint32_t MaxBufferSize = 65536; - - /* - * Number of IO buffers allocated - I think the code can only use 2 - - * 1 for reading and 1 for writing, allocate 4 for safety - */ - const static uint32_t BufferCount = 4; - void queueForDeletion(); void start(qpid::sys::Poller::shared_ptr poller); @@ -174,6 +138,8 @@ public: void notifyPendingWrite(); void queueWriteClose(); bool writeQueueEmpty() { return writeQueue.empty(); } + void startReading() {}; + void stopReading() {}; void requestCallback(RequestCallback); BufferBase* getQueuedBuffer(); diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp index 0568ed8350..6b6f326492 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.cpp +++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp @@ -20,6 +20,7 @@ */ #include "qpid/sys/ssl/SslSocket.h" +#include "qpid/sys/SocketAddress.h" #include "qpid/sys/ssl/check.h" #include "qpid/sys/ssl/util.h" #include "qpid/Exception.h" @@ -81,11 +82,15 @@ std::string getDomainFromSubject(std::string subject) } } -SslSocket::SslSocket() : socket(0), prototype(0) +SslSocket::SslSocket(const std::string& certName, bool clientAuth) : + nssSocket(0), certname(certName), prototype(0) { - impl->fd = ::socket (PF_INET, SOCK_STREAM, 0); - if (impl->fd < 0) throw QPID_POSIX_ERROR(errno); - socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd)); + //configure prototype socket: + prototype = SSL_ImportFD(0, PR_NewTCPSocket()); + if (clientAuth) { + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); + } } /** @@ -93,25 +98,41 @@ SslSocket::SslSocket() : socket(0), prototype(0) * returned from accept. Because we use posix accept rather than * PR_Accept, we have to reset the handshake. */ -SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0) +SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), nssSocket(0), prototype(0) { - socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); - NSS_CHECK(SSL_ResetHandshake(socket, true)); + nssSocket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); + NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_TRUE)); } void SslSocket::setNonblocking() const { + if (!nssSocket) { + Socket::setNonblocking(); + return; + } PRSocketOptionData option; option.option = PR_SockOpt_Nonblocking; option.value.non_blocking = true; - PR_SetSocketOption(socket, &option); + PR_SetSocketOption(nssSocket, &option); +} + +void SslSocket::setTcpNoDelay() const +{ + if (!nssSocket) { + Socket::setTcpNoDelay(); + return; + } + PRSocketOptionData option; + option.option = PR_SockOpt_NoDelay; + option.value.no_delay = true; + PR_SetSocketOption(nssSocket, &option); } -void SslSocket::connect(const std::string& host, const std::string& port) const +void SslSocket::connect(const SocketAddress& addr) const { - std::stringstream namestream; - namestream << host << ":" << port; - connectname = namestream.str(); + Socket::connect(addr); + + nssSocket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd)); void* arg; // Use the connection's cert-name if it has one; else use global cert-name @@ -122,41 +143,31 @@ void SslSocket::connect(const std::string& host, const std::string& port) const } else { arg = const_cast<char*>(SslOptions::global.certName.c_str()); } - NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg)); - NSS_CHECK(SSL_SetURL(socket, host.data())); - - char hostBuffer[PR_NETDB_BUF_SIZE]; - PRHostEnt hostEntry; - PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry)); - PRNetAddr address; - int value = PR_EnumerateHostEnt(0, &hostEntry, boost::lexical_cast<PRUint16>(port), &address); - if (value < 0) { - throw Exception(QPID_MSG("Error getting address for host: " << ErrorString())); - } else if (value == 0) { - throw Exception(QPID_MSG("Could not resolve address for host.")); - } - PR_CHECK(PR_Connect(socket, &address, PR_INTERVAL_NO_TIMEOUT)); - NSS_CHECK(SSL_ForceHandshake(socket)); + NSS_CHECK(SSL_GetClientAuthDataHook(nssSocket, NSS_GetClientAuthData, arg)); + + url = addr.getHost(); + NSS_CHECK(SSL_SetURL(nssSocket, url.data())); + + NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_FALSE)); + NSS_CHECK(SSL_ForceHandshake(nssSocket)); } void SslSocket::close() const { + if (!nssSocket) { + Socket::close(); + return; + } if (impl->fd > 0) { - PR_Close(socket); + PR_Close(nssSocket); impl->fd = -1; } } -int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, bool clientAuth) const +int SslSocket::listen(const SocketAddress& sa, int backlog) const { - //configure prototype socket: - prototype = SSL_ImportFD(0, PR_NewTCPSocket()); - if (clientAuth) { - NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); - NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); - } - //get certificate and key (is this the correct way?) + std::string certName( (certname == "") ? "localhost.localdomain" : certname); CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(certName.c_str()), 0); if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << certName << "'")); SECKEYPrivateKey *key = PK11_FindKeyByAnyCert(cert, 0); @@ -165,24 +176,7 @@ int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, b SECKEY_DestroyPrivateKey(key); CERT_DestroyCertificate(cert); - //bind and listen - const int& socket = impl->fd; - int yes=1; - QPID_POSIX_CHECK(setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes))); - struct sockaddr_in name; - name.sin_family = AF_INET; - name.sin_port = htons(port); - name.sin_addr.s_addr = 0; - if (::bind(socket, (struct sockaddr*)&name, sizeof(name)) < 0) - throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno))); - if (::listen(socket, backlog) < 0) - throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno))); - - socklen_t namelen = sizeof(name); - if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0) - throw QPID_POSIX_ERROR(errno); - - return ntohs(name.sin_port); + return Socket::listen(sa, backlog); } SslSocket* SslSocket::accept() const @@ -274,6 +268,11 @@ static bool isSslStream(int afd) { return isSSL2Handshake || isSSL3Handshake; } +SslMuxSocket::SslMuxSocket(const std::string& certName, bool clientAuth) : + SslSocket(certName, clientAuth) +{ +} + Socket* SslMuxSocket::accept() const { int afd = ::accept(impl->fd, 0, 0); @@ -295,20 +294,12 @@ Socket* SslMuxSocket::accept() const int SslSocket::read(void *buf, size_t count) const { - return PR_Read(socket, buf, count); + return PR_Read(nssSocket, buf, count); } int SslSocket::write(const void *buf, size_t count) const { - return PR_Write(socket, buf, count); -} - -void SslSocket::setTcpNoDelay() const -{ - PRSocketOptionData option; - option.option = PR_SockOpt_NoDelay; - option.value.no_delay = true; - PR_SetSocketOption(socket, &option); + return PR_Write(nssSocket, buf, count); } void SslSocket::setCertName(const std::string& name) @@ -324,7 +315,7 @@ int SslSocket::getKeyLen() const int keySize = 0; SECStatus rc; - rc = SSL_SecurityStatus( socket, + rc = SSL_SecurityStatus( nssSocket, &enabled, NULL, NULL, @@ -339,7 +330,7 @@ int SslSocket::getKeyLen() const std::string SslSocket::getClientAuthId() const { std::string authId; - CERTCertificate* cert = SSL_PeerCertificate(socket); + CERTCertificate* cert = SSL_PeerCertificate(nssSocket); if (cert) { authId = CERT_GetCommonName(&(cert->subject)); /* diff --git a/cpp/src/qpid/sys/ssl/SslSocket.h b/cpp/src/qpid/sys/ssl/SslSocket.h index 0f7e74f977..1b5424cfeb 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.h +++ b/cpp/src/qpid/sys/ssl/SslSocket.h @@ -40,8 +40,10 @@ namespace ssl { class SslSocket : public qpid::sys::Socket { public: - /** Create a socket wrapper for descriptor. */ - SslSocket(); + /** Create a socket wrapper for descriptor. + *@param certName name of certificate to use to identify the socket + */ + SslSocket(const std::string& certName = "", bool clientAuth = false); /** Set socket non blocking */ void setNonblocking() const; @@ -54,17 +56,16 @@ public: * NSSInit().*/ void setCertName(const std::string& certName); - void connect(const std::string& host, const std::string& port) const; + void connect(const SocketAddress&) const; void close() const; /** Bind to a port and start listening. *@param port 0 means choose an available port. *@param backlog maximum number of pending connections. - *@param certName name of certificate to use to identify the server *@return The bound port. */ - int listen(uint16_t port = 0, int backlog = 10, const std::string& certName = "localhost.localdomain", bool clientAuth = false) const; + int listen(const SocketAddress&, int backlog = 10) const; /** * Accept a connection from a socket that is already listening @@ -80,9 +81,9 @@ public: std::string getClientAuthId() const; protected: - mutable std::string connectname; - mutable PRFileDesc* socket; + mutable PRFileDesc* nssSocket; std::string certname; + mutable std::string url; /** * 'model' socket, with configuration to use when importing @@ -98,6 +99,7 @@ protected: class SslMuxSocket : public SslSocket { public: + SslMuxSocket(const std::string& certName = "", bool clientAuth = false); Socket* accept() const; }; diff --git a/cpp/src/qpid/sys/windows/AsynchIO.cpp b/cpp/src/qpid/sys/windows/AsynchIO.cpp index 618d0e14c7..9fdf89c83b 100644 --- a/cpp/src/qpid/sys/windows/AsynchIO.cpp +++ b/cpp/src/qpid/sys/windows/AsynchIO.cpp @@ -24,6 +24,7 @@ #include "qpid/sys/AsynchIO.h" #include "qpid/sys/Mutex.h" #include "qpid/sys/Socket.h" +#include "qpid/sys/SocketAddress.h" #include "qpid/sys/Poller.h" #include "qpid/sys/Thread.h" #include "qpid/sys/Time.h" @@ -195,7 +196,7 @@ AsynchConnector::AsynchConnector(const Socket& sock, void AsynchConnector::start(Poller::shared_ptr) { try { - socket.connect(hostname, port); + socket.connect(SocketAddress(hostname, port)); socket.setNonblocking(); connCallback(socket); } catch(std::exception& e) { diff --git a/cpp/src/qpid/sys/windows/Socket.cpp b/cpp/src/qpid/sys/windows/Socket.cpp index 17e3212a46..0c74b3a725 100644 --- a/cpp/src/qpid/sys/windows/Socket.cpp +++ b/cpp/src/qpid/sys/windows/Socket.cpp @@ -160,12 +160,6 @@ void Socket::setNonblocking() const { QPID_WINSOCK_CHECK(ioctlsocket(impl->fd, FIONBIO, &nonblock)); } -void Socket::connect(const std::string& host, const std::string& port) const -{ - SocketAddress sa(host, port); - connect(sa); -} - void Socket::connect(const SocketAddress& addr) const { @@ -209,12 +203,6 @@ int Socket::read(void *buf, size_t count) const return received; } -int Socket::listen(const std::string& host, const std::string& port, int backlog) const -{ - SocketAddress sa(host, port); - return listen(sa, backlog); -} - int Socket::listen(const SocketAddress& addr, int backlog) const { createSocket(addr); |