summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
authorsamantharitter <samantha.ritter@10gen.com>2017-06-27 12:09:40 -0400
committerJason Carey <jcarey@argv.me>2017-07-13 17:40:53 -0400
commite1cae24805e3e7282958ee67a01555dd6ce40039 (patch)
treeebce77d9a502a193784483b2201b65e1a5010d98 /src/mongo/db
parent9a49ee3a03e02597086e577f06a71a0723bc0582 (diff)
downloadmongo-e1cae24805e3e7282958ee67a01555dd6ce40039.tar.gz
SERVER-29610 Allow LogicalSessionIds to contain signed user information
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/SConscript85
-rw-r--r--src/mongo/db/db.cpp3
-rw-r--r--src/mongo/db/keys_collection_manager.cpp309
-rw-r--r--src/mongo/db/keys_collection_manager.h135
-rw-r--r--src/mongo/db/keys_collection_manager_direct.cpp137
-rw-r--r--src/mongo/db/keys_collection_manager_direct.h73
-rw-r--r--src/mongo/db/keys_collection_manager_sharding.cpp345
-rw-r--r--src/mongo/db/keys_collection_manager_sharding.h197
-rw-r--r--src/mongo/db/keys_collection_manager_sharding_test.cpp (renamed from src/mongo/db/keys_collection_manager_test.cpp)53
-rw-r--r--src/mongo/db/keys_collection_manager_zero.cpp57
-rw-r--r--src/mongo/db/keys_collection_manager_zero.h69
-rw-r--r--src/mongo/db/logical_session_cache.cpp78
-rw-r--r--src/mongo/db/logical_session_cache.h53
-rw-r--r--src/mongo/db/logical_session_cache_test.cpp205
-rw-r--r--src/mongo/db/logical_session_id.h10
-rw-r--r--src/mongo/db/logical_session_id_test.cpp6
-rw-r--r--src/mongo/db/logical_session_record.cpp33
-rw-r--r--src/mongo/db/logical_session_record.h41
-rw-r--r--src/mongo/db/logical_session_record.idl30
-rw-r--r--src/mongo/db/logical_session_record_test.cpp20
-rw-r--r--src/mongo/db/logical_time_validator.cpp13
-rw-r--r--src/mongo/db/logical_time_validator.h10
-rw-r--r--src/mongo/db/logical_time_validator_test.cpp14
-rw-r--r--src/mongo/db/service_context.cpp9
-rw-r--r--src/mongo/db/service_context.h48
-rw-r--r--src/mongo/db/service_liason.cpp71
-rw-r--r--src/mongo/db/service_liason.h24
-rw-r--r--src/mongo/db/service_liason_mock.h8
-rw-r--r--src/mongo/db/service_liason_mongod.cpp4
-rw-r--r--src/mongo/db/service_liason_mongod.h6
-rw-r--r--src/mongo/db/sessions_collection.h5
-rw-r--r--src/mongo/db/sessions_collection_mock.cpp14
-rw-r--r--src/mongo/db/sessions_collection_mock.h10
-rw-r--r--src/mongo/db/signed_logical_session_id.cpp73
-rw-r--r--src/mongo/db/signed_logical_session_id.h115
-rw-r--r--src/mongo/db/signed_logical_session_id.idl50
-rw-r--r--src/mongo/db/signed_logical_session_id_test.cpp81
37 files changed, 1707 insertions, 787 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index eba59679fc9..ae96b374c7f 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -448,6 +448,7 @@ env.Library(
'operation_context_group.cpp'
],
LIBDEPS=[
+ '$BUILD_DIR/mongo/db/logical_session_id',
'$BUILD_DIR/mongo/util/clock_sources',
'$BUILD_DIR/mongo/util/concurrency/spin_lock',
'$BUILD_DIR/mongo/util/decorable',
@@ -455,7 +456,6 @@ env.Library(
'$BUILD_DIR/mongo/util/net/network',
'$BUILD_DIR/mongo/util/periodic_runner',
'$BUILD_DIR/mongo/transport/transport_layer_common',
- 'logical_session_cache',
],
)
@@ -894,6 +894,32 @@ env.CppUnitTest(
)
env.Library(
+ target='signed_logical_session_id',
+ source=[
+ 'signed_logical_session_id.cpp',
+ env.Idlc('signed_logical_session_id.idl')[0],
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/base',
+ '$BUILD_DIR/mongo/crypto/sha1_block',
+ '$BUILD_DIR/mongo/idl/idl_parser',
+ '$BUILD_DIR/mongo/util/uuid',
+ 'logical_session_id'
+ ],
+)
+
+env.CppUnitTest(
+ target='signed_logical_session_id_test',
+ source=[
+ 'signed_logical_session_id_test.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/base',
+ 'signed_logical_session_id',
+ ],
+)
+
+env.Library(
target='logical_session_record',
source=[
'logical_session_record.cpp',
@@ -903,7 +929,8 @@ env.Library(
'$BUILD_DIR/mongo/base',
'$BUILD_DIR/mongo/db/auth/user_name',
'$BUILD_DIR/mongo/idl/idl_parser',
- 'logical_session_id'
+ 'logical_session_id',
+ 'signed_logical_session_id',
],
)
@@ -925,6 +952,12 @@ env.Library(
'service_liason.cpp',
],
LIBDEPS=[
+ '$BUILD_DIR/mongo/crypto/sha1_block',
+ 'keys_collection_document',
+ 'keys_collection_manager',
+ 'keys_collection_manager_zero',
+ 'logical_clock',
+ 'signed_logical_session_id',
],
)
@@ -985,6 +1018,7 @@ env.Library(
'logical_session_cache.cpp',
],
LIBDEPS=[
+ 'logical_session_id',
'logical_session_record',
'sessions_collection',
'server_parameters',
@@ -1000,9 +1034,13 @@ envWithAsio.CppUnitTest(
LIBDEPS=[
'$BUILD_DIR/mongo/db/service_context_noop_init',
'$BUILD_DIR/mongo/executor/async_timer_mock',
+ 'keys_collection_manager',
+ 'keys_collection_document',
+ 'logical_clock',
'logical_session_cache',
'service_liason_mock',
'sessions_collection_mock',
+ 'signed_logical_session_id',
],
)
@@ -1079,12 +1117,46 @@ env.Library(
target='keys_collection_manager',
source=[
'keys_collection_manager.cpp',
+ ],
+ LIBDEPS=[
+ ],
+)
+
+env.Library(
+ target='keys_collection_manager_direct',
+ source=[
+ 'keys_collection_manager_direct.cpp',
+ ],
+ LIBDEPS=[
+ 'dbdirectclient',
+ 'keys_collection_manager',
+ 'logical_clock',
+ 'logical_time',
+ ],
+)
+
+env.Library(
+ target='keys_collection_manager_zero',
+ source=[
+ 'keys_collection_manager_zero.cpp',
+ ],
+ LIBDEPS=[
+ 'keys_collection_manager',
+ 'logical_time',
+ ],
+)
+
+env.Library(
+ target='keys_collection_manager_sharding',
+ source=[
+ 'keys_collection_manager_sharding.cpp',
'keys_collection_cache_reader.cpp',
'keys_collection_cache_reader_and_updater.cpp',
],
LIBDEPS=[
'logical_clock',
'keys_collection_document',
+ 'keys_collection_manager',
'logical_time',
'server_options',
'$BUILD_DIR/mongo/s/catalog/sharding_catalog_client',
@@ -1109,7 +1181,7 @@ env.Library(
'logical_time_validator.cpp',
],
LIBDEPS=[
- 'keys_collection_manager',
+ 'keys_collection_manager_sharding',
'service_context',
'signed_logical_time',
'time_proof_service',
@@ -1167,6 +1239,7 @@ env.CppUnitTest(
'logical_time_validator_test.cpp',
],
LIBDEPS=[
+ 'keys_collection_manager_sharding',
'logical_time_validator',
'$BUILD_DIR/mongo/s/config_server_test_fixture',
'$BUILD_DIR/mongo/s/coreshard',
@@ -1233,12 +1306,12 @@ env.CppUnitTest(
)
env.CppUnitTest(
- target='keys_collection_manager_test',
+ target='keys_collection_manager_sharding_test',
source=[
- 'keys_collection_manager_test.cpp',
+ 'keys_collection_manager_sharding_test.cpp',
],
LIBDEPS=[
- 'keys_collection_manager',
+ 'keys_collection_manager_sharding',
'$BUILD_DIR/mongo/s/config_server_test_fixture',
],
)
diff --git a/src/mongo/db/db.cpp b/src/mongo/db/db.cpp
index 64446785910..569958ccbad 100644
--- a/src/mongo/db/db.cpp
+++ b/src/mongo/db/db.cpp
@@ -73,6 +73,7 @@
#include "mongo/db/initialize_snmp.h"
#include "mongo/db/introspect.h"
#include "mongo/db/json.h"
+#include "mongo/db/keys_collection_manager.h"
#include "mongo/db/log_process_details.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_session_cache.h"
@@ -724,7 +725,7 @@ ExitCode _initAndListen(int listenPort) {
}
auto sessionCache = makeLogicalSessionCacheD(kind);
- globalServiceContext->setLogicalSessionCache(std::move(sessionCache));
+ LogicalSessionCache::set(globalServiceContext, std::move(sessionCache));
// MessageServer::run will return when exit code closes its socket and we don't need the
// operation context anymore
diff --git a/src/mongo/db/keys_collection_manager.cpp b/src/mongo/db/keys_collection_manager.cpp
index e91d9af8dad..bed56f87847 100644
--- a/src/mongo/db/keys_collection_manager.cpp
+++ b/src/mongo/db/keys_collection_manager.cpp
@@ -30,317 +30,10 @@
#include "mongo/db/keys_collection_manager.h"
-#include "mongo/db/keys_collection_cache_reader.h"
-#include "mongo/db/keys_collection_cache_reader_and_updater.h"
-#include "mongo/db/logical_clock.h"
-#include "mongo/db/logical_time.h"
-#include "mongo/db/operation_context.h"
-#include "mongo/db/server_options.h"
-#include "mongo/db/service_context.h"
-#include "mongo/stdx/memory.h"
-#include "mongo/util/concurrency/idle_thread_block.h"
-#include "mongo/util/fail_point_service.h"
-#include "mongo/util/mongoutils/str.h"
-#include "mongo/util/time_support.h"
-
namespace mongo {
const Seconds KeysCollectionManager::kKeyValidInterval{3 * 30 * 24 * 60 * 60}; // ~3 months
-namespace {
-
-Milliseconds kDefaultRefreshWaitTime(30 * 1000);
-Milliseconds kRefreshIntervalIfErrored(200);
-Milliseconds kMaxRefreshWaitTime(10 * 60 * 1000);
-
-// Prevents the refresher thread from waiting longer than the given number of milliseconds, even on
-// a successful refresh.
-MONGO_FP_DECLARE(maxKeyRefreshWaitTimeOverrideMS);
-
-/**
- * Returns the amount of time to wait until the monitoring thread should attempt to refresh again.
- */
-Milliseconds howMuchSleepNeedFor(const LogicalTime& currentTime,
- const LogicalTime& latestExpiredAt,
- const Milliseconds& interval) {
- auto currentSecs = currentTime.asTimestamp().getSecs();
- auto expiredSecs = latestExpiredAt.asTimestamp().getSecs();
-
- if (currentSecs >= expiredSecs) {
- // This means that the last round didn't generate a usable key for the current time.
- // However, we don't want to poll too hard as well, so use a low interval.
- return kRefreshIntervalIfErrored;
- }
-
- auto millisBeforeExpire = 1000 * (expiredSecs - currentSecs);
-
- if (interval.count() <= millisBeforeExpire) {
- return interval;
- }
-
- return Milliseconds(millisBeforeExpire);
-}
-
-} // unnamed namespace
-
-KeysCollectionManager::KeysCollectionManager(std::string purpose,
- ShardingCatalogClient* client,
- Seconds keyValidForInterval)
- : _purpose(std::move(purpose)),
- _keyValidForInterval(keyValidForInterval),
- _catalogClient(client),
- _keysCache(_purpose, client) {}
-
-StatusWith<KeysCollectionDocument> KeysCollectionManager::getKeyForValidation(
- OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) {
- auto keyStatus = _getKeyWithKeyIdCheck(keyId, forThisTime);
-
- if (keyStatus != ErrorCodes::KeyNotFound) {
- return keyStatus;
- }
-
- _refresher.refreshNow(opCtx);
-
- return _getKeyWithKeyIdCheck(keyId, forThisTime);
-}
-
-StatusWith<KeysCollectionDocument> KeysCollectionManager::getKeyForSigning(
- const LogicalTime& forThisTime) {
- return _getKey(forThisTime);
-}
-
-StatusWith<KeysCollectionDocument> KeysCollectionManager::_getKeyWithKeyIdCheck(
- long long keyId, const LogicalTime& forThisTime) {
- auto keyStatus = _keysCache.getKeyById(keyId, forThisTime);
-
- if (!keyStatus.isOK()) {
- return keyStatus;
- }
-
- return keyStatus.getValue();
-}
-
-StatusWith<KeysCollectionDocument> KeysCollectionManager::_getKey(const LogicalTime& forThisTime) {
- auto keyStatus = _keysCache.getKey(forThisTime);
-
- if (!keyStatus.isOK()) {
- return keyStatus;
- }
-
- const auto& key = keyStatus.getValue();
-
- if (key.getExpiresAt() < forThisTime) {
- return {ErrorCodes::KeyNotFound,
- str::stream() << "No keys found for " << _purpose << " that is valid for "
- << forThisTime.toString()};
- }
-
- return key;
-}
-
-void KeysCollectionManager::refreshNow(OperationContext* opCtx) {
- _refresher.refreshNow(opCtx);
-}
-
-void KeysCollectionManager::startMonitoring(ServiceContext* service) {
- _refresher.setFunc([this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
- _refresher.start(
- service, str::stream() << "monitoring keys for " << _purpose, _keyValidForInterval);
-}
-
-void KeysCollectionManager::stopMonitoring() {
- _refresher.stop();
-}
-
-void KeysCollectionManager::enableKeyGenerator(OperationContext* opCtx, bool doEnable) {
- if (doEnable) {
- _refresher.switchFunc(opCtx, [this](OperationContext* opCtx) {
- KeysCollectionCacheReaderAndUpdater keyGenerator(
- _purpose, _catalogClient, _keyValidForInterval);
- auto keyGenerationStatus = keyGenerator.refresh(opCtx);
-
- if (ErrorCodes::isShutdownError(keyGenerationStatus.getStatus().code())) {
- return keyGenerationStatus;
- }
-
- // An error encountered by the keyGenerator should not prevent refreshing the cache
- auto cacheRefreshStatus = _keysCache.refresh(opCtx);
-
- if (!keyGenerationStatus.isOK()) {
- return keyGenerationStatus;
- }
-
- return cacheRefreshStatus;
- });
- } else {
- _refresher.switchFunc(
- opCtx, [this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
- }
-}
-
-bool KeysCollectionManager::hasSeenKeys() {
- return _refresher.hasSeenKeys();
-}
-
-void KeysCollectionManager::PeriodicRunner::refreshNow(OperationContext* opCtx) {
- auto refreshRequest = [this]() {
- stdx::lock_guard<stdx::mutex> lk(_mutex);
-
- if (_inShutdown) {
- throw DBException("aborting keys cache refresh because node is shutting down",
- ErrorCodes::ShutdownInProgress);
- }
-
- if (_refreshRequest) {
- return _refreshRequest;
- }
-
- _refreshNeededCV.notify_all();
- _refreshRequest = std::make_shared<Notification<void>>();
- return _refreshRequest;
- }();
-
- // note: waitFor waits min(maxTimeMS, kDefaultRefreshWaitTime).
- // waitFor also throws if timeout, so also throw when notification was not satisfied after
- // waiting.
- if (!refreshRequest->waitFor(opCtx, kDefaultRefreshWaitTime)) {
- throw DBException("timed out waiting for refresh", ErrorCodes::ExceededTimeLimit);
- }
-}
-
-void KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh(ServiceContext* service,
- std::string threadName,
- Milliseconds refreshInterval) {
- Client::initThreadIfNotAlready(threadName);
-
- while (true) {
- auto opCtx = cc().makeOperationContext();
-
- bool hasRefreshRequestInitially = false;
- unsigned errorCount = 0;
- std::shared_ptr<RefreshFunc> doRefresh;
- {
- stdx::lock_guard<stdx::mutex> lock(_mutex);
-
- if (_inShutdown) {
- break;
- }
-
- invariant(_doRefresh.get() != nullptr);
- doRefresh = _doRefresh;
- hasRefreshRequestInitially = _refreshRequest.get() != nullptr;
- }
-
- Milliseconds nextWakeup = kRefreshIntervalIfErrored;
-
- // No need to refresh keys in FCV 3.4, since key generation will be disabled.
- if (serverGlobalParams.featureCompatibility.version.load() !=
- ServerGlobalParams::FeatureCompatibility::Version::k34) {
- auto latestKeyStatusWith = (*doRefresh)(opCtx.get());
- if (latestKeyStatusWith.getStatus().isOK()) {
- errorCount = 0;
- const auto& latestKey = latestKeyStatusWith.getValue();
- auto currentTime = LogicalClock::get(service)->getClusterTime();
-
- {
- stdx::unique_lock<stdx::mutex> lock(_mutex);
- _hasSeenKeys = true;
- }
-
- nextWakeup =
- howMuchSleepNeedFor(currentTime, latestKey.getExpiresAt(), refreshInterval);
- } else {
- errorCount += 1;
- nextWakeup = Milliseconds(kRefreshIntervalIfErrored.count() * errorCount);
- if (nextWakeup > kMaxRefreshWaitTime) {
- nextWakeup = kMaxRefreshWaitTime;
- }
- }
- } else {
- nextWakeup = kDefaultRefreshWaitTime;
- }
-
- MONGO_FAIL_POINT_BLOCK(maxKeyRefreshWaitTimeOverrideMS, data) {
- const BSONObj& dataObj = data.getData();
- auto overrideMS = Milliseconds(dataObj["overrideMS"].numberInt());
- if (nextWakeup > overrideMS) {
- nextWakeup = overrideMS;
- }
- }
-
- stdx::unique_lock<stdx::mutex> lock(_mutex);
-
- if (_refreshRequest) {
- if (!hasRefreshRequestInitially) {
- // A fresh request came in, fulfill the request before going to sleep.
- continue;
- }
-
- _refreshRequest->set();
- _refreshRequest.reset();
- }
-
- if (_inShutdown) {
- break;
- }
-
- MONGO_IDLE_THREAD_BLOCK;
- auto sleepStatus = opCtx->waitForConditionOrInterruptNoAssertUntil(
- _refreshNeededCV, lock, Date_t::now() + nextWakeup);
-
- if (ErrorCodes::isShutdownError(sleepStatus.getStatus().code())) {
- break;
- }
- }
-
- stdx::unique_lock<stdx::mutex> lock(_mutex);
- if (_refreshRequest) {
- _refreshRequest->set();
- _refreshRequest.reset();
- }
-}
-
-void KeysCollectionManager::PeriodicRunner::setFunc(RefreshFunc newRefreshStrategy) {
- stdx::lock_guard<stdx::mutex> lock(_mutex);
- _doRefresh = std::make_shared<RefreshFunc>(std::move(newRefreshStrategy));
-}
-
-void KeysCollectionManager::PeriodicRunner::switchFunc(OperationContext* opCtx,
- RefreshFunc newRefreshStrategy) {
- setFunc(newRefreshStrategy);
-}
-
-void KeysCollectionManager::PeriodicRunner::start(ServiceContext* service,
- const std::string& threadName,
- Milliseconds refreshInterval) {
- stdx::lock_guard<stdx::mutex> lock(_mutex);
- invariant(!_backgroundThread.joinable());
- invariant(!_inShutdown);
-
- _backgroundThread =
- stdx::thread(stdx::bind(&KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh,
- this,
- service,
- threadName,
- refreshInterval));
-}
-
-void KeysCollectionManager::PeriodicRunner::stop() {
- {
- stdx::lock_guard<stdx::mutex> lock(_mutex);
- if (!_backgroundThread.joinable()) {
- return;
- }
-
- _inShutdown = true;
- _refreshNeededCV.notify_all();
- }
-
- _backgroundThread.join();
-}
-bool KeysCollectionManager::PeriodicRunner::hasSeenKeys() {
- stdx::lock_guard<stdx::mutex> lock(_mutex);
- return _hasSeenKeys;
-}
+KeysCollectionManager::~KeysCollectionManager() = default;
} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager.h b/src/mongo/db/keys_collection_manager.h
index 53de6257b32..87b57ef1250 100644
--- a/src/mongo/db/keys_collection_manager.h
+++ b/src/mongo/db/keys_collection_manager.h
@@ -31,21 +31,13 @@
#include <memory>
#include "mongo/base/status_with.h"
-#include "mongo/db/keys_collection_cache_reader.h"
-#include "mongo/db/keys_collection_cache_reader_and_updater.h"
#include "mongo/db/keys_collection_document.h"
-#include "mongo/stdx/functional.h"
-#include "mongo/stdx/mutex.h"
-#include "mongo/stdx/thread.h"
-#include "mongo/util/concurrency/notification.h"
-#include "mongo/util/duration.h"
namespace mongo {
class OperationContext;
class LogicalTime;
class ServiceContext;
-class ShardingCatalogClient;
/**
* This is responsible for providing keys that can be used for HMAC computation. This also supports
@@ -55,19 +47,16 @@ class KeysCollectionManager {
public:
static const Seconds kKeyValidInterval;
- KeysCollectionManager(std::string purpose,
- ShardingCatalogClient* client,
- Seconds keyValidForInterval);
+ virtual ~KeysCollectionManager();
/**
* Return a key that is valid for the given time and also matches the keyId. Note that this call
- * can block if it will need to do a refresh.
+ * can block if it will need to do a refresh and we are on a sharded cluster.
*
* Throws ErrorCode::ExceededTimeLimit if it times out.
*/
- StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx,
- long long keyId,
- const LogicalTime& forThisTime);
+ virtual StatusWith<KeysCollectionDocument> getKeyForValidation(
+ OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) = 0;
/**
* Returns a key that is valid for the given time. Note that unlike getKeyForValidation, this
@@ -75,120 +64,8 @@ public:
*
* Throws ErrorCode::ExceededTimeLimit if it times out.
*/
- StatusWith<KeysCollectionDocument> getKeyForSigning(const LogicalTime& forThisTime);
-
- /**
- * Request this manager to perform a refresh.
- */
- void refreshNow(OperationContext* opCtx);
-
- /**
- * Starts a background thread that will constantly update the internal cache of keys.
- *
- * Cannot call this after stopMonitoring was called at least once.
- */
- void startMonitoring(ServiceContext* service);
-
- /**
- * Stops the background thread updating the cache.
- */
- void stopMonitoring();
-
- /**
- * Enable writing new keys to the config server primary. Should only be called if current node
- * is the config primary.
- */
- void enableKeyGenerator(OperationContext* opCtx, bool doEnable);
-
- /**
- * Returns true if the refresher has ever successfully returned keys from the config server.
- */
- bool hasSeenKeys();
-
-private:
- /**
- * This is responsible for periodically performing refresh in the background.
- */
- class PeriodicRunner {
- public:
- using RefreshFunc = stdx::function<StatusWith<KeysCollectionDocument>(OperationContext*)>;
-
- /**
- * Preemptively inform the monitoring thread it needs to perform a refresh. Returns an
- * object
- * that gets notified after the current round of refresh is over. Note that being notified
- * can
- * mean either of these things:
- *
- * 1. An error occurred and refresh was not performed.
- * 2. No error occurred but no new key was found.
- * 3. No error occurred and new keys were found.
- */
- void refreshNow(OperationContext* opCtx);
-
- /**
- * Sets the refresh function to use.
- * Should only be used to bootstrap this refresher with initial strategy.
- */
- void setFunc(RefreshFunc newRefreshStrategy);
-
- /**
- * Switches the current strategy with a new one. This also waits to make sure that the old
- * strategy is not being used and will no longer be used after this call.
- */
- void switchFunc(OperationContext* opCtx, RefreshFunc newRefreshStrategy);
-
- /**
- * Starts the refresh thread.
- */
- void start(ServiceContext* service,
- const std::string& threadName,
- Milliseconds refreshInterval);
-
- /**
- * Stops the refresh thread.
- */
- void stop();
-
- /**
- * Returns true if keys have ever successfully been returned from the config server.
- */
- bool hasSeenKeys();
-
- private:
- void _doPeriodicRefresh(ServiceContext* service,
- std::string threadName,
- Milliseconds refreshInterval);
-
- stdx::mutex _mutex; // protects all the member variables below.
- std::shared_ptr<Notification<void>> _refreshRequest;
- stdx::condition_variable _refreshNeededCV;
-
- stdx::thread _backgroundThread;
- std::shared_ptr<RefreshFunc> _doRefresh;
-
- bool _hasSeenKeys = false;
- bool _inShutdown = false;
- };
-
- /**
- * Return a key that is valid for the given time and also matches the keyId.
- */
- StatusWith<KeysCollectionDocument> _getKeyWithKeyIdCheck(long long keyId,
- const LogicalTime& forThisTime);
-
- /**
- * Return a key that is valid for the given time.
- */
- StatusWith<KeysCollectionDocument> _getKey(const LogicalTime& forThisTime);
-
- const std::string _purpose;
- const Seconds _keyValidForInterval;
- ShardingCatalogClient* _catalogClient;
-
- // No mutex needed since the members below have their own mutexes.
- KeysCollectionCacheReader _keysCache;
- PeriodicRunner _refresher;
+ virtual StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx,
+ const LogicalTime& forThisTime) = 0;
};
} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager_direct.cpp b/src/mongo/db/keys_collection_manager_direct.cpp
new file mode 100644
index 00000000000..60cf5672e6e
--- /dev/null
+++ b/src/mongo/db/keys_collection_manager_direct.cpp
@@ -0,0 +1,137 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/keys_collection_manager_direct.h"
+
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/logical_clock.h"
+#include "mongo/db/logical_time.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/db/server_options.h"
+#include "mongo/db/service_context.h"
+
+namespace mongo {
+
+namespace {
+const char kLogicalTimeKeysCollection[] = "admin.system.keys";
+const int kMaxCachedKeys = 20;
+} // namespace
+
+KeysCollectionManagerDirect::KeysCollectionManagerDirect(std::string purpose,
+ Seconds keyValidForInterval)
+ : _purpose(std::move(purpose)),
+ _keyValidForInterval(keyValidForInterval),
+ _cache(kMaxCachedKeys) {}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerDirect::getKeyForValidation(
+ OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) {
+ // First, attempt to find the key in our cache.
+ {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+ auto it = _cache.find(keyId);
+ if (it != _cache.end()) {
+ return it->second;
+ }
+ }
+
+ // Query admin.system.keys for an active key with this id.
+ DBDirectClient client(opCtx);
+
+ BSONObjBuilder queryBuilder;
+ queryBuilder.append("purpose", _purpose);
+ queryBuilder.append("_id", keyId);
+ queryBuilder.append("expiresAt", BSON("$gt" << forThisTime.asTimestamp()));
+
+ auto cursor = client.query(KeysCollectionDocument::ConfigNS, queryBuilder.obj());
+
+ if (!cursor->more()) {
+ return {ErrorCodes::KeyNotFound, "Could not find matching key"};
+ }
+
+ // Parse the key.
+ auto res = KeysCollectionDocument::fromBSON(cursor->next());
+ if (!res.isOK()) {
+ return res.getStatus();
+ }
+
+ // Add to our cache.
+ {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+ _cache.add(keyId, res.getValue());
+ }
+
+ return res.getValue();
+}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerDirect::getKeyForSigning(
+ OperationContext* opCtx, const LogicalTime& forThisTime) {
+ // Search through the cache for active keys.
+ {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+ for (auto& it : _cache) {
+ auto keyDoc = it.second;
+ auto expiration = keyDoc.getExpiresAt();
+ if (expiration > forThisTime) {
+ return keyDoc;
+ }
+ }
+ }
+
+ // Query admin.system.keys for active keys.
+ DBDirectClient client(opCtx);
+
+ BSONObjBuilder queryBuilder;
+ queryBuilder.append("purpose", _purpose);
+ queryBuilder.append("expiresAt", BSON("$gt" << forThisTime.asTimestamp()));
+
+ auto cursor = client.query(KeysCollectionDocument::ConfigNS, queryBuilder.obj());
+
+ if (!cursor->more()) {
+ return {ErrorCodes::KeyNotFound, "Could not find an active key for signing"};
+ }
+
+ // Parse and return the key.
+ auto res = KeysCollectionDocument::fromBSON(cursor->next());
+ if (!res.isOK()) {
+ return res.getStatus();
+ }
+
+ auto keyDoc = res.getValue();
+
+ // Add to our cache.
+ {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+ _cache.add(keyDoc.getKeyId(), keyDoc);
+ }
+
+ return keyDoc;
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager_direct.h b/src/mongo/db/keys_collection_manager_direct.h
new file mode 100644
index 00000000000..2f1228ff056
--- /dev/null
+++ b/src/mongo/db/keys_collection_manager_direct.h
@@ -0,0 +1,73 @@
+/**
+ * Copyright (C) 2017 MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "mongo/db/keys_collection_document.h"
+#include "mongo/db/keys_collection_manager.h"
+#include "mongo/stdx/mutex.h"
+#include "mongo/util/lru_cache.h"
+
+namespace mongo {
+
+class OperationContext;
+class LogicalTime;
+class ServiceContext;
+
+/**
+ * This implementation of the KeysCollectionManager uses DBDirectclient to query the
+ * keys collection local to this server.
+ */
+class KeysCollectionManagerDirect : public KeysCollectionManager {
+public:
+ KeysCollectionManagerDirect(std::string purpose, Seconds keyValidForInterval);
+
+ /**
+ * Return a key that is valid for the given time and also matches the keyId.
+ */
+ StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx,
+ long long keyId,
+ const LogicalTime& forThisTime) override;
+
+ /**
+ * Returns a key that is valid for the given time.
+ */
+ StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx,
+ const LogicalTime& forThisTime) override;
+
+private:
+ const std::string _purpose;
+ const Seconds _keyValidForInterval;
+
+ stdx::mutex _mutex;
+ LRUCache<long long, KeysCollectionDocument> _cache;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager_sharding.cpp b/src/mongo/db/keys_collection_manager_sharding.cpp
new file mode 100644
index 00000000000..c43779cbcbe
--- /dev/null
+++ b/src/mongo/db/keys_collection_manager_sharding.cpp
@@ -0,0 +1,345 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/keys_collection_manager_sharding.h"
+
+#include "mongo/db/keys_collection_cache_reader.h"
+#include "mongo/db/keys_collection_cache_reader_and_updater.h"
+#include "mongo/db/logical_clock.h"
+#include "mongo/db/logical_time.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/db/server_options.h"
+#include "mongo/db/service_context.h"
+#include "mongo/stdx/memory.h"
+#include "mongo/util/concurrency/idle_thread_block.h"
+#include "mongo/util/fail_point_service.h"
+#include "mongo/util/mongoutils/str.h"
+#include "mongo/util/time_support.h"
+
+namespace mongo {
+
+namespace {
+
+Milliseconds kDefaultRefreshWaitTime(30 * 1000);
+Milliseconds kRefreshIntervalIfErrored(200);
+Milliseconds kMaxRefreshWaitTime(10 * 60 * 1000);
+
+// Prevents the refresher thread from waiting longer than the given number of milliseconds, even on
+// a successful refresh.
+MONGO_FP_DECLARE(maxKeyRefreshWaitTimeOverrideMS);
+
+/**
+ * Returns the amount of time to wait until the monitoring thread should attempt to refresh again.
+ */
+Milliseconds howMuchSleepNeedFor(const LogicalTime& currentTime,
+ const LogicalTime& latestExpiredAt,
+ const Milliseconds& interval) {
+ auto currentSecs = currentTime.asTimestamp().getSecs();
+ auto expiredSecs = latestExpiredAt.asTimestamp().getSecs();
+
+ if (currentSecs >= expiredSecs) {
+ // This means that the last round didn't generate a usable key for the current time.
+ // However, we don't want to poll too hard as well, so use a low interval.
+ return kRefreshIntervalIfErrored;
+ }
+
+ auto millisBeforeExpire = 1000 * (expiredSecs - currentSecs);
+
+ if (interval.count() <= millisBeforeExpire) {
+ return interval;
+ }
+
+ return Milliseconds(millisBeforeExpire);
+}
+
+} // unnamed namespace
+
+KeysCollectionManagerSharding::KeysCollectionManagerSharding(std::string purpose,
+ ShardingCatalogClient* client,
+ Seconds keyValidForInterval)
+ : _purpose(std::move(purpose)),
+ _keyValidForInterval(keyValidForInterval),
+ _catalogClient(client),
+ _keysCache(_purpose, client) {}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::getKeyForValidation(
+ OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) {
+ auto keyStatus = _getKeyWithKeyIdCheck(keyId, forThisTime);
+
+ if (keyStatus != ErrorCodes::KeyNotFound) {
+ return keyStatus;
+ }
+
+ _refresher.refreshNow(opCtx);
+
+ return _getKeyWithKeyIdCheck(keyId, forThisTime);
+}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::getKeyForSigning(
+ OperationContext* opCtx, const LogicalTime& forThisTime) {
+ return _getKey(forThisTime);
+}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::_getKeyWithKeyIdCheck(
+ long long keyId, const LogicalTime& forThisTime) {
+ auto keyStatus = _keysCache.getKeyById(keyId, forThisTime);
+
+ if (!keyStatus.isOK()) {
+ return keyStatus;
+ }
+
+ return keyStatus.getValue();
+}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::_getKey(
+ const LogicalTime& forThisTime) {
+ auto keyStatus = _keysCache.getKey(forThisTime);
+
+ if (!keyStatus.isOK()) {
+ return keyStatus;
+ }
+
+ const auto& key = keyStatus.getValue();
+
+ if (key.getExpiresAt() < forThisTime) {
+ return {ErrorCodes::KeyNotFound,
+ str::stream() << "No keys found for " << _purpose << " that is valid for "
+ << forThisTime.toString()};
+ }
+
+ return key;
+}
+
+void KeysCollectionManagerSharding::refreshNow(OperationContext* opCtx) {
+ _refresher.refreshNow(opCtx);
+}
+
+void KeysCollectionManagerSharding::startMonitoring(ServiceContext* service) {
+ _refresher.setFunc([this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
+ _refresher.start(
+ service, str::stream() << "monitoring keys for " << _purpose, _keyValidForInterval);
+}
+
+void KeysCollectionManagerSharding::stopMonitoring() {
+ _refresher.stop();
+}
+
+void KeysCollectionManagerSharding::enableKeyGenerator(OperationContext* opCtx, bool doEnable) {
+ if (doEnable) {
+ _refresher.switchFunc(opCtx, [this](OperationContext* opCtx) {
+ KeysCollectionCacheReaderAndUpdater keyGenerator(
+ _purpose, _catalogClient, _keyValidForInterval);
+ auto keyGenerationStatus = keyGenerator.refresh(opCtx);
+
+ if (ErrorCodes::isShutdownError(keyGenerationStatus.getStatus().code())) {
+ return keyGenerationStatus;
+ }
+
+ // An error encountered by the keyGenerator should not prevent refreshing the cache
+ auto cacheRefreshStatus = _keysCache.refresh(opCtx);
+
+ if (!keyGenerationStatus.isOK()) {
+ return keyGenerationStatus;
+ }
+
+ return cacheRefreshStatus;
+ });
+ } else {
+ _refresher.switchFunc(
+ opCtx, [this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); });
+ }
+}
+
+bool KeysCollectionManagerSharding::hasSeenKeys() {
+ return _refresher.hasSeenKeys();
+}
+
+void KeysCollectionManagerSharding::PeriodicRunner::refreshNow(OperationContext* opCtx) {
+ auto refreshRequest = [this]() {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+
+ if (_inShutdown) {
+ throw DBException("aborting keys cache refresh because node is shutting down",
+ ErrorCodes::ShutdownInProgress);
+ }
+
+ if (_refreshRequest) {
+ return _refreshRequest;
+ }
+
+ _refreshNeededCV.notify_all();
+ _refreshRequest = std::make_shared<Notification<void>>();
+ return _refreshRequest;
+ }();
+
+ // note: waitFor waits min(maxTimeMS, kDefaultRefreshWaitTime).
+ // waitFor also throws if timeout, so also throw when notification was not satisfied after
+ // waiting.
+ if (!refreshRequest->waitFor(opCtx, kDefaultRefreshWaitTime)) {
+ throw DBException("timed out waiting for refresh", ErrorCodes::ExceededTimeLimit);
+ }
+}
+
+void KeysCollectionManagerSharding::PeriodicRunner::_doPeriodicRefresh(
+ ServiceContext* service, std::string threadName, Milliseconds refreshInterval) {
+ Client::initThreadIfNotAlready(threadName);
+
+ while (true) {
+ auto opCtx = cc().makeOperationContext();
+
+ bool hasRefreshRequestInitially = false;
+ unsigned errorCount = 0;
+ std::shared_ptr<RefreshFunc> doRefresh;
+ {
+ stdx::lock_guard<stdx::mutex> lock(_mutex);
+
+ if (_inShutdown) {
+ break;
+ }
+
+ invariant(_doRefresh.get() != nullptr);
+ doRefresh = _doRefresh;
+ hasRefreshRequestInitially = _refreshRequest.get() != nullptr;
+ }
+
+ Milliseconds nextWakeup = kRefreshIntervalIfErrored;
+
+ // No need to refresh keys in FCV 3.4, since key generation will be disabled.
+ if (serverGlobalParams.featureCompatibility.version.load() !=
+ ServerGlobalParams::FeatureCompatibility::Version::k34) {
+ auto latestKeyStatusWith = (*doRefresh)(opCtx.get());
+ if (latestKeyStatusWith.getStatus().isOK()) {
+ errorCount = 0;
+ const auto& latestKey = latestKeyStatusWith.getValue();
+ auto currentTime = LogicalClock::get(service)->getClusterTime();
+
+ {
+ stdx::unique_lock<stdx::mutex> lock(_mutex);
+ _hasSeenKeys = true;
+ }
+
+ nextWakeup =
+ howMuchSleepNeedFor(currentTime, latestKey.getExpiresAt(), refreshInterval);
+ } else {
+ errorCount += 1;
+ nextWakeup = Milliseconds(kRefreshIntervalIfErrored.count() * errorCount);
+ if (nextWakeup > kMaxRefreshWaitTime) {
+ nextWakeup = kMaxRefreshWaitTime;
+ }
+ }
+ } else {
+ nextWakeup = kDefaultRefreshWaitTime;
+ }
+
+ MONGO_FAIL_POINT_BLOCK(maxKeyRefreshWaitTimeOverrideMS, data) {
+ const BSONObj& dataObj = data.getData();
+ auto overrideMS = Milliseconds(dataObj["overrideMS"].numberInt());
+ if (nextWakeup > overrideMS) {
+ nextWakeup = overrideMS;
+ }
+ }
+
+ stdx::unique_lock<stdx::mutex> lock(_mutex);
+
+ if (_refreshRequest) {
+ if (!hasRefreshRequestInitially) {
+ // A fresh request came in, fulfill the request before going to sleep.
+ continue;
+ }
+
+ _refreshRequest->set();
+ _refreshRequest.reset();
+ }
+
+ if (_inShutdown) {
+ break;
+ }
+
+ MONGO_IDLE_THREAD_BLOCK;
+ auto sleepStatus = opCtx->waitForConditionOrInterruptNoAssertUntil(
+ _refreshNeededCV, lock, Date_t::now() + nextWakeup);
+
+ if (ErrorCodes::isShutdownError(sleepStatus.getStatus().code())) {
+ break;
+ }
+ }
+
+ stdx::unique_lock<stdx::mutex> lock(_mutex);
+ if (_refreshRequest) {
+ _refreshRequest->set();
+ _refreshRequest.reset();
+ }
+}
+
+void KeysCollectionManagerSharding::PeriodicRunner::setFunc(RefreshFunc newRefreshStrategy) {
+ stdx::lock_guard<stdx::mutex> lock(_mutex);
+ _doRefresh = std::make_shared<RefreshFunc>(std::move(newRefreshStrategy));
+}
+
+void KeysCollectionManagerSharding::PeriodicRunner::switchFunc(OperationContext* opCtx,
+ RefreshFunc newRefreshStrategy) {
+ setFunc(newRefreshStrategy);
+}
+
+void KeysCollectionManagerSharding::PeriodicRunner::start(ServiceContext* service,
+ const std::string& threadName,
+ Milliseconds refreshInterval) {
+ stdx::lock_guard<stdx::mutex> lock(_mutex);
+ invariant(!_backgroundThread.joinable());
+ invariant(!_inShutdown);
+
+ _backgroundThread =
+ stdx::thread(stdx::bind(&KeysCollectionManagerSharding::PeriodicRunner::_doPeriodicRefresh,
+ this,
+ service,
+ threadName,
+ refreshInterval));
+}
+
+void KeysCollectionManagerSharding::PeriodicRunner::stop() {
+ {
+ stdx::lock_guard<stdx::mutex> lock(_mutex);
+ if (!_backgroundThread.joinable()) {
+ return;
+ }
+
+ _inShutdown = true;
+ _refreshNeededCV.notify_all();
+ }
+
+ _backgroundThread.join();
+}
+
+bool KeysCollectionManagerSharding::PeriodicRunner::hasSeenKeys() {
+ stdx::lock_guard<stdx::mutex> lock(_mutex);
+ return _hasSeenKeys;
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager_sharding.h b/src/mongo/db/keys_collection_manager_sharding.h
new file mode 100644
index 00000000000..82f358ebe18
--- /dev/null
+++ b/src/mongo/db/keys_collection_manager_sharding.h
@@ -0,0 +1,197 @@
+/**
+ * Copyright (C) 2017 MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "mongo/base/status_with.h"
+#include "mongo/db/keys_collection_cache_reader.h"
+#include "mongo/db/keys_collection_cache_reader_and_updater.h"
+#include "mongo/db/keys_collection_document.h"
+#include "mongo/db/keys_collection_manager.h"
+#include "mongo/stdx/functional.h"
+#include "mongo/stdx/mutex.h"
+#include "mongo/stdx/thread.h"
+#include "mongo/util/concurrency/notification.h"
+#include "mongo/util/duration.h"
+
+namespace mongo {
+
+class OperationContext;
+class LogicalTime;
+class ServiceContext;
+class ShardingCatalogClient;
+
+/**
+ * This implementation of the KeysCollectionManager queries the config servers for keys.
+ * It maintains in internal background thread that is used to periodically refresh
+ * the local key cache against the keys collection stored on the config servers.
+ */
+class KeysCollectionManagerSharding : public KeysCollectionManager {
+public:
+ static const Seconds kKeyValidInterval;
+
+ KeysCollectionManagerSharding(std::string purpose,
+ ShardingCatalogClient* client,
+ Seconds keyValidForInterval);
+
+ /**
+ * Return a key that is valid for the given time and also matches the keyId. Note that this call
+ * can block if it will need to do a refresh.
+ *
+ * Throws ErrorCode::ExceededTimeLimit if it times out.
+ */
+ StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx,
+ long long keyId,
+ const LogicalTime& forThisTime) override;
+
+ /**
+ * Returns a key that is valid for the given time. Note that unlike getKeyForValidation, this
+ * will never do a refresh.
+ *
+ * Throws ErrorCode::ExceededTimeLimit if it times out.
+ */
+ StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx,
+ const LogicalTime& forThisTime) override;
+
+ /**
+ * Request this manager to perform a refresh.
+ */
+ void refreshNow(OperationContext* opCtx);
+
+ /**
+ * Starts a background thread that will constantly update the internal cache of keys.
+ *
+ * Cannot call this after stopMonitoring was called at least once.
+ */
+ void startMonitoring(ServiceContext* service);
+
+ /**
+ * Stops the background thread updating the cache.
+ */
+ void stopMonitoring();
+
+ /**
+ * Enable writing new keys to the config server primary. Should only be called if current node
+ * is the config primary.
+ */
+ void enableKeyGenerator(OperationContext* opCtx, bool doEnable);
+
+ /**
+ * Returns true if the refresher has ever successfully returned keys from the config server.
+ */
+ bool hasSeenKeys();
+
+private:
+ /**
+ * This is responsible for periodically performing refresh in the background.
+ */
+ class PeriodicRunner {
+ public:
+ using RefreshFunc = stdx::function<StatusWith<KeysCollectionDocument>(OperationContext*)>;
+
+ /**
+ * Preemptively inform the monitoring thread it needs to perform a refresh. Returns an
+ * object
+ * that gets notified after the current round of refresh is over. Note that being notified
+ * can
+ * mean either of these things:
+ *
+ * 1. An error occurred and refresh was not performed.
+ * 2. No error occurred but no new key was found.
+ * 3. No error occurred and new keys were found.
+ */
+ void refreshNow(OperationContext* opCtx);
+
+ /**
+ * Sets the refresh function to use.
+ * Should only be used to bootstrap this refresher with initial strategy.
+ */
+ void setFunc(RefreshFunc newRefreshStrategy);
+
+ /**
+ * Switches the current strategy with a new one. This also waits to make sure that the old
+ * strategy is not being used and will no longer be used after this call.
+ */
+ void switchFunc(OperationContext* opCtx, RefreshFunc newRefreshStrategy);
+
+ /**
+ * Starts the refresh thread.
+ */
+ void start(ServiceContext* service,
+ const std::string& threadName,
+ Milliseconds refreshInterval);
+
+ /**
+ * Stops the refresh thread.
+ */
+ void stop();
+
+ /**
+ * Returns true if keys have ever successfully been returned from the config server.
+ */
+ bool hasSeenKeys();
+
+ private:
+ void _doPeriodicRefresh(ServiceContext* service,
+ std::string threadName,
+ Milliseconds refreshInterval);
+
+ stdx::mutex _mutex; // protects all the member variables below.
+ std::shared_ptr<Notification<void>> _refreshRequest;
+ stdx::condition_variable _refreshNeededCV;
+
+ stdx::thread _backgroundThread;
+ std::shared_ptr<RefreshFunc> _doRefresh;
+
+ bool _hasSeenKeys = false;
+ bool _inShutdown = false;
+ };
+
+ /**
+ * Return a key that is valid for the given time and also matches the keyId.
+ */
+ StatusWith<KeysCollectionDocument> _getKeyWithKeyIdCheck(long long keyId,
+ const LogicalTime& forThisTime);
+
+ /**
+ * Return a key that is valid for the given time.
+ */
+ StatusWith<KeysCollectionDocument> _getKey(const LogicalTime& forThisTime);
+
+ const std::string _purpose;
+ const Seconds _keyValidForInterval;
+ ShardingCatalogClient* _catalogClient;
+
+ // No mutex needed since the members below have their own mutexes.
+ KeysCollectionCacheReader _keysCache;
+ PeriodicRunner _refresher;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager_test.cpp b/src/mongo/db/keys_collection_manager_sharding_test.cpp
index e68d9285cc0..d7af13a09fc 100644
--- a/src/mongo/db/keys_collection_manager_test.cpp
+++ b/src/mongo/db/keys_collection_manager_sharding_test.cpp
@@ -33,7 +33,7 @@
#include "mongo/db/jsobj.h"
#include "mongo/db/keys_collection_document.h"
-#include "mongo/db/keys_collection_manager.h"
+#include "mongo/db/keys_collection_manager_sharding.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/operation_context.h"
@@ -48,9 +48,9 @@
namespace mongo {
-class KeysManagerTest : public ConfigServerTestFixture {
+class KeysManagerShardedTest : public ConfigServerTestFixture {
public:
- KeysCollectionManager* keyManager() {
+ KeysCollectionManagerSharding* keyManager() {
return _keyManager.get();
}
@@ -65,7 +65,8 @@ protected:
auto clockSource = stdx::make_unique<ClockSourceMock>();
operationContext()->getServiceContext()->setFastClockSource(std::move(clockSource));
auto catalogClient = Grid::get(operationContext())->catalogClient();
- _keyManager = stdx::make_unique<KeysCollectionManager>("dummy", catalogClient, Seconds(1));
+ _keyManager =
+ stdx::make_unique<KeysCollectionManagerSharding>("dummy", catalogClient, Seconds(1));
}
void tearDown() override {
@@ -81,10 +82,10 @@ protected:
}
private:
- std::unique_ptr<KeysCollectionManager> _keyManager;
+ std::unique_ptr<KeysCollectionManagerSharding> _keyManager;
};
-TEST_F(KeysManagerTest, GetKeyForValidationTimesOutIfRefresherIsNotRunning) {
+TEST_F(KeysManagerShardedTest, GetKeyForValidationTimesOutIfRefresherIsNotRunning) {
operationContext()->setDeadlineAfterNowBy(Microseconds(250 * 1000));
ASSERT_THROWS(keyManager()
@@ -93,7 +94,7 @@ TEST_F(KeysManagerTest, GetKeyForValidationTimesOutIfRefresherIsNotRunning) {
DBException);
}
-TEST_F(KeysManagerTest, GetKeyForValidationErrorsIfKeyDoesntExist) {
+TEST_F(KeysManagerShardedTest, GetKeyForValidationErrorsIfKeyDoesntExist) {
keyManager()->startMonitoring(getServiceContext());
auto keyStatus =
@@ -101,7 +102,7 @@ TEST_F(KeysManagerTest, GetKeyForValidationErrorsIfKeyDoesntExist) {
ASSERT_EQ(ErrorCodes::KeyNotFound, keyStatus.getStatus());
}
-TEST_F(KeysManagerTest, GetKeyWithSingleKey) {
+TEST_F(KeysManagerShardedTest, GetKeyWithSingleKey) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -119,7 +120,7 @@ TEST_F(KeysManagerTest, GetKeyWithSingleKey) {
ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp());
}
-TEST_F(KeysManagerTest, GetKeyWithMultipleKeys) {
+TEST_F(KeysManagerShardedTest, GetKeyWithMultipleKeys) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -151,7 +152,7 @@ TEST_F(KeysManagerTest, GetKeyWithMultipleKeys) {
ASSERT_EQ(Timestamp(205, 0), key.getExpiresAt().asTimestamp());
}
-TEST_F(KeysManagerTest, GetKeyShouldErrorIfKeyIdMismatchKey) {
+TEST_F(KeysManagerShardedTest, GetKeyShouldErrorIfKeyIdMismatchKey) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -164,7 +165,7 @@ TEST_F(KeysManagerTest, GetKeyShouldErrorIfKeyIdMismatchKey) {
ASSERT_EQ(ErrorCodes::KeyNotFound, keyStatus.getStatus());
}
-TEST_F(KeysManagerTest, GetKeyWithoutRefreshShouldReturnRightKey) {
+TEST_F(KeysManagerShardedTest, GetKeyWithoutRefreshShouldReturnRightKey) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -199,7 +200,7 @@ TEST_F(KeysManagerTest, GetKeyWithoutRefreshShouldReturnRightKey) {
}
}
-TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) {
+TEST_F(KeysManagerShardedTest, GetKeyForSigningShouldReturnRightKey) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -209,7 +210,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) {
keyManager()->refreshNow(operationContext());
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -218,7 +219,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) {
ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp());
}
-TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) {
+TEST_F(KeysManagerShardedTest, GetKeyForSigningShouldReturnRightOldKey) {
keyManager()->startMonitoring(getServiceContext());
KeysCollectionDocument origKey1(
@@ -233,7 +234,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) {
keyManager()->refreshNow(operationContext());
{
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -243,7 +244,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) {
}
{
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(105, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(105, 0)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -253,7 +254,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) {
}
}
-TEST_F(KeysManagerTest, ShouldCreateKeysIfKeyGeneratorEnabled) {
+TEST_F(KeysManagerShardedTest, ShouldCreateKeysIfKeyGeneratorEnabled) {
keyManager()->startMonitoring(getServiceContext());
const LogicalTime currentTime(LogicalTime(Timestamp(100, 0)));
@@ -262,14 +263,14 @@ TEST_F(KeysManagerTest, ShouldCreateKeysIfKeyGeneratorEnabled) {
keyManager()->enableKeyGenerator(operationContext(), true);
keyManager()->refreshNow(operationContext());
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 100)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 100)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
ASSERT_EQ(Timestamp(101, 0), key.getExpiresAt().asTimestamp());
}
-TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) {
+TEST_F(KeysManagerShardedTest, EnableModeFlipFlopStressTest) {
keyManager()->startMonitoring(getServiceContext());
const LogicalTime currentTime(LogicalTime(Timestamp(100, 0)));
@@ -281,7 +282,7 @@ TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) {
keyManager()->enableKeyGenerator(operationContext(), doEnable);
keyManager()->refreshNow(operationContext());
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 100)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 100)));
ASSERT_OK(keyStatus.getStatus());
auto key = keyStatus.getValue();
@@ -291,7 +292,7 @@ TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) {
}
}
-TEST_F(KeysManagerTest, ShouldStillBeAbleToUpdateCacheEvenIfItCantCreateKeys) {
+TEST_F(KeysManagerShardedTest, ShouldStillBeAbleToUpdateCacheEvenIfItCantCreateKeys) {
KeysCollectionDocument origKey1(
1, "dummy", TimeProofService::generateRandomKey(), LogicalTime(Timestamp(105, 0)));
ASSERT_OK(insertToConfigCollection(
@@ -319,7 +320,7 @@ TEST_F(KeysManagerTest, ShouldStillBeAbleToUpdateCacheEvenIfItCantCreateKeys) {
ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp());
}
-TEST_F(KeysManagerTest, ShouldNotCreateKeysWithDisableKeyGenerationFailPoint) {
+TEST_F(KeysManagerShardedTest, ShouldNotCreateKeysWithDisableKeyGenerationFailPoint) {
const LogicalTime currentTime(Timestamp(100, 0));
LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime);
@@ -336,11 +337,11 @@ TEST_F(KeysManagerTest, ShouldNotCreateKeysWithDisableKeyGenerationFailPoint) {
// Once the failpoint is disabled, the generator can make keys again.
keyManager()->refreshNow(operationContext());
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0)));
ASSERT_OK(keyStatus.getStatus());
}
-TEST_F(KeysManagerTest, HasSeenKeysIsFalseUntilKeysAreFound) {
+TEST_F(KeysManagerShardedTest, HasSeenKeysIsFalseUntilKeysAreFound) {
const LogicalTime currentTime(Timestamp(100, 0));
LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime);
@@ -361,13 +362,13 @@ TEST_F(KeysManagerTest, HasSeenKeysIsFalseUntilKeysAreFound) {
// Once the failpoint is disabled, the generator can make keys again.
keyManager()->refreshNow(operationContext());
- auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0)));
+ auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0)));
ASSERT_OK(keyStatus.getStatus());
ASSERT_EQ(true, keyManager()->hasSeenKeys());
}
-TEST_F(KeysManagerTest, ShouldNotReturnKeysInFeatureCompatibilityVersion34) {
+TEST_F(KeysManagerShardedTest, ShouldNotReturnKeysInFeatureCompatibilityVersion34) {
serverGlobalParams.featureCompatibility.version.store(
ServerGlobalParams::FeatureCompatibility::Version::k34);
diff --git a/src/mongo/db/keys_collection_manager_zero.cpp b/src/mongo/db/keys_collection_manager_zero.cpp
new file mode 100644
index 00000000000..5f773b0152d
--- /dev/null
+++ b/src/mongo/db/keys_collection_manager_zero.cpp
@@ -0,0 +1,57 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/keys_collection_manager_zero.h"
+
+#include "mongo/db/keys_collection_document.h"
+
+namespace mongo {
+
+namespace {
+
+const TimeProofService::Key kTimeProofServiceKey{};
+const long long kKeyId = 1;
+
+} // namespace
+
+KeysCollectionManagerZero::KeysCollectionManagerZero(std::string purpose)
+ : _purpose(std::move(purpose)) {}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerZero::getKeyForValidation(
+ OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) {
+ return KeysCollectionDocument(keyId, _purpose, kTimeProofServiceKey, forThisTime);
+}
+
+StatusWith<KeysCollectionDocument> KeysCollectionManagerZero::getKeyForSigning(
+ OperationContext* opCtx, const LogicalTime& forThisTime) {
+ return KeysCollectionDocument(kKeyId, _purpose, kTimeProofServiceKey, forThisTime);
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/keys_collection_manager_zero.h b/src/mongo/db/keys_collection_manager_zero.h
new file mode 100644
index 00000000000..71f2dcc89ee
--- /dev/null
+++ b/src/mongo/db/keys_collection_manager_zero.h
@@ -0,0 +1,69 @@
+/**
+ * Copyright (C) 2017 MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "mongo/db/keys_collection_document.h"
+#include "mongo/db/keys_collection_manager.h"
+
+namespace mongo {
+
+class OperationContext;
+class LogicalTime;
+class ServiceContext;
+
+/**
+ * This implementation of the KeysCollectionManager always returns a zeroed key. This is a
+ * transitional type to bridge the gap until we decide how to store and rotate keys in standalones
+ * and non-clustered replica sets.
+ */
+class KeysCollectionManagerZero : public KeysCollectionManager {
+public:
+ KeysCollectionManagerZero(std::string purpose);
+
+ /**
+ * Return a key that is valid for the given time and also matches the keyId.
+ */
+ StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx,
+ long long keyId,
+ const LogicalTime& forThisTime) override;
+
+ /**
+ * Returns a key that is valid for the given time.
+ */
+ StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx,
+ const LogicalTime& forThisTime) override;
+
+private:
+ std::string _purpose;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/logical_session_cache.cpp b/src/mongo/db/logical_session_cache.cpp
index 8aa1eaba802..7474c398159 100644
--- a/src/mongo/db/logical_session_cache.cpp
+++ b/src/mongo/db/logical_session_cache.cpp
@@ -32,13 +32,20 @@
#include "mongo/db/logical_session_cache.h"
+#include "mongo/db/operation_context.h"
#include "mongo/db/server_parameters.h"
+#include "mongo/db/service_context.h"
#include "mongo/util/duration.h"
#include "mongo/util/log.h"
#include "mongo/util/periodic_runner.h"
namespace mongo {
+namespace {
+const auto getLogicalSessionCache =
+ ServiceContext::declareDecoration<std::unique_ptr<LogicalSessionCache>>();
+} // namespace
+
MONGO_EXPORT_STARTUP_SERVER_PARAMETER(logicalSessionRecordCacheSize,
int,
LogicalSessionCache::kLogicalSessionCacheDefaultCapacity);
@@ -55,6 +62,20 @@ constexpr int LogicalSessionCache::kLogicalSessionCacheDefaultCapacity;
constexpr Minutes LogicalSessionCache::kLogicalSessionDefaultTimeout;
constexpr Minutes LogicalSessionCache::kLogicalSessionDefaultRefresh;
+LogicalSessionCache* LogicalSessionCache::get(ServiceContext* service) {
+ return getLogicalSessionCache(service).get();
+}
+
+LogicalSessionCache* LogicalSessionCache::get(OperationContext* ctx) {
+ return get(ctx->getClient()->getServiceContext());
+}
+
+void LogicalSessionCache::set(ServiceContext* service,
+ std::unique_ptr<LogicalSessionCache> sessionCache) {
+ auto& cache = getLogicalSessionCache(service);
+ cache = std::move(sessionCache);
+}
+
LogicalSessionCache::LogicalSessionCache(std::unique_ptr<ServiceLiason> service,
std::unique_ptr<SessionsCollection> collection,
Options options)
@@ -78,42 +99,34 @@ LogicalSessionCache::~LogicalSessionCache() {
}
}
-StatusWith<LogicalSessionRecord::Owner> LogicalSessionCache::getOwner(LogicalSessionId lsid) {
+Status LogicalSessionCache::fetchAndPromote(SignedLogicalSessionId slsid) {
// Search our local cache first
- auto owner = getOwnerFromCache(lsid);
- if (owner.isOK()) {
- return owner;
+ auto promoteRes = promote(slsid);
+ if (promoteRes.isOK()) {
+ return promoteRes;
}
// Cache miss, must fetch from the sessions collection.
- auto res = _sessionsColl->fetchRecord(lsid);
+ auto res = _sessionsColl->fetchRecord(slsid);
// If we got a valid record, add it to our cache.
if (res.isOK()) {
auto& record = res.getValue();
record.setLastUse(_service->now());
+ // Any duplicate records here are actually the same record with different
+ // lastUse times, ignore them.
auto oldRecord = _addToCache(record);
-
- // If we had a conflicting record for this id, and they aren't the same record,
- // it could mean that an interloper called endSession and startSession for the
- // same lsid while we were fetching its record from the sessions collection.
- // This means our session has been written over, do not allow the caller to use it.
- // Note: we could find expired versions of our same record here, but they'll compare equal.
- if (oldRecord && *oldRecord != record) {
- return {ErrorCodes::NoSuchSession, "no matching session record found"};
- }
-
- return record.getSessionOwner();
+ return Status::OK();
}
+ // If we could not get a valid record, return the error.
return res.getStatus();
}
-StatusWith<LogicalSessionRecord::Owner> LogicalSessionCache::getOwnerFromCache(
- LogicalSessionId lsid) {
+Status LogicalSessionCache::promote(SignedLogicalSessionId slsid) {
stdx::unique_lock<stdx::mutex> lk(_cacheMutex);
- auto it = _cache.find(lsid);
+ auto it = _cache.find(slsid.getLsid());
if (it == _cache.end()) {
return {ErrorCodes::NoSuchSession, "no matching session record found in the cache"};
}
@@ -126,13 +139,13 @@ StatusWith<LogicalSessionRecord::Owner> LogicalSessionCache::getOwnerFromCache(
// Update the last use time before returning.
it->second.setLastUse(now);
- return it->second.getSessionOwner();
+ return Status::OK();
}
-Status LogicalSessionCache::startSession(LogicalSessionRecord authoritativeRecord) {
+Status LogicalSessionCache::startSession(SignedLogicalSessionId slsid) {
// Make sure the timestamp makes sense
- auto now = _service->now();
- authoritativeRecord.setLastUse(now);
+ auto authoritativeRecord =
+ LogicalSessionRecord::makeAuthoritativeRecord(slsid, _service->now());
// Attempt to insert into the sessions collection first. This collection enforces
// unique session ids, so it will act as concurrency control for us.
@@ -148,7 +161,7 @@ Status LogicalSessionCache::startSession(LogicalSessionRecord authoritativeRecor
auto oldRecord = _addToCache(authoritativeRecord);
if (oldRecord) {
if (*oldRecord != authoritativeRecord) {
- if (!_isDead(*oldRecord, now)) {
+ if (!_isDead(*oldRecord, _service->now())) {
return {ErrorCodes::DuplicateSession, "session with this id already exists"};
}
}
@@ -157,6 +170,17 @@ Status LogicalSessionCache::startSession(LogicalSessionRecord authoritativeRecor
return Status::OK();
}
+StatusWith<SignedLogicalSessionId> LogicalSessionCache::signLsid(OperationContext* opCtx,
+ LogicalSessionId* id,
+ boost::optional<OID> userId) {
+ return _service->signLsid(opCtx, id, std::move(userId));
+}
+
+Status LogicalSessionCache::validateLsid(OperationContext* opCtx,
+ const SignedLogicalSessionId& slsid) {
+ return _service->validateLsid(opCtx, slsid);
+}
+
void LogicalSessionCache::_refresh() {
LogicalSessionIdSet activeSessions;
LogicalSessionIdSet deadSessions;
@@ -177,9 +201,9 @@ void LogicalSessionCache::_refresh() {
for (auto& it : cacheCopy) {
auto record = it.second;
if (!_isDead(record, now)) {
- activeSessions.insert(record.getLsid());
+ activeSessions.insert(record.getSignedLsid().getLsid());
} else {
- deadSessions.insert(record.getLsid());
+ deadSessions.insert(record.getSignedLsid().getLsid());
}
}
@@ -236,7 +260,7 @@ bool LogicalSessionCache::_isDead(const LogicalSessionRecord& record, Date_t now
boost::optional<LogicalSessionRecord> LogicalSessionCache::_addToCache(
LogicalSessionRecord record) {
stdx::unique_lock<stdx::mutex> lk(_cacheMutex);
- return _cache.add(record.getLsid(), std::move(record));
+ return _cache.add(record.getSignedLsid().getLsid(), std::move(record));
}
} // namespace mongo
diff --git a/src/mongo/db/logical_session_cache.h b/src/mongo/db/logical_session_cache.h
index a88bb3f31d0..56ce96eec34 100644
--- a/src/mongo/db/logical_session_cache.h
+++ b/src/mongo/db/logical_session_cache.h
@@ -33,12 +33,17 @@
#include "mongo/db/logical_session_record.h"
#include "mongo/db/service_liason.h"
#include "mongo/db/sessions_collection.h"
+#include "mongo/db/signed_logical_session_id.h"
+#include "mongo/db/time_proof_service.h"
#include "mongo/platform/atomic_word.h"
#include "mongo/stdx/thread.h"
#include "mongo/util/lru_cache.h"
namespace mongo {
+class OperationContext;
+class ServiceContext;
+
extern int logicalSessionRecordCacheSize;
extern int localLogicalSessionTimeoutMinutes;
extern int logicalSessionRefreshMinutes;
@@ -51,6 +56,13 @@ extern int logicalSessionRefreshMinutes;
*/
class LogicalSessionCache {
public:
+ /**
+ * Decorate the ServiceContext with a LogicalSessionCache instance.
+ */
+ static LogicalSessionCache* get(ServiceContext* service);
+ static LogicalSessionCache* get(OperationContext* opCtx);
+ static void set(ServiceContext* service, std::unique_ptr<LogicalSessionCache> sessionCache);
+
static constexpr int kLogicalSessionCacheDefaultCapacity = 10000;
static constexpr Minutes kLogicalSessionDefaultTimeout = Minutes(30);
static constexpr Minutes kLogicalSessionDefaultRefresh = Minutes(5);
@@ -102,26 +114,22 @@ public:
~LogicalSessionCache();
/**
- * Returns the owner for the given session, or return an error if there
- * is no authoritative record for this session.
- *
- * If the cache does not already contain a record for this session, this
- * method may issue networking operations to obtain the record. Afterwards,
- * the cache will keep the record for future use.
+ * If the cache contains a record for this LogicalSessionId, promotes that lsid
+ * to be the most recently used and updates its lastUse date to be the current
+ * time. Otherwise, returns an error.
*
- * This call will promote any record it touches to be the most-recently-used
- * record in the cache.
+ * This method does not issue networking calls.
*/
- StatusWith<LogicalSessionRecord::Owner> getOwner(LogicalSessionId lsid);
+ Status promote(SignedLogicalSessionId lsid);
/**
- * Returns the owner for the given session if we already have its record in the
- * cache. Do not fetch the record from the network if we do not already have it.
+ * If the cache contains a record for this LogicalSessionId, promotes it.
+ * Otherwise, attempts to fetch the record for this LogicalSessionId from the
+ * sessions collection, and returns the record if found. Otherwise, returns an error.
*
- * This call will promote any record it touches to be the most-recently-used
- * record in the cache.
+ * This method may issue networking calls.
*/
- StatusWith<LogicalSessionRecord::Owner> getOwnerFromCache(LogicalSessionId lsid);
+ Status fetchAndPromote(SignedLogicalSessionId lsid);
/**
* Inserts a new authoritative session record into the cache. This method will
@@ -129,7 +137,22 @@ public:
* should only be used when starting new sessions and should not be used to
* insert records for existing sessions.
*/
- Status startSession(LogicalSessionRecord authoritativeRecord);
+ Status startSession(SignedLogicalSessionId lsid);
+
+ /**
+ * Generates and sets a signature for the fields in this LogicalSessionId.
+ *
+ * If this method is not able to acquire a key to perform the signature
+ * this call will return an error.
+ */
+ StatusWith<SignedLogicalSessionId> signLsid(OperationContext* opCtx,
+ LogicalSessionId* id,
+ boost::optional<OID> userId);
+
+ /**
+ * Validates that this LogicalSessionId was signed with the correct key.
+ */
+ Status validateLsid(OperationContext* opCtx, const SignedLogicalSessionId& lsid);
/**
* Removes all local records in this cache. Does not remove the corresponding
diff --git a/src/mongo/db/logical_session_cache_test.cpp b/src/mongo/db/logical_session_cache_test.cpp
index a9d3c5eb1c1..baba6814b34 100644
--- a/src/mongo/db/logical_session_cache_test.cpp
+++ b/src/mongo/db/logical_session_cache_test.cpp
@@ -57,9 +57,7 @@ class LogicalSessionCacheTest : public unittest::Test {
public:
LogicalSessionCacheTest()
: _service(std::make_shared<MockServiceLiasonImpl>()),
- _sessions(std::make_shared<MockSessionsCollectionImpl>()),
- _user("sam", "test"),
- _userId(OID::gen()) {}
+ _sessions(std::make_shared<MockSessionsCollectionImpl>()) {}
void setUp() override {
auto mockService = stdx::make_unique<MockServiceLiason>(_service);
@@ -78,11 +76,6 @@ public:
}
}
- LogicalSessionRecord newRecord() {
- return LogicalSessionRecord::makeAuthoritativeRecord(
- LogicalSessionId::gen(), _user, _userId, _service->now());
- }
-
std::unique_ptr<LogicalSessionCache>& cache() {
return _cache;
}
@@ -100,142 +93,132 @@ private:
std::shared_ptr<MockSessionsCollectionImpl> _sessions;
std::unique_ptr<LogicalSessionCache> _cache;
-
- UserName _user;
- boost::optional<OID> _userId;
};
// Test that session cache fetches new records from the sessions collection
TEST_F(LogicalSessionCacheTest, CacheFetchesNewRecords) {
- auto record = newRecord();
- auto lsid = record.getLsid();
+ auto signedLsid = SignedLogicalSessionId::gen();
// When the record is not present (and not in the sessions collection) returns an error
- auto res = cache()->getOwner(lsid);
+ auto res = cache()->fetchAndPromote(signedLsid);
ASSERT(!res.isOK());
// When the record is not present (but is in the sessions collection) returns it
- sessions()->add(record);
- res = cache()->getOwner(lsid);
+ sessions()->add(LogicalSessionRecord::makeAuthoritativeRecord(signedLsid, service()->now()));
+ res = cache()->fetchAndPromote(signedLsid);
ASSERT(res.isOK());
- ASSERT(res.getValue() == record.getSessionOwner());
// When the record is present in the cache, returns it
- sessions()->setFetchHook([](LogicalSessionId id) -> StatusWith<LogicalSessionRecord> {
+ sessions()->setFetchHook([](SignedLogicalSessionId id) -> StatusWith<LogicalSessionRecord> {
// We should not be querying the sessions collection on the next call
ASSERT(false);
return {ErrorCodes::NoSuchSession, "no such session"};
});
- res = cache()->getOwner(lsid);
+ res = cache()->fetchAndPromote(signedLsid);
ASSERT(res.isOK());
- ASSERT(res.getValue() == record.getSessionOwner());
}
// Test that the getFromCache method does not make calls to the sessions collection
TEST_F(LogicalSessionCacheTest, TestCacheHitsOnly) {
- auto record = newRecord();
- auto lsid = record.getLsid();
+ auto signedLsid = SignedLogicalSessionId::gen();
// When the record is not present (and not in the sessions collection), returns an error
- auto res = cache()->getOwnerFromCache(lsid);
+ auto res = cache()->promote(signedLsid);
ASSERT(!res.isOK());
// When the record is not present (but is in the sessions collection), returns an error
- sessions()->add(record);
- res = cache()->getOwnerFromCache(lsid);
+ sessions()->add(LogicalSessionRecord::makeAuthoritativeRecord(signedLsid, service()->now()));
+ res = cache()->promote(signedLsid);
ASSERT(!res.isOK());
// When the record is present, returns the owner
- cache()->getOwner(lsid).status_with_transitional_ignore();
- res = cache()->getOwnerFromCache(lsid);
+ cache()->fetchAndPromote(signedLsid).transitional_ignore();
+ res = cache()->promote(signedLsid);
ASSERT(res.isOK());
- auto fetched = res.getValue();
- ASSERT(res.getValue() == record.getSessionOwner());
}
// Test that fetching from the cache updates the lastUse date of records
TEST_F(LogicalSessionCacheTest, FetchUpdatesLastUse) {
- auto record = newRecord();
- auto lsid = record.getLsid();
+ auto signedLsid = SignedLogicalSessionId::gen();
auto start = service()->now();
// Insert the record into the sessions collection with 'start'
- record.setLastUse(start);
- sessions()->add(record);
+ sessions()->add(LogicalSessionRecord::makeAuthoritativeRecord(signedLsid, start));
// Fast forward time and fetch
service()->fastForward(Milliseconds(500));
ASSERT(start != service()->now());
- auto res = cache()->getOwner(lsid);
+ auto res = cache()->fetchAndPromote(signedLsid);
ASSERT(res.isOK());
// Now that we fetched, lifetime of session should be extended
service()->fastForward(kSessionTimeout - Milliseconds(500));
- res = cache()->getOwner(lsid);
+ res = cache()->fetchAndPromote(signedLsid);
ASSERT(res.isOK());
// We fetched again, so lifetime extended again
service()->fastForward(kSessionTimeout - Milliseconds(10));
- res = cache()->getOwner(lsid);
+ res = cache()->fetchAndPromote(signedLsid);
ASSERT(res.isOK());
// Fast forward and hit-only fetch
service()->fastForward(kSessionTimeout - Milliseconds(10));
- res = cache()->getOwnerFromCache(lsid);
+ res = cache()->promote(signedLsid);
ASSERT(res.isOK());
// Lifetime extended again
service()->fastForward(Milliseconds(11));
- res = cache()->getOwnerFromCache(lsid);
+ res = cache()->promote(signedLsid);
ASSERT(res.isOK());
// Let record expire, we should not be able to get it from the cache
service()->fastForward(kSessionTimeout + Milliseconds(1));
- res = cache()->getOwnerFromCache(lsid);
+ res = cache()->promote(signedLsid);
ASSERT(!res.isOK());
}
// Test the startSession method
TEST_F(LogicalSessionCacheTest, StartSession) {
- auto record = newRecord();
- auto lsid = record.getLsid();
+ auto signedLsid = SignedLogicalSessionId::gen();
// Test starting a new session
- auto res = cache()->startSession(record);
+ auto res = cache()->startSession(signedLsid);
ASSERT(res.isOK());
- ASSERT(sessions()->has(lsid));
+ ASSERT(sessions()->has(signedLsid.getLsid()));
// Try to start a session that is already in the sessions collection and our
// local cache, should fail
- res = cache()->startSession(record);
+ res = cache()->startSession(signedLsid);
ASSERT(!res.isOK());
// Try to start a session that is already in the sessions collection but
// is not in our local cache, should fail
- auto record2 = newRecord();
- sessions()->add(record2);
- res = cache()->startSession(record2);
+ auto record2 = LogicalSessionRecord::makeAuthoritativeRecord(SignedLogicalSessionId::gen(),
+ service()->now());
+ auto signedLsid2 = record2.getSignedLsid();
+ sessions()->add(std::move(record2));
+ res = cache()->startSession(signedLsid2);
ASSERT(!res.isOK());
// Try to start a session that has expired from our cache, and is no
// longer in the sessions collection, should succeed
service()->fastForward(Milliseconds(kSessionTimeout.count() + 5));
- sessions()->remove(lsid);
- ASSERT(!sessions()->has(lsid));
- res = cache()->startSession(record);
+ sessions()->remove(signedLsid.getLsid());
+ ASSERT(!sessions()->has(signedLsid.getLsid()));
+ res = cache()->startSession(signedLsid);
ASSERT(res.isOK());
- ASSERT(sessions()->has(lsid));
+ ASSERT(sessions()->has(signedLsid.getLsid()));
}
// Test that records in the cache are properly refreshed until they expire
TEST_F(LogicalSessionCacheTest, CacheRefreshesOwnRecords) {
// Insert two records into the cache
- auto record1 = newRecord();
- auto record2 = newRecord();
- cache()->startSession(record1).transitional_ignore();
- cache()->startSession(record2).transitional_ignore();
+ auto signedLsid1 = SignedLogicalSessionId::gen();
+ auto signedLsid2 = SignedLogicalSessionId::gen();
+ cache()->startSession(signedLsid1).transitional_ignore();
+ cache()->startSession(signedLsid2).transitional_ignore();
stdx::promise<int> hitRefresh;
auto refreshFuture = hitRefresh.get_future();
@@ -258,8 +241,7 @@ TEST_F(LogicalSessionCacheTest, CacheRefreshesOwnRecords) {
auto refresh2Future = refresh2.get_future();
// Use one of the records
- auto lsid = record1.getLsid();
- auto res = cache()->getOwner(lsid);
+ auto res = cache()->fetchAndPromote(signedLsid1);
ASSERT(res.isOK());
// Advance time so that one record expires
@@ -276,25 +258,25 @@ TEST_F(LogicalSessionCacheTest, CacheRefreshesOwnRecords) {
service()->fastForward(kSessionTimeout - kForceRefresh + Milliseconds(1));
refresh2Future.wait();
- ASSERT_EQ(refresh2Future.get(), lsid);
+ ASSERT_EQ(refresh2Future.get(), signedLsid1.getLsid());
}
// Test that cache deletes records that fail to refresh
TEST_F(LogicalSessionCacheTest, CacheDeletesRecordsThatFailToRefresh) {
// Put two sessions into the cache
- auto record1 = newRecord();
- auto record2 = newRecord();
- cache()->startSession(record1).transitional_ignore();
- cache()->startSession(record2).transitional_ignore();
+ auto signedLsid1 = SignedLogicalSessionId::gen();
+ auto signedLsid2 = SignedLogicalSessionId::gen();
+ cache()->startSession(signedLsid1).transitional_ignore();
+ cache()->startSession(signedLsid2).transitional_ignore();
stdx::promise<void> hitRefresh;
auto refreshFuture = hitRefresh.get_future();
// Record 1 fails to refresh
- sessions()->setRefreshHook([&hitRefresh, &record1](LogicalSessionIdSet sessions) {
+ sessions()->setRefreshHook([&hitRefresh, &signedLsid1](LogicalSessionIdSet sessions) {
ASSERT_EQ(sessions.size(), size_t(2));
hitRefresh.set_value();
- return LogicalSessionIdSet{record1.getLsid()};
+ return LogicalSessionIdSet{signedLsid1.getLsid()};
});
// Force a refresh
@@ -302,72 +284,71 @@ TEST_F(LogicalSessionCacheTest, CacheDeletesRecordsThatFailToRefresh) {
refreshFuture.wait();
// Ensure that one record is still there and the other is gone
- auto res = cache()->getOwnerFromCache(record1.getLsid());
+ auto res = cache()->promote(signedLsid1);
ASSERT(!res.isOK());
- res = cache()->getOwnerFromCache(record2.getLsid());
+ res = cache()->promote(signedLsid2);
ASSERT(res.isOK());
}
// Test that we don't remove records that fail to refresh if they are active on the service
TEST_F(LogicalSessionCacheTest, KeepActiveSessionAliveEvenIfRefreshFails) {
// Put two sessions into the cache, one into the service
- auto record1 = newRecord();
- auto record2 = newRecord();
- cache()->startSession(record1).transitional_ignore();
- service()->add(record1.getLsid());
- cache()->startSession(record2).transitional_ignore();
+ auto signedLsid1 = SignedLogicalSessionId::gen();
+ auto signedLsid2 = SignedLogicalSessionId::gen();
+ cache()->startSession(signedLsid1).transitional_ignore();
+ service()->add(signedLsid1.getLsid());
+ cache()->startSession(signedLsid2).transitional_ignore();
stdx::promise<void> hitRefresh;
auto refreshFuture = hitRefresh.get_future();
- // Record 1 fails to refresh
- sessions()->setRefreshHook([&hitRefresh, &record1](LogicalSessionIdSet sessions) {
+ // SignedLsid 1 fails to refresh
+ sessions()->setRefreshHook([&hitRefresh, &signedLsid1](LogicalSessionIdSet sessions) {
ASSERT_EQ(sessions.size(), size_t(2));
hitRefresh.set_value();
- return LogicalSessionIdSet{record1.getLsid()};
+ return LogicalSessionIdSet{signedLsid1.getLsid()};
});
// Force a refresh
service()->fastForward(kForceRefresh);
refreshFuture.wait();
- // Ensure that both records are still there
- auto res = cache()->getOwnerFromCache(record1.getLsid());
+ // Ensure that both signedLsids are still there
+ auto res = cache()->promote(signedLsid1);
ASSERT(res.isOK());
- res = cache()->getOwnerFromCache(record2.getLsid());
+ res = cache()->promote(signedLsid2);
ASSERT(res.isOK());
}
-// Test that session cache properly expires records after 30 minutes of no use
+// Test that session cache properly expires signedLsids after 30 minutes of no use
TEST_F(LogicalSessionCacheTest, BasicSessionExpiration) {
- // Insert a record
- auto record = newRecord();
- cache()->startSession(record).transitional_ignore();
- auto res = cache()->getOwnerFromCache(record.getLsid());
+ // Insert a signedLsid
+ auto signedLsid = SignedLogicalSessionId::gen();
+ cache()->startSession(signedLsid).transitional_ignore();
+ auto res = cache()->promote(signedLsid);
ASSERT(res.isOK());
// Force it to expire
service()->fastForward(Milliseconds(kSessionTimeout.count() + 5));
// Check that it is no longer in the cache
- res = cache()->getOwnerFromCache(record.getLsid());
+ res = cache()->promote(signedLsid);
ASSERT(!res.isOK());
}
// Test that we keep refreshing sessions that are active on the service
TEST_F(LogicalSessionCacheTest, LongRunningQueriesAreRefreshed) {
- auto record = newRecord();
- auto lsid = record.getLsid();
+ auto signedLsid = SignedLogicalSessionId::gen();
- // Insert one active record on the service, none in the cache
- service()->add(lsid);
+ // Insert one active signedLsid on the service, none in the cache
+ service()->add(signedLsid.getLsid());
stdx::mutex mutex;
stdx::condition_variable cv;
int count = 0;
- sessions()->setRefreshHook([&cv, &mutex, &count, &lsid](LogicalSessionIdSet sessions) {
- ASSERT_EQ(*(sessions.begin()), lsid);
+ sessions()->setRefreshHook([&cv, &mutex, &count, &signedLsid](LogicalSessionIdSet sessions) {
+ ASSERT_EQ(*(sessions.begin()), signedLsid.getLsid());
{
stdx::unique_lock<stdx::mutex> lk(mutex);
count++;
@@ -397,7 +378,7 @@ TEST_F(LogicalSessionCacheTest, LongRunningQueriesAreRefreshed) {
// Wait until the next job has been scheduled
waitUntilRefreshScheduled();
- // Force another refresh, check that it refreshes that active record again
+ // Force another refresh, check that it refreshes that active signedLsid again
service()->fastForward(kForceRefresh);
{
stdx::unique_lock<stdx::mutex> lk(mutex);
@@ -405,18 +386,18 @@ TEST_F(LogicalSessionCacheTest, LongRunningQueriesAreRefreshed) {
}
}
-// Test that the set of records we refresh is a sum of cached + active records
-TEST_F(LogicalSessionCacheTest, RefreshCachedAndServiceRecordsTogether) {
+// Test that the set of signedLsids we refresh is a sum of cached + active signedLsids
+TEST_F(LogicalSessionCacheTest, RefreshCachedAndServiceSignedLsidsTogether) {
// Put one session into the cache, one into the service
- auto record1 = newRecord();
- service()->add(record1.getLsid());
- auto record2 = newRecord();
- cache()->startSession(record2).transitional_ignore();
+ auto signedLsid1 = SignedLogicalSessionId::gen();
+ service()->add(signedLsid1.getLsid());
+ auto signedLsid2 = SignedLogicalSessionId::gen();
+ cache()->startSession(signedLsid2).transitional_ignore();
stdx::promise<void> hitRefresh;
auto refreshFuture = hitRefresh.get_future();
- // Both records refresh
+ // Both signedLsids refresh
sessions()->setRefreshHook([&hitRefresh](LogicalSessionIdSet sessions) {
ASSERT_EQ(sessions.size(), size_t(2));
hitRefresh.set_value();
@@ -428,18 +409,18 @@ TEST_F(LogicalSessionCacheTest, RefreshCachedAndServiceRecordsTogether) {
refreshFuture.wait();
}
-// Test large sets of cache-only session records
-TEST_F(LogicalSessionCacheTest, ManyRecordsInCacheRefresh) {
+// Test large sets of cache-only session signedLsids
+TEST_F(LogicalSessionCacheTest, ManySignedLsidsInCacheRefresh) {
int count = LogicalSessionCache::kLogicalSessionCacheDefaultCapacity;
for (int i = 0; i < count; i++) {
- auto record = newRecord();
- cache()->startSession(record).transitional_ignore();
+ auto signedLsid = SignedLogicalSessionId::gen();
+ cache()->startSession(signedLsid).transitional_ignore();
}
stdx::promise<void> hitRefresh;
auto refreshFuture = hitRefresh.get_future();
- // Check that all records refresh
+ // Check that all signedLsids refresh
sessions()->setRefreshHook([&hitRefresh, &count](LogicalSessionIdSet sessions) {
ASSERT_EQ(sessions.size(), size_t(count));
hitRefresh.set_value();
@@ -451,18 +432,18 @@ TEST_F(LogicalSessionCacheTest, ManyRecordsInCacheRefresh) {
refreshFuture.wait();
}
-// Test larger sets of service-only session records
+// Test larger sets of service-only session signedLsids
TEST_F(LogicalSessionCacheTest, ManyLongRunningSessionsRefresh) {
int count = LogicalSessionCache::kLogicalSessionCacheDefaultCapacity;
for (int i = 0; i < count; i++) {
- auto record = newRecord();
- service()->add(record.getLsid());
+ auto lsid = LogicalSessionId::gen();
+ service()->add(lsid);
}
stdx::promise<void> hitRefresh;
auto refreshFuture = hitRefresh.get_future();
- // Check that all records refresh
+ // Check that all signedLsids refresh
sessions()->setRefreshHook([&hitRefresh, &count](LogicalSessionIdSet sessions) {
ASSERT_EQ(sessions.size(), size_t(count));
hitRefresh.set_value();
@@ -478,11 +459,11 @@ TEST_F(LogicalSessionCacheTest, ManyLongRunningSessionsRefresh) {
TEST_F(LogicalSessionCacheTest, ManySessionsRefreshComboDeluxe) {
int count = LogicalSessionCache::kLogicalSessionCacheDefaultCapacity;
for (int i = 0; i < count; i++) {
- auto record = newRecord();
- service()->add(record.getLsid());
+ auto lsid = LogicalSessionId::gen();
+ service()->add(lsid);
- auto record2 = newRecord();
- cache()->startSession(record2).transitional_ignore();
+ auto lsid2 = SignedLogicalSessionId::gen();
+ cache()->startSession(lsid2).transitional_ignore();
}
stdx::mutex mutex;
@@ -490,7 +471,7 @@ TEST_F(LogicalSessionCacheTest, ManySessionsRefreshComboDeluxe) {
int refreshes = 0;
int nRefreshed = 0;
- // Check that all records refresh successfully
+ // Check that all signedLsids refresh successfully
sessions()->setRefreshHook(
[&refreshes, &mutex, &cv, &nRefreshed](LogicalSessionIdSet sessions) {
{
@@ -550,7 +531,7 @@ TEST_F(LogicalSessionCacheTest, ManySessionsRefreshComboDeluxe) {
cv.wait(lk, [&refreshes] { return refreshes == 3; });
}
- // Since all but one record failed to refresh, third set should just have one record
+ // Since all but one signedLsid failed to refresh, third set should just have one signedLsid
ASSERT_EQ(nRefreshed, 1);
}
diff --git a/src/mongo/db/logical_session_id.h b/src/mongo/db/logical_session_id.h
index ef73bc0fcd1..230b892d934 100644
--- a/src/mongo/db/logical_session_id.h
+++ b/src/mongo/db/logical_session_id.h
@@ -46,21 +46,25 @@ const TxnNumber kUninitializedTxnNumber = -1;
class BSONObjBuilder;
/**
- * A 128-bit identifier for a logical session.
+ * A 128-bit unique identifier for a logical session.
*/
class LogicalSessionId : public Logical_session_id {
public:
LogicalSessionId();
LogicalSessionId(Logical_session_id&& lsid);
+ friend class Logical_session_id;
+ friend class Logical_session_record;
+ friend class SignedLogicalSessionId;
+
/**
* Create and return a new LogicalSessionId with a random UUID.
*/
static LogicalSessionId gen();
/**
- * If the given string represents a valid LogicalSessionId, constructs and returns,
- * the id, otherwise returns an error.
+ * If the given string represents a valid UUID, constructs and returns
+ * a new LogicalSessionId. Otherwise returns an error.
*/
static StatusWith<LogicalSessionId> parse(const std::string& s);
diff --git a/src/mongo/db/logical_session_id_test.cpp b/src/mongo/db/logical_session_id_test.cpp
index 13f3375ab33..c9e86639c4a 100644
--- a/src/mongo/db/logical_session_id_test.cpp
+++ b/src/mongo/db/logical_session_id_test.cpp
@@ -59,7 +59,10 @@ TEST(LogicalSessionIdTest, ToAndFromStringTest) {
TEST(LogicalSessionIdTest, FromBSONTest) {
auto uuid = UUID::gen();
- auto bson = BSON("id" << uuid.toBSON());
+
+ BSONObjBuilder b;
+ b.append("id", uuid.toBSON());
+ auto bson = b.done();
auto lsid = LogicalSessionId::parse(bson);
ASSERT_EQUALS(lsid.toString(), uuid.toString());
@@ -73,7 +76,6 @@ TEST(LogicalSessionIdTest, FromBSONTest) {
<< "there")),
UserException);
- // TODO: use these and add more once there is bindata
ASSERT_THROWS(LogicalSessionId::parse(BSON("id"
<< "not a session id!")),
UserException);
diff --git a/src/mongo/db/logical_session_record.cpp b/src/mongo/db/logical_session_record.cpp
index d6904a36f5f..51ca629962b 100644
--- a/src/mongo/db/logical_session_record.cpp
+++ b/src/mongo/db/logical_session_record.cpp
@@ -39,25 +39,17 @@ namespace mongo {
StatusWith<LogicalSessionRecord> LogicalSessionRecord::parse(const BSONObj& bson) {
try {
IDLParserErrorContext ctxt("logical session record");
-
LogicalSessionRecord record;
record.parseProtected(ctxt, bson);
-
- auto owner = record.getOwner();
- UserName user{owner.getUserName(), owner.getDbName()};
- record._owner = std::make_pair(user, owner.getUserId());
-
return record;
} catch (std::exception e) {
return exceptionToStatus();
}
}
-LogicalSessionRecord LogicalSessionRecord::makeAuthoritativeRecord(LogicalSessionId id,
- UserName user,
- boost::optional<OID> userId,
+LogicalSessionRecord LogicalSessionRecord::makeAuthoritativeRecord(SignedLogicalSessionId id,
Date_t now) {
- return LogicalSessionRecord(std::move(id), std::move(user), std::move(userId), std::move(now));
+ return LogicalSessionRecord(std::move(id), std::move(now));
}
BSONObj LogicalSessionRecord::toBSON() const {
@@ -68,28 +60,13 @@ BSONObj LogicalSessionRecord::toBSON() const {
std::string LogicalSessionRecord::toString() const {
return str::stream() << "LogicalSessionRecord"
- << " Id: '" << getLsid() << "'"
- << " Owner name: '" << getOwner().getUserName() << "'"
+ << " Id: '" << getSignedLsid() << "'"
<< " Last-use: " << getLastUse().toString();
}
-LogicalSessionRecord::LogicalSessionRecord(LogicalSessionId id,
- UserName user,
- boost::optional<OID> userId,
- Date_t now)
- : _owner(std::make_pair(std::move(user), std::move(userId))) {
- setLsid(std::move(id));
+LogicalSessionRecord::LogicalSessionRecord(SignedLogicalSessionId id, Date_t now) {
+ setSignedLsid(std::move(id));
setLastUse(now);
-
- Session_owner owner;
- owner.setUserName(_owner.first.getUser());
- owner.setDbName(_owner.first.getDB());
- owner.setUserId(_owner.second);
- setOwner(std::move(owner));
-}
-
-LogicalSessionRecord::Owner LogicalSessionRecord::getSessionOwner() const {
- return _owner;
}
} // namespace mongo
diff --git a/src/mongo/db/logical_session_record.h b/src/mongo/db/logical_session_record.h
index e576cd82eb5..6d8662327ab 100644
--- a/src/mongo/db/logical_session_record.h
+++ b/src/mongo/db/logical_session_record.h
@@ -32,10 +32,8 @@
#include <utility>
#include "mongo/bson/bsonobj.h"
-#include "mongo/bson/oid.h"
-#include "mongo/db/auth/user_name.h"
-#include "mongo/db/logical_session_id.h"
#include "mongo/db/logical_session_record_gen.h"
+#include "mongo/db/signed_logical_session_id.h"
#include "mongo/util/time_support.h"
namespace mongo {
@@ -46,17 +44,12 @@ namespace mongo {
* The BSON representation of a session record follows this form:
*
* {
- * lsid : LogicalSessionId,
- * lastUse : Date_t,
- * owner : {
- * user : UserName,
- * userId : OID
- * }
+ * lsid : SignedLogicalSessionId,
+ * lastUse : Date_t
+ * }
*/
class LogicalSessionRecord : public Logical_session_record {
public:
- using Owner = std::pair<UserName, boost::optional<OID>>;
-
/**
* Constructs and returns a LogicalSessionRecord from a BSON representation,
* or throws an error. For IDL.
@@ -69,10 +62,7 @@ public:
* only be used when the caller is intending to make a new authoritative
* record and subsequently insert that record into the sessions collection.
*/
- static LogicalSessionRecord makeAuthoritativeRecord(LogicalSessionId id,
- UserName user,
- boost::optional<OID> userId,
- Date_t now);
+ static LogicalSessionRecord makeAuthoritativeRecord(SignedLogicalSessionId id, Date_t now);
/**
* Return a BSON representation of this session record.
@@ -85,34 +75,17 @@ public:
std::string toString() const;
inline bool operator==(const LogicalSessionRecord& rhs) const {
- return getLsid() == rhs.getLsid() && getSessionOwner() == rhs.getSessionOwner();
+ return getSignedLsid() == rhs.getSignedLsid();
}
inline bool operator!=(const LogicalSessionRecord& rhs) const {
return !(*this == rhs);
}
- /**
- * Return the username and id of the User who owns this session. Only a User
- * that matches both the name and id returned by this method should be
- * permitted to use this session.
- *
- * Note: if the returned optional OID is set to boost::none, this implies that
- * the owning user is a pre-3.6 user that has no id. In this case, only a User
- * with a matching UserName who also has an unset optional id should be
- * permitted to use this session.
- */
- Owner getSessionOwner() const;
-
private:
LogicalSessionRecord() = default;
- LogicalSessionRecord(LogicalSessionId id,
- UserName user,
- boost::optional<OID> userId,
- Date_t now);
-
- Owner _owner;
+ LogicalSessionRecord(SignedLogicalSessionId id, Date_t now);
};
inline std::ostream& operator<<(std::ostream& s, const LogicalSessionRecord& record) {
diff --git a/src/mongo/db/logical_session_record.idl b/src/mongo/db/logical_session_record.idl
index ac4df1c3b6c..25e2ac65b76 100644
--- a/src/mongo/db/logical_session_record.idl
+++ b/src/mongo/db/logical_session_record.idl
@@ -21,40 +21,28 @@ global:
cpp_namespace: "mongo"
cpp_includes:
- "mongo/db/logical_session_id.h"
+ - "mongo/db/signed_logical_session_id.h"
imports:
- "mongo/idl/basic_types.idl"
- - "mongo/db/logical_session_id.idl"
+ - "mongo/db/signed_logical_session_id.idl"
types:
-
- LogicalSessionIdIDL:
- description: "IDL representation of the LogicalSessionId cpp type"
+ SignedLogicalSessionIdIDL:
+ description: "IDL representation of the SignedLogicalSessionId cpp type"
bson_serialization_type: object
- cpp_type: "mongo::LogicalSessionId"
- deserializer: "mongo::LogicalSessionId::parse"
- serializer: "mongo::LogicalSessionId::toBSON"
+ cpp_type: "mongo::SignedLogicalSessionId"
+ deserializer: "mongo::SignedLogicalSessionId::parse"
+ serializer: "mongo::SignedLogicalSessionId::toBSON"
structs:
- session_owner:
- description: "A sub-document of a session record with owner info"
- fields:
- userName: string
- dbName: string
- userId:
- type: objectid
- optional: true
-
logical_session_record:
description: "A struct representing a LogicalSessionRecord"
fields:
- lsid:
+ signedLsid:
description: "The id for this session record"
- type: LogicalSessionIdIDL
+ type: SignedLogicalSessionIdIDL
lastUse:
description: "The time at which this record was last used. Note: the date expressed in this in-memory record object may not match the date on the authoritative record for this session, which is stored in the sessions collection."
type: date
- owner:
- description: "The username and id of the User who owns this session."
- type: session_owner
diff --git a/src/mongo/db/logical_session_record_test.cpp b/src/mongo/db/logical_session_record_test.cpp
index a307095a21b..659d23efa5b 100644
--- a/src/mongo/db/logical_session_record_test.cpp
+++ b/src/mongo/db/logical_session_record_test.cpp
@@ -44,32 +44,18 @@ namespace {
TEST(LogicalSessionRecordTest, ToAndFromBSONTest) {
// Round-trip a BSON obj
auto lsid = LogicalSessionId::gen();
+ SignedLogicalSessionId slsid{lsid, boost::none, 1, SHA1Block{}};
auto lastUse = Date_t::now();
- auto oid = OID::gen();
-
- auto owner = BSON("userName"
- << "Sam"
- << "dbName"
- << "test"
- << "userId"
- << oid);
- auto bson = BSON("lsid" << lsid.toBSON() << "lastUse" << lastUse << "owner" << owner);
+ auto bson = BSON("signedLsid" << slsid.toBSON() << "lastUse" << lastUse);
// Make a session record out of this
auto res = LogicalSessionRecord::parse(bson);
ASSERT(res.isOK());
auto record = res.getValue();
- ASSERT_EQ(record.getLsid(), lsid);
+ ASSERT_EQ(record.getSignedLsid(), slsid);
ASSERT_EQ(record.getLastUse(), lastUse);
- auto recordOwner = record.getSessionOwner();
-
- ASSERT_EQ(recordOwner.first.getUser(), "Sam");
- ASSERT_EQ(recordOwner.first.getDB(), "test");
- ASSERT(recordOwner.second);
- ASSERT_EQ(*(recordOwner.second), oid);
-
// Dump back to bson, make sure we get the same thing
auto bsonDump = record.toBSON();
ASSERT_EQ(bsonDump.woCompare(bson), 0);
diff --git a/src/mongo/db/logical_time_validator.cpp b/src/mongo/db/logical_time_validator.cpp
index d17a728435c..98e4f740ff6 100644
--- a/src/mongo/db/logical_time_validator.cpp
+++ b/src/mongo/db/logical_time_validator.cpp
@@ -35,7 +35,7 @@
#include "mongo/db/auth/action_type.h"
#include "mongo/db/auth/authorization_session.h"
#include "mongo/db/auth/privilege.h"
-#include "mongo/db/keys_collection_manager.h"
+#include "mongo/db/keys_collection_manager_sharding.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/service_context.h"
#include "mongo/util/assert_util.h"
@@ -77,8 +77,9 @@ void LogicalTimeValidator::set(ServiceContext* service,
validator = std::move(newValidator);
}
-LogicalTimeValidator::LogicalTimeValidator(std::unique_ptr<KeysCollectionManager> keyManager)
- : _keyManager(std::move(keyManager)) {}
+LogicalTimeValidator::LogicalTimeValidator(
+ std::shared_ptr<KeysCollectionManagerSharding> keyManager)
+ : _keyManager(keyManager) {}
SignedLogicalTime LogicalTimeValidator::_getProof(const KeysCollectionDocument& keyDoc,
LogicalTime newTime) {
@@ -103,7 +104,7 @@ SignedLogicalTime LogicalTimeValidator::_getProof(const KeysCollectionDocument&
}
SignedLogicalTime LogicalTimeValidator::trySignLogicalTime(const LogicalTime& newTime) {
- auto keyStatusWith = _keyManager->getKeyForSigning(newTime);
+ auto keyStatusWith = _keyManager->getKeyForSigning(nullptr, newTime);
auto keyStatus = keyStatusWith.getStatus();
if (keyStatus == ErrorCodes::KeyNotFound) {
@@ -117,13 +118,13 @@ SignedLogicalTime LogicalTimeValidator::trySignLogicalTime(const LogicalTime& ne
SignedLogicalTime LogicalTimeValidator::signLogicalTime(OperationContext* opCtx,
const LogicalTime& newTime) {
- auto keyStatusWith = _keyManager->getKeyForSigning(newTime);
+ auto keyStatusWith = _keyManager->getKeyForSigning(nullptr, newTime);
auto keyStatus = keyStatusWith.getStatus();
while (keyStatus == ErrorCodes::KeyNotFound) {
_keyManager->refreshNow(opCtx);
- keyStatusWith = _keyManager->getKeyForSigning(newTime);
+ keyStatusWith = _keyManager->getKeyForSigning(nullptr, newTime);
keyStatus = keyStatusWith.getStatus();
if (keyStatus == ErrorCodes::KeyNotFound) {
diff --git a/src/mongo/db/logical_time_validator.h b/src/mongo/db/logical_time_validator.h
index 44bb34922e5..7d11450bf1c 100644
--- a/src/mongo/db/logical_time_validator.h
+++ b/src/mongo/db/logical_time_validator.h
@@ -39,7 +39,7 @@ namespace mongo {
class OperationContext;
class ServiceContext;
class KeysCollectionDocument;
-class KeysCollectionManager;
+class KeysCollectionManagerSharding;
/**
* This is responsible for signing cluster times that can be used to sent to other servers and
@@ -52,7 +52,11 @@ public:
static LogicalTimeValidator* get(OperationContext* ctx);
static void set(ServiceContext* service, std::unique_ptr<LogicalTimeValidator> validator);
- explicit LogicalTimeValidator(std::unique_ptr<KeysCollectionManager> keyManager);
+ /**
+ * Constructs a new LogicalTimeValidator that uses the given key manager. The passed-in
+ * key manager must outlive this object.
+ */
+ explicit LogicalTimeValidator(std::shared_ptr<KeysCollectionManagerSharding> keyManager);
/**
* Tries to sign the newTime with a valid signature. Can return an empty signature and keyId
@@ -103,7 +107,7 @@ private:
stdx::mutex _mutex;
SignedLogicalTime _lastSeenValidTime;
TimeProofService _timeProofService;
- std::unique_ptr<KeysCollectionManager> _keyManager;
+ std::shared_ptr<KeysCollectionManagerSharding> _keyManager;
};
} // namespace mongo
diff --git a/src/mongo/db/logical_time_validator_test.cpp b/src/mongo/db/logical_time_validator_test.cpp
index 9a4c289d983..8292ee75db4 100644
--- a/src/mongo/db/logical_time_validator_test.cpp
+++ b/src/mongo/db/logical_time_validator_test.cpp
@@ -30,6 +30,7 @@
#include "mongo/bson/timestamp.h"
#include "mongo/db/keys_collection_manager.h"
+#include "mongo/db/keys_collection_manager_sharding.h"
#include "mongo/db/logical_clock.h"
#include "mongo/db/logical_time.h"
#include "mongo/db/logical_time_validator.h"
@@ -70,10 +71,9 @@ protected:
const LogicalTime currentTime(LogicalTime(Timestamp(1, 0)));
LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime);
- auto keyManager =
- stdx::make_unique<KeysCollectionManager>("dummy", catalogClient, Seconds(1000));
- _keyManager = keyManager.get();
- _validator = stdx::make_unique<LogicalTimeValidator>(std::move(keyManager));
+ _keyManager =
+ std::make_shared<KeysCollectionManagerSharding>("dummy", catalogClient, Seconds(1000));
+ _validator = stdx::make_unique<LogicalTimeValidator>(_keyManager);
_validator->init(operationContext()->getServiceContext());
}
@@ -97,7 +97,7 @@ protected:
private:
std::unique_ptr<LogicalTimeValidator> _validator;
- KeysCollectionManager* _keyManager;
+ std::shared_ptr<KeysCollectionManagerSharding> _keyManager;
};
TEST_F(LogicalTimeValidatorTest, GetTimeWithIncreasingTimes) {
@@ -133,7 +133,7 @@ TEST_F(LogicalTimeValidatorTest, ValidateErrorsOnInvalidTime) {
refreshKeyManager();
auto newTime = validator()->trySignLogicalTime(t1);
- TimeProofService::TimeProof invalidProof = {{1, 2, 3}};
+ TimeProofService::TimeProof invalidProof = {{{1, 2, 3}}};
SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, newTime.getKeyId());
// ASSERT_THROWS_CODE(validator()->validate(operationContext(), invalidTime), DBException,
// ErrorCodes::TimeProofMismatch);
@@ -156,7 +156,7 @@ TEST_F(LogicalTimeValidatorTest, ValidateErrorsOnInvalidTimeWithImplicitRefresh)
LogicalTime t1(Timestamp(20, 0));
auto newTime = validator()->signLogicalTime(operationContext(), t1);
- TimeProofService::TimeProof invalidProof = {{1, 2, 3}};
+ TimeProofService::TimeProof invalidProof = {{{1, 2, 3}}};
SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, newTime.getKeyId());
// ASSERT_THROWS_CODE(validator()->validate(operationContext(), invalidTime), DBException,
// ErrorCodes::TimeProofMismatch);
diff --git a/src/mongo/db/service_context.cpp b/src/mongo/db/service_context.cpp
index f4083be9901..05d89c4cc6f 100644
--- a/src/mongo/db/service_context.cpp
+++ b/src/mongo/db/service_context.cpp
@@ -161,15 +161,6 @@ PeriodicRunner* ServiceContext::getPeriodicRunner() const {
return _runner.get();
}
-void ServiceContext::setLogicalSessionCache(std::unique_ptr<LogicalSessionCache> cache)& {
- invariant(!_sessionCache);
- _sessionCache = std::move(cache);
-}
-
-LogicalSessionCache* ServiceContext::getLogicalSessionCache() const& {
- return _sessionCache.get();
-}
-
transport::TransportLayer* ServiceContext::getTransportLayer() const {
return _transportLayer.get();
}
diff --git a/src/mongo/db/service_context.h b/src/mongo/db/service_context.h
index 3cf830c3fe7..9f97445d69f 100644
--- a/src/mongo/db/service_context.h
+++ b/src/mongo/db/service_context.h
@@ -31,7 +31,8 @@
#include <vector>
#include "mongo/base/disallow_copying.h"
-#include "mongo/db/logical_session_cache.h"
+#include "mongo/db/keys_collection_manager.h"
+#include "mongo/db/logical_session_id.h"
#include "mongo/db/storage/storage_engine.h"
#include "mongo/platform/atomic_word.h"
#include "mongo/platform/unordered_set.h"
@@ -266,6 +267,28 @@ public:
virtual StorageEngine* getGlobalStorageEngine() = 0;
//
+ // Key manager, for HMAC keys.
+ //
+
+ /**
+ * Sets the key manager on this service context.
+ */
+ void setKeyManager(std::shared_ptr<KeysCollectionManager> keyManager) & {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+ _keyManager = std::move(keyManager);
+ }
+
+ /**
+ * Returns a pointer to the keys collection manager owned by this service context.
+ */
+ std::shared_ptr<KeysCollectionManager> getKeyManager() & {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+ return _keyManager;
+ }
+
+ std::shared_ptr<KeysCollectionManager> getKeyManager() && = delete;
+
+ //
// Global operation management. This may not belong here and there may be too many methods
// here.
//
@@ -329,21 +352,6 @@ public:
PeriodicRunner* getPeriodicRunner() const;
//
- // Logical sessions.
- //
-
- /**
- * Set the logical session cache on this service context.
- */
- void setLogicalSessionCache(std::unique_ptr<LogicalSessionCache> cache) &;
-
- /**
- * Return a pointer to the logical session cache on this service context.
- */
- LogicalSessionCache* getLogicalSessionCache() const&;
- LogicalSessionCache* getLogicalSessionCache() && = delete;
-
- //
// Transport.
//
@@ -459,14 +467,14 @@ private:
void _killOperation_inlock(OperationContext* opCtx, ErrorCodes::Error killCode);
/**
- * The periodic runner.
+ * The key manager.
*/
- std::unique_ptr<PeriodicRunner> _runner;
+ std::shared_ptr<KeysCollectionManager> _keyManager;
/**
- * The logical session cache.
+ * The periodic runner.
*/
- std::unique_ptr<LogicalSessionCache> _sessionCache;
+ std::unique_ptr<PeriodicRunner> _runner;
/**
* The TransportLayer.
diff --git a/src/mongo/db/service_liason.cpp b/src/mongo/db/service_liason.cpp
index e12a243b6e2..a4a576c17f6 100644
--- a/src/mongo/db/service_liason.cpp
+++ b/src/mongo/db/service_liason.cpp
@@ -30,8 +30,79 @@
#include "mongo/db/service_liason.h"
+#include "mongo/db/keys_collection_manager_zero.h"
+#include "mongo/db/logical_clock.h"
+#include "mongo/db/service_context.h"
+
namespace mongo {
+namespace {
+
+const int kSignatureSize = sizeof(UUID) + sizeof(OID);
+
+SHA1Block computeSignature(const SignedLogicalSessionId* id, TimeProofService::Key key) {
+ // Write the uuid and user id to a block for signing.
+ char signatureBlock[kSignatureSize] = {0};
+ DataRangeCursor cursor(signatureBlock, signatureBlock + kSignatureSize);
+ auto res = cursor.writeAndAdvance<ConstDataRange>(id->getLsid().getId().toCDR());
+ invariant(res.isOK());
+ if (auto userId = id->getUserId()) {
+ res = cursor.writeAndAdvance<ConstDataRange>(userId->toCDR());
+ invariant(res.isOK());
+ }
+
+ // Compute the signature.
+ return SHA1Block::computeHmac(
+ key.data(), key.size(), reinterpret_cast<uint8_t*>(signatureBlock), kSignatureSize);
+}
+
+KeysCollectionManagerZero kKeysCollectionManagerZero{"HMAC"};
+
+} // namespace
+
ServiceLiason::~ServiceLiason() = default;
+StatusWith<SignedLogicalSessionId> ServiceLiason::signLsid(OperationContext* opCtx,
+ LogicalSessionId* lsid,
+ boost::optional<OID> userId) {
+ auto& keyManager = kKeysCollectionManagerZero;
+
+ auto logicalTime = LogicalClock::get(_context())->getClusterTime();
+ auto res = keyManager.getKeyForSigning(opCtx, logicalTime);
+ if (!res.isOK()) {
+ return res.getStatus();
+ }
+
+ SignedLogicalSessionId signedLsid;
+ signedLsid.setUserId(std::move(userId));
+ signedLsid.setLsid(*lsid);
+
+ auto keyDoc = res.getValue();
+ signedLsid.setKeyId(keyDoc.getKeyId());
+
+ auto signature = computeSignature(&signedLsid, keyDoc.getKey());
+ signedLsid.setSignature(std::move(signature));
+
+ return signedLsid;
+}
+
+Status ServiceLiason::validateLsid(OperationContext* opCtx, const SignedLogicalSessionId& id) {
+ auto& keyManager = kKeysCollectionManagerZero;
+
+ // Attempt to get the correct key.
+ auto logicalTime = LogicalClock::get(_context())->getClusterTime();
+ auto res = keyManager.getKeyForValidation(opCtx, id.getKeyId(), logicalTime);
+ if (!res.isOK()) {
+ return res.getStatus();
+ }
+
+ // Re-compute the signature, and see that it matches.
+ auto signature = computeSignature(&id, res.getValue().getKey());
+ if (signature != id.getSignature()) {
+ return {ErrorCodes::NoSuchSession, "Signature validation failed."};
+ }
+
+ return Status::OK();
+}
+
} // namespace mongo
diff --git a/src/mongo/db/service_liason.h b/src/mongo/db/service_liason.h
index 2389245f6a5..dca11f50182 100644
--- a/src/mongo/db/service_liason.h
+++ b/src/mongo/db/service_liason.h
@@ -29,12 +29,15 @@
#pragma once
#include "mongo/db/logical_session_id.h"
+#include "mongo/db/signed_logical_session_id.h"
#include "mongo/stdx/functional.h"
#include "mongo/util/periodic_runner.h"
#include "mongo/util/time_support.h"
namespace mongo {
+class ServiceContext;
+
/**
* A service-dependent type for the LogicalSessionCache to use to find the
* current time, schedule periodic refresh jobs, and get a list of sessions
@@ -73,6 +76,27 @@ public:
* Return the current time.
*/
virtual Date_t now() const = 0;
+
+ /**
+ * Generates and sets a signature for the fields in this LogicalSessionId.
+ *
+ * If this method is not able to acquire a key to perform the signature
+ * this call will return an error.
+ */
+ StatusWith<SignedLogicalSessionId> signLsid(OperationContext* opCtx,
+ LogicalSessionId* lsid,
+ boost::optional<OID> userId);
+
+ /**
+ * Validates that this LogicalSessionId was signed with the correct key.
+ */
+ Status validateLsid(OperationContext* opCtx, const SignedLogicalSessionId& id);
+
+protected:
+ /**
+ * Returns the service context.
+ */
+ virtual ServiceContext* _context() = 0;
};
} // namespace mongo
diff --git a/src/mongo/db/service_liason_mock.h b/src/mongo/db/service_liason_mock.h
index bfd2d281ca1..39323057e59 100644
--- a/src/mongo/db/service_liason_mock.h
+++ b/src/mongo/db/service_liason_mock.h
@@ -28,6 +28,8 @@
#pragma once
+#include "mongo/db/service_context.h"
+#include "mongo/db/service_context_noop.h"
#include "mongo/db/service_liason.h"
#include "mongo/executor/async_timer_mock.h"
#include "mongo/platform/atomic_word.h"
@@ -101,8 +103,14 @@ public:
return _impl->join();
}
+protected:
+ ServiceContext* _context() override {
+ return _serviceContext.get();
+ }
+
private:
std::shared_ptr<MockServiceLiasonImpl> _impl;
+ std::unique_ptr<ServiceContextNoop> _serviceContext;
};
} // namespace mongo
diff --git a/src/mongo/db/service_liason_mongod.cpp b/src/mongo/db/service_liason_mongod.cpp
index fdbf811de18..08c49127936 100644
--- a/src/mongo/db/service_liason_mongod.cpp
+++ b/src/mongo/db/service_liason_mongod.cpp
@@ -88,4 +88,8 @@ Date_t ServiceLiasonMongod::now() const {
return getGlobalServiceContext()->getFastClockSource()->now();
}
+ServiceContext* ServiceLiasonMongod::_context() {
+ return getGlobalServiceContext();
+}
+
} // namespace mongo
diff --git a/src/mongo/db/service_liason_mongod.h b/src/mongo/db/service_liason_mongod.h
index 21c677feb8c..31304b94573 100644
--- a/src/mongo/db/service_liason_mongod.h
+++ b/src/mongo/db/service_liason_mongod.h
@@ -57,6 +57,12 @@ public:
void join() override;
Date_t now() const override;
+
+protected:
+ /**
+ * Returns the service context.
+ */
+ ServiceContext* _context() override;
};
} // namespace mongo
diff --git a/src/mongo/db/sessions_collection.h b/src/mongo/db/sessions_collection.h
index ec59f28c597..7692e634275 100644
--- a/src/mongo/db/sessions_collection.h
+++ b/src/mongo/db/sessions_collection.h
@@ -30,6 +30,7 @@
#include "mongo/db/logical_session_id.h"
#include "mongo/db/logical_session_record.h"
+#include "mongo/db/signed_logical_session_id.h"
namespace mongo {
@@ -44,10 +45,10 @@ public:
virtual ~SessionsCollection();
/**
- * Returns a LogicalSessionRecord for the given LogicalSessionId. This method
+ * Returns a LogicalSessionRecord for the given session id. This method
* may run networking operations on the calling thread.
*/
- virtual StatusWith<LogicalSessionRecord> fetchRecord(LogicalSessionId lsid) = 0;
+ virtual StatusWith<LogicalSessionRecord> fetchRecord(SignedLogicalSessionId id) = 0;
/**
* Inserts the given record into the sessions collection. This method may run
diff --git a/src/mongo/db/sessions_collection_mock.cpp b/src/mongo/db/sessions_collection_mock.cpp
index 4abc5a326d1..1079605badd 100644
--- a/src/mongo/db/sessions_collection_mock.cpp
+++ b/src/mongo/db/sessions_collection_mock.cpp
@@ -65,8 +65,9 @@ void MockSessionsCollectionImpl::clearHooks() {
_remove = stdx::bind(&MockSessionsCollectionImpl::_removeRecords, this, stdx::placeholders::_1);
}
-StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::fetchRecord(LogicalSessionId lsid) {
- return _fetch(std::move(lsid));
+StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::fetchRecord(
+ SignedLogicalSessionId id) {
+ return _fetch(std::move(id));
}
Status MockSessionsCollectionImpl::insertRecord(LogicalSessionRecord record) {
@@ -83,7 +84,7 @@ void MockSessionsCollectionImpl::removeRecords(LogicalSessionIdSet sessions) {
void MockSessionsCollectionImpl::add(LogicalSessionRecord record) {
stdx::unique_lock<stdx::mutex> lk(_mutex);
- _sessions.insert({record.getLsid(), std::move(record)});
+ _sessions.insert({record.getSignedLsid().getLsid(), std::move(record)});
}
void MockSessionsCollectionImpl::remove(LogicalSessionId lsid) {
@@ -105,11 +106,12 @@ const MockSessionsCollectionImpl::SessionMap& MockSessionsCollectionImpl::sessio
return _sessions;
}
-StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::_fetchRecord(LogicalSessionId lsid) {
+StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::_fetchRecord(
+ SignedLogicalSessionId id) {
stdx::unique_lock<stdx::mutex> lk(_mutex);
// If we do not have this record, return an error
- auto it = _sessions.find(lsid);
+ auto it = _sessions.find(id.getLsid());
if (it == _sessions.end()) {
return {ErrorCodes::NoSuchSession, "No matching record in the sessions collection"};
}
@@ -119,7 +121,7 @@ StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::_fetchRecord(Logica
Status MockSessionsCollectionImpl::_insertRecord(LogicalSessionRecord record) {
stdx::unique_lock<stdx::mutex> lk(_mutex);
- auto res = _sessions.insert({record.getLsid(), std::move(record)});
+ auto res = _sessions.insert({record.getSignedLsid().getLsid(), std::move(record)});
// We should never try to insert the same record twice. In theory this could
// happen because of a UUID conflict.
diff --git a/src/mongo/db/sessions_collection_mock.h b/src/mongo/db/sessions_collection_mock.h
index 64e2a76fc7b..afa04bfd9db 100644
--- a/src/mongo/db/sessions_collection_mock.h
+++ b/src/mongo/db/sessions_collection_mock.h
@@ -59,7 +59,7 @@ public:
MockSessionsCollectionImpl();
- using FetchHook = stdx::function<StatusWith<LogicalSessionRecord>(LogicalSessionId)>;
+ using FetchHook = stdx::function<StatusWith<LogicalSessionRecord>(SignedLogicalSessionId)>;
using InsertHook = stdx::function<Status(LogicalSessionRecord)>;
using RefreshHook = stdx::function<LogicalSessionIdSet(LogicalSessionIdSet)>;
using RemoveHook = stdx::function<void(LogicalSessionIdSet)>;
@@ -74,7 +74,7 @@ public:
void clearHooks();
// Forwarding methods from the MockSessionsCollection
- StatusWith<LogicalSessionRecord> fetchRecord(LogicalSessionId lsid);
+ StatusWith<LogicalSessionRecord> fetchRecord(SignedLogicalSessionId id);
Status insertRecord(LogicalSessionRecord record);
LogicalSessionIdSet refreshSessions(LogicalSessionIdSet sessions);
void removeRecords(LogicalSessionIdSet sessions);
@@ -88,7 +88,7 @@ public:
private:
// Default implementations, may be overridden with custom hooks.
- StatusWith<LogicalSessionRecord> _fetchRecord(LogicalSessionId lsid);
+ StatusWith<LogicalSessionRecord> _fetchRecord(SignedLogicalSessionId id);
Status _insertRecord(LogicalSessionRecord record);
LogicalSessionIdSet _refreshSessions(LogicalSessionIdSet sessions);
void _removeRecords(LogicalSessionIdSet sessions);
@@ -112,8 +112,8 @@ public:
explicit MockSessionsCollection(std::shared_ptr<MockSessionsCollectionImpl> impl)
: _impl(std::move(impl)) {}
- StatusWith<LogicalSessionRecord> fetchRecord(LogicalSessionId lsid) override {
- return _impl->fetchRecord(std::move(lsid));
+ StatusWith<LogicalSessionRecord> fetchRecord(SignedLogicalSessionId id) override {
+ return _impl->fetchRecord(std::move(id));
}
Status insertRecord(LogicalSessionRecord record) override {
diff --git a/src/mongo/db/signed_logical_session_id.cpp b/src/mongo/db/signed_logical_session_id.cpp
new file mode 100644
index 00000000000..a284fb6344b
--- /dev/null
+++ b/src/mongo/db/signed_logical_session_id.cpp
@@ -0,0 +1,73 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/signed_logical_session_id.h"
+
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo {
+
+SignedLogicalSessionId SignedLogicalSessionId::gen() {
+ return SignedLogicalSessionId();
+}
+
+SignedLogicalSessionId::SignedLogicalSessionId() {
+ setLsid(LogicalSessionId::gen());
+}
+
+SignedLogicalSessionId::SignedLogicalSessionId(LogicalSessionId lsid,
+ boost::optional<OID> userId,
+ long long keyId,
+ SHA1Block signature) {
+ setLsid(std::move(lsid));
+ setUserId(std::move(userId));
+ setKeyId(std::move(keyId));
+ setSignature(std::move(signature));
+}
+
+SignedLogicalSessionId SignedLogicalSessionId::parse(const BSONObj& doc) {
+ IDLParserErrorContext ctx("signed logical session id");
+ SignedLogicalSessionId lsid;
+ lsid.parseProtected(ctx, doc);
+ return lsid;
+}
+
+BSONObj SignedLogicalSessionId::toBSON() const {
+ BSONObjBuilder builder;
+ serialize(&builder);
+ return builder.obj();
+}
+
+std::string SignedLogicalSessionId::toString() const {
+ return getLsid().toString();
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/signed_logical_session_id.h b/src/mongo/db/signed_logical_session_id.h
new file mode 100644
index 00000000000..632f645c623
--- /dev/null
+++ b/src/mongo/db/signed_logical_session_id.h
@@ -0,0 +1,115 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <string>
+
+#include "mongo/base/status_with.h"
+#include "mongo/bson/oid.h"
+#include "mongo/db/logical_session_id.h"
+#include "mongo/db/signed_logical_session_id_gen.h"
+#include "mongo/stdx/unordered_set.h"
+#include "mongo/util/uuid.h"
+
+namespace mongo {
+
+class BSONObjBuilder;
+class OperationContext;
+
+/**
+ * An identifier for a logical session. A LogicalSessionId has the following components:
+ *
+ * - A 128-bit unique identifier (UUID)
+ * - An optional user id (ObjectId)
+ * - A key id (long long)
+ * - An HMAC signature (SHA1Block)
+ */
+class SignedLogicalSessionId : public Signed_logical_session_id {
+public:
+ using Owner = boost::optional<OID>;
+
+ friend class Logical_session_id;
+ friend class Logical_session_record;
+
+ using keyIdType = long long;
+
+ /**
+ * Create and return a new LogicalSessionId with a random UUID. This method
+ * should be used for testing only. The generated SignedLogicalSessionId will
+ * not be signed, and will have no owner.
+ */
+ static SignedLogicalSessionId gen();
+
+ /**
+ * Creates a new SignedLogicalSessionId.
+ */
+ SignedLogicalSessionId(LogicalSessionId lsid,
+ boost::optional<OID> userId,
+ long long keyId,
+ SHA1Block signature);
+
+ /**
+ * Constructs a new LogicalSessionId out of a BSONObj. For IDL.
+ */
+ static SignedLogicalSessionId parse(const BSONObj& doc);
+
+ /**
+ * Returns a string representation of this session id.
+ */
+ std::string toString() const;
+
+ /**
+ * Serialize this object to BSON.
+ */
+ BSONObj toBSON() const;
+
+ inline bool operator==(const SignedLogicalSessionId& rhs) const {
+ return getLsid() == rhs.getLsid() && getUserId() == rhs.getUserId() &&
+ getKeyId() == rhs.getKeyId() && getSignature() == rhs.getSignature();
+ }
+
+ inline bool operator!=(const SignedLogicalSessionId& rhs) const {
+ return !(*this == rhs);
+ }
+
+ /**
+ * This constructor exists for IDL only.
+ */
+ SignedLogicalSessionId();
+};
+
+inline std::ostream& operator<<(std::ostream& s, const SignedLogicalSessionId& lsid) {
+ return (s << lsid.toString());
+}
+
+inline StringBuilder& operator<<(StringBuilder& s, const SignedLogicalSessionId& lsid) {
+ return (s << lsid.toString());
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/signed_logical_session_id.idl b/src/mongo/db/signed_logical_session_id.idl
new file mode 100644
index 00000000000..634d4731a52
--- /dev/null
+++ b/src/mongo/db/signed_logical_session_id.idl
@@ -0,0 +1,50 @@
+# Copyright (C) 2017 MongoDB Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License, version 3,
+# as published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+
+# This IDL file describes the BSON format for a LogicalSessionId, and
+# handles the serialization to and deserialization from its BSON representation
+# for that class.
+
+global:
+ cpp_namespace: "mongo"
+ cpp_includes:
+ - "mongo/util/uuid.h"
+ - "mongo/db/logical_session_id.h"
+
+imports:
+ - "mongo/crypto/sha1_block.idl"
+ - "mongo/db/logical_session_id.idl"
+ - "mongo/idl/basic_types.idl"
+
+types:
+ LogicalSessionIdIdl:
+ description: "IDL representation of the LogicalSessionId cpp type"
+ bson_serialization_type: object
+ cpp_type: "mongo::LogicalSessionId"
+ deserializer: "mongo::LogicalSessionId::parse"
+ serializer: "mongo::LogicalSessionId::toBSON"
+
+structs:
+
+ signed_logical_session_id:
+ description: "A struct representing a SignedLogicalSessionId"
+ strict: true
+ fields:
+ lsid: LogicalSessionIdIdl
+ userId:
+ optional: true
+ type: objectid
+ keyId: long
+ signature: sha1Block
diff --git a/src/mongo/db/signed_logical_session_id_test.cpp b/src/mongo/db/signed_logical_session_id_test.cpp
new file mode 100644
index 00000000000..2622d65f2f8
--- /dev/null
+++ b/src/mongo/db/signed_logical_session_id_test.cpp
@@ -0,0 +1,81 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include <string>
+
+#include "mongo/bson/bsonmisc.h"
+#include "mongo/bson/bsonobj.h"
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/bson/bsontypes.h"
+#include "mongo/crypto/sha1_block.h"
+#include "mongo/db/logical_session_id.h"
+#include "mongo/db/signed_logical_session_id.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/uuid.h"
+
+namespace mongo {
+namespace {
+
+TEST(SignedLogicalSessionIdTest, ConstructWithLsid) {
+ auto lsid = LogicalSessionId::gen();
+ SignedLogicalSessionId slsid(lsid, boost::none, 1, SHA1Block{});
+ ASSERT_EQ(slsid.getLsid(), lsid);
+}
+
+TEST(SignedLogicalSessionIdTest, FromBSONTest) {
+ auto lsid = LogicalSessionId::gen();
+
+ BSONObjBuilder b;
+ b.append("lsid", lsid.toBSON());
+ b.append("keyId", 4ll);
+ char buffer[SHA1Block::kHashLength] = {0};
+ b.appendBinData("signature", SHA1Block::kHashLength, BinDataGeneral, buffer);
+ auto bson = b.done();
+
+ auto slsid = SignedLogicalSessionId::parse(bson);
+ ASSERT_EQ(slsid.getLsid(), lsid);
+
+ // Dump back to BSON, make sure we get the same thing
+ auto bsonDump = slsid.toBSON();
+ ASSERT_EQ(bsonDump.woCompare(bson), 0);
+
+ // Try parsing mal-formatted bson objs
+ ASSERT_THROWS(SignedLogicalSessionId::parse(BSON("hi"
+ << "there")),
+ UserException);
+
+ ASSERT_THROWS(SignedLogicalSessionId::parse(BSON("lsid"
+ << "not a session id!")),
+ UserException);
+ ASSERT_THROWS(SignedLogicalSessionId::parse(BSON("lsid" << 14)), UserException);
+}
+
+} // namespace
+} // namespace mongo