diff options
Diffstat (limited to 'src/mongo/crypto/mechanism_scram.cpp')
-rw-r--r-- | src/mongo/crypto/mechanism_scram.cpp | 120 |
1 files changed, 54 insertions, 66 deletions
diff --git a/src/mongo/crypto/mechanism_scram.cpp b/src/mongo/crypto/mechanism_scram.cpp index 57a2eed3dd0..a1b3c69d8fc 100644 --- a/src/mongo/crypto/mechanism_scram.cpp +++ b/src/mongo/crypto/mechanism_scram.cpp @@ -32,7 +32,6 @@ #include <vector> -#include "mongo/crypto/crypto.h" #include "mongo/platform/random.h" #include "mongo/util/base64.h" #include "mongo/util/secure_zero_memory.h" @@ -67,18 +66,14 @@ bool consttimeMemEqual(volatile const unsigned char* s1, // NOLINT - using vola } } // namespace -std::string hashToBase64(const SecureHandle<SHA1Hash>& hash) { - return base64::encode(reinterpret_cast<const char*>(hash->data()), hash->size()); -} - // Compute the SCRAM step Hi() as defined in RFC5802 -static SHA1Hash HMACIteration(const unsigned char input[], - size_t inputLen, - const unsigned char salt[], - size_t saltLen, - unsigned int iterationCount) { - SHA1Hash output; - SHA1Hash intermediateDigest; +static SHA1Block HMACIteration(const unsigned char input[], + size_t inputLen, + const unsigned char salt[], + size_t saltLen, + unsigned int iterationCount) { + SHA1Block output; + SHA1Block intermediateDigest; // Reserve a 20 byte block for the initial key. We use 16 byte salts, and must reserve an extra // 4 bytes for a suffix mandated by RFC5802. std::array<std::uint8_t, 20> startKey; @@ -92,25 +87,23 @@ static SHA1Hash HMACIteration(const unsigned char input[], startKey[saltLen + 3] = 1; // U1 = HMAC(input, salt + 0001) - output = crypto::hmacSha1(input, inputLen, startKey.data(), startKey.size()); + output = SHA1Block::computeHmac(input, inputLen, startKey.data(), startKey.size()); intermediateDigest = output; // intermediateDigest contains Ui and output contains the accumulated XOR:ed result for (size_t i = 2; i <= iterationCount; i++) { - intermediateDigest = - crypto::hmacSha1(input, inputLen, intermediateDigest.data(), intermediateDigest.size()); - for (size_t k = 0; k < output.size(); k++) { - output[k] ^= intermediateDigest[k]; - } + intermediateDigest = SHA1Block::computeHmac( + input, inputLen, intermediateDigest.data(), intermediateDigest.size()); + output.xorInline(intermediateDigest); } return output; } // Iterate the hash function to generate SaltedPassword -SHA1Hash generateSaltedPassword(const SCRAMPresecrets& presecrets) { +SHA1Block generateSaltedPassword(const SCRAMPresecrets& presecrets) { // saltedPassword = Hi(hashedPassword, salt) - SHA1Hash saltedPassword = + SHA1Block saltedPassword = HMACIteration(reinterpret_cast<const unsigned char*>(presecrets.hashedPassword.c_str()), presecrets.hashedPassword.size(), presecrets.salt.data(), @@ -121,30 +114,30 @@ SHA1Hash generateSaltedPassword(const SCRAMPresecrets& presecrets) { } SCRAMSecrets generateSecrets(const SCRAMPresecrets& presecrets) { - SHA1Hash saltedPassword = generateSaltedPassword(presecrets); + SHA1Block saltedPassword = generateSaltedPassword(presecrets); return generateSecrets(saltedPassword); } -SCRAMSecrets generateSecrets(const SHA1Hash& saltedPassword) { +SCRAMSecrets generateSecrets(const SHA1Block& saltedPassword) { SCRAMSecrets credentials; // ClientKey := HMAC(saltedPassword, "Client Key") credentials.clientKey = - crypto::hmacSha1(saltedPassword.data(), - saltedPassword.size(), - reinterpret_cast<const unsigned char*>(clientKeyConst.data()), - clientKeyConst.size()); + SHA1Block::computeHmac(saltedPassword.data(), + saltedPassword.size(), + reinterpret_cast<const unsigned char*>(clientKeyConst.data()), + clientKeyConst.size()); // StoredKey := H(clientKey) credentials.storedKey = - crypto::sha1(credentials.clientKey->data(), credentials.clientKey->size()); + SHA1Block::computeHash(credentials.clientKey->data(), credentials.clientKey->size()); // ServerKey := HMAC(SaltedPassword, "Server Key") credentials.serverKey = - crypto::hmacSha1(saltedPassword.data(), - saltedPassword.size(), - reinterpret_cast<const unsigned char*>(serverKeyConst.data()), - serverKeyConst.size()); + SHA1Block::computeHmac(saltedPassword.data(), + saltedPassword.size(), + reinterpret_cast<const unsigned char*>(serverKeyConst.data()), + serverKeyConst.size()); return credentials; } @@ -171,8 +164,8 @@ BSONObj generateCredentials(const std::string& hashedPassword, int iterationCoun saltLenQWords * sizeof(uint64_t)), iterationCount)); - std::string encodedStoredKey = hashToBase64(secrets.storedKey); - std::string encodedServerKey = hashToBase64(secrets.serverKey); + std::string encodedStoredKey = secrets.storedKey->toString(); + std::string encodedServerKey = secrets.serverKey->toString(); return BSON(iterationCountFieldName << iterationCount << saltFieldName << encodedUserSalt << storedKeyFieldName @@ -184,32 +177,27 @@ BSONObj generateCredentials(const std::string& hashedPassword, int iterationCoun std::string generateClientProof(const SCRAMSecrets& clientCredentials, const std::string& authMessage) { // ClientSignature := HMAC(StoredKey, AuthMessage) - SHA1Hash clientSignature = - crypto::hmacSha1(clientCredentials.storedKey->data(), - clientCredentials.storedKey->size(), - reinterpret_cast<const unsigned char*>(authMessage.c_str()), - authMessage.size()); - - // ClientProof := ClientKey XOR ClientSignature - SHA1Hash clientProof; - for (size_t i = 0; i < clientCredentials.clientKey->size(); i++) { - clientProof[i] = (*clientCredentials.clientKey)[i] ^ clientSignature[i]; - } - - return hashToBase64(clientProof); + SHA1Block clientSignature = + SHA1Block::computeHmac(clientCredentials.storedKey->data(), + clientCredentials.storedKey->size(), + reinterpret_cast<const unsigned char*>(authMessage.c_str()), + authMessage.size()); + + clientSignature.xorInline(*clientCredentials.clientKey); + return clientSignature.toString(); } bool verifyServerSignature(const SCRAMSecrets& clientCredentials, const std::string& authMessage, const std::string& receivedServerSignature) { // ServerSignature := HMAC(ServerKey, AuthMessage) - SHA1Hash serverSignature = - crypto::hmacSha1(clientCredentials.serverKey->data(), - clientCredentials.serverKey->size(), - reinterpret_cast<const unsigned char*>(authMessage.c_str()), - authMessage.size()); + SHA1Block serverSignature = + SHA1Block::computeHmac(clientCredentials.serverKey->data(), + clientCredentials.serverKey->size(), + reinterpret_cast<const unsigned char*>(authMessage.c_str()), + authMessage.size()); - std::string encodedServerSignature = hashToBase64(serverSignature); + std::string encodedServerSignature = serverSignature.toString(); if (encodedServerSignature.size() != receivedServerSignature.size()) { return false; @@ -223,20 +211,20 @@ bool verifyServerSignature(const SCRAMSecrets& clientCredentials, bool verifyClientProof(StringData clientProof, StringData storedKey, StringData authMessage) { // ClientSignature := HMAC(StoredKey, AuthMessage) - SHA1Hash clientSignature = - crypto::hmacSha1(reinterpret_cast<const unsigned char*>(storedKey.rawData()), - storedKey.size(), - reinterpret_cast<const unsigned char*>(authMessage.rawData()), - authMessage.size()); - - // ClientKey := ClientSignature XOR ClientProof - SHA1Hash clientKey; - for (size_t i = 0; i < clientKey.size(); i++) { - clientKey[i] = clientSignature[i] ^ clientProof.rawData()[i]; - } - - // StoredKey := H(ClientKey) - SHA1Hash computedStoredKey = crypto::sha1(clientKey.data(), clientKey.size()); + SHA1Block clientSignature = + SHA1Block::computeHmac(reinterpret_cast<const unsigned char*>(storedKey.rawData()), + storedKey.size(), + reinterpret_cast<const unsigned char*>(authMessage.rawData()), + authMessage.size()); + + auto clientProofSHA1Status = SHA1Block::fromBuffer( + reinterpret_cast<const uint8_t*>(clientProof.rawData()), clientProof.size()); + uassertStatusOK(clientProofSHA1Status); + clientSignature.xorInline(clientProofSHA1Status.getValue()); + + // StoredKey := H(clientSignature) + SHA1Block computedStoredKey = + SHA1Block::computeHash(clientSignature.data(), clientSignature.size()); return consttimeMemEqual(reinterpret_cast<const unsigned char*>(storedKey.rawData()), computedStoredKey.data(), |