summaryrefslogtreecommitdiff
path: root/src/mongo/db/auth/sasl_scram_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/auth/sasl_scram_test.cpp')
-rw-r--r--src/mongo/db/auth/sasl_scram_test.cpp649
1 files changed, 649 insertions, 0 deletions
diff --git a/src/mongo/db/auth/sasl_scram_test.cpp b/src/mongo/db/auth/sasl_scram_test.cpp
new file mode 100644
index 00000000000..c75cae5f260
--- /dev/null
+++ b/src/mongo/db/auth/sasl_scram_test.cpp
@@ -0,0 +1,649 @@
+/*
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kDefault
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/client/native_sasl_client_session.h"
+#include "mongo/client/scram_client_cache.h"
+#include "mongo/crypto/mechanism_scram.h"
+#include "mongo/crypto/sha1_block.h"
+#include "mongo/crypto/sha256_block.h"
+#include "mongo/db/auth/authorization_manager.h"
+#include "mongo/db/auth/authz_manager_external_state_mock.h"
+#include "mongo/db/auth/authz_session_external_state_mock.h"
+#include "mongo/db/auth/native_sasl_authentication_session.h"
+#include "mongo/db/auth/sasl_scram_server_conversation.h"
+#include "mongo/db/service_context_noop.h"
+#include "mongo/stdx/memory.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/base64.h"
+#include "mongo/util/log.h"
+#include "mongo/util/password_digest.h"
+
+namespace mongo {
+namespace {
+
+BSONObj generateSCRAMUserDocument(StringData username, StringData password) {
+ const auto database = "test"_sd;
+
+ const auto digested = createPasswordDigest(username, password);
+ const auto sha1Cred = scram::Secrets<SHA1Block>::generateCredentials(digested, 10000);
+ const auto sha256Cred =
+ scram::Secrets<SHA256Block>::generateCredentials(password.toString(), 15000);
+ return BSON("_id" << (str::stream() << database << "." << username).operator StringData()
+ << AuthorizationManager::USER_NAME_FIELD_NAME
+ << username
+ << AuthorizationManager::USER_DB_FIELD_NAME
+ << database
+ << "credentials"
+ << BSON("SCRAM-SHA-1" << sha1Cred << "SCRAM-SHA-256" << sha256Cred)
+ << "roles"
+ << BSONArray()
+ << "privileges"
+ << BSONArray());
+}
+
+std::string corruptEncodedPayload(const std::string& message,
+ std::string::const_iterator begin,
+ std::string::const_iterator end) {
+ std::string raw = base64::decode(
+ message.substr(std::distance(message.begin(), begin), std::distance(begin, end)));
+ if (raw[0] == std::numeric_limits<char>::max()) {
+ raw[0] -= 1;
+ } else {
+ raw[0] += 1;
+ }
+ return base64::encode(raw);
+}
+
+class SaslTestState {
+public:
+ enum Participant { kClient, kServer };
+ SaslTestState() : SaslTestState(kClient, 0) {}
+ SaslTestState(Participant participant, size_t stage) : participant(participant), stage(stage) {}
+
+private:
+ // Define members here, so that they can be used in declaration of lens(). In C++14, lens()
+ // can be declared with a return of decltype(auto), without a trailing return type, and these
+ // members can go at the end of the class.
+ Participant participant;
+ size_t stage;
+
+public:
+ auto lens() const -> decltype(std::tie(this->stage, this->participant)) {
+ return std::tie(stage, participant);
+ }
+
+ friend bool operator==(const SaslTestState& lhs, const SaslTestState& rhs) {
+ return lhs.lens() == rhs.lens();
+ }
+
+ friend bool operator<(const SaslTestState& lhs, const SaslTestState& rhs) {
+ return lhs.lens() < rhs.lens();
+ }
+
+ void next() {
+ if (participant == kClient) {
+ participant = kServer;
+ } else {
+ participant = kClient;
+ stage++;
+ }
+ }
+
+ std::string toString() const {
+ std::stringstream ss;
+ if (participant == kClient) {
+ ss << "Client";
+ } else {
+ ss << "Server";
+ }
+ ss << "Step" << stage;
+
+ return ss.str();
+ }
+};
+
+class SCRAMMutators {
+public:
+ SCRAMMutators() {}
+
+ void setMutator(SaslTestState state, stdx::function<void(std::string&)> fun) {
+ mutators.insert(std::make_pair(state, fun));
+ }
+
+ void execute(SaslTestState state, std::string& str) {
+ auto it = mutators.find(state);
+ if (it != mutators.end()) {
+ it->second(str);
+ }
+ }
+
+private:
+ std::map<SaslTestState, stdx::function<void(std::string&)>> mutators;
+};
+
+struct SCRAMStepsResult {
+ SCRAMStepsResult() : outcome(SaslTestState::kClient, 1), status(Status::OK()) {}
+ SCRAMStepsResult(SaslTestState outcome, Status status) : outcome(outcome), status(status) {}
+ bool operator==(const SCRAMStepsResult& other) const {
+ return outcome == other.outcome && status.code() == other.status.code() &&
+ status.reason() == other.status.reason();
+ }
+ SaslTestState outcome;
+ Status status;
+
+ friend std::ostream& operator<<(std::ostream& os, const SCRAMStepsResult& result) {
+ return os << "{outcome: " << result.outcome.toString() << ", status: " << result.status
+ << "}";
+ }
+};
+
+SCRAMStepsResult runSteps(NativeSaslAuthenticationSession* saslServerSession,
+ NativeSaslClientSession* saslClientSession,
+ SCRAMMutators interposers = SCRAMMutators{}) {
+ SCRAMStepsResult result{};
+ std::string clientOutput = "";
+ std::string serverOutput = "";
+
+ for (size_t step = 1; step <= 3; step++) {
+ ASSERT_FALSE(saslClientSession->isDone());
+ ASSERT_FALSE(saslServerSession->isDone());
+
+ // Client step
+ result.status = saslClientSession->step(serverOutput, &clientOutput);
+ if (result.status != Status::OK()) {
+ return result;
+ }
+ interposers.execute(result.outcome, clientOutput);
+ std::cout << result.outcome.toString() << ": " << clientOutput << std::endl;
+ result.outcome.next();
+
+ // Server step
+ result.status = saslServerSession->step(clientOutput, &serverOutput);
+ if (result.status != Status::OK()) {
+ return result;
+ }
+ interposers.execute(result.outcome, serverOutput);
+ std::cout << result.outcome.toString() << ": " << serverOutput << std::endl;
+ result.outcome.next();
+ }
+ ASSERT_TRUE(saslClientSession->isDone());
+ ASSERT_TRUE(saslServerSession->isDone());
+
+ return result;
+}
+
+class SCRAMFixture : public mongo::unittest::Test {
+protected:
+ const SCRAMStepsResult goalState =
+ SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 4), Status::OK());
+
+ ServiceContextNoop serviceContext;
+ ServiceContextNoop::UniqueClient client;
+ ServiceContextNoop::UniqueOperationContext opCtx;
+
+ AuthzManagerExternalStateMock* authzManagerExternalState;
+ std::unique_ptr<AuthorizationManager> authzManager;
+ std::unique_ptr<AuthorizationSession> authzSession;
+
+ std::unique_ptr<NativeSaslAuthenticationSession> saslServerSession;
+ std::unique_ptr<NativeSaslClientSession> saslClientSession;
+
+ void setUp() final {
+ client = serviceContext.makeClient("test");
+ opCtx = serviceContext.makeOperationContext(client.get());
+
+ auto uniqueAuthzManagerExternalStateMock =
+ stdx::make_unique<AuthzManagerExternalStateMock>();
+ authzManagerExternalState = uniqueAuthzManagerExternalStateMock.get();
+ authzManager =
+ stdx::make_unique<AuthorizationManager>(std::move(uniqueAuthzManagerExternalStateMock));
+ authzSession = stdx::make_unique<AuthorizationSession>(
+ stdx::make_unique<AuthzSessionExternalStateMock>(authzManager.get()));
+
+ saslServerSession = stdx::make_unique<NativeSaslAuthenticationSession>(authzSession.get());
+ saslServerSession->setOpCtxt(opCtx.get());
+ ASSERT_OK(
+ saslServerSession->start("test", _mechanism, "mongodb", "MockServer.test", 1, false));
+ saslClientSession = stdx::make_unique<NativeSaslClientSession>();
+ saslClientSession->setParameter(NativeSaslClientSession::parameterMechanism, _mechanism);
+ saslClientSession->setParameter(NativeSaslClientSession::parameterServiceName, "mongodb");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostname,
+ "MockServer.test");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostAndPort,
+ "MockServer.test:27017");
+ }
+
+ void tearDown() final {
+ saslClientSession.reset();
+ saslServerSession.reset();
+ authzSession.reset();
+ authzManager.reset();
+ authzManagerExternalState = nullptr;
+ opCtx.reset();
+ client.reset();
+ }
+
+ std::string createPasswordDigest(StringData username, StringData password) {
+ if (_digestPassword) {
+ return mongo::createPasswordDigest(username, password);
+ } else {
+ return password.toString();
+ }
+ }
+
+ std::string _mechanism;
+ bool _digestPassword;
+
+public:
+ void run() {
+ log() << "SCRAM-SHA-1 variant";
+ _mechanism = "SCRAM-SHA-1";
+ _digestPassword = true;
+ Test::run();
+
+ log() << "SCRAM-SHA-256 variant";
+ _mechanism = "SCRAM-SHA-256";
+ _digestPassword = false;
+ Test::run();
+ }
+};
+
+TEST_F(SCRAMFixture, testServerStep1DoesNotIncludeNonceFromClientStep1) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kServer, 1), [](std::string& serverMessage) {
+ std::string::iterator nonceBegin = serverMessage.begin() + serverMessage.find("r=");
+ std::string::iterator nonceEnd = std::find(nonceBegin, serverMessage.end(), ',');
+ serverMessage = serverMessage.replace(nonceBegin, nonceEnd, "r=");
+
+ });
+ ASSERT_EQ(
+ SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 2),
+ Status(ErrorCodes::BadValue, "Incorrect SCRAM client|server nonce: r=")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testClientStep2DoesNotIncludeNonceFromServerStep1) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 2), [](std::string& clientMessage) {
+ std::string::iterator nonceBegin = clientMessage.begin() + clientMessage.find("r=");
+ std::string::iterator nonceEnd = std::find(nonceBegin, clientMessage.end(), ',');
+ clientMessage = clientMessage.replace(nonceBegin, nonceEnd, "r=");
+ });
+ ASSERT_EQ(
+ SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::BadValue, "Incorrect SCRAM client|server nonce: r=")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testClientStep2GivesBadProof) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 2), [](std::string& clientMessage) {
+ std::string::iterator proofBegin = clientMessage.begin() + clientMessage.find("p=") + 2;
+ std::string::iterator proofEnd = std::find(proofBegin, clientMessage.end(), ',');
+ clientMessage = clientMessage.replace(
+ proofBegin, proofEnd, corruptEncodedPayload(clientMessage, proofBegin, proofEnd));
+
+ });
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::AuthenticationFailed,
+ "SCRAM authentication failed, storedKey mismatch")),
+
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testServerStep2GivesBadVerifier) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ std::string encodedVerifier;
+ SCRAMMutators mutator;
+ mutator.setMutator(
+ SaslTestState(SaslTestState::kServer, 2), [&encodedVerifier](std::string& serverMessage) {
+ std::string::iterator verifierBegin =
+ serverMessage.begin() + serverMessage.find("v=") + 2;
+ std::string::iterator verifierEnd = std::find(verifierBegin, serverMessage.end(), ',');
+ encodedVerifier = corruptEncodedPayload(serverMessage, verifierBegin, verifierEnd);
+
+ serverMessage = serverMessage.replace(verifierBegin, verifierEnd, encodedVerifier);
+
+ });
+
+ auto result = runSteps(saslServerSession.get(), saslClientSession.get(), mutator);
+
+ ASSERT_EQ(SCRAMStepsResult(
+ SaslTestState(SaslTestState::kClient, 3),
+ Status(ErrorCodes::BadValue,
+ str::stream() << "Client failed to verify SCRAM ServerSignature, received "
+ << encodedVerifier)),
+ result);
+}
+
+
+TEST_F(SCRAMFixture, testSCRAM) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMFixture, testSCRAMWithChannelBindingSupportedByClient) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
+ clientMessage.replace(clientMessage.begin(), clientMessage.begin() + 1, "y");
+ });
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testSCRAMWithChannelBindingRequiredByClient) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
+ clientMessage.replace(clientMessage.begin(), clientMessage.begin() + 1, "p=tls-unique");
+ });
+
+ ASSERT_EQ(
+ SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
+ Status(ErrorCodes::BadValue, "Server does not support channel binding")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testSCRAMWithInvalidChannelBinding) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
+ clientMessage.replace(clientMessage.begin(), clientMessage.begin() + 1, "v=illegalGarbage");
+ });
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
+ Status(ErrorCodes::BadValue,
+ "Incorrect SCRAM client message prefix: v=illegalGarbage")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testNULLInPassword) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "saj\0ack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "saj\0ack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+
+TEST_F(SCRAMFixture, testCommasInUsernameAndPassword) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("s,a,jack", "s,a,jack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "s,a,jack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("s,a,jack", "s,a,jack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMFixture, testIncorrectUser) {
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
+ Status(ErrorCodes::UserNotFound, "Could not find user sajack@test")),
+ runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMFixture, testIncorrectPassword) {
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "invalidPassword"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::AuthenticationFailed,
+ "SCRAM authentication failed, storedKey mismatch")),
+ runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMFixture, testOptionalClientExtensions) {
+ // Verify server ignores unknown/optional extensions sent by client.
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
+ clientMessage += ",x=unsupported-extension";
+ });
+
+ // Optional client extension is successfully ignored, or we'd have failed in step 1.
+ // We still fail at step 2, because client was unaware of the injected extension.
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::AuthenticationFailed,
+ "SCRAM authentication failed, storedKey mismatch")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMFixture, testOptionalServerExtensions) {
+ // Verify client errors on unknown/optional extensions sent by server.
+ ASSERT_OK(authzManagerExternalState->insertPrivilegeDocument(
+ opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()));
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kServer, 1), [](std::string& serverMessage) {
+ serverMessage += ",x=unsupported-extension";
+ });
+
+ // As with testOptionalClientExtensions, we can be confident that the optionality
+ // is respected because we would have failed at client step 2.
+ // We do still fail at server step 2 because server was unaware of injected extension.
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::AuthenticationFailed,
+ "SCRAM authentication failed, storedKey mismatch")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+template <typename HashBlock>
+void testGetFromEmptyCache() {
+ SCRAMClientCache<HashBlock> cache;
+ const auto salt = scram::Presecrets<HashBlock>::generateSecureRandomSalt();
+ HostAndPort host("localhost:27017");
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::Presecrets<HashBlock>("aaa", salt, 10000)));
+}
+
+TEST(SCRAMCache, testGetFromEmptyCache) {
+ testGetFromEmptyCache<SHA1Block>();
+ testGetFromEmptyCache<SHA256Block>();
+}
+
+template <typename HashBlock>
+void testSetAndGet() {
+ SCRAMClientCache<HashBlock> cache;
+ const auto salt = scram::Presecrets<HashBlock>::generateSecureRandomSalt();
+ HostAndPort host("localhost:27017");
+
+ const auto presecrets = scram::Presecrets<HashBlock>("aaa", salt, 10000);
+ const auto secrets = scram::Secrets<HashBlock>(presecrets);
+ cache.setCachedSecrets(host, presecrets, secrets);
+ const auto cachedSecrets = cache.getCachedSecrets(host, presecrets);
+
+ ASSERT_TRUE(cachedSecrets);
+ ASSERT_TRUE(secrets.clientKey() == cachedSecrets.clientKey());
+ ASSERT_TRUE(secrets.serverKey() == cachedSecrets.serverKey());
+ ASSERT_TRUE(secrets.storedKey() == cachedSecrets.storedKey());
+}
+
+TEST(SCRAMCache, testSetAndGet) {
+ testSetAndGet<SHA1Block>();
+ testSetAndGet<SHA256Block>();
+}
+
+template <typename HashBlock>
+void testSetAndGetWithDifferentParameters() {
+ SCRAMClientCache<HashBlock> cache;
+ const auto salt = scram::Presecrets<HashBlock>::generateSecureRandomSalt();
+ HostAndPort host("localhost:27017");
+
+ const auto presecrets = scram::Presecrets<HashBlock>("aaa", salt, 10000);
+ const auto secrets = scram::Secrets<HashBlock>(presecrets);
+ cache.setCachedSecrets(host, presecrets, secrets);
+ ASSERT_TRUE(cache.getCachedSecrets(host, presecrets));
+
+ // Alter each of: host, password, salt, iterationCount.
+ // Any one of which should fail to retreive from cache.
+ ASSERT_FALSE(cache.getCachedSecrets(HostAndPort("localhost:27018"), presecrets));
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::Presecrets<HashBlock>("aab", salt, 10000)));
+ const auto badSalt = scram::Presecrets<HashBlock>::generateSecureRandomSalt();
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::Presecrets<HashBlock>("aaa", badSalt, 10000)));
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::Presecrets<HashBlock>("aaa", salt, 10001)));
+}
+
+TEST(SCRAMCache, testSetAndGetWithDifferentParameters) {
+ testSetAndGetWithDifferentParameters<SHA1Block>();
+ testSetAndGetWithDifferentParameters<SHA256Block>();
+}
+
+template <typename HashBlock>
+void testSetAndReset() {
+ SCRAMClientCache<HashBlock> cache;
+ const auto salt = scram::Presecrets<HashBlock>::generateSecureRandomSalt();
+ HostAndPort host("localhost:27017");
+
+ const auto presecretsA = scram::Presecrets<HashBlock>("aaa", salt, 10000);
+ const auto secretsA = scram::Secrets<HashBlock>(presecretsA);
+ cache.setCachedSecrets(host, presecretsA, secretsA);
+ const auto presecretsB = scram::Presecrets<HashBlock>("aab", salt, 10000);
+ const auto secretsB = scram::Secrets<HashBlock>(presecretsB);
+ cache.setCachedSecrets(host, presecretsB, secretsB);
+
+ ASSERT_FALSE(cache.getCachedSecrets(host, presecretsA));
+ const auto cachedSecret = cache.getCachedSecrets(host, presecretsB);
+ ASSERT_TRUE(cachedSecret);
+ ASSERT_TRUE(secretsB.clientKey() == cachedSecret.clientKey());
+ ASSERT_TRUE(secretsB.serverKey() == cachedSecret.serverKey());
+ ASSERT_TRUE(secretsB.storedKey() == cachedSecret.storedKey());
+}
+
+TEST(SCRAMCache, testSetAndReset) {
+ testSetAndReset<SHA1Block>();
+ testSetAndReset<SHA256Block>();
+}
+
+} // namespace
+} // namespace mongo