diff options
Diffstat (limited to 'cpp/src/qpid/client')
-rw-r--r-- | cpp/src/qpid/client/ConnectionHandler.cpp | 56 | ||||
-rw-r--r-- | cpp/src/qpid/client/ConnectionHandler.h | 7 | ||||
-rw-r--r-- | cpp/src/qpid/client/ConnectionImpl.cpp | 8 | ||||
-rw-r--r-- | cpp/src/qpid/client/ConnectionSettings.cpp | 9 | ||||
-rw-r--r-- | cpp/src/qpid/client/ConnectionSettings.h | 17 | ||||
-rw-r--r-- | cpp/src/qpid/client/Connector.cpp | 197 | ||||
-rw-r--r-- | cpp/src/qpid/client/Connector.h | 8 | ||||
-rw-r--r-- | cpp/src/qpid/client/RdmaConnector.cpp | 182 | ||||
-rw-r--r-- | cpp/src/qpid/client/Sasl.h | 52 | ||||
-rw-r--r-- | cpp/src/qpid/client/SaslFactory.cpp | 345 | ||||
-rw-r--r-- | cpp/src/qpid/client/SaslFactory.h | 48 |
11 files changed, 709 insertions, 220 deletions
diff --git a/cpp/src/qpid/client/ConnectionHandler.cpp b/cpp/src/qpid/client/ConnectionHandler.cpp index db5d006a17..2a070ebcff 100644 --- a/cpp/src/qpid/client/ConnectionHandler.cpp +++ b/cpp/src/qpid/client/ConnectionHandler.cpp @@ -21,16 +21,18 @@ #include "ConnectionHandler.h" -#include "qpid/log/Statement.h" +#include "SaslFactory.h" #include "qpid/framing/amqp_framing.h" #include "qpid/framing/all_method_bodies.h" #include "qpid/framing/ClientInvoker.h" #include "qpid/framing/reply_exceptions.h" #include "qpid/log/Helpers.h" +#include "qpid/log/Statement.h" using namespace qpid::client; using namespace qpid::framing; using namespace qpid::framing::connection; +using qpid::sys::SecurityLayer; namespace { const std::string OK("OK"); @@ -146,18 +148,50 @@ void ConnectionHandler::fail(const std::string& message) setState(FAILED); } -void ConnectionHandler::start(const FieldTable& /*serverProps*/, const Array& /*mechanisms*/, const Array& /*locales*/) +namespace { +std::string SPACE(" "); +} + +void ConnectionHandler::start(const FieldTable& /*serverProps*/, const Array& mechanisms, const Array& /*locales*/) { checkState(NOT_STARTED, INVALID_STATE_START); setState(NEGOTIATING); - //TODO: verify that desired mechanism and locale are supported - string response = ((char)0) + username + ((char)0) + password; - proxy.startOk(properties, mechanism, response, locale); + sasl = SaslFactory::getInstance().create(*this); + + std::string mechlist; + bool chosenMechanismSupported = mechanism.empty(); + for (Array::const_iterator i = mechanisms.begin(); i != mechanisms.end(); ++i) { + if (!mechanism.empty() && mechanism == (*i)->get<std::string>()) { + chosenMechanismSupported = true; + mechlist = (*i)->get<std::string>() + SPACE + mechlist; + } else { + if (i != mechanisms.begin()) mechlist += SPACE; + mechlist += (*i)->get<std::string>(); + } + } + + if (!chosenMechanismSupported) { + fail("Selected mechanism not supported: " + mechanism); + } + + if (sasl.get()) { + string response = sasl->start(mechanism.empty() ? mechlist : mechanism); + proxy.startOk(properties, sasl->getMechanism(), response, locale); + } else { + //TODO: verify that desired mechanism and locale are supported + string response = ((char)0) + username + ((char)0) + password; + proxy.startOk(properties, mechanism, response, locale); + } } -void ConnectionHandler::secure(const std::string& /*challenge*/) +void ConnectionHandler::secure(const std::string& challenge) { - throw NotImplementedException("Challenge-response cycle not yet implemented in client"); + if (sasl.get()) { + string response = sasl->step(challenge); + proxy.secureOk(response); + } else { + throw NotImplementedException("Challenge-response cycle not yet implemented in client"); + } } void ConnectionHandler::tune(uint16_t maxChannelsProposed, uint16_t maxFrameSizeProposed, @@ -179,6 +213,9 @@ void ConnectionHandler::openOk ( const Array& knownBrokers ) framing::Array::ValueVector::const_iterator i; for ( i = knownBrokers.begin(); i != knownBrokers.end(); ++i ) knownBrokersUrls.push_back(Url((*i)->get<std::string>())); + if (sasl.get()) { + securityLayer = sasl->getSecurityLayer(maxFrameSize); + } setState(OPEN); QPID_LOG(debug, "Known-brokers for connection: " << log::formatList(knownBrokersUrls)); } @@ -224,3 +261,8 @@ bool ConnectionHandler::isClosed() const } bool ConnectionHandler::isClosing() const { return getState() == CLOSING; } + +std::auto_ptr<qpid::sys::SecurityLayer> ConnectionHandler::getSecurityLayer() +{ + return securityLayer; +} diff --git a/cpp/src/qpid/client/ConnectionHandler.h b/cpp/src/qpid/client/ConnectionHandler.h index 12323684a5..ec9278626f 100644 --- a/cpp/src/qpid/client/ConnectionHandler.h +++ b/cpp/src/qpid/client/ConnectionHandler.h @@ -23,6 +23,7 @@ #include "ChainableFrameHandler.h" #include "ConnectionSettings.h" +#include "Sasl.h" #include "StateManager.h" #include "qpid/framing/AMQMethodBody.h" #include "qpid/framing/AMQP_HighestVersion.h" @@ -33,7 +34,9 @@ #include "qpid/framing/FieldTable.h" #include "qpid/framing/FrameHandler.h" #include "qpid/framing/InputHandler.h" +#include "qpid/sys/SecurityLayer.h" #include "qpid/Url.h" +#include <memory> namespace qpid { namespace client { @@ -64,6 +67,8 @@ class ConnectionHandler : private StateManager, framing::ProtocolVersion version; framing::Array capabilities; framing::FieldTable properties; + std::auto_ptr<Sasl> sasl; + std::auto_ptr<qpid::sys::SecurityLayer> securityLayer; void checkState(STATES s, const std::string& msg); @@ -103,6 +108,8 @@ public: bool isClosed() const; bool isClosing() const; + std::auto_ptr<qpid::sys::SecurityLayer> getSecurityLayer(); + CloseListener onClose; ErrorListener onError; diff --git a/cpp/src/qpid/client/ConnectionImpl.cpp b/cpp/src/qpid/client/ConnectionImpl.cpp index 0d7ffa0288..aa9eeb7489 100644 --- a/cpp/src/qpid/client/ConnectionImpl.cpp +++ b/cpp/src/qpid/client/ConnectionImpl.cpp @@ -110,6 +110,14 @@ void ConnectionImpl::open() connector->connect(host, port); connector->init(); handler.waitForOpen(); + //enable security layer if one has been negotiated: + std::auto_ptr<SecurityLayer> securityLayer = handler.getSecurityLayer(); + if (securityLayer.get()) { + QPID_LOG(debug, "Activating security layer"); + connector->activateSecurityLayer(securityLayer); + } else { + QPID_LOG(debug, "No security layer in place"); + } failover.reset(new FailoverListener(shared_from_this(), handler.knownBrokersUrls)); } diff --git a/cpp/src/qpid/client/ConnectionSettings.cpp b/cpp/src/qpid/client/ConnectionSettings.cpp index f5fc62dad2..5851917da6 100644 --- a/cpp/src/qpid/client/ConnectionSettings.cpp +++ b/cpp/src/qpid/client/ConnectionSettings.cpp @@ -22,6 +22,7 @@ #include "qpid/log/Logger.h" #include "qpid/sys/Socket.h" +#include "qpid/Version.h" namespace qpid { namespace client { @@ -30,15 +31,15 @@ ConnectionSettings::ConnectionSettings() : protocol("tcp"), host("localhost"), port(TcpAddress::DEFAULT_PORT), - username("guest"), - password("guest"), - mechanism("PLAIN"), locale("en_US"), heartbeat(0), maxChannels(32767), maxFrameSize(65535), bounds(2), - tcpNoDelay(false) + tcpNoDelay(false), + service(qpid::saslName), + minSsf(0), + maxSsf(256) {} ConnectionSettings::~ConnectionSettings() {} diff --git a/cpp/src/qpid/client/ConnectionSettings.h b/cpp/src/qpid/client/ConnectionSettings.h index 1b994a6da3..c7725e19f0 100644 --- a/cpp/src/qpid/client/ConnectionSettings.h +++ b/cpp/src/qpid/client/ConnectionSettings.h @@ -71,7 +71,8 @@ struct ConnectionSettings { std::string virtualhost; /** - * The username to use when authenticating the connection. + * The username to use when authenticating the connection. If not + * specified the current users login is used if available. */ std::string username; /** @@ -111,6 +112,20 @@ struct ConnectionSettings { * If true, TCP_NODELAY will be set for the connection. */ bool tcpNoDelay; + /** + * SASL service name + */ + std::string service; + /** + * Minimum acceptable strength of any SASL negotiated security + * layer. 0 means no security layer required. + */ + uint minSsf; + /** + * Maximum acceptable strength of any SASL negotiated security + * layer. 0 means no security layer allowed. + */ + uint maxSsf; }; }} // namespace qpid::client diff --git a/cpp/src/qpid/client/Connector.cpp b/cpp/src/qpid/client/Connector.cpp index bef98863a1..0e11b920e1 100644 --- a/cpp/src/qpid/client/Connector.cpp +++ b/cpp/src/qpid/client/Connector.cpp @@ -24,15 +24,18 @@ #include "ConnectionImpl.h" #include "ConnectionSettings.h" #include "qpid/log/Statement.h" +#include "qpid/sys/Codec.h" #include "qpid/sys/Time.h" #include "qpid/framing/AMQFrame.h" #include "qpid/sys/AsynchIO.h" #include "qpid/sys/Dispatcher.h" #include "qpid/sys/Poller.h" +#include "qpid/sys/SecurityLayer.h" #include "qpid/Msg.h" #include <iostream> #include <map> +#include <deque> #include <boost/bind.hpp> #include <boost/format.hpp> #include <boost/weak_ptr.hpp> @@ -74,39 +77,19 @@ void Connector::registerFactory(const std::string& proto, Factory* connectorFact theProtocolRegistry()[proto] = connectorFactory; } -class TCPConnector : public Connector, private sys::Runnable +class TCPConnector : public Connector, public sys::Codec, private sys::Runnable { + typedef std::deque<framing::AMQFrame> Frames; struct Buff; - /** Batch up frames for writing to aio. */ - class Writer : public framing::FrameHandler { - typedef sys::AsynchIOBufferBase BufferBase; - typedef std::vector<framing::AMQFrame> Frames; - - const uint16_t maxFrameSize; - sys::Mutex lock; - sys::AsynchIO* aio; - BufferBase* buffer; - Frames frames; - size_t lastEof; // Position after last EOF in frames - framing::Buffer encode; - size_t framesEncoded; - std::string identifier; - Bounds* bounds; - - void writeOne(); - void newBuffer(); + const uint16_t maxFrameSize; - public: - - Writer(uint16_t maxFrameSize, Bounds*); - ~Writer(); - void init(std::string id, sys::AsynchIO*); - void handle(framing::AMQFrame&); - void write(sys::AsynchIO&); - }; + sys::Mutex lock; + Frames frames; // Outgoing frame queue + size_t lastEof; // Position after last EOF in frames + uint64_t currentSize; + Bounds* bounds; - const uint16_t maxFrameSize; framing::ProtocolVersion version; bool initiated; @@ -119,14 +102,14 @@ class TCPConnector : public Connector, private sys::Runnable framing::InitiationHandler* initialiser; framing::OutputHandler* output; - Writer writer; - sys::Thread receiver; sys::Socket socket; sys::AsynchIO* aio; + std::string identifier; boost::shared_ptr<sys::Poller> poller; + std::auto_ptr<qpid::sys::SecurityLayer> securityLayer; ~TCPConnector(); @@ -139,8 +122,6 @@ class TCPConnector : public Connector, private sys::Runnable void writeDataBlock(const framing::AMQDataBlock& data); void eof(qpid::sys::AsynchIO&); - std::string identifier; - boost::weak_ptr<ConnectionImpl> impl; void connect(const std::string& host, int port); @@ -153,6 +134,12 @@ class TCPConnector : public Connector, private sys::Runnable sys::ShutdownHandler* getShutdownHandler() const; framing::OutputHandler* getOutputHandler(); const std::string& getIdentifier() const; + void activateSecurityLayer(std::auto_ptr<qpid::sys::SecurityLayer>); + + size_t decode(const char* buffer, size_t size); + size_t encode(const char* buffer, size_t size); + bool canEncode(); + public: TCPConnector(framing::ProtocolVersion pVersion, @@ -177,12 +164,14 @@ TCPConnector::TCPConnector(ProtocolVersion ver, const ConnectionSettings& settings, ConnectionImpl* cimpl) : maxFrameSize(settings.maxFrameSize), + lastEof(0), + currentSize(0), + bounds(cimpl), version(ver), initiated(false), closed(true), joined(true), shutdownHandler(0), - writer(maxFrameSize, cimpl), aio(0), impl(cimpl->shared_from_this()) { @@ -214,7 +203,6 @@ void TCPConnector::connect(const std::string& host, int port){ 0, // closed 0, // nobuffs boost::bind(&TCPConnector::writebuff, this, _1)); - writer.init(identifier, aio); } void TCPConnector::init(){ @@ -266,7 +254,21 @@ const std::string& TCPConnector::getIdentifier() const { } void TCPConnector::send(AMQFrame& frame) { - writer.handle(frame); + bool notifyWrite = false; + { + Mutex::ScopedLock l(lock); + frames.push_back(frame); + //only ask to write if this is the end of a frameset or if we + //already have a buffers worth of data + currentSize += frame.encodedSize(); + if (frame.getEof()) { + lastEof = frames.size(); + notifyWrite = true; + } else { + notifyWrite = (currentSize >= maxFrameSize); + } + } + if (notifyWrite) aio->notifyPendingWrite(); } void TCPConnector::handleClosed() { @@ -279,70 +281,70 @@ struct TCPConnector::Buff : public AsynchIO::BufferBase { ~Buff() { delete [] bytes;} }; -TCPConnector::Writer::Writer(uint16_t s, Bounds* b) : maxFrameSize(s), aio(0), buffer(0), lastEof(0), bounds(b) +void TCPConnector::writebuff(AsynchIO& /*aio*/) { -} - -TCPConnector::Writer::~Writer() { delete buffer; } + Codec* codec = securityLayer.get() ? (Codec*) securityLayer.get() : (Codec*) this; + if (codec->canEncode()) { + std::auto_ptr<AsynchIO::BufferBase> buffer = std::auto_ptr<AsynchIO::BufferBase>(aio->getQueuedBuffer()); + if (!buffer.get()) buffer = std::auto_ptr<AsynchIO::BufferBase>(new Buff(maxFrameSize)); + + size_t encoded = codec->encode(buffer->bytes, buffer->byteCount); -void TCPConnector::Writer::init(std::string id, sys::AsynchIO* a) { - Mutex::ScopedLock l(lock); - identifier = id; - aio = a; - newBuffer(); -} -void TCPConnector::Writer::handle(framing::AMQFrame& frame) { - Mutex::ScopedLock l(lock); - frames.push_back(frame); - //only try to write if this is the end of a frameset or if we - //already have a buffers worth of data - if (frame.getEof() || (bounds && bounds->getCurrentSize() >= maxFrameSize)) { - lastEof = frames.size(); - aio->notifyPendingWrite(); + buffer->dataStart = 0; + buffer->dataCount = encoded; + aio->queueWrite(buffer.release()); } - QPID_LOG(trace, "SENT " << identifier << ": " << frame); -} - -void TCPConnector::Writer::writeOne() { - assert(buffer); - framesEncoded = 0; - - buffer->dataStart = 0; - buffer->dataCount = encode.getPosition(); - aio->queueWrite(buffer); - newBuffer(); } -void TCPConnector::Writer::newBuffer() { - buffer = aio->getQueuedBuffer(); - if (!buffer) buffer = new Buff(maxFrameSize); - encode = framing::Buffer(buffer->bytes, buffer->byteCount); - framesEncoded = 0; +// Called in IO thread. +bool TCPConnector::canEncode() +{ + Mutex::ScopedLock l(lock); + //have at least one full frameset or a whole buffers worth of data + return lastEof || currentSize >= maxFrameSize; } // Called in IO thread. -void TCPConnector::Writer::write(sys::AsynchIO&) { - Mutex::ScopedLock l(lock); - assert(buffer); +size_t TCPConnector::encode(const char* buffer, size_t size) +{ + framing::Buffer out(const_cast<char*>(buffer), size); size_t bytesWritten(0); - for (size_t i = 0; i < lastEof; ++i) { - AMQFrame& frame = frames[i]; - uint32_t size = frame.encodedSize(); - if (size > encode.available()) writeOne(); - assert(size <= encode.available()); - frame.encode(encode); - ++framesEncoded; - bytesWritten += size; + { + Mutex::ScopedLock l(lock); + while (!frames.empty() && out.available() >= frames.front().encodedSize() ) { + frames.front().encode(out); + QPID_LOG(trace, "SENT " << identifier << ": " << frames.front()); + frames.pop_front(); + if (lastEof) --lastEof; + } + bytesWritten = size - out.available(); + currentSize -= bytesWritten; } - frames.erase(frames.begin(), frames.begin()+lastEof); - lastEof = 0; if (bounds) bounds->reduce(bytesWritten); - if (encode.getPosition() > 0) writeOne(); + return bytesWritten; } -bool TCPConnector::readbuff(AsynchIO& aio, AsynchIO::BufferBase* buff) { - framing::Buffer in(buff->bytes+buff->dataStart, buff->dataCount); +bool TCPConnector::readbuff(AsynchIO& aio, AsynchIO::BufferBase* buff) +{ + Codec* codec = securityLayer.get() ? (Codec*) securityLayer.get() : (Codec*) this; + int32_t decoded = codec->decode(buff->bytes+buff->dataStart, buff->dataCount); + // TODO: unreading needs to go away, and when we can cope + // with multiple sub-buffers in the general buffer scheme, it will + if (decoded < buff->dataCount) { + // Adjust buffer for used bytes and then "unread them" + buff->dataStart += decoded; + buff->dataCount -= decoded; + aio.unread(buff); + } else { + // Give whole buffer back to aio subsystem + aio.queueReadBuffer(buff); + } + return true; +} +size_t TCPConnector::decode(const char* buffer, size_t size) +{ + framing::Buffer in(const_cast<char*>(buffer), size); if (!initiated) { framing::ProtocolInitiation protocolInit; if (protocolInit.decode(in)) { @@ -356,22 +358,7 @@ bool TCPConnector::readbuff(AsynchIO& aio, AsynchIO::BufferBase* buff) { QPID_LOG(trace, "RECV " << identifier << ": " << frame); input->received(frame); } - // TODO: unreading needs to go away, and when we can cope - // with multiple sub-buffers in the general buffer scheme, it will - if (in.available() != 0) { - // Adjust buffer for used bytes and then "unread them" - buff->dataStart += buff->dataCount-in.available(); - buff->dataCount = in.available(); - aio.unread(buff); - } else { - // Give whole buffer back to aio subsystem - aio.queueReadBuffer(buff); - } - return true; -} - -void TCPConnector::writebuff(AsynchIO& aio_) { - writer.write(aio_); + return size - in.available(); } void TCPConnector::writeDataBlock(const AMQDataBlock& data) { @@ -388,7 +375,7 @@ void TCPConnector::eof(AsynchIO&) { // TODO: astitcher 20070908 This version of the code can never time out, so the idle processing // will never be called -void TCPConnector::run(){ +void TCPConnector::run() { // Keep the connection impl in memory until run() completes. boost::shared_ptr<ConnectionImpl> protect = impl.lock(); assert(protect); @@ -409,5 +396,11 @@ void TCPConnector::run(){ } } +void TCPConnector::activateSecurityLayer(std::auto_ptr<qpid::sys::SecurityLayer> sl) +{ + securityLayer = sl; + securityLayer->init(this); +} + }} // namespace qpid::client diff --git a/cpp/src/qpid/client/Connector.h b/cpp/src/qpid/client/Connector.h index 5c37d95300..e23fb8875b 100644 --- a/cpp/src/qpid/client/Connector.h +++ b/cpp/src/qpid/client/Connector.h @@ -40,6 +40,11 @@ #include <boost/shared_ptr.hpp> namespace qpid { + +namespace sys { +class SecurityLayer; +} + namespace client { struct ConnectionSettings; @@ -65,6 +70,9 @@ class Connector : public framing::OutputHandler virtual sys::ShutdownHandler* getShutdownHandler() const = 0; virtual framing::OutputHandler* getOutputHandler() = 0; virtual const std::string& getIdentifier() const = 0; + + virtual void activateSecurityLayer(std::auto_ptr<qpid::sys::SecurityLayer>) {} + }; }} diff --git a/cpp/src/qpid/client/RdmaConnector.cpp b/cpp/src/qpid/client/RdmaConnector.cpp index 98fe762f31..3cc8961eea 100644 --- a/cpp/src/qpid/client/RdmaConnector.cpp +++ b/cpp/src/qpid/client/RdmaConnector.cpp @@ -29,6 +29,7 @@ #include "qpid/sys/rdma/RdmaIO.h" #include "qpid/sys/Dispatcher.h" #include "qpid/sys/Poller.h" +#include "qpid/sys/SecurityLayer.h" #include "qpid/Msg.h" #include <iostream> @@ -47,39 +48,21 @@ using namespace qpid::framing; using boost::format; using boost::str; -class RdmaConnector : public Connector, private sys::Runnable + class RdmaConnector : public Connector, public sys::Codec, private sys::Runnable { struct Buff; - /** Batch up frames for writing to aio. */ - class Writer : public framing::FrameHandler { - typedef Rdma::Buffer BufferBase; - typedef std::deque<framing::AMQFrame> Frames; - - const uint16_t maxFrameSize; - sys::Mutex lock; - Rdma::AsynchIO* aio; - BufferBase* buffer; - Frames frames; - size_t lastEof; // Position after last EOF in frames - framing::Buffer encode; - size_t framesEncoded; - std::string identifier; - Bounds* bounds; - - void writeOne(); - void newBuffer(); + typedef Rdma::Buffer BufferBase; + typedef std::deque<framing::AMQFrame> Frames; - public: - - Writer(uint16_t maxFrameSize, Bounds*); - ~Writer(); - void init(std::string id, Rdma::AsynchIO*); - void handle(framing::AMQFrame&); - void write(Rdma::AsynchIO&); - }; - const uint16_t maxFrameSize; + sys::Mutex lock; + Frames frames; + size_t lastEof; // Position after last EOF in frames + uint64_t currentSize; + Bounds* bounds; + + framing::ProtocolVersion version; bool initiated; @@ -92,12 +75,11 @@ class RdmaConnector : public Connector, private sys::Runnable framing::InitiationHandler* initialiser; framing::OutputHandler* output; - Writer writer; - sys::Thread receiver; Rdma::AsynchIO* aio; sys::Poller::shared_ptr poller; + std::auto_ptr<qpid::sys::SecurityLayer> securityLayer; ~RdmaConnector(); @@ -129,6 +111,11 @@ class RdmaConnector : public Connector, private sys::Runnable sys::ShutdownHandler* getShutdownHandler() const; framing::OutputHandler* getOutputHandler(); const std::string& getIdentifier() const; + void activateSecurityLayer(std::auto_ptr<qpid::sys::SecurityLayer>); + + size_t decode(const char* buffer, size_t size); + size_t encode(const char* buffer, size_t size); + bool canEncode(); public: RdmaConnector(framing::ProtocolVersion pVersion, @@ -155,12 +142,14 @@ RdmaConnector::RdmaConnector(ProtocolVersion ver, const ConnectionSettings& settings, ConnectionImpl* cimpl) : maxFrameSize(settings.maxFrameSize), + lastEof(0), + currentSize(0), + bounds(cimpl), version(ver), initiated(false), polling(false), joined(true), shutdownHandler(0), - writer(maxFrameSize, cimpl), aio(0), impl(cimpl) { @@ -216,7 +205,6 @@ void RdmaConnector::connected(Poller::shared_ptr poller, Rdma::Connection::intru aio->start(poller); identifier = str(format("[%1% %2%]") % ci->getLocalName() % ci->getPeerName()); - writer.init(identifier, aio); ProtocolInitiation init(version); writeDataBlock(init); } @@ -279,7 +267,21 @@ const std::string& RdmaConnector::getIdentifier() const { } void RdmaConnector::send(AMQFrame& frame) { - writer.handle(frame); + bool notifyWrite = false; + { + Mutex::ScopedLock l(lock); + frames.push_back(frame); + //only ask to write if this is the end of a frameset or if we + //already have a buffers worth of data + currentSize += frame.encodedSize(); + if (frame.getEof()) { + lastEof = frames.size(); + notifyWrite = true; + } else { + notifyWrite = (currentSize >= maxFrameSize); + } + } + if (notifyWrite) aio->notifyPendingWrite(); } void RdmaConnector::handleClosed() { @@ -287,88 +289,54 @@ void RdmaConnector::handleClosed() { shutdownHandler->shutdown(); } -RdmaConnector::Writer::Writer(uint16_t s, Bounds* b) : - maxFrameSize(s), - aio(0), - buffer(0), - lastEof(0), - bounds(b) -{ -} - -RdmaConnector::Writer::~Writer() { - if (aio) - aio->returnBuffer(buffer); -} - -void RdmaConnector::Writer::init(std::string id, Rdma::AsynchIO* a) { - Mutex::ScopedLock l(lock); - identifier = id; - aio = a; - assert(aio->bufferAvailable()); - newBuffer(); -} -void RdmaConnector::Writer::handle(framing::AMQFrame& frame) { - Mutex::ScopedLock l(lock); - frames.push_back(frame); - // Don't bother to send anything unless we're at the end of a frameset (assembly in 0-10 terminology) - if (frame.getEof()) { - lastEof = frames.size(); - QPID_LOG(debug, "Requesting write: lastEof=" << lastEof); - aio->notifyPendingWrite(); +// Called in IO thread. (write idle routine) +// This is NOT only called in response to previously calling notifyPendingWrite +void RdmaConnector::writebuff(Rdma::AsynchIO&) { + Codec* codec = securityLayer.get() ? (Codec*) securityLayer.get() : (Codec*) this; + if (codec->canEncode()) { + std::auto_ptr<BufferBase> buffer = std::auto_ptr<BufferBase>(aio->getBuffer()); + size_t encoded = codec->encode(buffer->bytes, buffer->byteCount); + + buffer->dataStart = 0; + buffer->dataCount = encoded; + aio->queueWrite(buffer.release()); } - QPID_LOG(trace, "SENT " << identifier << ": " << frame); } -void RdmaConnector::Writer::writeOne() { - assert(buffer); - QPID_LOG(trace, "Write buffer " << encode.getPosition() - << " bytes " << framesEncoded << " frames "); - framesEncoded = 0; - - buffer->dataStart = 0; - buffer->dataCount = encode.getPosition(); - aio->queueWrite(buffer); - newBuffer(); -} - -void RdmaConnector::Writer::newBuffer() { - buffer = aio->getBuffer(); - encode = framing::Buffer(buffer->bytes, buffer->byteCount); - framesEncoded = 0; +bool RdmaConnector::canEncode() +{ + Mutex::ScopedLock l(lock); + //have at least one full frameset or a whole buffers worth of data + return aio->writable() && aio->bufferAvailable() && (lastEof || currentSize >= maxFrameSize); } -// Called in IO thread. (write idle routine) -// This is NOT only called in response to previously calling notifyPendingWrite -void RdmaConnector::Writer::write(Rdma::AsynchIO&) { - Mutex::ScopedLock l(lock); - assert(buffer); - // If nothing to do return immediately - if (lastEof==0) - return; - size_t bytesWritten = 0; - while (aio->writable() && aio->bufferAvailable() && !frames.empty()) { - const AMQFrame* frame = &frames.front(); - uint32_t size = frame->encodedSize(); - while (size <= encode.available()) { - frame->encode(encode); +size_t RdmaConnector::encode(const char* buffer, size_t size) +{ + framing::Buffer out(const_cast<char*>(buffer), size); + size_t bytesWritten(0); + { + Mutex::ScopedLock l(lock); + while (!frames.empty() && out.available() >= frames.front().encodedSize() ) { + frames.front().encode(out); + QPID_LOG(trace, "SENT " << identifier << ": " << frames.front()); frames.pop_front(); - ++framesEncoded; - bytesWritten += size; - if (frames.empty()) - break; - frame = &frames.front(); - size = frame->encodedSize(); + if (lastEof) --lastEof; } - lastEof -= framesEncoded; - writeOne(); + bytesWritten = size - out.available(); + currentSize -= bytesWritten; } if (bounds) bounds->reduce(bytesWritten); + return bytesWritten; } void RdmaConnector::readbuff(Rdma::AsynchIO&, Rdma::Buffer* buff) { - framing::Buffer in(buff->bytes+buff->dataStart, buff->dataCount); + Codec* codec = securityLayer.get() ? (Codec*) securityLayer.get() : (Codec*) this; + codec->decode(buff->bytes+buff->dataStart, buff->dataCount); +} +size_t RdmaConnector::decode(const char* buffer, size_t size) +{ + framing::Buffer in(const_cast<char*>(buffer), size); if (!initiated) { framing::ProtocolInitiation protocolInit; if (protocolInit.decode(in)) { @@ -382,10 +350,7 @@ void RdmaConnector::readbuff(Rdma::AsynchIO&, Rdma::Buffer* buff) { QPID_LOG(trace, "RECV " << identifier << ": " << frame); input->received(frame); } -} - -void RdmaConnector::writebuff(Rdma::AsynchIO& aio_) { - writer.write(aio_); + return size - in.available(); } void RdmaConnector::writeDataBlock(const AMQDataBlock& data) { @@ -424,5 +389,10 @@ void RdmaConnector::run(){ } } +void RdmaConnector::activateSecurityLayer(std::auto_ptr<qpid::sys::SecurityLayer> sl) +{ + securityLayer = sl; + securityLayer->init(this); +} }} // namespace qpid::client diff --git a/cpp/src/qpid/client/Sasl.h b/cpp/src/qpid/client/Sasl.h new file mode 100644 index 0000000000..e7a911ebce --- /dev/null +++ b/cpp/src/qpid/client/Sasl.h @@ -0,0 +1,52 @@ +#ifndef QPID_CLIENT_SASL_H +#define QPID_CLIENT_SASL_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include <memory> +#include <string> + +namespace qpid { + +namespace sys { +class SecurityLayer; +} + +namespace client { + +class ConnectionSettings; + +/** + * Interface to SASL support + */ +class Sasl +{ + public: + virtual std::string start(const std::string& mechanisms) = 0; + virtual std::string step(const std::string& challenge) = 0; + virtual std::string getMechanism() = 0; + virtual std::auto_ptr<qpid::sys::SecurityLayer> getSecurityLayer(uint16_t maxFrameSize) = 0; + virtual ~Sasl() {} +}; +}} // namespace qpid::client + +#endif /*!QPID_CLIENT_SASL_H*/ diff --git a/cpp/src/qpid/client/SaslFactory.cpp b/cpp/src/qpid/client/SaslFactory.cpp new file mode 100644 index 0000000000..d6edc6501d --- /dev/null +++ b/cpp/src/qpid/client/SaslFactory.cpp @@ -0,0 +1,345 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "SaslFactory.h" +#include "ConnectionSettings.h" + +#ifdef HAVE_CONFIG_H +# include "config.h" +#endif + +#ifndef HAVE_SASL + +namespace qpid { +namespace client { + +//Null implementation + +SaslFactory::SaslFactory() {} + +SaslFactory::~SaslFactory() {} + +SaslFactory& SaslFactory::getInstance() +{ + qpid::sys::Mutex::ScopedLock l(lock); + if (!instance.get()) { + instance = std::auto_ptr<SaslFactory>(new SaslFactory()); + } + return *instance; +} + +std::auto_ptr<Sasl> SaslFactory::create(const ConnectionSettings&) +{ + return std::auto_ptr<Sasl>(); +} + +qpid::sys::Mutex SaslFactory::lock; +std::auto_ptr<SaslFactory> SaslFactory::instance; + +}} // namespace qpid::client + +#else + +#include "qpid/Exception.h" +#include "qpid/framing/reply_exceptions.h" +#include "qpid/sys/SecurityLayer.h" +#include "qpid/sys/cyrus/CyrusSecurityLayer.h" +#include "qpid/log/Statement.h" +#include <sasl/sasl.h> +#include <strings.h> + +namespace qpid { +namespace client { + +using qpid::sys::SecurityLayer; +using qpid::sys::cyrus::CyrusSecurityLayer; +using qpid::framing::InternalErrorException; + +const size_t MAX_LOGIN_LENGTH = 50; + +class CyrusSasl : public Sasl +{ + public: + CyrusSasl(const ConnectionSettings&); + ~CyrusSasl(); + std::string start(const std::string& mechanisms); + std::string step(const std::string& challenge); + std::string getMechanism(); + std::auto_ptr<SecurityLayer> getSecurityLayer(uint16_t maxFrameSize); + private: + sasl_conn_t* conn; + sasl_callback_t callbacks[5];//realm, user, authname, password, end-of-list + ConnectionSettings settings; + std::string input; + std::string mechanism; + char login[MAX_LOGIN_LENGTH]; + + void interact(sasl_interact_t* client_interact); +}; + +//sasl callback functions +int getLogin(void *context, int id, const char **result, unsigned *len); +int getUserFromSettings(void *context, int id, const char **result, unsigned *len); +int getPasswordFromSettings(sasl_conn_t *conn, void *context, int id, sasl_secret_t **psecret); +typedef int CallbackProc(); + +qpid::sys::Mutex SaslFactory::lock; +std::auto_ptr<SaslFactory> SaslFactory::instance; + +SaslFactory::SaslFactory() +{ + sasl_callback_t* callbacks = 0; + int result = sasl_client_init(callbacks); + if (result != SASL_OK) { + throw InternalErrorException(QPID_MSG("Sasl error: " << sasl_errstring(result, 0, 0))); + } +} + +SaslFactory::~SaslFactory() +{ + sasl_done(); +} + +SaslFactory& SaslFactory::getInstance() +{ + qpid::sys::Mutex::ScopedLock l(lock); + if (!instance.get()) { + instance = std::auto_ptr<SaslFactory>(new SaslFactory()); + } + return *instance; +} + +std::auto_ptr<Sasl> SaslFactory::create(const ConnectionSettings& settings) +{ + std::auto_ptr<Sasl> sasl(new CyrusSasl(settings)); + return sasl; +} + +CyrusSasl::CyrusSasl(const ConnectionSettings& s) : conn(0), settings(s) +{ + size_t i = 0; + + callbacks[i].id = SASL_CB_GETREALM; + callbacks[i].proc = 0; + callbacks[i++].context = 0; + + if (settings.username.empty()) { + callbacks[i].id = SASL_CB_USER; + callbacks[i].proc = (CallbackProc*) &getLogin; + callbacks[i++].context = &login; + + callbacks[i].id = SASL_CB_AUTHNAME; + callbacks[i].proc = (CallbackProc*) &getLogin; + callbacks[i++].context = &login; + } else { + callbacks[i].id = SASL_CB_USER; + callbacks[i].proc = (CallbackProc*) &getUserFromSettings; + callbacks[i++].context = &settings; + + callbacks[i].id = SASL_CB_AUTHNAME; + callbacks[i].proc = (CallbackProc*) &getUserFromSettings; + callbacks[i++].context = &settings; + } + + callbacks[i].id = SASL_CB_PASS; + callbacks[i].proc = (CallbackProc*) &getPasswordFromSettings; + callbacks[i++].context = &settings; + + callbacks[i].id = SASL_CB_LIST_END; + callbacks[i].proc = 0; + callbacks[i++].context = 0; +} + +CyrusSasl::~CyrusSasl() +{ + if (conn) { + sasl_dispose(&conn); + } +} + +namespace { + const std::string SSL("ssl"); +} + +std::string CyrusSasl::start(const std::string& mechanisms) +{ + QPID_LOG(debug, "CyrusSasl::start(" << mechanisms << ")"); + int result = sasl_client_new(settings.service.c_str(), + settings.host.c_str(), + 0, 0, /* Local and remote IP address strings */ + callbacks, + 0, /* security flags */ + &conn); + + if (result != SASL_OK) throw InternalErrorException(QPID_MSG("Sasl error: " << sasl_errdetail(conn))); + + sasl_security_properties_t secprops; + + secprops.min_ssf = settings.minSsf; + secprops.max_ssf = settings.maxSsf; + secprops.maxbufsize = 65535; + + QPID_LOG(debug, "min_ssf: " << secprops.min_ssf << ", max_ssf: " << secprops.max_ssf); + + secprops.property_names = 0; + secprops.property_values = 0; + secprops.security_flags = 0;//TODO: provide means for application to configure these + + result = sasl_setprop(conn, SASL_SEC_PROPS, &secprops); + if (result != SASL_OK) { + throw framing::InternalErrorException(QPID_MSG("SASL error: " << sasl_errdetail(conn))); + } + + + sasl_interact_t* client_interact = 0; + const char *out = 0; + unsigned outlen = 0; + const char *chosenMechanism = 0; + + do { + result = sasl_client_start(conn, + mechanisms.c_str(), + &client_interact, + &out, + &outlen, + &chosenMechanism); + + if (result == SASL_INTERACT) { + interact(client_interact); + } + } while (result == SASL_INTERACT); + + if (result != SASL_CONTINUE && result != SASL_OK) { + throw InternalErrorException(QPID_MSG("Sasl error: " << sasl_errdetail(conn))); + } + + mechanism = std::string(chosenMechanism); + QPID_LOG(debug, "CyrusSasl::start(" << mechanisms << "): selected " + << mechanism << " response: '" << std::string(out, outlen) << "'"); + return std::string(out, outlen); +} + +std::string CyrusSasl::step(const std::string& challenge) +{ + sasl_interact_t* client_interact = 0; + const char *out = 0; + unsigned outlen = 0; + int result = 0; + do { + result = sasl_client_step(conn, /* our context */ + challenge.data(), /* the data from the server */ + challenge.size(), /* it's length */ + &client_interact, /* this should be + unallocated and NULL */ + &out, /* filled in on success */ + &outlen); /* filled in on success */ + + if (result == SASL_INTERACT) { + interact(client_interact); + } + } while (result == SASL_INTERACT); + + std::string response; + if (result == SASL_CONTINUE || result == SASL_OK) response = std::string(out, outlen); + else if (result != SASL_OK) { + throw InternalErrorException(QPID_MSG("Sasl error: " << sasl_errdetail(conn))); + } + QPID_LOG(debug, "CyrusSasl::step(" << challenge << "): " << response); + return response; +} + +std::string CyrusSasl::getMechanism() +{ + return mechanism; +} + +void CyrusSasl::interact(sasl_interact_t* client_interact) +{ + std::cout << "[" << client_interact->id << "] " << client_interact->challenge << " " << client_interact->prompt; + if (client_interact->defresult) std::cout << " (" << client_interact->defresult << ")"; + std::cout << std::endl; + if (std::cin >> input) { + client_interact->result = input.data(); + client_interact->len = input.size(); + } +} + +std::auto_ptr<SecurityLayer> CyrusSasl::getSecurityLayer(uint16_t maxFrameSize) +{ + const void* value(0); + int result = sasl_getprop(conn, SASL_SSF, &value); + if (result != SASL_OK) { + throw framing::InternalErrorException(QPID_MSG("SASL error: " << sasl_errdetail(conn))); + } + uint ssf = *(reinterpret_cast<const unsigned*>(value)); + std::auto_ptr<SecurityLayer> securityLayer; + if (ssf) { + QPID_LOG(info, "Installing security layer, SSF: "<< ssf); + securityLayer = std::auto_ptr<SecurityLayer>(new CyrusSecurityLayer(conn, maxFrameSize)); + } + return securityLayer; +} + +int getLogin(void* context, int /*id*/, const char** result, unsigned* /*len*/) +{ + if (context) { + char* login = (char*) context; + int status = getlogin_r(login, MAX_LOGIN_LENGTH); + if (status == 0) { + *result = login; + QPID_LOG(debug, "getLogin(): " << (*result)); + } else { + strcpy(login, "guest"); + QPID_LOG(error, "getlogin_r() failed with " << status << "; defaulting to " << login); + } + return SASL_OK; + } else { + return SASL_FAIL; + } +} + +int getUserFromSettings(void* context, int /*id*/, const char** result, unsigned* /*len*/) +{ + if (context) { + *result = ((ConnectionSettings*) context)->username.c_str(); + QPID_LOG(debug, "getUserFromSettings(): " << (*result)); + return SASL_OK; + } else { + return SASL_FAIL; + } +} + +int getPasswordFromSettings(sasl_conn_t* /*conn*/, void* context, int /*id*/, sasl_secret_t** psecret) +{ + if (context) { + size_t length = ((ConnectionSettings*) context)->password.size(); + sasl_secret_t* secret = (sasl_secret_t*) malloc(sizeof(sasl_secret_t) + length); + secret->len = length; + memcpy(secret->data, ((ConnectionSettings*) context)->password.data(), length); + *psecret = secret; + return SASL_OK; + } else { + return SASL_FAIL; + } +} + +}} // namespace qpid::client + +#endif diff --git a/cpp/src/qpid/client/SaslFactory.h b/cpp/src/qpid/client/SaslFactory.h new file mode 100644 index 0000000000..60a1d60ff3 --- /dev/null +++ b/cpp/src/qpid/client/SaslFactory.h @@ -0,0 +1,48 @@ +#ifndef QPID_CLIENT_SASLFACTORY_H +#define QPID_CLIENT_SASLFACTORY_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "Sasl.h" +#include "qpid/sys/Mutex.h" +#include <memory> + +namespace qpid { +namespace client { + +/** + * Factory for instances of the Sasl interface through which Sasl + * support is provided to a ConnectionHandler. + */ +class SaslFactory +{ + public: + std::auto_ptr<Sasl> create(const ConnectionSettings&); + static SaslFactory& getInstance(); + ~SaslFactory(); + private: + SaslFactory(); + static qpid::sys::Mutex lock; + static std::auto_ptr<SaslFactory> instance; +}; +}} // namespace qpid::client + +#endif /*!QPID_CLIENT_SASLFACTORY_H*/ |