summaryrefslogtreecommitdiff
path: root/src/mongo/crypto/mechanism_scram.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/crypto/mechanism_scram.cpp')
-rw-r--r--src/mongo/crypto/mechanism_scram.cpp120
1 files changed, 54 insertions, 66 deletions
diff --git a/src/mongo/crypto/mechanism_scram.cpp b/src/mongo/crypto/mechanism_scram.cpp
index 57a2eed3dd0..a1b3c69d8fc 100644
--- a/src/mongo/crypto/mechanism_scram.cpp
+++ b/src/mongo/crypto/mechanism_scram.cpp
@@ -32,7 +32,6 @@
#include <vector>
-#include "mongo/crypto/crypto.h"
#include "mongo/platform/random.h"
#include "mongo/util/base64.h"
#include "mongo/util/secure_zero_memory.h"
@@ -67,18 +66,14 @@ bool consttimeMemEqual(volatile const unsigned char* s1, // NOLINT - using vola
}
} // namespace
-std::string hashToBase64(const SecureHandle<SHA1Hash>& hash) {
- return base64::encode(reinterpret_cast<const char*>(hash->data()), hash->size());
-}
-
// Compute the SCRAM step Hi() as defined in RFC5802
-static SHA1Hash HMACIteration(const unsigned char input[],
- size_t inputLen,
- const unsigned char salt[],
- size_t saltLen,
- unsigned int iterationCount) {
- SHA1Hash output;
- SHA1Hash intermediateDigest;
+static SHA1Block HMACIteration(const unsigned char input[],
+ size_t inputLen,
+ const unsigned char salt[],
+ size_t saltLen,
+ unsigned int iterationCount) {
+ SHA1Block output;
+ SHA1Block intermediateDigest;
// Reserve a 20 byte block for the initial key. We use 16 byte salts, and must reserve an extra
// 4 bytes for a suffix mandated by RFC5802.
std::array<std::uint8_t, 20> startKey;
@@ -92,25 +87,23 @@ static SHA1Hash HMACIteration(const unsigned char input[],
startKey[saltLen + 3] = 1;
// U1 = HMAC(input, salt + 0001)
- output = crypto::hmacSha1(input, inputLen, startKey.data(), startKey.size());
+ output = SHA1Block::computeHmac(input, inputLen, startKey.data(), startKey.size());
intermediateDigest = output;
// intermediateDigest contains Ui and output contains the accumulated XOR:ed result
for (size_t i = 2; i <= iterationCount; i++) {
- intermediateDigest =
- crypto::hmacSha1(input, inputLen, intermediateDigest.data(), intermediateDigest.size());
- for (size_t k = 0; k < output.size(); k++) {
- output[k] ^= intermediateDigest[k];
- }
+ intermediateDigest = SHA1Block::computeHmac(
+ input, inputLen, intermediateDigest.data(), intermediateDigest.size());
+ output.xorInline(intermediateDigest);
}
return output;
}
// Iterate the hash function to generate SaltedPassword
-SHA1Hash generateSaltedPassword(const SCRAMPresecrets& presecrets) {
+SHA1Block generateSaltedPassword(const SCRAMPresecrets& presecrets) {
// saltedPassword = Hi(hashedPassword, salt)
- SHA1Hash saltedPassword =
+ SHA1Block saltedPassword =
HMACIteration(reinterpret_cast<const unsigned char*>(presecrets.hashedPassword.c_str()),
presecrets.hashedPassword.size(),
presecrets.salt.data(),
@@ -121,30 +114,30 @@ SHA1Hash generateSaltedPassword(const SCRAMPresecrets& presecrets) {
}
SCRAMSecrets generateSecrets(const SCRAMPresecrets& presecrets) {
- SHA1Hash saltedPassword = generateSaltedPassword(presecrets);
+ SHA1Block saltedPassword = generateSaltedPassword(presecrets);
return generateSecrets(saltedPassword);
}
-SCRAMSecrets generateSecrets(const SHA1Hash& saltedPassword) {
+SCRAMSecrets generateSecrets(const SHA1Block& saltedPassword) {
SCRAMSecrets credentials;
// ClientKey := HMAC(saltedPassword, "Client Key")
credentials.clientKey =
- crypto::hmacSha1(saltedPassword.data(),
- saltedPassword.size(),
- reinterpret_cast<const unsigned char*>(clientKeyConst.data()),
- clientKeyConst.size());
+ SHA1Block::computeHmac(saltedPassword.data(),
+ saltedPassword.size(),
+ reinterpret_cast<const unsigned char*>(clientKeyConst.data()),
+ clientKeyConst.size());
// StoredKey := H(clientKey)
credentials.storedKey =
- crypto::sha1(credentials.clientKey->data(), credentials.clientKey->size());
+ SHA1Block::computeHash(credentials.clientKey->data(), credentials.clientKey->size());
// ServerKey := HMAC(SaltedPassword, "Server Key")
credentials.serverKey =
- crypto::hmacSha1(saltedPassword.data(),
- saltedPassword.size(),
- reinterpret_cast<const unsigned char*>(serverKeyConst.data()),
- serverKeyConst.size());
+ SHA1Block::computeHmac(saltedPassword.data(),
+ saltedPassword.size(),
+ reinterpret_cast<const unsigned char*>(serverKeyConst.data()),
+ serverKeyConst.size());
return credentials;
}
@@ -171,8 +164,8 @@ BSONObj generateCredentials(const std::string& hashedPassword, int iterationCoun
saltLenQWords * sizeof(uint64_t)),
iterationCount));
- std::string encodedStoredKey = hashToBase64(secrets.storedKey);
- std::string encodedServerKey = hashToBase64(secrets.serverKey);
+ std::string encodedStoredKey = secrets.storedKey->toString();
+ std::string encodedServerKey = secrets.serverKey->toString();
return BSON(iterationCountFieldName << iterationCount << saltFieldName << encodedUserSalt
<< storedKeyFieldName
@@ -184,32 +177,27 @@ BSONObj generateCredentials(const std::string& hashedPassword, int iterationCoun
std::string generateClientProof(const SCRAMSecrets& clientCredentials,
const std::string& authMessage) {
// ClientSignature := HMAC(StoredKey, AuthMessage)
- SHA1Hash clientSignature =
- crypto::hmacSha1(clientCredentials.storedKey->data(),
- clientCredentials.storedKey->size(),
- reinterpret_cast<const unsigned char*>(authMessage.c_str()),
- authMessage.size());
-
- // ClientProof := ClientKey XOR ClientSignature
- SHA1Hash clientProof;
- for (size_t i = 0; i < clientCredentials.clientKey->size(); i++) {
- clientProof[i] = (*clientCredentials.clientKey)[i] ^ clientSignature[i];
- }
-
- return hashToBase64(clientProof);
+ SHA1Block clientSignature =
+ SHA1Block::computeHmac(clientCredentials.storedKey->data(),
+ clientCredentials.storedKey->size(),
+ reinterpret_cast<const unsigned char*>(authMessage.c_str()),
+ authMessage.size());
+
+ clientSignature.xorInline(*clientCredentials.clientKey);
+ return clientSignature.toString();
}
bool verifyServerSignature(const SCRAMSecrets& clientCredentials,
const std::string& authMessage,
const std::string& receivedServerSignature) {
// ServerSignature := HMAC(ServerKey, AuthMessage)
- SHA1Hash serverSignature =
- crypto::hmacSha1(clientCredentials.serverKey->data(),
- clientCredentials.serverKey->size(),
- reinterpret_cast<const unsigned char*>(authMessage.c_str()),
- authMessage.size());
+ SHA1Block serverSignature =
+ SHA1Block::computeHmac(clientCredentials.serverKey->data(),
+ clientCredentials.serverKey->size(),
+ reinterpret_cast<const unsigned char*>(authMessage.c_str()),
+ authMessage.size());
- std::string encodedServerSignature = hashToBase64(serverSignature);
+ std::string encodedServerSignature = serverSignature.toString();
if (encodedServerSignature.size() != receivedServerSignature.size()) {
return false;
@@ -223,20 +211,20 @@ bool verifyServerSignature(const SCRAMSecrets& clientCredentials,
bool verifyClientProof(StringData clientProof, StringData storedKey, StringData authMessage) {
// ClientSignature := HMAC(StoredKey, AuthMessage)
- SHA1Hash clientSignature =
- crypto::hmacSha1(reinterpret_cast<const unsigned char*>(storedKey.rawData()),
- storedKey.size(),
- reinterpret_cast<const unsigned char*>(authMessage.rawData()),
- authMessage.size());
-
- // ClientKey := ClientSignature XOR ClientProof
- SHA1Hash clientKey;
- for (size_t i = 0; i < clientKey.size(); i++) {
- clientKey[i] = clientSignature[i] ^ clientProof.rawData()[i];
- }
-
- // StoredKey := H(ClientKey)
- SHA1Hash computedStoredKey = crypto::sha1(clientKey.data(), clientKey.size());
+ SHA1Block clientSignature =
+ SHA1Block::computeHmac(reinterpret_cast<const unsigned char*>(storedKey.rawData()),
+ storedKey.size(),
+ reinterpret_cast<const unsigned char*>(authMessage.rawData()),
+ authMessage.size());
+
+ auto clientProofSHA1Status = SHA1Block::fromBuffer(
+ reinterpret_cast<const uint8_t*>(clientProof.rawData()), clientProof.size());
+ uassertStatusOK(clientProofSHA1Status);
+ clientSignature.xorInline(clientProofSHA1Status.getValue());
+
+ // StoredKey := H(clientSignature)
+ SHA1Block computedStoredKey =
+ SHA1Block::computeHash(clientSignature.data(), clientSignature.size());
return consttimeMemEqual(reinterpret_cast<const unsigned char*>(storedKey.rawData()),
computedStoredKey.data(),