diff options
Diffstat (limited to 'cpp/src')
-rw-r--r-- | cpp/src/qpid/messaging/amqp/ConnectionContext.cpp | 47 | ||||
-rw-r--r-- | cpp/src/qpid/messaging/amqp/Sasl.cpp | 49 | ||||
-rw-r--r-- | cpp/src/qpid/messaging/amqp/Sasl.h | 2 | ||||
-rw-r--r-- | cpp/src/qpid/messaging/exceptions.cpp | 1 |
4 files changed, 69 insertions, 30 deletions
diff --git a/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp b/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp index 9ed3713920..5aba00fc50 100644 --- a/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp +++ b/cpp/src/qpid/messaging/amqp/ConnectionContext.cpp @@ -851,13 +851,17 @@ std::size_t ConnectionContext::decode(const char* buffer, std::size_t size) { qpid::sys::ScopedLock<qpid::sys::Monitor> l(lock); size_t decoded = 0; - if (sasl.get() && !sasl->authenticated()) { - decoded = sasl->decode(buffer, size); - if (!sasl->authenticated()) return decoded; - } - if (decoded < size) { - if (sasl.get() && sasl->getSecurityLayer()) decoded += sasl->getSecurityLayer()->decode(buffer+decoded, size-decoded); - else decoded += decodePlain(buffer+decoded, size-decoded); + try { + if (sasl.get() && !sasl->authenticated()) { + decoded = sasl->decode(buffer, size); + if (!sasl->authenticated()) return decoded; + } + if (decoded < size) { + if (sasl.get() && sasl->getSecurityLayer()) decoded += sasl->getSecurityLayer()->decode(buffer+decoded, size-decoded); + else decoded += decodePlain(buffer+decoded, size-decoded); + } + } catch (const AuthenticationFailure&) { + transport->close(); } return decoded; } @@ -865,13 +869,17 @@ std::size_t ConnectionContext::encode(char* buffer, std::size_t size) { qpid::sys::ScopedLock<qpid::sys::Monitor> l(lock); size_t encoded = 0; - if (sasl.get() && sasl->canEncode()) { - encoded += sasl->encode(buffer, size); - if (!sasl->authenticated()) return encoded; - } - if (encoded < size) { - if (sasl.get() && sasl->getSecurityLayer()) encoded += sasl->getSecurityLayer()->encode(buffer+encoded, size-encoded); - else encoded += encodePlain(buffer+encoded, size-encoded); + try { + if (sasl.get() && sasl->canEncode()) { + encoded += sasl->encode(buffer, size); + if (!sasl->authenticated()) return encoded; + } + if (encoded < size) { + if (sasl.get() && sasl->getSecurityLayer()) encoded += sasl->getSecurityLayer()->encode(buffer+encoded, size-encoded); + else encoded += encodePlain(buffer+encoded, size-encoded); + } + } catch (const AuthenticationFailure&) { + transport->close(); } return encoded; } @@ -879,9 +887,14 @@ bool ConnectionContext::canEncode() { qpid::sys::ScopedLock<qpid::sys::Monitor> l(lock); if (sasl.get()) { - if (sasl->canEncode()) return true; - else if (!sasl->authenticated()) return false; - else if (sasl->getSecurityLayer()) return sasl->getSecurityLayer()->canEncode(); + try { + if (sasl->canEncode()) return true; + else if (!sasl->authenticated()) return false; + else if (sasl->getSecurityLayer()) return sasl->getSecurityLayer()->canEncode(); + } catch (const AuthenticationFailure&) { + transport->close(); + return false; + } } return canEncodePlain(); } diff --git a/cpp/src/qpid/messaging/amqp/Sasl.cpp b/cpp/src/qpid/messaging/amqp/Sasl.cpp index 9c198f81af..e1c15c2c22 100644 --- a/cpp/src/qpid/messaging/amqp/Sasl.cpp +++ b/cpp/src/qpid/messaging/amqp/Sasl.cpp @@ -93,21 +93,29 @@ void Sasl::mechanisms(const std::string& offered) mechanisms = offered; } - if (sasl->start(mechanisms, response, context.getTransportSecuritySettings())) { - init(sasl->getMechanism(), &response, hostname.size() ? &hostname : 0); - } else { - init(sasl->getMechanism(), 0, hostname.size() ? &hostname : 0); + try { + if (sasl->start(mechanisms, response, context.getTransportSecuritySettings())) { + init(sasl->getMechanism(), &response, hostname.size() ? &hostname : 0); + } else { + init(sasl->getMechanism(), 0, hostname.size() ? &hostname : 0); + } + haveOutput = true; + context.activateOutput(); + } catch (const std::exception& e) { + failed(e.what()); } - haveOutput = true; - context.activateOutput(); } void Sasl::challenge(const std::string& challenge) { QPID_LOG_CAT(debug, protocol, id << " Received SASL-CHALLENGE(" << challenge.size() << " bytes)"); - std::string r = sasl->step(challenge); - response(&r); - haveOutput = true; - context.activateOutput(); + try { + std::string r = sasl->step(challenge); + response(&r); + haveOutput = true; + context.activateOutput(); + } catch (const std::exception& e) { + failed(e.what()); + } } namespace { const std::string EMPTY; @@ -115,8 +123,12 @@ const std::string EMPTY; void Sasl::challenge() { QPID_LOG_CAT(debug, protocol, id << " Received SASL-CHALLENGE(null)"); - std::string r = sasl->step(EMPTY); - response(&r); + try { + std::string r = sasl->step(EMPTY); + response(&r); + } catch (const std::exception& e) { + failed(e.what()); + } } void Sasl::outcome(uint8_t result, const std::string& extra) { @@ -146,15 +158,26 @@ qpid::sys::Codec* Sasl::getSecurityLayer() return securityLayer.get(); } +namespace { +const std::string DEFAULT_ERROR("Authentication failed"); +} + bool Sasl::authenticated() { switch (state) { case SUCCEEDED: return true; - case FAILED: throw qpid::messaging::UnauthorizedAccess("Failed to authenticate"); + case FAILED: throw qpid::messaging::AuthenticationFailure(error.size() ? error : DEFAULT_ERROR); case NONE: default: return false; } } +void Sasl::failed(const std::string& text) +{ + QPID_LOG_CAT(info, client, id << " Failure during authentication: " << text); + error = text; + state = FAILED; +} + std::string Sasl::getAuthenticatedUsername() { return sasl->getUserId(); diff --git a/cpp/src/qpid/messaging/amqp/Sasl.h b/cpp/src/qpid/messaging/amqp/Sasl.h index 6de36bd7f2..a836e2e465 100644 --- a/cpp/src/qpid/messaging/amqp/Sasl.h +++ b/cpp/src/qpid/messaging/amqp/Sasl.h @@ -61,12 +61,14 @@ class Sasl : public qpid::sys::Codec, qpid::amqp::SaslClient NONE, FAILED, SUCCEEDED } state; std::auto_ptr<qpid::sys::SecurityLayer> securityLayer; + std::string error; void mechanisms(const std::string&); void challenge(const std::string&); void challenge(); //null != empty string void outcome(uint8_t result, const std::string&); void outcome(uint8_t result); + void failed(const std::string&); protected: bool stopReading(); }; diff --git a/cpp/src/qpid/messaging/exceptions.cpp b/cpp/src/qpid/messaging/exceptions.cpp index 11b0eb33f7..d21477b494 100644 --- a/cpp/src/qpid/messaging/exceptions.cpp +++ b/cpp/src/qpid/messaging/exceptions.cpp @@ -56,6 +56,7 @@ TransactionAborted::TransactionAborted(const std::string& msg) : TransactionErro UnauthorizedAccess::UnauthorizedAccess(const std::string& msg) : SessionError(msg) {} ConnectionError::ConnectionError(const std::string& msg) : MessagingException(msg) {} +AuthenticationFailure::AuthenticationFailure(const std::string& msg) : ConnectionError(msg) {} TransportFailure::TransportFailure(const std::string& msg) : MessagingException(msg) {} |