/**
* Copyright (C) 2017 MongoDB Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the GNU Affero General Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#include "mongo/platform/basic.h"
#include "mongo/db/keys_collection_manager.h"
#include "mongo/db/keys_collection_cache_reader.h"
#include "mongo/db/keys_collection_cache_reader_and_updater.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/service_context.h"
#include "mongo/stdx/memory.h"
#include "mongo/util/concurrency/idle_thread_block.h"
#include "mongo/util/mongoutils/str.h"
#include "mongo/util/time_support.h"
namespace mongo {
namespace {
Milliseconds kDefaultRefreshWaitTime(30 * 1000);
Milliseconds kRefreshIntervalIfErrored(200);
/**
* Returns the amount of time to wait until the monitoring thread should attempt to refresh again.
*/
Milliseconds howMuchSleepNeedFor(const LogicalTime& currentTime,
const LogicalTime& latestExpiredAt,
const Milliseconds& interval) {
auto currentSecs = currentTime.asTimestamp().getSecs();
auto expiredSecs = latestExpiredAt.asTimestamp().getSecs();
if (currentSecs >= expiredSecs) {
// This means that the last round didn't generate a usable key for the current time.
// However, we don't want to poll too hard as well, so use a low interval.
return kRefreshIntervalIfErrored;
}
auto millisBeforeExpire = 1000 * (expiredSecs - currentSecs);
if (interval.count() <= millisBeforeExpire) {
return interval;
}
return kRefreshIntervalIfErrored;
}
} // unnamed namespace
KeysCollectionManager::KeysCollectionManager(std::string purpose, Seconds keyValidForInterval)
: _purpose(std::move(purpose)),
_keyValidForInterval(keyValidForInterval),
_keysCache(_purpose) {}
StatusWith KeysCollectionManager::getKeyForValidation(
OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) {
auto keyStatus = _getKeyWithKeyIdCheck(keyId, forThisTime);
if (keyStatus != ErrorCodes::KeyNotFound) {
return keyStatus;
}
_refresher.refreshNow(opCtx);
return _getKeyWithKeyIdCheck(keyId, forThisTime);
}
StatusWith KeysCollectionManager::getKeyForSigning(
OperationContext* opCtx, const LogicalTime& forThisTime) {
auto keyStatusWith = _getKey(forThisTime);
auto keyStatus = keyStatusWith.getStatus();
if (keyStatus != ErrorCodes::KeyNotFound) {
return keyStatusWith;
}
do {
_refresher.refreshNow(opCtx);
keyStatusWith = _getKey(forThisTime);
keyStatus = keyStatusWith.getStatus();
if (keyStatus == ErrorCodes::KeyNotFound) {
sleepFor(kRefreshIntervalIfErrored);
}
} while (keyStatus == ErrorCodes::KeyNotFound);
return keyStatusWith;
}
StatusWith KeysCollectionManager::_getKeyWithKeyIdCheck(
long long keyId, const LogicalTime& forThisTime) {
auto keyStatus = _getKey(forThisTime);
if (!keyStatus.isOK()) {
return keyStatus;
}
auto key = keyStatus.getValue();
if (keyId == key.getKeyId()) {
return key;
}
// Key not expired but keyId does not match!
return {ErrorCodes::KeyNotFound,
str::stream() << "No keys found for " << _purpose << " that is valid for time: "
<< forThisTime.toString()
<< " with id: "
<< keyId};
}
StatusWith KeysCollectionManager::_getKey(const LogicalTime& forThisTime) {
auto keyStatus = _keysCache.getKey(forThisTime);
if (!keyStatus.isOK()) {
return keyStatus;
}
const auto& key = keyStatus.getValue();
if (key.getExpiresAt() < forThisTime) {
return {ErrorCodes::KeyNotFound,
str::stream() << "No keys found for " << _purpose << " that is valid for "
<< forThisTime.toString()};
}
return key;
}
void KeysCollectionManager::startMonitoring(ServiceContext* service) {
_refresher.setFunc([this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
_refresher.start(
service, str::stream() << "monitoring keys for " << _purpose, _keyValidForInterval);
}
void KeysCollectionManager::stopMonitoring() {
_refresher.stop();
}
void KeysCollectionManager::enableKeyGenerator(OperationContext* opCtx, bool doEnable) {
if (doEnable) {
_refresher.switchFunc(opCtx, [this](OperationContext* opCtx) {
KeysCollectionCacheReaderAndUpdater keyGenerator(_purpose, _keyValidForInterval);
auto keyGenerationStatus = keyGenerator.refresh(opCtx);
if (ErrorCodes::isShutdownError(keyGenerationStatus.getStatus().code())) {
return keyGenerationStatus;
}
// An error encountered by the keyGenerator should not prevent refreshing the cache
auto cacheRefreshStatus = _keysCache.refresh(opCtx);
if (!keyGenerationStatus.isOK()) {
return keyGenerationStatus;
}
return cacheRefreshStatus;
});
} else {
_refresher.switchFunc(
opCtx, [this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
}
}
void KeysCollectionManager::PeriodicRunner::refreshNow(OperationContext* opCtx) {
auto refreshRequest = [this]() {
stdx::lock_guard lk(_mutex);
if (_inShutdown) {
throw DBException("aborting keys cache refresh because node is shutting down",
ErrorCodes::ShutdownInProgress);
}
if (_refreshRequest) {
return _refreshRequest;
}
_refreshNeededCV.notify_all();
_refreshRequest = std::make_shared>();
return _refreshRequest;
}();
// note: waitFor waits min(maxTimeMS, kDefaultRefreshWaitTime).
// waitFor also throws if timeout, so also throw when notification was not satisfied after
// waiting.
if (!refreshRequest->waitFor(opCtx, kDefaultRefreshWaitTime)) {
throw DBException("timed out waiting for refresh", ErrorCodes::ExceededTimeLimit);
}
}
void KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh(ServiceContext* service,
std::string threadName,
Milliseconds refreshInterval) {
Client::initThreadIfNotAlready(threadName);
while (true) {
auto opCtx = cc().makeOperationContext();
std::shared_ptr doRefresh;
{
stdx::lock_guard lock(_mutex);
if (_inShutdown) {
break;
}
invariant(_doRefresh.get() != nullptr);
doRefresh = _doRefresh;
}
Milliseconds nextWakeup = kRefreshIntervalIfErrored;
auto latestKeyStatusWith = (*doRefresh)(opCtx.get());
if (latestKeyStatusWith.getStatus().isOK()) {
const auto& latestKey = latestKeyStatusWith.getValue();
auto currentTime = LogicalClock::get(service)->getClusterTime();
nextWakeup =
howMuchSleepNeedFor(currentTime, latestKey.getExpiresAt(), refreshInterval);
}
// TODO: Add backoff to nextWakeup if it has a very small value in a row to avoid spinning.
stdx::unique_lock lock(_mutex);
if (_refreshRequest) {
_refreshRequest->set();
_refreshRequest.reset();
}
if (_inShutdown) {
break;
}
MONGO_IDLE_THREAD_BLOCK;
auto sleepStatus = opCtx->waitForConditionOrInterruptNoAssertUntil(
_refreshNeededCV, lock, Date_t::now() + nextWakeup);
if (ErrorCodes::isShutdownError(sleepStatus.getStatus().code())) {
break;
}
}
stdx::unique_lock lock(_mutex);
if (_refreshRequest) {
_refreshRequest->set();
_refreshRequest.reset();
}
}
void KeysCollectionManager::PeriodicRunner::setFunc(RefreshFunc newRefreshStrategy) {
stdx::lock_guard lock(_mutex);
_doRefresh = std::make_shared(std::move(newRefreshStrategy));
}
void KeysCollectionManager::PeriodicRunner::switchFunc(OperationContext* opCtx,
RefreshFunc newRefreshStrategy) {
setFunc(newRefreshStrategy);
// Note: calling refreshNow will ensure that if there is an ongoing method call to the original
// refreshStrategy, it will be finished after this.
refreshNow(opCtx);
}
void KeysCollectionManager::PeriodicRunner::start(ServiceContext* service,
const std::string& threadName,
Milliseconds refreshInterval) {
stdx::lock_guard lock(_mutex);
invariant(!_backgroundThread.joinable());
invariant(!_inShutdown);
_backgroundThread =
stdx::thread(stdx::bind(&KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh,
this,
service,
threadName,
refreshInterval));
}
void KeysCollectionManager::PeriodicRunner::stop() {
{
stdx::lock_guard lock(_mutex);
if (!_backgroundThread.joinable()) {
return;
}
_inShutdown = true;
_refreshNeededCV.notify_all();
}
_backgroundThread.join();
}
} // namespace mongo