From f9cc7dd7258d49c8e9e22d8761e585d75ed36bf0 Mon Sep 17 00:00:00 2001 From: Gordon Sim Date: Wed, 28 Nov 2012 14:14:03 +0000 Subject: QPID-4477: make sasl logic a bit smarter, to handle case where we transition input to tunnelled layer while output still has work for sasl git-svn-id: https://svn.apache.org/repos/asf/qpid/branches/0.20@1414712 13f79535-47bb-0310-9956-ffa450edef68 --- qpid/cpp/src/qpid/amqp/Decoder.cpp | 1 + qpid/cpp/src/qpid/amqp/Decoder.h | 1 + qpid/cpp/src/qpid/amqp/Sasl.cpp | 48 ++++++++++++--------- qpid/cpp/src/qpid/broker/amqp/Message.cpp | 5 +++ qpid/cpp/src/qpid/broker/amqp/Message.h | 1 + qpid/cpp/src/qpid/broker/amqp/Translation.cpp | 3 ++ .../src/qpid/messaging/amqp/ConnectionContext.cpp | 50 +++++++++++++++++++--- .../src/qpid/messaging/amqp/ConnectionContext.h | 11 +++++ qpid/cpp/src/qpid/messaging/amqp/Sasl.cpp | 11 ++--- qpid/cpp/src/qpid/messaging/amqp/Sasl.h | 2 +- 10 files changed, 96 insertions(+), 37 deletions(-) diff --git a/qpid/cpp/src/qpid/amqp/Decoder.cpp b/qpid/cpp/src/qpid/amqp/Decoder.cpp index 4c14c8e4d9..9c577e6c92 100644 --- a/qpid/cpp/src/qpid/amqp/Decoder.cpp +++ b/qpid/cpp/src/qpid/amqp/Decoder.cpp @@ -540,5 +540,6 @@ CharSequence Decoder::readRawUuid() } size_t Decoder::getPosition() const { return position; } +size_t Decoder::getSize() const { return size; } void Decoder::resetSize(size_t s) { size = s; } }} // namespace qpid::amqp diff --git a/qpid/cpp/src/qpid/amqp/Decoder.h b/qpid/cpp/src/qpid/amqp/Decoder.h index cf3e2d36d1..7ddfe0f17f 100644 --- a/qpid/cpp/src/qpid/amqp/Decoder.h +++ b/qpid/cpp/src/qpid/amqp/Decoder.h @@ -71,6 +71,7 @@ class Decoder QPID_COMMON_EXTERN void advance(size_t); QPID_COMMON_EXTERN size_t getPosition() const; QPID_COMMON_EXTERN void resetSize(size_t size); + QPID_COMMON_EXTERN size_t getSize() const; private: const char* const start; diff --git a/qpid/cpp/src/qpid/amqp/Sasl.cpp b/qpid/cpp/src/qpid/amqp/Sasl.cpp index 6d0a7ccb1f..7b0779fe94 100644 --- a/qpid/cpp/src/qpid/amqp/Sasl.cpp +++ b/qpid/cpp/src/qpid/amqp/Sasl.cpp @@ -58,29 +58,35 @@ void Sasl::endFrame(void* frame) std::size_t Sasl::read(const char* data, size_t available) { - Decoder decoder(data, available); - //read frame-header - uint32_t frameSize = decoder.readUInt(); - QPID_LOG(trace, "Reading SASL frame of size " << frameSize); - decoder.resetSize(frameSize); - uint8_t dataOffset = decoder.readUByte(); - uint8_t frameType = decoder.readUByte(); - if (frameType != 0x01) { - QPID_LOG(error, "Expected SASL frame; got type " << frameType); - } - uint16_t ignored = decoder.readUShort(); - if (ignored) { - QPID_LOG(info, "Got non null bytes at end of SASL frame header"); - } + size_t consumed = 0; + while (available - consumed > 4/*framesize*/) { + Decoder decoder(data+consumed, available-consumed); + //read frame-header + uint32_t frameSize = decoder.readUInt(); + if (frameSize > decoder.getSize()) break;//don't have all the data for this frame yet + + QPID_LOG(trace, "Reading SASL frame of size " << frameSize); + decoder.resetSize(frameSize); + uint8_t dataOffset = decoder.readUByte(); + uint8_t frameType = decoder.readUByte(); + if (frameType != 0x01) { + QPID_LOG(error, "Expected SASL frame; got type " << frameType); + } + uint16_t ignored = decoder.readUShort(); + if (ignored) { + QPID_LOG(info, "Got non null bytes at end of SASL frame header"); + } - //body is at offset 4*dataOffset from the start - size_t skip = dataOffset*4 - 8; - if (skip) { - QPID_LOG(info, "Offset for sasl frame was not as expected"); - decoder.advance(skip); + //body is at offset 4*dataOffset from the start + size_t skip = dataOffset*4 - 8; + if (skip) { + QPID_LOG(info, "Offset for sasl frame was not as expected"); + decoder.advance(skip); + } + decoder.read(*this); + consumed += decoder.getPosition(); } - decoder.read(*this); - return decoder.getPosition(); + return consumed; } std::size_t Sasl::write(char* data, size_t size) diff --git a/qpid/cpp/src/qpid/broker/amqp/Message.cpp b/qpid/cpp/src/qpid/broker/amqp/Message.cpp index af67f2ce22..a4c346e131 100644 --- a/qpid/cpp/src/qpid/broker/amqp/Message.cpp +++ b/qpid/cpp/src/qpid/broker/amqp/Message.cpp @@ -94,6 +94,7 @@ Message::Message(size_t size) : data(size) applicationProperties.init(); body.init(); + footer.init(); } char* Message::getData() { return &data[0]; } const char* Message::getData() const { return &data[0]; } @@ -140,6 +141,10 @@ qpid::amqp::CharSequence Message::getBody() const { return body; } +qpid::amqp::CharSequence Message::getFooter() const +{ + return footer; +} void Message::scan() { diff --git a/qpid/cpp/src/qpid/broker/amqp/Message.h b/qpid/cpp/src/qpid/broker/amqp/Message.h index d4a97c928a..cc3406f72a 100644 --- a/qpid/cpp/src/qpid/broker/amqp/Message.h +++ b/qpid/cpp/src/qpid/broker/amqp/Message.h @@ -63,6 +63,7 @@ class Message : public qpid::broker::Message::Encoding, private qpid::amqp::Mess qpid::amqp::CharSequence getApplicationProperties() const; qpid::amqp::CharSequence getBareMessage() const; qpid::amqp::CharSequence getBody() const; + qpid::amqp::CharSequence getFooter() const; Message(size_t size); char* getData(); diff --git a/qpid/cpp/src/qpid/broker/amqp/Translation.cpp b/qpid/cpp/src/qpid/broker/amqp/Translation.cpp index 551b4182e0..ca2094b965 100644 --- a/qpid/cpp/src/qpid/broker/amqp/Translation.cpp +++ b/qpid/cpp/src/qpid/broker/amqp/Translation.cpp @@ -215,6 +215,9 @@ void Translation::write(Outgoing& out) //write bare message qpid::amqp::CharSequence bareMessage = message->getBareMessage(); if (bareMessage.size) out.write(bareMessage.data, bareMessage.size); + //write footer: + qpid::amqp::CharSequence footer = message->getFooter(); + if (footer.size) out.write(footer.data, footer.size); } else { const qpid::broker::amqp_0_10::MessageTransfer* transfer = dynamic_cast(&original.getEncoding()); if (transfer) { diff --git a/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp b/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp index 173fcba552..b300fee450 100644 --- a/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp +++ b/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp @@ -53,7 +53,8 @@ ConnectionContext::ConnectionContext(const std::string& u, const qpid::types::Va writeHeader(false), readHeader(false), haveOutput(false), - state(DISCONNECTED) + state(DISCONNECTED), + codecSwitch(*this) { if (pn_transport_bind(engine, connection)) { //error @@ -563,13 +564,48 @@ bool ConnectionContext::useSasl() qpid::sys::Codec& ConnectionContext::getCodec() { - qpid::sys::ScopedLock l(lock); - if (sasl.get()) { - qpid::sys::Codec* c = sasl->getCodec(); - if (c) return *c; - lock.notifyAll(); + return codecSwitch; +} + +ConnectionContext::CodecSwitch::CodecSwitch(ConnectionContext& p) : parent(p) {} +std::size_t ConnectionContext::CodecSwitch::decode(const char* buffer, std::size_t size) +{ + qpid::sys::ScopedLock l(parent.lock); + size_t decoded = 0; + if (parent.sasl.get() && !parent.sasl->authenticated()) { + decoded = parent.sasl->decode(buffer, size); + if (!parent.sasl->authenticated()) return decoded; } - return *this; + if (decoded < size) { + if (parent.sasl.get() && parent.sasl->getSecurityLayer()) decoded += parent.sasl->getSecurityLayer()->decode(buffer+decoded, size-decoded); + else decoded += parent.decode(buffer+decoded, size-decoded); + } + return decoded; } +std::size_t ConnectionContext::CodecSwitch::encode(char* buffer, std::size_t size) +{ + qpid::sys::ScopedLock l(parent.lock); + size_t encoded = 0; + if (parent.sasl.get() && parent.sasl->canEncode()) { + encoded += parent.sasl->encode(buffer, size); + if (!parent.sasl->authenticated()) return encoded; + } + if (encoded < size) { + if (parent.sasl.get() && parent.sasl->getSecurityLayer()) encoded += parent.sasl->getSecurityLayer()->encode(buffer+encoded, size-encoded); + else encoded += parent.encode(buffer+encoded, size-encoded); + } + return encoded; +} +bool ConnectionContext::CodecSwitch::canEncode() +{ + qpid::sys::ScopedLock l(parent.lock); + if (parent.sasl.get()) { + if (parent.sasl->canEncode()) return true; + else if (!parent.sasl->authenticated()) return false; + else if (parent.sasl->getSecurityLayer()) return parent.sasl->getSecurityLayer()->canEncode(); + } + return parent.canEncode(); +} + }}} // namespace qpid::messaging::amqp diff --git a/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.h b/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.h index d9da6551b3..3718184365 100644 --- a/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.h +++ b/qpid/cpp/src/qpid/messaging/amqp/ConnectionContext.h @@ -123,6 +123,17 @@ class ConnectionContext : public qpid::sys::ConnectionCodec, public qpid::messag CONNECTED } state; std::auto_ptr sasl; + class CodecSwitch : public qpid::sys::Codec + { + public: + CodecSwitch(ConnectionContext&); + std::size_t decode(const char* buffer, std::size_t size); + std::size_t encode(char* buffer, std::size_t size); + bool canEncode(); + private: + ConnectionContext& parent; + }; + CodecSwitch codecSwitch; void wait(); void wakeupDriver(); diff --git a/qpid/cpp/src/qpid/messaging/amqp/Sasl.cpp b/qpid/cpp/src/qpid/messaging/amqp/Sasl.cpp index af13697c20..a8bae1adda 100644 --- a/qpid/cpp/src/qpid/messaging/amqp/Sasl.cpp +++ b/qpid/cpp/src/qpid/messaging/amqp/Sasl.cpp @@ -58,7 +58,7 @@ std::size_t Sasl::encode(char* buffer, std::size_t size) encoded += writeProtocolHeader(buffer, size); writeHeader = !encoded; } - if (state == NONE && encoded < size) { + if (encoded < size) { encoded += write(buffer + encoded, size - encoded); } haveOutput = (encoded == size); @@ -135,14 +135,9 @@ void Sasl::outcome(uint8_t result) context.activateOutput(); } -qpid::sys::Codec* Sasl::getCodec() +qpid::sys::Codec* Sasl::getSecurityLayer() { - switch (state) { - case SUCCEEDED: return static_cast(securityLayer.get()); - case FAILED: throw qpid::messaging::UnauthorizedAccess("Failed to authenticate"); - case NONE: return static_cast(this); - } - return 0; + return securityLayer.get(); } bool Sasl::authenticated() diff --git a/qpid/cpp/src/qpid/messaging/amqp/Sasl.h b/qpid/cpp/src/qpid/messaging/amqp/Sasl.h index 3a2f2e9ffc..6657779fdc 100644 --- a/qpid/cpp/src/qpid/messaging/amqp/Sasl.h +++ b/qpid/cpp/src/qpid/messaging/amqp/Sasl.h @@ -47,7 +47,7 @@ class Sasl : public qpid::sys::Codec, qpid::amqp::SaslClient bool canEncode(); bool authenticated(); - qpid::sys::Codec* getCodec(); + qpid::sys::Codec* getSecurityLayer(); std::string getAuthenticatedUsername(); private: ConnectionContext& context; -- cgit v1.2.1