summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/db/SConscript2
-rw-r--r--src/mongo/db/commands/dbcommands.cpp1
-rw-r--r--src/mongo/db/db.cpp7
-rw-r--r--src/mongo/db/keys_collection_manager.cpp39
-rw-r--r--src/mongo/db/keys_collection_manager.h15
-rw-r--r--src/mongo/db/keys_collection_manager_test.cpp41
-rw-r--r--src/mongo/db/logical_time_metadata_hook.cpp7
-rw-r--r--src/mongo/db/logical_time_validator.cpp101
-rw-r--r--src/mongo/db/logical_time_validator.h38
-rw-r--r--src/mongo/db/logical_time_validator_test.cpp101
-rw-r--r--src/mongo/db/namespace_string.cpp3
-rw-r--r--src/mongo/db/repl/replication_coordinator_external_state_impl.cpp18
-rw-r--r--src/mongo/db/run_commands.cpp21
-rw-r--r--src/mongo/rpc/metadata.cpp42
-rw-r--r--src/mongo/s/commands/strategy.cpp7
-rw-r--r--src/mongo/s/server.cpp7
-rw-r--r--src/mongo/s/sharding_initialization.cpp20
17 files changed, 315 insertions, 155 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index 8a99cd03103..7e1d88abade 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -1058,6 +1058,8 @@ env.CppUnitTest(
],
LIBDEPS=[
'logical_time_validator',
+ '$BUILD_DIR/mongo/s/config_server_test_fixture',
+ '$BUILD_DIR/mongo/s/coreshard',
],
)
diff --git a/src/mongo/db/commands/dbcommands.cpp b/src/mongo/db/commands/dbcommands.cpp
index e1383becee1..949f6d717ff 100644
--- a/src/mongo/db/commands/dbcommands.cpp
+++ b/src/mongo/db/commands/dbcommands.cpp
@@ -1200,5 +1200,4 @@ public:
}
} availableQueryOptionsCmd;
-
} // namespace mongo
diff --git a/src/mongo/db/db.cpp b/src/mongo/db/db.cpp
index a76bdac50f5..c8ae94c4785 100644
--- a/src/mongo/db/db.cpp
+++ b/src/mongo/db/db.cpp
@@ -75,6 +75,7 @@
#include "mongo/db/log_process_details.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time_metadata_hook.h"
+#include "mongo/db/logical_time_validator.h"
#include "mongo/db/mongod_options.h"
#include "mongo/db/op_observer_impl.h"
#include "mongo/db/operation_context.h"
@@ -952,6 +953,12 @@ static void shutdownTask() {
sr->shutdown();
}
+ // Validator shutdown must be called after setKillAllOperations is called. Otherwise, this can
+ // deadlock.
+ if (auto validator = LogicalTimeValidator::get(serviceContext)) {
+ validator->shutDown();
+ }
+
#if __has_feature(address_sanitizer)
if (auto sep = checked_cast<ServiceEntryPointImpl*>(serviceContext->getServiceEntryPoint())) {
// When running under address sanitizer, we get false positive leaks due to disorder around
diff --git a/src/mongo/db/keys_collection_manager.cpp b/src/mongo/db/keys_collection_manager.cpp
index 12fd318df09..5dea71eb123 100644
--- a/src/mongo/db/keys_collection_manager.cpp
+++ b/src/mongo/db/keys_collection_manager.cpp
@@ -96,28 +96,8 @@ StatusWith<KeysCollectionDocument> KeysCollectionManager::getKeyForValidation(
}
StatusWith<KeysCollectionDocument> KeysCollectionManager::getKeyForSigning(
- OperationContext* opCtx, const LogicalTime& forThisTime) {
-
- auto keyStatusWith = _getKey(forThisTime);
- auto keyStatus = keyStatusWith.getStatus();
-
- if (keyStatus != ErrorCodes::KeyNotFound) {
- return keyStatusWith;
- }
-
- do {
- _refresher.refreshNow(opCtx);
-
- keyStatusWith = _getKey(forThisTime);
- keyStatus = keyStatusWith.getStatus();
-
- if (keyStatus == ErrorCodes::KeyNotFound) {
- sleepFor(kRefreshIntervalIfErrored);
- }
-
- } while (keyStatus == ErrorCodes::KeyNotFound);
-
- return keyStatusWith;
+ const LogicalTime& forThisTime) {
+ return _getKey(forThisTime);
}
StatusWith<KeysCollectionDocument> KeysCollectionManager::_getKeyWithKeyIdCheck(
@@ -160,6 +140,10 @@ StatusWith<KeysCollectionDocument> KeysCollectionManager::_getKey(const LogicalT
return key;
}
+void KeysCollectionManager::refreshNow(OperationContext* opCtx) {
+ _refresher.refreshNow(opCtx);
+}
+
void KeysCollectionManager::startMonitoring(ServiceContext* service) {
_refresher.setFunc([this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
_refresher.start(
@@ -230,6 +214,7 @@ void KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh(ServiceContext* s
while (true) {
auto opCtx = cc().makeOperationContext();
+ bool hasRefreshRequestInitially = false;
std::shared_ptr<RefreshFunc> doRefresh;
{
stdx::lock_guard<stdx::mutex> lock(_mutex);
@@ -240,6 +225,7 @@ void KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh(ServiceContext* s
invariant(_doRefresh.get() != nullptr);
doRefresh = _doRefresh;
+ hasRefreshRequestInitially = _refreshRequest.get() != nullptr;
}
Milliseconds nextWakeup = kRefreshIntervalIfErrored;
@@ -258,6 +244,11 @@ void KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh(ServiceContext* s
stdx::unique_lock<stdx::mutex> lock(_mutex);
if (_refreshRequest) {
+ if (!hasRefreshRequestInitially) {
+ // A fresh request came in, fulfill the request before going to sleep.
+ continue;
+ }
+
_refreshRequest->set();
_refreshRequest.reset();
}
@@ -290,10 +281,6 @@ void KeysCollectionManager::PeriodicRunner::setFunc(RefreshFunc newRefreshStrate
void KeysCollectionManager::PeriodicRunner::switchFunc(OperationContext* opCtx,
RefreshFunc newRefreshStrategy) {
setFunc(newRefreshStrategy);
-
- // Note: calling refreshNow will ensure that if there is an ongoing method call to the original
- // refreshStrategy, it will be finished after this.
- refreshNow(opCtx);
}
void KeysCollectionManager::PeriodicRunner::start(ServiceContext* service,
diff --git a/src/mongo/db/keys_collection_manager.h b/src/mongo/db/keys_collection_manager.h
index b05b5046a7c..bfe4a27a408 100644
--- a/src/mongo/db/keys_collection_manager.h
+++ b/src/mongo/db/keys_collection_manager.h
@@ -53,9 +53,6 @@ class ShardingCatalogClient;
*/
class KeysCollectionManager {
public:
- /**
- * Creates a new instance of key manager. This should outlive the client.
- */
KeysCollectionManager(std::string purpose,
ShardingCatalogClient* client,
Seconds keyValidForInterval);
@@ -71,13 +68,17 @@ public:
const LogicalTime& forThisTime);
/**
- * Return a key that is valid for the given time. Note that this call can block if it will need
- * to do a refresh.
+ * Returns a key that is valid for the given time. Note that unlike getKeyForValidation, this
+ * will never do a refresh.
*
* Throws ErrorCode::ExceededTimeLimit if it times out.
*/
- StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx,
- const LogicalTime& forThisTime);
+ StatusWith<KeysCollectionDocument> getKeyForSigning(const LogicalTime& forThisTime);
+
+ /**
+ * Request this manager to perform a refresh.
+ */
+ void refreshNow(OperationContext* opCtx);
/**
* Starts a background thread that will constantly update the internal cache of keys.
diff --git a/src/mongo/db/keys_collection_manager_test.cpp b/src/mongo/db/keys_collection_manager_test.cpp
index a0f671aedeb..3abc415cf8d 100644
--- a/src/mongo/db/keys_collection_manager_test.cpp
+++ b/src/mongo/db/keys_collection_manager_test.cpp
@@ -161,24 +161,6 @@ TEST_F(KeysManagerTest, GetKeyWithoutRefreshShouldReturnRightKey) {
}
}
-TEST_F(KeysManagerTest, GetKeyForSigningTimesOutIfRefresherIsNotRunning) {
- operationContext()->setDeadlineAfterNowBy(Microseconds(250 * 1000));
-
- ASSERT_THROWS(
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(100, 0))),
- DBException);
-}
-
-TEST_F(KeysManagerTest, GetKeyForSigningTimesOutIfKeyDoesntExist) {
- keyManager()->startMonitoring(getServiceContext());
-
- operationContext()->setDeadlineAfterNowBy(Microseconds(250 * 1000));
-
- ASSERT_THROWS(
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(100, 0))),
- DBException);
-}
-
TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) {
keyManager()->startMonitoring(getServiceContext());
@@ -187,8 +169,9 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) {
ASSERT_OK(insertToConfigCollection(
operationContext(), NamespaceString(KeysCollectionDocument::ConfigNS), origKey1.toBSON()));
- auto keyStatus =
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(100, 0)));
+ keyManager()->refreshNow(operationContext());
+
+ auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -197,7 +180,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) {
ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp());
}
-TEST_F(KeysManagerTest, GetKeyForSigningWithoutRefreshShouldReturnRightKey) {
+TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -209,9 +192,10 @@ TEST_F(KeysManagerTest, GetKeyForSigningWithoutRefreshShouldReturnRightKey) {
ASSERT_OK(insertToConfigCollection(
operationContext(), NamespaceString(KeysCollectionDocument::ConfigNS), origKey2.toBSON()));
+ keyManager()->refreshNow(operationContext());
+
{
- auto keyStatus =
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(100, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -221,8 +205,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningWithoutRefreshShouldReturnRightKey) {
}
{
- auto keyStatus =
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(105, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(105, 0)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -239,9 +222,9 @@ TEST_F(KeysManagerTest, ShouldCreateKeysIfKeyGeneratorEnabled) {
LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime);
keyManager()->enableKeyGenerator(operationContext(), true);
+ keyManager()->refreshNow(operationContext());
- auto keyStatus =
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(100, 100)));
+ auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 100)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -258,9 +241,9 @@ TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) {
for (int x = 0; x < 10; x++) {
keyManager()->enableKeyGenerator(operationContext(), doEnable);
+ keyManager()->refreshNow(operationContext());
- auto keyStatus =
- keyManager()->getKeyForSigning(operationContext(), LogicalTime(Timestamp(100, 100)));
+ auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 100)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
diff --git a/src/mongo/db/logical_time_metadata_hook.cpp b/src/mongo/db/logical_time_metadata_hook.cpp
index 29949e36c82..561765dbf03 100644
--- a/src/mongo/db/logical_time_metadata_hook.cpp
+++ b/src/mongo/db/logical_time_metadata_hook.cpp
@@ -49,7 +49,7 @@ Status LogicalTimeMetadataHook::writeRequestMetadata(OperationContext* opCtx,
}
auto newTime = LogicalClock::get(_service)->getClusterTime();
- LogicalTimeMetadata metadata(validator->signLogicalTime(newTime));
+ LogicalTimeMetadata metadata(validator->trySignLogicalTime(newTime));
metadata.writeToMetadata(metadataBob);
return Status::OK();
}
@@ -69,11 +69,6 @@ Status LogicalTimeMetadataHook::readReplyMetadata(StringData replySource,
return Status::OK();
}
- auto validator = LogicalTimeValidator::get(_service);
- if (validator) {
- validator->updateCacheTrustedSource(signedTime);
- }
-
return LogicalClock::get(_service)->advanceClusterTime(signedTime.getTime());
}
diff --git a/src/mongo/db/logical_time_validator.cpp b/src/mongo/db/logical_time_validator.cpp
index 6a3de2ae9d4..8387198f6f9 100644
--- a/src/mongo/db/logical_time_validator.cpp
+++ b/src/mongo/db/logical_time_validator.cpp
@@ -30,21 +30,37 @@
#include "mongo/db/logical_time_validator.h"
+#include "mongo/base/init.h"
+#include "mongo/db/auth/action_set.h"
+#include "mongo/db/auth/action_type.h"
+#include "mongo/db/auth/authorization_session.h"
+#include "mongo/db/auth/privilege.h"
+#include "mongo/db/keys_collection_manager.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/service_context.h"
+#include "mongo/util/assert_util.h"
namespace mongo {
namespace {
const auto getLogicalClockValidator =
ServiceContext::declareDecoration<std::unique_ptr<LogicalTimeValidator>>();
-stdx::mutex validatorMutex;
-// TODO: SERVER-28127 Implement KeysCollectionManager
-// Remove _tempKey and its uses from logical clock, and pass actual key from key manager.
-TimeProofService::Key tempKey = {};
+stdx::mutex validatorMutex; // protects access to decoration instance of LogicalTimeValidator.
+
+std::vector<Privilege> advanceLogicalClockPrivilege;
+
+MONGO_INITIALIZER(InitializeAdvanceLogicalClockPrivilegeVector)(InitializerContext* const) {
+ ActionSet actions;
+ actions.addAction(ActionType::internal);
+ advanceLogicalClockPrivilege.emplace_back(ResourcePattern::forClusterResource(), actions);
+ return Status::OK();
}
+Milliseconds kRefreshIntervalIfErrored(200);
+
+} // unnamed namespace
+
LogicalTimeValidator* LogicalTimeValidator::get(ServiceContext* service) {
stdx::lock_guard<stdx::mutex> lk(validatorMutex);
return getLogicalClockValidator(service).get();
@@ -61,7 +77,13 @@ void LogicalTimeValidator::set(ServiceContext* service,
validator = std::move(newValidator);
}
-SignedLogicalTime LogicalTimeValidator::signLogicalTime(const LogicalTime& newTime) {
+LogicalTimeValidator::LogicalTimeValidator(std::unique_ptr<KeysCollectionManager> keyManager)
+ : _keyManager(std::move(keyManager)) {}
+
+SignedLogicalTime LogicalTimeValidator::_getProof(const KeysCollectionDocument& keyDoc,
+ LogicalTime newTime) {
+ auto key = keyDoc.getKey();
+
// Compare and calculate HMAC inside mutex to prevent multiple threads computing HMAC for the
// same logical time.
stdx::lock_guard<stdx::mutex> lk(_mutex);
@@ -70,8 +92,8 @@ SignedLogicalTime LogicalTimeValidator::signLogicalTime(const LogicalTime& newTi
return _lastSeenValidTime;
}
- auto signature = _timeProofService.getProof(newTime, tempKey);
- SignedLogicalTime newSignedTime(newTime, std::move(signature), 0);
+ auto signature = _timeProofService.getProof(newTime, key);
+ SignedLogicalTime newSignedTime(newTime, std::move(signature), keyDoc.getKeyId());
if (newTime > _lastSeenValidTime.getTime() || !_lastSeenValidTime.getProof()) {
_lastSeenValidTime = newSignedTime;
@@ -80,7 +102,40 @@ SignedLogicalTime LogicalTimeValidator::signLogicalTime(const LogicalTime& newTi
return newSignedTime;
}
-Status LogicalTimeValidator::validate(const SignedLogicalTime& newTime) {
+SignedLogicalTime LogicalTimeValidator::trySignLogicalTime(const LogicalTime& newTime) {
+ auto keyStatusWith = _keyManager->getKeyForSigning(newTime);
+ auto keyStatus = keyStatusWith.getStatus();
+
+ if (keyStatus == ErrorCodes::KeyNotFound) {
+ // Attach invalid signature and keyId if we don't have the right keys to sign it.
+ return SignedLogicalTime(newTime, TimeProofService::TimeProof(), 0);
+ }
+
+ uassertStatusOK(keyStatus);
+ return _getProof(keyStatusWith.getValue(), newTime);
+}
+
+SignedLogicalTime LogicalTimeValidator::signLogicalTime(OperationContext* opCtx,
+ const LogicalTime& newTime) {
+ auto keyStatusWith = _keyManager->getKeyForSigning(newTime);
+ auto keyStatus = keyStatusWith.getStatus();
+
+ while (keyStatus == ErrorCodes::KeyNotFound) {
+ _keyManager->refreshNow(opCtx);
+
+ keyStatusWith = _keyManager->getKeyForSigning(newTime);
+ keyStatus = keyStatusWith.getStatus();
+
+ if (keyStatus == ErrorCodes::KeyNotFound) {
+ sleepFor(kRefreshIntervalIfErrored);
+ }
+ }
+
+ uassertStatusOK(keyStatus);
+ return _getProof(keyStatusWith.getValue(), newTime);
+}
+
+Status LogicalTimeValidator::validate(OperationContext* opCtx, const SignedLogicalTime& newTime) {
{
stdx::lock_guard<stdx::mutex> lk(_mutex);
if (newTime.getTime() == _lastSeenValidTime.getTime()) {
@@ -88,12 +143,17 @@ Status LogicalTimeValidator::validate(const SignedLogicalTime& newTime) {
}
}
+ auto keyStatus = _keyManager->getKeyForValidation(opCtx, newTime.getKeyId(), newTime.getTime());
+ uassertStatusOK(keyStatus.getStatus());
+
+ const auto& key = keyStatus.getValue().getKey();
+
const auto newProof = newTime.getProof();
// Logical time is only sent if a server's clock can verify and sign logical times, so any
// received logical times should have proofs.
invariant(newProof);
- auto res = _timeProofService.checkProof(newTime.getTime(), newProof.get(), tempKey);
+ auto res = _timeProofService.checkProof(newTime.getTime(), newProof.get(), key);
if (res != Status::OK()) {
return res;
}
@@ -101,11 +161,24 @@ Status LogicalTimeValidator::validate(const SignedLogicalTime& newTime) {
return Status::OK();
}
-void LogicalTimeValidator::updateCacheTrustedSource(const SignedLogicalTime& newTime) {
- stdx::lock_guard<stdx::mutex> lk(_mutex);
- if (newTime.getTime() > _lastSeenValidTime.getTime()) {
- _lastSeenValidTime = newTime;
- }
+void LogicalTimeValidator::init(ServiceContext* service) {
+ _keyManager->startMonitoring(service);
+}
+
+void LogicalTimeValidator::shutDown() {
+ _keyManager->stopMonitoring();
+}
+
+void LogicalTimeValidator::enableKeyGenerator(OperationContext* opCtx, bool doEnable) {
+ _keyManager->enableKeyGenerator(opCtx, doEnable);
+}
+
+bool LogicalTimeValidator::isAuthorizedToAdvanceClock(OperationContext* opCtx) {
+ auto client = opCtx->getClient();
+ // Note: returns true if auth is off, courtesy of
+ // AuthzSessionExternalStateServerCommon::shouldIgnoreAuthChecks.
+ return AuthorizationSession::get(client)->isAuthorizedForPrivileges(
+ advanceLogicalClockPrivilege);
}
} // namespace mongo
diff --git a/src/mongo/db/logical_time_validator.h b/src/mongo/db/logical_time_validator.h
index 74bf93077bf..24be173d2dd 100644
--- a/src/mongo/db/logical_time_validator.h
+++ b/src/mongo/db/logical_time_validator.h
@@ -38,6 +38,8 @@ namespace mongo {
class OperationContext;
class ServiceContext;
+class KeysCollectionDocument;
+class KeysCollectionManager;
/**
* This is responsible for signing logical times that can be used to sent to other servers and
@@ -50,26 +52,52 @@ public:
static LogicalTimeValidator* get(OperationContext* ctx);
static void set(ServiceContext* service, std::unique_ptr<LogicalTimeValidator> validator);
+ explicit LogicalTimeValidator(std::unique_ptr<KeysCollectionManager> keyManager);
+
+ /**
+ * Tries to sign the newTime with a valid signature. Can return an empty signature and keyId
+ * of 0 if it cannot find valid key for newTime.
+ */
+ SignedLogicalTime trySignLogicalTime(const LogicalTime& newTime);
+
/**
* Returns the newTime with a valid signature.
*/
- SignedLogicalTime signLogicalTime(const LogicalTime& newTime);
+ SignedLogicalTime signLogicalTime(OperationContext* opCtx, const LogicalTime& newTime);
/**
* Returns true if the signature of newTime is valid.
*/
- Status validate(const SignedLogicalTime& newTime);
+ Status validate(OperationContext* opCtx, const SignedLogicalTime& newTime);
+
+ /**
+ * Initializes this validator. This should be called first before the other methods can be used.
+ */
+ void init(ServiceContext* service);
/**
- * Saves the newTime if it is newer than the last seen valid LogicalTime without performing
- * validation.
+ * Cleans up this validator. This will no longer be usable after this is called.
*/
- void updateCacheTrustedSource(const SignedLogicalTime& newTime);
+ void shutDown();
+
+ /**
+ * Enable writing new keys to the config server primary. Should only be called if current node
+ * is the config primary.
+ */
+ void enableKeyGenerator(OperationContext* opCtx, bool doEnable);
+
+ /**
+ * Returns true if client has sufficient privilege to advance clock.
+ */
+ static bool isAuthorizedToAdvanceClock(OperationContext* opCtx);
private:
+ SignedLogicalTime _getProof(const KeysCollectionDocument& keyDoc, LogicalTime newTime);
+
stdx::mutex _mutex;
SignedLogicalTime _lastSeenValidTime;
TimeProofService _timeProofService;
+ std::unique_ptr<KeysCollectionManager> _keyManager;
};
} // namespace mongo
diff --git a/src/mongo/db/logical_time_validator_test.cpp b/src/mongo/db/logical_time_validator_test.cpp
index 3eb328c7044..2aeb11462a8 100644
--- a/src/mongo/db/logical_time_validator_test.cpp
+++ b/src/mongo/db/logical_time_validator_test.cpp
@@ -29,13 +29,18 @@
#include "mongo/platform/basic.h"
#include "mongo/bson/timestamp.h"
+#include "mongo/db/keys_collection_manager.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time.h"
#include "mongo/db/logical_time_validator.h"
-#include "mongo/db/service_context_noop.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/db/service_context.h"
#include "mongo/db/signed_logical_time.h"
#include "mongo/db/time_proof_service.h"
#include "mongo/platform/basic.h"
+#include "mongo/s/catalog/dist_lock_manager_mock.h"
+#include "mongo/s/config_server_test_fixture.h"
+#include "mongo/s/grid.h"
#include "mongo/stdx/memory.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/clock_source_mock.h"
@@ -43,41 +48,105 @@
namespace mongo {
namespace {
-TEST(LogicalTimeValidator, GetTimeWithIncreasingTimes) {
- LogicalTimeValidator validator;
-
+class LogicalTimeValidatorTest : public ConfigServerTestFixture {
+public:
+ LogicalTimeValidator* validator() {
+ return _validator.get();
+ }
+
+protected:
+ void setUp() override {
+ ConfigServerTestFixture::setUp();
+
+ auto clockSource = stdx::make_unique<ClockSourceMock>();
+ operationContext()->getServiceContext()->setFastClockSource(std::move(clockSource));
+ auto catalogClient = Grid::get(operationContext())->catalogClient(operationContext());
+
+ const LogicalTime currentTime(LogicalTime(Timestamp(1, 0)));
+ LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime);
+
+ auto keyManager =
+ stdx::make_unique<KeysCollectionManager>("dummy", catalogClient, Seconds(1000));
+ _keyManager = keyManager.get();
+ _validator = stdx::make_unique<LogicalTimeValidator>(std::move(keyManager));
+ _validator->init(operationContext()->getServiceContext());
+ _validator->enableKeyGenerator(operationContext(), true);
+ }
+
+ void tearDown() override {
+ _validator->shutDown();
+ ConfigServerTestFixture::tearDown();
+ }
+
+ std::unique_ptr<DistLockManager> makeDistLockManager(
+ std::unique_ptr<DistLockCatalog> distLockCatalog) override {
+ invariant(distLockCatalog);
+ return stdx::make_unique<DistLockManagerMock>(std::move(distLockCatalog));
+ }
+
+ /**
+ * Forces KeyManager to refresh cache and generate new keys.
+ */
+ void refreshKeyManager() {
+ _keyManager->refreshNow(operationContext());
+ }
+
+private:
+ std::unique_ptr<LogicalTimeValidator> _validator;
+ KeysCollectionManager* _keyManager;
+};
+
+TEST_F(LogicalTimeValidatorTest, GetTimeWithIncreasingTimes) {
LogicalTime t1(Timestamp(10, 0));
- auto newTime = validator.signLogicalTime(t1);
+ auto newTime = validator()->trySignLogicalTime(t1);
ASSERT_EQ(t1.asTimestamp(), newTime.getTime().asTimestamp());
ASSERT_TRUE(newTime.getProof());
LogicalTime t2(Timestamp(20, 0));
- auto newTime2 = validator.signLogicalTime(t2);
+ auto newTime2 = validator()->trySignLogicalTime(t2);
ASSERT_EQ(t2.asTimestamp(), newTime2.getTime().asTimestamp());
ASSERT_TRUE(newTime2.getProof());
}
-TEST(LogicalTimeValidator, ValidateReturnsOkForValidSignature) {
- LogicalTimeValidator validator;
+TEST_F(LogicalTimeValidatorTest, ValidateReturnsOkForValidSignature) {
+ LogicalTime t1(Timestamp(20, 0));
+ refreshKeyManager();
+ auto newTime = validator()->trySignLogicalTime(t1);
+ ASSERT_OK(validator()->validate(operationContext(), newTime));
+}
+
+TEST_F(LogicalTimeValidatorTest, ValidateErrorsOnInvalidTime) {
LogicalTime t1(Timestamp(20, 0));
- auto newTime = validator.signLogicalTime(t1);
+ refreshKeyManager();
+ auto newTime = validator()->trySignLogicalTime(t1);
- ASSERT_OK(validator.validate(newTime));
+ TimeProofService::TimeProof invalidProof = {{1, 2, 3}};
+ SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, newTime.getKeyId());
+ // ASSERT_THROWS_CODE(validator()->validate(operationContext(), invalidTime), DBException,
+ // ErrorCodes::TimeProofMismatch);
+ auto status = validator()->validate(operationContext(), invalidTime);
+ ASSERT_EQ(ErrorCodes::TimeProofMismatch, status);
}
-TEST(LogicalTimeValidator, ValidateErrorsOnInvalidTime) {
- LogicalTimeValidator validator;
+TEST_F(LogicalTimeValidatorTest, ValidateReturnsOkForValidSignatureWithImplicitRefresh) {
+ LogicalTime t1(Timestamp(20, 0));
+ auto newTime = validator()->signLogicalTime(operationContext(), t1);
+ ASSERT_OK(validator()->validate(operationContext(), newTime));
+}
+
+TEST_F(LogicalTimeValidatorTest, ValidateErrorsOnInvalidTimeWithImplicitRefresh) {
LogicalTime t1(Timestamp(20, 0));
- auto newTime = validator.signLogicalTime(t1);
+ auto newTime = validator()->signLogicalTime(operationContext(), t1);
TimeProofService::TimeProof invalidProof = {{1, 2, 3}};
- SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, 0);
-
- auto status = validator.validate(invalidTime);
+ SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, newTime.getKeyId());
+ // ASSERT_THROWS_CODE(validator()->validate(operationContext(), invalidTime), DBException,
+ // ErrorCodes::TimeProofMismatch);
+ auto status = validator()->validate(operationContext(), invalidTime);
ASSERT_EQ(ErrorCodes::TimeProofMismatch, status);
}
diff --git a/src/mongo/db/namespace_string.cpp b/src/mongo/db/namespace_string.cpp
index d86878e4aef..f16e9a936de 100644
--- a/src/mongo/db/namespace_string.cpp
+++ b/src/mongo/db/namespace_string.cpp
@@ -72,6 +72,7 @@ const string escapeTable[256] = {
".252", ".253", ".254", ".255"};
const char kConfigCollection[] = "admin.system.version";
+const char kLogicalTimeKeysCollection[] = "admin.system.keys";
constexpr auto listCollectionsCursorCol = "$cmd.listCollections"_sd;
constexpr auto listIndexesCursorNSPrefix = "$cmd.listIndexes."_sd;
@@ -90,6 +91,8 @@ bool legalClientSystemNS(StringData ns) {
return true;
if (ns == kConfigCollection)
return true;
+ if (ns == kLogicalTimeKeysCollection)
+ return true;
if (ns == "admin.system.new_users")
return true;
if (ns == "admin.system.backup_users")
diff --git a/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp b/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp
index b0c4a14e175..32c750d2827 100644
--- a/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp
+++ b/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp
@@ -48,6 +48,7 @@
#include "mongo/db/dbhelpers.h"
#include "mongo/db/jsobj.h"
#include "mongo/db/logical_time_metadata_hook.h"
+#include "mongo/db/logical_time_validator.h"
#include "mongo/db/op_observer.h"
#include "mongo/db/repair_database.h"
#include "mongo/db/repl/bgsync.h"
@@ -688,6 +689,19 @@ void ReplicationCoordinatorExternalStateImpl::shardingOnStepDownHook() {
}
ShardingState::get(_service)->markCollectionsNotShardedAtStepdown();
+
+ if (serverGlobalParams.clusterRole == ClusterRole::ConfigServer) {
+ if (auto validator = LogicalTimeValidator::get(_service)) {
+ auto opCtx = cc().getOperationContext();
+
+ if (opCtx != nullptr) {
+ validator->enableKeyGenerator(opCtx, false);
+ } else {
+ auto opCtxPtr = cc().makeOperationContext();
+ validator->enableKeyGenerator(opCtxPtr.get(), false);
+ }
+ }
+ }
}
void ReplicationCoordinatorExternalStateImpl::_shardingOnTransitionToPrimaryHook(
@@ -743,6 +757,10 @@ void ReplicationCoordinatorExternalStateImpl::_shardingOnTransitionToPrimaryHook
// If this is a config server node becoming a primary, start the balancer
Balancer::get(opCtx)->initiateBalancer(opCtx);
+
+ if (auto validator = LogicalTimeValidator::get(_service)) {
+ validator->enableKeyGenerator(opCtx, true);
+ }
} else if (ShardingState::get(opCtx)->enabled()) {
invariant(serverGlobalParams.clusterRole == ClusterRole::ShardServer);
diff --git a/src/mongo/db/run_commands.cpp b/src/mongo/db/run_commands.cpp
index 6e2f0b26388..704549bf868 100644
--- a/src/mongo/db/run_commands.cpp
+++ b/src/mongo/db/run_commands.cpp
@@ -226,15 +226,18 @@ void appendReplyMetadata(OperationContext* opCtx,
if (isShardingAware || isConfig) {
rpc::ShardingMetadata(lastOpTimeFromClient, replCoord->getElectionId())
.writeToMetadata(metadataBob);
- }
-
- auto validator = LogicalTimeValidator::get(opCtx);
- if (validator) {
-
- auto currentTime =
- validator->signLogicalTime(LogicalClock::get(opCtx)->getClusterTime());
- rpc::LogicalTimeMetadata logicalTimeMetadata(currentTime);
- logicalTimeMetadata.writeToMetadata(metadataBob);
+ if (LogicalTimeValidator::isAuthorizedToAdvanceClock(opCtx)) {
+ // No need to sign logical times for internal clients.
+ SignedLogicalTime currentTime(
+ LogicalClock::get(opCtx)->getClusterTime(), TimeProofService::TimeProof(), 0);
+ rpc::LogicalTimeMetadata logicalTimeMetadata(currentTime);
+ logicalTimeMetadata.writeToMetadata(metadataBob);
+ } else if (auto validator = LogicalTimeValidator::get(opCtx)) {
+ auto currentTime =
+ validator->trySignLogicalTime(LogicalClock::get(opCtx)->getClusterTime());
+ rpc::LogicalTimeMetadata logicalTimeMetadata(currentTime);
+ logicalTimeMetadata.writeToMetadata(metadataBob);
+ }
}
}
diff --git a/src/mongo/rpc/metadata.cpp b/src/mongo/rpc/metadata.cpp
index 80046bdad58..7b35b2c7821 100644
--- a/src/mongo/rpc/metadata.cpp
+++ b/src/mongo/rpc/metadata.cpp
@@ -30,12 +30,7 @@
#include "mongo/rpc/metadata.h"
-#include "mongo/base/init.h"
#include "mongo/client/dbclientinterface.h"
-#include "mongo/db/auth/action_set.h"
-#include "mongo/db/auth/action_type.h"
-#include "mongo/db/auth/authorization_session.h"
-#include "mongo/db/auth/privilege.h"
#include "mongo/db/jsobj.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time_validator.h"
@@ -49,27 +44,6 @@
namespace mongo {
namespace rpc {
-namespace {
-
-std::vector<Privilege> advanceLogicalClockPrivilege;
-
-MONGO_INITIALIZER(InitializeAdvanceLogicalClockPrivilegeVector)(InitializerContext* const) {
- ActionSet actions;
- actions.addAction(ActionType::internal);
- advanceLogicalClockPrivilege.emplace_back(ResourcePattern::forClusterResource(), actions);
- return Status::OK();
-}
-
-bool isAuthorizedToAdvanceClock(OperationContext* opCtx) {
- auto client = opCtx->getClient();
- // Note: returns true if auth is off, courtesy of
- // AuthzSessionExternalStateServerCommon::shouldIgnoreAuthChecks.
- return AuthorizationSession::get(client)->isAuthorizedForPrivileges(
- advanceLogicalClockPrivilege);
-}
-
-} // unnamed namespace
-
BSONObj makeEmptyMetadata() {
return BSONObj();
}
@@ -124,16 +98,14 @@ void readRequestMetadata(OperationContext* opCtx, const BSONObj& metadataObj) {
// default constructed SignedLogicalTime should be ignored.
if (signedTime.getTime() != LogicalTime::kUninitialized) {
auto logicalTimeValidator = LogicalTimeValidator::get(opCtx);
- if (isAuthorizedToAdvanceClock(opCtx)) {
- if (logicalTimeValidator) {
- logicalTimeValidator->updateCacheTrustedSource(signedTime);
+ if (!LogicalTimeValidator::isAuthorizedToAdvanceClock(opCtx)) {
+ if (!logicalTimeValidator) {
+ uasserted(ErrorCodes::CannotVerifyAndSignLogicalTime,
+ "Cannot accept logicalTime: " + signedTime.getTime().toString() +
+ ". May not be a part of a sharded cluster");
+ } else {
+ uassertStatusOK(logicalTimeValidator->validate(opCtx, signedTime));
}
- } else if (!logicalTimeValidator) {
- uasserted(ErrorCodes::CannotVerifyAndSignLogicalTime,
- "Cannot accept logicalTime: " + signedTime.getTime().toString() +
- ". May not be a part of a sharded cluster");
- } else {
- uassertStatusOK(logicalTimeValidator->validate(signedTime));
}
uassertStatusOK(logicalClock->advanceClusterTime(signedTime.getTime()));
diff --git a/src/mongo/s/commands/strategy.cpp b/src/mongo/s/commands/strategy.cpp
index c91cd861b78..1748d6380f0 100644
--- a/src/mongo/s/commands/strategy.cpp
+++ b/src/mongo/s/commands/strategy.cpp
@@ -112,13 +112,11 @@ Status processCommandMetadata(OperationContext* opCtx, const BSONObj& cmdObj) {
}
if (authSession->getAuthorizationManager().isAuthEnabled()) {
- auto advanceClockStatus = logicalTimeValidator->validate(signedTime);
+ auto advanceClockStatus = logicalTimeValidator->validate(opCtx, signedTime);
if (!advanceClockStatus.isOK()) {
return advanceClockStatus;
}
- } else {
- logicalTimeValidator->updateCacheTrustedSource(signedTime);
}
return logicalClock->advanceClusterTime(signedTime.getTime());
@@ -129,7 +127,8 @@ Status processCommandMetadata(OperationContext* opCtx, const BSONObj& cmdObj) {
*/
void appendRequiredFieldsToResponse(OperationContext* opCtx, BSONObjBuilder* responseBuilder) {
auto validator = LogicalTimeValidator::get(opCtx);
- auto currentTime = validator->signLogicalTime(LogicalClock::get(opCtx)->getClusterTime());
+ auto currentTime =
+ validator->signLogicalTime(opCtx, LogicalClock::get(opCtx)->getClusterTime());
rpc::LogicalTimeMetadata logicalTimeMetadata(currentTime);
logicalTimeMetadata.writeToMetadata(responseBuilder);
auto tracker = OperationTimeTracker::get(opCtx);
diff --git a/src/mongo/s/server.cpp b/src/mongo/s/server.cpp
index 7e4c84bd01f..791dbae5e91 100644
--- a/src/mongo/s/server.cpp
+++ b/src/mongo/s/server.cpp
@@ -55,6 +55,7 @@
#include "mongo/db/log_process_details.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time_metadata_hook.h"
+#include "mongo/db/logical_time_validator.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/server_options.h"
#include "mongo/db/service_context.h"
@@ -147,6 +148,12 @@ static void cleanupTask() {
if (serviceContext)
serviceContext->setKillAllOperations();
+ // Validator shutdown must be called after setKillAllOperations is called. Otherwise, this
+ // can deadlock.
+ if (auto validator = LogicalTimeValidator::get(serviceContext)) {
+ validator->shutDown();
+ }
+
if (auto cursorManager = Grid::get(opCtx)->getCursorManager()) {
cursorManager->shutdown();
}
diff --git a/src/mongo/s/sharding_initialization.cpp b/src/mongo/s/sharding_initialization.cpp
index 570440f817b..68b5538aef8 100644
--- a/src/mongo/s/sharding_initialization.cpp
+++ b/src/mongo/s/sharding_initialization.cpp
@@ -37,8 +37,10 @@
#include "mongo/base/status.h"
#include "mongo/client/remote_command_targeter_factory_impl.h"
#include "mongo/db/audit.h"
+#include "mongo/db/keys_collection_manager.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time_validator.h"
+#include "mongo/db/repl/replication_coordinator.h"
#include "mongo/db/s/sharding_task_executor.h"
#include "mongo/db/server_options.h"
#include "mongo/db/server_parameters.h"
@@ -96,6 +98,8 @@ using executor::ThreadPoolTaskExecutor;
using executor::ShardingTaskExecutor;
static constexpr auto kRetryInterval = Seconds{2};
+const std::string kKeyManagerPurposeString = "SigningClusterTime";
+const Seconds kKeyValidInterval(3 * 30 * 24 * 60 * 60); // ~3 months
auto makeTaskExecutor(std::unique_ptr<NetworkInterface> net) {
auto netPtr = net.get();
@@ -200,7 +204,8 @@ Status initializeGlobalShardingState(OperationContext* opCtx,
makeTaskExecutor(executor::makeNetworkInterface("AddShard-TaskExecutor")));
auto rawCatalogManager = catalogManager.get();
- grid.init(
+ auto grid = Grid::get(opCtx);
+ grid->init(
std::move(catalogClient),
std::move(catalogManager),
std::move(catalogCache),
@@ -211,7 +216,7 @@ Status initializeGlobalShardingState(OperationContext* opCtx,
networkPtr);
// must be started once the grid is initialized
- grid.shardRegistry()->startup(opCtx);
+ grid->shardRegistry()->startup(opCtx);
auto status = rawCatalogClient->startup();
if (!status.isOK()) {
@@ -226,9 +231,18 @@ Status initializeGlobalShardingState(OperationContext* opCtx,
}
}
+ auto keyManager = stdx::make_unique<KeysCollectionManager>(
+ kKeyManagerPurposeString, grid->catalogClient(opCtx), kKeyValidInterval);
+ keyManager->startMonitoring(opCtx->getServiceContext());
+
LogicalTimeValidator::set(opCtx->getServiceContext(),
- stdx::make_unique<LogicalTimeValidator>());
+ stdx::make_unique<LogicalTimeValidator>(std::move(keyManager)));
+ auto replCoord = repl::ReplicationCoordinator::get(opCtx->getClient()->getServiceContext());
+ if (serverGlobalParams.clusterRole == ClusterRole::ConfigServer &&
+ replCoord->getMemberState().primary()) {
+ LogicalTimeValidator::get(opCtx)->enableKeyGenerator(opCtx, true);
+ }
return Status::OK();
}