summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSpencer Jackson <spencer.jackson@mongodb.com>2017-04-28 18:34:21 -0400
committerSpencer Jackson <spencer.jackson@mongodb.com>2017-07-11 16:20:07 -0400
commit764b75a48f57c84ea8c0b867b3128e1d8760086a (patch)
treef13685f83219fa0c59b6b3e3368bb1a6899685ed
parentb686a69d6ed4653c6973dc62b50eb7b40df87fd4 (diff)
downloadmongo-764b75a48f57c84ea8c0b867b3128e1d8760086a.tar.gz
SERVER-28997: Limit SCRAM-SHA-1 cache's use of Secure Memory
(cherry picked from commit 7ca9cebf2623865fd0077f90baf61132d866a674) (cherry picked from commit 8a4d00991cd1721240f13c8713d7d819baa1763e)
-rw-r--r--src/mongo/client/sasl_scramsha1_client_conversation.cpp6
-rw-r--r--src/mongo/client/scram_sha1_client_cache.cpp2
-rw-r--r--src/mongo/client/scram_sha1_client_cache.h4
-rw-r--r--src/mongo/crypto/mechanism_scram.cpp56
-rw-r--r--src/mongo/crypto/mechanism_scram.h46
-rw-r--r--src/mongo/crypto/sha1_block.h16
-rw-r--r--src/mongo/crypto/sha1_block_openssl.cpp16
-rw-r--r--src/mongo/crypto/sha1_block_test.cpp136
-rw-r--r--src/mongo/crypto/sha1_block_tom.cpp13
-rw-r--r--src/mongo/db/auth/SConscript7
-rw-r--r--src/mongo/db/auth/sasl_plain_server_conversation.cpp4
-rw-r--r--src/mongo/db/auth/sasl_scramsha1_test.cpp501
12 files changed, 735 insertions, 72 deletions
diff --git a/src/mongo/client/sasl_scramsha1_client_conversation.cpp b/src/mongo/client/sasl_scramsha1_client_conversation.cpp
index 87d51d27d2a..fb6fb1702ea 100644
--- a/src/mongo/client/sasl_scramsha1_client_conversation.cpp
+++ b/src/mongo/client/sasl_scramsha1_client_conversation.cpp
@@ -187,11 +187,9 @@ StatusWith<bool> SaslSCRAMSHA1ClientConversation::_secondStep(const std::vector<
_saslClientSession->getParameter(SaslClientSession::parameterServiceHostAndPort));
if (targetHost.isOK()) {
- auto cachedSecrets = _clientCache->getCachedSecrets(targetHost.getValue(), presecrets);
+ _credentials = _clientCache->getCachedSecrets(targetHost.getValue(), presecrets);
- if (cachedSecrets) {
- _credentials = *cachedSecrets;
- } else {
+ if (!_credentials) {
_credentials = scram::generateSecrets(presecrets);
_clientCache->setCachedSecrets(
diff --git a/src/mongo/client/scram_sha1_client_cache.cpp b/src/mongo/client/scram_sha1_client_cache.cpp
index a087f27ca0a..ca3bf9eca14 100644
--- a/src/mongo/client/scram_sha1_client_cache.cpp
+++ b/src/mongo/client/scram_sha1_client_cache.cpp
@@ -32,7 +32,7 @@
namespace mongo {
-boost::optional<scram::SCRAMSecrets> SCRAMSHA1ClientCache::getCachedSecrets(
+scram::SCRAMSecrets SCRAMSHA1ClientCache::getCachedSecrets(
const HostAndPort& target, const scram::SCRAMPresecrets& presecrets) const {
const stdx::lock_guard<stdx::mutex> lock(_hostToSecretsMutex);
diff --git a/src/mongo/client/scram_sha1_client_cache.h b/src/mongo/client/scram_sha1_client_cache.h
index b3fba8734e2..ee595094ec5 100644
--- a/src/mongo/client/scram_sha1_client_cache.h
+++ b/src/mongo/client/scram_sha1_client_cache.h
@@ -68,8 +68,8 @@ public:
* match those recorded for the hostname. Otherwise, no secrets
* are returned.
*/
- boost::optional<scram::SCRAMSecrets> getCachedSecrets(
- const HostAndPort& target, const scram::SCRAMPresecrets& presecrets) const;
+ scram::SCRAMSecrets getCachedSecrets(const HostAndPort& target,
+ const scram::SCRAMPresecrets& presecrets) const;
/**
* Records a set of precomputed SCRAMSecrets for the specified
diff --git a/src/mongo/crypto/mechanism_scram.cpp b/src/mongo/crypto/mechanism_scram.cpp
index b4a543c6a54..c18bc39cc51 100644
--- a/src/mongo/crypto/mechanism_scram.cpp
+++ b/src/mongo/crypto/mechanism_scram.cpp
@@ -119,27 +119,27 @@ SCRAMSecrets generateSecrets(const SCRAMPresecrets& presecrets) {
}
SCRAMSecrets generateSecrets(const SHA1Block& saltedPassword) {
- SCRAMSecrets credentials;
-
- // ClientKey := HMAC(saltedPassword, "Client Key")
- credentials.clientKey =
- SHA1Block::computeHmac(saltedPassword.data(),
- saltedPassword.size(),
- reinterpret_cast<const unsigned char*>(clientKeyConst.data()),
- clientKeyConst.size());
-
- // StoredKey := H(clientKey)
- credentials.storedKey =
- SHA1Block::computeHash(credentials.clientKey->data(), credentials.clientKey->size());
-
- // ServerKey := HMAC(SaltedPassword, "Server Key")
- credentials.serverKey =
- SHA1Block::computeHmac(saltedPassword.data(),
- saltedPassword.size(),
- reinterpret_cast<const unsigned char*>(serverKeyConst.data()),
- serverKeyConst.size());
-
- return credentials;
+ auto generateAndStoreSecrets = [&saltedPassword](
+ SHA1Block& clientKey, SHA1Block& storedKey, SHA1Block& serverKey) {
+
+ // ClientKey := HMAC(saltedPassword, "Client Key")
+ clientKey =
+ SHA1Block::computeHmac(saltedPassword.data(),
+ saltedPassword.size(),
+ reinterpret_cast<const unsigned char*>(clientKeyConst.data()),
+ clientKeyConst.size());
+
+ // StoredKey := H(clientKey)
+ storedKey = SHA1Block::computeHash(clientKey.data(), clientKey.size());
+
+ // ServerKey := HMAC(SaltedPassword, "Server Key")
+ serverKey =
+ SHA1Block::computeHmac(saltedPassword.data(),
+ saltedPassword.size(),
+ reinterpret_cast<const unsigned char*>(serverKeyConst.data()),
+ serverKeyConst.size());
+ };
+ return SCRAMSecrets(std::move(generateAndStoreSecrets));
}
@@ -164,8 +164,8 @@ BSONObj generateCredentials(const std::string& hashedPassword, int iterationCoun
saltLenQWords * sizeof(uint64_t)),
iterationCount));
- std::string encodedStoredKey = secrets.storedKey->toString();
- std::string encodedServerKey = secrets.serverKey->toString();
+ std::string encodedStoredKey = secrets->storedKey.toString();
+ std::string encodedServerKey = secrets->serverKey.toString();
return BSON(iterationCountFieldName << iterationCount << saltFieldName << encodedUserSalt
<< storedKeyFieldName << encodedStoredKey
@@ -176,12 +176,12 @@ std::string generateClientProof(const SCRAMSecrets& clientCredentials,
const std::string& authMessage) {
// ClientSignature := HMAC(StoredKey, AuthMessage)
SHA1Block clientSignature =
- SHA1Block::computeHmac(clientCredentials.storedKey->data(),
- clientCredentials.storedKey->size(),
+ SHA1Block::computeHmac(clientCredentials->storedKey.data(),
+ clientCredentials->storedKey.size(),
reinterpret_cast<const unsigned char*>(authMessage.c_str()),
authMessage.size());
- clientSignature.xorInline(*clientCredentials.clientKey);
+ clientSignature.xorInline(clientCredentials->clientKey);
return clientSignature.toString();
}
@@ -190,8 +190,8 @@ bool verifyServerSignature(const SCRAMSecrets& clientCredentials,
const std::string& receivedServerSignature) {
// ServerSignature := HMAC(ServerKey, AuthMessage)
SHA1Block serverSignature =
- SHA1Block::computeHmac(clientCredentials.serverKey->data(),
- clientCredentials.serverKey->size(),
+ SHA1Block::computeHmac(clientCredentials->serverKey.data(),
+ clientCredentials->serverKey.size(),
reinterpret_cast<const unsigned char*>(authMessage.c_str()),
authMessage.size());
diff --git a/src/mongo/crypto/mechanism_scram.h b/src/mongo/crypto/mechanism_scram.h
index fb070e8162c..92b79efffb9 100644
--- a/src/mongo/crypto/mechanism_scram.h
+++ b/src/mongo/crypto/mechanism_scram.h
@@ -76,12 +76,48 @@ SHA1Block generateSaltedPassword(const SCRAMPresecrets& presecrets);
/*
* Stores all of the keys, generated from a password, needed for a client or server to perform a
- * SCRAM handshake. This structure will secureZeroMemory itself on destruction.
+ * SCRAM handshake.
+ * These keys are reference counted, and allocated using the SecureAllocator.
+ * May be unpopulated. SCRAMSecrets created via the default constructor are unpopulated.
+ * The behavior is undefined if the accessors are called when unpopulated.
*/
-struct SCRAMSecrets {
- SecureHandle<SHA1Block> clientKey;
- SecureHandle<SHA1Block> storedKey;
- SecureHandle<SHA1Block> serverKey;
+class SCRAMSecrets {
+private:
+ struct SCRAMSecretsHolder {
+ SHA1Block clientKey;
+ SHA1Block storedKey;
+ SHA1Block serverKey;
+ };
+
+public:
+ // Creates an unpopulated SCRAMSecrets object.
+ SCRAMSecrets() = default;
+
+ // Creates a populated SCRAMSecrets object. First, allocates secure storage, then provides it
+ // to a callback, which fills the memory.
+ template <typename T>
+ explicit SCRAMSecrets(T initializationFun)
+ : _ptr(std::make_shared<SecureHandle<SCRAMSecretsHolder>>()) {
+ initializationFun((*this)->clientKey, (*this)->storedKey, (*this)->serverKey);
+ }
+
+ // Returns true if the underlying shared_pointer is populated.
+ explicit operator bool() const {
+ return static_cast<bool>(_ptr);
+ }
+
+ const SecureHandle<SCRAMSecretsHolder>& operator*() const {
+ invariant(_ptr);
+ return *_ptr;
+ }
+
+ const SecureHandle<SCRAMSecretsHolder>& operator->() const {
+ invariant(_ptr);
+ return *_ptr;
+ }
+
+private:
+ std::shared_ptr<SecureHandle<SCRAMSecretsHolder>> _ptr;
};
/*
diff --git a/src/mongo/crypto/sha1_block.h b/src/mongo/crypto/sha1_block.h
index 280177de309..266f6602e6c 100644
--- a/src/mongo/crypto/sha1_block.h
+++ b/src/mongo/crypto/sha1_block.h
@@ -63,7 +63,21 @@ public:
static SHA1Block computeHmac(const uint8_t* key,
size_t keyLen,
const uint8_t* input,
- size_t inputLen);
+ size_t inputLen) {
+ SHA1Block output;
+ SHA1Block::computeHmac(key, keyLen, input, inputLen, &output);
+ return output;
+ }
+
+ /**
+ * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key'. Writes the results into
+ * a pre-allocated SHA1Block. This lets us allocate SHA1Blocks with the SecureAllocator.
+ */
+ static void computeHmac(const uint8_t* key,
+ size_t keyLen,
+ const uint8_t* input,
+ size_t inputLen,
+ SHA1Block* const output);
const uint8_t* data() const {
return _hash.data();
diff --git a/src/mongo/crypto/sha1_block_openssl.cpp b/src/mongo/crypto/sha1_block_openssl.cpp
index 50d18b9bace..ce0cab8e9e6 100644
--- a/src/mongo/crypto/sha1_block_openssl.cpp
+++ b/src/mongo/crypto/sha1_block_openssl.cpp
@@ -82,15 +82,15 @@ SHA1Block SHA1Block::computeHash(const uint8_t* input, size_t inputLen) {
}
/*
- * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key'
+ * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key', writes output into 'output'.
*/
-SHA1Block SHA1Block::computeHmac(const uint8_t* key,
- size_t keyLen,
- const uint8_t* input,
- size_t inputLen) {
- HashType output;
- fassert(40380, HMAC(EVP_sha1(), key, keyLen, input, inputLen, output.data(), NULL) != NULL);
- return SHA1Block(output);
+void SHA1Block::computeHmac(const uint8_t* key,
+ size_t keyLen,
+ const uint8_t* input,
+ size_t inputLen,
+ SHA1Block* const output) {
+ fassert(40380,
+ HMAC(EVP_sha1(), key, keyLen, input, inputLen, output->_hash.data(), NULL) != NULL);
}
} // namespace mongo
diff --git a/src/mongo/crypto/sha1_block_test.cpp b/src/mongo/crypto/sha1_block_test.cpp
index 111132ff388..651b2e52bfe 100644
--- a/src/mongo/crypto/sha1_block_test.cpp
+++ b/src/mongo/crypto/sha1_block_test.cpp
@@ -40,13 +40,49 @@ namespace {
const struct {
const char* msg;
SHA1Block hash;
-} sha1Tests[] = {
- {"abc", SHA1Block::HashType{0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81, 0x6a, 0xba, 0x3e,
- 0x25, 0x71, 0x78, 0x50, 0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d}},
+} sha1Tests[] = {{"abc",
+ SHA1Block::HashType{0xa9,
+ 0x99,
+ 0x3e,
+ 0x36,
+ 0x47,
+ 0x06,
+ 0x81,
+ 0x6a,
+ 0xba,
+ 0x3e,
+ 0x25,
+ 0x71,
+ 0x78,
+ 0x50,
+ 0xc2,
+ 0x6c,
+ 0x9c,
+ 0xd0,
+ 0xd8,
+ 0x9d}},
- {"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
- SHA1Block::HashType{0x84, 0x98, 0x3E, 0x44, 0x1C, 0x3B, 0xD2, 0x6E, 0xBA, 0xAE,
- 0x4A, 0xA1, 0xF9, 0x51, 0x29, 0xE5, 0xE5, 0x46, 0x70, 0xF1}}};
+ {"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
+ SHA1Block::HashType{0x84,
+ 0x98,
+ 0x3E,
+ 0x44,
+ 0x1C,
+ 0x3B,
+ 0xD2,
+ 0x6E,
+ 0xBA,
+ 0xAE,
+ 0x4A,
+ 0xA1,
+ 0xF9,
+ 0x51,
+ 0x29,
+ 0xE5,
+ 0xE5,
+ 0x46,
+ 0x70,
+ 0xF1}}};
TEST(CryptoVectors, SHA1) {
size_t numTests = sizeof(sha1Tests) / sizeof(sha1Tests[0]);
@@ -91,8 +127,26 @@ const struct {
20,
{0x48, 0x69, 0x20, 0x54, 0x68, 0x65, 0x72, 0x65},
8,
- SHA1Block::HashType{0xb6, 0x17, 0x31, 0x86, 0x55, 0x05, 0x72, 0x64, 0xe2, 0x8b,
- 0xc0, 0xb6, 0xfb, 0x37, 0x8c, 0x8e, 0xf1, 0x46, 0xbe, 0x00}},
+ SHA1Block::HashType{0xb6,
+ 0x17,
+ 0x31,
+ 0x86,
+ 0x55,
+ 0x05,
+ 0x72,
+ 0x64,
+ 0xe2,
+ 0x8b,
+ 0xc0,
+ 0xb6,
+ 0xfb,
+ 0x37,
+ 0x8c,
+ 0x8e,
+ 0xf1,
+ 0x46,
+ 0xbe,
+ 0x00}},
// RFC test case 3
{{0xaa,
@@ -167,8 +221,26 @@ const struct {
0xdd,
0xdd},
50,
- SHA1Block::HashType{0x12, 0x5d, 0x73, 0x42, 0xb9, 0xac, 0x11, 0xcd, 0x91, 0xa3,
- 0x9a, 0xf4, 0x8a, 0xa1, 0x7b, 0x4f, 0x63, 0xf1, 0x75, 0xd3}},
+ SHA1Block::HashType{0x12,
+ 0x5d,
+ 0x73,
+ 0x42,
+ 0xb9,
+ 0xac,
+ 0x11,
+ 0xcd,
+ 0x91,
+ 0xa3,
+ 0x9a,
+ 0xf4,
+ 0x8a,
+ 0xa1,
+ 0x7b,
+ 0x4f,
+ 0x63,
+ 0xf1,
+ 0x75,
+ 0xd3}},
// RFC test case 4
{{0x01,
@@ -248,8 +320,26 @@ const struct {
0xcd,
0xcd},
50,
- SHA1Block::HashType{0x4c, 0x90, 0x07, 0xf4, 0x02, 0x62, 0x50, 0xc6, 0xbc, 0x84,
- 0x14, 0xf9, 0xbf, 0x50, 0xc8, 0x6c, 0x2d, 0x72, 0x35, 0xda}},
+ SHA1Block::HashType{0x4c,
+ 0x90,
+ 0x07,
+ 0xf4,
+ 0x02,
+ 0x62,
+ 0x50,
+ 0xc6,
+ 0xbc,
+ 0x84,
+ 0x14,
+ 0xf9,
+ 0xbf,
+ 0x50,
+ 0xc8,
+ 0x6c,
+ 0x2d,
+ 0x72,
+ 0x35,
+ 0xda}},
// RFC test case 6
{{0xaa,
@@ -388,8 +478,26 @@ const struct {
0x73,
0x74},
54,
- SHA1Block::HashType{0xaa, 0x4a, 0xe5, 0xe1, 0x52, 0x72, 0xd0, 0x0e, 0x95, 0x70,
- 0x56, 0x37, 0xce, 0x8a, 0x3b, 0x55, 0xed, 0x40, 0x21, 0x12}}};
+ SHA1Block::HashType{0xaa,
+ 0x4a,
+ 0xe5,
+ 0xe1,
+ 0x52,
+ 0x72,
+ 0xd0,
+ 0x0e,
+ 0x95,
+ 0x70,
+ 0x56,
+ 0x37,
+ 0xce,
+ 0x8a,
+ 0x3b,
+ 0x55,
+ 0xed,
+ 0x40,
+ 0x21,
+ 0x12}}};
TEST(CryptoVectors, HMACSHA1) {
size_t numTests = sizeof(hmacSha1Tests) / sizeof(hmacSha1Tests[0]);
diff --git a/src/mongo/crypto/sha1_block_tom.cpp b/src/mongo/crypto/sha1_block_tom.cpp
index d1078e60458..7d32dbc2f53 100644
--- a/src/mongo/crypto/sha1_block_tom.cpp
+++ b/src/mongo/crypto/sha1_block_tom.cpp
@@ -58,12 +58,12 @@ SHA1Block SHA1Block::computeHash(const uint8_t* input, size_t inputLen) {
/*
* Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key'
*/
-SHA1Block SHA1Block::computeHmac(const uint8_t* key,
- size_t keyLen,
- const uint8_t* input,
- size_t inputLen) {
+void SHA1Block::computeHmac(const uint8_t* key,
+ size_t keyLen,
+ const uint8_t* input,
+ size_t inputLen,
+ SHA1Block* const output) {
invariant(key && input);
- HashType output;
static int hashId = -1;
if (hashId == -1) {
@@ -73,9 +73,8 @@ SHA1Block SHA1Block::computeHmac(const uint8_t* key,
unsigned long sha1HashLen = 20;
fassert(40382,
- hmac_memory(hashId, key, keyLen, input, inputLen, output.data(), &sha1HashLen) ==
+ hmac_memory(hashId, key, keyLen, input, inputLen, output->_hash.data(), &sha1HashLen) ==
CRYPT_OK);
- return SHA1Block(output);
}
} // namespace mongo
diff --git a/src/mongo/db/auth/SConscript b/src/mongo/db/auth/SConscript
index c16caede0ae..1201d5e8d0e 100644
--- a/src/mongo/db/auth/SConscript
+++ b/src/mongo/db/auth/SConscript
@@ -188,3 +188,10 @@ env.Library(
'authcore',
]
)
+
+env.CppUnitTest('sasl_scramsha1_test',
+ 'sasl_scramsha1_test.cpp',
+ LIBDEPS=[
+ 'saslauth',
+ '$BUILD_DIR/mongo/client/sasl_client',
+ ])
diff --git a/src/mongo/db/auth/sasl_plain_server_conversation.cpp b/src/mongo/db/auth/sasl_plain_server_conversation.cpp
index d68a1ab92ad..d1e993db4a1 100644
--- a/src/mongo/db/auth/sasl_plain_server_conversation.cpp
+++ b/src/mongo/db/auth/sasl_plain_server_conversation.cpp
@@ -88,8 +88,8 @@ StatusWith<bool> SaslPLAINServerConversation::step(StringData inputData, std::st
16),
creds.scram.iterationCount));
if (creds.scram.storedKey !=
- base64::encode(reinterpret_cast<const char*>(secrets.storedKey->data()),
- secrets.storedKey->size())) {
+ base64::encode(reinterpret_cast<const char*>(secrets->storedKey.data()),
+ secrets->storedKey.size())) {
return StatusWith<bool>(ErrorCodes::AuthenticationFailed,
mongoutils::str::stream() << "Incorrect user name or password");
}
diff --git a/src/mongo/db/auth/sasl_scramsha1_test.cpp b/src/mongo/db/auth/sasl_scramsha1_test.cpp
new file mode 100644
index 00000000000..66cddabe02e
--- /dev/null
+++ b/src/mongo/db/auth/sasl_scramsha1_test.cpp
@@ -0,0 +1,501 @@
+/*
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/client/native_sasl_client_session.h"
+#include "mongo/client/scram_sha1_client_cache.h"
+#include "mongo/crypto/mechanism_scram.h"
+#include "mongo/db/auth/authorization_manager.h"
+#include "mongo/db/auth/authz_manager_external_state_mock.h"
+#include "mongo/db/auth/authz_session_external_state_mock.h"
+#include "mongo/db/auth/native_sasl_authentication_session.h"
+#include "mongo/db/auth/sasl_scramsha1_server_conversation.h"
+#include "mongo/db/service_context_noop.h"
+#include "mongo/stdx/memory.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/base64.h"
+#include "mongo/util/password_digest.h"
+
+namespace mongo {
+
+BSONObj generateSCRAMUserDocument(StringData username, StringData password) {
+ const size_t scramIterationCount = 10000;
+ std::string database = "test";
+
+ std::string digested = createPasswordDigest(username, password);
+ BSONObj scramCred = scram::generateCredentials(digested, scramIterationCount);
+ return BSON("_id" << (str::stream() << database << "." << username).operator std::string()
+ << AuthorizationManager::USER_NAME_FIELD_NAME << username
+ << AuthorizationManager::USER_DB_FIELD_NAME << database << "credentials"
+ << BSON("SCRAM-SHA-1" << scramCred) << "roles" << BSONArray() << "privileges"
+ << BSONArray());
+}
+
+BSONObj generateMONGODBCRUserDocument(StringData username, StringData password) {
+ std::string database = "test";
+
+ std::string digested = createPasswordDigest(username, password);
+ return BSON("_id" << (str::stream() << database << "." << username).operator std::string()
+ << AuthorizationManager::USER_NAME_FIELD_NAME << username
+ << AuthorizationManager::USER_DB_FIELD_NAME << database << "credentials"
+ << BSON("MONGODB-CR" << digested) << "roles" << BSONArray() << "privileges"
+ << BSONArray());
+}
+
+std::string corruptEncodedPayload(const std::string& message,
+ std::string::const_iterator begin,
+ std::string::const_iterator end) {
+ std::string raw = base64::decode(
+ message.substr(std::distance(message.begin(), begin), std::distance(begin, end)));
+ if (raw[0] == std::numeric_limits<char>::max()) {
+ raw[0] -= 1;
+ } else {
+ raw[0] += 1;
+ }
+ return base64::encode(raw);
+}
+
+class SaslTestState {
+public:
+ enum Participant { kClient, kServer };
+ SaslTestState() : SaslTestState(kClient, 0) {}
+ SaslTestState(Participant participant, size_t stage) : participant(participant), stage(stage) {}
+
+private:
+ // Define members here, so that they can be used in declaration of lens(). In C++14, lens()
+ // can be declared with a return of decltype(auto), without a trailing return type, and these
+ // members can go at the end of the class.
+ Participant participant;
+ size_t stage;
+
+public:
+ std::tuple<size_t, Participant> lens() const {
+ return std::tie(stage, participant);
+ }
+
+ friend bool operator==(const SaslTestState& lhs, const SaslTestState& rhs) {
+ return lhs.lens() == rhs.lens();
+ }
+
+ friend bool operator<(const SaslTestState& lhs, const SaslTestState& rhs) {
+ return lhs.lens() < rhs.lens();
+ }
+
+ void next() {
+ if (participant == kClient) {
+ participant = kServer;
+ } else {
+ participant = kClient;
+ stage++;
+ }
+ }
+
+ std::string toString() const {
+ std::stringstream ss;
+ if (participant == kClient) {
+ ss << "Client";
+ } else {
+ ss << "Server";
+ }
+ ss << "Step" << stage;
+
+ return ss.str();
+ }
+};
+
+class SCRAMMutators {
+public:
+ SCRAMMutators() {}
+
+ void setMutator(SaslTestState state, stdx::function<void(std::string&)> fun) {
+ mutators.insert(std::make_pair(state, fun));
+ }
+
+ void execute(SaslTestState state, std::string& str) {
+ auto it = mutators.find(state);
+ if (it != mutators.end()) {
+ it->second(str);
+ }
+ }
+
+private:
+ std::map<SaslTestState, stdx::function<void(std::string&)>> mutators;
+};
+
+struct SCRAMStepsResult {
+ SCRAMStepsResult() : outcome(SaslTestState::kClient, 1), status(Status::OK()) {}
+ SCRAMStepsResult(SaslTestState outcome, Status status) : outcome(outcome), status(status) {}
+ bool operator==(const SCRAMStepsResult& other) const {
+ return outcome == other.outcome && status.code() == other.status.code() &&
+ status.reason() == other.status.reason();
+ }
+ SaslTestState outcome;
+ Status status;
+
+ friend std::ostream& operator<<(std::ostream& os, const SCRAMStepsResult& result) {
+ return os << "{outcome: " << result.outcome.toString() << ", status: " << result.status
+ << "}";
+ }
+};
+
+SCRAMStepsResult runSteps(NativeSaslAuthenticationSession* saslServerSession,
+ NativeSaslClientSession* saslClientSession,
+ SCRAMMutators interposers = SCRAMMutators{}) {
+ SCRAMStepsResult result{};
+ std::string clientOutput = "";
+ std::string serverOutput = "";
+
+ for (size_t step = 1; step <= 3; step++) {
+ ASSERT_FALSE(saslClientSession->isDone());
+ ASSERT_FALSE(saslServerSession->isDone());
+
+ // Client step
+ result.status = saslClientSession->step(serverOutput, &clientOutput);
+ if (result.status != Status::OK()) {
+ return result;
+ }
+ std::cout << result.outcome.toString() << ": " << clientOutput << std::endl;
+ interposers.execute(result.outcome, clientOutput);
+ result.outcome.next();
+
+ // Server step
+ result.status = saslServerSession->step(clientOutput, &serverOutput);
+ if (result.status != Status::OK()) {
+ return result;
+ }
+ interposers.execute(result.outcome, serverOutput);
+ std::cout << result.outcome.toString() << ": " << serverOutput << std::endl;
+ result.outcome.next();
+ }
+ ASSERT_TRUE(saslClientSession->isDone());
+ ASSERT_TRUE(saslServerSession->isDone());
+
+ return result;
+}
+
+class SCRAMSHA1Fixture : public mongo::unittest::Test {
+protected:
+ const SCRAMStepsResult goalState =
+ SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 4), Status::OK());
+
+ ServiceContextNoop serviceContext;
+ ServiceContextNoop::UniqueClient client;
+ ServiceContextNoop::UniqueOperationContext txn;
+
+ AuthzManagerExternalStateMock* authzManagerExternalState;
+ std::unique_ptr<AuthorizationManager> authzManager;
+ std::unique_ptr<AuthorizationSession> authzSession;
+
+ std::unique_ptr<NativeSaslAuthenticationSession> saslServerSession;
+ std::unique_ptr<NativeSaslClientSession> saslClientSession;
+
+ void setUp() {
+ client = serviceContext.makeClient("test");
+ txn = serviceContext.makeOperationContext(client.get());
+
+ auto uniqueAuthzManagerExternalStateMock =
+ stdx::make_unique<AuthzManagerExternalStateMock>();
+ authzManagerExternalState = uniqueAuthzManagerExternalStateMock.get();
+ authzManager =
+ stdx::make_unique<AuthorizationManager>(std::move(uniqueAuthzManagerExternalStateMock));
+ authzSession = stdx::make_unique<AuthorizationSession>(
+ stdx::make_unique<AuthzSessionExternalStateMock>(authzManager.get()));
+
+ saslServerSession = stdx::make_unique<NativeSaslAuthenticationSession>(authzSession.get());
+ saslServerSession->setOpCtxt(txn.get());
+ saslServerSession->start("test", "SCRAM-SHA-1", "mongodb", "MockServer.test", 1, false);
+ saslClientSession = stdx::make_unique<NativeSaslClientSession>();
+ saslClientSession->setParameter(NativeSaslClientSession::parameterMechanism, "SCRAM-SHA-1");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterServiceName, "mongodb");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostname,
+ "MockServer.test");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostAndPort,
+ "MockServer.test:27017");
+ }
+};
+
+/*TEST_F(SCRAMSHA1Fixture, testServerStep1DoesNotIncludeNonceFromClientStep1) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kServer, 1), [](std::string& serverMessage) {
+ std::string::iterator nonceBegin = serverMessage.begin() + serverMessage.find("r=");
+ std::string::iterator nonceEnd = std::find(nonceBegin, serverMessage.end(), ',');
+ serverMessage = serverMessage.replace(nonceBegin, nonceEnd, "r=");
+
+ });
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 2),
+ Status(ErrorCodes::BadValue,
+ "Server SCRAM-SHA-1 nonce does not match client nonce: r=")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}*/
+
+TEST_F(SCRAMSHA1Fixture, testClientStep2DoesNotIncludeNonceFromServerStep1) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(SaslTestState(SaslTestState::kClient, 2),
+ [](std::string& clientMessage) {
+ std::string::iterator nonceBegin =
+ clientMessage.begin() + clientMessage.find("r=");
+ std::string::iterator nonceEnd =
+ std::find(nonceBegin, clientMessage.end(), ',');
+ clientMessage = clientMessage.replace(nonceBegin, nonceEnd, "r=");
+ });
+ ASSERT_EQ(SCRAMStepsResult(
+ SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::BadValue, "Incorrect SCRAM-SHA-1 client|server nonce: r=")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMSHA1Fixture, testClientStep2GivesBadProof) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ SCRAMMutators mutator;
+ mutator.setMutator(
+ SaslTestState(SaslTestState::kClient, 2),
+ [](std::string& clientMessage) {
+ std::string::iterator proofBegin = clientMessage.begin() + clientMessage.find("p=") + 2;
+ std::string::iterator proofEnd = std::find(proofBegin, clientMessage.end(), ',');
+ clientMessage = clientMessage.replace(
+ proofBegin, proofEnd, corruptEncodedPayload(clientMessage, proofBegin, proofEnd));
+
+ });
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::AuthenticationFailed,
+ "SCRAM-SHA-1 authentication failed, storedKey mismatch")),
+ runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
+}
+
+TEST_F(SCRAMSHA1Fixture, testServerStep2GivesBadVerifier) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ std::string encodedVerifier;
+ SCRAMMutators mutator;
+ mutator.setMutator(
+ SaslTestState(SaslTestState::kServer, 2),
+ [&encodedVerifier](std::string& serverMessage) {
+ std::string::iterator verifierBegin =
+ serverMessage.begin() + serverMessage.find("v=") + 2;
+ std::string::iterator verifierEnd = std::find(verifierBegin, serverMessage.end(), ',');
+ encodedVerifier = corruptEncodedPayload(serverMessage, verifierBegin, verifierEnd);
+
+ serverMessage = serverMessage.replace(verifierBegin, verifierEnd, encodedVerifier);
+
+ });
+
+ auto result = runSteps(saslServerSession.get(), saslClientSession.get(), mutator);
+
+ ASSERT_EQ(
+ SCRAMStepsResult(
+ SaslTestState(SaslTestState::kClient, 3),
+ Status(ErrorCodes::BadValue,
+ str::stream() << "Client failed to verify SCRAM-SHA-1 ServerSignature, received "
+ << encodedVerifier)),
+ result);
+}
+
+
+TEST_F(SCRAMSHA1Fixture, testSCRAM) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMSHA1Fixture, testNULLInPassword) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "saj\0ack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "saj\0ack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+
+TEST_F(SCRAMSHA1Fixture, testCommasInUsernameAndPassword) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("s,a,jack", "s,a,jack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "s,a,jack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("s,a,jack", "s,a,jack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMSHA1Fixture, testIncorrectUser) {
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
+ Status(ErrorCodes::UserNotFound, "Could not find user sajack@test")),
+ runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMSHA1Fixture, testIncorrectPassword) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "invalidPassword"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
+ Status(ErrorCodes::AuthenticationFailed,
+ "SCRAM-SHA-1 authentication failed, storedKey mismatch")),
+ runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST_F(SCRAMSHA1Fixture, testMONGODBCR) {
+ authzManagerExternalState->insertPrivilegeDocument(
+ txn.get(), generateMONGODBCRUserDocument("sajack", "sajack"), BSONObj());
+
+ saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
+ saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
+ createPasswordDigest("sajack", "sajack"));
+
+ ASSERT_OK(saslClientSession->initialize());
+
+ ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
+}
+
+TEST(SCRAMSHA1Cache, testGetFromEmptyCache) {
+ SCRAMSHA1ClientCache cache;
+ std::string saltStr("saltsaltsaltsalt");
+ std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
+ HostAndPort host("localhost:27017");
+
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000)));
+}
+
+
+TEST(SCRAMSHA1Cache, testSetAndGet) {
+ SCRAMSHA1ClientCache cache;
+ std::string saltStr("saltsaltsaltsalt");
+ std::string badSaltStr("s@lts@lts@lts@lt");
+ std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
+ std::vector<std::uint8_t> badSalt(badSaltStr.begin(), badSaltStr.end());
+ HostAndPort host("localhost:27017");
+
+ auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000));
+ cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret);
+ auto cachedSecret = cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000));
+ ASSERT_TRUE(cachedSecret);
+ ASSERT_TRUE(secret->clientKey == cachedSecret->clientKey);
+ ASSERT_TRUE(secret->serverKey == cachedSecret->serverKey);
+ ASSERT_TRUE(secret->storedKey == cachedSecret->storedKey);
+}
+
+
+TEST(SCRAMSHA1Cache, testSetAndGetWithDifferentParameters) {
+ SCRAMSHA1ClientCache cache;
+ std::string saltStr("saltsaltsaltsalt");
+ std::string badSaltStr("s@lts@lts@lts@lt");
+ std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
+ std::vector<std::uint8_t> badSalt(badSaltStr.begin(), badSaltStr.end());
+ HostAndPort host("localhost:27017");
+
+ auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000));
+ cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret);
+
+ ASSERT_FALSE(cache.getCachedSecrets(HostAndPort("localhost:27018"),
+ scram::SCRAMPresecrets("aaa", salt, 10000)));
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000)));
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", badSalt, 10000)));
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10001)));
+}
+
+
+TEST(SCRAMSHA1Cache, testSetAndReset) {
+ SCRAMSHA1ClientCache cache;
+ StringData saltStr("saltsaltsaltsalt");
+ std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
+ HostAndPort host("localhost:27017");
+
+ auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000));
+ cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret);
+ auto newSecret = scram::generateSecrets(scram::SCRAMPresecrets("aab", salt, 10000));
+ cache.setCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000), newSecret);
+
+ ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000)));
+ auto cachedSecret = cache.getCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000));
+ ASSERT_TRUE(cachedSecret);
+ ASSERT_TRUE(newSecret->clientKey == cachedSecret->clientKey);
+ ASSERT_TRUE(newSecret->serverKey == cachedSecret->serverKey);
+ ASSERT_TRUE(newSecret->storedKey == cachedSecret->storedKey);
+}
+
+} // namespace mongo