summaryrefslogtreecommitdiff
path: root/cpp
diff options
context:
space:
mode:
authorAndrew Stitcher <astitcher@apache.org>2012-10-24 05:51:19 +0000
committerAndrew Stitcher <astitcher@apache.org>2012-10-24 05:51:19 +0000
commitc61f7c0c79717b0eb842d0c4c88deeda9f7672e6 (patch)
treebaabe1be2b4389b3412154598fe0ea543ddd63f5 /cpp
parentf47c28c9ee2fb8f2967a221a28912060edcba749 (diff)
downloadqpid-python-c61f7c0c79717b0eb842d0c4c88deeda9f7672e6.tar.gz
QPID-4272: Large amounts of code are duplicated between the SSL and TCP transports
Refactor to unify the various SSL and TCP interfaces: - Make ssl listen have the same signature as regular sockets - Give ssl connect same interface as tcp - Unify the SSL and TCP IO interfaces some more git-svn-id: https://svn.apache.org/repos/asf/qpid/trunk/qpid@1401558 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'cpp')
-rw-r--r--cpp/src/qpid/client/SslConnector.cpp22
-rw-r--r--cpp/src/qpid/sys/Socket.h17
-rw-r--r--cpp/src/qpid/sys/SocketAddress.h3
-rw-r--r--cpp/src/qpid/sys/SslPlugin.cpp35
-rw-r--r--cpp/src/qpid/sys/posix/Socket.cpp12
-rw-r--r--cpp/src/qpid/sys/posix/SocketAddress.cpp5
-rw-r--r--cpp/src/qpid/sys/ssl/SslHandler.cpp16
-rw-r--r--cpp/src/qpid/sys/ssl/SslHandler.h18
-rw-r--r--cpp/src/qpid/sys/ssl/SslIo.cpp23
-rw-r--r--cpp/src/qpid/sys/ssl/SslIo.h54
-rw-r--r--cpp/src/qpid/sys/ssl/SslSocket.cpp123
-rw-r--r--cpp/src/qpid/sys/ssl/SslSocket.h16
-rw-r--r--cpp/src/qpid/sys/windows/AsynchIO.cpp3
-rw-r--r--cpp/src/qpid/sys/windows/Socket.cpp12
14 files changed, 153 insertions, 206 deletions
diff --git a/cpp/src/qpid/client/SslConnector.cpp b/cpp/src/qpid/client/SslConnector.cpp
index 0b07d14f35..c49deaa279 100644
--- a/cpp/src/qpid/client/SslConnector.cpp
+++ b/cpp/src/qpid/client/SslConnector.cpp
@@ -79,11 +79,11 @@ class SslConnector : public Connector
~SslConnector();
- void readbuff(qpid::sys::ssl::SslIO&, qpid::sys::ssl::SslIOBufferBase*);
- void writebuff(qpid::sys::ssl::SslIO&);
+ void readbuff(AsynchIO&, AsynchIOBufferBase*);
+ void writebuff(AsynchIO&);
void writeDataBlock(const framing::AMQDataBlock& data);
- void eof(qpid::sys::ssl::SslIO&);
- void disconnected(qpid::sys::ssl::SslIO&);
+ void eof(AsynchIO&);
+ void disconnected(AsynchIO&);
void connect(const std::string& host, const std::string& port);
void close();
@@ -96,7 +96,7 @@ class SslConnector : public Connector
framing::OutputHandler* getOutputHandler();
const std::string& getIdentifier() const;
const SecuritySettings* getSecuritySettings();
- void socketClosed(qpid::sys::ssl::SslIO&, const qpid::sys::ssl::SslSocket&);
+ void socketClosed(AsynchIO&, const Socket&);
size_t decode(const char* buffer, size_t size);
size_t encode(char* buffer, size_t size);
@@ -168,7 +168,7 @@ void SslConnector::connect(const std::string& host, const std::string& port){
Mutex::ScopedLock l(lock);
assert(closed);
try {
- socket.connect(host, port);
+ socket.connect(SocketAddress(host, port));
} catch (const std::exception& e) {
socket.close();
throw TransportFailure(e.what());
@@ -199,7 +199,7 @@ void SslConnector::close() {
}
}
-void SslConnector::socketClosed(SslIO&, const SslSocket&) {
+void SslConnector::socketClosed(AsynchIO&, const Socket&) {
if (aio)
aio->queueForDeletion();
if (shutdownHandler)
@@ -255,7 +255,7 @@ void SslConnector::send(AMQFrame& frame) {
}
}
-void SslConnector::writebuff(SslIO& /*aio*/)
+void SslConnector::writebuff(AsynchIO& /*aio*/)
{
// It's possible to be disconnected and be writable
if (closed)
@@ -304,7 +304,7 @@ size_t SslConnector::encode(char* buffer, size_t size)
return bytesWritten;
}
-void SslConnector::readbuff(SslIO& aio, SslIO::BufferBase* buff)
+void SslConnector::readbuff(AsynchIO& aio, AsynchIOBufferBase* buff)
{
int32_t decoded = decode(buff->bytes+buff->dataStart, buff->dataCount);
// TODO: unreading needs to go away, and when we can cope
@@ -351,11 +351,11 @@ void SslConnector::writeDataBlock(const AMQDataBlock& data) {
aio->queueWrite(buff);
}
-void SslConnector::eof(SslIO&) {
+void SslConnector::eof(AsynchIO&) {
close();
}
-void SslConnector::disconnected(SslIO&) {
+void SslConnector::disconnected(AsynchIO&) {
close();
socketClosed(*aio, socket);
}
diff --git a/cpp/src/qpid/sys/Socket.h b/cpp/src/qpid/sys/Socket.h
index defec4879c..aa8a8a31d9 100644
--- a/cpp/src/qpid/sys/Socket.h
+++ b/cpp/src/qpid/sys/Socket.h
@@ -33,6 +33,10 @@ namespace sys {
class Duration;
class SocketAddress;
+namespace ssl {
+class SslMuxSocket;
+}
+
class QPID_COMMON_CLASS_EXTERN Socket : public IOHandle
{
public:
@@ -47,7 +51,6 @@ public:
QPID_COMMON_EXTERN void setTcpNoDelay() const;
- QPID_COMMON_EXTERN void connect(const std::string& host, const std::string& port) const;
QPID_COMMON_EXTERN void connect(const SocketAddress&) const;
QPID_COMMON_EXTERN void close() const;
@@ -57,7 +60,6 @@ public:
*@param backlog maximum number of pending connections.
*@return The bound port.
*/
- QPID_COMMON_EXTERN int listen(const std::string& host = "", const std::string& port = "0", int backlog = 10) const;
QPID_COMMON_EXTERN int listen(const SocketAddress&, int backlog = 10) const;
/**
@@ -91,19 +93,18 @@ public:
QPID_COMMON_EXTERN int read(void *buf, size_t count) const;
QPID_COMMON_EXTERN int write(const void *buf, size_t count) const;
-private:
+protected:
/** Create socket */
void createSocket(const SocketAddress&) const;
-public:
- /** Construct socket with existing handle */
- Socket(IOHandlePrivate*);
-
-protected:
mutable std::string localname;
mutable std::string peername;
mutable bool nonblocking;
mutable bool nodelay;
+
+ /** Construct socket with existing handle */
+ Socket(IOHandlePrivate*);
+ friend class qpid::sys::ssl::SslMuxSocket;
};
}}
diff --git a/cpp/src/qpid/sys/SocketAddress.h b/cpp/src/qpid/sys/SocketAddress.h
index dcca109d94..a4da5cca79 100644
--- a/cpp/src/qpid/sys/SocketAddress.h
+++ b/cpp/src/qpid/sys/SocketAddress.h
@@ -44,11 +44,12 @@ public:
QPID_COMMON_EXTERN bool nextAddress();
QPID_COMMON_EXTERN std::string asString(bool numeric=true) const;
+ QPID_COMMON_EXTERN std::string getHost() const;
QPID_COMMON_EXTERN void setAddrInfoPort(uint16_t port);
QPID_COMMON_EXTERN static std::string asString(::sockaddr const * const addr, size_t addrlen);
QPID_COMMON_EXTERN static uint16_t getPort(::sockaddr const * const addr);
-
+
private:
std::string host;
diff --git a/cpp/src/qpid/sys/SslPlugin.cpp b/cpp/src/qpid/sys/SslPlugin.cpp
index c14cb5f016..1cebadeab3 100644
--- a/cpp/src/qpid/sys/SslPlugin.cpp
+++ b/cpp/src/qpid/sys/SslPlugin.cpp
@@ -29,6 +29,7 @@
#include "qpid/sys/AsynchIO.h"
#include "qpid/sys/ssl/SslIo.h"
#include "qpid/sys/ssl/SslSocket.h"
+#include "qpid/sys/SocketAddress.h"
#include "qpid/broker/Broker.h"
#include "qpid/log/Statement.h"
@@ -68,8 +69,6 @@ template <class T>
class SslProtocolFactoryTmpl : public ProtocolFactory {
private:
- typedef SslAcceptorTmpl<T> SslAcceptor;
-
Timer& brokerTimer;
uint32_t maxNegotiateTime;
const bool tcpNoDelay;
@@ -79,7 +78,10 @@ class SslProtocolFactoryTmpl : public ProtocolFactory {
bool nodict;
public:
- SslProtocolFactoryTmpl(const SslServerOptions&, int backlog, bool nodelay, Timer& timer, uint32_t maxTime);
+ SslProtocolFactoryTmpl(const std::string& host, const std::string& port,
+ const SslServerOptions&,
+ int backlog, bool nodelay,
+ Timer& timer, uint32_t maxTime);
void accept(Poller::shared_ptr, ConnectionCodec::Factory*);
void connect(Poller::shared_ptr, const std::string& host, const std::string& port,
ConnectionCodec::Factory*,
@@ -139,14 +141,16 @@ static struct SslPlugin : public Plugin {
const broker::Broker::Options& opts = broker->getOptions();
ProtocolFactory::shared_ptr protocol(options.multiplex ?
- static_cast<ProtocolFactory*>(new SslMuxProtocolFactory(options,
- opts.connectionBacklog,
- opts.tcpNoDelay,
- broker->getTimer(), opts.maxNegotiateTime)) :
- static_cast<ProtocolFactory*>(new SslProtocolFactory(options,
- opts.connectionBacklog,
- opts.tcpNoDelay,
- broker->getTimer(), opts.maxNegotiateTime)));
+ static_cast<ProtocolFactory*>(new SslMuxProtocolFactory("", boost::lexical_cast<std::string>(options.port),
+ options,
+ opts.connectionBacklog,
+ opts.tcpNoDelay,
+ broker->getTimer(), opts.maxNegotiateTime)) :
+ static_cast<ProtocolFactory*>(new SslProtocolFactory("", boost::lexical_cast<std::string>(options.port),
+ options,
+ opts.connectionBacklog,
+ opts.tcpNoDelay,
+ broker->getTimer(), opts.maxNegotiateTime)));
QPID_LOG(notice, "Listening for " <<
(options.multiplex ? "SSL or TCP" : "SSL") <<
" connections on TCP port " <<
@@ -161,10 +165,15 @@ static struct SslPlugin : public Plugin {
} sslPlugin;
template <class T>
-SslProtocolFactoryTmpl<T>::SslProtocolFactoryTmpl(const SslServerOptions& options, int backlog, bool nodelay, Timer& timer, uint32_t maxTime) :
+SslProtocolFactoryTmpl<T>::SslProtocolFactoryTmpl(const std::string& host, const std::string& port,
+ const SslServerOptions& options,
+ int backlog, bool nodelay,
+ Timer& timer, uint32_t maxTime) :
brokerTimer(timer),
maxNegotiateTime(maxTime),
- tcpNoDelay(nodelay), listeningPort(listener.listen(options.port, backlog, options.certName, options.clientAuth)),
+ tcpNoDelay(nodelay),
+ listener(options.certName, options.clientAuth),
+ listeningPort(listener.listen(SocketAddress(host, port), backlog)),
nodict(options.nodict)
{}
diff --git a/cpp/src/qpid/sys/posix/Socket.cpp b/cpp/src/qpid/sys/posix/Socket.cpp
index 77ae1af60c..0c01374369 100644
--- a/cpp/src/qpid/sys/posix/Socket.cpp
+++ b/cpp/src/qpid/sys/posix/Socket.cpp
@@ -135,12 +135,6 @@ void Socket::setTcpNoDelay() const
}
}
-void Socket::connect(const std::string& host, const std::string& port) const
-{
- SocketAddress sa(host, port);
- connect(sa);
-}
-
void Socket::connect(const SocketAddress& addr) const
{
// The display name for an outbound connection needs to be the name that was specified
@@ -188,12 +182,6 @@ Socket::close() const
socket = -1;
}
-int Socket::listen(const std::string& host, const std::string& port, int backlog) const
-{
- SocketAddress sa(host, port);
- return listen(sa, backlog);
-}
-
int Socket::listen(const SocketAddress& sa, int backlog) const
{
createSocket(sa);
diff --git a/cpp/src/qpid/sys/posix/SocketAddress.cpp b/cpp/src/qpid/sys/posix/SocketAddress.cpp
index 344bd28669..cd23442226 100644
--- a/cpp/src/qpid/sys/posix/SocketAddress.cpp
+++ b/cpp/src/qpid/sys/posix/SocketAddress.cpp
@@ -102,6 +102,11 @@ std::string SocketAddress::asString(bool numeric) const
return asString(ai.ai_addr, ai.ai_addrlen);
}
+std::string SocketAddress::getHost() const
+{
+ return host;
+}
+
bool SocketAddress::nextAddress() {
bool r = currentAddrInfo->ai_next != 0;
if (r)
diff --git a/cpp/src/qpid/sys/ssl/SslHandler.cpp b/cpp/src/qpid/sys/ssl/SslHandler.cpp
index 8668c7d8d0..6e079a8094 100644
--- a/cpp/src/qpid/sys/ssl/SslHandler.cpp
+++ b/cpp/src/qpid/sys/ssl/SslHandler.cpp
@@ -83,7 +83,7 @@ void SslHandler::init(SslIO* a, Timer& timer, uint32_t maxTime) {
void SslHandler::write(const framing::ProtocolInitiation& data)
{
QPID_LOG(debug, "SENT [" << identifier << "]: INIT(" << data << ")");
- SslIO::BufferBase* buff = aio->getQueuedBuffer();
+ AsynchIOBufferBase* buff = aio->getQueuedBuffer();
assert(buff);
framing::Buffer out(buff->bytes, buff->byteCount);
data.encode(out);
@@ -106,7 +106,7 @@ void SslHandler::giveReadCredit(int32_t) {
}
// Input side
-void SslHandler::readbuff(SslIO& , SslIO::BufferBase* buff) {
+void SslHandler::readbuff(AsynchIO& , AsynchIOBufferBase* buff) {
if (readError) {
return;
}
@@ -160,13 +160,13 @@ void SslHandler::readbuff(SslIO& , SslIO::BufferBase* buff) {
}
}
-void SslHandler::eof(SslIO&) {
+void SslHandler::eof(AsynchIO&) {
QPID_LOG(debug, "DISCONNECTED [" << identifier << "]");
if (codec) codec->closed();
aio->queueWriteClose();
}
-void SslHandler::closedSocket(SslIO&, const SslSocket& s) {
+void SslHandler::closedSocket(AsynchIO&, const Socket& s) {
// If we closed with data still to send log a warning
if (!aio->writeQueueEmpty()) {
QPID_LOG(warning, "CLOSING [" << identifier << "] unsent data (probably due to client disconnect)");
@@ -176,16 +176,16 @@ void SslHandler::closedSocket(SslIO&, const SslSocket& s) {
delete this;
}
-void SslHandler::disconnect(SslIO& a) {
+void SslHandler::disconnect(AsynchIO& a) {
// treat the same as eof
eof(a);
}
// Notifications
-void SslHandler::nobuffs(SslIO&) {
+void SslHandler::nobuffs(AsynchIO&) {
}
-void SslHandler::idle(SslIO&){
+void SslHandler::idle(AsynchIO&){
if (isClient && codec == 0) {
codec = factory->create(*this, identifier, getSecuritySettings(aio));
write(framing::ProtocolInitiation(codec->getVersion()));
@@ -199,7 +199,7 @@ void SslHandler::idle(SslIO&){
if (!codec->canEncode()) {
return;
}
- SslIO::BufferBase* buff = aio->getQueuedBuffer();
+ AsynchIOBufferBase* buff = aio->getQueuedBuffer();
if (buff) {
size_t encoded=codec->encode(buff->bytes, buff->byteCount);
buff->dataCount = encoded;
diff --git a/cpp/src/qpid/sys/ssl/SslHandler.h b/cpp/src/qpid/sys/ssl/SslHandler.h
index 14814b0281..d25304b37e 100644
--- a/cpp/src/qpid/sys/ssl/SslHandler.h
+++ b/cpp/src/qpid/sys/ssl/SslHandler.h
@@ -24,6 +24,7 @@
#include "qpid/sys/ConnectionCodec.h"
#include "qpid/sys/OutputControl.h"
+#include "qpid/sys/SecuritySettings.h"
#include <boost/intrusive_ptr.hpp>
@@ -35,14 +36,15 @@ namespace framing {
namespace sys {
+class AsynchIO;
+struct AsynchIOBufferBase;
+class Socket;
class Timer;
class TimerTask;
namespace ssl {
class SslIO;
-struct SslIOBufferBase;
-class SslSocket;
class SslHandler : public OutputControl {
std::string identifier;
@@ -70,14 +72,14 @@ class SslHandler : public OutputControl {
void giveReadCredit(int32_t);
// Input side
- void readbuff(SslIO& aio, SslIOBufferBase* buff);
- void eof(SslIO& aio);
- void disconnect(SslIO& aio);
+ void readbuff(qpid::sys::AsynchIO&, qpid::sys::AsynchIOBufferBase* buff);
+ void eof(qpid::sys::AsynchIO&);
+ void disconnect(qpid::sys::AsynchIO& a);
// Notifications
- void nobuffs(SslIO& aio);
- void idle(SslIO& aio);
- void closedSocket(SslIO& aio, const SslSocket& s);
+ void nobuffs(qpid::sys::AsynchIO&);
+ void idle(qpid::sys::AsynchIO&);
+ void closedSocket(qpid::sys::AsynchIO&, const qpid::sys::Socket& s);
};
}}} // namespace qpid::sys::ssl
diff --git a/cpp/src/qpid/sys/ssl/SslIo.cpp b/cpp/src/qpid/sys/ssl/SslIo.cpp
index bbfb703170..92e51a2234 100644
--- a/cpp/src/qpid/sys/ssl/SslIo.cpp
+++ b/cpp/src/qpid/sys/ssl/SslIo.cpp
@@ -68,32 +68,28 @@ __thread int64_t threadMaxIoTimeNs = 2 * 1000000; // start at 2ms
* Asynch Acceptor
*/
-template <class T>
-SslAcceptorTmpl<T>::SslAcceptorTmpl(const T& s, Callback callback) :
+SslAcceptor::SslAcceptor(const Socket& s, Callback callback) :
acceptedCallback(callback),
- handle(s, boost::bind(&SslAcceptorTmpl<T>::readable, this, _1), 0, 0),
+ handle(s, boost::bind(&SslAcceptor::readable, this, _1), 0, 0),
socket(s) {
s.setNonblocking();
ignoreSigpipe();
}
-template <class T>
-SslAcceptorTmpl<T>::~SslAcceptorTmpl()
+SslAcceptor::~SslAcceptor()
{
handle.stopWatch();
}
-template <class T>
-void SslAcceptorTmpl<T>::start(Poller::shared_ptr poller) {
+void SslAcceptor::start(Poller::shared_ptr poller) {
handle.startWatch(poller);
}
/*
* We keep on accepting as long as there is something to accept
*/
-template <class T>
-void SslAcceptorTmpl<T>::readable(DispatchHandle& h) {
+void SslAcceptor::readable(DispatchHandle& h) {
Socket* s;
do {
errno = 0;
@@ -114,10 +110,6 @@ void SslAcceptorTmpl<T>::readable(DispatchHandle& h) {
h.rewatch();
}
-// Explicitly instantiate the templates we need
-template class SslAcceptorTmpl<SslSocket>;
-template class SslAcceptorTmpl<SslMuxSocket>;
-
/*
* Asynch Connector
*/
@@ -134,13 +126,14 @@ SslConnector::SslConnector(const SslSocket& s,
boost::bind(&SslConnector::connComplete, this, _1)),
connCallback(connCb),
failCallback(failCb),
- socket(s)
+ socket(s),
+ sa(hostname, port)
{
//TODO: would be better for connect to be performed on a
//non-blocking socket, but that doesn't work at present so connect
//blocks until complete
try {
- socket.connect(hostname, port);
+ socket.connect(sa);
socket.setNonblocking();
startWatch(poller);
} catch(std::exception& e) {
diff --git a/cpp/src/qpid/sys/ssl/SslIo.h b/cpp/src/qpid/sys/ssl/SslIo.h
index f3112bfa65..a72cd7c76c 100644
--- a/cpp/src/qpid/sys/ssl/SslIo.h
+++ b/cpp/src/qpid/sys/ssl/SslIo.h
@@ -21,8 +21,10 @@
*
*/
+#include <qpid/sys/AsynchIO.h>
#include "qpid/sys/DispatchHandle.h"
#include "qpid/sys/SecuritySettings.h"
+#include "qpid/sys/SocketAddress.h"
#include <boost/function.hpp>
#include <boost/shared_array.hpp>
@@ -41,19 +43,18 @@ class SslSocket;
* Asynchronous ssl acceptor: accepts connections then does a callback
* with the accepted fd
*/
-template <class T>
-class SslAcceptorTmpl {
+class SslAcceptor {
public:
typedef boost::function1<void, const Socket&> Callback;
private:
Callback acceptedCallback;
qpid::sys::DispatchHandle handle;
- const T& socket;
+ const Socket& socket;
public:
- SslAcceptorTmpl(const T& s, Callback callback);
- ~SslAcceptorTmpl();
+ SslAcceptor(const Socket& s, Callback callback);
+ ~SslAcceptor();
void start(qpid::sys::Poller::shared_ptr poller);
private:
@@ -73,6 +74,7 @@ private:
ConnectedCallback connCallback;
FailedCallback failCallback;
const SslSocket& socket;
+ SocketAddress sa;
public:
SslConnector(const SslSocket& socket,
@@ -87,23 +89,6 @@ private:
void failure(int, std::string);
};
-struct SslIOBufferBase {
- char* bytes;
- int32_t byteCount;
- int32_t dataStart;
- int32_t dataCount;
-
- SslIOBufferBase(char* const b, const int32_t s) :
- bytes(b),
- byteCount(s),
- dataStart(0),
- dataCount(0)
- {}
-
- virtual ~SslIOBufferBase()
- {}
-};
-
/*
* Asychronous reader/writer:
* Reader accepts buffers to read into; reads into the provided buffers
@@ -116,18 +101,8 @@ struct SslIOBufferBase {
* The class is implemented in terms of DispatchHandle to allow it to be deleted by deleting
* the contained DispatchHandle
*/
-class SslIO : private qpid::sys::DispatchHandle {
+class SslIO : public AsynchIO, private qpid::sys::DispatchHandle {
public:
- typedef SslIOBufferBase BufferBase;
-
- typedef boost::function2<void, SslIO&, BufferBase*> ReadCallback;
- typedef boost::function1<void, SslIO&> EofCallback;
- typedef boost::function1<void, SslIO&> DisconnectCallback;
- typedef boost::function2<void, SslIO&, const SslSocket&> ClosedCallback;
- typedef boost::function1<void, SslIO&> BuffersEmptyCallback;
- typedef boost::function1<void, SslIO&> IdleCallback;
- typedef boost::function1<void, SslIO&> RequestCallback;
-
SslIO(const SslSocket& s,
ReadCallback rCb, EofCallback eofCb, DisconnectCallback disCb,
ClosedCallback cCb = 0, BuffersEmptyCallback eCb = 0, IdleCallback iCb = 0);
@@ -153,17 +128,6 @@ private:
volatile bool writePending;
public:
- /*
- * Size of IO buffers - this is the maximum possible frame size + 1
- */
- const static uint32_t MaxBufferSize = 65536;
-
- /*
- * Number of IO buffers allocated - I think the code can only use 2 -
- * 1 for reading and 1 for writing, allocate 4 for safety
- */
- const static uint32_t BufferCount = 4;
-
void queueForDeletion();
void start(qpid::sys::Poller::shared_ptr poller);
@@ -174,6 +138,8 @@ public:
void notifyPendingWrite();
void queueWriteClose();
bool writeQueueEmpty() { return writeQueue.empty(); }
+ void startReading() {};
+ void stopReading() {};
void requestCallback(RequestCallback);
BufferBase* getQueuedBuffer();
diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp
index 0568ed8350..6b6f326492 100644
--- a/cpp/src/qpid/sys/ssl/SslSocket.cpp
+++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp
@@ -20,6 +20,7 @@
*/
#include "qpid/sys/ssl/SslSocket.h"
+#include "qpid/sys/SocketAddress.h"
#include "qpid/sys/ssl/check.h"
#include "qpid/sys/ssl/util.h"
#include "qpid/Exception.h"
@@ -81,11 +82,15 @@ std::string getDomainFromSubject(std::string subject)
}
}
-SslSocket::SslSocket() : socket(0), prototype(0)
+SslSocket::SslSocket(const std::string& certName, bool clientAuth) :
+ nssSocket(0), certname(certName), prototype(0)
{
- impl->fd = ::socket (PF_INET, SOCK_STREAM, 0);
- if (impl->fd < 0) throw QPID_POSIX_ERROR(errno);
- socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd));
+ //configure prototype socket:
+ prototype = SSL_ImportFD(0, PR_NewTCPSocket());
+ if (clientAuth) {
+ NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE));
+ NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE));
+ }
}
/**
@@ -93,25 +98,41 @@ SslSocket::SslSocket() : socket(0), prototype(0)
* returned from accept. Because we use posix accept rather than
* PR_Accept, we have to reset the handshake.
*/
-SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0)
+SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), nssSocket(0), prototype(0)
{
- socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd));
- NSS_CHECK(SSL_ResetHandshake(socket, true));
+ nssSocket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd));
+ NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_TRUE));
}
void SslSocket::setNonblocking() const
{
+ if (!nssSocket) {
+ Socket::setNonblocking();
+ return;
+ }
PRSocketOptionData option;
option.option = PR_SockOpt_Nonblocking;
option.value.non_blocking = true;
- PR_SetSocketOption(socket, &option);
+ PR_SetSocketOption(nssSocket, &option);
+}
+
+void SslSocket::setTcpNoDelay() const
+{
+ if (!nssSocket) {
+ Socket::setTcpNoDelay();
+ return;
+ }
+ PRSocketOptionData option;
+ option.option = PR_SockOpt_NoDelay;
+ option.value.no_delay = true;
+ PR_SetSocketOption(nssSocket, &option);
}
-void SslSocket::connect(const std::string& host, const std::string& port) const
+void SslSocket::connect(const SocketAddress& addr) const
{
- std::stringstream namestream;
- namestream << host << ":" << port;
- connectname = namestream.str();
+ Socket::connect(addr);
+
+ nssSocket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd));
void* arg;
// Use the connection's cert-name if it has one; else use global cert-name
@@ -122,41 +143,31 @@ void SslSocket::connect(const std::string& host, const std::string& port) const
} else {
arg = const_cast<char*>(SslOptions::global.certName.c_str());
}
- NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg));
- NSS_CHECK(SSL_SetURL(socket, host.data()));
-
- char hostBuffer[PR_NETDB_BUF_SIZE];
- PRHostEnt hostEntry;
- PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry));
- PRNetAddr address;
- int value = PR_EnumerateHostEnt(0, &hostEntry, boost::lexical_cast<PRUint16>(port), &address);
- if (value < 0) {
- throw Exception(QPID_MSG("Error getting address for host: " << ErrorString()));
- } else if (value == 0) {
- throw Exception(QPID_MSG("Could not resolve address for host."));
- }
- PR_CHECK(PR_Connect(socket, &address, PR_INTERVAL_NO_TIMEOUT));
- NSS_CHECK(SSL_ForceHandshake(socket));
+ NSS_CHECK(SSL_GetClientAuthDataHook(nssSocket, NSS_GetClientAuthData, arg));
+
+ url = addr.getHost();
+ NSS_CHECK(SSL_SetURL(nssSocket, url.data()));
+
+ NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_FALSE));
+ NSS_CHECK(SSL_ForceHandshake(nssSocket));
}
void SslSocket::close() const
{
+ if (!nssSocket) {
+ Socket::close();
+ return;
+ }
if (impl->fd > 0) {
- PR_Close(socket);
+ PR_Close(nssSocket);
impl->fd = -1;
}
}
-int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, bool clientAuth) const
+int SslSocket::listen(const SocketAddress& sa, int backlog) const
{
- //configure prototype socket:
- prototype = SSL_ImportFD(0, PR_NewTCPSocket());
- if (clientAuth) {
- NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE));
- NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE));
- }
-
//get certificate and key (is this the correct way?)
+ std::string certName( (certname == "") ? "localhost.localdomain" : certname);
CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(certName.c_str()), 0);
if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << certName << "'"));
SECKEYPrivateKey *key = PK11_FindKeyByAnyCert(cert, 0);
@@ -165,24 +176,7 @@ int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, b
SECKEY_DestroyPrivateKey(key);
CERT_DestroyCertificate(cert);
- //bind and listen
- const int& socket = impl->fd;
- int yes=1;
- QPID_POSIX_CHECK(setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes)));
- struct sockaddr_in name;
- name.sin_family = AF_INET;
- name.sin_port = htons(port);
- name.sin_addr.s_addr = 0;
- if (::bind(socket, (struct sockaddr*)&name, sizeof(name)) < 0)
- throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno)));
- if (::listen(socket, backlog) < 0)
- throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno)));
-
- socklen_t namelen = sizeof(name);
- if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0)
- throw QPID_POSIX_ERROR(errno);
-
- return ntohs(name.sin_port);
+ return Socket::listen(sa, backlog);
}
SslSocket* SslSocket::accept() const
@@ -274,6 +268,11 @@ static bool isSslStream(int afd) {
return isSSL2Handshake || isSSL3Handshake;
}
+SslMuxSocket::SslMuxSocket(const std::string& certName, bool clientAuth) :
+ SslSocket(certName, clientAuth)
+{
+}
+
Socket* SslMuxSocket::accept() const
{
int afd = ::accept(impl->fd, 0, 0);
@@ -295,20 +294,12 @@ Socket* SslMuxSocket::accept() const
int SslSocket::read(void *buf, size_t count) const
{
- return PR_Read(socket, buf, count);
+ return PR_Read(nssSocket, buf, count);
}
int SslSocket::write(const void *buf, size_t count) const
{
- return PR_Write(socket, buf, count);
-}
-
-void SslSocket::setTcpNoDelay() const
-{
- PRSocketOptionData option;
- option.option = PR_SockOpt_NoDelay;
- option.value.no_delay = true;
- PR_SetSocketOption(socket, &option);
+ return PR_Write(nssSocket, buf, count);
}
void SslSocket::setCertName(const std::string& name)
@@ -324,7 +315,7 @@ int SslSocket::getKeyLen() const
int keySize = 0;
SECStatus rc;
- rc = SSL_SecurityStatus( socket,
+ rc = SSL_SecurityStatus( nssSocket,
&enabled,
NULL,
NULL,
@@ -339,7 +330,7 @@ int SslSocket::getKeyLen() const
std::string SslSocket::getClientAuthId() const
{
std::string authId;
- CERTCertificate* cert = SSL_PeerCertificate(socket);
+ CERTCertificate* cert = SSL_PeerCertificate(nssSocket);
if (cert) {
authId = CERT_GetCommonName(&(cert->subject));
/*
diff --git a/cpp/src/qpid/sys/ssl/SslSocket.h b/cpp/src/qpid/sys/ssl/SslSocket.h
index 0f7e74f977..1b5424cfeb 100644
--- a/cpp/src/qpid/sys/ssl/SslSocket.h
+++ b/cpp/src/qpid/sys/ssl/SslSocket.h
@@ -40,8 +40,10 @@ namespace ssl {
class SslSocket : public qpid::sys::Socket
{
public:
- /** Create a socket wrapper for descriptor. */
- SslSocket();
+ /** Create a socket wrapper for descriptor.
+ *@param certName name of certificate to use to identify the socket
+ */
+ SslSocket(const std::string& certName = "", bool clientAuth = false);
/** Set socket non blocking */
void setNonblocking() const;
@@ -54,17 +56,16 @@ public:
* NSSInit().*/
void setCertName(const std::string& certName);
- void connect(const std::string& host, const std::string& port) const;
+ void connect(const SocketAddress&) const;
void close() const;
/** Bind to a port and start listening.
*@param port 0 means choose an available port.
*@param backlog maximum number of pending connections.
- *@param certName name of certificate to use to identify the server
*@return The bound port.
*/
- int listen(uint16_t port = 0, int backlog = 10, const std::string& certName = "localhost.localdomain", bool clientAuth = false) const;
+ int listen(const SocketAddress&, int backlog = 10) const;
/**
* Accept a connection from a socket that is already listening
@@ -80,9 +81,9 @@ public:
std::string getClientAuthId() const;
protected:
- mutable std::string connectname;
- mutable PRFileDesc* socket;
+ mutable PRFileDesc* nssSocket;
std::string certname;
+ mutable std::string url;
/**
* 'model' socket, with configuration to use when importing
@@ -98,6 +99,7 @@ protected:
class SslMuxSocket : public SslSocket
{
public:
+ SslMuxSocket(const std::string& certName = "", bool clientAuth = false);
Socket* accept() const;
};
diff --git a/cpp/src/qpid/sys/windows/AsynchIO.cpp b/cpp/src/qpid/sys/windows/AsynchIO.cpp
index 618d0e14c7..9fdf89c83b 100644
--- a/cpp/src/qpid/sys/windows/AsynchIO.cpp
+++ b/cpp/src/qpid/sys/windows/AsynchIO.cpp
@@ -24,6 +24,7 @@
#include "qpid/sys/AsynchIO.h"
#include "qpid/sys/Mutex.h"
#include "qpid/sys/Socket.h"
+#include "qpid/sys/SocketAddress.h"
#include "qpid/sys/Poller.h"
#include "qpid/sys/Thread.h"
#include "qpid/sys/Time.h"
@@ -195,7 +196,7 @@ AsynchConnector::AsynchConnector(const Socket& sock,
void AsynchConnector::start(Poller::shared_ptr)
{
try {
- socket.connect(hostname, port);
+ socket.connect(SocketAddress(hostname, port));
socket.setNonblocking();
connCallback(socket);
} catch(std::exception& e) {
diff --git a/cpp/src/qpid/sys/windows/Socket.cpp b/cpp/src/qpid/sys/windows/Socket.cpp
index 17e3212a46..0c74b3a725 100644
--- a/cpp/src/qpid/sys/windows/Socket.cpp
+++ b/cpp/src/qpid/sys/windows/Socket.cpp
@@ -160,12 +160,6 @@ void Socket::setNonblocking() const {
QPID_WINSOCK_CHECK(ioctlsocket(impl->fd, FIONBIO, &nonblock));
}
-void Socket::connect(const std::string& host, const std::string& port) const
-{
- SocketAddress sa(host, port);
- connect(sa);
-}
-
void
Socket::connect(const SocketAddress& addr) const
{
@@ -209,12 +203,6 @@ int Socket::read(void *buf, size_t count) const
return received;
}
-int Socket::listen(const std::string& host, const std::string& port, int backlog) const
-{
- SocketAddress sa(host, port);
- return listen(sa, backlog);
-}
-
int Socket::listen(const SocketAddress& addr, int backlog) const
{
createSocket(addr);