/**
* Copyright (C) 2014 10gen Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the GNU Affero General Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#pragma once
#include
#include
#include
#include
#include "mongo/base/secure_allocator.h"
#include "mongo/base/status.h"
#include "mongo/crypto/sha1_block.h"
#include "mongo/db/auth/authorization_manager.h"
#include "mongo/db/jsobj.h"
#include "mongo/platform/random.h"
#include "mongo/util/assert_util.h"
namespace mongo {
namespace scram {
constexpr auto kServerKeyConst = "Server Key"_sd;
constexpr auto kClientKeyConst = "Client Key"_sd;
constexpr auto kIterationCountFieldName = "iterationCount"_sd;
constexpr auto kSaltFieldName = "salt"_sd;
constexpr auto kStoredKeyFieldName = "storedKey"_sd;
constexpr auto kServerKeyFieldName = "serverKey"_sd;
const int kIterationCountMinimum = 4096;
/* The precursors necessary to perform the computation which produces SCRAMSecrets.
* These are the original password, its salt, and the number of times it must be
* hashed to produce the SaltedPassword used to generate the rest of the SCRAMSecrets.
*/
template
class Presecrets {
public:
Presecrets(std::string password, std::vector salt, size_t iterationCount)
: _password(std::move(password)), _salt(std::move(salt)), _iterationCount(iterationCount) {
uassert(17450, "invalid salt length provided", _salt.size() == saltLength());
uassert(50662, "invalid iteration count", _iterationCount >= kIterationCountMinimum);
}
HashBlock generateSaltedPassword() const noexcept {
// saltedPassword = Hi(hashedPassword, salt)
// Reserve a HashBlock::kHashLength block for the initial key.
// We use saltLength() salts, and reserve the extra for a suffix mandated by RFC5802.
std::array startKey;
std::copy(_salt.cbegin(), _salt.cend(), startKey.begin());
startKey[_salt.size() + 0] = 0;
startKey[_salt.size() + 1] = 0;
startKey[_salt.size() + 2] = 0;
startKey[_salt.size() + 3] = 1;
// U1 = HMAC(input, salt + 0001)
auto output =
HashBlock::computeHmac(reinterpret_cast(_password.c_str()),
_password.size(),
startKey.data(),
startKey.size());
auto intermediate = output;
// intermediateDigest contains Ui and output contains the accumulated XOR:ed result
invariant(_iterationCount >= kIterationCountMinimum);
for (size_t i = 1; i < _iterationCount; ++i) {
intermediate =
HashBlock::computeHmac(reinterpret_cast(_password.c_str()),
_password.size(),
intermediate.data(),
intermediate.size());
output.xorInline(intermediate);
}
return output;
}
static std::vector generateSecureRandomSalt() {
// Express salt length as a number of quad words, rounded up.
constexpr auto qwords = (saltLength() + sizeof(std::int64_t) - 1) / sizeof(std::int64_t);
std::array userSalt;
std::unique_ptr sr(SecureRandom::create());
std::generate(userSalt.begin(), userSalt.end(), [&sr] { return sr->nextInt64(); });
return std::vector(reinterpret_cast(userSalt.data()),
reinterpret_cast(userSalt.data()) +
saltLength());
}
private:
template
friend bool operator==(const Presecrets&, const Presecrets&);
auto equalityLens() const {
return std::tie(_password, _salt, _iterationCount);
}
static constexpr auto saltLength() {
return HashBlock::kHashLength - 4;
}
std::string _password;
std::vector _salt;
size_t _iterationCount;
};
template
bool operator==(const Presecrets& lhs, const Presecrets& rhs) {
return lhs.equalityLens() == rhs.equalityLens();
}
template
bool operator!=(const Presecrets& lhs, const Presecrets& rhs) {
return !(lhs == rhs);
}
/* Stores all of the keys, generated from a password, needed for a client or server to perform a
* SCRAM handshake.
* These keys are reference counted, and allocated using the SecureAllocator.
* May be unpopulated. SCRAMSecrets created via the default constructor are unpopulated.
* The behavior is undefined if the accessors are called when unpopulated.
*/
template
class Secrets {
private:
struct SecretsHolder {
HashBlock clientKey;
HashBlock storedKey;
HashBlock serverKey;
};
using SecureSecrets = SecureAllocatorAuthDomain::SecureHandle;
public:
Secrets() = default;
Secrets(StringData client, StringData stored, StringData server)
: _ptr(std::make_shared()) {
if (!client.empty()) {
(*_ptr)->clientKey = uassertStatusOK(HashBlock::fromBuffer(
reinterpret_cast(client.rawData()), client.size()));
}
(*_ptr)->storedKey = uassertStatusOK(HashBlock::fromBuffer(
reinterpret_cast(stored.rawData()), stored.size()));
(*_ptr)->serverKey = uassertStatusOK(HashBlock::fromBuffer(
reinterpret_cast(server.rawData()), stored.size()));
}
Secrets(const HashBlock& saltedPassword) : _ptr(std::make_shared()) {
// ClientKey := HMAC(saltedPassword, "Client Key")
(*_ptr)->clientKey = HashBlock::computeHmac(
saltedPassword.data(),
saltedPassword.size(),
reinterpret_cast(kClientKeyConst.rawData()),
kClientKeyConst.size());
// StoredKey := H(clientKey)
(*_ptr)->storedKey = HashBlock::computeHash(clientKey().data(), clientKey().size());
// ServerKey := HMAC(SaltedPassword, "Server Key")
(*_ptr)->serverKey = HashBlock::computeHmac(
saltedPassword.data(),
saltedPassword.size(),
reinterpret_cast(kServerKeyConst.rawData()),
kServerKeyConst.size());
}
Secrets(const Presecrets& presecrets)
: Secrets(presecrets.generateSaltedPassword()) {}
std::string generateClientProof(StringData authMessage) const {
// ClientProof := HMAC(StoredKey, AuthMessage) ^ ClientKey
auto proof =
HashBlock::computeHmac(storedKey().data(),
storedKey().size(),
reinterpret_cast(authMessage.rawData()),
authMessage.size());
proof.xorInline(clientKey());
return proof.toString();
}
bool verifyClientProof(StringData authMessage, StringData proof) const {
// ClientKey := HMAC(StoredKey, AuthMessage) ^ ClientProof
auto key =
HashBlock::computeHmac(storedKey().data(),
storedKey().size(),
reinterpret_cast(authMessage.rawData()),
authMessage.size());
key.xorInline(uassertStatusOK(HashBlock::fromBuffer(
reinterpret_cast(proof.rawData()), proof.size())));
// StoredKey := H(ClientKey)
auto exp = HashBlock::computeHash(key.data(), key.size());
if ((exp.size() != HashBlock::kHashLength) ||
(storedKey().size() != HashBlock::kHashLength)) {
return false;
}
return consttimeMemEqual(reinterpret_cast(exp.data()),
storedKey().data(),
HashBlock::kHashLength);
}
std::string generateServerSignature(StringData authMessage) const {
// ServerSignature := HMAC(ServerKey, AuthMessage)
return HashBlock::computeHmac(serverKey().data(),
serverKey().size(),
reinterpret_cast(authMessage.rawData()),
authMessage.size())
.toString();
}
bool verifyServerSignature(StringData authMessage, StringData sig) const {
// ServerSignature := HMAC(ServerKey, AuthMessage)
const auto exp =
HashBlock::computeHmac(serverKey().data(),
serverKey().size(),
reinterpret_cast(authMessage.rawData()),
authMessage.size());
if ((sig.size() != HashBlock::kHashLength) || (exp.size() != HashBlock::kHashLength)) {
return false;
}
return consttimeMemEqual(reinterpret_cast(sig.rawData()),
reinterpret_cast(exp.data()),
HashBlock::kHashLength);
}
static BSONObj generateCredentials(std::string password, int iterationCount) {
auto salt = Presecrets::generateSecureRandomSalt();
Secrets secrets(Presecrets(password, salt, iterationCount));
const auto encodedSalt = base64::encode(reinterpret_cast(salt.data()), salt.size());
return BSON(kIterationCountFieldName << iterationCount << kSaltFieldName << encodedSalt
<< kStoredKeyFieldName
<< secrets.storedKey().toString()
<< kServerKeyFieldName
<< secrets.serverKey().toString());
}
const HashBlock& clientKey() const {
auto& ret = (*_ptr)->clientKey;
uassert(
ErrorCodes::BadValue, "Invalid SCRAM client key", ret.size() == HashBlock::kHashLength);
return ret;
}
const HashBlock& storedKey() const {
auto& ret = (*_ptr)->storedKey;
uassert(
ErrorCodes::BadValue, "Invalid SCRAM stored key", ret.size() == HashBlock::kHashLength);
return ret;
}
const HashBlock& serverKey() const {
auto& ret = (*_ptr)->serverKey;
uassert(
ErrorCodes::BadValue, "Invalid SCRAM server key", ret.size() == HashBlock::kHashLength);
return ret;
}
operator bool() const {
return (bool)_ptr;
}
private:
std::shared_ptr _ptr;
};
using SHA1Presecrets = Presecrets;
using SHA1Secrets = Secrets;
} // namespace scram
} // namespace mongo