diff options
author | Spencer Jackson <spencer.jackson@mongodb.com> | 2017-04-28 18:34:21 -0400 |
---|---|---|
committer | Spencer Jackson <spencer.jackson@mongodb.com> | 2017-07-11 16:20:07 -0400 |
commit | 764b75a48f57c84ea8c0b867b3128e1d8760086a (patch) | |
tree | f13685f83219fa0c59b6b3e3368bb1a6899685ed | |
parent | b686a69d6ed4653c6973dc62b50eb7b40df87fd4 (diff) | |
download | mongo-764b75a48f57c84ea8c0b867b3128e1d8760086a.tar.gz |
SERVER-28997: Limit SCRAM-SHA-1 cache's use of Secure Memory
(cherry picked from commit 7ca9cebf2623865fd0077f90baf61132d866a674)
(cherry picked from commit 8a4d00991cd1721240f13c8713d7d819baa1763e)
-rw-r--r-- | src/mongo/client/sasl_scramsha1_client_conversation.cpp | 6 | ||||
-rw-r--r-- | src/mongo/client/scram_sha1_client_cache.cpp | 2 | ||||
-rw-r--r-- | src/mongo/client/scram_sha1_client_cache.h | 4 | ||||
-rw-r--r-- | src/mongo/crypto/mechanism_scram.cpp | 56 | ||||
-rw-r--r-- | src/mongo/crypto/mechanism_scram.h | 46 | ||||
-rw-r--r-- | src/mongo/crypto/sha1_block.h | 16 | ||||
-rw-r--r-- | src/mongo/crypto/sha1_block_openssl.cpp | 16 | ||||
-rw-r--r-- | src/mongo/crypto/sha1_block_test.cpp | 136 | ||||
-rw-r--r-- | src/mongo/crypto/sha1_block_tom.cpp | 13 | ||||
-rw-r--r-- | src/mongo/db/auth/SConscript | 7 | ||||
-rw-r--r-- | src/mongo/db/auth/sasl_plain_server_conversation.cpp | 4 | ||||
-rw-r--r-- | src/mongo/db/auth/sasl_scramsha1_test.cpp | 501 |
12 files changed, 735 insertions, 72 deletions
diff --git a/src/mongo/client/sasl_scramsha1_client_conversation.cpp b/src/mongo/client/sasl_scramsha1_client_conversation.cpp index 87d51d27d2a..fb6fb1702ea 100644 --- a/src/mongo/client/sasl_scramsha1_client_conversation.cpp +++ b/src/mongo/client/sasl_scramsha1_client_conversation.cpp @@ -187,11 +187,9 @@ StatusWith<bool> SaslSCRAMSHA1ClientConversation::_secondStep(const std::vector< _saslClientSession->getParameter(SaslClientSession::parameterServiceHostAndPort)); if (targetHost.isOK()) { - auto cachedSecrets = _clientCache->getCachedSecrets(targetHost.getValue(), presecrets); + _credentials = _clientCache->getCachedSecrets(targetHost.getValue(), presecrets); - if (cachedSecrets) { - _credentials = *cachedSecrets; - } else { + if (!_credentials) { _credentials = scram::generateSecrets(presecrets); _clientCache->setCachedSecrets( diff --git a/src/mongo/client/scram_sha1_client_cache.cpp b/src/mongo/client/scram_sha1_client_cache.cpp index a087f27ca0a..ca3bf9eca14 100644 --- a/src/mongo/client/scram_sha1_client_cache.cpp +++ b/src/mongo/client/scram_sha1_client_cache.cpp @@ -32,7 +32,7 @@ namespace mongo { -boost::optional<scram::SCRAMSecrets> SCRAMSHA1ClientCache::getCachedSecrets( +scram::SCRAMSecrets SCRAMSHA1ClientCache::getCachedSecrets( const HostAndPort& target, const scram::SCRAMPresecrets& presecrets) const { const stdx::lock_guard<stdx::mutex> lock(_hostToSecretsMutex); diff --git a/src/mongo/client/scram_sha1_client_cache.h b/src/mongo/client/scram_sha1_client_cache.h index b3fba8734e2..ee595094ec5 100644 --- a/src/mongo/client/scram_sha1_client_cache.h +++ b/src/mongo/client/scram_sha1_client_cache.h @@ -68,8 +68,8 @@ public: * match those recorded for the hostname. Otherwise, no secrets * are returned. */ - boost::optional<scram::SCRAMSecrets> getCachedSecrets( - const HostAndPort& target, const scram::SCRAMPresecrets& presecrets) const; + scram::SCRAMSecrets getCachedSecrets(const HostAndPort& target, + const scram::SCRAMPresecrets& presecrets) const; /** * Records a set of precomputed SCRAMSecrets for the specified diff --git a/src/mongo/crypto/mechanism_scram.cpp b/src/mongo/crypto/mechanism_scram.cpp index b4a543c6a54..c18bc39cc51 100644 --- a/src/mongo/crypto/mechanism_scram.cpp +++ b/src/mongo/crypto/mechanism_scram.cpp @@ -119,27 +119,27 @@ SCRAMSecrets generateSecrets(const SCRAMPresecrets& presecrets) { } SCRAMSecrets generateSecrets(const SHA1Block& saltedPassword) { - SCRAMSecrets credentials; - - // ClientKey := HMAC(saltedPassword, "Client Key") - credentials.clientKey = - SHA1Block::computeHmac(saltedPassword.data(), - saltedPassword.size(), - reinterpret_cast<const unsigned char*>(clientKeyConst.data()), - clientKeyConst.size()); - - // StoredKey := H(clientKey) - credentials.storedKey = - SHA1Block::computeHash(credentials.clientKey->data(), credentials.clientKey->size()); - - // ServerKey := HMAC(SaltedPassword, "Server Key") - credentials.serverKey = - SHA1Block::computeHmac(saltedPassword.data(), - saltedPassword.size(), - reinterpret_cast<const unsigned char*>(serverKeyConst.data()), - serverKeyConst.size()); - - return credentials; + auto generateAndStoreSecrets = [&saltedPassword]( + SHA1Block& clientKey, SHA1Block& storedKey, SHA1Block& serverKey) { + + // ClientKey := HMAC(saltedPassword, "Client Key") + clientKey = + SHA1Block::computeHmac(saltedPassword.data(), + saltedPassword.size(), + reinterpret_cast<const unsigned char*>(clientKeyConst.data()), + clientKeyConst.size()); + + // StoredKey := H(clientKey) + storedKey = SHA1Block::computeHash(clientKey.data(), clientKey.size()); + + // ServerKey := HMAC(SaltedPassword, "Server Key") + serverKey = + SHA1Block::computeHmac(saltedPassword.data(), + saltedPassword.size(), + reinterpret_cast<const unsigned char*>(serverKeyConst.data()), + serverKeyConst.size()); + }; + return SCRAMSecrets(std::move(generateAndStoreSecrets)); } @@ -164,8 +164,8 @@ BSONObj generateCredentials(const std::string& hashedPassword, int iterationCoun saltLenQWords * sizeof(uint64_t)), iterationCount)); - std::string encodedStoredKey = secrets.storedKey->toString(); - std::string encodedServerKey = secrets.serverKey->toString(); + std::string encodedStoredKey = secrets->storedKey.toString(); + std::string encodedServerKey = secrets->serverKey.toString(); return BSON(iterationCountFieldName << iterationCount << saltFieldName << encodedUserSalt << storedKeyFieldName << encodedStoredKey @@ -176,12 +176,12 @@ std::string generateClientProof(const SCRAMSecrets& clientCredentials, const std::string& authMessage) { // ClientSignature := HMAC(StoredKey, AuthMessage) SHA1Block clientSignature = - SHA1Block::computeHmac(clientCredentials.storedKey->data(), - clientCredentials.storedKey->size(), + SHA1Block::computeHmac(clientCredentials->storedKey.data(), + clientCredentials->storedKey.size(), reinterpret_cast<const unsigned char*>(authMessage.c_str()), authMessage.size()); - clientSignature.xorInline(*clientCredentials.clientKey); + clientSignature.xorInline(clientCredentials->clientKey); return clientSignature.toString(); } @@ -190,8 +190,8 @@ bool verifyServerSignature(const SCRAMSecrets& clientCredentials, const std::string& receivedServerSignature) { // ServerSignature := HMAC(ServerKey, AuthMessage) SHA1Block serverSignature = - SHA1Block::computeHmac(clientCredentials.serverKey->data(), - clientCredentials.serverKey->size(), + SHA1Block::computeHmac(clientCredentials->serverKey.data(), + clientCredentials->serverKey.size(), reinterpret_cast<const unsigned char*>(authMessage.c_str()), authMessage.size()); diff --git a/src/mongo/crypto/mechanism_scram.h b/src/mongo/crypto/mechanism_scram.h index fb070e8162c..92b79efffb9 100644 --- a/src/mongo/crypto/mechanism_scram.h +++ b/src/mongo/crypto/mechanism_scram.h @@ -76,12 +76,48 @@ SHA1Block generateSaltedPassword(const SCRAMPresecrets& presecrets); /* * Stores all of the keys, generated from a password, needed for a client or server to perform a - * SCRAM handshake. This structure will secureZeroMemory itself on destruction. + * SCRAM handshake. + * These keys are reference counted, and allocated using the SecureAllocator. + * May be unpopulated. SCRAMSecrets created via the default constructor are unpopulated. + * The behavior is undefined if the accessors are called when unpopulated. */ -struct SCRAMSecrets { - SecureHandle<SHA1Block> clientKey; - SecureHandle<SHA1Block> storedKey; - SecureHandle<SHA1Block> serverKey; +class SCRAMSecrets { +private: + struct SCRAMSecretsHolder { + SHA1Block clientKey; + SHA1Block storedKey; + SHA1Block serverKey; + }; + +public: + // Creates an unpopulated SCRAMSecrets object. + SCRAMSecrets() = default; + + // Creates a populated SCRAMSecrets object. First, allocates secure storage, then provides it + // to a callback, which fills the memory. + template <typename T> + explicit SCRAMSecrets(T initializationFun) + : _ptr(std::make_shared<SecureHandle<SCRAMSecretsHolder>>()) { + initializationFun((*this)->clientKey, (*this)->storedKey, (*this)->serverKey); + } + + // Returns true if the underlying shared_pointer is populated. + explicit operator bool() const { + return static_cast<bool>(_ptr); + } + + const SecureHandle<SCRAMSecretsHolder>& operator*() const { + invariant(_ptr); + return *_ptr; + } + + const SecureHandle<SCRAMSecretsHolder>& operator->() const { + invariant(_ptr); + return *_ptr; + } + +private: + std::shared_ptr<SecureHandle<SCRAMSecretsHolder>> _ptr; }; /* diff --git a/src/mongo/crypto/sha1_block.h b/src/mongo/crypto/sha1_block.h index 280177de309..266f6602e6c 100644 --- a/src/mongo/crypto/sha1_block.h +++ b/src/mongo/crypto/sha1_block.h @@ -63,7 +63,21 @@ public: static SHA1Block computeHmac(const uint8_t* key, size_t keyLen, const uint8_t* input, - size_t inputLen); + size_t inputLen) { + SHA1Block output; + SHA1Block::computeHmac(key, keyLen, input, inputLen, &output); + return output; + } + + /** + * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key'. Writes the results into + * a pre-allocated SHA1Block. This lets us allocate SHA1Blocks with the SecureAllocator. + */ + static void computeHmac(const uint8_t* key, + size_t keyLen, + const uint8_t* input, + size_t inputLen, + SHA1Block* const output); const uint8_t* data() const { return _hash.data(); diff --git a/src/mongo/crypto/sha1_block_openssl.cpp b/src/mongo/crypto/sha1_block_openssl.cpp index 50d18b9bace..ce0cab8e9e6 100644 --- a/src/mongo/crypto/sha1_block_openssl.cpp +++ b/src/mongo/crypto/sha1_block_openssl.cpp @@ -82,15 +82,15 @@ SHA1Block SHA1Block::computeHash(const uint8_t* input, size_t inputLen) { } /* - * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key' + * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key', writes output into 'output'. */ -SHA1Block SHA1Block::computeHmac(const uint8_t* key, - size_t keyLen, - const uint8_t* input, - size_t inputLen) { - HashType output; - fassert(40380, HMAC(EVP_sha1(), key, keyLen, input, inputLen, output.data(), NULL) != NULL); - return SHA1Block(output); +void SHA1Block::computeHmac(const uint8_t* key, + size_t keyLen, + const uint8_t* input, + size_t inputLen, + SHA1Block* const output) { + fassert(40380, + HMAC(EVP_sha1(), key, keyLen, input, inputLen, output->_hash.data(), NULL) != NULL); } } // namespace mongo diff --git a/src/mongo/crypto/sha1_block_test.cpp b/src/mongo/crypto/sha1_block_test.cpp index 111132ff388..651b2e52bfe 100644 --- a/src/mongo/crypto/sha1_block_test.cpp +++ b/src/mongo/crypto/sha1_block_test.cpp @@ -40,13 +40,49 @@ namespace { const struct { const char* msg; SHA1Block hash; -} sha1Tests[] = { - {"abc", SHA1Block::HashType{0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81, 0x6a, 0xba, 0x3e, - 0x25, 0x71, 0x78, 0x50, 0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d}}, +} sha1Tests[] = {{"abc", + SHA1Block::HashType{0xa9, + 0x99, + 0x3e, + 0x36, + 0x47, + 0x06, + 0x81, + 0x6a, + 0xba, + 0x3e, + 0x25, + 0x71, + 0x78, + 0x50, + 0xc2, + 0x6c, + 0x9c, + 0xd0, + 0xd8, + 0x9d}}, - {"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", - SHA1Block::HashType{0x84, 0x98, 0x3E, 0x44, 0x1C, 0x3B, 0xD2, 0x6E, 0xBA, 0xAE, - 0x4A, 0xA1, 0xF9, 0x51, 0x29, 0xE5, 0xE5, 0x46, 0x70, 0xF1}}}; + {"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", + SHA1Block::HashType{0x84, + 0x98, + 0x3E, + 0x44, + 0x1C, + 0x3B, + 0xD2, + 0x6E, + 0xBA, + 0xAE, + 0x4A, + 0xA1, + 0xF9, + 0x51, + 0x29, + 0xE5, + 0xE5, + 0x46, + 0x70, + 0xF1}}}; TEST(CryptoVectors, SHA1) { size_t numTests = sizeof(sha1Tests) / sizeof(sha1Tests[0]); @@ -91,8 +127,26 @@ const struct { 20, {0x48, 0x69, 0x20, 0x54, 0x68, 0x65, 0x72, 0x65}, 8, - SHA1Block::HashType{0xb6, 0x17, 0x31, 0x86, 0x55, 0x05, 0x72, 0x64, 0xe2, 0x8b, - 0xc0, 0xb6, 0xfb, 0x37, 0x8c, 0x8e, 0xf1, 0x46, 0xbe, 0x00}}, + SHA1Block::HashType{0xb6, + 0x17, + 0x31, + 0x86, + 0x55, + 0x05, + 0x72, + 0x64, + 0xe2, + 0x8b, + 0xc0, + 0xb6, + 0xfb, + 0x37, + 0x8c, + 0x8e, + 0xf1, + 0x46, + 0xbe, + 0x00}}, // RFC test case 3 {{0xaa, @@ -167,8 +221,26 @@ const struct { 0xdd, 0xdd}, 50, - SHA1Block::HashType{0x12, 0x5d, 0x73, 0x42, 0xb9, 0xac, 0x11, 0xcd, 0x91, 0xa3, - 0x9a, 0xf4, 0x8a, 0xa1, 0x7b, 0x4f, 0x63, 0xf1, 0x75, 0xd3}}, + SHA1Block::HashType{0x12, + 0x5d, + 0x73, + 0x42, + 0xb9, + 0xac, + 0x11, + 0xcd, + 0x91, + 0xa3, + 0x9a, + 0xf4, + 0x8a, + 0xa1, + 0x7b, + 0x4f, + 0x63, + 0xf1, + 0x75, + 0xd3}}, // RFC test case 4 {{0x01, @@ -248,8 +320,26 @@ const struct { 0xcd, 0xcd}, 50, - SHA1Block::HashType{0x4c, 0x90, 0x07, 0xf4, 0x02, 0x62, 0x50, 0xc6, 0xbc, 0x84, - 0x14, 0xf9, 0xbf, 0x50, 0xc8, 0x6c, 0x2d, 0x72, 0x35, 0xda}}, + SHA1Block::HashType{0x4c, + 0x90, + 0x07, + 0xf4, + 0x02, + 0x62, + 0x50, + 0xc6, + 0xbc, + 0x84, + 0x14, + 0xf9, + 0xbf, + 0x50, + 0xc8, + 0x6c, + 0x2d, + 0x72, + 0x35, + 0xda}}, // RFC test case 6 {{0xaa, @@ -388,8 +478,26 @@ const struct { 0x73, 0x74}, 54, - SHA1Block::HashType{0xaa, 0x4a, 0xe5, 0xe1, 0x52, 0x72, 0xd0, 0x0e, 0x95, 0x70, - 0x56, 0x37, 0xce, 0x8a, 0x3b, 0x55, 0xed, 0x40, 0x21, 0x12}}}; + SHA1Block::HashType{0xaa, + 0x4a, + 0xe5, + 0xe1, + 0x52, + 0x72, + 0xd0, + 0x0e, + 0x95, + 0x70, + 0x56, + 0x37, + 0xce, + 0x8a, + 0x3b, + 0x55, + 0xed, + 0x40, + 0x21, + 0x12}}}; TEST(CryptoVectors, HMACSHA1) { size_t numTests = sizeof(hmacSha1Tests) / sizeof(hmacSha1Tests[0]); diff --git a/src/mongo/crypto/sha1_block_tom.cpp b/src/mongo/crypto/sha1_block_tom.cpp index d1078e60458..7d32dbc2f53 100644 --- a/src/mongo/crypto/sha1_block_tom.cpp +++ b/src/mongo/crypto/sha1_block_tom.cpp @@ -58,12 +58,12 @@ SHA1Block SHA1Block::computeHash(const uint8_t* input, size_t inputLen) { /* * Computes a HMAC SHA-1 keyed hash of 'input' using the key 'key' */ -SHA1Block SHA1Block::computeHmac(const uint8_t* key, - size_t keyLen, - const uint8_t* input, - size_t inputLen) { +void SHA1Block::computeHmac(const uint8_t* key, + size_t keyLen, + const uint8_t* input, + size_t inputLen, + SHA1Block* const output) { invariant(key && input); - HashType output; static int hashId = -1; if (hashId == -1) { @@ -73,9 +73,8 @@ SHA1Block SHA1Block::computeHmac(const uint8_t* key, unsigned long sha1HashLen = 20; fassert(40382, - hmac_memory(hashId, key, keyLen, input, inputLen, output.data(), &sha1HashLen) == + hmac_memory(hashId, key, keyLen, input, inputLen, output->_hash.data(), &sha1HashLen) == CRYPT_OK); - return SHA1Block(output); } } // namespace mongo diff --git a/src/mongo/db/auth/SConscript b/src/mongo/db/auth/SConscript index c16caede0ae..1201d5e8d0e 100644 --- a/src/mongo/db/auth/SConscript +++ b/src/mongo/db/auth/SConscript @@ -188,3 +188,10 @@ env.Library( 'authcore', ] ) + +env.CppUnitTest('sasl_scramsha1_test', + 'sasl_scramsha1_test.cpp', + LIBDEPS=[ + 'saslauth', + '$BUILD_DIR/mongo/client/sasl_client', + ]) diff --git a/src/mongo/db/auth/sasl_plain_server_conversation.cpp b/src/mongo/db/auth/sasl_plain_server_conversation.cpp index d68a1ab92ad..d1e993db4a1 100644 --- a/src/mongo/db/auth/sasl_plain_server_conversation.cpp +++ b/src/mongo/db/auth/sasl_plain_server_conversation.cpp @@ -88,8 +88,8 @@ StatusWith<bool> SaslPLAINServerConversation::step(StringData inputData, std::st 16), creds.scram.iterationCount)); if (creds.scram.storedKey != - base64::encode(reinterpret_cast<const char*>(secrets.storedKey->data()), - secrets.storedKey->size())) { + base64::encode(reinterpret_cast<const char*>(secrets->storedKey.data()), + secrets->storedKey.size())) { return StatusWith<bool>(ErrorCodes::AuthenticationFailed, mongoutils::str::stream() << "Incorrect user name or password"); } diff --git a/src/mongo/db/auth/sasl_scramsha1_test.cpp b/src/mongo/db/auth/sasl_scramsha1_test.cpp new file mode 100644 index 00000000000..66cddabe02e --- /dev/null +++ b/src/mongo/db/auth/sasl_scramsha1_test.cpp @@ -0,0 +1,501 @@ +/* + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/client/native_sasl_client_session.h" +#include "mongo/client/scram_sha1_client_cache.h" +#include "mongo/crypto/mechanism_scram.h" +#include "mongo/db/auth/authorization_manager.h" +#include "mongo/db/auth/authz_manager_external_state_mock.h" +#include "mongo/db/auth/authz_session_external_state_mock.h" +#include "mongo/db/auth/native_sasl_authentication_session.h" +#include "mongo/db/auth/sasl_scramsha1_server_conversation.h" +#include "mongo/db/service_context_noop.h" +#include "mongo/stdx/memory.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/base64.h" +#include "mongo/util/password_digest.h" + +namespace mongo { + +BSONObj generateSCRAMUserDocument(StringData username, StringData password) { + const size_t scramIterationCount = 10000; + std::string database = "test"; + + std::string digested = createPasswordDigest(username, password); + BSONObj scramCred = scram::generateCredentials(digested, scramIterationCount); + return BSON("_id" << (str::stream() << database << "." << username).operator std::string() + << AuthorizationManager::USER_NAME_FIELD_NAME << username + << AuthorizationManager::USER_DB_FIELD_NAME << database << "credentials" + << BSON("SCRAM-SHA-1" << scramCred) << "roles" << BSONArray() << "privileges" + << BSONArray()); +} + +BSONObj generateMONGODBCRUserDocument(StringData username, StringData password) { + std::string database = "test"; + + std::string digested = createPasswordDigest(username, password); + return BSON("_id" << (str::stream() << database << "." << username).operator std::string() + << AuthorizationManager::USER_NAME_FIELD_NAME << username + << AuthorizationManager::USER_DB_FIELD_NAME << database << "credentials" + << BSON("MONGODB-CR" << digested) << "roles" << BSONArray() << "privileges" + << BSONArray()); +} + +std::string corruptEncodedPayload(const std::string& message, + std::string::const_iterator begin, + std::string::const_iterator end) { + std::string raw = base64::decode( + message.substr(std::distance(message.begin(), begin), std::distance(begin, end))); + if (raw[0] == std::numeric_limits<char>::max()) { + raw[0] -= 1; + } else { + raw[0] += 1; + } + return base64::encode(raw); +} + +class SaslTestState { +public: + enum Participant { kClient, kServer }; + SaslTestState() : SaslTestState(kClient, 0) {} + SaslTestState(Participant participant, size_t stage) : participant(participant), stage(stage) {} + +private: + // Define members here, so that they can be used in declaration of lens(). In C++14, lens() + // can be declared with a return of decltype(auto), without a trailing return type, and these + // members can go at the end of the class. + Participant participant; + size_t stage; + +public: + std::tuple<size_t, Participant> lens() const { + return std::tie(stage, participant); + } + + friend bool operator==(const SaslTestState& lhs, const SaslTestState& rhs) { + return lhs.lens() == rhs.lens(); + } + + friend bool operator<(const SaslTestState& lhs, const SaslTestState& rhs) { + return lhs.lens() < rhs.lens(); + } + + void next() { + if (participant == kClient) { + participant = kServer; + } else { + participant = kClient; + stage++; + } + } + + std::string toString() const { + std::stringstream ss; + if (participant == kClient) { + ss << "Client"; + } else { + ss << "Server"; + } + ss << "Step" << stage; + + return ss.str(); + } +}; + +class SCRAMMutators { +public: + SCRAMMutators() {} + + void setMutator(SaslTestState state, stdx::function<void(std::string&)> fun) { + mutators.insert(std::make_pair(state, fun)); + } + + void execute(SaslTestState state, std::string& str) { + auto it = mutators.find(state); + if (it != mutators.end()) { + it->second(str); + } + } + +private: + std::map<SaslTestState, stdx::function<void(std::string&)>> mutators; +}; + +struct SCRAMStepsResult { + SCRAMStepsResult() : outcome(SaslTestState::kClient, 1), status(Status::OK()) {} + SCRAMStepsResult(SaslTestState outcome, Status status) : outcome(outcome), status(status) {} + bool operator==(const SCRAMStepsResult& other) const { + return outcome == other.outcome && status.code() == other.status.code() && + status.reason() == other.status.reason(); + } + SaslTestState outcome; + Status status; + + friend std::ostream& operator<<(std::ostream& os, const SCRAMStepsResult& result) { + return os << "{outcome: " << result.outcome.toString() << ", status: " << result.status + << "}"; + } +}; + +SCRAMStepsResult runSteps(NativeSaslAuthenticationSession* saslServerSession, + NativeSaslClientSession* saslClientSession, + SCRAMMutators interposers = SCRAMMutators{}) { + SCRAMStepsResult result{}; + std::string clientOutput = ""; + std::string serverOutput = ""; + + for (size_t step = 1; step <= 3; step++) { + ASSERT_FALSE(saslClientSession->isDone()); + ASSERT_FALSE(saslServerSession->isDone()); + + // Client step + result.status = saslClientSession->step(serverOutput, &clientOutput); + if (result.status != Status::OK()) { + return result; + } + std::cout << result.outcome.toString() << ": " << clientOutput << std::endl; + interposers.execute(result.outcome, clientOutput); + result.outcome.next(); + + // Server step + result.status = saslServerSession->step(clientOutput, &serverOutput); + if (result.status != Status::OK()) { + return result; + } + interposers.execute(result.outcome, serverOutput); + std::cout << result.outcome.toString() << ": " << serverOutput << std::endl; + result.outcome.next(); + } + ASSERT_TRUE(saslClientSession->isDone()); + ASSERT_TRUE(saslServerSession->isDone()); + + return result; +} + +class SCRAMSHA1Fixture : public mongo::unittest::Test { +protected: + const SCRAMStepsResult goalState = + SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 4), Status::OK()); + + ServiceContextNoop serviceContext; + ServiceContextNoop::UniqueClient client; + ServiceContextNoop::UniqueOperationContext txn; + + AuthzManagerExternalStateMock* authzManagerExternalState; + std::unique_ptr<AuthorizationManager> authzManager; + std::unique_ptr<AuthorizationSession> authzSession; + + std::unique_ptr<NativeSaslAuthenticationSession> saslServerSession; + std::unique_ptr<NativeSaslClientSession> saslClientSession; + + void setUp() { + client = serviceContext.makeClient("test"); + txn = serviceContext.makeOperationContext(client.get()); + + auto uniqueAuthzManagerExternalStateMock = + stdx::make_unique<AuthzManagerExternalStateMock>(); + authzManagerExternalState = uniqueAuthzManagerExternalStateMock.get(); + authzManager = + stdx::make_unique<AuthorizationManager>(std::move(uniqueAuthzManagerExternalStateMock)); + authzSession = stdx::make_unique<AuthorizationSession>( + stdx::make_unique<AuthzSessionExternalStateMock>(authzManager.get())); + + saslServerSession = stdx::make_unique<NativeSaslAuthenticationSession>(authzSession.get()); + saslServerSession->setOpCtxt(txn.get()); + saslServerSession->start("test", "SCRAM-SHA-1", "mongodb", "MockServer.test", 1, false); + saslClientSession = stdx::make_unique<NativeSaslClientSession>(); + saslClientSession->setParameter(NativeSaslClientSession::parameterMechanism, "SCRAM-SHA-1"); + saslClientSession->setParameter(NativeSaslClientSession::parameterServiceName, "mongodb"); + saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostname, + "MockServer.test"); + saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostAndPort, + "MockServer.test:27017"); + } +}; + +/*TEST_F(SCRAMSHA1Fixture, testServerStep1DoesNotIncludeNonceFromClientStep1) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + SCRAMMutators mutator; + mutator.setMutator(SaslTestState(SaslTestState::kServer, 1), [](std::string& serverMessage) { + std::string::iterator nonceBegin = serverMessage.begin() + serverMessage.find("r="); + std::string::iterator nonceEnd = std::find(nonceBegin, serverMessage.end(), ','); + serverMessage = serverMessage.replace(nonceBegin, nonceEnd, "r="); + + }); + ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 2), + Status(ErrorCodes::BadValue, + "Server SCRAM-SHA-1 nonce does not match client nonce: r=")), + runSteps(saslServerSession.get(), saslClientSession.get(), mutator)); +}*/ + +TEST_F(SCRAMSHA1Fixture, testClientStep2DoesNotIncludeNonceFromServerStep1) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + SCRAMMutators mutator; + mutator.setMutator(SaslTestState(SaslTestState::kClient, 2), + [](std::string& clientMessage) { + std::string::iterator nonceBegin = + clientMessage.begin() + clientMessage.find("r="); + std::string::iterator nonceEnd = + std::find(nonceBegin, clientMessage.end(), ','); + clientMessage = clientMessage.replace(nonceBegin, nonceEnd, "r="); + }); + ASSERT_EQ(SCRAMStepsResult( + SaslTestState(SaslTestState::kServer, 2), + Status(ErrorCodes::BadValue, "Incorrect SCRAM-SHA-1 client|server nonce: r=")), + runSteps(saslServerSession.get(), saslClientSession.get(), mutator)); +} + +TEST_F(SCRAMSHA1Fixture, testClientStep2GivesBadProof) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + SCRAMMutators mutator; + mutator.setMutator( + SaslTestState(SaslTestState::kClient, 2), + [](std::string& clientMessage) { + std::string::iterator proofBegin = clientMessage.begin() + clientMessage.find("p=") + 2; + std::string::iterator proofEnd = std::find(proofBegin, clientMessage.end(), ','); + clientMessage = clientMessage.replace( + proofBegin, proofEnd, corruptEncodedPayload(clientMessage, proofBegin, proofEnd)); + + }); + + ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2), + Status(ErrorCodes::AuthenticationFailed, + "SCRAM-SHA-1 authentication failed, storedKey mismatch")), + runSteps(saslServerSession.get(), saslClientSession.get(), mutator)); +} + +TEST_F(SCRAMSHA1Fixture, testServerStep2GivesBadVerifier) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + std::string encodedVerifier; + SCRAMMutators mutator; + mutator.setMutator( + SaslTestState(SaslTestState::kServer, 2), + [&encodedVerifier](std::string& serverMessage) { + std::string::iterator verifierBegin = + serverMessage.begin() + serverMessage.find("v=") + 2; + std::string::iterator verifierEnd = std::find(verifierBegin, serverMessage.end(), ','); + encodedVerifier = corruptEncodedPayload(serverMessage, verifierBegin, verifierEnd); + + serverMessage = serverMessage.replace(verifierBegin, verifierEnd, encodedVerifier); + + }); + + auto result = runSteps(saslServerSession.get(), saslClientSession.get(), mutator); + + ASSERT_EQ( + SCRAMStepsResult( + SaslTestState(SaslTestState::kClient, 3), + Status(ErrorCodes::BadValue, + str::stream() << "Client failed to verify SCRAM-SHA-1 ServerSignature, received " + << encodedVerifier)), + result); +} + + +TEST_F(SCRAMSHA1Fixture, testSCRAM) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get())); +} + +TEST_F(SCRAMSHA1Fixture, testNULLInPassword) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "saj\0ack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "saj\0ack")); + + ASSERT_OK(saslClientSession->initialize()); + + ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get())); +} + + +TEST_F(SCRAMSHA1Fixture, testCommasInUsernameAndPassword) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("s,a,jack", "s,a,jack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "s,a,jack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("s,a,jack", "s,a,jack")); + + ASSERT_OK(saslClientSession->initialize()); + + ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get())); +} + +TEST_F(SCRAMSHA1Fixture, testIncorrectUser) { + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1), + Status(ErrorCodes::UserNotFound, "Could not find user sajack@test")), + runSteps(saslServerSession.get(), saslClientSession.get())); +} + +TEST_F(SCRAMSHA1Fixture, testIncorrectPassword) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "invalidPassword")); + + ASSERT_OK(saslClientSession->initialize()); + + ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2), + Status(ErrorCodes::AuthenticationFailed, + "SCRAM-SHA-1 authentication failed, storedKey mismatch")), + runSteps(saslServerSession.get(), saslClientSession.get())); +} + +TEST_F(SCRAMSHA1Fixture, testMONGODBCR) { + authzManagerExternalState->insertPrivilegeDocument( + txn.get(), generateMONGODBCRUserDocument("sajack", "sajack"), BSONObj()); + + saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack"); + saslClientSession->setParameter(NativeSaslClientSession::parameterPassword, + createPasswordDigest("sajack", "sajack")); + + ASSERT_OK(saslClientSession->initialize()); + + ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get())); +} + +TEST(SCRAMSHA1Cache, testGetFromEmptyCache) { + SCRAMSHA1ClientCache cache; + std::string saltStr("saltsaltsaltsalt"); + std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end()); + HostAndPort host("localhost:27017"); + + ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000))); +} + + +TEST(SCRAMSHA1Cache, testSetAndGet) { + SCRAMSHA1ClientCache cache; + std::string saltStr("saltsaltsaltsalt"); + std::string badSaltStr("s@lts@lts@lts@lt"); + std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end()); + std::vector<std::uint8_t> badSalt(badSaltStr.begin(), badSaltStr.end()); + HostAndPort host("localhost:27017"); + + auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000)); + cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret); + auto cachedSecret = cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000)); + ASSERT_TRUE(cachedSecret); + ASSERT_TRUE(secret->clientKey == cachedSecret->clientKey); + ASSERT_TRUE(secret->serverKey == cachedSecret->serverKey); + ASSERT_TRUE(secret->storedKey == cachedSecret->storedKey); +} + + +TEST(SCRAMSHA1Cache, testSetAndGetWithDifferentParameters) { + SCRAMSHA1ClientCache cache; + std::string saltStr("saltsaltsaltsalt"); + std::string badSaltStr("s@lts@lts@lts@lt"); + std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end()); + std::vector<std::uint8_t> badSalt(badSaltStr.begin(), badSaltStr.end()); + HostAndPort host("localhost:27017"); + + auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000)); + cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret); + + ASSERT_FALSE(cache.getCachedSecrets(HostAndPort("localhost:27018"), + scram::SCRAMPresecrets("aaa", salt, 10000))); + ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000))); + ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", badSalt, 10000))); + ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10001))); +} + + +TEST(SCRAMSHA1Cache, testSetAndReset) { + SCRAMSHA1ClientCache cache; + StringData saltStr("saltsaltsaltsalt"); + std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end()); + HostAndPort host("localhost:27017"); + + auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000)); + cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret); + auto newSecret = scram::generateSecrets(scram::SCRAMPresecrets("aab", salt, 10000)); + cache.setCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000), newSecret); + + ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000))); + auto cachedSecret = cache.getCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000)); + ASSERT_TRUE(cachedSecret); + ASSERT_TRUE(newSecret->clientKey == cachedSecret->clientKey); + ASSERT_TRUE(newSecret->serverKey == cachedSecret->serverKey); + ASSERT_TRUE(newSecret->storedKey == cachedSecret->storedKey); +} + +} // namespace mongo |